Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Appendix A: Metal Shader Parameter Reference

This appendix is a complete reference for every parameter struct defined in backend/metal/kernels/ShaderTypes.h. These structs are shared between Metal shader code (MSL) and the C++/Swift host code. They are the contract between CPU and GPU: the host fills in the fields, binds the struct as a setBytes argument, and the shader reads from it.

Every struct in this file is padded to 16-byte alignment boundaries for Metal argument buffer compatibility. The comment at the top of ShaderTypes.h puts it plainly:

All structs are padded to 16-byte boundaries for Metal argument buffer alignment.

Any change to these structs must be mirrored in Sources/KernelStore/MetalTypes.swift to keep the Swift binding in sync.

How to Read the Tables

Each struct is documented with:

  • Total size: The sizeof the struct in bytes
  • Alignment: The alignment requirement (always 16 bytes for Metal compatibility)
  • Field table: Name, C type, byte offset, size, and notes

Byte offsets are calculated from the struct layout assuming standard C packing rules with the explicit padding fields (_pad0, _pad1, etc.) that akunu includes. The padding fields exist to ensure the struct is a multiple of 16 bytes and that fields after padding land on natural alignment boundaries.

Fields marked with “patched per-token” are dynamically modified during chain decode – the DispatchCmd::patch_type mechanism overwrites these fields at specific byte offsets for each token in the batch.


GEMMParams

General matrix-matrix multiplication parameters. Used by all simd_gemm_* and simd_gemm_small_* kernels during prefill.

Total size: 32 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
Muint32_t04Rows of A / rows of output C
Nuint32_t44Columns of B / columns of output C
Kuint32_t84Columns of A / rows of B (contraction dimension)
ldauint32_t124Leading dimension of A (typically K for row-major)
ldbuint32_t164Leading dimension of B (typically N for row-major)
ldcuint32_t204Leading dimension of C (typically N for row-major)
alphafloat244Scale factor: C = alpha * A @ B + beta * C
betafloat284Accumulation factor: C = alpha * A @ B + beta * C

Notes: The leading dimension fields (lda, ldb, ldc) allow non-contiguous matrix views. When matrices are contiguous row-major, lda = K, ldb = N, ldc = N. The alpha/beta fields support BLAS-style C = alpha*A@B + beta*C but in practice akunu always uses alpha=1.0, beta=0.0 (pure multiply, no accumulation).


ElementwiseParams

Parameters for element-wise kernels (add, multiply, activation functions applied to flat buffers).

Total size: 16 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
countuint32_t04Total number of elements to process
_pad0uint32_t44Padding (unused)
_pad1uint32_t84Padding (unused)
_pad2uint32_t124Padding (unused)

Notes: 12 bytes of explicit padding to reach 16-byte alignment. The kernel dispatches ceil(count / threadgroup_size) threadgroups. Each thread processes one element at index thread_position_in_grid.


AttentionParams

Parameters for the attention kernel. Handles both prefill (multi-token) and decode (single token) attention, including GQA (Grouped Query Attention) where n_kv_heads < n_heads.

Total size: 32 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
seq_lenuint32_t04Query sequence length (1 for decode, N for prefill)
kv_seq_lenuint32_t44KV cache length (may differ from seq_len during decode). Patched per-token in chain decode.
head_dimuint32_t84Dimension per attention head
n_headsuint32_t124Number of query heads
n_kv_headsuint32_t164Number of key/value heads (GQA: n_kv_heads <= n_heads)
scalefloat204Attention scale factor: 1.0 / sqrt(head_dim)
kv_strideuint32_t244Elements between KV heads in cache: max_seq_len * head_dim. 0 = use kv_seq_len * head_dim
q_strideuint32_t284Elements between Q/O rows. 0 = n_heads * head_dim (contiguous)

Notes: The kv_stride field encodes the head-major KV cache layout. For a cache shaped [n_kv_heads, max_seq_len, head_dim], the stride between heads is max_seq_len * head_dim elements. When kv_stride = 0, the kernel computes the stride from kv_seq_len * head_dim, which is the dense (no padding) case. The kv_seq_len field is patched per-token during chain decode using PATCH_KV_SEQ_LEN or PATCH_POS_AND_KV.


RMSNormParams

Parameters for RMSNorm (Root Mean Square Layer Normalization).1

Total size: 16 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
dimuint32_t04Vector dimension to normalize
epsfloat44Epsilon for numerical stability (typically 1e-5 or 1e-6)
_pad0uint32_t84Padding (unused)
_pad1uint32_t124Padding (unused)

Notes: The kernel computes rms = sqrt(sum(x[i]^2) / dim + eps) then out[i] = x[i] / rms * weight[i]. The dim field must match the weight buffer length. The threadgroup reduces to compute the sum-of-squares, then each thread normalizes its element.


