Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

KV Cache Design and Management

Every transformer-based language model has a dirty secret: the attention mechanism is fundamentally stateful. Each time the model generates a new token, it needs to look back at every previous token’s key and value projections. Without caching, you would re-compute K and V for the entire prompt on every single decode step, turning O(n) generation into O(n^2). The KV cache is what prevents that – it stores previously computed key and value tensors so the model only computes the new token’s K and V, then attends over the full cached history.

In this chapter, we will walk through exactly how Akunu designs, allocates, and manages its KV cache. The design philosophy is relentlessly simple: no virtual calls, no optionals, no reference counting. A flat POD struct with contiguous GPU buffers.

What the KV Cache Actually Stores

At every transformer layer, the attention mechanism projects the hidden state into three matrices: Q (query), K (key), and V (value). During decode, we only compute Q/K/V for the current token, but we need the K and V from all previous tokens to compute attention scores. So we keep a running buffer of K and V per layer.

Here is what that looks like conceptually:

For each layer l in [0, n_layers):
    K cache[l] = all past key vectors for layer l
    V cache[l] = all past value vectors for layer l

When generating token at position t, the model:

  1. Computes Q_t, K_t, V_t from the current hidden state
  2. Writes K_t, V_t into the cache at position t
  3. Reads the full K[0..t], V[0..t] from the cache
  4. Computes attention: softmax(Q_t @ K[0..t]^T / sqrt(d)) @ V[0..t]

The cache grows by one position per generated token. Without it, you would need to re-run the entire prompt through every layer on every step.

The KVCache Struct

Let us look at Akunu’s actual implementation. The struct lives in src/cache/kv_cache.h and it is remarkably compact:

struct KVCache {
    int n_layers;
    int n_kv_heads;
    int head_dim;
    int max_length;
    int current_length;

    std::vector<Buffer> k_buffers;  // one per layer
    std::vector<Buffer> v_buffers;  // one per layer

    int kv_stride;  // max_length * head_dim
};

That is it. Five integers, two vectors of GPU buffers, and a pre-computed stride. No inheritance hierarchy. No smart pointers. No allocators or memory pools. Just the data you need and nothing else.

Let us break each field down:

n_layers: The number of transformer layers. For LLaMA 3-8B, this is 32. For a 70B model, 80. Each layer gets its own independent K buffer and V buffer.

n_kv_heads: The number of key/value heads. In grouped-query attention (GQA), this is smaller than the number of query heads. LLaMA 3-8B has 32 query heads but only 8 KV heads – a 4:1 ratio that saves 75% of the cache memory.

head_dim: The dimension of each attention head. Typically 128 for modern models (e.g., 4096 / 32 = 128 for LLaMA 3-8B).

max_length: The maximum sequence length the cache can hold. Typically set to the model’s context window (e.g., 4096 or 8192).

current_length: The stateful counter – how many positions have been written so far. Starts at 0, incremented by advance(), never exceeds max_length.

kv_stride: Pre-computed as max_length * head_dim. This is the number of FP16 elements between consecutive KV heads in memory.

Memory Layout: Head-Major Ordering

Each K and V buffer has the following shape:

[n_kv_heads, max_seq_len, head_dim]

All stored in FP16 (half-precision, 2 bytes per element). Let us draw what this looks like in memory for a single layer:

Buffer: k_buffers[layer]

+------------------------------------------------------------------+
|                          KV Head 0                                |
|  +-----------------------------------------------------------+   |
|  | pos 0: [d0, d1, d2, ..., d127]  (head_dim=128, FP16)     |   |
|  | pos 1: [d0, d1, d2, ..., d127]                            |   |
|  | pos 2: [d0, d1, d2, ..., d127]                            |   |
|  | ...                                                        |   |
|  | pos max_length-1: [d0, d1, ..., d127]                     |   |
|  +-----------------------------------------------------------+   |
|                          KV Head 1                                |
|  +-----------------------------------------------------------+   |
|  | pos 0: [d0, d1, d2, ..., d127]                            |   |
|  | pos 1: [d0, d1, d2, ..., d127]                            |   |
|  | ...                                                        |   |
|  +-----------------------------------------------------------+   |
|  ...                                                              |
|                          KV Head (n_kv_heads-1)                   |
|  +-----------------------------------------------------------+   |
|  | pos 0: [d0, d1, d2, ..., d127]                            |   |
|  | ...                                                        |   |
|  +-----------------------------------------------------------+   |
+------------------------------------------------------------------+

