Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Kernel Architecture Overview

If you have made it this far in the book, you understand how akunu’s Swift host code orchestrates inference – how buffers are allocated, how the command encoder sequences dispatches, how the KV cache grows and shifts. But we have been treating each computeEncoder.dispatchThreadgroups(...) call as a black box. Time to open every single one of those boxes.

Akunu ships roughly 135 Metal shader files, spread across 11 subdirectories, each one a hand-tuned GPU program that runs on Apple Silicon. This chapter is the map. We will cover the directory layout, the shared header infrastructure, the parameter struct conventions, the function-constant specialization system, the buffer binding conventions, and finally a taxonomy of every kernel category. By the end you will know where to look for any particular kernel and why it was written the way it was.


33.1 The 10,000-Foot View: Directory Layout

All Metal source lives under a single tree:

backend/metal/kernels/
  ShaderTypes.h            <-- Parameter structs (shared with Swift)
  KernelCommon.h           <-- Block types, helpers, constants
  metal/kernel/
    activation/            <-- 2 files   (silu, gelu)
    attention/             <-- 6 files   (decode, fast-decode, parallel, prefill, softmax, logit_cap)
    common/                <-- 6 files   (bias_add, residual_add, vector_add, transpose, head_rearrange, qkv_split)
    conv/                  <-- 1 file    (conv1d_f16)
    convert/               <-- 12 files  (dequant for every quant format + f16<->f32)
    embedding/             <-- 12 files  (lookup for every quant format + bf16 + pos_embed_add)
    fused/                 <-- 2 files   (gemv_q8_0_head_rmsnorm, whisper_gemv_fused)
    kv_cache/              <-- 2 files   (kv_cache_write, kv_cache_shift)
    matmul/                <-- 72 files  (gemv, gemv_wide, gemv_batched, simd_gemm, simd_gemm_small)
    norm/                  <-- 5 files   (rmsnorm, rmsnorm_gemma, layernorm, residual_rmsnorm, head_rmsnorm)
    rope/                  <-- 7 files   (rope, rope_neox, fused rope+kv variants)
    sampling/              <-- 7 files   (argmax, topk, temperature, repetition_penalty, gumbel, grammar, whisper_suppress)

Let us count that up:

+---------------------+-------+
| Category            | Files |
+---------------------+-------+
| matmul (GEMV/GEMM)  |    72 |
| embedding            |    12 |
| convert (dequant)    |    12 |
| sampling             |     7 |
| rope                 |     7 |
| attention            |     6 |
| common               |     6 |
| norm                 |     5 |
| activation           |     2 |
| fused                |     2 |
| kv_cache             |     2 |
| conv                 |     1 |
+---------------------+-------+
| Total                |  ~135 |
+---------------------+-------+

The overwhelming majority – 72 out of 135 – are in matmul/. That is not a surprise. The single hottest operation in LLM inference is the matrix-vector multiply (GEMV for decode) and matrix-matrix multiply (GEMM for prefill). When you multiply those two kernel shapes (GEMV + GEMM) by the number of quantization formats akunu supports (Q4_0, Q4_1, Q4_K, Q5_0, Q5_K, Q6_K, Q8_0, Q2_K, Q3_K, F16, BF16, MLX-Q4, MLX-Q3, MLX-Q6, MLX-Q8), and then add the variant axes (standard, wide, batched, small-M, fused-silu), you get a combinatorial explosion of kernel files.

Here is how the matmul/ directory decomposes:

matmul/
  gemv_*.metal           (18 files)  -- Single-row GEMV (M=1 decode)
  gemv_wide_*.metal      ( 7 files)  -- Wide GEMV (large N, e.g. vocab projection)
  gemv_batched_*.metal   (12 files)  -- Batched GEMV (M=2..16, speculative decode)
  simd_gemm_*.metal      (17 files)  -- Tiled GEMM (M>=32, prefill)
  simd_gemm_small_*.metal(16 files)  -- Small-M GEMM (M=2..8, small prefill)
  gemv_*_silu.metal      ( 5 files)  -- Fused SiLU(gate)*up + GEMV

That is the landscape. Let us now zoom in on the two shared header files that every kernel includes.


