TC
English
Technical writing
Tutorial

KV Cache in LLMs From Zero

How KV cache changes LLM decoding, why it saves compute, and where memory becomes the bottleneck.

LLM inference KV cache Serving Memory Optimization

KV cache is one of the most important ideas in LLM inference, but it is also one of the easiest ideas to misunderstand. At a high level, it means:

During generation, do not recompute the attention keys and values for old tokens. Store them once, then reuse them.

This small idea explains a large part of why modern LLM inference works at all. It also explains why LLM serving is so memory-hungry: the cache grows with the number of active requests and with the length of their contexts.

This tutorial starts with the basic motivation, builds the mental model, then goes deeper into shapes, memory, paged KV cache, prefix reuse, quantization, eviction, offloading, batching, and multi-GPU inference.

1. Why KV cache exists

The easiest way to understand KV cache is to start from the naive generation loop and then remove the wasted work.

1.1 The naive generation loop

Suppose the prompt is:

KV cache is important because

The model predicts one token:

KV cache is important because it

Then another:

KV cache is important because it avoids

Then another:

KV cache is important because it avoids recomputing

Autoregressive generation works like this:

input tokens -> model -> next token
input tokens + next token -> model -> next token
input tokens + next token + next token -> model -> next token
...

A very naive implementation would run the full model on the full sequence every time:

def very_naive_generate(model, tokens, num_new_tokens):
    for _ in range(num_new_tokens):
        logits = model.forward(tokens)
        next_token = logits[-1].argmax()
        tokens.append(next_token)
    return tokens

This works conceptually, but it repeats most of the work. At step 50, the first 49 tokens have already been processed many times. The model keeps recomputing information about tokens that have not changed.

1.2 What changes with KV cache

With KV cache, generation is split into two ideas:

1. Process the existing prompt once.
2. During generation, process only the newest token and reuse cached information about previous tokens.

A simplified cached loop looks like this:

def cached_generate(model, tokens, num_new_tokens):
    kv_cache = model.prefill(tokens)

    for _ in range(num_new_tokens):
        logits, kv_cache = model.decode_one_token(tokens[-1], kv_cache)
        next_token = logits.argmax()
        tokens.append(next_token)

    return tokens

This is not real framework code, but it captures the central tradeoff: KV cache saves computation by spending memory. Instead of recomputing old tokens through every Transformer layer, the model stores the key and value tensors it will need later.

1.3 Prefill and decode

The two phases of generation are usually called prefill and decode.

Prefill processes the prompt:

Input:
  "KV cache is important because"

The model computes:
  keys for all prompt tokens
  values for all prompt tokens

Then it stores them in the KV cache.

Decode generates new tokens one by one:

Decode step 1:
  Input only the newest token.
  Reuse old K/V from the cache.
  Append the new token's K/V to the cache.

Decode step 2:
  Input only the newest token.
  Reuse the larger cache.
  Append again.

A compact mental model:

Prompt tokens
    |
    v
Prefill
    |
    v
KV cache contains K/V for prompt
    |
    v
Generate token 1 and append its K/V
    |
    v
Generate token 2 and append its K/V
    |
    v
Generate token 3 and append its K/V
    |
    v
...

The cache grows as generation continues. That growth is useful because it avoids repeated computation, but it is also expensive because the cache usually lives in GPU memory.


2. The attention objects behind the cache

KV cache is easier to understand once the attention computation is clear. This section is a short refresher on queries, keys, and values, then connects them directly to the cache.

2.1 Queries, keys, and values

A Transformer layer receives hidden states X. You can think of X as a matrix where each row is one token’s current hidden vector. Attention uses three learned projection matrices, W_Q, W_K, and W_V, to make three different views of those same token vectors:

Q = X @ W_Q
K = X @ W_K
V = X @ W_V

Here @ means matrix multiplication. The W_* matrices are model weights learned during training. They are not new inputs; they are part of the Transformer layer.

A useful intuition is:

