Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

GEMM: Batched Matrix Multiplication

During prefill, every linear projection is a full matrix-matrix multiply: C[M,N] = A[M,K] @ B^T[N,K]. Unlike GEMV (which is memory-bound), GEMM can be compute-bound when M is large enough, because each weight element is reused across M activation rows. This chapter covers Akunu’s GEMM kernels, which use Apple Silicon’s SIMD group matrix multiply-accumulate (MMA) instructions to achieve near-peak throughput.

The kernels live in backend/metal/kernels/metal/kernel/matmul/simd_gemm_*.metal. We will focus on two representative variants: the FP16 GEMM and the Q4_0 GEMM, which together illustrate the key design patterns.

Tile Geometry: The 32x64 Layout

Both GEMM kernels use the same tile geometry, inherited from llama.cpp’s kernel_mul_mm:

ParameterSymbolValueMeaning
Tile M (activation rows)TM / NR132Rows of A processed per threadgroup
Tile N (weight rows)TN / NR064Rows of B (columns of output) per threadgroup
Tile K (accumulation)TK / NK32K-dimension per accumulation step
Threads per TG1284 SIMD groups x 32 lanes
Dispatch grid(ceil(N/64), ceil(M/32))One TG per output tile

Why 32x64 and not 64x64 or 32x32? The answer lies in the SIMD group MMA instruction, which operates on 8x8 half-precision matrices. The 32x64 tile decomposes into:

Output tile [32, 64] as 8x8 sub-tiles:
┌────┬────┬────┬────┬────┬────┬────┬────┐
│8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │  row 0-7
├────┼────┼────┼────┼────┼────┼────┼────┤
│8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │  row 8-15
├────┼────┼────┼────┼────┼────┼────┼────┤
│8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │  row 16-23
├────┼────┼────┼────┼────┼────┼────┼────┤
│8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │  row 24-31
└────┴────┴────┴────┴────┴────┴────┴────┘
 col  col  col  col  col  col  col  col
 0-7  8-15 16-23 24-31 32-39 40-47 48-55 56-63

4 SIMD groups split this into 4 quadrants:
  SG0: rows 0-15,  cols 0-31   (2×4 = 8 sub-tiles)
  SG1: rows 0-15,  cols 32-63  (2×4 = 8 sub-tiles)
  SG2: rows 16-31, cols 0-31   (2×4 = 8 sub-tiles)
  SG3: rows 16-31, cols 32-63  (2×4 = 8 sub-tiles)

Each SIMD group maintains mc[8] accumulators (8 simdgroup_half8x8 matrices), covering its 8 sub-tiles.

Interactive: GEMM Tiled Execution on the GPU

This animation shows how one threadgroup computes a 32x64 output tile. Watch 4 SIMD groups cooperatively load weight and activation tiles into threadgroup memory, then execute 8x8 MMA operations. The K-dimension sweeps left to right, and the output tile fills as accumulators grow. Step through to see the cooperative loading, the MMA compute, and the final store.

A [M=32, K]
Output C [32, 64] — 4 SG quadrants
B^T [N=64, K]

FP16 GEMM (simd_gemm_f16)

Threadgroup Memory Layout

threadgroup half *sa = shmem;                    // Weight tile: 4096 bytes
threadgroup half *sb = shmem + 4096 / sizeof(half); // Activation tile: 2048 bytes
// Total: 6144 bytes

The weight tile is larger (4096 bytes for 64 rows x 32 K-cols) because TN=64 > TM=32. The activation tile is 2048 bytes (32 rows x 32 K-cols).

Cooperative Loading

Each of the 128 threads loads a portion of the weight and activation tiles into threadgroup memory:

Weight loading (sa):

const short lr0 = ((short)tiitg / NL0) < nr0 ? ((short)tiitg / NL0) : nr0 - 1;
const short il0 = (tiitg % NL0);

// F16: just read 16 halves per thread
half4x4 temp_a;
for (int i = 0; i < 16; i++) {
    temp_a[i/4][i%4] = x[i];
}

