Scratch Buffer Architecture
If the KV cache is the long-term memory of inference, scratch buffers are the
working memory – the scratchpad where every intermediate computation happens.
Matrix multiplications, attention outputs, FFN activations, logits – all of these
need temporary GPU buffers, and if you allocate them on the fly during generation,
you are dead. GPU memory allocation is slow. Fragmentation is real. And the last
thing you want in a tight decode loop running at 100+ tokens per second is a call
to MTLDevice.makeBuffer().
Akunu’s solution is dead simple: allocate every scratch buffer at model load
time, reuse them every forward pass, and never touch the allocator again. This
chapter walks through the ScratchBuffers struct, the ping-pong pattern, the
QKV sub-offset trick, and the full memory budget.
The Core Idea: Zero Allocation in the Hot Path
Here is the principle, stated bluntly: after ScratchBuffers::create() returns,
zero bytes of GPU memory are ever allocated during inference. Not one buffer.
Not one resize. Nothing.
Every temporary result – the hidden state after embedding lookup, the Q/K/V projections, the attention output, the FFN gate and up projections, the activated intermediate, the final logits – all live in pre-allocated buffers that get overwritten on every forward pass.
This is possible because the sizes of all intermediates are known at model load
time. The model config tells us dim, q_dim, kv_dim, ffn_dim, and
vocab_size. Every buffer size is a simple function of these constants.
The ScratchBuffers Struct
The full struct from src/cache/scratch.h:
struct ScratchBuffers {
// === Decode buffers (single token) ===
Buffer h0; // [dim] FP16
Buffer h1; // [dim] FP16
Buffer residual; // [dim] FP16
Buffer qkv; // [q_dim + 2*kv_dim] FP16
Buffer attn_out; // [max(q_dim, dim)] FP16
Buffer post_norm; // [dim] FP16
Buffer ffn_gate; // [ffn_dim] FP16
Buffer ffn_up; // [ffn_dim] FP16
Buffer ffn_act; // [ffn_dim] FP16
Buffer logits; // [vocab_size] FP16
Buffer token_ids; // [max_chain] U32
int qkv_q_offset; // byte offset of Q within qkv
int qkv_k_offset; // byte offset of K within qkv
int qkv_v_offset; // byte offset of V within qkv
// === Prefill buffers (batch) ===
Buffer batch_h0, batch_h1, batch_residual;
Buffer batch_q, batch_k, batch_v;
Buffer batch_attn_out;
Buffer batch_gate, batch_up, batch_act;
Buffer batch_post_norm;
int max_prefill_chunk;
};
There are two sets of buffers: decode buffers (single-token inference) and prefill buffers (batch processing of the prompt). Let us go through each one.
Decode Buffers: The Single-Token Pipeline
During decode, exactly one token flows through the transformer per step. The buffers are sized for a single-row computation:
Forward pass data flow (one decode step):
token_ids ──> [embedding lookup] ──> h0 [dim]
|
+─────────────+
|
[layer loop x N_layers]
|
h0 ──> [RMSNorm] ──> residual ──> [QKV GEMV] ──> qkv [q_dim+2*kv_dim]
|
+-──────────────+──────────────+
| | |
Q [q_dim] K [kv_dim] V [kv_dim]
| | |
| [KV cache write] |
| | |
+──> [Flash Attention] <──────+
|
attn_out [dim]
|
[O projection]
|
h1 [dim] ──+──> h0 = h0 + h1
| (residual add)
+───────────────────────────────+
|
h0 ──> [RMSNorm] ──> residual ──> [Gate GEMV] ──> ffn_gate [ffn_dim]
|──> [Up GEMV] ──> ffn_up [ffn_dim]
|
[SiLU(gate) * up] ──> ffn_act [ffn_dim]
|
[Down GEMV] ──> h1 [dim]
|
h0 = h0 + h1 (residual add)
|
[end layer loop]
|
h0 ──> [RMSNorm] ──> residual ──> [Logit GEMV] ──> logits [vocab_size]
|
[argmax/sample]
|
token_ids (next token)
h0 and h1: The Ping-Pong Pair
These are the two primary hidden state buffers, each sized [dim] in FP16.
They implement a ping-pong pattern: the output of one operation goes into
one buffer, the next operation reads from that buffer and writes to the other.
Layer start: h0 holds the current hidden state
RMSNorm: reads h0, writes residual
QKV GEMV: reads residual, writes qkv
Attention: reads qkv + KV cache, writes attn_out
O Projection: reads attn_out, writes h1
Residual add: h0 = h0 + h1 (h0 updated in-place)
FFN start: h0 holds the updated hidden state
RMSNorm: reads h0, writes residual
Gate GEMV: reads residual, writes ffn_gate
Up GEMV: reads residual, writes ffn_up
Activation: reads ffn_gate + ffn_up, writes ffn_act
Down GEMV: reads ffn_act, writes h1
Residual add: h0 = h0 + h1 (h0 updated in-place)
Next layer: h0 holds the result
Why ping-pong instead of in-place? Because GPU kernels read and write concurrently. You cannot safely read and write the same buffer in a single dispatch. The ping-pong pattern ensures the read source and write destination are always different physical buffers.
Ping-pong pattern across operations:
Op 1: READ h0 ────> WRITE h1
Op 2: READ h1 ────> WRITE h0
Op 3: READ h0 ────> WRITE h1
...
The two buffers alternate roles as "source" and "destination"
so we never have a read-write hazard on the same buffer.
The residual Buffer
Sized [dim] FP16. Holds the output of RMSNorm/LayerNorm before it gets projected
into Q/K/V or FFN inputs. This is separate from h0/h1 because we need the
un-normed h0 for the residual connection: h0 = h0 + f(norm(h0)).
The qkv Buffer: Contiguous Q|K|V
This is one of the cleverer design decisions. Instead of three separate buffers for Q, K, and V, Akunu packs them into a single contiguous buffer:
qkv buffer layout (bytes):
|<-------- q_dim*2 -------->|<--- kv_dim*2 --->|<--- kv_dim*2 --->|
| Q region | K region | V region |
^ ^ ^
qkv_q_offset = 0 qkv_k_offset qkv_v_offset
The byte offsets are pre-computed at creation time:
s.qkv_q_offset = 0;
s.qkv_k_offset = q_dim * 2; // q_dim FP16 elements = q_dim*2 bytes
s.qkv_v_offset = (q_dim + kv_dim) * 2;
Why pack them together? Two reasons:
-
Fewer buffer bindings. The QKV GEMV kernel can write Q, K, and V with different buffer offsets into the same Metal buffer. This means one buffer binding instead of three, which reduces Metal command encoder overhead.
-
The KV cache write kernel reads K and V from known offsets. It binds the
qkvbuffer with the appropriate offset to read just the K or V portion.
For GQA models where q_dim != kv_dim (e.g., LLaMA 3-8B has q_dim=4096,
kv_dim=1024), this packing is especially efficient – Q is 4x larger than K or V,
and they all fit in one allocation.
Example: LLaMA 3.1-8B
q_dim = 32 heads * 128 dim = 4096
kv_dim = 8 heads * 128 dim = 1024
qkv buffer: (4096 + 1024 + 1024) * 2 = 12,288 bytes = 12 KB
Offsets:
Q starts at byte 0
K starts at byte 8192
V starts at byte 10240
ffn_gate, ffn_up, ffn_act
Three buffers for the feed-forward network, each [ffn_dim] FP16. The FFN in
LLaMA-style models is:
gate = W_gate @ x --> ffn_gate
up = W_up @ x --> ffn_up
act = SiLU(gate) * up --> ffn_act
out = W_down @ act --> h1
Note that ffn_gate is allocated as ffn_dim * 2 * 2 bytes (double-sized) to
support fused gate+up computation in some kernel variants. The actual FFN
dimension for LLaMA 3-8B is 14336, so ffn_gate is 56 KB.
logits Buffer
Sized [vocab_size] FP16. For LLaMA 3, vocab_size = 128256, so this buffer is
128256 * 2 = ~250 KB. The logit projection (hidden state times the un-embedding
matrix) writes here, and the sampler reads from here.
token_ids Buffer
Sized [max_chain] U32 (4 bytes per token). This holds token IDs for both the
input (prefill) and the output (chain decode). The chain decode loop writes
the next token ID here, then the embedding lookup reads it on the next iteration.
Prefill Buffers: Batch Processing
During prefill, we process up to max_prefill_chunk tokens simultaneously (default
4096). Every buffer needs a batch dimension:
Decode buffer: [dim] -- 1 token
Prefill buffer: [prefill_chunk, dim] -- up to 4096 tokens
The prefill buffers mirror the decode buffers with an added batch dimension:
+------------------+----------------------------------+
| Decode Buffer | Prefill Buffer |
+------------------+----------------------------------+
| h0 [dim] | batch_h0 [chunk * dim] |
| h1 [dim] | batch_h1 [chunk * dim] |
| residual [dim] | batch_residual [chunk * dim] |
| qkv [qkv_dim] | batch_q [chunk * q_dim] |
| | batch_k [chunk * kv_dim] |
| | batch_v [chunk * kv_dim] |
| attn_out [dim] | batch_attn_out [chunk * dim] |
| ffn_gate [ffn] | batch_gate [chunk * ffn_dim] |
| ffn_up [ffn] | batch_up [chunk * ffn_dim] |
| ffn_act [ffn] | batch_act [chunk * ffn_dim] |
| post_norm [dim] | batch_post_norm [chunk * dim] |
+------------------+----------------------------------+
Note that prefill uses separate Q, K, V buffers instead of the packed qkv
layout. This is because the prefill attention kernel (flash attention prefill)
expects separate Q, K, V inputs in the shape [seq, n_heads, head_dim], which
is more natural for batched GEMM operations.
Memory Budget Calculation
Let us compute the total scratch memory for LLaMA 3.1-8B:
Model parameters:
dim = 4096
q_dim = 4096 (32 heads * 128)
kv_dim = 1024 (8 heads * 128)
ffn_dim = 14336
vocab_size = 128256
chunk = 4096 (prefill chunk size)
Decode buffers (single token):
h0: 4096 * 2 = 8,192 bytes
h1: 4096 * 2 = 8,192 bytes
residual: 4096 * 2 = 8,192 bytes
qkv: (4096+2*1024) * 2 = 12,288 bytes
attn_out: 4096 * 2 = 8,192 bytes
post_norm: 4096 * 2 = 8,192 bytes
ffn_gate: 14336 * 2 * 2 = 57,344 bytes (2x for fused gate+up)
ffn_up: 14336 * 2 = 28,672 bytes
ffn_act: 14336 * 2 = 28,672 bytes
logits: 128256 * 2 = 256,512 bytes
token_ids: 4096 * 4 = 16,384 bytes
──────────────────────────────────────────────
Decode total: ~432 KB
Prefill buffers (4096 tokens):
batch_h0: 4096 * 4096 * 2 = 33,554,432 bytes
batch_h1: 4096 * 4096 * 2 = 33,554,432 bytes
batch_residual: 4096 * 4096 * 2 = 33,554,432 bytes
batch_q: 4096 * 4096 * 2 = 33,554,432 bytes
batch_k: 4096 * 1024 * 2 = 8,388,608 bytes
batch_v: 4096 * 1024 * 2 = 8,388,608 bytes
batch_attn_out: 4096 * 4096 * 2 = 33,554,432 bytes
batch_gate: 4096 * 14336 * 2 = 117,440,512 bytes
batch_up: 4096 * 14336 * 2 = 117,440,512 bytes
batch_act: 4096 * 14336 * 2 = 117,440,512 bytes
batch_post_norm:4096 * 4096 * 2 = 33,554,432 bytes
──────────────────────────────────────────────
Prefill total: ~534 MB
Grand total scratch: ~534 MB
The prefill buffers dominate – they are the batch dimension multiplied by the
hidden and FFN dimensions. For a model with ffn_dim = 14336 and a prefill
chunk of 4096, each FFN buffer alone is 112 MB.
Here is the full memory picture for LLaMA 3.1-8B at 4096 context:
+-----------------------------------+-----------+
| Component | Memory |
+-----------------------------------+-----------+
| Model weights (Q4_0) | ~4.3 GB |
| KV cache (32 layers, 4096 ctx) | 512 MB |
| Scratch decode buffers | ~0.4 MB |
| Scratch prefill buffers | ~534 MB |
+-----------------------------------+-----------+
| Total | ~5.3 GB |
+-----------------------------------+-----------+
Buffer Reuse Within a Forward Pass
A key insight is that these buffers are reused within a single forward pass, not just across forward passes. Within the layer loop:
Layer L:
residual: written by RMSNorm, read by QKV projection
qkv: written by QKV projection, read by attention + KV write
attn_out: written by attention, read by O projection
h1: written by O projection, read by residual add
ffn_gate: written by Gate GEMV, read by activation
ffn_up: written by Up GEMV, read by activation
ffn_act: written by activation, read by Down GEMV
Layer L+1:
Same buffers, completely overwritten!
The transformer processes layers sequentially. Layer L’s ffn_gate output is
consumed within layer L, then layer L+1 overwrites the same buffer with its own
ffn_gate output. No per-layer scratch is needed – one set of buffers serves
all layers.
The only per-layer storage is the KV cache (Chapter 45), which must retain values across the entire sequence.
The Ping-Pong Pattern in Detail
Let us trace the exact read/write pattern through two consecutive layers:
Layer 0:
READ h0 WRITE residual (RMSNorm)
READ residual WRITE qkv (QKV GEMV)
READ qkv+cache WRITE attn_out (Attention)
READ attn_out WRITE h1 (O GEMV)
READ h0, h1 WRITE h0 (Residual add: h0 += h1)
READ h0 WRITE residual (RMSNorm)
READ residual WRITE ffn_gate (Gate GEMV)
READ residual WRITE ffn_up (Up GEMV)
READ gate, up WRITE ffn_act (SiLU * mul)
READ ffn_act WRITE h1 (Down GEMV)
READ h0, h1 WRITE h0 (Residual add: h0 += h1)
Layer 1:
READ h0 WRITE residual (RMSNorm)
... same pattern, same buffers, completely safe because
each buffer is fully consumed before being overwritten ...
Notice that h0 is both read and written in the residual add step. This is
safe because it is an element-wise operation (h0[i] += h1[i]), implemented
as a fused add kernel that handles the in-place update correctly.
Allocation: All at Once, All FP16
The create() factory method allocates everything in one shot:
static ScratchBuffers create(Device& device, const AkunuModelConfig& cfg,
int max_context = 4096,
int prefill_chunk = 4096,
int max_chain = 128) {
ScratchBuffers s;
int dim = cfg.dim;
int q_dim = cfg.q_dim;
int kv_dim = cfg.kv_dim;
int ffn_dim = cfg.ffn_dim;
int vocab = cfg.vocab_size;
// Decode buffers
s.h0 = device.allocate(dim * 2);
s.h1 = device.allocate(dim * 2);
s.residual = device.allocate(dim * 2);
s.qkv = device.allocate((q_dim + 2 * kv_dim) * 2);
s.attn_out = device.allocate((q_dim > dim ? q_dim : dim) * 2);
// ... etc ...
// Prefill buffers
s.batch_h0 = device.allocate(prefill_chunk * dim * 2);
// ... etc ...
return s;
}
Every size is count * 2 because FP16 is 2 bytes per element. The token_ids
buffer uses count * 4 because token IDs are 32-bit unsigned integers.
Notice the max(q_dim, dim) for attn_out – this handles the case where the
attention output dimension might differ from the model dimension (though in
practice they are usually equal).
Cleanup
Like the KV cache, cleanup is explicit:
void destroy(Device& device) {
for (Buffer *b : {&h0, &h1, &residual, &qkv, &attn_out, &post_norm,
&ffn_gate, &ffn_up, &ffn_act, &logits, &token_ids,
&batch_h0, &batch_h1, &batch_residual,
&batch_q, &batch_k, &batch_v, &batch_attn_out,
&batch_gate, &batch_up, &batch_act, &batch_post_norm}) {
device.free_buffer(*b);
}
}
Every buffer gets freed. The initializer list is a convenient C++ trick for iterating over all the member buffers without repeating the free logic.
Post-Norm Buffer (Gemma Compatibility)
The post_norm and batch_post_norm buffers are specifically for Gemma-style
architectures that use post-attention and post-FFN normalization:
Standard LLaMA: x = x + Attn(Norm(x))
Gemma: x = x + Norm(Attn(Norm(x)))
^^^^
post-norm needs its own buffer
For models that do not use post-norm (most LLaMA variants), this buffer is
allocated but never written to. The wasted memory is dim * 2 = 8 KB for the
decode version – negligible.
Why Not Use a Memory Pool?
You might wonder: instead of individual named buffers, why not allocate one big slab and carve it up? Memory pools are common in GPU programming.
The answer is debuggability. With named buffers:
- Metal GPU debugger shows “h0”, “ffn_gate”, “logits” etc. in the buffer list
- Each buffer has a known size that matches its semantic purpose
- There is no offset arithmetic to get wrong
- Adding a new buffer is trivial – just add a field and an allocate call
The overhead of having ~22 separate MTLBuffer objects instead of 1 is
negligible. Metal’s buffer creation is fast, and we only do it once at load time.
Summary
Key design principles:
1. Pre-allocate ALL buffers at model load
2. Zero allocation during inference
3. Ping-pong (h0/h1) avoids read-write hazards
4. Contiguous QKV with byte sub-offsets
5. Separate decode (1-token) and prefill (N-token) buffer sets
6. Every buffer is reused across all transformer layers
7. Explicit create/destroy lifecycle
Memory hierarchy during inference:
+─────────────────────────────────────+
| Model Weights (read-only, ~GB) | Largest
+─────────────────────────────────────+
| Prefill Scratch (~500 MB) |
+─────────────────────────────────────+
| KV Cache (~512 MB for 4K ctx) |
+─────────────────────────────────────+
| Decode Scratch (~0.4 MB) | Smallest
+─────────────────────────────────────+
With the KV cache and scratch buffers understood, we have covered the complete runtime memory picture. Every byte of GPU memory used during inference is accounted for: model weights, KV cache, and scratch buffers. No hidden allocations, no surprises, no fragmentation. This is what makes it possible to predict exactly whether a given model will fit in memory before loading it.