Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Adding a New Model Architecture

One of Akunu’s design goals is that adding support for a new transformer architecture should not require touching the core inference loop, the dispatch table builder, or any Metal kernel. Instead, you define a data descriptor that captures the architecture’s unique properties, and the existing machinery handles the rest. This chapter walks through exactly how to do that.1

The Data-Driven Approach

In many inference engines, adding a new architecture means writing a new forward pass function full of if (arch == "llama") ... else if (arch == "qwen") ... branches. Akunu takes a different approach: the ArchDescriptor struct encodes every architecture-specific decision as a data field, and build_dispatch_table() reads from the descriptor without ever branching on architecture name.

Here is the ArchDescriptor:

struct ArchDescriptor {
    // Activation
    const char *activation_kernel;  // "silu_gate_f16", "gelu_gate_f16"

    // Embedding
    float embedding_scale;  // 0 = no scaling

    // Per-head norms
    bool has_qk_norm;  // Q/K head-level RMSNorm

    // Post-norms (Gemma-style)
    bool has_post_attn_norm;
    bool has_post_ffn_norm;
    const char *post_attn_norm_key;  // weight name suffix
    const char *post_ffn_norm_key;

    // MLX quantization
    int quant_bits;
    int quant_group_size;

    // RoPE
    const char *rope_kernel;      // fused kernel name
    const char *rope_standalone;  // standalone kernel name
    Buffer rope_freqs;            // precomputed frequencies

    // Output
    bool tie_embeddings;  // reuse embedding for logit projection

    // Encoder-decoder properties
    bool is_encoder_decoder;
    bool has_cross_attention;
    bool has_conv_frontend;
    bool has_bias;
    const char *norm_type;           // "rmsnorm" or "layernorm"
    const char *encoder_activation;
    bool is_embedding_model;
};

Every field in this struct corresponds to a decision point in the dispatch table builder. Let us trace how these fields affect inference.

How the Descriptor Drives Inference

Activation Kernel

activation_kernel = "silu_gate_f16"   -> SwiGLU: SiLU(gate) * up
activation_kernel = "gelu_gate_f16"   -> GeGLU:  GELU(gate) * up
activation_kernel = "gelu_f16"        -> Plain GELU (no gate, Whisper)

The table builder passes this directly to device.get_pipeline():

Pipeline act_pso = device.get_pipeline(arch.activation_kernel);

If you have a new architecture that uses, say, ReLU-squared gating, you would add a relu_sq_gate_f16 Metal kernel and set activation_kernel = "relu_sq_gate_f16".

Embedding Scale

Gemma models multiply the embedding output by sqrt(dim) before feeding it into the transformer. This is captured as:

float embedding_scale;  // 0 = no scaling, sqrt(dim) for Gemma

The table builder checks:

if (arch.embedding_scale > 0.0f) {
    // emit a temperature_scale_f16 dispatch
}

QK-Norm

Qwen3 and Gemma apply per-head RMSNorm to Q and K after projection. The flag has_qk_norm triggers emission of either:

  • A fused head_norm_rope_neox_kv_write_f16 kernel (if compatible RoPE style)
  • Separate head_rmsnorm_f16 dispatches for Q and K

Post-Norms

Gemma has an unusual architecture where there are additional RMSNorm layers after the attention output projection and after the FFN down projection, before the residual add. The flags has_post_attn_norm and has_post_ffn_norm control whether these extra norm dispatches are emitted, and the post_attn_norm_key / post_ffn_norm_key strings tell the table builder what weight names to look up.

RoPE Variant

Different architectures use different RoPE implementations:

Architecturerope_kernelDescription
LLaMArope_qkv_write_f16Standard interleaved RoPE
Qwen3rope_neox_qkv_write_f16NeoX split-half RoPE
Gemmarope_neox_qkv_write_f16NeoX split-half RoPE
WhispernullptrNo RoPE (sinusoidal PE)

Setting rope_kernel = nullptr causes the table builder to skip the RoPE+KV-write dispatch entirely.

