from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable

import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn

from .model import (
    ANSWER_LEN,
    CHAR_TO_ID,
    PROMPT_LEN,
    SEQ_LEN,
    ArithmeticTransformer,
    TransformerConfig,
)


GenerateFn = Callable[[np.ndarray], Any]
LowerFn = Callable[[np.ndarray], Any]


@dataclass(frozen=True)
class VariantRuntime:
    name: str
    generate: GenerateFn
    compile_applicable: bool
    notes: str
    paper_ideas: list[dict[str, str]]
    lower: LowerFn | None = None


def block_until_ready(value: Any) -> Any:
    leaves = jax.tree_util.tree_leaves(value)
    for leaf in leaves:
        if hasattr(leaf, "block_until_ready"):
            leaf.block_until_ready()
    return value


def layer_norm(x: jnp.ndarray, params: dict[str, Any]) -> jnp.ndarray:
    mean = jnp.mean(x, axis=-1, keepdims=True)
    var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True)
    y = (x - mean) * jax.lax.rsqrt(var + 1e-6)
    return y * params["scale"] + params["bias"]


def dense(x: jnp.ndarray, kernel: Any) -> jnp.ndarray:
    return jnp.einsum("...d,df->...f", x, materialize_weight(kernel))


def materialize_weight(value: Any) -> jnp.ndarray:
    if isinstance(value, dict) and "q" in value and "scale" in value:
        return value["q"].astype(value["scale"].dtype) * value["scale"]
    return value


def get_layer(params: dict[str, Any], index: int) -> dict[str, Any]:
    return params[f"TransformerBlock_{index}"]


def split_heads(x: jnp.ndarray, config: TransformerConfig) -> jnp.ndarray:
    batch, seq_len, _ = x.shape
    head_dim = config.d_model // config.n_heads
    return x.reshape(batch, seq_len, config.n_heads, head_dim).transpose(0, 2, 1, 3)


def manual_full_forward(
    params: dict[str, Any],
    input_ids: jnp.ndarray,
    config: TransformerConfig,
) -> jnp.ndarray:
    batch, seq_len = input_ids.shape
    head_dim = config.d_model // config.n_heads
    token_embedding = materialize_weight(params["token_embedding"]["embedding"])
    position_embedding = materialize_weight(params["position_embedding"])
    x = token_embedding[input_ids] + position_embedding[None, :seq_len, :]

    for layer_index in range(config.n_layers):
        block = get_layer(params, layer_index)
        attn_params = block["CausalSelfAttention_0"]
        y = layer_norm(x, block["LayerNorm_0"])
        qkv = dense(y, attn_params["Dense_0"]["kernel"])
        q, k, v = jnp.split(qkv, 3, axis=-1)
        q = split_heads(q, config)
        k = split_heads(k, config)
        v = split_heads(v, config)
        scores = jnp.einsum("bhqd,bhkd->bhqk", q, k) * (head_dim**-0.5)
        mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))[None, None, :, :]
        weights = nn.softmax(jnp.where(mask, scores, -1e10), axis=-1)
        attended = jnp.einsum("bhqk,bhkd->bhqd", weights, v)
        attended = attended.transpose(0, 2, 1, 3).reshape(batch, seq_len, config.d_model)
        x = x + dense(attended, attn_params["Dense_1"]["kernel"])

        y = layer_norm(x, block["LayerNorm_1"])
        y = dense(y, block["Dense_0"]["kernel"])
        y = nn.gelu(y, approximate=True)
        y = dense(y, block["Dense_1"]["kernel"])
        x = x + y

    x = layer_norm(x, params["LayerNorm_0"])
    return dense(x, params["lm_head"]["kernel"])


