Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Greedy Decoding and Chain Decode

Greedy decoding is the simplest generation strategy: at each step, pick the token with the highest logit. No temperature, no sampling, no randomness – just argmax. Despite its simplicity, the implementation in Akunu is anything but trivial, because Akunu’s greedy path does not generate one token per GPU submission. Instead, it generates an entire chain of tokens in a single GPU command buffer.

This chapter covers decode_greedy in src/inference/decode_greedy.cpp, the dispatch table replay mechanism, and the chain decode technique that makes greedy generation remarkably fast.

What Makes Greedy Special

When temperature is zero, the entire sampling pipeline collapses to a single operation: argmax. There is no need for softmax, no random number generation, no top-k filtering. And crucially, the argmax result is deterministic – given the same input, you always get the same output.

This determinism enables a powerful optimization: if the GPU can compute argmax as the last step of the forward pass, it can immediately feed the result as the input to the next forward pass, without ever returning control to the CPU. This is chain decode.

The Chain Decode Concept

The idea is deceptively simple. Instead of:

CPU: write token → GPU: forward pass → CPU: read result → CPU: write next token → ...

We do:

CPU: write token → GPU: [forward pass → argmax → forward pass → argmax → ... × N] → CPU: read N results

The entire chain of N tokens is encoded into a single GPU command buffer. The GPU executes the full sequence without any CPU intervention.

Traditional: One GPU roundtrip per token
GPU
CPU
GPU
CPU
GPU
CPU
GPU
CPU
Chain Decode: One GPU submission for N tokens
GPU: token 1 → token 2 → token 3 → ... → token N
CPU: read all

The benefit is eliminating the CPU-GPU synchronization overhead that occurs between tokens. On Apple Silicon, each end_encoding_sync() call costs roughly 20-50 microseconds in Metal command buffer overhead. At 50 tokens/sec, this overhead is negligible. But chain decode also eliminates the command buffer creation overhead, which can be 100-300 microseconds per submission. For a chunk of 64 tokens, we save ~63 command buffer creations.

The decode_greedy Implementation

Let’s read the actual code:

int decode_greedy(ModelState& state, akunu_model_t model,
                  uint32_t& next_token, int& pos, int max_tokens,
                  akunu_token_callback callback, void *user_data) {
    int generated = 0;
    int chunk_size = state.chip.chain_decode_chunk;

    bool first = true;
    while (generated < max_tokens) {
        int remaining = max_tokens - generated;
        int n = std::min(chunk_size, remaining);

        state.device->write_buffer(state.scratch.token_ids, &next_token, 4);
        state.device->begin_encoding();
        state.device->encode_dispatch_table(&state.dispatch_table, pos, n);

        if (first) {
            state.device->end_encoding_sync();
            first = false;
        } else {
            state.device->end_encoding_async();
            state.device->wait();
        }
        state.kv_cache.advance(n);
        pos += n;

        uint32_t *tokens = (uint32_t *)state.device->buffer_contents(
            state.scratch.token_ids);
        for (int i = 0; i < n; i++) {
            uint32_t tok = tokens[i + 1];
            generated++;
            if (state.tokenizer.is_eos(tok))
                return generated;
            if (callback) {
                const char *text = decode_token_text(state, tok);
                if (!callback(tok, text, user_data))
                    return generated;
            }
            next_token = tok;
        }
    }
    return generated;
}

There is a lot packed into these 30 lines. Let’s unpack each piece.

Chunk Size Selection

int chunk_size = state.chip.chain_decode_chunk;

The chunk size comes from ChipConfig and varies by hardware:

Hardware TierGPU CoresFamilyChunk Size
M1/M2/M3 Base< 16< 964
M3 Pro>= 16< 996
M4 Base< 16>= 9128
M4 Pro>= 16>= 9128
M-series Max>= 30any128
M-series Ultra>= 60any128

