Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

The Prefill Phase

When you send a prompt to an LLM, something interesting happens before the model starts generating text: every single prompt token gets processed in parallel. This is the prefill phase, and it is architecturally distinct from the token-by-token decode phase that follows. In Akunu, the prefill phase lives in src/core/prefill.cpp and is responsible for transforming a sequence of token IDs into a populated KV cache, a set of logits for the last position, and the first predicted token.

This chapter will walk through the entire prefill pipeline, from the moment token IDs hit the GPU to the moment we get our first generated token back. Along the way, we will see why prefill uses GEMM (matrix-matrix multiply) instead of GEMV (matrix-vector multiply), how chunked prefill works, and why the whole thing is fundamentally different from decode.

Why Prefill Is Not Just “Decode but Faster”

At first glance, you might think: “Well, prefill is just running the model on all the prompt tokens. Can’t we just call decode N times?” Technically yes, but doing so would be absurdly slow. The key insight is that prefill processes all tokens simultaneously through the transformer layers, whereas decode processes one token at a time.1

The difference comes down to arithmetic intensity:

PhaseOperationMatrix ShapeBottleneck
PrefillGEMM[seq_len, K] @ [N, K]^TCompute-bound
DecodeGEMV[1, K] @ [N, K]^TMemory-bound

During prefill with a prompt of length S, the activation tensor is [S, dim] rather than [1, dim]. This means the weight matrix gets reused across S rows of activations, giving us an arithmetic intensity that scales with S. A 2048-token prefill reuses every weight element 2048 times instead of once. This is why prefill can achieve throughput measured in thousands of tokens per second, while decode is measured in tens of tokens per second.2

In Akunu’s code, this shows up directly. The encode_prefill function signature tells the story:

uint32_t encode_prefill(Device& device, WeightProvider& weights,
                        const AkunuModelConfig& cfg, const ArchDescriptor& arch,
                        KVCache& kv_cache, ScratchBuffers& scratch,
                        const uint32_t *token_ids, int seq_len,
                        int start_position = 0);

It takes the full array of token_ids and seq_len, processes everything in one shot, and returns a single uint32_t: the predicted next token.

The Prefill Pipeline at a Glance

Before diving into the code, let’s lay out the full pipeline. Every token in the prompt passes through these stages:

                    ┌────────────────────────────────────────────┐
                    │            PREFILL PIPELINE                 │
                    ├────────────────────────────────────────────┤
                    │                                            │
  token_ids[S] ───>│  1. BATCH EMBEDDING  [S] → [S, dim]        │
                    │  2. EMBEDDING NORM   (if needed)           │
                    │  3. FIRST LAYER NORM (RMSNorm/LayerNorm)   │
                    │                                            │
                    │  ┌─── LAYER LOOP (×N layers) ───────────┐  │
                    │  │  a. QKV GEMM projections              │  │
                    │  │  b. QK-norm (if model has it)         │  │
                    │  │  c. Fused RoPE + KV cache write       │  │
                    │  │  d. Flash Attention (prefill kernel)   │  │
                    │  │  e. Output GEMM projection            │  │
                    │  │  f. Fused residual + FFN norm          │  │
                    │  │  g. Gate GEMM + Up GEMM                │  │
                    │  │  h. SiLU activation                    │  │
                    │  │  i. Down GEMM projection               │  │
                    │  │  j. Fused residual + next attn norm    │  │
                    │  └───────────────────────────────────────┘  │
                    │                                            │
                    │  4. OUTPUT NORM (final RMSNorm)            │
                    │  5. LOGIT GEMV (last position only)        │
                    │  6. ARGMAX → next token                    │
                    │                                            │
                    └────────────────────────────────────────────┘

Notice step 5: even though the entire prompt flows through the transformer in parallel, we only need logits for the last position. This is because autoregressive generation predicts the next token, and the “next” token after the prompt corresponds to the last position’s output. Akunu exploits this by using a GEMV (single-row projection) for the logit computation instead of a full GEMM.3

Chunked Prefill

