Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Sampled Decoding and GPU Gumbel-Max

Greedy decoding is clean and deterministic. You run argmax on the logits, you get a token, you move on. But anybody who has used an LLM for more than five minutes knows that greedy decoding produces text that is boring. Repetitive, predictable, and lifeless. If you want creative prose, diverse completions, or anything that feels like genuine language, you need to sample from the probability distribution rather than simply picking the mode.

This chapter covers how akunu implements sampled decoding. We will start with the theory – temperature scaling, the Gumbel-max trick, and the family of filtering methods (top-k, top-p, min-p). Then we will trace through the actual GPU kernel that does all of this in a single dispatch, entirely avoiding a CPU round-trip. Finally, we will look at the CPU fallback path for when the GPU kernel is unavailable.

Why Sampling Matters

Consider a language model predicting the next word after “The cat sat on the”. Greedy decoding will always produce “mat” (or whatever token has the highest logit). Every single time. That is useful for factual Q&A, but terrible for storytelling.

Sampling says: the model assigned 40% probability to “mat”, 15% to “roof”, 10% to “fence”, 8% to “windowsill”, and so on. Let us actually use that distribution. Roll a die weighted by those probabilities and pick a token accordingly. Now different runs produce different continuations, and the text feels more natural.

The tricky part is that naive categorical sampling from the full vocabulary (often 128k+ tokens) can produce garbage. The model assigns tiny but nonzero probability to tokens like “xyzzy” or “<0xFF>” – and occasionally you will land on one. So in practice, everybody applies some combination of temperature scaling and filtering before sampling. Let us walk through each.

Temperature Scaling

Temperature is the simplest knob. Given raw logits z_i from the model, we compute:

p_i = softmax(z / T)_i = exp(z_i / T) / sum_j exp(z_j / T)

where T is the temperature.

  T = 0.0  -->  argmax (greedy)
  T = 1.0  -->  sample from the model's native distribution
  T > 1.0  -->  flatten the distribution (more random)
  T < 1.0  -->  sharpen the distribution (more deterministic)

Visually, here is what temperature does to a toy distribution:

  Logits:   [3.0,  1.0,  0.5, -1.0, -2.0]

  T=0.5 (sharp):
  Prob:     [0.88, 0.05, 0.03, 0.01, 0.00]
            ####################################
            ##
            #
            .
            .

  T=1.0 (native):
  Prob:     [0.57, 0.08, 0.05, 0.01, 0.00]
            ########################
            ###
            ##
            .
            .

  T=2.0 (flat):
  Prob:     [0.36, 0.15, 0.12, 0.06, 0.04]
            ###############
            ######
            #####
            ##
            #

In akunu’s CPU path (sampling.cpp), temperature scaling is applied as multiplication by the inverse temperature:

float inv_temp = 1.0f / temperature;
for (int i = 0; i < vocab_size; i++)
    logits[i] *= inv_temp;

Multiplying all logits by 1/T before softmax is mathematically equivalent to dividing by T inside the exponential. It avoids a division per element.

On the GPU path, the Gumbel-max trick (discussed below) absorbs temperature into the noise magnitude, so there is no separate scaling pass.

The Gumbel-Max Trick

Here is the key insight that makes GPU-native sampling possible.

Theorem (Gumbel-Max): If you add independent Gumbel(0,1) noise to each logit and then take the argmax, the result is a sample from the categorical distribution defined by softmax of those logits.

More precisely, let g_i ~ Gumbel(0,1) be independent. Then:

argmax_i (z_i + g_i) ~ Categorical(softmax(z))

And for temperature scaling:

argmax_i (z_i + T * g_i) ~ Categorical(softmax(z / T))

This is remarkable. It means we can turn sampling into argmax – which is exactly the operation we already have a fast GPU kernel for (greedy decoding). We just need to add appropriately scaled random noise to the logits first.

Here is the pipeline:

  +------------------+     +------------------+     +------------------+
  |  Model forward   | --> | Add Gumbel noise | --> |     Argmax       |
  |  (produces       |     | scaled by temp   |     |  (same kernel    |
  |   logits)        |     | to each logit    |     |   as greedy)     |
  +------------------+     +------------------+     +------------------+
         GPU                      GPU                      GPU
                      No CPU round-trip needed!

