GEMM: Batched Matrix Multiplication
During prefill, every linear projection is a full matrix-matrix multiply: C[M,N] = A[M,K] @ B^T[N,K]. Unlike GEMV (which is memory-bound), GEMM can be compute-bound when M is large enough, because each weight element is reused across M activation rows. This chapter covers Akunu’s GEMM kernels, which use Apple Silicon’s SIMD group matrix multiply-accumulate (MMA) instructions to achieve near-peak throughput.
The kernels live in backend/metal/kernels/metal/kernel/matmul/simd_gemm_*.metal. We will focus on two representative variants: the FP16 GEMM and the Q4_0 GEMM, which together illustrate the key design patterns.
Tile Geometry: The 32x64 Layout
Both GEMM kernels use the same tile geometry, inherited from llama.cpp’s kernel_mul_mm:
| Parameter | Symbol | Value | Meaning |
|---|---|---|---|
| Tile M (activation rows) | TM / NR1 | 32 | Rows of A processed per threadgroup |
| Tile N (weight rows) | TN / NR0 | 64 | Rows of B (columns of output) per threadgroup |
| Tile K (accumulation) | TK / NK | 32 | K-dimension per accumulation step |
| Threads per TG | – | 128 | 4 SIMD groups x 32 lanes |
| Dispatch grid | – | (ceil(N/64), ceil(M/32)) | One TG per output tile |
Why 32x64 and not 64x64 or 32x32? The answer lies in the SIMD group MMA instruction, which operates on 8x8 half-precision matrices. The 32x64 tile decomposes into:
Output tile [32, 64] as 8x8 sub-tiles:
┌────┬────┬────┬────┬────┬────┬────┬────┐
│8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │ row 0-7
├────┼────┼────┼────┼────┼────┼────┼────┤
│8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │ row 8-15
├────┼────┼────┼────┼────┼────┼────┼────┤
│8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │ row 16-23
├────┼────┼────┼────┼────┼────┼────┼────┤
│8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │8×8 │ row 24-31
└────┴────┴────┴────┴────┴────┴────┴────┘
col col col col col col col col
0-7 8-15 16-23 24-31 32-39 40-47 48-55 56-63
4 SIMD groups split this into 4 quadrants:
SG0: rows 0-15, cols 0-31 (2×4 = 8 sub-tiles)
SG1: rows 0-15, cols 32-63 (2×4 = 8 sub-tiles)
SG2: rows 16-31, cols 0-31 (2×4 = 8 sub-tiles)
SG3: rows 16-31, cols 32-63 (2×4 = 8 sub-tiles)
Each SIMD group maintains mc[8] accumulators (8 simdgroup_half8x8 matrices), covering its 8 sub-tiles.
Interactive: GEMM Tiled Execution on the GPU
This animation shows how one threadgroup computes a 32x64 output tile. Watch 4 SIMD groups cooperatively load weight and activation tiles into threadgroup memory, then execute 8x8 MMA operations. The K-dimension sweeps left to right, and the output tile fills as accumulators grow. Step through to see the cooperative loading, the MMA compute, and the final store.
FP16 GEMM (simd_gemm_f16)
Threadgroup Memory Layout
threadgroup half *sa = shmem; // Weight tile: 4096 bytes
threadgroup half *sb = shmem + 4096 / sizeof(half); // Activation tile: 2048 bytes
// Total: 6144 bytes
The weight tile is larger (4096 bytes for 64 rows x 32 K-cols) because TN=64 > TM=32. The activation tile is 2048 bytes (32 rows x 32 K-cols).
Cooperative Loading
Each of the 128 threads loads a portion of the weight and activation tiles into threadgroup memory:
Weight loading (sa):
const short lr0 = ((short)tiitg / NL0) < nr0 ? ((short)tiitg / NL0) : nr0 - 1;
const short il0 = (tiitg % NL0);
// F16: just read 16 halves per thread
half4x4 temp_a;
for (int i = 0; i < 16; i++) {
temp_a[i/4][i%4] = x[i];
}
For FP16, the load is a simple copy from device memory to registers, then a scatter to threadgroup memory in the sub-block layout that the MMA instructions expect.
Activation loading (sb):
const short lr1 = ((short)tiitg / NL1) < nr1 ? ((short)tiitg / NL1) : nr1 - 1;
const short iy = 8 * (tiitg % NL1);
*(threadgroup half2x4 *)(sb + 64 * ib + 8 * ly) = *((device const half2x4 *)y);
The activation tile uses half2x4 (16-byte) vector stores for efficient threadgroup memory writes.
The Scatter Pattern
The threadgroup memory layout is not a simple row-major matrix. Instead, it uses an 8x8 sub-block interleaved layout that aligns with the MMA instruction’s expected input format:
for (short i = 0; i < 16; i++) {
const short sx = 2 * il0 + i / 8;
const short sy = lr0 / 8;
const short lx = lr0 % 8;
const short ly = i % 8;
const short ib = 8 * sx + sy;
*(sa + 64 * ib + 8 * ly + lx) = temp_a[i/4][i%4];
}
This scatter writes 16 elements per thread into the correct positions for efficient simdgroup_load. The layout ensures that each 8x8 sub-block is contiguous in memory, with a stride of 8 between columns and 64 between rows of sub-blocks.1
The MMA Accumulation Loop
for (uint loop_k = 0; loop_k < K_dim; loop_k += NK) {
// Load weight and activation tiles (shown above)
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup const half *lsma = (sa + 4 * 64 * (sgitg % 2));
threadgroup const half *lsmb = (sb + 2 * 64 * (sgitg / 2));
for (short ik = 0; ik < NK / 8; ik++) {
simdgroup_barrier(mem_flags::mem_none);
simdgroup_half8x8 ma[4];
for (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + 64 * i, 8, 0, false);
}
simdgroup_half8x8 mb[2];
for (short i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + 64 * i, 8, 0, false);
}
for (short i = 0; i < 8; i++) {
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
lsma += 8 * 64;
lsmb += 4 * 64;
}
}
Let’s break down what happens per K-step (8 elements of K):
-
Load weight sub-tiles: 4
simdgroup_half8x8matrices (ma[0..3]) are loaded fromsa. These represent a 32x8 slice of the weight tile (4 sub-tiles of 8x8). -
Load activation sub-tiles: 2
simdgroup_half8x8matrices (mb[0..1]) fromsb. These represent a 16x8 slice of the activation tile (2 sub-tiles of 8x8). -
MMA: 8 multiply-accumulate operations, one per output sub-tile. Each computes
mc[i] += mb[i/4] * ma[i%4], which is an 8x8 @ 8x8 -> 8x8 matrix multiply-accumulate.
The simdgroup_barrier(mem_flags::mem_none) is a lightweight barrier that synchronizes execution within the SIMD group without requiring memory ordering. This is cheaper than a full threadgroup_barrier.
Function Constant K Specialization
constant uint FC_GEMM_K [[function_constant(10)]];
constant bool FC_GEMM_K_SPECIALIZED = is_function_constant_defined(FC_GEMM_K);
const uint K_dim = FC_GEMM_K_SPECIALIZED ? FC_GEMM_K : K;
When K is known at pipeline creation time and is a multiple of 32, the host passes it as a function constant. The Metal compiler can then:
- Generate a fixed-count loop (or fully unrolled for small K)
- Eliminate the remainder check
- Optimize memory access patterns for the known stride
Output Store with Alpha/Beta
The FP16 GEMM supports the full BLAS-style interface C = alpha * A @ B^T + beta * C:
const half alpha_h = half(params.alpha);
const half beta_h = half(params.beta);
// ...
const bool has_alphabeta = (alpha_h != half(1) || beta_h != half(0));
if (has_alphabeta) {
for (int i = 0; i < nr0; i++) {
D[i] = alpha_h * S[i] + beta_h * D[i];
}
} else {
// Fast path: direct copy with half4 stores
device half4 *D4 = (device half4 *)D;
threadgroup half4 *S4 = (threadgroup half4 *)S;
for (int i = 0; i < nr0 / 4; i++) *(D4 + i) = *(S4 + i);
}
When alpha=1, beta=0 (the common case), the output is stored directly with half4 vector stores, avoiding the multiply-add overhead.
Q4_0 GEMM (simd_gemm_q4_0)
The Q4_0 GEMM is structurally identical to the FP16 GEMM – same tile geometry, same MMA loop, same output store. The only difference is how the weight tile is loaded: instead of a simple copy, the quantized data must be dequantized into FP16.
Inline Dequantization
inline void dequantize_q4_0_half4x4(device const block_q4_0 *xb,
short il, thread half4x4 ®) {
device const uint16_t *qs = ((device const uint16_t *)xb + 1);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
}
reg = (half4x4)reg_f;
}
This dequantizes one block (32 elements) into a half4x4 (16 elements). The il parameter selects which half of the block to dequantize (low nibbles or high nibbles). The two calls to this function per thread cover all 32 elements.
The key insight: dequantization happens into registers, not into a separate buffer. The dequantized values go directly into the scatter pattern, and from there into the MMA pipeline. No intermediate buffer is ever allocated for dequantized weights.
Threadgroup Swizzling
The Q4_0 GEMM includes an optimization not present in the FP16 version: threadgroup swizzling for cache locality:
constexpr uint SWIZZLE_LOG = 3;
constexpr uint SWIZZLE_WIDTH = 1u << SWIZZLE_LOG; // 8
uint tg_x = tgpig.x;
uint tg_y = tgpig.y;
uint tiles_x = (N + NR0 - 1) / NR0;
if (tiles_x >= SWIZZLE_WIDTH) {
uint group = tg_x >> SWIZZLE_LOG;
uint within = tg_x & (SWIZZLE_WIDTH - 1);
tg_x = (group << SWIZZLE_LOG) + ((within + tg_y) & (SWIZZLE_WIDTH - 1));
}
Without swizzling, threadgroups are dispatched in row-major order: (0,0), (1,0), (2,0), .... Adjacent threadgroups in the X direction access different weight columns but the same activation rows. Swizzling rotates the column index by the row index within strips of 8 tiles, so that adjacent threadgroups access overlapping weight columns:
Without swizzling (row 0): TG(0,0) TG(1,0) TG(2,0) TG(3,0) TG(4,0) ...
Without swizzling (row 1): TG(0,1) TG(1,1) TG(2,1) TG(3,1) TG(4,1) ...
With swizzling (row 0): TG(0,0) TG(1,0) TG(2,0) TG(3,0) TG(4,0) ...
With swizzling (row 1): TG(1,1) TG(2,1) TG(3,1) TG(4,1) TG(5,1) ...
The effect: TG(1,0) and TG(1,1) (which are likely to execute on neighboring GPU cores) now access weight tiles that are only 64 columns apart instead of the full N-stride. This keeps weight data hot in the System Level Cache (SLC).2
Full Tile Fast Path
When the output tile is fully covered (no edge padding needed), the Q4_0 GEMM uses a direct device memory store:
if (nr0 == NR0 && nr1 == NR1) {
device half *D = C
+ (uint)(r1 + 16 * (sgitg >> 1)) * ldc
+ (uint)(r0 + 32 * (sgitg & 1));
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], D + 8 * (i/4) * ldc + 8 * (i%4), ldc, 0, false);
}
}
Each SIMD group writes its 8 sub-tiles (8x8 each) directly to the output matrix using simdgroup_store. The ldc stride tells the store instruction the row pitch of the output matrix.
For edge tiles (where the tile extends beyond the matrix boundary), a staging area in threadgroup memory is used, and only the valid elements are copied to device memory.
Tile Accumulation Visualization
The following shows how a single output tile accumulates over K-steps:
Tile Accumulation Loop (one threadgroup computes C[32,64]):
K=0..31 K=32..63 K=64..95 K=4065..4095
┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐
│Load A,B │ │Load A,B │ │Load A,B │ │Load A,B │ │ Store C │
│Dequant B │→ │Dequant B │→ │Dequant B │→ ... → │Dequant B │ → │ [32,64] │
│C += A@B │ │C += A@B │ │C += A@B │ │C += A@B │ │to device│
└─────────┘ └─────────┘ └─────────┘ └─────────┘ └─────────┘
128 iterations total (K=4096, stride=32)
For a model with K=4096, there are 4096/32 = 128 accumulation steps per tile. Each step loads 2KB of activation data and 4KB of weight data (for Q4_0, the raw quantized data is ~1KB but dequantizes to 4KB in registers), performs 8 MMA operations (each 8x8 @ 8x8), and accumulates into the 8 output sub-tiles.
Small GEMM Variants
For very small M (2-8 rows), Akunu provides “small” GEMM variants with TM=8 instead of TM=32:
simd_gemm_small_f16, simd_gemm_small_q4_0, simd_gemm_small_q4_k, ...
These use fewer threadgroup memory (fewer activation rows to store) and produce smaller output tiles, avoiding wasted computation on padding rows. The dispatch threshold is:
bool small = (M >= 2 && M <= 8);
int TM = small ? 8 : 32;
When M=4 (e.g., a 4-token speculative verification batch), the small variant processes all 4 rows in one 8-row tile (with 4 padding rows), while the standard variant would use a 32-row tile with 28 wasted rows.
Memory Requirements
| Resource | FP16 GEMM | Q4_0 GEMM |
|---|---|---|
| Threadgroup memory | 6144 bytes | 6144 bytes |
| Registers per SG (accumulators) | 8 x simdgroup_half8x8 | 8 x simdgroup_half8x8 |
| Weight tile bandwidth | 64 * 32 * 2 = 4096 bytes/step | 64 * 32 / 2 * 1.25 = ~1280 bytes/step |
| Activation tile bandwidth | 32 * 32 * 2 = 2048 bytes/step | 32 * 32 * 2 = 2048 bytes/step |
The Q4_0 GEMM reads only ~1280 bytes of quantized weight data per K-step (compared to 4096 bytes for FP16), a 3.2x reduction. This is why quantized GEMMs achieve higher effective throughput than FP16 GEMMs on the same hardware – the memory subsystem is the bottleneck for both, and Q4_0 moves less data per FLOP.
The Full GEMM Kernel Zoo
Akunu provides GEMM kernels for every supported weight format:
| Format | Standard Kernel | Small Kernel | Notes |
|---|---|---|---|
| FP16 | simd_gemm_f16 | simd_gemm_small_f16 | No dequant |
| BF16 | simd_gemm_bf16 | simd_gemm_small_bf16 | BF16->FP16 convert |
| Q4_0 | simd_gemm_q4_0 | simd_gemm_small_q4_0 | 4-bit, group=32, swizzle |
| Q4_1 | simd_gemm_q4_1 | simd_gemm_small_q4_1 | 4-bit with min |
| Q5_0 | simd_gemm_q5_0 | simd_gemm_small_q5_0 | 5-bit |
| Q5_K | simd_gemm_q5_k | simd_gemm_small_q5_k | 5-bit, K-quant |
| Q8_0 | simd_gemm_q8_0 | simd_gemm_small_q8_0 | 8-bit |
| Q4_K | simd_gemm_q4_k | simd_gemm_small_q4_k | 4-bit, K-quant |
| Q6_K | simd_gemm_q6_k | simd_gemm_small_q6_k | 6-bit, K-quant |
| Q2_K | simd_gemm_q2_k | simd_gemm_small_q2_k | 2-bit, K-quant |
| Q3_K | simd_gemm_q3_k | simd_gemm_small_q3_k | 3-bit, K-quant |
| MLX Q3 | simd_gemm_mlx_q3 | simd_gemm_small_mlx_q3 | MLX 3-bit |
| MLX Q4 | simd_gemm_mlx_q4 | simd_gemm_small_mlx_q4 | MLX 4-bit |
| MLX Q6 | simd_gemm_mlx_q6 | simd_gemm_small_mlx_q6 | MLX 6-bit |
| MLX Q8 | simd_gemm_mlx_q8 | simd_gemm_small_mlx_q8 | MLX 8-bit |
| MLX Gen | simd_gemm_mlx_gen | simd_gemm_small_mlx_gen | MLX arbitrary bits |
That is 30+ kernel variants, all sharing the same tile geometry and MMA loop, differing only in the dequantization path.
Performance Characteristics
GEMM performance on Apple Silicon depends primarily on the tile utilization and memory bandwidth:
| Factor | Impact | How Akunu Handles It |
|---|---|---|
| M too small | Wasted rows in activation tile | Small GEMM variant (TM=8) |
| N not multiple of 64 | Edge tile with partial store | Staging through TG memory |
| K not multiple of 32 | Remainder loop needed | FC_GEMM_K specialization |
| Cache thrashing | Weight tile eviction | Threadgroup swizzling |
| Register pressure | Spill to local memory | 8 accumulators fits in 128 registers |
The theoretical peak for an Apple M4 Pro (20 GPU cores) at FP16 MMA is approximately 14 TFLOPS. A well-optimized 4096x4096 GEMM achieves roughly 80-90% of peak, limited by threadgroup memory bandwidth and barrier synchronization overhead.
Pipeline State Object Caching
Each GEMM variant requires a compiled Pipeline State Object (PSO) before it can be dispatched. Akunu caches these PSOs aggressively:
std::string cache_key = std::string(kernel) + "_k" + std::to_string(K);
pso = device.get_pipeline(kernel, cache_key, fc_indices, fc_values, 1);
The cache key includes the kernel name and any function constant values, ensuring that different K-specializations produce separate PSOs. The first call to get_pipeline compiles the MSL kernel into GPU machine code (which can take 10-50ms), but subsequent calls return the cached PSO instantly.
For a typical model, there are approximately 10-15 unique GEMM PSOs (one per unique K dimension per weight format). These are compiled during model loading and never recompiled during inference.
GEMM vs GEMV: The Crossover Point
An important question: when should the engine use GEMM instead of GEMV? The answer depends on M (the number of activation rows):
| M | Optimal Kernel | Why |
|---|---|---|
| 1 | GEMV | No tile overhead, direct reduction |
| 2-8 | Small GEMM (TM=8) | Some row reuse, minimal padding |
| 9-32 | Standard GEMM (TM=32) | Good tile utilization |
| 33+ | Standard GEMM (TM=32) | Multiple tiles in M dimension |
The crossover between GEMV and GEMM is at M=2. Even with just 2 activation rows, the GEMM kernel’s weight tile loading (shared between both rows) provides better memory efficiency than two separate GEMV dispatches. However, for M=1, the GEMM kernel wastes 31 out of 32 rows in the activation tile, so GEMV is always faster.
Akunu’s dispatch_gemm function makes this decision:
bool small = (M >= 2 && M <= 8);
int TM = small ? 8 : 32;
The small GEMM variant (TM=8) wastes at most 6 rows (for M=2) instead of 30 rows (with TM=32), providing a good compromise for very small batch sizes.
The MMA Instruction in Detail
Apple Silicon’s simdgroup_multiply_accumulate is the hardware primitive that makes efficient GEMM possible. Let’s understand exactly how it works.
Lane-to-Element Mapping
In an 8x8 SIMD matrix, the 32 lanes of a SIMD group each hold 2 elements. The mapping follows Apple’s proprietary layout:3
For an 8x8 matrix stored in a simdgroup_half8x8:
Lane 0: elements (0,0) and (0,1)
Lane 1: elements (0,2) and (0,3)
Lane 2: elements (1,0) and (1,1)
Lane 3: elements (1,2) and (1,3)
...
The thread_elements() accessor returns a vec<T, 2> containing the calling thread’s two elements. This is used by the V2 attention kernel to perform per-element operations directly on MMA results without going through threadgroup memory.
MMA Throughput
Each simdgroup_multiply_accumulate(C, A, B, C) computes:
C[8,8] += A[8,8] @ B[8,8]
This performs 8 * 8 * 8 = 512 multiply-accumulate operations. At FP16 precision, this is 1024 FLOPs per instruction. With 4 SIMD groups per threadgroup and a typical clock rate of 1.4 GHz on M4 Pro, the peak throughput per threadgroup is:
4 SG * 1024 FLOP/instruction * ~1 instruction/cycle * 1.4 GHz
= ~5.7 GFLOPS per threadgroup
With 20 GPU cores running ~10 threadgroups each, the chip-level throughput is approximately 1.14 TFLOPS of FP16 MMA – though in practice, memory bandwidth and barrier overhead reduce this to ~60-80% of peak.
Register Accumulator Precision
The MMA instruction accumulates in the same precision as the operands. For simdgroup_half8x8, accumulation is in FP16. For long K-dimensions (K > 4096), this can lead to precision loss from repeated half-precision additions.
Akunu mitigates this by using the simdgroup_float8x8 accumulator type for attention scores (where precision matters more) while keeping simdgroup_half8x8 for GEMM output (where the subsequent operations, norm + activation, tolerate half-precision).
GEMM vs GEMV: The Crossover Point
An important question: when should the engine use GEMM instead of GEMV? The answer depends on M (the number of activation rows):
| M | Optimal Kernel | Why |
|---|---|---|
| 1 | GEMV | No tile overhead, direct reduction |
| 2-8 | Small GEMM (TM=8) | Some row reuse, minimal padding |
| 9-32 | Standard GEMM (TM=32) | Good tile utilization |
| 33+ | Standard GEMM (TM=32) | Multiple tiles in M dimension |
The crossover between GEMV and GEMM is at M=2. Even with just 2 activation rows, the GEMM kernel’s weight tile loading (shared between both rows) provides better memory efficiency than two separate GEMV dispatches. However, for M=1, the GEMM kernel wastes 31 out of 32 rows in the activation tile, so GEMV is always faster.
Akunu’s dispatch_gemm function makes this decision:
bool small = (M >= 2 && M <= 8);
int TM = small ? 8 : 32;
The small GEMM variant (TM=8) wastes at most 6 rows (for M=2) instead of 30 rows (with TM=32), providing a good compromise for very small batch sizes.
Quantized GEMM Performance Analysis
For a 7B model prefilling 2048 tokens with Q4_0 weights:
The Q projection GEMM: C[2048, 4096] = A[2048, 4096] @ B^T[4096, 4096]
| Metric | Value |
|---|---|
| Output elements | 2048 * 4096 = 8.4M |
| FLOPs | 2 * 2048 * 4096 * 4096 = 68.7 GFLOP |
| Weight data read | 4096 * 4096 * 0.56 bytes = 9.4 MB |
| Activation data read | 2048 * 4096 * 2 bytes = 16.8 MB |
| Output data written | 2048 * 4096 * 2 bytes = 16.8 MB |
| Total memory traffic | ~43 MB |
| Arithmetic intensity | 68.7 GFLOP / 43 MB ≈ 1598 FLOP/byte |
At 1598 FLOP/byte, this is firmly in the compute-bound regime. The M4 Pro’s 14 TFLOPS of FP16 throughput would complete this in ~4.9ms, and memory bandwidth (200 GB/s) would complete the data transfer in ~0.2ms. The GEMM is compute-bound by a factor of ~24x.
This is the fundamental reason prefill is so much faster per-token than decode: the same weight data is reused across 2048 activation rows, amortizing the memory transfer cost.
Handling Non-Standard Architectures
Akunu’s GEMM dispatch supports several architectural variations through the descriptor system:
BERT/Encoder models: Use the same GEMM kernels but with different weight names and optional bias addition (dispatched as a separate kernel after the GEMM).
Gemma models: Have post-attention and post-FFN norms that require extra GEMM dispatch passes. The dispatch_gemm function is architecture-agnostic – it just computes C = alpha * A @ B^T + beta * C.
MLX quantized models: The GEMM kernels receive MLX-specific parameters (group_size, bits, weight_bytes) via a secondary parameter buffer, enabling the same tile geometry with different dequantization logic.
Tied embeddings: The logit projection in some models reuses the embedding table as the output weight. dispatch_gemm does not care about the semantic meaning of the weight – it just needs the buffer, dimensions, and dtype.
Threadgroup Memory Bandwidth
Threadgroup memory on Apple Silicon GPUs has significantly higher bandwidth than device memory – roughly 10-20x, depending on the chip generation. This is why the GEMM kernel’s performance depends heavily on the TG memory access pattern.
The weight tile scatter pattern places data in 8x8 sub-blocks with stride 64 between sub-block rows and stride 8 between columns within a sub-block. This layout is not arbitrary – it matches the simdgroup_load access pattern, ensuring that each MMA instruction’s operand load reads a contiguous 64-byte chunk from threadgroup memory.
For each K-step (32 elements of K):
| Access | Pattern | Bytes | Bandwidth Required |
|---|---|---|---|
| Load weight tile from device | 128 threads, 16 elements each | 4096 bytes | Device BW |
| Scatter weight to TG | 128 threads, indexed writes | 4096 bytes | TG BW |
| Load activation tile from device | 128 threads, 8 elements each | 2048 bytes | Device BW |
| Store activation to TG | 128 threads, vector stores | 2048 bytes | TG BW |
| MMA loads from TG | 4 SG x (4+2) loads per K/8 step | ~6144 bytes per step | TG BW |
The TG memory acts as a software-managed L1 cache, giving the programmer explicit control over data reuse that would otherwise depend on hardware caching behavior.
End-to-End Prefill GEMM Flow
For a complete understanding, let’s trace the GEMMs in a single transformer layer during prefill of a 7B model with seq_len=2048:
| GEMM | M | N | K | Weight Shape | Time (est.) |
|---|---|---|---|---|---|
| Q projection | 2048 | 4096 | 4096 | [4096, 4096] | ~5ms |
| K projection | 2048 | 1024 | 4096 | [1024, 4096] | ~1.5ms |
| V projection | 2048 | 1024 | 4096 | [1024, 4096] | ~1.5ms |
| O projection | 2048 | 4096 | 4096 | [4096, 4096] | ~5ms |
| Gate projection | 2048 | 14336 | 4096 | [14336, 4096] | ~16ms |
| Up projection | 2048 | 14336 | 4096 | [14336, 4096] | ~16ms |
| Down projection | 2048 | 4096 | 14336 | [4096, 14336] | ~16ms |
Total per layer: ~61ms. For 32 layers: ~1.95 seconds. Plus attention, norms, and activations: roughly 2.5 seconds total for 2048 tokens. That is about 820 tokens/sec prefill throughput, which matches real-world measurements on M4 Pro hardware.
The FFN GEMMs (Gate, Up, Down) dominate because ffn_dim (14336) is ~3.5x larger than dim (4096). This is characteristic of modern LLMs that use SwiGLU activation, which requires a wider intermediate dimension.
The Barrier Budget
Threadgroup barriers are a significant cost in GEMM kernels. Each threadgroup_barrier(mem_flags::mem_threadgroup) call synchronizes all threads in the threadgroup and flushes the threadgroup memory. On Apple Silicon, a barrier takes approximately 0.2-0.5 microseconds.
For each K-step (32 elements of K), the GEMM kernel requires 2 barriers:
- After the cooperative tile load (ensure all threads have written their portion)
- After the MMA loop (ensure all SIMD groups have finished reading)
For K=4096, there are 128 K-steps, requiring 256 barriers. At 0.3us per barrier, this is ~77us of pure barrier overhead per tile, or roughly 10-15% of the total tile computation time. This is one of the reasons GEMM does not achieve 100% of peak MMA throughput.
The V2 attention kernel’s approach of keeping data in registers (avoiding the MMA-barrier-MMA cycle) provides a hint at how future GEMM kernels might reduce barrier overhead, though the GEMM’s much larger tile sizes make this approach more challenging.
Comparison with llama.cpp’s GEMM
Akunu’s GEMM kernels are derived from llama.cpp’s kernel_mul_mm family but include several improvements:
| Feature | llama.cpp | Akunu |
|---|---|---|
| Tile geometry | 32x64 (same) | 32x64 + 8x64 small variant |
| Threadgroup swizzling | No | Yes (Q4_0, other quantized) |
| Function constant K | No | Yes (FC_GEMM_K) |
| Alpha/beta support | No | Yes (FP16 GEMM) |
| MLX format support | No | Yes (6 MLX variants) |
| Small M variant | No | Yes (TM=8 for M=2-8) |
| BF16 support | Partial | Full |
The most impactful difference is the function constant K specialization, which allows the Metal compiler to generate tighter loops with known bounds, often resulting in 5-10% speedup for common K dimensions.
The threadgroup swizzling provides another 3-8% improvement at large grid sizes by improving SLC hit rates for weight tiles. This is most noticeable during the FFN GEMMs where the grid is large (14336/64 = 224 tiles in the weight dimension).
Future Directions
Apple’s Metal 3.2 (introduced with the M4 family) provides enhanced simdgroup matrix operations, including support for larger tile sizes and new data types. Future GEMM kernels may benefit from:
- Larger MMA tiles: 16x16 or 32x32 sub-tiles would reduce the number of MMA instructions per output element, improving throughput.
- BF16 MMA: Native BF16 matrix operations would eliminate the conversion overhead for BF16 models.
- Cooperative groups: Finer-grained synchronization primitives could reduce barrier overhead.
- Persistent kernels: A single long-running kernel that processes all tiles sequentially could eliminate inter-tile overhead.
However, the current 8x8 MMA-based approach is well-proven and delivers near-peak performance. The 32x64 tile geometry will likely remain optimal for Apple Silicon’s current generation of GPU architectures.
Debugging GEMM Correctness
GEMM bugs are notoriously difficult to debug because the output is a large matrix where each element depends on the full K-dimension accumulation. Akunu uses several strategies:
- Alpha/Beta support: Setting
alpha=1, beta=0for production andalpha=0, beta=1for “identity” (output = input C) enables isolating GEMM output from existing data. - PSO validation: The
dispatch_gemmfunction includes a fatal error if the PSO fails to compile, catching kernel bugs early. - Dimension checks: The scratch buffer sizes are validated at model load time to ensure no GEMM dispatch will write out of bounds.
- Profiling labels: Each GEMM dispatch in the dispatch table carries a label like
"L5.ffn.gate", making it easy to identify which GEMM produced incorrect output in a GPU debugger.
Summary
Akunu’s GEMM kernels are the workhorses of prefill. The key design decisions are:
- 32x64 tile geometry with 4 SIMD groups per threadgroup, maximizing MMA instruction utilization.
- Inline dequantization for quantized formats, converting directly from packed format to registers without intermediate buffers.
- Cooperative loading where all 128 threads participate in loading both weight and activation tiles.
- Threadgroup swizzling for cache-friendly access patterns across the grid.
- Small GEMM variants for low-M cases to avoid wasted padding computation.
- Function constant specialization for K-dimension to enable compiler optimizations.
-
Apple. “Metal Shading Language Specification, Version 3.1.” Section 6.10, Simdgroup Matrix Functions. The
simdgroup_loadandsimdgroup_storefunctions operate on 8x8 matrices distributed across the 32 threads of a SIMD group, with each thread holding two elements (thethread_elements()accessor). See https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf. ↩ -
The swizzling technique is adapted from NVIDIA’s CUTLASS library. See: Thakkar, V., et al. “CUTLASS: Fast Linear Algebra in CUDA C++.” NVIDIA Technical Blog, 2017. The Apple Silicon SLC acts similarly to NVIDIA’s L2 cache for this optimization. ↩
-
Apple. “Metal Shading Language Specification, Version 3.1.” Section 6.10.3, Simdgroup Matrix Thread Elements. The
thread_elements()accessor returns avec<T, 2>containing the thread’s owned elements, following Apple’s proprietary lane mapping for 8x8 matrices. See https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf. ↩