Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Sampling Kernels

The previous chapter on the sampling pipeline described the algorithmic flow from logits to tokens. This chapter dives into the Metal kernels that implement sampling on the GPU: the Top-K selector, the Gumbel-max sampler, the argmax reduction, and the repetition penalty kernel. These kernels live in backend/metal/kernels/metal/kernel/sampling/.

The GPU sampling path is critical for Akunu’s chain decode performance: by keeping sampling entirely on the GPU, the engine avoids a CPU roundtrip for every generated token. The Gumbel-max kernel is particularly important because it enables sampled generation at the same throughput as greedy decoding.

Kernel 1: Argmax (argmax_f16)

Argmax is the simplest sampling kernel: find the index of the maximum value in the logit buffer. Despite its simplicity, getting it right on a GPU requires a two-level reduction.

Dispatch

Grid: (1, 1, 1) — single threadgroup
Threadgroup: (1024, 1, 1) — 1024 threads = 32 SIMD groups

Phase 1: Thread-Local Scan

float best_val = -INFINITY;
uint  best_idx = 0;
for (uint i = tid; i < N; i += tg_size) {
    float v = float(logits[i]);
    if (v > best_val) { best_val = v; best_idx = i; }
}

Each thread scans a strided subset of the vocabulary. For a 128K vocabulary with 1024 threads, each thread examines 128 elements.

Phase 2: SIMD Reduction

for (uint offset = SIMD_WIDTH / 2; offset > 0; offset >>= 1) {
    float other_val = simd_shuffle_down(best_val, offset);
    uint  other_idx = simd_shuffle_down(best_idx, offset);
    if (slid + offset < SIMD_WIDTH && other_val > best_val) {
        best_val = other_val;
        best_idx = other_idx;
    }
}

This is a classic butterfly reduction within a SIMD group. Each step halves the active lanes:

Step 1: lanes 0-15 compare with lanes 16-31  (offset=16)
Step 2: lanes 0-7  compare with lanes 8-15   (offset=8)
Step 3: lanes 0-3  compare with lanes 4-7    (offset=4)
Step 4: lanes 0-1  compare with lanes 2-3    (offset=2)
Step 5: lane 0     compares with lane 1      (offset=1)

After 5 steps, lane 0 of each SIMD group holds the group’s local winner.

Phase 3: Cross-SIMD-Group Reduction

threadgroup float shared_val[32];
threadgroup uint  shared_idx[32];

if (slid == 0) {
    shared_val[sgid] = best_val;
    shared_idx[sgid] = best_idx;
}
threadgroup_barrier(mem_flags::mem_threadgroup);

if (sgid == 0 && slid < n_sg) {
    best_val = shared_val[slid];
    best_idx = shared_idx[slid];
    // Same butterfly reduction on 32 SG winners
    for (uint offset = SIMD_WIDTH / 2; offset > 0; offset >>= 1) {
        // ...
    }
    if (slid == 0) *result = best_idx;
}

The 32 SIMD group winners are written to threadgroup shared memory, then a single SIMD group performs a final butterfly reduction. The overall winner is written to the result buffer.

Total work: One strided scan of the vocabulary + two levels of reduction. For 128K vocabulary, this takes approximately 5-10 microseconds on Apple Silicon.

Kernel 2: Gumbel Top-K (gumbel_topk_f16)

The Gumbel-max kernel is the most complex sampling kernel. It performs the complete sampling pipeline on the GPU in a single dispatch:

Dispatch: grid=(1), threadgroup=(1024)

One threadgroup of 1024 threads processes the entire vocabulary. The kernel executes four phases:

Phase 0: Repetition Penalty

float rep_penalty = params.repeat_penalty;
if (rep_penalty > 1.0f && params.position > 0) {
    for (uint i = tid; i < params.position; i += 1024) {
        uint token = token_ids[i + 1];
        if (token < V) {
            float val = float(logits[token]);
            logits[token] = half(val > 0 ? val / rep_penalty : val * rep_penalty);
        }
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);
}

