#!/usr/bin/env python3
from __future__ import annotations

import argparse
import os
import sys
import time
import traceback
from dataclasses import asdict
from pathlib import Path
from typing import Any


ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))


def early_backend_from_argv(argv: list[str]) -> None:
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument("--backend", default="auto")
    known, _ = parser.parse_known_args(argv)
    if known.backend and known.backend != "auto":
        os.environ.setdefault("JAX_PLATFORM_NAME", known.backend)


early_backend_from_argv(sys.argv[1:])

from arithmetic_transformer.benchmarking import append_jsonl, environment_record, write_json
from arithmetic_transformer.diagnostics import (
    DiagnosticConfig,
    benchmark_generation_grid,
    build_diagnostic_summary,
    run_attention_microbenchmarks,
    run_compilation_behavior,
    run_cpu_gpu_crossover,
    run_memory_behavior,
    run_prefix_recomputation_analysis,
    run_production_inspired_simulations,
    run_profiling_experiments,
    write_diagnostic_notes,
    write_optimization_note,
)
from arithmetic_transformer.inference_variants import make_variant_runtime
from arithmetic_transformer.model import (
    DEFAULT_MODEL_PATH,
    ArithmeticTransformer,
    TransformerConfig,
    cast_params,
    count_params,
    load_params,
    make_prompt_tokens,
)
from arithmetic_transformer.prompt_sets import build_prompt_sets


def comma_list(text: str) -> list[str]:
    return [part.strip() for part in text.split(",") if part.strip()]


def int_comma_list(text: str) -> list[int]:
    values = [int(part.strip()) for part in text.split(",") if part.strip()]
    if not values or any(value < 1 for value in values):
        raise argparse.ArgumentTypeError("expected comma-separated positive integers")
    return values


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Run diagnostic inference optimization experiments."
    )
    parser.add_argument("--params-path", type=Path, default=DEFAULT_MODEL_PATH)
    parser.add_argument("--run-name", default="")
    parser.add_argument(
        "--output-dir",
        type=Path,
        default=ROOT / "runs" / "inference_diagnostics",
    )
    parser.add_argument("--mode", choices=("quick", "full"), default="quick")
    parser.add_argument("--batch-sizes", type=int_comma_list, default=None)
    parser.add_argument("--sequence-lengths", type=int_comma_list, default=None)
    parser.add_argument("--dtypes", type=comma_list, default=None)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--warmup-iters", type=int, default=None)
    parser.add_argument("--measure-iters", type=int, default=None)
    parser.add_argument("--backend", default="auto")
    return parser


def resolve_config(args: argparse.Namespace) -> DiagnosticConfig:
    if args.mode == "quick":
        batch_sizes = args.batch_sizes if args.batch_sizes is not None else [1, 8, 32]
        sequence_lengths = args.sequence_lengths if args.sequence_lengths is not None else [19, 32, 64, 128, 256]
        dtypes = args.dtypes if args.dtypes is not None else ["float32", "bfloat16"]
        warmup_iters = args.warmup_iters if args.warmup_iters is not None else 0
        measure_iters = args.measure_iters if args.measure_iters is not None else 1
    else:
        batch_sizes = args.batch_sizes if args.batch_sizes is not None else [1, 8, 32, 128, 512, 1024]
        sequence_lengths = args.sequence_lengths if args.sequence_lengths is not None else [19, 32, 64, 128, 256]
        dtypes = args.dtypes if args.dtypes is not None else ["float32", "bfloat16", "float16"]
        warmup_iters = args.warmup_iters if args.warmup_iters is not None else 3
        measure_iters = args.measure_iters if args.measure_iters is not None else 10
    for dtype in dtypes:
        if dtype not in ("float32", "bfloat16", "float16"):
            raise SystemExit(f"Unsupported dtype: {dtype}")
    if warmup_iters < 0 or measure_iters < 1:
        raise SystemExit("--warmup-iters must be >= 0 and --measure-iters must be >= 1")
    run_name = args.run_name or f"diagnostics_{int(time.time())}"
    return DiagnosticConfig(
        params_path=str(args.params_path.expanduser().resolve()),
        run_name=run_name,
        output_dir=str(args.output_dir.expanduser().resolve()),
        mode=args.mode,
        batch_sizes=batch_sizes,
        sequence_lengths=sequence_lengths,
        dtypes=dtypes,
        seed=args.seed,
        warmup_iters=warmup_iters,
        measure_iters=measure_iters,
        backend=args.backend,
    )