Long prompts cannot always be processed in a single batch. GPU memory for scratch buffers is proportional to seq_len * dim, and for a 128K-token prompt on a model with dim=4096, that is over 1 GB just for one activation buffer. Akunu addresses this with chunked prefill: the prompt is split into chunks of at most max_prefill_chunk tokens (default: 4096), and each chunk is processed separately.

Here is the chunking logic from run_decode_loop:

int chunk_size = state.scratch.max_prefill_chunk;
uint32_t next_token = 0;
int prefill_pos = 0;
while (prefill_pos < n_prompt) {
    int chunk = std::min(chunk_size, n_prompt - prefill_pos);
    next_token = encode_prefill(*state.device, *state.weights, state.config,
                                state.arch, state.kv_cache, state.scratch,
                                prompt_tokens + prefill_pos, chunk,
                                start_pos + prefill_pos);
    prefill_pos += chunk;
}

Each call to encode_prefill processes chunk tokens starting at position start_pos + prefill_pos. The start_position parameter is critical: it tells the RoPE kernel where in the sequence these tokens actually belong, and it tells the KV cache where to write the K and V entries. Without correct position tracking, chunked prefill would produce garbage because the positional embeddings would be wrong.

The chunk size of 4096 is configured in ChipConfig:

c.max_prefill_chunk = 4096;

This value balances GPU utilization (larger chunks mean more parallelism and better GEMM efficiency) against memory pressure (scratch buffers must be allocated for the maximum chunk size).

Stage 1: Batch Embedding

The first step converts token IDs to dense vectors. Akunu dispatches an embedding lookup kernel that reads from the embedding weight table and writes to the batch_h0 scratch buffer:

Pipeline pso = device.get_pipeline(device.embedding_kernel_for(emb_dtype));
device.set_pipeline(pso);
device.set_buffer(scratch.token_ids, 0, 0);    // input: token IDs
device.set_buffer(emb_weight, 0, 1);            // weights: [vocab, dim]
device.set_buffer(scratch.batch_h0, 0, 2);      // output: [seq_len, dim]
device.dispatch_threads(Dim3(dim, seq_len), Dim3(std::min(dim, 256)));

The dispatch is a simple 2D grid: one thread per (dimension, token) pair. Each thread copies one element from the embedding table, so the entire [seq_len, dim] output matrix is filled in a single dispatch.

For models with embedding scaling (like Gemma), an additional kernel multiplies every element by the scale factor:

if (arch.embedding_scale > 0.0f) {
    Pipeline scale_pso = device.get_pipeline("temperature_scale_f16");
    // ... applies: batch_h0[i] *= embedding_scale for all elements
}

This reuses the temperature_scale_f16 kernel (originally written for sampling) – a nice example of kernel reuse. The same “multiply every element by a scalar” operation appears in multiple contexts.

MLX Quantized Embeddings

When the model uses MLX-format quantized weights, even the embedding table is quantized. The embedding lookup kernel (embedding_lookup_mlx_q4) must dequantize on-the-fly:

const uint32_t word = W_u32[token_id * n_u32_per_row + u32_idx];
const uint qval = (word >> (within * bits)) & mask;
output[token_idx * K + d_idx] = half(s * float(qval) + b);

Each 32-bit word packs 32/bits quantized values. The kernel extracts the relevant nibble (or bit group), multiplies by the per-group scale, adds the per-group bias, and writes the dequantized FP16 value. The cost is negligible – embedding lookup is always memory-bound, and the few extra ALU ops for dequantization are effectively free.

Stage 2: Normalization

Before entering the layer loop, the activations must be normalized. Akunu supports two paths:

Model TypeNormalizationHas Bias?
Standard LLM (Llama, Gemma, etc.)RMSNormNo
BERT/Encoder modelsLayerNormYes

The choice is driven by whether the layer’s norm weight has an associated bias tensor:

Buffer norm_b = weights.get_tensor("layers.0.attention_norm.bias");
if (norm_b.handle) {
    // LayerNorm path (BERT)
    Pipeline pso = device.get_pipeline("layernorm_f16");
} else {
    // RMSNorm path (standard LLM)
    Pipeline pso = device.get_pipeline("rmsnorm_f16");
}

