Skip to main content

sapient_models/forward/
llama.rs

1//! Llama-family causal LM forward pass (Llama, Mistral, Qwen, SmolVLM text backbone).
2
3use std::collections::HashMap;
4
5use anyhow::Result;
6use sapient_core::Tensor;
7use sapient_hub::model_info::ModelInfo;
8
9use super::backend::{LlmBackend, LlmBackendDispatch, LlmBackendKind};
10use super::common::{
11    embed_tokens, mean_pool_hidden, merge_heads, quantize_tensor_to_q8_0,
12    should_quantize_online, split_heads,
13};
14use crate::weights::{
15    detect_weight_prefix, load_hf_weights, resolve_bias, resolve_lm_head, resolve_weight,
16    tie_word_embeddings_from_config,
17};
18
19/// Per-layer KV cache stored as concatenated 4-D tensors.
20#[derive(Debug, Default, Clone)]
21struct LayerCache {
22    keys: Option<Tensor>,
23    values: Option<Tensor>,
24    seq_len: usize,
25}
26
27/// Real Llama-architecture forward engine backed by safetensors weights.
28pub struct LlamaForward {
29    info: ModelInfo,
30    prefix: String,
31    weights: HashMap<String, Tensor>,
32    embed_key: String,
33    lm_head: Tensor,
34    cache: Vec<LayerCache>,
35    backend: LlmBackendDispatch,
36}
37
38impl LlamaForward {
39    pub fn from_files(info: ModelInfo, weight_paths: &[std::path::PathBuf]) -> Result<Self> {
40        Self::from_files_with_backend(info, weight_paths, LlmBackendKind::Auto)
41    }
42
43    pub fn from_files_with_backend(
44        info: ModelInfo,
45        weight_paths: &[std::path::PathBuf],
46        backend: LlmBackendKind,
47    ) -> Result<Self> {
48        let weights = load_hf_weights(weight_paths)?;
49        Self::from_weights_with_backend(info, weights, backend)
50    }
51
52    pub fn from_weights(info: ModelInfo, weights: HashMap<String, Tensor>) -> Result<Self> {
53        Self::from_weights_with_backend(info, weights, LlmBackendKind::Auto)
54    }
55
56    pub fn from_weights_with_backend(
57        info: ModelInfo,
58        weights: HashMap<String, Tensor>,
59        backend: LlmBackendKind,
60    ) -> Result<Self> {
61        let prefix = detect_weight_prefix(&weights);
62
63        // Online quantization: convert F16/BF16 projection matrices to Q8_0 at
64        // load time.  This is strictly better than expanding to F32:
65        //   - F32 expansion: 2 bytes/weight (F16) -> 4 bytes/weight (F32) = 2x RAM
66        //   - Q8_0 quantization: 2 bytes/weight (F16) -> ~1.06 bytes/weight = half RAM
67        //   - Per-step bandwidth: Q8_0 kernel reads ~1 byte/weight vs 4 for F32
68        //   - Quality: Q8_0 is near-lossless (~0.01 PPL increase over F16)
69        // Norm weights, biases, and embeddings retain their original dtype since
70        // they are accessed differently (row gather, broadcast, etc.).
71        // For already-quantized (Q4_0/Q8_0/K-quant) models this is a no-op.
72        let weights: HashMap<String, Tensor> = weights
73            .into_iter()
74            .map(|(k, v)| {
75                if should_quantize_online(&k, &v) {
76                    (k, quantize_tensor_to_q8_0(v))
77                } else {
78                    (k, v)
79                }
80            })
81            .collect();
82        let embed_key = format!("{prefix}embed_tokens.weight");
83        let tie = tie_word_embeddings_from_config(&info.raw);
84        let lm_head = resolve_lm_head(&weights, &prefix, tie, &embed_key)?.clone();
85        validate_core_shapes(&info, &weights, &embed_key, &lm_head)?;
86        let backend = LlmBackendDispatch::from_kind(backend)?;
87        tracing::debug!(
88            backend = backend.name(),
89            "initialized Llama forward backend"
90        );
91
92        let max_seq = info.max_position_embeddings;
93        let n_kv = info.num_key_value_heads;
94        let hd = info.head_dim;
95        let cache_shape = vec![1, n_kv, max_seq, hd];
96
97        // Allocate KV cache as Q8_0 (4× smaller than F32) when head_dim is a multiple
98        // of 32 (the Q8_0 block size).  Fall back to F32 otherwise.
99        let use_q8_cache = hd % 32 == 0;
100
101        let cache = (0..info.num_hidden_layers)
102            .map(|_| {
103                let (keys, values) = if use_q8_cache {
104                    // Q8_0: numel/32 blocks × 34 bytes each.
105                    let numel = n_kv * max_seq * hd;
106                    let kv_bytes = numel / 32 * 34;
107                    let k = Tensor::from_quant_bytes(
108                        &vec![0u8; kv_bytes],
109                        cache_shape.clone(),
110                        sapient_core::DType::Q8_0,
111                    )
112                    .unwrap();
113                    let v = Tensor::from_quant_bytes(
114                        &vec![0u8; kv_bytes],
115                        cache_shape.clone(),
116                        sapient_core::DType::Q8_0,
117                    )
118                    .unwrap();
119                    (k, v)
120                } else {
121                    let k = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
122                    let v = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
123                    (k, v)
124                };
125                LayerCache {
126                    keys: Some(keys),
127                    values: Some(values),
128                    seq_len: 0,
129                }
130            })
131            .collect();
132
133        Ok(Self {
134            cache,
135            info,
136            prefix,
137            embed_key,
138            lm_head,
139            weights,
140            backend,
141        })
142    }
143
144    pub fn reset_cache(&mut self) {
145        for layer in &mut self.cache {
146            layer.seq_len = 0;
147        }
148    }
149
150    /// Run forward on token ids and return logits for the last token.
151    pub fn forward_logits(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Vec<f32>> {
152        let hidden = self.forward_hidden(input_ids, use_cache)?;
153        self.backend.logits_from_hidden(&hidden, &self.lm_head)
154    }
155
156    /// Returns logits for ALL positions without updating the KV cache.
157    /// Used by speculative decoding to verify draft tokens in one shot.
158    pub fn forward_all_logits(&mut self, input_ids: &[u32]) -> Result<Vec<Vec<f32>>> {
159        let hidden = self.forward_hidden(input_ids, false)?;
160        self.backend.all_logits_from_hidden(&hidden, &self.lm_head)
161    }
162
163    /// Mean-pooled hidden states for embedding models.
164    pub fn embed(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
165        self.reset_cache();
166        let hidden = self.forward_hidden(input_ids, false)?;
167        mean_pool_hidden(&hidden)
168    }
169
170    fn forward_hidden(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Tensor> {
171        let embed = self
172            .weights
173            .get(&self.embed_key)
174            .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{}'", self.embed_key))?;
175        let mut x = embed_tokens(embed, input_ids)?;
176
177        let start_pos = if use_cache {
178            self.cache.first().map(|l| l.seq_len).unwrap_or(0)
179        } else {
180            self.reset_cache();
181            0
182        };
183
184        let seq_len = input_ids.len();
185        let positions: Vec<usize> = (start_pos..start_pos + seq_len).collect();
186
187        for layer_idx in 0..self.info.num_hidden_layers {
188            x = self.forward_layer(x, layer_idx, &positions, use_cache)?;
189        }
190
191        let norm_w = resolve_weight(&self.weights, &self.prefix, "norm")?;
192        self.backend
193            .rms_norm(&x, norm_w, self.info.rms_norm_eps as f32)
194    }
195
196    fn forward_layer(
197        &mut self,
198        x: Tensor,
199        layer_idx: usize,
200        positions: &[usize],
201        use_cache: bool,
202    ) -> Result<Tensor> {
203        let pfx = format!("layers.{layer_idx}");
204        let eps = self.info.rms_norm_eps as f32;
205        let n_heads = self.info.num_attention_heads;
206        let n_kv = self.info.num_key_value_heads;
207        let head_dim = self.info.head_dim;
208
209        let attn_norm_w = resolve_weight(
210            &self.weights,
211            &self.prefix,
212            &format!("{pfx}.input_layernorm"),
213        )?;
214        let h = self.backend.rms_norm(&x, attn_norm_w, eps)?;
215
216        // Q/K/V projections — parallel on CPU (rayon thread pool, thread-safe).
217        // Sequential on Metal/GPU: Metal command buffers do not support concurrent
218        // encoding from multiple threads — parallel join causes assertion failures.
219        let q_name = format!("{pfx}.self_attn.q_proj");
220        let k_name = format!("{pfx}.self_attn.k_proj");
221        let v_name = format!("{pfx}.self_attn.v_proj");
222        let (q, k, v) = if self.backend.is_cpu() {
223            let ((q_res, k_res), v_res) = rayon::join(
224                || rayon::join(
225                    || self.linear(&h, &q_name),
226                    || self.linear(&h, &k_name),
227                ),
228                || self.linear(&h, &v_name),
229            );
230            (q_res?, k_res?, v_res?)
231        } else {
232            let q = self.linear(&h, &q_name)?;
233            let k = self.linear(&h, &k_name)?;
234            let v = self.linear(&h, &v_name)?;
235            (q, k, v)
236        };
237
238        let mut q = split_heads(&q, n_heads, head_dim)?;
239        let mut k = split_heads(&k, n_kv, head_dim)?;
240        let mut v = split_heads(&v, n_kv, head_dim)?;
241
242        q = self
243            .backend
244            .apply_rope_positions(&q, positions, self.info.rope_theta as f32)?;
245        k = self
246            .backend
247            .apply_rope_positions(&k, positions, self.info.rope_theta as f32)?;
248
249        let cache = &mut self.cache[layer_idx];
250        if use_cache {
251            let current_seq = cache.seq_len;
252            if let (Some(ck), Some(cv)) = (&mut cache.keys, &mut cache.values) {
253                k = crate::forward::common::update_kv_cache(ck, current_seq, &k)?;
254                v = crate::forward::common::update_kv_cache(cv, current_seq, &v)?;
255            }
256            cache.seq_len = current_seq + positions.len();
257        }
258
259        let attn = self.backend.gqa_attention(&q, &k, &v, n_kv, true)?;
260        let attn = merge_heads(&attn)?;
261        let o = self.linear(&attn, &format!("{pfx}.self_attn.o_proj"))?;
262        let x = self.backend.add(&x, &o)?;
263
264        let ffn_norm_w = resolve_weight(
265            &self.weights,
266            &self.prefix,
267            &format!("{pfx}.post_attention_layernorm"),
268        )?;
269        let h = self.backend.rms_norm(&x, ffn_norm_w, eps)?;
270
271        // Gate and up projections — parallel on CPU, sequential on Metal.
272        let gate_w = resolve_weight(&self.weights, &self.prefix, &format!("{pfx}.mlp.gate_proj"))?;
273        let up_w = resolve_weight(&self.weights, &self.prefix, &format!("{pfx}.mlp.up_proj"))?;
274        let (gate, up) = if self.backend.is_cpu() {
275            let backend = &self.backend;
276            let (gr, ur) = rayon::join(
277                || backend.linear_3d(&h, gate_w),
278                || backend.linear_3d(&h, up_w),
279            );
280            (gr?, ur?)
281        } else {
282            (
283                self.backend.linear_3d(&h, gate_w)?,
284                self.backend.linear_3d(&h, up_w)?,
285            )
286        };
287        let gate = self.backend.silu(&gate)?;
288        let mid = self.backend.mul(&gate, &up)?;
289        let down = self.backend.linear_3d(
290            &mid,
291            resolve_weight(&self.weights, &self.prefix, &format!("{pfx}.mlp.down_proj"))?,
292        )?;
293        self.backend.add(&x, &down)
294    }
295
296    /// Linear projection that automatically applies a bias when the model has one
297    /// (Qwen2 q/k/v), and is a plain matmul otherwise (Llama, Mistral).
298    fn linear(&self, x: &Tensor, name: &str) -> Result<Tensor> {
299        let weight = resolve_weight(&self.weights, &self.prefix, name)?;
300        let bias = resolve_bias(&self.weights, &self.prefix, name);
301        self.backend.linear_3d_bias(x, weight, bias)
302    }
303}
304
305fn validate_core_shapes(
306    info: &ModelInfo,
307    weights: &HashMap<String, Tensor>,
308    embed_key: &str,
309    lm_head: &Tensor,
310) -> Result<()> {
311    let embed = weights
312        .get(embed_key)
313        .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{embed_key}'"))?;
314    let embed_dims = embed.shape().dims();
315    if embed_dims.len() != 2 || embed_dims[1] != info.hidden_size {
316        anyhow::bail!(
317            "embedding shape mismatch at '{embed_key}': expected [vocab, {}], got {:?}",
318            info.hidden_size,
319            embed_dims
320        );
321    }
322    if embed_dims[0] < info.vocab_size {
323        anyhow::bail!(
324            "embedding vocab rows {} are smaller than config vocab_size {}",
325            embed_dims[0],
326            info.vocab_size
327        );
328    }
329
330    let head_dims = lm_head.shape().dims();
331    if head_dims.len() != 2 || head_dims[1] != info.hidden_size {
332        anyhow::bail!(
333            "lm_head shape mismatch: expected [vocab, {}], got {:?}",
334            info.hidden_size,
335            head_dims
336        );
337    }
338
339    Ok(())
340}