def manual_incremental_logits(
    params: dict[str, Any],
    token_ids: jnp.ndarray,
    pos: jnp.ndarray,
    k_cache: jnp.ndarray,
    v_cache: jnp.ndarray,
    config: TransformerConfig,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    head_dim = config.d_model // config.n_heads
    token_embedding = params["token_embedding"]["embedding"]
    position_embedding = params["position_embedding"]
    x = token_embedding[token_ids] + position_embedding[pos]

    for layer_index in range(config.n_layers):
        block = get_layer(params, layer_index)
        attn_params = block["CausalSelfAttention_0"]
        y = layer_norm(x, block["LayerNorm_0"])
        qkv = jnp.matmul(y, attn_params["Dense_0"]["kernel"])
        q, k, v = jnp.split(qkv, 3, axis=-1)
        batch = token_ids.shape[0]
        q = q.reshape(batch, config.n_heads, head_dim)
        k = k.reshape(batch, config.n_heads, head_dim)
        v = v.reshape(batch, config.n_heads, head_dim)
        k_cache = k_cache.at[layer_index, :, :, pos, :].set(k)
        v_cache = v_cache.at[layer_index, :, :, pos, :].set(v)
        layer_k = k_cache[layer_index]
        layer_v = v_cache[layer_index]
        scores = jnp.einsum("bhd,bhkd->bhk", q, layer_k) * (head_dim**-0.5)
        mask = jnp.arange(SEQ_LEN - 1) <= pos
        weights = nn.softmax(jnp.where(mask[None, None, :], scores, -1e10), axis=-1)
        attended = jnp.einsum("bhk,bhkd->bhd", weights, layer_v)
        attended = attended.reshape(batch, config.d_model)
        x = x + jnp.matmul(attended, attn_params["Dense_1"]["kernel"])

        y = layer_norm(x, block["LayerNorm_1"])
        y = jnp.matmul(y, block["Dense_0"]["kernel"])
        y = nn.gelu(y, approximate=True)
        y = jnp.matmul(y, block["Dense_1"]["kernel"])
        x = x + y

    x = layer_norm(x, params["LayerNorm_0"])
    logits = jnp.matmul(x, params["lm_head"]["kernel"])
    return logits, k_cache, v_cache


def manual_parallel_prefill(
    params: dict[str, Any],
    prompt_ids: jnp.ndarray,
    config: TransformerConfig,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    batch, prompt_len = prompt_ids.shape
    head_dim = config.d_model // config.n_heads
    dtype = params["token_embedding"]["embedding"].dtype
    token_embedding = params["token_embedding"]["embedding"]
    position_embedding = params["position_embedding"]
    x = token_embedding[prompt_ids] + position_embedding[None, :prompt_len, :]
    k_cache = jnp.zeros(
        (config.n_layers, batch, config.n_heads, SEQ_LEN - 1, head_dim),
        dtype=dtype,
    )
    v_cache = jnp.zeros_like(k_cache)

    for layer_index in range(config.n_layers):
        block = get_layer(params, layer_index)
        attn_params = block["CausalSelfAttention_0"]
        y = layer_norm(x, block["LayerNorm_0"])
        qkv = jnp.matmul(y, attn_params["Dense_0"]["kernel"])
        q, k, v = jnp.split(qkv, 3, axis=-1)
        q = split_heads(q, config)
        k = split_heads(k, config)
        v = split_heads(v, config)
        k_cache = k_cache.at[layer_index, :, :, :prompt_len, :].set(k)
        v_cache = v_cache.at[layer_index, :, :, :prompt_len, :].set(v)
        scores = jnp.einsum("bhqd,bhkd->bhqk", q, k) * (head_dim**-0.5)
        mask = jnp.tril(jnp.ones((prompt_len, prompt_len), dtype=bool))[None, None, :, :]
        weights = nn.softmax(jnp.where(mask, scores, -1e10), axis=-1)
        attended = jnp.einsum("bhqk,bhkd->bhqd", weights, v)
        attended = attended.transpose(0, 2, 1, 3).reshape(batch, prompt_len, config.d_model)
        x = x + jnp.matmul(attended, attn_params["Dense_1"]["kernel"])

        y = layer_norm(x, block["LayerNorm_1"])
        y = jnp.matmul(y, block["Dense_0"]["kernel"])
        y = nn.gelu(y, approximate=True)
        y = jnp.matmul(y, block["Dense_1"]["kernel"])
        x = x + y

    x = layer_norm(x, params["LayerNorm_0"])
    logits = jnp.matmul(x[:, -1, :], params["lm_head"]["kernel"])
    return logits, k_cache, v_cache


def manual_incremental_logits_dpa(
    params: dict[str, Any],
    token_ids: jnp.ndarray,
    pos: jnp.ndarray,
    k_cache: jnp.ndarray,
    v_cache: jnp.ndarray,
    config: TransformerConfig,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    head_dim = config.d_model // config.n_heads
    token_embedding = params["token_embedding"]["embedding"]
    position_embedding = params["position_embedding"]
    x = token_embedding[token_ids] + position_embedding[pos]

    for layer_index in range(config.n_layers):
        block = get_layer(params, layer_index)
        attn_params = block["CausalSelfAttention_0"]
        y = layer_norm(x, block["LayerNorm_0"])
        qkv = jnp.matmul(y, attn_params["Dense_0"]["kernel"])
        q, k, v = jnp.split(qkv, 3, axis=-1)
        batch = token_ids.shape[0]
        q = q.reshape(batch, config.n_heads, head_dim)
        k = k.reshape(batch, config.n_heads, head_dim)
        v = v.reshape(batch, config.n_heads, head_dim)
        k_cache = k_cache.at[layer_index, :, :, pos, :].set(k)
        v_cache = v_cache.at[layer_index, :, :, pos, :].set(v)
        layer_k = k_cache[layer_index].transpose(0, 2, 1, 3)
        layer_v = v_cache[layer_index].transpose(0, 2, 1, 3)
        mask = (jnp.arange(SEQ_LEN - 1) <= pos)[None, None, None, :]
        attended = jax.nn.dot_product_attention(
            q[:, None, :, :],
            layer_k,
            layer_v,
            mask=mask,
            implementation="xla",
        )
        attended = attended[:, 0, :, :].reshape(batch, config.d_model)
        x = x + jnp.matmul(attended, attn_params["Dense_1"]["kernel"])

        y = layer_norm(x, block["LayerNorm_1"])
        y = jnp.matmul(y, block["Dense_0"]["kernel"])
        y = nn.gelu(y, approximate=True)
        y = jnp.matmul(y, block["Dense_1"]["kernel"])
        x = x + y

    x = layer_norm(x, params["LayerNorm_0"])
    logits = jnp.matmul(x, params["lm_head"]["kernel"])
    return logits, k_cache, v_cache


def manual_parallel_prefill_dpa(
    params: dict[str, Any],
    prompt_ids: jnp.ndarray,
    config: TransformerConfig,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    batch, prompt_len = prompt_ids.shape
    head_dim = config.d_model // config.n_heads
    dtype = params["token_embedding"]["embedding"].dtype
    token_embedding = params["token_embedding"]["embedding"]
    position_embedding = params["position_embedding"]
    x = token_embedding[prompt_ids] + position_embedding[None, :prompt_len, :]
    k_cache = jnp.zeros(
        (config.n_layers, batch, config.n_heads, SEQ_LEN - 1, head_dim),
        dtype=dtype,
    )
    v_cache = jnp.zeros_like(k_cache)

    for layer_index in range(config.n_layers):
        block = get_layer(params, layer_index)
        attn_params = block["CausalSelfAttention_0"]
        y = layer_norm(x, block["LayerNorm_0"])
        qkv = jnp.matmul(y, attn_params["Dense_0"]["kernel"])
        q, k, v = jnp.split(qkv, 3, axis=-1)
        q = split_heads(q, config)
        k = split_heads(k, config)
        v = split_heads(v, config)
        k_cache = k_cache.at[layer_index, :, :, :prompt_len, :].set(k)
        v_cache = v_cache.at[layer_index, :, :, :prompt_len, :].set(v)
        attended = jax.nn.dot_product_attention(
            q.transpose(0, 2, 1, 3),
            k.transpose(0, 2, 1, 3),
            v.transpose(0, 2, 1, 3),
            is_causal=True,
            implementation="xla",
        )
        attended = attended.reshape(batch, prompt_len, config.d_model)
        x = x + jnp.matmul(attended, attn_params["Dense_1"]["kernel"])

        y = layer_norm(x, block["LayerNorm_1"])
        y = jnp.matmul(y, block["Dense_0"]["kernel"])
        y = nn.gelu(y, approximate=True)
        y = jnp.matmul(y, block["Dense_1"]["kernel"])
        x = x + y

    x = layer_norm(x, params["LayerNorm_0"])
    logits = jnp.matmul(x[:, -1, :], params["lm_head"]["kernel"])
    return logits, k_cache, v_cache


def quantize_params(params: Any, bits: int, dtype: str) -> Any:
    if bits not in (2, 4, 8):
        raise ValueError(f"unsupported quantization bits: {bits}")
    target_dtype = {
        "float32": jnp.float32,
        "bfloat16": jnp.bfloat16,
        "float16": jnp.float16,
    }[dtype]
    qmax = (2 ** (bits - 1)) - 1

    def quantize_leaf(leaf):
        if not hasattr(leaf, "dtype") or not jnp.issubdtype(leaf.dtype, jnp.floating):
            return leaf
        if leaf.ndim < 2:
            return leaf.astype(target_dtype)
        max_abs = jnp.max(jnp.abs(leaf))
        scale = jnp.maximum(max_abs / qmax, jnp.asarray(1e-8, dtype=target_dtype))
        q = jnp.clip(jnp.round(leaf.astype(jnp.float32) / scale), -qmax, qmax).astype(jnp.int8)
        return {"q": q, "scale": scale.astype(target_dtype), "bits": jnp.asarray(bits, dtype=jnp.int32)}

    return jax.tree_util.tree_map(quantize_leaf, params)


def make_eager_runtime(
    params: Any,
    model: ArithmeticTransformer,
    _config: TransformerConfig,
) -> VariantRuntime:
    def generate(prompt_tokens: np.ndarray) -> jnp.ndarray:
        tokens = jnp.asarray(prompt_tokens, dtype=jnp.int32)
        for pos in range(PROMPT_LEN, SEQ_LEN):
            logits = model.apply({"params": params}, tokens[:, :pos])
            next_ids = jnp.argmax(logits[:, -1, :], axis=-1).astype(jnp.int32)
            tokens = tokens.at[:, pos].set(next_ids)
        return tokens

    return VariantRuntime(
        name="eager",
        generate=generate,
        compile_applicable=False,
        notes="Reference Python loop with full prefix recomputation at each generated token.",
        paper_ideas=[],
    )


def make_jit_step_runtime(
    params: Any,
    _model: ArithmeticTransformer,
    config: TransformerConfig,
) -> VariantRuntime:
    def step(params_arg: Any, tokens: jnp.ndarray, pos: jnp.ndarray) -> jnp.ndarray:
        logits = manual_full_forward(params_arg, tokens, config)
        step_logits = jax.lax.dynamic_index_in_dim(logits, pos - 1, axis=1, keepdims=False)
        next_ids = jnp.argmax(step_logits, axis=-1).astype(jnp.int32)
        return tokens.at[:, pos].set(next_ids)

    step_jit = jax.jit(step)

    def generate(prompt_tokens: np.ndarray) -> jnp.ndarray:
        tokens = jnp.asarray(prompt_tokens, dtype=jnp.int32)
        for pos in range(PROMPT_LEN, SEQ_LEN):
            tokens = step_jit(params, tokens, jnp.asarray(pos, dtype=jnp.int32))
        return tokens

    return VariantRuntime(
        name="jit_step",
        generate=generate,
        compile_applicable=True,
        notes="One static-shape token step is jitted, while the autoregressive loop stays in Python.",
        paper_ideas=[
            {
                "paper": "Compiler graph capture and fusion, similar in spirit to XLA or torch.compile style inference.",
                "url": "https://openxla.org/xla",
                "implementation": "JIT a single decode step with static token shape.",
            }
        ],
    )


def make_jit_full_runtime(
    params: Any,
    _model: ArithmeticTransformer,
    config: TransformerConfig,
) -> VariantRuntime:
    def generate_impl(params_arg: Any, prompt_tokens: jnp.ndarray) -> jnp.ndarray:
        tokens = jnp.asarray(prompt_tokens, dtype=jnp.int32)

        def body(pos: jnp.ndarray, current_tokens: jnp.ndarray) -> jnp.ndarray:
            logits = manual_full_forward(params_arg, current_tokens, config)
            step_logits = jax.lax.dynamic_index_in_dim(logits, pos - 1, axis=1, keepdims=False)
            next_ids = jnp.argmax(step_logits, axis=-1).astype(jnp.int32)
            return current_tokens.at[:, pos].set(next_ids)

        return jax.lax.fori_loop(PROMPT_LEN, SEQ_LEN, body, tokens)

    generate_jit = jax.jit(generate_impl)

    def generate(prompt_tokens: np.ndarray) -> jnp.ndarray:
        return generate_jit(params, prompt_tokens)

    def lower(prompt_tokens: np.ndarray) -> Any:
        return generate_jit.lower(params, prompt_tokens)

    return VariantRuntime(
        name="jit_full",
        generate=generate,
        compile_applicable=True,
        notes="The complete greedy generation loop is captured in one XLA computation with static shapes.",
        paper_ideas=[
            {
                "paper": "Compiler graph capture and loop lowering, similar in spirit to XLA or torch.compile style inference.",
                "url": "https://openxla.org/xla",
                "implementation": "JIT the full fixed-length generation loop.",
            }
        ],
        lower=lower,
    )


def make_kv_cache_runtime(
    params: Any,
    _model: ArithmeticTransformer,
    config: TransformerConfig,
) -> VariantRuntime:
    def generate_impl(params_arg: Any, prompt_tokens: jnp.ndarray) -> jnp.ndarray:
        tokens = jnp.asarray(prompt_tokens, dtype=jnp.int32)
        batch = tokens.shape[0]
        head_dim = config.d_model // config.n_heads
        dtype = params_arg["token_embedding"]["embedding"].dtype
        k_cache = jnp.zeros(
            (config.n_layers, batch, config.n_heads, SEQ_LEN - 1, head_dim),
            dtype=dtype,
        )
        v_cache = jnp.zeros_like(k_cache)
        logits0 = jnp.zeros((batch, config.vocab_size), dtype=dtype)

        def prefill_body(pos: jnp.ndarray, state: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]):
            current_k, current_v, _logits = state
            logits, next_k, next_v = manual_incremental_logits(
                params_arg,
                tokens[:, pos],
                pos,
                current_k,
                current_v,
                config,
            )
            return next_k, next_v, logits

        k_cache, v_cache, logits = jax.lax.fori_loop(
            0,
            PROMPT_LEN,
            prefill_body,
            (k_cache, v_cache, logits0),
        )

        def decode_body(pos: jnp.ndarray, state: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]):
            current_tokens, current_k, current_v, current_logits = state
            next_ids = jnp.argmax(current_logits, axis=-1).astype(jnp.int32)
            current_tokens = current_tokens.at[:, pos].set(next_ids)

            def process_generated(args):
                proc_tokens, proc_k, proc_v = args
                new_logits, new_k, new_v = manual_incremental_logits(
                    params_arg,
                    proc_tokens[:, pos],
                    pos,
                    proc_k,
                    proc_v,
                    config,
                )
                return proc_tokens, new_k, new_v, new_logits

            def skip_generated(args):
                proc_tokens, proc_k, proc_v = args
                return proc_tokens, proc_k, proc_v, current_logits

            return jax.lax.cond(
                pos < SEQ_LEN - 1,
                process_generated,
                skip_generated,
                (current_tokens, current_k, current_v),
            )

        tokens, k_cache, v_cache, logits = jax.lax.fori_loop(
            PROMPT_LEN,
            SEQ_LEN,
            decode_body,
            (tokens, k_cache, v_cache, logits),
        )
        return tokens

    generate_jit = jax.jit(generate_impl)

    def generate(prompt_tokens: np.ndarray) -> jnp.ndarray:
        return generate_jit(params, prompt_tokens)

    def lower(prompt_tokens: np.ndarray) -> Any:
        return generate_jit.lower(params, prompt_tokens)

    return VariantRuntime(
        name="kv_cache",
        generate=generate,
        compile_applicable=True,
        notes=(
            "Manual incremental decode with static preallocated key/value caches. "
            "This is a small-model analogue of KV-cache serving, not a paged cache implementation."
        ),
        paper_ideas=[
            {
                "paper": "Efficient Memory Management for Large Language Model Serving with PagedAttention",
                "url": "https://arxiv.org/abs/2309.06180",
                "implementation": "Use preallocated K/V tensors and append one token at a time, without paging.",
            },
            {
                "paper": "Fast and Expressive LLM Inference with RadixAttention and SGLang",
                "url": "https://arxiv.org/abs/2312.07104",
                "implementation": "Reuse the prompt prefill inside a single request; cross-request radix reuse is logged as out of scope.",
            },
        ],
        lower=lower,
    )


def make_kv_cache_prefill_runtime(
    params: Any,
    _model: ArithmeticTransformer,
    config: TransformerConfig,
) -> VariantRuntime:
    def generate_impl(params_arg: Any, prompt_tokens: jnp.ndarray) -> jnp.ndarray:
        tokens = jnp.asarray(prompt_tokens, dtype=jnp.int32)
        logits, k_cache, v_cache = manual_parallel_prefill(
            params_arg,
            tokens[:, :PROMPT_LEN],
            config,
        )

        def decode_body(pos: jnp.ndarray, state: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]):
            current_tokens, current_k, current_v, current_logits = state
            next_ids = jnp.argmax(current_logits, axis=-1).astype(jnp.int32)
            current_tokens = current_tokens.at[:, pos].set(next_ids)

            def process_generated(args):
                proc_tokens, proc_k, proc_v = args
                new_logits, new_k, new_v = manual_incremental_logits(
                    params_arg,
                    proc_tokens[:, pos],
                    pos,
                    proc_k,
                    proc_v,
                    config,
                )
                return proc_tokens, new_k, new_v, new_logits

            def skip_generated(args):
                proc_tokens, proc_k, proc_v = args
                return proc_tokens, proc_k, proc_v, current_logits

            return jax.lax.cond(
                pos < SEQ_LEN - 1,
                process_generated,
                skip_generated,
                (current_tokens, current_k, current_v),
            )

        tokens, k_cache, v_cache, logits = jax.lax.fori_loop(
            PROMPT_LEN,
            SEQ_LEN,
            decode_body,
            (tokens, k_cache, v_cache, logits),
        )
        return tokens

    generate_jit = jax.jit(generate_impl)

    def generate(prompt_tokens: np.ndarray) -> jnp.ndarray:
        return generate_jit(params, prompt_tokens)

    def lower(prompt_tokens: np.ndarray) -> Any:
        return generate_jit.lower(params, prompt_tokens)

    return VariantRuntime(
        name="kv_cache_prefill",
        generate=generate,
        compile_applicable=True,
        notes=(
            "KV-cache decode with the prompt prefilled in one parallel causal forward pass. "
            "This tests the common serving split between prompt prefill and token-by-token decode."
        ),
        paper_ideas=[
            {
                "paper": "Production LLM serving prefill/decode separation, as used in systems such as vLLM and SGLang.",
                "url": "https://arxiv.org/abs/2309.06180",
                "implementation": "Parallel causal prefill for the fixed prompt, then incremental KV-cache decode for answer tokens.",
            }
        ],
        lower=lower,
    )


def make_kv_cache_prefill_dpa_runtime(
    params: Any,
    _model: ArithmeticTransformer,
    config: TransformerConfig,
) -> VariantRuntime:
    def generate_impl(params_arg: Any, prompt_tokens: jnp.ndarray) -> jnp.ndarray:
        tokens = jnp.asarray(prompt_tokens, dtype=jnp.int32)
        logits, k_cache, v_cache = manual_parallel_prefill_dpa(
            params_arg,
            tokens[:, :PROMPT_LEN],
            config,
        )

        def decode_body(pos: jnp.ndarray, state: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]):
            current_tokens, current_k, current_v, current_logits = state
            next_ids = jnp.argmax(current_logits, axis=-1).astype(jnp.int32)
            current_tokens = current_tokens.at[:, pos].set(next_ids)

            def process_generated(args):
                proc_tokens, proc_k, proc_v = args
                new_logits, new_k, new_v = manual_incremental_logits_dpa(
                    params_arg,
                    proc_tokens[:, pos],
                    pos,
                    proc_k,
                    proc_v,
                    config,
                )
                return proc_tokens, new_k, new_v, new_logits

            def skip_generated(args):
                proc_tokens, proc_k, proc_v = args
                return proc_tokens, proc_k, proc_v, current_logits

            return jax.lax.cond(
                pos < SEQ_LEN - 1,
                process_generated,
                skip_generated,
                (current_tokens, current_k, current_v),
            )

        tokens, k_cache, v_cache, logits = jax.lax.fori_loop(
            PROMPT_LEN,
            SEQ_LEN,
            decode_body,
            (tokens, k_cache, v_cache, logits),
        )
        return tokens

    generate_jit = jax.jit(generate_impl)

    def generate(prompt_tokens: np.ndarray) -> jnp.ndarray:
        return generate_jit(params, prompt_tokens)

    def lower(prompt_tokens: np.ndarray) -> Any:
        return generate_jit.lower(params, prompt_tokens)

    return VariantRuntime(
        name="kv_cache_prefill_dpa",
        generate=generate,
        compile_applicable=True,
        notes=(
            "KV-cache prefill/decode using jax.nn.dot_product_attention with the XLA implementation. "
            "This tests whether exposing the attention pattern to JAX's attention primitive creates useful fused attention code on this backend."
        ),
        paper_ideas=[
            {
                "paper": "Scaled dot-product attention kernel fusion as used by modern transformer runtimes.",
                "url": "https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.dot_product_attention.html",
                "implementation": "Replace explicit QK, mask, softmax, and AV attention code with jax.nn.dot_product_attention.",
            }
        ],
        lower=lower,
    )


