Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Tensors and Linear Algebra on GPU

If you have spent any time reading ML papers or browsing inference codebases, you have seen the word “tensor” more times than you can count. It gets thrown around so casually that you might suspect it means something terribly complicated. It does not. At its heart, a tensor in the machine learning context is just a multi-dimensional array of numbers. That is it. A 1D tensor is a vector. A 2D tensor is a matrix. A 3D tensor is… well, a 3D array. And so on.

But here is the thing: while the concept is simple, the implementation on a GPU is where all the interesting engineering lives. How do you lay out a 4-dimensional array in a flat slab of GPU memory? How do you make sure that when 3,000 threads all try to read different elements at once, the memory system does not choke? How do you express broadcasting rules in terms of pointer arithmetic? These are the questions that determine whether your inference engine runs at 10 tokens per second or 100.

In this chapter, we will build your intuition for tensors from the ground up, then show exactly how they map to Metal GPU buffers. By the end, you will understand the memory layout decisions that dominate every kernel we write in later chapters.

What Is a Tensor, Really?

Let us start with the basics and build up.

Scalars, Vectors, Matrices, and Beyond

Rank 0 (Scalar):     42.0

Rank 1 (Vector):     [1.0, 2.0, 3.0, 4.0]

Rank 2 (Matrix):     [[1.0, 2.0, 3.0],
                       [4.0, 5.0, 6.0]]

Rank 3 (3D Tensor):  [[[1.0, 2.0],
                        [3.0, 4.0],
                        [5.0, 6.0]],

                       [[7.0, 8.0],
                        [9.0, 10.0],
                        [11.0, 12.0]]]

The rank (or order) of a tensor is simply how many indices you need to pick out a single element:

  • Rank 0: zero indices. It is just a number.
  • Rank 1: one index. v[i] gives you element i of a vector.
  • Rank 2: two indices. M[i][j] gives you the element at row i, column j.
  • Rank 3: three indices. T[i][j][k] picks one element from a 3D block.

The shape describes how many elements exist along each dimension. For the rank-3 tensor above, the shape is (2, 3, 2) – two “slices,” each containing 3 rows of 2 elements.

Tensors in Transformers

In a typical transformer model, you encounter tensors of various ranks constantly:

Token embeddings:       shape (seq_len, d_model)         Rank 2
Weight matrix:          shape (d_model, d_out)            Rank 2
Attention scores:       shape (n_heads, seq_len, seq_len) Rank 3
Batched activations:    shape (batch, seq_len, d_model)   Rank 3
Multi-head QKV:         shape (batch, n_heads, seq, d_k)  Rank 4

For inference on a single sequence (batch size 1), we can often drop the batch dimension. But even without batching, we regularly work with rank-2 and rank-3 tensors.

Here is a concrete example. Suppose we have a small model with:

  • d_model = 512 (embedding dimension)
  • n_heads = 8 (number of attention heads)
  • d_k = 64 (per-head dimension, since 512 / 8 = 64)
  • seq_len = 128 (sequence length)

The query tensor Q after projection would have shape (128, 512), and after reshaping for multi-head attention, it becomes (8, 128, 64).

Memory Layouts: How Multi-Dimensional Data Lives in Flat Memory

Here is the fundamental problem: GPU memory (and CPU memory, for that matter) is a flat, one-dimensional address space. You have addresses 0, 1, 2, 3, and so on. But your tensor is multi-dimensional. So you need a rule for mapping multi-dimensional indices to flat memory addresses.

Row-Major Order (C-Style)

The most common layout, and the one used by virtually all ML frameworks and inference engines, is row-major order. The idea is simple: you lay out the last dimension contiguously, then the second-to-last, and so on.

For a 2D matrix with shape (3, 4):

Logical view:           Memory layout (row-major):

     col 0  col 1  col 2  col 3
row 0 [ a      b      c      d  ]     addr 0:  a
row 1 [ e      f      g      h  ]     addr 1:  b
row 2 [ i      j      k      l  ]     addr 2:  c
                                       addr 3:  d
                                       addr 4:  e
                                       addr 5:  f
                                       addr 6:  g
                                       addr 7:  h
                                       addr 8:  i
                                       addr 9:  j
                                       addr 10: k
                                       addr 11: l

The elements of row 0 are contiguous in memory (a, b, c, d), followed by row 1 (e, f, g, h), followed by row 2 (i, j, k, l). The last index changes fastest as you walk through memory.

To find element M[i][j] in a (R, C) matrix:

address(i, j) = base + (i * C + j) * element_size

Column-Major Order (Fortran-Style)

The alternative is column-major order, where the first dimension is contiguous. BLAS libraries and MATLAB use this. The same (3, 4) matrix in column-major:

Memory layout (column-major):

addr 0:  a   (row 0, col 0)
addr 1:  e   (row 1, col 0)
addr 2:  i   (row 2, col 0)
addr 3:  b   (row 0, col 1)
addr 4:  f   (row 1, col 1)
addr 5:  j   (row 2, col 1)
addr 6:  c   (row 0, col 2)
addr 7:  g   (row 1, col 2)
addr 8:  k   (row 2, col 2)
addr 9:  d   (row 0, col 3)
addr 10: h   (row 1, col 3)
addr 11: l   (row 2, col 3)

Here the first index changes fastest. To find element M[i][j]:

address(i, j) = base + (j * R + i) * element_size

Why Does This Matter for GPU Performance?

Consider what happens when 32 threads in a SIMD group all need to read one element each. If thread t reads element M[row][t] and the matrix is row-major, all 32 threads read from consecutive memory addresses. The GPU’s memory system can satisfy this with a single coalesced memory transaction.

Row-major, threads reading M[row][0..31]:

Memory: [ M[r][0] | M[r][1] | M[r][2] | ... | M[r][31] | M[r][32] | ... ]
           ^          ^          ^                ^
         thread 0   thread 1   thread 2       thread 31

All addresses are contiguous --> ONE memory transaction (coalesced)

But if the matrix were column-major and threads tried to read M[row][0..31], each thread would be reading from addresses that are R elements apart:

Column-major, threads reading M[row][0..31]:

Memory: [ M[0][0] | M[1][0] | M[2][0] | ... | M[0][1] | M[1][1] | ... ]

Thread 0 reads M[row][0] at offset row
Thread 1 reads M[row][1] at offset R + row
Thread 2 reads M[row][2] at offset 2R + row
...
Addresses are R elements apart --> MANY memory transactions (strided)

This can be 10-30x slower depending on the stride. Coalesced access is one of the single most important performance considerations on any GPU.

Strides: The General Case

The concept of strides generalizes memory layout. A stride tells you how many elements to skip in memory when you increment one index by one.

For a row-major matrix of shape (R, C):

  • Stride along dimension 0 (rows): C (skip one whole row of C elements)
  • Stride along dimension 1 (columns): 1 (elements are adjacent)

For a row-major 3D tensor of shape (D0, D1, D2):

  • Stride along dimension 0: D1 * D2
  • Stride along dimension 1: D2
  • Stride along dimension 2: 1

The general address formula for any tensor:

address(i0, i1, ..., in) = base + sum(ik * stride_k) for k in 0..n

Here is a concrete example. For a tensor of shape (2, 3, 4) in row-major:

Strides: (12, 4, 1)

Because:
  stride[2] = 1
  stride[1] = shape[2] = 4
  stride[0] = shape[1] * shape[2] = 3 * 4 = 12

Element T[1][2][3]:
  address = 0 + 1*12 + 2*4 + 3*1 = 12 + 8 + 3 = 23

Let us verify by counting:
  Slice 0: [[ 0  1  2  3]    elements 0-11
            [ 4  5  6  7]
            [ 8  9 10 11]]

  Slice 1: [[12 13 14 15]    elements 12-23
            [16 17 18 19]
            [20 21 22 23]]   <-- T[1][2][3] = element 23. Correct!

Strides also let you express interesting things like transposition without copying data. If you have a matrix with strides (C, 1), its transpose has strides (1, C) – same data, just different stride interpretation.

How Tensors Map to GPU Buffers

On Metal, a tensor lives in an MTLBuffer – a contiguous block of GPU-accessible memory. There is no built-in “tensor” type in Metal. The buffer is just bytes. Your shader code interprets those bytes as a tensor by computing offsets from indices using the stride formula.

The Buffer Layout

MTLBuffer (contiguous memory):
+----+----+----+----+----+----+----+----+----+----+----+----+
| e0 | e1 | e2 | e3 | e4 | e5 | e6 | e7 | e8 | e9 |e10 |e11 |
+----+----+----+----+----+----+----+----+----+----+----+----+
  ^                                                         ^
  buffer.contents()                                   buffer end

Each element is sizeof(element_type) bytes:
  float  = 4 bytes
  half   = 2 bytes
  int8_t = 1 byte

For a half-precision (2, 3, 2) tensor:
  Total elements = 2 * 3 * 2 = 12
  Total bytes = 12 * 2 = 24 bytes

Computing Offsets in a Metal Shader

In a Metal compute shader, you typically receive the buffer pointer and the tensor’s metadata (shape, strides) as arguments. Here is what the pattern looks like:

kernel void elementwise_relu(
    device const half* input  [[buffer(0)]],
    device half*       output [[buffer(1)]],
    constant uint3&    shape  [[buffer(2)]],   // (D0, D1, D2)
    constant uint3&    strides [[buffer(3)]],  // (S0, S1, S2)
    uint3 gid [[thread_position_in_grid]]
) {
    // Bounds check
    if (gid.x >= shape.x || gid.y >= shape.y || gid.z >= shape.z) return;

    // Compute flat index from multi-dimensional position
    uint idx = gid.x * strides.x + gid.y * strides.y + gid.z * strides.z;

    // Read, compute, write
    half val = input[idx];
    output[idx] = val > 0 ? val : 0;
}

The key insight: the GPU hardware knows nothing about tensors. It only knows about flat memory addresses. All the multi-dimensional indexing is just arithmetic that we do in the shader.

Alignment Considerations

Metal has specific alignment requirements for buffer access:

Type        Size    Alignment
float       4B      4B
half        2B      2B
float4      16B     16B
half4       8B      8B

Rule of thumb: access addresses that are multiples of the type size.

When packing tensor data into buffers, you generally want each row to start at an aligned address. For a matrix of half values with 127 columns, each row takes 254 bytes. The next row starts at byte 254, which is 2-byte aligned (fine for half). But if you wanted to read rows as half4 vectors (common optimization), you would want 8-byte alignment, which means padding rows to 128 elements (256 bytes).

Multiple Tensors in One Buffer (Offset Packing)

In practice, inference engines often pack multiple tensors into a single large buffer with offsets:

Single MTLBuffer:
+------------------+------------------+------------------+
|  Weight matrix   |   Bias vector    |  LayerNorm scale |
|  (4096 x 4096)   |   (4096)         |  (4096)          |
|  offset: 0       |   offset: 33MB   |  offset: 33MB+8K |
+------------------+------------------+------------------+

This reduces the number of buffer bindings needed (Metal has a limit on how many buffers you can bind to a single kernel invocation) and can improve memory allocation efficiency.

Common Tensor Operations

Now let us look at the operations that transformers actually perform on tensors and how they parallelize on a GPU.

Element-wise Operations

These apply a function independently to each element. The output has the same shape as the input. Examples:

  • ReLU: output[i] = max(0, input[i])
  • SiLU (Swish): output[i] = input[i] * sigmoid(input[i])
  • GELU: output[i] = input[i] * 0.5 * (1 + erf(input[i] / sqrt(2)))
  • Addition: output[i] = a[i] + b[i]
  • Scalar multiply: output[i] = alpha * input[i]

Parallelization is trivial – assign one thread per element:

Input tensor:  [a0, a1, a2, a3, a4, a5, a6, a7]

Thread 0 --> processes a0
Thread 1 --> processes a1
Thread 2 --> processes a2
...
Thread 7 --> processes a7

Each thread:
  1. Read input[thread_id]
  2. Apply function
  3. Write output[thread_id]

No synchronization needed! Each thread is independent.

For a tensor with N elements, you dispatch ceil(N / threads_per_threadgroup) threadgroups, each with threads_per_threadgroup threads (commonly 256 or 1024).

Dispatch for N = 10000, threads_per_group = 256:

Threadgroups needed = ceil(10000 / 256) = 40 threadgroups

Threadgroup  0: threads 0-255     --> elements 0-255
Threadgroup  1: threads 256-511   --> elements 256-511
...
Threadgroup 38: threads 9728-9983 --> elements 9728-9983
Threadgroup 39: threads 9984-10239 --> elements 9984-9999 (rest out of bounds)

Reduction Operations

Reductions collapse one or more dimensions by aggregating elements. Examples:

  • Sum along a dimension: output[i] = sum(input[i][j] for all j)
  • Max along a dimension: output[i] = max(input[i][j] for all j)
  • Mean: sum divided by count

These are trickier to parallelize because threads need to cooperate to produce the result.

Strategy 1: One threadgroup per reduction

Matrix shape (4, 8), reduce along columns (dim 1):

Row 0: [a0  a1  a2  a3  a4  a5  a6  a7]  --> sum = a0+a1+...+a7
Row 1: [b0  b1  b2  b3  b4  b5  b6  b7]  --> sum = b0+b1+...+b7
Row 2: [c0  c1  c2  c3  c4  c5  c6  c7]  --> sum = c0+c1+...+c7
Row 3: [d0  d1  d2  d3  d4  d5  d6  d7]  --> sum = d0+d1+...+d7

Assign one threadgroup per row.
Within each threadgroup, threads cooperate to sum the row:

Threadgroup 0 (8 threads for row 0):
  Step 1: Each thread loads one element
    t0=a0, t1=a1, t2=a2, t3=a3, t4=a4, t5=a5, t6=a6, t7=a7

  Step 2: Parallel reduction tree
    t0 = a0+a4   t1 = a1+a5   t2 = a2+a6   t3 = a3+a7
    t0 = (a0+a4)+(a2+a6)      t1 = (a1+a5)+(a3+a7)
    t0 = total sum

  Step 3: Thread 0 writes result

Strategy 2: SIMD group reduction

On Metal, SIMD groups of 32 threads have special fast operations for reduction. The simd_sum() function sums a value across all threads in the SIMD group in a single instruction cycle:

SIMD group reduction (32 threads):

Before simd_sum:
  t0=3.0  t1=1.0  t2=4.0  t3=1.0  ...  t31=2.0

After val = simd_sum(val):
  t0=sum  t1=sum  t2=sum  t3=sum  ...  t31=sum
  (all threads have the same total sum)

For reductions over dimensions larger than 32, you use a combination: each SIMD group reduces its portion, then results are combined using threadgroup memory.

Broadcasting

Broadcasting lets you operate on tensors of different shapes by “stretching” the smaller tensor to match the larger one, conceptually. No data is actually copied.

Example: Add a bias vector to every row of a matrix

Matrix A shape: (1024, 512)
Bias b shape:   (512,)

Result C[i][j] = A[i][j] + b[j]    for all i, j

The bias is "broadcast" along dimension 0:

    A:                      b:                  C:
    [[a00 a01 ... a0,511]   [b0 b1 ... b511]    [[a00+b0  a01+b1  ...]
     [a10 a11 ... a1,511] +                   =   [a10+b0  a11+b1  ...]
     ...                                           ...
     [a1023,0 ...       ]]                         [a1023,0+b0 ...   ]]

In the shader, broadcasting is implemented by simply not advancing the index along the broadcast dimension:

// Adding bias (shape: [N]) to matrix (shape: [M, N])
kernel void add_bias(
    device const half* matrix [[buffer(0)]],
    device const half* bias   [[buffer(1)]],
    device half*       output [[buffer(2)]],
    uint2 gid [[thread_position_in_grid]]   // (row, col)
) {
    uint row = gid.x;
    uint col = gid.y;
    uint idx = row * N + col;

    output[idx] = matrix[idx] + bias[col];  // bias indexed only by col
}

The bias[col] access ignores the row entirely. That is broadcasting – the same bias element b[j] is used for every row.

Parallelizing Tensor Operations

The general strategies for mapping tensor operations to GPU threads fall into a few patterns. Let us formalize them.

Pattern 1: One Thread Per Element

Used for element-wise operations.

Tensor shape: (M, N)
Grid: (M, N)
Each thread (i, j) processes element [i][j]

+-------+-------+-------+-------+
| t(0,0)| t(0,1)| t(0,2)| t(0,3)|
+-------+-------+-------+-------+
| t(1,0)| t(1,1)| t(1,2)| t(1,3)|
+-------+-------+-------+-------+
| t(2,0)| t(2,1)| t(2,2)| t(2,3)|
+-------+-------+-------+-------+

Dispatch: grid_size = (M, N, 1)
          threadgroup_size = (16, 16, 1)  // 256 threads per group

Pattern 2: One Threadgroup Per Row (or Column)

Used for reductions, softmax, layer normalization – anything that operates across one dimension.

Tensor shape: (M, N)
Grid: M threadgroups, each with T threads

Threadgroup 0 --> Row 0:  [e0  e1  e2 ... eN-1]
Threadgroup 1 --> Row 1:  [e0  e1  e2 ... eN-1]
...
Threadgroup M-1 --> Row M-1

Within each threadgroup, T threads cooperate:
  Thread t handles elements: t, t+T, t+2T, ...

For N=1024, T=256:
  Thread 0: elements 0, 256, 512, 768
  Thread 1: elements 1, 257, 513, 769
  ...

Pattern 3: Tiled Processing

Used for matrix multiplication and attention. Threads in a threadgroup cooperatively load tiles into fast threadgroup memory, then compute.

Matrix multiply C = A * B
A: (M, K), B: (K, N), C: (M, N)

Tile size: (Tm, Tn) per threadgroup, iterate over K in tiles of Tk

Threadgroup (gi, gj) computes C[gi*Tm..(gi+1)*Tm][gj*Tn..(gj+1)*Tn]

       B (K x N)
       +---+---+---+---+
       |   | Bj|   |   |      Bj = tile column j
       +---+---+---+---+
       |   | Bj|   |   |
  A    +---+---+---+---+
(M x K)
+--+--+   +---+
|  |Ai| x | Bj| = Cij (Tm x Tn tile of output)
+--+--+   +---+
|  |  |
+--+--+

Ai = tile row i of A
For each k-tile: load Ai[:,k:k+Tk] and Bj[k:k+Tk,:] into shared memory
                 compute partial products
                 accumulate into Cij registers

We will explore tiled matrix multiplication in enormous detail in the next chapter.

Pattern 4: SIMD Group per Work Unit

Metal’s SIMD groups (wavefronts) of 32 threads have hardware support for fast intra-group communication. Many kernels assign one SIMD group to one logical unit of work:

GEMV: One SIMD group per output element (or a few elements)

Output vector y = W * x    (W is M x K, x is K x 1, y is M x 1)

SIMD group 0 --> y[0] = dot(W[0,:], x)
SIMD group 1 --> y[1] = dot(W[1,:], x)
...

Within each SIMD group:
  32 threads divide the K-length dot product
  Thread t handles indices: t, t+32, t+64, ...
  Each thread accumulates a partial sum
  simd_sum() gives the total

The Transformer’s Core Operations

A transformer model is built from a surprisingly small set of tensor operations, repeated many times. Let us catalog them and understand their computational characteristics.

Linear Projection (GEMM / GEMV)

The workhorse of the transformer. Every layer has multiple linear projections:

In one transformer block:
  Q = X * Wq + bq       (query projection)
  K = X * Wk + bk       (key projection)
  V = X * Wv + bv       (value projection)
  O = Attn * Wo + bo    (output projection)
  H = X * W1 + b1       (FFN first layer)
  Y = H * W2 + b2       (FFN second layer)

That's 6 matrix multiplications per layer!
A 32-layer model does 192 matmuls per forward pass.

The computation is Y = X * W where:

  • During prefill: X is (seq_len, d_model), W is (d_model, d_out) – a full GEMM
  • During decode: X is (1, d_model), W is (d_model, d_out) – a GEMV
Prefill (GEMM):                    Decode (GEMV):

 X (128 x 4096)                     x (1 x 4096)
 +----------+                        +----------+
 |          |                        |          |
 |          |    W (4096 x 4096)     +----------+    W (4096 x 4096)
 |          |    +----------+                        +----------+
 |          |  x |          |                      x |          |
 |          |    |          |                        |          |
 |          |    |          |                        |          |
 +----------+    +----------+                        +----------+
       =                                   =
 Y (128 x 4096)                     y (1 x 4096)
 +----------+                        +----------+
 |          |                        |          |
 |          |                        +----------+
 |          |
 +----------+

The GEMV case (decode) is memory-bandwidth bound because you read the entire weight matrix just to produce one output row. The GEMM case (prefill) is compute bound because you amortize the weight read across many input rows.

Attention (Batched Dot Products + Softmax)

The attention mechanism computes, for each query position, a weighted average of all value positions:

scores = Q * K^T / sqrt(d_k)     Matrix multiply: (seq, d_k) x (d_k, seq) = (seq, seq)
weights = softmax(scores)         Element-wise + reduction
output = weights * V              Matrix multiply: (seq, seq) x (seq, d_k) = (seq, d_k)

The middle step (softmax) is a reduction along the last dimension of each row, which requires special handling. We will devote Chapters 15 and 16 to attention.

Layer Normalization (RMSNorm)

Modern transformers use RMSNorm, which normalizes each row by its root-mean-square:

RMSNorm(x) = x / sqrt(mean(x^2) + eps) * gamma

For a row x of length d_model:
  1. Compute sum of squares:  ss = sum(x[i]^2 for i in 0..d_model)
  2. Compute RMS:             rms = sqrt(ss / d_model + eps)
  3. Normalize and scale:     output[i] = x[i] / rms * gamma[i]

Step 1 is a reduction (sum).
Step 3 is element-wise.

This follows the “one threadgroup per row” pattern. The reduction in step 1 uses SIMD group sums.

Activation Functions

Applied element-wise between the FFN layers:

SiLU(x) = x * sigmoid(x) = x * (1 / (1 + exp(-x)))

GELU(x) = x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))

These are purely element-wise, one thread per element. Trivially parallel.

Residual Connections

Just element-wise addition:

output = sublayer_output + input

Every transformer block has two residual connections:
  x = attention(norm(x)) + x      <-- residual
  x = ffn(norm(x)) + x            <-- residual

Putting It All Together: One Transformer Block

Input: x (seq_len, d_model)

Step 1: RMSNorm          --> x_norm (seq_len, d_model)         [reduction per row]
Step 2: Q = x_norm * Wq  --> Q (seq_len, d_model)             [GEMM or GEMV]
Step 3: K = x_norm * Wk  --> K (seq_len, d_kv)                [GEMM or GEMV]
Step 4: V = x_norm * Wv  --> V (seq_len, d_kv)                [GEMM or GEMV]
Step 5: Reshape Q/K/V for multi-head
Step 6: Attention         --> attn (seq_len, d_model)          [batched matmul + softmax]
Step 7: O = attn * Wo     --> O (seq_len, d_model)            [GEMM or GEMV]
Step 8: residual          --> x = O + x                        [element-wise add]
Step 9: RMSNorm           --> x_norm                           [reduction per row]
Step 10: FFN up           --> h (seq_len, d_ff)                [GEMM or GEMV]
Step 11: FFN gate         --> g (seq_len, d_ff)                [GEMM or GEMV]
Step 12: SiLU + multiply  --> h = SiLU(g) * h                 [element-wise]
Step 13: FFN down         --> f (seq_len, d_model)             [GEMM or GEMV]
Step 14: residual         --> x = f + x                        [element-wise add]

Output: x (seq_len, d_model)

Count the operations: 6 matrix multiplications, 2 normalizations, 2 residual adds, 1 activation, and the attention computation. The matrix multiplications dominate compute time (90%+ of total FLOPS).

Floating-Point Precision: FP16, FP32, and BF161

The choice of numeric precision has enormous impact on both performance and quality.

FP32 (Single Precision)

FP32: 1 sign bit + 8 exponent bits + 23 mantissa bits = 32 bits

+---+----------+-----------------------+
| S | EEEEEEEE | MMMMMMMMMMMMMMMMMMMMMMM|
+---+----------+-----------------------+
 1      8              23

Range: ~1.18e-38 to ~3.4e+38
Precision: ~7 decimal digits
Size: 4 bytes per value

FP32 is the “safe” choice. It has plenty of range and precision for any ML computation. But it uses twice the memory and bandwidth of FP16, making it 2x slower for memory-bound operations.

FP16 (Half Precision)

FP16: 1 sign bit + 5 exponent bits + 10 mantissa bits = 16 bits

+---+-------+------------+
| S | EEEEE | MMMMMMMMMM |
+---+-------+------------+
 1     5          10

Range: ~6.1e-5 to ~65504
Precision: ~3.3 decimal digits
Size: 2 bytes per value

FP16 halves memory usage and doubles effective bandwidth. Apple GPUs have native FP16 support and can process FP16 values at 2x the rate of FP32 in many operations.

The downside is the limited range. Values smaller than 6.1e-5 become zero (underflow), and values larger than 65504 become infinity (overflow). During inference, weights are usually small enough that this is not a problem. But intermediate computations (especially in attention) can overflow if not handled carefully.

BF16 (Brain Float 16)

BF16: 1 sign bit + 8 exponent bits + 7 mantissa bits = 16 bits

+---+----------+---------+
| S | EEEEEEEE | MMMMMMM |
+---+----------+---------+
 1      8           7

Range: ~1.18e-38 to ~3.4e+38  (same as FP32!)
Precision: ~2.4 decimal digits
Size: 2 bytes per value

BF16 has the same exponent range as FP32 (so no overflow/underflow issues) but less precision than FP16. It was designed specifically for ML workloads where range matters more than precision.

Apple Silicon support: M1/M2 GPUs do not have native BF16 support in Metal. M3+ has some BF16 capabilities. In practice, most Apple GPU inference uses FP16 for computation with FP32 for accumulation.

The Mixed-Precision Strategy

The standard approach in inference:

Weights:        stored in FP16 (or quantized, e.g., 4-bit)
Activations:    computed in FP16
Accumulations:  done in FP32 (dot products, reductions)
Final output:   cast back to FP16

Why FP32 for accumulations?

Consider summing 4096 FP16 values, each around 0.01:
  True sum: ~40.96
  FP16 sum: after adding a few hundred values, each new addition
            changes the sum by less than FP16 precision can represent.
            You lose significant accuracy.

  FP32 sum: no problem. 7 decimal digits of precision is plenty.

In Metal shader code, this looks like:

// Dot product with mixed precision
float sum = 0.0f;    // FP32 accumulator
for (uint i = 0; i < K; i++) {
    sum += float(A[i]) * float(B[i]);   // Cast FP16 inputs to FP32
}
output[idx] = half(sum);   // Cast result back to FP16

Precision Comparison Table

+--------+------+--------+-----------+-------------+------------------+
| Format | Bits | Range  | Precision | Memory/elem | Apple GPU Native |
+--------+------+--------+-----------+-------------+------------------+
| FP32   |  32  | 1e38   | ~7 digits | 4 bytes     | Yes              |
| FP16   |  16  | 65504  | ~3 digits | 2 bytes     | Yes (2x rate)    |
| BF16   |  16  | 1e38   | ~2 digits | 2 bytes     | M3+ partial      |
| INT8   |   8  | -128   | exact     | 1 byte      | Yes              |
|        |      | to 127 | integers  |             |                  |
| INT4   |   4  | -8     | exact     | 0.5 bytes   | No (emulated)    |
|        |      | to 7   | integers  |             |                  |
+--------+------+--------+-----------+-------------+------------------+

Fused Operations: Saving Bandwidth by Combining Kernels

Here is one of the most important optimization ideas in GPU programming: kernel fusion.2

The Problem: Bandwidth Waste

Consider computing SiLU activation followed by element-wise multiply (as in the LLaMA FFN):

Step 1: gate = SiLU(x * Wg)
Step 2: up = x * Wu
Step 3: hidden = gate * up

If each step is a separate kernel:

Kernel 1 (SiLU):
  Read gate_raw from global memory     (N elements read)
  Compute SiLU
  Write gate to global memory          (N elements written)

Kernel 2 (element-wise multiply):
  Read gate from global memory          (N elements read - AGAIN!)
  Read up from global memory            (N elements read)
  Compute gate * up
  Write hidden to global memory         (N elements written)

Total memory traffic: 2N reads + N writes + 2N reads + N writes = 4N reads + 2N writes

The Solution: Fused Kernel

Combine both operations into a single kernel:

Fused kernel (SiLU + multiply):
  Read gate_raw from global memory     (N elements read)
  Read up from global memory           (N elements read)
  Compute SiLU(gate_raw) * up in registers
  Write hidden to global memory        (N elements written)

Total memory traffic: 2N reads + N writes

Savings: 2N reads + N writes eliminated!

The fused kernel avoids the round trip through global memory between operations. The intermediate value (the SiLU output) lives only in registers, which are essentially free to access.

Common Fused Operations in Transformers

+------------------------------------+---------------------------------+
| Fused Operation                    | What It Combines                |
+------------------------------------+---------------------------------+
| Fused SiLU gate                    | SiLU(x) * y                    |
| Fused RMSNorm                      | norm + scale (gamma multiply)   |
| Fused attention                    | QK^T + scale + mask + softmax   |
|   (FlashAttention)                 |   + V multiply (Ch 16)          |
| Fused dequant + matmul             | Dequantize weights + GEMV/GEMM  |
| Fused residual + norm              | x + sublayer + RMSNorm          |
| Fused RoPE + reshape               | Rotary embedding + head reshape |
+------------------------------------+---------------------------------+

When Fusion Matters and When It Does Not

Fusion helps MOST when:
  - Operations are memory-bandwidth bound (small compute per element)
  - Intermediate tensors are large
  - Operations are sequential (output of one feeds input of next)

Fusion helps LEAST when:
  - Operations are compute bound (GEMM with large matrices)
  - Intermediate tensors are small (fit in cache anyway)
  - Operations have different parallelism patterns
    (e.g., can't easily fuse a reduction with an element-wise op)

The Bandwidth Equation

To understand why fusion matters, consider the numbers for an Apple M2 GPU:

M2 GPU specs:
  Memory bandwidth: ~100 GB/s
  Compute (FP16):   ~3.6 TFLOPS

For a SiLU activation on 4096 FP16 elements:
  Data to read:  4096 * 2 bytes = 8 KB
  Data to write: 4096 * 2 bytes = 8 KB
  Compute: 4096 * ~5 FLOPs (exp, add, div, mul) = ~20 KFLOPS

  Time limited by bandwidth: 16 KB / 100 GB/s = 0.16 microseconds
  Time limited by compute:   20 KFLOPS / 3.6 TFLOPS = 0.006 microseconds

  This operation is 27x more bandwidth bound than compute bound!

Every byte you can avoid reading from or writing to global memory is a direct performance win for bandwidth-bound operations. And in transformers, the majority of non-GEMM operations are bandwidth bound.

A Complete Example: Processing One Token

Let us trace the memory traffic for processing a single token through one LLaMA-7B layer, showing where fusion helps.

Model parameters (LLaMA 7B):
  d_model = 4096
  n_heads = 32
  d_k = 128
  d_ff = 11008 (intermediate FFN dimension)
  Weights in FP16

Step 1: RMSNorm
  Read: x (4096 * 2B = 8KB) + gamma (8KB) = 16KB
  Write: x_norm (8KB)
  Total: 24KB

Step 2: Q projection (GEMV: 1 x 4096 times 4096 x 4096)
  Read: x_norm (8KB) + Wq (4096*4096*2B = 32MB)
  Write: Q (8KB)
  Total: ~32MB   <-- Weight read dominates!

Step 3: K projection
  Read: x_norm (8KB) + Wk (4096*128*32*2B = ... depends on GQA)
  ... similar to Q

Step 4-6: V projection, attention, output projection
  ... similar pattern

Step 7-8: FFN up + gate + SiLU + down
  Read: Wu (4096*11008*2B = 86MB) + Wg (86MB) + Wd (86MB)
  Total: ~258MB of weight reads for FFN alone

Grand total weight reads per layer: ~400MB
For 32 layers: ~12.8GB of weight reads PER TOKEN

At 100 GB/s bandwidth: 12.8GB / 100 GB/s = 128ms per token = ~8 tokens/sec

This is why:

  1. Quantization is essential – 4-bit weights cut reads by 4x.
  2. Fusion matters for the non-GEMM ops – saving even a few MB per layer adds up.
  3. Decode is bandwidth-bound – the GPU spends most of its time waiting for memory.

Summary

In this chapter, we have established the foundations:

  • Tensors are multi-dimensional arrays. In ML, we work with rank-2 to rank-4 tensors constantly.
  • Memory layout (row-major with strides) determines how multi-dimensional indices map to flat GPU buffer addresses.
  • Coalesced access – adjacent threads reading adjacent memory – is critical for GPU performance.
  • Common operations include element-wise (trivially parallel), reductions (require thread cooperation), and matrix multiplications (require tiling).
  • Mixed precision (FP16 values, FP32 accumulation) gives us the best of both worlds.
  • Kernel fusion eliminates wasteful memory round-trips between operations.

The next chapter zooms in on the single most important operation: matrix multiplication. It consumes 90%+ of inference compute, and getting it right is the difference between a usable inference engine and a toy.



  1. Apple. “Metal Shading Language Specification, v3.1.” developer.apple.com. Covers the half type, vector types, and precision guarantees for FP16 arithmetic on Apple GPUs. See https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf.

  2. Micikevicius, P., et al. “Mixed Precision Training.” ICLR 2018. Establishes the practice of using FP16 for computation with FP32 accumulation to maintain numerical stability. See https://arxiv.org/abs/1710.03740.