#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import os
import sys
import time
from dataclasses import asdict
from pathlib import Path
from typing import Any


ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))


def early_backend_from_argv(argv: list[str]) -> str:
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument("--backend", default="auto")
    known, _ = parser.parse_known_args(argv)
    backend = known.backend
    if backend and backend != "auto":
        os.environ.setdefault("JAX_PLATFORM_NAME", backend)
    return backend


early_backend_from_argv(sys.argv[1:])

import jax

from arithmetic_transformer.benchmarking import (
    BenchmarkConfig,
    append_jsonl,
    benchmark_runtime,
    build_summary,
    environment_record,
    evaluate_variant,
    failure_record,
    prediction_records,
    summarize_correctness,
    write_json,
    write_notes,
)
from arithmetic_transformer.inference_variants import AVAILABLE_VARIANTS, make_variant_runtime
from arithmetic_transformer.model import (
    DEFAULT_MODEL_PATH,
    ArithmeticTransformer,
    TransformerConfig,
    cast_params,
    count_params,
    load_params,
)
from arithmetic_transformer.prompt_sets import build_prompt_sets


def comma_list(text: str) -> list[str]:
    return [part.strip() for part in text.split(",") if part.strip()]


def int_comma_list(text: str) -> list[int]:
    values = [int(part.strip()) for part in text.split(",") if part.strip()]
    if not values or any(value < 1 for value in values):
        raise argparse.ArgumentTypeError("expected comma-separated positive integers")
    return values


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Benchmark inference variants for the arithmetic transformer."
    )
    parser.add_argument(
        "--params-path",
        type=Path,
        default=DEFAULT_MODEL_PATH,
        help=f"Path to msgpack params. Default: {DEFAULT_MODEL_PATH}",
    )
    parser.add_argument(
        "--run-name",
        default="",
        help="Run directory name. Default: inference_<unix timestamp>.",
    )
    parser.add_argument(
        "--output-dir",
        type=Path,
        default=ROOT / "runs" / "inference_benchmarks",
        help="Directory where benchmark run folders are written.",
    )
    parser.add_argument(
        "--mode",
        choices=("quick", "full"),
        default="quick",
        help="quick uses small debug settings; full uses article-grade defaults.",
    )
    parser.add_argument(
        "--batch-sizes",
        type=int_comma_list,
        default=None,
        help="Comma-separated batch sizes. Defaults: quick=1,8; full=1,8,32,128,512,1024.",
    )
    parser.add_argument(
        "--dtypes",
        type=comma_list,
        default=None,
        help="Comma-separated dtypes. Defaults: quick=float32; full=float32,bfloat16.",
    )
    parser.add_argument(
        "--variants",
        type=comma_list,
        default=None,
        help=(
            "Comma-separated variants. Available: "
            + ",".join(AVAILABLE_VARIANTS)
            + ". Defaults: quick=eager,jit_step,jit_full,kv_cache; full=all."
        ),
    )
    parser.add_argument(
        "--num-prompts",
        type=int,
        default=None,
        help="Number of random balanced benchmark prompts. Defaults: quick=128; full=10000.",
    )
    parser.add_argument(
        "--warmup-iters",
        type=int,
        default=None,
        help="Warmup batches per block. Defaults: quick=1; full=5.",
    )
    parser.add_argument(
        "--measure-iters",
        type=int,
        default=None,
        help="Measured batches per block. Defaults: quick=3; full=20.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="Random seed for deterministic prompt sets.",
    )
    parser.add_argument(
        "--require-correctness",
        action="store_true",
        help="Mark speed results invalid unless every correctness prompt is correct against Python arithmetic.",
    )
    parser.add_argument(
        "--backend",
        default="auto",
        help="Requested JAX backend, for example auto, cpu, gpu, or metal. Default: auto.",
    )
    parser.add_argument(
        "--correctness-random-limit",
        type=int,
        default=None,
        help="Random prompts used for correctness. Defaults: quick=64; full=10000.",
    )
    parser.add_argument(
        "--correctness-edge-limit",
        type=int,
        default=None,
        help="Edge prompts used for correctness and samples. Defaults: quick=32; full=all edge prompts.",
    )
    parser.add_argument(
        "--reuse-correctness-across-batches",
        action="store_true",
        help="Evaluate correctness once per variant/dtype and reuse it for all batch-size timing rows.",
    )
    parser.add_argument(
        "--correctness-eval-batch-size",
        type=int,
        default=None,
        help="Batch size used for the reusable correctness pass. Default: largest requested batch size.",
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        help="Append to an existing run directory and skip completed dtype/variant/batch rows.",
    )
    return parser