Each thread processes a strided subset of previously generated tokens, applying the same asymmetric penalty as the CPU path (divide positive logits, multiply negative logits by the penalty factor).

Phase 1: Find Global Maximum

float local_max = -INFINITY;
for (uint i = tid; i < V; i += 1024)
    local_max = max(local_max, float(logits[i]));

local_max = simd_max(local_max);
if (slid == 0) tg_vals[sgid] = local_max;
threadgroup_barrier(mem_flags::mem_threadgroup);

if (tid == 0) {
    float m = tg_vals[0];
    for (uint s = 1; s < 32; s++) m = max(m, tg_vals[s]);
    tg_vals[0] = m;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float global_max = tg_vals[0];

A standard two-level max reduction: SIMD-first with simd_max, then cross-SIMD via threadgroup memory.

Phase 2: Binary Search for K-th Largest

This is the most innovative part of the kernel. Instead of sorting the entire vocabulary to find the top-K, it uses binary search on the value domain to find the threshold value below which exactly K elements survive:

  1. Init: lo = global_max - 30, hi = global_max
  2. Iteration 1: mid = (lo+hi)/2 – count elements > mid (1024 threads in parallel). If count > K: lo = mid, else: hi = mid
  3. Iteration 2: Narrower range, count again, adjust bounds
  4. 12 iterations total (2^12 = 4096x precision refinement)
  5. Result: threshold separates top-K from rest
float threshold = global_max;
if (top_k > 1) {
    float lo = global_max - 30.0f;
    float hi = global_max;

    for (int iter = 0; iter < 12; iter++) {
        float mid = (lo + hi) * 0.5f;

        uint local_count = 0;
        for (uint i = tid; i < V; i += 1024)
            local_count += (float(logits[i]) > mid) ? 1 : 0;

        // SIMD reduction
        uint sg_count = simd_sum(local_count);
        if (slid == 0) tg_counts[sgid] = sg_count;
        threadgroup_barrier(mem_flags::mem_threadgroup);

        if (tid == 0) {
            uint total = 0;
            for (uint s = 0; s < 32; s++) total += tg_counts[s];
            tg_counts[0] = total;
        }
        threadgroup_barrier(mem_flags::mem_threadgroup);

        if (tg_counts[0] > (uint)top_k)
            lo = mid;  // too many survivors, raise threshold
        else
            hi = mid;  // too few survivors, lower threshold
    }
    threshold = hi;
}

Each iteration of the binary search counts how many elements exceed the threshold. The count uses a parallel scan (1024 threads, each checking ~128 elements), followed by a SIMD + threadgroup reduction.

12 iterations of binary search with a range of 30.0 gives a precision of 30 / 2^12 ≈ 0.007, which is more than sufficient to separate top-K from the rest. Each iteration requires 2 barriers and one full vocabulary scan, for a total of 24 barriers and 12 scans.

Why binary search instead of sorting? Sorting 128K elements on a GPU is expensive (O(N log N)). The binary search approach is O(N * log(range/precision)) ≈ O(N * 12), which is effectively O(N) and much faster. It does not find the exact K-th element – it finds a threshold that separates approximately K elements from the rest – but this is sufficient for sampling.1

Phase 2b: Top-P Filtering

If top-p is enabled, the kernel performs a second binary search to find the probability threshold:

if (top_p < 1.0f && top_p > 0.0f) {
    // Compute softmax sum for survivors
    float local_exp_sum = 0;
    for (uint i = tid; i < V; i += 1024) {
        float val = float(logits[i]);
        if (val >= threshold)
            local_exp_sum += fast::exp(val - global_max);
    }
    // Reduce to get total_exp
    // Binary search for probability mass threshold
    float target_exp = top_p * total_exp;
    float lo_p = threshold;
    float hi_p = global_max;
    for (int iter = 0; iter < 8; iter++) {
        // Count exp sum above mid_p
        // Adjust bounds
    }
    threshold = max(threshold, hi_p);
}

This is a binary search on the cumulative probability mass, not on raw logit values. 8 iterations are sufficient because the range is narrower (already within the top-K region).

Phase 2c: Min-P Filtering

Min-P is the simplest filter – just a threshold relative to the maximum:

if (params.min_p > 0.0f && params.min_p < 1.0f) {
    float min_p_threshold = global_max + log(params.min_p);
    threshold = max(threshold, min_p_threshold);
}

Since P(token) ∝ exp(logit), the condition P(token) < min_p * P(max_token) becomes logit < max_logit + log(min_p). This is a single comparison with no iteration needed.

Phase 3: Apply Mask and Gumbel Noise

uint element_seed = (params.position + params.seed_offset) * 2654435761u;
for (uint i = tid; i < V; i += 1024) {
    float val = float(logits[i]);
    if (val < threshold) {
        logits[i] = half(-INFINITY);
    } else {
        float u = pcg_float(element_seed + i);
        u = clamp(u, 1e-7f, 1.0f - 1e-7f);
        float gumbel = -log(-log(u));
        logits[i] = half(val + temp * gumbel);
    }
}

Tokens below the threshold are masked to -inf (they can never win the argmax). Surviving tokens receive Gumbel noise scaled by temperature. The argmax of these perturbed logits is equivalent to sampling from the filtered distribution.2

The PCG Hash RNG

inline float pcg_float(uint state) {
    state = state * 747796405u + 2891336453u;
    uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
    word = (word >> 22u) ^ word;
    return float(word) / 4294967296.0f;
}

PCG (Permuted Congruential Generator) is a stateless hash function: the same input always produces the same output. The constants are chosen to ensure good statistical properties:3

ConstantPurpose
747796405LCG multiplier (chosen for good spectral properties)
2891336453LCG increment
277803737Output permutation multiplier

The output is uniformly distributed in [0, 1) after division by 2^32. The clamp to [1e-7, 1-1e-7] prevents log(0) and log(-0) in the Gumbel noise computation.

The seed (params.position + params.seed_offset) * 2654435761u + i ensures:

  • Different noise per token position (via params.position, patched in chain decode)
  • Different noise per vocabulary element (via + i)
  • Different noise per generation call (via params.seed_offset, time-based)

Kernel 3: Top-K Select (topk_select_f16)

The top-K select kernel is an alternative to the Gumbel binary search approach. It explicitly selects the top-K logits using a min-heap algorithm:

Chunked Architecture

Grid: (n_chunks, 1, 1) — one threadgroup per chunk
Threadgroup: (chunk_size, 1, 1) — typically 512 threads

The vocabulary is divided into chunks, and each threadgroup finds its local top-K. The host then merges per-chunk results on the CPU.

Thread-Local Best

for (uint i = tid; i < chunk_size; i += tg_size) {
    uint global_idx = base + i;
    if (global_idx >= vocab_size) break;
    float v = float(logits[global_idx]);
    if (v > my_best_val) {
        my_best_val = v;
        my_best_idx = global_idx;
    }
}

Each thread finds its single best candidate from a strided scan.

Min-Heap Selection

// Thread 0 picks top_k from tg_size candidates via min-heap
float heap_vals[MAX_TOP_K];  // MAX_TOP_K = 128
uint  heap_idxs[MAX_TOP_K];
uint  heap_size = 0;

for (uint s = 0; s < tg_size; ++s) {
    float v = tg_vals[s];
    uint gi = tg_idxs[s];
    if (heap_size < k) {
        // Insert and bubble up
        heap_vals[heap_size] = v;
        heap_idxs[heap_size] = gi;
        heap_size++;
        uint pos = heap_size - 1;
        while (pos > 0) {
            uint parent = (pos - 1) / 2;
            if (heap_vals[parent] > heap_vals[pos]) break;
            // swap
            pos = parent;
        }
    } else if (v > heap_vals[0]) {
        // Replace min (heap root) and sift down
        heap_vals[0] = v;
        heap_idxs[0] = gi;
        uint pos = 0;
        while (true) {
            uint l = 2*pos + 1, r = 2*pos + 2;
            uint smallest = pos;
            if (l < heap_size && heap_vals[l] < heap_vals[smallest]) smallest = l;
            if (r < heap_size && heap_vals[r] < heap_vals[smallest]) smallest = r;
            if (smallest == pos) break;
            // swap
            pos = smallest;
        }
    }
}

This is a classic min-heap of size K. The heap root always holds the smallest of the K largest values seen so far. When a new candidate exceeds the root, it replaces the root and is sifted down to maintain the heap property.

The complexity is O(tg_size * log(K)) per chunk. With tg_size=512 and K=128, this is about 512 * 7 = 3584 comparisons – fast enough for a single thread.

AspectTop-K SelectGumbel Binary Search
OutputExact top-K indices + valuesThreshold value
CPU post-processingMerge per-chunk resultsNone (followed by argmax)
Chain decode compatibleNo (CPU merge needed)Yes
Use caseCPU sampling pathGPU sampling path

The top-K select kernel produces exact results but requires CPU post-processing, making it unsuitable for chain decode. The Gumbel binary search is approximate (the threshold may not select exactly K elements) but keeps everything on the GPU.

Kernel 4: Repetition Penalty (repetition_penalty_f16)

The standalone repetition penalty kernel applies the penalty to specific token positions in the logit buffer:

kernel void repetition_penalty_f16(
    device half *logits, device const uint32_t *token_ids,
    constant float &penalty, constant uint32_t &n_tokens,
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= n_tokens) return;
    uint32_t token = token_ids[tid];
    float val = float(logits[token]);
    logits[token] = half(val > 0 ? val / penalty : val * penalty);
}

