Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

SIMD Group Matrix Operations

Matrix multiplication is the beating heart of every transformer. During a single forward pass of a 7B-parameter model, you perform roughly 30 matrix multiplications per layer across 32 layers – nearly a thousand matmuls per token. If your matmul is slow, everything is slow. On Apple Silicon, the key to fast matmuls is the simdgroup_matrix API: a set of hardware-accelerated operations that let 32 GPU threads cooperatively multiply 8x8 matrix tiles at full throughput.1

This chapter explains how SIMD group matrix operations work, how they compose into larger GEMM/GEMV kernels, and how akunu uses them for both prefill and decode.

What Is a SIMD Group?

Before diving into matrix operations, let’s establish the execution model. An Apple GPU organizes threads into SIMD groups (also called warps on NVIDIA, or wavefronts on AMD). A SIMD group on Apple Silicon is always 32 threads that execute in lockstep – every thread runs the same instruction at the same cycle.2

┌─────────────────────────────────────────────┐
│              Threadgroup (128 threads)       │
│  ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌───────────┐  │
│  │ SIMD 0  │ │ SIMD 1  │ │ SIMD 2  │ │ SIMD 3    │  │
│  │ 32 thds │ │ 32 thds │ │ 32 thds │ │ 32 thds   │  │
│  └─────────┘ └─────────┘ └─────────┘ └───────────┘  │
└─────────────────────────────────────────────┘

Within a SIMD group, threads can communicate through SIMD shuffle instructions – reading values from any other thread in the group without going through memory. This is extremely fast (single-cycle) and is the foundation of SIMD group matrix operations.

The simdgroup_matrix API

Starting with Apple GPU family 7 (M1 and later), Metal Shading Language provides the simdgroup_matrix type and associated operations.3 The fundamental unit is an 8x8 matrix distributed across the 32 threads of a SIMD group:

#include <metal_simdgroup_matrix>

simdgroup_half8x8   mat_a;   // 8x8 matrix of float16
simdgroup_float8x8  mat_b;   // 8x8 matrix of float32

Each thread in the SIMD group holds a portion of the 8x8 matrix. You do not control which thread holds which element – the hardware distributes the 64 elements across 32 threads (2 elements per thread) in a hardware-defined layout. This layout is opaque; you interact with the matrix through three operations:

load

simdgroup_half8x8 ma;
simdgroup_load(ma, src_ptr, stride);

Load an 8x8 tile from device or threadgroup memory. The stride is the number of elements between consecutive rows. All 32 threads participate in the load cooperatively – each thread reads its assigned 2 elements.

store

simdgroup_store(mc, dst_ptr, stride);

Store an 8x8 tile back to memory. Same cooperative pattern as load.

multiply_accumulate

simdgroup_multiply_accumulate(mc, ma, mb, mc);
// mc += ma * mb   (8x8 += 8x8 * 8x8)

This is the core operation. It multiplies two 8x8 matrices and accumulates the result into a third. The hardware performs 512 multiply-accumulate operations (8 * 8 * 8) in a single instruction across the SIMD group.4 On Apple GPU family 7+, this executes in a few cycles – dramatically faster than doing the same work with scalar or vector instructions.

How 32 Threads Hold One 8x8 Matrix

Let’s be precise about the data distribution. An 8x8 matrix has 64 elements. A SIMD group has 32 threads. So each thread holds exactly 2 elements. The exact mapping is hardware-defined and opaque, but conceptually:

32 Threads Holding an 8x8 Matrix (each thread holds 2 elements):

     col0  col1  col2  col3  col4  col5  col6  col7
    ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐
r0  │ T0  │ T0  │ T1  │ T1  │ T2  │ T2  │ T3  │ T3  │
    ├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
r1  │ T4  │ T4  │ T5  │ T5  │ T6  │ T6  │ T7  │ T7  │
    ├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
r2  │ T8  │ T8  │ T9  │ T9  │ T10 │ T10 │ T11 │ T11 │
    ├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
r3  │ T12 │ T12 │ T13 │ T13 │ T14 │ T14 │ T15 │ T15 │
    ├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
r4  │ T16 │ T16 │ T17 │ T17 │ T18 │ T18 │ T19 │ T19 │
    ├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
r5  │ T20 │ T20 │ T21 │ T21 │ T22 │ T22 │ T23 │ T23 │
    ├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