def resolve_args(args: argparse.Namespace) -> BenchmarkConfig:
    if args.mode == "quick":
        batch_sizes = args.batch_sizes if args.batch_sizes is not None else [1, 8]
        dtypes = args.dtypes if args.dtypes is not None else ["float32"]
        variants = args.variants if args.variants is not None else ["eager", "jit_step", "jit_full", "kv_cache"]
        num_prompts = args.num_prompts if args.num_prompts is not None else 128
        warmup_iters = args.warmup_iters if args.warmup_iters is not None else 1
        measure_iters = args.measure_iters if args.measure_iters is not None else 3
        correctness_random_limit = (
            args.correctness_random_limit if args.correctness_random_limit is not None else 64
        )
        correctness_edge_limit = args.correctness_edge_limit if args.correctness_edge_limit is not None else 32
    else:
        batch_sizes = args.batch_sizes if args.batch_sizes is not None else [1, 8, 32, 128, 512, 1024]
        dtypes = args.dtypes if args.dtypes is not None else ["float32", "bfloat16"]
        variants = args.variants if args.variants is not None else list(AVAILABLE_VARIANTS)
        num_prompts = args.num_prompts if args.num_prompts is not None else 10000
        warmup_iters = args.warmup_iters if args.warmup_iters is not None else 5
        measure_iters = args.measure_iters if args.measure_iters is not None else 20
        correctness_random_limit = (
            args.correctness_random_limit if args.correctness_random_limit is not None else 10000
        )
        correctness_edge_limit = args.correctness_edge_limit

    unknown = [variant for variant in variants if variant not in AVAILABLE_VARIANTS]
    if unknown:
        raise SystemExit(f"Unknown variants: {unknown}. Available: {AVAILABLE_VARIANTS}")
    for dtype in dtypes:
        if dtype not in ("float32", "bfloat16", "float16"):
            raise SystemExit(f"Unsupported dtype: {dtype}")
    if num_prompts < 1:
        raise SystemExit("--num-prompts must be at least 1")
    if warmup_iters < 0 or measure_iters < 1:
        raise SystemExit("--warmup-iters must be >= 0 and --measure-iters must be >= 1")
    if correctness_random_limit < 0:
        raise SystemExit("--correctness-random-limit must be >= 0")
    if correctness_edge_limit is not None and correctness_edge_limit < 0:
        raise SystemExit("--correctness-edge-limit must be >= 0")
    if args.correctness_eval_batch_size is not None and args.correctness_eval_batch_size < 1:
        raise SystemExit("--correctness-eval-batch-size must be >= 1")

    run_name = args.run_name or f"inference_{int(time.time())}"
    return BenchmarkConfig(
        params_path=str(args.params_path.expanduser().resolve()),
        run_name=run_name,
        output_dir=str(args.output_dir.expanduser().resolve()),
        mode=args.mode,
        batch_sizes=batch_sizes,
        dtypes=dtypes,
        variants=variants,
        num_prompts=num_prompts,
        warmup_iters=warmup_iters,
        measure_iters=measure_iters,
        seed=args.seed,
        require_correctness=args.require_correctness,
        backend=args.backend,
        correctness_random_limit=correctness_random_limit,
        correctness_edge_limit=correctness_edge_limit,
        reuse_correctness_across_batches=args.reuse_correctness_across_batches,
        correctness_eval_batch_size=args.correctness_eval_batch_size,
    )


