Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Normalization Kernels

Every transformer layer starts with a normalization step. Before attention, before the FFN, before practically anything interesting happens, you need to normalize your hidden states. Without it, activations drift, gradients explode, and your model produces gibberish. The math is simple. The GPU implementation? That is where things get interesting.

In this chapter we will walk through akunu’s Metal normalization kernels: RMSNorm, LayerNorm, fused Residual+RMSNorm, per-head RMSNorm, and the Gemma variants. We will trace every threadgroup reduction, every rsqrt call, and every clamping trick that keeps FP16 from blowing up.

The Normalization Zoo

Modern LLMs use two main normalization flavors:

LayerNorm (GPT-2, Whisper):    y[i] = ((x[i] - mean) / sqrt(var + eps)) * weight[i] + bias[i]

RMSNorm  (LLaMA, Qwen, etc):  y[i] = (x[i] / sqrt(mean(x^2) + eps)) * weight[i]

RMSNorm is cheaper: it skips the mean subtraction entirely. You only need the root mean square (RMS) of the input. No centering, no bias. This is why nearly every modern LLM uses it – one fewer reduction pass over the data, and the empirical quality is essentially identical.

Let us see how akunu implements both.

RMSNorm: One Threadgroup per Row

The dispatch model is beautifully simple: one threadgroup handles one row of the input tensor. If you have a batch of 8 tokens with dimension 4096, you launch 8 threadgroups. Each threadgroup has enough threads to cover the dimension with striding.

Dispatch Grid:
  threadgroups = (num_rows, 1, 1)
  threads_per_threadgroup = (tg_size, 1, 1)    // e.g., 256 or 512

  Row 0: TG 0   -->  threads 0..tg_size-1 stride over dim elements
  Row 1: TG 1   -->  threads 0..tg_size-1 stride over dim elements
  ...
  Row N: TG N   -->  threads 0..tg_size-1 stride over dim elements

Here is the actual kernel from rmsnorm.metal:

kernel void rmsnorm_f16(
    device const half       *input   [[buffer(0)]],
    device const half       *weight  [[buffer(1)]],
    device half             *output  [[buffer(2)]],
    constant RMSNormParams  &params  [[buffer(3)]],
    uint3 tgid_v  [[threadgroup_position_in_grid]],
    uint3 tid_v   [[thread_position_in_threadgroup]],
    uint  sgid    [[simdgroup_index_in_threadgroup]],
    uint  slid    [[thread_index_in_simdgroup]],
    uint3 tpg     [[threads_per_threadgroup]]
) {
    const uint dim = params.dim;
    const float eps = params.eps;
    const uint row = tgid_v.x;
    const uint tid = tid_v.x;
    const uint tg_size = tpg.x;

    device const half *row_in  = input  + row * dim;
    device half       *row_out = output + row * dim;

    threadgroup float shared[32];

    float local_sum_sq = 0.0f;
    for (uint i = tid; i < dim; i += tg_size) {
        float v = float(row_in[i]);
        local_sum_sq += v * v;
    }

    float total_sum_sq = tg_reduce_sum(local_sum_sq, sgid, slid, tg_size, shared);
    float rms = rsqrt(total_sum_sq / float(dim) + eps);

    for (uint i = tid; i < dim; i += tg_size) {
        row_out[i] = half(float(row_in[i]) * rms) * weight[i];
    }
}

Let us break this down step by step.

Step 1: Accumulate Sum of Squares

Each thread strides over the row, accumulating x[i]^2 into a local float register. For a 4096-dimensional row with 256 threads, each thread processes 16 elements:

Thread 0:   x[0]^2 + x[256]^2 + x[512]^2 + ... + x[3840]^2
Thread 1:   x[1]^2 + x[257]^2 + x[513]^2 + ... + x[3841]^2
Thread 2:   x[2]^2 + x[258]^2 + x[514]^2 + ... + x[3842]^2
...
Thread 255: x[255]^2 + x[511]^2 + x[767]^2 + ... + x[4095]^2

