Training a 10M-Parameter Transformer to Learn 3-Digit Arithmetic
A code-first experiment that builds a character-level decoder-only transformer with JAX, Flax, and Optax, then trains it on generated addition, subtraction, multiplication, integer division, and modulo problems.
I trained a 10.65M-parameter decoder-only transformer on five 3-digit arithmetic operations. The goal was not to build a calculator. It was to make every part of a small language-model training run visible and auditable: data generation, tokenization, masking, model definition, optimization, evaluation, serialization, and local inference.
Addition alone is too narrow for this kind of experiment, so the task includes addition, subtraction, multiplication, integer division, and modulo over operands from 0 to 999. Division and modulo by zero are omitted. This is a bounded synthetic task, not evidence of human-like arithmetic or symbolic reasoning.
For orientation, here is the headline result before getting into the implementation. The final run scored 19,858 / 20,000 = 99.29% exact integer accuracy on random validation prompts:
| Operation | Correct | Total | Accuracy |
|---|---|---|---|
+ | 4,022 | 4,022 | 100.00% |
- | 3,965 | 3,965 | 100.00% |
* | 4,084 | 4,084 | 100.00% |
// | 4,007 | 4,028 | 99.48% |
% | 3,780 | 3,901 | 96.90% |
Modulo remained the weakest operation in this run. I treat that as an observed failure pattern worth testing further, not as a deep claim about transformers.
The score is only meaningful once the sequence task is precise. The next sections start with the ordinary arithmetic examples, then show why the exact string representation matters for a causal language model.
The Task
The human-readable task is ordinary integer arithmetic:
123+45=168
45-123=-78
123*45=5535
5//4=1
5%4=1
999*999=998001
The model never receives Python integers. It receives characters. The operator is part of the prompt, and the answer is generated one token at a time by a GPT-style decoder-only transformer.
The useful question is not whether this replaces a calculator. It does not. The useful question is whether a small, fully visible training setup can be used to study how sequence models learn algorithmic-looking behavior under controlled assumptions.
That last phrase, “sequence models”, is the key constraint. Arithmetic on paper is usually carried out from low-order digits to high-order digits. A decoder-only transformer generates from left to right. The representation below is the bridge between those two directions.
Why The Representation Matters
A decoder-only transformer can only condition on tokens to its left. When it predicts the first answer token after =, it has seen the prompt but none of the answer tokens it will generate later. That is natural for prose. It is awkward for arithmetic.
In ordinary decimal arithmetic, the low-order columns decide the carries used by later columns. For example, the ones column of 123 * 45 creates a carry that affects the tens column, which creates another carry, and so on. The mechanical procedure runs from right to left:
ones column -> tens column -> hundreds column -> thousands column
A normal answer string is emitted in the opposite direction. For 123 * 45, the final answer is written as 5535, so direct left-to-right generation asks for the thousands digit before the model has emitted the ones digit or any carry information.
For 123 * 45, direct generation asks for this:
123*45=5535
This is not impossible for a transformer. The prompt contains both operands, so the network can in principle compute internal features that represent the whole product before emitting any answer token. But it makes the learned algorithm less visible and puts more pressure on hidden activations: the model has to learn arithmetic and also invent its own scratch space for carries.
The representation I used makes the supervised sequence closer to the arithmetic procedure:
human-readable arithmetic -> model training format -> decoded result
123*45=5535 -> 123* 45=+501302501500000000 -> 5535
The change has three practical effects:
- Fixed-width fields keep the same semantic information at the same token positions.
- Reversed result digits make the answer run from low-order digits to high-order digits.
- Carry-trace digits expose multiplication carries as explicit scratchpad tokens.
I did not run a separate ablation for every format choice, so I would not claim a measured causal improvement for each one. The important claim is narrower: this representation deliberately makes the task easier for an autoregressive model by matching the output order to the dependency structure of column arithmetic.
The fixed format is:
AAA??BBB=SDCCDCCDCCDCCDCCDCC
AAA: first operand, right-aligned to width 3??: two-character operator field, so//and one-character operators share positionsBBB: second operand, right-aligned to width 3=: delimiter at a fixed positionS: result sign,+or-DCC: one reversed result digit followed by two carry-trace digits
Here is the construction recipe for one example:
- Format operand
ain a 3-character right-aligned field. - Format the operator in a 2-character left-aligned field, so
+and//occupy the same positions. - Format operand
bin a 3-character right-aligned field. - Add
=, then the result sign. - Convert the absolute result to a fixed 6-digit decimal string.
- Reverse those 6 result digits.
- After each reversed digit, add a 2-character carry trace.
The carry trace is easiest to understand for multiplication. A carry-trace value is the carry that remains after processing one base-10 multiplication column. It is computed as:
carry_after_column = floor(column_total / 10)
where column_total includes the previous carry plus all digit products that land in the current column. The field is two characters wide because carries such as 17 and 26 occur in this 3-digit multiplication range.
For 123 * 45, the multiplication columns are:
| Column | Products and incoming carry | Column total | Result digit | Carry trace |
|---|---|---|---|---|
| 0 | 3*5 | 15 | 5 | 01 |
| 1 | 2*5 + 3*4 + 1 | 23 | 3 | 02 |
| 2 | 1*5 + 2*4 + 2 | 15 | 5 | 01 |
| 3 | 1*4 + 1 | 5 | 5 | 00 |
| 4 | no remaining product terms | 0 | 0 | 00 |
| 5 | no remaining product terms | 0 | 0 | 00 |
That gives:
sign reversed result digits carry traces
+ 5 3 5 5 0 0 01 02 01 00 00 00
interleaved answer field
+ 5 01 3 02 5 01 5 00 0 00 0 00
compact string
+501302501500000000
The D positions are the answer digits. The CC positions are scaffolding. During decoding, I read only the D positions, reverse them back to normal order, and apply the sign. The carry traces are not part of the final integer answer.
For addition, subtraction, integer division, and modulo, the carry-trace digits are set to 00. That keeps every operation at the same sequence length without pretending those operations use the same multiplication scratchpad.
123+45=168 -> 123+ 45=+800600100000000000
45-123=-78 -> 45- 123=-800700000000000000
123*45=5535 -> 123* 45=+501302501500000000
5//4=1 -> 5// 4=+100000000000000000
5%4=1 -> 5% 4=+100000000000000000
999*999=998001 -> 999* 999=+108017026818909900
Without the fixed-width layout, short examples like 5%4=1 and long examples like 999*999=998001 would shift important tokens into different positions. Without reversed digits, the model would have to emit the most significant result digit first, even though column arithmetic naturally discovers it last. Without carry traces, multiplication would still be learnable in principle, but the carries would have to live entirely inside hidden activations instead of being supervised as visible intermediate tokens.
The representation is therefore doing real work. It is not a neutral serialization detail, and it is not a symbolic calculator. It is a task design choice that makes the sequence model’s job more aligned with the arithmetic procedure. At inference time, the model still generates characters. The decoder reads the generated field; it does not call Python arithmetic to solve the prompt.
The constants are small and easy to audit:
| Quantity | Value |
|---|---|
| Vocabulary size | 17 |
| Sequence length | 28 |
| Prompt length | 9 |
| Answer length | 19 |
| Operand range | 0..999 |
| Operations | +, -, *, //, % |
Once the representation is fixed, tokenization becomes deliberately uninteresting: every input is just a sequence of known characters.
Tokenization And Formatting
No tokenizer library is needed. The vocabulary is a fixed character set:
VOCAB_CHARS = "0123456789 +-*%/="
char_to_id = {ch: i for i, ch in enumerate(VOCAB_CHARS)}
id_to_char = {i: ch for ch, i in char_to_id.items()}
def encode_text(text: str) -> np.ndarray:
return np.array([char_to_id[ch] for ch in text], dtype=np.int32)
The formatting code is the most important part of the experiment because it defines the task the transformer actually sees. The constants below are the shape of that task:
OPS = ("+", "-", "*", "//", "%")
OPERAND_WIDTH = 3
OP_WIDTH = 2
MAGNITUDE_WIDTH = 6
CARRY_WIDTH = 2
ANSWER_CHUNK_WIDTH = 1 + CARRY_WIDTH
PROMPT_LEN = OPERAND_WIDTH + OP_WIDTH + OPERAND_WIDTH + 1
ANSWER_LEN = 1 + MAGNITUDE_WIDTH * ANSWER_CHUNK_WIDTH
SEQ_LEN = PROMPT_LEN + ANSWER_LEN
The core representation logic is small: compute multiplication carries, interleave each reversed result digit with a carry trace, and decode only the digit positions later.
# formatting_core.py
def multiplication_carry_trace(a: int, b: int) -> list[int]:
a_digits = [(a // (10**i)) % 10 for i in range(MAGNITUDE_WIDTH)]
b_digits = [(b // (10**i)) % 10 for i in range(MAGNITUDE_WIDTH)]
carry = 0
carries = []
for column in range(MAGNITUDE_WIDTH):
column_total = carry
for i in range(column + 1):
column_total += a_digits[i] * b_digits[column - i]
carry = column_total // 10
carries.append(carry)
return carries
def encode_result_value(value: int, a: int, op: str, b: int) -> str:
sign = "-" if value < 0 else "+"
reversed_digits = f"{abs(value):0{MAGNITUDE_WIDTH}d}"[::-1]
carries = multiplication_carry_trace(a, b) if op == "*" else [0] * MAGNITUDE_WIDTH
chunks = [sign]
for digit, carry in zip(reversed_digits, carries):
chunks.extend([digit, f"{carry:0{CARRY_WIDTH}d}"])
return "".join(chunks)
def format_example(a: int, op: str, b: int) -> str:
op_field = f"{op:<{OP_WIDTH}}"
value = compute_result(a, op, b) # label generation, not model inference
return f"{a:>{OPERAND_WIDTH}}{op_field}{b:>{OPERAND_WIDTH}}={encode_result_value(value, a, op, b)}"
def decode_result_field(field: str) -> int:
sign = field[0]
digits = "".join(field[1 + i * ANSWER_CHUNK_WIDTH] for i in range(MAGNITUDE_WIDTH))
value = int(digits[::-1])
return -value if sign == "-" else value
Show full formatting code
# formatting.py
import numpy as np
VOCAB_CHARS = "0123456789 +-*%/="
char_to_id = {ch: i for i, ch in enumerate(VOCAB_CHARS)}
id_to_char = {i: ch for ch, i in char_to_id.items()}
OPS = ("+", "-", "*", "//", "%")
DIVMOD_OPS = ("//", "%")
VOCAB_SIZE = len(VOCAB_CHARS)
OPERAND_WIDTH = 3
OP_WIDTH = 2
MAGNITUDE_WIDTH = 6
CARRY_WIDTH = 2
ANSWER_CHUNK_WIDTH = 1 + CARRY_WIDTH
PROMPT_LEN = OPERAND_WIDTH + OP_WIDTH + OPERAND_WIDTH + 1
ANSWER_LEN = 1 + MAGNITUDE_WIDTH * ANSWER_CHUNK_WIDTH
SEQ_LEN = PROMPT_LEN + ANSWER_LEN
def compute_result(a: int, op: str, b: int) -> int:
if op == "+":
return a + b
if op == "-":
return a - b
if op == "*":
return a * b
if op == "//":
if b == 0:
raise ZeroDivisionError("integer division by zero")
return a // b
if op == "%":
if b == 0:
raise ZeroDivisionError("modulo by zero")
return a % b
raise ValueError(f"unknown operator: {op}")
def multiplication_carry_trace(a: int, b: int) -> list[int]:
a_digits = [(a // (10**i)) % 10 for i in range(MAGNITUDE_WIDTH)]
b_digits = [(b // (10**i)) % 10 for i in range(MAGNITUDE_WIDTH)]
carry = 0
carries = []
for column in range(MAGNITUDE_WIDTH):
column_total = carry
for i in range(column + 1):
j = column - i
column_total += a_digits[i] * b_digits[j]
carry = column_total // 10
carries.append(carry)
return carries
def encode_result_value(value: int, a: int, op: str, b: int) -> str:
sign = "-" if value < 0 else "+"
magnitude = f"{abs(value):0{MAGNITUDE_WIDTH}d}"
result_digits = magnitude[::-1]
carries = multiplication_carry_trace(a, b) if op == "*" else [0] * MAGNITUDE_WIDTH
chunks = [sign]
for digit, carry in zip(result_digits, carries):
chunks.append(digit)
chunks.append(f"{carry:0{CARRY_WIDTH}d}")
return "".join(chunks)
def decode_result_field(field: str) -> int:
sign = field[0]
digits = "".join(
field[1 + i * ANSWER_CHUNK_WIDTH]
for i in range(MAGNITUDE_WIDTH)
)[::-1]
value = int(digits)
return -value if sign == "-" else value
def format_example(a: int, op: str, b: int) -> str:
op_field = f"{op:<{OP_WIDTH}}"
result = encode_result_value(compute_result(a, op, b), a, op, b)
return f"{a:>{OPERAND_WIDTH}}{op_field}{b:>{OPERAND_WIDTH}}={result}"
def encode_text(text: str) -> np.ndarray:
return np.array([char_to_id[ch] for ch in text], dtype=np.int32)The sanity checks catch silent format drift before training:
assert VOCAB_SIZE == 17
assert PROMPT_LEN == 9
assert ANSWER_LEN == 19
assert SEQ_LEN == 28
assert len(format_example(999, "*", 999)) == SEQ_LEN
assert len(format_example(5, "//", 4)) == SEQ_LEN
assert set(format_example(0, "-", 999)).issubset(set(VOCAB_CHARS))
With the format locked down, dataset generation becomes a deterministic enumeration problem rather than a data-cleaning problem.
Dataset
The full generated corpus has 4,998,000 valid examples:
3 operations * 1,000 first operands * 1,000 second operands = 3,000,000
2 operations * 1,000 first operands * 999 second operands = 1,998,000
total = 4,998,000
The // and % rows skip b=0 because division and modulo by zero are undefined. I used a 100,000 example validation split, leaving 4,898,000 training rows. The tokenized full dataset uses about 559.78 MB in the saved run metadata.
OPERAND_VALUES = range(1000)
NONZERO_OPERAND_VALUES = range(1, 1000)
TOTAL_EXAMPLES = 3 * 1000 * 1000 + 2 * 1000 * 999
def rhs_values_for_op(op: str):
return NONZERO_OPERAND_VALUES if op in DIVMOD_OPS else OPERAND_VALUES
def build_full_dataset() -> tuple[np.ndarray, np.ndarray]:
tokens = np.empty((TOTAL_EXAMPLES, SEQ_LEN), dtype=np.int32)
op_ids = np.empty((TOTAL_EXAMPLES,), dtype=np.int8)
row = 0
for op_index, op in enumerate(OPS):
for a in OPERAND_VALUES:
for b in rhs_values_for_op(op):
tokens[row] = encode_text(format_example(a, op, b))
op_ids[row] = op_index
row += 1
return tokens, op_ids
The validation split is a reproducible permutation of that enumerated table:
Show full dataset split code
all_tokens, op_ids = build_full_dataset()
split_rng = np.random.default_rng(42)
indices = split_rng.permutation(len(all_tokens))
VAL_SIZE = 100_000
val_tokens = all_tokens[indices[:VAL_SIZE]]
train_tokens = all_tokens[indices[VAL_SIZE:]]
val_ops = op_ids[indices[:VAL_SIZE]]
train_ops = op_ids[indices[VAL_SIZE:]]Fixed length keeps batching simple. Spaces are tokens, not padding to be ignored. For example, 5%4=1 becomes:
5% 4=+100000000000000000
That gives the model stable positions: operand A at positions 0..2, operator at 3..4, operand B at 5..7, equals at 8, and the answer at 9..27.
At this point the training data is just a fixed-length character prediction task. The model can therefore stay close to a standard small GPT-style decoder.
Model
The model is a compact decoder-only transformer:
- 6 transformer layers
- width 384
- 6 attention heads
- MLP width 1536
- learned positional embeddings
- causal self-attention
- 10,650,624 parameters
from dataclasses import dataclass
from flax import linen as nn
@dataclass
class TransformerConfig:
vocab_size: int = VOCAB_SIZE
max_seq_len: int = SEQ_LEN
d_model: int = 384
n_heads: int = 6
n_layers: int = 6
mlp_dim: int = 1536
class ArithmeticTransformer(nn.Module):
config: TransformerConfig
@nn.compact
def __call__(self, input_ids):
token_embed = nn.Embed(
num_embeddings=self.config.vocab_size,
features=self.config.d_model,
)(input_ids)
pos_embed = self.param(
"position_embedding",
nn.initializers.normal(0.02),
(self.config.max_seq_len, self.config.d_model),
)
x = token_embed + pos_embed[None, : input_ids.shape[1], :]
for _ in range(self.config.n_layers):
x = TransformerBlock(self.config)(x)
return nn.Dense(self.config.vocab_size, name="lm_head")(nn.LayerNorm()(x))
That inline snippet is the model shape. The full implementation below expands the attention block, MLP block, initialization, and parameter counter.
Show full model code
# model.py
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from flax import linen as nn
@dataclass
class TransformerConfig:
vocab_size: int = VOCAB_SIZE
max_seq_len: int = SEQ_LEN
d_model: int = 384
n_heads: int = 6
n_layers: int = 6
mlp_dim: int = 1536
def normal_init(stddev: float = 0.02):
return nn.initializers.normal(stddev=stddev)
class CausalSelfAttention(nn.Module):
d_model: int
n_heads: int
@nn.compact
def __call__(self, x):
batch, seq_len, width = x.shape
head_dim = self.d_model // self.n_heads
qkv = nn.Dense(
3 * self.d_model,
use_bias=False,
kernel_init=normal_init(),
)(x)
q, k, v = jnp.split(qkv, 3, axis=-1)
def split_heads(tensor):
tensor = tensor.reshape(batch, seq_len, self.n_heads, head_dim)
return tensor.transpose(0, 2, 1, 3)
q, k, v = split_heads(q), split_heads(k), split_heads(v)
scores = jnp.einsum("bhqd,bhkd->bhqk", q, k) * (head_dim**-0.5)
causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))[None, None]
scores = jnp.where(causal_mask, scores, -1e10)
weights = nn.softmax(scores, axis=-1)
attended = jnp.einsum("bhqk,bhkd->bhqd", weights, v)
attended = attended.transpose(0, 2, 1, 3).reshape(batch, seq_len, self.d_model)
return nn.Dense(
self.d_model,
use_bias=False,
kernel_init=normal_init(),
)(attended)
class TransformerBlock(nn.Module):
config: TransformerConfig
@nn.compact
def __call__(self, x):
x = x + CausalSelfAttention(self.config.d_model, self.config.n_heads)(
nn.LayerNorm()(x)
)
y = nn.LayerNorm()(x)
y = nn.Dense(
self.config.mlp_dim,
use_bias=False,
kernel_init=normal_init(),
)(y)
y = nn.gelu(y, approximate=True)
y = nn.Dense(
self.config.d_model,
use_bias=False,
kernel_init=normal_init(),
)(y)
return x + y
class ArithmeticTransformer(nn.Module):
config: TransformerConfig
@nn.compact
def __call__(self, input_ids):
seq_len = input_ids.shape[1]
token_embed = nn.Embed(
num_embeddings=self.config.vocab_size,
features=self.config.d_model,
embedding_init=normal_init(),
name="token_embedding",
)(input_ids)
pos_embed = self.param(
"position_embedding",
normal_init(),
(self.config.max_seq_len, self.config.d_model),
)
x = token_embed + pos_embed[None, :seq_len, :]
for _ in range(self.config.n_layers):
x = TransformerBlock(self.config)(x)
x = nn.LayerNorm()(x)
return nn.Dense(
self.config.vocab_size,
use_bias=False,
kernel_init=normal_init(),
name="lm_head",
)(x)
def count_params(params) -> int:
return int(sum(leaf.size for leaf in jax.tree_util.tree_leaves(params)))This is intentionally small enough to inspect and run on a single accelerator, while still using the same building blocks as a language model training loop.
The architecture is ordinary on purpose. The experiment is mostly about the representation, the mask, and the evaluation protocol, not architectural novelty.
Training
The optimizer is AdamW with warmup cosine decay, weight decay, and global-norm gradient clipping. The model is trained as a next-character predictor, but the supervised loss is applied only to the answer positions. The prompt is context.
# training.py
import optax
from flax.training import train_state
BATCH_SIZE = 2048
LEARNING_RATE = 2e-4
WEIGHT_DECAY = 1e-4
WARMUP_STEPS = 50
GRAD_CLIP = 1.0
label_positions = np.arange(1, SEQ_LEN)
loss_mask_template = (label_positions >= PROMPT_LEN).astype(np.float32)
def make_batch(tokens: np.ndarray, rows: np.ndarray) -> dict[str, np.ndarray]:
seq = tokens[rows]
return {
"inputs": seq[:, :-1],
"labels": seq[:, 1:],
"loss_mask": np.broadcast_to(
loss_mask_template,
(len(rows), SEQ_LEN - 1),
).copy(),
}
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=LEARNING_RATE,
warmup_steps=WARMUP_STEPS,
decay_steps=1500,
end_value=LEARNING_RATE * 0.1,
)
tx = optax.chain(
optax.clip_by_global_norm(GRAD_CLIP),
optax.adamw(schedule, weight_decay=WEIGHT_DECAY),
)
state = train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=tx,
)
Aggregate accuracy is useful, but per-operation accuracy is more informative here. Addition and subtraction became reliable earlier than multiplication, division, and modulo. The final run used operation-aware sampling:
+ 5%
- 10%
* 20%
// 30%
% 35%
That does not change the generated dataset. It changes how often each operation is sampled into a batch.
OP_PROBS = np.array([0.05, 0.10, 0.20, 0.30, 0.35], dtype=np.float64)
def make_weighted_op_batch(
rng: np.random.Generator,
split_tokens: np.ndarray,
indices_by_op: list[np.ndarray],
batch_size: int,
) -> dict[str, np.ndarray]:
chosen_ops = rng.choice(len(OPS), size=batch_size, p=OP_PROBS)
rows = np.empty(batch_size, dtype=np.int64)
for op_index in range(len(OPS)):
mask = chosen_ops == op_index
count = int(mask.sum())
if count:
rows[mask] = rng.choice(indices_by_op[op_index], size=count, replace=True)
return make_batch(split_tokens, rows)
The training step is standard masked language-model supervision:
@jax.jit
def train_step(state, batch):
def loss_fn(params):
logits = state.apply_fn({"params": params}, batch["inputs"])
per_token_loss = optax.softmax_cross_entropy_with_integer_labels(
logits,
batch["labels"],
)
mask = batch["loss_mask"]
loss = (per_token_loss * mask).sum() / mask.sum()
predictions = jnp.argmax(logits, axis=-1)
token_accuracy = ((predictions == batch["labels"]) * mask).sum() / mask.sum()
return loss, token_accuracy
(loss, token_accuracy), grads = jax.value_and_grad(loss_fn, has_aux=True)(
state.params
)
state = state.apply_gradients(grads=grads)
return state, {"loss": loss, "token_accuracy": token_accuracy}
I also used a short boundary-case fine-tune that mixed a small pool of examples such as 999*999, 999//1, and 998%999 into the batches. This was a practical response to visible boundary failures, not a claim that the model had generalized outside the synthetic setup.
python scripts/train_arithmetic_model.py \
--run-name gpu2_b2048_edge_finetune \
--init-params runs/gpu3_b2048_divmod5000/arithmetic_transformer_params.msgpack \
--edge-case-prob 0.25 \
--op-probs 0.05,0.10,0.20,0.30,0.35 \
--learning-rate 0.0002 \
--warmup-steps 50 \
--max-steps 1500 \
--min-steps 500 \
--eval-every 250 \
--final-exact-n 20000
After training, token accuracy is not the metric I care about most. The useful check is whether the decoded integer is exactly right for each operation.
Evaluation
Generation uses greedy decoding. Each prompt has one intended answer, so sampling is unnecessary for this evaluation.
# inference.py
def make_prompt_tokens(problems):
tokens = np.full((len(problems), SEQ_LEN), char_to_id[" "], dtype=np.int32)
for i, (a, op, b) in enumerate(problems):
prompt = f"{a:>{OPERAND_WIDTH}}{op:<{OP_WIDTH}}{b:>{OPERAND_WIDTH}}="
tokens[i, :PROMPT_LEN] = encode_text(prompt)
return tokens
def generate(params, problems):
tokens = jnp.asarray(make_prompt_tokens(problems), dtype=jnp.int32)
for pos in range(PROMPT_LEN, SEQ_LEN):
logits = model.apply({"params": params}, tokens[:, :pos])
next_ids = jnp.argmax(logits[:, -1, :], axis=-1).astype(jnp.int32)
tokens = tokens.at[:, pos].set(next_ids)
return np.asarray(tokens)
Exact accuracy compares decoded integers, not raw strings. This matters because the model format includes sign, reversed digits, and carry traces.
def exact_answer_accuracy(params, problems):
generated = generate(params, problems)
correct = 0
by_op = {op: {"correct": 0, "total": 0} for op in OPS}
for row, (a, op, b) in zip(generated, problems):
model_text = decode_ids(row)
prediction = safe_decode_result_field(model_text.split("=")[1])
expected = str(compute_result(a, op, b))
is_correct = prediction == expected
correct += int(is_correct)
by_op[op]["correct"] += int(is_correct)
by_op[op]["total"] += 1
return correct, by_op
The local inference script can be run against the saved Flax parameters:
python scripts/test_arithmetic_model.py \
'123+45' '45-123' '123*45' '5//4' '5%4' '999*999' \
--check-random 1000
Quote expressions like '123*45' in the shell so * is not expanded as a glob.
Results
Final exact accuracy on 20,000 random validation prompts was:
19,858 / 20,000 = 99.29%
The per-operation split is more useful than the aggregate:
| Operation | Accuracy |
|---|---|
+ | 100.00% |
- | 100.00% |
* | 100.00% |
// | 99.48% |
% | 96.90% |
Sample predictions from the final run:
| Problem | Prediction | Expected |
|---|---|---|
0+0 | 0 | 0 |
7+35 | 42 | 42 |
45-123 | -78 | -78 |
12*89 | 1068 | 1068 |
123*45 | 5535 | 5535 |
5//4 | 1 | 1 |
999//1 | 999 | 999 |
5%4 | 1 | 1 |
998%999 | 998 | 998 |
501*499 | 249999 | 249999 |
999*999 | 998001 | 998001 |
This is not a state-of-the-art result. It is a controlled result for one bounded distribution, one representation, and one training recipe.
Limitations
The operand range is bounded to 0..999. I did not claim extrapolation outside that range.
The train and validation examples come from the same generator and the same distribution. A high validation score here does not imply robustness under a changed distribution.
The representation helps the model. Reversed result digits align the generated sequence with low-to-high arithmetic, and carry traces give multiplication explicit scaffolding.
The system is not a symbolic calculator. The inference script computes expected answers only for evaluation. The model itself generates a character sequence from learned parameters.
Modulo was the weakest operation in the final run, but the result only says that modulo produced the most remaining errors under this setup.
What I Would Test Next
- Remove multiplication carry traces and measure the accuracy drop.
- Compare normal answer order against reversed answer order.
- Train on operands from
0..99and test on100..999. - Analyze modulo errors by divisor size, quotient boundary, and remainder range.
- Compare model sizes under the same representation and evaluation protocol.
- Compare uniform sampling, operation-aware sampling, and targeted boundary-case sampling.
Conclusion
This setup does not prove symbolic arithmetic. It gives a controlled environment where the full training stack is visible: generated data, representation design, character tokenization, answer masking, transformer definition, optimizer setup, exact-answer evaluation, model serialization, and local inference.
Synthetic tasks are useful when they are treated as small laboratories: narrow, auditable, and honest about what was measured.