Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Speculative Decoding with N-grams

Autoregressive decoding has a fundamental problem: each token depends on the previous one. You cannot compute token 42 until you have produced token 41, because token 41 changes the KV cache and thus the attention output for position 42. This makes decode inherently sequential. For large models, each decode step takes (say) 5-10 milliseconds of GPU time, so generating 500 tokens takes 2.5-5 seconds. The GPU is actually underutilized during decode – it is doing a single matrix-vector multiply per layer, not the fat matrix-matrix multiplies of prefill – but there is no obvious way to parallelize.

Speculative decoding breaks this bottleneck by guessing future tokens cheaply and then verifying them in a single batched forward pass. If the guesses are right (and for many text patterns, they often are), you produce multiple tokens per forward pass without changing the output distribution.

This chapter covers akunu’s speculative decode implementation, which uses n-gram frequency tables as its draft predictor – no draft model required.

The Core Idea

The speculative decode algorithm has three steps:

  1. DRAFT:   Predict N future tokens cheaply
  2. VERIFY:  Run all N+1 tokens through the real model in one batch
  3. ACCEPT:  Keep the longest prefix where draft == verified

  +----------+     +------------------+     +-----------+
  |  N-gram  | --> |  Batched model   | --> |  Compare  |
  | predictor|     |  forward pass    |     | draft vs  |
  | (draft)  |     |  (N+1 positions) |     | verified  |
  +----------+     +------------------+     +-----------+
    ~0 cost         ~1x forward cost        ~0 cost

  If K of N drafts match: you got K+1 tokens for the price of 1 forward pass!

Let us trace through a concrete example. Suppose we are generating text and the most recent token is “the”. The n-gram predictor looks at its frequency tables and predicts: [“quick”, “brown”, “fox”, “jumps”].

We pack these into a batch:

  Position:    42      43       44      45      46
  Input:      "the"  "quick"  "brown"  "fox"  "jumps"
               ^       ^        ^       ^       ^
             known   draft[0] draft[1] draft[2] draft[3]

We run the model on all 5 positions in a single forward pass. The model produces an argmax (or sampled) token for each position:

  Position:    42      43       44      45      46
  Verified:  "quick" "brown"  "fox"   "jumped" "over"
               ^       ^        ^       ^
             matches  matches  matches  MISMATCH at draft[3]

Drafts 0-2 matched (“quick”, “brown”, “fox” were all correct guesses). Draft 3 was wrong (“jumped” vs “jumps”). So we accept the 3 matching drafts plus the bonus token at the mismatch point:

  Accepted tokens: "quick", "brown", "fox", "jumped"
  = 4 tokens from 1 forward pass!

The bonus token at position 45 is the model’s actual prediction given the correct prefix [“the”, “quick”, “brown”, “fox”] – so it is always valid. We just could not predict what came after it.

Why N-grams?

The standard speculative decoding literature uses a small draft model to generate the draft tokens. That draft model is typically 10-50x smaller than the target model, runs fast, and hopefully agrees with the target model on easy tokens.

Akunu takes a different approach: no draft model at all. Instead, it builds n-gram frequency tables from the tokens seen so far (prompt + generated) and uses those to predict the next tokens.

This has several advantages:

  1. Zero overhead at load time. No extra model to load, no extra memory.
  2. Perfect for repetitive patterns. Code, structured data, and template- heavy text often repeat multi-token sequences. N-grams capture these perfectly.
  3. Adapts in real-time. The frequency tables update with every generated token, so the predictor learns patterns specific to this generation.
  4. Simplicity. The entire predictor is ~100 lines of C++ with no external dependencies.

The downside is that n-gram prediction has zero “understanding” – it cannot predict tokens it has not seen in exactly the right context. For highly novel text, the predictor will fail to produce any drafts, and speculative decode degrades gracefully to standard autoregressive decode.

The N-gram Predictor

Let us look at akunu’s NGramPredictor class in detail.

Configuration

  MAX_ORDER   = 4     Up to 4-grams (context of 3 tokens)
  DRAFT_COUNT = 4     Predict up to 4 tokens per round
  MAX_HISTORY = 512   Sliding window of recent tokens

Data Structures

The predictor maintains:

  1. A deque<uint32_t> history_ of the last 512 tokens seen
  2. Three hash tables (one per n-gram order):
    • tables_[0]: bigrams (1 token context -> next token -> count)
    • tables_[1]: trigrams (2 token context -> next token -> count)
    • tables_[2]: 4-grams (3 token context -> next token -> count)
  tables_[order-2]:  hash(context) --> { token_id: count, ... }

  Example for the context "the cat sat":
  tables_[2][hash("the","cat","sat")] = {
      "on":    47,     <-- seen 47 times after "the cat sat"
      "down":  3,      <-- seen 3 times
      "and":   1,      <-- seen 1 time
  }