Q: query   -> What am I looking for?
K: key     -> What information do I match against?
V: value   -> What information do I read if I attend here?

Why use these multiplications? Because the same hidden vector has to play different roles. Multiplying by W_Q creates the representation used by the current token to search. Multiplying by W_K creates the representation other tokens can be searched by. Multiplying by W_V creates the content that will be mixed into the output if that token is attended to.

The attention operation is roughly:

scores = Q @ K^T
weights = softmax(scores)
output = weights @ V

The score between one query and one key is a dot product. If they point in a similar direction, the score is high. The softmax turns those scores into weights, and weights @ V mixes together the value vectors according to those weights.

In a causal LLM, token i can attend only to tokens 0..i because generation is left-to-right. Future tokens have not been generated yet, and during training the model is masked so it cannot cheat by looking at the answer. For example:

Token 0: KV
Token 1: cache
Token 2: is
Token 3: useful

The causal attention pattern is:

Token 0 attends to: token 0
Token 1 attends to: token 0, token 1
Token 2 attends to: token 0, token 1, token 2
Token 3 attends to: token 0, token 1, token 2, token 3

2.2 What exactly is cached?

The KV cache stores the K and V tensors from previous tokens. It usually does not store queries, attention scores, attention probabilities, logits, sampled tokens, or the full hidden states and intermediate activations for every layer.

The reason is simple: once the model has computed the key and value for an old token at a given layer, those tensors do not change during generation. A new token needs to attend to them, but the old tokens themselves do not need to be recomputed.

For one attention layer, the simplified update is:

Past cache:
  K_past: [tokens already processed]
  V_past: [tokens already processed]

New token:
  K_new
  V_new

Updated cache:
  K_cache = concat(K_past, K_new)
  V_cache = concat(V_past, V_new)

During the next decode step, the model reuses this updated cache.

2.3 What the real KV cache shape looks like

In a real decoder-only LLM, the cache exists for every Transformer layer. A common conceptual shape is:

key cache:
  [batch_size, num_kv_heads, sequence_length, head_dim]

value cache:
  [batch_size, num_kv_heads, sequence_length, head_dim]

And this is repeated layer by layer:

KV cache:
  layer 0:
    K: [B, H_kv, S, D]
    V: [B, H_kv, S, D]
  layer 1:
    K: [B, H_kv, S, D]
    V: [B, H_kv, S, D]
  ...

Where:

B: batch size
H_kv: number of key/value heads
S: sequence length
D: head dimension

For a single request, B = 1. In a server, many requests are active at the same time, and the runtime has to manage many growing caches together.


3. Code: from toy attention to cached decoding

The next examples make the idea concrete with tiny attention code. These scripts are not real LLM implementations, but they isolate the part that matters for KV cache: full recomputation versus prefill plus one-token decode.

3.1 Toy attention without KV cache

This script computes causal attention over the whole sequence. It is not a real LLM, but it shows the basic computation that would be repeated in a naive generation loop.

Create toy_attention_no_cache.py:

#!/usr/bin/env python3
import argparse
import math

import torch


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Toy causal attention without KV cache."
    )
    parser.add_argument("--seq-len", type=int, default=6)
    parser.add_argument("--hidden-size", type=int, default=8)
    parser.add_argument("--seed", type=int, default=0)
    return parser.parse_args()


def causal_attention(x: torch.Tensor) -> torch.Tensor:
    seq_len, hidden_size = x.shape

    torch.manual_seed(0)
    wq = torch.randn(hidden_size, hidden_size)
    wk = torch.randn(hidden_size, hidden_size)
    wv = torch.randn(hidden_size, hidden_size)

    q = x @ wq
    k = x @ wk
    v = x @ wv

    scores = q @ k.T / math.sqrt(hidden_size)

    causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
    scores = scores.masked_fill(~causal_mask, float("-inf"))

    weights = torch.softmax(scores, dim=-1)
    output = weights @ v

    return output