33.2 ShaderTypes.h: The Contract Between Swift and Metal

Every Metal kernel receives its parameters via a constant buffer. On the Swift side, that buffer is a MTLBuffer filled with a C struct. On the GPU side, the kernel reads that same struct. The contract between the two is ShaderTypes.h.

There is a critical comment at the top of this file:

/*
 * Shared type definitions used by both Metal kernels (.metal) and Swift host code.
 *
 * CRITICAL: Any change here MUST be mirrored in Sources/KernelStore/MetalTypes.swift.
 * All structs are padded to 16-byte boundaries for Metal argument buffer alignment.
 */

Metal requires that constant buffer offsets and struct sizes be aligned to 16 bytes. Every struct in ShaderTypes.h is manually padded with _pad fields to guarantee this. Let us walk through every parameter struct.

GEMMParams (32 bytes)

struct GEMMParams {
    uint32_t M;      // Rows of A / rows of C
    uint32_t N;      // Columns of B / columns of C
    uint32_t K;      // Columns of A / rows of B
    uint32_t lda;    // Leading dimension of A
    uint32_t ldb;    // Leading dimension of B
    uint32_t ldc;    // Leading dimension of C
    float alpha;     // C = alpha * A @ B + beta * C
    float beta;
};

This is the workhorse parameter struct. It is used by every GEMV and GEMM kernel. The M, N, K triple defines the matrix dimensions, and lda/ldb/ldc are the leading dimensions (row strides), allowing matrices that are sub-views of larger buffers. The alpha/beta fields enable BLAS-style C = alpha*A*B + beta*C semantics – useful for residual connections and accumulation.

Here is how it maps onto the GEMV case (M=1):

    x [1, K] @ W^T [N, K] --> y [1, N]

    +---+---+---+---+---+---+---+---+   GEMMParams
    | M | N | K |lda|ldb|ldc| a | b |
    | 1 |4096|4096|4096|4096|4096|1.0|0.0|
    +---+---+---+---+---+---+---+---+
     0   4   8  12  16  20  24  28     byte offset

AttentionParams (32 bytes)

struct AttentionParams {
    uint32_t seq_len;
    uint32_t kv_seq_len;
    uint32_t head_dim;
    uint32_t n_heads;
    uint32_t n_kv_heads;
    float scale;           // 1.0 / sqrt(head_dim)
    uint32_t kv_stride;    // elements between KV heads
    uint32_t q_stride;     // elements between Q/O rows
};

The attention params carry everything the flash-attention kernels need. The kv_seq_len can differ from seq_len during decode (where seq_len=1 but kv_seq_len could be thousands). The kv_stride and q_stride fields allow flexible memory layouts – if zero, the kernel falls back to kv_seq_len * head_dim or n_heads * head_dim respectively. This is what lets the same kernel handle both contiguous and interleaved head layouts.

RMSNormParams (16 bytes)

struct RMSNormParams {
    uint32_t dim;
    float eps;
    uint32_t _pad0;
    uint32_t _pad1;
};

Minimal. Just the dimension and epsilon. The two pad fields bring it to exactly 16 bytes. The kernel figures out which row to process from its threadgroup position in the grid.

RoPEParams (32 bytes)

struct RoPEParams {
    uint32_t seq_len;
    uint32_t head_dim;
    uint32_t n_heads;
    uint32_t pos_offset;    // global position for decode step
    float theta;            // base frequency (default 10000.0)
    uint32_t row_stride;    // elements between rows
    uint32_t _pad0;
    uint32_t _pad1;
};

The pos_offset is the key field here – during decode, each token’s position is pos_offset, not derived from the sequence index. The theta field (default 10000.0) is the RoPE base frequency, configurable per model.

MLXParams (32 bytes)

struct MLXParams {
    uint32_t M;
    uint32_t N;
    uint32_t K;
    uint32_t group_size;     // quantization group size (typically 64)
    uint32_t bits;           // bits per value (4 or 8)
    uint32_t weight_bytes;   // byte offset to scales section
    uint32_t _pad0;
    uint32_t _pad1;
};