def print_environment(env: dict[str, Any]) -> None:
    print("Environment summary")
    print(f"  backend: {env.get('backend')}")
    print(f"  devices: {env.get('visible_devices')}")
    print(f"  device names: {env.get('device_names')}")
    print(f"  jax: {env.get('jax_version')} flax: {env.get('flax_version')}")


def main(argv: list[str] | None = None) -> int:
    parser = build_parser()
    args = parser.parse_args(argv)
    config = resolve_args(args)

    run_dir = Path(config.output_dir) / config.run_name
    run_dir.mkdir(parents=True, exist_ok=True)
    results_path = run_dir / "results.jsonl"
    failures_path = run_dir / "failures.jsonl"
    if not config.run_name:
        raise SystemExit("--run-name must not be empty after config resolution")
    if args.resume:
        results_path.touch()
        failures_path.touch()
    else:
        failures_path.write_text("", encoding="utf-8")
        results_path.write_text("", encoding="utf-8")

    config_path = run_dir / "config.json"
    if args.resume and config_path.exists():
        append_jsonl(run_dir / "resume_invocations.jsonl", asdict(config))
    else:
        write_json(config_path, asdict(config))
    env = environment_record(ROOT.parent, config.backend)
    write_json(run_dir / "environment.json", env)

    print_environment(env)
    print(f"Loaded benchmark config: {run_dir}")
    print(f"Selected variants: {config.variants}")
    print(f"Selected batch sizes: {config.batch_sizes}")
    print(f"Selected dtypes: {config.dtypes}")

    prompt_sets = build_prompt_sets(config.num_prompts, config.seed)
    correctness_random = prompt_sets.random_balanced[: config.correctness_random_limit]
    correctness_edge = (
        prompt_sets.edge[: config.correctness_edge_limit]
        if config.correctness_edge_limit is not None
        else prompt_sets.edge
    )
    correctness_problems = []
    seen = set()
    for problem in [*prompt_sets.sanity, *correctness_edge, *correctness_random]:
        if problem not in seen:
            seen.add(problem)
            correctness_problems.append(problem)
    sample_problems = [*prompt_sets.sanity, *correctness_edge]
    print(f"Prompt set size: random={len(prompt_sets.random_balanced)} correctness={len(correctness_problems)}")

    model = ArithmeticTransformer(TransformerConfig())
    params_path = Path(config.params_path)
    print(f"Loading model params: {params_path}")
    load_start = time.perf_counter()
    params = load_params(params_path, model)
    print(f"Loaded model in {time.perf_counter() - load_start:.3f}s")
    print(f"Parameter count: {count_params(params)}")

    all_results: list[dict[str, Any]] = []
    all_failures: list[dict[str, Any]] = []
    sample_predictions: list[dict[str, Any]] = []
    reference_texts_by_key: dict[tuple[str, int], list[str]] = {}
    correctness_cache: dict[tuple[str, str], dict[str, Any]] = {}
    completed_keys: set[tuple[str, str, int]] = set()
    existing_results: list[dict[str, Any]] = []
    existing_failures: list[dict[str, Any]] = []
    existing_sample_predictions: list[dict[str, Any]] = []
    if args.resume and results_path.exists():
        for line in results_path.read_text(encoding="utf-8").splitlines():
            if not line.strip():
                continue
            try:
                existing = json.loads(line)
            except Exception:
                continue
            existing_results.append(existing)
            completed_keys.add(
                (
                    str(existing.get("dtype")),
                    str(existing.get("implementation_name")),
                    int(existing.get("batch_size")),
                )
            )
        if completed_keys:
            print(f"Resume mode: skipping {len(completed_keys)} completed result rows")
    if args.resume and failures_path.exists():
        for line in failures_path.read_text(encoding="utf-8").splitlines():
            if line.strip():
                try:
                    existing_failures.append(json.loads(line))
                except Exception:
                    pass
    sample_predictions_path = run_dir / "sample_predictions.json"
    if args.resume and sample_predictions_path.exists():
        try:
            existing_sample_predictions = json.loads(sample_predictions_path.read_text(encoding="utf-8"))
        except Exception:
            existing_sample_predictions = []

    def load_correctness_override(
        dtype: str,
        variant_name: str,
    ) -> dict[str, Any] | None:
        correctness_path = run_dir / f"correctness_{variant_name}_{dtype}.json"
        if not correctness_path.exists():
            return None
        try:
            payload = json.loads(correctness_path.read_text(encoding="utf-8"))
        except Exception as exc:
            print(f"Could not load cached correctness {correctness_path}: {exc}")
            return None
        records = payload.get("records") or []
        sample_records = records[: len(sample_problems)]
        if not sample_records:
            for entry in existing_sample_predictions:
                if (
                    entry.get("implementation_name") == variant_name
                    and entry.get("dtype") == dtype
                ):
                    sample_records = entry.get("samples") or []
                    break
        generated_texts = [
            record.get("model_text", "")
            for record in records
            if isinstance(record, dict)
        ]
        correctness = payload["correctness"]
        print(
            "Loaded cached correctness: "
            f"variant={variant_name} dtype={dtype} "
            f"examples={correctness.get('number_total')} "
            f"path={correctness_path}"
        )
        return {
            "correctness": correctness,
            "sample_records": sample_records,
            "generated_texts": generated_texts,
            "correctness_eval_batch_size": payload.get("correctness_eval_batch_size"),
            "correctness_path": str(correctness_path),
        }

    def compute_correctness_override(
        runtime,
        dtype: str,
        variant_name: str,
        batch_size: int,
        reference_texts: list[str] | None,
    ) -> dict[str, Any]:
        print(
            "Correctness pass: "
            f"variant={variant_name} dtype={dtype} batch={batch_size} "
            f"examples={len(correctness_problems)}"
        )
        correctness_records, generated_texts = evaluate_variant(runtime, correctness_problems, batch_size)
        correctness = summarize_correctness(correctness_records, reference_texts, generated_texts)
        if correctness_problems[: len(sample_problems)] == sample_problems:
            sample_records = correctness_records[: len(sample_problems)]
        else:
            sample_records = prediction_records(sample_problems, generated_texts[: len(sample_problems)])
        correctness_path = run_dir / f"correctness_{variant_name}_{dtype}.json"
        write_json(
            correctness_path,
            {
                "implementation_name": variant_name,
                "dtype": dtype,
                "correctness_eval_batch_size": batch_size,
                "correctness": correctness,
                "records": correctness_records,
            },
        )
        return {
            "correctness": correctness,
            "sample_records": sample_records,
            "generated_texts": generated_texts,
            "correctness_eval_batch_size": batch_size,
            "correctness_path": str(correctness_path),
        }

    for dtype in config.dtypes:
        print(f"Preparing dtype: {dtype}")
        try:
            dtype_params = cast_params(params, dtype)
        except Exception as exc:
            for variant in config.variants:
                for batch_size in config.batch_sizes:
                    failure = failure_record(config.run_name, dtype, variant, batch_size, exc)
                    all_failures.append(failure)
                    append_jsonl(failures_path, failure)
            continue

        for batch_size in config.batch_sizes:
            for variant_name in config.variants:
                if (dtype, variant_name, batch_size) in completed_keys:
                    print(f"Skip completed: variant={variant_name} dtype={dtype} batch={batch_size}")
                    continue
                print(f"Start benchmark: variant={variant_name} dtype={dtype} batch={batch_size}")
                try:
                    runtime = make_variant_runtime(
                        variant_name,
                        dtype_params,
                        model,
                        TransformerConfig(),
                        dtype,
                    )
                    reference = reference_texts_by_key.get((dtype, batch_size))
                    correctness_override = None
                    if config.reuse_correctness_across_batches:
                        cache_key = (dtype, variant_name)
                        correctness_override = correctness_cache.get(cache_key)
                        if correctness_override is None:
                            correctness_override = load_correctness_override(dtype, variant_name)
                        if correctness_override is None:
                            correctness_batch_size = config.correctness_eval_batch_size or max(config.batch_sizes)
                            correctness_reference = reference_texts_by_key.get((dtype, correctness_batch_size))
                            correctness_override = compute_correctness_override(
                                runtime,
                                dtype,
                                variant_name,
                                correctness_batch_size,
                                correctness_reference,
                            )
                            correctness_cache[cache_key] = correctness_override
                            if variant_name == "eager":
                                reference_texts_by_key[(dtype, correctness_batch_size)] = correctness_override["generated_texts"]
                                correctness_override["correctness"]["reference_match_accuracy"] = 1.0
                                correctness_override["correctness"]["reference_matches"] = correctness_override["correctness"]["number_total"]
                                correctness_override["correctness"]["reference_total"] = correctness_override["correctness"]["number_total"]
                                correctness_override["correctness"]["matches_reference"] = True
                    result, samples, generated_texts = benchmark_runtime(
                        runtime=runtime,
                        run_name=config.run_name,
                        params_path=config.params_path,
                        dtype=dtype,
                        batch_size=batch_size,
                        warmup_iters=config.warmup_iters,
                        measure_iters=config.measure_iters,
                        random_problems=prompt_sets.random_balanced,
                        correctness_problems=correctness_problems,
                        sample_problems=sample_problems,
                        environment=env,
                        require_correctness=config.require_correctness,
                        reference_texts=reference,
                        correctness_override=correctness_override,
                    )
                    if variant_name == "eager":
                        reference_texts_by_key[(dtype, batch_size)] = generated_texts
                        result["reference_match_accuracy"] = 1.0
                        result["reference_matches"] = result["number_total"]
                        result["reference_total"] = result["number_total"]
                        result["valid_for_speed_comparison"] = bool(result["ground_truth_all_correct"])
                    all_results.append(result)
                    append_jsonl(results_path, result)
                    sample_predictions.append(
                        {
                            "implementation_name": variant_name,
                            "dtype": dtype,
                            "batch_size": batch_size,
                            "samples": samples,
                        }
                    )
                    print(
                        "End benchmark: "
                        f"compile={result['compile_time']} "
                        f"accuracy={result['exact_accuracy']:.6f} "
                        f"eps={result['examples_per_second']:.6f} "
                        f"tps={result['generated_tokens_per_second']:.6f}"
                    )
                except Exception as exc:
                    failure = failure_record(config.run_name, dtype, variant_name, batch_size, exc)
                    all_failures.append(failure)
                    append_jsonl(failures_path, failure)
                    print(f"Failure: variant={variant_name} dtype={dtype} batch={batch_size}: {exc}")

    combined_results = [*existing_results, *all_results]
    combined_failures = [*existing_failures, *all_failures]
    summary = build_summary(combined_results, combined_failures, config.require_correctness)
    write_json(run_dir / "sample_predictions.json", [*existing_sample_predictions, *sample_predictions])
    write_json(run_dir / "summary.json", summary)
    write_notes(run_dir / "notes.md", summary, combined_results)

    best = summary.get("best_valid_implementation_by_examples_per_second")
    print(f"Saved results: {run_dir}")
    if best:
        baseline = next(
            (
                item
                for item in all_results
                if item["implementation_name"] == "eager"
                and item["dtype"] == best["dtype"]
                and item["batch_size"] == best["batch_size"]
            ),
            None,
        )
        baseline_eps = baseline.get("examples_per_second") if baseline else None
        speedup = best["examples_per_second"] / baseline_eps if baseline_eps else None
        print("Final report")
        print(f"  best valid variant: {best['implementation_name']} dtype={best['dtype']} batch={best['batch_size']}")
        print(f"  baseline throughput: {baseline_eps}")
        print(f"  optimized throughput: {best['examples_per_second']}")
        print(f"  speedup: {speedup}")
        print(f"  accuracy: {best['exact_accuracy']}")
        print(f"  output directory: {run_dir}")
    else:
        print("Final report")
        print("  best valid variant: none")
        print(f"  output directory: {run_dir}")

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
