Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Architecture Descriptors: Data-Driven Design

Here is a problem that every multi-architecture inference engine faces: Llama uses SiLU activation and standard RoPE. Qwen3 uses SiLU but NeoX-style RoPE and per-head QK norms. Gemma uses GELU activation, NeoX RoPE, QK norms, post-attention norms, post-FFN norms, and scales the embedding output by sqrt(dim). Whisper uses LayerNorm instead of RMSNorm, has bias terms on every linear layer, uses sinusoidal positional encoding instead of RoPE, and has an encoder-decoder architecture.

How do you handle all of this without drowning in if (arch == "llama") ... else if (arch == "qwen3") ... else if (arch == "gemma") ... branches scattered across your codebase?

Akunu’s answer is the ArchDescriptor: a plain data struct that captures all architecture-specific behavior as fields. The table builder and prefill engine read from this struct – they never branch on architecture name. Adding a new architecture means writing one factory function and one case in the lookup function. Zero changes to the core engine.

The Problem in Detail

Consider the table builder. It needs to emit dispatch commands for one transformer layer. Here is a partial list of things that vary by architecture:

AspectLlamaQwen3GemmaWhisper
ActivationSiLU gateSiLU gateGELU gateGELU (no gate)
RoPE styleStandard (interleaved)NeoX (split-half)NeoX (split-half)None (sinusoidal PE)
QK head normsNoYesYesNo
Post-attention normNoNoYesNo
Post-FFN normNoNoYesNo
Embedding scalingNoNoYes (sqrt(dim))No
Tied embeddingsNoYesYesYes
Encoder-decoderNoNoNoYes
Cross attentionNoNoNoYes
Linear biasNoNoNoYes
Norm typeRMSNormRMSNormRMSNormLayerNorm

Without a descriptor, you would need if-else branches for each of these in every place the table builder makes an architecture-dependent decision. That is 11+ branch points per layer, across hundreds of lines of code. It is unreadable, error-prone, and a maintenance nightmare.

The ArchDescriptor Struct

The descriptor is defined in src/core/arch_descriptor.h:

struct ArchDescriptor {
    // Activation
    const char *activation_kernel;  // "silu_gate_f16", "gelu_gate_f16", or "gelu_f16"

    // Embedding
    float embedding_scale;  // 0 = no scaling, >0 = scale by this value

    // Per-head norms
    bool has_qk_norm;

    // Post-norms (Gemma-style)
    bool has_post_attn_norm;
    bool has_post_ffn_norm;
    const char *post_attn_norm_key;  // weight suffix: "post_attention_norm"
    const char *post_ffn_norm_key;   // "post_ffw_norm"

    // MLX quantization
    int quant_bits;        // 0 = not MLX, 3/4/6/8 = MLX quant bits
    int quant_group_size;  // typically 64

    // RoPE
    const char *rope_kernel;      // fused RoPE+KV write kernel name
    const char *rope_standalone;  // standalone RoPE kernel name
    Buffer rope_freqs;            // precomputed frequency divisors

    // Output
    bool tie_embeddings;  // use token_embedding.weight for logit projection

    // Encoder-decoder
    bool is_encoder_decoder;
    bool has_cross_attention;
    bool has_conv_frontend;
    bool has_bias;
    const char *norm_type;          // "rmsnorm" or "layernorm"
    const char *encoder_activation;
    bool is_embedding_model;        // BERT-style encoder-only
};

Every field is a simple type: bool, int, float, const char*, or Buffer (a POD struct). No virtual methods. No inheritance. No dynamism. The descriptor is created once and read many times.

Field Categories

The fields fall into several categories:

Kernel selection fields (activation_kernel, rope_kernel, rope_standalone, encoder_activation): These are kernel function names that get passed directly to device.get_pipeline(). The table builder does not care what activation function is used – it just dispatches whatever kernel the descriptor says.

Boolean feature flags (has_qk_norm, has_post_attn_norm, has_post_ffn_norm, tie_embeddings, is_encoder_decoder, has_cross_attention, has_conv_frontend, has_bias, is_embedding_model): These control whether certain dispatch commands are emitted. A false flag means the corresponding block is simply skipped.