MLX-format weights pack everything into a single contiguous buffer: [packed_weights | scales | biases]. The weight_bytes field tells the kernel where the scales section starts, and from there the biases follow immediately at scales + N * (K / group_size). The bits field selects between 3-bit, 4-bit, 6-bit, and 8-bit dequantization paths.

Fused Parameter Structs

Akunu has several fused kernels that combine two operations into one dispatch. Each has its own parameter struct:

RoPEQKVWriteParams (32 bytes)   -- Fused Q/K-RoPE + KV cache write
KVCacheWriteParams (32 bytes)   -- KV cache write (separate K and V)
KVCacheShiftParams (32 bytes)   -- KV cache left-shift (ring-buffer eviction)
GEMVHeadNormParams (32 bytes)   -- Fused GEMV + per-head RMSNorm
GEMVKVParams       (32 bytes)   -- Fused GEMV + KV cache write
HeadNormParams     (32 bytes)   -- Per-head RMSNorm (standalone)
Conv1DParams       (32 bytes)   -- Conv1D parameters (Whisper)

And simpler ones:

ElementwiseParams          (16 bytes)  -- Just a count
SoftmaxParams              (16 bytes)  -- rows, cols
EmbeddingParams            (16 bytes)  -- num_tokens, dim
LayerNormParams            (16 bytes)  -- dim, eps
TemperatureScaleParams     (16 bytes)  -- inv_temperature, count
RepetitionPenaltyParams    (16 bytes)  -- penalty, n_tokens

The pattern is consistent: every struct is either 16 or 32 bytes, always 16-byte aligned, always with explicit padding.


33.3 KernelCommon.h: Shared Infrastructure

Every .metal file includes KernelCommon.h. This header defines:

Hardware Constants

constant constexpr uint SIMD_WIDTH = 32;
constant constexpr uint MAX_TG_MEMORY = 32768;    // 32 KB
constant constexpr uint SIMD_TILE = 8;             // native simdgroup_matrix dimension

Apple Silicon GPUs have a SIMD width of 32 threads (unlike NVIDIA’s 32 or AMD’s 64). The SIMD_TILE = 8 is the native dimension of Apple’s simdgroup_matrix operations – all simdgroup matrix operations work on 8x8 tiles.

GEMM Tiling Constants

constant constexpr uint TILE_M = 64;
constant constexpr uint TILE_N = 64;
constant constexpr uint TILE_K = 32;
constant constexpr uint GEMM_TG_WIDTH = 32;
constant constexpr uint GEMM_TG_HEIGHT = 4;
constant constexpr uint GEMM_TG_SIZE = 128;       // 4 SIMD groups

The make_uniform() Helper

inline int make_uniform(int val) {
    return simd_broadcast_first(val);
}

This is a surprisingly important optimization. When you write for (int i = 0; i < N; i++), the Metal compiler does not know whether N is the same across all threads. If it might differ, the compiler must generate divergent branching code. By wrapping the loop bound in make_uniform(), you explicitly tell the compiler “this value is identical across all threads in the SIMD group,” enabling it to use uniform branch prediction and avoid per-lane divergence handling.

You will see make_uniform() wrapped around virtually every loop bound in every kernel.

Quantized Block Types

#define QK4_0  32     // elements per Q4_0 block
#define QK8_0  32     // elements per Q8_0 block
#define QK_K  256     // elements per K-quant superblock

struct block_q4_0 {
    half d;                   // scale factor (2 bytes)
    uint8_t qs[QK4_0 / 2];   // 16 bytes of nibble-packed values
};                            // Total: 18 bytes per 32 elements = 4.5 bits/element

struct block_q8_0 {
    half d;                   // scale factor (2 bytes)
    int8_t qs[QK8_0];        // 32 bytes of 8-bit values
};                            // Total: 34 bytes per 32 elements = 8.5 bits/element

struct block_q4_K {
    half d;                         // super-block scale (2 bytes)
    half dmin;                      // super-block min scale (2 bytes)
    uint8_t scales[K_SCALE_SIZE];   // 12 bytes of packed 6-bit scales
    uint8_t qs[QK_K / 2];          // 128 bytes of nibble-packed values
};                                  // Total: 144 bytes per 256 elements = 4.5 bits/element

