Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Threadgroups, SIMD Groups, and Dispatch

So far we have talked about individual threads and what they can do. But a single GPU thread is weak – it is only useful because there are thousands of them running simultaneously. The way you organize those thousands of threads matters enormously for performance. Get the geometry wrong and you leave half the GPU idle. Get it right and your kernel saturates every execution unit.

This chapter is about the organizational hierarchy of GPU threads in Metal, and how to dispatch them effectively.

The Thread Hierarchy

Metal organizes threads into a three-level hierarchy:

Grid (all threads in the dispatch)
+------------------------------------------------------------------+
|                                                                  |
|  Threadgroup 0         Threadgroup 1         Threadgroup 2       |
|  +-----------------+   +-----------------+   +-----------------+ |
|  | SIMD Group 0    |   | SIMD Group 0    |   | SIMD Group 0    | |
|  | [32 threads]    |   | [32 threads]    |   | [32 threads]    | |
|  +-----------------+   +-----------------+   +-----------------+ |
|  | SIMD Group 1    |   | SIMD Group 1    |   | SIMD Group 1    | |
|  | [32 threads]    |   | [32 threads]    |   | [32 threads]    | |
|  +-----------------+   +-----------------+   +-----------------+ |
|  | SIMD Group 2    |   | SIMD Group 2    |   | SIMD Group 2    | |
|  | [32 threads]    |   | [32 threads]    |   | [32 threads]    | |
|  +-----------------+   +-----------------+   +-----------------+ |
|  | SIMD Group 3    |   | SIMD Group 3    |   | SIMD Group 3    | |
|  | [32 threads]    |   | [32 threads]    |   | [32 threads]    | |
|  +-----------------+   +-----------------+   +-----------------+ |
|                                                                  |
+------------------------------------------------------------------+

Let us define each level.

SIMD Group (32 threads)

A SIMD group is 32 threads that execute in lockstep – the same instruction, at the same clock cycle. This is the fundamental unit of execution on Apple GPUs. You do not choose the SIMD group size; it is always 32.

Threads in a SIMD group can communicate through SIMD intrinsics (simd_sum, simd_shuffle, etc.) with zero overhead – no barriers, no shared memory, just direct register reads. This is the fastest form of inter-thread communication available.

Threadgroup (up to 1024 threads)

A threadgroup is a collection of SIMD groups that are co-scheduled on the same GPU compute unit. All threads in a threadgroup:

  • Share access to threadgroup memory (fast on-chip SRAM)
  • Can synchronize with threadgroup_barrier()
  • Are guaranteed to execute concurrently (they are all resident on the same compute unit)

A threadgroup can contain up to 1024 threads, which means up to 32 SIMD groups. In practice, common sizes are 128 (4 SIMD groups), 256 (8 SIMD groups), or sometimes 512 or 1024.

Grid (the entire dispatch)

The grid is the collection of all threadgroups in a dispatch. Threadgroups in the grid are independent – there is no way for threads in different threadgroups to communicate during a single dispatch (except through device memory, but there are no cross-threadgroup barriers).

Two Ways to Dispatch: Threadgroups vs Threads

Metal gives you two dispatch methods, and the difference is subtle but important.

dispatchThreadgroups

You specify how many threadgroups to launch and how many threads per threadgroup:

// Launch a grid of threadgroups
let threadgroupsPerGrid = MTLSize(width: 16, height: 8, depth: 1)
let threadsPerThreadgroup = MTLSize(width: 256, height: 1, depth: 1)

encoder.dispatchThreadgroups(threadgroupsPerGrid,
                             threadsPerThreadgroup: threadsPerThreadgroup)
// Total threads = 16 * 8 * 256 = 32,768

With this method, you are responsible for computing the grid dimensions. If your problem size is not a perfect multiple of the threadgroup size, you need to handle the boundary condition in the kernel.

dispatchThreads

You specify the total number of threads you want, and Metal figures out the threadgroup count:

// Launch exactly this many threads
let threadsPerGrid = MTLSize(width: 4000, height: 3000, depth: 1)
let threadsPerThreadgroup = MTLSize(width: 16, height: 16, depth: 1)

encoder.dispatchThreads(threadsPerGrid,
                        threadsPerThreadgroup: threadsPerThreadgroup)
// Metal launches enough threadgroups to cover 4000 x 3000,
// and threads outside that range are automatically disabled