r6  │ T24 │ T24 │ T25 │ T25 │ T26 │ T26 │ T27 │ T27 │
    ├─────┼─────┼─────┼─────┼─────┼─────┼─────┼─────┤
r7  │ T28 │ T28 │ T29 │ T29 │ T30 │ T30 │ T31 │ T31 │
    └─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘

The key insight: you never access individual elements of a simdgroup_matrix. You load, you multiply-accumulate, you store. The hardware handles the internal layout. Trying to extract element [3][5] would require SIMD shuffles and defeat the purpose.

Building Larger Matmuls from 8x8 Tiles

An 8x8 tile is too small for real work. A typical LLM linear layer might be [4096, 4096] – that is 512 x 512 tiles. The strategy is to:

  1. Assign tiles to threadgroups. Each threadgroup computes a rectangular block of the output matrix.
  2. Within a threadgroup, each SIMD group accumulates a column of output tiles.
  3. Loop over the K dimension in steps of 8 (or larger), accumulating partial sums.

Akunu’s GEMM kernels follow the llama.cpp tile geometry, which uses TM=32 and TN=64:

Output tile per threadgroup: 32 rows (M) x 64 columns (N)
Threadgroup size: 128 threads = 4 SIMD groups
Each SIMD group handles: 32 rows x 16 columns (using 8x8 sub-tiles)

K-loop: stride 32 elements (NK=32)
  - Load A tile [32 x 32] into threadgroup memory
  - Load B tile [64 x 32] into threadgroup memory
  - Each SIMD group does: 4 rows x 2 columns x 4 K-steps of 8x8 MACs

Here is how this maps to akunu’s simd_gemm_f16.metal:

// Dispatch: grid=(ceil(N/64), ceil(M/32), 1), threads=(32,4,1)
// TG memory: 4096 + 2048 = 6144 bytes

constexpr short NR0 = 64;   // N tile (weight rows)
constexpr short NR1 = 32;   // M tile (activation rows)
constexpr short NK  = 32;   // K tile (reduction dimension)

simdgroup_half8x8 mc[8];    // 8 accumulator tiles per SIMD group
for (short i = 0; i < 8; i++) {
    mc[i] = make_filled_simdgroup_matrix<half, 8>(0.h);
}