Tied Embeddings

When tie_embeddings = true, the logit projection reuses the token embedding weight matrix instead of a separate output.weight tensor:

const char *logit_name = arch.tie_embeddings
    ? "token_embedding.weight"
    : "output.weight";
Buffer logit_w = weights.get_tensor(logit_name);

Walkthrough: Adding a Hypothetical Architecture

Let us add support for a hypothetical “Phoenix” architecture with these properties:

  • SwiGLU activation (same as LLaMA)
  • NeoX-style RoPE (same as Qwen3)
  • No QK-norm
  • No post-norms
  • Tied embeddings
  • Embedding scale of 1.0 / sqrt(dim) (inverse, unlike Gemma)
  • Standard decoder-only (no encoder)

Step 1: Create the Factory Function

In arch_descriptor.h, add:

inline ArchDescriptor arch_phoenix(int dim) {
    ArchDescriptor d = arch_llama();  // start from LLaMA defaults

    // Override what differs
    d.rope_kernel = "rope_neox_qkv_write_f16";  // NeoX RoPE
    d.rope_standalone = "rope_neox_f16";
    d.tie_embeddings = true;
    d.embedding_scale = 1.0f / sqrtf((float)dim);  // inverse scaling

    return d;
}

Notice we start from arch_llama() and only override what is different. This is the inheritance pattern – most architectures share 80% of their properties with LLaMA.

Step 2: Register in arch_from_config

Add a case to the dispatch function:

inline ArchDescriptor arch_from_config(const char *arch_name, int dim) {
    if (strstr(arch_name, "phoenix"))
        return arch_phoenix(dim);
    if (strstr(arch_name, "whisper"))
        return arch_whisper();
    if (strstr(arch_name, "bert"))
        return arch_bert();
    if (strstr(arch_name, "qwen"))
        return arch_qwen3();
    if (strstr(arch_name, "gemma"))
        return arch_gemma(dim);
    return arch_llama();  // default
}

The strstr matching is intentionally loose – it matches "phoenix", "phoenix-1.5", "phoenix_moe", etc. The first match wins, so ordering matters for architectures whose names are substrings of others.

Step 3: Add Weight Name Mappings (if different from LLaMA)

If Phoenix uses different tensor names in its weight files, you need mappings. For GGUF, the tensor names are already canonical (llama.cpp normalizes them). For MLX/SafeTensors, add rules to kMLXRules:

// If Phoenix uses different HF naming:
{"model.layers.{n}.self_attn.qkv_proj.weight",
 "layers.{n}.attention.qkv.weight"},

Most architectures follow the LLaMA naming convention in GGUF, so this step is often unnecessary.

Step 4: Handle Any Unique Weight Structure

If Phoenix has a fused QKV weight (a single matrix instead of separate Q, K, V), you might need to add logic to split it during loading. But if Phoenix follows the standard separate Q/K/V pattern (most models do), nothing extra is needed.

Step 5: Test

Load a Phoenix model and verify:

  1. Config extraction produces correct dimensions
  2. All weights are found and loaded
  3. The dispatch table has the expected number of commands
  4. The embedding scale is applied
  5. NeoX RoPE is used instead of standard RoPE
  6. Output projection uses the embedding weight

A minimal test:

// Load model
akunu_model_t model = akunu_load("phoenix-7b.gguf", "akunu.metallib");
AkunuModelConfig cfg = akunu_get_config(model);

// Verify architecture was detected
assert(strstr(cfg.architecture, "phoenix") != nullptr);

// Generate a few tokens to verify correctness
uint32_t tokens[] = {1, 15043, 29892, 920};  // "Hello, world"
akunu_generate(model, tokens, 4, 10, sampling, callback, nullptr);

akunu_free_model(model);

For numerical correctness testing, compare output logits against a reference implementation (e.g., Hugging Face Transformers in Python) on the same input tokens.

What Does NOT Need to Change

This is the key point of the data-driven approach. When you add a new architecture, the following files remain untouched:

FileWhy it does not change
table_builder.cppReads from ArchDescriptor, never branches on arch name
table_builder.hOnly the function signature, no arch-specific logic
All .metal filesKernels are generic (F16 GEMV, RoPE, etc.)
device.h / device.cppHardware abstraction, no model knowledge
dispatch_table.hFormat of commands, no model knowledge
chain_decoder.cppReplays dispatch table, no model knowledge
prefill.cppUses same kernels via dtype descriptors
serve.hHTTP server, model-agnostic

The only files that change are:

FileWhat changes
arch_descriptor.hAdd factory function + case in arch_from_config
mlx_weight_store.hAdd name mapping rules (if different naming)
weight_store.cppHandle any unique GGUF tensor layout (rare)

That is typically 10-30 lines of code for a standard transformer variant.

When the Descriptor Is Not Enough

There are cases where the ArchDescriptor pattern does not cover an architectural difference:

Mixture of Experts (MoE). MoE models like Mixtral have a routing mechanism that selects a subset of FFN experts per token. This requires a fundamentally different dispatch pattern (router + sparse expert selection) that cannot be expressed as a boolean flag.

Novel attention patterns. If an architecture uses linear attention, sliding window attention with a non-standard pattern, or multi-query attention with a different head structure, the attention dispatch logic may need extension.

Non-standard normalization. If an architecture uses something other than RMSNorm or LayerNorm (e.g., CRMSNorm, QKNorm with different epsilon handling), a new kernel may be needed.

For these cases, the approach is:

  1. Add the new kernel(s) as described in the previous chapter
  2. Add new flags to ArchDescriptor to control the new behavior
  3. Add conditional logic to build_dispatch_table() gated on those flags
  4. The new logic only runs when the flag is set – existing architectures are unaffected

The goal is to keep build_dispatch_table() as a data-driven loop, not a forest of architecture-specific branches. Even when new logic is needed, it should be expressed as “if this flag, use this kernel” rather than “if architecture is X”.

Existing Architecture Descriptors

For reference, here are the current architectures and how they differ:

LLaMA (Default)

activation:     SiLU gate (SwiGLU)
embedding_scale: none
qk_norm:        no
post_norms:     no
rope:           standard interleaved
tie_embeddings: no
norm_type:      rmsnorm
bias:           no

Qwen3

activation:     SiLU gate (SwiGLU)     <- same as LLaMA
embedding_scale: none                   <- same
qk_norm:        YES                     <- different
post_norms:     no                      <- same
rope:           NeoX split-half         <- different
tie_embeddings: YES                     <- different
norm_type:      rmsnorm                 <- same
bias:           no                      <- same

Gemma 3

activation:     GELU gate (GeGLU)       <- different
embedding_scale: sqrt(dim)              <- different
qk_norm:        YES                     <- different
post_norms:     YES (both)              <- different
rope:           NeoX split-half         <- same as Qwen3
tie_embeddings: YES                     <- same as Qwen3
norm_type:      rmsnorm                 <- same
bias:           no                      <- same
sliding_window: every 6th layer global  <- unique

Whisper

activation:     GELU (no gate)          <- different
embedding_scale: none                   <- same as LLaMA
qk_norm:        no                      <- same as LLaMA
post_norms:     no                      <- same
rope:           NONE (sinusoidal PE)    <- different
tie_embeddings: YES                     <- different
norm_type:      layernorm               <- different
bias:           YES (all layers)        <- different
encoder_decoder: YES                    <- unique
cross_attention: YES                    <- unique
conv_frontend:   YES                    <- unique

BERT (nomic-bert)

activation:     SiLU gate (SwiGLU)      <- same as LLaMA
embedding_scale: none                   <- same
qk_norm:        no                      <- same
post_norms:     no                      <- same
rope:           NeoX split-half         <- same as Qwen3
tie_embeddings: no                      <- same as LLaMA
norm_type:      rmsnorm                 <- same
bias:           no                      <- same
is_embedding:   YES                     <- unique