Hashing

Context tokens are hashed using FNV-1a, a simple non-cryptographic hash:

static uint64_t context_hash(const uint32_t *tokens, int n) {
    uint64_t h = 14695981039346656037ULL;  // FNV-1a offset basis
    for (int i = 0; i < n; i++) {
        h ^= tokens[i];
        h *= 1099511628211ULL;             // FNV-1a prime
    }
    return h;
}

This maps a variable-length token sequence to a 64-bit key. Collisions are theoretically possible but vanishingly unlikely for 512-token histories with 128k vocabulary – the hash space is 2^64 while the number of distinct contexts is at most 512^3 ~ 134M for 4-grams.

Update

When a new token is generated, update() adds it to the frequency tables at all applicable orders:

  History: [... t_{n-3}, t_{n-2}, t_{n-1}]
  New token: t_n

  Bigram:  hash(t_{n-1})                    -> t_n  count++
  Trigram: hash(t_{n-2}, t_{n-1})           -> t_n  count++
  4-gram:  hash(t_{n-3}, t_{n-2}, t_{n-1}) -> t_n  count++

The prompt tokens are added via update_batch(), which calls update() for each token. This seeds the frequency tables with the patterns present in the prompt – which is crucial for tasks like “continue this code” where the prompt establishes the patterns that will repeat.

Prediction

The predict() method generates up to DRAFT_COUNT (4) predicted tokens. For each position, it tries the longest matching context first:

  Tentative context: [... t_{n-2}, t_{n-1}, t_n]

  Try 4-gram: lookup hash(t_{n-2}, t_{n-1}, t_n) in tables_[2]
    Found?  Pick the most frequent continuation.  Done.
    Not found?  Try 3-gram.

  Try 3-gram: lookup hash(t_{n-1}, t_n) in tables_[1]
    Found?  Pick the most frequent continuation.  Done.
    Not found?  Try 2-gram.

  Try 2-gram: lookup hash(t_n) in tables_[0]
    Found?  Pick the most frequent continuation.  Done.
    Not found?  No draft for this position.  Stop.

Once a token is predicted, it is appended to the tentative context, and the process repeats for the next position. This means the predictor can chain its own predictions – predicting “quick” allows it to then predict “brown” given the extended context.

  predict() trace:

  Context: ["the", "cat", "sat"]

  Step 0: 4-gram("the","cat","sat") --> "on"    (count=47)
          Context becomes ["the","cat","sat","on"]

  Step 1: 4-gram("cat","sat","on")  --> "the"   (count=31)
          Context becomes ["the","cat","sat","on","the"]

  Step 2: 4-gram("sat","on","the")  --> "mat"   (count=23)
          Context becomes ["the","cat","sat","on","the","mat"]

  Step 3: 4-gram("on","the","mat")  --> not found
          3-gram("the","mat")       --> "."     (count=5)

  Draft: ["on", "the", "mat", "."]

The Verification Loop

Now let us trace through decode_speculative() in detail.

No-draft fallback

If the predictor fails to produce any drafts (n_draft <= 0), the loop falls back to standard single-token decode:

if (n_draft <= 0) {
    // Single-token forward pass
    state.device->write_buffer(state.scratch.token_ids, &next_token, 4);
    state.device->begin_encoding();
    state.device->encode_dispatch_table(&state.dispatch_table, pos, 1);
    state.device->end_encoding_sync();
    state.kv_cache.advance(1);
    // ... emit token, update predictor ...
    continue;
}

This is important: speculative decode never hurts performance. In the worst case (no successful predictions), it behaves identically to standard decode.

Batched verification

When drafts are available, we pack them into the token_ids buffer along with the current token:

uint32_t *chain_buf = (uint32_t *)state.device->buffer_contents(
    state.scratch.token_ids);
chain_buf[0] = next_token;      // the known token
for (int i = 0; i < n_draft; i++)
    chain_buf[i + 1] = drafts[i]; // the draft tokens

Then we run the model on the entire batch:

state.device->begin_encoding();
state.device->encode_dispatch_table(&state.dispatch_table, pos, batch_size);
state.device->end_encoding_sync();
state.kv_cache.advance(batch_size);

