Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

The Decode Loop

After prefill populates the KV cache and produces the first token, the engine enters the decode loop: the main generation loop that produces tokens one at a time. This is where the user experiences “streaming” – each token appears as soon as it is generated, and the loop continues until a stop condition is met.

In Akunu, the decode loop lives in src/inference/decode_loop.cpp, and the four decode paths it can dispatch are declared in src/inference/decode_paths.h. The loop itself is surprisingly compact – most of the complexity lives in the individual decode paths – but the orchestration logic contains several subtleties worth understanding.

The Four Decode Paths

Akunu does not have a single “decode” function. Instead, it offers four distinct decode paths, each optimized for a different scenario:

                    ┌─────────────────────────┐
                    │   run_decode_loop()      │
                    │                          │
                    │   After prefill:         │
                    │   Choose decode path     │
                    └──────────┬──────────────┘
                               │
               Has grammar?    │
              ┌────────────────┼────────────────────┐
              │ YES            │ NO                  │
              ▼                │                     │
    ┌─────────────────┐        │                     │
    │ decode_grammar   │        │   temperature > 0?  │
    │ (constrained)    │        ├──────────────────┐  │
    └─────────────────┘        │ YES              │ NO│
                               ▼                  │  │
                    ┌─────────────────┐           │  │
                    │ decode_sampled   │           │  │
                    │ (Gumbel-max GPU) │           │  │
                    └─────────────────┘           │  │
                                                  │  │
                                    speculation?  │  │
                                   ┌──────────────┤  │
                                   │ YES          │NO│
                                   ▼              ▼  │
                    ┌──────────────────┐ ┌──────────────┐
                    │decode_speculative│ │decode_greedy  │
                    │(n-gram draft)    │ │(chain decode) │
                    └──────────────────┘ └──────────────┘

Here is the decision logic from the source:

if (grammar) {
    generated += decode_grammar(state, model, next_token, pos, max_tokens - generated,
                                sampling, prompt_tokens, n_prompt, *grammar,
                                callback, user_data);
} else {
    bool use_sampling = (sampling.temperature > 0.0f);
    if (use_sampling) {
        generated += decode_sampled(state, model, next_token, pos,
                                    max_tokens - generated, sampling,
                                    prompt_tokens, n_prompt, callback, user_data);
    } else if (state.speculation_enabled) {
        generated += decode_speculative(state, model, next_token, pos,
                                        max_tokens - generated, prompt_tokens,
                                        n_prompt, callback, user_data);
    } else {
        generated += decode_greedy(state, model, next_token, pos,
                                   max_tokens - generated, callback, user_data);
    }
}

Let’s characterize each path:

PathWhen UsedGPU RoundtripsSamplingCPU Involvement
decode_greedytemperature=0, no grammar, no speculation1 per chunkargmax on GPUMinimal – read token IDs
decode_sampledtemperature>0, no grammar1 per chunkGumbel-max on GPUMinimal – read token IDs
decode_speculativetemperature=0, speculation enabled1 per batchargmax on GPUn-gram prediction
decode_grammargrammar constraints active1 per tokenCPU (grammar mask)Heavy – grammar state per token

The Decode Loop Entry Point

Let’s walk through run_decode_loop in detail. The function signature reveals the full set of inputs:

AkunuGenerationStats run_decode_loop(
    ModelState& state, akunu_model_t model,
    const uint32_t *prompt_tokens, int n_prompt,
    int start_pos, int max_tokens,
    AkunuSamplingConfig sampling,
    akunu_token_callback callback, void *user_data,
    GrammarHandle *grammar = nullptr);

The callback is what enables streaming. Every time a token is generated, the callback receives both the token ID and its decoded text:

typedef bool (*akunu_token_callback)(uint32_t token, const char *text, void *user_data);

If the callback returns false, generation stops immediately. This is how the host application can implement user cancellation or stop-sequence detection at the application level.

Phase 1: Prefill

The first thing the decode loop does is run prefill (covered in the previous chapter):

