Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

RoPE Kernels

Transformers have no built-in sense of order. Without positional information, the sentence “the cat sat on the mat” is indistinguishable from “mat the on sat cat the.” Rotary Position Embeddings (RoPE) solve this by rotating pairs of elements in the Q and K vectors by position-dependent angles. It is mathematically elegant, it extrapolates to longer sequences than training saw, and it composes beautifully with the attention dot product.

In this chapter we will walk through akunu’s seven RoPE-related Metal kernels: the basic standard and NeoX variants, the fused RoPE+KV-cache-write kernels for single-token decode, and the batch (prefill) versions. We will finish with the ultimate fusion: per-head RMSNorm + NeoX RoPE + KV write in a single dispatch.

The Math Behind RoPE

RoPE treats each pair of elements in a head dimension as a 2D vector and rotates it by an angle proportional to the token’s position. For a pair at dimension index d, the rotation angle is:

theta_d = position / (base_freq ^ (2d / head_dim))

Where base_freq (typically 10000.0) controls the frequency spectrum. Low-index pairs rotate quickly (high frequency), high-index pairs rotate slowly (low frequency). This gives each position a unique “fingerprint” of rotations.

The rotation itself is just a 2D rotation matrix:

[x0']   [cos(theta)  -sin(theta)] [x0]
[x1'] = [sin(theta)   cos(theta)] [x1]

  x0' = x0 * cos(theta) - x1 * sin(theta)
  x1' = x0 * sin(theta) + x1 * cos(theta)

Now, different model families pair up elements differently. This is where the “standard” vs “NeoX” distinction comes in.

Standard (Interleaved) vs NeoX (Split-Half) Pairing

The two RoPE styles differ in which elements they pair for rotation. The animation below shows both. Use the position slider to see how vectors rotate — low-index pairs spin fast, high-index pairs spin slow.

Standard (Interleaved)
head_dim = 8, pairs: (0,1) (2,3) (4,5) (6,7)
NeoX (Split-Half)
head_dim = 8, pairs: (0,4) (1,5) (2,6) (3,7)

Standard pairing comes from the original LLaMA GGUF format. NeoX pairing (named after GPT-NeoX) is used by HuggingFace models including Qwen3, and matches the rotate_half function in the Python reference code:

# HuggingFace rotate_half:
x1 = x[..., :head_dim//2]
x2 = x[..., head_dim//2:]
rotated = cat(-x2, x1, dim=-1)
output = x * cos + rotated * sin

Both styles are mathematically equivalent in terms of expressiveness – the model just needs to learn weights that match the pairing convention. But you must use the same convention as the model was trained with, or the positions will be garbage.

The Standard RoPE Kernel

Let us start with the simplest kernel, rope_f16:

kernel void rope_f16(
    device half              *x       [[buffer(0)]],
    constant RoPEParams      &params  [[buffer(1)]],
    device const float       *freqs   [[buffer(2)]],
    uint3 tid [[thread_position_in_grid]]
) {
    const uint pair_idx = tid.x;   // [0 .. head_dim/2)
    const uint head_idx = tid.y;   // [0 .. n_heads)
    const uint seq_idx  = tid.z;   // [0 .. seq_len)

    const uint half_dim = params.head_dim / 2;
    if (pair_idx >= half_dim || head_idx >= params.n_heads
        || seq_idx >= params.seq_len) return;

    const uint pos = seq_idx + params.pos_offset;
    float freq_divisor = (freqs != nullptr) ? freqs[pair_idx]
                       : pow(params.theta, float(2 * pair_idx) / float(params.head_dim));
    float freq = float(pos) / freq_divisor;
    float cos_f = cos(freq);
    float sin_f = sin(freq);

    uint base = seq_idx * stride + head_idx * params.head_dim + 2 * pair_idx;

    float x0 = float(x[base]);
    float x1 = float(x[base + 1]);

    x[base]     = half(x0 * cos_f - x1 * sin_f);
    x[base + 1] = half(x0 * sin_f + x1 * cos_f);
}

Dispatch: One Thread per Rotation

The dispatch grid is 3D:

Grid: (head_dim/2, n_heads, seq_len)

For head_dim=128, n_heads=32, seq_len=1 (decode):
  Total threads = 64 * 32 * 1 = 2,048

For head_dim=128, n_heads=32, seq_len=512 (prefill):
  Total threads = 64 * 32 * 512 = 1,048,576

Each thread handles exactly one pair of elements. No shared memory, no reductions, no barriers. This is a pure embarrassingly-parallel kernel – every thread reads two elements, computes a rotation, and writes two elements back. The only shared state is the params constant buffer.

Precomputed Frequencies

Notice the frequency computation has two paths:

float freq_divisor = (freqs != nullptr) ? freqs[pair_idx]
                   : pow(params.theta, float(2 * pair_idx) / float(params.head_dim));

If a precomputed frequency table is provided in buffer(2), the kernel reads from it directly. Otherwise it computes theta^(2d/head_dim) on the fly using pow. The precomputed path avoids a transcendental function (pow) per thread. For head_dim=128, that is 64 pow calls saved per head per position. The host precomputes these once at initialization:

freqs[d] = theta^(2*d / head_dim)    for d in [0, head_dim/2)

Then the kernel just does pos / freqs[d], which is a simple float division.

In-Place Operation

The kernel modifies x in place. This is safe because each thread works on a unique pair of elements:

Thread (pair=0, head=0, seq=0):  reads/writes x[0], x[1]
Thread (pair=1, head=0, seq=0):  reads/writes x[2], x[3]
Thread (pair=0, head=1, seq=0):  reads/writes x[128], x[129]
...

No two threads ever touch the same memory location. No synchronization needed.

The NeoX RoPE Kernel

The NeoX variant pairs elements at distance head_dim/2:

kernel void rope_neox_f16(
    device half              *x       [[buffer(0)]],
    constant RoPEParams      &params  [[buffer(1)]],
    device const uint        *position_ids [[buffer(2)]],
    device const float       *freqs   [[buffer(3)]],
    uint3 tid [[thread_position_in_grid]]
) {
    ...
    // NeoX-style: pair element i with element i + head_dim/2
    uint idx0 = base + pair_idx;
    uint idx1 = base + pair_idx + half_dim;

    float x0 = float(x[idx0]);
    float x1 = float(x[idx1]);

    x[idx0] = half(x0 * cos_f - x1 * sin_f);
    x[idx1] = half(x1 * cos_f + x0 * sin_f);
}

The indexing difference visualized:

Standard:  base + 2*pair_idx,  base + 2*pair_idx + 1
           (adjacent elements)

NeoX:      base + pair_idx,    base + pair_idx + half_dim
           (elements half a head apart)

For head_dim=128, pair_idx=5:
  Standard: elements 10, 11
  NeoX:     elements 5, 69

Function Constants for Position IDs

The NeoX kernel has an interesting feature – Metal function constants for per-token position IDs:

constant bool FC_USE_POSITION_IDS [[function_constant(0)]];
constant bool FC_HAS_POSITION_IDS = is_function_constant_defined(FC_USE_POSITION_IDS)
                                    && FC_USE_POSITION_IDS;
...
const uint pos = FC_HAS_POSITION_IDS ? position_ids[seq_idx]
               : (seq_idx + params.pos_offset);

When FC_USE_POSITION_IDS is true, each token reads its position from a separate buffer. This is needed for tree speculation, where tokens do not have sequential positions – you might be evaluating multiple speculative continuations in parallel, each at a different position in the sequence.

When false, positions are simply seq_idx + offset, which is the common case for normal sequential generation. The function constant lets Metal compile out the branch entirely, so there is zero overhead in the non-speculative path.

Fused RoPE + KV Cache Write: Eliminating Four Dispatches

Now we get to the real workhorses. During single-token decode (the hot path for autoregressive generation), each transformer layer needs to:

  1. Apply RoPE to Q (all heads)
  2. Apply RoPE to K (KV heads only)
  3. Write rotated K to the KV cache
  4. Write V to the KV cache (no rotation)

Naively, that is four kernel dispatches. Akunu fuses all four into one.

Here is the standard-pairing fused kernel:

kernel void rope_qkv_write_f16(
    device half                  *qkv      [[buffer(0)]],
    device half                  *k_cache  [[buffer(1)]],
    device half                  *v_cache  [[buffer(2)]],
    constant RoPEQKVWriteParams  &params   [[buffer(3)]],
    device const float           *freqs    [[buffer(4)]],
    uint2 tid [[thread_position_in_grid]]
) {
    const uint pair_idx = tid.x;  // [0, head_dim/2)
    const uint head_idx = tid.y;  // [0, n_heads)
    ...

The dispatch is 2D: (head_dim/2, n_heads). Each thread:

  1. Computes cos/sin for this position and dimension pair
  2. Rotates Q in-place for its head
  3. If head_idx < n_kv_heads, also rotates K and writes both K and V to cache

Let us trace the data flow:

QKV Buffer Layout (contiguous):
+-----------------------------------------------+
| Q: [n_heads * head_dim]                       |
+-----------------------------------------------+
| K: [n_kv_heads * head_dim]                    |  <-- at k_elem_offset
+-----------------------------------------------+
| V: [n_kv_heads * head_dim]                    |  <-- at v_elem_offset
+-----------------------------------------------+

Thread (pair=d, head=h):

  1. Compute freq = pos * freq_scale / freq_divisor
     cos_f = cos(freq), sin_f = sin(freq)

  2. Q rotation (all heads):
     q_src = h * head_dim + 2*d
     q0' = q0 * cos_f - q1 * sin_f  --> qkv[q_src]
     q1' = q0 * sin_f + q1 * cos_f  --> qkv[q_src + 1]

  3. K rotation + cache write (head h < n_kv_heads only):
     k_src = k_elem_offset + h * head_dim + 2*d
     k0' = k0 * cos_f - k1 * sin_f  --> k_cache[cache_base + 2*d]
     k1' = k0 * sin_f + k1 * cos_f  --> k_cache[cache_base + 2*d + 1]

  4. V cache write (straight copy, no rotation):
     v_cache[cache_base + 2*d] = qkv[v_src + 2*d]
     v_cache[cache_base + 2*d + 1] = qkv[v_src + 2*d + 1]

The KV Cache Layout

The cache stores K and V in head-major order:

K Cache: [n_kv_heads, max_seq_len, head_dim]
V Cache: [n_kv_heads, max_seq_len, head_dim]

cache_base = head_idx * max_seq_len * head_dim + pos * head_dim

For head 0, position 42, head_dim 128:
  cache_base = 0 * max_seq_len * 128 + 42 * 128 = 42 * 128 = 5376

For head 3, position 42, head_dim 128:
  cache_base = 3 * max_seq_len * 128 + 42 * 128

This layout means that for a given head, all positions are contiguous in memory. During attention, when you need to load K/V for all past positions of one head, the access pattern is a simple sequential scan – perfect for GPU memory throughput.

freq_scale: Dynamic Frequency Scaling

Notice the frequency computation includes a freq_scale parameter:

float freq = float(params.pos) * params.freq_scale / freq_divisor;

This enables dynamic RoPE scaling techniques like Linear Scaling (divide frequencies by a factor to extend context length) and NTK-aware scaling. If freq_scale = 1.0, you get standard RoPE. If freq_scale = 0.5, all frequencies are halved, effectively doubling the model’s positional resolution and extending its context window.

GQA-Aware Thread Assignment

The kernel elegantly handles Grouped Query Attention (GQA), where n_heads > n_kv_heads. All threads rotate Q (for all n_heads heads), but only threads where head_idx < n_kv_heads do the K/V work:

n_heads=32, n_kv_heads=8 (GQA ratio 4:1):

Threads with head_idx 0-7:   Rotate Q + Rotate K + Write K + Write V
Threads with head_idx 8-31:  Rotate Q only

  head:  0  1  2  3  4  5  6  7  8  9  10 ... 31
  Q:     *  *  *  *  *  *  *  *  *  *  *  ... *
  K/V:   *  *  *  *  *  *  *  *  .  .  .  ... .

This is a minor load imbalance – some threads do more work than others. But since the extra work (K rotation + 2 cache writes) is just a few memory ops, the imbalance is negligible compared to the cost of an extra kernel dispatch.

NeoX Fused Variant

The NeoX fused kernel (rope_neox_qkv_write_f16) is structurally identical to the standard variant, but with split-half indexing:

// Standard:  q_src = q_base + 2 * pair_idx,  q_src + 1
// NeoX:      q_lo  = q_base + pair_idx
//            q_hi  = q_base + pair_idx + half_dim

uint q_lo = q_base + pair_idx;
uint q_hi = q_base + pair_idx + half_dim;
float q0 = float(qkv[q_lo]);
float q1 = float(qkv[q_hi]);
qkv[q_lo] = half(q0 * cos_f - q1 * sin_f);
qkv[q_hi] = half(q1 * cos_f + q0 * sin_f);

Same thread count, same dispatch shape, same GQA handling. The only difference is which pairs of memory locations get rotated together.

Batch (Prefill) Variants

During prefill (processing the initial prompt), there are multiple tokens to process simultaneously. The single-token fused kernels assume seq_len = 1. The batch variants add a third grid dimension for sequence position:

Single-token dispatch:  (head_dim/2, n_heads, 1)
Batch dispatch:         (head_dim/2, n_heads, seq_len)

The batch NeoX kernel (rope_neox_batch_kv_write_f16) handles a key difference: Q, K, and V are separate buffers (already split by a preceding QKV split kernel), rather than one contiguous QKV buffer:

Single-token:
  buffer(0) = qkv [contiguous Q|K|V]

Batch:
  buffer(0) = batchQ [seq_len, n_heads * head_dim]
  buffer(1) = batchK [seq_len, n_kv_heads * head_dim]
  buffer(2) = batchV [seq_len, n_kv_heads * head_dim]

Each thread computes its position as:

const uint position = params.pos + seq_idx;

So if you are prefilling tokens starting at position 0 with seq_len=512, thread with seq_idx=100 processes position 100. The rotated K and V are written directly into the cache at the correct positions.

An interesting implementation detail: the k_elem_offset parameter field is repurposed to carry seq_len in the batch variants. This avoids changing the params struct and keeps buffer layouts compatible.

The standard-pairing batch variant (rope_batch_kv_write_f16) does the same thing but with interleaved (2i, 2i+1) pairs instead of split-half (i, i+half_dim).

The Ultimate Fusion: Head Norm + NeoX RoPE + KV Write

For Qwen3, which applies per-head RMSNorm to Q and K before RoPE, akunu offers the ultimate fused kernel: head_norm_rope_neox_kv_write_f16. This replaces three separate dispatches per layer:

  1. head_rmsnorm_f16 on Q
  2. head_rmsnorm_f16 on K
  3. rope_neox_qkv_write_f16 for RoPE + KV cache write

All fused into a single dispatch.

kernel void head_norm_rope_neox_kv_write_f16(
    device half                  *qkv          [[buffer(0)]],
    device half                  *k_cache      [[buffer(1)]],
    device half                  *v_cache      [[buffer(2)]],
    constant RoPEQKVWriteParams  &params       [[buffer(3)]],
    device const half            *q_norm_weight [[buffer(4)]],
    device const half            *k_norm_weight [[buffer(5)]],
    constant float               &norm_eps     [[buffer(6)]],
    uint2 tgpig [[threadgroup_position_in_grid]],
    uint  tid_in_tg [[thread_index_in_threadgroup]],
    uint  sgid  [[simdgroup_index_in_threadgroup]],
    uint  slid  [[thread_index_in_simdgroup]]
) {

Dispatch Model

This kernel uses threadgroups rather than a flat grid, because the RMSNorm requires a reduction across the head dimension:

Dispatch: threadgroups = (1, n_heads, 1)
          threads_per_threadgroup = (head_dim/2, 1, 1)

For head_dim=128, n_heads=32:
  32 threadgroups, each with 64 threads
  Total: 2,048 threads (same as the non-fused approach)

Each threadgroup handles one head. The 64 threads within the threadgroup cooperatively compute the RMSNorm reduction, then each thread does its rotation.

Data Flow Per Thread

Let me trace the complete flow for a single thread:

Thread (pair_idx=d, head_idx=h):

  PHASE 1: Q RMSNorm
  +---------------------------------------------------------+
  | 1a. Load Q elements: q0 = qkv[h*hd + d],               |
  |                       q1 = qkv[h*hd + d + hd/2]        |
  | 1b. Compute q_sq = q0*q0 + q1*q1                       |
  | 1c. SIMD reduce: simd_sum(q_sq)                         |
  | 1d. Write to shared[sgid], barrier                      |
  | 1e. Sum shared[], compute q_rms = rsqrt(sum/hd + eps)   |
  | 1f. q0 *= q_rms * q_norm_weight[d]                      |
  |     q1 *= q_rms * q_norm_weight[d + hd/2]               |
  +---------------------------------------------------------+

  PHASE 2: Q NeoX RoPE
  +---------------------------------------------------------+
  | 2a. Compute freq = pos / pow(theta, 2*d/hd)            |
  |     (always computed inline — no freq_scale or          |
  |      precomputed freqs buffer in this fused kernel)     |
  |     cos_f = cos(freq), sin_f = sin(freq)               |
  | 2b. qkv[q_lo] = q0 * cos_f - q1 * sin_f               |
  |     qkv[q_hi] = q1 * cos_f + q0 * sin_f               |
  +---------------------------------------------------------+

  PHASE 3: K RMSNorm + RoPE + Cache Write (if h < n_kv_heads)
  +---------------------------------------------------------+
  | 3a. Load K elements: k0, k1                             |
  | 3b. Barrier (reuse shared memory from Q norm)           |
  | 3c. SIMD reduce k_sq, compute k_rms                    |
  | 3d. k0 *= k_rms * k_norm_weight[d]                     |
  |     k1 *= k_rms * k_norm_weight[d + hd/2]              |
  | 3e. Rotate + write to cache:                            |
  |     k_cache[...] = k0*cos_f - k1*sin_f                 |
  |     k_cache[...] = k1*cos_f + k0*sin_f                 |
  +---------------------------------------------------------+

  PHASE 4: V Cache Write (if h < n_kv_heads)
  +---------------------------------------------------------+
  | 4a. Copy V directly to cache (no norm, no rotation)     |
  +---------------------------------------------------------+

Shared Memory Reuse

The kernel uses only 4 floats of shared memory (for up to 4 SIMD groups with head_dim=256):

threadgroup float shared_sq[4];  // max 4 SIMD groups for head_dim=256

This memory is reused between the Q and K normalization phases. The barrier before the K phase ensures the Q reads from shared_sq are complete before K overwrites them:

// RMSNorm for K (reuse shared memory -- barrier ensures Q reads are done)
threadgroup_barrier(mem_flags::mem_threadgroup);

Reduction Without tg_reduce_sum

Since head_dim/2 threads is at most 128 (4 SIMD groups), the reduction is tiny. The kernel sums across SIMD groups with a simple loop rather than the full tg_reduce_sum helper:

float q_total_sq = 0;
for (uint s = 0; s < n_simd; s++) q_total_sq += shared_sq[s];

For 4 SIMD groups, this is 4 additions – negligible. The simd_sum within each SIMD group does the heavy lifting.

The Performance Impact

Let us quantify the fusion benefit. For a Qwen3 model with 32 heads, head_dim=128, 8 KV heads:

Unfused (per layer, per token):
  Dispatch 1: head_rmsnorm on Q   (2048 threads, 2 barriers)
  Dispatch 2: head_rmsnorm on K   (512 threads, 2 barriers)
  Dispatch 3: rope_neox_qkv_write (RoPE + KV cache write)
  = 3 dispatches per layer

Fused (per layer, per token):
  Dispatch 1: head_norm_rope_neox_kv_write (2048 threads, ~4 barriers)
  = 1 dispatch per layer

For 32 layers: 96 dispatches --> 32 dispatches (saves 64 dispatch overheads)

At maybe 5-10 microseconds per dispatch overhead, that is 800-1600 microseconds saved per token generation. For a model generating at 100 tokens/second, that could be 8-16% of total time. Fusion matters.

Why RoPE Generalizes to Unseen Lengths

The dot product between two rotated vectors depends only on the relative position difference, not the absolute positions:

<RoPE(x, pos_i), RoPE(y, pos_j)> = f(x, y, pos_i - pos_j)

This is the mathematical property that makes RoPE so powerful: a model trained on 4096-token sequences can attend to tokens at position 8000 and 8005 just as well as positions 0 and 5, because the rotation difference is identical. The multi-scale frequency spectrum (visible in the interactive animation above — fast-spinning pair 0 for local patterns, slow-spinning pair 3 for global structure) gives the model a rich positional encoding at every scale.

Summary of RoPE Kernels

+-------------------------------------------+--------+----------+-----------+
| Kernel                                    | Style  | Tokens   | Fusion    |
+-------------------------------------------+--------+----------+-----------+
| rope_f16                                  | Std    | Any      | RoPE only |
| rope_neox_f16                             | NeoX   | Any      | RoPE only |
| rope_qkv_write_f16                        | Std    | 1        | +KV write |
| rope_neox_qkv_write_f16                   | NeoX   | 1        | +KV write |
| rope_batch_kv_write_f16                   | Std    | Batch    | +KV write |
| rope_neox_batch_kv_write_f16              | NeoX   | Batch    | +KV write |
| head_norm_rope_neox_kv_write_f16          | NeoX   | 1        | +Norm+KV  |
+-------------------------------------------+--------+----------+-----------+

Key takeaways:

  1. One thread per rotation pair – embarrassingly parallel, no reductions needed for standalone RoPE.
  2. Two pairing styles – standard (interleaved) and NeoX (split-half), chosen to match the model’s training convention.
  3. Fused dispatches – the single-token decode path fuses RoPE for Q and K with KV cache writes into one dispatch. The Qwen3 path further fuses per-head RMSNorm.
  4. GQA-aware – threads for heads beyond n_kv_heads only rotate Q; those within n_kv_heads also handle K rotation and K/V cache writes.
  5. Precomputed frequencies – avoids per-thread pow() calls during inference.
  6. freq_scale – a single multiplier enables dynamic context length extension.
  7. Function constants – Metal specialization constants compile out the position-ID branch for the common (non-speculative) case.

Next, we will look at the kernels that happen before and after normalization and RoPE: embedding lookups and activation functions.