Both kernels operate on batch_h0 (shape [seq_len, dim]) and produce batch_residual. The dispatch is Dim3(seq_len) threadgroups, each with up to 1024 threads – one threadgroup per row (token position), with threads cooperatively computing the norm statistics.

Stage 3: The Layer Loop

The layer loop is where the bulk of prefill computation happens. For each of the n_layers transformer layers, we execute approximately 10 GPU dispatches. Let’s walk through each.

3a. QKV Projections

The attention mechanism needs three projections: Q (query), K (key), and V (value). Akunu supports two approaches:

Fused QKV (BERT-style): A single GEMM produces the concatenated [Q|K|V] output, which is then split with a GPU kernel:

dispatch_gemm(device, scratch.batch_residual, qkv_w, scratch.batch_gate,
              seq_len, qkv_dim, dim, qkv_dtype, ...);
// Then split on GPU:
Pipeline split_pso = device.get_pipeline("qkv_split_f16");

This is more efficient when Q, K, and V share the same weight dtype and the combined projection fits nicely into a single GEMM dispatch.

Separate Q/K/V (standard LLM): Three independent GEMMs:

dispatch_gemm(device, scratch.batch_residual, q_w, scratch.batch_q,
              seq_len, q_dim, dim, q_dtype, ...);
dispatch_gemm(device, scratch.batch_residual, k_w, scratch.batch_k,
              seq_len, kv_dim, dim, k_dtype, ...);
dispatch_gemm(device, scratch.batch_residual, v_w, scratch.batch_v,
              seq_len, kv_dim, dim, v_dtype, ...);

For GQA models where kv_dim < q_dim, separate projections are actually better because the K and V GEMMs are smaller and can be dispatched with appropriately sized grids.

3b. The GEMM Dispatch Function

Every projection in prefill goes through dispatch_gemm, which is the central GEMM dispatcher. It handles:

  1. Kernel selection: The dtype descriptor tells it which kernel to use (e.g., simd_gemm_f16, simd_gemm_q4_0).
  2. Small-M optimization: For M between 2 and 8, it uses the “small” GEMM variant with TM=8 instead of TM=32, avoiding wasted computation on padding.
  3. Function constant specialization: When K is a multiple of 32, the kernel is specialized with K baked in as a compile-time constant. This eliminates a register and enables loop unrolling.
  4. MLX format handling: MLX quantized weights need extra params (group_size, bits, weight_bytes).

The tile geometry is fixed:

ParameterValueMeaning
TM32 (or 8 for small M)Activation rows per tile
TN64Weight rows per tile
TK32K-dimension per accumulation step
Threadgroup(32, 4) = 128 threads4 SIMD groups

The grid is computed as:

int gridX = (N + TN - 1) / TN;
int gridY = (M + TM - 1) / TM;
device.dispatch(Dim3(gridX, gridY), Dim3(32, 4));

Threadgroup memory is allocated for the cooperative tile loading:

int loadBytes = (TN * TK + TM * TK) * 2;   // weight + activation tiles in FP16
int storeBytes = TN * TM * 4;                // output tile in FP32
int tgMem = std::max(loadBytes, storeBytes);  // reuse same memory

The max here is clever: during the accumulation phase, the memory holds the input tiles; during the output phase, it holds the result tile. They never overlap in time, so the same memory region serves both purposes.4

3c. QK-Norm

Some models (notably DeepSeek, Gemma 2) apply RMSNorm to the Q and K projections per head before attention. This ensures that the dot product scores stay in a reasonable range regardless of head dimension:

if (arch.has_qk_norm) {
    Pipeline hn_pso = device.get_pipeline("head_rmsnorm_f16");
    // Q norm: grid = (n_heads, seq_len), threads = head_dim
    device.dispatch(Dim3(n_heads, seq_len), Dim3(head_dim));
    // K norm: grid = (n_kv_heads, seq_len), threads = head_dim
    device.dispatch(Dim3(n_kv_heads, seq_len), Dim3(head_dim));
}

The dispatch geometry is interesting: one threadgroup per (head, position) pair, with head_dim threads per group. Each threadgroup normalizes exactly one head’s worth of data.

3d. Fused RoPE + KV Cache Write

