Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Scratch Buffer Architecture

If the KV cache is the long-term memory of inference, scratch buffers are the working memory – the scratchpad where every intermediate computation happens. Matrix multiplications, attention outputs, FFN activations, logits – all of these need temporary GPU buffers, and if you allocate them on the fly during generation, you are dead. GPU memory allocation is slow. Fragmentation is real. And the last thing you want in a tight decode loop running at 100+ tokens per second is a call to MTLDevice.makeBuffer().

Akunu’s solution is dead simple: allocate every scratch buffer at model load time, reuse them every forward pass, and never touch the allocator again. This chapter walks through the ScratchBuffers struct, the ping-pong pattern, the QKV sub-offset trick, and the full memory budget.

The Core Idea: Zero Allocation in the Hot Path

Here is the principle, stated bluntly: after ScratchBuffers::create() returns, zero bytes of GPU memory are ever allocated during inference. Not one buffer. Not one resize. Nothing.

Every temporary result – the hidden state after embedding lookup, the Q/K/V projections, the attention output, the FFN gate and up projections, the activated intermediate, the final logits – all live in pre-allocated buffers that get overwritten on every forward pass.

This is possible because the sizes of all intermediates are known at model load time. The model config tells us dim, q_dim, kv_dim, ffn_dim, and vocab_size. Every buffer size is a simple function of these constants.

The ScratchBuffers Struct

The full struct from src/cache/scratch.h:

struct ScratchBuffers {
    // === Decode buffers (single token) ===
    Buffer h0;         // [dim] FP16
    Buffer h1;         // [dim] FP16
    Buffer residual;   // [dim] FP16
    Buffer qkv;        // [q_dim + 2*kv_dim] FP16
    Buffer attn_out;   // [max(q_dim, dim)] FP16
    Buffer post_norm;  // [dim] FP16
    Buffer ffn_gate;   // [ffn_dim] FP16
    Buffer ffn_up;     // [ffn_dim] FP16
    Buffer ffn_act;    // [ffn_dim] FP16
    Buffer logits;     // [vocab_size] FP16
    Buffer token_ids;  // [max_chain] U32

    int qkv_q_offset;  // byte offset of Q within qkv
    int qkv_k_offset;  // byte offset of K within qkv
    int qkv_v_offset;  // byte offset of V within qkv

    // === Prefill buffers (batch) ===
    Buffer batch_h0, batch_h1, batch_residual;
    Buffer batch_q, batch_k, batch_v;
    Buffer batch_attn_out;
    Buffer batch_gate, batch_up, batch_act;
    Buffer batch_post_norm;

    int max_prefill_chunk;
};

There are two sets of buffers: decode buffers (single-token inference) and prefill buffers (batch processing of the prompt). Let us go through each one.

Decode Buffers: The Single-Token Pipeline

During decode, exactly one token flows through the transformer per step. The buffers are sized for a single-row computation:

Forward pass data flow (one decode step):

  token_ids ──> [embedding lookup] ──> h0 [dim]
                                        |
                          +─────────────+
                          |
                   [layer loop x N_layers]
                          |
       h0 ──> [RMSNorm] ──> residual ──> [QKV GEMV] ──> qkv [q_dim+2*kv_dim]
                                                          |
                                          +-──────────────+──────────────+
                                          |               |              |
                                     Q [q_dim]       K [kv_dim]    V [kv_dim]
                                          |               |              |
                                          |          [KV cache write]    |
                                          |               |              |
                                          +──> [Flash Attention] <──────+
                                                     |
                                              attn_out [dim]
                                                     |
                                              [O projection]
                                                     |
                                              h1 [dim]  ──+──> h0 = h0 + h1
                                                          |    (residual add)
                          +───────────────────────────────+
                          |
       h0 ──> [RMSNorm] ──> residual ──> [Gate GEMV] ──> ffn_gate [ffn_dim]
                                    |──> [Up GEMV]   ──> ffn_up [ffn_dim]
                                                          |
                                          [SiLU(gate) * up] ──> ffn_act [ffn_dim]
                                                          |
                                              [Down GEMV] ──> h1 [dim]
                                                          |
                                              h0 = h0 + h1  (residual add)
                          |
                   [end layer loop]
                          |
       h0 ──> [RMSNorm] ──> residual ──> [Logit GEMV] ──> logits [vocab_size]
                                                            |
                                                     [argmax/sample]
                                                            |
                                                     token_ids (next token)