Numeric parameters (embedding_scale, quant_bits, quant_group_size): These are values that get baked into kernel parameters.

Weight key strings (post_attn_norm_key, post_ffn_norm_key): These are the suffixes used to look up weight tensors. Gemma calls its post-attention norm weights layers.N.post_attention_norm.weight, while other architectures do not have them at all.

Factory Functions

Each supported architecture has a factory function that returns a fully initialized descriptor:

arch_llama() – The Default

inline ArchDescriptor arch_llama() {
    ArchDescriptor d = {};
    d.activation_kernel = "silu_gate_f16";
    d.embedding_scale = 0.0f;
    d.has_qk_norm = false;
    d.has_post_attn_norm = false;
    d.has_post_ffn_norm = false;
    d.rope_kernel = "rope_qkv_write_f16";  // standard (interleaved)
    d.rope_standalone = "rope_f16";
    d.tie_embeddings = false;
    d.is_encoder_decoder = false;
    d.has_bias = false;
    d.norm_type = "rmsnorm";
    return d;
}

Llama is the simplest architecture and serves as the default. SiLU-gated activation, standard RoPE, no special norms, no encoder. Most LLaMA-family models (Llama 2, Llama 3, Mistral, etc.) use this descriptor.1

arch_qwen3() – Incremental Derivation

inline ArchDescriptor arch_qwen3() {
    ArchDescriptor d = arch_llama();  // start from llama defaults
    d.has_qk_norm = true;
    d.rope_kernel = "rope_neox_qkv_write_f16";  // NeoX (split-half)
    d.rope_standalone = "rope_neox_f16";
    d.tie_embeddings = true;
    return d;
}

Qwen3 starts from Llama and overrides three things: adds QK head norms, switches to NeoX-style RoPE, and ties the embedding weights. This derivation pattern makes the differences explicit and minimizes redundancy.

arch_gemma() – The Complex One

inline ArchDescriptor arch_gemma(int dim) {
    ArchDescriptor d = arch_llama();  // start from llama defaults
    d.activation_kernel = "gelu_gate_f16";
    d.embedding_scale = sqrtf((float)dim);
    d.has_qk_norm = true;
    d.has_post_attn_norm = true;
    d.has_post_ffn_norm = true;
    d.post_attn_norm_key = "post_attention_norm";
    d.post_ffn_norm_key = "post_ffw_norm";
    d.rope_kernel = "rope_neox_qkv_write_f16";
    d.rope_standalone = "rope_neox_f16";
    d.tie_embeddings = true;
    return d;
}

Gemma is the most feature-rich decoder-only architecture. Notice that embedding_scale is computed from the model dimension – this is why the factory takes dim as a parameter. The value sqrt(dim) is specific to Gemma’s architectural design.2

arch_whisper() – Encoder-Decoder

inline ArchDescriptor arch_whisper() {
    ArchDescriptor d = {};
    d.activation_kernel = "gelu_f16";  // plain GELU (no gate)
    d.rope_kernel = nullptr;  // no RoPE
    d.rope_standalone = nullptr;
    d.tie_embeddings = true;
    d.is_encoder_decoder = true;
    d.has_cross_attention = true;
    d.has_conv_frontend = true;
    d.has_bias = true;
    d.norm_type = "layernorm";
    d.encoder_activation = "gelu_f16";
    return d;
}

Whisper does NOT derive from Llama – it starts from a zeroed struct because almost everything is different. No RoPE, no gated activation, LayerNorm instead of RMSNorm, bias on every linear layer, and an encoder-decoder architecture with cross-attention and a convolutional audio frontend.3

arch_bert() – Embedding Model

inline ArchDescriptor arch_bert() {
    ArchDescriptor d = {};
    d.activation_kernel = "silu_gate_f16";  // SwiGLU (nomic-bert, modernBERT)
    d.rope_kernel = "rope_neox_qkv_write_f16";
    d.rope_standalone = "rope_neox_f16";
    d.norm_type = "rmsnorm";
    d.is_embedding_model = true;
    return d;
}

BERT-style models (specifically modern variants like nomic-bert) use a LLaMA-like architecture with bidirectional attention. The is_embedding_model flag tells the inference engine to skip autoregressive decoding and instead return the hidden states for mean-pooling.

