Sampling Kernels
The previous chapter on the sampling pipeline described the algorithmic flow from logits to tokens. This chapter dives into the Metal kernels that implement sampling on the GPU: the Top-K selector, the Gumbel-max sampler, the argmax reduction, and the repetition penalty kernel. These kernels live in backend/metal/kernels/metal/kernel/sampling/.
The GPU sampling path is critical for Akunu’s chain decode performance: by keeping sampling entirely on the GPU, the engine avoids a CPU roundtrip for every generated token. The Gumbel-max kernel is particularly important because it enables sampled generation at the same throughput as greedy decoding.
Kernel 1: Argmax (argmax_f16)
Argmax is the simplest sampling kernel: find the index of the maximum value in the logit buffer. Despite its simplicity, getting it right on a GPU requires a two-level reduction.
Dispatch
Grid: (1, 1, 1) — single threadgroup
Threadgroup: (1024, 1, 1) — 1024 threads = 32 SIMD groups
Phase 1: Thread-Local Scan
float best_val = -INFINITY;
uint best_idx = 0;
for (uint i = tid; i < N; i += tg_size) {
float v = float(logits[i]);
if (v > best_val) { best_val = v; best_idx = i; }
}
Each thread scans a strided subset of the vocabulary. For a 128K vocabulary with 1024 threads, each thread examines 128 elements.
Phase 2: SIMD Reduction
for (uint offset = SIMD_WIDTH / 2; offset > 0; offset >>= 1) {
float other_val = simd_shuffle_down(best_val, offset);
uint other_idx = simd_shuffle_down(best_idx, offset);
if (slid + offset < SIMD_WIDTH && other_val > best_val) {
best_val = other_val;
best_idx = other_idx;
}
}
This is a classic butterfly reduction within a SIMD group. Each step halves the active lanes:
Step 1: lanes 0-15 compare with lanes 16-31 (offset=16)
Step 2: lanes 0-7 compare with lanes 8-15 (offset=8)
Step 3: lanes 0-3 compare with lanes 4-7 (offset=4)
Step 4: lanes 0-1 compare with lanes 2-3 (offset=2)
Step 5: lane 0 compares with lane 1 (offset=1)
After 5 steps, lane 0 of each SIMD group holds the group’s local winner.
Phase 3: Cross-SIMD-Group Reduction
threadgroup float shared_val[32];
threadgroup uint shared_idx[32];
if (slid == 0) {
shared_val[sgid] = best_val;
shared_idx[sgid] = best_idx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgid == 0 && slid < n_sg) {
best_val = shared_val[slid];
best_idx = shared_idx[slid];
// Same butterfly reduction on 32 SG winners
for (uint offset = SIMD_WIDTH / 2; offset > 0; offset >>= 1) {
// ...
}
if (slid == 0) *result = best_idx;
}
The 32 SIMD group winners are written to threadgroup shared memory, then a single SIMD group performs a final butterfly reduction. The overall winner is written to the result buffer.
Total work: One strided scan of the vocabulary + two levels of reduction. For 128K vocabulary, this takes approximately 5-10 microseconds on Apple Silicon.
Kernel 2: Gumbel Top-K (gumbel_topk_f16)
The Gumbel-max kernel is the most complex sampling kernel. It performs the complete sampling pipeline on the GPU in a single dispatch:
Dispatch: grid=(1), threadgroup=(1024)
One threadgroup of 1024 threads processes the entire vocabulary. The kernel executes four phases:
Phase 0: Repetition Penalty
float rep_penalty = params.repeat_penalty;
if (rep_penalty > 1.0f && params.position > 0) {
for (uint i = tid; i < params.position; i += 1024) {
uint token = token_ids[i + 1];
if (token < V) {
float val = float(logits[token]);
logits[token] = half(val > 0 ? val / rep_penalty : val * rep_penalty);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
Each thread processes a strided subset of previously generated tokens, applying the same asymmetric penalty as the CPU path (divide positive logits, multiply negative logits by the penalty factor).
Phase 1: Find Global Maximum
float local_max = -INFINITY;
for (uint i = tid; i < V; i += 1024)
local_max = max(local_max, float(logits[i]));
local_max = simd_max(local_max);
if (slid == 0) tg_vals[sgid] = local_max;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float m = tg_vals[0];
for (uint s = 1; s < 32; s++) m = max(m, tg_vals[s]);
tg_vals[0] = m;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float global_max = tg_vals[0];
A standard two-level max reduction: SIMD-first with simd_max, then cross-SIMD via threadgroup memory.
Phase 2: Binary Search for K-th Largest
This is the most innovative part of the kernel. Instead of sorting the entire vocabulary to find the top-K, it uses binary search on the value domain to find the threshold value below which exactly K elements survive:
- Init:
lo = global_max - 30,hi = global_max - Iteration 1:
mid = (lo+hi)/2– count elements > mid (1024 threads in parallel). If count > K:lo = mid, else:hi = mid - Iteration 2: Narrower range, count again, adjust bounds
- … 12 iterations total (2^12 = 4096x precision refinement)
- Result: threshold separates top-K from rest
float threshold = global_max;
if (top_k > 1) {
float lo = global_max - 30.0f;
float hi = global_max;
for (int iter = 0; iter < 12; iter++) {
float mid = (lo + hi) * 0.5f;
uint local_count = 0;
for (uint i = tid; i < V; i += 1024)
local_count += (float(logits[i]) > mid) ? 1 : 0;
// SIMD reduction
uint sg_count = simd_sum(local_count);
if (slid == 0) tg_counts[sgid] = sg_count;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
uint total = 0;
for (uint s = 0; s < 32; s++) total += tg_counts[s];
tg_counts[0] = total;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tg_counts[0] > (uint)top_k)
lo = mid; // too many survivors, raise threshold
else
hi = mid; // too few survivors, lower threshold
}
threshold = hi;
}
Each iteration of the binary search counts how many elements exceed the threshold. The count uses a parallel scan (1024 threads, each checking ~128 elements), followed by a SIMD + threadgroup reduction.
12 iterations of binary search with a range of 30.0 gives a precision of 30 / 2^12 ≈ 0.007, which is more than sufficient to separate top-K from the rest. Each iteration requires 2 barriers and one full vocabulary scan, for a total of 24 barriers and 12 scans.
Why binary search instead of sorting? Sorting 128K elements on a GPU is expensive (O(N log N)). The binary search approach is O(N * log(range/precision)) ≈ O(N * 12), which is effectively O(N) and much faster. It does not find the exact K-th element – it finds a threshold that separates approximately K elements from the rest – but this is sufficient for sampling.1
Phase 2b: Top-P Filtering
If top-p is enabled, the kernel performs a second binary search to find the probability threshold:
if (top_p < 1.0f && top_p > 0.0f) {
// Compute softmax sum for survivors
float local_exp_sum = 0;
for (uint i = tid; i < V; i += 1024) {
float val = float(logits[i]);
if (val >= threshold)
local_exp_sum += fast::exp(val - global_max);
}
// Reduce to get total_exp
// Binary search for probability mass threshold
float target_exp = top_p * total_exp;
float lo_p = threshold;
float hi_p = global_max;
for (int iter = 0; iter < 8; iter++) {
// Count exp sum above mid_p
// Adjust bounds
}
threshold = max(threshold, hi_p);
}
This is a binary search on the cumulative probability mass, not on raw logit values. 8 iterations are sufficient because the range is narrower (already within the top-K region).
Phase 2c: Min-P Filtering
Min-P is the simplest filter – just a threshold relative to the maximum:
if (params.min_p > 0.0f && params.min_p < 1.0f) {
float min_p_threshold = global_max + log(params.min_p);
threshold = max(threshold, min_p_threshold);
}
Since P(token) ∝ exp(logit), the condition P(token) < min_p * P(max_token) becomes logit < max_logit + log(min_p). This is a single comparison with no iteration needed.
Phase 3: Apply Mask and Gumbel Noise
uint element_seed = (params.position + params.seed_offset) * 2654435761u;
for (uint i = tid; i < V; i += 1024) {
float val = float(logits[i]);
if (val < threshold) {
logits[i] = half(-INFINITY);
} else {
float u = pcg_float(element_seed + i);
u = clamp(u, 1e-7f, 1.0f - 1e-7f);
float gumbel = -log(-log(u));
logits[i] = half(val + temp * gumbel);
}
}
Tokens below the threshold are masked to -inf (they can never win the argmax). Surviving tokens receive Gumbel noise scaled by temperature. The argmax of these perturbed logits is equivalent to sampling from the filtered distribution.2
The PCG Hash RNG
inline float pcg_float(uint state) {
state = state * 747796405u + 2891336453u;
uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
word = (word >> 22u) ^ word;
return float(word) / 4294967296.0f;
}
PCG (Permuted Congruential Generator) is a stateless hash function: the same input always produces the same output. The constants are chosen to ensure good statistical properties:3
| Constant | Purpose |
|---|---|
747796405 | LCG multiplier (chosen for good spectral properties) |
2891336453 | LCG increment |
277803737 | Output permutation multiplier |
The output is uniformly distributed in [0, 1) after division by 2^32. The clamp to [1e-7, 1-1e-7] prevents log(0) and log(-0) in the Gumbel noise computation.
The seed (params.position + params.seed_offset) * 2654435761u + i ensures:
- Different noise per token position (via
params.position, patched in chain decode) - Different noise per vocabulary element (via
+ i) - Different noise per generation call (via
params.seed_offset, time-based)
Kernel 3: Top-K Select (topk_select_f16)
The top-K select kernel is an alternative to the Gumbel binary search approach. It explicitly selects the top-K logits using a min-heap algorithm:
Chunked Architecture
Grid: (n_chunks, 1, 1) — one threadgroup per chunk
Threadgroup: (chunk_size, 1, 1) — typically 512 threads
The vocabulary is divided into chunks, and each threadgroup finds its local top-K. The host then merges per-chunk results on the CPU.
Thread-Local Best
for (uint i = tid; i < chunk_size; i += tg_size) {
uint global_idx = base + i;
if (global_idx >= vocab_size) break;
float v = float(logits[global_idx]);
if (v > my_best_val) {
my_best_val = v;
my_best_idx = global_idx;
}
}
Each thread finds its single best candidate from a strided scan.
Min-Heap Selection
// Thread 0 picks top_k from tg_size candidates via min-heap
float heap_vals[MAX_TOP_K]; // MAX_TOP_K = 128
uint heap_idxs[MAX_TOP_K];
uint heap_size = 0;
for (uint s = 0; s < tg_size; ++s) {
float v = tg_vals[s];
uint gi = tg_idxs[s];
if (heap_size < k) {
// Insert and bubble up
heap_vals[heap_size] = v;
heap_idxs[heap_size] = gi;
heap_size++;
uint pos = heap_size - 1;
while (pos > 0) {
uint parent = (pos - 1) / 2;
if (heap_vals[parent] > heap_vals[pos]) break;
// swap
pos = parent;
}
} else if (v > heap_vals[0]) {
// Replace min (heap root) and sift down
heap_vals[0] = v;
heap_idxs[0] = gi;
uint pos = 0;
while (true) {
uint l = 2*pos + 1, r = 2*pos + 2;
uint smallest = pos;
if (l < heap_size && heap_vals[l] < heap_vals[smallest]) smallest = l;
if (r < heap_size && heap_vals[r] < heap_vals[smallest]) smallest = r;
if (smallest == pos) break;
// swap
pos = smallest;
}
}
}
This is a classic min-heap of size K. The heap root always holds the smallest of the K largest values seen so far. When a new candidate exceeds the root, it replaces the root and is sifted down to maintain the heap property.
The complexity is O(tg_size * log(K)) per chunk. With tg_size=512 and K=128, this is about 512 * 7 = 3584 comparisons – fast enough for a single thread.
Comparison with Gumbel Binary Search
| Aspect | Top-K Select | Gumbel Binary Search |
|---|---|---|
| Output | Exact top-K indices + values | Threshold value |
| CPU post-processing | Merge per-chunk results | None (followed by argmax) |
| Chain decode compatible | No (CPU merge needed) | Yes |
| Use case | CPU sampling path | GPU sampling path |
The top-K select kernel produces exact results but requires CPU post-processing, making it unsuitable for chain decode. The Gumbel binary search is approximate (the threshold may not select exactly K elements) but keeps everything on the GPU.
Kernel 4: Repetition Penalty (repetition_penalty_f16)
The standalone repetition penalty kernel applies the penalty to specific token positions in the logit buffer:
kernel void repetition_penalty_f16(
device half *logits, device const uint32_t *token_ids,
constant float &penalty, constant uint32_t &n_tokens,
uint tid [[thread_position_in_grid]]
) {
if (tid >= n_tokens) return;
uint32_t token = token_ids[tid];
float val = float(logits[token]);
logits[token] = half(val > 0 ? val / penalty : val * penalty);
}
Each thread handles one previous token. The dispatch is ceil(n_tokens / threadgroup_size) threadgroups. This kernel is used in the non-chain-decode path (when repetition penalty is applied as a separate step). In the Gumbel kernel, the penalty is applied inline (Phase 0) to avoid an extra dispatch.
Kernel 5: Temperature Scale (temperature_scale_f16)
kernel void temperature_scale_f16(
device half *logits, constant float &inv_temperature,
constant uint32_t &count, uint tid [[thread_position_in_grid]]
) {
if (tid >= count) return;
logits[tid] = half(float(logits[tid]) * inv_temperature);
}
A simple element-wise multiply used for:
- Temperature scaling before sampling (standalone path)
- Embedding scaling in prefill (Gemma)
- Any generic “scale all elements” operation
Putting It All Together: The GPU Sampling Pipeline
In chain decode with sampling enabled, the last two commands in the dispatch table are:
Command N-1: gumbel_topk_f16
- Reads: logits buffer, token_ids (previous tokens)
- Writes: logits buffer (in-place: mask + noise)
- Params: vocab_size, temperature, position (PATCHED), seed_offset,
top_k, top_p, repeat_penalty, min_p
Command N: argmax_f16
- Reads: logits buffer (now with Gumbel noise)
- Writes: token_ids[tok+1] (PATCHED output offset)
The position field in the Gumbel kernel and the output offset in the argmax kernel are patched per-token by the dispatch table mechanism, ensuring each token in the chain gets:
- A unique RNG seed (different position)
- Its own output slot in the token_ids buffer
The rest of the chain decode infrastructure (embedding lookup, layer loop, etc.) is identical to the greedy path. The only difference is two extra kernel dispatches per token.
Summary
Akunu’s sampling kernels demonstrate that sophisticated sampling algorithms can run entirely on the GPU:
- Argmax: Two-level butterfly reduction (SIMD + threadgroup) for O(N) global maximum with 1024 threads.
- Gumbel Top-K: Binary search on value domain for approximate top-K (O(N * iterations)), followed by Gumbel noise and argmax for exact sampling. PCG hash provides stateless RNG.
- Top-K Select: Explicit min-heap per chunk for exact top-K, used by the CPU sampling path.
- Repetition Penalty: Simple per-token penalty application, both standalone and inline.
The Gumbel-max approach is the key enabler for chain decode with sampling: by reducing sampling to noise + argmax, it eliminates the CPU roundtrip that would otherwise break the chain.
Performance Characteristics
The sampling kernels are all dispatched with a single threadgroup, making them latency-sensitive rather than throughput-sensitive. Here is the approximate timing breakdown:
| Kernel | Threads | Iterations | Time (128K vocab) | Bottleneck |
|---|---|---|---|---|
| argmax_f16 | 1024 | 1 scan + 2 reductions | ~5-8 us | Memory bandwidth |
| gumbel_topk_f16 | 1024 | 1 + 12 + 8 + 1 scans | ~30-50 us | Memory bandwidth |
| repetition_penalty_f16 | varies | 1 scan of prev tokens | ~1-5 us | Memory latency |
| temperature_scale_f16 | varies | 1 scan | ~3-5 us | Memory bandwidth |
The Gumbel kernel is the most expensive at ~30-50 microseconds, but this is negligible compared to the ~8ms forward pass. Even in chain decode with 128 tokens, the total sampling overhead is only ~4-6ms (128 * 35us), or less than 1% of the total generation time.
Statistical Quality of GPU Sampling
A natural concern: does the GPU sampling path produce the same distribution as the CPU path? Mathematically, yes – the Gumbel-max trick provides exact samples from the categorical distribution, not an approximation.4 However, there are practical differences:
-
RNG quality: The CPU uses a Mersenne Twister (std::mt19937) with 19937 bits of state, while the GPU uses a PCG hash with effectively 32 bits of state per element. The PCG hash has been validated against TestU01’s BigCrush battery, but its per-element independence relies on the hash quality.
-
Top-K precision: The CPU path uses exact top-K via
std::partial_sort, while the GPU path uses binary search to find an approximate threshold. The binary search may select slightly more or fewer than K elements, but the difference is bounded by the search precision (K +/- 1 at the threshold boundary). -
Top-P precision: Similarly, the GPU’s binary search for the top-P threshold has 8 iterations of precision (~0.03 logit units), which may include or exclude borderline tokens that the CPU path would handle differently.
In practice, these differences are undetectable in the output quality. The sampling distribution is dominated by the high-probability tokens (which are always included by both paths), and the borderline tokens near the cutoff thresholds have negligible probability mass.
Grammar Bitmask Kernel
While not in the sampling/ directory, it is worth mentioning the grammar_bitmask.metal kernel that can apply grammar constraints on the GPU:
grammar_bitmask.metal: Apply a precomputed bitmask to logits
- Sets logits[i] = -inf where bitmask bit i is 0
- Used for XGrammar integration on the GPU path
This kernel enables a hybrid approach: the grammar state machine runs on the CPU to compute the bitmask, but the bitmask application (which touches every logit) runs on the GPU. For vocabularies of 128K+, this saves a significant amount of CPU memory bandwidth.
-
The binary search approach for GPU top-K was pioneered in approximate nearest neighbor search. See: Johnson, J., Douze, M., and Jegou, H. “Billion-Scale Similarity Search with GPUs.” IEEE Transactions on Big Data, 2019. The same principle applies to logit filtering: we do not need the exact K-th element, just a threshold that selects approximately K candidates. ↩
-
Gumbel, E.J. “Statistical Theory of Extreme Values and Some Practical Applications.” National Bureau of Standards Applied Mathematics Series 33, 1954. The Gumbel-max trick for exact categorical sampling is proven in: Maddison, C.J., Tarlow, D., and Minka, T. “A* Sampling.” NeurIPS 2014. See https://arxiv.org/abs/1411.0030. ↩
-
O’Neill, M.E. “PCG: A Family of Simple Fast Space-Efficient Statistically Good Algorithms for Random Number Generation.” Harvey Mudd College Technical Report HMC-CS-2014-0905, 2014. The PCG hash provides excellent statistical properties (passing TestU01’s BigCrush) with minimal state, making it ideal for GPU kernels. See https://www.pcg-random.org/paper.html. ↩
-
The exactness of the Gumbel-max trick is proven in Theorem 1 of Maddison et al. (2014). The proof shows that for any discrete distribution pi, argmax(log(pi_i) + G_i) where G_i are i.i.d. Gumbel(0,1) random variables yields a sample exactly distributed as Categorical(pi). No approximation is involved. ↩