The linear address of element [head h, position p, dimension d] is:

offset = (h * kv_stride + p * head_dim + d) * sizeof(FP16)
       = (h * max_length * head_dim + p * head_dim + d) * 2

Why Head-Major?

You might ask: why not position-major [max_seq_len, n_kv_heads, head_dim]? The answer comes down to how attention kernels access memory.

During decode, each attention head operates independently. The flash attention decode kernel processes one head at a time. For a given head h, it needs to read all positions [0..current_length] for that head in sequence. With head-major layout, these positions are contiguous in memory:

Head-major (what Akunu uses):
  Reading head h, positions 0..T = contiguous read of T*head_dim elements
  Great for coalesced GPU memory access!

  Head 0: [pos0][pos1][pos2]...[posT][---padding---]
  Head 1: [pos0][pos1][pos2]...[posT][---padding---]
  ...

Position-major (alternative):
  Reading head h, positions 0..T = strided read with stride n_kv_heads*head_dim
  Terrible for coalesced access!

  Pos 0: [head0][head1]...[headN]
  Pos 1: [head0][head1]...[headN]
  ...

With head-major layout, each SIMD group in the attention kernel reads a nice contiguous chunk of memory. The kv_stride field (= max_length * head_dim) gives the distance between head 0’s data and head 1’s data, which the kernel uses to index into the right region.

Memory Budget Calculation

Let us work through a concrete example. Take LLaMA 3.1-8B with a 4096 context:

Parameters:
  n_layers    = 32
  n_kv_heads  = 8    (GQA: 32 query heads, 8 KV heads)
  head_dim    = 128
  max_length  = 4096
  dtype       = FP16 (2 bytes)

Per-layer buffer size:
  size = n_kv_heads * max_length * head_dim * sizeof(FP16)
       = 8 * 4096 * 128 * 2
       = 8,388,608 bytes
       = 8 MB

Total KV cache (K + V, all layers):
  total = 2 * n_layers * size
        = 2 * 32 * 8 MB
        = 512 MB

That is 512 MB just for the KV cache on a 4096 context. Scale up to 8192 context and you are at 1 GB. For a 70B model with 80 layers and 8 KV heads:

Per-layer: 8 * 8192 * 128 * 2 = 16 MB
Total: 2 * 80 * 16 MB = 2,560 MB = 2.5 GB

This is why GQA is so important – without it (i.e., multi-head attention where n_kv_heads = n_heads = 64), the 70B model would need:

64 * 8192 * 128 * 2 = 128 MB per layer
2 * 80 * 128 MB = 20,480 MB = 20 GB

GQA with 8 KV heads reduces cache memory by 8x. That is the difference between fitting on a MacBook Pro and not fitting at all.

Here is a summary table:

+----------------+--------+---------+--------+---------+---------+
| Model          | Layers | KV Heads| Head D | Ctx Len | KV Size |
+----------------+--------+---------+--------+---------+---------+
| LLaMA 3-8B    |   32   |    8    |  128   |  4096   |  512 MB |
| LLaMA 3-8B    |   32   |    8    |  128   |  8192   |   1 GB  |
| Qwen 2.5-7B   |   28   |    4    |  128   |  4096   |  224 MB |
| Gemma 3-4B    |   34   |    4    |  256   |  4096   |  896 MB |
| LLaMA 3-70B   |   80   |    8    |  128   |  4096   | 1.25 GB |
| LLaMA 3-70B   |   80   |    8    |  128   |  8192   |  2.5 GB |
+----------------+--------+---------+--------+---------+---------+

Pre-Allocation: No Malloc in the Hot Path

The KVCache::create() factory allocates everything upfront:

static KVCache create(Device& device, int n_layers, int n_kv_heads,
                      int head_dim, int max_length) {
    KVCache cache;
    // ... set fields ...

    size_t buf_size = (size_t)n_kv_heads * max_length * head_dim * sizeof(uint16_t);

    cache.k_buffers.resize(n_layers);
    cache.v_buffers.resize(n_layers);
    for (int i = 0; i < n_layers; i++) {
        cache.k_buffers[i] = device.allocate(buf_size);
        cache.v_buffers[i] = device.allocate(buf_size);
        memset(cache.k_buffers[i].contents, 0, buf_size);
        memset(cache.v_buffers[i].contents, 0, buf_size);
    }
    return cache;
}