Here is a visual layout of block_q4_0:

  block_q4_0: 18 bytes total, 32 elements
  +---------+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  |  d (f16)|q0|q1|q2|q3|q4|q5|q6|q7|q8|q9|qA|qB|qC|qD|qE|qF|
  +---------+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
  | 2 bytes |                 16 bytes                         |
  +---------+-------------------------------------------------+

  Each byte qs[j] holds two 4-bit values:
    element j      = (qs[j] & 0x0F) - 8     (low nibble)
    element j + 16 = (qs[j] >> 4) - 8       (high nibble)

  Dequantization: value = d * (nibble - 8)

And block_q4_K, the “K-quant” superblock:

  block_q4_K: 144 bytes total, 256 elements
  +------+------+-----------+---------------------------------------+
  |d (f16)|dmin  | scales[12]|              qs[128]                  |
  |2 bytes|2 bytes| 12 bytes |            128 bytes                  |
  +------+------+-----------+---------------------------------------+

  The 256 elements are divided into 8 sub-blocks of 32 elements each.
  Each sub-block has its own 6-bit scale and 6-bit min, packed into
  the 12-byte scales array.

  Dequantization:
    value = d * sub_scale * nibble - dmin * sub_min

Threadgroup Reduction Helpers

inline float tg_reduce_sum(float val, uint sgid, uint slid,
                           uint tg_size, threadgroup float *shared);

inline float tg_reduce_max(float val, uint sgid, uint slid,
                           uint tg_size, threadgroup float *shared);

These implement the classic two-phase reduction pattern:

  Phase 1: simd_sum() within each SIMD group
  Phase 2: Write lane-0 results to shared[], barrier, then simd_sum() in SG 0

  Example: 256 threads = 8 SIMD groups

  SG0: simd_sum → shared[0]
  SG1: simd_sum → shared[1]
  ...
  SG7: simd_sum → shared[7]
       --- barrier ---
  SG0: reads shared[0..7], simd_sum → shared[0]
       --- barrier ---
  All threads read shared[0]

Activation Functions

inline half act_silu(half x) {
    return x / (half(1) + exp(-x));
}

inline float act_gelu_f32(half x) {
    float xf = float(x);
    constexpr float SQRT_2_OVER_PI = 0.7978845608f;
    constexpr float GELU_COEF_A = 0.044715f;
    return 0.5f * xf * (1.0f + precise::tanh(SQRT_2_OVER_PI * xf * (1.0f + GELU_COEF_A * xf * xf)));
}

Note that act_gelu_f32 returns float, not half. The comment explains why: “returns float to avoid F16 overflow in gate*up multiplication.” When you compute gelu(gate) * up, both operands can be large in FP16, and the product can overflow. Computing in FP32 avoids this.


33.4 Function Constant Specialization

Metal’s function constants are compile-time values that can vary between pipeline states. Think of them as C++ template parameters that you set at pipeline creation time, not at shader compile time. The Metal compiler generates specialized variants with the constants inlined, enabling dead-code elimination, loop unrolling, and strength reduction.

Akunu uses function constants extensively. Here are the key ones:

FC_HEAD_DIM       [[function_constant(0)]]   -- Attention head dimension (64/128/256)
FC_GROUP_SIZE     [[function_constant(0)]]   -- MLX quantization group size
FC_K_DIM          [[function_constant(1)]]   -- MLX K dimension
FC_NQ_ROWS        [[function_constant(2)]]   -- Adaptive query rows for prefill
FC_NON_CAUSAL     [[function_constant(3)]]   -- Non-causal attention flag
FC_GEMM_K         [[function_constant(10)]]  -- GEMM K dimension
FC_BATCH_M        [[function_constant(11)]]  -- Batched GEMV M dimension

The pattern for using them looks like this:

constant uint FC_HEAD_DIM [[function_constant(0)]];
constant bool FC_ATTN_SPECIALIZED = is_function_constant_defined(FC_HEAD_DIM);

// In the kernel:
const uint head_dim = FC_ATTN_SPECIALIZED ? FC_HEAD_DIM : params.head_dim;