LayerNormParams

Parameters for standard LayerNorm (used by Whisper).

Total size: 16 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
dimuint32_t04Vector dimension to normalize
epsfloat44Epsilon for numerical stability
_pad0uint32_t84Padding (unused)
_pad1uint32_t124Padding (unused)

Notes: Identical layout to RMSNormParams. The kernel is different: it computes mean and variance, then normalizes as (x - mean) / sqrt(var + eps) * weight + bias. The bias buffer is an additional binding not captured in the params struct.


SoftmaxParams

Parameters for the standalone softmax kernel.

Total size: 16 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
rowsuint32_t04Number of rows (independent softmax operations)
colsuint32_t44Number of columns (softmax dimension per row)
_pad0uint32_t84Padding (unused)
_pad1uint32_t124Padding (unused)

Notes: Each row is an independent softmax: out[r][c] = exp(x[r][c] - max_r) / sum_r(exp(x[r][c] - max_r)). Used for the final softmax in attention during prefill. During decode, the softmax is typically fused into the attention kernel.


RoPEParams

Parameters for the standalone RoPE (Rotary Position Embedding) kernel. Used during prefill when the fused RoPE+QKV+KV-write kernel is not applicable.

Total size: 32 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
seq_lenuint32_t04Number of positions to rotate
head_dimuint32_t44Dimension per head (rotation applies to pairs)
n_headsuint32_t84Number of heads in the input tensor
pos_offsetuint32_t124Global position offset for decode step. Patched per-token.
thetafloat164RoPE base frequency (default 10000.0)
row_strideuint32_t204Elements between rows. 0 = n_heads * head_dim (contiguous)
_pad0uint32_t244Padding (unused)
_pad1uint32_t284Padding (unused)

Notes: RoPE rotates dimension pairs (2i, 2i+1) by angle pos * theta^(-2i/head_dim). Two kernel variants exist: rope_f16 (interleaved, LLaMA-style) and rope_neox_f16 (split-half, NeoX-style where the first head_dim/2 elements are the “real” part and the second half is the “imaginary” part). The ArchDescriptor::rope_standalone field selects which variant to use.


EmbeddingParams

Parameters for the token embedding lookup kernel.

Total size: 16 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
num_tokensuint32_t04Number of tokens to look up (1 for decode, N for prefill)
dimuint32_t44Embedding dimension per token
_pad0uint32_t84Padding (unused)
_pad1uint32_t124Padding (unused)

Notes: The kernel reads token IDs from a uint32 buffer and writes the corresponding embedding rows to the output buffer. For quantized embeddings (Q4_0, Q4_K, Q6_K, Q8_0, MLX formats), a specialized dequantizing embedding kernel is used that dequantizes on the fly and outputs FP16. The kernel name is selected by dtype_descriptor.h::embedding_kernel_for().


KVCacheWriteParams

Parameters for writing new key/value vectors into the KV cache.

Total size: 32 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
n_kv_headsuint32_t04Number of KV heads
head_dimuint32_t44Dimension per head
max_seq_lenuint32_t84Maximum sequence length (cache dimension)
posuint32_t124Write position (single-token) or batch offset. Patched per-token.
src_strideuint32_t164Elements between rows in source. 0 = n_kv_heads * head_dim
seq_lenuint32_t204Batch sequence length (1 for decode, N for prefill batch)
_pad0uint32_t244Padding (unused)
_pad1uint32_t284Padding (unused)

Notes: Writes into the head-major KV cache at cache[head][pos][0..head_dim-1]. The destination offset is computed as head * max_seq_len * head_dim + pos * head_dim. For batch writes during prefill, seq_len > 1 and the kernel writes seq_len consecutive positions starting at pos.


RoPEQKVWriteParams

Parameters for the fused kernel that applies RoPE rotation to Q and K, then writes K and V into the KV cache. This is the most complex parameter struct and the workhorse of the decode path.

Total size: 36 bytes | Alignment: 16 bytes (padded to 48 bytes in practice due to Metal alignment)

FieldTypeOffsetSizeDescription
n_kv_headsuint32_t04Number of KV heads
head_dimuint32_t44Dimension per head
max_seq_lenuint32_t84KV cache max sequence length
posuint32_t124Current sequence position. Patched per-token.
thetafloat164RoPE base frequency
n_headsuint32_t204Number of Q heads
k_elem_offsetuint32_t244Element offset to K section in QKV buffer
v_elem_offsetuint32_t284Element offset to V section in QKV buffer
freq_scalefloat324Linear RoPE scaling: 1/factor (1.0 = no scaling)

