Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

The Metal Shading Language

The Metal Shading Language – MSL for short – is what you write GPU kernels in. If you already know C++, you are going to feel right at home, because MSL is essentially C++14 with some things removed (no exceptions, no virtual functions, no RTTI, no recursive function calls) and a bunch of GPU-specific features added. It is not a weird domain-specific language. It is not a visual graph editor. It is just C++ with extra address spaces, vector types, and thread-coordination primitives bolted on.

In this chapter, we are going to cover the MSL features that matter most for ML kernel development. We will write several complete kernels along the way, starting simple and building up to real patterns used in inference engines.

The Basics: Kernel Functions

A compute kernel in MSL is declared with the kernel keyword (or equivalently, [[kernel]]):

#include <metal_stdlib>
using namespace metal;

kernel void add_arrays(
    device const float* a [[buffer(0)]],
    device const float* b [[buffer(1)]],
    device float* c       [[buffer(2)]],
    uint id               [[thread_position_in_grid]]
) {
    c[id] = a[id] + b[id];
}

Let us break down every piece of this.

  • kernel void: This is a GPU entry point. Kernels always return void – they communicate results by writing to buffers.
  • device const float* a: A pointer into device memory (more on address spaces below). The [[buffer(0)]] attribute tells Metal which buffer binding index this parameter corresponds to.
  • uint id [[thread_position_in_grid]]: A built-in variable that gives this thread its unique ID in the dispatch grid.

That is the most minimal kernel possible: read two values, add them, write the result. Every thread processes one element. A thousand threads process a thousand elements. The GPU runs them all in parallel.

Address Spaces

In regular C++, all pointers live in one big flat address space. On the GPU, memory is divided into distinct address spaces with different performance characteristics and access rules. Every pointer in MSL must be qualified with its address space.

+------------------------------------------------------------------+
|                        GPU Memory Map                            |
+------------------------------------------------------------------+
|                                                                  |
|  device          Large, slow-ish (DRAM / unified memory)         |
|  [read/write]    All threads can access. This is where your      |
|                  weight matrices and activation tensors live.     |
|                  Hundreds of GB/s bandwidth on Apple Silicon.     |
|                                                                  |
+------------------------------------------------------------------+
|                                                                  |
|  constant        Same physical memory as device, but goes        |
|  [read-only]     through the constant cache. Best for small,     |
|                  uniform data read by all threads (like params).  |
|                  Limited to 64KB per argument.                    |
|                                                                  |
+------------------------------------------------------------------+
|                                                                  |
|  threadgroup     Fast on-chip SRAM (like CUDA shared memory).    |
|  [read/write]    Only threads in the same threadgroup can access. |
|                  32KB per threadgroup on Apple Silicon.           |
|                  Much faster than device memory.                  |
|                                                                  |
+------------------------------------------------------------------+
|                                                                  |
|  thread          Registers. Private to each thread.              |
|  [read/write]    Fastest possible access.                        |
|                  Limited by register file size.                   |
|                                                                  |
+------------------------------------------------------------------+

Here is how they look in practice:

kernel void address_space_demo(
    device float* data         [[buffer(0)]],   // device address space
    constant Params& params    [[buffer(1)]],   // constant address space
    threadgroup float* shared  [[threadgroup(0)]], // threadgroup (rarely used this way)
    uint tid [[thread_position_in_grid]]
) {
    // 'thread' address space -- local variable, lives in registers
    float local_val = data[tid];
    
    // Read from constant (cached, good for broadcast reads)
    float scale = params.scale;
    
    // Write to device (main memory)
    data[tid] = local_val * scale;
}

When to Use Each Address Space

Address SpaceUse ForExample
deviceLarge read/write buffersWeight matrices, activations, outputs
constantSmall read-only data broadcast to all threadsKernel parameters, lookup tables
threadgroupShared scratchpad for threads in a groupReduction accumulators, tiled data
threadPer-thread temporariesLoop variables, accumulators