When the host creates a pipeline state with FC_HEAD_DIM = 128, the compiler replaces head_dim with the literal 128 everywhere, enabling:

  • Loop unrolling: for (uint d = 0; d < head_dim; d += 8) becomes a known-count loop
  • Dead code elimination: if (head_dim == 256) branches can be removed
  • Register allocation: the compiler knows exactly how many registers to allocate
  • Shift/mask optimization: idx / head_dim becomes idx >> 7 for head_dim=128

This is why you see patterns like:

const uint hd_shift = (head_dim == 128) ? 7 : (head_dim == 64) ? 6 : (head_dim == 256) ? 8 : 0;
const uint hd_mask  = head_dim - 1;

When FC_HEAD_DIM = 128, the compiler collapses this to hd_shift = 7 and hd_mask = 127, and every subsequent idx / head_dim becomes idx >> 7.


33.5 Buffer Binding Convention

Every kernel follows a consistent buffer binding pattern:

Buffer 0:  Input data     (activations, Q, token_ids, etc.)
Buffer 1:  Weights        (W, K, packed MLX buffer, etc.)
Buffer 2:  Output         (y, O, etc.)
Buffer 3:  Parameters     (GEMMParams, AttentionParams, etc.)
Buffer 4+: Optional       (V for attention, tree_mask, etc.)

For matmul kernels specifically:

  buffer(0) = x         -- activation vector [K] or [M, K]
  buffer(1) = W         -- weight matrix [N, K] (possibly quantized)
  buffer(2) = y         -- output vector [N] or [M, N]
  buffer(3) = params    -- GEMMParams struct

For attention kernels:

  buffer(0) = Q         -- query   [n_heads, seq_len, head_dim]
  buffer(1) = K         -- key     [n_kv_heads, kv_seq_len, head_dim]
  buffer(2) = V         -- value   [n_kv_heads, kv_seq_len, head_dim]
  buffer(3) = O         -- output  [n_heads, seq_len, head_dim]
  buffer(4) = params    -- AttentionParams struct
  buffer(5) = tree_mask -- optional speculative decoding mask

For fused kernels that combine two operations (like gemv_q4_0_silu), the input buffer splits:

  buffer(0) = gate      -- gate projection output
  buffer(1) = up        -- up projection output
  buffer(2) = W         -- down projection weights
  buffer(3) = y         -- output
  buffer(4) = params    -- GEMMParams

This convention is enforced by the Swift host code in KernelStore, which maps each buffer to its binding index at pipeline creation time.


33.6 Kernel Naming Conventions

The naming follows a strict {operation}_{quantformat}[_{variant}] pattern:

  gemv_f16            -- GEMV, F16 weights, standard
  gemv_q4_0           -- GEMV, Q4_0 weights, K < 2048
  gemv_q4_0_l         -- GEMV, Q4_0 weights, K >= 2048 (large)
  gemv_q4_0_silu      -- GEMV, Q4_0 weights, fused SiLU activation
  gemv_wide_q4_0      -- GEMV, Q4_0 weights, wide variant (large N)
  gemv_batched_q4_0   -- GEMV, Q4_0 weights, batched (M=2..16)
  simd_gemm_q4_0      -- GEMM, Q4_0 weights, standard tile
  simd_gemm_small_q4_0-- GEMM, Q4_0 weights, small M tile
  gemv_mlx_q4         -- GEMV, MLX 4-bit format
  gemv_mlx_q4_l       -- GEMV, MLX 4-bit format, large (8 SGs)

The _l suffix denotes the “large K” variant. The _silu suffix denotes fused SiLU activation. The _small prefix in GEMM denotes the TM=8 tile geometry optimized for small batch sizes.


33.7 Kernel Category Deep Dive

Let us briefly survey each category before we dive deep in subsequent chapters.

Matmul (72 files)

This is the heart of akunu. Three major kernel families:

GEMV (M=1 decode): The single-token decode path. One activation vector multiplied by the weight matrix. This is memory-bandwidth-bound – you read the entire weight matrix for a single dot product per row. The key optimization is reading the weight once and computing multiple output rows per SIMD group (NR=4 typically).

GEMM (M>=32 prefill): The multi-token prefill path. Uses Apple’s simdgroup_multiply_accumulate for native 8x8 matrix operations. Tile geometry is TM=32, TN=64, NK=32, with 4 SIMD groups cooperatively loading tiles into threadgroup memory.