def make_kv_cache_prefill_static_runtime(
    params: Any,
    _model: ArithmeticTransformer,
    config: TransformerConfig,
) -> VariantRuntime:
    static_params = jax.tree_util.tree_map(jnp.asarray, params)

    def generate_impl(prompt_tokens: jnp.ndarray) -> jnp.ndarray:
        tokens = jnp.asarray(prompt_tokens, dtype=jnp.int32)
        logits, k_cache, v_cache = manual_parallel_prefill(
            static_params,
            tokens[:, :PROMPT_LEN],
            config,
        )

        def decode_body(pos: jnp.ndarray, state: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]):
            current_tokens, current_k, current_v, current_logits = state
            next_ids = jnp.argmax(current_logits, axis=-1).astype(jnp.int32)
            current_tokens = current_tokens.at[:, pos].set(next_ids)

            def process_generated(args):
                proc_tokens, proc_k, proc_v = args
                new_logits, new_k, new_v = manual_incremental_logits(
                    static_params,
                    proc_tokens[:, pos],
                    pos,
                    proc_k,
                    proc_v,
                    config,
                )
                return proc_tokens, new_k, new_v, new_logits

            def skip_generated(args):
                proc_tokens, proc_k, proc_v = args
                return proc_tokens, proc_k, proc_v, current_logits

            return jax.lax.cond(
                pos < SEQ_LEN - 1,
                process_generated,
                skip_generated,
                (current_tokens, current_k, current_v),
            )

        tokens, k_cache, v_cache, logits = jax.lax.fori_loop(
            PROMPT_LEN,
            SEQ_LEN,
            decode_body,
            (tokens, k_cache, v_cache, logits),
        )
        return tokens

    generate_jit = jax.jit(generate_impl)

    def generate(prompt_tokens: np.ndarray) -> jnp.ndarray:
        return generate_jit(prompt_tokens)

    def lower(prompt_tokens: np.ndarray) -> Any:
        return generate_jit.lower(prompt_tokens)

    return VariantRuntime(
        name="kv_cache_prefill_static",
        generate=generate,
        compile_applicable=True,
        notes=(
            "KV-cache prefill/decode with model parameters closed over by the compiled function. "
            "This tests whether treating weights as compile-time constants reduces host argument overhead for this small CPU workload."
        ),
        paper_ideas=[
            {
                "paper": "Static-graph inference deployment, similar in spirit to AOT-compiled serving graphs.",
                "url": "https://openxla.org/xla",
                "implementation": "Close the parameter tree over the jitted generation graph and pass only prompt tokens at runtime.",
            }
        ],
        lower=lower,
    )


