{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Character-level transformer learns 3-digit arithmetic\n",
        "\n",
        "This notebook trains a decoder-only transformer with about 10 million\n",
        "parameters to solve fixed-format arithmetic problems with addition,\n",
        "subtraction, multiplication, integer division, and modulo.\n",
        "\n",
        "Examples:\n",
        "\n",
        "Human-readable examples:\n",
        "\n",
        "- `123+45=168`\n",
        "- `123-45=78`\n",
        "- `123*45=5535`\n",
        "- `5//4=1`\n",
        "- `5%4=1`\n",
        "- `0-999=-999`\n",
        "- `999*999=998001`\n",
        "\n",
        "Model-format examples use a sign plus reversed, zero-padded result digits\n",
        "with two carry-trace digits after each result digit:\n",
        "\n",
        "- `123+  45=+800600100000000000`\n",
        "- `123-  45=+800700000000000000`\n",
        "- `123*  45=+501302501500000000`\n",
        "- `  5//  4=+100000000000000000`\n",
        "- `  5%   4=+100000000000000000`\n",
        "- `  0- 999=-900900900000000000`\n",
        "- `999* 999=+108017026818909900`\n",
        "\n",
        "The model sees each example as characters, not as Python integers. Spaces\n",
        "are real tokens used for fixed-width formatting.\n",
        "\n",
        "Runtime policy:\n",
        "\n",
        "- Use free Colab only.\n",
        "- Prefer a free TPU if one is available.\n",
        "- Otherwise use a free T4 GPU.\n",
        "- Do not enable paid resources.\n",
        "- If no TPU/GPU accelerator is attached, the first code cell stops before training."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 1. Runtime and package audit\n",
        "\n",
        "The model and training code use only JAX, Flax, and Optax. This cell records\n",
        "the runtime, package versions, and attached devices so the run can be\n",
        "reproduced later."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import importlib.util\n",
        "import json\n",
        "import os\n",
        "import platform\n",
        "import subprocess\n",
        "import sys\n",
        "import time\n",
        "import gc\n",
        "\n",
        "missing = [pkg for pkg in (\"flax\", \"optax\") if importlib.util.find_spec(pkg) is None]\n",
        "if missing:\n",
        "    print(\"Installing missing packages:\", missing)\n",
        "    subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"flax\", \"optax\"])\n",
        "\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import jaxlib\n",
        "import flax\n",
        "from flax import linen as nn\n",
        "from flax.training import train_state\n",
        "from flax.serialization import to_bytes\n",
        "import numpy as np\n",
        "import optax\n",
        "from dataclasses import dataclass\n",
        "from functools import partial\n",
        "from typing import Any\n",
        "from IPython.display import Markdown, display\n",
        "\n",
        "# When rerunning inside one Colab runtime, release large objects from a\n",
        "# previous interrupted experiment before allocating the new model.\n",
        "for _name in [\n",
        "    \"all_tokens\", \"train_tokens\", \"val_tokens\", \"state\", \"params\",\n",
        "    \"initial_variables\", \"model\", \"optimizer\", \"lr_schedule\",\n",
        "    \"train_indices_by_op\", \"val_indices_by_op\",\n",
        "]:\n",
        "    globals().pop(_name, None)\n",
        "gc.collect()\n",
        "try:\n",
        "    jax.clear_caches()\n",
        "except Exception:\n",
        "    pass\n",
        "\n",
        "run_notes = []\n",
        "\n",
        "def safe_command(command):\n",
        "    try:\n",
        "        return subprocess.check_output(command, text=True, stderr=subprocess.STDOUT).strip()\n",
        "    except Exception as exc:\n",
        "        return f\"unavailable: {type(exc).__name__}: {exc}\"\n",
        "\n",
        "runtime_info = {\n",
        "    \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S %Z\"),\n",
        "    \"python\": platform.python_version(),\n",
        "    \"jax\": jax.__version__,\n",
        "    \"jaxlib\": jaxlib.__version__,\n",
        "    \"flax\": flax.__version__,\n",
        "    \"optax\": optax.__version__,\n",
        "    \"backend\": jax.default_backend(),\n",
        "    \"devices\": [str(device) for device in jax.devices()],\n",
        "    \"gpu_name\": safe_command([\"nvidia-smi\", \"--query-gpu=name\", \"--format=csv,noheader\"]),\n",
        "    \"nvidia_smi\": safe_command([\"nvidia-smi\"]),\n",
        "    \"platform\": platform.platform(),\n",
        "}\n",
        "\n",
        "print(json.dumps(runtime_info, indent=2))\n",
        "\n",
        "if runtime_info[\"backend\"] not in (\"gpu\", \"tpu\"):\n",
        "    raise RuntimeError(\n",
        "        \"No free TPU/GPU accelerator is attached. In Colab, open Runtime > Change runtime type, \"\n",
        "        \"choose a free TPU if available or a free T4 GPU otherwise, then rerun. Do not enable paid resources.\"\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 2. Tokenization and fixed-width formatting\n",
        "\n",
        "The complete vocabulary is hard-coded:\n",
        "\n",
        "- digits `0` to `9`\n",
        "- space\n",
        "- `+`\n",
        "- `-`\n",
        "- `*`\n",
        "- `/`\n",
        "- `%`\n",
        "- `=`\n",
        "\n",
        "Every example has length 28:\n",
        "\n",
        "```text\n",
        "AAA??BBB=SDCCDCCDCCDCCDCCDCC\n",
        "```\n",
        "\n",
        "where `??` is a two-character operator field. Single-character operators\n",
        "are left-aligned with a space (`\"+ \"`, `\"- \"`, `\"* \"`, `\"% \"`), and integer\n",
        "division uses `\"//\"`. The answer field is one sign character (`+` or `-`)\n",
        "followed by six chunks. Each chunk is one reversed result digit plus two\n",
        "carry-trace digits. The local inference script reads only the result\n",
        "digits, but training the carry trace makes multiplication easier because\n",
        "later digits can condition on explicit earlier carries."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "VOCAB_CHARS = \"0123456789 +-*%/=\"\n",
        "char_to_id = {ch: i for i, ch in enumerate(VOCAB_CHARS)}\n",
        "id_to_char = {i: ch for ch, i in char_to_id.items()}\n",
        "\n",
        "OPS = (\"+\", \"-\", \"*\", \"//\", \"%\")\n",
        "DIVMOD_OPS = (\"//\", \"%\")\n",
        "VOCAB_SIZE = len(VOCAB_CHARS)\n",
        "OPERAND_WIDTH = 3\n",
        "OP_WIDTH = 2\n",
        "MAGNITUDE_WIDTH = 6\n",
        "CARRY_WIDTH = 2\n",
        "ANSWER_CHUNK_WIDTH = 1 + CARRY_WIDTH\n",
        "PROMPT_LEN = OPERAND_WIDTH + OP_WIDTH + OPERAND_WIDTH + 1\n",
        "ANSWER_LEN = 1 + MAGNITUDE_WIDTH * ANSWER_CHUNK_WIDTH\n",
        "SEQ_LEN = PROMPT_LEN + ANSWER_LEN\n",
        "OPERAND_VALUES = range(1000)\n",
        "NONZERO_OPERAND_VALUES = range(1, 1000)\n",
        "EXAMPLES_PER_STANDARD_OPERATION = 1000 * 1000\n",
        "EXAMPLES_PER_DIVMOD_OPERATION = 1000 * 999\n",
        "TOTAL_EXAMPLES = 3 * EXAMPLES_PER_STANDARD_OPERATION + 2 * EXAMPLES_PER_DIVMOD_OPERATION\n",
        "\n",
        "def compute_result(a: int, op: str, b: int) -> int:\n",
        "    if op == \"+\":\n",
        "        return a + b\n",
        "    if op == \"-\":\n",
        "        return a - b\n",
        "    if op == \"*\":\n",
        "        return a * b\n",
        "    if op == \"//\":\n",
        "        if b == 0:\n",
        "            raise ZeroDivisionError(\"integer division by zero\")\n",
        "        return a // b\n",
        "    if op == \"%\":\n",
        "        if b == 0:\n",
        "            raise ZeroDivisionError(\"modulo by zero\")\n",
        "        return a % b\n",
        "    raise ValueError(f\"unknown operator: {op}\")\n",
        "\n",
        "def rhs_values_for_op(op: str):\n",
        "    # Division and modulo by zero are undefined, so those operations\n",
        "    # simply skip b=0. No duplicated approximation is used.\n",
        "    return NONZERO_OPERAND_VALUES if op in DIVMOD_OPS else OPERAND_VALUES\n",
        "\n",
        "def multiplication_carry_trace(a: int, b: int) -> list[int]:\n",
        "    a_digits = [(a // (10 ** i)) % 10 for i in range(MAGNITUDE_WIDTH)]\n",
        "    b_digits = [(b // (10 ** i)) % 10 for i in range(MAGNITUDE_WIDTH)]\n",
        "    carry = 0\n",
        "    carries = []\n",
        "    for column in range(MAGNITUDE_WIDTH):\n",
        "        column_total = carry\n",
        "        for i in range(column + 1):\n",
        "            j = column - i\n",
        "            if i < len(a_digits) and j < len(b_digits):\n",
        "                column_total += a_digits[i] * b_digits[j]\n",
        "        carries.append(column_total // 10)\n",
        "        carry = column_total // 10\n",
        "    return carries\n",
        "\n",
        "def encode_result_value(value: int, a: int, op: str, b: int) -> str:\n",
        "    sign = \"-\" if value < 0 else \"+\"\n",
        "    magnitude = f\"{abs(value):0{MAGNITUDE_WIDTH}d}\"\n",
        "    result_digits = magnitude[::-1]\n",
        "    carries = multiplication_carry_trace(a, b) if op == \"*\" else [0] * MAGNITUDE_WIDTH\n",
        "    chunks = [sign]\n",
        "    for digit, carry in zip(result_digits, carries):\n",
        "        chunks.append(digit)\n",
        "        chunks.append(f\"{carry:0{CARRY_WIDTH}d}\")\n",
        "    return \"\".join(chunks)\n",
        "\n",
        "def decode_result_field(field: str) -> int:\n",
        "    sign = field[0]\n",
        "    if sign not in \"+-\":\n",
        "        raise ValueError(f\"bad result sign: {field!r}\")\n",
        "    if len(field) != ANSWER_LEN:\n",
        "        raise ValueError(f\"bad result length: {field!r}\")\n",
        "    digits = \"\".join(\n",
        "        field[1 + i * ANSWER_CHUNK_WIDTH]\n",
        "        for i in range(MAGNITUDE_WIDTH)\n",
        "    )[::-1]\n",
        "    if not digits.isdigit():\n",
        "        raise ValueError(f\"bad result digits: {field!r}\")\n",
        "    value = int(digits)\n",
        "    return -value if sign == \"-\" else value\n",
        "\n",
        "def safe_decode_result_field(field: str) -> str:\n",
        "    try:\n",
        "        return str(decode_result_field(field))\n",
        "    except Exception:\n",
        "        return f\"<invalid:{field}>\"\n",
        "\n",
        "def format_example(a: int, op: str, b: int) -> str:\n",
        "    # Right-align operands and emit sign + reversed digits + carry trace.\n",
        "    op_field = f\"{op:<{OP_WIDTH}}\"\n",
        "    result = encode_result_value(compute_result(a, op, b), a, op, b)\n",
        "    return f\"{a:>{OPERAND_WIDTH}}{op_field}{b:>{OPERAND_WIDTH}}={result}\"\n",
        "\n",
        "def encode_text(text: str) -> np.ndarray:\n",
        "    return np.array([char_to_id[ch] for ch in text], dtype=np.int32)\n",
        "\n",
        "def decode_ids(ids) -> str:\n",
        "    return \"\".join(id_to_char[int(i)] for i in ids)\n",
        "\n",
        "def human_readable(model_text: str) -> str:\n",
        "    left, answer = model_text.split(\"=\")\n",
        "    a = int(left[:OPERAND_WIDTH])\n",
        "    op = left[OPERAND_WIDTH:OPERAND_WIDTH + OP_WIDTH].strip()\n",
        "    b = int(left[OPERAND_WIDTH + OP_WIDTH:])\n",
        "    return f\"{a}{op}{b}={decode_result_field(answer)}\"\n",
        "\n",
        "assert VOCAB_SIZE == 17\n",
        "assert SEQ_LEN == 28\n",
        "assert PROMPT_LEN == 9\n",
        "assert ANSWER_LEN == 19\n",
        "assert len(format_example(999, \"*\", 999)) == SEQ_LEN\n",
        "assert len(format_example(5, \"//\", 4)) == SEQ_LEN\n",
        "assert len(format_example(5, \"%\", 4)) == SEQ_LEN\n",
        "assert set(format_example(0, \"-\", 999)).issubset(set(VOCAB_CHARS))\n",
        "\n",
        "for a, op, b in [(123, \"+\", 45), (123, \"-\", 45), (0, \"-\", 999), (123, \"*\", 45), (5, \"//\", 4), (5, \"%\", 4), (999, \"*\", 999)]:\n",
        "    text = format_example(a, op, b)\n",
        "    ids = encode_text(text)\n",
        "    print(f\"human: {human_readable(text):>16} | model text: {text!r} | token ids: {ids.tolist()}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 3. Dataset\n",
        "\n",
        "We generate all valid ordered operand pairs for each operation:\n",
        "\n",
        "```text\n",
        "3 operations * 1000 * 1000 pairs = 3,000,000 rows\n",
        "2 operations * 1000 * 999 pairs  = 1,998,000 rows\n",
        "total                            = 4,998,000 rows\n",
        "```\n",
        "\n",
        "Division and modulo by zero are undefined, so examples with `b=0` are\n",
        "omitted for `//` and `%`. The dataset therefore has 4,998,000 rows, all\n",
        "of them unique valid arithmetic triples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def build_full_dataset() -> tuple[np.ndarray, np.ndarray]:\n",
        "    tokens = np.empty((TOTAL_EXAMPLES, SEQ_LEN), dtype=np.int32)\n",
        "    op_ids = np.empty((TOTAL_EXAMPLES,), dtype=np.int8)\n",
        "    row = 0\n",
        "    for op_index, op in enumerate(OPS):\n",
        "        rhs_values = rhs_values_for_op(op)\n",
        "        for a in OPERAND_VALUES:\n",
        "            for b in rhs_values:\n",
        "                tokens[row] = encode_text(format_example(a, op, b))\n",
        "                op_ids[row] = op_index\n",
        "                row += 1\n",
        "    assert row == TOTAL_EXAMPLES\n",
        "    return tokens, op_ids\n",
        "\n",
        "dataset_start = time.perf_counter()\n",
        "all_tokens, op_ids = build_full_dataset()\n",
        "dataset_seconds = time.perf_counter() - dataset_start\n",
        "\n",
        "split_rng = np.random.default_rng(42)\n",
        "indices = split_rng.permutation(len(all_tokens))\n",
        "VAL_SIZE = 100_000\n",
        "val_tokens = all_tokens[indices[:VAL_SIZE]]\n",
        "train_tokens = all_tokens[indices[VAL_SIZE:]]\n",
        "val_ops = op_ids[indices[:VAL_SIZE]]\n",
        "train_ops = op_ids[indices[VAL_SIZE:]]\n",
        "train_indices_by_op = [np.flatnonzero(train_ops == op_index) for op_index in range(len(OPS))]\n",
        "val_indices_by_op = [np.flatnonzero(val_ops == op_index) for op_index in range(len(OPS))]\n",
        "\n",
        "# Labels are next-token positions. Answer labels begin at PROMPT_LEN.\n",
        "label_positions = np.arange(1, SEQ_LEN)\n",
        "loss_mask_template = (label_positions >= PROMPT_LEN).astype(np.float32)\n",
        "\n",
        "def make_batch_from_rows(split_tokens: np.ndarray, rows: np.ndarray, batch_size: int) -> dict[str, np.ndarray]:\n",
        "    seq = split_tokens[rows]\n",
        "    return {\n",
        "        \"inputs\": seq[:, :-1],\n",
        "        \"labels\": seq[:, 1:],\n",
        "        \"loss_mask\": np.broadcast_to(loss_mask_template, (batch_size, SEQ_LEN - 1)).copy(),\n",
        "    }\n",
        "\n",
        "def make_uniform_batch(rng: np.random.Generator, split_tokens: np.ndarray, batch_size: int) -> dict[str, np.ndarray]:\n",
        "    rows = rng.integers(0, len(split_tokens), size=batch_size)\n",
        "    return make_batch_from_rows(split_tokens, rows, batch_size)\n",
        "\n",
        "def make_weighted_op_batch(\n",
        "    rng: np.random.Generator,\n",
        "    split_tokens: np.ndarray,\n",
        "    indices_by_op: list[np.ndarray],\n",
        "    batch_size: int,\n",
        "    op_probs: np.ndarray,\n",
        ") -> dict[str, np.ndarray]:\n",
        "    chosen_ops = rng.choice(len(OPS), size=batch_size, p=op_probs)\n",
        "    rows = np.empty(batch_size, dtype=np.int64)\n",
        "    for op_index in range(len(OPS)):\n",
        "        mask = chosen_ops == op_index\n",
        "        count = int(mask.sum())\n",
        "        if count:\n",
        "            rows[mask] = rng.choice(indices_by_op[op_index], size=count, replace=True)\n",
        "    return make_batch_from_rows(split_tokens, rows, batch_size)\n",
        "\n",
        "dataset_info = {\n",
        "    \"operations\": list(OPS),\n",
        "    \"total_examples\": int(len(all_tokens)),\n",
        "    \"unique_valid_examples\": int(TOTAL_EXAMPLES),\n",
        "    \"train_examples\": int(len(train_tokens)),\n",
        "    \"validation_examples\": int(len(val_tokens)),\n",
        "    \"sequence_length\": SEQ_LEN,\n",
        "    \"prompt_length\": PROMPT_LEN,\n",
        "    \"answer_length\": ANSWER_LEN,\n",
        "    \"dataset_build_seconds\": round(dataset_seconds, 2),\n",
        "    \"dataset_memory_mb\": round(all_tokens.nbytes / 1e6, 2),\n",
        "    \"train_examples_by_operation\": {op: int(len(train_indices_by_op[i])) for i, op in enumerate(OPS)},\n",
        "    \"validation_examples_by_operation\": {op: int(len(val_indices_by_op[i])) for i, op in enumerate(OPS)},\n",
        "    \"loss_positions_in_original_sequence\": list(range(PROMPT_LEN, SEQ_LEN)),\n",
        "    \"division_modulo_zero_policy\": \"b=0 is undefined for // and %, so those examples are omitted rather than duplicated.\",\n",
        "}\n",
        "\n",
        "print(json.dumps(dataset_info, indent=2))\n",
        "print(\"Sample validation rows:\")\n",
        "for row in val_tokens[:8]:\n",
        "    text = decode_ids(row)\n",
        "    print(f\"  {text!r} -> {human_readable(text)}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 4. Model architecture\n",
        "\n",
        "The model is still a small GPT-style decoder-only transformer:\n",
        "\n",
        "- token embedding\n",
        "- learned position embedding\n",
        "- causal self-attention\n",
        "- residual MLP blocks\n",
        "- next-character prediction head\n",
        "\n",
        "The architecture is intentionally near 10 million parameters. The task is\n",
        "harder than addition-only, but the model is still small enough for a free\n",
        "accelerator session."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "@dataclass\n",
        "class TransformerConfig:\n",
        "    vocab_size: int = VOCAB_SIZE\n",
        "    max_seq_len: int = SEQ_LEN\n",
        "    d_model: int = 384\n",
        "    n_heads: int = 6\n",
        "    n_layers: int = 6\n",
        "    mlp_dim: int = 1536\n",
        "\n",
        "def normal_init(stddev=0.02):\n",
        "    return nn.initializers.normal(stddev=stddev)\n",
        "\n",
        "class CausalSelfAttention(nn.Module):\n",
        "    d_model: int\n",
        "    n_heads: int\n",
        "\n",
        "    @nn.compact\n",
        "    def __call__(self, x):\n",
        "        batch, seq_len, width = x.shape\n",
        "        assert width == self.d_model\n",
        "        assert self.d_model % self.n_heads == 0\n",
        "        head_dim = self.d_model // self.n_heads\n",
        "\n",
        "        qkv = nn.Dense(3 * self.d_model, use_bias=False, kernel_init=normal_init())(x)\n",
        "        q, k, v = jnp.split(qkv, 3, axis=-1)\n",
        "\n",
        "        def split_heads(tensor):\n",
        "            tensor = tensor.reshape(batch, seq_len, self.n_heads, head_dim)\n",
        "            return tensor.transpose(0, 2, 1, 3)\n",
        "\n",
        "        q = split_heads(q)\n",
        "        k = split_heads(k)\n",
        "        v = split_heads(v)\n",
        "\n",
        "        scale = head_dim ** -0.5\n",
        "        scores = jnp.einsum(\"bhqd,bhkd->bhqk\", q, k) * scale\n",
        "        causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))[None, None, :, :]\n",
        "        scores = jnp.where(causal_mask, scores, -1e10)\n",
        "        weights = nn.softmax(scores, axis=-1)\n",
        "\n",
        "        attended = jnp.einsum(\"bhqk,bhkd->bhqd\", weights, v)\n",
        "        attended = attended.transpose(0, 2, 1, 3).reshape(batch, seq_len, self.d_model)\n",
        "        return nn.Dense(self.d_model, use_bias=False, kernel_init=normal_init())(attended)\n",
        "\n",
        "class TransformerBlock(nn.Module):\n",
        "    config: TransformerConfig\n",
        "\n",
        "    @nn.compact\n",
        "    def __call__(self, x):\n",
        "        x = x + CausalSelfAttention(self.config.d_model, self.config.n_heads)(nn.LayerNorm()(x))\n",
        "        y = nn.LayerNorm()(x)\n",
        "        y = nn.Dense(self.config.mlp_dim, use_bias=False, kernel_init=normal_init())(y)\n",
        "        y = nn.gelu(y, approximate=True)\n",
        "        y = nn.Dense(self.config.d_model, use_bias=False, kernel_init=normal_init())(y)\n",
        "        return x + y\n",
        "\n",
        "class ArithmeticTransformer(nn.Module):\n",
        "    config: TransformerConfig\n",
        "\n",
        "    @nn.compact\n",
        "    def __call__(self, input_ids):\n",
        "        seq_len = input_ids.shape[1]\n",
        "        token_embed = nn.Embed(\n",
        "            num_embeddings=self.config.vocab_size,\n",
        "            features=self.config.d_model,\n",
        "            embedding_init=normal_init(),\n",
        "            name=\"token_embedding\",\n",
        "        )(input_ids)\n",
        "        pos_embed = self.param(\n",
        "            \"position_embedding\",\n",
        "            normal_init(),\n",
        "            (self.config.max_seq_len, self.config.d_model),\n",
        "        )\n",
        "        x = token_embed + pos_embed[None, :seq_len, :]\n",
        "\n",
        "        for _ in range(self.config.n_layers):\n",
        "            x = TransformerBlock(self.config)(x)\n",
        "\n",
        "        x = nn.LayerNorm()(x)\n",
        "        return nn.Dense(self.config.vocab_size, use_bias=False, kernel_init=normal_init(), name=\"lm_head\")(x)\n",
        "\n",
        "def count_params(params: Any) -> int:\n",
        "    return int(sum(leaf.size for leaf in jax.tree_util.tree_leaves(params)))\n",
        "\n",
        "config = TransformerConfig()\n",
        "model = ArithmeticTransformer(config)\n",
        "init_key = jax.random.PRNGKey(0)\n",
        "initial_variables = model.init(init_key, jnp.ones((1, SEQ_LEN - 1), dtype=jnp.int32))\n",
        "params = initial_variables[\"params\"]\n",
        "param_count = count_params(params)\n",
        "\n",
        "model_info = {\n",
        "    \"d_model\": config.d_model,\n",
        "    \"n_heads\": config.n_heads,\n",
        "    \"n_layers\": config.n_layers,\n",
        "    \"mlp_dim\": config.mlp_dim,\n",
        "    \"parameter_count\": param_count,\n",
        "    \"parameter_count_millions\": round(param_count / 1e6, 2),\n",
        "}\n",
        "\n",
        "print(json.dumps(model_info, indent=2))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 5. Loss, optimizer, and train state\n",
        "\n",
        "The model predicts the next character at every position. We only apply\n",
        "loss to the nineteen answer characters after the equals sign.\n",
        "\n",
        "Addition and subtraction become easy quickly. Multiplication, integer\n",
        "division, and modulo need more attention, and modulo is usually the last\n",
        "operation to become reliable. The dataset itself is still the full\n",
        "4,998,000-row valid generated dataset; we only change how often each\n",
        "operation appears in training batches."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "BATCH_SIZE = 1024 if runtime_info[\"backend\"] == \"gpu\" else 2048\n",
        "MAX_STEPS = 8000\n",
        "MIN_STEPS = 5000\n",
        "EVAL_EVERY = 500\n",
        "EVAL_EXACT_N = 1000\n",
        "FINAL_EXACT_N = 20000\n",
        "LEARNING_RATE = 8e-4\n",
        "WEIGHT_DECAY = 1e-4\n",
        "WARMUP_STEPS = 400\n",
        "TARGET_EXACT_ACCURACY = 0.99\n",
        "OP_PROBS = np.array([0.05, 0.10, 0.20, 0.30, 0.35], dtype=np.float64)\n",
        "\n",
        "lr_schedule = optax.warmup_cosine_decay_schedule(\n",
        "    init_value=0.0,\n",
        "    peak_value=LEARNING_RATE,\n",
        "    warmup_steps=WARMUP_STEPS,\n",
        "    decay_steps=MAX_STEPS,\n",
        "    end_value=LEARNING_RATE * 0.1,\n",
        ")\n",
        "\n",
        "optimizer = optax.chain(\n",
        "    optax.clip_by_global_norm(1.0),\n",
        "    optax.adamw(learning_rate=lr_schedule, weight_decay=WEIGHT_DECAY),\n",
        ")\n",
        "\n",
        "state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)\n",
        "\n",
        "def loss_and_metrics(params, batch):\n",
        "    logits = model.apply({\"params\": params}, batch[\"inputs\"])\n",
        "    per_token_loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch[\"labels\"])\n",
        "    mask = batch[\"loss_mask\"]\n",
        "    denom = jnp.maximum(mask.sum(), 1.0)\n",
        "    loss = (per_token_loss * mask).sum() / denom\n",
        "    predictions = jnp.argmax(logits, axis=-1)\n",
        "    token_accuracy = ((predictions == batch[\"labels\"]) * mask).sum() / denom\n",
        "    return loss, {\"loss\": loss, \"token_accuracy\": token_accuracy}\n",
        "\n",
        "@jax.jit\n",
        "def train_step(state, batch):\n",
        "    (loss, metrics), grads = jax.value_and_grad(loss_and_metrics, has_aux=True)(state.params, batch)\n",
        "    state = state.apply_gradients(grads=grads)\n",
        "    metrics = dict(metrics)\n",
        "    metrics[\"learning_rate\"] = lr_schedule(state.step)\n",
        "    return state, metrics\n",
        "\n",
        "@jax.jit\n",
        "def eval_step(state, batch):\n",
        "    _, metrics = loss_and_metrics(state.params, batch)\n",
        "    return metrics\n",
        "\n",
        "hyperparameters = {\n",
        "    \"batch_size\": BATCH_SIZE,\n",
        "    \"max_steps\": MAX_STEPS,\n",
        "    \"min_steps\": MIN_STEPS,\n",
        "    \"eval_every\": EVAL_EVERY,\n",
        "    \"learning_rate\": LEARNING_RATE,\n",
        "    \"weight_decay\": WEIGHT_DECAY,\n",
        "    \"warmup_steps\": WARMUP_STEPS,\n",
        "    \"target_exact_accuracy\": TARGET_EXACT_ACCURACY,\n",
        "    \"op_probs\": {op: float(OP_PROBS[i]) for i, op in enumerate(OPS)},\n",
        "}\n",
        "\n",
        "print(json.dumps(hyperparameters, indent=2))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 6. Generation and exact-answer checks\n",
        "\n",
        "Greedy generation starts from a prompt like `123*  45=` or `  5//  4=`\n",
        "and fills the nineteen answer characters one at a time. The generation step\n",
        "is JIT-compiled because exact-answer evaluation calls it many times."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def make_prompt_tokens(problems):\n",
        "    tokens = np.full((len(problems), SEQ_LEN), char_to_id[\" \"], dtype=np.int32)\n",
        "    for i, (a, op, b) in enumerate(problems):\n",
        "        prompt = f\"{a:>{OPERAND_WIDTH}}{op:<{OP_WIDTH}}{b:>{OPERAND_WIDTH}}=\"\n",
        "        tokens[i, :PROMPT_LEN] = encode_text(prompt)\n",
        "    return tokens\n",
        "\n",
        "@partial(jax.jit, static_argnums=(2,))\n",
        "def generate_position(params, tokens, pos: int):\n",
        "    logits = model.apply({\"params\": params}, tokens[:, :pos])\n",
        "    next_ids = jnp.argmax(logits[:, -1, :], axis=-1).astype(jnp.int32)\n",
        "    return tokens.at[:, pos].set(next_ids)\n",
        "\n",
        "def greedy_generate(params, problems):\n",
        "    tokens = make_prompt_tokens(problems)\n",
        "    tokens = jnp.asarray(tokens, dtype=jnp.int32)\n",
        "    for pos in range(PROMPT_LEN, SEQ_LEN):\n",
        "        tokens = generate_position(params, tokens, pos)\n",
        "    tokens = np.asarray(tokens)\n",
        "    return [decode_ids(row) for row in tokens]\n",
        "\n",
        "def parse_model_text(model_text):\n",
        "    left, answer = model_text.split(\"=\")\n",
        "    a = int(left[:OPERAND_WIDTH])\n",
        "    op = left[OPERAND_WIDTH:OPERAND_WIDTH + OP_WIDTH].strip()\n",
        "    b = int(left[OPERAND_WIDTH + OP_WIDTH:])\n",
        "    return a, op, b, decode_result_field(answer)\n",
        "\n",
        "def prediction_rows(params, problems):\n",
        "    generated = greedy_generate(params, problems)\n",
        "    rows = []\n",
        "    for (a, op, b), model_text in zip(problems, generated):\n",
        "        pred_text = safe_decode_result_field(model_text.split(\"=\")[1])\n",
        "        expected = str(compute_result(a, op, b))\n",
        "        rows.append({\n",
        "            \"problem\": f\"{a}{op}{b}\",\n",
        "            \"model_text\": model_text,\n",
        "            \"prediction\": pred_text,\n",
        "            \"expected\": expected,\n",
        "            \"correct\": pred_text == expected,\n",
        "        })\n",
        "    return rows\n",
        "\n",
        "def print_prediction_rows(rows):\n",
        "    print(f\"{'problem':>10} | {'model text':>16} | {'pred':>8} | {'expected':>8} | correct\")\n",
        "    print(\"-\" * 70)\n",
        "    for row in rows:\n",
        "        print(\n",
        "            f\"{row['problem']:>10} | {row['model_text']!r:>16} | \"\n",
        "            f\"{row['prediction']:>8} | {row['expected']:>8} | {row['correct']}\"\n",
        "        )\n",
        "\n",
        "def exact_answer_accuracy(params, split_tokens, n_examples=1000, seed=0):\n",
        "    rng = np.random.default_rng(seed)\n",
        "    n_examples = min(n_examples, len(split_tokens))\n",
        "    rows = rng.choice(len(split_tokens), size=n_examples, replace=False)\n",
        "    problems = []\n",
        "    expected = []\n",
        "    op_expected = {op: [] for op in OPS}\n",
        "    op_correct = {op: [] for op in OPS}\n",
        "    for row in split_tokens[rows]:\n",
        "        a, op, b, answer = parse_model_text(decode_ids(row))\n",
        "        problems.append((a, op, b))\n",
        "        expected_text = str(answer)\n",
        "        expected.append(expected_text)\n",
        "        op_expected[op].append(expected_text)\n",
        "\n",
        "    generated = greedy_generate(params, problems)\n",
        "    predicted = [safe_decode_result_field(text.split(\"=\")[1]) for text in generated]\n",
        "    correct = [pred == exp for pred, exp in zip(predicted, expected)]\n",
        "\n",
        "    for (_, op, _), pred, exp in zip(problems, predicted, expected):\n",
        "        op_correct[op].append(pred == exp)\n",
        "\n",
        "    by_operation = {}\n",
        "    for op in OPS:\n",
        "        total = len(op_expected[op])\n",
        "        good = int(np.sum(op_correct[op])) if total else 0\n",
        "        by_operation[op] = {\n",
        "            \"correct\": good,\n",
        "            \"total\": total,\n",
        "            \"accuracy\": float(good / total) if total else None,\n",
        "        }\n",
        "\n",
        "    return float(np.mean(correct)), int(np.sum(correct)), int(n_examples), by_operation\n",
        "\n",
        "sample_problems = [\n",
        "    (0, \"+\", 0),\n",
        "    (7, \"+\", 35),\n",
        "    (123, \"-\", 45),\n",
        "    (45, \"-\", 123),\n",
        "    (12, \"*\", 89),\n",
        "    (123, \"*\", 45),\n",
        "    (5, \"//\", 4),\n",
        "    (999, \"//\", 1),\n",
        "    (5, \"%\", 4),\n",
        "    (998, \"%\", 999),\n",
        "    (501, \"*\", 499),\n",
        "    (999, \"*\", 999),\n",
        "]\n",
        "\n",
        "print(\"Before training:\")\n",
        "before_training_rows = prediction_rows(state.params, sample_problems)\n",
        "print_prediction_rows(before_training_rows)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 7. Training loop\n",
        "\n",
        "This loop logs token accuracy, exact-answer accuracy, and exact-answer\n",
        "accuracy split by operation. Multiplication, integer division, and modulo\n",
        "usually lag behind addition and subtraction, so per-operation metrics are\n",
        "more informative than one aggregate number."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "batch_rng = np.random.default_rng(1234)\n",
        "history = []\n",
        "train_start = time.perf_counter()\n",
        "\n",
        "for step in range(1, MAX_STEPS + 1):\n",
        "    batch_np = make_weighted_op_batch(batch_rng, train_tokens, train_indices_by_op, BATCH_SIZE, OP_PROBS)\n",
        "    batch = {name: jnp.asarray(value) for name, value in batch_np.items()}\n",
        "    state, train_metrics = train_step(state, batch)\n",
        "\n",
        "    if step == 1 or step % EVAL_EVERY == 0:\n",
        "        val_batch_np = make_uniform_batch(batch_rng, val_tokens, BATCH_SIZE)\n",
        "        val_batch = {name: jnp.asarray(value) for name, value in val_batch_np.items()}\n",
        "        val_metrics = eval_step(state, val_batch)\n",
        "\n",
        "        train_metrics = jax.device_get(train_metrics)\n",
        "        val_metrics = jax.device_get(val_metrics)\n",
        "        exact_acc, exact_correct, exact_total, by_operation = exact_answer_accuracy(\n",
        "            state.params, val_tokens, n_examples=EVAL_EXACT_N, seed=step\n",
        "        )\n",
        "        elapsed = time.perf_counter() - train_start\n",
        "        record = {\n",
        "            \"step\": step,\n",
        "            \"elapsed_seconds\": round(elapsed, 2),\n",
        "            \"train_loss\": float(train_metrics[\"loss\"]),\n",
        "            \"train_token_accuracy\": float(train_metrics[\"token_accuracy\"]),\n",
        "            \"val_loss\": float(val_metrics[\"loss\"]),\n",
        "            \"val_token_accuracy\": float(val_metrics[\"token_accuracy\"]),\n",
        "            \"val_exact_accuracy_sample\": exact_acc,\n",
        "            \"val_exact_correct\": exact_correct,\n",
        "            \"val_exact_total\": exact_total,\n",
        "            \"val_exact_by_operation\": by_operation,\n",
        "            \"learning_rate\": float(train_metrics[\"learning_rate\"]),\n",
        "            \"op_sampling\": {\n",
        "                op: float(OP_PROBS[i])\n",
        "                for i, op in enumerate(OPS)\n",
        "            },\n",
        "        }\n",
        "        history.append(record)\n",
        "        per_op = \" | \".join(\n",
        "            f\"{op}: {stats['correct']}/{stats['total']} ({stats['accuracy']:.3f})\"\n",
        "            for op, stats in by_operation.items()\n",
        "            if stats[\"total\"]\n",
        "        )\n",
        "        print(\n",
        "            f\"step {step:5d} | \"\n",
        "            f\"train loss {record['train_loss']:.4f} | \"\n",
        "            f\"val loss {record['val_loss']:.4f} | \"\n",
        "            f\"val token acc {record['val_token_accuracy']:.3f} | \"\n",
        "            f\"exact {exact_correct}/{exact_total} ({exact_acc:.3f}) | \"\n",
        "            f\"{per_op} | elapsed {elapsed:.1f}s\"\n",
        "        )\n",
        "\n",
        "        if step >= MIN_STEPS and exact_acc >= TARGET_EXACT_ACCURACY:\n",
        "            print(f\"Early stop: exact-answer sample accuracy reached {exact_acc:.3f}.\")\n",
        "            break\n",
        "\n",
        "train_seconds = time.perf_counter() - train_start\n",
        "print(f\"Training finished in {train_seconds:.1f} seconds.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 8. After-training predictions and validation accuracy\n",
        "\n",
        "The final check prints human-readable examples and a larger exact-answer\n",
        "validation sample."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "print(\"After training:\")\n",
        "after_training_rows = prediction_rows(state.params, sample_problems)\n",
        "print_prediction_rows(after_training_rows)\n",
        "\n",
        "final_exact_accuracy, final_exact_correct, final_exact_total, final_by_operation = exact_answer_accuracy(\n",
        "    state.params, val_tokens, n_examples=FINAL_EXACT_N, seed=2026\n",
        ")\n",
        "print(\n",
        "    f\"Final exact-answer validation accuracy: \"\n",
        "    f\"{final_exact_correct}/{final_exact_total} = {final_exact_accuracy:.3f}\"\n",
        ")\n",
        "print(json.dumps(final_by_operation, indent=2))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 9. Save the trained model\n",
        "\n",
        "Because the task changed, this saves to `arithmetic_transformer_params.msgpack`\n",
        "rather than the old addition-only artifact name."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "model_path = \"/content/arithmetic_transformer_params.msgpack\"\n",
        "with open(model_path, \"wb\") as f:\n",
        "    f.write(to_bytes(state.params))\n",
        "\n",
        "print(f\"Saved trained parameters to: {model_path}\")\n",
        "print(f\"File size: {os.path.getsize(model_path) / 1e6:.2f} MB\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 10. Run summary\n",
        "\n",
        "The summary keeps the facts needed for a future tutorial or blog post."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "final_history = history[-1] if history else {}\n",
        "summary = {\n",
        "    \"runtime\": runtime_info,\n",
        "    \"packages\": {\n",
        "        \"jax\": runtime_info[\"jax\"],\n",
        "        \"jaxlib\": runtime_info[\"jaxlib\"],\n",
        "        \"flax\": runtime_info[\"flax\"],\n",
        "        \"optax\": runtime_info[\"optax\"],\n",
        "    },\n",
        "    \"dataset\": dataset_info,\n",
        "    \"model\": model_info,\n",
        "    \"hyperparameters\": hyperparameters,\n",
        "    \"training\": {\n",
        "        \"actual_steps\": int(final_history.get(\"step\", 0)),\n",
        "        \"training_seconds\": round(train_seconds, 2),\n",
        "        \"final_logged_train_loss\": final_history.get(\"train_loss\"),\n",
        "        \"final_logged_val_loss\": final_history.get(\"val_loss\"),\n",
        "        \"final_logged_val_token_accuracy\": final_history.get(\"val_token_accuracy\"),\n",
        "        \"final_sample_exact_accuracy\": final_history.get(\"val_exact_accuracy_sample\"),\n",
        "        \"final_sample_exact_by_operation\": final_history.get(\"val_exact_by_operation\"),\n",
        "        \"final_validation_exact_accuracy\": final_exact_accuracy,\n",
        "        \"final_validation_exact_correct\": final_exact_correct,\n",
        "        \"final_validation_exact_total\": final_exact_total,\n",
        "        \"final_validation_exact_by_operation\": final_by_operation,\n",
        "    },\n",
        "    \"before_training_predictions\": before_training_rows,\n",
        "    \"after_training_predictions\": after_training_rows,\n",
        "    \"errors_encountered\": run_notes if run_notes else [\"None recorded in this notebook run.\"],\n",
        "    \"model_artifact\": model_path,\n",
        "    \"observations\": [\n",
        "        \"The same character-level transformer can model multiple operations when the operator is part of the prompt.\",\n",
        "        \"Multiplication, integer division, and modulo are materially harder than addition and subtraction, so per-operation accuracy matters.\",\n",
        "        \"Operation-aware sampling helped the harder operations without changing the generated dataset.\",\n",
        "        \"The slash character appears twice for integer division because the model is character-level.\",\n",
        "        \"The answer field is sign plus reversed result digits, with two carry-trace digits after each digit.\",\n",
        "        \"The minus character is overloaded: it can be an operator in the prompt or a sign in the result field.\",\n",
        "        \"The six result digits handle both negative subtraction results and 999*999=998001.\",\n",
        "        \"Division and modulo by zero are omitted, so the generated dataset has 4,998,000 unique valid rows.\",\n",
        "    ],\n",
        "}\n",
        "\n",
        "print(json.dumps(summary, indent=2))\n",
        "\n",
        "display(Markdown(\n",
        "    f\"\"\"\n",
        "    ### Concise run summary\n",
        "\n",
        "    - Runtime backend: `{runtime_info['backend']}`\n",
        "    - Devices: `{runtime_info['devices']}`\n",
        "    - Parameters: `{model_info['parameter_count_millions']}M`\n",
        "    - Dataset: `{dataset_info['train_examples']}` train / `{dataset_info['validation_examples']}` validation examples\n",
        "    - Operations: `{dataset_info['operations']}`\n",
        "    - Sequence format: fixed length `{SEQ_LEN}`, prompt length `{PROMPT_LEN}`, answer length `{ANSWER_LEN}`\n",
        "    - Training: `{summary['training']['actual_steps']}` steps in `{summary['training']['training_seconds']}` seconds\n",
        "    - Final validation exact accuracy: `{final_exact_correct}/{final_exact_total}` = `{final_exact_accuracy:.3f}`\n",
        "    - Per-operation final accuracy: `{final_by_operation}`\n",
        "    - Model artifact: `{model_path}`\n",
        "    \"\"\"\n",
        "))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "arithmetic_transformer_jax_flax_colab.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
