What Actually Speeds Up Transformer Inference?
Profiling and optimizing a small autoregressive transformer with JAX, KV caching, batching, graph compilation, and low-bit inference.
I recently trained a small 10.65M-parameter decoder-only transformer on 3-digit arithmetic. The goal was not to build a useful calculator. It was to create a fully visible transformer training stack: data generation, tokenization, masking, optimization, evaluation, serialization, and local inference.
Once the model was working, I became more interested in a different question:
What actually speeds up autoregressive transformer inference?
That sounds simple, but inference optimization quickly becomes messy. There are Python overheads, graph compilation overheads, synchronization costs, prefix recomputation, batching tradeoffs, correctness pitfalls, and hardware-dependent behavior.
Large language models make these effects difficult to isolate because too many things change simultaneously: model size, serving framework, distributed execution, memory pressure, quantization strategy, scheduling, and hardware topology. So instead of starting with a frontier-scale model, I used the arithmetic transformer as a controlled laboratory.
The model is small enough that implementation details become visible. Python loops matter. Compilation overhead matters. Kernel dispatch and synchronization matter. Some large-model optimizations are almost invisible at this scale. Others reduce latency while changing outputs, so I treat them as failed optimizations rather than speedups.
This article is about what happened when I benchmarked and optimized that model. It is not a serving benchmark for production LLMs. It is a small, auditable inference experiment.
Here is the headline CPU result from the final cumulative stack:
| Batch | Eager tokens/s | Final tokens/s | Same-batch speedup | Eager latency | Final latency | Final compile |
|---|---|---|---|---|---|---|
1 | 31.4 | 509.3 | 16.23x | 0.606s | 0.037s | 7.96s |
8 | 222.2 | 2,252.3 | 10.14x | 0.684s | 0.067s | 8.10s |
32 | 594.8 | 4,076.9 | 6.85x | 1.022s | 0.149s | 8.07s |
128 | 624.7 | 6,225.1 | 9.97x | 3.893s | 0.391s | 8.49s |
The final stack was JIT compilation, a manual KV cache, parallel prompt prefill, fixed-length unrolled decoding, default XLA fusion, and batching. All final-stack rows matched the eager model’s generated text on the cached 3,277-prompt correctness set. Ground-truth exact accuracy stayed at 3,249 / 3,277 = 99.15% because the trained model itself still makes some arithmetic errors.
That distinction matters: the optimization preserved model behavior. It did not make the model more accurate.
Baseline: Naive Autoregressive Inference
The original inference loop was intentionally simple. At every generated token:
- Run the transformer on the entire prefix.
- Take the final logits.
- Greedily decode the next token.
- Append it to the sequence.
- Repeat.
Conceptually:
for position in range(prompt_len, seq_len):
logits = model(prefix[:, :position])
next_token = argmax(logits[:, -1])
prefix[:, position] = next_token
This is correct, readable, and easy to debug. It is also wasteful.
At every decoding step, the model recomputes attention over the entire prefix, even though most of that prefix was already processed during previous steps. This is the classic inefficiency of naive autoregressive decoding.
For this arithmetic model the sequence is only 28 tokens long, so the inefficiency is bounded. But it is still measurable, and the short sequence makes it easier to see which optimizations target real work and which mostly move overhead around.
Experimental Setup
The model is the same small decoder-only transformer from the arithmetic article:
| Component | Value |
|---|---|
| Parameters | 10.65M |
| Layers | 6 |
| Hidden size | 384 |
| Attention heads | 6 |
| MLP dimension | 1536 |
| Sequence length | 28 |
| Vocabulary size | 17 |
The main benchmark suite compared:
- eager inference,
- JIT-compiled token-step inference,
- full-generation graph compilation,
- KV-cached decoding,
- prompt prefill plus cached decoding,
- fixed-length unrolled decoding,
- dtype variants,
- simple low-bit weight-only quantization experiments,
- batching behavior,
- graph-fusion settings,
- and dot-product-attention rewrites.
The primary final-stack run used:
| Setting | Value |
|---|---|
| Backend | CPU |
| CPU | Apple M2 |
| JAX | 0.6.2 |
| Prompt count | 10,000 |
| Warmup iterations | 3 |
| Measured iterations | 10 |
| Batch sizes | 1, 8, 32, 128 |
| Dtype | float32 |
| Correctness examples | 3,277 |
| Generated tokens | 19 answer tokens per prompt |
Every variant used identical prompt buffers, saved structured benchmark artifacts, and separated compile time from steady-state timing when compilation was applicable.
The benchmark harness also verified generated text against a reference path and checked decoded arithmetic answers against Python arithmetic. That dual check is important because this model is not perfectly accurate. A speed result is only meaningful if it preserves the model’s behavior, and preferably if it also preserves exact-answer accuracy.
Why Small Models Are Interesting
Large LLM inference is usually dominated by matrix multiplication throughput, memory bandwidth, KV-cache memory pressure, and distributed scheduling. A 10M-parameter transformer behaves differently.
With this model:
- Python overhead is visible.
- Shape-specialized compilation is visible.
- Dispatch and synchronization overhead are visible.
- Batch-size effects show up quickly.
- Some graph-level rewrites barely matter.
That makes the setup useful as a microscope. The conclusions do not automatically generalize to frontier-scale LLM serving, but the bottlenecks are easier to isolate.
JIT Compilation
The first optimization was straightforward: compile more of the generation path with JAX.
JIT means just-in-time compilation. In JAX, jax.jit traces a Python function for concrete array shapes, lowers it to an XLA computation graph, compiles that graph, and reuses the compiled version on later calls. It can be faster because repeated Python control flow, small array operations, and dispatch overhead get replaced by a smaller number of compiled graph executions.
The price is that the first call pays compilation cost, and new input shapes can trigger new compilations.
In this article:
eagermeans the generation loop runs directly without JIT compilation around the decoding path.jit_stepmeans the per-token decoding step is compiled, but Python still drives the autoregressive loop.jit_fullmeans the fixed-shape generation function is compiled as a larger graph.
The eager baseline spends time in Python loop logic, array slicing, dispatch, synchronization, and repeated graph execution. JIT compilation can remove a large fraction of that steady-state overhead after the initial compile.
In the full benchmark run, batch-1 float32 throughput improved from 31.37 generated tokens/sec in eager mode to:
| Variant | Tokens/sec | Batch latency | Compile time |
|---|---|---|---|
| eager | 31.4 | 0.606s | n/a |
jit_step | 230.1 | 0.083s | 0.416s |
jit_full | 215.7 | 0.088s | 0.458s |
For batch size 1, JIT was clearly useful. But the same idea did not monotonically improve every setting. At batch size 128, jit_step and jit_full were slightly slower than eager in the final run:
| Variant | Tokens/sec | Same-batch ratio vs eager |
|---|---|---|
| eager | 624.7 | 1.00x |
jit_step | 477.1 | 0.76x |
jit_full | 486.8 | 0.78x |
That does not mean JIT is bad. It means the graph you compile, the shapes you specialize, the amount of Python still left in the loop, and the batch size all matter.
The diagnostics also made compile cost visible. In one first-vs-steady diagnostic, batch-1 kv_cache took 2.57s on the first run and then averaged 0.097s steady-state. Shape changes also triggered separate compile work: prefix lengths 9, 10, and 11 each took about one second on first use, then repeated shapes dropped to milliseconds.
That distinction matters in real systems. Faster steady-state execution does not necessarily mean lower user-visible latency for short-lived processes, serverless workloads, or shape-heavy traffic.
KV Caching
The clearest structural optimization was KV caching.
In self-attention, each token is projected into three vectors: a query (Q), a key (K), and a value (V). The current token’s query is compared with previous keys to decide where to attend, then it reads information from the corresponding values.
During autoregressive decoding, old tokens do not change. Their keys and values therefore do not need to be recomputed every time a new token is generated. A KV cache stores those previous key/value tensors layer by layer and appends only the new token’s key/value at each step.
Without KV caching, decoding does this:
token 1 -> compute prefix length 9
token 2 -> recompute prefix length 10
token 3 -> recompute prefix length 11
...
With KV caching:
prompt tokens -> compute keys and values once
new token -> update cache and attend to cached keys/values
next token -> update cache and attend again
...
The model still computes the same kind of attention. The difference is that old keys and values are reused instead of recomputed from the full prefix every step.
For the article-grade CPU run, float32 KV caching produced the strongest single-path speedup among the initial variants:
| Batch | Eager tokens/s | KV-cache tokens/s | Speedup |
|---|---|---|---|
1 | 11.3 | 218.1 | 19.28x |
128 | 141.1 | 1,628.6 | 11.54x |
The final cumulative stack went further by adding prompt prefill and unrolled fixed-length decoding.
Prompt prefill means processing the known prompt tokens in one parallel causal forward pass, then starting token-by-token generation from the filled KV cache. In this experiment the prompt length is fixed at 9, so prefill avoids running the incremental decoder nine separate times before the first generated answer token.
Unrolled fixed-length decoding uses another property of this task: every answer sequence has exactly 19 generated tokens. Instead of compiling a loop that repeatedly checks loop state, the benchmark traces those 19 decode steps as a fixed graph. That gives XLA a more static computation to optimize, but it also makes compilation much more expensive.
In the final cumulative run, the incremental gains looked like this:
| Batch | Variant | Tokens/sec | Batch latency | Compile time |
|---|---|---|---|---|
1 | eager full-prefix | 31.4 | 0.606s | n/a |
1 | KV cache | 284.3 | 0.067s | 1.22s |
1 | KV cache + prefill | 482.2 | 0.039s | 0.81s |
1 | KV cache + prefill + unrolled | 509.3 | 0.037s | 7.96s |
128 | eager full-prefix | 624.7 | 3.893s | n/a |
128 | KV cache | 4,578.7 | 0.531s | 1.43s |
128 | KV cache + prefill | 5,222.7 | 0.466s | 1.37s |
128 | KV cache + prefill + unrolled | 6,225.1 | 0.391s | 8.49s |
All rows in that table matched the eager model’s generated text on the 3,277-prompt correctness set. The last row is fastest in steady state, but the compile-time tradeoff is real: unrolling is attractive here because the sequence length is tiny and fixed. It would need a different evaluation for variable-length generation.
A small attention-only calculation explains why KV caching should help. In the naive full-prefix path, each new token reruns attention over the whole prefix seen so far. That repeats two expensive attention operations: computing query-key scores (QK) and multiplying attention weights by values (AV). For this task’s 28 total tokens and 19 generated tokens, counting only those QK and AV matrix multiplications gives about 10.33M units of work for full-prefix recomputation versus 0.53M for one-token cached decoding, or 19.67x more attention work.
That is a simplified estimate because it ignores the MLP, input/output projections, embedding lookups, compiler behavior, and runtime overhead. It is not a latency prediction. It is just a way to isolate the repeated attention work that KV caching removes.
What The Benchmarks Showed
Several patterns were consistent across the runs.
Python overhead matters more than expected
For small models and short sequences, Python overhead is not hidden under giant matrix multiplications. The eager loop spends a nontrivial amount of time dispatching operations, slicing prefixes, synchronizing execution, and managing autoregressive control flow.
This is most visible at batch size 1, short sequence lengths, and CPU inference. It is one reason JIT and fixed-shape generation helped so much at small batch sizes.
Compile time can dominate
JIT improved steady-state performance, but compilation itself was expensive enough to change the interpretation. The final stack compiled in about 8s for the measured batch sizes:
| Batch | Final steady latency | Final compile time |
|---|---|---|
1 | 0.037s | 7.96s |
8 | 0.067s | 8.10s |
32 | 0.149s | 8.07s |
128 | 0.391s | 8.49s |
If a compiled function is reused many times, this is a good trade. If a process starts cold and serves only a few requests, compile-inclusive latency can dominate everything else.
KV cache changed the scaling behavior
KV caching helped even with only 19 generated answer tokens. It reduced redundant prefix work and made larger batches much more effective.
In the diagnostics run, the batch-size tradeoff was direct:
| Batch | Variant | Batch latency | Per-example latency | Tokens/sec |
|---|---|---|---|---|
1 | kv_cache | 0.148s | 0.148s | 128.6 |
8 | kv_cache | 0.288s | 0.036s | 527.7 |
32 | kv_cache | 0.514s | 0.016s | 1,184.0 |
Batching improved throughput and per-example latency, but it increased per-batch latency. That is the serving tradeoff: interactive latency wants small batches; offline throughput wants large batches.
Output equivalence is not the same as model accuracy
The final stack matched eager generated text on all 3,277 cached correctness prompts. But exact arithmetic accuracy was 99.15%, because the original trained model was not perfect on that correctness set.
That is the right way to read these numbers. The optimization stack preserved behavior; it did not fix model errors.
Quantization Experiments
I ran two focused follow-up checks on the final cumulative stack:
float32versusfloat16,- and
float32versus naiveint8weight-only quantization.
Both checks used the same structural inference path: KV cache, prompt prefill, and fixed-length unrolled decoding on the CPU backend. They also used the same cached 3,277-prompt correctness set. The rows below should be read as paired comparisons within each table, not as replacements for the earlier final-stack table; separate CPU benchmark runs can have different absolute throughput.
The first useful result was float16. It was not uniformly faster:
| Batch | Float32 tokens/s | Float16 tokens/s | FP16/FP32 | Correctness |
|---|---|---|---|---|
1 | 287.1 | 101.8 | 0.35x | matched fp32 |
8 | 1,253.3 | 1,335.1 | 1.07x | matched fp32 |
32 | 1,301.1 | 1,527.9 | 1.17x | matched fp32 |
128 | 1,751.5 | 2,133.3 | 1.22x | matched fp32 |
All float16 rows matched the float32 generated text on 3,277 / 3,277 prompts, and exact arithmetic accuracy remained 3,249 / 3,277 = 99.15%. But the performance pattern is the interesting part: at batch 1, float16 was much slower (0.187s batch latency versus 0.066s for float32). At batch 128, it became faster (1.140s versus 1.389s).
That is consistent with a small-model systems effect. At very small batch sizes, runtime overhead, dispatch overhead, and less efficient tiny kernels can dominate. As the batch grows, the matrix operations become large enough that lower precision appears to help throughput. I did not isolate the exact CPU kernel behavior, so I would not claim a universal reason. The safe conclusion is narrower: in this setup, float16 only became useful once the workload was large enough.
The second check compared the float32 final stack with the same stack using int8 weight-only quantization:
| Batch | Float32 tokens/s | Int8 tokens/s | Int8/FP32 | Correctness |
|---|---|---|---|---|
1 | 331.6 | 343.3 | 1.04x | matched float32 |
8 | 1,316.4 | 1,299.1 | 0.99x | matched float32 |
32 | 2,099.4 | 2,415.1 | 1.15x | matched float32 |
128 | 3,007.3 | 2,445.5 | 0.81x | matched float32 |
The int8 path matched the float32 stack on 3,277 / 3,277 generated outputs, with the same 99.15% exact arithmetic accuracy. Correctness was not the problem. Throughput was mixed: slightly better at batch 1, slightly worse at batch 8, better at batch 32, and clearly worse at batch 128. Batch latency moved with that pattern: at batch 32, int8 improved latency from 0.290s to 0.252s; at batch 128, it regressed from 0.809s to 0.995s.
Focused CPU checks on the final KV-cache + prefill + unrolled stack. Compare lines within each panel; the two panels come from separate benchmark runs.
Why Quantization Did Not Behave The Way I Expected
The theoretical story is simple: lower precision should reduce memory traffic and can expose faster arithmetic. The measured story was more complicated.
The int8 implementation here is deliberately simple. It is naive weight-only quantization: large weight matrices are stored as integer values with a scale, then materialized or dequantized back inside the JAX graph before the floating-point computation. It is not a production int8 GEMM kernel, not GPTQ, not AWQ, and not a packed-weight serving runtime.
That implementation detail matters. If the runtime does not execute the core matmuls as efficient packed int8 operations, the cost of dequantization, graph complexity, memory movement, and compile overhead can cancel the theoretical savings. In the focused check, int8 compile time was also higher at every measured batch size; at batch 128, it was 28.18s versus 18.61s for the float32 stack.
The result is a useful systems lesson: a lower-bit representation is not the same thing as a faster inference kernel. The representation, compiler lowering, backend kernels, batch size, and model size all matter.
Optimizations That Sounded Good But Barely Helped
A few follow-up results are worth calling out because they are easy to get wrong from intuition alone:
float16sounded like an obvious win, but at batch1it was only0.35xas fast asfloat32.int8preserved outputs, but it was slower than the float32 stack at batch8and batch128.int8reduced weight precision, but this implementation still paid materialization/dequantization cost inside the graph.- Both
float16andint8increased compile time in these focused CPU checks.
The broader takeaway is that theoretical arithmetic reductions do not automatically translate into real-world inference speedups. Runtime implementation matters. Backend support matters. Workload scale matters. Correctness validation matters. For this small autoregressive model, the best optimization was not “use fewer bits”; it was measuring where the time actually went.
Optimizations That Did Not Help Much
One useful outcome of the project was identifying things that sounded promising but did not matter much in this setup.
Dot-product attention rewrite
I compared explicit QK/mask/softmax/AV attention against jax.nn.dot_product_attention(..., implementation="xla").
On this CPU backend, the rewrite did not produce a fused attention custom call and did not meaningfully improve speed:
| Dtype | Batch | Baseline tokens/s | DPA tokens/s | DPA/base |
|---|---|---|---|---|
float32 | 1 | 512.8 | 502.7 | 0.98x |
float32 | 8 | 2,076.3 | 2,085.2 | 1.00x |
float32 | 128 | 6,011.7 | 5,989.5 | 1.00x |
bfloat16 | 128 | 3,308.7 | 3,200.6 | 0.97x |
The float32 outputs matched exactly. The bfloat16 path changed generated text on two examples in the correctness set, so I did not treat it as a lossless optimization.
Disabling XLA fusion
I also compared the normal XLA CPU compiler pipeline against the same benchmark with HLO fusion disabled.
For float32, default fusion was consistently better. For kv_cache_prefill_unrolled batch 128, default fusion reached 7,331.3 tokens/sec versus 5,687.2 tokens/sec with fusion disabled, a 1.29x ratio. Compile time was also much lower with default fusion: 7.27s versus 19.55s.
The conclusion is not surprising, but it is useful: on this backend, the lossless node-fusion answer was mostly to keep XLA fusion enabled and avoid graph changes that block it.
JIT alone at large batch sizes
JIT alone helped batch size 1, but at batch size 128 it was slightly worse than eager in the final stack comparison. The large gains came from changing the structure of decoding, not merely compiling the same prefix-recomputing algorithm.
That is a recurring theme in inference work: compiler improvements help, but algorithmic waste usually needs to be removed directly.
Limits Of This Experiment
This is a 10M-parameter transformer with sequence length 28, measured on a CPU backend. It is not a frontier model, and it is not a production serving stack.
Large LLM inference is often dominated by:
- memory bandwidth,
- GPU utilization,
- communication overhead,
- KV-cache memory pressure,
- scheduling,
- and large matrix multiplication throughput.
This small model behaves differently. Launch and Python overheads are proportionally larger. Compilation costs are easier to see. Some GPU-oriented optimizations are not represented. The CPU backend does not expose kernel launch timing in the same way a GPU profiler would.
The value of the experiment is narrower: it makes inference mechanics visible.
What I Would Test Next
The next useful experiments are mostly about scale and realism:
- Increase sequence length and measure when KV caching becomes overwhelming.
- Run the same stack on a GPU backend.
- Measure first-token latency separately from full-answer latency.
- Add fused kernels or Triton kernels for the small decode path.
- Test a larger model, such as a 1B-parameter decoder, with the same measurement discipline.
- Measure when memory bandwidth becomes dominant.
- Compare persistent graph execution against cold compile-heavy execution.
- Add speculative decoding.
- Replace the simple low-bit baselines with GPTQ/AWQ-style methods.
- Test continuous batching and paged KV caches with a real serving loop.
The useful constraint is that each experiment should keep correctness checks in the loop. Fast invalid output is not an optimization.
Conclusion
This started as a small arithmetic transformer experiment. It eventually became a controlled inference optimization laboratory.
The project is not important because the model itself is useful. It is useful because every part of the inference stack is visible: autoregressive decoding, graph compilation, batching, caching, synchronization, quantization, correctness validation, and performance tradeoffs.
The biggest lesson was not “KV caching is faster.” That was expected.
The more important lesson was that inference optimization is mostly about understanding where time is actually spent. Some optimizations interact badly with compilation. Some improve throughput while hurting latency. Some silently break correctness. Some help large models much more than small ones. And many optimizations that sound important barely matter once measured carefully.
Small controlled systems are valuable because they make those mechanics easier to isolate. In practice, understanding the mechanics is often more important than applying optimizations blindly.