def main() -> None:
    args = parse_args()
    torch.manual_seed(args.seed)

    x = torch.randn(args.seq_len, args.hidden_size)

    output = causal_attention(x)

    print("Input shape:")
    print(tuple(x.shape))
    print("Output shape:")
    print(tuple(output.shape))
    print("Last token output:")
    print(output[-1])


if __name__ == "__main__":
    main()

Run:

uv pip install torch
python toy_attention_no_cache.py

For generation without KV cache, code like this would be called again and again on longer and longer sequences.

3.2 Toy attention with KV cache

Now let us compare two methods:

Method 1:
  Run full causal attention on the whole sequence.

Method 2:
  Run prefill on the prefix.
  Decode one token using cached K/V.

The last-token output should match.

Create toy_attention_with_cache.py:

#!/usr/bin/env python3
import argparse
import math

import torch


class ToyAttention:
    def __init__(self, hidden_size: int, seed: int) -> None:
        torch.manual_seed(seed)
        self.hidden_size = hidden_size
        self.wq = torch.randn(hidden_size, hidden_size)
        self.wk = torch.randn(hidden_size, hidden_size)
        self.wv = torch.randn(hidden_size, hidden_size)

    def project(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        q = x @ self.wq
        k = x @ self.wk
        v = x @ self.wv
        return q, k, v

    def full_forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len, hidden_size = x.shape
        q, k, v = self.project(x)

        scores = q @ k.T / math.sqrt(hidden_size)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
        scores = scores.masked_fill(~causal_mask, float("-inf"))

        weights = torch.softmax(scores, dim=-1)
        return weights @ v

    def prefill(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        _, k, v = self.project(x)
        return k, v

    def decode_one(
        self,
        x_new: torch.Tensor,
        past_k: torch.Tensor,
        past_v: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        q_new, k_new, v_new = self.project(x_new)

        all_k = torch.cat([past_k, k_new], dim=0)
        all_v = torch.cat([past_v, v_new], dim=0)

        scores = q_new @ all_k.T / math.sqrt(self.hidden_size)
        weights = torch.softmax(scores, dim=-1)
        output = weights @ all_v

        return output, all_k, all_v


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Toy causal attention with KV cache."
    )
    parser.add_argument("--prefix-len", type=int, default=5)
    parser.add_argument("--hidden-size", type=int, default=8)
    parser.add_argument("--seed", type=int, default=0)
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    torch.manual_seed(args.seed)

    attention = ToyAttention(hidden_size=args.hidden_size, seed=args.seed)

    prefix = torch.randn(args.prefix_len, args.hidden_size)
    new_token = torch.randn(1, args.hidden_size)

    full_input = torch.cat([prefix, new_token], dim=0)
    full_output = attention.full_forward(full_input)
    full_last_output = full_output[-1:]

    past_k, past_v = attention.prefill(prefix)
    cached_output, updated_k, updated_v = attention.decode_one(
        x_new=new_token,
        past_k=past_k,
        past_v=past_v,
    )

    print("Full last-token output:")
    print(full_last_output)
    print()
    print("Cached decode output:")
    print(cached_output)
    print()
    print("Max absolute difference:")
    print((full_last_output - cached_output).abs().max().item())
    print()
    print("Updated K cache shape:")
    print(tuple(updated_k.shape))
    print("Updated V cache shape:")
    print(tuple(updated_v.shape))


if __name__ == "__main__":
    main()

Run:

python toy_attention_with_cache.py

The max difference should be tiny, usually just floating-point noise. For the newest token, cached attention gives the same result as full causal attention, but avoids recomputing the old K/V tensors.

3.3 Benchmarking the toy cache

Now we can compare the two loops directly. The benchmark below uses the ToyAttention class from toy_attention_with_cache.py.

Create benchmark_toy_kv_cache.py:

#!/usr/bin/env python3
import argparse
import time

import torch

from toy_attention_with_cache import ToyAttention


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Benchmark toy decoding with and without KV cache.")
    parser.add_argument("--prompt-len", type=int, default=256)
    parser.add_argument("--new-tokens", type=int, default=128)
    parser.add_argument("--hidden-size", type=int, default=128)
    parser.add_argument("--repeats", type=int, default=5)
    parser.add_argument("--seed", type=int, default=0)
    return parser.parse_args()


def mean_time(fn, repeats: int) -> float:
    times = []
    for _ in range(repeats):
        start = time.perf_counter()
        fn()
        times.append(time.perf_counter() - start)
    return sum(times) / len(times)


def main() -> None:
    args = parse_args()
    torch.set_num_threads(1)
    torch.manual_seed(args.seed)

    attention = ToyAttention(hidden_size=args.hidden_size, seed=args.seed)
    prompt = torch.randn(args.prompt_len, args.hidden_size)
    new_tokens = torch.randn(args.new_tokens, args.hidden_size)

    def without_cache() -> None:
        sequence = prompt.clone()
        for i in range(args.new_tokens):
            sequence = torch.cat([sequence, new_tokens[i : i + 1]], dim=0)
            _ = attention.full_forward(sequence)[-1]

    def with_cache() -> None:
        past_k, past_v = attention.prefill(prompt)
        for i in range(args.new_tokens):
            _, past_k, past_v = attention.decode_one(new_tokens[i : i + 1], past_k, past_v)

    # Warmup
    without_cache()
    with_cache()

    no_cache_time = mean_time(without_cache, args.repeats)
    cache_time = mean_time(with_cache, args.repeats)

    print("Toy KV cache benchmark")
    print("----------------------")
    print(f"Prompt length: {args.prompt_len}")
    print(f"Generated tokens: {args.new_tokens}")
    print(f"Hidden size: {args.hidden_size}")
    print(f"Repeats: {args.repeats}")
    print()
    print(f"Without KV cache: {no_cache_time:.4f} s")
    print(f"With KV cache:    {cache_time:.4f} s")
    print(f"Speedup:          {no_cache_time / cache_time:.2f}x")


if __name__ == "__main__":
    main()

Run:

python benchmark_toy_kv_cache.py

On an Apple M2 MacBook Air CPU, with the defaults above, I measured:

Toy KV cache benchmark
----------------------
Prompt length: 256
Generated tokens: 128
Hidden size: 128
Repeats: 5

Without KV cache: 0.0471 s
With KV cache:    0.0036 s
Speedup:          13.07x

Do not treat this as a production benchmark. It is a tiny CPU-only demonstration. The useful takeaway is the direction: recomputing the whole prefix again and again is much more expensive than computing old K/V once and reusing it.


4. Memory math and architectural choices

The performance benefit of KV cache is easy to like; the memory cost is the painful part. This section gives the formulas and the main architecture choices that affect cache size.

4.1 KV cache memory formula

The rough memory formula is:

KV cache bytes =
    batch_size
  * sequence_length
  * num_layers
  * 2
  * num_kv_heads
  * head_dim
  * bytes_per_element

The 2 appears because the cache stores both K and V.

Create kv_cache_size.py:

#!/usr/bin/env python3
import argparse


DTYPE_BYTES = {
    "fp32": 4.0,
    "bf16": 2.0,
    "fp16": 2.0,
    "fp8": 1.0,
    "int8": 1.0,
    "fp4": 0.5,
    "int4": 0.5,
}


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Estimate KV cache memory for decoder-only LLM inference."
    )
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--sequence-length", type=int, default=8192)
    parser.add_argument("--num-layers", type=int, default=32)
    parser.add_argument("--num-kv-heads", type=int, default=8)
    parser.add_argument("--head-dim", type=int, default=128)
    parser.add_argument(
        "--dtype",
        choices=sorted(DTYPE_BYTES),
        default="bf16",
        help="KV cache dtype.",
    )
    parser.add_argument(
        "--tp-size",
        type=int,
        default=1,
        help="Tensor parallel size for a rough per-rank estimate.",
    )
    return parser.parse_args()


def format_gib(num_bytes: float) -> str:
    return f"{num_bytes / (1024 ** 3):.2f} GiB"


def main() -> None:
    args = parse_args()

    bytes_per_element = DTYPE_BYTES[args.dtype]

    total_bytes = (
        args.batch_size
        * args.sequence_length
        * args.num_layers
        * 2
        * args.num_kv_heads
        * args.head_dim
        * bytes_per_element
    )

    bytes_per_token_per_sequence = (
        args.num_layers
        * 2
        * args.num_kv_heads
        * args.head_dim
        * bytes_per_element
    )

    per_rank_bytes = total_bytes / args.tp_size

    print("KV cache estimate")
    print("-----------------")
    print(f"Batch size: {args.batch_size}")
    print(f"Sequence length: {args.sequence_length}")
    print(f"Layers: {args.num_layers}")
    print(f"KV heads: {args.num_kv_heads}")
    print(f"Head dim: {args.head_dim}")
    print(f"Dtype: {args.dtype}")
    print()
    print(f"Bytes per token per sequence: {bytes_per_token_per_sequence:,.0f}")
    print(f"Total KV cache: {total_bytes:,.0f} bytes")
    print(f"Total KV cache: {format_gib(total_bytes)}")
    print()
    print(f"Rough per-rank estimate with tp_size={args.tp_size}:")
    print(f"{format_gib(per_rank_bytes)}")


if __name__ == "__main__":
    main()

Run:

python kv_cache_size.py

With the defaults, the output is:

KV cache estimate
-----------------
Batch size: 64
Sequence length: 8192
Layers: 32
KV heads: 8
Head dim: 128
Dtype: bf16

Bytes per token per sequence: 131,072
Total KV cache: 68,719,476,736 bytes
Total KV cache: 64.00 GiB

Rough per-rank estimate with tp_size=1:
64.00 GiB

The KV cache is about 64 GiB. That is just the cache; it does not include model weights, temporary buffers, CUDA workspace, attention workspaces, fragmentation, allocator overhead, or server overhead.

4.2 Memory per token

The most useful number is often the cache cost of one token for one sequence:

bytes_per_token =
    num_layers
  * 2
  * num_kv_heads
  * head_dim
  * bytes_per_element

For example:

num_layers = 32
num_kv_heads = 8
head_dim = 128
dtype = bf16, so bytes_per_element = 2

Then:

bytes_per_token =
    32 * 2 * 8 * 128 * 2
  = 131,072 bytes
  = 128 KiB

One 8,192-token request therefore uses roughly:

8192 * 128 KiB = 1 GiB

Here, a request means one active sequence being generated by the server, such as one chat completion or API call. If the server is handling many active sequences at the same time, the KV cache scales roughly with that count. For example, 64 active 8,192-token requests at about 1 GiB each gives about 64 GiB of KV cache.

4.3 KV cache does not make decoding constant-time

A common misunderstanding is that KV cache makes decoding one token O(1). It does not. The new token still attends to all previous tokens, so the attention read grows with sequence length.

Without KV cache:

Recompute old tokens through all layers.

With KV cache:

Compute only the new token through the layers,
but read old K/V tensors during attention.

This is why decode can be memory-bandwidth-heavy. The GPU may spend a lot of time reading cached K/V tensors, especially for long contexts and large batches.

4.4 MHA, MQA, and GQA

The formula depends heavily on num_kv_heads. A query head is one parallel attention channel that produces its own queries. A key/value head is the corresponding channel that stores keys and values in the cache. Older multi-head attention usually gives each query head its own K/V head, but newer layouts often let multiple query heads share fewer K/V heads.

There are three common attention layouts.

MHA: multi-head attention

num_query_heads = num_kv_heads

Example:

32 query heads
32 key/value heads

Each query head has its own K/V head.

MQA: multi-query attention

num_query_heads = many
num_kv_heads = 1

Example:

32 query heads
1 key/value head

Many query heads share one K/V head, which drastically reduces cache memory.

GQA: grouped-query attention

num_query_heads = many
num_kv_heads = fewer than query heads

Example:

32 query heads
8 key/value heads

This is a compromise between MHA and MQA.

For memory, the effect is direct:

MHA:
  Q heads: 32
  KV heads: 32

GQA:
  Q heads: 32
  KV heads: 8

MQA:
  Q heads: 32
  KV heads: 1

Going from 32 KV heads to 8 KV heads makes the KV cache 4 times smaller. Going from 32 to 1 makes it 32 times smaller.

4.5 Comparing MHA, GQA, and MQA memory

Create compare_kv_heads.py:

#!/usr/bin/env python3
import argparse


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Compare KV cache memory for MHA, GQA, and MQA."
    )
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--sequence-length", type=int, default=8192)
    parser.add_argument("--num-layers", type=int, default=32)
    parser.add_argument("--query-heads", type=int, default=32)
    parser.add_argument("--head-dim", type=int, default=128)
    parser.add_argument("--bytes-per-element", type=float, default=2.0)
    return parser.parse_args()


