Greedy Decoding and Chain Decode
Greedy decoding is the simplest generation strategy: at each step, pick the token with the highest logit. No temperature, no sampling, no randomness – just argmax. Despite its simplicity, the implementation in Akunu is anything but trivial, because Akunu’s greedy path does not generate one token per GPU submission. Instead, it generates an entire chain of tokens in a single GPU command buffer.
This chapter covers decode_greedy in src/inference/decode_greedy.cpp, the dispatch table replay mechanism, and the chain decode technique that makes greedy generation remarkably fast.
What Makes Greedy Special
When temperature is zero, the entire sampling pipeline collapses to a single operation: argmax. There is no need for softmax, no random number generation, no top-k filtering. And crucially, the argmax result is deterministic – given the same input, you always get the same output.
This determinism enables a powerful optimization: if the GPU can compute argmax as the last step of the forward pass, it can immediately feed the result as the input to the next forward pass, without ever returning control to the CPU. This is chain decode.
The Chain Decode Concept
The idea is deceptively simple. Instead of:
CPU: write token → GPU: forward pass → CPU: read result → CPU: write next token → ...
We do:
CPU: write token → GPU: [forward pass → argmax → forward pass → argmax → ... × N] → CPU: read N results
The entire chain of N tokens is encoded into a single GPU command buffer. The GPU executes the full sequence without any CPU intervention.
The benefit is eliminating the CPU-GPU synchronization overhead that occurs between tokens. On Apple Silicon, each end_encoding_sync() call costs roughly 20-50 microseconds in Metal command buffer overhead. At 50 tokens/sec, this overhead is negligible. But chain decode also eliminates the command buffer creation overhead, which can be 100-300 microseconds per submission. For a chunk of 64 tokens, we save ~63 command buffer creations.
The decode_greedy Implementation
Let’s read the actual code:
int decode_greedy(ModelState& state, akunu_model_t model,
uint32_t& next_token, int& pos, int max_tokens,
akunu_token_callback callback, void *user_data) {
int generated = 0;
int chunk_size = state.chip.chain_decode_chunk;
bool first = true;
while (generated < max_tokens) {
int remaining = max_tokens - generated;
int n = std::min(chunk_size, remaining);
state.device->write_buffer(state.scratch.token_ids, &next_token, 4);
state.device->begin_encoding();
state.device->encode_dispatch_table(&state.dispatch_table, pos, n);
if (first) {
state.device->end_encoding_sync();
first = false;
} else {
state.device->end_encoding_async();
state.device->wait();
}
state.kv_cache.advance(n);
pos += n;
uint32_t *tokens = (uint32_t *)state.device->buffer_contents(
state.scratch.token_ids);
for (int i = 0; i < n; i++) {
uint32_t tok = tokens[i + 1];
generated++;
if (state.tokenizer.is_eos(tok))
return generated;
if (callback) {
const char *text = decode_token_text(state, tok);
if (!callback(tok, text, user_data))
return generated;
}
next_token = tok;
}
}
return generated;
}
There is a lot packed into these 30 lines. Let’s unpack each piece.
Chunk Size Selection
int chunk_size = state.chip.chain_decode_chunk;
The chunk size comes from ChipConfig and varies by hardware:
| Hardware Tier | GPU Cores | Family | Chunk Size |
|---|---|---|---|
| M1/M2/M3 Base | < 16 | < 9 | 64 |
| M3 Pro | >= 16 | < 9 | 96 |
| M4 Base | < 16 | >= 9 | 128 |
| M4 Pro | >= 16 | >= 9 | 128 |
| M-series Max | >= 30 | any | 128 |
| M-series Ultra | >= 60 | any | 128 |
The M4 family gets larger chunks because its GPU command processor is more efficient at handling long command buffers, and its memory subsystem has better bandwidth for the interleaved read-write patterns of chain decode.1
Writing the Input Token
state.device->write_buffer(state.scratch.token_ids, &next_token, 4);
Only 4 bytes are written: the single token ID that starts the chain. The rest of the token_ids buffer will be filled by the GPU as each step’s argmax writes its result.
Encoding the Dispatch Table
state.device->begin_encoding();
state.device->encode_dispatch_table(&state.dispatch_table, pos, n);
This is the heart of chain decode. encode_dispatch_table calls encode_chain, which replays the dispatch table n times with position patching:
inline void encode_chain(Device& device, const DispatchTable& table,
int start_position, int count) {
const auto& cmds = table.commands;
const int n_cmds = (int)cmds.size();
for (int tok = 0; tok < count; tok++) {
int pos = start_position + tok;
for (int c = 0; c < n_cmds; c++) {
const auto& cmd = cmds[c];
device.set_pipeline(cmd.pso);
// ... set buffers, params, dispatch ...
}
}
}
For a 7B model with ~60 dispatches per token and a chunk of 64 tokens, this encodes 60 * 64 = 3840 GPU dispatches into a single command buffer. The Metal runtime batches these efficiently.
The Replicated Dispatch Pattern
Each iteration of the outer loop (each token) replays the same dispatch table, but with patched fields. The per-token patching ensures each forward pass uses the correct:
- Token ID offset: The embedding lookup reads from
token_ids[tok]instead oftoken_ids[0]. - Position: RoPE uses the correct absolute position for this token.
- KV sequence length: Attention knows how many KV entries are valid.
- Argmax output offset: The argmax result writes to
token_ids[tok + 1].
Here is how patching works for the most common types:
// Token offset patching (embedding lookup)
if (cmd.patch_type == DispatchCmd::PATCH_TOKEN_OFFSET && b == 0) {
offset = tok * 4; // byte offset to token_ids[tok]
}
// Argmax output patching
if (cmd.patch_type == DispatchCmd::PATCH_ARGMAX_OUTPUT && b == 1) {
offset = (tok + 1) * 4; // write result to token_ids[tok+1]
}
// Position patching (RoPE, attention)
if (cmd.patch_type == DispatchCmd::PATCH_POSITION) {
uint32_t pos_val = (uint32_t)pos;
memcpy(patched + cmd.patch_offset_1, &pos_val, 4);
}
// Combined position + KV length patching
if (cmd.patch_type == DispatchCmd::PATCH_POS_AND_KV) {
uint32_t pos_val = (uint32_t)pos;
uint32_t kv_val = (uint32_t)(pos + 1);
memcpy(patched + cmd.patch_offset_1, &pos_val, 4);
memcpy(patched + cmd.patch_offset_2, &kv_val, 4);
}
The token_ids buffer acts as a conveyor belt: token_ids[0] holds the input to the first forward pass, token_ids[1] gets the argmax output of the first pass (and becomes the input to the second pass), and so on.
token_ids buffer:
┌──────┬──────┬──────┬──────┬──────┬──────┐
│ t0 │ t1 │ t2 │ t3 │ t4 │ ... │
│(input)│(out1)│(out2)│(out3)│(out4)│ │
│ │(in2) │(in3) │(in4) │(in5) │ │
└──────┴──────┴──────┴──────┴──────┴──────┘
↑ ↑
CPU writes CPU reads all
one token after GPU done
Double Buffering
if (first) {
state.device->end_encoding_sync();
first = false;
} else {
state.device->end_encoding_async();
state.device->wait();
}
The first chunk uses synchronous execution because there is nothing to overlap with. Subsequent chunks use asynchronous execution: end_encoding_async() submits the command buffer and returns immediately, then wait() blocks until the GPU signals completion.
The benefit of async submission is that the Metal driver can start preparing the command buffer for submission while the CPU reads results from the previous chunk. In practice, this saves 10-30 microseconds per chunk – small, but it adds up over thousands of chunks.
Reading Results
uint32_t *tokens = (uint32_t *)state.device->buffer_contents(
state.scratch.token_ids);
for (int i = 0; i < n; i++) {
uint32_t tok = tokens[i + 1]; // offset by 1: tokens[0] was input
generated++;
if (state.tokenizer.is_eos(tok))
return generated;
if (callback) {
const char *text = decode_token_text(state, tok);
if (!callback(tok, text, user_data))
return generated;
}
next_token = tok;
}
After the GPU completes, the CPU reads all n tokens at once from the token_ids buffer. For each token:
- Check for EOS: if the model generated an end-of-sequence token, stop immediately.
- Invoke the callback: convert the token to text and stream it to the user.
- Update
next_token: the last non-EOS token becomes the seed for the next chunk.
Note that EOS checking happens after the GPU has finished the entire chunk. If the model generated EOS at position 5 in a 64-token chunk, positions 6-63 were computed unnecessarily. This is the price of chain decode: you cannot stop mid-chain. However, the wasted work is bounded by one chunk (typically 64-128 tokens), and the throughput gain from chain decode far outweighs this cost.
When Does the Chain Break?
The chain breaks (a new GPU submission is needed) in three cases:
- EOS token: The CPU detects EOS when reading results and stops generating.
- Callback returns false: The user requested cancellation.
- Max tokens reached: The generation limit was hit.
Importantly, the chain does not break for stop sequences or for multi-token EOS patterns. Those are detected at the application level after the chain completes.
Performance Impact of Chain Decode
Let’s quantify the benefit. Consider generating 200 tokens on an M4 Pro with a 7B Q4_0 model:
Without chain decode (one submission per token):
| Component | Per-Token Cost | Total (200 tokens) |
|---|---|---|
| Forward pass GPU time | ~8ms | 1600ms |
| Command buffer overhead | ~0.2ms | 40ms |
| CPU-GPU sync | ~0.05ms | 10ms |
| Total | ~8.25ms | 1650ms |
With chain decode (chunk=128, 2 submissions):
| Component | Per-Chunk Cost | Total (2 chunks) |
|---|---|---|
| Forward pass GPU time | ~128 * 8ms = 1024ms | 1600ms |
| Command buffer overhead | ~0.2ms | 0.4ms |
| CPU-GPU sync | ~0.05ms | 0.1ms |
| Total | ~1024ms | 1600.5ms |
The GPU time is identical – the forward pass takes the same time regardless of whether it is in a chain or standalone. But the overhead drops from 50ms to 0.5ms: a 100x reduction. At 8ms per token, this is a ~3% throughput improvement, which is modest but free.
The real benefit emerges at faster per-token speeds. On an M4 Max with a 3B model doing 3ms/token, the overhead would be 10% without chaining and <0.1% with chaining.
The Dispatch Table in Detail
The dispatch table for a typical 7B Llama model contains approximately 60 commands per token:
| Command Group | Count | Description |
|---|---|---|
| Embedding | 1 | Token ID -> hidden state |
| Layer loop (x32) | ~48 | 32 layers * ~1.5 cmds each |
| RMSNorm | 32 | One per layer (fused) |
| RoPE + KV write | 32 | Fused per layer |
| Attention | 32 | Flash attention decode |
| O projection | 32 | GEMV per layer |
| SiLU + down | 32 | Fused activation GEMV or separate |
| Output norm | 1 | Final RMSNorm |
| Logit GEMV | 1 | Hidden state -> logits |
| Argmax | 1 | Logits -> token ID |
Wait, that is more than 60. The key is that many operations are fused. For example, a layer with fused SiLU GEMV (gemv_q4_0_silu) replaces three dispatches (gate GEMV, up GEMV, SiLU element-wise) with a single dispatch. Residual + RMSNorm is another fusion. With full fusion, a layer can be as few as 4 dispatches:
- Fused residual + RMSNorm + QKV GEMV
- RoPE + KV write
- Attention
- Fused SiLU GEMV (gate * up * down)
In practice, Akunu achieves 4-6 dispatches per layer depending on the model architecture and quantization format.
Chain Decode vs. Batched Decode
It is worth distinguishing chain decode from batched/continuous batching used by server-side inference engines (e.g., vLLM, TensorRT-LLM). Those systems process multiple independent sequences simultaneously, sharing the weight reads across sequences. Chain decode processes a single sequence, but chains multiple sequential tokens into one GPU submission.
| Feature | Chain Decode (Akunu) | Continuous Batching (vLLM) |
|---|---|---|
| Sequences | 1 | Many |
| Tokens per submission | N (sequential) | 1 per sequence, B sequences |
| Weight reuse | Across N tokens (sequential) | Across B sequences (parallel) |
| Use case | On-device inference | Server inference |
| KV cache | Single stream | Multiple streams |
Chain decode is uniquely suited to single-user on-device inference, where there is only one sequence to generate but we want to minimize CPU-GPU synchronization overhead.
How Sampled Decode Achieves Chain Performance
The sampled decode path (decode_sampled) achieves identical throughput to greedy by using the Gumbel-max trick on the GPU. Instead of an argmax at the end of each forward pass, it uses a Gumbel noise + argmax combination that is mathematically equivalent to sampling. The sampled dispatch table is identical to the greedy one except:
- A
gumbel_topk_f16kernel is inserted before the argmax - The Gumbel kernel has a
PATCH_POSITIONso each token gets unique noise
This means the chain is unbroken – no CPU roundtrip is needed for random number generation or softmax computation. The RNG is a PCG hash function that takes the position as input, so each token in the chain gets a different noise sample despite being computed in a single GPU submission.
// In the Gumbel kernel:
uint element_seed = (params.position + params.seed_offset) * 2654435761u;
float u = pcg_float(element_seed + i);
float gumbel = -log(-log(u));
logits[i] = half(val + temp * gumbel);
The params.position is different for each token in the chain (patched by PATCH_POSITION), and element_seed varies per vocabulary element via the + i term. Together, these ensure all Gumbel noise values are unique.2
Understanding the Data Flow in Detail
Let’s trace exactly what happens inside the GPU during a chain of 3 tokens. Assume a simple model with 1 layer, dim=4096, vocab_size=32000:
Chain of 3: token_ids = [42, ?, ?, ?]
TOKEN 0 (position P):
Embedding: read token_ids[0]=42 → hidden[4096]
RMSNorm: hidden → normed
Q GEMV: normed @ Wq → q[4096]
K GEMV: normed @ Wk → k[1024] (written to KV cache at pos P)
V GEMV: normed @ Wv → v[1024] (written to KV cache at pos P)
RoPE: rotate q, k in-place
Attention: q @ K_cache^T → scores → softmax → @ V_cache → attn_out
O GEMV: attn_out @ Wo → residual
... (FFN) ...
Logit GEMV: final_hidden @ W_logit → logits[32000]
Argmax: logits → token_ids[1] = 7891 (writes to slot 1)
TOKEN 1 (position P+1):
Embedding: read token_ids[1]=7891 → hidden[4096] (reads what token 0 wrote!)
... (same operations, but position is P+1, KV cache has P+1 entries) ...
Argmax: logits → token_ids[2] = 512 (writes to slot 2)
TOKEN 2 (position P+2):
Embedding: read token_ids[2]=512 → hidden[4096]
... (same operations, position P+2, KV cache has P+2 entries) ...
Argmax: logits → token_ids[3] = 1044 (writes to slot 3)
CPU reads: token_ids = [42, 7891, 512, 1044]
Stream tokens 7891, 512, 1044 to callback.
The critical data dependency is between the argmax write and the next token’s embedding read. Metal guarantees that dispatches within a single command buffer execute in order, so token 1’s embedding dispatch will see the value written by token 0’s argmax dispatch. No explicit synchronization is needed.
Edge Cases and Robustness
Buffer Sizing
The token_ids buffer must be large enough for chunk_size + 1 entries (input token + N output tokens). The speculative path has an additional constraint:
int max_batch = (int)(state.scratch.token_ids.size / sizeof(uint32_t)) - 1;
if (n_draft + 1 > max_batch)
n_draft = max_batch - 1;
KV Cache Advancement
After each chunk, the KV cache advances by the full chunk size:
state.kv_cache.advance(n);
pos += n;
This is correct even if EOS was generated mid-chunk, because the KV entries for positions after EOS are simply ignored – they will never be queried by subsequent attention operations (there won’t be any).
Position Overflow
The position counter uses int, which limits sequences to ~2 billion tokens. In practice, the KV cache size (typically 4K-128K) is the binding constraint, not the position counter.
Summary
Greedy decoding in Akunu is deceptively simple on the surface but architecturally sophisticated:
- Chain decode processes multiple sequential tokens in a single GPU command buffer, eliminating per-token CPU-GPU synchronization overhead.
- Dispatch table replay with per-token patching enables the chain without recompiling or rebuilding GPU commands.
- The token_ids conveyor belt passes each token’s argmax result as the next token’s input, all within GPU memory.
- Double buffering overlaps GPU execution with CPU result processing.
- Hardware-tuned chunk sizes balance throughput against streaming latency.
The same chain decode mechanism powers sampled decode (via Gumbel-max) and speculative decode (via batched verification), making it the foundational building block of Akunu’s generation pipeline.
Deep Dive: The Dispatch Table Build
To understand chain decode fully, we need to understand how the dispatch table is constructed. During model initialization, Akunu builds a DispatchTable that represents the complete forward pass for a single token. Here is the conceptual structure for a 32-layer Llama model:
Command 0: embedding_lookup (PATCH_TOKEN_OFFSET)
Command 1: rmsnorm_f16 (layer 0 attn norm)
Command 2: gemv_q4_0 (Q projection)
Command 3: gemv_q4_0 (K projection)
Command 4: gemv_q4_0 (V projection)
Command 5: rope_kv_write_f16 (PATCH_POS_AND_KV)
Command 6: flash_attention_decode_fast_f16 (PATCH_KV_SEQ_LEN)
Command 7: gemv_q4_0 (O projection)
Command 8: residual_rmsnorm_f16 (layer 0 FFN norm)
Command 9: gemv_q4_0_silu (fused gate*up*down)
... repeat for layers 1-31 ...
Command 57: residual_rmsnorm_f16 (output norm)
Command 58: gemv_q4_0 (logit projection)
Command 59: argmax_f16 (PATCH_ARGMAX_OUTPUT)
Each command is a fixed-size DispatchCmd struct (no heap allocations), and the full table is a contiguous vector that fits in a few cache lines worth of pointers. The table is built once and never modified during generation – only the patch fields change per-token.
Hot/Cold Data Split
The DispatchTable uses a hot/cold split for profiling data:
struct DispatchTable {
std::vector<DispatchCmd> commands; // HOT: iterated every token
std::vector<DispatchLabel> labels; // COLD: only used during profiling
};
During generation, only commands is accessed. The labels vector (with 48-byte strings per command) is never touched unless a GPU profiler is attached. This prevents the labels from evicting hot command data from the CPU cache.
Buffer Bindings Are Static
A key property of the dispatch table: all buffer bindings are static (fixed at build time). The weight buffers, scratch buffers, and KV cache buffers are allocated during model init and never change. The only per-token dynamic data is:
- Which offset within a buffer to use (patched via
PATCH_TOKEN_OFFSETandPATCH_ARGMAX_OUTPUT) - Which scalar parameter values to use (patched via
PATCH_POSITION,PATCH_KV_SEQ_LEN,PATCH_POS_AND_KV)
This means the Metal runtime can reuse pipeline state objects (PSOs) and buffer bindings across tokens, minimizing the command encoding overhead.
Practical Considerations for Chain Decode
Streaming Latency vs. Throughput
Chain decode introduces a fundamental tension: larger chunks give higher throughput (less overhead per token) but worse streaming latency (the user sees nothing until the chunk completes). Here is the tradeoff:
| Chunk Size | Overhead Savings | Time to First Token in Chunk | Tokens Buffered |
|---|---|---|---|
| 1 (no chain) | 0% (baseline) | ~8ms | 0 |
| 16 | ~94% | ~128ms | 16 |
| 64 | ~98.5% | ~512ms | 64 |
| 128 | ~99.2% | ~1024ms | 128 |
At chunk_size=128, the user waits up to 1 second before seeing a burst of 128 tokens appear nearly simultaneously. Whether this is acceptable depends on the application: for interactive chat, chunk_size=64 provides a good balance; for batch processing, chunk_size=128 maximizes throughput.
Interaction with Stop Sequences
Modern LLM applications often use stop sequences (e.g., "\n\nHuman:") to terminate generation. With chain decode, stop sequence detection happens after the chunk completes, not during. If the model generates the stop sequence at token 10 of a 64-token chunk, tokens 11-64 are wasted computation.
However, the wasted work is bounded: at most one chunk’s worth of tokens. Since the overhead savings from chain decode (eliminating 63 command buffer creations per chunk) far exceed the cost of a few wasted tokens, the net benefit is strongly positive.
Memory Bandwidth During Chain Decode
During a chain of N tokens, the GPU reads the full model weights N times (once per token). However, Apple Silicon’s SLC (System Level Cache) can cache a portion of the weights across tokens. For a 7B Q4_0 model (~3.5 GB), the SLC on different chips caches:
| Chip | SLC Size | % of Model Cached | Effective Bandwidth Boost |
|---|---|---|---|
| M4 Base | 16 MB | 0.5% | Negligible |
| M4 Pro | 32 MB | 0.9% | Small |
| M4 Max | 48 MB | 1.4% | Moderate |
| M4 Ultra | 96 MB | 2.7% | Moderate |
Even at 2.7% cache hit rate, the SLC provides measurable benefit because the cached portion includes the hot first-layer weights and norm parameters that are accessed every token. The threadgroup swizzling in GEMM kernels (covered in the GEMM chapter) is designed to maximize this SLC reuse.
For smaller models (1-3B) that partially fit in the SLC, the benefit is much larger. A 1B model at Q4_0 is ~500 MB, and a 96 MB SLC can cache ~19% of the weights, providing a meaningful bandwidth amplification.
Correctness of Chain Decode
A natural question: is chain decode mathematically equivalent to single-token decode? Yes, because:
- The dispatch table encodes the same forward pass regardless of chaining.
- Per-token patching ensures correct positions and KV cache lengths.
- The argmax writes to the correct output slot, and the next token reads from the previous slot.
- There are no data hazards: each forward pass within the chain reads only from buffers written by the previous pass, and the GPU’s command execution model guarantees sequential ordering within a command buffer.
The only difference is that stop conditions (EOS, callback cancellation) are checked after the chunk rather than after each token. This does not affect the generated text – it only affects how quickly the engine responds to stop conditions.
-
Apple. “Apple M4 chip.” apple.com, 2024. The M4’s GPU command processor improvements include reduced dispatch latency and better utilization of the Apple GPU’s tile-based deferred rendering architecture for compute workloads. See https://www.apple.com/newsroom/2024/05/apple-introduces-m4-chip/. ↩
-
Maddison, C.J., Tarlow, D., and Minka, T. “A* Sampling.” NeurIPS 2014. The Gumbel-max trick provides an exact sample from the categorical distribution without explicitly computing the softmax or CDF. The PCG hash function (O’Neill, M. “PCG: A Family of Simple Fast Space-Efficient Statistically Good Algorithms for Random Number Generation.” Harvey Mudd College Technical Report HMC-CS-2014-0905) provides high-quality pseudorandomness with minimal state. See https://arxiv.org/abs/1411.0030. ↩