Matrix Multiplication Strategies
If there is one operation you need to understand deeply when running neural networks on GPUs, it is matrix multiplication. Not because it is conceptually hard – you learned it in linear algebra class – but because virtually everything in a transformer reduces to it. Every linear layer, every projection, every feed-forward network… all matmuls. If your matmul is slow, your model is slow. Period.
In this chapter we are going to build up from the textbook definition of matrix multiplication, examine why it matters so much for inference, and then dive into the different strategies you will encounter when targeting Metal GPUs. We will cover the two main regimes – GEMV and GEMM – plus all the messy in-between cases, and finally discuss how quantization changes the picture.
Why Matmul Is the Center of the Universe
Let us start with a concrete example. Consider a single linear layer in a transformer:
y = x * W + b
Here x is your input (say, a vector of dimension 4096), W is the weight matrix (say,
4096 x 11008), and b is a bias vector. That multiplication x * W is a matrix
multiplication. Now count how many linear layers exist in a single transformer block:
- Q projection: matmul
- K projection: matmul
- V projection: matmul
- Output projection: matmul
- Gate projection (FFN): matmul
- Up projection (FFN): matmul
- Down projection (FFN): matmul
That is seven matmuls per layer. A 32-layer model has 224 matmuls per forward pass. If each one takes 0.5ms, you are already at 112ms just for the matmuls – and that is before attention, normalization, or anything else.
One Transformer Block
=====================
Input
|
v
[LayerNorm]
|
+---> Q Projection (matmul #1)
+---> K Projection (matmul #2)
+---> V Projection (matmul #3)
|
v
[Attention Mechanism]
|
v
[Output Projection] (matmul #4)
|
v
[Residual Add]
|
v
[LayerNorm]
|
+---> Gate Projection (matmul #5)
+---> Up Projection (matmul #6)
|
v
[SiLU + Elementwise Mul]
|
v
[Down Projection] (matmul #7)
|
v
[Residual Add]
|
v
Output
So when someone says “optimizing inference,” they mostly mean “optimizing matmul.”
The Two Regimes: GEMV vs GEMM
Here is the critical insight that changes everything about how you write GPU kernels for LLM inference: the shape of the matmul depends on what phase of inference you are in.
Token Generation (Decode): GEMV
During autoregressive decoding, you generate one token at a time. Your input x is a
single vector of dimension d_model. The weight matrix W has shape [d_model, d_out].
So the multiplication is:
x: [1, 4096]
W: [4096, 11008]
y: [1, 11008]
This is a matrix-vector multiplication, or GEMV (General Matrix-Vector multiply). In BLAS terminology, M=1 (only one row in the output). The defining characteristic of GEMV is that every element of the weight matrix is read exactly once. There is no data reuse. This makes GEMV fundamentally memory bandwidth bound – your performance is limited by how fast you can stream the weight matrix from memory.
GEMV: One vector times a matrix
================================
x = [x0, x1, x2, x3] (1 x K)
W = [ w00 w01 w02 ]
[ w10 w11 w12 ] (K x N)
[ w20 w21 w22 ]
[ w30 w31 w32 ]
y = [y0, y1, y2] (1 x N)
y0 = x0*w00 + x1*w10 + x2*w20 + x3*w30
y1 = x0*w01 + x1*w11 + x2*w21 + x3*w31
y2 = x0*w02 + x1*w12 + x2*w22 + x3*w32
Each weight is touched exactly once.
Bottleneck = memory bandwidth.
Prompt Processing (Prefill): GEMM
During prefill, you process the entire input prompt at once. If the prompt has 512 tokens,
your input is a matrix of shape [512, 4096]. Now the multiplication becomes:
X: [512, 4096]
W: [4096, 11008]
Y: [512, 11008]
This is a full matrix-matrix multiplication, or GEMM (General Matrix-Matrix multiply).
M=512, and now every element of W gets reused across 512 different input rows. This
changes the arithmetic intensity dramatically – GEMM can be compute bound if you
organize the data access well.
GEMM: A matrix times a matrix
===============================
X = [ x00 x01 x02 x03 ]
[ x10 x11 x12 x13 ] (M x K)
[ x20 x21 x22 x23 ]
W = [ w00 w01 ]
[ w10 w11 ] (K x N)
[ w20 w21 ]
[ w30 w31 ]
Y = [ y00 y01 ]
[ y10 y11 ] (M x N)
[ y20 y21 ]
Each weight wij is used M times (once per row of X).
With good tiling, this becomes compute-bound.
The Arithmetic Intensity Spectrum
Let us quantify this. Arithmetic intensity is the ratio of compute operations to memory bytes transferred:
GEMV (M=1):
Operations = 2 * K * N (multiply + add for each element)
Bytes read = K * N * sizeof(weight) + K * sizeof(input)
Intensity ~ 2 / sizeof(weight)
For FP16: 2 / 2 = 1 FLOP/byte --> bandwidth bound
GEMM (M=512):
Operations = 2 * M * K * N
Bytes read = K * N * sizeof(weight) + M * K * sizeof(input)
Intensity ~ 2 * M / sizeof(weight) (when K*N >> M*K)
For FP16: 2 * 512 / 2 = 512 FLOP/byte --> compute bound
Metal GPUs typically have 200-400 GB/s of memory bandwidth and 10-20 TFLOPS of FP16 compute. The crossover point – where you shift from bandwidth bound to compute bound – is roughly around M=16-64, depending on the specific hardware and data types.
Arithmetic Intensity vs M
=========================
Intensity
(FLOP/byte)
|
512 | * GEMM (M=512)
| *
| *
| *
| *
64 | * <-- Compute bound above this line
| *
| - - - - - - - - - - -*- - - - - - HW Compute/BW ratio
16 | *
| *
4 | *
2 | *
1 | * * <-- GEMV (M=1)
+--+---+---+---+---+---+---+---+---> M
1 2 4 8 16 32 64 512
GEMV: Thread and SIMD Parallelism
Let us start with GEMV since it is the bread and butter of token generation. Remember,
M=1, so we are computing y = x * W where x is a vector and W is a matrix. The
output y has N elements, and each output element is a dot product of length K.
The Naive Approach
The simplest approach: assign one thread per output element. Each thread computes a full dot product of length K:
// Naive GEMV: one thread per output element
kernel void gemv_naive(
device const float* x [[buffer(0)]], // [K]
device const float* W [[buffer(1)]], // [K x N], row-major
device float* y [[buffer(2)]], // [N]
uint tid [[thread_position_in_grid]])
{
float sum = 0.0f;
for (uint k = 0; k < K; k++) {
sum += x[k] * W[k * N + tid]; // Column tid of W
}
y[tid] = sum;
}
This works but has terrible performance. Each thread reads all K elements of x (lots of
redundant reads) and walks down a column of W with stride N (terrible memory access
pattern, no coalescing).
Coalesced Access with Transposed Weights
First improvement: store W in column-major order (or equivalently, store W^T in
row-major order). Now adjacent threads read adjacent memory locations:
// GEMV with transposed weights: coalesced reads
kernel void gemv_transposed(
device const float* x [[buffer(0)]], // [K]
device const float* Wt [[buffer(1)]], // [N x K], W transposed
device float* y [[buffer(2)]], // [N]
uint tid [[thread_position_in_grid]])
{
// Thread tid computes output element tid
// Reads row tid of Wt, which is column tid of W
float sum = 0.0f;
for (uint k = 0; k < K; k++) {
sum += x[k] * Wt[tid * K + k];
}
y[tid] = sum;
}
Better memory access pattern, but we still have N threads each independently reading the
entire x vector. Let us fix that.
SIMD-Level Parallelism: Splitting the K Dimension
Here is the key idea for efficient GEMV on Metal: instead of one thread per output, use a
group of threads (a SIMD group) to cooperatively compute one output element. Each
thread in the SIMD group handles a slice of the K dimension, then we use simd_sum to
reduce:
SIMD GEMV: 32 threads cooperate on one dot product
====================================================
x = [x0, x1, x2, ..., x31, x32, ..., x63, ...]
Thread 0 handles: x[0]*w[0] + x[32]*w[32] + x[64]*w[64] + ...
Thread 1 handles: x[1]*w[1] + x[33]*w[33] + x[65]*w[65] + ...
...
Thread 31 handles: x[31]*w[31] + x[63]*w[63] + x[95]*w[95] + ...
Then: simd_sum() adds all 32 partial sums in hardware
Result: one output element computed cooperatively
// SIMD GEMV: one SIMD group per output element
kernel void gemv_simd(
device const float* x [[buffer(0)]],
device const float* Wt [[buffer(1)]], // [N x K]
device float* y [[buffer(2)]],
constant uint& K [[buffer(3)]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint tg_id [[threadgroup_position_in_grid]])
{
// Each SIMD group computes one output element
uint n = tg_id * SIMD_GROUPS_PER_TG + simd_gid;
if (n >= N) return;
float partial = 0.0f;
// Each lane processes every 32nd element
for (uint k = simd_lane; k < K; k += 32) {
partial += x[k] * Wt[n * K + k];
}
// Hardware reduction across the SIMD group
float total = simd_sum(partial);
// Only lane 0 writes the result
if (simd_lane == 0) {
y[n] = total;
}
}
The simd_sum intrinsic is a hardware-level reduction. On Metal, a SIMD group (also
called a “wave” on other platforms) is 32 threads that execute in lockstep. The simd_sum
operation sums a value across all 32 lanes in just a few clock cycles – no shared memory,
no barriers, no atomics. This is enormously powerful.
simd_sum reduction (hardware butterfly)
========================================
Lane: 0 1 2 3 ... 30 31
Val: 2.1 0.5 1.3 0.8 ... 0.2 1.1
Step 1: swap with neighbor, add
2.6 2.6 2.1 2.1 ... 1.3 1.3
Step 2: swap with stride-2 neighbor, add
4.7 4.7 4.7 4.7 ... ... ...
Step 3: stride 4 ...
Step 4: stride 8 ...
Step 5: stride 16 ...
After 5 steps: all lanes hold the total sum.
Cost: ~5 cycles. No memory traffic.
Multiple Output Elements per SIMD Group
We can do even better. Instead of one output element per SIMD group, compute several. This
amortizes the cost of reading x:
// SIMD GEMV: one SIMD group computes 4 output elements
kernel void gemv_simd_multi(
device const float* x [[buffer(0)]],
device const float* Wt [[buffer(1)]],
device float* y [[buffer(2)]],
constant uint& K [[buffer(3)]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint tg_id [[threadgroup_position_in_grid]])
{
uint n_base = (tg_id * SIMD_GROUPS_PER_TG + simd_gid) * 4;
float4 partial = float4(0.0f);
for (uint k = simd_lane; k < K; k += 32) {
float xk = x[k]; // Load x[k] once, use for all 4 outputs
partial[0] += xk * Wt[(n_base + 0) * K + k];
partial[1] += xk * Wt[(n_base + 1) * K + k];
partial[2] += xk * Wt[(n_base + 2) * K + k];
partial[3] += xk * Wt[(n_base + 3) * K + k];
}
// Reduce each output element
float y0 = simd_sum(partial[0]);
float y1 = simd_sum(partial[1]);
float y2 = simd_sum(partial[2]);
float y3 = simd_sum(partial[3]);
if (simd_lane == 0) {
y[n_base + 0] = y0;
y[n_base + 1] = y1;
y[n_base + 2] = y2;
y[n_base + 3] = y3;
}
}
Now each SIMD group loads x[k] once and uses it four times. The weight reads are still
one-shot (each weight used exactly once), but we have reduced the x reads by 4x.
GEMM: Tiling for Compute Efficiency
When M is large (prefill phase), we enter the world of GEMM. The key technique here is tiling: dividing the output matrix into tiles and assigning each tile to a threadgroup. Within the threadgroup, we further divide work among SIMD groups.
The Tiling Concept
GEMM Tiling Overview
=====================
Output Y [M x N] is divided into tiles of size [Bm x Bn]
N
<----------->
+----+----+----+----+ ^
| TG | TG | TG | TG | |
| 00 | 01 | 02 | 03 | |
+----+----+----+----+ | M
| TG | TG | TG | TG | |
| 10 | 11 | 12 | 13 | |
+----+----+----+----+ v
Each TG computes a [Bm x Bn] tile of the output.
To compute its tile, TG needs:
- A strip of X: [Bm x K]
- A strip of W: [K x Bn]
But K can be huge (4096+), so we tile along K too:
K
<----------->
+====+====+====+ <- X rows for this TG's tile
| Bk | Bk | Bk |
+====+====+====+
Process K in chunks of Bk, accumulating partial results.
Cooperative Loading into Threadgroup Memory
The key insight is that threads within a threadgroup can cooperatively load a tile of data into fast threadgroup memory (shared memory), then all threads read from it. This converts slow device memory reads into fast threadgroup memory reads:
// Simplified tiled GEMM
kernel void gemm_tiled(
device const half* X [[buffer(0)]], // [M x K]
device const half* W [[buffer(1)]], // [K x N]
device half* Y [[buffer(2)]], // [M x N]
constant uint& M [[buffer(3)]],
constant uint& N [[buffer(4)]],
constant uint& K [[buffer(5)]],
uint2 tg_pos [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]])
{
// Tile dimensions
const uint Bm = 32; // Tile rows
const uint Bn = 32; // Tile cols
const uint Bk = 16; // K-tile size
// Threadgroup memory for tiles
threadgroup half X_tile[Bm * Bk];
threadgroup half W_tile[Bk * Bn];
// This threadgroup computes output tile at (tg_pos.y * Bm, tg_pos.x * Bn)
uint m_base = tg_pos.y * Bm;
uint n_base = tg_pos.x * Bn;
// Accumulator for this thread's output elements
float acc[4] = {0, 0, 0, 0}; // Each thread computes 4 elements
// Walk along K in steps of Bk
for (uint k_base = 0; k_base < K; k_base += Bk) {
// === Cooperative load: all threads load tiles into shared memory ===
// (simplified -- real code distributes load across all threads)
if (tid < Bm * Bk) {
uint row = tid / Bk;
uint col = tid % Bk;
X_tile[tid] = X[(m_base + row) * K + (k_base + col)];
}
if (tid < Bk * Bn) {
uint row = tid / Bn;
uint col = tid % Bn;
W_tile[tid] = W[(k_base + row) * N + (n_base + col)];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// === Compute: multiply tiles ===
// Each thread computes its assigned output elements
// using data from threadgroup memory (fast!)
uint my_m = tid / 8; // Which row within tile
uint my_n = (tid % 8) * 4; // Which 4 columns within tile
for (uint kk = 0; kk < Bk; kk++) {
half x_val = X_tile[my_m * Bk + kk];
for (uint j = 0; j < 4; j++) {
acc[j] += float(x_val) * float(W_tile[kk * Bn + my_n + j]);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Write results
uint my_m = tid / 8;
uint my_n = (tid % 8) * 4;
for (uint j = 0; j < 4; j++) {
Y[(m_base + my_m) * N + (n_base + my_n + j)] = half(acc[j]);
}
}
Let us trace through what happens for one K-tile:
One iteration of the K-loop (k_base = 0, Bk = 16)
===================================================
Step 1: Cooperative Load
~~~~~~~~~~~~~~~~~~~~~~~~
All 256 threads in the threadgroup work together to load:
X_tile [32 x 16]: W_tile [16 x 32]:
Rows m_base..m_base+31 Rows 0..15
Cols 0..15 of X Cols n_base..n_base+31 of W
Thread 0 loads X[m_base][0..15] row Thread 128 loads W[0][n_base..+31]
Thread 1 loads X[m_base+1][0..15] Thread 129 loads W[1][n_base..+31]
... ...
Step 2: Barrier
~~~~~~~~~~~~~~~
Wait for all loads to complete.
Step 3: Compute
~~~~~~~~~~~~~~~
Each thread reads from the fast threadgroup memory tiles
and accumulates partial results.
Thread 0 computes: Y[m_base+0][n_base+0..3] += X_tile[0][k] * W_tile[k][0..3]
Thread 1 computes: Y[m_base+0][n_base+4..7] += X_tile[0][k] * W_tile[k][4..7]
...
Thread 8 computes: Y[m_base+1][n_base+0..3] += X_tile[1][k] * W_tile[k][0..3]
...
Step 4: Barrier
~~~~~~~~~~~~~~~
Wait before overwriting tiles with next K-chunk.
SIMD Group Matrix Multiply-Accumulate (MMA)
Metal 3.0+ on Apple Silicon supports SIMD group matrix operations that are much faster than manually computing the multiply-accumulate. These map to the hardware’s matrix multiplication units (the AMX/matrix coprocessor):
#include <metal_simdgroup_matrix>
// Using SIMD group MMA for tiled GEMM
kernel void gemm_simdgroup_mma(
device const half* X [[buffer(0)]],
device const half* W [[buffer(1)]],
device half* Y [[buffer(2)]],
constant uint& M [[buffer(3)]],
constant uint& N [[buffer(4)]],
constant uint& K [[buffer(5)]],
uint2 tg_pos [[threadgroup_position_in_grid]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]])
{
// Each SIMD group computes an 8x8 tile of the output
// using 8x8 matrix operations
// Declare SIMD group matrices
simdgroup_matrix<half, 8, 8> x_mat;
simdgroup_matrix<half, 8, 8> w_mat;
simdgroup_matrix<half, 8, 8> acc_mat;
// Initialize accumulator to zero
simdgroup_fill(acc_mat, half(0));
uint m_base = tg_pos.y * 32 + (sgid / 4) * 8;
uint n_base = tg_pos.x * 32 + (sgid % 4) * 8;
for (uint k = 0; k < K; k += 8) {
// Load 8x8 tiles from X and W
simdgroup_load(x_mat, X + m_base * K + k, K);
simdgroup_load(w_mat, W + k * N + n_base, N);
// 8x8 matrix multiply-accumulate in hardware
simdgroup_multiply_accumulate(acc_mat, x_mat, w_mat, acc_mat);
}
// Store the 8x8 result tile
simdgroup_store(acc_mat, Y + m_base * N + n_base, N);
}
The simdgroup_multiply_accumulate function computes an 8x8 matrix multiply in hardware.
A single SIMD group of 32 threads cooperatively holds the 8x8 matrices (each thread holds
2 elements) and the multiplication happens in the dedicated matrix hardware. This is
dramatically faster than doing the multiply-add manually.
SIMD Group MMA: 32 threads own an 8x8 matrix
===============================================
The 8x8 = 64 elements are distributed across 32 threads:
Thread 0: elements [0,0] and [0,1]
Thread 1: elements [0,2] and [0,3]
...
Thread 15: elements [1,6] and [1,7]
Thread 16: elements [2,0] and [2,1]
...
Thread 31: elements [3,6] and [3,7]
simdgroup_multiply_accumulate(C, A, B, C):
C += A * B (all 8x8, in ~1-2 cycles)
This maps to Apple's matrix coprocessor hardware.
Wide GEMV: The Vocabulary Projection Problem
There is one particular GEMV that deserves special attention: the final vocabulary projection. In a typical LLM, this multiplies the hidden state (dimension 4096) by the vocabulary embedding matrix to produce logits over the entire vocabulary:
hidden: [1, 4096]
vocab_weights: [4096, 32000] (or 128256 for Llama 3!)
logits: [1, 32000]
This is still a GEMV (M=1), but N is enormous – 32K to 128K output elements. The challenge is that you need enough parallelism to saturate the GPU, and you need to read a very large weight matrix.
The strategy is to parallelize aggressively across the N dimension:
Wide GEMV for Vocab Projection
===============================
N = 32000 output logits
K = 4096
Approach: assign SIMD groups to output elements
Threadgroup 0: SIMD groups compute outputs [0..127]
Threadgroup 1: SIMD groups compute outputs [128..255]
...
Threadgroup 249: SIMD groups compute outputs [31872..31999]
With 250 threadgroups x 4 SIMD groups/TG x 32 threads/SIMD = 32000 threads
Each SIMD group: 32 threads reduce a dot product of length 4096
That is 4096/32 = 128 multiply-adds per thread.
Total weight data: 4096 * 32000 * 2 bytes = 256 MB (FP16)
At 200 GB/s: ~1.3ms to stream all weights
This sets the floor for GEMV performance.
The key insight: for wide GEMV, you want each SIMD group to handle one (or a few) output
columns, with the 32 lanes splitting the K-dimension reduction. The x vector is small
enough to fit in threadgroup memory or even registers, so you load it once and reuse it
for all output columns in the threadgroup.
// Wide GEMV: optimized for large N (vocab projection)
kernel void gemv_wide(
device const half* x [[buffer(0)]], // [K]
device const half* Wt [[buffer(1)]], // [N x K]
device half* y [[buffer(2)]], // [N]
constant uint& K [[buffer(3)]],
constant uint& N [[buffer(4)]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]],
uint tg_id [[threadgroup_position_in_grid]])
{
// Load x into threadgroup memory (cooperative load)
threadgroup half x_shared[4096]; // Assumes K <= 4096
uint tid = sgid * 32 + lane;
for (uint i = tid; i < K; i += 256) { // 256 threads per TG
x_shared[i] = x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Each SIMD group handles one output element
uint n = tg_id * 8 + sgid; // 8 SIMD groups per TG
if (n >= N) return;
float sum = 0.0f;
device const half* w_row = Wt + n * K;
for (uint k = lane; k < K; k += 32) {
sum += float(x_shared[k]) * float(w_row[k]);
}
sum = simd_sum(sum);
if (lane == 0) {
y[n] = half(sum);
}
}
Batched GEMV: The Awkward Middle Ground
In practice, inference is not always purely M=1 or M=512. There are several scenarios where M is small but greater than 1:
- Batch decoding: serving multiple users simultaneously (M = batch_size, often 2-16)
- Speculative decoding: verifying multiple candidate tokens at once (M = num_candidates, often 4-8)
- Small prompts: short prefill with just a few tokens
For these cases, M is too small for efficient GEMM (not enough reuse to be compute bound) but too large for pure GEMV (we want some reuse of the weight data).
The solution is batched GEMV: treat it as M independent GEMVs but share weight loads across them:
Batched GEMV (M=4)
===================
X = [ x0 ] (4 rows, each a vector of length K)
[ x1 ]
[ x2 ]
[ x3 ]
Strategy: Each SIMD group handles multiple output columns
and ALL M rows simultaneously.
SIMD Group 0 processing output column n=0:
+-----+-----+-----+-----+-----+
| Lane| 0 | 1 | 2 | ... |
+-----+-----+-----+-----+-----+
| Row0| x0[0]*w[0] | x0[1]*w[1] | x0[2]*w[2] | ... |
| Row1| x1[0]*w[0] | x1[1]*w[1] | x1[2]*w[2] | ... |
| Row2| x2[0]*w[0] | x2[1]*w[1] | x2[2]*w[2] | ... |
| Row3| x3[0]*w[0] | x3[1]*w[1] | x3[2]*w[2] | ... |
+-----+-----+-----+-----+-----+
Each weight w[k] is loaded once and used for M=4 rows.
Weight reuse factor: 4x compared to pure GEMV.
// Batched GEMV: M rows share weight loads
kernel void gemv_batched(
device const half* X [[buffer(0)]], // [M x K]
device const half* Wt [[buffer(1)]], // [N x K]
device half* Y [[buffer(2)]], // [M x N]
constant uint& M [[buffer(3)]],
constant uint& K [[buffer(4)]],
constant uint& N [[buffer(5)]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]],
uint tg_id [[threadgroup_position_in_grid]])
{
uint n = tg_id * 8 + sgid; // Output column
if (n >= N) return;
// Accumulators for each row
float sums[8] = {0}; // Support up to M=8
device const half* w_row = Wt + n * K;
for (uint k = lane; k < K; k += 32) {
half w_val = w_row[k]; // Load weight ONCE
// Apply to all M rows
for (uint m = 0; m < M; m++) {
sums[m] += float(X[m * K + k]) * float(w_val);
}
}
// Reduce each row's sum
for (uint m = 0; m < M; m++) {
float total = simd_sum(sums[m]);
if (lane == 0) {
Y[m * N + n] = half(total);
}
}
}
The performance gain over running M separate GEMVs is significant because:
- Weight data is loaded once instead of M times
- Kernel launch overhead is reduced
- The GPU can be better saturated with the additional work
Performance: M separate GEMVs vs Batched GEMV
===============================================
M=4, K=4096, N=4096, FP16 weights
M separate GEMVs:
Weight reads: 4 * 4096 * 4096 * 2 bytes = 128 MB
Time at 200 GB/s: ~0.64 ms
Batched GEMV:
Weight reads: 1 * 4096 * 4096 * 2 bytes = 32 MB
Time at 200 GB/s: ~0.16 ms
4x improvement from weight reuse!
How Quantization Changes the Game
Everything we have discussed so far assumes FP16 weights. But in practice, most inference deployments use quantized weights – 4-bit or even 2-bit. This changes the matmul kernels significantly.
The Core Challenge: Dequantize on the Fly
Quantized weights are stored in a compressed format. You cannot directly multiply them with the input – you first need to dequantize them back to a floating-point representation. The question is: where and when do you dequantize?
The answer is on-the-fly dequantization: each thread dequantizes just the weight values it needs, right before multiplying them with the input. The dequantized values live only in registers, never written back to memory.
Dequantize-on-the-fly GEMV
===========================
Memory: [...Q4 packed weights...] (4 bits per weight)
|
v (load)
Registers: [packed_byte]
|
v (unpack + scale)
Registers: [fp16_weight_a, fp16_weight_b]
|
v (multiply with input)
Registers: [partial_sum]
|
v (simd_sum)
Memory: [output]
Key: dequantized weights never touch memory.
We trade compute (dequantize) for bandwidth (smaller reads).
Q4_0 GEMV Example
Let us work through a concrete example with Q4_0 quantization (the simplest format). In Q4_0, every block of 32 weights shares a single FP16 scale factor. Each weight is stored as a 4-bit integer (0-15), representing the range [-8, 7] after subtracting 8:
// Q4_0 block structure
struct block_q4_0 {
half scale; // 2 bytes: shared scale for 32 weights
uchar packed[16]; // 16 bytes: 32 x 4-bit values packed into pairs
};
// Total: 18 bytes for 32 weights (4.5 bits/weight effective)
// Q4_0 GEMV: dequantize on the fly
kernel void gemv_q4_0(
device const half* x [[buffer(0)]], // [K]
device const block_q4_0* W [[buffer(1)]], // [N x K/32] blocks
device half* y [[buffer(2)]], // [N]
constant uint& K [[buffer(3)]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]],
uint tg_id [[threadgroup_position_in_grid]])
{
uint n = tg_id * 8 + sgid;
uint num_blocks = K / 32;
float sum = 0.0f;
// Each lane processes every 32nd block
for (uint b = lane; b < num_blocks; b += 32) {
// Load the quantized block
device const block_q4_0& blk = W[n * num_blocks + b];
half scale = blk.scale;
// Dequantize and multiply all 32 weights in this block
for (uint j = 0; j < 16; j++) {
uchar packed = blk.packed[j];
// Unpack two 4-bit values
int8_t w0 = (packed & 0x0F) - 8; // Low nibble
int8_t w1 = (packed >> 4) - 8; // High nibble
// Dequantize: weight = scale * quantized_value
float dq0 = float(scale) * float(w0);
float dq1 = float(scale) * float(w1);
// Multiply with input
uint k_idx = b * 32 + j * 2;
sum += dq0 * float(x[k_idx]);
sum += dq1 * float(x[k_idx + 1]);
}
}
sum = simd_sum(sum);
if (lane == 0) {
y[n] = half(sum);
}
}
The Bandwidth Win
The reason quantization is so important for GEMV performance is straightforward arithmetic. Since GEMV is bandwidth bound, reducing the data size directly speeds things up:
Bandwidth Savings with Quantization
====================================
Weight matrix: 4096 x 4096 = 16.7M weights
FP16: 16.7M * 2 bytes = 33.6 MB (baseline)
Q8_0: 16.7M * 1.06 bytes = 17.7 MB (1.9x faster)
Q4_0: 16.7M * 0.56 bytes = 9.4 MB (3.6x faster)
Q4_K: 16.7M * 0.57 bytes = 9.5 MB (3.5x faster)
Q2_K: 16.7M * 0.34 bytes = 5.7 MB (5.9x faster)
At 200 GB/s memory bandwidth:
FP16: 0.168 ms per layer
Q4_0: 0.047 ms per layer <-- 3.6x speedup!
Decode speed for 7B model:
FP16: 14.0 GB / 200 GB/s = 70 ms/token -> 14 tok/s
Q4_0: 3.9 GB / 200 GB/s = 19 ms/token -> 52 tok/s
The compute cost of dequantization is negligible compared to the bandwidth savings. A few shifts and multiplies per weight value are cheap when the alternative is waiting for twice as many bytes to arrive from memory.
Quantized GEMM: A Different Story
For GEMM (prefill), the situation is more nuanced. GEMM is compute bound, so reducing weight size does not automatically speed things up – you are not bandwidth limited in the first place. However, quantized GEMM is still useful because:
- It reduces memory footprint, letting larger models fit in memory
- For moderate M values (16-64), you might still be bandwidth bound
- Apple’s AMX hardware has limited support for mixed-precision MMA
The typical approach is to dequantize tiles of weights into threadgroup memory in FP16, then use the standard SIMD group MMA operations on the dequantized tiles:
Quantized GEMM: Dequantize into Threadgroup Memory
====================================================
Device Memory: Threadgroup Memory: Registers:
+-----------+ +----------+ +--------+
| Q4 Weights| --load-->| FP16 Tile| --MMA--> | Accum |
| (compact) | | (32x32) | | (8x8) |
+-----------+ +----------+ +--------+
1. Cooperatively load Q4 blocks from device memory
2. Dequantize to FP16 in threadgroup memory
3. Use simdgroup_multiply_accumulate on FP16 tiles
4. Repeat for all K-tiles
Worked Example: Full Thread Assignment
Let us trace through a complete example to see how threads are assigned for a real GEMV. Consider:
Problem: y = x * W
x: [1, 4096] (one token's hidden state)
W: [4096, 4096] (Q projection weight)
y: [1, 4096] (query vector)
Configuration:
- Threadgroup size: 256 threads = 8 SIMD groups x 32 lanes
- Each SIMD group: 1 output element
- Grid: 4096 / 8 = 512 threadgroups
Memory layout: W stored transposed as Wt [4096 x 4096]
Let us follow Threadgroup 0, SIMD group 3, Lane 17:
Thread Identity
===============
Threadgroup: 0
SIMD group: 3
Lane: 17
Output assignment
=================
output index n = threadgroup * 8 + simd_group = 0 * 8 + 3 = 3
This thread helps compute y[3].
Work assignment
===============
Lane 17 processes every 32nd element along K:
k = 17, 49, 81, 113, ..., 4081
Total iterations: 4096 / 32 = 128
For each iteration (e.g., k=17):
partial_sum += x[17] * Wt[3 * 4096 + 17]
After all 128 iterations:
partial_sum = x[17]*Wt[3,17] + x[49]*Wt[3,49] + ... + x[4081]*Wt[3,4081]
Reduction
=========
simd_sum(partial_sum) adds all 32 lanes' partial sums:
y[3] = sum over all k: x[k] * Wt[3, k]
= dot(x, column 3 of W)
Only lane 0 of SIMD group 3 writes y[3] to memory.
Summary for entire kernel
=========================
512 threadgroups * 8 SIMD groups = 4096 dot products
Each SIMD group: 32 threads * 128 iterations = 4096 multiply-adds
Total: 4096 * 4096 = 16.7M multiply-adds (matches 2*K*N/2 operations)
Choosing the Right Strategy
Here is a decision tree for choosing the right matmul strategy:
Matmul Strategy Decision Tree
==============================
What is M?
|
+-------------+-------------+
| | |
M = 1 2-16 > 16
| | |
GEMV Batched GEMV GEMM
| | |
+----+----+ +----+----+ +----+----+
| | | | | |
Small N Large N | Use SIMD
(4096) (32K+) Share Group
| | weight MMA
1 SIMD Lots of loads with tiling
group threadgroups |
per out +----+----+
| | |
Standard FP16 Quantized
SIMD GEMV tiles (dequant
+ MMA into TG
memory)
And here are rough performance expectations on Apple M-series GPUs:
Performance Expectations (M2 Pro, ~200 GB/s)
=============================================
Operation Size FP16 Q4_0
--------- ---- ---- ----
GEMV (decode, 1 layer) [1,4096]*[4096,4096] 0.17ms 0.05ms
GEMV (vocab projection) [1,4096]*[4096,32000] 1.28ms 0.36ms
Batched GEMV (M=4) [4,4096]*[4096,4096] 0.17ms 0.06ms
GEMM (prefill, M=512) [512,4096]*[4096,4096] ~2ms ~2ms
Note: Batched GEMV with M=4 is barely slower than M=1 GEMV because
the same weight data is read -- only the compute increases.
GEMM times are similar for FP16 and Q4 because GEMM is compute-bound.
Summary
Matrix multiplication is the heart of neural network inference. On Metal GPUs, you need different strategies depending on the shape:
-
GEMV (M=1, decode): Bandwidth bound. Split the K dimension across SIMD lanes, reduce with
simd_sum. Optimize for coalesced memory access and maximize bandwidth utilization. -
Batched GEMV (M=2-16): Still bandwidth bound but with some weight reuse. Load weights once, apply to all M rows. Significant speedup over M separate GEMVs.
-
GEMM (M>16, prefill): Can be compute bound with proper tiling. Use cooperative loading into threadgroup memory and SIMD group MMA for maximum throughput.
-
Quantized variants: For GEMV, quantization directly translates to speedup via reduced bandwidth. Dequantize on-the-fly in registers. For GEMM, dequantize into threadgroup memory tiles, then use standard MMA.
The next chapter will look at where these matmuls are used in the context of attention – where Q, K, and V come from matmuls, and then the attention computation itself introduces a different kind of matrix multiplication.