Each thread handles one previous token. The dispatch is ceil(n_tokens / threadgroup_size) threadgroups. This kernel is used in the non-chain-decode path (when repetition penalty is applied as a separate step). In the Gumbel kernel, the penalty is applied inline (Phase 0) to avoid an extra dispatch.

Kernel 5: Temperature Scale (temperature_scale_f16)

kernel void temperature_scale_f16(
    device half *logits, constant float &inv_temperature,
    constant uint32_t &count, uint tid [[thread_position_in_grid]]
) {
    if (tid >= count) return;
    logits[tid] = half(float(logits[tid]) * inv_temperature);
}

A simple element-wise multiply used for:

  • Temperature scaling before sampling (standalone path)
  • Embedding scaling in prefill (Gemma)
  • Any generic “scale all elements” operation

Putting It All Together: The GPU Sampling Pipeline

In chain decode with sampling enabled, the last two commands in the dispatch table are:

Command N-1: gumbel_topk_f16
  - Reads: logits buffer, token_ids (previous tokens)
  - Writes: logits buffer (in-place: mask + noise)
  - Params: vocab_size, temperature, position (PATCHED), seed_offset,
            top_k, top_p, repeat_penalty, min_p

Command N:   argmax_f16
  - Reads: logits buffer (now with Gumbel noise)
  - Writes: token_ids[tok+1] (PATCHED output offset)