The M4 family gets larger chunks because its GPU command processor is more efficient at handling long command buffers, and its memory subsystem has better bandwidth for the interleaved read-write patterns of chain decode.1

Writing the Input Token

state.device->write_buffer(state.scratch.token_ids, &next_token, 4);

Only 4 bytes are written: the single token ID that starts the chain. The rest of the token_ids buffer will be filled by the GPU as each step’s argmax writes its result.

Encoding the Dispatch Table

state.device->begin_encoding();
state.device->encode_dispatch_table(&state.dispatch_table, pos, n);

This is the heart of chain decode. encode_dispatch_table calls encode_chain, which replays the dispatch table n times with position patching:

inline void encode_chain(Device& device, const DispatchTable& table,
                         int start_position, int count) {
    const auto& cmds = table.commands;
    const int n_cmds = (int)cmds.size();

    for (int tok = 0; tok < count; tok++) {
        int pos = start_position + tok;
        for (int c = 0; c < n_cmds; c++) {
            const auto& cmd = cmds[c];
            device.set_pipeline(cmd.pso);
            // ... set buffers, params, dispatch ...
        }
    }
}

For a 7B model with ~60 dispatches per token and a chunk of 64 tokens, this encodes 60 * 64 = 3840 GPU dispatches into a single command buffer. The Metal runtime batches these efficiently.

The Replicated Dispatch Pattern

Each iteration of the outer loop (each token) replays the same dispatch table, but with patched fields. The per-token patching ensures each forward pass uses the correct:

  • Token ID offset: The embedding lookup reads from token_ids[tok] instead of token_ids[0].
  • Position: RoPE uses the correct absolute position for this token.
  • KV sequence length: Attention knows how many KV entries are valid.
  • Argmax output offset: The argmax result writes to token_ids[tok + 1].

Here is how patching works for the most common types:

// Token offset patching (embedding lookup)
if (cmd.patch_type == DispatchCmd::PATCH_TOKEN_OFFSET && b == 0) {
    offset = tok * 4;  // byte offset to token_ids[tok]
}

// Argmax output patching
if (cmd.patch_type == DispatchCmd::PATCH_ARGMAX_OUTPUT && b == 1) {
    offset = (tok + 1) * 4;  // write result to token_ids[tok+1]
}

// Position patching (RoPE, attention)
if (cmd.patch_type == DispatchCmd::PATCH_POSITION) {
    uint32_t pos_val = (uint32_t)pos;
    memcpy(patched + cmd.patch_offset_1, &pos_val, 4);
}

// Combined position + KV length patching
if (cmd.patch_type == DispatchCmd::PATCH_POS_AND_KV) {
    uint32_t pos_val = (uint32_t)pos;
    uint32_t kv_val = (uint32_t)(pos + 1);
    memcpy(patched + cmd.patch_offset_1, &pos_val, 4);
    memcpy(patched + cmd.patch_offset_2, &kv_val, 4);
}

The token_ids buffer acts as a conveyor belt: token_ids[0] holds the input to the first forward pass, token_ids[1] gets the argmax output of the first pass (and becomes the input to the second pass), and so on.

token_ids buffer:
┌──────┬──────┬──────┬──────┬──────┬──────┐
│  t0  │  t1  │  t2  │  t3  │  t4  │ ...  │
│(input)│(out1)│(out2)│(out3)│(out4)│      │
│      │(in2) │(in3) │(in4) │(in5) │      │
└──────┴──────┴──────┴──────┴──────┴──────┘
   ↑                                   ↑
   CPU writes                    CPU reads all
   one token                     after GPU done

Double Buffering

if (first) {
    state.device->end_encoding_sync();
    first = false;
} else {
    state.device->end_encoding_async();
    state.device->wait();
}

The first chunk uses synchronous execution because there is nothing to overlap with. Subsequent chunks use asynchronous execution: end_encoding_async() submits the command buffer and returns immediately, then wait() blocks until the GPU signals completion.