Note several things:

  1. All buffers are allocated at once. No lazy allocation, no on-demand growth. You pay the memory cost at model load time, not during generation.

  2. Buffers are zero-filled. The memset ensures that unwritten positions have zero values. This matters because the attention kernel might read past current_length due to SIMD alignment, and we do not want garbage affecting softmax scores.

  3. The sizeof(uint16_t) is FP16. Akunu stores cache values in half precision. There is no option for FP32 or INT8 cache – keeping it simple.

  4. Device::allocate() returns a Buffer. On Metal, this is an MTLBuffer allocated in shared memory (Apple Silicon UMA), meaning both CPU and GPU can access it without explicit copies.

Stateful Tracking

The KV cache tracks how many positions have been filled via current_length. The API provides four operations for managing this state:

advance(count)

After computing K and V for count new tokens (1 during decode, N during prefill), call advance() to update the position:

void advance(int count) {
    current_length += count;
    if (current_length > max_length)
        current_length = max_length;
}

The clamping to max_length is a safety measure. In practice, the caller should check would_overflow() before adding tokens, but the clamp prevents buffer overruns if something goes wrong.

would_overflow(additional)

bool would_overflow(int additional) const {
    return current_length + additional > max_length;
}

The caller checks this before prefill or decode to avoid writing past the buffer. If it returns true, the inference engine must either truncate the input or refuse the request.

rollback(to_length)

void rollback(int to_length) {
    if (to_length < current_length)
        current_length = to_length;
}

Rollback moves the cursor backwards. The actual data in the buffers is not erased – only the position counter changes. This is safe because future writes at positions >= to_length will overwrite the stale data before it is read.

This is used for prefix caching in the server: if the new prompt shares a prefix with the previous one, we rollback to the shared prefix length and only re-compute the divergent tokens.

reset()

void reset() { current_length = 0; }

The nuclear option. Resets to the beginning without touching the actual buffer contents. The next prefill will overwrite everything.

The Lifecycle of a Cache During Generation

Here is the full flow from prompt to generation:

1. User sends prompt: "What is the capital of France?"
   Tokenized: [BOS, 1724, 338, 278, 7483, 310, 3444, 29973]  (8 tokens)

2. PREFILL: process all 8 tokens in one batch
   For each layer:
     Compute K[0..7], V[0..7]        <-- batch of 8 KV vectors
     Write into cache at pos 0..7
   cache.advance(8)                   <-- current_length = 8

3. DECODE step 1: generate token "The"
   For each layer:
     Compute K[8], V[8]              <-- single new KV vector
     Write into cache at pos 8
     Attend Q[8] over K[0..8], V[0..8]
   cache.advance(1)                   <-- current_length = 9

4. DECODE step 2: generate token " capital"
   For each layer:
     Compute K[9], V[9]
     Write into cache at pos 9
     Attend Q[9] over K[0..9], V[0..9]
   cache.advance(1)                   <-- current_length = 10

5. ... continue until EOS or max_tokens ...

6. NEXT CONVERSATION:
   New prompt: "What is the capital of Germany?"
   Shares prefix: "What is the capital of " = 7 tokens

   Option A: cache.reset() + full prefill (simple)
   Option B: rollback to shared prefix + incremental prefill (efficient)

   With prefix caching (Option B):
     shared = 7 tokens match
     cache.rollback(7)                <-- current_length = 7
     Prefill only tokens 7..N         <-- "Germany?" = 2 tokens
     cache.advance(2)                 <-- current_length = 9

Here is a timeline diagram:

Position in KV cache:
     0    1    2    3    4    5    6    7    8    9    10   11
     |    |    |    |    |    |    |    |    |    |    |    |
     [BOS][What][ is][ the][cap][ital][ of][ Fr][The][ ca][pit]
     |----- prefill (8 tokens) -----| |-- decode step by step--|
                                      ^
                              current_length advances: 8 -> 9 -> 10 -> 11

After rollback(7) for new conversation:
     0    1    2    3    4    5    6    7    8
     |    |    |    |    |    |    |    |    |
     [BOS][What][ is][ the][cap][ital][ of][Ger][many]
     |-- preserved prefix (7) --|     ^
                                  incremental prefill: 2 tokens

Prefix Caching in the Server

The HTTP server (Chapter 50) maintains a ModelEntry per loaded model that tracks the last prompt’s tokens:

struct ModelEntry {
    std::vector<uint32_t> cached_tokens;
    int cached_position = 0;