The dispatch table processes batch_size = 1 + n_draft positions in a single encoding pass. This is the same mechanism used for prefill – the model sees multiple tokens and produces a prediction for each position.

  Token IDs buffer before forward pass:
  +-------+----------+----------+----------+----------+
  | known | draft[0] | draft[1] | draft[2] | draft[3] |
  |  tok  |          |          |          |          |
  +-------+----------+----------+----------+----------+
    pos     pos+1      pos+2      pos+3      pos+4

  Token IDs buffer after forward pass (output shifted by 1):
  +-------+----------+----------+----------+----------+----------+
  | (in)  | verify0  | verify1  | verify2  | verify3  | verify4  |
  +-------+----------+----------+----------+----------+----------+
    [0]      [1]        [2]        [3]        [4]        [5]

  verify0 = model's prediction given context up to pos
  verify1 = model's prediction given context up to pos+1 (including draft[0])
  ...

Acceptance logic

We compare each verified token against the corresponding draft:

int accepted = 0;
uint32_t bonus_token = 0;
for (int i = 0; i < batch_size; i++) {
    uint32_t verified = chain_buf[i + 1];
    if (i < n_draft) {
        if (verified == drafts[i])
            accepted++;       // Draft was correct!
        else {
            bonus_token = verified;  // First mismatch
            break;
        }
    } else {
        bonus_token = verified;  // Token after all drafts
    }
}

This finds the longest prefix of correct drafts. The key insight: when draft[i] matches verified[i], the model was going to produce that token anyway. So we can accept it for free. The first mismatch gives us a “bonus token” – the model’s actual prediction at that position.

KV cache rollback

Here is a subtle but critical detail. The forward pass advanced the KV cache by batch_size positions. But we only accepted accepted + 1 tokens (the accepted drafts plus the bonus). The remaining KV cache entries (for the rejected drafts) are wrong – they were computed based on a context that included incorrect draft tokens.

We must roll back the KV cache:

int keep = accepted + 1;
state.kv_cache.rollback(pos + keep);
pos += keep;

rollback(n) truncates the KV cache to position n, discarding everything after. This is a cheap operation – typically just updating a length counter, since the actual data does not need to be zeroed (it will be overwritten by the next forward pass).

Complete flow diagram

  +-------------------+
  | N-gram predictor  |
  | predict() -> 4    |
  | draft tokens      |
  +-------------------+
           |
           v
  +-------------------+
  | Pack batch:       |
  | [known, d0, d1,   |
  |  d2, d3]          |
  | batch_size = 5    |
  +-------------------+
           |
           v
  +-------------------+
  | Model forward     |
  | on 5 positions    |
  | (single GPU pass) |
  +-------------------+
           |
           v
  +-------------------+
  | Read outputs:     |
  | [v0,v1,v2,v3,v4] |
  +-------------------+
           |
           v
  +-------------------+     +-------------------+
  | Compare:          |     | Example:          |
  | d0==v0? yes       |     | d0="on"  v0="on"  |
  | d1==v1? yes       |     | d1="the" v1="the" |
  | d2==v2? no        |     | d2="mat" v2="rug" |
  |   bonus = v2      |     |   bonus = "rug"   |
  +-------------------+     +-------------------+
           |
           v
  +-------------------+
  | Rollback KV cache |
  | keep = 2+1 = 3    |
  | Emit: d0, d1,     |
  |       bonus       |
  | = 3 tokens for    |
  |   1 forward pass! |
  +-------------------+

Buffer Capacity Guard

There is a practical limit on batch size: the token_ids buffer has a fixed allocation. The code guards against overflow:

int max_batch = (int)(state.scratch.token_ids.size / sizeof(uint32_t)) - 1;
if (n_draft + 1 > max_batch)
    n_draft = max_batch - 1;

The -1 accounts for the output slot: the buffer needs batch_size + 1 uint32 slots (input tokens plus one extra output slot at the end).

Speedup Analysis

The theoretical speedup from speculative decoding is:

  Speedup = (accepted + 1) / 1 = accepted + 1

That is: if K drafts are accepted, you get K+1 tokens from 1 forward pass instead of K+1 forward passes.

The actual speedup depends on the acceptance rate – what fraction of drafts match the model’s predictions. Let alpha be the per-token acceptance probability. With D drafts, the expected number of accepted tokens is:

  E[accepted] = sum_{k=0}^{D} k * alpha^k * (1-alpha)
              + D * alpha^D    (all D accepted)

  For D=4, alpha=0.5:  E[accepted] = 0.94
  For D=4, alpha=0.7:  E[accepted] = 2.05
  For D=4, alpha=0.9:  E[accepted] = 3.28

So the expected speedup is roughly:

  alpha=0.5:  1.94x  (barely worth it)
  alpha=0.7:  3.05x  (significant)
  alpha=0.9:  4.28x  (near-optimal)