The benefit of async submission is that the Metal driver can start preparing the command buffer for submission while the CPU reads results from the previous chunk. In practice, this saves 10-30 microseconds per chunk – small, but it adds up over thousands of chunks.

Reading Results

uint32_t *tokens = (uint32_t *)state.device->buffer_contents(
    state.scratch.token_ids);
for (int i = 0; i < n; i++) {
    uint32_t tok = tokens[i + 1];  // offset by 1: tokens[0] was input
    generated++;
    if (state.tokenizer.is_eos(tok))
        return generated;
    if (callback) {
        const char *text = decode_token_text(state, tok);
        if (!callback(tok, text, user_data))
            return generated;
    }
    next_token = tok;
}

After the GPU completes, the CPU reads all n tokens at once from the token_ids buffer. For each token:

  1. Check for EOS: if the model generated an end-of-sequence token, stop immediately.
  2. Invoke the callback: convert the token to text and stream it to the user.
  3. Update next_token: the last non-EOS token becomes the seed for the next chunk.

Note that EOS checking happens after the GPU has finished the entire chunk. If the model generated EOS at position 5 in a 64-token chunk, positions 6-63 were computed unnecessarily. This is the price of chain decode: you cannot stop mid-chain. However, the wasted work is bounded by one chunk (typically 64-128 tokens), and the throughput gain from chain decode far outweighs this cost.

When Does the Chain Break?

The chain breaks (a new GPU submission is needed) in three cases:

  1. EOS token: The CPU detects EOS when reading results and stops generating.
  2. Callback returns false: The user requested cancellation.
  3. Max tokens reached: The generation limit was hit.

Importantly, the chain does not break for stop sequences or for multi-token EOS patterns. Those are detected at the application level after the chain completes.

Performance Impact of Chain Decode

Let’s quantify the benefit. Consider generating 200 tokens on an M4 Pro with a 7B Q4_0 model:

Without chain decode (one submission per token):

ComponentPer-Token CostTotal (200 tokens)
Forward pass GPU time~8ms1600ms
Command buffer overhead~0.2ms40ms
CPU-GPU sync~0.05ms10ms
Total~8.25ms1650ms

With chain decode (chunk=128, 2 submissions):

ComponentPer-Chunk CostTotal (2 chunks)
Forward pass GPU time~128 * 8ms = 1024ms1600ms
Command buffer overhead~0.2ms0.4ms
CPU-GPU sync~0.05ms0.1ms
Total~1024ms1600.5ms

The GPU time is identical – the forward pass takes the same time regardless of whether it is in a chain or standalone. But the overhead drops from 50ms to 0.5ms: a 100x reduction. At 8ms per token, this is a ~3% throughput improvement, which is modest but free.

The real benefit emerges at faster per-token speeds. On an M4 Max with a 3B model doing 3ms/token, the overhead would be 10% without chaining and <0.1% with chaining.

The Dispatch Table in Detail

The dispatch table for a typical 7B Llama model contains approximately 60 commands per token:

Command GroupCountDescription
Embedding1Token ID -> hidden state
Layer loop (x32)~4832 layers * ~1.5 cmds each
RMSNorm32One per layer (fused)
RoPE + KV write32Fused per layer
Attention32Flash attention decode
O projection32GEMV per layer
SiLU + down32Fused activation GEMV or separate
Output norm1Final RMSNorm
Logit GEMV1Hidden state -> logits
Argmax1Logits -> token ID

Wait, that is more than 60. The key is that many operations are fused. For example, a layer with fused SiLU GEMV (gemv_q4_0_silu) replaces three dispatches (gate GEMV, up GEMV, SiLU element-wise) with a single dispatch. Residual + RMSNorm is another fusion. With full fusion, a layer can be as few as 4 dispatches:

  1. Fused residual + RMSNorm + QKV GEMV
  2. RoPE + KV write
  3. Attention
  4. Fused SiLU GEMV (gate * up * down)