    int shared_prefix(const uint32_t *tokens, int n_tokens) const {
        int shared = 0;
        int limit = std::min((int)cached_tokens.size(), n_tokens);
        for (int i = 0; i < limit; i++) {
            if (cached_tokens[i] != tokens[i]) break;
            shared++;
        }
        return shared;
    }
};

When a new request arrives:

  1. Encode the new prompt to tokens
  2. Compare with cached_tokens to find the shared prefix length
  3. If shared > 0 && shared <= cached_position, use rollback + incremental prefill
  4. Otherwise, full reset + prefill from scratch

This gives you “free” prefix caching with zero extra infrastructure. In a chatbot scenario where the system prompt is the same across turns, you skip re-processing hundreds of tokens. For a 2048-token system prompt, that can save 50-100ms of prefill time on each request.

How the Attention Kernel Uses the Cache

During decode, the flash attention kernel receives the cache buffers as arguments. Here is a simplified view of the kernel parameters:

flash_attention_decode_fast_f16(
    Q:           [n_heads, 1, head_dim]        <-- current token's query
    K_cache:     [n_kv_heads, max_seq, head_dim]  <-- full K cache for this layer
    V_cache:     [n_kv_heads, max_seq, head_dim]  <-- full V cache for this layer
    output:      [n_heads, 1, head_dim]        <-- attention output
    kv_seq_len:  current_length + 1            <-- how far to read
    kv_stride:   max_length * head_dim         <-- stride between heads
    scale:       1.0 / sqrt(head_dim)
)

The kernel uses kv_stride to jump between heads and kv_seq_len to know how many positions to attend over. It reads K[h, 0..kv_seq_len-1, :] and V[h, 0..kv_seq_len-1, :] contiguously for each head h.

For GQA (grouped-query attention), where multiple Q heads share a single KV head, the kernel maps Q head index to KV head index with integer division: kv_head = q_head / (n_heads / n_kv_heads).

Cleanup

The destroy() method frees all GPU buffers:

void destroy(Device& device) {
    for (auto& b : k_buffers) device.free_buffer(b);
    for (auto& b : v_buffers) device.free_buffer(b);
    k_buffers.clear();
    v_buffers.clear();
}

No destructor magic. The caller is responsible for calling destroy() before the Device goes away. This is deliberate – GPU resource lifetimes must be explicit in a system without garbage collection.

What Akunu Does NOT Do

It is worth noting what this KV cache design deliberately omits:

  1. No paged attention. Systems like vLLM use virtual memory paging to efficiently share cache across sequences. Akunu allocates one flat buffer per layer, trading memory efficiency for simplicity and zero fragmentation.

  2. No multi-sequence support. The cache tracks a single current_length. There is no batch dimension or per-sequence tracking. For serving multiple concurrent conversations, you would need multiple model instances.

  3. No quantized cache. Some inference engines store KV in INT8 or INT4 to reduce memory. Akunu keeps everything in FP16 for maximum quality and simplicity.

  4. No sliding window. Some architectures (Mistral) use a sliding window where old positions are evicted. Akunu’s cache is a simple grow-only buffer with a hard maximum.

  5. No speculative decoding cache management. Systems that do speculative decoding need to speculatively advance and then rollback the cache. Akunu’s rollback() could support this, but the current codebase does not implement speculative decoding.

These are all conscious trade-offs. For a single-user inference engine targeting Apple Silicon, the simple flat-buffer approach gives you maximum GPU throughput (contiguous memory access) and zero overhead (no bookkeeping, no page tables).

Summary

+-------------------------------+----------------------------------+
| Design Decision               | Rationale                        |
+-------------------------------+----------------------------------+
| POD struct, no inheritance    | Cache-line friendly, no vtable   |
| Head-major layout             | Contiguous reads per attention   |
|                               | head in flash attention kernel   |
| FP16 storage                  | 2x smaller than FP32, native    |
|                               | Metal half-precision support     |
| Pre-allocate max_length       | Zero allocation in hot path      |
| Zero-fill on creation         | Safe reads past current_length   |
| Single current_length counter | Simple state machine, no locks   |
| Pre-computed kv_stride        | One less multiply per kernel     |
| Explicit destroy()            | Deterministic GPU resource mgmt  |
+-------------------------------+----------------------------------+

The KV cache is the single largest runtime memory consumer after the model weights themselves. Understanding its layout is essential for reasoning about memory budgets, context window limits, and the performance characteristics of the attention kernel. Next, we will look at the other half of the runtime memory story: scratch buffers.