def make_kv_cache_prefill_unrolled_runtime(
    params: Any,
    _model: ArithmeticTransformer,
    config: TransformerConfig,
) -> VariantRuntime:
    def generate_impl(params_arg: Any, prompt_tokens: jnp.ndarray) -> jnp.ndarray:
        tokens = jnp.asarray(prompt_tokens, dtype=jnp.int32)
        logits, k_cache, v_cache = manual_parallel_prefill(
            params_arg,
            tokens[:, :PROMPT_LEN],
            config,
        )

        for pos in range(PROMPT_LEN, SEQ_LEN):
            next_ids = jnp.argmax(logits, axis=-1).astype(jnp.int32)
            tokens = tokens.at[:, pos].set(next_ids)
            if pos < SEQ_LEN - 1:
                logits, k_cache, v_cache = manual_incremental_logits(
                    params_arg,
                    tokens[:, pos],
                    jnp.asarray(pos, dtype=jnp.int32),
                    k_cache,
                    v_cache,
                    config,
                )
        return tokens

    generate_jit = jax.jit(generate_impl)

    def generate(prompt_tokens: np.ndarray) -> jnp.ndarray:
        return generate_jit(params, prompt_tokens)

    def lower(prompt_tokens: np.ndarray) -> Any:
        return generate_jit.lower(params, prompt_tokens)

    return VariantRuntime(
        name="kv_cache_prefill_unrolled",
        generate=generate,
        compile_applicable=True,
        notes=(
            "KV-cache prefill/decode with the fixed 19-token decode loop unrolled at JIT trace time. "
            "This trades larger compile graphs for less loop/control overhead in steady-state inference."
        ),
        paper_ideas=[
            {
                "paper": "Static graph specialization for fixed-shape autoregressive decoding.",
                "url": "https://openxla.org/xla",
                "implementation": "Use Python-level unrolling inside JIT for the known answer length.",
            }
        ],
        lower=lower,
    )