In practice, Akunu achieves 4-6 dispatches per layer depending on the model architecture and quantization format.

Chain Decode vs. Batched Decode

It is worth distinguishing chain decode from batched/continuous batching used by server-side inference engines (e.g., vLLM, TensorRT-LLM). Those systems process multiple independent sequences simultaneously, sharing the weight reads across sequences. Chain decode processes a single sequence, but chains multiple sequential tokens into one GPU submission.

FeatureChain Decode (Akunu)Continuous Batching (vLLM)
Sequences1Many
Tokens per submissionN (sequential)1 per sequence, B sequences
Weight reuseAcross N tokens (sequential)Across B sequences (parallel)
Use caseOn-device inferenceServer inference
KV cacheSingle streamMultiple streams

Chain decode is uniquely suited to single-user on-device inference, where there is only one sequence to generate but we want to minimize CPU-GPU synchronization overhead.

How Sampled Decode Achieves Chain Performance

The sampled decode path (decode_sampled) achieves identical throughput to greedy by using the Gumbel-max trick on the GPU. Instead of an argmax at the end of each forward pass, it uses a Gumbel noise + argmax combination that is mathematically equivalent to sampling. The sampled dispatch table is identical to the greedy one except:

  1. A gumbel_topk_f16 kernel is inserted before the argmax
  2. The Gumbel kernel has a PATCH_POSITION so each token gets unique noise

This means the chain is unbroken – no CPU roundtrip is needed for random number generation or softmax computation. The RNG is a PCG hash function that takes the position as input, so each token in the chain gets a different noise sample despite being computed in a single GPU submission.

// In the Gumbel kernel:
uint element_seed = (params.position + params.seed_offset) * 2654435761u;
float u = pcg_float(element_seed + i);
float gumbel = -log(-log(u));
logits[i] = half(val + temp * gumbel);

The params.position is different for each token in the chain (patched by PATCH_POSITION), and element_seed varies per vocabulary element via the + i term. Together, these ensure all Gumbel noise values are unique.2

Understanding the Data Flow in Detail

Let’s trace exactly what happens inside the GPU during a chain of 3 tokens. Assume a simple model with 1 layer, dim=4096, vocab_size=32000:

Chain of 3: token_ids = [42, ?, ?, ?]

TOKEN 0 (position P):
  Embedding: read token_ids[0]=42 → hidden[4096]
  RMSNorm: hidden → normed
  Q GEMV: normed @ Wq → q[4096]
  K GEMV: normed @ Wk → k[1024]    (written to KV cache at pos P)
  V GEMV: normed @ Wv → v[1024]    (written to KV cache at pos P)
  RoPE: rotate q, k in-place
  Attention: q @ K_cache^T → scores → softmax → @ V_cache → attn_out
  O GEMV: attn_out @ Wo → residual
  ... (FFN) ...
  Logit GEMV: final_hidden @ W_logit → logits[32000]
  Argmax: logits → token_ids[1] = 7891  (writes to slot 1)

TOKEN 1 (position P+1):
  Embedding: read token_ids[1]=7891 → hidden[4096]   (reads what token 0 wrote!)
  ... (same operations, but position is P+1, KV cache has P+1 entries) ...
  Argmax: logits → token_ids[2] = 512   (writes to slot 2)

TOKEN 2 (position P+2):
  Embedding: read token_ids[2]=512 → hidden[4096]
  ... (same operations, position P+2, KV cache has P+2 entries) ...
  Argmax: logits → token_ids[3] = 1044  (writes to slot 3)

CPU reads: token_ids = [42, 7891, 512, 1044]
Stream tokens 7891, 512, 1044 to callback.

The critical data dependency is between the argmax write and the next token’s embedding read. Metal guarantees that dispatches within a single command buffer execute in order, so token 1’s embedding dispatch will see the value written by token 0’s argmax dispatch. No explicit synchronization is needed.

Edge Cases and Robustness