The Lookup Function

inline ArchDescriptor arch_from_config(const char *arch_name, int dim) {
    if (strstr(arch_name, "whisper")) return arch_whisper();
    if (strstr(arch_name, "bert"))    return arch_bert();
    if (strstr(arch_name, "qwen"))    return arch_qwen3();
    if (strstr(arch_name, "gemma"))   return arch_gemma(dim);
    return arch_llama();  // Default: LLaMA-like
}

This function reads the architecture field from the GGUF metadata (a string like “llama”, “qwen3”, “gemma2”, “whisper”) and returns the appropriate descriptor. The strstr matching is intentionally loose – “gemma2” and “gemma3” both match “gemma”, “qwen2” and “qwen3” both match “qwen”. This works because the architectural differences between Gemma 2 and Gemma 3, or between Qwen 2 and Qwen 3, are captured in the model’s weight metadata (different head counts, etc.), not in the descriptor.

The default fallback is arch_llama(), which handles the vast majority of GGUF models in the wild (Llama, Mistral, CodeLlama, etc.).

How the Table Builder Uses Descriptors

Let’s trace through the table builder to see how the descriptor eliminates branching.

Activation Kernel

// In the SiLU/GELU dispatch:
Pipeline act_pso = device.get_pipeline(arch.activation_kernel);

The table builder does not know or care whether the activation is SiLU or GELU. It dispatches whatever arch.activation_kernel says. This single line replaces:

// Without descriptors (do NOT do this):
Pipeline act_pso;
if (cfg.architecture == "llama" || cfg.architecture == "qwen3")
    act_pso = device.get_pipeline("silu_gate_f16");
else if (cfg.architecture == "gemma")
    act_pso = device.get_pipeline("gelu_gate_f16");
else if (cfg.architecture == "whisper")
    act_pso = device.get_pipeline("gelu_f16");

RoPE Selection

// In the RoPE dispatch:
if (arch.rope_kernel) {
    Pipeline rope_pso = device.get_pipeline(arch.rope_kernel);
    // ... emit RoPE command ...
}
// If rope_kernel is nullptr (Whisper), the block is skipped entirely

Post-Norms

if (arch.has_post_attn_norm) {
    snprintf(name, sizeof(name), "layers.%d.%s.weight",
             layer, arch.post_attn_norm_key);
    Buffer post_norm_w = weights.get_tensor(name);
    emit_standalone_norm(table, device, chip, scratch.residual,
                         post_norm_w, scratch.post_norm, dim, cfg.norm_eps);
}

The boolean flag controls whether the block exists; the key string controls which weight tensor is loaded. For Llama and Qwen3, has_post_attn_norm is false and this code is never reached.

Embedding Scaling

if (arch.embedding_scale > 0.0f) {
    float scale = arch.embedding_scale;
    CmdBuilder(table, device.get_pipeline("temperature_scale_f16"), ...)
        .buf(scratch.h0, 0)
        .params(scale, 1)
        .emit();
}

Only Gemma has embedding_scale > 0, so only Gemma gets this dispatch command.

Tied Embeddings

const char *logit_name = arch.tie_embeddings
    ? "token_embedding.weight"
    : "output.weight";
Buffer logit_w = weights.get_tensor(logit_name);

A single ternary replaces a multi-way branch.

Fused QK-Norm + RoPE

The table builder checks whether the fused kernel is applicable:

bool use_fused_norm_rope = arch.has_qk_norm &&
    strcmp(arch.rope_kernel, "rope_neox_qkv_write_f16") == 0;

This is true for Qwen3 and Gemma (NeoX RoPE + QK norms), false for everything else. When true, a single fused kernel replaces 3 separate dispatches.

The DTypeDescriptor: Kernel Registry

While ArchDescriptor captures architecture-level variation, DTypeDescriptor captures dtype-level variation. It is defined in src/core/dtype_descriptor.h:

struct DTypeDescriptor {
    uint32_t dtype;
    const char *gemv_kernel;
    const char *gemv_large_kernel;
    const char *gemv_wide_kernel;
    const char *gemm_kernel;
    const char *gemm_small_kernel;
    const char *embedding_kernel;
    const char *fused_silu_kernel;
    const char *fused_silu_large_kernel;
    int gemv_rows_per_tg;
    int gemv_tg_size;
    int large_rows_per_tg;
    int large_tg_size;
    int wide_rows_per_tg;
    int wide_tg_size;
    int fused_silu_rows;
    int fused_silu_tg;
    int fused_silu_large_rows;
    int fused_silu_large_tg;
    bool is_mlx;
};

This is a flat POD struct with no pointers to chase. Each supported GGUF dtype code gets an entry in a static lookup table:

static const DTypeDescriptor kDTypes[] = {
    // dtype  gemv           gemv_large     gemv_wide         gemm              gemm_small           embed                  fused_silu          fused_silu_large  rows tg  lg_r lg_tg w_r w_tg fs_r fs_tg fsl_r fsl_tg mlx
    {0,  "gemv_f16",         nullptr,       "gemv_wide_f16",  "simd_gemm_f16",  "simd_gemm_small_f16", nullptr,             nullptr,            nullptr,           16, 128, 0,   0,   64, 256, 0,   0,    0,    0,    false},
    {2,  "gemv_q4_0",        nullptr,       "gemv_wide_q4_0", "simd_gemm_q4_0", "simd_gemm_small_q4_0","embedding_lookup_q4_0","gemv_q4_0_silu", nullptr,          16, 128, 0,   0,   64, 256, 16,  128,  0,    0,    false},
    {8,  "gemv_q8_0",        nullptr,       "gemv_wide_q8_0", "simd_gemm_q8_0", "simd_gemm_small_q8_0","embedding_lookup_q8_0", nullptr,         nullptr,          32, 256, 0,   0,   64, 256, 0,   0,    0,    0,    false},
    {12, "gemv_q4_k",        nullptr,       "gemv_wide_q4_k", "simd_gemm_q4_k", "simd_gemm_small_q4_k","embedding_lookup_q4_k", nullptr,         nullptr,          16, 256, 0,   0,   32, 256, 0,   0,    0,    0,    false},
    {100,"gemv_mlx_q4",      nullptr,       "gemv_wide_mlx_q4","simd_gemm_mlx_q4","simd_gemm_small_mlx_q4","embedding_lookup_mlx_q4","gemv_mlx_q4_silu",nullptr,   16, 128, 0,   0,   32, 256, 16,  128,  0,    0,    true},
    // ... more entries ...
};

The table currently has 17 entries covering F16, BF16, Q4_0, Q4_1, Q5_0, Q8_0, Q2_K through Q6_K, and MLX Q3/Q4/Q6/Q8 formats.

How the Table Builder Uses DTypeDescriptor

The emit_gemv helper in table_builder.cpp reads from the descriptor to select the right kernel and dispatch geometry:

static void emit_gemv(DispatchTable& table, Device& device,
                      const ChipConfig& chip,
                      Buffer input, Buffer weight, Buffer output,
                      int output_offset, uint32_t dtype, int N, int K, ...) {
    const auto& dt = device.dtype_lookup(dtype);

    // Select kernel variant
    bool use_wide = (N > chip.wide_gemv_threshold && dt.gemv_wide_kernel != nullptr);
    bool use_large = (!use_wide && dt.gemv_large_kernel != nullptr
                      && K >= chip.q4_small_k_threshold);

    const char *kernel_name;
    int rows, tg;
    if (use_wide) {
        kernel_name = dt.gemv_wide_kernel;
        rows = dt.wide_rows_per_tg;
        tg = dt.wide_tg_size;
    } else if (use_large) {
        kernel_name = dt.gemv_large_kernel;
        rows = dt.large_rows_per_tg;
        tg = dt.large_tg_size;
    } else {
        kernel_name = dt.gemv_kernel;
        rows = dt.gemv_rows_per_tg;
        tg = dt.gemv_tg_size;
    }

    Pipeline pso = device.get_pipeline(kernel_name);
    int n_groups = (N + rows - 1) / rows;
    // ... create and emit DispatchCmd ...
}

The table builder never mentions “Q4_0” or “Q8_0” by name. It looks up the descriptor for the weight’s dtype code and uses whatever kernel and geometry the table specifies. Supporting a new quantization format means adding one row to kDTypes[] and writing the corresponding Metal kernel.