Notice the promotion to float: float v = float(row_in[i]). The input is FP16, but all accumulation happens in FP32. This is critical. Summing squares in FP16 overflows almost instantly for typical hidden state magnitudes.

Step 2: Two-Stage Threadgroup Reduction

This is where the magic happens. We need to sum 256 (or however many) partial sums into a single total. The kernel calls tg_reduce_sum, which lives in KernelCommon.h:

inline float tg_reduce_sum(float val, uint sgid, uint slid,
                            uint tg_size, threadgroup float *shared) {
    float simd_val = simd_sum(val);
    uint n_sg = (tg_size + SIMD_WIDTH - 1) / SIMD_WIDTH;
    if (slid == 0)
        shared[sgid] = simd_val;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    if (sgid == 0) {
        float v = (slid < n_sg) ? shared[slid] : 0.0f;
        float total = simd_sum(v);
        if (slid == 0)
            shared[0] = total;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
    return shared[0];
}

This is a classic two-stage SIMD reduction. Let me draw it out for a threadgroup of 256 threads (8 SIMD groups of 32 threads each):

Stage 1: SIMD-level reduction (hardware shuffle, no barriers needed)
+------------------------------------------------------------------+
| SIMD Group 0:  t0..t31  --simd_sum-->  partial_0   (in t0)      |
| SIMD Group 1:  t32..t63 --simd_sum-->  partial_1   (in t32)     |
| SIMD Group 2:  t64..t95 --simd_sum-->  partial_2   (in t64)     |
| ...                                                               |
| SIMD Group 7: t224..t255 --simd_sum-->  partial_7  (in t224)    |
+------------------------------------------------------------------+
                              |
                     lane 0 of each SG writes to shared[]
                              |
                    [threadgroup_barrier]
                              |
                              v
Stage 2: Cross-SIMD reduction (only SIMD Group 0 participates)
+------------------------------------------------------------------+
| SG 0, lane 0: reads shared[0] = partial_0                       |
| SG 0, lane 1: reads shared[1] = partial_1                       |
| SG 0, lane 2: reads shared[2] = partial_2                       |
| ...                                                               |
| SG 0, lane 7: reads shared[7] = partial_7                       |
| SG 0, lanes 8-31: use 0.0f (padding)                            |
|                                                                   |
| --simd_sum-->  TOTAL   (written to shared[0] by lane 0)         |
+------------------------------------------------------------------+
                              |
                    [threadgroup_barrier]
                              |
                              v
All threads read shared[0] = total_sum_sq

Why only 32 floats of shared memory? Because Apple GPUs have 32 threads per SIMD group, and you can have at most 32 SIMD groups in a threadgroup (32 * 32 = 1024 threads max). So shared[32] is always enough for the cross-SIMD exchange.

The beauty here is that simd_sum compiles to hardware shuffle instructions. No shared memory access, no barriers. It is a register-to-register operation within the SIMD group. The only shared memory access is the handoff between stage 1 and stage 2, requiring just two barriers for the entire reduction.

Step 3: Compute the Inverse RMS

float rms = rsqrt(total_sum_sq / float(dim) + eps);

This single line does three things:

  1. Divides the sum of squares by the dimension to get the mean
  2. Adds epsilon (typically 1e-5 or 1e-6) for numerical stability
  3. Calls rsqrt – the reciprocal square root

The rsqrt hardware instruction on Apple Silicon computes 1/sqrt(x) in a single cycle. Compared to doing 1.0f / sqrt(x), the fused rsqrt is both faster and more numerically accurate. This is why the kernel computes x * rsqrt(mean_sq + eps) rather than x / sqrt(mean_sq + eps) – multiplication is cheaper than division, and rsqrt gives us the reciprocal directly.

Step 4: Normalize and Scale

for (uint i = tid; i < dim; i += tg_size) {
    row_out[i] = half(float(row_in[i]) * rms) * weight[i];
}

Each thread walks the row again, multiplying each element by the inverse RMS and then by the learned weight. The float(row_in[i]) * rms computation happens in FP32, then is cast down to FP16 for the multiply with weight[i] (which is already FP16). This two-pass approach (one pass for reduction, one for output) is unavoidable – you cannot write output until you know the RMS.

LayerNorm: Two Reductions Instead of One

LayerNorm is the older, more expensive cousin. Used in GPT-2, Whisper, and other pre-LLaMA architectures, it subtracts the mean before dividing by standard deviation:

y[i] = ((x[i] - mean) / sqrt(var + eps)) * weight[i] + bias[i]

The extra mean subtraction means an extra reduction pass. Here is the kernel:

kernel void layernorm_f16(
    device const half        *input   [[buffer(0)]],
    device const half        *weight  [[buffer(1)]],
    device const half        *bias    [[buffer(2)]],
    device half              *output  [[buffer(3)]],
    constant LayerNormParams &params  [[buffer(4)]],
    ...
) {
    // Pass 1: compute mean
    float local_sum = 0.0f;
    for (uint i = tid; i < dim; i += tg_size)
        local_sum += float(row_in[i]);

    float mean = tg_reduce_sum(local_sum, sgid, slid, tg_size, shared)
                 / float(dim);

    // Pass 2: compute variance
    float local_var = 0.0f;
    for (uint i = tid; i < dim; i += tg_size) {
        float diff = float(row_in[i]) - mean;
        local_var += diff * diff;
    }

    float variance = tg_reduce_sum(local_var, sgid, slid, tg_size, shared)
                     / float(dim);
    float inv_std = rsqrt(variance + eps);

    // Pass 3: normalize, scale, shift
    for (uint i = tid; i < dim; i += tg_size) {
        float val = (float(row_in[i]) - mean) * inv_std;
        row_out[i] = half(val) * weight[i] + bias[i];
    }
}

The data flow looks like this:

Pass 1: Compute mean
  row_in --> [sum elements] --> [tg_reduce_sum] --> mean = sum / dim

Pass 2: Compute variance  
  row_in --> [(x - mean)^2] --> [tg_reduce_sum] --> var = sum / dim
                                                     inv_std = rsqrt(var + eps)

Pass 3: Output
  row_in --> [(x - mean) * inv_std * weight + bias] --> row_out

That is three passes over the data versus two for RMSNorm. For a 4096-dim model, this means 3 * 4096 * 2 = 24,576 bytes of memory traffic per row instead of 2 * 4096 * 2 = 16,384 bytes. The extra pass (and the extra tg_reduce_sum) is why RMSNorm replaced LayerNorm in modern architectures.

Also note the + bias[i] at the end. LayerNorm has both learnable scale (weight) and shift (bias) parameters. RMSNorm drops the bias entirely.

Fused Residual + RMSNorm: The Performance Killer Feature

Here is where akunu gets clever. In a transformer, the pattern looks like this:

hidden = attention_output + residual       // residual add
normalized = RMSNorm(hidden) * weight      // normalize for FFN

The naive approach dispatches two kernels: one for the addition, one for the normalization. But look at what that means in terms of memory:

Naive (2 kernels):
  Kernel 1 (residual_add):  read a[], read b[]  --> write hidden[]
  Kernel 2 (rmsnorm):       read hidden[]        --> write norm[]
  
  Total memory traffic: 4 * dim * sizeof(half) reads + 2 * dim * sizeof(half) writes

Fused (1 kernel):
  read a[], read b[] --> compute hidden, accumulate sum_sq, write hidden[]
  read hidden[]      --> write norm[]
  
  Total memory traffic: saves one full read of hidden[] (2 * dim bytes)

But the real win is not just bandwidth. It is kernel launch overhead. Each Metal compute dispatch has a non-trivial fixed cost – command buffer encoding, GPU scheduling, pipeline state switches. Eliminating a dispatch is free performance.

Here is the fused kernel:

kernel void residual_rmsnorm_f16(
    device const half       *a         [[buffer(0)]],
    device const half       *b         [[buffer(1)]],
    device const half       *weight    [[buffer(2)]],
    device half             *res_out   [[buffer(3)]],
    device half             *norm_out  [[buffer(4)]],
    constant RMSNormParams  &params    [[buffer(5)]],
    ...
) {
    constexpr float F16_MAX = 65504.0f;
    float local_sum_sq = 0.0f;
    for (uint i = tid; i < dim; i += tg_size) {
        float val = clamp(float(row_a[i]) + float(row_b[i]), -F16_MAX, F16_MAX);
        row_res[i] = half(val);
        local_sum_sq += val * val;
    }

    float total_sum_sq = tg_reduce_sum(local_sum_sq, sgid, slid, tg_size, shared);
    float rms = rsqrt(total_sum_sq / float(dim) + eps);

    for (uint i = tid; i < dim; i += tg_size) {
        row_norm[i] = half(float(row_res[i]) * rms) * weight[i];
    }
}

The F16_MAX Clamp

Notice this critical line:

float val = clamp(float(row_a[i]) + float(row_b[i]), -F16_MAX, F16_MAX);

FP16 can represent values up to 65504.0. If the residual sum exceeds this, the cast back to half produces infinity, which then poisons everything downstream. The clamp prevents this:

Without clamp:
  a[i] = 40000.0h,  b[i] = 30000.0h
  sum = 70000.0f   (fine in float)
  half(70000.0f) = inf   (OVERFLOW -- poisons entire row)

With clamp:
  sum = clamp(70000.0f, -65504.0f, 65504.0f) = 65504.0f
  half(65504.0f) = 65504.0h  (saturated but finite)

This is a practical concern. During inference with long contexts, residual accumulation can push values near the FP16 boundary. The clamp costs essentially nothing (it compiles to a min/max pair) but prevents catastrophic NaN propagation.

Dual Output Buffers

The fused kernel writes two outputs: res_out (the residual sum, needed for the next residual connection) and norm_out (the normalized result, fed to the FFN or attention). This is why there are five buffer bindings instead of three.

Per-Head RMSNorm: Qwen3’s QK Normalization

Some models (notably Qwen3) apply RMSNorm independently to each attention head’s Q and K projections. Instead of normalizing a full [seq_len, model_dim] row, you normalize individual [head_dim] slices.

The input layout is [seq_len, n_heads, head_dim], and each threadgroup handles one (seq_position, head) pair:

Dispatch Grid:
  threadgroups = (n_heads, seq_len, 1)
  threads_per_threadgroup = (min(head_dim, 1024), 1, 1)

  TG(0,0): head 0, pos 0   -->  normalize row[0 * n_heads + 0]
  TG(1,0): head 1, pos 0   -->  normalize row[0 * n_heads + 1]
  ...
  TG(0,1): head 0, pos 1   -->  normalize row[1 * n_heads + 0]
  ...

The reduction is slightly different here. Instead of using the tg_reduce_sum helper, the kernel manually implements the two-stage SIMD reduction:

float sum_sq = 0.0f;
for (uint d = tid; d < head_dim; d += tg_size) {
    float v = float(row[d]);
    sum_sq += v * v;
}

// Stage 1: SIMD-level sum
sum_sq = simd_sum(sum_sq);
if (lane == 0) shared[warp] = sum_sq;
threadgroup_barrier(mem_flags::mem_threadgroup);

// Stage 2: Cross-SIMD sum (only warp 0)
if (warp == 0) {
    float v = (lane < (tg_size + 31) / 32) ? shared[lane] : 0.0f;
    v = simd_sum(v);
    shared[0] = v;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float total_sq = shared[0];

This is the same algorithm as tg_reduce_sum, just inlined. For head_dim = 128 (common in Qwen3), you only need 128 threads and 4 SIMD groups. The reduction is tiny.

The normalization is in-place – the kernel modifies x directly:

for (uint d = tid; d < head_dim; d += tg_size) {
    row[d] = half(float(row[d]) * rms * float(weight[d]));
}

This is safe because each threadgroup works on a different (head, position) pair. No data races.

The Gemma Variants: weight’ = 1 + weight

Google’s Gemma model has an unusual normalization quirk. The norm weights are initialized to zero and the effective scale is (1 + weight) rather than just weight. This means a freshly initialized model has identity-like normalization (scale = 1 everywhere), which improves training stability.

Akunu has dedicated Gemma variants for both RMSNorm and the fused residual version. The only difference is in the final output line:

// Standard RMSNorm:
row_out[i] = half(float(row_in[i]) * rms) * weight[i];

// Gemma RMSNorm:
row_out[i] = half(float(row_in[i]) * rms * (1.0f + float(weight[i])));

Notice the Gemma variant does the entire computation in float32 before casting to FP16. The comment in the source says it all: “Compute entirely in float to avoid F16 overflow in normalized * (1+weight).” If the weight is, say, 2.0h, then (1 + weight) = 3.0, and multiplying a normalized value near 1.0 by 3.0 might push things dangerously close to the FP16 range limit. Doing it all in float32 avoids this.

There is also a Gemma variant of the fused residual+RMSNorm kernel (residual_rmsnorm_gemma_f16) that combines both the F16_MAX clamping and the (1 + weight) scaling. Same dispatch, same reduction, just a different final multiply.

Memory Access Patterns

Let us think about how these kernels interact with the GPU memory hierarchy:

Kernel Memory Access Pattern
+-------------------------------------------------------------------+
|                                                                   |
|  Global Memory (Device)                                           |
|  +-------------------------------------------------------------+ |
|  | input[row * dim + 0..dim-1]    <-- read once (pass 1)       | |
|  | input[row * dim + 0..dim-1]    <-- read again (pass 2)      | |
|  | weight[0..dim-1]              <-- read once (pass 2)         | |
|  | output[row * dim + 0..dim-1]  <-- write once (pass 2)       | |
|  +-------------------------------------------------------------+ |
|                                                                   |
|  Threadgroup Memory (32 KB max)                                   |
|  +-------------------------------------------------------------+ |
|  | shared[32]  <-- 128 bytes for reduction intermediates        | |
|  +-------------------------------------------------------------+ |
|                                                                   |
|  Registers (per thread)                                           |
|  +-------------------------------------------------------------+ |
|  | local_sum_sq (float) -- accumulator                          | |
|  | rms (float)          -- broadcast to all threads via shared  | |
|  +-------------------------------------------------------------+ |
|                                                                   |
+-------------------------------------------------------------------+

The strided access pattern (for i = tid; i < dim; i += tg_size) ensures coalesced memory access. Thread 0 reads element 0, thread 1 reads element 1, and so on. Since Apple GPUs fetch 128 bytes per memory transaction, and FP16 elements are 2 bytes each, a single fetch satisfies 64 consecutive threads. With a SIMD width of 32, that means two SIMD groups’ worth of data per fetch. Excellent utilization.

The weight vector is also read with perfect coalescing. And since weights are the same for every row, they likely stay in the GPU L2 cache after the first row is processed. For a 4096-dim model, the weight vector is only 8 KB – easily cached.

Performance Characteristics

Let us estimate the arithmetic intensity (operations per byte transferred) for RMSNorm on a 4096-dim row with 256 threads:

Pass 1 (sum of squares):
  Reads:  dim * 2 bytes = 8,192 bytes
  FLOPs:  dim * 2 (cast + multiply) + dim (add) = 12,288 FLOPs
  + reduction: ~500 FLOPs (negligible)

Pass 2 (normalize):
  Reads:  dim * 2 (input) + dim * 2 (weight) = 16,384 bytes
  Writes: dim * 2 (output) = 8,192 bytes
  FLOPs:  dim * 3 (cast, multiply by rms, multiply by weight) = 12,288 FLOPs

Total:
  Memory: 32,768 bytes
  FLOPs:  ~24,576
  Arithmetic intensity: 0.75 FLOPs/byte

This is firmly in the memory-bound regime. Apple M-series GPUs can do hundreds of GFLOPs but have “only” ~200-400 GB/s of memory bandwidth. RMSNorm barely scratches the ALU – the bottleneck is waiting for memory loads.

This is actually fine. Normalization is a tiny fraction of total inference time. The GEMM (matrix multiply) kernels dominate runtime. Normalization exists in the cracks between GEMMs, and the main optimization goal is to avoid unnecessary memory round-trips – which is exactly what the fused residual+RMSNorm achieves.

The Complete Reduction Tree

To really nail down the reduction, let us trace an example with 8 SIMD groups (256 threads) reducing the values [10, 20, 30, 40, 50, 60, 70, 80] (one partial sum per SIMD group):

                        SIMD Reduction Tree
                        ===================

Level 0: Each SIMD group does simd_sum() on 32 per-thread values
         (hardware shuffle -- zero latency visible to software)

  SG0: threads 0-31   -> simd_sum -> 10.0   (lane 0 writes shared[0])
  SG1: threads 32-63  -> simd_sum -> 20.0   (lane 0 writes shared[1])
  SG2: threads 64-95  -> simd_sum -> 30.0   (lane 0 writes shared[2])
  SG3: threads 96-127 -> simd_sum -> 40.0   (lane 0 writes shared[3])
  SG4: threads 128-159-> simd_sum -> 50.0   (lane 0 writes shared[4])
  SG5: threads 160-191-> simd_sum -> 60.0   (lane 0 writes shared[5])
  SG6: threads 192-223-> simd_sum -> 70.0   (lane 0 writes shared[6])
  SG7: threads 224-255-> simd_sum -> 80.0   (lane 0 writes shared[7])

                     --- BARRIER ---

Level 1: SIMD Group 0 reads shared[], does one more simd_sum()

  SG0, lane 0: v = shared[0] = 10.0
  SG0, lane 1: v = shared[1] = 20.0
  SG0, lane 2: v = shared[2] = 30.0
  SG0, lane 3: v = shared[3] = 40.0
  SG0, lane 4: v = shared[4] = 50.0
  SG0, lane 5: v = shared[5] = 60.0
  SG0, lane 6: v = shared[6] = 70.0
  SG0, lane 7: v = shared[7] = 80.0
  SG0, lanes 8-31:  v = 0.0  (padding)

  simd_sum -> 360.0   (lane 0 writes shared[0])

                     --- BARRIER ---

All 256 threads read: shared[0] = 360.0

Two barriers, two simd_sum calls. That is all it takes to reduce 256 values. Compare this to a naive shared-memory reduction tree, which would need log2(256) = 8 barrier-synchronize-and-halve steps. The SIMD-first approach is dramatically more efficient on Apple hardware.

Summary

Akunu’s normalization kernels follow a clear pattern:

KernelModelPassesOutputsSpecial
rmsnorm_f16LLaMA, Mistral21 (norm)Basic RMSNorm
layernorm_f16GPT-2, Whisper31 (norm)Mean + variance
residual_rmsnorm_f16LLaMA, Mistral22 (res + norm)F16_MAX clamp
rmsnorm_gemma_f16Gemma21 (norm)(1+w) scaling
residual_rmsnorm_gemma_f16Gemma22 (res + norm)Both tricks
head_rmsnorm_f16Qwen321 (in-place)Per-head QK norm
head_rmsnorm_gemma_f16Qwen3+Gemma21 (in-place)Per-head + (1+w)

The key takeaways:

  1. One threadgroup per row – simple, parallel, no cross-row communication.
  2. Two-stage SIMD reductionsimd_sum within each SIMD group, then one more simd_sum across SIMD groups. Only two barriers.
  3. FP32 accumulation – all arithmetic in float32, cast to FP16 only for memory writes.
  4. Fused kernels – the residual+RMSNorm fusion saves an entire kernel dispatch and one pass over the data.
  5. rsqrt hardware – single-cycle reciprocal square root, multiplication instead of division.
  6. F16_MAX clamping – a cheap safety net that prevents overflow in the residual sum from poisoning downstream computation.

Next up, we will see how akunu applies positional information to these normalized vectors using Rotary Position Embeddings.