Buffer Sizing

The token_ids buffer must be large enough for chunk_size + 1 entries (input token + N output tokens). The speculative path has an additional constraint:

int max_batch = (int)(state.scratch.token_ids.size / sizeof(uint32_t)) - 1;
if (n_draft + 1 > max_batch)
    n_draft = max_batch - 1;

KV Cache Advancement

After each chunk, the KV cache advances by the full chunk size:

state.kv_cache.advance(n);
pos += n;

This is correct even if EOS was generated mid-chunk, because the KV entries for positions after EOS are simply ignored – they will never be queried by subsequent attention operations (there won’t be any).

Position Overflow

The position counter uses int, which limits sequences to ~2 billion tokens. In practice, the KV cache size (typically 4K-128K) is the binding constraint, not the position counter.

Summary

Greedy decoding in Akunu is deceptively simple on the surface but architecturally sophisticated:

  1. Chain decode processes multiple sequential tokens in a single GPU command buffer, eliminating per-token CPU-GPU synchronization overhead.
  2. Dispatch table replay with per-token patching enables the chain without recompiling or rebuilding GPU commands.
  3. The token_ids conveyor belt passes each token’s argmax result as the next token’s input, all within GPU memory.
  4. Double buffering overlaps GPU execution with CPU result processing.
  5. Hardware-tuned chunk sizes balance throughput against streaming latency.

The same chain decode mechanism powers sampled decode (via Gumbel-max) and speculative decode (via batched verification), making it the foundational building block of Akunu’s generation pipeline.

Deep Dive: The Dispatch Table Build

To understand chain decode fully, we need to understand how the dispatch table is constructed. During model initialization, Akunu builds a DispatchTable that represents the complete forward pass for a single token. Here is the conceptual structure for a 32-layer Llama model:

Command 0:  embedding_lookup (PATCH_TOKEN_OFFSET)
Command 1:  rmsnorm_f16 (layer 0 attn norm)
Command 2:  gemv_q4_0 (Q projection)
Command 3:  gemv_q4_0 (K projection)
Command 4:  gemv_q4_0 (V projection)
Command 5:  rope_kv_write_f16 (PATCH_POS_AND_KV)
Command 6:  flash_attention_decode_fast_f16 (PATCH_KV_SEQ_LEN)
Command 7:  gemv_q4_0 (O projection)
Command 8:  residual_rmsnorm_f16 (layer 0 FFN norm)
Command 9:  gemv_q4_0_silu (fused gate*up*down)
  ... repeat for layers 1-31 ...
Command 57: residual_rmsnorm_f16 (output norm)
Command 58: gemv_q4_0 (logit projection)
Command 59: argmax_f16 (PATCH_ARGMAX_OUTPUT)

Each command is a fixed-size DispatchCmd struct (no heap allocations), and the full table is a contiguous vector that fits in a few cache lines worth of pointers. The table is built once and never modified during generation – only the patch fields change per-token.

Hot/Cold Data Split

The DispatchTable uses a hot/cold split for profiling data:

struct DispatchTable {
    std::vector<DispatchCmd> commands;     // HOT: iterated every token
    std::vector<DispatchLabel> labels;     // COLD: only used during profiling
};

During generation, only commands is accessed. The labels vector (with 48-byte strings per command) is never touched unless a GPU profiler is attached. This prevents the labels from evicting hot command data from the CPU cache.

Buffer Bindings Are Static

A key property of the dispatch table: all buffer bindings are static (fixed at build time). The weight buffers, scratch buffers, and KV cache buffers are allocated during model init and never change. The only per-token dynamic data is:

  1. Which offset within a buffer to use (patched via PATCH_TOKEN_OFFSET and PATCH_ARGMAX_OUTPUT)
  2. Which scalar parameter values to use (patched via PATCH_POSITION, PATCH_KV_SEQ_LEN, PATCH_POS_AND_KV)