Compare this with the traditional approach:

  +------------------+     +----------+     +----------+     +----------+
  |  Model forward   | --> | Copy to  | --> | Softmax  | --> | Sample   |
  |  (produces       |     | CPU      |     | + filter |     | (rand)   |
  |   logits)        |     |          |     |          |     |          |
  +------------------+     +----------+     +----------+     +----------+
         GPU                  sync!            CPU              CPU

The traditional approach requires a GPU-to-CPU synchronization to copy the logits, then CPU work for softmax and sampling, then writing the result back. That synchronization stall can cost 50-200 microseconds per token – which at high throughput becomes a significant fraction of total decode time.

The Gumbel-max trick keeps everything on the GPU. The sampled dispatch table simply inserts the gumbel_topk_f16 kernel between the model forward pass and the existing argmax kernel. No sync, no copy, no stall.

Generating Gumbel Noise

A Gumbel(0,1) random variable is generated from a uniform random variable u ~ Uniform(0,1) via the inverse CDF:

g = -log(-log(u))

On the GPU, we need a fast source of uniform random numbers. We cannot use rand() (no such thing in Metal compute shaders). Instead, we use a PCG (Permuted Congruential Generator) hash function:

inline float pcg_float(uint state) {
    state = state * 747796405u + 2891336453u;
    uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
    word = (word >> 22u) ^ word;
    return float(word) / 4294967296.0f;
}

This is not a PRNG in the traditional sense – it is a hash function. You feed it a seed and get back a pseudo-random float in [0, 1). The seed for each vocabulary element is computed as:

uint element_seed = (params.position + params.seed_offset) * 2654435761u;
// Then for token i:
float u = pcg_float(element_seed + i);

The position is the current sequence position (patched per-token via PATCH_POSITION in the dispatch table), and seed_offset is derived from std::chrono::high_resolution_clock at the start of each generate call. The multiplicative constant 2654435761 is the golden ratio times 2^32, a common choice for hash mixing.

This scheme ensures:

  1. Different tokens at the same position get different noise (seeded by i)
  2. Different positions get different noise (seeded by position)
  3. Different calls get different noise (seeded by seed_offset)
  4. The same (position, seed_offset, i) triple always produces the same noise (reproducibility when seeds are fixed)

The Gumbel noise is then:

float gumbel = -log(-log(u));
logits[i] = half(val + temp * gumbel);

where temp is the temperature parameter. The clamp(u, 1e-7, 1-1e-7) ensures we never take log(0).

Filtering Methods

Pure sampling from the full vocabulary is too noisy. Filtering methods prune the candidate set before sampling. akunu implements three, applied in this order on the GPU: top-k, top-p, and min-p.

Top-K Filtering

Keep only the K tokens with the highest logits. Mask everything else to negative infinity.

  Before top-k (K=3):
  Token:  A     B     C     D     E     F     G
  Logit: 5.2   3.1   2.8   1.5   0.3  -1.0  -2.5
         ^^^   ^^^   ^^^   ---   ---   ---   ---
         keep  keep  keep  mask  mask  mask  mask

  After top-k:
  Logit: 5.2   3.1   2.8  -inf  -inf  -inf  -inf

The GPU kernel finds the k-th largest logit via binary search rather than sorting. This is a clever approach: sorting 128k elements is expensive, but counting how many elements exceed a threshold is cheap (each thread scans its portion in parallel, then a SIMD reduction sums the counts).

// Binary search: 12 iterations to find the k-th largest value
float lo = global_max - 30.0f;
float hi = global_max;

