KV Cache Design and Management
Every transformer-based language model has a dirty secret: the attention mechanism is fundamentally stateful. Each time the model generates a new token, it needs to look back at every previous token’s key and value projections. Without caching, you would re-compute K and V for the entire prompt on every single decode step, turning O(n) generation into O(n^2). The KV cache is what prevents that – it stores previously computed key and value tensors so the model only computes the new token’s K and V, then attends over the full cached history.
In this chapter, we will walk through exactly how Akunu designs, allocates, and manages its KV cache. The design philosophy is relentlessly simple: no virtual calls, no optionals, no reference counting. A flat POD struct with contiguous GPU buffers.
What the KV Cache Actually Stores
At every transformer layer, the attention mechanism projects the hidden state into three matrices: Q (query), K (key), and V (value). During decode, we only compute Q/K/V for the current token, but we need the K and V from all previous tokens to compute attention scores. So we keep a running buffer of K and V per layer.
Here is what that looks like conceptually:
For each layer l in [0, n_layers):
K cache[l] = all past key vectors for layer l
V cache[l] = all past value vectors for layer l
When generating token at position t, the model:
- Computes Q_t, K_t, V_t from the current hidden state
- Writes K_t, V_t into the cache at position
t - Reads the full K[0..t], V[0..t] from the cache
- Computes attention: softmax(Q_t @ K[0..t]^T / sqrt(d)) @ V[0..t]
The cache grows by one position per generated token. Without it, you would need to re-run the entire prompt through every layer on every step.
The KVCache Struct
Let us look at Akunu’s actual implementation. The struct lives in
src/cache/kv_cache.h and it is remarkably compact:
struct KVCache {
int n_layers;
int n_kv_heads;
int head_dim;
int max_length;
int current_length;
std::vector<Buffer> k_buffers; // one per layer
std::vector<Buffer> v_buffers; // one per layer
int kv_stride; // max_length * head_dim
};
That is it. Five integers, two vectors of GPU buffers, and a pre-computed stride. No inheritance hierarchy. No smart pointers. No allocators or memory pools. Just the data you need and nothing else.
Let us break each field down:
n_layers: The number of transformer layers. For LLaMA 3-8B, this is 32.
For a 70B model, 80. Each layer gets its own independent K buffer and V buffer.
n_kv_heads: The number of key/value heads. In grouped-query attention (GQA),
this is smaller than the number of query heads. LLaMA 3-8B has 32 query heads
but only 8 KV heads – a 4:1 ratio that saves 75% of the cache memory.
head_dim: The dimension of each attention head. Typically 128 for modern
models (e.g., 4096 / 32 = 128 for LLaMA 3-8B).
max_length: The maximum sequence length the cache can hold. Typically set
to the model’s context window (e.g., 4096 or 8192).
current_length: The stateful counter – how many positions have been written
so far. Starts at 0, incremented by advance(), never exceeds max_length.
kv_stride: Pre-computed as max_length * head_dim. This is the number of
FP16 elements between consecutive KV heads in memory.
Memory Layout: Head-Major Ordering
Each K and V buffer has the following shape:
[n_kv_heads, max_seq_len, head_dim]
All stored in FP16 (half-precision, 2 bytes per element). Let us draw what this looks like in memory for a single layer:
Buffer: k_buffers[layer]
+------------------------------------------------------------------+
| KV Head 0 |
| +-----------------------------------------------------------+ |
| | pos 0: [d0, d1, d2, ..., d127] (head_dim=128, FP16) | |
| | pos 1: [d0, d1, d2, ..., d127] | |
| | pos 2: [d0, d1, d2, ..., d127] | |
| | ... | |
| | pos max_length-1: [d0, d1, ..., d127] | |
| +-----------------------------------------------------------+ |
| KV Head 1 |
| +-----------------------------------------------------------+ |
| | pos 0: [d0, d1, d2, ..., d127] | |
| | pos 1: [d0, d1, d2, ..., d127] | |
| | ... | |
| +-----------------------------------------------------------+ |
| ... |
| KV Head (n_kv_heads-1) |
| +-----------------------------------------------------------+ |
| | pos 0: [d0, d1, d2, ..., d127] | |
| | ... | |
| +-----------------------------------------------------------+ |
+------------------------------------------------------------------+
The linear address of element [head h, position p, dimension d] is:
offset = (h * kv_stride + p * head_dim + d) * sizeof(FP16)
= (h * max_length * head_dim + p * head_dim + d) * 2
Why Head-Major?
You might ask: why not position-major [max_seq_len, n_kv_heads, head_dim]?
The answer comes down to how attention kernels access memory.
During decode, each attention head operates independently. The flash attention
decode kernel processes one head at a time. For a given head h, it needs to
read all positions [0..current_length] for that head in sequence. With
head-major layout, these positions are contiguous in memory:
Head-major (what Akunu uses):
Reading head h, positions 0..T = contiguous read of T*head_dim elements
Great for coalesced GPU memory access!
Head 0: [pos0][pos1][pos2]...[posT][---padding---]
Head 1: [pos0][pos1][pos2]...[posT][---padding---]
...
Position-major (alternative):
Reading head h, positions 0..T = strided read with stride n_kv_heads*head_dim
Terrible for coalesced access!
Pos 0: [head0][head1]...[headN]
Pos 1: [head0][head1]...[headN]
...
With head-major layout, each SIMD group in the attention kernel reads a nice
contiguous chunk of memory. The kv_stride field (= max_length * head_dim)
gives the distance between head 0’s data and head 1’s data, which the kernel
uses to index into the right region.
Memory Budget Calculation
Let us work through a concrete example. Take LLaMA 3.1-8B with a 4096 context:
Parameters:
n_layers = 32
n_kv_heads = 8 (GQA: 32 query heads, 8 KV heads)
head_dim = 128
max_length = 4096
dtype = FP16 (2 bytes)
Per-layer buffer size:
size = n_kv_heads * max_length * head_dim * sizeof(FP16)
= 8 * 4096 * 128 * 2
= 8,388,608 bytes
= 8 MB
Total KV cache (K + V, all layers):
total = 2 * n_layers * size
= 2 * 32 * 8 MB
= 512 MB
That is 512 MB just for the KV cache on a 4096 context. Scale up to 8192 context and you are at 1 GB. For a 70B model with 80 layers and 8 KV heads:
Per-layer: 8 * 8192 * 128 * 2 = 16 MB
Total: 2 * 80 * 16 MB = 2,560 MB = 2.5 GB
This is why GQA is so important – without it (i.e., multi-head attention where
n_kv_heads = n_heads = 64), the 70B model would need:
64 * 8192 * 128 * 2 = 128 MB per layer
2 * 80 * 128 MB = 20,480 MB = 20 GB
GQA with 8 KV heads reduces cache memory by 8x. That is the difference between fitting on a MacBook Pro and not fitting at all.
Here is a summary table:
+----------------+--------+---------+--------+---------+---------+
| Model | Layers | KV Heads| Head D | Ctx Len | KV Size |
+----------------+--------+---------+--------+---------+---------+
| LLaMA 3-8B | 32 | 8 | 128 | 4096 | 512 MB |
| LLaMA 3-8B | 32 | 8 | 128 | 8192 | 1 GB |
| Qwen 2.5-7B | 28 | 4 | 128 | 4096 | 224 MB |
| Gemma 3-4B | 34 | 4 | 256 | 4096 | 896 MB |
| LLaMA 3-70B | 80 | 8 | 128 | 4096 | 1.25 GB |
| LLaMA 3-70B | 80 | 8 | 128 | 8192 | 2.5 GB |
+----------------+--------+---------+--------+---------+---------+
Pre-Allocation: No Malloc in the Hot Path
The KVCache::create() factory allocates everything upfront:
static KVCache create(Device& device, int n_layers, int n_kv_heads,
int head_dim, int max_length) {
KVCache cache;
// ... set fields ...
size_t buf_size = (size_t)n_kv_heads * max_length * head_dim * sizeof(uint16_t);
cache.k_buffers.resize(n_layers);
cache.v_buffers.resize(n_layers);
for (int i = 0; i < n_layers; i++) {
cache.k_buffers[i] = device.allocate(buf_size);
cache.v_buffers[i] = device.allocate(buf_size);
memset(cache.k_buffers[i].contents, 0, buf_size);
memset(cache.v_buffers[i].contents, 0, buf_size);
}
return cache;
}
Note several things:
-
All buffers are allocated at once. No lazy allocation, no on-demand growth. You pay the memory cost at model load time, not during generation.
-
Buffers are zero-filled. The
memsetensures that unwritten positions have zero values. This matters because the attention kernel might read pastcurrent_lengthdue to SIMD alignment, and we do not want garbage affecting softmax scores. -
The
sizeof(uint16_t)is FP16. Akunu stores cache values in half precision. There is no option for FP32 or INT8 cache – keeping it simple. -
Device::allocate()returns aBuffer. On Metal, this is anMTLBufferallocated in shared memory (Apple Silicon UMA), meaning both CPU and GPU can access it without explicit copies.
Stateful Tracking
The KV cache tracks how many positions have been filled via current_length. The
API provides four operations for managing this state:
advance(count)
After computing K and V for count new tokens (1 during decode, N during prefill),
call advance() to update the position:
void advance(int count) {
current_length += count;
if (current_length > max_length)
current_length = max_length;
}
The clamping to max_length is a safety measure. In practice, the caller should
check would_overflow() before adding tokens, but the clamp prevents buffer
overruns if something goes wrong.
would_overflow(additional)
bool would_overflow(int additional) const {
return current_length + additional > max_length;
}
The caller checks this before prefill or decode to avoid writing past the buffer. If it returns true, the inference engine must either truncate the input or refuse the request.
rollback(to_length)
void rollback(int to_length) {
if (to_length < current_length)
current_length = to_length;
}
Rollback moves the cursor backwards. The actual data in the buffers is not
erased – only the position counter changes. This is safe because future writes
at positions >= to_length will overwrite the stale data before it is read.
This is used for prefix caching in the server: if the new prompt shares a prefix with the previous one, we rollback to the shared prefix length and only re-compute the divergent tokens.
reset()
void reset() { current_length = 0; }
The nuclear option. Resets to the beginning without touching the actual buffer contents. The next prefill will overwrite everything.
The Lifecycle of a Cache During Generation
Here is the full flow from prompt to generation:
1. User sends prompt: "What is the capital of France?"
Tokenized: [BOS, 1724, 338, 278, 7483, 310, 3444, 29973] (8 tokens)
2. PREFILL: process all 8 tokens in one batch
For each layer:
Compute K[0..7], V[0..7] <-- batch of 8 KV vectors
Write into cache at pos 0..7
cache.advance(8) <-- current_length = 8
3. DECODE step 1: generate token "The"
For each layer:
Compute K[8], V[8] <-- single new KV vector
Write into cache at pos 8
Attend Q[8] over K[0..8], V[0..8]
cache.advance(1) <-- current_length = 9
4. DECODE step 2: generate token " capital"
For each layer:
Compute K[9], V[9]
Write into cache at pos 9
Attend Q[9] over K[0..9], V[0..9]
cache.advance(1) <-- current_length = 10
5. ... continue until EOS or max_tokens ...
6. NEXT CONVERSATION:
New prompt: "What is the capital of Germany?"
Shares prefix: "What is the capital of " = 7 tokens
Option A: cache.reset() + full prefill (simple)
Option B: rollback to shared prefix + incremental prefill (efficient)
With prefix caching (Option B):
shared = 7 tokens match
cache.rollback(7) <-- current_length = 7
Prefill only tokens 7..N <-- "Germany?" = 2 tokens
cache.advance(2) <-- current_length = 9
Here is a timeline diagram:
Position in KV cache:
0 1 2 3 4 5 6 7 8 9 10 11
| | | | | | | | | | | |
[BOS][What][ is][ the][cap][ital][ of][ Fr][The][ ca][pit]
|----- prefill (8 tokens) -----| |-- decode step by step--|
^
current_length advances: 8 -> 9 -> 10 -> 11
After rollback(7) for new conversation:
0 1 2 3 4 5 6 7 8
| | | | | | | | |
[BOS][What][ is][ the][cap][ital][ of][Ger][many]
|-- preserved prefix (7) --| ^
incremental prefill: 2 tokens
Prefix Caching in the Server
The HTTP server (Chapter 50) maintains a ModelEntry per loaded model that tracks
the last prompt’s tokens:
struct ModelEntry {
std::vector<uint32_t> cached_tokens;
int cached_position = 0;
int shared_prefix(const uint32_t *tokens, int n_tokens) const {
int shared = 0;
int limit = std::min((int)cached_tokens.size(), n_tokens);
for (int i = 0; i < limit; i++) {
if (cached_tokens[i] != tokens[i]) break;
shared++;
}
return shared;
}
};
When a new request arrives:
- Encode the new prompt to tokens
- Compare with
cached_tokensto find the shared prefix length - If
shared > 0 && shared <= cached_position, use rollback + incremental prefill - Otherwise, full reset + prefill from scratch
This gives you “free” prefix caching with zero extra infrastructure. In a chatbot scenario where the system prompt is the same across turns, you skip re-processing hundreds of tokens. For a 2048-token system prompt, that can save 50-100ms of prefill time on each request.
How the Attention Kernel Uses the Cache
During decode, the flash attention kernel receives the cache buffers as arguments. Here is a simplified view of the kernel parameters:
flash_attention_decode_fast_f16(
Q: [n_heads, 1, head_dim] <-- current token's query
K_cache: [n_kv_heads, max_seq, head_dim] <-- full K cache for this layer
V_cache: [n_kv_heads, max_seq, head_dim] <-- full V cache for this layer
output: [n_heads, 1, head_dim] <-- attention output
kv_seq_len: current_length + 1 <-- how far to read
kv_stride: max_length * head_dim <-- stride between heads
scale: 1.0 / sqrt(head_dim)
)
The kernel uses kv_stride to jump between heads and kv_seq_len to know how
many positions to attend over. It reads K[h, 0..kv_seq_len-1, :] and
V[h, 0..kv_seq_len-1, :] contiguously for each head h.
For GQA (grouped-query attention), where multiple Q heads share a single KV head,
the kernel maps Q head index to KV head index with integer division:
kv_head = q_head / (n_heads / n_kv_heads).
Cleanup
The destroy() method frees all GPU buffers:
void destroy(Device& device) {
for (auto& b : k_buffers) device.free_buffer(b);
for (auto& b : v_buffers) device.free_buffer(b);
k_buffers.clear();
v_buffers.clear();
}
No destructor magic. The caller is responsible for calling destroy() before
the Device goes away. This is deliberate – GPU resource lifetimes must be
explicit in a system without garbage collection.
What Akunu Does NOT Do
It is worth noting what this KV cache design deliberately omits:
-
No paged attention. Systems like vLLM use virtual memory paging to efficiently share cache across sequences. Akunu allocates one flat buffer per layer, trading memory efficiency for simplicity and zero fragmentation.
-
No multi-sequence support. The cache tracks a single
current_length. There is no batch dimension or per-sequence tracking. For serving multiple concurrent conversations, you would need multiple model instances. -
No quantized cache. Some inference engines store KV in INT8 or INT4 to reduce memory. Akunu keeps everything in FP16 for maximum quality and simplicity.
-
No sliding window. Some architectures (Mistral) use a sliding window where old positions are evicted. Akunu’s cache is a simple grow-only buffer with a hard maximum.
-
No speculative decoding cache management. Systems that do speculative decoding need to speculatively advance and then rollback the cache. Akunu’s
rollback()could support this, but the current codebase does not implement speculative decoding.
These are all conscious trade-offs. For a single-user inference engine targeting Apple Silicon, the simple flat-buffer approach gives you maximum GPU throughput (contiguous memory access) and zero overhead (no bookkeeping, no page tables).
Summary
+-------------------------------+----------------------------------+
| Design Decision | Rationale |
+-------------------------------+----------------------------------+
| POD struct, no inheritance | Cache-line friendly, no vtable |
| Head-major layout | Contiguous reads per attention |
| | head in flash attention kernel |
| FP16 storage | 2x smaller than FP32, native |
| | Metal half-precision support |
| Pre-allocate max_length | Zero allocation in hot path |
| Zero-fill on creation | Safe reads past current_length |
| Single current_length counter | Simple state machine, no locks |
| Pre-computed kv_stride | One less multiply per kernel |
| Explicit destroy() | Deterministic GPU resource mgmt |
+-------------------------------+----------------------------------+
The KV cache is the single largest runtime memory consumer after the model weights themselves. Understanding its layout is essential for reasoning about memory budgets, context window limits, and the performance characteristics of the attention kernel. Next, we will look at the other half of the runtime memory story: scratch buffers.