But wait – there is overhead. The batched forward pass with 5 tokens is not free. It is faster than 5 separate forward passes (because attention over 5 positions is still cheap), but it does more work than a single forward pass.

In practice, on Apple Silicon, the overhead of a 5-token batch vs a 1-token decode is about 20-40%. So the actual speedup is closer to:

  alpha=0.5:  1.94 / 1.3 = 1.5x
  alpha=0.7:  3.05 / 1.3 = 2.3x
  alpha=0.9:  4.28 / 1.3 = 3.3x

Still quite good for alpha >= 0.7.

When Does N-gram Prediction Work Well?

The acceptance rate depends heavily on the content being generated:

High acceptance rate (alpha > 0.8):

  • Repetitive code patterns (boilerplate, getters/setters)
  • Template-based text (email signatures, legal disclaimers)
  • Continuation of previously seen phrases
  • JSON/XML with repeated structure
  • Code that mirrors the prompt (e.g., implementing similar functions)

Medium acceptance rate (alpha 0.4-0.7):

  • Natural language with common phrases
  • Code with some repetition but also novel logic
  • Structured output that partially follows established patterns

Low acceptance rate (alpha < 0.3):

  • Highly creative or novel text
  • Mathematical proofs with unique symbol sequences
  • First-time generation of a new pattern
  • Diverse conversational responses

The n-gram predictor excels in exactly the situations where LLMs are most boring – repetitive, predictable text. This is a happy coincidence: the tokens that are easy to predict are also the tokens that take the most wall-clock time (because there are many of them), so accelerating them has the most impact.

Comparison with Draft-Model Speculative Decoding

The classic approach uses a small draft model (e.g., a 160M parameter model drafting for a 7B target). How does n-gram prediction compare?

  +-------------------+-------------------+-------------------+
  |                   | N-gram            | Draft model       |
  +-------------------+-------------------+-------------------+
  | Setup cost        | None              | Load 2nd model    |
  | Memory            | ~1MB tables       | 100s MB - GBs     |
  | Draft speed       | ~0 (hash lookup)  | Fast but nonzero  |
  | Novel text        | Poor              | Good              |
  | Repetitive text   | Excellent         | Good              |
  | Code patterns     | Excellent         | Good              |
  | Implementation    | ~100 lines        | Full model stack  |
  | Correctness       | Exact (greedy)    | Exact (rejection) |
  +-------------------+-------------------+-------------------+

For akunu’s use case – a lightweight local inference engine on Apple Silicon where memory is precious – n-gram prediction is an excellent choice. It adds negligible memory overhead and provides significant speedup on the kinds of text that dominate many workloads (code, structured output, templates).

Interaction with Greedy vs Sampled Decode

A subtle point: the current implementation uses the regular dispatch_table (greedy argmax) for verification. This means speculative decode currently produces the same output as greedy decode – it is a pure speedup with no quality change.

Extending this to sampled decode requires the rejection sampling variant of speculative decoding, where draft tokens are accepted probabilistically rather than by exact match. The acceptance probability for draft token d when the model would sample t is:

  accept with probability min(1, P_target(d) / P_draft(d))

This preserves the target model’s sampling distribution exactly. Akunu does not currently implement this variant, but the n-gram predictor’s frequency counts could be used as approximate draft probabilities if needed.

The Predictor’s Sliding Window

The MAX_HISTORY = 512 sliding window serves two purposes:

  1. Memory bound. Without a limit, the frequency tables would grow indefinitely during long generations. The deque automatically evicts old tokens.

  2. Recency bias. Patterns from 10,000 tokens ago are less likely to repeat than patterns from 100 tokens ago. The sliding window implicitly prioritizes recent context.

Note that the frequency tables themselves are not pruned when tokens leave the sliding window. Only the history used for context matching is windowed. This means the tables accumulate counts from the entire generation, but prediction only uses the last 512 tokens as context. In practice, this works well – the tables capture global patterns while the context window provides relevance filtering.

Summary

Speculative decoding with n-gram prediction gives akunu a clean, zero-overhead way to accelerate the most tedious part of LLM inference: autoregressive decode of predictable tokens. The implementation is lean – about 200 lines total between the predictor and the decode loop – and the algorithm is remarkably simple: guess tokens from frequency tables, verify in batch, accept the longest matching prefix, roll back the KV cache for rejected drafts.

The key architectural insight is that the n-gram predictor and the batched forward pass are completely decoupled. Any prediction source could be swapped in (a draft model, a lookup table, a regex-based predictor for structured output) without changing the verification logic. The n-gram approach just happens to be the best cost/benefit ratio for a memory-constrained local inference engine.