This is one of the most performance-sensitive dispatches in prefill. A single kernel handles three operations simultaneously:

  1. Apply RoPE to Q (in-place)
  2. Apply RoPE to K and write to the K cache
  3. Copy V to the V cache
const char *fused_kernel = is_neox
    ? "rope_neox_batch_kv_write_f16"
    : "rope_batch_kv_write_f16";

Two RoPE variants are supported: the original “interleaved” layout (pairs of adjacent elements are rotated) and the “neox” layout (first and second halves are rotated). The fused kernel processes all seq_len positions in parallel, writing each K/V vector to its correct position in the KV cache based on start_position + position_within_batch.

The dispatch grid is:

device.dispatch_threads(Dim3(head_dim / 2, n_heads, seq_len),
                        Dim3(std::min(head_dim / 2, 32)));

Each thread handles one complex pair (two elements) of one head at one position.

3e. Flash Attention (Prefill)

Prefill attention is where things get really interesting. Unlike decode attention (where each query has one row), prefill attention has seq_len query rows, all attending to kv_seq_len key-value positions. Akunu selects between three attention kernels based on sequence length:

                        ┌─────────────────────┐
                        │   seq_len check      │
                        └──────┬──────────────┘
                               │
              ┌────────────────┼────────────────┐
              │                │                │
         seq_len >= 1024  seq_len >= thresh  otherwise
              │                │                │
              ▼                ▼                ▼
    ┌─────────────────┐ ┌──────────┐  ┌────────────────┐
    │ Prefill V2      │ │ Prefill  │  │ Decode kernel  │
    │ BQ=32, register │ │ V1       │  │ (per-query TG) │
    │ output, exp2    │ │ simd MMA │  │                │
    └─────────────────┘ └──────────┘  └────────────────┘

Prefill V2 (flash_attention_prefill_v2_f16): For long sequences (>= 1024), this kernel processes 32 query rows per threadgroup using simdgroup matrix multiply-accumulate (MMA). Output stays in registers rather than threadgroup memory, saving 16KB per threadgroup. It uses exp2 instead of exp for faster softmax computation.5

Prefill V1 (flash_attention_prefill_f16): For medium sequences, this kernel processes 8 or 16 query rows per threadgroup with simdgroup MMA and threadgroup memory for the output accumulator.

Decode fallback: For very short sequences (< threshold), it simply dispatches one threadgroup per query position using the decode attention kernel. The threshold is:

int nq_rows = (head_dim <= 64) ? 16 : (head_dim <= 128) ? 8 : 0;
int v2_threshold = (nq_rows > 0) ? std::max(nq_rows * 2, 16) : INT_MAX;

All three support non-causal attention for encoder models (BERT) via function constant specialization:

constant bool FC_NON_CAUSAL [[function_constant(3)]];

When FC_NON_CAUSAL is true, the causal mask is skipped, allowing every position to attend to every other position.

3f-3j. FFN Block

After attention, the FFN block follows the standard SwiGLU pattern:

  1. Fused residual + FFN norm: Adds the attention output to the residual stream and normalizes. Uses residual_rmsnorm_f16 for standard LLMs or decomposed vector_add_f16 + layernorm_f16 for BERT.

  2. Gate GEMM: batch_attn_out[S, dim] -> batch_gate[S, ffn_dim]

  3. Up GEMM: batch_attn_out[S, dim] -> batch_up[S, ffn_dim]

  4. SiLU activation: batch_act = SiLU(batch_gate) * batch_up element-wise

  5. Down GEMM: batch_act[S, ffn_dim] -> batch_residual[S, dim]

  6. Fused residual + next attn norm: Prepares the residual stream for the next layer.

The Gate and Up GEMMs read from the same input buffer, which is great for cache locality on the GPU side – the activation data loaded for Gate is still in the SLC when Up runs.

Post-Attention and Post-FFN Norms (Gemma 3)

Some architectures add extra normalization after attention and/or FFN outputs. Akunu handles these via descriptor flags:

if (arch.has_post_attn_norm) {
    // RMSNorm on batch_residual → batch_post_norm
}
if (arch.has_post_ffn_norm) {
    // RMSNorm on batch_residual → batch_post_norm
}