for (uint loop_k = 0; loop_k < K; loop_k += NK) {
    // Load A and B tiles into threadgroup memory
    // ...
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Multiply: 8x8 tiles along K
    for (short k_step = 0; k_step < NK; k_step += 8) {
        simdgroup_half8x8 ma, mb;
        simdgroup_load(ma, sa + ..., ...);
        for (short tile = 0; tile < 8; tile++) {
            simdgroup_load(mb, sb + ..., ...);
            simdgroup_multiply_accumulate(mc[tile], ma, mb, mc[tile]);
        }
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
}

// Store accumulated results
for (short tile = 0; tile < 8; tile++) {
    simdgroup_store(mc[tile], C + ..., ldc);
}

The accumulator array mc[8] holds 8 output sub-tiles (8x8 each = 8 * 64 = 512 elements), which covers a 32x16 or 8x64 region depending on the layout. The K-loop processes 32 elements at a time, with each step performing 8x8 MACs.

The Tiling Hierarchy

Let’s trace how a large matmul decomposes:

Full matrix:  C[M,N] = A[M,K] * B^T[N,K]
Example:      C[4096, 4096] = A[4096, 4096] * B^T[4096, 4096]

Grid level:   ceil(4096/64) x ceil(4096/32) = 64 x 128 threadgroups
TG level:     Each TG computes C[32, 64] using 128 threads (4 SIMDs)
SIMD level:   Each SIMD accumulates C[8, 64] or similar sub-region
Tile level:   Each 8x8 MAC produces C[8,8] += A[8,8] * B[8,8]
LevelSizeUnitCount
Full matrix4096 x 4096Elements16.7M
Threadgroup tile32 x 64Elements8,192 tiles
SIMD accumulator8 x 8Elements per tile8 tiles/SIMD
Single MAC8 x 8 x 8FMA operations512 ops

Akunu’s GEMM Dispatch Geometry

Looking at simd_gemm_f16.metal, the dispatch parameters are:

grid    = (ceil(N/64), ceil(M/32), 1)
threads = (32, 4, 1)   // 128 threads = 4 SIMD groups
TG memory = 6144 bytes  // 4096 for A tile + 2048 for B tile

The threadgroup of 128 threads breaks down as:

  • 4 SIMD groups (sgitg = 0..3)
  • Each SIMD group processes a different row band of the M-tile
  • All 4 SIMD groups share the same B-tile data (loaded cooperatively into threadgroup memory)

For the K-loop:

  • NL0 = NK/16 = 2: each thread loads 2 chunks of B data
  • NL1 = NK/8 = 4: each thread loads 4 chunks of A data
  • Data flows: device memory -> threadgroup memory -> SIMD registers -> accumulate -> device memory

Small-M Variant: simd_gemm_small_f16

During decode (M=1 to ~8), the full GEMM is wasteful – most of the 32-row M-tile is empty. Akunu provides simd_gemm_small_f16 which is optimized for M <= 8:

// From dtype_descriptor.h:
return (M >= 2 && M <= 8) ? d.gemm_small_kernel : d.gemm_kernel;

The small variant uses a reduced M-tile size and fewer SIMD groups, trading parallelism for reduced overhead. For M=1 (single-token decode), akunu uses GEMV kernels instead of GEMM entirely – the dispatch geometry is fundamentally different.

GEMV: The Decode Workhorse

During autoregressive decode, M=1. This is a matrix-vector product, not a matrix-matrix product. The access pattern changes dramatically:

  • GEMM (prefill): Each weight element is reused across M activation rows. Arithmetic intensity is O(M). With M=4096, you get excellent compute utilization.
  • GEMV (decode): Each weight element is used exactly once. Arithmetic intensity is O(1). You are completely bandwidth-bound.5

Akunu’s GEMV kernels do not use simdgroup_matrix at all – they use vectorized loads and SIMD reduction instead. The GEMV kernel for Q4_0 quantized weights, for example, dequantizes blocks of 32 weights, multiplies them by the input vector, and reduces across the SIMD group using simd_sum.

The DTypeDescriptor table maps each quantization format to its GEMV kernel and dispatch geometry:

DtypeGEMV KernelRows/TGTG SizeWide Variant
F16gemv_f1616128gemv_wide_f16 (64 rows, 256 threads)
Q4_0gemv_q4_016128gemv_wide_q4_0 (64 rows, 256 threads)
Q4_Kgemv_q4_k16256gemv_wide_q4_k (32 rows, 256 threads)
Q8_0gemv_q8_032256gemv_wide_q8_0 (64 rows, 256 threads)
MLX Q4gemv_mlx_q416128gemv_wide_mlx_q4 (32 rows, 256 threads)
BF16gemv_bf1616128N/A

The “wide” variants use larger threadgroups (256 threads = 8 SIMD groups) to increase occupancy on Pro/Max chips with many GPU cores. The ChipConfig controls when to switch:

c.wide_gemv_threshold = 32768;  // use wide GEMV when N exceeds this

The Accumulation Flow: From Tiles to Output

                    8x8 Tile Accumulation Loop

   K=0..7       K=8..15      K=16..23              Result
  ┌──────┐     ┌──────┐     ┌──────┐              ┌──────┐
  │A[8x8]│     │A[8x8]│     │A[8x8]│              │      │
  └──┬───┘     └──┬───┘     └──┬───┘              │C[8x8]│
     ×            ×            ×          ...      │      │
  ┌──┴───┐     ┌──┴───┐     ┌──┴───┐              │Accum │
  │B[8x8]│     │B[8x8]│     │B[8x8]│              └──────┘
  └──┬───┘     └──┬───┘     └──┬───┘
     +            +            +  ──────────────►

The accumulation loop for a single output tile works as follows:

  1. Initialize mc = 0 (zero-filled 8x8 accumulator)
  2. For each K-step (stride 8 or 32):
    • Load an 8x8 slice of A into ma
    • Load an 8x8 slice of B into mb
    • mc += ma * mb (hardware multiply-accumulate)
  3. After K-loop completes, store mc to output

For K=4096, this loop runs 4096/8 = 512 times (or 4096/32 = 128 times if the K-stride is 32 with 4 sub-steps). Each iteration performs 512 FMA operations. Total: 512 * 512 = 262,144 FMA operations per 8x8 output tile, which is exactly 8 * 8 * 4096 – the correct number.

Quantized GEMM: Dequantize in Registers

For quantized weights (Q4_0, Q4_K, Q8_0, etc.), the GEMM kernels follow the same tiling strategy but add a dequantization step. The weight data is loaded in its quantized format and dequantized into threadgroup memory before the SIMD matrix operations:

For each K-step:
  1. Load quantized weight block from device memory
  2. Dequantize to FP16 in threadgroup memory
  3. Load A tile into threadgroup memory (already FP16)
  4. barrier()
  5. SIMD matrix multiply-accumulate from threadgroup memory
  6. barrier()

Akunu has separate GEMM kernels for each quantization format – simd_gemm_q4_0, simd_gemm_q4_k, simd_gemm_q8_0, etc. – because the dequantization logic is tightly coupled with the load pattern. A Q4_0 block is 18 bytes (16 four-bit values + 1 FP16 scale), while a Q4_K super-block is 144 bytes with nested sub-scales. The memory access pattern and threadgroup memory layout differ for each.

MLX Quantized Formats

Akunu also supports MLX SafeTensors quantized models. These use a different quantization format (affine quantization with configurable group size and bit width) and require function constants to specialize the kernel at pipeline creation time:

// From emit_gemv() in table_builder.cpp:
if (dt.is_mlx) {
    uint32_t fc_indices[] = {0, 1};
    uint32_t fc_values[] = {(uint32_t)quant_group_size, (uint32_t)K};
    pso = device.get_pipeline(kernel_name, cache_key, fc_indices, fc_values, 2);
}

Function constants are Metal’s mechanism for compile-time specialization. Instead of branching on group_size inside the kernel, you bake the value into the pipeline state object. The compiler can then optimize the kernel for that specific group size – unrolling loops, eliminating dead branches, and computing strides at compile time.6

Comparison to NVIDIA Tensor Cores

It is instructive to compare Apple’s SIMD group matrix operations to NVIDIA’s Tensor Cores:

FeatureApple simdgroup_matrixNVIDIA Tensor Cores (Ampere+)
Tile size8x816x8x16 (varies by precision)
Thread group32 threads (SIMD group)32 threads (warp)
PrecisionFP16, FP32, BF16 (M4+)FP16, BF16, TF32, FP8, INT8
Programming modelMSL intrinsicsWMMA/MMA PTX intrinsics
Throughput per core~1 TFLOPS (estimated)~2 TFLOPS per SM (A100)
Memory modelUMA sharedDiscrete HBM
Tensor Core countIntegrated in GPU coresDedicated hardware units

The key architectural difference: NVIDIA Tensor Cores are dedicated hardware units separate from the CUDA cores. Apple’s SIMD matrix operations are executed by the same ALUs that do regular floating-point math – they are an instruction set extension, not a separate unit.7 This means:

  1. No mode switching. You can freely interleave matrix operations with scalar/vector code. On NVIDIA, switching between Tensor Core and CUDA core work can cause pipeline bubbles.

  2. Lower peak throughput per core. Apple’s matrix multiply throughput is lower than NVIDIA’s dedicated Tensor Cores. But Apple compensates with higher memory bandwidth per FLOP (critical for inference) and the zero-copy UMA advantage.

  3. Simpler programming model. The simdgroup_matrix API is genuinely easier to use than NVIDIA’s WMMA or inline PTX MMA instructions. Load, store, multiply-accumulate – that’s it.

BF16 Support on M4

Starting with Apple GPU family 9 (M4), Metal supports native bfloat (BF16) as a first-class type.8 Akunu has dedicated BF16 kernels:

// From dtype_descriptor.h:
{31, "gemv_bf16", nullptr, nullptr,
 "simd_gemm_bf16", "simd_gemm_small_bf16",
 "embedding_lookup_bf16", ...}

BF16 has the same exponent range as FP32 (8 bits) but with only 7 bits of mantissa (vs FP16’s 10 bits). This makes it better for accumulation-heavy workloads where dynamic range matters more than precision. The simdgroup_matrix API supports simdgroup_bfloat8x8 on M4, enabling hardware-accelerated BF16 matrix multiply.

Fused SiLU + GEMV

One of akunu’s more aggressive optimizations is the fused SiLU+down GEMV kernel. In a standard SwiGLU FFN block:

gate = GEMV(gate_weight, x)
up   = GEMV(up_weight, x)
act  = SiLU(gate) * up
down = GEMV(down_weight, act)

That’s 3 GEMV dispatches + 1 activation dispatch = 4 kernel launches. The fused kernel combines the last two steps:

down[i] = sum_j( SiLU(gate[j]) * up[j] * down_weight[i,j] )

This reads the gate and up vectors, applies SiLU element-wise, and immediately multiplies by the down weight – all in one kernel, one pass over the down weight matrix. The DTypeDescriptor tracks which formats have fused kernels:

const char *fused_silu_kernel;        // "gemv_q4_0_silu", etc.
const char *fused_silu_large_kernel;  // Wide variant for Pro+

Not all formats have fused kernels (Q4_K does not, for example), so the table builder falls back to separate activation + GEMV when unavailable.

Practical Performance Numbers

To put all of this in perspective, here is what these kernel choices mean for actual throughput. On an M4 Pro (20 GPU cores, 273 GB/s bandwidth):

OperationKernel TypeTime (7B Q4_0)Bottleneck
Prefill GEMM (M=512)simd_gemm_q4_0~15 ms/layerCompute-bound
Decode GEMV (M=1)gemv_q4_0~0.35 ms/layerBandwidth-bound
Fused SiLU+Down GEMVgemv_q4_0_silu~0.30 ms (saves ~0.05ms)Bandwidth-bound
Flash Attention (seq=2048)flash_attention_decode_parallel_f16~0.08 ms/layerCompute/BW mixed

The GEMV kernels dominate decode time because they read the most data. Everything else – norms, activations, attention – is comparatively cheap. This is why akunu spends so much effort on GEMV kernel variants, wide vs. standard selection, and fused operations.

Summary

Apple Silicon’s simdgroup_matrix API provides hardware-accelerated 8x8 matrix multiplication that serves as the building block for larger GEMM operations. Akunu uses the llama.cpp tile geometry (TM=32, TN=64) for prefill GEMM and specialized GEMV kernels for decode. The choice between GEMM and GEMV, and between standard and wide variants, is driven by the DTypeDescriptor table and ChipConfig thresholds – all resolved at dispatch table build time, not at runtime.

The next chapter explores the broader set of performance optimization patterns that make these kernels – and the overall inference pipeline – fast.



  1. Apple, “Metal Shading Language Specification,” Version 3.2, Section 6.9, “SIMD-group Matrix Functions.” Available at https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf.

  2. Apple, “Metal Best Practices Guide: SIMD-groups.” A SIMD group on Apple GPU is always 32 threads. This is analogous to NVIDIA’s warp size. See https://developer.apple.com/documentation/metal/compute_passes/creating_threads_and_threadgroups.

  3. The simdgroup_matrix API was introduced in Metal 2.4 (iOS 15, macOS 12) and requires Apple GPU family 7 or later (M1+). Earlier Apple GPUs (A-series before A14) have smaller SIMD groups (8 or 16 threads) and do not support matrix operations. See https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf.

  4. The actual hardware implementation likely uses a systolic array or similar structure within each GPU execution unit. Apple does not disclose microarchitectural details, but the instruction behavior (32 threads cooperatively computing an 8x8x8 MAC) is consistent with a small matrix engine per execution unit.

  5. This is the fundamental insight behind the roofline model applied to LLM inference. For M=1 GEMV, the arithmetic intensity is approximately 1 FLOPs/byte (2 FLOPs per weight element, ~2 bytes per weight for FP16). The roofline crossover point on Apple Silicon is typically at M=16-64, depending on the chip. See Williams et al., “Roofline: An Insightful Visual Performance Model for Multicore Architectures,” CACM 2009. See https://doi.org/10.1145/1498765.1498785.

  6. Apple, “Using Function Specialization to Build Pipeline Variants.” Function constants allow creating specialized versions of a shader function, enabling the compiler to optimize based on known constant values. See https://developer.apple.com/documentation/metal/using-function-specialization-to-build-pipeline-variants.

  7. Dougall Johnson, “Apple GPU Architecture” (reverse-engineering documentation), 2022-2024. Johnson’s analysis shows that Apple GPU ALUs execute SIMD matrix instructions directly, without dedicated tensor hardware. See https://dougallj.github.io/applegpu/.

  8. Apple, “What’s New in Metal,” WWDC24. M4 introduces native bfloat16 support in Metal, including simdgroup_bfloat8x8 matrix operations. See https://developer.apple.com/videos/play/wwdc2024/10220/.