The is_mlx Flag

MLX quantized formats require special handling: function constants for group_size and K, a different parameter struct layout (MLXParams vs GEMMParams), and different weight byte calculations. The is_mlx boolean flag in DTypeDescriptor controls this:

if (dt.is_mlx) {
    uint32_t fc_indices[] = {0, 1};
    uint32_t fc_values[] = {(uint32_t)quant_group_size, (uint32_t)K};
    pso = device.get_pipeline(kernel_name, cache_key, fc_indices, fc_values, 2);
} else {
    pso = device.get_pipeline(kernel_name);
}

This is the one architectural branch in emit_gemv that is not purely data-driven. It could be eliminated by adding a “pipeline creation strategy” function pointer to the descriptor, but the additional complexity is not justified for two variants.

Design Principles

The ArchDescriptor and DTypeDescriptor embody several design principles:

1. Data Over Code

Instead of encoding architecture-specific behavior in if/else chains, encode it in data structures. The table builder reads the data; it does not interpret architecture names.

2. Flat Over Deep

Both descriptors are flat POD structs. No inheritance hierarchies, no virtual methods, no builder patterns. This makes them trivially copyable, inspectable in a debugger, and cacheline-friendly.

3. Explicit Over Implicit

Every architecture-specific behavior is a named field. When you read arch.has_qk_norm, you know exactly what it controls. There are no hidden side effects, no overridden methods with subtle behavior differences.

4. Default-Safe

The zero-initialization of ArchDescriptor (ArchDescriptor d = {}) produces a safe state: no activation (nullptr), no RoPE (nullptr), no norms (false), not encoder-decoder (false). Every factory function explicitly sets the fields it needs. This prevents “forgot to set X” bugs.

5. Derivation Without Inheritance

arch_qwen3() derives from arch_llama() by copying the struct and overriding fields. This is simpler and more explicit than C++ inheritance. You can see all the differences at a glance.

Adding a New Architecture

To add support for a hypothetical “Falcon” architecture:

  1. Write a factory function:
inline ArchDescriptor arch_falcon() {
    ArchDescriptor d = arch_llama();  // similar to Llama
    d.has_bias = true;               // Falcon has bias
    d.rope_kernel = "rope_neox_qkv_write_f16";  // NeoX RoPE
    d.rope_standalone = "rope_neox_f16";
    return d;
}
  1. Add a case to the lookup:
inline ArchDescriptor arch_from_config(const char *arch_name, int dim) {
    if (strstr(arch_name, "falcon")) return arch_falcon();
    // ... existing cases ...
}
  1. That’s it. No changes to table_builder.cpp, prefill.cpp, or any other core file (assuming the existing kernel set handles the architecture’s needs).

Summary

The ArchDescriptor pattern replaces architecture-specific branching with data-driven dispatch. A flat POD struct captures all variation points (activation, RoPE, norms, embeddings, encoder/decoder). Factory functions create descriptors for each architecture, with incremental derivation from the Llama default. The DTypeDescriptor does the same for quantization formats, mapping dtype codes to kernel names and dispatch geometry. Together, these two descriptors make the table builder and prefill engine completely agnostic to both architecture and quantization format.



  1. Touvron, H., et al. (2023). “LLaMA: Open and Efficient Foundation Language Models.” arXiv:2302.13971. LLaMA’s architecture (RMSNorm, SwiGLU, RoPE, no bias) has become the de facto standard for open-weight LLMs. See https://arxiv.org/abs/2302.13971.

  2. Google DeepMind. (2024). “Gemma: Open Models Based on Gemini Research and Technology.” arXiv:2403.08295. Gemma scales embedding outputs by sqrt(dim), which is uncommon in other LLM architectures but follows the original Transformer convention. See https://arxiv.org/abs/2403.08295.

  3. Radford, A., et al. (2023). “Robust Speech Recognition via Large-Scale Weak Supervision.” Proceedings of ICML 2023. Whisper’s architecture uses sinusoidal positional encoding, LayerNorm, and cross-attention, following the original Transformer encoder-decoder design. See https://arxiv.org/abs/2212.04356.