This means the Metal runtime can reuse pipeline state objects (PSOs) and buffer bindings across tokens, minimizing the command encoding overhead.

Practical Considerations for Chain Decode

Streaming Latency vs. Throughput

Chain decode introduces a fundamental tension: larger chunks give higher throughput (less overhead per token) but worse streaming latency (the user sees nothing until the chunk completes). Here is the tradeoff:

Chunk SizeOverhead SavingsTime to First Token in ChunkTokens Buffered
1 (no chain)0% (baseline)~8ms0
16~94%~128ms16
64~98.5%~512ms64
128~99.2%~1024ms128

At chunk_size=128, the user waits up to 1 second before seeing a burst of 128 tokens appear nearly simultaneously. Whether this is acceptable depends on the application: for interactive chat, chunk_size=64 provides a good balance; for batch processing, chunk_size=128 maximizes throughput.

Interaction with Stop Sequences

Modern LLM applications often use stop sequences (e.g., "\n\nHuman:") to terminate generation. With chain decode, stop sequence detection happens after the chunk completes, not during. If the model generates the stop sequence at token 10 of a 64-token chunk, tokens 11-64 are wasted computation.

However, the wasted work is bounded: at most one chunk’s worth of tokens. Since the overhead savings from chain decode (eliminating 63 command buffer creations per chunk) far exceed the cost of a few wasted tokens, the net benefit is strongly positive.

Memory Bandwidth During Chain Decode

During a chain of N tokens, the GPU reads the full model weights N times (once per token). However, Apple Silicon’s SLC (System Level Cache) can cache a portion of the weights across tokens. For a 7B Q4_0 model (~3.5 GB), the SLC on different chips caches:

ChipSLC Size% of Model CachedEffective Bandwidth Boost
M4 Base16 MB0.5%Negligible
M4 Pro32 MB0.9%Small
M4 Max48 MB1.4%Moderate
M4 Ultra96 MB2.7%Moderate

Even at 2.7% cache hit rate, the SLC provides measurable benefit because the cached portion includes the hot first-layer weights and norm parameters that are accessed every token. The threadgroup swizzling in GEMM kernels (covered in the GEMM chapter) is designed to maximize this SLC reuse.

For smaller models (1-3B) that partially fit in the SLC, the benefit is much larger. A 1B model at Q4_0 is ~500 MB, and a 96 MB SLC can cache ~19% of the weights, providing a meaningful bandwidth amplification.

Correctness of Chain Decode

A natural question: is chain decode mathematically equivalent to single-token decode? Yes, because:

  1. The dispatch table encodes the same forward pass regardless of chaining.
  2. Per-token patching ensures correct positions and KV cache lengths.
  3. The argmax writes to the correct output slot, and the next token reads from the previous slot.
  4. There are no data hazards: each forward pass within the chain reads only from buffers written by the previous pass, and the GPU’s command execution model guarantees sequential ordering within a command buffer.

The only difference is that stop conditions (EOS, callback cancellation) are checked after the chunk rather than after each token. This does not affect the generated text – it only affects how quickly the engine responds to stop conditions.



  1. Apple. “Apple M4 chip.” apple.com, 2024. The M4’s GPU command processor improvements include reduced dispatch latency and better utilization of the Apple GPU’s tile-based deferred rendering architecture for compute workloads. See https://www.apple.com/newsroom/2024/05/apple-introduces-m4-chip/.

  2. Maddison, C.J., Tarlow, D., and Minka, T. “A* Sampling.” NeurIPS 2014. The Gumbel-max trick provides an exact sample from the categorical distribution without explicitly computing the softmax or CDF. The PCG hash function (O’Neill, M. “PCG: A Family of Simple Fast Space-Efficient Statistically Good Algorithms for Random Number Generation.” Harvey Mudd College Technical Report HMC-CS-2014-0905) provides high-quality pseudorandomness with minimal state. See https://arxiv.org/abs/1411.0030.