h0 and h1: The Ping-Pong Pair

These are the two primary hidden state buffers, each sized [dim] in FP16. They implement a ping-pong pattern: the output of one operation goes into one buffer, the next operation reads from that buffer and writes to the other.

Layer start:  h0 holds the current hidden state
RMSNorm:      reads h0, writes residual
QKV GEMV:     reads residual, writes qkv
Attention:    reads qkv + KV cache, writes attn_out
O Projection: reads attn_out, writes h1
Residual add: h0 = h0 + h1  (h0 updated in-place)

FFN start:    h0 holds the updated hidden state
RMSNorm:      reads h0, writes residual
Gate GEMV:    reads residual, writes ffn_gate
Up GEMV:      reads residual, writes ffn_up
Activation:   reads ffn_gate + ffn_up, writes ffn_act
Down GEMV:    reads ffn_act, writes h1
Residual add: h0 = h0 + h1  (h0 updated in-place)

Next layer:   h0 holds the result

Why ping-pong instead of in-place? Because GPU kernels read and write concurrently. You cannot safely read and write the same buffer in a single dispatch. The ping-pong pattern ensures the read source and write destination are always different physical buffers.

Ping-pong pattern across operations:

  Op 1:  READ h0 ────> WRITE h1
  Op 2:  READ h1 ────> WRITE h0
  Op 3:  READ h0 ────> WRITE h1
  ...

  The two buffers alternate roles as "source" and "destination"
  so we never have a read-write hazard on the same buffer.

The residual Buffer

Sized [dim] FP16. Holds the output of RMSNorm/LayerNorm before it gets projected into Q/K/V or FFN inputs. This is separate from h0/h1 because we need the un-normed h0 for the residual connection: h0 = h0 + f(norm(h0)).

The qkv Buffer: Contiguous Q|K|V

This is one of the cleverer design decisions. Instead of three separate buffers for Q, K, and V, Akunu packs them into a single contiguous buffer:

qkv buffer layout (bytes):

  |<-------- q_dim*2 -------->|<--- kv_dim*2 --->|<--- kv_dim*2 --->|
  |          Q region         |     K region      |     V region      |
  ^                           ^                   ^
  qkv_q_offset = 0            qkv_k_offset        qkv_v_offset

The byte offsets are pre-computed at creation time:

s.qkv_q_offset = 0;
s.qkv_k_offset = q_dim * 2;      // q_dim FP16 elements = q_dim*2 bytes
s.qkv_v_offset = (q_dim + kv_dim) * 2;

Why pack them together? Two reasons:

  1. Fewer buffer bindings. The QKV GEMV kernel can write Q, K, and V with different buffer offsets into the same Metal buffer. This means one buffer binding instead of three, which reduces Metal command encoder overhead.

  2. The KV cache write kernel reads K and V from known offsets. It binds the qkv buffer with the appropriate offset to read just the K or V portion.

For GQA models where q_dim != kv_dim (e.g., LLaMA 3-8B has q_dim=4096, kv_dim=1024), this packing is especially efficient – Q is 4x larger than K or V, and they all fit in one allocation.

Example: LLaMA 3.1-8B
  q_dim  = 32 heads * 128 dim = 4096
  kv_dim =  8 heads * 128 dim = 1024

  qkv buffer: (4096 + 1024 + 1024) * 2 = 12,288 bytes = 12 KB

  Offsets:
    Q starts at byte 0
    K starts at byte 8192
    V starts at byte 10240

ffn_gate, ffn_up, ffn_act

