Appendix A: Metal Shader Parameter Reference
This appendix is a complete reference for every parameter struct defined in backend/metal/kernels/ShaderTypes.h. These structs are shared between Metal shader code (MSL) and the C++/Swift host code. They are the contract between CPU and GPU: the host fills in the fields, binds the struct as a setBytes argument, and the shader reads from it.
Every struct in this file is padded to 16-byte alignment boundaries for Metal argument buffer compatibility. The comment at the top of ShaderTypes.h puts it plainly:
All structs are padded to 16-byte boundaries for Metal argument buffer alignment.
Any change to these structs must be mirrored in Sources/KernelStore/MetalTypes.swift to keep the Swift binding in sync.
How to Read the Tables
Each struct is documented with:
- Total size: The
sizeofthe struct in bytes - Alignment: The alignment requirement (always 16 bytes for Metal compatibility)
- Field table: Name, C type, byte offset, size, and notes
Byte offsets are calculated from the struct layout assuming standard C packing rules with the explicit padding fields (_pad0, _pad1, etc.) that akunu includes. The padding fields exist to ensure the struct is a multiple of 16 bytes and that fields after padding land on natural alignment boundaries.
Fields marked with “patched per-token” are dynamically modified during chain decode – the DispatchCmd::patch_type mechanism overwrites these fields at specific byte offsets for each token in the batch.
GEMMParams
General matrix-matrix multiplication parameters. Used by all simd_gemm_* and simd_gemm_small_* kernels during prefill.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
M | uint32_t | 0 | 4 | Rows of A / rows of output C |
N | uint32_t | 4 | 4 | Columns of B / columns of output C |
K | uint32_t | 8 | 4 | Columns of A / rows of B (contraction dimension) |
lda | uint32_t | 12 | 4 | Leading dimension of A (typically K for row-major) |
ldb | uint32_t | 16 | 4 | Leading dimension of B (typically N for row-major) |
ldc | uint32_t | 20 | 4 | Leading dimension of C (typically N for row-major) |
alpha | float | 24 | 4 | Scale factor: C = alpha * A @ B + beta * C |
beta | float | 28 | 4 | Accumulation factor: C = alpha * A @ B + beta * C |
Notes: The leading dimension fields (lda, ldb, ldc) allow non-contiguous matrix views. When matrices are contiguous row-major, lda = K, ldb = N, ldc = N. The alpha/beta fields support BLAS-style C = alpha*A@B + beta*C but in practice akunu always uses alpha=1.0, beta=0.0 (pure multiply, no accumulation).
ElementwiseParams
Parameters for element-wise kernels (add, multiply, activation functions applied to flat buffers).
Total size: 16 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
count | uint32_t | 0 | 4 | Total number of elements to process |
_pad0 | uint32_t | 4 | 4 | Padding (unused) |
_pad1 | uint32_t | 8 | 4 | Padding (unused) |
_pad2 | uint32_t | 12 | 4 | Padding (unused) |
Notes: 12 bytes of explicit padding to reach 16-byte alignment. The kernel dispatches ceil(count / threadgroup_size) threadgroups. Each thread processes one element at index thread_position_in_grid.
AttentionParams
Parameters for the attention kernel. Handles both prefill (multi-token) and decode (single token) attention, including GQA (Grouped Query Attention) where n_kv_heads < n_heads.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
seq_len | uint32_t | 0 | 4 | Query sequence length (1 for decode, N for prefill) |
kv_seq_len | uint32_t | 4 | 4 | KV cache length (may differ from seq_len during decode). Patched per-token in chain decode. |
head_dim | uint32_t | 8 | 4 | Dimension per attention head |
n_heads | uint32_t | 12 | 4 | Number of query heads |
n_kv_heads | uint32_t | 16 | 4 | Number of key/value heads (GQA: n_kv_heads <= n_heads) |
scale | float | 20 | 4 | Attention scale factor: 1.0 / sqrt(head_dim) |
kv_stride | uint32_t | 24 | 4 | Elements between KV heads in cache: max_seq_len * head_dim. 0 = use kv_seq_len * head_dim |
q_stride | uint32_t | 28 | 4 | Elements between Q/O rows. 0 = n_heads * head_dim (contiguous) |
Notes: The kv_stride field encodes the head-major KV cache layout. For a cache shaped [n_kv_heads, max_seq_len, head_dim], the stride between heads is max_seq_len * head_dim elements. When kv_stride = 0, the kernel computes the stride from kv_seq_len * head_dim, which is the dense (no padding) case. The kv_seq_len field is patched per-token during chain decode using PATCH_KV_SEQ_LEN or PATCH_POS_AND_KV.
RMSNormParams
Parameters for RMSNorm (Root Mean Square Layer Normalization).1
Total size: 16 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
dim | uint32_t | 0 | 4 | Vector dimension to normalize |
eps | float | 4 | 4 | Epsilon for numerical stability (typically 1e-5 or 1e-6) |
_pad0 | uint32_t | 8 | 4 | Padding (unused) |
_pad1 | uint32_t | 12 | 4 | Padding (unused) |
Notes: The kernel computes rms = sqrt(sum(x[i]^2) / dim + eps) then out[i] = x[i] / rms * weight[i]. The dim field must match the weight buffer length. The threadgroup reduces to compute the sum-of-squares, then each thread normalizes its element.
LayerNormParams
Parameters for standard LayerNorm (used by Whisper).
Total size: 16 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
dim | uint32_t | 0 | 4 | Vector dimension to normalize |
eps | float | 4 | 4 | Epsilon for numerical stability |
_pad0 | uint32_t | 8 | 4 | Padding (unused) |
_pad1 | uint32_t | 12 | 4 | Padding (unused) |
Notes: Identical layout to RMSNormParams. The kernel is different: it computes mean and variance, then normalizes as (x - mean) / sqrt(var + eps) * weight + bias. The bias buffer is an additional binding not captured in the params struct.
SoftmaxParams
Parameters for the standalone softmax kernel.
Total size: 16 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
rows | uint32_t | 0 | 4 | Number of rows (independent softmax operations) |
cols | uint32_t | 4 | 4 | Number of columns (softmax dimension per row) |
_pad0 | uint32_t | 8 | 4 | Padding (unused) |
_pad1 | uint32_t | 12 | 4 | Padding (unused) |
Notes: Each row is an independent softmax: out[r][c] = exp(x[r][c] - max_r) / sum_r(exp(x[r][c] - max_r)). Used for the final softmax in attention during prefill. During decode, the softmax is typically fused into the attention kernel.
RoPEParams
Parameters for the standalone RoPE (Rotary Position Embedding) kernel. Used during prefill when the fused RoPE+QKV+KV-write kernel is not applicable.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
seq_len | uint32_t | 0 | 4 | Number of positions to rotate |
head_dim | uint32_t | 4 | 4 | Dimension per head (rotation applies to pairs) |
n_heads | uint32_t | 8 | 4 | Number of heads in the input tensor |
pos_offset | uint32_t | 12 | 4 | Global position offset for decode step. Patched per-token. |
theta | float | 16 | 4 | RoPE base frequency (default 10000.0) |
row_stride | uint32_t | 20 | 4 | Elements between rows. 0 = n_heads * head_dim (contiguous) |
_pad0 | uint32_t | 24 | 4 | Padding (unused) |
_pad1 | uint32_t | 28 | 4 | Padding (unused) |
Notes: RoPE rotates dimension pairs (2i, 2i+1) by angle pos * theta^(-2i/head_dim). Two kernel variants exist: rope_f16 (interleaved, LLaMA-style) and rope_neox_f16 (split-half, NeoX-style where the first head_dim/2 elements are the “real” part and the second half is the “imaginary” part). The ArchDescriptor::rope_standalone field selects which variant to use.
EmbeddingParams
Parameters for the token embedding lookup kernel.
Total size: 16 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
num_tokens | uint32_t | 0 | 4 | Number of tokens to look up (1 for decode, N for prefill) |
dim | uint32_t | 4 | 4 | Embedding dimension per token |
_pad0 | uint32_t | 8 | 4 | Padding (unused) |
_pad1 | uint32_t | 12 | 4 | Padding (unused) |
Notes: The kernel reads token IDs from a uint32 buffer and writes the corresponding embedding rows to the output buffer. For quantized embeddings (Q4_0, Q4_K, Q6_K, Q8_0, MLX formats), a specialized dequantizing embedding kernel is used that dequantizes on the fly and outputs FP16. The kernel name is selected by dtype_descriptor.h::embedding_kernel_for().
KVCacheWriteParams
Parameters for writing new key/value vectors into the KV cache.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
n_kv_heads | uint32_t | 0 | 4 | Number of KV heads |
head_dim | uint32_t | 4 | 4 | Dimension per head |
max_seq_len | uint32_t | 8 | 4 | Maximum sequence length (cache dimension) |
pos | uint32_t | 12 | 4 | Write position (single-token) or batch offset. Patched per-token. |
src_stride | uint32_t | 16 | 4 | Elements between rows in source. 0 = n_kv_heads * head_dim |
seq_len | uint32_t | 20 | 4 | Batch sequence length (1 for decode, N for prefill batch) |
_pad0 | uint32_t | 24 | 4 | Padding (unused) |
_pad1 | uint32_t | 28 | 4 | Padding (unused) |
Notes: Writes into the head-major KV cache at cache[head][pos][0..head_dim-1]. The destination offset is computed as head * max_seq_len * head_dim + pos * head_dim. For batch writes during prefill, seq_len > 1 and the kernel writes seq_len consecutive positions starting at pos.
RoPEQKVWriteParams
Parameters for the fused kernel that applies RoPE rotation to Q and K, then writes K and V into the KV cache. This is the most complex parameter struct and the workhorse of the decode path.
Total size: 36 bytes | Alignment: 16 bytes (padded to 48 bytes in practice due to Metal alignment)
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
n_kv_heads | uint32_t | 0 | 4 | Number of KV heads |
head_dim | uint32_t | 4 | 4 | Dimension per head |
max_seq_len | uint32_t | 8 | 4 | KV cache max sequence length |
pos | uint32_t | 12 | 4 | Current sequence position. Patched per-token. |
theta | float | 16 | 4 | RoPE base frequency |
n_heads | uint32_t | 20 | 4 | Number of Q heads |
k_elem_offset | uint32_t | 24 | 4 | Element offset to K section in QKV buffer |
v_elem_offset | uint32_t | 28 | 4 | Element offset to V section in QKV buffer |
freq_scale | float | 32 | 4 | Linear RoPE scaling: 1/factor (1.0 = no scaling) |
Notes: This struct drives the fused rope_qkv_write_f16 and rope_neox_qkv_write_f16 kernels. The input is the contiguous QKV buffer [q_dim + 2*kv_dim] output by the QKV GEMV. The kernel:
- Reads Q elements, applies RoPE rotation, writes back to Q section (in-place)
- Reads K elements, applies RoPE rotation, writes to KV cache K buffer at
pos - Reads V elements (no rotation), writes to KV cache V buffer at
pos
The k_elem_offset and v_elem_offset fields tell the kernel where K and V start in the QKV buffer. For a model with q_dim=4096, kv_dim=1024: k_elem_offset = 4096, v_elem_offset = 4096 + 1024 = 5120. The freq_scale field supports extended context via linear RoPE scaling (e.g., freq_scale = 0.25 for 4x context extension).
KVCacheShiftParams
Parameters for shifting the KV cache contents (sliding window eviction).
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
n_kv_heads | uint32_t | 0 | 4 | Number of KV heads |
head_dim | uint32_t | 4 | 4 | Dimension per head |
max_seq_len | uint32_t | 8 | 4 | Cache max sequence length |
shift | uint32_t | 12 | 4 | Number of positions to shift left (evict oldest) |
new_len | uint32_t | 16 | 4 | New sequence length after shift |
_pad0 | uint32_t | 20 | 4 | Padding (unused) |
_pad1 | uint32_t | 24 | 4 | Padding (unused) |
_pad2 | uint32_t | 28 | 4 | Padding (unused) |
Notes: Shifts cache contents by shift positions. Entries at positions [shift, shift+new_len) are moved to [0, new_len). This is a per-head memmove operation. Used when the KV cache fills up and the oldest tokens need to be evicted.
HeadNormParams
Parameters for per-head RMSNorm, used by architectures with QK normalization (Qwen3, Gemma).
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
head_dim | uint32_t | 0 | 4 | Elements per head |
n_heads | uint32_t | 4 | 4 | Number of heads to normalize |
seq_len | uint32_t | 8 | 4 | Sequence length (1 for decode, N for prefill) |
eps | float | 12 | 4 | Norm epsilon |
_pad0 | uint32_t | 16 | 4 | Padding (unused) |
_pad1 | uint32_t | 20 | 4 | Padding (unused) |
_pad2 | uint32_t | 24 | 4 | Padding (unused) |
_pad3 | uint32_t | 28 | 4 | Padding (unused) |
Notes: Unlike standard RMSNorm which normalizes the entire hidden vector, per-head norm applies RMSNorm independently to each head_dim-sized slice. For Q normalization: Q[h] = rmsnorm(Q[h], q_norm_weight) for each head h. Each threadgroup handles one head.
GEMVHeadNormParams
Parameters for the fused GEMV + per-head RMSNorm kernel. Combines a matrix-vector multiply with per-head normalization in a single dispatch.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
N | uint32_t | 0 | 4 | Total output dimension (n_heads * head_dim) |
K | uint32_t | 4 | 4 | Input dimension |
head_dim | uint32_t | 8 | 4 | Elements per head |
n_heads | uint32_t | 12 | 4 | Number of heads |
eps | float | 16 | 4 | Norm epsilon |
_pad0 | uint32_t | 20 | 4 | Padding (unused) |
_pad1 | uint32_t | 24 | 4 | Padding (unused) |
_pad2 | uint32_t | 28 | 4 | Padding (unused) |
Notes: Each threadgroup computes one output row (one head’s worth of elements) via GEMV, then applies RMSNorm to the result. This fuses Q = W_q @ x; Q[h] = rmsnorm(Q[h]) into a single kernel, eliminating the intermediate write of the un-normalized Q vector.
TemperatureScaleParams
Parameters for applying temperature scaling to the logits buffer.
Total size: 16 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
inv_temperature | float | 0 | 4 | Inverse temperature: 1.0 / temperature. Logits are multiplied by this value. |
count | uint32_t | 4 | 4 | Number of logit elements (vocabulary size) |
_pad0 | uint32_t | 8 | 4 | Padding (unused) |
_pad1 | uint32_t | 12 | 4 | Padding (unused) |
Notes: Applies logits[i] *= inv_temperature in-place. Using inverse temperature (multiply instead of divide) avoids a per-element division in the shader. Called via akunu_gpu_temperature_scale() in the C API.
RepetitionPenaltyParams
Parameters for applying repetition penalty to logits.
Total size: 16 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
penalty | float | 0 | 4 | Repetition penalty factor (>1.0 penalizes, <1.0 encourages) |
n_tokens | uint32_t | 4 | 4 | Number of token IDs in the penalty list |
_pad0 | uint32_t | 8 | 4 | Padding (unused) |
_pad1 | uint32_t | 12 | 4 | Padding (unused) |
Notes: For each token in the penalty list: if logit > 0, divide by penalty; if logit < 0, multiply by penalty. The token ID list is passed as a separate buffer binding. Called via akunu_gpu_repetition_penalty() in the C API.
MLXParams
Parameters for MLX-format quantized GEMV and embedding kernels. MLX uses group quantization: weights are packed with bits-per-value in groups of group_size, with FP16 scale and bias per group.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
M | uint32_t | 0 | 4 | Batch size (1 for GEMV, num_tokens for batch embedding) |
N | uint32_t | 4 | 4 | Output dimension (weight rows / vocab_size) |
K | uint32_t | 8 | 4 | Input dimension (unpacked element count) |
group_size | uint32_t | 12 | 4 | Quantization group size (typically 64) |
bits | uint32_t | 16 | 4 | Bits per quantized value (3, 4, 6, or 8) |
weight_bytes | uint32_t | 20 | 4 | Byte offset to scales section within weight buffer |
_pad0 | uint32_t | 24 | 4 | Padding (unused) |
_pad1 | uint32_t | 28 | 4 | Padding (unused) |
Notes: The MLX weight buffer layout is [packed_weights | scales | biases]. The weight_bytes field gives the byte offset where scales begin. Biases follow immediately after scales. The packed weight format differs by bit-width: 4-bit packs 8 values per uint32, 3-bit packs values with a more complex scheme, 8-bit uses one byte per value. The dequantization formula is: value = scale * (packed_int - bias) per group.
Conv1DParams
Parameters for 1D convolution, used by Whisper’s audio frontend.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
in_channels | uint32_t | 0 | 4 | Number of input channels |
out_channels | uint32_t | 4 | 4 | Number of output channels (filters) |
kernel_size | uint32_t | 8 | 4 | Convolution kernel width |
stride | uint32_t | 12 | 4 | Stride between convolution windows |
in_length | uint32_t | 16 | 4 | Input sequence length |
out_length | uint32_t | 20 | 4 | Output sequence length: (in_length + 2*padding - kernel_size) / stride + 1 |
padding | uint32_t | 24 | 4 | Zero-padding on each side of input |
_pad0 | uint32_t | 28 | 4 | Padding (unused) |
Notes: Used for Whisper’s two-layer Conv1D frontend that processes mel spectrograms. The first conv has kernel_size=3, stride=1, padding=1; the second has kernel_size=3, stride=2, padding=1, downsampling the time dimension by 2x.
GEMVKVParams
Parameters for the fused GEMV + KV cache write kernel. Combines a projection GEMV with direct KV cache insertion.
Total size: 32 bytes | Alignment: 16 bytes
| Field | Type | Offset | Size | Description |
|---|---|---|---|---|
N | uint32_t | 0 | 4 | Output dimension (n_kv_heads * head_dim) |
K | uint32_t | 4 | 4 | Input dimension |
head_dim | uint32_t | 8 | 4 | Dimension per head |
max_seq_len | uint32_t | 12 | 4 | KV cache max sequence length |
pos | uint32_t | 16 | 4 | Write position in cache. Patched per-token. |
_pad0 | uint32_t | 20 | 4 | Padding (unused) |
_pad1 | uint32_t | 24 | 4 | Padding (unused) |
_pad2 | uint32_t | 28 | 4 | Padding (unused) |
Notes: Each threadgroup computes a slice of the GEMV output and writes it directly to the KV cache buffer at the correct position, bypassing the intermediate scratch buffer. The cache write offset is (row / head_dim) * max_seq_len * head_dim + pos * head_dim + (row % head_dim), implementing the head-major layout.
Quick Reference Summary
| Struct | Size (bytes) | Used By | Patched Fields |
|---|---|---|---|
GEMMParams | 32 | simd_gemm_* (prefill) | None |
ElementwiseParams | 16 | Elementwise ops (add, residual) | None |
AttentionParams | 32 | Attention kernels | kv_seq_len |
RMSNormParams | 16 | rmsnorm_* | None |
LayerNormParams | 16 | layernorm_* (Whisper) | None |
SoftmaxParams | 16 | softmax_* | None |
RoPEParams | 32 | rope_f16, rope_neox_f16 | pos_offset |
EmbeddingParams | 16 | embedding_lookup_* | None |
KVCacheWriteParams | 32 | KV cache write kernels | pos |
RoPEQKVWriteParams | 36 | rope_qkv_write_f16, rope_neox_qkv_write_f16 | pos |
KVCacheShiftParams | 32 | KV cache shift kernel | None |
HeadNormParams | 32 | Per-head RMSNorm | None |
GEMVHeadNormParams | 32 | Fused GEMV + head norm | None |
TemperatureScaleParams | 16 | Temperature scaling | None |
RepetitionPenaltyParams | 16 | Repetition penalty | None |
MLXParams | 32 | MLX quantized GEMV/embedding | None |
Conv1DParams | 32 | Conv1D (Whisper frontend) | None |
GEMVKVParams | 32 | Fused GEMV + KV write | pos |
Alignment and Padding Rules
All structs follow these conventions:
-
16-byte total size minimum: Every struct is at least 16 bytes. Metal’s
setBytesrequires 16-byte aligned data for argument buffer compatibility. -
Explicit padding fields: Rather than relying on compiler-inserted padding (which varies across compilers and platforms), akunu includes explicit
_pad0,_pad1, etc. fields. This makes the layout identical whether compiled as C, C++, Objective-C++, or MSL. -
Natural alignment: All
uint32_tfields are at 4-byte-aligned offsets. Allfloatfields are at 4-byte-aligned offsets. No field crosses a natural alignment boundary. -
Cross-language compatibility: The same header (
ShaderTypes.h) is included by both Metal shaders (via#ifdef __METAL_VERSION__) and host code. The#ifdefguard switches between<metal_stdlib>and<simd/simd.h>for type compatibility, but the struct layouts are identical in both compilation contexts.
When adding a new parameter struct, follow the pattern: use uint32_t and float fields only, pad to a multiple of 16 bytes with explicit _padN fields, and add a // MARK: section comment with the struct size.
-
Zhang, B. & Sennrich, R. (2019). “Root Mean Square Layer Normalization.” NeurIPS 2019. RMSNorm computes
x / sqrt(mean(x^2) + eps) * weight, omitting the mean subtraction of standard LayerNorm. See https://arxiv.org/abs/1910.07467. ↩