def make_quant_runtime(
    params: Any,
    _model: ArithmeticTransformer,
    config: TransformerConfig,
    bits: int,
    dtype: str,
) -> VariantRuntime:
    qparams = quantize_params(params, bits=bits, dtype=dtype)

    def generate_impl(prompt_tokens: jnp.ndarray) -> jnp.ndarray:
        tokens = jnp.asarray(prompt_tokens, dtype=jnp.int32)

        def body(pos: jnp.ndarray, current_tokens: jnp.ndarray) -> jnp.ndarray:
            logits = manual_full_forward(qparams, current_tokens, config)
            step_logits = jax.lax.dynamic_index_in_dim(logits, pos - 1, axis=1, keepdims=False)
            next_ids = jnp.argmax(step_logits, axis=-1).astype(jnp.int32)
            return current_tokens.at[:, pos].set(next_ids)

        return jax.lax.fori_loop(PROMPT_LEN, SEQ_LEN, body, tokens)

    generate_jit = jax.jit(generate_impl)

    def generate(prompt_tokens: np.ndarray) -> jnp.ndarray:
        return generate_jit(prompt_tokens)

    def lower(prompt_tokens: np.ndarray) -> Any:
        return generate_jit.lower(prompt_tokens)

    return VariantRuntime(
        name=f"quant_int{bits}",
        generate=generate,
        compile_applicable=True,
        notes=(
            f"Naive symmetric per-tensor int{bits} weight-only quantization with dequantization in the JAX graph. "
            "Quantized weights are closed over by the compiled function so the benchmark is not dominated by passing them as dynamic inputs. "
            "This is not a production GPTQ, AWQ, SmoothQuant, or LLM.int8 kernel."
        ),
        paper_ideas=[
            {
                "paper": "LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale",
                "url": "https://arxiv.org/abs/2208.07339",
                "implementation": "For int8, store large matrices as int8 with a scale and dequantize inside matmul.",
            },
            {
                "paper": "GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers",
                "url": "https://arxiv.org/abs/2210.17323",
                "implementation": "Benchmark a low-bit post-training quantization baseline; Hessian compensation is not implemented.",
            },
            {
                "paper": "SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models",
                "url": "https://arxiv.org/abs/2211.10438",
                "implementation": "Benchmark W-only quantization as a contrast; activation smoothing is logged as not implemented.",
            },
            {
                "paper": "AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration",
                "url": "https://arxiv.org/abs/2306.00978",
                "implementation": "Benchmark low-bit W-only quantization as a contrast; activation-aware scaling is not implemented.",
            },
        ],
        lower=lower,
    )