Three buffers for the feed-forward network, each [ffn_dim] FP16. The FFN in LLaMA-style models is:

gate = W_gate @ x          --> ffn_gate
up   = W_up   @ x          --> ffn_up
act  = SiLU(gate) * up     --> ffn_act
out  = W_down @ act         --> h1

Note that ffn_gate is allocated as ffn_dim * 2 * 2 bytes (double-sized) to support fused gate+up computation in some kernel variants. The actual FFN dimension for LLaMA 3-8B is 14336, so ffn_gate is 56 KB.

logits Buffer

Sized [vocab_size] FP16. For LLaMA 3, vocab_size = 128256, so this buffer is 128256 * 2 = ~250 KB. The logit projection (hidden state times the un-embedding matrix) writes here, and the sampler reads from here.

token_ids Buffer

Sized [max_chain] U32 (4 bytes per token). This holds token IDs for both the input (prefill) and the output (chain decode). The chain decode loop writes the next token ID here, then the embedding lookup reads it on the next iteration.

Prefill Buffers: Batch Processing

During prefill, we process up to max_prefill_chunk tokens simultaneously (default 4096). Every buffer needs a batch dimension:

Decode buffer:    [dim]                    -- 1 token
Prefill buffer:   [prefill_chunk, dim]     -- up to 4096 tokens

The prefill buffers mirror the decode buffers with an added batch dimension:

+------------------+----------------------------------+
| Decode Buffer    | Prefill Buffer                   |
+------------------+----------------------------------+
| h0  [dim]        | batch_h0  [chunk * dim]          |
| h1  [dim]        | batch_h1  [chunk * dim]          |
| residual [dim]   | batch_residual [chunk * dim]     |
| qkv [qkv_dim]    | batch_q [chunk * q_dim]          |
|                  | batch_k [chunk * kv_dim]         |
|                  | batch_v [chunk * kv_dim]         |
| attn_out [dim]   | batch_attn_out [chunk * dim]     |
| ffn_gate [ffn]   | batch_gate [chunk * ffn_dim]     |
| ffn_up [ffn]     | batch_up [chunk * ffn_dim]       |
| ffn_act [ffn]    | batch_act [chunk * ffn_dim]      |
| post_norm [dim]  | batch_post_norm [chunk * dim]    |
+------------------+----------------------------------+

Note that prefill uses separate Q, K, V buffers instead of the packed qkv layout. This is because the prefill attention kernel (flash attention prefill) expects separate Q, K, V inputs in the shape [seq, n_heads, head_dim], which is more natural for batched GEMM operations.

Memory Budget Calculation

Let us compute the total scratch memory for LLaMA 3.1-8B:

Model parameters:
  dim        = 4096
  q_dim      = 4096  (32 heads * 128)
  kv_dim     = 1024  (8 heads * 128)
  ffn_dim    = 14336
  vocab_size = 128256
  chunk      = 4096  (prefill chunk size)

Decode buffers (single token):
  h0:        4096 * 2          =     8,192 bytes
  h1:        4096 * 2          =     8,192 bytes
  residual:  4096 * 2          =     8,192 bytes
  qkv:       (4096+2*1024) * 2 =    12,288 bytes
  attn_out:  4096 * 2          =     8,192 bytes
  post_norm: 4096 * 2          =     8,192 bytes
  ffn_gate:  14336 * 2 * 2     =    57,344 bytes  (2x for fused gate+up)
  ffn_up:    14336 * 2         =    28,672 bytes
  ffn_act:   14336 * 2         =    28,672 bytes
  logits:    128256 * 2        =   256,512 bytes
  token_ids: 4096 * 4          =    16,384 bytes
  ──────────────────────────────────────────────
  Decode total:                   ~432 KB