Batched GEMV (M=2..16): Bridges the gap between GEMV and GEMM. For speculative decoding or small batches, neither pure GEMV (reads weights M times) nor GEMM (wastes tile space when M<32) is optimal. Batched GEMV reads weights once and computes all M activation rows simultaneously.

Attention (6 files)

Four distinct flash-attention implementations, each tuned for a different scenario:

  1. Standard Decode (flash_attention.metal: flash_attention_decode_f16): Multi-SG, register-local V accumulation, 2 barriers per KV entry.

  2. Fast Decode (flash_attention_decode_fast.metal): Single SIMD group, zero threadgroup barriers, each thread handles head_dim/32 elements.

  3. Parallel Decode (flash_attention_decode_parallel.metal): 32 SIMD groups (1024 threads), strided KV access for maximum memory parallelism, cross-SG reduction.

  4. Prefill V2 (flash_attention_prefill_v2.metal): BQ=32 query rows, simdgroup register accumulators, exp2 trick, row-level reduce via simd_shuffle_xor.

Plus a standalone softmax kernel and a Gemma logit soft-capping kernel.

Norm (5 files)

rmsnorm.metal           -- Standard RMSNorm: y = (x / rms) * weight
rmsnorm_gemma.metal     -- Gemma variant: y = (x / rms) * (1 + weight)
layernorm.metal         -- Full LayerNorm with mean subtraction
residual_rmsnorm.metal  -- Fused residual add + RMSNorm
head_rmsnorm.metal      -- Per-head RMSNorm (for architectures like Cohere)

All follow the same pattern: one threadgroup per row, threads stride over the dimension, two-phase reduction for sum-of-squares.

RoPE (7 files)

rope.metal              -- Standard rotary position embeddings
rope_neox.metal         -- GPT-NeoX interleaved RoPE
rope_kv_write.metal     -- Fused RoPE + KV cache write
rope_neox_kv_write.metal-- Fused NeoX RoPE + KV cache write
*_batch.metal           -- Batch (prefill) variants of the above
head_norm_rope_neox_kv_write.metal -- Triple fused: head norm + RoPE + KV write

The fused variants are critical for performance – they eliminate intermediate buffer writes between RoPE and KV cache insertion.

Embedding (12 files)

One embedding lookup kernel per quantization format. Each dequantizes on the fly during the lookup, converting the quantized embedding table directly to FP16 output. The MLX variants handle the packed [weights|scales|biases] buffer layout.

Sampling (7 files)

argmax.metal            -- Simple argmax (greedy decoding)
topk_select.metal       -- Top-K selection for sampling
temperature_scale.metal -- Temperature scaling (logits *= 1/T)
repetition_penalty.metal-- Repetition penalty application
gumbel_topk.metal       -- Gumbel-max trick for stochastic top-K
grammar_bitmask.metal   -- Grammar-constrained decoding mask
whisper_suppress.metal  -- Whisper-specific token suppression

Convert/Dequant (12 files)

Standalone dequantization kernels that convert quantized buffers to FP16. These are used when a kernel does not have a native quantized variant, or for debugging.

Common Utilities (6 files)

Elementwise operations: bias_add, residual_add, vector_add, transpose, head_rearrange (permute between [batch, seq, heads, dim] and [batch, heads, seq, dim]), qkv_split (split a fused QKV projection output into separate Q, K, V buffers).


33.8 The Threadgroup Geometry Taxonomy

One of the most confusing aspects of reading akunu’s kernels is that different kernel families use radically different threadgroup geometries. Here is a reference card:

+-------------------------------+--------+-------+--------+--------+
| Kernel                        | TG Size| # SGs | Rows/TG| Notes  |
+-------------------------------+--------+-------+--------+--------+
| gemv_f16                      |   128  |   4   |   16   | NR=4   |
| gemv_q4_0 (small K)          |   128  |   4   |   16   | NQ=16  |
| gemv_q4_0_l (large K)        |   256  |   8   |   32   | NQ=16  |
| gemv_q8_0                    |   256  |   8   |   32   | NR=4   |
| gemv_q4_k                    |   256  |   8   |   16   | nr0=2  |
| gemv_wide_*                  |   256  |   8   |   64   | NCOLS=8|
| gemv_batched_*               | 128/256| 4/8   | 16/32  | M<=16  |
| gemv_mlx_q4                  |   128  |   4   |   16   | NR=4   |
| gemv_mlx_q4_l                |   256  |   8   |   32   | NR=4   |
| simd_gemm_*                  |   128  |   4   | 32x64  | Tiled  |
| simd_gemm_small_*            |   128  |   4   |  8x64  | Tiled  |
| flash_attention_decode       |   128  |   4   |  1 head | per-TG |
| flash_attention_decode_fast  |    32  |   1   |  1 head | no bar |
| flash_attention_decode_par   |  1024  |  32   |  1 head | max BW |
| flash_attention_prefill_v2   |   128  |   4   |  32 Q   | BQ=32  |
| rmsnorm, softmax             | varies |varies |  1 row  | per-TG |
+-------------------------------+--------+-------+--------+--------+

The pattern: GEMV kernels use 128-256 threads with 4-8 SIMD groups, each SG computing 2-8 output rows. GEMM kernels use 128 threads with 4 SGs in a 2D grid. Attention has three completely different geometries depending on the decode scenario.


33.9 Memory Access Patterns: Why It All Matters

Apple Silicon’s GPU shares unified memory with the CPU, but bandwidth is still the limiting factor for LLM inference. M2 Pro provides about 200 GB/s, M3 Max about 400 GB/s. A 7B model at Q4_0 is about 3.5 GB of weights. At 200 GB/s, you can read the entire model in ~17.5 ms, giving a theoretical ceiling of ~57 tokens/second for decode (one full weight read per token).

This means every wasted byte of memory bandwidth directly costs throughput. The kernels are designed around three principles:

  1. Read weights once, compute multiple outputs. GEMV kernels process NR=4 output rows per SIMD group, amortizing the activation vector read.

  2. Vectorized loads. Using half4 (8 bytes) or float4 (16 bytes) loads instead of scalar loads gives 4x-8x better memory throughput.

  3. Dequantize on-the-fly. Never materialize the full FP16 weight matrix. Read the quantized blocks, dequantize in registers, multiply, accumulate.

The next three chapters will dive deep into the actual implementations: GEMV (Chapter 34), GEMM (Chapter 35), and FlashAttention (Chapter 36).


33.10 How a Single Inference Step Maps to Kernels

To tie it all together, here is the kernel sequence for a single decode step of a typical Llama-style model:

  1. embedding_lookup_q4_0        -- Token embedding dequant + lookup
  2. For each transformer layer:
     a. rmsnorm_f16               -- Attention norm
     b. gemv_q4_0  (x3)          -- Q, K, V projections
     c. rope_neox_kv_write        -- Fused RoPE + KV cache write
     d. flash_attention_decode_*  -- Attention (variant depends on kv_seq_len)
     e. gemv_q4_0                 -- Output projection
     f. residual_add              -- Residual connection
     g. rmsnorm_f16               -- FFN norm
     h. gemv_q4_0  (x2)          -- Gate and Up projections
     i. silu                      -- SiLU activation (or fused into gemv_q4_0_silu)
     j. gemv_q4_0                 -- Down projection
     k. residual_add              -- Residual connection
  3. rmsnorm_f16                  -- Final norm
  4. gemv_wide_q4_0               -- LM head (vocab projection, large N)
  5. temperature_scale_f16        -- Temperature scaling
  6. argmax / topk_select         -- Sampling

That is roughly 13 kernel dispatches per layer, plus 4 for the head. For a 32-layer model, that is 420+ kernel dispatches per token. Each one is a separate computeEncoder.dispatchThreadgroups() call. The overhead per dispatch on Apple Silicon is about 1-3 microseconds, so dispatch overhead alone is 0.5-1.3 ms – which is why kernel fusion (like gemv_q4_0_silu and rope_neox_kv_write) matters so much.

With this map in hand, let us dive into the actual kernel implementations. Chapter 34 starts with the GEMV kernels – the single hottest code path in decode.