def zero_format_draft(batch_size: int, start_pos: int, chunk_size: int) -> np.ndarray:
    draft = np.full((batch_size, chunk_size), CHAR_TO_ID["0"], dtype=np.int32)
    for i in range(chunk_size):
        if start_pos + i == PROMPT_LEN:
            draft[:, i] = CHAR_TO_ID["+"]
    return draft


def make_speculative_zero_runtime(
    params: Any,
    model: ArithmeticTransformer,
    _config: TransformerConfig,
    chunk_size: int = 4,
) -> VariantRuntime:
    def verify_chunk(
        params_arg: Any,
        tokens: jnp.ndarray,
        pos: jnp.ndarray,
        draft_ids: jnp.ndarray,
    ) -> tuple[jnp.ndarray, jnp.ndarray]:
        def draft_body(i: jnp.ndarray, current_tokens: jnp.ndarray) -> jnp.ndarray:
            return current_tokens.at[:, pos + i].set(draft_ids[:, i])

        drafted_tokens = jax.lax.fori_loop(0, chunk_size, draft_body, tokens)
        logits = model.apply({"params": params_arg}, drafted_tokens)
        positions = pos + jnp.arange(chunk_size) - 1
        verify_logits = jnp.take(logits, positions, axis=1)
        target_ids = jnp.argmax(verify_logits, axis=-1).astype(jnp.int32)
        matches = target_ids == draft_ids
        accepted = jnp.cumprod(matches.astype(jnp.int32), axis=1).astype(bool)
        accepted_count = jnp.sum(accepted.astype(jnp.int32), axis=1)
        has_reject = accepted_count < chunk_size
        advance = accepted_count + has_reject.astype(jnp.int32)

        def apply_body(i: jnp.ndarray, current_tokens: jnp.ndarray) -> jnp.ndarray:
            accept_i = i < accepted_count
            reject_i = has_reject & (i == accepted_count)
            should_set = accept_i | reject_i
            chosen = jnp.where(accept_i, draft_ids[:, i], target_ids[:, i])
            old = current_tokens[:, pos + i]
            value = jnp.where(should_set, chosen, old)
            return current_tokens.at[:, pos + i].set(value)

        next_tokens = jax.lax.fori_loop(0, chunk_size, apply_body, tokens)
        return next_tokens, advance

    def one_step(params_arg: Any, tokens: jnp.ndarray, pos: jnp.ndarray) -> jnp.ndarray:
        logits = model.apply({"params": params_arg}, tokens)
        step_logits = jax.lax.dynamic_index_in_dim(logits, pos - 1, axis=1, keepdims=False)
        next_ids = jnp.argmax(step_logits, axis=-1).astype(jnp.int32)
        return tokens.at[:, pos].set(next_ids)

    verify_jit = jax.jit(verify_chunk)
    step_jit = jax.jit(one_step)

    def generate(prompt_tokens: np.ndarray) -> jnp.ndarray:
        tokens = jnp.asarray(prompt_tokens, dtype=jnp.int32)
        batch_size = int(prompt_tokens.shape[0])
        pos = PROMPT_LEN
        while pos < SEQ_LEN:
            remaining = SEQ_LEN - pos
            if remaining >= chunk_size:
                draft = zero_format_draft(batch_size, pos, chunk_size)
                tokens, advance = verify_jit(params, tokens, jnp.asarray(pos, dtype=jnp.int32), draft)
                block_until_ready(tokens)
                host_advance = int(np.asarray(jnp.min(advance)))
                pos += max(1, min(host_advance, remaining))
            else:
                tokens = step_jit(params, tokens, jnp.asarray(pos, dtype=jnp.int32))
                pos += 1
        return tokens

    return VariantRuntime(
        name="speculative_zero",
        generate=generate,
        compile_applicable=True,
        notes=(
            "Speculative decoding with a fixed cheap draft that proposes '+' then '0' tokens and verifies with the target model. "
            "The batch advances by the smallest accepted token count, so divergent rows reduce the benefit."
        ),
        paper_ideas=[
            {
                "paper": "Fast Inference from Transformers via Speculative Decoding",
                "url": "https://arxiv.org/abs/2211.17192",
                "implementation": "Draft several tokens and verify them with the target model before accepting.",
            },
            {
                "paper": "Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads",
                "url": "https://arxiv.org/abs/2401.10774",
                "implementation": "Test multi-token verification mechanics without training Medusa heads.",
            },
        ],
    )


