#!/usr/bin/env python3
"""Run local inference for the Colab-trained arithmetic transformer.

Install the runtime dependencies on macOS with:

    python3 -m pip install "jax[cpu]" flax

Example:

    python3 scripts/test_arithmetic_model.py '123+45' '45-123' '123*45' '5//4' '5%4' --check-random 1000
"""

from __future__ import annotations

import argparse
import re
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any

try:
    import jax
    import jax.numpy as jnp
    import numpy as np
    from flax import linen as nn
    from flax.serialization import from_bytes
except ModuleNotFoundError as exc:
    missing = exc.name or "a required package"
    print(
        f"Missing dependency: {missing}\n\n"
        "Install local inference dependencies with:\n"
        '  python3 -m pip install "jax[cpu]" flax\n',
        file=sys.stderr,
    )
    raise SystemExit(1) from exc


VOCAB_CHARS = "0123456789 +-*%/="
CHAR_TO_ID = {ch: i for i, ch in enumerate(VOCAB_CHARS)}
ID_TO_CHAR = {i: ch for ch, i in CHAR_TO_ID.items()}
OPS = ("+", "-", "*", "//", "%")
DIVMOD_OPS = ("//", "%")

VOCAB_SIZE = len(VOCAB_CHARS)
OPERAND_WIDTH = 3
OP_WIDTH = 2
MAGNITUDE_WIDTH = 6
CARRY_WIDTH = 2
ANSWER_CHUNK_WIDTH = 1 + CARRY_WIDTH
PROMPT_LEN = OPERAND_WIDTH + OP_WIDTH + OPERAND_WIDTH + 1
ANSWER_LEN = 1 + MAGNITUDE_WIDTH * ANSWER_CHUNK_WIDTH
SEQ_LEN = PROMPT_LEN + ANSWER_LEN

DEFAULT_MODEL_PATH = (
    Path(__file__).resolve().parents[1] / "models" / "arithmetic_transformer_params.msgpack"
)


@dataclass
class TransformerConfig:
    vocab_size: int = VOCAB_SIZE
    max_seq_len: int = SEQ_LEN
    d_model: int = 384
    n_heads: int = 6
    n_layers: int = 6
    mlp_dim: int = 1536


def normal_init(stddev: float = 0.02):
    return nn.initializers.normal(stddev=stddev)


class CausalSelfAttention(nn.Module):
    d_model: int
    n_heads: int

    @nn.compact
    def __call__(self, x):
        batch, seq_len, width = x.shape
        assert width == self.d_model
        assert self.d_model % self.n_heads == 0
        head_dim = self.d_model // self.n_heads

        qkv = nn.Dense(3 * self.d_model, use_bias=False, kernel_init=normal_init())(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)

        def split_heads(tensor):
            tensor = tensor.reshape(batch, seq_len, self.n_heads, head_dim)
            return tensor.transpose(0, 2, 1, 3)

        q = split_heads(q)
        k = split_heads(k)
        v = split_heads(v)

        scale = head_dim**-0.5
        scores = jnp.einsum("bhqd,bhkd->bhqk", q, k) * scale
        causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))[None, None, :, :]
        scores = jnp.where(causal_mask, scores, -1e10)
        weights = nn.softmax(scores, axis=-1)

        attended = jnp.einsum("bhqk,bhkd->bhqd", weights, v)
        attended = attended.transpose(0, 2, 1, 3).reshape(batch, seq_len, self.d_model)
        return nn.Dense(self.d_model, use_bias=False, kernel_init=normal_init())(attended)


class TransformerBlock(nn.Module):
    config: TransformerConfig

    @nn.compact
    def __call__(self, x):
        x = x + CausalSelfAttention(self.config.d_model, self.config.n_heads)(
            nn.LayerNorm()(x)
        )
        y = nn.LayerNorm()(x)
        y = nn.Dense(self.config.mlp_dim, use_bias=False, kernel_init=normal_init())(y)
        y = nn.gelu(y, approximate=True)
        y = nn.Dense(self.config.d_model, use_bias=False, kernel_init=normal_init())(y)
        return x + y


class ArithmeticTransformer(nn.Module):
    config: TransformerConfig

    @nn.compact
    def __call__(self, input_ids):
        seq_len = input_ids.shape[1]
        token_embed = nn.Embed(
            num_embeddings=self.config.vocab_size,
            features=self.config.d_model,
            embedding_init=normal_init(),
            name="token_embedding",
        )(input_ids)
        pos_embed = self.param(
            "position_embedding",
            normal_init(),
            (self.config.max_seq_len, self.config.d_model),
        )
        x = token_embed + pos_embed[None, :seq_len, :]

        for _ in range(self.config.n_layers):
            x = TransformerBlock(self.config)(x)

        x = nn.LayerNorm()(x)
        return nn.Dense(
            self.config.vocab_size,
            use_bias=False,
            kernel_init=normal_init(),
            name="lm_head",
        )(x)