Prefill buffers (4096 tokens):
  batch_h0:       4096 * 4096 * 2   =    33,554,432 bytes
  batch_h1:       4096 * 4096 * 2   =    33,554,432 bytes
  batch_residual: 4096 * 4096 * 2   =    33,554,432 bytes
  batch_q:        4096 * 4096 * 2   =    33,554,432 bytes
  batch_k:        4096 * 1024 * 2   =     8,388,608 bytes
  batch_v:        4096 * 1024 * 2   =     8,388,608 bytes
  batch_attn_out: 4096 * 4096 * 2   =    33,554,432 bytes
  batch_gate:     4096 * 14336 * 2  =   117,440,512 bytes
  batch_up:       4096 * 14336 * 2  =   117,440,512 bytes
  batch_act:      4096 * 14336 * 2  =   117,440,512 bytes
  batch_post_norm:4096 * 4096 * 2   =    33,554,432 bytes
  ──────────────────────────────────────────────
  Prefill total:                      ~534 MB

Grand total scratch:  ~534 MB

The prefill buffers dominate – they are the batch dimension multiplied by the hidden and FFN dimensions. For a model with ffn_dim = 14336 and a prefill chunk of 4096, each FFN buffer alone is 112 MB.

Here is the full memory picture for LLaMA 3.1-8B at 4096 context:

+-----------------------------------+-----------+
| Component                         | Memory    |
+-----------------------------------+-----------+
| Model weights (Q4_0)              |  ~4.3 GB  |
| KV cache (32 layers, 4096 ctx)    |   512 MB  |
| Scratch decode buffers            |  ~0.4 MB  |
| Scratch prefill buffers           |  ~534 MB  |
+-----------------------------------+-----------+
| Total                             |  ~5.3 GB  |
+-----------------------------------+-----------+

Buffer Reuse Within a Forward Pass

A key insight is that these buffers are reused within a single forward pass, not just across forward passes. Within the layer loop:

Layer L:
  residual: written by RMSNorm, read by QKV projection
  qkv:      written by QKV projection, read by attention + KV write
  attn_out: written by attention, read by O projection
  h1:       written by O projection, read by residual add
  ffn_gate: written by Gate GEMV, read by activation
  ffn_up:   written by Up GEMV, read by activation
  ffn_act:  written by activation, read by Down GEMV

Layer L+1:
  Same buffers, completely overwritten!

The transformer processes layers sequentially. Layer L’s ffn_gate output is consumed within layer L, then layer L+1 overwrites the same buffer with its own ffn_gate output. No per-layer scratch is needed – one set of buffers serves all layers.

The only per-layer storage is the KV cache (Chapter 45), which must retain values across the entire sequence.

The Ping-Pong Pattern in Detail

Let us trace the exact read/write pattern through two consecutive layers:

Layer 0:
  READ  h0         WRITE residual     (RMSNorm)
  READ  residual   WRITE qkv          (QKV GEMV)
  READ  qkv+cache  WRITE attn_out     (Attention)
  READ  attn_out   WRITE h1           (O GEMV)
  READ  h0, h1     WRITE h0           (Residual add: h0 += h1)
  READ  h0         WRITE residual     (RMSNorm)
  READ  residual   WRITE ffn_gate     (Gate GEMV)
  READ  residual   WRITE ffn_up       (Up GEMV)
  READ  gate, up   WRITE ffn_act      (SiLU * mul)
  READ  ffn_act    WRITE h1           (Down GEMV)
  READ  h0, h1     WRITE h0           (Residual add: h0 += h1)

Layer 1:
  READ  h0         WRITE residual     (RMSNorm)
  ... same pattern, same buffers, completely safe because
      each buffer is fully consumed before being overwritten ...

Notice that h0 is both read and written in the residual add step. This is safe because it is an element-wise operation (h0[i] += h1[i]), implemented as a fused add kernel that handles the in-place update correctly.

Allocation: All at Once, All FP16

The create() factory method allocates everything in one shot:

static ScratchBuffers create(Device& device, const AkunuModelConfig& cfg,
                             int max_context = 4096,
                             int prefill_chunk = 4096,
                             int max_chain = 128) {
    ScratchBuffers s;
    int dim = cfg.dim;
    int q_dim = cfg.q_dim;
    int kv_dim = cfg.kv_dim;
    int ffn_dim = cfg.ffn_dim;
    int vocab = cfg.vocab_size;

    // Decode buffers
    s.h0       = device.allocate(dim * 2);
    s.h1       = device.allocate(dim * 2);
    s.residual = device.allocate(dim * 2);
    s.qkv      = device.allocate((q_dim + 2 * kv_dim) * 2);
    s.attn_out = device.allocate((q_dim > dim ? q_dim : dim) * 2);
    // ... etc ...

    // Prefill buffers
    s.batch_h0 = device.allocate(prefill_chunk * dim * 2);
    // ... etc ...

    return s;
}

Every size is count * 2 because FP16 is 2 bytes per element. The token_ids buffer uses count * 4 because token IDs are 32-bit unsigned integers.

Notice the max(q_dim, dim) for attn_out – this handles the case where the attention output dimension might differ from the model dimension (though in practice they are usually equal).

Cleanup

Like the KV cache, cleanup is explicit:

void destroy(Device& device) {
    for (Buffer *b : {&h0, &h1, &residual, &qkv, &attn_out, &post_norm,
                      &ffn_gate, &ffn_up, &ffn_act, &logits, &token_ids,
                      &batch_h0, &batch_h1, &batch_residual,
                      &batch_q, &batch_k, &batch_v, &batch_attn_out,
                      &batch_gate, &batch_up, &batch_act, &batch_post_norm}) {
        device.free_buffer(*b);
    }
}

Every buffer gets freed. The initializer list is a convenient C++ trick for iterating over all the member buffers without repeating the free logic.

Post-Norm Buffer (Gemma Compatibility)

The post_norm and batch_post_norm buffers are specifically for Gemma-style architectures that use post-attention and post-FFN normalization:

Standard LLaMA:  x = x + Attn(Norm(x))
Gemma:           x = x + Norm(Attn(Norm(x)))
                             ^^^^
                          post-norm needs its own buffer

For models that do not use post-norm (most LLaMA variants), this buffer is allocated but never written to. The wasted memory is dim * 2 = 8 KB for the decode version – negligible.

Why Not Use a Memory Pool?

You might wonder: instead of individual named buffers, why not allocate one big slab and carve it up? Memory pools are common in GPU programming.

The answer is debuggability. With named buffers:

  • Metal GPU debugger shows “h0”, “ffn_gate”, “logits” etc. in the buffer list
  • Each buffer has a known size that matches its semantic purpose
  • There is no offset arithmetic to get wrong
  • Adding a new buffer is trivial – just add a field and an allocate call

The overhead of having ~22 separate MTLBuffer objects instead of 1 is negligible. Metal’s buffer creation is fast, and we only do it once at load time.

Summary

Key design principles:
  1. Pre-allocate ALL buffers at model load
  2. Zero allocation during inference
  3. Ping-pong (h0/h1) avoids read-write hazards
  4. Contiguous QKV with byte sub-offsets
  5. Separate decode (1-token) and prefill (N-token) buffer sets
  6. Every buffer is reused across all transformer layers
  7. Explicit create/destroy lifecycle

Memory hierarchy during inference:
  +─────────────────────────────────────+
  |  Model Weights (read-only, ~GB)     |  Largest
  +─────────────────────────────────────+
  |  Prefill Scratch (~500 MB)          |
  +─────────────────────────────────────+
  |  KV Cache (~512 MB for 4K ctx)      |
  +─────────────────────────────────────+
  |  Decode Scratch (~0.4 MB)           |  Smallest
  +─────────────────────────────────────+

With the KV cache and scratch buffers understood, we have covered the complete runtime memory picture. Every byte of GPU memory used during inference is accounted for: model weights, KV cache, and scratch buffers. No hidden allocations, no surprises, no fragmentation. This is what makes it possible to predict exactly whether a given model will fit in memory before loading it.