For FP16, the load is a simple copy from device memory to registers, then a scatter to threadgroup memory in the sub-block layout that the MMA instructions expect.

Activation loading (sb):

const short lr1 = ((short)tiitg / NL1) < nr1 ? ((short)tiitg / NL1) : nr1 - 1;
const short iy = 8 * (tiitg % NL1);

*(threadgroup half2x4 *)(sb + 64 * ib + 8 * ly) = *((device const half2x4 *)y);

The activation tile uses half2x4 (16-byte) vector stores for efficient threadgroup memory writes.

The Scatter Pattern

The threadgroup memory layout is not a simple row-major matrix. Instead, it uses an 8x8 sub-block interleaved layout that aligns with the MMA instruction’s expected input format:

for (short i = 0; i < 16; i++) {
    const short sx = 2 * il0 + i / 8;
    const short sy = lr0 / 8;
    const short lx = lr0 % 8;
    const short ly = i % 8;
    const short ib = 8 * sx + sy;
    *(sa + 64 * ib + 8 * ly + lx) = temp_a[i/4][i%4];
}

This scatter writes 16 elements per thread into the correct positions for efficient simdgroup_load. The layout ensures that each 8x8 sub-block is contiguous in memory, with a stride of 8 between columns and 64 between rows of sub-blocks.1

The MMA Accumulation Loop

for (uint loop_k = 0; loop_k < K_dim; loop_k += NK) {
    // Load weight and activation tiles (shown above)
    threadgroup_barrier(mem_flags::mem_threadgroup);

    threadgroup const half *lsma = (sa + 4 * 64 * (sgitg % 2));
    threadgroup const half *lsmb = (sb + 2 * 64 * (sgitg / 2));

    for (short ik = 0; ik < NK / 8; ik++) {
        simdgroup_barrier(mem_flags::mem_none);

        simdgroup_half8x8 ma[4];
        for (short i = 0; i < 4; i++) {
            simdgroup_load(ma[i], lsma + 64 * i, 8, 0, false);
        }

        simdgroup_half8x8 mb[2];
        for (short i = 0; i < 2; i++) {
            simdgroup_load(mb[i], lsmb + 64 * i, 8, 0, false);
        }

        for (short i = 0; i < 8; i++) {
            simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
        }

        lsma += 8 * 64;
        lsmb += 4 * 64;
    }
}

Let’s break down what happens per K-step (8 elements of K):

  1. Load weight sub-tiles: 4 simdgroup_half8x8 matrices (ma[0..3]) are loaded from sa. These represent a 32x8 slice of the weight tile (4 sub-tiles of 8x8).

  2. Load activation sub-tiles: 2 simdgroup_half8x8 matrices (mb[0..1]) from sb. These represent a 16x8 slice of the activation tile (2 sub-tiles of 8x8).

  3. MMA: 8 multiply-accumulate operations, one per output sub-tile. Each computes mc[i] += mb[i/4] * ma[i%4], which is an 8x8 @ 8x8 -> 8x8 matrix multiply-accumulate.

The simdgroup_barrier(mem_flags::mem_none) is a lightweight barrier that synchronizes execution within the SIMD group without requiring memory ordering. This is cheaper than a full threadgroup_barrier.

Function Constant K Specialization

constant uint FC_GEMM_K [[function_constant(10)]];
constant bool FC_GEMM_K_SPECIALIZED = is_function_constant_defined(FC_GEMM_K);

const uint K_dim = FC_GEMM_K_SPECIALIZED ? FC_GEMM_K : K;

When K is known at pipeline creation time and is a multiple of 32, the host passes it as a function constant. The Metal compiler can then:

  • Generate a fixed-count loop (or fully unrolled for small K)
  • Eliminate the remainder check
  • Optimize memory access patterns for the known stride

Output Store with Alpha/Beta

The FP16 GEMM supports the full BLAS-style interface C = alpha * A @ B^T + beta * C:

const half alpha_h = half(params.alpha);
const half beta_h  = half(params.beta);