The pattern is clear: each architecture differs from LLaMA in a small number of dimensions. The descriptor captures exactly those differences, and the table builder handles the combinatorics.

Testing Methodology

When you add a new architecture, testing should cover multiple levels:

Level 1: Config Extraction

Verify that the model’s config (dimensions, layer count, head count, etc.) is extracted correctly from either GGUF metadata or MLX config.json:

akunu_model_t model = akunu_load("phoenix-7b.gguf", "akunu.metallib");
AkunuModelConfig cfg = akunu_get_config(model);

assert(cfg.dim == 4096);
assert(cfg.n_layers == 32);
assert(cfg.n_heads == 32);
assert(cfg.n_kv_heads == 8);
assert(cfg.head_dim == 128);

Level 2: Dispatch Table Sanity

Check that the dispatch table has the expected structure. Count the total commands and verify key labels are present:

Expected for a standard 32-layer decoder-only model:
  1  embedding
  1  initial norm
  32 * ~11 commands per layer = 352  (varies with fusion)
  1  output norm
  1  logit projection
  1  argmax
  ~357 total commands

If your architecture has QK-norm, expect +1-2 commands per layer (or 0 if fused). If it has post-norms, expect +1-2 per layer. The numbers do not need to be exact, but a gross mismatch (e.g., 100 commands for a 32-layer model) indicates a problem.

Level 3: Single-Token Numerical Correctness

The gold standard is comparing logit outputs against a reference implementation. The procedure:

  1. Pick a known input sequence (e.g., [1, 15043, 29892])
  2. Run the same input through Hugging Face Transformers in Python
  3. Extract the logits for the last position
  4. Run the same input through Akunu
  5. Compare the top-5 token IDs and their logit values

For quantized models, exact numerical match is not expected. But the top-1 token should match, and top-5 should overlap significantly. If the top-1 tokens diverge on simple inputs, something is wrong with the architecture implementation.

Level 4: Generation Quality

Run the model on a few prompts and check that the output is coherent. This is subjective but important – subtle bugs (wrong RoPE variant, incorrect norm epsilon, swapped Q/K norms) can produce output that is plausible-looking but degraded in quality. Compare against the same model running in its native framework (e.g., MLX for MLX models, llama.cpp for GGUF models).

Troubleshooting Guide

Common issues when adding a new architecture:

ProblemLikely causeFix
Model loads but generates gibberishWrong RoPE variant or wrong rope_thetaCheck rope_kernel field and rope_theta in config
Outputs repeat the same tokenMissing positional encoding (RoPE or PE)Verify RoPE kernel is dispatched, or PE is added
First token correct, rest wrongKV cache write not happeningCheck that RoPE+KV write dispatch exists in table
Crash during weight loadingTensor name mismatchAdd missing name mapping rules
NaN in outputWrong norm epsilon or missing norm weightCheck norm_eps value and weight names
Quality worse than referenceWrong activation (SiLU vs GELU)Check activation_kernel field
Embedding values too large/smallMissing or wrong embedding scaleCheck embedding_scale field

Summary

Adding a new architecture to Akunu is a three-step process:

1. Write an arch_xxx() factory function (5-20 lines)
   |
2. Add case to arch_from_config() (1 line)
   |
3. Add weight name mappings if needed (0-10 lines)
   |
   Done. No kernel changes. No table builder changes.
   No dispatch table changes. No server changes.

The ArchDescriptor pattern is what makes this possible. By encoding architectural decisions as data rather than control flow, the system remains modular: the kernel layer knows about math, the dispatch layer knows about GPU commands, and only the descriptor layer knows about transformer architecture variants. Each layer can be modified independently, and adding a new architecture is a localized change that cannot break existing ones.


  1. The data-driven approach is inspired by compiler design, where instruction selection is driven by pattern tables rather than hand-coded switch statements. The ArchDescriptor plays a similar role to an ISA descriptor in a retargetable code generator.