Skip to main content

sapient_models/forward/
llama.rs

1//! Llama-family causal LM forward pass (Llama, Mistral, Qwen, SmolVLM text backbone).
2
3use std::collections::HashMap;
4
5use anyhow::Result;
6use sapient_core::Tensor;
7use sapient_hub::model_info::ModelInfo;
8
9use super::backend::{LlmBackend, LlmBackendDispatch, LlmBackendKind};
10use super::common::{embed_tokens, mean_pool_hidden, merge_heads, split_heads};
11use crate::weights::{
12    detect_weight_prefix, load_hf_weights, resolve_bias, resolve_lm_head, resolve_weight,
13    tie_word_embeddings_from_config,
14};
15
16/// Per-layer KV cache stored as concatenated 4-D tensors.
17#[derive(Debug, Default, Clone)]
18struct LayerCache {
19    keys: Option<Tensor>,
20    values: Option<Tensor>,
21    seq_len: usize,
22}
23
24/// Real Llama-architecture forward engine backed by safetensors weights.
25pub struct LlamaForward {
26    info: ModelInfo,
27    prefix: String,
28    weights: HashMap<String, Tensor>,
29    embed_key: String,
30    lm_head: Tensor,
31    cache: Vec<LayerCache>,
32    backend: LlmBackendDispatch,
33}
34
35impl LlamaForward {
36    pub fn from_files(info: ModelInfo, weight_paths: &[std::path::PathBuf]) -> Result<Self> {
37        Self::from_files_with_backend(info, weight_paths, LlmBackendKind::Auto)
38    }
39
40    pub fn from_files_with_backend(
41        info: ModelInfo,
42        weight_paths: &[std::path::PathBuf],
43        backend: LlmBackendKind,
44    ) -> Result<Self> {
45        let weights = load_hf_weights(weight_paths)?;
46        Self::from_weights_with_backend(info, weights, backend)
47    }
48
49    pub fn from_weights(info: ModelInfo, weights: HashMap<String, Tensor>) -> Result<Self> {
50        Self::from_weights_with_backend(info, weights, LlmBackendKind::Auto)
51    }
52
53    pub fn from_weights_with_backend(
54        info: ModelInfo,
55        weights: HashMap<String, Tensor>,
56        backend: LlmBackendKind,
57    ) -> Result<Self> {
58        let prefix = detect_weight_prefix(&weights);
59        let embed_key = format!("{prefix}embed_tokens.weight");
60        let tie = tie_word_embeddings_from_config(&info.raw);
61        let lm_head = resolve_lm_head(&weights, &prefix, tie, &embed_key)?.clone();
62        validate_core_shapes(&info, &weights, &embed_key, &lm_head)?;
63        let backend = LlmBackendDispatch::from_kind(backend)?;
64        tracing::debug!(
65            backend = backend.name(),
66            "initialized Llama forward backend"
67        );
68
69        let max_seq = info.max_position_embeddings;
70        let n_kv = info.num_key_value_heads;
71        let hd = info.head_dim;
72        let cache_shape = vec![1, n_kv, max_seq, hd];
73
74        // Allocate KV cache as Q8_0 (4× smaller than F32) when head_dim is a multiple
75        // of 32 (the Q8_0 block size).  Fall back to F32 otherwise.
76        let use_q8_cache = hd % 32 == 0;
77
78        let cache = (0..info.num_hidden_layers)
79            .map(|_| {
80                let (keys, values) = if use_q8_cache {
81                    // Q8_0: numel/32 blocks × 34 bytes each.
82                    let numel = n_kv * max_seq * hd;
83                    let kv_bytes = numel / 32 * 34;
84                    let k = Tensor::from_quant_bytes(
85                        &vec![0u8; kv_bytes],
86                        cache_shape.clone(),
87                        sapient_core::DType::Q8_0,
88                    )
89                    .unwrap();
90                    let v = Tensor::from_quant_bytes(
91                        &vec![0u8; kv_bytes],
92                        cache_shape.clone(),
93                        sapient_core::DType::Q8_0,
94                    )
95                    .unwrap();
96                    (k, v)
97                } else {
98                    let k = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
99                    let v = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
100                    (k, v)
101                };
102                LayerCache {
103                    keys: Some(keys),
104                    values: Some(values),
105                    seq_len: 0,
106                }
107            })
108            .collect();
109
110        Ok(Self {
111            cache,
112            info,
113            prefix,
114            embed_key,
115            lm_head,
116            weights,
117            backend,
118        })
119    }
120
121    pub fn reset_cache(&mut self) {
122        for layer in &mut self.cache {
123            layer.seq_len = 0;
124        }
125    }
126
127    /// Run forward on token ids and return logits for the last token.
128    pub fn forward_logits(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Vec<f32>> {
129        let hidden = self.forward_hidden(input_ids, use_cache)?;
130        self.backend.logits_from_hidden(&hidden, &self.lm_head)
131    }
132
133    /// Returns logits for ALL positions without updating the KV cache.
134    /// Used by speculative decoding to verify draft tokens in one shot.
135    pub fn forward_all_logits(&mut self, input_ids: &[u32]) -> Result<Vec<Vec<f32>>> {
136        let hidden = self.forward_hidden(input_ids, false)?;
137        self.backend.all_logits_from_hidden(&hidden, &self.lm_head)
138    }
139
140    /// Mean-pooled hidden states for embedding models.
141    pub fn embed(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
142        self.reset_cache();
143        let hidden = self.forward_hidden(input_ids, false)?;
144        mean_pool_hidden(&hidden)
145    }
146
147    fn forward_hidden(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Tensor> {
148        let embed = self
149            .weights
150            .get(&self.embed_key)
151            .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{}'", self.embed_key))?;
152        let mut x = embed_tokens(embed, input_ids)?;
153
154        let start_pos = if use_cache {
155            self.cache.first().map(|l| l.seq_len).unwrap_or(0)
156        } else {
157            self.reset_cache();
158            0
159        };
160
161        let seq_len = input_ids.len();
162        let positions: Vec<usize> = (start_pos..start_pos + seq_len).collect();
163
164        for layer_idx in 0..self.info.num_hidden_layers {
165            x = self.forward_layer(x, layer_idx, &positions, use_cache)?;
166        }
167
168        let norm_w = resolve_weight(&self.weights, &self.prefix, "norm")?;
169        self.backend
170            .rms_norm(&x, norm_w, self.info.rms_norm_eps as f32)
171    }
172
173    fn forward_layer(
174        &mut self,
175        x: Tensor,
176        layer_idx: usize,
177        positions: &[usize],
178        use_cache: bool,
179    ) -> Result<Tensor> {
180        let pfx = format!("layers.{layer_idx}");
181        let eps = self.info.rms_norm_eps as f32;
182        let n_heads = self.info.num_attention_heads;
183        let n_kv = self.info.num_key_value_heads;
184        let head_dim = self.info.head_dim;
185
186        let attn_norm_w = resolve_weight(
187            &self.weights,
188            &self.prefix,
189            &format!("{pfx}.input_layernorm"),
190        )?;
191        let h = self.backend.rms_norm(&x, attn_norm_w, eps)?;
192
193        // Q/K/V projections. Llama/Mistral have no bias; Qwen2 has q/k/v biases —
194        // resolve_bias returns None when absent, so this is correct for both.
195        let q = self.linear(&h, &format!("{pfx}.self_attn.q_proj"))?;
196        let k = self.linear(&h, &format!("{pfx}.self_attn.k_proj"))?;
197        let v = self.linear(&h, &format!("{pfx}.self_attn.v_proj"))?;
198
199        let mut q = split_heads(&q, n_heads, head_dim)?;
200        let mut k = split_heads(&k, n_kv, head_dim)?;
201        let mut v = split_heads(&v, n_kv, head_dim)?;
202
203        q = self
204            .backend
205            .apply_rope_positions(&q, positions, self.info.rope_theta as f32)?;
206        k = self
207            .backend
208            .apply_rope_positions(&k, positions, self.info.rope_theta as f32)?;
209
210        let cache = &mut self.cache[layer_idx];
211        if use_cache {
212            let current_seq = cache.seq_len;
213            if let (Some(ck), Some(cv)) = (&mut cache.keys, &mut cache.values) {
214                k = crate::forward::common::update_kv_cache(ck, current_seq, &k)?;
215                v = crate::forward::common::update_kv_cache(cv, current_seq, &v)?;
216            }
217            cache.seq_len = current_seq + positions.len();
218        }
219
220        let attn = self.backend.gqa_attention(&q, &k, &v, n_kv, true)?;
221        let attn = merge_heads(&attn)?;
222        let o = self.linear(&attn, &format!("{pfx}.self_attn.o_proj"))?;
223        let x = self.backend.add(&x, &o)?;
224
225        let ffn_norm_w = resolve_weight(
226            &self.weights,
227            &self.prefix,
228            &format!("{pfx}.post_attention_layernorm"),
229        )?;
230        let h = self.backend.rms_norm(&x, ffn_norm_w, eps)?;
231
232        let gate = self.backend.linear_3d(
233            &h,
234            resolve_weight(&self.weights, &self.prefix, &format!("{pfx}.mlp.gate_proj"))?,
235        )?;
236        let up = self.backend.linear_3d(
237            &h,
238            resolve_weight(&self.weights, &self.prefix, &format!("{pfx}.mlp.up_proj"))?,
239        )?;
240        let gate = self.backend.silu(&gate)?;
241        let mid = self.backend.mul(&gate, &up)?;
242        let down = self.backend.linear_3d(
243            &mid,
244            resolve_weight(&self.weights, &self.prefix, &format!("{pfx}.mlp.down_proj"))?,
245        )?;
246        self.backend.add(&x, &down)
247    }
248
249    /// Linear projection that automatically applies a bias when the model has one
250    /// (Qwen2 q/k/v), and is a plain matmul otherwise (Llama, Mistral).
251    fn linear(&self, x: &Tensor, name: &str) -> Result<Tensor> {
252        let weight = resolve_weight(&self.weights, &self.prefix, name)?;
253        let bias = resolve_bias(&self.weights, &self.prefix, name);
254        self.backend.linear_3d_bias(x, weight, bias)
255    }
256}
257
258fn validate_core_shapes(
259    info: &ModelInfo,
260    weights: &HashMap<String, Tensor>,
261    embed_key: &str,
262    lm_head: &Tensor,
263) -> Result<()> {
264    let embed = weights
265        .get(embed_key)
266        .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{embed_key}'"))?;
267    let embed_dims = embed.shape().dims();
268    if embed_dims.len() != 2 || embed_dims[1] != info.hidden_size {
269        anyhow::bail!(
270            "embedding shape mismatch at '{embed_key}': expected [vocab, {}], got {:?}",
271            info.hidden_size,
272            embed_dims
273        );
274    }
275    if embed_dims[0] < info.vocab_size {
276        anyhow::bail!(
277            "embedding vocab rows {} are smaller than config vocab_size {}",
278            embed_dims[0],
279            info.vocab_size
280        );
281    }
282
283    let head_dims = lm_head.shape().dims();
284    if head_dims.len() != 2 || head_dims[1] != info.hidden_size {
285        anyhow::bail!(
286            "lm_head shape mismatch: expected [vocab, {}], got {:?}",
287            info.hidden_size,
288            head_dims
289        );
290    }
291
292    Ok(())
293}