#!/usr/bin/env python3
"""Train the mixed arithmetic transformer outside the notebook.

This script is intentionally close to the Colab notebook, but adds practical
experiment logging for longer runs:

    python scripts/train_arithmetic_model.py --run-name carry_curriculum

It writes metrics, summaries, predictions, and params under the selected output
directory. The saved params are compatible with scripts/test_arithmetic_model.py.
"""

from __future__ import annotations

import argparse
import json
import os
import platform
import subprocess
import sys
import time
from dataclasses import asdict
from pathlib import Path
from typing import Any

try:
    import jax
    import jax.numpy as jnp
    import numpy as np
    import optax
    from flax.serialization import from_bytes, to_bytes
    from flax.training import train_state
except ModuleNotFoundError as exc:
    missing = exc.name or "a required package"
    print(
        f"Missing dependency: {missing}\n\n"
        "Install training dependencies with:\n"
        '  python -m pip install -U "jax[cuda12]" flax optax\n',
        file=sys.stderr,
    )
    raise SystemExit(1) from exc

from test_arithmetic_model import (
    ANSWER_LEN,
    CHAR_TO_ID,
    DIVMOD_OPS,
    ID_TO_CHAR,
    MAGNITUDE_WIDTH,
    OPS,
    OP_WIDTH,
    OPERAND_WIDTH,
    PROMPT_LEN,
    SEQ_LEN,
    VOCAB_SIZE,
    ArithmeticTransformer,
    TransformerConfig,
    compute_result,
    count_params,
    decode_ids,
    decode_result_field,
    encode_text,
    make_prompt_tokens,
    safe_decode_result_field,
)


OPERAND_VALUES = range(1000)
NONZERO_OPERAND_VALUES = range(1, 1000)
EXAMPLES_PER_STANDARD_OPERATION = 1000 * 1000
EXAMPLES_PER_DIVMOD_OPERATION = 1000 * 999
TOTAL_EXAMPLES = 3 * EXAMPLES_PER_STANDARD_OPERATION + 2 * EXAMPLES_PER_DIVMOD_OPERATION
OP_TO_ID = {op: i for i, op in enumerate(OPS)}


def safe_command(command: list[str]) -> str:
    try:
        return subprocess.check_output(command, text=True, stderr=subprocess.STDOUT).strip()
    except Exception as exc:
        return f"unavailable: {type(exc).__name__}: {exc}"