The most common mistake newcomers make is putting everything in device. If you have a small struct of parameters (dimensions, strides, scaling factors), put it in constant – the constant cache will broadcast it to all threads efficiently.

Vector and Matrix Types

MSL provides built-in vector types that map directly to the GPU’s SIMD hardware. These are not library abstractions – they are native types that the hardware processes in a single instruction.

Scalar Types

half   h = 1.0h;    // 16-bit float (FP16)
float  f = 1.0f;    // 32-bit float (FP32)
int    i = 42;      // 32-bit signed integer
uint   u = 42u;     // 32-bit unsigned integer
short  s = 42;      // 16-bit signed integer
ushort us = 42u;    // 16-bit unsigned integer
bool   b = true;    // boolean

Vector Types

Vectors come in sizes 2, 3, and 4:

half2  h2 = half2(1.0h, 2.0h);          // 2 x FP16
half4  h4 = half4(1.0h, 2.0h, 3.0h, 4.0h);  // 4 x FP16
float2 f2 = float2(1.0f, 2.0f);         // 2 x FP32
float4 f4 = float4(1.0f, 2.0f, 3.0f, 4.0f); // 4 x FP32
uint2  u2 = uint2(10, 20);              // 2 x uint32

You can swizzle components (rearrange them) using .xyzw or .rgba:

float4 v = float4(1.0, 2.0, 3.0, 4.0);
float2 xy = v.xy;      // (1.0, 2.0)
float  z  = v.z;       // 3.0
float4 rev = v.wzyx;   // (4.0, 3.0, 2.0, 1.0)

Why Vectors Matter for ML

Loading a half4 from memory is a single 64-bit load instead of four separate 16-bit loads. This is huge for memory-bound kernels (which most ML kernels are):

Single loads (bad):                 Vectorized load (good):
  load h[0]  -- 16 bits              load h4  -- 64 bits, one instruction
  load h[1]  -- 16 bits              
  load h[2]  -- 16 bits              4x fewer load instructions
  load h[3]  -- 16 bits              Better memory bus utilization
  4 instructions total                1 instruction total

We will come back to vectorized loads in the performance chapter. For now, just remember: prefer half4 and float4 over scalar loads whenever possible.

Thread Identification

Every thread in a compute dispatch has several built-in identifiers. Understanding which one to use is crucial.

kernel void thread_id_demo(
    uint tid    [[thread_position_in_grid]],
    uint tpg    [[threads_per_grid]],
    uint gid    [[threadgroup_position_in_grid]],
    uint tgpg   [[threadgroups_per_grid]],
    uint lid    [[thread_position_in_threadgroup]],
    uint tptg   [[threads_per_threadgroup]],
    uint simd_lane [[thread_index_in_simdgroup]],
    uint simd_id   [[simdgroup_index_in_threadgroup]]
) {
    // tid  = global thread ID (0, 1, 2, ..., total_threads - 1)
    // gid  = which threadgroup this thread belongs to
    // lid  = local thread ID within the threadgroup (0, 1, ..., tptg - 1)
    // simd_lane = lane within the SIMD group (0 - 31)
    // simd_id   = which SIMD group within the threadgroup
}

Here is how they relate:

Grid (all threads)
+------------------------------------------------------------------+
| Threadgroup 0 (gid=0)     | Threadgroup 1 (gid=1)     | ...     |
| +------------------------+ | +------------------------+ |         |
| | SIMD Group 0 (simd=0)  | | | SIMD Group 0           | |         |
| | lanes 0..31             | | | lanes 0..31             | |         |
| +------------------------+ | +------------------------+ |         |
| | SIMD Group 1 (simd=1)  | | | SIMD Group 1           | |         |
| | lanes 0..31             | | | lanes 0..31             | |         |
| +------------------------+ | +------------------------+ |         |
| | SIMD Group 2 (simd=2)  | | | SIMD Group 2           | |         |
| | lanes 0..31             | | | lanes 0..31             | |         |
| +------------------------+ | +------------------------+ |         |
| | SIMD Group 3 (simd=3)  | | | SIMD Group 3           | |         |
| | lanes 0..31             | | | lanes 0..31             | |         |
| +------------------------+ | +------------------------+ |         |
+------------------------------------------------------------------+

