Tensors and Linear Algebra on GPU
If you have spent any time reading ML papers or browsing inference codebases, you have seen the word “tensor” more times than you can count. It gets thrown around so casually that you might suspect it means something terribly complicated. It does not. At its heart, a tensor in the machine learning context is just a multi-dimensional array of numbers. That is it. A 1D tensor is a vector. A 2D tensor is a matrix. A 3D tensor is… well, a 3D array. And so on.
But here is the thing: while the concept is simple, the implementation on a GPU is where all the interesting engineering lives. How do you lay out a 4-dimensional array in a flat slab of GPU memory? How do you make sure that when 3,000 threads all try to read different elements at once, the memory system does not choke? How do you express broadcasting rules in terms of pointer arithmetic? These are the questions that determine whether your inference engine runs at 10 tokens per second or 100.
In this chapter, we will build your intuition for tensors from the ground up, then show exactly how they map to Metal GPU buffers. By the end, you will understand the memory layout decisions that dominate every kernel we write in later chapters.
What Is a Tensor, Really?
Let us start with the basics and build up.
Scalars, Vectors, Matrices, and Beyond
Rank 0 (Scalar): 42.0
Rank 1 (Vector): [1.0, 2.0, 3.0, 4.0]
Rank 2 (Matrix): [[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]]
Rank 3 (3D Tensor): [[[1.0, 2.0],
[3.0, 4.0],
[5.0, 6.0]],
[[7.0, 8.0],
[9.0, 10.0],
[11.0, 12.0]]]
The rank (or order) of a tensor is simply how many indices you need to pick out a single element:
- Rank 0: zero indices. It is just a number.
- Rank 1: one index.
v[i]gives you elementiof a vector. - Rank 2: two indices.
M[i][j]gives you the element at rowi, columnj. - Rank 3: three indices.
T[i][j][k]picks one element from a 3D block.
The shape describes how many elements exist along each dimension. For the rank-3 tensor above, the shape is (2, 3, 2) – two “slices,” each containing 3 rows of 2 elements.
Tensors in Transformers
In a typical transformer model, you encounter tensors of various ranks constantly:
Token embeddings: shape (seq_len, d_model) Rank 2
Weight matrix: shape (d_model, d_out) Rank 2
Attention scores: shape (n_heads, seq_len, seq_len) Rank 3
Batched activations: shape (batch, seq_len, d_model) Rank 3
Multi-head QKV: shape (batch, n_heads, seq, d_k) Rank 4
For inference on a single sequence (batch size 1), we can often drop the batch dimension. But even without batching, we regularly work with rank-2 and rank-3 tensors.
Here is a concrete example. Suppose we have a small model with:
d_model = 512(embedding dimension)n_heads = 8(number of attention heads)d_k = 64(per-head dimension, since 512 / 8 = 64)seq_len = 128(sequence length)
The query tensor Q after projection would have shape (128, 512), and after reshaping for multi-head attention, it becomes (8, 128, 64).
Memory Layouts: How Multi-Dimensional Data Lives in Flat Memory
Here is the fundamental problem: GPU memory (and CPU memory, for that matter) is a flat, one-dimensional address space. You have addresses 0, 1, 2, 3, and so on. But your tensor is multi-dimensional. So you need a rule for mapping multi-dimensional indices to flat memory addresses.
Row-Major Order (C-Style)
The most common layout, and the one used by virtually all ML frameworks and inference engines, is row-major order. The idea is simple: you lay out the last dimension contiguously, then the second-to-last, and so on.
For a 2D matrix with shape (3, 4):
Logical view: Memory layout (row-major):
col 0 col 1 col 2 col 3
row 0 [ a b c d ] addr 0: a
row 1 [ e f g h ] addr 1: b
row 2 [ i j k l ] addr 2: c
addr 3: d
addr 4: e
addr 5: f
addr 6: g
addr 7: h
addr 8: i
addr 9: j
addr 10: k
addr 11: l
The elements of row 0 are contiguous in memory (a, b, c, d), followed by row 1 (e, f, g, h), followed by row 2 (i, j, k, l). The last index changes fastest as you walk through memory.
To find element M[i][j] in a (R, C) matrix:
address(i, j) = base + (i * C + j) * element_size
Column-Major Order (Fortran-Style)
The alternative is column-major order, where the first dimension is contiguous. BLAS libraries and MATLAB use this. The same (3, 4) matrix in column-major:
Memory layout (column-major):
addr 0: a (row 0, col 0)
addr 1: e (row 1, col 0)
addr 2: i (row 2, col 0)
addr 3: b (row 0, col 1)
addr 4: f (row 1, col 1)
addr 5: j (row 2, col 1)
addr 6: c (row 0, col 2)
addr 7: g (row 1, col 2)
addr 8: k (row 2, col 2)
addr 9: d (row 0, col 3)
addr 10: h (row 1, col 3)
addr 11: l (row 2, col 3)
Here the first index changes fastest. To find element M[i][j]:
address(i, j) = base + (j * R + i) * element_size
Why Does This Matter for GPU Performance?
Consider what happens when 32 threads in a SIMD group all need to read one element each. If thread t reads element M[row][t] and the matrix is row-major, all 32 threads read from consecutive memory addresses. The GPU’s memory system can satisfy this with a single coalesced memory transaction.
Row-major, threads reading M[row][0..31]:
Memory: [ M[r][0] | M[r][1] | M[r][2] | ... | M[r][31] | M[r][32] | ... ]
^ ^ ^ ^
thread 0 thread 1 thread 2 thread 31
All addresses are contiguous --> ONE memory transaction (coalesced)
But if the matrix were column-major and threads tried to read M[row][0..31], each thread would be reading from addresses that are R elements apart:
Column-major, threads reading M[row][0..31]:
Memory: [ M[0][0] | M[1][0] | M[2][0] | ... | M[0][1] | M[1][1] | ... ]
Thread 0 reads M[row][0] at offset row
Thread 1 reads M[row][1] at offset R + row
Thread 2 reads M[row][2] at offset 2R + row
...
Addresses are R elements apart --> MANY memory transactions (strided)
This can be 10-30x slower depending on the stride. Coalesced access is one of the single most important performance considerations on any GPU.
Strides: The General Case
The concept of strides generalizes memory layout. A stride tells you how many elements to skip in memory when you increment one index by one.
For a row-major matrix of shape (R, C):
- Stride along dimension 0 (rows):
C(skip one whole row ofCelements) - Stride along dimension 1 (columns):
1(elements are adjacent)
For a row-major 3D tensor of shape (D0, D1, D2):
- Stride along dimension 0:
D1 * D2 - Stride along dimension 1:
D2 - Stride along dimension 2:
1
The general address formula for any tensor:
address(i0, i1, ..., in) = base + sum(ik * stride_k) for k in 0..n
Here is a concrete example. For a tensor of shape (2, 3, 4) in row-major:
Strides: (12, 4, 1)
Because:
stride[2] = 1
stride[1] = shape[2] = 4
stride[0] = shape[1] * shape[2] = 3 * 4 = 12
Element T[1][2][3]:
address = 0 + 1*12 + 2*4 + 3*1 = 12 + 8 + 3 = 23
Let us verify by counting:
Slice 0: [[ 0 1 2 3] elements 0-11
[ 4 5 6 7]
[ 8 9 10 11]]
Slice 1: [[12 13 14 15] elements 12-23
[16 17 18 19]
[20 21 22 23]] <-- T[1][2][3] = element 23. Correct!
Strides also let you express interesting things like transposition without copying data. If you have a matrix with strides (C, 1), its transpose has strides (1, C) – same data, just different stride interpretation.
How Tensors Map to GPU Buffers
On Metal, a tensor lives in an MTLBuffer – a contiguous block of GPU-accessible memory. There is no built-in “tensor” type in Metal. The buffer is just bytes. Your shader code interprets those bytes as a tensor by computing offsets from indices using the stride formula.
The Buffer Layout
MTLBuffer (contiguous memory):
+----+----+----+----+----+----+----+----+----+----+----+----+
| e0 | e1 | e2 | e3 | e4 | e5 | e6 | e7 | e8 | e9 |e10 |e11 |
+----+----+----+----+----+----+----+----+----+----+----+----+
^ ^
buffer.contents() buffer end
Each element is sizeof(element_type) bytes:
float = 4 bytes
half = 2 bytes
int8_t = 1 byte
For a half-precision (2, 3, 2) tensor:
Total elements = 2 * 3 * 2 = 12
Total bytes = 12 * 2 = 24 bytes
Computing Offsets in a Metal Shader
In a Metal compute shader, you typically receive the buffer pointer and the tensor’s metadata (shape, strides) as arguments. Here is what the pattern looks like:
kernel void elementwise_relu(
device const half* input [[buffer(0)]],
device half* output [[buffer(1)]],
constant uint3& shape [[buffer(2)]], // (D0, D1, D2)
constant uint3& strides [[buffer(3)]], // (S0, S1, S2)
uint3 gid [[thread_position_in_grid]]
) {
// Bounds check
if (gid.x >= shape.x || gid.y >= shape.y || gid.z >= shape.z) return;
// Compute flat index from multi-dimensional position
uint idx = gid.x * strides.x + gid.y * strides.y + gid.z * strides.z;
// Read, compute, write
half val = input[idx];
output[idx] = val > 0 ? val : 0;
}
The key insight: the GPU hardware knows nothing about tensors. It only knows about flat memory addresses. All the multi-dimensional indexing is just arithmetic that we do in the shader.
Alignment Considerations
Metal has specific alignment requirements for buffer access:
Type Size Alignment
float 4B 4B
half 2B 2B
float4 16B 16B
half4 8B 8B
Rule of thumb: access addresses that are multiples of the type size.
When packing tensor data into buffers, you generally want each row to start at an aligned address. For a matrix of half values with 127 columns, each row takes 254 bytes. The next row starts at byte 254, which is 2-byte aligned (fine for half). But if you wanted to read rows as half4 vectors (common optimization), you would want 8-byte alignment, which means padding rows to 128 elements (256 bytes).
Multiple Tensors in One Buffer (Offset Packing)
In practice, inference engines often pack multiple tensors into a single large buffer with offsets:
Single MTLBuffer:
+------------------+------------------+------------------+
| Weight matrix | Bias vector | LayerNorm scale |
| (4096 x 4096) | (4096) | (4096) |
| offset: 0 | offset: 33MB | offset: 33MB+8K |
+------------------+------------------+------------------+
This reduces the number of buffer bindings needed (Metal has a limit on how many buffers you can bind to a single kernel invocation) and can improve memory allocation efficiency.
Common Tensor Operations
Now let us look at the operations that transformers actually perform on tensors and how they parallelize on a GPU.
Element-wise Operations
These apply a function independently to each element. The output has the same shape as the input. Examples:
- ReLU:
output[i] = max(0, input[i]) - SiLU (Swish):
output[i] = input[i] * sigmoid(input[i]) - GELU:
output[i] = input[i] * 0.5 * (1 + erf(input[i] / sqrt(2))) - Addition:
output[i] = a[i] + b[i] - Scalar multiply:
output[i] = alpha * input[i]
Parallelization is trivial – assign one thread per element:
Input tensor: [a0, a1, a2, a3, a4, a5, a6, a7]
Thread 0 --> processes a0
Thread 1 --> processes a1
Thread 2 --> processes a2
...
Thread 7 --> processes a7
Each thread:
1. Read input[thread_id]
2. Apply function
3. Write output[thread_id]
No synchronization needed! Each thread is independent.
For a tensor with N elements, you dispatch ceil(N / threads_per_threadgroup) threadgroups, each with threads_per_threadgroup threads (commonly 256 or 1024).
Dispatch for N = 10000, threads_per_group = 256:
Threadgroups needed = ceil(10000 / 256) = 40 threadgroups
Threadgroup 0: threads 0-255 --> elements 0-255
Threadgroup 1: threads 256-511 --> elements 256-511
...
Threadgroup 38: threads 9728-9983 --> elements 9728-9983
Threadgroup 39: threads 9984-10239 --> elements 9984-9999 (rest out of bounds)
Reduction Operations
Reductions collapse one or more dimensions by aggregating elements. Examples:
- Sum along a dimension:
output[i] = sum(input[i][j] for all j) - Max along a dimension:
output[i] = max(input[i][j] for all j) - Mean: sum divided by count
These are trickier to parallelize because threads need to cooperate to produce the result.
Strategy 1: One threadgroup per reduction
Matrix shape (4, 8), reduce along columns (dim 1):
Row 0: [a0 a1 a2 a3 a4 a5 a6 a7] --> sum = a0+a1+...+a7
Row 1: [b0 b1 b2 b3 b4 b5 b6 b7] --> sum = b0+b1+...+b7
Row 2: [c0 c1 c2 c3 c4 c5 c6 c7] --> sum = c0+c1+...+c7
Row 3: [d0 d1 d2 d3 d4 d5 d6 d7] --> sum = d0+d1+...+d7
Assign one threadgroup per row.
Within each threadgroup, threads cooperate to sum the row:
Threadgroup 0 (8 threads for row 0):
Step 1: Each thread loads one element
t0=a0, t1=a1, t2=a2, t3=a3, t4=a4, t5=a5, t6=a6, t7=a7
Step 2: Parallel reduction tree
t0 = a0+a4 t1 = a1+a5 t2 = a2+a6 t3 = a3+a7
t0 = (a0+a4)+(a2+a6) t1 = (a1+a5)+(a3+a7)
t0 = total sum
Step 3: Thread 0 writes result
Strategy 2: SIMD group reduction
On Metal, SIMD groups of 32 threads have special fast operations for reduction. The simd_sum() function sums a value across all threads in the SIMD group in a single instruction cycle:
SIMD group reduction (32 threads):
Before simd_sum:
t0=3.0 t1=1.0 t2=4.0 t3=1.0 ... t31=2.0
After val = simd_sum(val):
t0=sum t1=sum t2=sum t3=sum ... t31=sum
(all threads have the same total sum)
For reductions over dimensions larger than 32, you use a combination: each SIMD group reduces its portion, then results are combined using threadgroup memory.
Broadcasting
Broadcasting lets you operate on tensors of different shapes by “stretching” the smaller tensor to match the larger one, conceptually. No data is actually copied.
Example: Add a bias vector to every row of a matrix
Matrix A shape: (1024, 512)
Bias b shape: (512,)
Result C[i][j] = A[i][j] + b[j] for all i, j
The bias is "broadcast" along dimension 0:
A: b: C:
[[a00 a01 ... a0,511] [b0 b1 ... b511] [[a00+b0 a01+b1 ...]
[a10 a11 ... a1,511] + = [a10+b0 a11+b1 ...]
... ...
[a1023,0 ... ]] [a1023,0+b0 ... ]]
In the shader, broadcasting is implemented by simply not advancing the index along the broadcast dimension:
// Adding bias (shape: [N]) to matrix (shape: [M, N])
kernel void add_bias(
device const half* matrix [[buffer(0)]],
device const half* bias [[buffer(1)]],
device half* output [[buffer(2)]],
uint2 gid [[thread_position_in_grid]] // (row, col)
) {
uint row = gid.x;
uint col = gid.y;
uint idx = row * N + col;
output[idx] = matrix[idx] + bias[col]; // bias indexed only by col
}
The bias[col] access ignores the row entirely. That is broadcasting – the same bias element b[j] is used for every row.
Parallelizing Tensor Operations
The general strategies for mapping tensor operations to GPU threads fall into a few patterns. Let us formalize them.
Pattern 1: One Thread Per Element
Used for element-wise operations.
Tensor shape: (M, N)
Grid: (M, N)
Each thread (i, j) processes element [i][j]
+-------+-------+-------+-------+
| t(0,0)| t(0,1)| t(0,2)| t(0,3)|
+-------+-------+-------+-------+
| t(1,0)| t(1,1)| t(1,2)| t(1,3)|
+-------+-------+-------+-------+
| t(2,0)| t(2,1)| t(2,2)| t(2,3)|
+-------+-------+-------+-------+
Dispatch: grid_size = (M, N, 1)
threadgroup_size = (16, 16, 1) // 256 threads per group
Pattern 2: One Threadgroup Per Row (or Column)
Used for reductions, softmax, layer normalization – anything that operates across one dimension.
Tensor shape: (M, N)
Grid: M threadgroups, each with T threads
Threadgroup 0 --> Row 0: [e0 e1 e2 ... eN-1]
Threadgroup 1 --> Row 1: [e0 e1 e2 ... eN-1]
...
Threadgroup M-1 --> Row M-1
Within each threadgroup, T threads cooperate:
Thread t handles elements: t, t+T, t+2T, ...
For N=1024, T=256:
Thread 0: elements 0, 256, 512, 768
Thread 1: elements 1, 257, 513, 769
...
Pattern 3: Tiled Processing
Used for matrix multiplication and attention. Threads in a threadgroup cooperatively load tiles into fast threadgroup memory, then compute.
Matrix multiply C = A * B
A: (M, K), B: (K, N), C: (M, N)
Tile size: (Tm, Tn) per threadgroup, iterate over K in tiles of Tk
Threadgroup (gi, gj) computes C[gi*Tm..(gi+1)*Tm][gj*Tn..(gj+1)*Tn]
B (K x N)
+---+---+---+---+
| | Bj| | | Bj = tile column j
+---+---+---+---+
| | Bj| | |
A +---+---+---+---+
(M x K)
+--+--+ +---+
| |Ai| x | Bj| = Cij (Tm x Tn tile of output)
+--+--+ +---+
| | |
+--+--+
Ai = tile row i of A
For each k-tile: load Ai[:,k:k+Tk] and Bj[k:k+Tk,:] into shared memory
compute partial products
accumulate into Cij registers
We will explore tiled matrix multiplication in enormous detail in the next chapter.
Pattern 4: SIMD Group per Work Unit
Metal’s SIMD groups (wavefronts) of 32 threads have hardware support for fast intra-group communication. Many kernels assign one SIMD group to one logical unit of work:
GEMV: One SIMD group per output element (or a few elements)
Output vector y = W * x (W is M x K, x is K x 1, y is M x 1)
SIMD group 0 --> y[0] = dot(W[0,:], x)
SIMD group 1 --> y[1] = dot(W[1,:], x)
...
Within each SIMD group:
32 threads divide the K-length dot product
Thread t handles indices: t, t+32, t+64, ...
Each thread accumulates a partial sum
simd_sum() gives the total
The Transformer’s Core Operations
A transformer model is built from a surprisingly small set of tensor operations, repeated many times. Let us catalog them and understand their computational characteristics.
Linear Projection (GEMM / GEMV)
The workhorse of the transformer. Every layer has multiple linear projections:
In one transformer block:
Q = X * Wq + bq (query projection)
K = X * Wk + bk (key projection)
V = X * Wv + bv (value projection)
O = Attn * Wo + bo (output projection)
H = X * W1 + b1 (FFN first layer)
Y = H * W2 + b2 (FFN second layer)
That's 6 matrix multiplications per layer!
A 32-layer model does 192 matmuls per forward pass.
The computation is Y = X * W where:
- During prefill: X is
(seq_len, d_model), W is(d_model, d_out)– a full GEMM - During decode: X is
(1, d_model), W is(d_model, d_out)– a GEMV
Prefill (GEMM): Decode (GEMV):
X (128 x 4096) x (1 x 4096)
+----------+ +----------+
| | | |
| | W (4096 x 4096) +----------+ W (4096 x 4096)
| | +----------+ +----------+
| | x | | x | |
| | | | | |
| | | | | |
+----------+ +----------+ +----------+
= =
Y (128 x 4096) y (1 x 4096)
+----------+ +----------+
| | | |
| | +----------+
| |
+----------+
The GEMV case (decode) is memory-bandwidth bound because you read the entire weight matrix just to produce one output row. The GEMM case (prefill) is compute bound because you amortize the weight read across many input rows.
Attention (Batched Dot Products + Softmax)
The attention mechanism computes, for each query position, a weighted average of all value positions:
scores = Q * K^T / sqrt(d_k) Matrix multiply: (seq, d_k) x (d_k, seq) = (seq, seq)
weights = softmax(scores) Element-wise + reduction
output = weights * V Matrix multiply: (seq, seq) x (seq, d_k) = (seq, d_k)
The middle step (softmax) is a reduction along the last dimension of each row, which requires special handling. We will devote Chapters 15 and 16 to attention.
Layer Normalization (RMSNorm)
Modern transformers use RMSNorm, which normalizes each row by its root-mean-square:
RMSNorm(x) = x / sqrt(mean(x^2) + eps) * gamma
For a row x of length d_model:
1. Compute sum of squares: ss = sum(x[i]^2 for i in 0..d_model)
2. Compute RMS: rms = sqrt(ss / d_model + eps)
3. Normalize and scale: output[i] = x[i] / rms * gamma[i]
Step 1 is a reduction (sum).
Step 3 is element-wise.
This follows the “one threadgroup per row” pattern. The reduction in step 1 uses SIMD group sums.
Activation Functions
Applied element-wise between the FFN layers:
SiLU(x) = x * sigmoid(x) = x * (1 / (1 + exp(-x)))
GELU(x) = x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
These are purely element-wise, one thread per element. Trivially parallel.
Residual Connections
Just element-wise addition:
output = sublayer_output + input
Every transformer block has two residual connections:
x = attention(norm(x)) + x <-- residual
x = ffn(norm(x)) + x <-- residual
Putting It All Together: One Transformer Block
Input: x (seq_len, d_model)
Step 1: RMSNorm --> x_norm (seq_len, d_model) [reduction per row]
Step 2: Q = x_norm * Wq --> Q (seq_len, d_model) [GEMM or GEMV]
Step 3: K = x_norm * Wk --> K (seq_len, d_kv) [GEMM or GEMV]
Step 4: V = x_norm * Wv --> V (seq_len, d_kv) [GEMM or GEMV]
Step 5: Reshape Q/K/V for multi-head
Step 6: Attention --> attn (seq_len, d_model) [batched matmul + softmax]
Step 7: O = attn * Wo --> O (seq_len, d_model) [GEMM or GEMV]
Step 8: residual --> x = O + x [element-wise add]
Step 9: RMSNorm --> x_norm [reduction per row]
Step 10: FFN up --> h (seq_len, d_ff) [GEMM or GEMV]
Step 11: FFN gate --> g (seq_len, d_ff) [GEMM or GEMV]
Step 12: SiLU + multiply --> h = SiLU(g) * h [element-wise]
Step 13: FFN down --> f (seq_len, d_model) [GEMM or GEMV]
Step 14: residual --> x = f + x [element-wise add]
Output: x (seq_len, d_model)
Count the operations: 6 matrix multiplications, 2 normalizations, 2 residual adds, 1 activation, and the attention computation. The matrix multiplications dominate compute time (90%+ of total FLOPS).
Floating-Point Precision: FP16, FP32, and BF161
The choice of numeric precision has enormous impact on both performance and quality.
FP32 (Single Precision)
FP32: 1 sign bit + 8 exponent bits + 23 mantissa bits = 32 bits
+---+----------+-----------------------+
| S | EEEEEEEE | MMMMMMMMMMMMMMMMMMMMMMM|
+---+----------+-----------------------+
1 8 23
Range: ~1.18e-38 to ~3.4e+38
Precision: ~7 decimal digits
Size: 4 bytes per value
FP32 is the “safe” choice. It has plenty of range and precision for any ML computation. But it uses twice the memory and bandwidth of FP16, making it 2x slower for memory-bound operations.
FP16 (Half Precision)
FP16: 1 sign bit + 5 exponent bits + 10 mantissa bits = 16 bits
+---+-------+------------+
| S | EEEEE | MMMMMMMMMM |
+---+-------+------------+
1 5 10
Range: ~6.1e-5 to ~65504
Precision: ~3.3 decimal digits
Size: 2 bytes per value
FP16 halves memory usage and doubles effective bandwidth. Apple GPUs have native FP16 support and can process FP16 values at 2x the rate of FP32 in many operations.
The downside is the limited range. Values smaller than 6.1e-5 become zero (underflow), and values larger than 65504 become infinity (overflow). During inference, weights are usually small enough that this is not a problem. But intermediate computations (especially in attention) can overflow if not handled carefully.
BF16 (Brain Float 16)
BF16: 1 sign bit + 8 exponent bits + 7 mantissa bits = 16 bits
+---+----------+---------+
| S | EEEEEEEE | MMMMMMM |
+---+----------+---------+
1 8 7
Range: ~1.18e-38 to ~3.4e+38 (same as FP32!)
Precision: ~2.4 decimal digits
Size: 2 bytes per value
BF16 has the same exponent range as FP32 (so no overflow/underflow issues) but less precision than FP16. It was designed specifically for ML workloads where range matters more than precision.
Apple Silicon support: M1/M2 GPUs do not have native BF16 support in Metal. M3+ has some BF16 capabilities. In practice, most Apple GPU inference uses FP16 for computation with FP32 for accumulation.
The Mixed-Precision Strategy
The standard approach in inference:
Weights: stored in FP16 (or quantized, e.g., 4-bit)
Activations: computed in FP16
Accumulations: done in FP32 (dot products, reductions)
Final output: cast back to FP16
Why FP32 for accumulations?
Consider summing 4096 FP16 values, each around 0.01:
True sum: ~40.96
FP16 sum: after adding a few hundred values, each new addition
changes the sum by less than FP16 precision can represent.
You lose significant accuracy.
FP32 sum: no problem. 7 decimal digits of precision is plenty.
In Metal shader code, this looks like:
// Dot product with mixed precision
float sum = 0.0f; // FP32 accumulator
for (uint i = 0; i < K; i++) {
sum += float(A[i]) * float(B[i]); // Cast FP16 inputs to FP32
}
output[idx] = half(sum); // Cast result back to FP16
Precision Comparison Table
+--------+------+--------+-----------+-------------+------------------+
| Format | Bits | Range | Precision | Memory/elem | Apple GPU Native |
+--------+------+--------+-----------+-------------+------------------+
| FP32 | 32 | 1e38 | ~7 digits | 4 bytes | Yes |
| FP16 | 16 | 65504 | ~3 digits | 2 bytes | Yes (2x rate) |
| BF16 | 16 | 1e38 | ~2 digits | 2 bytes | M3+ partial |
| INT8 | 8 | -128 | exact | 1 byte | Yes |
| | | to 127 | integers | | |
| INT4 | 4 | -8 | exact | 0.5 bytes | No (emulated) |
| | | to 7 | integers | | |
+--------+------+--------+-----------+-------------+------------------+
Fused Operations: Saving Bandwidth by Combining Kernels
Here is one of the most important optimization ideas in GPU programming: kernel fusion.2
The Problem: Bandwidth Waste
Consider computing SiLU activation followed by element-wise multiply (as in the LLaMA FFN):
Step 1: gate = SiLU(x * Wg)
Step 2: up = x * Wu
Step 3: hidden = gate * up
If each step is a separate kernel:
Kernel 1 (SiLU):
Read gate_raw from global memory (N elements read)
Compute SiLU
Write gate to global memory (N elements written)
Kernel 2 (element-wise multiply):
Read gate from global memory (N elements read - AGAIN!)
Read up from global memory (N elements read)
Compute gate * up
Write hidden to global memory (N elements written)
Total memory traffic: 2N reads + N writes + 2N reads + N writes = 4N reads + 2N writes
The Solution: Fused Kernel
Combine both operations into a single kernel:
Fused kernel (SiLU + multiply):
Read gate_raw from global memory (N elements read)
Read up from global memory (N elements read)
Compute SiLU(gate_raw) * up in registers
Write hidden to global memory (N elements written)
Total memory traffic: 2N reads + N writes
Savings: 2N reads + N writes eliminated!
The fused kernel avoids the round trip through global memory between operations. The intermediate value (the SiLU output) lives only in registers, which are essentially free to access.
Common Fused Operations in Transformers
+------------------------------------+---------------------------------+
| Fused Operation | What It Combines |
+------------------------------------+---------------------------------+
| Fused SiLU gate | SiLU(x) * y |
| Fused RMSNorm | norm + scale (gamma multiply) |
| Fused attention | QK^T + scale + mask + softmax |
| (FlashAttention) | + V multiply (Ch 16) |
| Fused dequant + matmul | Dequantize weights + GEMV/GEMM |
| Fused residual + norm | x + sublayer + RMSNorm |
| Fused RoPE + reshape | Rotary embedding + head reshape |
+------------------------------------+---------------------------------+
When Fusion Matters and When It Does Not
Fusion helps MOST when:
- Operations are memory-bandwidth bound (small compute per element)
- Intermediate tensors are large
- Operations are sequential (output of one feeds input of next)
Fusion helps LEAST when:
- Operations are compute bound (GEMM with large matrices)
- Intermediate tensors are small (fit in cache anyway)
- Operations have different parallelism patterns
(e.g., can't easily fuse a reduction with an element-wise op)
The Bandwidth Equation
To understand why fusion matters, consider the numbers for an Apple M2 GPU:
M2 GPU specs:
Memory bandwidth: ~100 GB/s
Compute (FP16): ~3.6 TFLOPS
For a SiLU activation on 4096 FP16 elements:
Data to read: 4096 * 2 bytes = 8 KB
Data to write: 4096 * 2 bytes = 8 KB
Compute: 4096 * ~5 FLOPs (exp, add, div, mul) = ~20 KFLOPS
Time limited by bandwidth: 16 KB / 100 GB/s = 0.16 microseconds
Time limited by compute: 20 KFLOPS / 3.6 TFLOPS = 0.006 microseconds
This operation is 27x more bandwidth bound than compute bound!
Every byte you can avoid reading from or writing to global memory is a direct performance win for bandwidth-bound operations. And in transformers, the majority of non-GEMM operations are bandwidth bound.
A Complete Example: Processing One Token
Let us trace the memory traffic for processing a single token through one LLaMA-7B layer, showing where fusion helps.
Model parameters (LLaMA 7B):
d_model = 4096
n_heads = 32
d_k = 128
d_ff = 11008 (intermediate FFN dimension)
Weights in FP16
Step 1: RMSNorm
Read: x (4096 * 2B = 8KB) + gamma (8KB) = 16KB
Write: x_norm (8KB)
Total: 24KB
Step 2: Q projection (GEMV: 1 x 4096 times 4096 x 4096)
Read: x_norm (8KB) + Wq (4096*4096*2B = 32MB)
Write: Q (8KB)
Total: ~32MB <-- Weight read dominates!
Step 3: K projection
Read: x_norm (8KB) + Wk (4096*128*32*2B = ... depends on GQA)
... similar to Q
Step 4-6: V projection, attention, output projection
... similar pattern
Step 7-8: FFN up + gate + SiLU + down
Read: Wu (4096*11008*2B = 86MB) + Wg (86MB) + Wd (86MB)
Total: ~258MB of weight reads for FFN alone
Grand total weight reads per layer: ~400MB
For 32 layers: ~12.8GB of weight reads PER TOKEN
At 100 GB/s bandwidth: 12.8GB / 100 GB/s = 128ms per token = ~8 tokens/sec
This is why:
- Quantization is essential – 4-bit weights cut reads by 4x.
- Fusion matters for the non-GEMM ops – saving even a few MB per layer adds up.
- Decode is bandwidth-bound – the GPU spends most of its time waiting for memory.
Summary
In this chapter, we have established the foundations:
- Tensors are multi-dimensional arrays. In ML, we work with rank-2 to rank-4 tensors constantly.
- Memory layout (row-major with strides) determines how multi-dimensional indices map to flat GPU buffer addresses.
- Coalesced access – adjacent threads reading adjacent memory – is critical for GPU performance.
- Common operations include element-wise (trivially parallel), reductions (require thread cooperation), and matrix multiplications (require tiling).
- Mixed precision (FP16 values, FP32 accumulation) gives us the best of both worlds.
- Kernel fusion eliminates wasteful memory round-trips between operations.
The next chapter zooms in on the single most important operation: matrix multiplication. It consumes 90%+ of inference compute, and getting it right is the difference between a usable inference engine and a toy.
-
Apple. “Metal Shading Language Specification, v3.1.” developer.apple.com. Covers the
halftype, vector types, and precision guarantees for FP16 arithmetic on Apple GPUs. See https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf. ↩ -
Micikevicius, P., et al. “Mixed Precision Training.” ICLR 2018. Establishes the practice of using FP16 for computation with FP32 accumulation to maintain numerical stability. See https://arxiv.org/abs/1710.03740. ↩