This is driven by ArchDescriptor, so adding support for a new model that uses post-norms requires zero code changes – just setting the descriptor flags.

Stage 4: Output Norm

After the layer loop, the final hidden states need one more normalization. For standard LLMs, this is a fused residual + RMSNorm:

Pipeline pso = device.get_pipeline("residual_rmsnorm_f16");
device.set_buffer(last_fused_input_ffn, 0, 0);  // last layer's output
device.set_buffer(scratch.batch_h1, 0, 1);       // residual stream
device.set_buffer(output_norm_w, 0, 2);           // norm weights
device.set_buffer(scratch.batch_h0, 0, 3);        // updated residual
device.set_buffer(scratch.batch_residual, 0, 4);   // normed output

For embedding models (BERT), there is no output norm – the final residual is directly used for mean-pooling:

if (arch.is_embedding_model) {
    // vector_add: batch_residual = last_fused_input_ffn + batch_h1
    device.end_encoding_sync();
    kv_cache.advance(seq_len);
    return 0;  // no logit token for embedding models
}

Stage 5: Logit Projection (Last Position Only)

Here is where prefill gets clever. We have batch_residual of shape [seq_len, dim], but we only need logits for the last token. Instead of running a full GEMM ([seq_len, dim] @ [vocab_size, dim]^T), Akunu extracts just the last row and runs a GEMV:

int last_row_offset = (seq_len - 1) * dim * 2;  // byte offset
device.set_buffer(scratch.batch_residual, last_row_offset, 0);
device.set_buffer(logit_w, 0, 1);
device.set_buffer(scratch.logits, 0, 2);
device.dispatch(Dim3(n_groups), Dim3(logit_dt.gemv_tg_size));

The last_row_offset is used as a buffer offset, so the GEMV kernel sees only the last row’s data starting at buffer index 0. This converts a potentially massive [seq_len, vocab_size] GEMM into a single [1, vocab_size] GEMV – a huge savings when vocab_size is 128K+ tokens.

For models with tied embeddings (where the output projection reuses the embedding table), the same weight buffer is used:

const char *logit_name = arch.tie_embeddings
    ? "token_embedding.weight"
    : "output.weight";

Stage 6: Argmax

The final step finds the most probable next token:

Pipeline pso = device.get_pipeline("argmax_f16");
device.set_pipeline(pso);
device.set_buffer(scratch.logits, 0, 0);
device.set_buffer(scratch.token_ids, 0, 1);
uint32_t vocab = cfg.vocab_size;
device.set_bytes(&vocab, sizeof(vocab), 2);
device.dispatch(Dim3(1), Dim3(1024));

One threadgroup of 1024 threads. Each thread scans a strided portion of the logits, finds its local maximum, then a two-level SIMD reduction finds the global maximum. The winning token ID is written to scratch.token_ids[0].

After end_encoding_sync() (which waits for GPU completion), the result is read back:

device.end_encoding_sync();
kv_cache.advance(seq_len);
return ((uint32_t *)scratch.token_ids.contents)[0];

Note kv_cache.advance(seq_len): this updates the cache’s write pointer so that subsequent decode operations know where the cached K/V data ends.

Scratch Buffer Layout

Prefill requires a constellation of temporary GPU buffers. Here is what each one holds:

BufferShapeUsed For
batch_h0[S, dim]Embedding output, residual stream
batch_h1[S, dim]Residual stream (second copy)
batch_residual[S, dim]Normed activations, GEMM input
batch_q[S, q_dim]Query projections
batch_k[S, kv_dim]Key projections
batch_v[S, kv_dim]Value projections
batch_attn_out[S, q_dim]Attention output, FFN norm output
batch_gate[S, ffn_dim]Gate projection, also used for fused QKV
batch_up[S, ffn_dim]Up projection
batch_act[S, ffn_dim]SiLU(gate) * up
batch_post_norm[S, dim]Post-attn/FFN norm scratch
logits[vocab_size]Final logit buffer
token_ids[max_chunk+1]Input token IDs + argmax result

