Adding a New Model Architecture
One of Akunu’s design goals is that adding support for a new transformer architecture should not require touching the core inference loop, the dispatch table builder, or any Metal kernel. Instead, you define a data descriptor that captures the architecture’s unique properties, and the existing machinery handles the rest. This chapter walks through exactly how to do that.1
The Data-Driven Approach
In many inference engines, adding a new architecture means writing a new forward pass function full of if (arch == "llama") ... else if (arch == "qwen") ... branches. Akunu takes a different approach: the ArchDescriptor struct encodes every architecture-specific decision as a data field, and build_dispatch_table() reads from the descriptor without ever branching on architecture name.
Here is the ArchDescriptor:
struct ArchDescriptor {
// Activation
const char *activation_kernel; // "silu_gate_f16", "gelu_gate_f16"
// Embedding
float embedding_scale; // 0 = no scaling
// Per-head norms
bool has_qk_norm; // Q/K head-level RMSNorm
// Post-norms (Gemma-style)
bool has_post_attn_norm;
bool has_post_ffn_norm;
const char *post_attn_norm_key; // weight name suffix
const char *post_ffn_norm_key;
// MLX quantization
int quant_bits;
int quant_group_size;
// RoPE
const char *rope_kernel; // fused kernel name
const char *rope_standalone; // standalone kernel name
Buffer rope_freqs; // precomputed frequencies
// Output
bool tie_embeddings; // reuse embedding for logit projection
// Encoder-decoder properties
bool is_encoder_decoder;
bool has_cross_attention;
bool has_conv_frontend;
bool has_bias;
const char *norm_type; // "rmsnorm" or "layernorm"
const char *encoder_activation;
bool is_embedding_model;
};
Every field in this struct corresponds to a decision point in the dispatch table builder. Let us trace how these fields affect inference.
How the Descriptor Drives Inference
Activation Kernel
activation_kernel = "silu_gate_f16" -> SwiGLU: SiLU(gate) * up
activation_kernel = "gelu_gate_f16" -> GeGLU: GELU(gate) * up
activation_kernel = "gelu_f16" -> Plain GELU (no gate, Whisper)
The table builder passes this directly to device.get_pipeline():
Pipeline act_pso = device.get_pipeline(arch.activation_kernel);
If you have a new architecture that uses, say, ReLU-squared gating, you would add a relu_sq_gate_f16 Metal kernel and set activation_kernel = "relu_sq_gate_f16".
Embedding Scale
Gemma models multiply the embedding output by sqrt(dim) before feeding it into the transformer. This is captured as:
float embedding_scale; // 0 = no scaling, sqrt(dim) for Gemma
The table builder checks:
if (arch.embedding_scale > 0.0f) {
// emit a temperature_scale_f16 dispatch
}
QK-Norm
Qwen3 and Gemma apply per-head RMSNorm to Q and K after projection. The flag has_qk_norm triggers emission of either:
- A fused
head_norm_rope_neox_kv_write_f16kernel (if compatible RoPE style) - Separate
head_rmsnorm_f16dispatches for Q and K
Post-Norms
Gemma has an unusual architecture where there are additional RMSNorm layers after the attention output projection and after the FFN down projection, before the residual add. The flags has_post_attn_norm and has_post_ffn_norm control whether these extra norm dispatches are emitted, and the post_attn_norm_key / post_ffn_norm_key strings tell the table builder what weight names to look up.
RoPE Variant
Different architectures use different RoPE implementations:
| Architecture | rope_kernel | Description |
|---|---|---|
| LLaMA | rope_qkv_write_f16 | Standard interleaved RoPE |
| Qwen3 | rope_neox_qkv_write_f16 | NeoX split-half RoPE |
| Gemma | rope_neox_qkv_write_f16 | NeoX split-half RoPE |
| Whisper | nullptr | No RoPE (sinusoidal PE) |
Setting rope_kernel = nullptr causes the table builder to skip the RoPE+KV-write dispatch entirely.
Tied Embeddings
When tie_embeddings = true, the logit projection reuses the token embedding weight matrix instead of a separate output.weight tensor:
const char *logit_name = arch.tie_embeddings
? "token_embedding.weight"
: "output.weight";
Buffer logit_w = weights.get_tensor(logit_name);
Walkthrough: Adding a Hypothetical Architecture
Let us add support for a hypothetical “Phoenix” architecture with these properties:
- SwiGLU activation (same as LLaMA)
- NeoX-style RoPE (same as Qwen3)
- No QK-norm
- No post-norms
- Tied embeddings
- Embedding scale of 1.0 / sqrt(dim) (inverse, unlike Gemma)
- Standard decoder-only (no encoder)
Step 1: Create the Factory Function
In arch_descriptor.h, add:
inline ArchDescriptor arch_phoenix(int dim) {
ArchDescriptor d = arch_llama(); // start from LLaMA defaults
// Override what differs
d.rope_kernel = "rope_neox_qkv_write_f16"; // NeoX RoPE
d.rope_standalone = "rope_neox_f16";
d.tie_embeddings = true;
d.embedding_scale = 1.0f / sqrtf((float)dim); // inverse scaling
return d;
}
Notice we start from arch_llama() and only override what is different. This is the inheritance pattern – most architectures share 80% of their properties with LLaMA.
Step 2: Register in arch_from_config
Add a case to the dispatch function:
inline ArchDescriptor arch_from_config(const char *arch_name, int dim) {
if (strstr(arch_name, "phoenix"))
return arch_phoenix(dim);
if (strstr(arch_name, "whisper"))
return arch_whisper();
if (strstr(arch_name, "bert"))
return arch_bert();
if (strstr(arch_name, "qwen"))
return arch_qwen3();
if (strstr(arch_name, "gemma"))
return arch_gemma(dim);
return arch_llama(); // default
}
The strstr matching is intentionally loose – it matches "phoenix", "phoenix-1.5", "phoenix_moe", etc. The first match wins, so ordering matters for architectures whose names are substrings of others.
Step 3: Add Weight Name Mappings (if different from LLaMA)
If Phoenix uses different tensor names in its weight files, you need mappings. For GGUF, the tensor names are already canonical (llama.cpp normalizes them). For MLX/SafeTensors, add rules to kMLXRules:
// If Phoenix uses different HF naming:
{"model.layers.{n}.self_attn.qkv_proj.weight",
"layers.{n}.attention.qkv.weight"},
Most architectures follow the LLaMA naming convention in GGUF, so this step is often unnecessary.
Step 4: Handle Any Unique Weight Structure
If Phoenix has a fused QKV weight (a single matrix instead of separate Q, K, V), you might need to add logic to split it during loading. But if Phoenix follows the standard separate Q/K/V pattern (most models do), nothing extra is needed.
Step 5: Test
Load a Phoenix model and verify:
- Config extraction produces correct dimensions
- All weights are found and loaded
- The dispatch table has the expected number of commands
- The embedding scale is applied
- NeoX RoPE is used instead of standard RoPE
- Output projection uses the embedding weight
A minimal test:
// Load model
akunu_model_t model = akunu_load("phoenix-7b.gguf", "akunu.metallib");
AkunuModelConfig cfg = akunu_get_config(model);
// Verify architecture was detected
assert(strstr(cfg.architecture, "phoenix") != nullptr);
// Generate a few tokens to verify correctness
uint32_t tokens[] = {1, 15043, 29892, 920}; // "Hello, world"
akunu_generate(model, tokens, 4, 10, sampling, callback, nullptr);
akunu_free_model(model);
For numerical correctness testing, compare output logits against a reference implementation (e.g., Hugging Face Transformers in Python) on the same input tokens.
What Does NOT Need to Change
This is the key point of the data-driven approach. When you add a new architecture, the following files remain untouched:
| File | Why it does not change |
|---|---|
table_builder.cpp | Reads from ArchDescriptor, never branches on arch name |
table_builder.h | Only the function signature, no arch-specific logic |
All .metal files | Kernels are generic (F16 GEMV, RoPE, etc.) |
device.h / device.cpp | Hardware abstraction, no model knowledge |
dispatch_table.h | Format of commands, no model knowledge |
chain_decoder.cpp | Replays dispatch table, no model knowledge |
prefill.cpp | Uses same kernels via dtype descriptors |
serve.h | HTTP server, model-agnostic |
The only files that change are:
| File | What changes |
|---|---|
arch_descriptor.h | Add factory function + case in arch_from_config |
mlx_weight_store.h | Add name mapping rules (if different naming) |
weight_store.cpp | Handle any unique GGUF tensor layout (rare) |
That is typically 10-30 lines of code for a standard transformer variant.
When the Descriptor Is Not Enough
There are cases where the ArchDescriptor pattern does not cover an architectural difference:
Mixture of Experts (MoE). MoE models like Mixtral have a routing mechanism that selects a subset of FFN experts per token. This requires a fundamentally different dispatch pattern (router + sparse expert selection) that cannot be expressed as a boolean flag.
Novel attention patterns. If an architecture uses linear attention, sliding window attention with a non-standard pattern, or multi-query attention with a different head structure, the attention dispatch logic may need extension.
Non-standard normalization. If an architecture uses something other than RMSNorm or LayerNorm (e.g., CRMSNorm, QKNorm with different epsilon handling), a new kernel may be needed.
For these cases, the approach is:
- Add the new kernel(s) as described in the previous chapter
- Add new flags to
ArchDescriptorto control the new behavior - Add conditional logic to
build_dispatch_table()gated on those flags - The new logic only runs when the flag is set – existing architectures are unaffected
The goal is to keep build_dispatch_table() as a data-driven loop, not a forest of architecture-specific branches. Even when new logic is needed, it should be expressed as “if this flag, use this kernel” rather than “if architecture is X”.
Existing Architecture Descriptors
For reference, here are the current architectures and how they differ:
LLaMA (Default)
activation: SiLU gate (SwiGLU)
embedding_scale: none
qk_norm: no
post_norms: no
rope: standard interleaved
tie_embeddings: no
norm_type: rmsnorm
bias: no
Qwen3
activation: SiLU gate (SwiGLU) <- same as LLaMA
embedding_scale: none <- same
qk_norm: YES <- different
post_norms: no <- same
rope: NeoX split-half <- different
tie_embeddings: YES <- different
norm_type: rmsnorm <- same
bias: no <- same
Gemma 3
activation: GELU gate (GeGLU) <- different
embedding_scale: sqrt(dim) <- different
qk_norm: YES <- different
post_norms: YES (both) <- different
rope: NeoX split-half <- same as Qwen3
tie_embeddings: YES <- same as Qwen3
norm_type: rmsnorm <- same
bias: no <- same
sliding_window: every 6th layer global <- unique
Whisper
activation: GELU (no gate) <- different
embedding_scale: none <- same as LLaMA
qk_norm: no <- same as LLaMA
post_norms: no <- same
rope: NONE (sinusoidal PE) <- different
tie_embeddings: YES <- different
norm_type: layernorm <- different
bias: YES (all layers) <- different
encoder_decoder: YES <- unique
cross_attention: YES <- unique
conv_frontend: YES <- unique
BERT (nomic-bert)
activation: SiLU gate (SwiGLU) <- same as LLaMA
embedding_scale: none <- same
qk_norm: no <- same
post_norms: no <- same
rope: NeoX split-half <- same as Qwen3
tie_embeddings: no <- same as LLaMA
norm_type: rmsnorm <- same
bias: no <- same
is_embedding: YES <- unique
The pattern is clear: each architecture differs from LLaMA in a small number of dimensions. The descriptor captures exactly those differences, and the table builder handles the combinatorics.
Testing Methodology
When you add a new architecture, testing should cover multiple levels:
Level 1: Config Extraction
Verify that the model’s config (dimensions, layer count, head count, etc.) is extracted correctly from either GGUF metadata or MLX config.json:
akunu_model_t model = akunu_load("phoenix-7b.gguf", "akunu.metallib");
AkunuModelConfig cfg = akunu_get_config(model);
assert(cfg.dim == 4096);
assert(cfg.n_layers == 32);
assert(cfg.n_heads == 32);
assert(cfg.n_kv_heads == 8);
assert(cfg.head_dim == 128);
Level 2: Dispatch Table Sanity
Check that the dispatch table has the expected structure. Count the total commands and verify key labels are present:
Expected for a standard 32-layer decoder-only model:
1 embedding
1 initial norm
32 * ~11 commands per layer = 352 (varies with fusion)
1 output norm
1 logit projection
1 argmax
~357 total commands
If your architecture has QK-norm, expect +1-2 commands per layer (or 0 if fused). If it has post-norms, expect +1-2 per layer. The numbers do not need to be exact, but a gross mismatch (e.g., 100 commands for a 32-layer model) indicates a problem.
Level 3: Single-Token Numerical Correctness
The gold standard is comparing logit outputs against a reference implementation. The procedure:
- Pick a known input sequence (e.g.,
[1, 15043, 29892]) - Run the same input through Hugging Face Transformers in Python
- Extract the logits for the last position
- Run the same input through Akunu
- Compare the top-5 token IDs and their logit values
For quantized models, exact numerical match is not expected. But the top-1 token should match, and top-5 should overlap significantly. If the top-1 tokens diverge on simple inputs, something is wrong with the architecture implementation.
Level 4: Generation Quality
Run the model on a few prompts and check that the output is coherent. This is subjective but important – subtle bugs (wrong RoPE variant, incorrect norm epsilon, swapped Q/K norms) can produce output that is plausible-looking but degraded in quality. Compare against the same model running in its native framework (e.g., MLX for MLX models, llama.cpp for GGUF models).
Troubleshooting Guide
Common issues when adding a new architecture:
| Problem | Likely cause | Fix |
|---|---|---|
| Model loads but generates gibberish | Wrong RoPE variant or wrong rope_theta | Check rope_kernel field and rope_theta in config |
| Outputs repeat the same token | Missing positional encoding (RoPE or PE) | Verify RoPE kernel is dispatched, or PE is added |
| First token correct, rest wrong | KV cache write not happening | Check that RoPE+KV write dispatch exists in table |
| Crash during weight loading | Tensor name mismatch | Add missing name mapping rules |
| NaN in output | Wrong norm epsilon or missing norm weight | Check norm_eps value and weight names |
| Quality worse than reference | Wrong activation (SiLU vs GELU) | Check activation_kernel field |
| Embedding values too large/small | Missing or wrong embedding scale | Check embedding_scale field |
Summary
Adding a new architecture to Akunu is a three-step process:
1. Write an arch_xxx() factory function (5-20 lines)
|
2. Add case to arch_from_config() (1 line)
|
3. Add weight name mappings if needed (0-10 lines)
|
Done. No kernel changes. No table builder changes.
No dispatch table changes. No server changes.
The ArchDescriptor pattern is what makes this possible. By encoding architectural decisions as data rather than control flow, the system remains modular: the kernel layer knows about math, the dispatch layer knows about GPU commands, and only the descriptor layer knows about transformer architecture variants. Each layer can be modified independently, and adding a new architecture is a localized change that cannot break existing ones.
-
The data-driven approach is inspired by compiler design, where instruction selection is driven by pattern tables rather than hand-coded switch statements. The
ArchDescriptorplays a similar role to an ISA descriptor in a retargetable code generator. ↩