auto prefill_start = std::chrono::high_resolution_clock::now();
int chunk_size = state.scratch.max_prefill_chunk;
uint32_t next_token = 0;
int prefill_pos = 0;
while (prefill_pos < n_prompt) {
    int chunk = std::min(chunk_size, n_prompt - prefill_pos);
    next_token = encode_prefill(*state.device, *state.weights, state.config,
                                state.arch, state.kv_cache, state.scratch,
                                prompt_tokens + prefill_pos, chunk,
                                start_pos + prefill_pos);
    prefill_pos += chunk;
}
auto prefill_end = std::chrono::high_resolution_clock::now();

The timing is precise: only the prefill computation itself is measured, not the decode loop setup.

Phase 2: First Token Handling

The first generated token comes from the prefill argmax (or from sampling when grammar is active). This token gets special treatment:

int generated = 0;
int pos = start_pos + n_prompt;
bool stopped = false;

if (grammar) {
    // Read logits, apply grammar mask, sample/argmax
} else {
    // Use prefill's argmax result directly
    if (!state.tokenizer.is_eos(next_token)) {
        generated++;
        if (callback) {
            const char *text = decode_token_text(state, next_token);
            if (!callback(next_token, text, user_data))
                stopped = true;
        }
    } else {
        stopped = true;
    }
}

When there is no grammar, the first token is essentially “free” – it came from the prefill’s argmax and requires no additional GPU work. When grammar is active, the first token requires reading the F16 logits from GPU memory, converting to F32, applying the grammar bitmask, and then sampling:

const __fp16 *f16 = (const __fp16 *)state.device->buffer_contents(state.scratch.logits);
for (uint32_t i = 0; i < vocab_count; i++)
    logits[i] = (float)f16[i];

// Apply grammar mask
grammar->legacy.apply(logits, vocab_count);

The F16-to-F32 conversion happens on the CPU here. This is one of the cases where the CPU path is unavoidable: grammar bitmasks need to be computed on the CPU (they depend on the grammar state machine), and it is cheaper to convert the logits to F32 on the CPU than to launch another GPU kernel just for the conversion.

Phase 3: Main Decode Loop

After the first token, the selected decode path takes over:

auto decode_start = std::chrono::high_resolution_clock::now();

if (!stopped && generated < max_tokens) {
    // ... dispatch to selected path ...
}

auto decode_end = std::chrono::high_resolution_clock::now();

Each decode path returns the number of tokens generated. The decode loop tracks generated and ensures we never exceed max_tokens.

The Grammar-Constrained Path

Grammar-constrained decoding (decode_grammar) is the most complex path because it must synchronize with the grammar state machine on every token. This means:

  1. Run the forward pass on the GPU
  2. Read logits back to CPU
  3. Convert F16 -> F32
  4. Apply the grammar bitmask (setting disallowed tokens to -inf)
  5. Sample or argmax from the masked logits
  6. Update the grammar state machine with the accepted token
  7. Repeat

Akunu supports two grammar backends:

#ifdef AKUNU_HAS_XGRAMMAR
if (grammar->use_xgrammar) {
    int bm_size = grammar->xgrammar.bitmask_size();
    std::vector<int32_t> bm(bm_size);
    grammar->xgrammar.fill_next_token_bitmask(bm.data());
    for (uint32_t i = 0; i < vocab_count; i++) {
        if (!((bm[i / 32] >> (i % 32)) & 1))
            logits[i] = -std::numeric_limits<float>::infinity();
    }
} else
#endif
{
    grammar->legacy.apply(logits, vocab_count);
}

The XGrammar backend uses a compact bitmask representation: one bit per vocabulary token, packed into 32-bit integers. A token is allowed if its corresponding bit is 1. This is remarkably efficient – for a 128K vocabulary, the bitmask is only 16KB.1

Grammar decode is inherently sequential and CPU-bound. Each token requires a full GPU->CPU->GPU roundtrip, which limits throughput to perhaps 20-30 tokens/sec even on fast hardware. This is why Akunu only uses this path when grammar constraints are explicitly requested.

The Sampled Path (GPU-Driven)

The sampled decode path (decode_sampled) is fascinating because it achieves the same throughput as greedy decoding while sampling from the full probability distribution. The secret: Gumbel-max sampling on the GPU.2