Notes: This struct drives the fused rope_qkv_write_f16 and rope_neox_qkv_write_f16 kernels. The input is the contiguous QKV buffer [q_dim + 2*kv_dim] output by the QKV GEMV. The kernel:

  1. Reads Q elements, applies RoPE rotation, writes back to Q section (in-place)
  2. Reads K elements, applies RoPE rotation, writes to KV cache K buffer at pos
  3. Reads V elements (no rotation), writes to KV cache V buffer at pos

The k_elem_offset and v_elem_offset fields tell the kernel where K and V start in the QKV buffer. For a model with q_dim=4096, kv_dim=1024: k_elem_offset = 4096, v_elem_offset = 4096 + 1024 = 5120. The freq_scale field supports extended context via linear RoPE scaling (e.g., freq_scale = 0.25 for 4x context extension).


KVCacheShiftParams

Parameters for shifting the KV cache contents (sliding window eviction).

Total size: 32 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
n_kv_headsuint32_t04Number of KV heads
head_dimuint32_t44Dimension per head
max_seq_lenuint32_t84Cache max sequence length
shiftuint32_t124Number of positions to shift left (evict oldest)
new_lenuint32_t164New sequence length after shift
_pad0uint32_t204Padding (unused)
_pad1uint32_t244Padding (unused)
_pad2uint32_t284Padding (unused)

Notes: Shifts cache contents by shift positions. Entries at positions [shift, shift+new_len) are moved to [0, new_len). This is a per-head memmove operation. Used when the KV cache fills up and the oldest tokens need to be evicted.


HeadNormParams

Parameters for per-head RMSNorm, used by architectures with QK normalization (Qwen3, Gemma).

Total size: 32 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
head_dimuint32_t04Elements per head
n_headsuint32_t44Number of heads to normalize
seq_lenuint32_t84Sequence length (1 for decode, N for prefill)
epsfloat124Norm epsilon
_pad0uint32_t164Padding (unused)
_pad1uint32_t204Padding (unused)
_pad2uint32_t244Padding (unused)
_pad3uint32_t284Padding (unused)

Notes: Unlike standard RMSNorm which normalizes the entire hidden vector, per-head norm applies RMSNorm independently to each head_dim-sized slice. For Q normalization: Q[h] = rmsnorm(Q[h], q_norm_weight) for each head h. Each threadgroup handles one head.


GEMVHeadNormParams

Parameters for the fused GEMV + per-head RMSNorm kernel. Combines a matrix-vector multiply with per-head normalization in a single dispatch.

Total size: 32 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
Nuint32_t04Total output dimension (n_heads * head_dim)
Kuint32_t44Input dimension
head_dimuint32_t84Elements per head
n_headsuint32_t124Number of heads
epsfloat164Norm epsilon
_pad0uint32_t204Padding (unused)
_pad1uint32_t244Padding (unused)
_pad2uint32_t284Padding (unused)

Notes: Each threadgroup computes one output row (one head’s worth of elements) via GEMV, then applies RMSNorm to the result. This fuses Q = W_q @ x; Q[h] = rmsnorm(Q[h]) into a single kernel, eliminating the intermediate write of the un-normalized Q vector.


TemperatureScaleParams

Parameters for applying temperature scaling to the logits buffer.

Total size: 16 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
inv_temperaturefloat04Inverse temperature: 1.0 / temperature. Logits are multiplied by this value.
countuint32_t44Number of logit elements (vocabulary size)
_pad0uint32_t84Padding (unused)
_pad1uint32_t124Padding (unused)

Notes: Applies logits[i] *= inv_temperature in-place. Using inverse temperature (multiply instead of divide) avoids a per-element division in the shader. Called via akunu_gpu_temperature_scale() in the C API.


RepetitionPenaltyParams

Parameters for applying repetition penalty to logits.

Total size: 16 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
penaltyfloat04Repetition penalty factor (>1.0 penalizes, <1.0 encourages)
n_tokensuint32_t44Number of token IDs in the penalty list
_pad0uint32_t84Padding (unused)
_pad1uint32_t124Padding (unused)

Notes: For each token in the penalty list: if logit > 0, divide by penalty; if logit < 0, multiply by penalty. The token ID list is passed as a separate buffer binding. Called via akunu_gpu_repetition_penalty() in the C API.


MLXParams

Parameters for MLX-format quantized GEMV and embedding kernels. MLX uses group quantization: weights are packed with bits-per-value in groups of group_size, with FP16 scale and bias per group.

Total size: 32 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
Muint32_t04Batch size (1 for GEMV, num_tokens for batch embedding)
Nuint32_t44Output dimension (weight rows / vocab_size)
Kuint32_t84Input dimension (unpacked element count)
group_sizeuint32_t124Quantization group size (typically 64)
bitsuint32_t164Bits per quantized value (3, 4, 6, or 8)
weight_bytesuint32_t204Byte offset to scales section within weight buffer
_pad0uint32_t244Padding (unused)
_pad1uint32_t284Padding (unused)