def compute_result(a: int, op: str, b: int) -> int:
    if op == "+":
        return a + b
    if op == "-":
        return a - b
    if op == "*":
        return a * b
    if op == "//":
        if b == 0:
            raise ZeroDivisionError("integer division by zero")
        return a // b
    if op == "%":
        if b == 0:
            raise ZeroDivisionError("modulo by zero")
        return a % b
    raise ValueError(f"unknown operator: {op}")


def count_params(params: Any) -> int:
    return int(sum(leaf.size for leaf in jax.tree_util.tree_leaves(params)))


def encode_text(text: str) -> np.ndarray:
    return np.array([CHAR_TO_ID[ch] for ch in text], dtype=np.int32)


def decode_ids(ids) -> str:
    return "".join(ID_TO_CHAR[int(i)] for i in ids)


def decode_result_field(field: str) -> int:
    sign = field[0]
    if sign not in "+-":
        raise ValueError(f"bad result sign: {field!r}")
    if len(field) != ANSWER_LEN:
        raise ValueError(f"bad result length: {field!r}")
    digits = "".join(
        field[1 + i * ANSWER_CHUNK_WIDTH]
        for i in range(MAGNITUDE_WIDTH)
    )[::-1]
    if not digits.isdigit():
        raise ValueError(f"bad result digits: {field!r}")
    value = int(digits)
    return -value if sign == "-" else value


def safe_decode_result_field(field: str) -> str:
    try:
        return str(decode_result_field(field))
    except Exception:
        return f"<invalid:{field}>"


def parse_problem(problem: str) -> tuple[int, str, int]:
    match = re.fullmatch(r"\s*(\d{1,3})\s*(//|[+\-*%])\s*(\d{1,3})\s*", problem)
    if not match:
        raise argparse.ArgumentTypeError(
            f"{problem!r} is not valid. Use a form like 123+45, 45-123, 123*45, 5//4, or 5%4."
        )
    a, op, b = int(match.group(1)), match.group(2), int(match.group(3))
    if not (0 <= a <= 999 and 0 <= b <= 999):
        raise argparse.ArgumentTypeError("operands must be between 0 and 999")
    if op in DIVMOD_OPS and b == 0:
        raise argparse.ArgumentTypeError(f"{op} by zero is undefined; use a divisor from 1 to 999")
    return a, op, b


def make_prompt_tokens(problems: list[tuple[int, str, int]]) -> np.ndarray:
    tokens = np.full((len(problems), SEQ_LEN), CHAR_TO_ID[" "], dtype=np.int32)
    for i, (a, op, b) in enumerate(problems):
        prompt = f"{a:>{OPERAND_WIDTH}}{op:<{OP_WIDTH}}{b:>{OPERAND_WIDTH}}="
        tokens[i, :PROMPT_LEN] = encode_text(prompt)
    return tokens


def make_generate_fn(model: ArithmeticTransformer, use_jit: bool):
    def generate_tokens(params, prompt_tokens):
        tokens = jnp.asarray(prompt_tokens, dtype=jnp.int32)
        for pos in range(PROMPT_LEN, SEQ_LEN):
            logits = model.apply({"params": params}, tokens[:, :pos])
            next_ids = jnp.argmax(logits[:, -1, :], axis=-1).astype(jnp.int32)
            tokens = tokens.at[:, pos].set(next_ids)
        return tokens

    return jax.jit(generate_tokens) if use_jit else generate_tokens


def predict(
    params: Any,
    model: ArithmeticTransformer,
    problems: list[tuple[int, str, int]],
    *,
    batch_size: int,
    use_jit: bool,
) -> list[str]:
    generate_tokens = make_generate_fn(model, use_jit)
    outputs: list[str] = []
    for start in range(0, len(problems), batch_size):
        batch_problems = problems[start : start + batch_size]
        prompt_tokens = make_prompt_tokens(batch_problems)
        generated = np.asarray(generate_tokens(params, prompt_tokens))
        outputs.extend(decode_ids(row) for row in generated)
    return outputs


def load_params(model_path: Path, model: ArithmeticTransformer):
    if not model_path.exists():
        raise FileNotFoundError(
            f"Model file not found: {model_path}\n"
            "Train the mixed arithmetic notebook first, then place the downloaded params at "
            "models/arithmetic_transformer_params.msgpack."
        )

    variables = model.init(
        jax.random.PRNGKey(0),
        jnp.ones((1, SEQ_LEN - 1), dtype=jnp.int32),
    )
    try:
        params = from_bytes(variables["params"], model_path.read_bytes())
        model.apply({"params": params}, jnp.ones((1, PROMPT_LEN), dtype=jnp.int32))
        return params
    except Exception as exc:
        raise RuntimeError(
            f"Could not load {model_path} into the mixed arithmetic model. "
            "If this is the old addition-only or three-operation params file, retrain the updated notebook "
            "and download arithmetic_transformer_params.msgpack instead."
        ) from exc