The key insight is that argmax(logit + temperature * Gumbel_noise) is equivalent to sampling from Categorical(softmax(logit / temperature)). This means we can replace the entire CPU sampling pipeline with:

  1. Apply Gumbel noise to logits (GPU kernel)
  2. Argmax (GPU kernel)

No CPU roundtrip. No softmax. No random number generation on the CPU.

// The sampled_dispatch_table includes gumbel_temperature + argmax
DispatchTable& table = have_sampled_table
    ? state.sampled_dispatch_table
    : state.dispatch_table;

while (generated < max_tokens) {
    int n = std::min(chunk, remaining);
    state.device->write_buffer(state.scratch.token_ids, &next_token, 4);
    state.device->begin_encoding();
    state.device->encode_dispatch_table(&table, pos, n);
    state.device->end_encoding_sync();
    state.kv_cache.advance(n);
    pos += n;
    // Read results...
}

The sampled_dispatch_table is identical to the dispatch_table (greedy) except that it has a gumbel_topk_f16 kernel inserted before the argmax. The Gumbel kernel also handles top-k, top-p, min-p, and repetition penalty – all on the GPU.

The sampling parameters are patched into the kernel’s parameter buffer before the loop starts:

auto& gumbel_cmd = cmds[cmds.size() - 2];  // second-to-last command
memcpy(gumbel_cmd.param_bytes + 4, &temp, sizeof(float));
memcpy(gumbel_cmd.param_bytes + 12, &seed_base, sizeof(uint32_t));
memcpy(gumbel_cmd.param_bytes + 16, &top_k, sizeof(int32_t));
memcpy(gumbel_cmd.param_bytes + 20, &top_p, sizeof(float));

The position field in the Gumbel params is patched per-token by the dispatch table’s PATCH_POSITION mechanism, which ensures each token gets a unique seed for the Gumbel noise RNG.

The Speculative Path

Speculative decoding (decode_speculative) uses an n-gram predictor to guess multiple tokens ahead, then verifies them in a single batched forward pass:

auto drafts = state.predictor.predict();
int n_draft = std::min((int)drafts.size(), max_tokens - generated - 1);

uint32_t *chain_buf = (uint32_t *)state.device->buffer_contents(
    state.scratch.token_ids);
chain_buf[0] = next_token;
for (int i = 0; i < n_draft; i++)
    chain_buf[i + 1] = drafts[i];

state.device->begin_encoding();
state.device->encode_dispatch_table(&state.dispatch_table, pos, batch_size);
state.device->end_encoding_sync();

The batch of [1 + n_draft] tokens is processed in a single GPU submission using the chain decode mechanism (covered in the next chapter). The GPU produces an argmax token for each position. If the draft token at position i matches the argmax at position i-1, it was correctly predicted and we get a “free” token. The speculative path can generate 2-3 tokens per GPU submission when the n-gram predictor is accurate.

Stop Conditions

The decode loop checks several stop conditions:

  1. EOS token: state.tokenizer.is_eos(tok) – the model emits an end-of-sequence token.
  2. Max tokens reached: generated >= max_tokens.
  3. Callback cancellation: the token callback returns false.
  4. Grammar completion: grammar->xgrammar.is_terminated() || grammar->xgrammar.is_completed().

These checks happen after each token (or chunk of tokens for chain decode), ensuring responsive stopping behavior.

Statistics and Timing

The decode loop returns detailed statistics:

AkunuGenerationStats stats = {};
stats.prompt_tokens = n_prompt;
stats.prefill_time_ms = (float)prefill_ms;
stats.generated_tokens = generated;
stats.decode_time_ms = (float)decode_ms;
stats.decode_tokens_per_sec = (float)(generated * 1000.0 / decode_ms);
stats.prefill_tokens_per_sec = (float)(n_prompt * 1000.0 / prefill_ms);
return stats;

The decode_tokens_per_sec metric is the one that matters most for user experience – it determines how fast text appears on screen.

The Token-by-Token Flow

Let’s trace the lifecycle of a single token through the decode loop in the common case (greedy decode, no grammar):

Token lifecycle (greedy, no grammar):

  1. Write token_id to GPU buffer
          │
          ▼
  2. Encode dispatch table (N commands)
          │
          ▼
  3. GPU executes forward pass
          │
          ▼
  4. Read result token from buffer
          │
          ▼
  5. Stream to callback