def multiplication_carry_trace(a: int, b: int) -> list[int]:
    a_digits = [(a // (10**i)) % 10 for i in range(MAGNITUDE_WIDTH)]
    b_digits = [(b // (10**i)) % 10 for i in range(MAGNITUDE_WIDTH)]
    carry = 0
    carries: list[int] = []
    for column in range(MAGNITUDE_WIDTH):
        column_total = carry
        for i in range(column + 1):
            j = column - i
            column_total += a_digits[i] * b_digits[j]
        carry = column_total // 10
        carries.append(carry)
    return carries


def encode_result_value(value: int, a: int, op: str, b: int, carry_width: int = 2) -> str:
    sign = "-" if value < 0 else "+"
    magnitude = f"{abs(value):0{MAGNITUDE_WIDTH}d}"
    result_digits = magnitude[::-1]
    carries = multiplication_carry_trace(a, b) if op == "*" else [0] * MAGNITUDE_WIDTH
    chunks = [sign]
    for digit, carry in zip(result_digits, carries):
        chunks.append(digit)
        chunks.append(f"{carry:0{carry_width}d}")
    return "".join(chunks)


def rhs_values_for_op(op: str) -> range:
    return NONZERO_OPERAND_VALUES if op in DIVMOD_OPS else OPERAND_VALUES


def format_example(a: int, op: str, b: int) -> str:
    op_field = f"{op:<{OP_WIDTH}}"
    result = encode_result_value(compute_result(a, op, b), a, op, b)
    return f"{a:>{OPERAND_WIDTH}}{op_field}{b:>{OPERAND_WIDTH}}={result}"


def human_readable(model_text: str) -> str:
    left, answer = model_text.split("=")
    a = int(left[:OPERAND_WIDTH])
    op = left[OPERAND_WIDTH : OPERAND_WIDTH + OP_WIDTH].strip()
    b = int(left[OPERAND_WIDTH + OP_WIDTH :])
    return f"{a}{op}{b}={decode_result_field(answer)}"


def build_full_dataset() -> tuple[np.ndarray, np.ndarray]:
    tokens = np.empty((TOTAL_EXAMPLES, SEQ_LEN), dtype=np.int32)
    op_ids = np.empty((TOTAL_EXAMPLES,), dtype=np.int8)
    row = 0
    for op_index, op in enumerate(OPS):
        for a in OPERAND_VALUES:
            for b in rhs_values_for_op(op):
                tokens[row] = encode_text(format_example(a, op, b))
                op_ids[row] = op_index
                row += 1
    assert row == TOTAL_EXAMPLES
    return tokens, op_ids


def build_edge_dataset() -> np.ndarray:
    """Create a small pool of boundary examples for optional fine-tuning."""
    interesting_values = sorted(
        {
            0,
            1,
            2,
            3,
            4,
            5,
            9,
            10,
            11,
            12,
            45,
            99,
            100,
            101,
            123,
            250,
            499,
            500,
            501,
            750,
            900,
            990,
            998,
            999,
        }
    )
    special_divisors = [1, 2, 3, 4, 5, 9, 10, 99, 100, 999]
    high_values = list(range(900, 1000))
    problems: set[tuple[int, str, int]] = set()

    for op in OPS:
        rhs_values = (
            [v for v in interesting_values if v != 0]
            if op in DIVMOD_OPS
            else interesting_values
        )
        for a in interesting_values:
            for b in rhs_values:
                problems.add((a, op, b))

    for a in high_values:
        for b in OPERAND_VALUES:
            problems.add((a, "*", b))
            problems.add((b, "*", a))

    for op in DIVMOD_OPS:
        for a in OPERAND_VALUES:
            for b in special_divisors:
                problems.add((a, op, b))

    rows = sorted(problems, key=lambda item: (OP_TO_ID[item[1]], item[0], item[2]))
    tokens = np.empty((len(rows), SEQ_LEN), dtype=np.int32)
    for row, (a, op, b) in enumerate(rows):
        tokens[row] = encode_text(format_example(a, op, b))
    return tokens


def make_loss_mask() -> np.ndarray:
    label_positions = np.arange(1, SEQ_LEN)
    return (label_positions >= PROMPT_LEN).astype(np.float32)


def make_batch_from_rows(
    split_tokens: np.ndarray,
    rows: np.ndarray,
    loss_mask_template: np.ndarray,
) -> dict[str, np.ndarray]:
    seq = split_tokens[rows]
    return {
        "inputs": seq[:, :-1],
        "labels": seq[:, 1:],
        "loss_mask": np.broadcast_to(loss_mask_template, (len(rows), SEQ_LEN - 1)).copy(),
    }


def make_uniform_batch(
    rng: np.random.Generator,
    split_tokens: np.ndarray,
    batch_size: int,
    loss_mask_template: np.ndarray,
) -> dict[str, np.ndarray]:
    rows = rng.integers(0, len(split_tokens), size=batch_size)
    return make_batch_from_rows(split_tokens, rows, loss_mask_template)


def make_weighted_op_batch(
    rng: np.random.Generator,
    split_tokens: np.ndarray,
    indices_by_op: list[np.ndarray],
    batch_size: int,
    op_probs: np.ndarray,
    loss_mask_template: np.ndarray,
) -> dict[str, np.ndarray]:
    chosen_ops = rng.choice(len(OPS), size=batch_size, p=op_probs)
    rows = np.empty(batch_size, dtype=np.int64)
    for op_index in range(len(OPS)):
        mask = chosen_ops == op_index
        count = int(mask.sum())
        if count:
            rows[mask] = rng.choice(indices_by_op[op_index], size=count, replace=True)
    return make_batch_from_rows(split_tokens, rows, loss_mask_template)


def make_training_batch(
    rng: np.random.Generator,
    split_tokens: np.ndarray,
    indices_by_op: list[np.ndarray],
    edge_tokens: np.ndarray | None,
    batch_size: int,
    op_probs: np.ndarray,
    edge_case_prob: float,
    loss_mask_template: np.ndarray,
) -> dict[str, np.ndarray]:
    edge_count = 0
    if edge_tokens is not None and edge_case_prob > 0:
        edge_count = min(batch_size, int(round(batch_size * edge_case_prob)))
    main_count = batch_size - edge_count

    parts = []
    if main_count:
        chosen_ops = rng.choice(len(OPS), size=main_count, p=op_probs)
        rows = np.empty(main_count, dtype=np.int64)
        for op_index in range(len(OPS)):
            mask = chosen_ops == op_index
            count = int(mask.sum())
            if count:
                rows[mask] = rng.choice(indices_by_op[op_index], size=count, replace=True)
        parts.append(split_tokens[rows])
    if edge_count:
        edge_rows = rng.integers(0, len(edge_tokens), size=edge_count)
        parts.append(edge_tokens[edge_rows])

    seq = np.concatenate(parts, axis=0) if len(parts) > 1 else parts[0]
    seq = seq[rng.permutation(len(seq))]
    return {
        "inputs": seq[:, :-1],
        "labels": seq[:, 1:],
        "loss_mask": np.broadcast_to(loss_mask_template, (len(seq), SEQ_LEN - 1)).copy(),
    }


def parse_model_text(model_text: str) -> tuple[int, str, int, int]:
    left, answer = model_text.split("=")
    a = int(left[:OPERAND_WIDTH])
    op = left[OPERAND_WIDTH : OPERAND_WIDTH + OP_WIDTH].strip()
    b = int(left[OPERAND_WIDTH + OP_WIDTH :])
    return a, op, b, decode_result_field(answer)


def random_eval_problems(
    rng: np.random.Generator,
    n_examples: int,
    op: str | None = None,
) -> list[tuple[int, str, int]]:
    problems: list[tuple[int, str, int]] = []
    op_choices = [op] if op is not None else list(OPS)
    for _ in range(n_examples):
        chosen_op = op_choices[int(rng.integers(0, len(op_choices)))]
        a = int(rng.integers(0, 1000))
        b_low = 1 if chosen_op in DIVMOD_OPS else 0
        b = int(rng.integers(b_low, 1000))
        problems.append((a, chosen_op, b))
    return problems


def parse_op_probs(text: str) -> np.ndarray:
    try:
        values = np.array([float(part.strip()) for part in text.split(",")], dtype=np.float64)
    except ValueError as exc:
        raise argparse.ArgumentTypeError(
            f"--op-probs must be {len(OPS)} comma-separated numbers for {OPS}"
        ) from exc
    if len(values) != len(OPS):
        raise argparse.ArgumentTypeError(
            f"--op-probs must provide {len(OPS)} values, one for each operation: {OPS}"
        )
    if np.any(values < 0):
        raise argparse.ArgumentTypeError("--op-probs values must be non-negative")
    total = values.sum()
    if total <= 0:
        raise argparse.ArgumentTypeError("--op-probs must have a positive sum")
    return values / total


def make_generate_fn(model: ArithmeticTransformer):
    def generate_tokens(params: Any, prompt_tokens: np.ndarray):
        tokens = jnp.asarray(prompt_tokens, dtype=jnp.int32)
        for pos in range(PROMPT_LEN, SEQ_LEN):
            logits = model.apply({"params": params}, tokens[:, :pos])
            next_ids = jnp.argmax(logits[:, -1, :], axis=-1).astype(jnp.int32)
            tokens = tokens.at[:, pos].set(next_ids)
        return tokens

    return jax.jit(generate_tokens)


def predict(
    params: Any,
    generate_tokens,
    problems: list[tuple[int, str, int]],
    batch_size: int,
) -> list[str]:
    outputs: list[str] = []
    for start in range(0, len(problems), batch_size):
        batch_problems = problems[start : start + batch_size]
        prompt_tokens = make_prompt_tokens(batch_problems)
        generated = np.asarray(generate_tokens(params, prompt_tokens))
        outputs.extend(decode_ids(row) for row in generated)
    return outputs


def exact_answer_accuracy(
    params: Any,
    generate_tokens,
    problems: list[tuple[int, str, int]],
    batch_size: int,
) -> dict[str, Any]:
    generated = predict(params, generate_tokens, problems, batch_size)
    correct = 0
    by_op = {op: {"correct": 0, "total": 0} for op in OPS}
    examples = []
    for (a, op, b), model_text in zip(problems, generated):
        pred = safe_decode_result_field(model_text.split("=")[1])
        expected = str(compute_result(a, op, b))
        is_correct = pred == expected
        correct += int(is_correct)
        by_op[op]["correct"] += int(is_correct)
        by_op[op]["total"] += 1
        if len(examples) < 12:
            examples.append(
                {
                    "problem": f"{a}{op}{b}",
                    "model_text": model_text,
                    "prediction": pred,
                    "expected": expected,
                    "correct": is_correct,
                }
            )
    for stats in by_op.values():
        stats["accuracy"] = (
            stats["correct"] / stats["total"] if stats["total"] else None
        )
    return {
        "correct": correct,
        "total": len(problems),
        "accuracy": correct / len(problems) if problems else 0.0,
        "by_operation": by_op,
        "examples": examples,
    }


def jsonable(value: Any) -> Any:
    if isinstance(value, Path):
        return str(value)
    if isinstance(value, dict):
        return {str(key): jsonable(item) for key, item in value.items()}
    if isinstance(value, (list, tuple)):
        return [jsonable(item) for item in value]
    if isinstance(value, np.ndarray):
        return value.tolist()
    if isinstance(value, np.generic):
        return value.item()
    return value


def save_json(path: Path, value: Any) -> None:
    path.write_text(json.dumps(jsonable(value), indent=2) + "\n", encoding="utf-8")


def create_train_state(args: argparse.Namespace):
    config = TransformerConfig(
        d_model=args.d_model,
        n_heads=args.n_heads,
        n_layers=args.n_layers,
        mlp_dim=args.mlp_dim,
    )
    model = ArithmeticTransformer(config)
    variables = model.init(
        jax.random.PRNGKey(args.init_seed),
        jnp.ones((1, SEQ_LEN - 1), dtype=jnp.int32),
    )
    params = variables["params"]
    lr_schedule = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=args.learning_rate,
        warmup_steps=args.warmup_steps,
        decay_steps=args.max_steps,
        end_value=args.learning_rate * args.end_lr_ratio,
    )
    optimizer = optax.chain(
        optax.clip_by_global_norm(args.grad_clip),
        optax.adamw(learning_rate=lr_schedule, weight_decay=args.weight_decay),
    )
    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer,
    )
    return model, config, state, lr_schedule


def main() -> int:
    parser = argparse.ArgumentParser(description="Train the arithmetic transformer.")
    parser.add_argument("--run-name", default=f"run_{int(time.time())}")
    parser.add_argument("--output-dir", type=Path, default=Path("runs"))
    parser.add_argument("--batch-size", type=int, default=2048)
    parser.add_argument("--eval-batch-size", type=int, default=256)
    parser.add_argument("--max-steps", type=int, default=8000)
    parser.add_argument("--min-steps", type=int, default=1500)
    parser.add_argument("--eval-every", type=int, default=500)
    parser.add_argument("--eval-exact-n", type=int, default=1000)
    parser.add_argument("--final-exact-n", type=int, default=10000)
    parser.add_argument("--target-exact-accuracy", type=float, default=0.90)
    parser.add_argument("--learning-rate", type=float, default=8e-4)
    parser.add_argument("--end-lr-ratio", type=float, default=0.1)
    parser.add_argument("--weight-decay", type=float, default=1e-4)
    parser.add_argument("--warmup-steps", type=int, default=400)
    parser.add_argument("--grad-clip", type=float, default=1.0)
    parser.add_argument("--multiplication-focus-steps", type=int, default=2500)
    parser.add_argument("--multiplication-focus-prob", type=float, default=0.70)
    parser.add_argument(
        "--op-probs",
        type=parse_op_probs,
        default=None,
        help=(
            "Optional comma-separated sampling probabilities for "
            f"{OPS}. Example: 0.05,0.10,0.20,0.30,0.35"
        ),
    )
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--init-seed", type=int, default=0)
    parser.add_argument("--d-model", type=int, default=384)
    parser.add_argument("--n-heads", type=int, default=6)
    parser.add_argument("--n-layers", type=int, default=6)
    parser.add_argument("--mlp-dim", type=int, default=1536)
    parser.add_argument("--val-size", type=int, default=100_000)
    parser.add_argument("--save-every-eval", action="store_true")
    parser.add_argument("--require-gpu", action="store_true")
    parser.add_argument(
        "--init-params",
        type=Path,
        default=None,
        help="Optional msgpack params to continue training from.",
    )
    parser.add_argument(
        "--edge-case-prob",
        type=float,
        default=0.0,
        help="Fraction of each training batch drawn from a boundary-case pool.",
    )
    args = parser.parse_args()
    if not 0.0 <= args.edge_case_prob <= 1.0:
        parser.error("--edge-case-prob must be between 0 and 1")

    run_dir = args.output_dir / args.run_name
    run_dir.mkdir(parents=True, exist_ok=True)
    metrics_path = run_dir / "metrics.jsonl"
    metrics_file = metrics_path.open("a", encoding="utf-8")

    def log_event(event: dict[str, Any]) -> None:
        serializable = jsonable(event)
        print(json.dumps(serializable, sort_keys=True), flush=True)
        metrics_file.write(json.dumps(serializable, sort_keys=True) + "\n")
        metrics_file.flush()

    runtime_info = {
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S %Z"),
        "python": platform.python_version(),
        "platform": platform.platform(),
        "jax": jax.__version__,
        "backend": jax.default_backend(),
        "devices": [str(device) for device in jax.devices()],
        "cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"),
        "nvidia_smi": safe_command(["nvidia-smi"]),
    }
    save_json(run_dir / "runtime.json", runtime_info)
    log_event({"event": "runtime", **runtime_info})
    if args.require_gpu and jax.default_backend() != "gpu":
        raise RuntimeError(f"Expected GPU backend, got {jax.default_backend()}")

    for a, op, b in [
        (123, "+", 45),
        (45, "-", 123),
        (123, "*", 45),
        (5, "//", 4),
        (5, "%", 4),
        (999, "*", 999),
    ]:
        text = format_example(a, op, b)
        assert len(text) == SEQ_LEN, text
        assert human_readable(text).endswith(f"={compute_result(a, op, b)}")

    dataset_start = time.perf_counter()
    all_tokens, op_ids = build_full_dataset()
    dataset_seconds = time.perf_counter() - dataset_start

    split_rng = np.random.default_rng(42)
    indices = split_rng.permutation(len(all_tokens))
    val_size = min(args.val_size, len(all_tokens) // 10)
    val_tokens = all_tokens[indices[:val_size]]
    train_tokens = all_tokens[indices[val_size:]]
    val_ops = op_ids[indices[:val_size]]
    train_ops = op_ids[indices[val_size:]]
    train_indices_by_op = [np.flatnonzero(train_ops == op_index) for op_index in range(len(OPS))]
    val_indices_by_op = [np.flatnonzero(val_ops == op_index) for op_index in range(len(OPS))]
    loss_mask_template = make_loss_mask()

    dataset_info = {
        "operations": list(OPS),
        "total_examples": int(len(all_tokens)),
        "unique_valid_examples": int(TOTAL_EXAMPLES),
        "train_examples": int(len(train_tokens)),
        "validation_examples": int(len(val_tokens)),
        "train_examples_by_operation": {op: int(len(train_indices_by_op[i])) for i, op in enumerate(OPS)},
        "validation_examples_by_operation": {op: int(len(val_indices_by_op[i])) for i, op in enumerate(OPS)},
        "sequence_length": SEQ_LEN,
        "prompt_length": PROMPT_LEN,
        "answer_length": ANSWER_LEN,
        "vocab_size": VOCAB_SIZE,
        "vocab": "".join(ID_TO_CHAR[i] for i in range(VOCAB_SIZE)),
        "dataset_build_seconds": round(dataset_seconds, 2),
        "dataset_memory_mb": round(all_tokens.nbytes / 1e6, 2),
        "division_modulo_zero_policy": "b=0 is omitted for // and %",
    }
    save_json(run_dir / "dataset.json", dataset_info)
    log_event({"event": "dataset", **dataset_info})

    model, config, state, lr_schedule = create_train_state(args)
    if args.init_params is not None:
        state = state.replace(
            params=from_bytes(state.params, args.init_params.expanduser().read_bytes())
        )
    param_count = count_params(state.params)
    model_info = {
        **asdict(config),
        "parameter_count": param_count,
        "parameter_count_millions": round(param_count / 1e6, 3),
    }
    save_json(run_dir / "model.json", model_info)
    save_json(run_dir / "hyperparameters.json", vars(args))
    log_event({"event": "model", **model_info})

    def loss_and_metrics(params: Any, batch: dict[str, jnp.ndarray]):
        logits = model.apply({"params": params}, batch["inputs"])
        per_token_loss = optax.softmax_cross_entropy_with_integer_labels(
            logits, batch["labels"]
        )
        mask = batch["loss_mask"]
        denom = jnp.maximum(mask.sum(), 1.0)
        loss = (per_token_loss * mask).sum() / denom
        predictions = jnp.argmax(logits, axis=-1)
        token_accuracy = ((predictions == batch["labels"]) * mask).sum() / denom
        return loss, {"loss": loss, "token_accuracy": token_accuracy}

    @jax.jit
    def train_step(state, batch):
        (_, metrics), grads = jax.value_and_grad(loss_and_metrics, has_aux=True)(
            state.params, batch
        )
        state = state.apply_gradients(grads=grads)
        metrics = dict(metrics)
        metrics["learning_rate"] = lr_schedule(state.step)
        return state, metrics

    @jax.jit
    def eval_step(state, batch):
        _, metrics = loss_and_metrics(state.params, batch)
        return metrics

    generate_tokens = make_generate_fn(model)
    batch_rng = np.random.default_rng(args.seed)
    focused_op_probs = np.array(
        [
            (1.0 - args.multiplication_focus_prob) / (len(OPS) - 1)
            if op != "*"
            else args.multiplication_focus_prob
            for op in OPS
        ],
        dtype=np.float64,
    )
    balanced_op_probs = np.ones(len(OPS), dtype=np.float64) / len(OPS)
    explicit_op_probs = args.op_probs
    edge_tokens = build_edge_dataset() if args.edge_case_prob > 0 else None
    if edge_tokens is not None:
        log_event(
            {
                "event": "edge_dataset",
                "edge_examples": int(len(edge_tokens)),
                "edge_case_prob": args.edge_case_prob,
            }
        )
    best_accuracy = -1.0
    best_step = 0
    history: list[dict[str, Any]] = []
    train_start = time.perf_counter()

    for step in range(1, args.max_steps + 1):
        op_probs = (
            explicit_op_probs
            if explicit_op_probs is not None
            else focused_op_probs
            if step <= args.multiplication_focus_steps
            else balanced_op_probs
        )
        batch_np = make_training_batch(
            batch_rng,
            train_tokens,
            train_indices_by_op,
            edge_tokens,
            args.batch_size,
            op_probs,
            args.edge_case_prob,
            loss_mask_template,
        )
        batch = {name: jnp.asarray(value) for name, value in batch_np.items()}
        state, train_metrics = train_step(state, batch)

        if step == 1 or step % args.eval_every == 0:
            val_batch_np = make_uniform_batch(
                batch_rng, val_tokens, args.batch_size, loss_mask_template
            )
            val_batch = {name: jnp.asarray(value) for name, value in val_batch_np.items()}
            val_metrics = eval_step(state, val_batch)
            train_metrics = jax.device_get(train_metrics)
            val_metrics = jax.device_get(val_metrics)

            eval_rng = np.random.default_rng(100_000 + step)
            eval_problems = random_eval_problems(eval_rng, args.eval_exact_n)
            exact = exact_answer_accuracy(
                state.params, generate_tokens, eval_problems, args.eval_batch_size
            )
            elapsed = time.perf_counter() - train_start
            record = {
                "event": "eval",
                "step": step,
                "elapsed_seconds": round(elapsed, 2),
                "train_loss": float(train_metrics["loss"]),
                "train_token_accuracy": float(train_metrics["token_accuracy"]),
                "val_loss": float(val_metrics["loss"]),
                "val_token_accuracy": float(val_metrics["token_accuracy"]),
                "exact_accuracy": exact["accuracy"],
                "exact_correct": exact["correct"],
                "exact_total": exact["total"],
                "exact_by_operation": exact["by_operation"],
                "learning_rate": float(train_metrics["learning_rate"]),
                "op_sampling": {op: float(op_probs[i]) for i, op in enumerate(OPS)},
            }
            history.append(record)
            log_event(record)
            save_json(run_dir / "latest_examples.json", exact["examples"])

            if exact["accuracy"] > best_accuracy:
                best_accuracy = exact["accuracy"]
                best_step = step
                (run_dir / "best_params.msgpack").write_bytes(to_bytes(state.params))
                save_json(run_dir / "best_record.json", record)
            if args.save_every_eval:
                (run_dir / f"params_step_{step:05d}.msgpack").write_bytes(
                    to_bytes(state.params)
                )
            if step >= args.min_steps and exact["accuracy"] >= args.target_exact_accuracy:
                log_event(
                    {
                        "event": "early_stop",
                        "step": step,
                        "accuracy": exact["accuracy"],
                    }
                )
                break

    train_seconds = time.perf_counter() - train_start
    final_rng = np.random.default_rng(2026)
    final_problems = random_eval_problems(final_rng, args.final_exact_n)
    final_exact = exact_answer_accuracy(
        state.params, generate_tokens, final_problems, args.eval_batch_size
    )
    sample_problems = [
        (0, "+", 0),
        (7, "+", 35),
        (123, "-", 45),
        (45, "-", 123),
        (12, "*", 89),
        (123, "*", 45),
        (5, "//", 4),
        (999, "//", 1),
        (5, "%", 4),
        (998, "%", 999),
        (501, "*", 499),
        (999, "*", 999),
    ]
    sample_predictions = exact_answer_accuracy(
        state.params, generate_tokens, sample_problems, args.eval_batch_size
    )["examples"]

    final_path = run_dir / "arithmetic_transformer_params.msgpack"
    final_path.write_bytes(to_bytes(state.params))
    summary = {
        "run_name": args.run_name,
        "training_seconds": round(train_seconds, 2),
        "best_step": best_step,
        "best_eval_exact_accuracy": best_accuracy,
        "final_exact_accuracy": final_exact["accuracy"],
        "final_exact_correct": final_exact["correct"],
        "final_exact_total": final_exact["total"],
        "final_exact_by_operation": final_exact["by_operation"],
        "model_path": str(final_path),
        "best_model_path": str(run_dir / "best_params.msgpack"),
        "runtime": runtime_info,
        "dataset": dataset_info,
        "model": model_info,
        "hyperparameters": vars(args),
        "sample_predictions": sample_predictions,
        "history": history,
    }
    save_json(run_dir / "summary.json", summary)
    save_json(run_dir / "final_examples.json", final_exact["examples"])
    log_event({"event": "final", **summary})
    metrics_file.close()
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
