RoPE Kernels
Transformers have no built-in sense of order. Without positional information, the sentence “the cat sat on the mat” is indistinguishable from “mat the on sat cat the.” Rotary Position Embeddings (RoPE) solve this by rotating pairs of elements in the Q and K vectors by position-dependent angles. It is mathematically elegant, it extrapolates to longer sequences than training saw, and it composes beautifully with the attention dot product.
In this chapter we will walk through akunu’s seven RoPE-related Metal kernels: the basic standard and NeoX variants, the fused RoPE+KV-cache-write kernels for single-token decode, and the batch (prefill) versions. We will finish with the ultimate fusion: per-head RMSNorm + NeoX RoPE + KV write in a single dispatch.
The Math Behind RoPE
RoPE treats each pair of elements in a head dimension as a 2D vector and rotates
it by an angle proportional to the token’s position. For a pair at dimension index
d, the rotation angle is:
theta_d = position / (base_freq ^ (2d / head_dim))
Where base_freq (typically 10000.0) controls the frequency spectrum. Low-index
pairs rotate quickly (high frequency), high-index pairs rotate slowly (low
frequency). This gives each position a unique “fingerprint” of rotations.
The rotation itself is just a 2D rotation matrix:
[x0'] [cos(theta) -sin(theta)] [x0]
[x1'] = [sin(theta) cos(theta)] [x1]
x0' = x0 * cos(theta) - x1 * sin(theta)
x1' = x0 * sin(theta) + x1 * cos(theta)
Now, different model families pair up elements differently. This is where the “standard” vs “NeoX” distinction comes in.
Standard (Interleaved) vs NeoX (Split-Half) Pairing
The two RoPE styles differ in which elements they pair for rotation. The animation below shows both. Use the position slider to see how vectors rotate — low-index pairs spin fast, high-index pairs spin slow.
Standard pairing comes from the original LLaMA GGUF format. NeoX pairing (named
after GPT-NeoX) is used by HuggingFace models including Qwen3, and matches the
rotate_half function in the Python reference code:
# HuggingFace rotate_half:
x1 = x[..., :head_dim//2]
x2 = x[..., head_dim//2:]
rotated = cat(-x2, x1, dim=-1)
output = x * cos + rotated * sin
Both styles are mathematically equivalent in terms of expressiveness – the model just needs to learn weights that match the pairing convention. But you must use the same convention as the model was trained with, or the positions will be garbage.
The Standard RoPE Kernel
Let us start with the simplest kernel, rope_f16:
kernel void rope_f16(
device half *x [[buffer(0)]],
constant RoPEParams ¶ms [[buffer(1)]],
device const float *freqs [[buffer(2)]],
uint3 tid [[thread_position_in_grid]]
) {
const uint pair_idx = tid.x; // [0 .. head_dim/2)
const uint head_idx = tid.y; // [0 .. n_heads)
const uint seq_idx = tid.z; // [0 .. seq_len)
const uint half_dim = params.head_dim / 2;
if (pair_idx >= half_dim || head_idx >= params.n_heads
|| seq_idx >= params.seq_len) return;
const uint pos = seq_idx + params.pos_offset;
float freq_divisor = (freqs != nullptr) ? freqs[pair_idx]
: pow(params.theta, float(2 * pair_idx) / float(params.head_dim));
float freq = float(pos) / freq_divisor;
float cos_f = cos(freq);
float sin_f = sin(freq);
uint base = seq_idx * stride + head_idx * params.head_dim + 2 * pair_idx;
float x0 = float(x[base]);
float x1 = float(x[base + 1]);
x[base] = half(x0 * cos_f - x1 * sin_f);
x[base + 1] = half(x0 * sin_f + x1 * cos_f);
}
Dispatch: One Thread per Rotation
The dispatch grid is 3D:
Grid: (head_dim/2, n_heads, seq_len)
For head_dim=128, n_heads=32, seq_len=1 (decode):
Total threads = 64 * 32 * 1 = 2,048
For head_dim=128, n_heads=32, seq_len=512 (prefill):
Total threads = 64 * 32 * 512 = 1,048,576
Each thread handles exactly one pair of elements. No shared memory, no reductions,
no barriers. This is a pure embarrassingly-parallel kernel – every thread reads
two elements, computes a rotation, and writes two elements back. The only shared
state is the params constant buffer.
Precomputed Frequencies
Notice the frequency computation has two paths:
float freq_divisor = (freqs != nullptr) ? freqs[pair_idx]
: pow(params.theta, float(2 * pair_idx) / float(params.head_dim));
If a precomputed frequency table is provided in buffer(2), the kernel reads from
it directly. Otherwise it computes theta^(2d/head_dim) on the fly using pow.
The precomputed path avoids a transcendental function (pow) per thread. For
head_dim=128, that is 64 pow calls saved per head per position. The host
precomputes these once at initialization:
freqs[d] = theta^(2*d / head_dim) for d in [0, head_dim/2)
Then the kernel just does pos / freqs[d], which is a simple float division.
In-Place Operation
The kernel modifies x in place. This is safe because each thread works on a
unique pair of elements:
Thread (pair=0, head=0, seq=0): reads/writes x[0], x[1]
Thread (pair=1, head=0, seq=0): reads/writes x[2], x[3]
Thread (pair=0, head=1, seq=0): reads/writes x[128], x[129]
...
No two threads ever touch the same memory location. No synchronization needed.
The NeoX RoPE Kernel
The NeoX variant pairs elements at distance head_dim/2:
kernel void rope_neox_f16(
device half *x [[buffer(0)]],
constant RoPEParams ¶ms [[buffer(1)]],
device const uint *position_ids [[buffer(2)]],
device const float *freqs [[buffer(3)]],
uint3 tid [[thread_position_in_grid]]
) {
...
// NeoX-style: pair element i with element i + head_dim/2
uint idx0 = base + pair_idx;
uint idx1 = base + pair_idx + half_dim;
float x0 = float(x[idx0]);
float x1 = float(x[idx1]);
x[idx0] = half(x0 * cos_f - x1 * sin_f);
x[idx1] = half(x1 * cos_f + x0 * sin_f);
}
The indexing difference visualized:
Standard: base + 2*pair_idx, base + 2*pair_idx + 1
(adjacent elements)
NeoX: base + pair_idx, base + pair_idx + half_dim
(elements half a head apart)
For head_dim=128, pair_idx=5:
Standard: elements 10, 11
NeoX: elements 5, 69
Function Constants for Position IDs
The NeoX kernel has an interesting feature – Metal function constants for per-token position IDs:
constant bool FC_USE_POSITION_IDS [[function_constant(0)]];
constant bool FC_HAS_POSITION_IDS = is_function_constant_defined(FC_USE_POSITION_IDS)
&& FC_USE_POSITION_IDS;
...
const uint pos = FC_HAS_POSITION_IDS ? position_ids[seq_idx]
: (seq_idx + params.pos_offset);
When FC_USE_POSITION_IDS is true, each token reads its position from a separate
buffer. This is needed for tree speculation, where tokens do not have sequential
positions – you might be evaluating multiple speculative continuations in parallel,
each at a different position in the sequence.
When false, positions are simply seq_idx + offset, which is the common case for
normal sequential generation. The function constant lets Metal compile out the
branch entirely, so there is zero overhead in the non-speculative path.
Fused RoPE + KV Cache Write: Eliminating Four Dispatches
Now we get to the real workhorses. During single-token decode (the hot path for autoregressive generation), each transformer layer needs to:
- Apply RoPE to Q (all heads)
- Apply RoPE to K (KV heads only)
- Write rotated K to the KV cache
- Write V to the KV cache (no rotation)
Naively, that is four kernel dispatches. Akunu fuses all four into one.
Here is the standard-pairing fused kernel:
kernel void rope_qkv_write_f16(
device half *qkv [[buffer(0)]],
device half *k_cache [[buffer(1)]],
device half *v_cache [[buffer(2)]],
constant RoPEQKVWriteParams ¶ms [[buffer(3)]],
device const float *freqs [[buffer(4)]],
uint2 tid [[thread_position_in_grid]]
) {
const uint pair_idx = tid.x; // [0, head_dim/2)
const uint head_idx = tid.y; // [0, n_heads)
...
The dispatch is 2D: (head_dim/2, n_heads). Each thread:
- Computes cos/sin for this position and dimension pair
- Rotates Q in-place for its head
- If
head_idx < n_kv_heads, also rotates K and writes both K and V to cache
Let us trace the data flow:
QKV Buffer Layout (contiguous):
+-----------------------------------------------+
| Q: [n_heads * head_dim] |
+-----------------------------------------------+
| K: [n_kv_heads * head_dim] | <-- at k_elem_offset
+-----------------------------------------------+
| V: [n_kv_heads * head_dim] | <-- at v_elem_offset
+-----------------------------------------------+
Thread (pair=d, head=h):
1. Compute freq = pos * freq_scale / freq_divisor
cos_f = cos(freq), sin_f = sin(freq)
2. Q rotation (all heads):
q_src = h * head_dim + 2*d
q0' = q0 * cos_f - q1 * sin_f --> qkv[q_src]
q1' = q0 * sin_f + q1 * cos_f --> qkv[q_src + 1]
3. K rotation + cache write (head h < n_kv_heads only):
k_src = k_elem_offset + h * head_dim + 2*d
k0' = k0 * cos_f - k1 * sin_f --> k_cache[cache_base + 2*d]
k1' = k0 * sin_f + k1 * cos_f --> k_cache[cache_base + 2*d + 1]
4. V cache write (straight copy, no rotation):
v_cache[cache_base + 2*d] = qkv[v_src + 2*d]
v_cache[cache_base + 2*d + 1] = qkv[v_src + 2*d + 1]
The KV Cache Layout
The cache stores K and V in head-major order:
K Cache: [n_kv_heads, max_seq_len, head_dim]
V Cache: [n_kv_heads, max_seq_len, head_dim]
cache_base = head_idx * max_seq_len * head_dim + pos * head_dim
For head 0, position 42, head_dim 128:
cache_base = 0 * max_seq_len * 128 + 42 * 128 = 42 * 128 = 5376
For head 3, position 42, head_dim 128:
cache_base = 3 * max_seq_len * 128 + 42 * 128
This layout means that for a given head, all positions are contiguous in memory. During attention, when you need to load K/V for all past positions of one head, the access pattern is a simple sequential scan – perfect for GPU memory throughput.
freq_scale: Dynamic Frequency Scaling
Notice the frequency computation includes a freq_scale parameter:
float freq = float(params.pos) * params.freq_scale / freq_divisor;
This enables dynamic RoPE scaling techniques like Linear Scaling (divide
frequencies by a factor to extend context length) and NTK-aware scaling. If
freq_scale = 1.0, you get standard RoPE. If freq_scale = 0.5, all frequencies
are halved, effectively doubling the model’s positional resolution and extending
its context window.
GQA-Aware Thread Assignment
The kernel elegantly handles Grouped Query Attention (GQA), where n_heads > n_kv_heads. All threads rotate Q (for all n_heads heads), but only threads
where head_idx < n_kv_heads do the K/V work:
n_heads=32, n_kv_heads=8 (GQA ratio 4:1):
Threads with head_idx 0-7: Rotate Q + Rotate K + Write K + Write V
Threads with head_idx 8-31: Rotate Q only
head: 0 1 2 3 4 5 6 7 8 9 10 ... 31
Q: * * * * * * * * * * * ... *
K/V: * * * * * * * * . . . ... .
This is a minor load imbalance – some threads do more work than others. But since the extra work (K rotation + 2 cache writes) is just a few memory ops, the imbalance is negligible compared to the cost of an extra kernel dispatch.
NeoX Fused Variant
The NeoX fused kernel (rope_neox_qkv_write_f16) is structurally identical to the
standard variant, but with split-half indexing:
// Standard: q_src = q_base + 2 * pair_idx, q_src + 1
// NeoX: q_lo = q_base + pair_idx
// q_hi = q_base + pair_idx + half_dim
uint q_lo = q_base + pair_idx;
uint q_hi = q_base + pair_idx + half_dim;
float q0 = float(qkv[q_lo]);
float q1 = float(qkv[q_hi]);
qkv[q_lo] = half(q0 * cos_f - q1 * sin_f);
qkv[q_hi] = half(q1 * cos_f + q0 * sin_f);
Same thread count, same dispatch shape, same GQA handling. The only difference is which pairs of memory locations get rotated together.
Batch (Prefill) Variants
During prefill (processing the initial prompt), there are multiple tokens to
process simultaneously. The single-token fused kernels assume seq_len = 1.
The batch variants add a third grid dimension for sequence position:
Single-token dispatch: (head_dim/2, n_heads, 1)
Batch dispatch: (head_dim/2, n_heads, seq_len)
The batch NeoX kernel (rope_neox_batch_kv_write_f16) handles a key difference:
Q, K, and V are separate buffers (already split by a preceding QKV split kernel),
rather than one contiguous QKV buffer:
Single-token:
buffer(0) = qkv [contiguous Q|K|V]
Batch:
buffer(0) = batchQ [seq_len, n_heads * head_dim]
buffer(1) = batchK [seq_len, n_kv_heads * head_dim]
buffer(2) = batchV [seq_len, n_kv_heads * head_dim]
Each thread computes its position as:
const uint position = params.pos + seq_idx;
So if you are prefilling tokens starting at position 0 with seq_len=512, thread
with seq_idx=100 processes position 100. The rotated K and V are written directly
into the cache at the correct positions.
An interesting implementation detail: the k_elem_offset parameter field is
repurposed to carry seq_len in the batch variants. This avoids changing the
params struct and keeps buffer layouts compatible.
The standard-pairing batch variant (rope_batch_kv_write_f16) does the same
thing but with interleaved (2i, 2i+1) pairs instead of split-half (i, i+half_dim).
The Ultimate Fusion: Head Norm + NeoX RoPE + KV Write
For Qwen3, which applies per-head RMSNorm to Q and K before RoPE, akunu offers
the ultimate fused kernel: head_norm_rope_neox_kv_write_f16. This replaces
three separate dispatches per layer:
head_rmsnorm_f16on Qhead_rmsnorm_f16on Krope_neox_qkv_write_f16for RoPE + KV cache write
All fused into a single dispatch.
kernel void head_norm_rope_neox_kv_write_f16(
device half *qkv [[buffer(0)]],
device half *k_cache [[buffer(1)]],
device half *v_cache [[buffer(2)]],
constant RoPEQKVWriteParams ¶ms [[buffer(3)]],
device const half *q_norm_weight [[buffer(4)]],
device const half *k_norm_weight [[buffer(5)]],
constant float &norm_eps [[buffer(6)]],
uint2 tgpig [[threadgroup_position_in_grid]],
uint tid_in_tg [[thread_index_in_threadgroup]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint slid [[thread_index_in_simdgroup]]
) {
Dispatch Model
This kernel uses threadgroups rather than a flat grid, because the RMSNorm requires a reduction across the head dimension:
Dispatch: threadgroups = (1, n_heads, 1)
threads_per_threadgroup = (head_dim/2, 1, 1)
For head_dim=128, n_heads=32:
32 threadgroups, each with 64 threads
Total: 2,048 threads (same as the non-fused approach)
Each threadgroup handles one head. The 64 threads within the threadgroup cooperatively compute the RMSNorm reduction, then each thread does its rotation.
Data Flow Per Thread
Let me trace the complete flow for a single thread:
Thread (pair_idx=d, head_idx=h):
PHASE 1: Q RMSNorm
+---------------------------------------------------------+
| 1a. Load Q elements: q0 = qkv[h*hd + d], |
| q1 = qkv[h*hd + d + hd/2] |
| 1b. Compute q_sq = q0*q0 + q1*q1 |
| 1c. SIMD reduce: simd_sum(q_sq) |
| 1d. Write to shared[sgid], barrier |
| 1e. Sum shared[], compute q_rms = rsqrt(sum/hd + eps) |
| 1f. q0 *= q_rms * q_norm_weight[d] |
| q1 *= q_rms * q_norm_weight[d + hd/2] |
+---------------------------------------------------------+
PHASE 2: Q NeoX RoPE
+---------------------------------------------------------+
| 2a. Compute freq = pos / pow(theta, 2*d/hd) |
| (always computed inline — no freq_scale or |
| precomputed freqs buffer in this fused kernel) |
| cos_f = cos(freq), sin_f = sin(freq) |
| 2b. qkv[q_lo] = q0 * cos_f - q1 * sin_f |
| qkv[q_hi] = q1 * cos_f + q0 * sin_f |
+---------------------------------------------------------+
PHASE 3: K RMSNorm + RoPE + Cache Write (if h < n_kv_heads)
+---------------------------------------------------------+
| 3a. Load K elements: k0, k1 |
| 3b. Barrier (reuse shared memory from Q norm) |
| 3c. SIMD reduce k_sq, compute k_rms |
| 3d. k0 *= k_rms * k_norm_weight[d] |
| k1 *= k_rms * k_norm_weight[d + hd/2] |
| 3e. Rotate + write to cache: |
| k_cache[...] = k0*cos_f - k1*sin_f |
| k_cache[...] = k1*cos_f + k0*sin_f |
+---------------------------------------------------------+
PHASE 4: V Cache Write (if h < n_kv_heads)
+---------------------------------------------------------+
| 4a. Copy V directly to cache (no norm, no rotation) |
+---------------------------------------------------------+
Shared Memory Reuse
The kernel uses only 4 floats of shared memory (for up to 4 SIMD groups with head_dim=256):
threadgroup float shared_sq[4]; // max 4 SIMD groups for head_dim=256
This memory is reused between the Q and K normalization phases. The barrier before
the K phase ensures the Q reads from shared_sq are complete before K overwrites
them:
// RMSNorm for K (reuse shared memory -- barrier ensures Q reads are done)
threadgroup_barrier(mem_flags::mem_threadgroup);
Reduction Without tg_reduce_sum
Since head_dim/2 threads is at most 128 (4 SIMD groups), the reduction is tiny.
The kernel sums across SIMD groups with a simple loop rather than the full
tg_reduce_sum helper:
float q_total_sq = 0;
for (uint s = 0; s < n_simd; s++) q_total_sq += shared_sq[s];
For 4 SIMD groups, this is 4 additions – negligible. The simd_sum within each
SIMD group does the heavy lifting.
The Performance Impact
Let us quantify the fusion benefit. For a Qwen3 model with 32 heads, head_dim=128, 8 KV heads:
Unfused (per layer, per token):
Dispatch 1: head_rmsnorm on Q (2048 threads, 2 barriers)
Dispatch 2: head_rmsnorm on K (512 threads, 2 barriers)
Dispatch 3: rope_neox_qkv_write (RoPE + KV cache write)
= 3 dispatches per layer
Fused (per layer, per token):
Dispatch 1: head_norm_rope_neox_kv_write (2048 threads, ~4 barriers)
= 1 dispatch per layer
For 32 layers: 96 dispatches --> 32 dispatches (saves 64 dispatch overheads)
At maybe 5-10 microseconds per dispatch overhead, that is 800-1600 microseconds saved per token generation. For a model generating at 100 tokens/second, that could be 8-16% of total time. Fusion matters.
Why RoPE Generalizes to Unseen Lengths
The dot product between two rotated vectors depends only on the relative position difference, not the absolute positions:
<RoPE(x, pos_i), RoPE(y, pos_j)> = f(x, y, pos_i - pos_j)
This is the mathematical property that makes RoPE so powerful: a model trained on 4096-token sequences can attend to tokens at position 8000 and 8005 just as well as positions 0 and 5, because the rotation difference is identical. The multi-scale frequency spectrum (visible in the interactive animation above — fast-spinning pair 0 for local patterns, slow-spinning pair 3 for global structure) gives the model a rich positional encoding at every scale.
Summary of RoPE Kernels
+-------------------------------------------+--------+----------+-----------+
| Kernel | Style | Tokens | Fusion |
+-------------------------------------------+--------+----------+-----------+
| rope_f16 | Std | Any | RoPE only |
| rope_neox_f16 | NeoX | Any | RoPE only |
| rope_qkv_write_f16 | Std | 1 | +KV write |
| rope_neox_qkv_write_f16 | NeoX | 1 | +KV write |
| rope_batch_kv_write_f16 | Std | Batch | +KV write |
| rope_neox_batch_kv_write_f16 | NeoX | Batch | +KV write |
| head_norm_rope_neox_kv_write_f16 | NeoX | 1 | +Norm+KV |
+-------------------------------------------+--------+----------+-----------+
Key takeaways:
- One thread per rotation pair – embarrassingly parallel, no reductions needed for standalone RoPE.
- Two pairing styles – standard (interleaved) and NeoX (split-half), chosen to match the model’s training convention.
- Fused dispatches – the single-token decode path fuses RoPE for Q and K with KV cache writes into one dispatch. The Qwen3 path further fuses per-head RMSNorm.
- GQA-aware – threads for heads beyond
n_kv_headsonly rotate Q; those withinn_kv_headsalso handle K rotation and K/V cache writes. - Precomputed frequencies – avoids per-thread
pow()calls during inference. - freq_scale – a single multiplier enables dynamic context length extension.
- Function constants – Metal specialization constants compile out the position-ID branch for the common (non-speculative) case.
Next, we will look at the kernels that happen before and after normalization and RoPE: embedding lookups and activation functions.