Notes: The MLX weight buffer layout is [packed_weights | scales | biases]. The weight_bytes field gives the byte offset where scales begin. Biases follow immediately after scales. The packed weight format differs by bit-width: 4-bit packs 8 values per uint32, 3-bit packs values with a more complex scheme, 8-bit uses one byte per value. The dequantization formula is: value = scale * (packed_int - bias) per group.


Conv1DParams

Parameters for 1D convolution, used by Whisper’s audio frontend.

Total size: 32 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
in_channelsuint32_t04Number of input channels
out_channelsuint32_t44Number of output channels (filters)
kernel_sizeuint32_t84Convolution kernel width
strideuint32_t124Stride between convolution windows
in_lengthuint32_t164Input sequence length
out_lengthuint32_t204Output sequence length: (in_length + 2*padding - kernel_size) / stride + 1
paddinguint32_t244Zero-padding on each side of input
_pad0uint32_t284Padding (unused)

Notes: Used for Whisper’s two-layer Conv1D frontend that processes mel spectrograms. The first conv has kernel_size=3, stride=1, padding=1; the second has kernel_size=3, stride=2, padding=1, downsampling the time dimension by 2x.


GEMVKVParams

Parameters for the fused GEMV + KV cache write kernel. Combines a projection GEMV with direct KV cache insertion.

Total size: 32 bytes | Alignment: 16 bytes

FieldTypeOffsetSizeDescription
Nuint32_t04Output dimension (n_kv_heads * head_dim)
Kuint32_t44Input dimension
head_dimuint32_t84Dimension per head
max_seq_lenuint32_t124KV cache max sequence length
posuint32_t164Write position in cache. Patched per-token.
_pad0uint32_t204Padding (unused)
_pad1uint32_t244Padding (unused)
_pad2uint32_t284Padding (unused)

Notes: Each threadgroup computes a slice of the GEMV output and writes it directly to the KV cache buffer at the correct position, bypassing the intermediate scratch buffer. The cache write offset is (row / head_dim) * max_seq_len * head_dim + pos * head_dim + (row % head_dim), implementing the head-major layout.


Quick Reference Summary

StructSize (bytes)Used ByPatched Fields
GEMMParams32simd_gemm_* (prefill)None
ElementwiseParams16Elementwise ops (add, residual)None
AttentionParams32Attention kernelskv_seq_len
RMSNormParams16rmsnorm_*None
LayerNormParams16layernorm_* (Whisper)None
SoftmaxParams16softmax_*None
RoPEParams32rope_f16, rope_neox_f16pos_offset
EmbeddingParams16embedding_lookup_*None
KVCacheWriteParams32KV cache write kernelspos
RoPEQKVWriteParams36rope_qkv_write_f16, rope_neox_qkv_write_f16pos
KVCacheShiftParams32KV cache shift kernelNone
HeadNormParams32Per-head RMSNormNone
GEMVHeadNormParams32Fused GEMV + head normNone
TemperatureScaleParams16Temperature scalingNone
RepetitionPenaltyParams16Repetition penaltyNone
MLXParams32MLX quantized GEMV/embeddingNone
Conv1DParams32Conv1D (Whisper frontend)None
GEMVKVParams32Fused GEMV + KV writepos

Alignment and Padding Rules

All structs follow these conventions:

  1. 16-byte total size minimum: Every struct is at least 16 bytes. Metal’s setBytes requires 16-byte aligned data for argument buffer compatibility.

  2. Explicit padding fields: Rather than relying on compiler-inserted padding (which varies across compilers and platforms), akunu includes explicit _pad0, _pad1, etc. fields. This makes the layout identical whether compiled as C, C++, Objective-C++, or MSL.

  3. Natural alignment: All uint32_t fields are at 4-byte-aligned offsets. All float fields are at 4-byte-aligned offsets. No field crosses a natural alignment boundary.

  4. Cross-language compatibility: The same header (ShaderTypes.h) is included by both Metal shaders (via #ifdef __METAL_VERSION__) and host code. The #ifdef guard switches between <metal_stdlib> and <simd/simd.h> for type compatibility, but the struct layouts are identical in both compilation contexts.

When adding a new parameter struct, follow the pattern: use uint32_t and float fields only, pad to a multiple of 16 bytes with explicit _padN fields, and add a // MARK: section comment with the struct size.


  1. Zhang, B. & Sennrich, R. (2019). “Root Mean Square Layer Normalization.” NeurIPS 2019. RMSNorm computes x / sqrt(mean(x^2) + eps) * weight, omitting the mean subtraction of standard LayerNorm. See https://arxiv.org/abs/1910.07467.