// ...
const bool has_alphabeta = (alpha_h != half(1) || beta_h != half(0));
if (has_alphabeta) {
    for (int i = 0; i < nr0; i++) {
        D[i] = alpha_h * S[i] + beta_h * D[i];
    }
} else {
    // Fast path: direct copy with half4 stores
    device half4 *D4 = (device half4 *)D;
    threadgroup half4 *S4 = (threadgroup half4 *)S;
    for (int i = 0; i < nr0 / 4; i++) *(D4 + i) = *(S4 + i);
}

When alpha=1, beta=0 (the common case), the output is stored directly with half4 vector stores, avoiding the multiply-add overhead.

Q4_0 GEMM (simd_gemm_q4_0)

The Q4_0 GEMM is structurally identical to the FP16 GEMM – same tile geometry, same MMA loop, same output store. The only difference is how the weight tile is loaded: instead of a simple copy, the quantized data must be dequantized into FP16.

Inline Dequantization

inline void dequantize_q4_0_half4x4(device const block_q4_0 *xb,
                                      short il, thread half4x4 &reg) {
    device const uint16_t *qs = ((device const uint16_t *)xb + 1);
    const float d1 = il ? (xb->d / 16.h) : xb->d;
    const float d2 = d1 / 256.f;
    const float md = -8.h * xb->d;
    const ushort mask0 = il ? 0x00F0 : 0x000F;
    const ushort mask1 = mask0 << 8;

    float4x4 reg_f;
    for (int i = 0; i < 8; i++) {
        reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
        reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
    }
    reg = (half4x4)reg_f;
}

This dequantizes one block (32 elements) into a half4x4 (16 elements). The il parameter selects which half of the block to dequantize (low nibbles or high nibbles). The two calls to this function per thread cover all 32 elements.

The key insight: dequantization happens into registers, not into a separate buffer. The dequantized values go directly into the scatter pattern, and from there into the MMA pipeline. No intermediate buffer is ever allocated for dequantized weights.

Threadgroup Swizzling

The Q4_0 GEMM includes an optimization not present in the FP16 version: threadgroup swizzling for cache locality:

constexpr uint SWIZZLE_LOG = 3;
constexpr uint SWIZZLE_WIDTH = 1u << SWIZZLE_LOG;  // 8

uint tg_x = tgpig.x;
uint tg_y = tgpig.y;
uint tiles_x = (N + NR0 - 1) / NR0;
if (tiles_x >= SWIZZLE_WIDTH) {
    uint group = tg_x >> SWIZZLE_LOG;
    uint within = tg_x & (SWIZZLE_WIDTH - 1);
    tg_x = (group << SWIZZLE_LOG) + ((within + tg_y) & (SWIZZLE_WIDTH - 1));
}

Without swizzling, threadgroups are dispatched in row-major order: (0,0), (1,0), (2,0), .... Adjacent threadgroups in the X direction access different weight columns but the same activation rows. Swizzling rotates the column index by the row index within strips of 8 tiles, so that adjacent threadgroups access overlapping weight columns:

Without swizzling (row 0):    TG(0,0) TG(1,0) TG(2,0) TG(3,0) TG(4,0) ...
Without swizzling (row 1):    TG(0,1) TG(1,1) TG(2,1) TG(3,1) TG(4,1) ...

With swizzling (row 0):       TG(0,0) TG(1,0) TG(2,0) TG(3,0) TG(4,0) ...
With swizzling (row 1):       TG(1,1) TG(2,1) TG(3,1) TG(4,1) TG(5,1) ...

The effect: TG(1,0) and TG(1,1) (which are likely to execute on neighboring GPU cores) now access weight tiles that are only 64 columns apart instead of the full N-stride. This keeps weight data hot in the System Level Cache (SLC).2

Full Tile Fast Path

When the output tile is fully covered (no edge padding needed), the Q4_0 GEMM uses a direct device memory store:

if (nr0 == NR0 && nr1 == NR1) {
    device half *D = C
        + (uint)(r1 + 16 * (sgitg >> 1)) * ldc
        + (uint)(r0 + 32 * (sgitg & 1));
    for (short i = 0; i < 8; i++) {
        simdgroup_store(mc[i], D + 8 * (i/4) * ldc + 8 * (i%4), ldc, 0, false);
    }
}

Each SIMD group writes its 8 sub-tiles (8x8 each) directly to the output matrix using simdgroup_store. The ldc stride tells the store instruction the row pitch of the output matrix.

For edge tiles (where the tile extends beyond the matrix boundary), a staging area in threadgroup memory is used, and only the valid elements are copied to device memory.

Tile Accumulation Visualization

The following shows how a single output tile accumulates over K-steps:

Tile Accumulation Loop (one threadgroup computes C[32,64]):

  K=0..31       K=32..63      K=64..95             K=4065..4095
 ┌─────────┐  ┌─────────┐  ┌─────────┐           ┌─────────┐   ┌─────────┐
 │Load A,B  │  │Load A,B  │  │Load A,B  │           │Load A,B  │   │ Store C │
 │Dequant B │→ │Dequant B │→ │Dequant B │→  ...  → │Dequant B │ → │ [32,64] │
 │C += A@B  │  │C += A@B  │  │C += A@B  │           │C += A@B  │   │to device│
 └─────────┘  └─────────┘  └─────────┘           └─────────┘   └─────────┘
    128 iterations total (K=4096, stride=32)

For a model with K=4096, there are 4096/32 = 128 accumulation steps per tile. Each step loads 2KB of activation data and 4KB of weight data (for Q4_0, the raw quantized data is ~1KB but dequantizes to 4KB in registers), performs 8 MMA operations (each 8x8 @ 8x8), and accumulates into the 8 output sub-tiles.

Small GEMM Variants

For very small M (2-8 rows), Akunu provides “small” GEMM variants with TM=8 instead of TM=32:

simd_gemm_small_f16, simd_gemm_small_q4_0, simd_gemm_small_q4_k, ...

These use fewer threadgroup memory (fewer activation rows to store) and produce smaller output tiles, avoiding wasted computation on padding rows. The dispatch threshold is:

bool small = (M >= 2 && M <= 8);
int TM = small ? 8 : 32;

When M=4 (e.g., a 4-token speculative verification batch), the small variant processes all 4 rows in one 8-row tile (with 4 padding rows), while the standard variant would use a 32-row tile with 28 wasted rows.

Memory Requirements

ResourceFP16 GEMMQ4_0 GEMM
Threadgroup memory6144 bytes6144 bytes
Registers per SG (accumulators)8 x simdgroup_half8x88 x simdgroup_half8x8
Weight tile bandwidth64 * 32 * 2 = 4096 bytes/step64 * 32 / 2 * 1.25 = ~1280 bytes/step
Activation tile bandwidth32 * 32 * 2 = 2048 bytes/step32 * 32 * 2 = 2048 bytes/step

The Q4_0 GEMM reads only ~1280 bytes of quantized weight data per K-step (compared to 4096 bytes for FP16), a 3.2x reduction. This is why quantized GEMMs achieve higher effective throughput than FP16 GEMMs on the same hardware – the memory subsystem is the bottleneck for both, and Q4_0 moves less data per FLOP.

The Full GEMM Kernel Zoo

Akunu provides GEMM kernels for every supported weight format:

