Kernel Architecture Overview
If you have made it this far in the book, you understand how akunu’s Swift host code
orchestrates inference – how buffers are allocated, how the command encoder sequences
dispatches, how the KV cache grows and shifts. But we have been treating each
computeEncoder.dispatchThreadgroups(...) call as a black box. Time to open every
single one of those boxes.
Akunu ships roughly 135 Metal shader files, spread across 11 subdirectories, each one a hand-tuned GPU program that runs on Apple Silicon. This chapter is the map. We will cover the directory layout, the shared header infrastructure, the parameter struct conventions, the function-constant specialization system, the buffer binding conventions, and finally a taxonomy of every kernel category. By the end you will know where to look for any particular kernel and why it was written the way it was.
33.1 The 10,000-Foot View: Directory Layout
All Metal source lives under a single tree:
backend/metal/kernels/
ShaderTypes.h <-- Parameter structs (shared with Swift)
KernelCommon.h <-- Block types, helpers, constants
metal/kernel/
activation/ <-- 2 files (silu, gelu)
attention/ <-- 6 files (decode, fast-decode, parallel, prefill, softmax, logit_cap)
common/ <-- 6 files (bias_add, residual_add, vector_add, transpose, head_rearrange, qkv_split)
conv/ <-- 1 file (conv1d_f16)
convert/ <-- 12 files (dequant for every quant format + f16<->f32)
embedding/ <-- 12 files (lookup for every quant format + bf16 + pos_embed_add)
fused/ <-- 2 files (gemv_q8_0_head_rmsnorm, whisper_gemv_fused)
kv_cache/ <-- 2 files (kv_cache_write, kv_cache_shift)
matmul/ <-- 72 files (gemv, gemv_wide, gemv_batched, simd_gemm, simd_gemm_small)
norm/ <-- 5 files (rmsnorm, rmsnorm_gemma, layernorm, residual_rmsnorm, head_rmsnorm)
rope/ <-- 7 files (rope, rope_neox, fused rope+kv variants)
sampling/ <-- 7 files (argmax, topk, temperature, repetition_penalty, gumbel, grammar, whisper_suppress)
Let us count that up:
+---------------------+-------+
| Category | Files |
+---------------------+-------+
| matmul (GEMV/GEMM) | 72 |
| embedding | 12 |
| convert (dequant) | 12 |
| sampling | 7 |
| rope | 7 |
| attention | 6 |
| common | 6 |
| norm | 5 |
| activation | 2 |
| fused | 2 |
| kv_cache | 2 |
| conv | 1 |
+---------------------+-------+
| Total | ~135 |
+---------------------+-------+
The overwhelming majority – 72 out of 135 – are in matmul/. That is not a
surprise. The single hottest operation in LLM inference is the matrix-vector
multiply (GEMV for decode) and matrix-matrix multiply (GEMM for prefill). When
you multiply those two kernel shapes (GEMV + GEMM) by the number of quantization
formats akunu supports (Q4_0, Q4_1, Q4_K, Q5_0, Q5_K, Q6_K, Q8_0, Q2_K, Q3_K,
F16, BF16, MLX-Q4, MLX-Q3, MLX-Q6, MLX-Q8), and then add the variant axes
(standard, wide, batched, small-M, fused-silu), you get a combinatorial explosion
of kernel files.
Here is how the matmul/ directory decomposes:
matmul/
gemv_*.metal (18 files) -- Single-row GEMV (M=1 decode)
gemv_wide_*.metal ( 7 files) -- Wide GEMV (large N, e.g. vocab projection)
gemv_batched_*.metal (12 files) -- Batched GEMV (M=2..16, speculative decode)
simd_gemm_*.metal (17 files) -- Tiled GEMM (M>=32, prefill)
simd_gemm_small_*.metal(16 files) -- Small-M GEMM (M=2..8, small prefill)
gemv_*_silu.metal ( 5 files) -- Fused SiLU(gate)*up + GEMV
That is the landscape. Let us now zoom in on the two shared header files that every kernel includes.
33.2 ShaderTypes.h: The Contract Between Swift and Metal
Every Metal kernel receives its parameters via a constant buffer. On the Swift side,
that buffer is a MTLBuffer filled with a C struct. On the GPU side, the kernel
reads that same struct. The contract between the two is ShaderTypes.h.
There is a critical comment at the top of this file:
/*
* Shared type definitions used by both Metal kernels (.metal) and Swift host code.
*
* CRITICAL: Any change here MUST be mirrored in Sources/KernelStore/MetalTypes.swift.
* All structs are padded to 16-byte boundaries for Metal argument buffer alignment.
*/
Metal requires that constant buffer offsets and struct sizes be aligned to 16 bytes.
Every struct in ShaderTypes.h is manually padded with _pad fields to guarantee
this. Let us walk through every parameter struct.
GEMMParams (32 bytes)
struct GEMMParams {
uint32_t M; // Rows of A / rows of C
uint32_t N; // Columns of B / columns of C
uint32_t K; // Columns of A / rows of B
uint32_t lda; // Leading dimension of A
uint32_t ldb; // Leading dimension of B
uint32_t ldc; // Leading dimension of C
float alpha; // C = alpha * A @ B + beta * C
float beta;
};
This is the workhorse parameter struct. It is used by every GEMV and GEMM kernel.
The M, N, K triple defines the matrix dimensions, and lda/ldb/ldc are
the leading dimensions (row strides), allowing matrices that are sub-views of larger
buffers. The alpha/beta fields enable BLAS-style C = alpha*A*B + beta*C
semantics – useful for residual connections and accumulation.
Here is how it maps onto the GEMV case (M=1):
x [1, K] @ W^T [N, K] --> y [1, N]
+---+---+---+---+---+---+---+---+ GEMMParams
| M | N | K |lda|ldb|ldc| a | b |
| 1 |4096|4096|4096|4096|4096|1.0|0.0|
+---+---+---+---+---+---+---+---+
0 4 8 12 16 20 24 28 byte offset
AttentionParams (32 bytes)
struct AttentionParams {
uint32_t seq_len;
uint32_t kv_seq_len;
uint32_t head_dim;
uint32_t n_heads;
uint32_t n_kv_heads;
float scale; // 1.0 / sqrt(head_dim)
uint32_t kv_stride; // elements between KV heads
uint32_t q_stride; // elements between Q/O rows
};
The attention params carry everything the flash-attention kernels need. The
kv_seq_len can differ from seq_len during decode (where seq_len=1 but
kv_seq_len could be thousands). The kv_stride and q_stride fields allow
flexible memory layouts – if zero, the kernel falls back to kv_seq_len * head_dim
or n_heads * head_dim respectively. This is what lets the same kernel handle both
contiguous and interleaved head layouts.
RMSNormParams (16 bytes)
struct RMSNormParams {
uint32_t dim;
float eps;
uint32_t _pad0;
uint32_t _pad1;
};
Minimal. Just the dimension and epsilon. The two pad fields bring it to exactly 16 bytes. The kernel figures out which row to process from its threadgroup position in the grid.
RoPEParams (32 bytes)
struct RoPEParams {
uint32_t seq_len;
uint32_t head_dim;
uint32_t n_heads;
uint32_t pos_offset; // global position for decode step
float theta; // base frequency (default 10000.0)
uint32_t row_stride; // elements between rows
uint32_t _pad0;
uint32_t _pad1;
};
The pos_offset is the key field here – during decode, each token’s position is
pos_offset, not derived from the sequence index. The theta field (default
10000.0) is the RoPE base frequency, configurable per model.
MLXParams (32 bytes)
struct MLXParams {
uint32_t M;
uint32_t N;
uint32_t K;
uint32_t group_size; // quantization group size (typically 64)
uint32_t bits; // bits per value (4 or 8)
uint32_t weight_bytes; // byte offset to scales section
uint32_t _pad0;
uint32_t _pad1;
};
MLX-format weights pack everything into a single contiguous buffer:
[packed_weights | scales | biases]. The weight_bytes field tells the kernel
where the scales section starts, and from there the biases follow immediately at
scales + N * (K / group_size). The bits field selects between 3-bit, 4-bit,
6-bit, and 8-bit dequantization paths.
Fused Parameter Structs
Akunu has several fused kernels that combine two operations into one dispatch. Each has its own parameter struct:
RoPEQKVWriteParams (32 bytes) -- Fused Q/K-RoPE + KV cache write
KVCacheWriteParams (32 bytes) -- KV cache write (separate K and V)
KVCacheShiftParams (32 bytes) -- KV cache left-shift (ring-buffer eviction)
GEMVHeadNormParams (32 bytes) -- Fused GEMV + per-head RMSNorm
GEMVKVParams (32 bytes) -- Fused GEMV + KV cache write
HeadNormParams (32 bytes) -- Per-head RMSNorm (standalone)
Conv1DParams (32 bytes) -- Conv1D parameters (Whisper)
And simpler ones:
ElementwiseParams (16 bytes) -- Just a count
SoftmaxParams (16 bytes) -- rows, cols
EmbeddingParams (16 bytes) -- num_tokens, dim
LayerNormParams (16 bytes) -- dim, eps
TemperatureScaleParams (16 bytes) -- inv_temperature, count
RepetitionPenaltyParams (16 bytes) -- penalty, n_tokens
The pattern is consistent: every struct is either 16 or 32 bytes, always 16-byte aligned, always with explicit padding.
33.3 KernelCommon.h: Shared Infrastructure
Every .metal file includes KernelCommon.h. This header defines:
Hardware Constants
constant constexpr uint SIMD_WIDTH = 32;
constant constexpr uint MAX_TG_MEMORY = 32768; // 32 KB
constant constexpr uint SIMD_TILE = 8; // native simdgroup_matrix dimension
Apple Silicon GPUs have a SIMD width of 32 threads (unlike NVIDIA’s 32 or AMD’s 64).
The SIMD_TILE = 8 is the native dimension of Apple’s simdgroup_matrix operations
– all simdgroup matrix operations work on 8x8 tiles.
GEMM Tiling Constants
constant constexpr uint TILE_M = 64;
constant constexpr uint TILE_N = 64;
constant constexpr uint TILE_K = 32;
constant constexpr uint GEMM_TG_WIDTH = 32;
constant constexpr uint GEMM_TG_HEIGHT = 4;
constant constexpr uint GEMM_TG_SIZE = 128; // 4 SIMD groups
The make_uniform() Helper
inline int make_uniform(int val) {
return simd_broadcast_first(val);
}
This is a surprisingly important optimization. When you write for (int i = 0; i < N; i++),
the Metal compiler does not know whether N is the same across all threads. If it
might differ, the compiler must generate divergent branching code. By wrapping
the loop bound in make_uniform(), you explicitly tell the compiler “this value is
identical across all threads in the SIMD group,” enabling it to use uniform branch
prediction and avoid per-lane divergence handling.
You will see make_uniform() wrapped around virtually every loop bound in every
kernel.
Quantized Block Types
#define QK4_0 32 // elements per Q4_0 block
#define QK8_0 32 // elements per Q8_0 block
#define QK_K 256 // elements per K-quant superblock
struct block_q4_0 {
half d; // scale factor (2 bytes)
uint8_t qs[QK4_0 / 2]; // 16 bytes of nibble-packed values
}; // Total: 18 bytes per 32 elements = 4.5 bits/element
struct block_q8_0 {
half d; // scale factor (2 bytes)
int8_t qs[QK8_0]; // 32 bytes of 8-bit values
}; // Total: 34 bytes per 32 elements = 8.5 bits/element
struct block_q4_K {
half d; // super-block scale (2 bytes)
half dmin; // super-block min scale (2 bytes)
uint8_t scales[K_SCALE_SIZE]; // 12 bytes of packed 6-bit scales
uint8_t qs[QK_K / 2]; // 128 bytes of nibble-packed values
}; // Total: 144 bytes per 256 elements = 4.5 bits/element
Here is a visual layout of block_q4_0:
block_q4_0: 18 bytes total, 32 elements
+---------+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| d (f16)|q0|q1|q2|q3|q4|q5|q6|q7|q8|q9|qA|qB|qC|qD|qE|qF|
+---------+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| 2 bytes | 16 bytes |
+---------+-------------------------------------------------+
Each byte qs[j] holds two 4-bit values:
element j = (qs[j] & 0x0F) - 8 (low nibble)
element j + 16 = (qs[j] >> 4) - 8 (high nibble)
Dequantization: value = d * (nibble - 8)
And block_q4_K, the “K-quant” superblock:
block_q4_K: 144 bytes total, 256 elements
+------+------+-----------+---------------------------------------+
|d (f16)|dmin | scales[12]| qs[128] |
|2 bytes|2 bytes| 12 bytes | 128 bytes |
+------+------+-----------+---------------------------------------+
The 256 elements are divided into 8 sub-blocks of 32 elements each.
Each sub-block has its own 6-bit scale and 6-bit min, packed into
the 12-byte scales array.
Dequantization:
value = d * sub_scale * nibble - dmin * sub_min
Threadgroup Reduction Helpers
inline float tg_reduce_sum(float val, uint sgid, uint slid,
uint tg_size, threadgroup float *shared);
inline float tg_reduce_max(float val, uint sgid, uint slid,
uint tg_size, threadgroup float *shared);
These implement the classic two-phase reduction pattern:
Phase 1: simd_sum() within each SIMD group
Phase 2: Write lane-0 results to shared[], barrier, then simd_sum() in SG 0
Example: 256 threads = 8 SIMD groups
SG0: simd_sum → shared[0]
SG1: simd_sum → shared[1]
...
SG7: simd_sum → shared[7]
--- barrier ---
SG0: reads shared[0..7], simd_sum → shared[0]
--- barrier ---
All threads read shared[0]
Activation Functions
inline half act_silu(half x) {
return x / (half(1) + exp(-x));
}
inline float act_gelu_f32(half x) {
float xf = float(x);
constexpr float SQRT_2_OVER_PI = 0.7978845608f;
constexpr float GELU_COEF_A = 0.044715f;
return 0.5f * xf * (1.0f + precise::tanh(SQRT_2_OVER_PI * xf * (1.0f + GELU_COEF_A * xf * xf)));
}
Note that act_gelu_f32 returns float, not half. The comment explains why:
“returns float to avoid F16 overflow in gate*up multiplication.” When you compute
gelu(gate) * up, both operands can be large in FP16, and the product can overflow.
Computing in FP32 avoids this.
33.4 Function Constant Specialization
Metal’s function constants are compile-time values that can vary between pipeline states. Think of them as C++ template parameters that you set at pipeline creation time, not at shader compile time. The Metal compiler generates specialized variants with the constants inlined, enabling dead-code elimination, loop unrolling, and strength reduction.
Akunu uses function constants extensively. Here are the key ones:
FC_HEAD_DIM [[function_constant(0)]] -- Attention head dimension (64/128/256)
FC_GROUP_SIZE [[function_constant(0)]] -- MLX quantization group size
FC_K_DIM [[function_constant(1)]] -- MLX K dimension
FC_NQ_ROWS [[function_constant(2)]] -- Adaptive query rows for prefill
FC_NON_CAUSAL [[function_constant(3)]] -- Non-causal attention flag
FC_GEMM_K [[function_constant(10)]] -- GEMM K dimension
FC_BATCH_M [[function_constant(11)]] -- Batched GEMV M dimension
The pattern for using them looks like this:
constant uint FC_HEAD_DIM [[function_constant(0)]];
constant bool FC_ATTN_SPECIALIZED = is_function_constant_defined(FC_HEAD_DIM);
// In the kernel:
const uint head_dim = FC_ATTN_SPECIALIZED ? FC_HEAD_DIM : params.head_dim;
When the host creates a pipeline state with FC_HEAD_DIM = 128, the compiler
replaces head_dim with the literal 128 everywhere, enabling:
- Loop unrolling:
for (uint d = 0; d < head_dim; d += 8)becomes a known-count loop - Dead code elimination:
if (head_dim == 256)branches can be removed - Register allocation: the compiler knows exactly how many registers to allocate
- Shift/mask optimization:
idx / head_dimbecomesidx >> 7for head_dim=128
This is why you see patterns like:
const uint hd_shift = (head_dim == 128) ? 7 : (head_dim == 64) ? 6 : (head_dim == 256) ? 8 : 0;
const uint hd_mask = head_dim - 1;
When FC_HEAD_DIM = 128, the compiler collapses this to hd_shift = 7 and
hd_mask = 127, and every subsequent idx / head_dim becomes idx >> 7.
33.5 Buffer Binding Convention
Every kernel follows a consistent buffer binding pattern:
Buffer 0: Input data (activations, Q, token_ids, etc.)
Buffer 1: Weights (W, K, packed MLX buffer, etc.)
Buffer 2: Output (y, O, etc.)
Buffer 3: Parameters (GEMMParams, AttentionParams, etc.)
Buffer 4+: Optional (V for attention, tree_mask, etc.)
For matmul kernels specifically:
buffer(0) = x -- activation vector [K] or [M, K]
buffer(1) = W -- weight matrix [N, K] (possibly quantized)
buffer(2) = y -- output vector [N] or [M, N]
buffer(3) = params -- GEMMParams struct
For attention kernels:
buffer(0) = Q -- query [n_heads, seq_len, head_dim]
buffer(1) = K -- key [n_kv_heads, kv_seq_len, head_dim]
buffer(2) = V -- value [n_kv_heads, kv_seq_len, head_dim]
buffer(3) = O -- output [n_heads, seq_len, head_dim]
buffer(4) = params -- AttentionParams struct
buffer(5) = tree_mask -- optional speculative decoding mask
For fused kernels that combine two operations (like gemv_q4_0_silu), the input
buffer splits:
buffer(0) = gate -- gate projection output
buffer(1) = up -- up projection output
buffer(2) = W -- down projection weights
buffer(3) = y -- output
buffer(4) = params -- GEMMParams
This convention is enforced by the Swift host code in KernelStore, which maps
each buffer to its binding index at pipeline creation time.
33.6 Kernel Naming Conventions
The naming follows a strict {operation}_{quantformat}[_{variant}] pattern:
gemv_f16 -- GEMV, F16 weights, standard
gemv_q4_0 -- GEMV, Q4_0 weights, K < 2048
gemv_q4_0_l -- GEMV, Q4_0 weights, K >= 2048 (large)
gemv_q4_0_silu -- GEMV, Q4_0 weights, fused SiLU activation
gemv_wide_q4_0 -- GEMV, Q4_0 weights, wide variant (large N)
gemv_batched_q4_0 -- GEMV, Q4_0 weights, batched (M=2..16)
simd_gemm_q4_0 -- GEMM, Q4_0 weights, standard tile
simd_gemm_small_q4_0-- GEMM, Q4_0 weights, small M tile
gemv_mlx_q4 -- GEMV, MLX 4-bit format
gemv_mlx_q4_l -- GEMV, MLX 4-bit format, large (8 SGs)
The _l suffix denotes the “large K” variant. The _silu suffix denotes fused
SiLU activation. The _small prefix in GEMM denotes the TM=8 tile geometry
optimized for small batch sizes.
33.7 Kernel Category Deep Dive
Let us briefly survey each category before we dive deep in subsequent chapters.
Matmul (72 files)
This is the heart of akunu. Three major kernel families:
GEMV (M=1 decode): The single-token decode path. One activation vector multiplied by the weight matrix. This is memory-bandwidth-bound – you read the entire weight matrix for a single dot product per row. The key optimization is reading the weight once and computing multiple output rows per SIMD group (NR=4 typically).
GEMM (M>=32 prefill): The multi-token prefill path. Uses Apple’s
simdgroup_multiply_accumulate for native 8x8 matrix operations. Tile geometry
is TM=32, TN=64, NK=32, with 4 SIMD groups cooperatively loading tiles into
threadgroup memory.
Batched GEMV (M=2..16): Bridges the gap between GEMV and GEMM. For speculative decoding or small batches, neither pure GEMV (reads weights M times) nor GEMM (wastes tile space when M<32) is optimal. Batched GEMV reads weights once and computes all M activation rows simultaneously.
Attention (6 files)
Four distinct flash-attention implementations, each tuned for a different scenario:
-
Standard Decode (
flash_attention.metal: flash_attention_decode_f16): Multi-SG, register-local V accumulation, 2 barriers per KV entry. -
Fast Decode (
flash_attention_decode_fast.metal): Single SIMD group, zero threadgroup barriers, each thread handleshead_dim/32elements. -
Parallel Decode (
flash_attention_decode_parallel.metal): 32 SIMD groups (1024 threads), strided KV access for maximum memory parallelism, cross-SG reduction. -
Prefill V2 (
flash_attention_prefill_v2.metal): BQ=32 query rows, simdgroup register accumulators, exp2 trick, row-level reduce viasimd_shuffle_xor.
Plus a standalone softmax kernel and a Gemma logit soft-capping kernel.
Norm (5 files)
rmsnorm.metal -- Standard RMSNorm: y = (x / rms) * weight
rmsnorm_gemma.metal -- Gemma variant: y = (x / rms) * (1 + weight)
layernorm.metal -- Full LayerNorm with mean subtraction
residual_rmsnorm.metal -- Fused residual add + RMSNorm
head_rmsnorm.metal -- Per-head RMSNorm (for architectures like Cohere)
All follow the same pattern: one threadgroup per row, threads stride over the dimension, two-phase reduction for sum-of-squares.
RoPE (7 files)
rope.metal -- Standard rotary position embeddings
rope_neox.metal -- GPT-NeoX interleaved RoPE
rope_kv_write.metal -- Fused RoPE + KV cache write
rope_neox_kv_write.metal-- Fused NeoX RoPE + KV cache write
*_batch.metal -- Batch (prefill) variants of the above
head_norm_rope_neox_kv_write.metal -- Triple fused: head norm + RoPE + KV write
The fused variants are critical for performance – they eliminate intermediate buffer writes between RoPE and KV cache insertion.
Embedding (12 files)
One embedding lookup kernel per quantization format. Each dequantizes on the fly
during the lookup, converting the quantized embedding table directly to FP16 output.
The MLX variants handle the packed [weights|scales|biases] buffer layout.
Sampling (7 files)
argmax.metal -- Simple argmax (greedy decoding)
topk_select.metal -- Top-K selection for sampling
temperature_scale.metal -- Temperature scaling (logits *= 1/T)
repetition_penalty.metal-- Repetition penalty application
gumbel_topk.metal -- Gumbel-max trick for stochastic top-K
grammar_bitmask.metal -- Grammar-constrained decoding mask
whisper_suppress.metal -- Whisper-specific token suppression
Convert/Dequant (12 files)
Standalone dequantization kernels that convert quantized buffers to FP16. These are used when a kernel does not have a native quantized variant, or for debugging.
Common Utilities (6 files)
Elementwise operations: bias_add, residual_add, vector_add, transpose,
head_rearrange (permute between [batch, seq, heads, dim] and
[batch, heads, seq, dim]), qkv_split (split a fused QKV projection output
into separate Q, K, V buffers).
33.8 The Threadgroup Geometry Taxonomy
One of the most confusing aspects of reading akunu’s kernels is that different kernel families use radically different threadgroup geometries. Here is a reference card:
+-------------------------------+--------+-------+--------+--------+
| Kernel | TG Size| # SGs | Rows/TG| Notes |
+-------------------------------+--------+-------+--------+--------+
| gemv_f16 | 128 | 4 | 16 | NR=4 |
| gemv_q4_0 (small K) | 128 | 4 | 16 | NQ=16 |
| gemv_q4_0_l (large K) | 256 | 8 | 32 | NQ=16 |
| gemv_q8_0 | 256 | 8 | 32 | NR=4 |
| gemv_q4_k | 256 | 8 | 16 | nr0=2 |
| gemv_wide_* | 256 | 8 | 64 | NCOLS=8|
| gemv_batched_* | 128/256| 4/8 | 16/32 | M<=16 |
| gemv_mlx_q4 | 128 | 4 | 16 | NR=4 |
| gemv_mlx_q4_l | 256 | 8 | 32 | NR=4 |
| simd_gemm_* | 128 | 4 | 32x64 | Tiled |
| simd_gemm_small_* | 128 | 4 | 8x64 | Tiled |
| flash_attention_decode | 128 | 4 | 1 head | per-TG |
| flash_attention_decode_fast | 32 | 1 | 1 head | no bar |
| flash_attention_decode_par | 1024 | 32 | 1 head | max BW |
| flash_attention_prefill_v2 | 128 | 4 | 32 Q | BQ=32 |
| rmsnorm, softmax | varies |varies | 1 row | per-TG |
+-------------------------------+--------+-------+--------+--------+
The pattern: GEMV kernels use 128-256 threads with 4-8 SIMD groups, each SG computing 2-8 output rows. GEMM kernels use 128 threads with 4 SGs in a 2D grid. Attention has three completely different geometries depending on the decode scenario.
33.9 Memory Access Patterns: Why It All Matters
Apple Silicon’s GPU shares unified memory with the CPU, but bandwidth is still the limiting factor for LLM inference. M2 Pro provides about 200 GB/s, M3 Max about 400 GB/s. A 7B model at Q4_0 is about 3.5 GB of weights. At 200 GB/s, you can read the entire model in ~17.5 ms, giving a theoretical ceiling of ~57 tokens/second for decode (one full weight read per token).
This means every wasted byte of memory bandwidth directly costs throughput. The kernels are designed around three principles:
-
Read weights once, compute multiple outputs. GEMV kernels process NR=4 output rows per SIMD group, amortizing the activation vector read.
-
Vectorized loads. Using
half4(8 bytes) orfloat4(16 bytes) loads instead of scalar loads gives 4x-8x better memory throughput. -
Dequantize on-the-fly. Never materialize the full FP16 weight matrix. Read the quantized blocks, dequantize in registers, multiply, accumulate.
The next three chapters will dive deep into the actual implementations: GEMV (Chapter 34), GEMM (Chapter 35), and FlashAttention (Chapter 36).
33.10 How a Single Inference Step Maps to Kernels
To tie it all together, here is the kernel sequence for a single decode step of a typical Llama-style model:
1. embedding_lookup_q4_0 -- Token embedding dequant + lookup
2. For each transformer layer:
a. rmsnorm_f16 -- Attention norm
b. gemv_q4_0 (x3) -- Q, K, V projections
c. rope_neox_kv_write -- Fused RoPE + KV cache write
d. flash_attention_decode_* -- Attention (variant depends on kv_seq_len)
e. gemv_q4_0 -- Output projection
f. residual_add -- Residual connection
g. rmsnorm_f16 -- FFN norm
h. gemv_q4_0 (x2) -- Gate and Up projections
i. silu -- SiLU activation (or fused into gemv_q4_0_silu)
j. gemv_q4_0 -- Down projection
k. residual_add -- Residual connection
3. rmsnorm_f16 -- Final norm
4. gemv_wide_q4_0 -- LM head (vocab projection, large N)
5. temperature_scale_f16 -- Temperature scaling
6. argmax / topk_select -- Sampling
That is roughly 13 kernel dispatches per layer, plus 4 for the head. For a
32-layer model, that is 420+ kernel dispatches per token. Each one is a
separate computeEncoder.dispatchThreadgroups() call. The overhead per dispatch
on Apple Silicon is about 1-3 microseconds, so dispatch overhead alone is
0.5-1.3 ms – which is why kernel fusion (like gemv_q4_0_silu and
rope_neox_kv_write) matters so much.
With this map in hand, let us dive into the actual kernel implementations. Chapter 34 starts with the GEMV kernels – the single hottest code path in decode.