With dispatchThreads, Metal handles the boundary for you. If you request 4000 x 3000 threads with 16x16 threadgroups, Metal launches 250 x 188 threadgroups (25016 = 4000, 18816 = 3008), but the extra 8 rows of threads simply do not execute.

Which Should You Use?

For ML kernels, dispatchThreadgroups is almost always the right choice. Here is why:

  1. Explicit control: You know exactly how many threadgroups you are launching, which makes reasoning about resource usage straightforward.
  2. Kernel design: ML kernels are typically designed around threadgroup-level cooperation (tiled GEMM, reductions). The kernel code already assumes a specific threadgroup structure.
  3. No wasted threads: With dispatchThreadgroups, you design your kernel so that every thread does useful work. There are no “out-of-bounds” threads.

dispatchThreads is nice for simple element-wise kernels where you just want one thread per element and do not care about threadgroup structure. But for anything involving shared memory, reductions, or tiling, use dispatchThreadgroups.

Choosing Threadgroup Sizes

The threadgroup size is one of the most important performance decisions you make. Here are the rules:

Rule 1: Always Use Multiples of 32

Since SIMD groups are 32 threads, your threadgroup size should always be a multiple of 32. If you use, say, 100 threads per threadgroup, Metal will still allocate 4 SIMD groups (128 threads worth of resources), but 28 threads in the last SIMD group will be idle. That is 22% waste.

Threadgroup size = 100 (bad):
+----------------+ +----------------+ +----------------+ +----------------+
| SIMD Group 0   | | SIMD Group 1   | | SIMD Group 2   | | SIMD Group 3   |
| 32 active      | | 32 active      | | 32 active      | | 4 active       |
| 0 idle         | | 0 idle         | | 0 idle         | | 28 IDLE        |
+----------------+ +----------------+ +----------------+ +----------------+
                                                            ^^^^^^^^^^^^
                                                            22% waste!

Threadgroup size = 128 (good):
+----------------+ +----------------+ +----------------+ +----------------+
| SIMD Group 0   | | SIMD Group 1   | | SIMD Group 2   | | SIMD Group 3   |
| 32 active      | | 32 active      | | 32 active      | | 32 active      |
| 0 idle         | | 0 idle         | | 0 idle         | | 0 idle         |
+----------------+ +----------------+ +----------------+ +----------------+
                                0% waste

Rule 2: Common Sizes and When to Use Them

Threadgroup SizeSIMD GroupsUse Case
321Simple per-element ops, tiny reductions
642Light cooperation needed
1284Default for GEMM (4 SIMD groups = 4 tiles)
2568Default for reductions, normalization
51216Large reductions, high occupancy kernels
102432Maximum, rarely needed

Rule 3: Balance Occupancy and Resources

Here is the tension: larger threadgroups use more resources (registers, threadgroup memory), which means fewer threadgroups can be resident on a compute unit at the same time. Smaller threadgroups use fewer resources but may not have enough threads for efficient cooperation.

Compute Unit (simplified)
+----------------------------------------------------------+
| Register File: 64KB      Threadgroup Memory: 32KB        |
|                                                          |
| Option A: 4 threadgroups of 128 threads (512 threads)    |
|   Each TG gets: 16KB registers, 8KB shared memory        |
|   Good latency hiding (many threads to switch between)   |
|                                                          |
| Option B: 2 threadgroups of 256 threads (512 threads)    |
|   Each TG gets: 32KB registers, 16KB shared memory       |
|   Same total threads, but more resources per TG          |
|                                                          |
| Option C: 1 threadgroup of 1024 threads (1024 threads)   |
|   Gets all: 64KB registers, 32KB shared memory           |
|   Lots of threads, but if this TG stalls, nothing else   |
|   can run on this compute unit                           |
+----------------------------------------------------------+

The sweet spot for ML kernels is usually 128-256 threads. This gives you enough SIMD groups for cooperation (4-8) while leaving room for multiple threadgroups per compute unit (good for latency hiding).

Rule 4: Consult maxTotalThreadsPerThreadgroup

Every PSO has a maxTotalThreadsPerThreadgroup property that tells you the maximum threadgroup size the kernel supports, based on its resource usage:

let pso = try device.makeComputePipelineState(function: function)
print(pso.maxTotalThreadsPerThreadgroup)  // e.g., 1024, or maybe 512

If your kernel uses a lot of registers or threadgroup memory, this number may be less than 1024. Always check.

SIMD Group Cooperation Patterns

Now let us look at how SIMD groups work together to solve problems. These patterns show up over and over in ML kernels.

