The Metal Shading Language
The Metal Shading Language – MSL for short – is what you write GPU kernels in. If you already know C++, you are going to feel right at home, because MSL is essentially C++14 with some things removed (no exceptions, no virtual functions, no RTTI, no recursive function calls) and a bunch of GPU-specific features added. It is not a weird domain-specific language. It is not a visual graph editor. It is just C++ with extra address spaces, vector types, and thread-coordination primitives bolted on.
In this chapter, we are going to cover the MSL features that matter most for ML kernel development. We will write several complete kernels along the way, starting simple and building up to real patterns used in inference engines.
The Basics: Kernel Functions
A compute kernel in MSL is declared with the kernel keyword (or equivalently, [[kernel]]):
#include <metal_stdlib>
using namespace metal;
kernel void add_arrays(
device const float* a [[buffer(0)]],
device const float* b [[buffer(1)]],
device float* c [[buffer(2)]],
uint id [[thread_position_in_grid]]
) {
c[id] = a[id] + b[id];
}
Let us break down every piece of this.
kernel void: This is a GPU entry point. Kernels always returnvoid– they communicate results by writing to buffers.device const float* a: A pointer into device memory (more on address spaces below). The[[buffer(0)]]attribute tells Metal which buffer binding index this parameter corresponds to.uint id [[thread_position_in_grid]]: A built-in variable that gives this thread its unique ID in the dispatch grid.
That is the most minimal kernel possible: read two values, add them, write the result. Every thread processes one element. A thousand threads process a thousand elements. The GPU runs them all in parallel.
Address Spaces
In regular C++, all pointers live in one big flat address space. On the GPU, memory is divided into distinct address spaces with different performance characteristics and access rules. Every pointer in MSL must be qualified with its address space.
+------------------------------------------------------------------+
| GPU Memory Map |
+------------------------------------------------------------------+
| |
| device Large, slow-ish (DRAM / unified memory) |
| [read/write] All threads can access. This is where your |
| weight matrices and activation tensors live. |
| Hundreds of GB/s bandwidth on Apple Silicon. |
| |
+------------------------------------------------------------------+
| |
| constant Same physical memory as device, but goes |
| [read-only] through the constant cache. Best for small, |
| uniform data read by all threads (like params). |
| Limited to 64KB per argument. |
| |
+------------------------------------------------------------------+
| |
| threadgroup Fast on-chip SRAM (like CUDA shared memory). |
| [read/write] Only threads in the same threadgroup can access. |
| 32KB per threadgroup on Apple Silicon. |
| Much faster than device memory. |
| |
+------------------------------------------------------------------+
| |
| thread Registers. Private to each thread. |
| [read/write] Fastest possible access. |
| Limited by register file size. |
| |
+------------------------------------------------------------------+
Here is how they look in practice:
kernel void address_space_demo(
device float* data [[buffer(0)]], // device address space
constant Params& params [[buffer(1)]], // constant address space
threadgroup float* shared [[threadgroup(0)]], // threadgroup (rarely used this way)
uint tid [[thread_position_in_grid]]
) {
// 'thread' address space -- local variable, lives in registers
float local_val = data[tid];
// Read from constant (cached, good for broadcast reads)
float scale = params.scale;
// Write to device (main memory)
data[tid] = local_val * scale;
}
When to Use Each Address Space
| Address Space | Use For | Example |
|---|---|---|
device | Large read/write buffers | Weight matrices, activations, outputs |
constant | Small read-only data broadcast to all threads | Kernel parameters, lookup tables |
threadgroup | Shared scratchpad for threads in a group | Reduction accumulators, tiled data |
thread | Per-thread temporaries | Loop variables, accumulators |
The most common mistake newcomers make is putting everything in device. If you have a small struct of parameters (dimensions, strides, scaling factors), put it in constant – the constant cache will broadcast it to all threads efficiently.
Vector and Matrix Types
MSL provides built-in vector types that map directly to the GPU’s SIMD hardware. These are not library abstractions – they are native types that the hardware processes in a single instruction.
Scalar Types
half h = 1.0h; // 16-bit float (FP16)
float f = 1.0f; // 32-bit float (FP32)
int i = 42; // 32-bit signed integer
uint u = 42u; // 32-bit unsigned integer
short s = 42; // 16-bit signed integer
ushort us = 42u; // 16-bit unsigned integer
bool b = true; // boolean
Vector Types
Vectors come in sizes 2, 3, and 4:
half2 h2 = half2(1.0h, 2.0h); // 2 x FP16
half4 h4 = half4(1.0h, 2.0h, 3.0h, 4.0h); // 4 x FP16
float2 f2 = float2(1.0f, 2.0f); // 2 x FP32
float4 f4 = float4(1.0f, 2.0f, 3.0f, 4.0f); // 4 x FP32
uint2 u2 = uint2(10, 20); // 2 x uint32
You can swizzle components (rearrange them) using .xyzw or .rgba:
float4 v = float4(1.0, 2.0, 3.0, 4.0);
float2 xy = v.xy; // (1.0, 2.0)
float z = v.z; // 3.0
float4 rev = v.wzyx; // (4.0, 3.0, 2.0, 1.0)
Why Vectors Matter for ML
Loading a half4 from memory is a single 64-bit load instead of four separate 16-bit loads. This is huge for memory-bound kernels (which most ML kernels are):
Single loads (bad): Vectorized load (good):
load h[0] -- 16 bits load h4 -- 64 bits, one instruction
load h[1] -- 16 bits
load h[2] -- 16 bits 4x fewer load instructions
load h[3] -- 16 bits Better memory bus utilization
4 instructions total 1 instruction total
We will come back to vectorized loads in the performance chapter. For now, just remember: prefer half4 and float4 over scalar loads whenever possible.
Thread Identification
Every thread in a compute dispatch has several built-in identifiers. Understanding which one to use is crucial.
kernel void thread_id_demo(
uint tid [[thread_position_in_grid]],
uint tpg [[threads_per_grid]],
uint gid [[threadgroup_position_in_grid]],
uint tgpg [[threadgroups_per_grid]],
uint lid [[thread_position_in_threadgroup]],
uint tptg [[threads_per_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]]
) {
// tid = global thread ID (0, 1, 2, ..., total_threads - 1)
// gid = which threadgroup this thread belongs to
// lid = local thread ID within the threadgroup (0, 1, ..., tptg - 1)
// simd_lane = lane within the SIMD group (0 - 31)
// simd_id = which SIMD group within the threadgroup
}
Here is how they relate:
Grid (all threads)
+------------------------------------------------------------------+
| Threadgroup 0 (gid=0) | Threadgroup 1 (gid=1) | ... |
| +------------------------+ | +------------------------+ | |
| | SIMD Group 0 (simd=0) | | | SIMD Group 0 | | |
| | lanes 0..31 | | | lanes 0..31 | | |
| +------------------------+ | +------------------------+ | |
| | SIMD Group 1 (simd=1) | | | SIMD Group 1 | | |
| | lanes 0..31 | | | lanes 0..31 | | |
| +------------------------+ | +------------------------+ | |
| | SIMD Group 2 (simd=2) | | | SIMD Group 2 | | |
| | lanes 0..31 | | | lanes 0..31 | | |
| +------------------------+ | +------------------------+ | |
| | SIMD Group 3 (simd=3) | | | SIMD Group 3 | | |
| | lanes 0..31 | | | lanes 0..31 | | |
| +------------------------+ | +------------------------+ | |
+------------------------------------------------------------------+
For threadgroup size = 128:
lid = simd_id * 32 + simd_lane
tid = gid * 128 + lid
Which ID Should You Use?
thread_position_in_grid: When each thread processes one element of an array. By far the most common.threadgroup_position_in_grid+thread_position_in_threadgroup: When threads cooperate within a threadgroup (reductions, tiled algorithms).thread_index_in_simdgroup+simdgroup_index_in_threadgroup: When you are doing SIMD-level tricks (shuffles, reductions, matrix ops).
The half Type: FP16 for ML
The half type is a 16-bit floating-point number. It has:
- 1 sign bit
- 5 exponent bits (range: ~6 x 10^-8 to 65504)
- 10 mantissa bits (~3.3 decimal digits of precision)
For ML inference, half is the workhorse type. Here is why:
- 2x throughput: Apple GPUs can process twice as many
halfoperations per clock compared tofloat. - 2x memory bandwidth: A
halfis 2 bytes vs 4 bytes forfloat. Since most ML kernels are memory-bound, this effectively doubles your throughput. - Sufficient precision: Neural network weights and activations rarely need more than 3 digits of precision. FP16 is more than enough for inference.
// FP16 dot product -- 2x the throughput of FP32
kernel void dot_product_fp16(
device const half4* a [[buffer(0)]],
device const half4* b [[buffer(1)]],
device half* result [[buffer(2)]],
uint id [[thread_position_in_grid]]
) {
// Each thread computes dot product of one half4 pair
// That is 4 multiply-adds in a single vector instruction
half4 va = a[id];
half4 vb = b[id];
result[id] = dot(va, vb); // va.x*vb.x + va.y*vb.y + ...
}
Mixed Precision
Sometimes you want FP16 for storage and memory transfer but FP32 for accumulation to avoid precision loss:
kernel void gemv_mixed_precision(
device const half4* W [[buffer(0)]], // Weights in FP16
device const half4* x [[buffer(1)]], // Input in FP16
device float* y [[buffer(2)]], // Output in FP32
constant uint& K [[buffer(3)]],
uint row [[thread_position_in_grid]]
) {
float sum = 0.0f; // Accumulate in FP32 for precision
uint k4 = K / 4; // Process 4 elements at a time via half4
for (uint i = 0; i < k4; i++) {
half4 w = W[row * k4 + i];
half4 v = x[i];
// Multiply in FP16, accumulate in FP32
sum += float(w.x) * float(v.x);
sum += float(w.y) * float(v.y);
sum += float(w.z) * float(v.z);
sum += float(w.w) * float(v.w);
}
y[row] = sum;
}
This is a common pattern in ML inference: load data as FP16 (saving bandwidth), accumulate in FP32 (preserving precision), and write the result back as whatever the downstream consumer needs.
SIMD Group Intrinsics
Now we get to the really fun stuff. SIMD group intrinsics let threads within a 32-thread SIMD group communicate directly through the register file – no shared memory, no barriers, no synchronization overhead.
On Apple GPUs, a SIMD group (also called a warp on NVIDIA or a wavefront on AMD) is always 32 threads. These 32 threads execute in lockstep – the same instruction, at the same time, on adjacent data. SIMD intrinsics let these threads share values with each other.
simd_sum: Reduction Within a SIMD Group
// Sum 32 values (one from each lane) in a single operation
float lane_value = data[tid];
float total = simd_sum(lane_value);
// Now ALL 32 lanes have the same total
Without simd_sum, you would need a tree reduction with shared memory and barriers. With it, the hardware does it for you in a few cycles.
Before simd_sum: After simd_sum:
Lane 0: 3.0 Lane 0: sum of all 32
Lane 1: 1.0 Lane 1: sum of all 32
Lane 2: 4.0 Lane 2: sum of all 32
... ...
Lane 31: 2.0 Lane 31: sum of all 32
simd_max / simd_min: Finding Extremes
float lane_value = data[tid];
float max_val = simd_max(lane_value); // Max across all 32 lanes
float min_val = simd_min(lane_value); // Min across all 32 lanes
This is essential for softmax, which needs the maximum value before computing exponentials.
simd_shuffle: Direct Lane-to-Lane Communication
// Read the value from a specific lane
float val = simd_shuffle(my_value, target_lane);
simd_shuffle lets any lane read any other lane’s value. It is like having a crossbar switch between all 32 registers.
simd_shuffle_down: Shift Values Down
float val = simd_shuffle_down(my_value, delta);
// Lane i gets the value from lane (i + delta)
Before simd_shuffle_down(val, 1): After:
Lane 0: A Lane 0: B (got lane 1's value)
Lane 1: B Lane 1: C (got lane 2's value)
Lane 2: C Lane 2: D (got lane 3's value)
Lane 3: D Lane 3: E (got lane 4's value)
... ...
This is the building block for prefix sums and sequential reductions.
simd_shuffle_xor: Butterfly Pattern
float val = simd_shuffle_xor(my_value, mask);
// Lane i gets the value from lane (i XOR mask)
simd_shuffle_xor(val, 1): // Swap adjacent pairs
Lane 0 <-> Lane 1
Lane 2 <-> Lane 3
Lane 4 <-> Lane 5
...
simd_shuffle_xor(val, 2): // Swap pairs of pairs
Lane 0 <-> Lane 2
Lane 1 <-> Lane 3
Lane 4 <-> Lane 6
...
This butterfly pattern is used in parallel reductions and FFTs.
Putting It Together: SIMD-Level Softmax
Here is a practical example – computing softmax over 32 elements using only SIMD intrinsics (no shared memory):
kernel void softmax_simd(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
uint lane [[thread_index_in_simdgroup]],
uint group [[simdgroup_index_in_threadgroup]],
uint gid [[threadgroup_position_in_grid]]
) {
// Each SIMD group processes one row of 32 elements
uint row = gid * (threads_per_threadgroup / 32) + group;
float val = input[row * 32 + lane];
// Step 1: Find max across all 32 lanes
float max_val = simd_max(val);
// Step 2: Subtract max and exponentiate (for numerical stability)
float exp_val = exp(val - max_val);
// Step 3: Sum all exponentials
float sum_exp = simd_sum(exp_val);
// Step 4: Normalize
output[row * 32 + lane] = exp_val / sum_exp;
}
That is a complete softmax in about 10 lines, with no shared memory and no barriers. Each SIMD group handles one 32-element row entirely within its register file. You cannot write this more concisely or more efficiently.
Threadgroup Memory and Barriers
When you need cooperation between threads in different SIMD groups (but within the same threadgroup), you use threadgroup memory and barriers.
Threadgroup memory is a fast on-chip SRAM shared by all threads in a threadgroup. Think of it as a programmer-managed cache.
kernel void reduction_with_shared(
device const float* input [[buffer(0)]],
device float* output [[buffer(1)]],
uint tid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint gid [[threadgroup_position_in_grid]],
uint tptg [[threads_per_threadgroup]]
) {
// Declare threadgroup memory
threadgroup float shared[256]; // One slot per thread
// Each thread loads one element into shared memory
shared[lid] = input[tid];
// BARRIER: Wait for all threads to finish writing
threadgroup_barrier(mem_flags::mem_threadgroup);
// Tree reduction in shared memory
for (uint stride = tptg / 2; stride > 0; stride /= 2) {
if (lid < stride) {
shared[lid] += shared[lid + stride];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Thread 0 writes the result
if (lid == 0) {
output[gid] = shared[0];
}
}
The threadgroup_barrier(mem_flags::mem_threadgroup) call is critical. It does two things:
- Execution barrier: All threads in the threadgroup must reach this point before any can proceed.
- Memory fence: All writes to threadgroup memory by threads before the barrier are visible to all threads after the barrier.
Without the barrier, you would read stale data – thread 5 might read shared[10] before thread 10 has written to it. Race conditions on the GPU are just as dangerous as on the CPU, and they are much harder to debug.
When to Use Threadgroup Memory vs SIMD Intrinsics
+---------------------------------------+---------------------------------------+
| SIMD Intrinsics | Threadgroup Memory |
+---------------------------------------+---------------------------------------+
| 32 threads only (one SIMD group) | Up to 1024 threads (one threadgroup) |
| No barriers needed (lockstep) | Barriers required |
| Fastest (register-to-register) | Fast (on-chip SRAM) |
| Simple patterns (reduce, broadcast) | Arbitrary access patterns |
| Use FIRST if possible | Use when SIMD is not enough |
+---------------------------------------+---------------------------------------+
The rule of thumb: if you can do it with SIMD intrinsics alone, do that. Fall back to threadgroup memory only when you need cross-SIMD-group communication.
Example: RMSNorm Kernel
Let us write a real kernel used in transformer inference. RMSNorm (Root Mean Square Normalization) is used in LLaMA and many modern models:
RMSNorm(x) = x * weight / sqrt(mean(x^2) + epsilon)
Here is the full kernel:
#include <metal_stdlib>
using namespace metal;
constant uint N [[function_constant(0)]]; // hidden dimension
kernel void rmsnorm(
device const half* input [[buffer(0)]],
device const half* weight [[buffer(1)]],
device half* output [[buffer(2)]],
constant float& epsilon [[buffer(3)]],
uint tid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]]
) {
// Each threadgroup processes one row of N elements
// Step 1: Compute sum of squares (each thread handles N/tptg elements)
uint threads = 256; // threadgroup size
float sum_sq = 0.0f;
for (uint i = lid; i < N; i += threads) {
float val = float(input[i]);
sum_sq += val * val;
}
// Step 2: Reduce within SIMD group
sum_sq = simd_sum(sum_sq);
// Step 3: Reduce across SIMD groups using threadgroup memory
threadgroup float simd_sums[8]; // max 8 SIMD groups (256/32)
if (simd_lane == 0) {
simd_sums[simd_id] = sum_sq;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// First SIMD group reduces the partial sums
float total = 0.0f;
if (simd_id == 0 && simd_lane < (threads / 32)) {
total = simd_sum(simd_sums[simd_lane]);
}
// Broadcast the RMS scale factor
threadgroup float rms_scale;
if (lid == 0) {
rms_scale = rsqrt(total / float(N) + epsilon);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Step 4: Apply normalization
float scale = rms_scale;
for (uint i = lid; i < N; i += threads) {
float val = float(input[i]);
output[i] = half(val * scale * float(weight[i]));
}
}
This kernel demonstrates several important MSL patterns:
- Function constants (
N) for compile-time specialization. - Mixed precision: Load as
half, compute infloat, store ashalf. - Two-stage reduction: First within SIMD groups (
simd_sum), then across SIMD groups via threadgroup memory. - Stride loop pattern: Each thread processes multiple elements (
for (uint i = lid; i < N; i += threads)).
Example: Vectorized Element-Wise Operations
For simple element-wise operations (add, multiply, activation functions), vectorized loads are the key optimization:
kernel void silu_activation(
device const half4* input [[buffer(0)]],
device half4* output [[buffer(1)]],
uint id [[thread_position_in_grid]]
) {
half4 x = input[id];
// SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))
half4 sigmoid_x = half4(1.0h) / (half4(1.0h) + exp(-x));
output[id] = x * sigmoid_x;
}
By processing half4 (4 elements per thread), each thread does useful work on 8 bytes of data per load. With a threadgroup size of 256, one threadgroup processes 1024 elements. This keeps the memory system saturated.
Example: Fused RoPE (Rotary Position Embeddings)
RoPE is used in virtually every modern LLM. It applies a rotation to pairs of dimensions based on their position in the sequence:
constant uint HEAD_DIM [[function_constant(0)]];
constant float ROPE_BASE [[function_constant(1)]];
kernel void rope(
device half* q [[buffer(0)]], // query tensor
device half* k [[buffer(1)]], // key tensor
constant uint& seq_pos [[buffer(2)]], // position in sequence
uint tid [[thread_position_in_grid]]
) {
// Each thread processes one pair of dimensions
uint pair_idx = tid; // Which (cos, sin) pair
uint d = pair_idx % (HEAD_DIM / 2);
uint head = pair_idx / (HEAD_DIM / 2);
// Compute rotation angle: theta = pos * base^(-2d/dim)
float freq = 1.0f / pow(ROPE_BASE, float(2 * d) / float(HEAD_DIM));
float angle = float(seq_pos) * freq;
float cos_angle = cos(angle);
float sin_angle = sin(angle);
// Apply rotation to Q
uint base_idx = head * HEAD_DIM;
float q0 = float(q[base_idx + d]);
float q1 = float(q[base_idx + d + HEAD_DIM / 2]);
q[base_idx + d] = half(q0 * cos_angle - q1 * sin_angle);
q[base_idx + d + HEAD_DIM/2] = half(q0 * sin_angle + q1 * cos_angle);
// Apply same rotation to K
float k0 = float(k[base_idx + d]);
float k1 = float(k[base_idx + d + HEAD_DIM / 2]);
k[base_idx + d] = half(k0 * cos_angle - k1 * sin_angle);
k[base_idx + d + HEAD_DIM/2] = half(k0 * sin_angle + k1 * cos_angle);
}
Notice the function constants: HEAD_DIM and ROPE_BASE. These are fixed per model (e.g., 128 and 10000.0 for LLaMA), so the compiler can optimize the frequency computation and potentially pre-compute parts of it.
Built-in Math Functions
MSL provides a comprehensive set of math functions. Here are the ones you will use most often in ML kernels:
// Exponential and logarithm
float e = exp(x); // e^x
float l = log(x); // natural log
half e_fast = exp(h); // FP16 exp -- 2x throughput
// Trigonometric
float s = sin(x);
float c = cos(x);
// Power and roots
float p = pow(base, exponent);
float s = sqrt(x);
float r = rsqrt(x); // 1/sqrt(x) -- faster than 1.0/sqrt(x)
// Min, max, clamp
float m = min(a, b);
float M = max(a, b);
float c = clamp(x, lo, hi); // max(lo, min(x, hi))
// Absolute value
float a = abs(x);
// Fused multiply-add (one rounding instead of two)
float f = fma(a, b, c); // a*b + c
// Dot product of vectors
float d = dot(v1, v2); // v1.x*v2.x + v1.y*v2.y + ...
// Type conversion (explicit)
half h = half(f); // float -> half (may lose precision)
float f = float(h); // half -> float (lossless)
The rsqrt function deserves special mention. You will see it everywhere in ML kernels (normalization layers, attention scaling). It computes 1/sqrt(x) in a single instruction, which is faster than computing sqrt(x) and then dividing.
Structs and Parameter Passing
For kernels with many parameters, pack them into a struct:
struct GEMMParams {
uint M; // rows of A, rows of C
uint N; // cols of B, cols of C
uint K; // cols of A, rows of B
float alpha; // scaling factor
};
kernel void gemm(
device const half* A [[buffer(0)]],
device const half* B [[buffer(1)]],
device half* C [[buffer(2)]],
constant GEMMParams& params [[buffer(3)]],
uint2 gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]]
) {
uint M = params.M;
uint N = params.N;
uint K = params.K;
// ...
}
On the CPU side, this struct gets passed via setBytes() (for structs under 4KB) or via a MTLBuffer. The struct layout must match between MSL and the host language. This means you need to be careful about alignment – MSL follows standard C++ alignment rules, and your Rust/Swift struct must match.
#![allow(unused)]
fn main() {
// Rust side -- must match the MSL struct layout
#[repr(C)]
struct GEMMParams {
m: u32,
n: u32,
k: u32,
alpha: f32,
}
}
The #[repr(C)] attribute ensures Rust uses C-compatible layout, which matches MSL’s struct layout.
Conditional Compilation with Function Constants
We touched on function constants in the previous chapter. Here is a deeper look at how they enable conditional compilation in MSL:
constant bool HAS_BIAS [[function_constant(0)]];
constant bool HAS_RESIDUAL[[function_constant(1)]];
constant uint BLOCK_SIZE [[function_constant(2)]];
kernel void linear_layer(
device const half* input [[buffer(0)]],
device const half* weights [[buffer(1)]],
device half* output [[buffer(2)]],
device const half* bias [[buffer(3)]],
device const half* residual[[buffer(4)]],
constant uint& N [[buffer(5)]],
uint tid [[thread_position_in_grid]]
) {
half result = /* ... compute matmul ... */ ;
// These branches are resolved at compile time!
// No runtime branch penalty.
if (HAS_BIAS) {
result += bias[tid];
}
if (HAS_RESIDUAL) {
result += residual[tid];
}
output[tid] = result;
}
When HAS_BIAS is false, the compiler completely removes the bias addition and the bias buffer binding. The generated code is identical to a kernel that never had bias support in the first place. This lets you write one kernel source that generates many specialized variants.
Atomic Operations
Sometimes multiple threads need to update the same memory location. MSL provides atomic operations for this:
kernel void histogram(
device const uint* data [[buffer(0)]],
device atomic_uint* bins [[buffer(1)]],
uint tid [[thread_position_in_grid]]
) {
uint value = data[tid];
uint bin = value % 256;
atomic_fetch_add_explicit(&bins[bin], 1, memory_order_relaxed);
}
Atomics are slow – they serialize access. Avoid them in hot paths. If you find yourself using atomics in a performance-critical kernel, there is almost certainly a better algorithm (like per-threadgroup histograms followed by a merge).
Putting It All Together: A Complete Softmax Kernel
Let us write a production-quality softmax kernel that handles arbitrary row sizes using all the MSL features we have covered:
#include <metal_stdlib>
using namespace metal;
constant uint COLS [[function_constant(0)]];
kernel void softmax(
device const half* input [[buffer(0)]],
device half* output [[buffer(1)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]]
) {
// Each threadgroup processes one row
uint row = gid;
uint threads = 256;
// Threadgroup memory for cross-SIMD-group reduction
threadgroup float simd_max_vals[8];
threadgroup float simd_sum_vals[8];
// ---- Pass 1: Find max ----
float local_max = -INFINITY;
for (uint i = lid; i < COLS; i += threads) {
local_max = max(local_max, float(input[row * COLS + i]));
}
// Reduce within SIMD group
local_max = simd_max(local_max);
// Reduce across SIMD groups
if (simd_lane == 0) {
simd_max_vals[simd_id] = local_max;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float row_max;
if (simd_id == 0) {
float v = (simd_lane < threads / 32) ?
simd_max_vals[simd_lane] : -INFINITY;
row_max = simd_max(v);
}
// Broadcast max to all threads
threadgroup float shared_max;
if (lid == 0) shared_max = row_max;
threadgroup_barrier(mem_flags::mem_threadgroup);
row_max = shared_max;
// ---- Pass 2: Compute exp and sum ----
float local_sum = 0.0f;
for (uint i = lid; i < COLS; i += threads) {
float val = exp(float(input[row * COLS + i]) - row_max);
local_sum += val;
}
// Reduce sum (same two-stage pattern)
local_sum = simd_sum(local_sum);
if (simd_lane == 0) {
simd_sum_vals[simd_id] = local_sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float row_sum;
if (simd_id == 0) {
float v = (simd_lane < threads / 32) ?
simd_sum_vals[simd_lane] : 0.0f;
row_sum = simd_sum(v);
}
threadgroup float shared_sum;
if (lid == 0) shared_sum = row_sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
row_sum = shared_sum;
// ---- Pass 3: Normalize and write output ----
float inv_sum = 1.0f / row_sum;
for (uint i = lid; i < COLS; i += threads) {
float val = exp(float(input[row * COLS + i]) - row_max);
output[row * COLS + i] = half(val * inv_sum);
}
}
This kernel has three passes over the data (find max, compute exponentials and sum, normalize). Each pass uses the two-stage reduction pattern (SIMD-level reduction followed by threadgroup-level reduction). The function constant COLS lets the compiler optimize the loop bounds.
Summary
MSL is a practical language for writing GPU kernels. Here are the key takeaways:
- MSL is C++14 with GPU extensions. No surprises if you know C++.
- Four address spaces:
device(main memory),constant(cached read-only),threadgroup(fast shared SRAM),thread(registers). - Vector types (
half4,float4) are essential for memory throughput. - The
halftype gives you 2x throughput and 2x bandwidth – use it everywhere in ML kernels. - SIMD intrinsics (
simd_sum,simd_max,simd_shuffle, etc.) enable fast intra-SIMD-group communication without shared memory or barriers. - Threadgroup memory + barriers enable cross-SIMD-group communication within a threadgroup.
- Function constants create specialized kernel variants at compile time.
- The two-stage reduction (SIMD reduce, then threadgroup reduce) is the most common pattern in ML kernels.
In the next chapter, we will zoom out from individual threads to look at how threadgroups and SIMD groups are organized and dispatched.