The position field in the Gumbel kernel and the output offset in the argmax kernel are patched per-token by the dispatch table mechanism, ensuring each token in the chain gets:

  1. A unique RNG seed (different position)
  2. Its own output slot in the token_ids buffer

The rest of the chain decode infrastructure (embedding lookup, layer loop, etc.) is identical to the greedy path. The only difference is two extra kernel dispatches per token.

Summary

Akunu’s sampling kernels demonstrate that sophisticated sampling algorithms can run entirely on the GPU:

  1. Argmax: Two-level butterfly reduction (SIMD + threadgroup) for O(N) global maximum with 1024 threads.
  2. Gumbel Top-K: Binary search on value domain for approximate top-K (O(N * iterations)), followed by Gumbel noise and argmax for exact sampling. PCG hash provides stateless RNG.
  3. Top-K Select: Explicit min-heap per chunk for exact top-K, used by the CPU sampling path.
  4. Repetition Penalty: Simple per-token penalty application, both standalone and inline.

The Gumbel-max approach is the key enabler for chain decode with sampling: by reducing sampling to noise + argmax, it eliminates the CPU roundtrip that would otherwise break the chain.

Performance Characteristics

The sampling kernels are all dispatched with a single threadgroup, making them latency-sensitive rather than throughput-sensitive. Here is the approximate timing breakdown:

KernelThreadsIterationsTime (128K vocab)Bottleneck
argmax_f1610241 scan + 2 reductions~5-8 usMemory bandwidth
gumbel_topk_f1610241 + 12 + 8 + 1 scans~30-50 usMemory bandwidth
repetition_penalty_f16varies1 scan of prev tokens~1-5 usMemory latency
temperature_scale_f16varies1 scan~3-5 usMemory bandwidth

The Gumbel kernel is the most expensive at ~30-50 microseconds, but this is negligible compared to the ~8ms forward pass. Even in chain decode with 128 tokens, the total sampling overhead is only ~4-6ms (128 * 35us), or less than 1% of the total generation time.

Statistical Quality of GPU Sampling

A natural concern: does the GPU sampling path produce the same distribution as the CPU path? Mathematically, yes – the Gumbel-max trick provides exact samples from the categorical distribution, not an approximation.4 However, there are practical differences:

  1. RNG quality: The CPU uses a Mersenne Twister (std::mt19937) with 19937 bits of state, while the GPU uses a PCG hash with effectively 32 bits of state per element. The PCG hash has been validated against TestU01’s BigCrush battery, but its per-element independence relies on the hash quality.

  2. Top-K precision: The CPU path uses exact top-K via std::partial_sort, while the GPU path uses binary search to find an approximate threshold. The binary search may select slightly more or fewer than K elements, but the difference is bounded by the search precision (K +/- 1 at the threshold boundary).

  3. Top-P precision: Similarly, the GPU’s binary search for the top-P threshold has 8 iterations of precision (~0.03 logit units), which may include or exclude borderline tokens that the CPU path would handle differently.

In practice, these differences are undetectable in the output quality. The sampling distribution is dominated by the high-probability tokens (which are always included by both paths), and the borderline tokens near the cutoff thresholds have negligible probability mass.

Grammar Bitmask Kernel

While not in the sampling/ directory, it is worth mentioning the grammar_bitmask.metal kernel that can apply grammar constraints on the GPU:

grammar_bitmask.metal: Apply a precomputed bitmask to logits
  - Sets logits[i] = -inf where bitmask bit i is 0
  - Used for XGrammar integration on the GPU path

This kernel enables a hybrid approach: the grammar state machine runs on the CPU to compute the bitmask, but the bitmask application (which touches every logit) runs on the GPU. For vocabularies of 128K+, this saves a significant amount of CPU memory bandwidth.



  1. The binary search approach for GPU top-K was pioneered in approximate nearest neighbor search. See: Johnson, J., Douze, M., and Jegou, H. “Billion-Scale Similarity Search with GPUs.” IEEE Transactions on Big Data, 2019. The same principle applies to logit filtering: we do not need the exact K-th element, just a threshold that selects approximately K candidates.

  2. Gumbel, E.J. “Statistical Theory of Extreme Values and Some Practical Applications.” National Bureau of Standards Applied Mathematics Series 33, 1954. The Gumbel-max trick for exact categorical sampling is proven in: Maddison, C.J., Tarlow, D., and Minka, T. “A* Sampling.” NeurIPS 2014. See https://arxiv.org/abs/1411.0030.

  3. O’Neill, M.E. “PCG: A Family of Simple Fast Space-Efficient Statistically Good Algorithms for Random Number Generation.” Harvey Mudd College Technical Report HMC-CS-2014-0905, 2014. The PCG hash provides excellent statistical properties (passing TestU01’s BigCrush) with minimal state, making it ideal for GPU kernels. See https://www.pcg-random.org/paper.html.

  4. The exactness of the Gumbel-max trick is proven in Theorem 1 of Maddison et al. (2014). The proof shows that for any discrete distribution pi, argmax(log(pi_i) + G_i) where G_i are i.i.d. Gumbel(0,1) random variables yields a sample exactly distributed as Categorical(pi). No approximation is involved.