For threadgroup size = 128:
  lid = simd_id * 32 + simd_lane
  tid = gid * 128 + lid

Which ID Should You Use?

  • thread_position_in_grid: When each thread processes one element of an array. By far the most common.
  • threadgroup_position_in_grid + thread_position_in_threadgroup: When threads cooperate within a threadgroup (reductions, tiled algorithms).
  • thread_index_in_simdgroup + simdgroup_index_in_threadgroup: When you are doing SIMD-level tricks (shuffles, reductions, matrix ops).

The half Type: FP16 for ML

The half type is a 16-bit floating-point number. It has:

  • 1 sign bit
  • 5 exponent bits (range: ~6 x 10^-8 to 65504)
  • 10 mantissa bits (~3.3 decimal digits of precision)

For ML inference, half is the workhorse type. Here is why:

  1. 2x throughput: Apple GPUs can process twice as many half operations per clock compared to float.
  2. 2x memory bandwidth: A half is 2 bytes vs 4 bytes for float. Since most ML kernels are memory-bound, this effectively doubles your throughput.
  3. Sufficient precision: Neural network weights and activations rarely need more than 3 digits of precision. FP16 is more than enough for inference.
// FP16 dot product -- 2x the throughput of FP32
kernel void dot_product_fp16(
    device const half4* a [[buffer(0)]],
    device const half4* b [[buffer(1)]],
    device half* result   [[buffer(2)]],
    uint id [[thread_position_in_grid]]
) {
    // Each thread computes dot product of one half4 pair
    // That is 4 multiply-adds in a single vector instruction
    half4 va = a[id];
    half4 vb = b[id];
    result[id] = dot(va, vb);  // va.x*vb.x + va.y*vb.y + ...
}

Mixed Precision

Sometimes you want FP16 for storage and memory transfer but FP32 for accumulation to avoid precision loss:

kernel void gemv_mixed_precision(
    device const half4* W [[buffer(0)]],   // Weights in FP16
    device const half4* x [[buffer(1)]],   // Input in FP16
    device float* y       [[buffer(2)]],   // Output in FP32
    constant uint& K      [[buffer(3)]],
    uint row [[thread_position_in_grid]]
) {
    float sum = 0.0f;  // Accumulate in FP32 for precision
    
    uint k4 = K / 4;   // Process 4 elements at a time via half4
    for (uint i = 0; i < k4; i++) {
        half4 w = W[row * k4 + i];
        half4 v = x[i];
        // Multiply in FP16, accumulate in FP32
        sum += float(w.x) * float(v.x);
        sum += float(w.y) * float(v.y);
        sum += float(w.z) * float(v.z);
        sum += float(w.w) * float(v.w);
    }
    
    y[row] = sum;
}

This is a common pattern in ML inference: load data as FP16 (saving bandwidth), accumulate in FP32 (preserving precision), and write the result back as whatever the downstream consumer needs.

SIMD Group Intrinsics

Now we get to the really fun stuff. SIMD group intrinsics let threads within a 32-thread SIMD group communicate directly through the register file – no shared memory, no barriers, no synchronization overhead.

On Apple GPUs, a SIMD group (also called a warp on NVIDIA or a wavefront on AMD) is always 32 threads. These 32 threads execute in lockstep – the same instruction, at the same time, on adjacent data. SIMD intrinsics let these threads share values with each other.

simd_sum: Reduction Within a SIMD Group

// Sum 32 values (one from each lane) in a single operation
float lane_value = data[tid];
float total = simd_sum(lane_value);
// Now ALL 32 lanes have the same total