Pattern 1: SIMD Reduction

The simplest and most common pattern. Each thread has a value, and you want the sum (or max, or min) across all threads in the SIMD group.

float my_value = data[tid];
float total = simd_sum(my_value);
// All 32 lanes now hold the same total
Lane:  0    1    2    3    ...  31
Val:   3.0  1.0  4.0  1.5  ...  2.7
                    |
              simd_sum
                    |
             total = 87.3 (in all lanes)

Pattern 2: SIMD Broadcast

One lane has a value that all other lanes need.

float value;
if (simd_lane == 0) {
    value = compute_something_expensive();
}
// Broadcast lane 0's value to all lanes
value = simd_broadcast_first(value);
// Or from a specific lane:
value = simd_shuffle(value, source_lane);

Pattern 3: SIMD Prefix Sum (Inclusive Scan)

Each lane gets the sum of all values from lane 0 through its own lane. Useful for compaction, histograms, and stream processing.

float val = data[tid];
float prefix = simd_prefix_inclusive_sum(val);
Lane:    0    1    2    3    4    5    ...
Input:   3    1    4    1    5    9    ...
Output:  3    4    8    9   14   23    ...
         ^    ^    ^
         |    |    +-- 3+1+4
         |    +-- 3+1
         +-- 3

Pattern 4: SIMD Shuffle for Data Reuse

When adjacent threads need overlapping data (like in convolution), shuffles avoid redundant memory loads:

// Each lane loads one element
float my_val = data[base + simd_lane];

// Now lane i can access neighboring values without memory loads:
float left  = simd_shuffle_up(my_val, 1);   // lane i gets lane (i-1)'s value
float right = simd_shuffle_down(my_val, 1); // lane i gets lane (i+1)'s value

// Stencil computation (e.g., 1D convolution with kernel [0.25, 0.5, 0.25])
float result = 0.25f * left + 0.5f * my_val + 0.25f * right;
Before shuffle:
Lane:   0    1    2    3    4    ...
Val:    A    B    C    D    E    ...

After simd_shuffle_down(val, 1):
Lane:   0    1    2    3    4    ...
Val:    B    C    D    E    F    ...
        ^    ^    ^
        Each lane got the next lane's value

Pattern 5: Two-Stage Reduction (Cross-SIMD-Group)

When you need a reduction across the entire threadgroup (not just one SIMD group), you do it in two stages:

Stage 1: SIMD-level reduction (fast, no barriers)
+------------------+  +------------------+  +------------------+  +------------------+
| SIMD Group 0     |  | SIMD Group 1     |  | SIMD Group 2     |  | SIMD Group 3     |
| 32 values        |  | 32 values        |  | 32 values        |  | 32 values        |
| --> simd_sum     |  | --> simd_sum     |  | --> simd_sum     |  | --> simd_sum     |
| Result: S0       |  | Result: S1       |  | Result: S2       |  | Result: S3       |
+------------------+  +------------------+  +------------------+  +------------------+

Stage 2: Write partial sums to threadgroup memory, then reduce
+-------------------------------------------+
| threadgroup float partials[4];            |
| partials = [S0, S1, S2, S3]              |
|                                           |
| threadgroup_barrier(...)                  |
|                                           |
| SIMD Group 0 reads all 4 values          |
| --> simd_sum(partials[simd_lane])         |
| Result: S0 + S1 + S2 + S3 = TOTAL        |
+-------------------------------------------+

Here is the code:

// Assume threadgroup size = 128 (4 SIMD groups)
threadgroup float partials[4];

// Stage 1: Each SIMD group reduces its 32 values
float local_sum = simd_sum(my_value);

// One thread per SIMD group writes to shared memory
if (simd_lane == 0) {
    partials[simd_id] = local_sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);

// Stage 2: First SIMD group reduces the partial sums
float total;
if (simd_id == 0 && simd_lane < 4) {
    total = simd_sum(partials[simd_lane]);
}