FormatStandard KernelSmall KernelNotes
FP16simd_gemm_f16simd_gemm_small_f16No dequant
BF16simd_gemm_bf16simd_gemm_small_bf16BF16->FP16 convert
Q4_0simd_gemm_q4_0simd_gemm_small_q4_04-bit, group=32, swizzle
Q4_1simd_gemm_q4_1simd_gemm_small_q4_14-bit with min
Q5_0simd_gemm_q5_0simd_gemm_small_q5_05-bit
Q5_Ksimd_gemm_q5_ksimd_gemm_small_q5_k5-bit, K-quant
Q8_0simd_gemm_q8_0simd_gemm_small_q8_08-bit
Q4_Ksimd_gemm_q4_ksimd_gemm_small_q4_k4-bit, K-quant
Q6_Ksimd_gemm_q6_ksimd_gemm_small_q6_k6-bit, K-quant
Q2_Ksimd_gemm_q2_ksimd_gemm_small_q2_k2-bit, K-quant
Q3_Ksimd_gemm_q3_ksimd_gemm_small_q3_k3-bit, K-quant
MLX Q3simd_gemm_mlx_q3simd_gemm_small_mlx_q3MLX 3-bit
MLX Q4simd_gemm_mlx_q4simd_gemm_small_mlx_q4MLX 4-bit
MLX Q6simd_gemm_mlx_q6simd_gemm_small_mlx_q6MLX 6-bit
MLX Q8simd_gemm_mlx_q8simd_gemm_small_mlx_q8MLX 8-bit
MLX Gensimd_gemm_mlx_gensimd_gemm_small_mlx_genMLX arbitrary bits

That is 30+ kernel variants, all sharing the same tile geometry and MMA loop, differing only in the dequantization path.

Performance Characteristics

GEMM performance on Apple Silicon depends primarily on the tile utilization and memory bandwidth:

FactorImpactHow Akunu Handles It
M too smallWasted rows in activation tileSmall GEMM variant (TM=8)
N not multiple of 64Edge tile with partial storeStaging through TG memory
K not multiple of 32Remainder loop neededFC_GEMM_K specialization
Cache thrashingWeight tile evictionThreadgroup swizzling
Register pressureSpill to local memory8 accumulators fits in 128 registers

The theoretical peak for an Apple M4 Pro (20 GPU cores) at FP16 MMA is approximately 14 TFLOPS. A well-optimized 4096x4096 GEMM achieves roughly 80-90% of peak, limited by threadgroup memory bandwidth and barrier synchronization overhead.

Pipeline State Object Caching

Each GEMM variant requires a compiled Pipeline State Object (PSO) before it can be dispatched. Akunu caches these PSOs aggressively:

std::string cache_key = std::string(kernel) + "_k" + std::to_string(K);
pso = device.get_pipeline(kernel, cache_key, fc_indices, fc_values, 1);

The cache key includes the kernel name and any function constant values, ensuring that different K-specializations produce separate PSOs. The first call to get_pipeline compiles the MSL kernel into GPU machine code (which can take 10-50ms), but subsequent calls return the cached PSO instantly.

For a typical model, there are approximately 10-15 unique GEMM PSOs (one per unique K dimension per weight format). These are compiled during model loading and never recompiled during inference.

GEMM vs GEMV: The Crossover Point

An important question: when should the engine use GEMM instead of GEMV? The answer depends on M (the number of activation rows):

MOptimal KernelWhy
1GEMVNo tile overhead, direct reduction
2-8Small GEMM (TM=8)Some row reuse, minimal padding
9-32Standard GEMM (TM=32)Good tile utilization
33+Standard GEMM (TM=32)Multiple tiles in M dimension

The crossover between GEMV and GEMM is at M=2. Even with just 2 activation rows, the GEMM kernel’s weight tile loading (shared between both rows) provides better memory efficiency than two separate GEMV dispatches. However, for M=1, the GEMM kernel wastes 31 out of 32 rows in the activation tile, so GEMV is always faster.

Akunu’s dispatch_gemm function makes this decision:

bool small = (M >= 2 && M <= 8);
int TM = small ? 8 : 32;

The small GEMM variant (TM=8) wastes at most 6 rows (for M=2) instead of 30 rows (with TM=32), providing a good compromise for very small batch sizes.

The MMA Instruction in Detail

Apple Silicon’s simdgroup_multiply_accumulate is the hardware primitive that makes efficient GEMM possible. Let’s understand exactly how it works.

Lane-to-Element Mapping

In an 8x8 SIMD matrix, the 32 lanes of a SIMD group each hold 2 elements. The mapping follows Apple’s proprietary layout:3