All buffers are allocated at model load time for S = max_prefill_chunk. The total memory is roughly:

S * (5*dim + 2*q_dim + 2*kv_dim + 3*ffn_dim) * 2 bytes

For a typical 7B model (dim=4096, q_dim=4096, kv_dim=1024, ffn_dim=14336) with S=4096, this is about 700 MB.

The BERT/Encoder Path

Akunu also supports encoder-only models (BERT, nomic-bert) through encode_prefill_bert. The key differences from the LLM path:

AspectLLM PrefillBERT Prefill
Positional encodingRoPE (during attention)Learned absolute embeddings (added to tokens)
NormalizationRMSNormLayerNorm (with bias)
Attention maskCausalNon-causal (all-to-all)
FFN activationSiLU/SwiGLUGELU (no gate)
Linear biasNoYes
OutputLogits + argmaxRaw hidden states for pooling

The BERT path is activated when arch.is_embedding_model is true. Most of the code is shared – the layer loop is identical, just using different kernel variants selected by the descriptor system.

Timing and Statistics

Prefill timing is captured for performance reporting:

auto prefill_start = std::chrono::high_resolution_clock::now();
// ... prefill ...
auto prefill_end = std::chrono::high_resolution_clock::now();
double prefill_ms = std::chrono::duration<double, std::milli>(
    prefill_end - prefill_start).count();
stats.prefill_time_ms = (float)prefill_ms;
stats.prefill_tokens_per_sec = (float)(n_prompt * 1000.0 / prefill_ms);

Note that prefill_tokens_per_sec divides the total number of prompt tokens by the wall-clock time, including all chunks. This gives the user-facing throughput metric that represents actual prompt processing speed.

Performance Characteristics

Prefill performance depends on several factors:

Typical prefill throughput on Apple Silicon:

HardwareModelQuantPrompt Tokens/sec
M1 Max (32 GPU)Llama 3.1 8BQ4_0~800-1200
M2 Ultra (76 GPU)Llama 3.1 8BQ4_0~2500-3500
M4 Pro (20 GPU)Llama 3.1 8BQ4_0~1000-1500

These numbers are heavily dominated by GEMM throughput, which scales roughly linearly with GPU core count.

Summary

The prefill phase is conceptually simple – run the full transformer on all prompt tokens – but the implementation details matter enormously for performance:

  1. GEMM not GEMV: Processing all tokens simultaneously gives orders-of-magnitude better arithmetic intensity.
  2. Chunked execution: Bounded scratch memory via configurable chunk sizes.
  3. Last-position-only logits: A GEMV instead of a GEMM for the output projection, saving (seq_len - 1) * vocab_size unnecessary computations.
  4. Fused operations: RoPE + KV write, residual + norm, and activation + gating are all fused to minimize GPU dispatch overhead and memory traffic.
  5. Architecture-driven dispatch: The same code handles LLMs and BERT-style encoders through the descriptor system.

In the next chapter, we will see what happens after prefill completes: the decode loop that generates tokens one at a time.



  1. Vaswani, A., et al. “Attention Is All You Need.” NeurIPS 2017. The autoregressive property means each token depends on all previous tokens, which is why generation must be sequential while prompt processing can be parallel. See https://arxiv.org/abs/1706.03762.

  2. Pope, R., et al. “Efficiently Scaling Transformer Inference.” MLSys 2023. This paper provides an excellent analysis of the memory-bound vs compute-bound regimes of transformer inference. See https://arxiv.org/abs/2211.05102.

  3. This “last position only” optimization is standard in all LLM inference engines. During prefill, all positions produce hidden states, but only the last position’s logits determine the next token. Some engines (e.g., vLLM) allow returning all logits for perplexity evaluation.

  4. This is a common GPU programming pattern called “ping-pong buffering” or “buffer aliasing.” The Metal runtime does not enforce temporal aliasing rules for threadgroup memory, so as long as barriers are placed correctly, the same memory can serve different purposes at different times.

  5. The exp2 trick replaces exp(x) with exp2(x * log2(e)). On Apple Silicon, fast::exp2 uses the hardware transcendental unit and is faster than exp. The scale factor is pre-multiplied into Q during loading.