// Broadcast to all threads if needed
threadgroup float shared_total;
if (lid == 0) {
    shared_total = total;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
total = shared_total;

This two-stage pattern is the workhorse of ML kernels. You will find it in normalization (RMSNorm, LayerNorm), softmax, and anywhere else you need a global reduction within a threadgroup.

Thread Mapping: From Grid to Data

One of the most important design decisions is how you map threads to data. There are several common patterns.

Pattern A: One Thread Per Element

The simplest mapping. Each thread processes exactly one element:

kernel void scale(
    device half* data [[buffer(0)]],
    constant float& factor [[buffer(1)]],
    uint tid [[thread_position_in_grid]]
) {
    data[tid] = half(float(data[tid]) * factor);
}
Thread:  0    1    2    3    4    5    6    7  ...
Data:   [0]  [1]  [2]  [3]  [4]  [5]  [6]  [7] ...
         |    |    |    |    |    |    |    |
         v    v    v    v    v    v    v    v

Dispatch: one thread per element, trivial.

Pattern B: One Thread Per Vector (Vectorized)

Each thread processes a vector of elements:

kernel void scale_vec4(
    device half4* data [[buffer(0)]],
    constant float& factor [[buffer(1)]],
    uint tid [[thread_position_in_grid]]
) {
    half4 v = data[tid];
    data[tid] = half4(float4(v) * factor);
}
Thread:  0         1         2         3      ...
Data:   [0,1,2,3] [4,5,6,7] [8,9,10,11] [12,13,14,15] ...
         |         |         |            |
         v         v         v            v

Dispatch: one thread per 4 elements. Total threads = N / 4.

Pattern C: Stride Loop (One Threadgroup Per Row)

Each threadgroup processes an entire row, with threads striding through the elements:

kernel void row_sum(
    device const float* matrix [[buffer(0)]],
    device float* sums [[buffer(1)]],
    constant uint& cols [[buffer(2)]],
    uint gid [[threadgroup_position_in_grid]],
    uint lid [[thread_position_in_threadgroup]]
) {
    uint row = gid;
    float sum = 0.0f;
    
    // Each thread strides through the row
    for (uint col = lid; col < cols; col += 256) {
        sum += matrix[row * cols + col];
    }
    
    // Reduce within threadgroup...
    sum = simd_sum(sum);
    // ... (two-stage reduction as before)
    
    if (lid == 0) {
        sums[row] = sum;
    }
}
Row 0 (threadgroup 0):
Thread 0 reads:  [0], [256], [512], [768], ...
Thread 1 reads:  [1], [257], [513], [769], ...
Thread 2 reads:  [2], [258], [514], [770], ...
...
Thread 255 reads: [255], [511], [767], [1023], ...

Dispatch: one threadgroup per row.

Pattern D: 2D Tiling (GEMM)

For matrix multiplication, threads are organized in a 2D grid of tiles:

kernel void gemm_tiled(
    device const half* A [[buffer(0)]],
    device const half* B [[buffer(1)]],
    device half* C       [[buffer(2)]],
    constant uint& M     [[buffer(3)]],
    constant uint& N     [[buffer(4)]],
    constant uint& K     [[buffer(5)]],
    uint2 gid [[threadgroup_position_in_grid]],
    uint  lid [[thread_position_in_threadgroup]],
    uint  simd_id [[simdgroup_index_in_threadgroup]]
) {
    // Each threadgroup computes a TM x TN tile of C
    // gid.y = tile row, gid.x = tile column
    uint tile_row = gid.y;  // Which TM-row block
    uint tile_col = gid.x;  // Which TN-column block
    
    // Each SIMD group within the threadgroup computes a sub-tile
    // ...
}
Matrix C (M x N):
+--------+--------+--------+--------+
| TG(0,0)| TG(0,1)| TG(0,2)| TG(0,3)|  <-- threadgroup row 0
| 32x64  | 32x64  | 32x64  | 32x64  |
+--------+--------+--------+--------+
| TG(1,0)| TG(1,1)| TG(1,2)| TG(1,3)|  <-- threadgroup row 1
| 32x64  | 32x64  | 32x64  | 32x64  |
+--------+--------+--------+--------+
| TG(2,0)| TG(2,1)| TG(2,2)| TG(2,3)|
| 32x64  | 32x64  | 32x64  | 32x64  |
+--------+--------+--------+--------+

Each threadgroup computes a TM x TN tile of C (e.g., 32 x 64).
Grid dimensions: (N/TN, M/TM, 1) threadgroups.

This is the tiling pattern used in akunu’s GEMM kernels. We will cover it in detail in the SIMD matrix operations chapter.

Dispatch Geometry Calculation

Let us put this together with concrete dispatch calculations for common ML operations.

Element-wise Operations (ReLU, SiLU, Add)

Problem: Apply activation to N elements
Approach: One thread per element (or per half4)

Vectorized (half4):
  threads_needed = N / 4
  threadgroup_size = 256
  threadgroups = ceil(threads_needed / 256)

Example: N = 4096
  threads_needed = 1024
  threadgroup_size = 256
  threadgroups = 4
  Grid: MTLSize(4, 1, 1)
  TG:   MTLSize(256, 1, 1)

Row-wise Operations (Softmax, RMSNorm)

Problem: Process M rows of N elements each
Approach: One threadgroup per row

  threadgroup_size = 256
  threadgroups = M

Example: M = 32 (batch), N = 4096 (hidden dim)
  threadgroups = 32
  Grid: MTLSize(32, 1, 1)
  TG:   MTLSize(256, 1, 1)
  
  Each of the 256 threads in a threadgroup strides through
  4096 / 256 = 16 elements.

Matrix Multiplication (GEMM)

Problem: C[M,N] = A[M,K] * B[K,N]
Approach: 2D grid, each threadgroup computes a TM x TN tile

  TM = 32, TN = 64 (tile sizes)
  threadgroup_size = 128 (4 SIMD groups)
  threadgroups_x = ceil(N / TN)
  threadgroups_y = ceil(M / TM)

Example: M = 4096, N = 4096, K = 4096
  threadgroups_x = 4096 / 64 = 64
  threadgroups_y = 4096 / 32 = 128
  Grid: MTLSize(64, 128, 1)
  TG:   MTLSize(128, 1, 1)
  Total threadgroups: 8192
  Total threads: 8192 * 128 = 1,048,576

Batched Operations

For batched operations, you can use the third grid dimension:

Problem: Batch of B matrices, each M x N, apply softmax per row
Approach: 3D grid -- batch x rows x 1

  Grid: MTLSize(M, B, 1)
  TG:   MTLSize(256, 1, 1)

Kernel sees:
  batch_idx = threadgroup_position_in_grid.y
  row_idx   = threadgroup_position_in_grid.x

Real-World Example: Dispatch for Multi-Head Attention

Let us trace through the dispatch geometry for multi-head attention, a critical component of transformer inference.

Suppose we have:

  • Batch size B = 1 (single sequence, typical for inference)
  • Number of heads H = 32
  • Sequence length S = 2048
  • Head dimension D = 128

The attention computation involves:

  1. Q * K^T -> scores [H, 1, S] (for single-token generation)
  2. softmax(scores / sqrt(D)) -> weights [H, 1, S]
  3. weights * V -> output [H, 1, D]
Step 1: Score computation (GEMV -- each head is a vector-matrix multiply)
  For each head: q[1, D] * K[S, D]^T -> scores[1, S]
  
  Grid: MTLSize(ceil(S/256), H, 1)   -- one TG row per head
  TG:   MTLSize(256, 1, 1)
  
  Each threadgroup computes 256 elements of the score vector
  for one attention head.

Step 2: Softmax over scores
  For each head: softmax(scores[1, S])
  
  Grid: MTLSize(H, 1, 1)   -- one TG per head
  TG:   MTLSize(256, 1, 1)
  
  Each threadgroup processes one row (one head's scores).

Step 3: Weighted sum (another GEMV)
  For each head: weights[1, S] * V[S, D] -> output[1, D]
  
  Grid: MTLSize(ceil(D/32), H, 1)
  TG:   MTLSize(256, 1, 1)

In practice, akunu fuses some of these steps together to reduce memory traffic, but the dispatch geometry follows this general pattern.

Common Pitfalls

Pitfall 1: Threadgroup Size Not a Multiple of 32

// BAD: wastes 22% of GPU resources
encoder.dispatchThreadgroups(grid, threadsPerThreadgroup: MTLSize(100, 1, 1))

// GOOD: full utilization
encoder.dispatchThreadgroups(grid, threadsPerThreadgroup: MTLSize(128, 1, 1))

Pitfall 2: Too Few Threadgroups

If you dispatch only 2 threadgroups on a GPU with 10 compute units, 8 compute units sit idle. You want at least as many threadgroups as compute units, and ideally several times more for latency hiding.

M2 Ultra: 76 compute units
Minimum threadgroups for full utilization: 76
Better: 76 * 4 = 304+ threadgroups (multiple waves)

Pitfall 3: Forgetting Bounds Checks

When using dispatchThreadgroups, your total thread count may exceed your data size. Always check bounds:

kernel void safe_scale(
    device half* data [[buffer(0)]],
    constant uint& count [[buffer(1)]],
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= count) return;  // Bounds check!
    data[tid] = data[tid] * 2.0h;
}

Pitfall 4: Divergent Execution Within a SIMD Group

All 32 threads in a SIMD group execute together. If some threads take one branch and others take a different branch, the SIMD group executes both branches (masking inactive threads). This is called divergence and it wastes cycles:

// BAD: divergent branches within a SIMD group
if (tid % 2 == 0) {
    // Even threads do expensive path A
    result = expensive_computation_A(data[tid]);
} else {
    // Odd threads do expensive path B
    result = expensive_computation_B(data[tid]);
}
// Both paths execute for every SIMD group -- 2x the cost!

// BETTER: organize work so entire SIMD groups take the same branch
uint group = tid / 32;
if (group % 2 == 0) {
    // Entire SIMD groups do path A
    result = expensive_computation_A(data[tid]);
} else {
    // Entire SIMD groups do path B
    result = expensive_computation_B(data[tid]);
}

Pitfall 5: Not Accounting for Non-Uniform Threadgroups

When using dispatchThreads, the last threadgroup in each dimension may be smaller than requested. If your kernel assumes a fixed threadgroup size (e.g., uses threadgroup float shared[256]), it will still work, but some of those shared memory slots will contain garbage. Be careful when reducing – only reduce over the actual active thread count.

Visualizing Execution on Apple Silicon

Let us trace how a dispatch actually executes on an Apple Silicon GPU.

An M1 Pro has 16 compute units. Each compute unit can execute multiple threadgroups concurrently (depending on resource usage). Here is a simplified timeline:

Dispatch: 64 threadgroups, 128 threads each (4 SIMD groups per TG)

Compute Unit 0:  [TG 0][TG 16][TG 32][TG 48]  (4 TGs across time)
Compute Unit 1:  [TG 1][TG 17][TG 33][TG 49]
Compute Unit 2:  [TG 2][TG 18][TG 34][TG 50]
...
Compute Unit 15: [TG 15][TG 31][TG 47][TG 63]

Wave 0: TG 0-15  (one TG per compute unit)
Wave 1: TG 16-31 (launched as wave 0 TGs finish)
Wave 2: TG 32-47
Wave 3: TG 48-63

Total: 4 waves to process all 64 threadgroups

In reality, the scheduler is more dynamic than this – it can have multiple threadgroups resident per compute unit if there are enough resources. But the wave model gives you the right intuition. More threadgroups = more waves = more opportunity for the GPU to hide latency.

Choosing Grid Dimensions for ML Workloads

Here is a decision flowchart for common ML operations:

What kind of operation?
    |
    +-- Element-wise (ReLU, add, scale)?
    |     Threads per element: 1 (or 1 per half4 for vectorization)
    |     Grid: (ceil(N/4/TG_SIZE), 1, 1)
    |     TG: (256, 1, 1)
    |
    +-- Row-wise reduction (softmax, norm)?
    |     One threadgroup per row
    |     Grid: (num_rows, 1, 1)
    |     TG: (256, 1, 1)
    |
    +-- Matrix multiply (GEMM)?
    |     2D grid of tiles
    |     Grid: (ceil(N/TN), ceil(M/TM), 1)
    |     TG: (128, 1, 1)   -- 4 SIMD groups
    |
    +-- Batched operation?
    |     Use 3rd dimension for batch
    |     Grid: (spatial_x, spatial_y, batch)
    |     TG: depends on inner operation
    |
    +-- Vector-matrix multiply (GEMV)?
          Grid: (ceil(M/TG_SIZE), 1, 1)   -- one TG per chunk of rows
          TG: (256, 1, 1)

Summary

  1. Thread hierarchy: Grid -> Threadgroups -> SIMD Groups -> Threads. Each level has different communication capabilities.
  2. SIMD groups (32 threads) communicate via intrinsics – fast, no barriers.
  3. Threadgroups (up to 1024 threads) communicate via threadgroup memory + barriers.
  4. Grid threadgroups cannot communicate during a dispatch.
  5. Always use multiples of 32 for threadgroup sizes.
  6. Common sizes: 128 for GEMM, 256 for reductions, 32 for trivial ops.
  7. dispatchThreadgroups for kernels with threadgroup cooperation (most ML kernels).
  8. dispatchThreads for simple per-element kernels.
  9. Two-stage reduction (SIMD reduce + threadgroup reduce) is the fundamental pattern.
  10. Dispatch enough threadgroups to keep all compute units busy (at least as many as there are compute units).

The next chapter covers the memory model – how data flows between CPU and GPU, and why memory bandwidth is the critical bottleneck for ML inference.