For an 8x8 matrix stored in a simdgroup_half8x8:
  Lane 0: elements (0,0) and (0,1)
  Lane 1: elements (0,2) and (0,3)
  Lane 2: elements (1,0) and (1,1)
  Lane 3: elements (1,2) and (1,3)
  ...

The thread_elements() accessor returns a vec<T, 2> containing the calling thread’s two elements. This is used by the V2 attention kernel to perform per-element operations directly on MMA results without going through threadgroup memory.

MMA Throughput

Each simdgroup_multiply_accumulate(C, A, B, C) computes:

C[8,8] += A[8,8] @ B[8,8]

This performs 8 * 8 * 8 = 512 multiply-accumulate operations. At FP16 precision, this is 1024 FLOPs per instruction. With 4 SIMD groups per threadgroup and a typical clock rate of 1.4 GHz on M4 Pro, the peak throughput per threadgroup is:

4 SG * 1024 FLOP/instruction * ~1 instruction/cycle * 1.4 GHz
= ~5.7 GFLOPS per threadgroup

With 20 GPU cores running ~10 threadgroups each, the chip-level throughput is approximately 1.14 TFLOPS of FP16 MMA – though in practice, memory bandwidth and barrier overhead reduce this to ~60-80% of peak.

Register Accumulator Precision

The MMA instruction accumulates in the same precision as the operands. For simdgroup_half8x8, accumulation is in FP16. For long K-dimensions (K > 4096), this can lead to precision loss from repeated half-precision additions.

Akunu mitigates this by using the simdgroup_float8x8 accumulator type for attention scores (where precision matters more) while keeping simdgroup_half8x8 for GEMM output (where the subsequent operations, norm + activation, tolerate half-precision).

GEMM vs GEMV: The Crossover Point

An important question: when should the engine use GEMM instead of GEMV? The answer depends on M (the number of activation rows):

MOptimal KernelWhy
1GEMVNo tile overhead, direct reduction
2-8Small GEMM (TM=8)Some row reuse, minimal padding
9-32Standard GEMM (TM=32)Good tile utilization
33+Standard GEMM (TM=32)Multiple tiles in M dimension

The crossover between GEMV and GEMM is at M=2. Even with just 2 activation rows, the GEMM kernel’s weight tile loading (shared between both rows) provides better memory efficiency than two separate GEMV dispatches. However, for M=1, the GEMM kernel wastes 31 out of 32 rows in the activation tile, so GEMV is always faster.

Akunu’s dispatch_gemm function makes this decision:

bool small = (M >= 2 && M <= 8);
int TM = small ? 8 : 32;

The small GEMM variant (TM=8) wastes at most 6 rows (for M=2) instead of 30 rows (with TM=32), providing a good compromise for very small batch sizes.

Quantized GEMM Performance Analysis

For a 7B model prefilling 2048 tokens with Q4_0 weights:

The Q projection GEMM: C[2048, 4096] = A[2048, 4096] @ B^T[4096, 4096]

MetricValue
Output elements2048 * 4096 = 8.4M
FLOPs2 * 2048 * 4096 * 4096 = 68.7 GFLOP
Weight data read4096 * 4096 * 0.56 bytes = 9.4 MB
Activation data read2048 * 4096 * 2 bytes = 16.8 MB
Output data written2048 * 4096 * 2 bytes = 16.8 MB
Total memory traffic~43 MB
Arithmetic intensity68.7 GFLOP / 43 MB ≈ 1598 FLOP/byte

At 1598 FLOP/byte, this is firmly in the compute-bound regime. The M4 Pro’s 14 TFLOPS of FP16 throughput would complete this in ~4.9ms, and memory bandwidth (200 GB/s) would complete the data transfer in ~0.2ms. The GEMM is compute-bound by a factor of ~24x.

This is the fundamental reason prefill is so much faster per-token than decode: the same weight data is reused across 2048 activation rows, amortizing the memory transfer cost.

Handling Non-Standard Architectures

Akunu’s GEMM dispatch supports several architectural variations through the descriptor system:

BERT/Encoder models: Use the same GEMM kernels but with different weight names and optional bias addition (dispatched as a separate kernel after the GEMM).

Gemma models: Have post-attention and post-FFN norms that require extra GEMM dispatch passes. The dispatch_gemm function is architecture-agnostic – it just computes C = alpha * A @ B^T + beta * C.

MLX quantized models: The GEMM kernels receive MLX-specific parameters (group_size, bits, weight_bytes) via a secondary parameter buffer, enabling the same tile geometry with different dequantization logic.

Tied embeddings: The logit projection in some models reuses the embedding table as the output weight. dispatch_gemm does not care about the semantic meaning of the weight – it just needs the buffer, dimensions, and dtype.

Threadgroup Memory Bandwidth

Threadgroup memory on Apple Silicon GPUs has significantly higher bandwidth than device memory – roughly 10-20x, depending on the chip generation. This is why the GEMM kernel’s performance depends heavily on the TG memory access pattern.

The weight tile scatter pattern places data in 8x8 sub-blocks with stride 64 between sub-block rows and stride 8 between columns within a sub-block. This layout is not arbitrary – it matches the simdgroup_load access pattern, ensuring that each MMA instruction’s operand load reads a contiguous 64-byte chunk from threadgroup memory.

For each K-step (32 elements of K):

AccessPatternBytesBandwidth Required
Load weight tile from device128 threads, 16 elements each4096 bytesDevice BW
Scatter weight to TG128 threads, indexed writes4096 bytesTG BW
Load activation tile from device128 threads, 8 elements each2048 bytesDevice BW
Store activation to TG128 threads, vector stores2048 bytesTG BW
MMA loads from TG4 SG x (4+2) loads per K/8 step~6144 bytes per stepTG BW

The TG memory acts as a software-managed L1 cache, giving the programmer explicit control over data reuse that would otherwise depend on hardware caching behavior.

End-to-End Prefill GEMM Flow

For a complete understanding, let’s trace the GEMMs in a single transformer layer during prefill of a 7B model with seq_len=2048:

GEMMMNKWeight ShapeTime (est.)
Q projection204840964096[4096, 4096]~5ms
K projection204810244096[1024, 4096]~1.5ms
V projection204810244096[1024, 4096]~1.5ms
O projection204840964096[4096, 4096]~5ms
Gate projection2048143364096[14336, 4096]~16ms
Up projection2048143364096[14336, 4096]~16ms
Down projection2048409614336[4096, 14336]~16ms

Total per layer: ~61ms. For 32 layers: ~1.95 seconds. Plus attention, norms, and activations: roughly 2.5 seconds total for 2048 tokens. That is about 820 tokens/sec prefill throughput, which matches real-world measurements on M4 Pro hardware.

The FFN GEMMs (Gate, Up, Down) dominate because ffn_dim (14336) is ~3.5x larger than dim (4096). This is characteristic of modern LLMs that use SwiGLU activation, which requires a wider intermediate dimension.

The Barrier Budget

Threadgroup barriers are a significant cost in GEMM kernels. Each threadgroup_barrier(mem_flags::mem_threadgroup) call synchronizes all threads in the threadgroup and flushes the threadgroup memory. On Apple Silicon, a barrier takes approximately 0.2-0.5 microseconds.

For each K-step (32 elements of K), the GEMM kernel requires 2 barriers:

  1. After the cooperative tile load (ensure all threads have written their portion)
  2. After the MMA loop (ensure all SIMD groups have finished reading)

For K=4096, there are 128 K-steps, requiring 256 barriers. At 0.3us per barrier, this is ~77us of pure barrier overhead per tile, or roughly 10-15% of the total tile computation time. This is one of the reasons GEMM does not achieve 100% of peak MMA throughput.

The V2 attention kernel’s approach of keeping data in registers (avoiding the MMA-barrier-MMA cycle) provides a hint at how future GEMM kernels might reduce barrier overhead, though the GEMM’s much larger tile sizes make this approach more challenging.

Comparison with llama.cpp’s GEMM