def print_predictions(problems: list[tuple[int, str, int]], generated_texts: list[str]) -> None:
    print(f"{'problem':>10} | {'model text':>16} | {'pred':>8} | {'expected':>8} | correct")
    print("-" * 70)
    for (a, op, b), model_text in zip(problems, generated_texts):
        prediction = safe_decode_result_field(model_text.split("=")[1])
        expected = str(compute_result(a, op, b))
        problem = f"{a}{op}{b}"
        print(
            f"{problem:>10} | {model_text!r:>16} | "
            f"{prediction:>8} | {expected:>8} | {prediction == expected}"
        )


def random_problems(n_examples: int, seed: int) -> list[tuple[int, str, int]]:
    rng = np.random.default_rng(seed)
    left = rng.integers(0, 1000, size=n_examples)
    right = rng.integers(0, 1000, size=n_examples)
    op_ids = rng.integers(0, len(OPS), size=n_examples)
    problems = []
    for a, op_id, b in zip(left, op_ids, right):
        op = OPS[int(op_id)]
        b = int(b)
        if op in DIVMOD_OPS and b == 0:
            b = 1
        problems.append((int(a), op, b))
    return problems


def main() -> int:
    parser = argparse.ArgumentParser(
        description="Test the Colab-trained character transformer on local arithmetic prompts."
    )
    parser.add_argument(
        "problems",
        nargs="*",
        type=parse_problem,
        help="Problems like 123+45, 45-123, 123*45, 5//4, or 5%4.",
    )
    parser.add_argument(
        "--model",
        type=Path,
        default=DEFAULT_MODEL_PATH,
        help=f"Path to msgpack params. Default: {DEFAULT_MODEL_PATH}",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=256,
        help="Batch size for local generation.",
    )
    parser.add_argument(
        "--check-random",
        type=int,
        default=0,
        metavar="N",
        help="Also test exact accuracy on N random arithmetic problems.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="Random seed for --check-random.",
    )
    parser.add_argument(
        "--no-jit",
        action="store_true",
        help="Disable JAX JIT. Useful for debugging, usually slower.",
    )
    args = parser.parse_args()

    if args.batch_size < 1:
        parser.error("--batch-size must be at least 1")
    if args.check_random < 0:
        parser.error("--check-random must be non-negative")

    config = TransformerConfig()
    model = ArithmeticTransformer(config)

    start_load = time.perf_counter()
    try:
        params = load_params(args.model.expanduser(), model)
    except Exception as exc:
        print(exc, file=sys.stderr)
        return 1
    load_seconds = time.perf_counter() - start_load

    model_path = args.model.expanduser()
    print(f"JAX backend: {jax.default_backend()}")
    print(f"JAX devices: {[str(device) for device in jax.devices()]}")
    print(f"Loaded model: {model_path}")
    print(f"File size: {model_path.stat().st_size / 1e6:.2f} MB")
    print(f"Parameters: {count_params(params) / 1e6:.2f}M")
    print(f"Load time: {load_seconds:.2f}s")
    print()

    problems = args.problems or [
        (0, "+", 0),
        (7, "+", 35),
        (123, "-", 45),
        (45, "-", 123),
        (12, "*", 89),
        (123, "*", 45),
        (5, "//", 4),
        (999, "//", 1),
        (5, "%", 4),
        (998, "%", 999),
        (999, "*", 999),
    ]
    start_predict = time.perf_counter()
    generated = predict(
        params,
        model,
        problems,
        batch_size=args.batch_size,
        use_jit=not args.no_jit,
    )
    predict_seconds = time.perf_counter() - start_predict
    print_predictions(problems, generated)
    print(f"\nPrediction time: {predict_seconds:.2f}s")

    if args.check_random:
        check_problems = random_problems(args.check_random, args.seed)
        start_check = time.perf_counter()
        check_outputs = predict(
            params,
            model,
            check_problems,
            batch_size=args.batch_size,
            use_jit=not args.no_jit,
        )
        correct = 0
        per_op = {op: {"correct": 0, "total": 0} for op in OPS}
        for (a, op, b), model_text in zip(check_problems, check_outputs):
            is_correct = safe_decode_result_field(model_text.split("=")[1]) == str(compute_result(a, op, b))
            correct += is_correct
            per_op[op]["correct"] += is_correct
            per_op[op]["total"] += 1
        elapsed = time.perf_counter() - start_check
        print(
            f"Random exact accuracy: {correct}/{args.check_random} = "
            f"{correct / args.check_random:.3f} ({elapsed:.2f}s)"
        )
        for op, stats in per_op.items():
            total = stats["total"]
            if total:
                print(f"  {op}: {stats['correct']}/{total} = {stats['correct'] / total:.3f}")

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