Step 2 is the critical one: the dispatch table encodes the entire forward pass as a pre-compiled sequence of GPU commands. The next chapter covers how this dispatch table works and how chain decode amortizes the per-token overhead.

Streaming and Latency

From the user’s perspective, what matters is the time between successive tokens appearing on screen. This is the inter-token latency (ITL), and it is determined by:

ITL = GPU_forward_pass + CPU_token_read + callback_overhead

For chain decode with a chunk size of 64, the GPU processes 64 tokens in one submission, and then the CPU reads all 64 results at once. The effective ITL per token is:

effective_ITL = GPU_time_for_64_tokens / 64

However, the perceived latency for the first token in each chunk is higher because the user sees nothing until the entire chunk completes. This is the tradeoff of chain decode: higher throughput at the cost of chunkier streaming. In practice, with a chunk time of ~500ms for 64 tokens, the burst of tokens appears fast enough that users perceive smooth streaming.

The Dispatch Table Mechanism

The decode loop does not manually encode GPU commands for each token. Instead, it uses a pre-compiled dispatch table (DispatchTable): a flat array of DispatchCmd structs that describe every GPU dispatch needed for one forward pass.

struct DispatchCmd {
    Pipeline pso;
    Buffer buffers[MAX_BUFFERS];
    uint32_t offsets[MAX_BUFFERS];
    int buffer_count;
    uint8_t param_bytes[64];
    int param_size;
    // ...
    PatchType patch_type;
    int patch_offset_1;
    int patch_offset_2;
};

The table is built once during model initialization. During decode, the engine simply iterates the table, patching per-token fields (position, KV sequence length, token buffer offset) and dispatching. The encode_chain function does this:

for (int tok = 0; tok < count; tok++) {
    int pos = start_position + tok;
    for (int c = 0; c < n_cmds; c++) {
        const auto& cmd = cmds[c];
        device.set_pipeline(cmd.pso);
        // Set buffers with per-token offset patching
        // Set params with per-token position patching
        device.dispatch(cmd.grid, cmd.threadgroup);
    }
}

The hot/cold split is deliberate: profiling labels are stored in a separate labels vector so that the commands vector stays cache-dense during the hot path. Since DispatchCmd is a POD struct with fixed-size arrays (no heap allocations), iterating the commands vector generates predictable memory access patterns.

The patch types handle all per-token dynamic data:

Patch TypeWhat Gets PatchedUsed By
PATCH_NONENothingMost commands
PATCH_TOKEN_OFFSETbuffers[0].offset = tok * 4Embedding lookup
PATCH_POSITIONPosition field in paramsRoPE, attention
PATCH_KV_SEQ_LENKV sequence length in paramsAttention
PATCH_POS_AND_KVBoth position and KV lengthFused RoPE+KV
PATCH_ARGMAX_OUTPUTbuffers[1].offset = (tok+1) * 4Argmax output

Double Buffering in Greedy Decode

The greedy path uses a double-buffering strategy for the first chunk:

bool first = true;
while (generated < max_tokens) {
    // ... prepare and encode dispatch table ...
    if (first) {
        state.device->end_encoding_sync();
        first = false;
    } else {
        state.device->end_encoding_async();
        state.device->wait();
    }
}

The first chunk uses end_encoding_sync() (blocks until GPU completes). Subsequent chunks use end_encoding_async() + wait(), which allows overlapping GPU execution with CPU processing of the previous chunk’s results. This is a subtle but important optimization: while the GPU is computing chunk N+1, the CPU is reading results from chunk N, decoding tokens to text, and calling the user’s callback.

Summary

The decode loop is the orchestrator of Akunu’s token generation. It:

  1. Runs prefill to populate the KV cache and get the first token.
  2. Selects one of four decode paths based on sampling config and grammar constraints.
  3. Manages the generate-stream-check loop until a stop condition is met.
  4. Collects timing statistics for performance reporting.

The path selection is the most impactful architectural decision: greedy and sampled paths can use chain decode (next chapter) for maximum throughput, grammar decode must go through the CPU for every token, and speculative decode trades prediction accuracy for throughput.