for (int iter = 0; iter < 12; iter++) {
    float mid = (lo + hi) * 0.5f;

    uint local_count = 0;
    for (uint i = tid; i < V; i += 1024)
        local_count += (float(logits[i]) > mid) ? 1 : 0;

    // SIMD reduction
    uint sg_count = simd_sum(local_count);
    if (slid == 0) tg_counts[sgid] = sg_count;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    if (tid == 0) {
        uint total = 0;
        for (uint s = 0; s < 32; s++) total += tg_counts[s];
        tg_counts[0] = total;
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    if (tg_counts[0] > (uint)top_k)
        lo = mid;   // too many survivors, raise threshold
    else
        hi = mid;   // too few, lower threshold
}
threshold = hi;

After 12 iterations of binary search, the threshold converges to within 30 / 2^12 ~ 0.007 logit units – more than precise enough. Each iteration requires only 2 threadgroup barriers and a parallel scan, so the total cost is modest.

The alternative – a full sort – would cost O(V log V) work. The binary search costs O(V * iterations) = O(V * 12), which for V=128k is about the same as a single pass of merge sort. And it parallelizes trivially.

Top-P (Nucleus) Filtering

Top-p keeps the smallest set of tokens whose cumulative probability exceeds p. This is adaptive – when the model is confident, it might keep only 2-3 tokens; when uncertain, it might keep hundreds.

  Sorted by probability:
  Token:   A      B      C      D      E      ...
  Prob:   0.40   0.25   0.15   0.10   0.05   ...
  CumSum: 0.40   0.65   0.80   0.90   0.95   ...
                                ^^^^
                         top_p=0.9 cutoff here

  Keep A, B, C, D.  Mask E and everything after.

On the GPU, akunu again uses binary search. It first computes the total exp-sum for tokens above the current threshold (from top-k), then searches for the logit value where the cumulative probability reaches top_p:

if (top_p < 1.0f && top_p > 0.0f) {
    // Compute total exp(logit - max) for survivors
    float total_exp = ...;  // parallel reduction
    float target_exp = top_p * total_exp;

    // Binary search: 8 iterations
    float lo_p = threshold;
    float hi_p = global_max;
    for (int iter = 0; iter < 8; iter++) {
        float mid_p = (lo_p + hi_p) * 0.5f;
        // Count exp-sum for logits >= mid_p
        // If sum > target, raise threshold (too many)
        // If sum <= target, lower threshold (too few)
    }
    threshold = max(threshold, hi_p);
}

The key insight: threshold can only go up from top-p, never down. If top-k already restricted us to 40 tokens, top-p can further restrict to (say) 10, but it will never expand beyond 40. That is why we take max(threshold, hi_p).

Min-P Filtering

Min-p is the newest and arguably most elegant filtering method. It keeps all tokens whose probability is at least min_p times the probability of the most likely token.

  Most likely token has probability 0.40
  min_p = 0.1

  Threshold = 0.40 * 0.1 = 0.04

  Token:  A      B      C      D      E      F
  Prob:  0.40   0.25   0.15   0.10   0.05   0.02
         keep   keep   keep   keep   keep   MASK

Since probabilities are proportional to exp(logit), and the max probability corresponds to exp(global_max), the min-p condition becomes:

exp(logit_i) >= min_p * exp(global_max)
logit_i >= global_max + log(min_p)

This is a single threshold comparison – no sorting, no cumulative sums:

if (params.min_p > 0.0f && params.min_p < 1.0f) {
    float min_p_threshold = global_max + log(params.min_p);
    threshold = max(threshold, min_p_threshold);
}

Beautifully simple on the GPU. One log, one add, one max.

Repetition Penalty

Before any filtering happens, the kernel applies repetition penalty to discourage the model from repeating tokens it has already generated. The algorithm is asymmetric:

  • If a logit is positive, divide by the penalty factor
  • If a logit is negative, multiply by the penalty factor
if (rep_penalty > 1.0f && params.position > 0) {
    for (uint i = tid; i < params.position; i += 1024) {
        uint token = token_ids[i + 1];
        if (token < V) {
            float val = float(logits[token]);
            logits[token] = half(val > 0 ? val / rep_penalty : val * rep_penalty);
        }
    }
}

This is Phase 0 of the kernel. Each thread in the 1024-thread threadgroup processes a different previous token (strided access). The token_ids buffer contains the sequence generated so far, indexed from position 1 (position 0 is the input token).

The asymmetric formulation ensures that the penalty always decreases the probability of repeated tokens, regardless of whether their logit is positive or negative. If we just divided all logits by the penalty, negative logits would become less negative (higher probability), which is the opposite of what we want.

The Complete GPU Kernel: Phase by Phase

The gumbel_topk_f16 Metal kernel executes in a single threadgroup of 1024 threads. Here is the full pipeline:

  +-------------------------------------------------------+
  |                gumbel_topk_f16 kernel                  |
  |                                                        |
  |  Phase 0: Repetition penalty                           |
  |    For each prev token (strided across 1024 threads):  |
  |      logit[tok] /= penalty  (if positive)              |
  |      logit[tok] *= penalty  (if negative)              |
  |    [threadgroup barrier]                               |
  |                                                        |
  |  Phase 1: Find global maximum                          |
  |    Each thread: local_max over its strided elements     |
  |    SIMD reduction -> simdgroup maxes                   |
  |    Threadgroup reduction -> single global_max           |
  |    [threadgroup barrier]                               |
  |                                                        |
  |  Phase 2: Find top-k threshold via binary search       |
  |    12 iterations:                                      |
  |      mid = (lo + hi) / 2                               |
  |      Each thread: count elements > mid (strided)       |
  |      SIMD sum -> simdgroup counts                      |
  |      Threadgroup sum -> total count                    |
  |      if total > k: lo = mid  else: hi = mid            |
  |    [2 barriers per iteration = 24 barriers]            |
  |                                                        |
  |  Phase 2b: Top-p via binary search                     |
  |    8 iterations of similar structure                   |
  |    Raises threshold if cumulative prob > top_p          |
  |                                                        |
  |  Phase 2c: Min-p threshold                             |
  |    threshold = max(threshold, global_max + log(min_p)) |
  |                                                        |
  |  Phase 3: Apply mask + Gumbel noise                    |
  |    For each element (strided):                         |
  |      if logit < threshold: logit = -inf                |
  |      else: logit += temp * Gumbel(pcg_hash(seed + i))  |
  +-------------------------------------------------------+
         |
         v
  +-------------------------------------------------------+
  |              argmax_f16 kernel (existing)               |
  |    Standard parallel reduction -> winning token ID      |
  +-------------------------------------------------------+

The entire pipeline – from raw logits to sampled token ID – never leaves the GPU. The sampled_dispatch_table in akunu chains these kernels together as part of the same command buffer as the model forward pass.

Dispatch Table Integration

In decode_sampled.cpp, the function first checks whether a sampled dispatch table exists:

bool have_sampled_table = !state.sampled_dispatch_table.commands.empty();

If it does, the function patches the Gumbel kernel’s parameters directly into the command buffer:

auto& gumbel_cmd = cmds[cmds.size() - 2];  // second-to-last command
// Param layout: [vocab(0), temp(4), pos(8), seed(12),
//                top_k(16), top_p(20), rep_penalty(24), min_p(28)]
memcpy(gumbel_cmd.param_bytes + 4, &temp, sizeof(float));
memcpy(gumbel_cmd.param_bytes + 12, &seed_base, sizeof(uint32_t));

The second-to-last command is the Gumbel kernel (the last is argmax). The seed is derived from the high-resolution clock:

uint32_t seed_base =
    (uint32_t)std::chrono::high_resolution_clock::now()
        .time_since_epoch().count();

The position field is patched per-token by the dispatch table’s PATCH_POSITION mechanism – the same mechanism used for KV cache position indexing during chain decode.

If the top_k parameter is >= 32 bytes, the extended parameters (top_k, top_p, repeat_penalty, min_p) are also patched. This backwards-compatible layout means older compiled dispatch tables (without the extended params) still work – they just do not get filtering.

The Decode Loop

The actual decode loop in decode_sampled is structurally identical to greedy chain decode:

while (generated < max_tokens) {
    int remaining = max_tokens - generated;
    int n = std::min(chunk, remaining);

    state.device->write_buffer(state.scratch.token_ids, &next_token, 4);
    state.device->begin_encoding();
    state.device->encode_dispatch_table(&table, pos, n);
    state.device->end_encoding_sync();
    state.kv_cache.advance(n);
    pos += n;

    uint32_t *tokens = (uint32_t *)state.device->buffer_contents(
        state.scratch.token_ids);
    for (int i = 0; i < n; i++) {
        uint32_t tok = tokens[i + 1];
        // ... emit token, check EOS, update next_token ...
    }
}

The chunk size is state.chip.chain_decode_chunk, typically 1 for sampled decode (since each token depends on the previous sample). But the dispatch table could support multi-token chunks if the Gumbel kernel were extended to produce multiple samples per dispatch.

The beauty of this design is that from the loop’s perspective, sampled decode and greedy decode look identical. The only difference is which dispatch table gets encoded: state.dispatch_table (greedy argmax) vs state.sampled_dispatch_table (Gumbel + argmax).

CPU Fallback Path

When the sampled dispatch table is not available (e.g., the Metal library was not compiled with the Gumbel kernel, or the model format does not support it), decode_sampled falls back to using the regular dispatch table:

DispatchTable& table = have_sampled_table
    ? state.sampled_dispatch_table
    : state.dispatch_table;

In this case, the regular dispatch table produces raw logits (no argmax), and the CPU sampling pipeline in sampling.cpp takes over. We will cover that pipeline in detail in Chapter 32, but the high-level flow is:

  GPU: model forward -> raw logits in buffer
  CPU: copy F16 logits -> convert to F32 -> temperature scale
       -> softmax -> top-k (partial sort) -> min-p filter
       -> top-p filter -> renormalize -> categorical sample

The CPU path is more flexible (easier to add new filtering methods) but slower due to the GPU-CPU synchronization.

Parameter Struct

The AkunuSamplingConfig struct defines all sampling knobs:

  +-----+-----+-----+-----+-----+
  |  0  |  4  |  8  | 12  | 16  |  byte offset
  +-----+-----+-----+-----+-----+
  |temp | topk| topp| minp| rep |
  |f32  | i32 | f32 | f32 | f32 |
  +-----+-----+-----+-----+-----+
  20 bytes total, 4-byte aligned

Default values:

  • temperature = 0.0 (greedy – no sampling)
  • top_k = 40
  • top_p = 0.9
  • min_p = 0.0 (disabled)
  • repeat_penalty = 1.0 (disabled)

A temperature of 0 triggers the greedy path; any positive temperature engages sampled decode.

Practical Considerations

Reproducibility. If you fix seed_offset (by seeding from a known value rather than the clock), the GPU Gumbel-max path is fully deterministic for a given (model, prompt, parameters) tuple. The PCG hash is deterministic, and the Gumbel-max trick is a monotonic function of the noise – so the same noise produces the same token.

Quality of randomness. The PCG hash is not cryptographically secure, and the per-element seeds are correlated (they differ by 1). In practice, this does not matter for LLM sampling. The quality of the generated text is dominated by the logit values, not by the precise distribution of the noise. Any reasonable hash function would work.

Filter interaction. The three filters (top-k, top-p, min-p) interact additively – each can only raise the threshold, never lower it. This means the order does not matter mathematically, but on the GPU, doing top-k first is most efficient because it establishes a tight initial threshold that makes the top-p binary search converge faster.

Throughput. The Gumbel kernel dispatches with grid=(1), threadgroup=(1024). That is a single threadgroup of 1024 threads. On Apple Silicon, a single GPU core can handle this in microseconds. The kernel is latency-bound, not throughput-bound – it does O(V) work per phase, but V=128k fits comfortably in the threadgroup’s register file and shared memory.

When to use sampling. For code generation, math, and factual Q&A, greedy decode (or low temperature like 0.1) usually wins. For creative writing, conversation, and brainstorming, temperature 0.7-1.0 with top-p 0.9 is a solid default. Min-p around 0.05-0.1 is increasingly popular as a replacement for top-k, since it adapts to the model’s confidence level.

Summary

Akunu’s sampled decode takes advantage of the Gumbel-max trick to keep the entire sampling pipeline on the GPU. The single gumbel_topk_f16 Metal kernel implements repetition penalty, global max finding, top-k via binary search, top-p via binary search, min-p filtering, and Gumbel noise injection in one dispatch of 1024 threads. The existing argmax kernel then produces the final token. No CPU round-trip, no synchronization stall, identical throughput to greedy decoding.

When the GPU kernel is unavailable, the CPU fallback in sampling.cpp provides the same functionality with explicit F16-to-F32 conversion, softmax, partial sort, and categorical sampling. We will dissect that path in Chapter 32.