Akunu’s GEMM kernels are derived from llama.cpp’s kernel_mul_mm family but include several improvements:

Featurellama.cppAkunu
Tile geometry32x64 (same)32x64 + 8x64 small variant
Threadgroup swizzlingNoYes (Q4_0, other quantized)
Function constant KNoYes (FC_GEMM_K)
Alpha/beta supportNoYes (FP16 GEMM)
MLX format supportNoYes (6 MLX variants)
Small M variantNoYes (TM=8 for M=2-8)
BF16 supportPartialFull

The most impactful difference is the function constant K specialization, which allows the Metal compiler to generate tighter loops with known bounds, often resulting in 5-10% speedup for common K dimensions.

The threadgroup swizzling provides another 3-8% improvement at large grid sizes by improving SLC hit rates for weight tiles. This is most noticeable during the FFN GEMMs where the grid is large (14336/64 = 224 tiles in the weight dimension).

Future Directions

Apple’s Metal 3.2 (introduced with the M4 family) provides enhanced simdgroup matrix operations, including support for larger tile sizes and new data types. Future GEMM kernels may benefit from:

  • Larger MMA tiles: 16x16 or 32x32 sub-tiles would reduce the number of MMA instructions per output element, improving throughput.
  • BF16 MMA: Native BF16 matrix operations would eliminate the conversion overhead for BF16 models.
  • Cooperative groups: Finer-grained synchronization primitives could reduce barrier overhead.
  • Persistent kernels: A single long-running kernel that processes all tiles sequentially could eliminate inter-tile overhead.

However, the current 8x8 MMA-based approach is well-proven and delivers near-peak performance. The 32x64 tile geometry will likely remain optimal for Apple Silicon’s current generation of GPU architectures.

Debugging GEMM Correctness

GEMM bugs are notoriously difficult to debug because the output is a large matrix where each element depends on the full K-dimension accumulation. Akunu uses several strategies:

  1. Alpha/Beta support: Setting alpha=1, beta=0 for production and alpha=0, beta=1 for “identity” (output = input C) enables isolating GEMM output from existing data.
  2. PSO validation: The dispatch_gemm function includes a fatal error if the PSO fails to compile, catching kernel bugs early.
  3. Dimension checks: The scratch buffer sizes are validated at model load time to ensure no GEMM dispatch will write out of bounds.
  4. Profiling labels: Each GEMM dispatch in the dispatch table carries a label like "L5.ffn.gate", making it easy to identify which GEMM produced incorrect output in a GPU debugger.

Summary

Akunu’s GEMM kernels are the workhorses of prefill. The key design decisions are:

  1. 32x64 tile geometry with 4 SIMD groups per threadgroup, maximizing MMA instruction utilization.
  2. Inline dequantization for quantized formats, converting directly from packed format to registers without intermediate buffers.
  3. Cooperative loading where all 128 threads participate in loading both weight and activation tiles.
  4. Threadgroup swizzling for cache-friendly access patterns across the grid.
  5. Small GEMM variants for low-M cases to avoid wasted padding computation.
  6. Function constant specialization for K-dimension to enable compiler optimizations.


  1. Apple. “Metal Shading Language Specification, Version 3.1.” Section 6.10, Simdgroup Matrix Functions. The simdgroup_load and simdgroup_store functions operate on 8x8 matrices distributed across the 32 threads of a SIMD group, with each thread holding two elements (the thread_elements() accessor). See https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf.

  2. The swizzling technique is adapted from NVIDIA’s CUTLASS library. See: Thakkar, V., et al. “CUTLASS: Fast Linear Algebra in CUDA C++.” NVIDIA Technical Blog, 2017. The Apple Silicon SLC acts similarly to NVIDIA’s L2 cache for this optimization.

  3. Apple. “Metal Shading Language Specification, Version 3.1.” Section 6.10.3, Simdgroup Matrix Thread Elements. The thread_elements() accessor returns a vec<T, 2> containing the thread’s owned elements, following Apple’s proprietary lane mapping for 8x8 matrices. See https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf.