Threadgroups, SIMD Groups, and Dispatch
So far we have talked about individual threads and what they can do. But a single GPU thread is weak – it is only useful because there are thousands of them running simultaneously. The way you organize those thousands of threads matters enormously for performance. Get the geometry wrong and you leave half the GPU idle. Get it right and your kernel saturates every execution unit.
This chapter is about the organizational hierarchy of GPU threads in Metal, and how to dispatch them effectively.
The Thread Hierarchy
Metal organizes threads into a three-level hierarchy:
Grid (all threads in the dispatch)
+------------------------------------------------------------------+
| |
| Threadgroup 0 Threadgroup 1 Threadgroup 2 |
| +-----------------+ +-----------------+ +-----------------+ |
| | SIMD Group 0 | | SIMD Group 0 | | SIMD Group 0 | |
| | [32 threads] | | [32 threads] | | [32 threads] | |
| +-----------------+ +-----------------+ +-----------------+ |
| | SIMD Group 1 | | SIMD Group 1 | | SIMD Group 1 | |
| | [32 threads] | | [32 threads] | | [32 threads] | |
| +-----------------+ +-----------------+ +-----------------+ |
| | SIMD Group 2 | | SIMD Group 2 | | SIMD Group 2 | |
| | [32 threads] | | [32 threads] | | [32 threads] | |
| +-----------------+ +-----------------+ +-----------------+ |
| | SIMD Group 3 | | SIMD Group 3 | | SIMD Group 3 | |
| | [32 threads] | | [32 threads] | | [32 threads] | |
| +-----------------+ +-----------------+ +-----------------+ |
| |
+------------------------------------------------------------------+
Let us define each level.
SIMD Group (32 threads)
A SIMD group is 32 threads that execute in lockstep – the same instruction, at the same clock cycle. This is the fundamental unit of execution on Apple GPUs. You do not choose the SIMD group size; it is always 32.
Threads in a SIMD group can communicate through SIMD intrinsics (simd_sum, simd_shuffle, etc.) with zero overhead – no barriers, no shared memory, just direct register reads. This is the fastest form of inter-thread communication available.
Threadgroup (up to 1024 threads)
A threadgroup is a collection of SIMD groups that are co-scheduled on the same GPU compute unit. All threads in a threadgroup:
- Share access to
threadgroupmemory (fast on-chip SRAM) - Can synchronize with
threadgroup_barrier() - Are guaranteed to execute concurrently (they are all resident on the same compute unit)
A threadgroup can contain up to 1024 threads, which means up to 32 SIMD groups. In practice, common sizes are 128 (4 SIMD groups), 256 (8 SIMD groups), or sometimes 512 or 1024.
Grid (the entire dispatch)
The grid is the collection of all threadgroups in a dispatch. Threadgroups in the grid are independent – there is no way for threads in different threadgroups to communicate during a single dispatch (except through device memory, but there are no cross-threadgroup barriers).
Two Ways to Dispatch: Threadgroups vs Threads
Metal gives you two dispatch methods, and the difference is subtle but important.
dispatchThreadgroups
You specify how many threadgroups to launch and how many threads per threadgroup:
// Launch a grid of threadgroups
let threadgroupsPerGrid = MTLSize(width: 16, height: 8, depth: 1)
let threadsPerThreadgroup = MTLSize(width: 256, height: 1, depth: 1)
encoder.dispatchThreadgroups(threadgroupsPerGrid,
threadsPerThreadgroup: threadsPerThreadgroup)
// Total threads = 16 * 8 * 256 = 32,768
With this method, you are responsible for computing the grid dimensions. If your problem size is not a perfect multiple of the threadgroup size, you need to handle the boundary condition in the kernel.
dispatchThreads
You specify the total number of threads you want, and Metal figures out the threadgroup count:
// Launch exactly this many threads
let threadsPerGrid = MTLSize(width: 4000, height: 3000, depth: 1)
let threadsPerThreadgroup = MTLSize(width: 16, height: 16, depth: 1)
encoder.dispatchThreads(threadsPerGrid,
threadsPerThreadgroup: threadsPerThreadgroup)
// Metal launches enough threadgroups to cover 4000 x 3000,
// and threads outside that range are automatically disabled
With dispatchThreads, Metal handles the boundary for you. If you request 4000 x 3000 threads with 16x16 threadgroups, Metal launches 250 x 188 threadgroups (25016 = 4000, 18816 = 3008), but the extra 8 rows of threads simply do not execute.
Which Should You Use?
For ML kernels, dispatchThreadgroups is almost always the right choice. Here is why:
- Explicit control: You know exactly how many threadgroups you are launching, which makes reasoning about resource usage straightforward.
- Kernel design: ML kernels are typically designed around threadgroup-level cooperation (tiled GEMM, reductions). The kernel code already assumes a specific threadgroup structure.
- No wasted threads: With
dispatchThreadgroups, you design your kernel so that every thread does useful work. There are no “out-of-bounds” threads.
dispatchThreads is nice for simple element-wise kernels where you just want one thread per element and do not care about threadgroup structure. But for anything involving shared memory, reductions, or tiling, use dispatchThreadgroups.
Choosing Threadgroup Sizes
The threadgroup size is one of the most important performance decisions you make. Here are the rules:
Rule 1: Always Use Multiples of 32
Since SIMD groups are 32 threads, your threadgroup size should always be a multiple of 32. If you use, say, 100 threads per threadgroup, Metal will still allocate 4 SIMD groups (128 threads worth of resources), but 28 threads in the last SIMD group will be idle. That is 22% waste.
Threadgroup size = 100 (bad):
+----------------+ +----------------+ +----------------+ +----------------+
| SIMD Group 0 | | SIMD Group 1 | | SIMD Group 2 | | SIMD Group 3 |
| 32 active | | 32 active | | 32 active | | 4 active |
| 0 idle | | 0 idle | | 0 idle | | 28 IDLE |
+----------------+ +----------------+ +----------------+ +----------------+
^^^^^^^^^^^^
22% waste!
Threadgroup size = 128 (good):
+----------------+ +----------------+ +----------------+ +----------------+
| SIMD Group 0 | | SIMD Group 1 | | SIMD Group 2 | | SIMD Group 3 |
| 32 active | | 32 active | | 32 active | | 32 active |
| 0 idle | | 0 idle | | 0 idle | | 0 idle |
+----------------+ +----------------+ +----------------+ +----------------+
0% waste
Rule 2: Common Sizes and When to Use Them
| Threadgroup Size | SIMD Groups | Use Case |
|---|---|---|
| 32 | 1 | Simple per-element ops, tiny reductions |
| 64 | 2 | Light cooperation needed |
| 128 | 4 | Default for GEMM (4 SIMD groups = 4 tiles) |
| 256 | 8 | Default for reductions, normalization |
| 512 | 16 | Large reductions, high occupancy kernels |
| 1024 | 32 | Maximum, rarely needed |
Rule 3: Balance Occupancy and Resources
Here is the tension: larger threadgroups use more resources (registers, threadgroup memory), which means fewer threadgroups can be resident on a compute unit at the same time. Smaller threadgroups use fewer resources but may not have enough threads for efficient cooperation.
Compute Unit (simplified)
+----------------------------------------------------------+
| Register File: 64KB Threadgroup Memory: 32KB |
| |
| Option A: 4 threadgroups of 128 threads (512 threads) |
| Each TG gets: 16KB registers, 8KB shared memory |
| Good latency hiding (many threads to switch between) |
| |
| Option B: 2 threadgroups of 256 threads (512 threads) |
| Each TG gets: 32KB registers, 16KB shared memory |
| Same total threads, but more resources per TG |
| |
| Option C: 1 threadgroup of 1024 threads (1024 threads) |
| Gets all: 64KB registers, 32KB shared memory |
| Lots of threads, but if this TG stalls, nothing else |
| can run on this compute unit |
+----------------------------------------------------------+
The sweet spot for ML kernels is usually 128-256 threads. This gives you enough SIMD groups for cooperation (4-8) while leaving room for multiple threadgroups per compute unit (good for latency hiding).
Rule 4: Consult maxTotalThreadsPerThreadgroup
Every PSO has a maxTotalThreadsPerThreadgroup property that tells you the maximum threadgroup size the kernel supports, based on its resource usage:
let pso = try device.makeComputePipelineState(function: function)
print(pso.maxTotalThreadsPerThreadgroup) // e.g., 1024, or maybe 512
If your kernel uses a lot of registers or threadgroup memory, this number may be less than 1024. Always check.
SIMD Group Cooperation Patterns
Now let us look at how SIMD groups work together to solve problems. These patterns show up over and over in ML kernels.
Pattern 1: SIMD Reduction
The simplest and most common pattern. Each thread has a value, and you want the sum (or max, or min) across all threads in the SIMD group.
float my_value = data[tid];
float total = simd_sum(my_value);
// All 32 lanes now hold the same total
Lane: 0 1 2 3 ... 31
Val: 3.0 1.0 4.0 1.5 ... 2.7
|
simd_sum
|
total = 87.3 (in all lanes)
Pattern 2: SIMD Broadcast
One lane has a value that all other lanes need.
float value;
if (simd_lane == 0) {
value = compute_something_expensive();
}
// Broadcast lane 0's value to all lanes
value = simd_broadcast_first(value);
// Or from a specific lane:
value = simd_shuffle(value, source_lane);
Pattern 3: SIMD Prefix Sum (Inclusive Scan)
Each lane gets the sum of all values from lane 0 through its own lane. Useful for compaction, histograms, and stream processing.
float val = data[tid];
float prefix = simd_prefix_inclusive_sum(val);
Lane: 0 1 2 3 4 5 ...
Input: 3 1 4 1 5 9 ...
Output: 3 4 8 9 14 23 ...
^ ^ ^
| | +-- 3+1+4
| +-- 3+1
+-- 3
Pattern 4: SIMD Shuffle for Data Reuse
When adjacent threads need overlapping data (like in convolution), shuffles avoid redundant memory loads:
// Each lane loads one element
float my_val = data[base + simd_lane];
// Now lane i can access neighboring values without memory loads:
float left = simd_shuffle_up(my_val, 1); // lane i gets lane (i-1)'s value
float right = simd_shuffle_down(my_val, 1); // lane i gets lane (i+1)'s value
// Stencil computation (e.g., 1D convolution with kernel [0.25, 0.5, 0.25])
float result = 0.25f * left + 0.5f * my_val + 0.25f * right;
Before shuffle:
Lane: 0 1 2 3 4 ...
Val: A B C D E ...
After simd_shuffle_down(val, 1):
Lane: 0 1 2 3 4 ...
Val: B C D E F ...
^ ^ ^
Each lane got the next lane's value
Pattern 5: Two-Stage Reduction (Cross-SIMD-Group)
When you need a reduction across the entire threadgroup (not just one SIMD group), you do it in two stages:
Stage 1: SIMD-level reduction (fast, no barriers)
+------------------+ +------------------+ +------------------+ +------------------+
| SIMD Group 0 | | SIMD Group 1 | | SIMD Group 2 | | SIMD Group 3 |
| 32 values | | 32 values | | 32 values | | 32 values |
| --> simd_sum | | --> simd_sum | | --> simd_sum | | --> simd_sum |
| Result: S0 | | Result: S1 | | Result: S2 | | Result: S3 |
+------------------+ +------------------+ +------------------+ +------------------+
Stage 2: Write partial sums to threadgroup memory, then reduce
+-------------------------------------------+
| threadgroup float partials[4]; |
| partials = [S0, S1, S2, S3] |
| |
| threadgroup_barrier(...) |
| |
| SIMD Group 0 reads all 4 values |
| --> simd_sum(partials[simd_lane]) |
| Result: S0 + S1 + S2 + S3 = TOTAL |
+-------------------------------------------+
Here is the code:
// Assume threadgroup size = 128 (4 SIMD groups)
threadgroup float partials[4];
// Stage 1: Each SIMD group reduces its 32 values
float local_sum = simd_sum(my_value);
// One thread per SIMD group writes to shared memory
if (simd_lane == 0) {
partials[simd_id] = local_sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Stage 2: First SIMD group reduces the partial sums
float total;
if (simd_id == 0 && simd_lane < 4) {
total = simd_sum(partials[simd_lane]);
}
// Broadcast to all threads if needed
threadgroup float shared_total;
if (lid == 0) {
shared_total = total;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
total = shared_total;
This two-stage pattern is the workhorse of ML kernels. You will find it in normalization (RMSNorm, LayerNorm), softmax, and anywhere else you need a global reduction within a threadgroup.
Thread Mapping: From Grid to Data
One of the most important design decisions is how you map threads to data. There are several common patterns.
Pattern A: One Thread Per Element
The simplest mapping. Each thread processes exactly one element:
kernel void scale(
device half* data [[buffer(0)]],
constant float& factor [[buffer(1)]],
uint tid [[thread_position_in_grid]]
) {
data[tid] = half(float(data[tid]) * factor);
}
Thread: 0 1 2 3 4 5 6 7 ...
Data: [0] [1] [2] [3] [4] [5] [6] [7] ...
| | | | | | | |
v v v v v v v v
Dispatch: one thread per element, trivial.
Pattern B: One Thread Per Vector (Vectorized)
Each thread processes a vector of elements:
kernel void scale_vec4(
device half4* data [[buffer(0)]],
constant float& factor [[buffer(1)]],
uint tid [[thread_position_in_grid]]
) {
half4 v = data[tid];
data[tid] = half4(float4(v) * factor);
}
Thread: 0 1 2 3 ...
Data: [0,1,2,3] [4,5,6,7] [8,9,10,11] [12,13,14,15] ...
| | | |
v v v v
Dispatch: one thread per 4 elements. Total threads = N / 4.
Pattern C: Stride Loop (One Threadgroup Per Row)
Each threadgroup processes an entire row, with threads striding through the elements:
kernel void row_sum(
device const float* matrix [[buffer(0)]],
device float* sums [[buffer(1)]],
constant uint& cols [[buffer(2)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]]
) {
uint row = gid;
float sum = 0.0f;
// Each thread strides through the row
for (uint col = lid; col < cols; col += 256) {
sum += matrix[row * cols + col];
}
// Reduce within threadgroup...
sum = simd_sum(sum);
// ... (two-stage reduction as before)
if (lid == 0) {
sums[row] = sum;
}
}
Row 0 (threadgroup 0):
Thread 0 reads: [0], [256], [512], [768], ...
Thread 1 reads: [1], [257], [513], [769], ...
Thread 2 reads: [2], [258], [514], [770], ...
...
Thread 255 reads: [255], [511], [767], [1023], ...
Dispatch: one threadgroup per row.
Pattern D: 2D Tiling (GEMM)
For matrix multiplication, threads are organized in a 2D grid of tiles:
kernel void gemm_tiled(
device const half* A [[buffer(0)]],
device const half* B [[buffer(1)]],
device half* C [[buffer(2)]],
constant uint& M [[buffer(3)]],
constant uint& N [[buffer(4)]],
constant uint& K [[buffer(5)]],
uint2 gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]]
) {
// Each threadgroup computes a TM x TN tile of C
// gid.y = tile row, gid.x = tile column
uint tile_row = gid.y; // Which TM-row block
uint tile_col = gid.x; // Which TN-column block
// Each SIMD group within the threadgroup computes a sub-tile
// ...
}
Matrix C (M x N):
+--------+--------+--------+--------+
| TG(0,0)| TG(0,1)| TG(0,2)| TG(0,3)| <-- threadgroup row 0
| 32x64 | 32x64 | 32x64 | 32x64 |
+--------+--------+--------+--------+
| TG(1,0)| TG(1,1)| TG(1,2)| TG(1,3)| <-- threadgroup row 1
| 32x64 | 32x64 | 32x64 | 32x64 |
+--------+--------+--------+--------+
| TG(2,0)| TG(2,1)| TG(2,2)| TG(2,3)|
| 32x64 | 32x64 | 32x64 | 32x64 |
+--------+--------+--------+--------+
Each threadgroup computes a TM x TN tile of C (e.g., 32 x 64).
Grid dimensions: (N/TN, M/TM, 1) threadgroups.
This is the tiling pattern used in akunu’s GEMM kernels. We will cover it in detail in the SIMD matrix operations chapter.
Dispatch Geometry Calculation
Let us put this together with concrete dispatch calculations for common ML operations.
Element-wise Operations (ReLU, SiLU, Add)
Problem: Apply activation to N elements
Approach: One thread per element (or per half4)
Vectorized (half4):
threads_needed = N / 4
threadgroup_size = 256
threadgroups = ceil(threads_needed / 256)
Example: N = 4096
threads_needed = 1024
threadgroup_size = 256
threadgroups = 4
Grid: MTLSize(4, 1, 1)
TG: MTLSize(256, 1, 1)
Row-wise Operations (Softmax, RMSNorm)
Problem: Process M rows of N elements each
Approach: One threadgroup per row
threadgroup_size = 256
threadgroups = M
Example: M = 32 (batch), N = 4096 (hidden dim)
threadgroups = 32
Grid: MTLSize(32, 1, 1)
TG: MTLSize(256, 1, 1)
Each of the 256 threads in a threadgroup strides through
4096 / 256 = 16 elements.
Matrix Multiplication (GEMM)
Problem: C[M,N] = A[M,K] * B[K,N]
Approach: 2D grid, each threadgroup computes a TM x TN tile
TM = 32, TN = 64 (tile sizes)
threadgroup_size = 128 (4 SIMD groups)
threadgroups_x = ceil(N / TN)
threadgroups_y = ceil(M / TM)
Example: M = 4096, N = 4096, K = 4096
threadgroups_x = 4096 / 64 = 64
threadgroups_y = 4096 / 32 = 128
Grid: MTLSize(64, 128, 1)
TG: MTLSize(128, 1, 1)
Total threadgroups: 8192
Total threads: 8192 * 128 = 1,048,576
Batched Operations
For batched operations, you can use the third grid dimension:
Problem: Batch of B matrices, each M x N, apply softmax per row
Approach: 3D grid -- batch x rows x 1
Grid: MTLSize(M, B, 1)
TG: MTLSize(256, 1, 1)
Kernel sees:
batch_idx = threadgroup_position_in_grid.y
row_idx = threadgroup_position_in_grid.x
Real-World Example: Dispatch for Multi-Head Attention
Let us trace through the dispatch geometry for multi-head attention, a critical component of transformer inference.
Suppose we have:
- Batch size B = 1 (single sequence, typical for inference)
- Number of heads H = 32
- Sequence length S = 2048
- Head dimension D = 128
The attention computation involves:
- Q * K^T -> scores [H, 1, S] (for single-token generation)
- softmax(scores / sqrt(D)) -> weights [H, 1, S]
- weights * V -> output [H, 1, D]
Step 1: Score computation (GEMV -- each head is a vector-matrix multiply)
For each head: q[1, D] * K[S, D]^T -> scores[1, S]
Grid: MTLSize(ceil(S/256), H, 1) -- one TG row per head
TG: MTLSize(256, 1, 1)
Each threadgroup computes 256 elements of the score vector
for one attention head.
Step 2: Softmax over scores
For each head: softmax(scores[1, S])
Grid: MTLSize(H, 1, 1) -- one TG per head
TG: MTLSize(256, 1, 1)
Each threadgroup processes one row (one head's scores).
Step 3: Weighted sum (another GEMV)
For each head: weights[1, S] * V[S, D] -> output[1, D]
Grid: MTLSize(ceil(D/32), H, 1)
TG: MTLSize(256, 1, 1)
In practice, akunu fuses some of these steps together to reduce memory traffic, but the dispatch geometry follows this general pattern.
Common Pitfalls
Pitfall 1: Threadgroup Size Not a Multiple of 32
// BAD: wastes 22% of GPU resources
encoder.dispatchThreadgroups(grid, threadsPerThreadgroup: MTLSize(100, 1, 1))
// GOOD: full utilization
encoder.dispatchThreadgroups(grid, threadsPerThreadgroup: MTLSize(128, 1, 1))
Pitfall 2: Too Few Threadgroups
If you dispatch only 2 threadgroups on a GPU with 10 compute units, 8 compute units sit idle. You want at least as many threadgroups as compute units, and ideally several times more for latency hiding.
M2 Ultra: 76 compute units
Minimum threadgroups for full utilization: 76
Better: 76 * 4 = 304+ threadgroups (multiple waves)
Pitfall 3: Forgetting Bounds Checks
When using dispatchThreadgroups, your total thread count may exceed your data size. Always check bounds:
kernel void safe_scale(
device half* data [[buffer(0)]],
constant uint& count [[buffer(1)]],
uint tid [[thread_position_in_grid]]
) {
if (tid >= count) return; // Bounds check!
data[tid] = data[tid] * 2.0h;
}
Pitfall 4: Divergent Execution Within a SIMD Group
All 32 threads in a SIMD group execute together. If some threads take one branch and others take a different branch, the SIMD group executes both branches (masking inactive threads). This is called divergence and it wastes cycles:
// BAD: divergent branches within a SIMD group
if (tid % 2 == 0) {
// Even threads do expensive path A
result = expensive_computation_A(data[tid]);
} else {
// Odd threads do expensive path B
result = expensive_computation_B(data[tid]);
}
// Both paths execute for every SIMD group -- 2x the cost!
// BETTER: organize work so entire SIMD groups take the same branch
uint group = tid / 32;
if (group % 2 == 0) {
// Entire SIMD groups do path A
result = expensive_computation_A(data[tid]);
} else {
// Entire SIMD groups do path B
result = expensive_computation_B(data[tid]);
}
Pitfall 5: Not Accounting for Non-Uniform Threadgroups
When using dispatchThreads, the last threadgroup in each dimension may be smaller than requested. If your kernel assumes a fixed threadgroup size (e.g., uses threadgroup float shared[256]), it will still work, but some of those shared memory slots will contain garbage. Be careful when reducing – only reduce over the actual active thread count.
Visualizing Execution on Apple Silicon
Let us trace how a dispatch actually executes on an Apple Silicon GPU.
An M1 Pro has 16 compute units. Each compute unit can execute multiple threadgroups concurrently (depending on resource usage). Here is a simplified timeline:
Dispatch: 64 threadgroups, 128 threads each (4 SIMD groups per TG)
Compute Unit 0: [TG 0][TG 16][TG 32][TG 48] (4 TGs across time)
Compute Unit 1: [TG 1][TG 17][TG 33][TG 49]
Compute Unit 2: [TG 2][TG 18][TG 34][TG 50]
...
Compute Unit 15: [TG 15][TG 31][TG 47][TG 63]
Wave 0: TG 0-15 (one TG per compute unit)
Wave 1: TG 16-31 (launched as wave 0 TGs finish)
Wave 2: TG 32-47
Wave 3: TG 48-63
Total: 4 waves to process all 64 threadgroups
In reality, the scheduler is more dynamic than this – it can have multiple threadgroups resident per compute unit if there are enough resources. But the wave model gives you the right intuition. More threadgroups = more waves = more opportunity for the GPU to hide latency.
Choosing Grid Dimensions for ML Workloads
Here is a decision flowchart for common ML operations:
What kind of operation?
|
+-- Element-wise (ReLU, add, scale)?
| Threads per element: 1 (or 1 per half4 for vectorization)
| Grid: (ceil(N/4/TG_SIZE), 1, 1)
| TG: (256, 1, 1)
|
+-- Row-wise reduction (softmax, norm)?
| One threadgroup per row
| Grid: (num_rows, 1, 1)
| TG: (256, 1, 1)
|
+-- Matrix multiply (GEMM)?
| 2D grid of tiles
| Grid: (ceil(N/TN), ceil(M/TM), 1)
| TG: (128, 1, 1) -- 4 SIMD groups
|
+-- Batched operation?
| Use 3rd dimension for batch
| Grid: (spatial_x, spatial_y, batch)
| TG: depends on inner operation
|
+-- Vector-matrix multiply (GEMV)?
Grid: (ceil(M/TG_SIZE), 1, 1) -- one TG per chunk of rows
TG: (256, 1, 1)
Summary
- Thread hierarchy: Grid -> Threadgroups -> SIMD Groups -> Threads. Each level has different communication capabilities.
- SIMD groups (32 threads) communicate via intrinsics – fast, no barriers.
- Threadgroups (up to 1024 threads) communicate via threadgroup memory + barriers.
- Grid threadgroups cannot communicate during a dispatch.
- Always use multiples of 32 for threadgroup sizes.
- Common sizes: 128 for GEMM, 256 for reductions, 32 for trivial ops.
- dispatchThreadgroups for kernels with threadgroup cooperation (most ML kernels).
- dispatchThreads for simple per-element kernels.
- Two-stage reduction (SIMD reduce + threadgroup reduce) is the fundamental pattern.
- Dispatch enough threadgroups to keep all compute units busy (at least as many as there are compute units).
The next chapter covers the memory model – how data flows between CPU and GPU, and why memory bandwidth is the critical bottleneck for ML inference.