Without simd_sum, you would need a tree reduction with shared memory and barriers. With it, the hardware does it for you in a few cycles.

Before simd_sum:          After simd_sum:
Lane 0:  3.0              Lane 0:  sum of all 32
Lane 1:  1.0              Lane 1:  sum of all 32
Lane 2:  4.0              Lane 2:  sum of all 32
...                       ...
Lane 31: 2.0              Lane 31: sum of all 32

simd_max / simd_min: Finding Extremes

float lane_value = data[tid];
float max_val = simd_max(lane_value);  // Max across all 32 lanes
float min_val = simd_min(lane_value);  // Min across all 32 lanes

This is essential for softmax, which needs the maximum value before computing exponentials.

simd_shuffle: Direct Lane-to-Lane Communication

// Read the value from a specific lane
float val = simd_shuffle(my_value, target_lane);

simd_shuffle lets any lane read any other lane’s value. It is like having a crossbar switch between all 32 registers.

simd_shuffle_down: Shift Values Down

float val = simd_shuffle_down(my_value, delta);
// Lane i gets the value from lane (i + delta)
Before simd_shuffle_down(val, 1):    After:
Lane 0: A                           Lane 0: B  (got lane 1's value)
Lane 1: B                           Lane 1: C  (got lane 2's value)
Lane 2: C                           Lane 2: D  (got lane 3's value)
Lane 3: D                           Lane 3: E  (got lane 4's value)
...                                  ...

This is the building block for prefix sums and sequential reductions.

simd_shuffle_xor: Butterfly Pattern

float val = simd_shuffle_xor(my_value, mask);
// Lane i gets the value from lane (i XOR mask)
simd_shuffle_xor(val, 1):    // Swap adjacent pairs
Lane 0 <-> Lane 1
Lane 2 <-> Lane 3
Lane 4 <-> Lane 5
...

simd_shuffle_xor(val, 2):    // Swap pairs of pairs
Lane 0 <-> Lane 2
Lane 1 <-> Lane 3
Lane 4 <-> Lane 6
...

This butterfly pattern is used in parallel reductions and FFTs.

Putting It Together: SIMD-Level Softmax

Here is a practical example – computing softmax over 32 elements using only SIMD intrinsics (no shared memory):

kernel void softmax_simd(
    device const float* input  [[buffer(0)]],
    device float* output       [[buffer(1)]],
    uint lane [[thread_index_in_simdgroup]],
    uint group [[simdgroup_index_in_threadgroup]],
    uint gid [[threadgroup_position_in_grid]]
) {
    // Each SIMD group processes one row of 32 elements
    uint row = gid * (threads_per_threadgroup / 32) + group;
    
    float val = input[row * 32 + lane];
    
    // Step 1: Find max across all 32 lanes
    float max_val = simd_max(val);
    
    // Step 2: Subtract max and exponentiate (for numerical stability)
    float exp_val = exp(val - max_val);
    
    // Step 3: Sum all exponentials
    float sum_exp = simd_sum(exp_val);
    
    // Step 4: Normalize
    output[row * 32 + lane] = exp_val / sum_exp;
}

That is a complete softmax in about 10 lines, with no shared memory and no barriers. Each SIMD group handles one 32-element row entirely within its register file. You cannot write this more concisely or more efficiently.

Threadgroup Memory and Barriers

When you need cooperation between threads in different SIMD groups (but within the same threadgroup), you use threadgroup memory and barriers.

Threadgroup memory is a fast on-chip SRAM shared by all threads in a threadgroup. Think of it as a programmer-managed cache.

kernel void reduction_with_shared(
    device const float* input [[buffer(0)]],
    device float* output      [[buffer(1)]],
    uint tid  [[thread_position_in_grid]],
    uint lid  [[thread_position_in_threadgroup]],
    uint gid  [[threadgroup_position_in_grid]],
    uint tptg [[threads_per_threadgroup]]
) {
    // Declare threadgroup memory
    threadgroup float shared[256];  // One slot per thread
    
    // Each thread loads one element into shared memory
    shared[lid] = input[tid];
    
    // BARRIER: Wait for all threads to finish writing
    threadgroup_barrier(mem_flags::mem_threadgroup);
    
    // Tree reduction in shared memory
    for (uint stride = tptg / 2; stride > 0; stride /= 2) {
        if (lid < stride) {
            shared[lid] += shared[lid + stride];
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
    
    // Thread 0 writes the result
    if (lid == 0) {
        output[gid] = shared[0];
    }
}

The threadgroup_barrier(mem_flags::mem_threadgroup) call is critical. It does two things:

  1. Execution barrier: All threads in the threadgroup must reach this point before any can proceed.
  2. Memory fence: All writes to threadgroup memory by threads before the barrier are visible to all threads after the barrier.

Without the barrier, you would read stale data – thread 5 might read shared[10] before thread 10 has written to it. Race conditions on the GPU are just as dangerous as on the CPU, and they are much harder to debug.

When to Use Threadgroup Memory vs SIMD Intrinsics

+---------------------------------------+---------------------------------------+
| SIMD Intrinsics                       | Threadgroup Memory                    |
+---------------------------------------+---------------------------------------+
| 32 threads only (one SIMD group)      | Up to 1024 threads (one threadgroup)  |
| No barriers needed (lockstep)         | Barriers required                     |
| Fastest (register-to-register)        | Fast (on-chip SRAM)                   |
| Simple patterns (reduce, broadcast)   | Arbitrary access patterns             |
| Use FIRST if possible                 | Use when SIMD is not enough           |
+---------------------------------------+---------------------------------------+

The rule of thumb: if you can do it with SIMD intrinsics alone, do that. Fall back to threadgroup memory only when you need cross-SIMD-group communication.

Example: RMSNorm Kernel

Let us write a real kernel used in transformer inference. RMSNorm (Root Mean Square Normalization) is used in LLaMA and many modern models:

RMSNorm(x) = x * weight / sqrt(mean(x^2) + epsilon)

Here is the full kernel:

#include <metal_stdlib>
using namespace metal;

constant uint N [[function_constant(0)]];  // hidden dimension

kernel void rmsnorm(
    device const half* input   [[buffer(0)]],
    device const half* weight  [[buffer(1)]],
    device half* output        [[buffer(2)]],
    constant float& epsilon    [[buffer(3)]],
    uint tid  [[thread_position_in_grid]],
    uint lid  [[thread_position_in_threadgroup]],
    uint simd_lane [[thread_index_in_simdgroup]],
    uint simd_id   [[simdgroup_index_in_threadgroup]]
) {
    // Each threadgroup processes one row of N elements
    // Step 1: Compute sum of squares (each thread handles N/tptg elements)
    uint threads = 256;  // threadgroup size
    float sum_sq = 0.0f;
    
    for (uint i = lid; i < N; i += threads) {
        float val = float(input[i]);
        sum_sq += val * val;
    }
    
    // Step 2: Reduce within SIMD group
    sum_sq = simd_sum(sum_sq);
    
    // Step 3: Reduce across SIMD groups using threadgroup memory
    threadgroup float simd_sums[8];  // max 8 SIMD groups (256/32)
    if (simd_lane == 0) {
        simd_sums[simd_id] = sum_sq;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
    
    // First SIMD group reduces the partial sums
    float total = 0.0f;
    if (simd_id == 0 && simd_lane < (threads / 32)) {
        total = simd_sum(simd_sums[simd_lane]);
    }
    
    // Broadcast the RMS scale factor
    threadgroup float rms_scale;
    if (lid == 0) {
        rms_scale = rsqrt(total / float(N) + epsilon);
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
    
    // Step 4: Apply normalization
    float scale = rms_scale;
    for (uint i = lid; i < N; i += threads) {
        float val = float(input[i]);
        output[i] = half(val * scale * float(weight[i]));
    }
}

This kernel demonstrates several important MSL patterns:

  1. Function constants (N) for compile-time specialization.
  2. Mixed precision: Load as half, compute in float, store as half.
  3. Two-stage reduction: First within SIMD groups (simd_sum), then across SIMD groups via threadgroup memory.
  4. Stride loop pattern: Each thread processes multiple elements (for (uint i = lid; i < N; i += threads)).

Example: Vectorized Element-Wise Operations

For simple element-wise operations (add, multiply, activation functions), vectorized loads are the key optimization:

kernel void silu_activation(
    device const half4* input  [[buffer(0)]],
    device half4* output       [[buffer(1)]],
    uint id [[thread_position_in_grid]]
) {
    half4 x = input[id];
    // SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))
    half4 sigmoid_x = half4(1.0h) / (half4(1.0h) + exp(-x));
    output[id] = x * sigmoid_x;
}

By processing half4 (4 elements per thread), each thread does useful work on 8 bytes of data per load. With a threadgroup size of 256, one threadgroup processes 1024 elements. This keeps the memory system saturated.

Example: Fused RoPE (Rotary Position Embeddings)

RoPE is used in virtually every modern LLM. It applies a rotation to pairs of dimensions based on their position in the sequence:

constant uint HEAD_DIM [[function_constant(0)]];
constant float ROPE_BASE [[function_constant(1)]];

kernel void rope(
    device half* q          [[buffer(0)]],  // query tensor
    device half* k          [[buffer(1)]],  // key tensor
    constant uint& seq_pos  [[buffer(2)]],  // position in sequence
    uint tid [[thread_position_in_grid]]
) {
    // Each thread processes one pair of dimensions
    uint pair_idx = tid;  // Which (cos, sin) pair
    uint d = pair_idx % (HEAD_DIM / 2);
    uint head = pair_idx / (HEAD_DIM / 2);
    
    // Compute rotation angle: theta = pos * base^(-2d/dim)
    float freq = 1.0f / pow(ROPE_BASE, float(2 * d) / float(HEAD_DIM));
    float angle = float(seq_pos) * freq;
    float cos_angle = cos(angle);
    float sin_angle = sin(angle);
    
    // Apply rotation to Q
    uint base_idx = head * HEAD_DIM;
    float q0 = float(q[base_idx + d]);
    float q1 = float(q[base_idx + d + HEAD_DIM / 2]);
    q[base_idx + d]              = half(q0 * cos_angle - q1 * sin_angle);
    q[base_idx + d + HEAD_DIM/2] = half(q0 * sin_angle + q1 * cos_angle);
    
    // Apply same rotation to K
    float k0 = float(k[base_idx + d]);
    float k1 = float(k[base_idx + d + HEAD_DIM / 2]);
    k[base_idx + d]              = half(k0 * cos_angle - k1 * sin_angle);
    k[base_idx + d + HEAD_DIM/2] = half(k0 * sin_angle + k1 * cos_angle);
}

Notice the function constants: HEAD_DIM and ROPE_BASE. These are fixed per model (e.g., 128 and 10000.0 for LLaMA), so the compiler can optimize the frequency computation and potentially pre-compute parts of it.

Built-in Math Functions

MSL provides a comprehensive set of math functions. Here are the ones you will use most often in ML kernels:

// Exponential and logarithm
float e = exp(x);        // e^x
float l = log(x);        // natural log
half  e_fast = exp(h);   // FP16 exp -- 2x throughput

// Trigonometric
float s = sin(x);
float c = cos(x);

// Power and roots
float p = pow(base, exponent);
float s = sqrt(x);
float r = rsqrt(x);      // 1/sqrt(x) -- faster than 1.0/sqrt(x)

// Min, max, clamp
float m = min(a, b);
float M = max(a, b);
float c = clamp(x, lo, hi);  // max(lo, min(x, hi))

// Absolute value
float a = abs(x);

// Fused multiply-add (one rounding instead of two)
float f = fma(a, b, c);  // a*b + c

// Dot product of vectors
float d = dot(v1, v2);   // v1.x*v2.x + v1.y*v2.y + ...

// Type conversion (explicit)
half  h = half(f);        // float -> half (may lose precision)
float f = float(h);       // half -> float (lossless)

The rsqrt function deserves special mention. You will see it everywhere in ML kernels (normalization layers, attention scaling). It computes 1/sqrt(x) in a single instruction, which is faster than computing sqrt(x) and then dividing.

Structs and Parameter Passing

For kernels with many parameters, pack them into a struct:

struct GEMMParams {
    uint M;        // rows of A, rows of C
    uint N;        // cols of B, cols of C
    uint K;        // cols of A, rows of B
    float alpha;   // scaling factor
};

kernel void gemm(
    device const half* A  [[buffer(0)]],
    device const half* B  [[buffer(1)]],
    device half* C        [[buffer(2)]],
    constant GEMMParams& params [[buffer(3)]],
    uint2 gid [[threadgroup_position_in_grid]],
    uint lid  [[thread_position_in_threadgroup]]
) {
    uint M = params.M;
    uint N = params.N;
    uint K = params.K;
    // ...
}

On the CPU side, this struct gets passed via setBytes() (for structs under 4KB) or via a MTLBuffer. The struct layout must match between MSL and the host language. This means you need to be careful about alignment – MSL follows standard C++ alignment rules, and your Rust/Swift struct must match.

#![allow(unused)]
fn main() {
// Rust side -- must match the MSL struct layout
#[repr(C)]
struct GEMMParams {
    m: u32,
    n: u32,
    k: u32,
    alpha: f32,
}
}

The #[repr(C)] attribute ensures Rust uses C-compatible layout, which matches MSL’s struct layout.

Conditional Compilation with Function Constants

We touched on function constants in the previous chapter. Here is a deeper look at how they enable conditional compilation in MSL:

constant bool HAS_BIAS    [[function_constant(0)]];
constant bool HAS_RESIDUAL[[function_constant(1)]];
constant uint BLOCK_SIZE  [[function_constant(2)]];

kernel void linear_layer(
    device const half* input   [[buffer(0)]],
    device const half* weights [[buffer(1)]],
    device half* output        [[buffer(2)]],
    device const half* bias    [[buffer(3)]],
    device const half* residual[[buffer(4)]],
    constant uint& N           [[buffer(5)]],
    uint tid [[thread_position_in_grid]]
) {
    half result = /* ... compute matmul ... */ ;
    
    // These branches are resolved at compile time!
    // No runtime branch penalty.
    if (HAS_BIAS) {
        result += bias[tid];
    }
    
    if (HAS_RESIDUAL) {
        result += residual[tid];
    }
    
    output[tid] = result;
}

When HAS_BIAS is false, the compiler completely removes the bias addition and the bias buffer binding. The generated code is identical to a kernel that never had bias support in the first place. This lets you write one kernel source that generates many specialized variants.

Atomic Operations

Sometimes multiple threads need to update the same memory location. MSL provides atomic operations for this:

kernel void histogram(
    device const uint* data       [[buffer(0)]],
    device atomic_uint* bins      [[buffer(1)]],
    uint tid [[thread_position_in_grid]]
) {
    uint value = data[tid];
    uint bin = value % 256;
    atomic_fetch_add_explicit(&bins[bin], 1, memory_order_relaxed);
}

Atomics are slow – they serialize access. Avoid them in hot paths. If you find yourself using atomics in a performance-critical kernel, there is almost certainly a better algorithm (like per-threadgroup histograms followed by a merge).

Putting It All Together: A Complete Softmax Kernel

Let us write a production-quality softmax kernel that handles arbitrary row sizes using all the MSL features we have covered:

#include <metal_stdlib>
using namespace metal;

constant uint COLS [[function_constant(0)]];

kernel void softmax(
    device const half* input  [[buffer(0)]],
    device half* output       [[buffer(1)]],
    uint gid  [[threadgroup_position_in_grid]],
    uint lid  [[thread_position_in_threadgroup]],
    uint simd_lane [[thread_index_in_simdgroup]],
    uint simd_id   [[simdgroup_index_in_threadgroup]]
) {
    // Each threadgroup processes one row
    uint row = gid;
    uint threads = 256;
    
    // Threadgroup memory for cross-SIMD-group reduction
    threadgroup float simd_max_vals[8];
    threadgroup float simd_sum_vals[8];
    
    // ---- Pass 1: Find max ----
    float local_max = -INFINITY;
    for (uint i = lid; i < COLS; i += threads) {
        local_max = max(local_max, float(input[row * COLS + i]));
    }
    
    // Reduce within SIMD group
    local_max = simd_max(local_max);
    
    // Reduce across SIMD groups
    if (simd_lane == 0) {
        simd_max_vals[simd_id] = local_max;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
    
    float row_max;
    if (simd_id == 0) {
        float v = (simd_lane < threads / 32) ? 
                  simd_max_vals[simd_lane] : -INFINITY;
        row_max = simd_max(v);
    }
    
    // Broadcast max to all threads
    threadgroup float shared_max;
    if (lid == 0) shared_max = row_max;
    threadgroup_barrier(mem_flags::mem_threadgroup);
    row_max = shared_max;
    
    // ---- Pass 2: Compute exp and sum ----
    float local_sum = 0.0f;
    for (uint i = lid; i < COLS; i += threads) {
        float val = exp(float(input[row * COLS + i]) - row_max);
        local_sum += val;
    }
    
    // Reduce sum (same two-stage pattern)
    local_sum = simd_sum(local_sum);
    if (simd_lane == 0) {
        simd_sum_vals[simd_id] = local_sum;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
    
    float row_sum;
    if (simd_id == 0) {
        float v = (simd_lane < threads / 32) ? 
                  simd_sum_vals[simd_lane] : 0.0f;
        row_sum = simd_sum(v);
    }
    
    threadgroup float shared_sum;
    if (lid == 0) shared_sum = row_sum;
    threadgroup_barrier(mem_flags::mem_threadgroup);
    row_sum = shared_sum;
    
    // ---- Pass 3: Normalize and write output ----
    float inv_sum = 1.0f / row_sum;
    for (uint i = lid; i < COLS; i += threads) {
        float val = exp(float(input[row * COLS + i]) - row_max);
        output[row * COLS + i] = half(val * inv_sum);
    }
}

This kernel has three passes over the data (find max, compute exponentials and sum, normalize). Each pass uses the two-stage reduction pattern (SIMD-level reduction followed by threadgroup-level reduction). The function constant COLS lets the compiler optimize the loop bounds.

Summary

MSL is a practical language for writing GPU kernels. Here are the key takeaways:

  1. MSL is C++14 with GPU extensions. No surprises if you know C++.
  2. Four address spaces: device (main memory), constant (cached read-only), threadgroup (fast shared SRAM), thread (registers).
  3. Vector types (half4, float4) are essential for memory throughput.
  4. The half type gives you 2x throughput and 2x bandwidth – use it everywhere in ML kernels.
  5. SIMD intrinsics (simd_sum, simd_max, simd_shuffle, etc.) enable fast intra-SIMD-group communication without shared memory or barriers.
  6. Threadgroup memory + barriers enable cross-SIMD-group communication within a threadgroup.
  7. Function constants create specialized kernel variants at compile time.
  8. The two-stage reduction (SIMD reduce, then threadgroup reduce) is the most common pattern in ML kernels.

In the next chapter, we will zoom out from individual threads to look at how threadgroups and SIMD groups are organized and dispatched.