Token Text Decoding

A subtle but important detail: the callback receives decoded text, not just token IDs. The decode_token_text helper converts a token ID to a UTF-8 string:

const char *text = decode_token_text(state, next_token);
if (!callback(next_token, text, user_data))
    stopped = true;

This decoding happens on the CPU after each token (or chunk of tokens) is read from the GPU. For most tokenizers, this is a simple lookup into a vocabulary table, taking nanoseconds. However, multi-byte UTF-8 characters can span multiple tokens, and the tokenizer must handle partial characters gracefully.

The text is returned as a const char * pointing to an internal buffer that is valid until the next call to decode_token_text. This avoids allocation overhead in the hot path.

Error Handling and Edge Cases

The decode loop handles several edge cases that are easy to overlook:

Empty Prompts

When n_prompt = 0, the prefill loop body never executes, and next_token remains 0. The decode loop will immediately generate from token 0, which is typically a padding or BOS token. In practice, the caller always provides at least the BOS token.

Max Tokens = 0

If max_tokens = 0, the decode paths are never entered, and the function returns with generated = 0 (or 1 if the prefill’s first token was streamed). This is useful for prompt evaluation without generation.

KV Cache Exhaustion

The KV cache has a fixed maximum length. If pos + n would exceed the cache capacity, behavior depends on the cache implementation: ring-buffer caches wrap around, while linear caches simply fail. The decode loop itself does not check for cache exhaustion – that responsibility belongs to the caller.

Very Long Generations

For generations exceeding the chunk size (e.g., generating 10,000 tokens with chunk_size=128), the decode loop iterates approximately 78 times, each time submitting a chunk to the GPU, waiting for completion, streaming results, and looping. The overhead of this outer loop is negligible compared to the GPU computation time.

Memory and State Management

The decode loop operates on ModelState, which bundles all GPU resources:

struct ModelState {
    Device *device;
    WeightProvider *weights;
    AkunuModelConfig config;
    ArchDescriptor arch;
    KVCache kv_cache;
    ScratchBuffers scratch;
    DispatchTable dispatch_table;
    DispatchTable sampled_dispatch_table;
    ChipConfig chip;
    Tokenizer tokenizer;
    // ...
};

The next_token and pos parameters are passed by reference and updated in-place. This allows the caller to resume generation from where it left off (e.g., for multi-turn conversations):

// First turn
stats1 = run_decode_loop(state, model, prompt1, n1, 0, max_tokens, ...);
// pos is now at start_pos + n1 + generated1

// Second turn (continue from where we left off)
stats2 = run_decode_loop(state, model, prompt2, n2, pos, max_tokens, ...);

The KV cache retains all previously computed K/V entries, so the second turn can attend to the first turn’s context without re-prefilling.

Profiling and Debugging

The dispatch table supports per-command labels for GPU profiling:

struct DispatchLabel {
    char text[48];
};

When Metal GPU capture is active, each dispatch in the command buffer carries a label like "L12.attention" or "L5.ffn.down_gemv", making it straightforward to identify performance hotspots in Instruments or Xcode’s GPU debugger.

The labels are stored in a cold parallel vector (DispatchTable::labels) separate from the hot command array (DispatchTable::commands), ensuring the profiling metadata does not pollute the cache during the decode hot path.



  1. XGrammar is a high-performance grammar engine from the TVM team. See: Dong, Y., et al. “XGrammar: Flexible and Efficient Structured Generation Engine for Large Language Models.” MLSys 2025 (arXiv:2411.15100, 2024). The bitmask approach is O(vocab_size) to apply but very cache-friendly. See https://arxiv.org/abs/2411.15100.

  2. Gumbel, E. J. “Statistical Theory of Extreme Values and Some Practical Applications.” National Bureau of Standards, 1954. The Gumbel-max trick is described in: Maddison, C.J., et al. “A* Sampling.” NeurIPS 2014 and Jang, E., et al. “Categorical Reparameterization with Gumbel-Softmax.” ICLR 2017. The key result: argmax(log(pi) + G_i) where G_i ~ Gumbel(0,1) samples from Categorical(pi). See https://arxiv.org/abs/1411.0030.