def kv_gib(
    batch_size: int,
    sequence_length: int,
    num_layers: int,
    num_kv_heads: int,
    head_dim: int,
    bytes_per_element: float,
) -> float:
    total_bytes = (
        batch_size
        * sequence_length
        * num_layers
        * 2
        * num_kv_heads
        * head_dim
        * bytes_per_element
    )
    return total_bytes / (1024 ** 3)


def main() -> None:
    args = parse_args()

    configs = [
        ("MHA", args.query_heads),
        ("GQA", max(1, args.query_heads // 4)),
        ("MQA", 1),
    ]

    print("KV cache memory comparison")
    print("--------------------------")
    print(f"Batch size: {args.batch_size}")
    print(f"Sequence length: {args.sequence_length}")
    print(f"Layers: {args.num_layers}")
    print(f"Query heads: {args.query_heads}")
    print(f"Head dim: {args.head_dim}")
    print()

    for name, num_kv_heads in configs:
        memory = kv_gib(
            batch_size=args.batch_size,
            sequence_length=args.sequence_length,
            num_layers=args.num_layers,
            num_kv_heads=num_kv_heads,
            head_dim=args.head_dim,
            bytes_per_element=args.bytes_per_element,
        )
        print(f"{name}:")
        print(f"  KV heads: {num_kv_heads}")
        print(f"  KV cache: {memory:.2f} GiB")


if __name__ == "__main__":
    main()

Run:

python compare_kv_heads.py

With the defaults, the output is:

KV cache memory comparison
--------------------------
Batch size: 16
Sequence length: 8192
Layers: 32
Query heads: 32
Head dim: 128

MHA:
  KV heads: 32
  KV cache: 64.00 GiB
GQA:
  KV heads: 8
  KV cache: 16.00 GiB
MQA:
  KV heads: 1
  KV cache: 2.00 GiB

The difference comes directly from the number of K/V heads. In this setup, MHA stores K/V for all 32 heads, so it uses 64 GiB. GQA stores K/V for 8 heads, so it uses one quarter of that: 16 GiB. MQA stores one shared K/V head, so it drops to 2 GiB. This is why many modern LLMs use GQA: it keeps many query heads for attention quality while making the KV cache much smaller than full MHA.


5. KV cache in serving systems

With one request, KV cache is a tuple of tensors. With many concurrent requests, it becomes a full memory-management problem.

5.1 Why serving makes KV cache hard

A production server handles many requests with different lengths:

Request A:
  prompt length 100
  output length 50

Request B:
  prompt length 3000
  output length 200

Request C:
  prompt length 20
  output length 10

Request D:
  prompt length 8000
  output length 1000

Each request has its own growing cache. Some finish early, some are cancelled, some share prefixes, and some stream for a long time. The runtime must answer questions such as:

Where does each request's KV cache live?
How much memory should be reserved?
Can old blocks be reused?
Can completed prefixes stay cached?
Can some cache move to CPU?
Which requests should run next?

This is why engines such as TensorRT-LLM, vLLM, SGLang, TGI, and llama.cpp spend so much engineering effort on cache management. LLM serving is not just matrix multiplication; it is also scheduling and memory allocation.

5.2 The naive KV cache layout

A naive server might allocate one large KV tensor per request at the maximum sequence length:

max_seq_len = 8192

Request A:
  allocate space for 8192 tokens

Request B:
  allocate space for 8192 tokens

Request C:
  allocate space for 8192 tokens

If request A only uses 200 tokens, most of its allocation is wasted:

Request A:
  [used used used free free free free free]

Request B:
  [used used used used used used free free]

Request C:
  [used free free free free free free free]

This wastes memory and makes variable-length batching awkward.

5.3 Paged KV cache

Paged KV cache splits memory into fixed-size blocks. Instead of giving each request one giant contiguous allocation, the runtime gives each request a list of blocks.

KV memory pool:
  [block 0] [block 1] [block 2] [block 3] [block 4] [block 5]

Request A:
  block 0 -> block 3

Request B:
  block 1 -> block 2 -> block 5

Request C:
  block 4

The request still has a logical sequence of tokens, but physically its cache can be scattered across blocks. The runtime keeps a block table:

Request B logical blocks:
  logical block 0 -> physical block 1
  logical block 1 -> physical block 2
  logical block 2 -> physical block 5

This is similar in spirit to virtual memory paging. The benefits are less wasted memory, easier allocation/freeing, better support for variable-length requests, and better support for prefix sharing. The cost is more complex attention kernels, block-table indirection, and a new tuning parameter: block size.

5.4 Block size tradeoff

If block_size = 16, a request with 35 tokens needs:

ceil(35 / 16) = 3 blocks

The last block is partially empty:

block 0: 16 tokens used
block 1: 16 tokens used
block 2: 3 tokens used, 13 token slots unused

Small blocks reduce wasted space in the last block and enable finer-grained sharing, but they require larger block tables and more metadata. Large blocks have less metadata and can be friendlier to memory locality, but waste more space when sequences do not fill the last block. The best block size depends on the model, workload, kernels, and hardware.

5.5 Continuous batching and KV cache

KV cache is tightly connected to batching. A naive server might form a batch and wait until every request in it finishes:

Batch 1:
  Request A
  Request B
  Request C

Wait until A, B, and C all finish.

Batch 2:
  Request D
  Request E
  Request F

This wastes GPU time because requests have different output lengths. Modern LLM servers use continuous batching, where the active batch can change at every iteration:

Step 0:
  A, B, C

Step 10:
  B, C, D
  A finished, D joined

Step 30:
  B, D, E
  C finished, E joined

Step 100:
  D, E, F
  B finished, F joined

The scheduler can change the batch, but each request keeps its own KV cache:

Request A:
  tokens generated: 10
  KV blocks: [0, 3]

Request B:
  tokens generated: 100
  KV blocks: [1, 2, 5, 8, 9]

Request C:
  tokens generated: 30
  KV blocks: [4, 7]

This is one of the central runtime problems in LLM serving: the batch is dynamic, while the cache state for each request must remain correct.

5.6 Chunked prefill and KV cache

Long prompts can monopolize the GPU. For example:

Request A:
  100,000-token prompt

Request B:
  short chat request

Request C:
  short chat request

If the server processes all of A’s prefill at once, B and C may wait too long. Chunked prefill splits the long prompt into pieces:

A prefill chunk 0:
  tokens 0..4095

Decode some other requests.

A prefill chunk 1:
  tokens 4096..8191

Decode some other requests.

A prefill chunk 2:
  tokens 8192..12287

The cache is built incrementally. After chunk 0, it contains tokens 0..4095; after chunk 1, it contains tokens 0..8191, and so on. Chunked prefill improves scheduling fairness, but it also makes the scheduler more complex because partial-prefill requests and decode requests must be interleaved correctly.

5.7 The lifecycle of a request’s KV cache

A request’s cache usually goes through this lifecycle:

1. Request arrives.
2. Tokens are produced by the tokenizer.
3. Scheduler accepts the request.
4. KV blocks are allocated.
5. Prefill writes prompt K/V.
6. Decode appends generated-token K/V.
7. Request finishes, errors, or is cancelled.
8. KV blocks are freed or kept for prefix reuse.
9. Reusable blocks may later be evicted.

A single server may run this lifecycle for thousands of requests over time, with many of them active simultaneously.


6. Reuse, eviction, offloading, and quantization

Once the cache is represented as blocks, the runtime can do more than allocate and free. It can reuse common prefixes, protect cache sharing between tenants, evict old blocks, offload cold blocks, or store K/V in lower precision.

6.1 Prefix caching

Many requests share the same prefix. For example:

System prompt:
  You are a helpful assistant. Answer carefully.

User request A:
  You are a helpful assistant. Answer carefully.
  Explain KV cache.

User request B:
  You are a helpful assistant. Answer carefully.
  Explain attention.

User request C:
  You are a helpful assistant. Answer carefully.
  Explain quantization.

Without prefix caching, each request computes KV for the same system prompt. With prefix caching, the first request computes it and later requests reuse it.

This can improve time to first token, throughput, GPU utilization, and prefill cost. The match must happen at the token level, not just the text level. For example, a trailing space may change tokenization:

"You are helpful."
"You are helpful. "

6.3 Cache eviction

GPU memory is limited. When the runtime needs a new KV block and no free blocks remain, it must evict something.

A simple policy is LRU:

Evict the least recently used reusable block.

Production systems may also consider priority, prefix length, tenant, expected reuse probability, block age, memory pressure, and offload cost.

A useful distinction is:

Active request cache:
  Needed right now for ongoing generation.

Reusable prefix cache:
  Belongs to a completed or shared prefix.
  Useful, but not strictly required.

Evicting reusable cache is usually fine. Evicting active request cache is harder; the system either cannot evict it or must pause/offload the request.

6.4 KV cache offloading

Offloading means moving some KV cache from GPU memory to CPU memory. GPU memory is fast but limited; CPU memory is larger but slower.

A typical strategy is:

Active hot KV cache:
  keep on GPU

Reusable but cold KV cache:
  move to CPU

If needed later:
  copy back to GPU

This can help when prefix reuse is common, GPU memory is tight, prompts are long, and CPU-GPU transfer overhead is acceptable. It can hurt when active cache moves too often, because transfers between CPU and GPU are expensive. Offloading is a memory-latency tradeoff.

6.5 Quantized KV cache

The cache is often stored in fp16 or bf16, where each value uses 2 bytes. Some systems can store K/V in lower precision:

fp8: 1 byte per value
int8: 1 byte per value
fp4/int4: 0.5 bytes per value

For example:

BF16 KV cache: 64 GiB
FP8 KV cache: 32 GiB
INT4 KV cache: 16 GiB

This is approximate because real quantized formats may need scales, metadata, alignment, or special layouts.

KV quantization is especially useful for long context, large batch size, high concurrency, and memory-limited serving. It can affect quality, attention accuracy, kernel support, latency, and memory bandwidth, so it should be measured rather than assumed to be free.


7. Conclusion

KV cache is the memory of the generation process. For each previous token, in each Transformer layer, the model stores the attention keys and values. When generating the next token, it reuses those tensors instead of recomputing the whole prefix.

That is what makes autoregressive generation practical. The tradeoff is that the cache grows with context length and with the number of active requests, so serving becomes a memory-management problem as much as a model-compute problem.

The core idea to remember is simple: KV cache saves repeated computation by keeping reusable attention tensors around. It makes decoding fast enough to be useful, and it also creates one of the main bottlenecks in LLM serving.