def failure(run_name: str, experiment: str, exc: BaseException) -> dict[str, Any]:
    return {
        "run_name": run_name,
        "timestamp": time.time(),
        "experiment": experiment,
        "error_type": type(exc).__name__,
        "error_message": str(exc),
        "traceback": traceback.format_exc(),
    }


def main(argv: list[str] | None = None) -> int:
    parser = build_parser()
    config = resolve_config(parser.parse_args(argv))
    run_dir = Path(config.output_dir) / config.run_name
    run_dir.mkdir(parents=True, exist_ok=True)
    failures_path = run_dir / "failures.jsonl"
    failures_path.write_text("", encoding="utf-8")
    notes_path = run_dir / "optimization_notes.jsonl"
    notes_path.write_text("", encoding="utf-8")

    write_json(run_dir / "config.json", asdict(config))
    env = environment_record(ROOT.parent, config.backend)
    write_json(run_dir / "environment.json", env)

    print("Environment summary")
    print(f"  backend: {env.get('backend')}")
    print(f"  devices: {env.get('visible_devices')}")
    print(f"  backend probe: {env.get('backend_probe')}")
    print(f"  run dir: {run_dir}")

    print("Loading model params")
    model = ArithmeticTransformer(TransformerConfig())
    params = load_params(Path(config.params_path), model)
    print(f"  params: {count_params(params)}")

    dtype_name = config.dtypes[0]
    params_for_generation = cast_params(params, dtype_name)
    prompt_sets = build_prompt_sets(max(max(config.batch_sizes), 64), config.seed)
    correctness_problems = prompt_sets.sanity[:]

    prompt_batches = {}
    for batch_size in config.batch_sizes:
        problems = prompt_sets.random_balanced[:batch_size]
        if len(problems) < batch_size:
            problems = (problems + prompt_sets.sanity)[:batch_size]
        prompt_batches[batch_size] = make_prompt_tokens(problems)

    runtimes = {
        name: make_variant_runtime(name, params_for_generation, model, TransformerConfig(), dtype_name)
        for name in ["jit_step", "jit_full", "kv_cache", "quant_int8"]
    }

    failures: list[dict[str, Any]] = []
    attention_rows: list[dict[str, Any]] = []
    memory_rows: list[dict[str, Any]] = []
    batch_rows: list[dict[str, Any]] = []

    try:
        print("Running compilation behavior experiments")
        run_compilation_behavior(
            run_dir,
            model,
            params_for_generation,
            runtimes,
            prompt_batches,
            config.warmup_iters,
            config.measure_iters,
        )
        write_optimization_note(
            notes_path,
            "compilation_behavior",
            "First calls and new shapes should be slower due to compilation.",
            "Static-shape first-vs-steady rows and dynamic-prefix shape rows were saved.",
            "useful",
            "JAX compilation is shape-specialized; varying token buffer shapes is expected to trigger separate compilations.",
        )
    except Exception as exc:
        item = failure(config.run_name, "compilation_behavior", exc)
        failures.append(item)
        append_jsonl(failures_path, item)
        print(f"Compilation behavior failed: {exc}")

    try:
        print("Running attention and sequence-length scaling experiments")
        attention_rows = run_attention_microbenchmarks(
            run_dir,
            config.sequence_lengths,
            config.dtypes,
            config.warmup_iters,
            config.measure_iters,
            config.seed,
        )
        write_optimization_note(
            notes_path,
            "sequence_length_scaling_and_attention_microbenchmarks",
            "Full-prefix attention should scale much worse than one-token KV attention as sequence length increases.",
            "Measured attention rows and theoretical FLOP estimates were saved for every configured sequence length.",
            "useful",
            "This isolates attention cost beyond the trained model's max sequence length without pretending the model is valid at those lengths.",
        )
    except Exception as exc:
        item = failure(config.run_name, "attention_sequence_scaling", exc)
        failures.append(item)
        append_jsonl(failures_path, item)
        print(f"Attention scaling failed: {exc}")

    try:
        print("Running profiling experiments")
        run_profiling_experiments(run_dir, runtimes, prompt_batches, config.warmup_iters, config.measure_iters)
        write_optimization_note(
            notes_path,
            "profiling_experiments",
            "Compiled variants should spend less Python time than Python-loop generation, while first calls expose compilation and dispatch costs.",
            "Profiling artifacts were saved. CPU backend does not expose separate kernel-launch timing.",
            "useful",
            "cProfile captures host overhead, and trivial JIT timing gives a dispatch/synchronization proxy.",
        )
    except Exception as exc:
        item = failure(config.run_name, "profiling_experiments", exc)
        failures.append(item)
        append_jsonl(failures_path, item)
        print(f"Profiling failed: {exc}")

    try:
        print("Running memory behavior experiments")
        memory_rows = run_memory_behavior(
            run_dir,
            params_for_generation,
            TransformerConfig(),
            runtimes,
            prompt_batches,
            dtype_name,
        )
        write_optimization_note(
            notes_path,
            "memory_behavior",
            "KV-cache should add predictable cache memory while avoiding recomputation.",
            "Process RSS and estimated parameter/KV-cache memory rows were saved.",
            "partially useful",
            "CPU process RSS is coarse, but the estimated cache sizes are deterministic.",
        )
    except Exception as exc:
        item = failure(config.run_name, "memory_behavior", exc)
        failures.append(item)
        append_jsonl(failures_path, item)
        print(f"Memory behavior failed: {exc}")

    try:
        print("Running throughput/latency and batch-scaling experiments")
        batch_rows = benchmark_generation_grid(
            run_dir,
            runtimes,
            prompt_batches,
            correctness_problems,
            config.warmup_iters,
            config.measure_iters,
        )
        write_optimization_note(
            notes_path,
            "throughput_latency_and_batch_scaling",
            "Larger batches should improve throughput but increase batch latency.",
            "Batch scaling and efficiency rows were saved.",
            "useful",
            "The same prompt buffers and variants are used across batch sizes.",
        )
    except Exception as exc:
        item = failure(config.run_name, "throughput_latency_batch_scaling", exc)
        failures.append(item)
        append_jsonl(failures_path, item)
        print(f"Batch scaling failed: {exc}")

    try:
        print("Running prefix recomputation cost analysis")
        run_prefix_recomputation_analysis(run_dir, batch_rows, config.sequence_lengths, TransformerConfig())
        write_optimization_note(
            notes_path,
            "prefix_recomputation_cost_analysis",
            "The theoretical attention work should show large wasted computation for full-prefix recomputation.",
            "Theoretical FLOP rows and measured KV-vs-full comparison were saved.",
            "useful",
            "The estimate isolates attention, while measured latency includes all model and runtime costs.",
        )
    except Exception as exc:
        item = failure(config.run_name, "prefix_recomputation_cost_analysis", exc)
        failures.append(item)
        append_jsonl(failures_path, item)
        print(f"Prefix analysis failed: {exc}")

    try:
        print("Running production-inspired simulations")
        run_production_inspired_simulations(run_dir, batch_rows, config.sequence_lengths, config.seed)
        write_optimization_note(
            notes_path,
            "production_inspired_simulations",
            "Continuous batching, paged KV allocation, and prefix caching should expose serving tradeoffs even without production kernels.",
            "Scheduling and cache-allocation simulations were saved.",
            "article context only",
            "These are intentionally labeled simulations rather than faithful vLLM/SGLang implementations.",
        )
    except Exception as exc:
        item = failure(config.run_name, "production_inspired_simulations", exc)
        failures.append(item)
        append_jsonl(failures_path, item)
        print(f"Production simulation failed: {exc}")

    try:
        print("Running CPU/GPU crossover check")
        run_cpu_gpu_crossover(run_dir, batch_rows)
        write_optimization_note(
            notes_path,
            "cpu_gpu_crossover",
            "If a GPU backend is available, large batches or long sequences may justify launch overhead.",
            "The available backend check was saved.",
            "useful" if isinstance(env.get("backend_probe", {}).get("gpu"), list) else "blocked on hardware",
            "This local run only had the active backend available unless environment.json reports GPU devices.",
        )
    except Exception as exc:
        item = failure(config.run_name, "cpu_gpu_crossover", exc)
        failures.append(item)
        append_jsonl(failures_path, item)
        print(f"CPU/GPU crossover failed: {exc}")

    summary = build_diagnostic_summary(run_dir, batch_rows, attention_rows, memory_rows, failures)
    write_diagnostic_notes(run_dir, summary)
    print("Saved diagnostics")
    print(f"  run dir: {run_dir}")
    if summary.get("highest_throughput_configuration"):
        best = summary["highest_throughput_configuration"]
        print(
            "  highest throughput: "
            f"{best['implementation_name']} batch={best['batch_size']} "
            f"eps={best['examples_per_second']}"
        )
    if summary.get("lowest_per_example_latency_configuration"):
        best = summary["lowest_per_example_latency_configuration"]
        print(
            "  lowest per-example latency: "
            f"{best['implementation_name']} batch={best['batch_size']} "
            f"seconds={best['per_example_latency']}"
        )
    print(f"  failures: {len(failures)}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