def make_variant_runtime(
    name: str,
    params: Any,
    model: ArithmeticTransformer,
    config: TransformerConfig,
    dtype: str,
) -> VariantRuntime:
    if name == "eager":
        return make_eager_runtime(params, model, config)
    if name == "jit_step":
        return make_jit_step_runtime(params, model, config)
    if name == "jit_full":
        return make_jit_full_runtime(params, model, config)
    if name == "kv_cache":
        return make_kv_cache_runtime(params, model, config)
    if name == "kv_cache_prefill":
        return make_kv_cache_prefill_runtime(params, model, config)
    if name == "kv_cache_prefill_dpa":
        return make_kv_cache_prefill_dpa_runtime(params, model, config)
    if name == "kv_cache_prefill_static":
        return make_kv_cache_prefill_static_runtime(params, model, config)
    if name == "kv_cache_prefill_unrolled":
        return make_kv_cache_prefill_unrolled_runtime(params, model, config)
    if name == "quant_int8":
        return make_quant_runtime(params, model, config, bits=8, dtype=dtype)
    if name == "quant_int4":
        return make_quant_runtime(params, model, config, bits=4, dtype=dtype)
    if name == "quant_int2":
        return make_quant_runtime(params, model, config, bits=2, dtype=dtype)
    if name == "speculative_zero":
        return make_speculative_zero_runtime(params, model, config)
    raise ValueError(f"unknown inference variant: {name}")


AVAILABLE_VARIANTS = (
    "eager",
    "jit_step",
    "jit_full",
    "kv_cache",
    "kv_cache_prefill",
    "kv_cache_prefill_dpa",
    "kv_cache_prefill_static",
    "kv_cache_prefill_unrolled",
    "quant_int8",
    "quant_int4",
    "quant_int2",
    "speculative_zero",
)
