Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Matrix Multiplication Strategies

If there is one operation you need to understand deeply when running neural networks on GPUs, it is matrix multiplication. Not because it is conceptually hard – you learned it in linear algebra class – but because virtually everything in a transformer reduces to it. Every linear layer, every projection, every feed-forward network… all matmuls. If your matmul is slow, your model is slow. Period.

In this chapter we are going to build up from the textbook definition of matrix multiplication, examine why it matters so much for inference, and then dive into the different strategies you will encounter when targeting Metal GPUs. We will cover the two main regimes – GEMV and GEMM – plus all the messy in-between cases, and finally discuss how quantization changes the picture.

Why Matmul Is the Center of the Universe

Let us start with a concrete example. Consider a single linear layer in a transformer:

y = x * W + b

Here x is your input (say, a vector of dimension 4096), W is the weight matrix (say, 4096 x 11008), and b is a bias vector. That multiplication x * W is a matrix multiplication. Now count how many linear layers exist in a single transformer block:

  • Q projection: matmul
  • K projection: matmul
  • V projection: matmul
  • Output projection: matmul
  • Gate projection (FFN): matmul
  • Up projection (FFN): matmul
  • Down projection (FFN): matmul

That is seven matmuls per layer. A 32-layer model has 224 matmuls per forward pass. If each one takes 0.5ms, you are already at 112ms just for the matmuls – and that is before attention, normalization, or anything else.

  One Transformer Block
  =====================

  Input
    |
    v
  [LayerNorm]
    |
    +---> Q Projection  (matmul #1)
    +---> K Projection  (matmul #2)
    +---> V Projection  (matmul #3)
    |
    v
  [Attention Mechanism]
    |
    v
  [Output Projection]    (matmul #4)
    |
    v
  [Residual Add]
    |
    v
  [LayerNorm]
    |
    +---> Gate Projection (matmul #5)
    +---> Up Projection   (matmul #6)
    |
    v
  [SiLU + Elementwise Mul]
    |
    v
  [Down Projection]       (matmul #7)
    |
    v
  [Residual Add]
    |
    v
  Output

So when someone says “optimizing inference,” they mostly mean “optimizing matmul.”

The Two Regimes: GEMV vs GEMM

Here is the critical insight that changes everything about how you write GPU kernels for LLM inference: the shape of the matmul depends on what phase of inference you are in.

Token Generation (Decode): GEMV

During autoregressive decoding, you generate one token at a time. Your input x is a single vector of dimension d_model. The weight matrix W has shape [d_model, d_out]. So the multiplication is:

x:  [1, 4096]
W:  [4096, 11008]
y:  [1, 11008]

This is a matrix-vector multiplication, or GEMV (General Matrix-Vector multiply). In BLAS terminology, M=1 (only one row in the output). The defining characteristic of GEMV is that every element of the weight matrix is read exactly once. There is no data reuse. This makes GEMV fundamentally memory bandwidth bound – your performance is limited by how fast you can stream the weight matrix from memory.

  GEMV: One vector times a matrix
  ================================

  x = [x0, x1, x2, x3]    (1 x K)

       W = [ w00  w01  w02 ]
           [ w10  w11  w12 ]    (K x N)
           [ w20  w21  w22 ]
           [ w30  w31  w32 ]

  y = [y0, y1, y2]         (1 x N)

  y0 = x0*w00 + x1*w10 + x2*w20 + x3*w30
  y1 = x0*w01 + x1*w11 + x2*w21 + x3*w31
  y2 = x0*w02 + x1*w12 + x2*w22 + x3*w32

  Each weight is touched exactly once.
  Bottleneck = memory bandwidth.

Prompt Processing (Prefill): GEMM

During prefill, you process the entire input prompt at once. If the prompt has 512 tokens, your input is a matrix of shape [512, 4096]. Now the multiplication becomes:

X:  [512, 4096]
W:  [4096, 11008]
Y:  [512, 11008]

This is a full matrix-matrix multiplication, or GEMM (General Matrix-Matrix multiply). M=512, and now every element of W gets reused across 512 different input rows. This changes the arithmetic intensity dramatically – GEMM can be compute bound if you organize the data access well.

  GEMM: A matrix times a matrix
  ===============================

  X = [ x00 x01 x02 x03 ]
      [ x10 x11 x12 x13 ]    (M x K)
      [ x20 x21 x22 x23 ]

       W = [ w00  w01 ]
           [ w10  w11 ]       (K x N)
           [ w20  w21 ]
           [ w30  w31 ]

  Y = [ y00 y01 ]
      [ y10 y11 ]             (M x N)
      [ y20 y21 ]

  Each weight wij is used M times (once per row of X).
  With good tiling, this becomes compute-bound.

The Arithmetic Intensity Spectrum

Let us quantify this. Arithmetic intensity is the ratio of compute operations to memory bytes transferred:

  GEMV (M=1):
    Operations = 2 * K * N  (multiply + add for each element)
    Bytes read = K * N * sizeof(weight) + K * sizeof(input)
    Intensity  ~ 2 / sizeof(weight)
    For FP16:    2 / 2 = 1 FLOP/byte    --> bandwidth bound

  GEMM (M=512):
    Operations = 2 * M * K * N
    Bytes read = K * N * sizeof(weight) + M * K * sizeof(input)
    Intensity  ~ 2 * M / sizeof(weight)  (when K*N >> M*K)
    For FP16:    2 * 512 / 2 = 512 FLOP/byte  --> compute bound

Metal GPUs typically have 200-400 GB/s of memory bandwidth and 10-20 TFLOPS of FP16 compute. The crossover point – where you shift from bandwidth bound to compute bound – is roughly around M=16-64, depending on the specific hardware and data types.

  Arithmetic Intensity vs M
  =========================

  Intensity
  (FLOP/byte)
      |
  512 |                                          *  GEMM (M=512)
      |                                       *
      |                                    *
      |                                 *
      |                              *
   64 |                           *  <-- Compute bound above this line
      |                        *
      | - - - - - - - - - - -*- - - - - - HW Compute/BW ratio
   16 |                  *
      |               *
    4 |            *
    2 |         *
    1 |  *   *    <-- GEMV (M=1)
      +--+---+---+---+---+---+---+---+--->  M
         1   2   4   8  16  32  64  512

GEMV: Thread and SIMD Parallelism

Let us start with GEMV since it is the bread and butter of token generation. Remember, M=1, so we are computing y = x * W where x is a vector and W is a matrix. The output y has N elements, and each output element is a dot product of length K.

The Naive Approach

The simplest approach: assign one thread per output element. Each thread computes a full dot product of length K:

// Naive GEMV: one thread per output element
kernel void gemv_naive(
    device const float* x      [[buffer(0)]],   // [K]
    device const float* W      [[buffer(1)]],   // [K x N], row-major
    device float*       y      [[buffer(2)]],   // [N]
    uint                tid    [[thread_position_in_grid]])
{
    float sum = 0.0f;
    for (uint k = 0; k < K; k++) {
        sum += x[k] * W[k * N + tid];  // Column tid of W
    }
    y[tid] = sum;
}

This works but has terrible performance. Each thread reads all K elements of x (lots of redundant reads) and walks down a column of W with stride N (terrible memory access pattern, no coalescing).

Coalesced Access with Transposed Weights

First improvement: store W in column-major order (or equivalently, store W^T in row-major order). Now adjacent threads read adjacent memory locations:

// GEMV with transposed weights: coalesced reads
kernel void gemv_transposed(
    device const float* x      [[buffer(0)]],   // [K]
    device const float* Wt     [[buffer(1)]],   // [N x K], W transposed
    device float*       y      [[buffer(2)]],   // [N]
    uint                tid    [[thread_position_in_grid]])
{
    // Thread tid computes output element tid
    // Reads row tid of Wt, which is column tid of W
    float sum = 0.0f;
    for (uint k = 0; k < K; k++) {
        sum += x[k] * Wt[tid * K + k];
    }
    y[tid] = sum;
}

Better memory access pattern, but we still have N threads each independently reading the entire x vector. Let us fix that.

SIMD-Level Parallelism: Splitting the K Dimension

Here is the key idea for efficient GEMV on Metal: instead of one thread per output, use a group of threads (a SIMD group) to cooperatively compute one output element. Each thread in the SIMD group handles a slice of the K dimension, then we use simd_sum to reduce:

  SIMD GEMV: 32 threads cooperate on one dot product
  ====================================================

  x = [x0, x1, x2, ..., x31, x32, ..., x63, ...]

  Thread 0 handles: x[0]*w[0] + x[32]*w[32] + x[64]*w[64] + ...
  Thread 1 handles: x[1]*w[1] + x[33]*w[33] + x[65]*w[65] + ...
  ...
  Thread 31 handles: x[31]*w[31] + x[63]*w[63] + x[95]*w[95] + ...

  Then: simd_sum() adds all 32 partial sums in hardware
  Result: one output element computed cooperatively
// SIMD GEMV: one SIMD group per output element
kernel void gemv_simd(
    device const float* x      [[buffer(0)]],
    device const float* Wt     [[buffer(1)]],   // [N x K]
    device float*       y      [[buffer(2)]],
    constant uint&      K      [[buffer(3)]],
    uint                simd_lane  [[thread_index_in_simdgroup]],
    uint                simd_gid   [[simdgroup_index_in_threadgroup]],
    uint                tg_id      [[threadgroup_position_in_grid]])
{
    // Each SIMD group computes one output element
    uint n = tg_id * SIMD_GROUPS_PER_TG + simd_gid;
    if (n >= N) return;

    float partial = 0.0f;
    // Each lane processes every 32nd element
    for (uint k = simd_lane; k < K; k += 32) {
        partial += x[k] * Wt[n * K + k];
    }

    // Hardware reduction across the SIMD group
    float total = simd_sum(partial);

    // Only lane 0 writes the result
    if (simd_lane == 0) {
        y[n] = total;
    }
}

The simd_sum intrinsic is a hardware-level reduction. On Metal, a SIMD group (also called a “wave” on other platforms) is 32 threads that execute in lockstep. The simd_sum operation sums a value across all 32 lanes in just a few clock cycles – no shared memory, no barriers, no atomics. This is enormously powerful.

  simd_sum reduction (hardware butterfly)
  ========================================

  Lane:  0    1    2    3   ...  30   31
  Val:  2.1  0.5  1.3  0.8 ...  0.2  1.1

  Step 1: swap with neighbor, add
         2.6  2.6  2.1  2.1 ...  1.3  1.3

  Step 2: swap with stride-2 neighbor, add
         4.7  4.7  4.7  4.7 ...  ...  ...

  Step 3: stride 4 ...
  Step 4: stride 8 ...
  Step 5: stride 16 ...

  After 5 steps: all lanes hold the total sum.
  Cost: ~5 cycles. No memory traffic.

Multiple Output Elements per SIMD Group

We can do even better. Instead of one output element per SIMD group, compute several. This amortizes the cost of reading x:

// SIMD GEMV: one SIMD group computes 4 output elements
kernel void gemv_simd_multi(
    device const float* x      [[buffer(0)]],
    device const float* Wt     [[buffer(1)]],
    device float*       y      [[buffer(2)]],
    constant uint&      K      [[buffer(3)]],
    uint                simd_lane  [[thread_index_in_simdgroup]],
    uint                simd_gid   [[simdgroup_index_in_threadgroup]],
    uint                tg_id      [[threadgroup_position_in_grid]])
{
    uint n_base = (tg_id * SIMD_GROUPS_PER_TG + simd_gid) * 4;

    float4 partial = float4(0.0f);

    for (uint k = simd_lane; k < K; k += 32) {
        float xk = x[k];  // Load x[k] once, use for all 4 outputs
        partial[0] += xk * Wt[(n_base + 0) * K + k];
        partial[1] += xk * Wt[(n_base + 1) * K + k];
        partial[2] += xk * Wt[(n_base + 2) * K + k];
        partial[3] += xk * Wt[(n_base + 3) * K + k];
    }

    // Reduce each output element
    float y0 = simd_sum(partial[0]);
    float y1 = simd_sum(partial[1]);
    float y2 = simd_sum(partial[2]);
    float y3 = simd_sum(partial[3]);

    if (simd_lane == 0) {
        y[n_base + 0] = y0;
        y[n_base + 1] = y1;
        y[n_base + 2] = y2;
        y[n_base + 3] = y3;
    }
}

Now each SIMD group loads x[k] once and uses it four times. The weight reads are still one-shot (each weight used exactly once), but we have reduced the x reads by 4x.

GEMM: Tiling for Compute Efficiency

When M is large (prefill phase), we enter the world of GEMM. The key technique here is tiling: dividing the output matrix into tiles and assigning each tile to a threadgroup. Within the threadgroup, we further divide work among SIMD groups.

The Tiling Concept

  GEMM Tiling Overview
  =====================

  Output Y [M x N] is divided into tiles of size [Bm x Bn]

       N
  <----------->
  +----+----+----+----+  ^
  | TG | TG | TG | TG |  |
  | 00 | 01 | 02 | 03 |  |
  +----+----+----+----+  | M
  | TG | TG | TG | TG |  |
  | 10 | 11 | 12 | 13 |  |
  +----+----+----+----+  v

  Each TG computes a [Bm x Bn] tile of the output.

  To compute its tile, TG needs:
  - A strip of X: [Bm x K]
  - A strip of W: [K x Bn]

  But K can be huge (4096+), so we tile along K too:

       K
  <----------->
  +====+====+====+  <- X rows for this TG's tile
  | Bk | Bk | Bk |
  +====+====+====+

  Process K in chunks of Bk, accumulating partial results.

Cooperative Loading into Threadgroup Memory

The key insight is that threads within a threadgroup can cooperatively load a tile of data into fast threadgroup memory (shared memory), then all threads read from it. This converts slow device memory reads into fast threadgroup memory reads:

// Simplified tiled GEMM
kernel void gemm_tiled(
    device const half* X        [[buffer(0)]],  // [M x K]
    device const half* W        [[buffer(1)]],  // [K x N]
    device half*       Y        [[buffer(2)]],  // [M x N]
    constant uint&     M        [[buffer(3)]],
    constant uint&     N        [[buffer(4)]],
    constant uint&     K        [[buffer(5)]],
    uint2              tg_pos   [[threadgroup_position_in_grid]],
    uint               tid      [[thread_index_in_threadgroup]])
{
    // Tile dimensions
    const uint Bm = 32;  // Tile rows
    const uint Bn = 32;  // Tile cols
    const uint Bk = 16;  // K-tile size

    // Threadgroup memory for tiles
    threadgroup half X_tile[Bm * Bk];
    threadgroup half W_tile[Bk * Bn];

    // This threadgroup computes output tile at (tg_pos.y * Bm, tg_pos.x * Bn)
    uint m_base = tg_pos.y * Bm;
    uint n_base = tg_pos.x * Bn;

    // Accumulator for this thread's output elements
    float acc[4] = {0, 0, 0, 0};  // Each thread computes 4 elements

    // Walk along K in steps of Bk
    for (uint k_base = 0; k_base < K; k_base += Bk) {

        // === Cooperative load: all threads load tiles into shared memory ===
        // (simplified -- real code distributes load across all threads)
        if (tid < Bm * Bk) {
            uint row = tid / Bk;
            uint col = tid % Bk;
            X_tile[tid] = X[(m_base + row) * K + (k_base + col)];
        }
        if (tid < Bk * Bn) {
            uint row = tid / Bn;
            uint col = tid % Bn;
            W_tile[tid] = W[(k_base + row) * N + (n_base + col)];
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        // === Compute: multiply tiles ===
        // Each thread computes its assigned output elements
        // using data from threadgroup memory (fast!)
        uint my_m = tid / 8;   // Which row within tile
        uint my_n = (tid % 8) * 4;  // Which 4 columns within tile

        for (uint kk = 0; kk < Bk; kk++) {
            half x_val = X_tile[my_m * Bk + kk];
            for (uint j = 0; j < 4; j++) {
                acc[j] += float(x_val) * float(W_tile[kk * Bn + my_n + j]);
            }
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);
    }

    // Write results
    uint my_m = tid / 8;
    uint my_n = (tid % 8) * 4;
    for (uint j = 0; j < 4; j++) {
        Y[(m_base + my_m) * N + (n_base + my_n + j)] = half(acc[j]);
    }
}

Let us trace through what happens for one K-tile:

  One iteration of the K-loop (k_base = 0, Bk = 16)
  ===================================================

  Step 1: Cooperative Load
  ~~~~~~~~~~~~~~~~~~~~~~~~
  All 256 threads in the threadgroup work together to load:

  X_tile [32 x 16]:                    W_tile [16 x 32]:
  Rows m_base..m_base+31               Rows 0..15
  Cols 0..15 of X                      Cols n_base..n_base+31 of W

  Thread 0 loads X[m_base][0..15] row  Thread 128 loads W[0][n_base..+31]
  Thread 1 loads X[m_base+1][0..15]    Thread 129 loads W[1][n_base..+31]
  ...                                   ...

  Step 2: Barrier
  ~~~~~~~~~~~~~~~
  Wait for all loads to complete.

  Step 3: Compute
  ~~~~~~~~~~~~~~~
  Each thread reads from the fast threadgroup memory tiles
  and accumulates partial results.

  Thread 0 computes: Y[m_base+0][n_base+0..3] += X_tile[0][k] * W_tile[k][0..3]
  Thread 1 computes: Y[m_base+0][n_base+4..7] += X_tile[0][k] * W_tile[k][4..7]
  ...
  Thread 8 computes: Y[m_base+1][n_base+0..3] += X_tile[1][k] * W_tile[k][0..3]
  ...

  Step 4: Barrier
  ~~~~~~~~~~~~~~~
  Wait before overwriting tiles with next K-chunk.

SIMD Group Matrix Multiply-Accumulate (MMA)

Metal 3.0+ on Apple Silicon supports SIMD group matrix operations that are much faster than manually computing the multiply-accumulate. These map to the hardware’s matrix multiplication units (the AMX/matrix coprocessor):

#include <metal_simdgroup_matrix>

// Using SIMD group MMA for tiled GEMM
kernel void gemm_simdgroup_mma(
    device const half* X        [[buffer(0)]],
    device const half* W        [[buffer(1)]],
    device half*       Y        [[buffer(2)]],
    constant uint&     M        [[buffer(3)]],
    constant uint&     N        [[buffer(4)]],
    constant uint&     K        [[buffer(5)]],
    uint2              tg_pos   [[threadgroup_position_in_grid]],
    uint               sgid     [[simdgroup_index_in_threadgroup]],
    uint               lane     [[thread_index_in_simdgroup]])
{
    // Each SIMD group computes an 8x8 tile of the output
    // using 8x8 matrix operations

    // Declare SIMD group matrices
    simdgroup_matrix<half, 8, 8> x_mat;
    simdgroup_matrix<half, 8, 8> w_mat;
    simdgroup_matrix<half, 8, 8> acc_mat;

    // Initialize accumulator to zero
    simdgroup_fill(acc_mat, half(0));

    uint m_base = tg_pos.y * 32 + (sgid / 4) * 8;
    uint n_base = tg_pos.x * 32 + (sgid % 4) * 8;

    for (uint k = 0; k < K; k += 8) {
        // Load 8x8 tiles from X and W
        simdgroup_load(x_mat, X + m_base * K + k, K);
        simdgroup_load(w_mat, W + k * N + n_base, N);

        // 8x8 matrix multiply-accumulate in hardware
        simdgroup_multiply_accumulate(acc_mat, x_mat, w_mat, acc_mat);
    }

    // Store the 8x8 result tile
    simdgroup_store(acc_mat, Y + m_base * N + n_base, N);
}

The simdgroup_multiply_accumulate function computes an 8x8 matrix multiply in hardware. A single SIMD group of 32 threads cooperatively holds the 8x8 matrices (each thread holds 2 elements) and the multiplication happens in the dedicated matrix hardware. This is dramatically faster than doing the multiply-add manually.

  SIMD Group MMA: 32 threads own an 8x8 matrix
  ===============================================

  The 8x8 = 64 elements are distributed across 32 threads:
  Thread 0:  elements [0,0] and [0,1]
  Thread 1:  elements [0,2] and [0,3]
  ...
  Thread 15: elements [1,6] and [1,7]
  Thread 16: elements [2,0] and [2,1]
  ...
  Thread 31: elements [3,6] and [3,7]

  simdgroup_multiply_accumulate(C, A, B, C):
    C += A * B    (all 8x8, in ~1-2 cycles)

  This maps to Apple's matrix coprocessor hardware.

Wide GEMV: The Vocabulary Projection Problem

There is one particular GEMV that deserves special attention: the final vocabulary projection. In a typical LLM, this multiplies the hidden state (dimension 4096) by the vocabulary embedding matrix to produce logits over the entire vocabulary:

hidden: [1, 4096]
vocab_weights: [4096, 32000]    (or 128256 for Llama 3!)
logits: [1, 32000]

This is still a GEMV (M=1), but N is enormous – 32K to 128K output elements. The challenge is that you need enough parallelism to saturate the GPU, and you need to read a very large weight matrix.

The strategy is to parallelize aggressively across the N dimension:

  Wide GEMV for Vocab Projection
  ===============================

  N = 32000 output logits
  K = 4096

  Approach: assign SIMD groups to output elements

  Threadgroup 0:    SIMD groups compute outputs [0..127]
  Threadgroup 1:    SIMD groups compute outputs [128..255]
  ...
  Threadgroup 249:  SIMD groups compute outputs [31872..31999]

  With 250 threadgroups x 4 SIMD groups/TG x 32 threads/SIMD = 32000 threads
  Each SIMD group: 32 threads reduce a dot product of length 4096
  That is 4096/32 = 128 multiply-adds per thread.

  Total weight data: 4096 * 32000 * 2 bytes = 256 MB (FP16)
  At 200 GB/s: ~1.3ms to stream all weights
  This sets the floor for GEMV performance.

The key insight: for wide GEMV, you want each SIMD group to handle one (or a few) output columns, with the 32 lanes splitting the K-dimension reduction. The x vector is small enough to fit in threadgroup memory or even registers, so you load it once and reuse it for all output columns in the threadgroup.

// Wide GEMV: optimized for large N (vocab projection)
kernel void gemv_wide(
    device const half*  x       [[buffer(0)]],   // [K]
    device const half*  Wt      [[buffer(1)]],   // [N x K]
    device half*        y       [[buffer(2)]],   // [N]
    constant uint&      K       [[buffer(3)]],
    constant uint&      N       [[buffer(4)]],
    uint                sgid    [[simdgroup_index_in_threadgroup]],
    uint                lane    [[thread_index_in_simdgroup]],
    uint                tg_id   [[threadgroup_position_in_grid]])
{
    // Load x into threadgroup memory (cooperative load)
    threadgroup half x_shared[4096];  // Assumes K <= 4096
    uint tid = sgid * 32 + lane;
    for (uint i = tid; i < K; i += 256) {  // 256 threads per TG
        x_shared[i] = x[i];
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Each SIMD group handles one output element
    uint n = tg_id * 8 + sgid;  // 8 SIMD groups per TG
    if (n >= N) return;

    float sum = 0.0f;
    device const half* w_row = Wt + n * K;

    for (uint k = lane; k < K; k += 32) {
        sum += float(x_shared[k]) * float(w_row[k]);
    }

    sum = simd_sum(sum);
    if (lane == 0) {
        y[n] = half(sum);
    }
}

Batched GEMV: The Awkward Middle Ground

In practice, inference is not always purely M=1 or M=512. There are several scenarios where M is small but greater than 1:

  1. Batch decoding: serving multiple users simultaneously (M = batch_size, often 2-16)
  2. Speculative decoding: verifying multiple candidate tokens at once (M = num_candidates, often 4-8)
  3. Small prompts: short prefill with just a few tokens

For these cases, M is too small for efficient GEMM (not enough reuse to be compute bound) but too large for pure GEMV (we want some reuse of the weight data).

The solution is batched GEMV: treat it as M independent GEMVs but share weight loads across them:

  Batched GEMV (M=4)
  ===================

  X = [ x0 ]     (4 rows, each a vector of length K)
      [ x1 ]
      [ x2 ]
      [ x3 ]

  Strategy: Each SIMD group handles multiple output columns
  and ALL M rows simultaneously.

  SIMD Group 0 processing output column n=0:
  +-----+-----+-----+-----+-----+
  | Lane|  0  |  1  |  2  | ... |
  +-----+-----+-----+-----+-----+
  | Row0| x0[0]*w[0] | x0[1]*w[1] | x0[2]*w[2] | ... |
  | Row1| x1[0]*w[0] | x1[1]*w[1] | x1[2]*w[2] | ... |
  | Row2| x2[0]*w[0] | x2[1]*w[1] | x2[2]*w[2] | ... |
  | Row3| x3[0]*w[0] | x3[1]*w[1] | x3[2]*w[2] | ... |
  +-----+-----+-----+-----+-----+

  Each weight w[k] is loaded once and used for M=4 rows.
  Weight reuse factor: 4x compared to pure GEMV.
// Batched GEMV: M rows share weight loads
kernel void gemv_batched(
    device const half*  X       [[buffer(0)]],   // [M x K]
    device const half*  Wt      [[buffer(1)]],   // [N x K]
    device half*        Y       [[buffer(2)]],   // [M x N]
    constant uint&      M       [[buffer(3)]],
    constant uint&      K       [[buffer(4)]],
    constant uint&      N       [[buffer(5)]],
    uint                sgid    [[simdgroup_index_in_threadgroup]],
    uint                lane    [[thread_index_in_simdgroup]],
    uint                tg_id   [[threadgroup_position_in_grid]])
{
    uint n = tg_id * 8 + sgid;  // Output column
    if (n >= N) return;

    // Accumulators for each row
    float sums[8] = {0};  // Support up to M=8

    device const half* w_row = Wt + n * K;

    for (uint k = lane; k < K; k += 32) {
        half w_val = w_row[k];  // Load weight ONCE

        // Apply to all M rows
        for (uint m = 0; m < M; m++) {
            sums[m] += float(X[m * K + k]) * float(w_val);
        }
    }

    // Reduce each row's sum
    for (uint m = 0; m < M; m++) {
        float total = simd_sum(sums[m]);
        if (lane == 0) {
            Y[m * N + n] = half(total);
        }
    }
}

The performance gain over running M separate GEMVs is significant because:

  1. Weight data is loaded once instead of M times
  2. Kernel launch overhead is reduced
  3. The GPU can be better saturated with the additional work
  Performance: M separate GEMVs vs Batched GEMV
  ===============================================

  M=4, K=4096, N=4096, FP16 weights

  M separate GEMVs:
    Weight reads: 4 * 4096 * 4096 * 2 bytes = 128 MB
    Time at 200 GB/s: ~0.64 ms

  Batched GEMV:
    Weight reads: 1 * 4096 * 4096 * 2 bytes = 32 MB
    Time at 200 GB/s: ~0.16 ms

  4x improvement from weight reuse!

How Quantization Changes the Game

Everything we have discussed so far assumes FP16 weights. But in practice, most inference deployments use quantized weights – 4-bit or even 2-bit. This changes the matmul kernels significantly.

The Core Challenge: Dequantize on the Fly

Quantized weights are stored in a compressed format. You cannot directly multiply them with the input – you first need to dequantize them back to a floating-point representation. The question is: where and when do you dequantize?

The answer is on-the-fly dequantization: each thread dequantizes just the weight values it needs, right before multiplying them with the input. The dequantized values live only in registers, never written back to memory.

  Dequantize-on-the-fly GEMV
  ===========================

  Memory:  [...Q4 packed weights...]  (4 bits per weight)
                    |
                    v  (load)
  Registers: [packed_byte]
                    |
                    v  (unpack + scale)
  Registers: [fp16_weight_a, fp16_weight_b]
                    |
                    v  (multiply with input)
  Registers: [partial_sum]
                    |
                    v  (simd_sum)
  Memory:   [output]

  Key: dequantized weights never touch memory.
  We trade compute (dequantize) for bandwidth (smaller reads).

Q4_0 GEMV Example

Let us work through a concrete example with Q4_0 quantization (the simplest format). In Q4_0, every block of 32 weights shares a single FP16 scale factor. Each weight is stored as a 4-bit integer (0-15), representing the range [-8, 7] after subtracting 8:

// Q4_0 block structure
struct block_q4_0 {
    half    scale;         // 2 bytes: shared scale for 32 weights
    uchar   packed[16];   // 16 bytes: 32 x 4-bit values packed into pairs
};
// Total: 18 bytes for 32 weights (4.5 bits/weight effective)

// Q4_0 GEMV: dequantize on the fly
kernel void gemv_q4_0(
    device const half*         x       [[buffer(0)]],  // [K]
    device const block_q4_0*   W       [[buffer(1)]],  // [N x K/32] blocks
    device half*               y       [[buffer(2)]],  // [N]
    constant uint&             K       [[buffer(3)]],
    uint                       sgid    [[simdgroup_index_in_threadgroup]],
    uint                       lane    [[thread_index_in_simdgroup]],
    uint                       tg_id   [[threadgroup_position_in_grid]])
{
    uint n = tg_id * 8 + sgid;
    uint num_blocks = K / 32;

    float sum = 0.0f;

    // Each lane processes every 32nd block
    for (uint b = lane; b < num_blocks; b += 32) {
        // Load the quantized block
        device const block_q4_0& blk = W[n * num_blocks + b];
        half scale = blk.scale;

        // Dequantize and multiply all 32 weights in this block
        for (uint j = 0; j < 16; j++) {
            uchar packed = blk.packed[j];

            // Unpack two 4-bit values
            int8_t w0 = (packed & 0x0F) - 8;  // Low nibble
            int8_t w1 = (packed >> 4)   - 8;   // High nibble

            // Dequantize: weight = scale * quantized_value
            float dq0 = float(scale) * float(w0);
            float dq1 = float(scale) * float(w1);

            // Multiply with input
            uint k_idx = b * 32 + j * 2;
            sum += dq0 * float(x[k_idx]);
            sum += dq1 * float(x[k_idx + 1]);
        }
    }

    sum = simd_sum(sum);
    if (lane == 0) {
        y[n] = half(sum);
    }
}

The Bandwidth Win

The reason quantization is so important for GEMV performance is straightforward arithmetic. Since GEMV is bandwidth bound, reducing the data size directly speeds things up:

  Bandwidth Savings with Quantization
  ====================================

  Weight matrix: 4096 x 4096 = 16.7M weights

  FP16:  16.7M * 2 bytes   = 33.6 MB    (baseline)
  Q8_0:  16.7M * 1.06 bytes = 17.7 MB   (1.9x faster)
  Q4_0:  16.7M * 0.56 bytes =  9.4 MB   (3.6x faster)
  Q4_K:  16.7M * 0.57 bytes =  9.5 MB   (3.5x faster)
  Q2_K:  16.7M * 0.34 bytes =  5.7 MB   (5.9x faster)

  At 200 GB/s memory bandwidth:
  FP16:  0.168 ms per layer
  Q4_0:  0.047 ms per layer  <-- 3.6x speedup!

  Decode speed for 7B model:
  FP16:  14.0 GB / 200 GB/s = 70 ms/token  -> 14 tok/s
  Q4_0:   3.9 GB / 200 GB/s = 19 ms/token  -> 52 tok/s

The compute cost of dequantization is negligible compared to the bandwidth savings. A few shifts and multiplies per weight value are cheap when the alternative is waiting for twice as many bytes to arrive from memory.

Quantized GEMM: A Different Story

For GEMM (prefill), the situation is more nuanced. GEMM is compute bound, so reducing weight size does not automatically speed things up – you are not bandwidth limited in the first place. However, quantized GEMM is still useful because:

  1. It reduces memory footprint, letting larger models fit in memory
  2. For moderate M values (16-64), you might still be bandwidth bound
  3. Apple’s AMX hardware has limited support for mixed-precision MMA

The typical approach is to dequantize tiles of weights into threadgroup memory in FP16, then use the standard SIMD group MMA operations on the dequantized tiles:

  Quantized GEMM: Dequantize into Threadgroup Memory
  ====================================================

  Device Memory:          Threadgroup Memory:       Registers:
  +-----------+           +----------+              +--------+
  | Q4 Weights|  --load-->| FP16 Tile|  --MMA-->   | Accum  |
  | (compact) |           | (32x32)  |              | (8x8)  |
  +-----------+           +----------+              +--------+

  1. Cooperatively load Q4 blocks from device memory
  2. Dequantize to FP16 in threadgroup memory
  3. Use simdgroup_multiply_accumulate on FP16 tiles
  4. Repeat for all K-tiles

Worked Example: Full Thread Assignment

Let us trace through a complete example to see how threads are assigned for a real GEMV. Consider:

  Problem: y = x * W
  x:  [1, 4096]   (one token's hidden state)
  W:  [4096, 4096] (Q projection weight)
  y:  [1, 4096]   (query vector)

  Configuration:
  - Threadgroup size: 256 threads = 8 SIMD groups x 32 lanes
  - Each SIMD group: 1 output element
  - Grid: 4096 / 8 = 512 threadgroups

  Memory layout: W stored transposed as Wt [4096 x 4096]

Let us follow Threadgroup 0, SIMD group 3, Lane 17:

  Thread Identity
  ===============
  Threadgroup:  0
  SIMD group:   3
  Lane:         17

  Output assignment
  =================
  output index n = threadgroup * 8 + simd_group = 0 * 8 + 3 = 3
  This thread helps compute y[3].

  Work assignment
  ===============
  Lane 17 processes every 32nd element along K:
    k = 17, 49, 81, 113, ..., 4081

  Total iterations: 4096 / 32 = 128

  For each iteration (e.g., k=17):
    partial_sum += x[17] * Wt[3 * 4096 + 17]

  After all 128 iterations:
    partial_sum = x[17]*Wt[3,17] + x[49]*Wt[3,49] + ... + x[4081]*Wt[3,4081]

  Reduction
  =========
  simd_sum(partial_sum) adds all 32 lanes' partial sums:
    y[3] = sum over all k: x[k] * Wt[3, k]
         = dot(x, column 3 of W)

  Only lane 0 of SIMD group 3 writes y[3] to memory.

  Summary for entire kernel
  =========================
  512 threadgroups * 8 SIMD groups = 4096 dot products
  Each SIMD group: 32 threads * 128 iterations = 4096 multiply-adds
  Total: 4096 * 4096 = 16.7M multiply-adds (matches 2*K*N/2 operations)

Choosing the Right Strategy

Here is a decision tree for choosing the right matmul strategy:

  Matmul Strategy Decision Tree
  ==============================

                     What is M?
                        |
          +-------------+-------------+
          |             |             |
        M = 1         2-16         > 16
          |             |             |
        GEMV      Batched GEMV      GEMM
          |             |             |
     +----+----+   +----+----+   +----+----+
     |         |   |         |   |         |
   Small N  Large N           |   Use SIMD
   (4096)  (32K+)          Share   Group
     |         |          weight    MMA
   1 SIMD   Lots of       loads  with tiling
   group    threadgroups          |
   per out                   +----+----+
     |                       |         |
   Standard               FP16     Quantized
   SIMD GEMV              tiles    (dequant
                          + MMA    into TG
                                   memory)

And here are rough performance expectations on Apple M-series GPUs:

  Performance Expectations (M2 Pro, ~200 GB/s)
  =============================================

  Operation                    Size           FP16       Q4_0
  ---------                    ----           ----       ----
  GEMV  (decode, 1 layer)     [1,4096]*[4096,4096]     0.17ms    0.05ms
  GEMV  (vocab projection)    [1,4096]*[4096,32000]    1.28ms    0.36ms
  Batched GEMV (M=4)          [4,4096]*[4096,4096]     0.17ms    0.06ms
  GEMM  (prefill, M=512)      [512,4096]*[4096,4096]   ~2ms      ~2ms

  Note: Batched GEMV with M=4 is barely slower than M=1 GEMV because
  the same weight data is read -- only the compute increases.
  GEMM times are similar for FP16 and Q4 because GEMM is compute-bound.

Summary

Matrix multiplication is the heart of neural network inference. On Metal GPUs, you need different strategies depending on the shape:

  1. GEMV (M=1, decode): Bandwidth bound. Split the K dimension across SIMD lanes, reduce with simd_sum. Optimize for coalesced memory access and maximize bandwidth utilization.

  2. Batched GEMV (M=2-16): Still bandwidth bound but with some weight reuse. Load weights once, apply to all M rows. Significant speedup over M separate GEMVs.

  3. GEMM (M>16, prefill): Can be compute bound with proper tiling. Use cooperative loading into threadgroup memory and SIMD group MMA for maximum throughput.

  4. Quantized variants: For GEMV, quantization directly translates to speedup via reduced bandwidth. Dequantize on-the-fly in registers. For GEMM, dequantize into threadgroup memory tiles, then use standard MMA.

The next chapter will look at where these matmuls are used in the context of attention – where Q, K, and V come from matmuls, and then the attention computation itself introduces a different kind of matrix multiplication.