Skip to main content

sapient_models/forward/
phi.rs

1//! Phi-family causal LM forward pass.
2
3use std::collections::HashMap;
4
5use anyhow::Result;
6use sapient_core::Tensor;
7use sapient_hub::model_info::ModelInfo;
8
9use super::backend::{LlmBackend, LlmBackendDispatch, LlmBackendKind};
10use super::common::{
11    embed_tokens, mean_pool_hidden, merge_heads, quantize_tensor_to_q8_0,
12    should_quantize_online, split_heads,
13};
14use crate::weights::{
15    detect_weight_prefix, load_hf_weights, resolve_bias, resolve_lm_head, resolve_weight,
16    tie_word_embeddings_from_config,
17};
18
19#[derive(Debug, Default, Clone)]
20struct LayerCache {
21    keys: Option<Tensor>,
22    values: Option<Tensor>,
23    seq_len: usize,
24}
25
26pub struct PhiForward {
27    info: ModelInfo,
28    prefix: String,
29    weights: HashMap<String, Tensor>,
30    embed_key: String,
31    lm_head: Tensor,
32    cache: Vec<LayerCache>,
33    backend: LlmBackendDispatch,
34}
35
36impl PhiForward {
37    pub fn from_files(info: ModelInfo, weight_paths: &[std::path::PathBuf]) -> Result<Self> {
38        Self::from_files_with_backend(info, weight_paths, LlmBackendKind::Auto)
39    }
40
41    pub fn from_files_with_backend(
42        info: ModelInfo,
43        weight_paths: &[std::path::PathBuf],
44        backend: LlmBackendKind,
45    ) -> Result<Self> {
46        let weights = load_hf_weights(weight_paths)?;
47        Self::from_weights_with_backend(info, weights, backend)
48    }
49
50    pub fn from_weights(info: ModelInfo, weights: HashMap<String, Tensor>) -> Result<Self> {
51        Self::from_weights_with_backend(info, weights, LlmBackendKind::Auto)
52    }
53
54    pub fn from_weights_with_backend(
55        info: ModelInfo,
56        weights: HashMap<String, Tensor>,
57        backend: LlmBackendKind,
58    ) -> Result<Self> {
59        let prefix = detect_weight_prefix(&weights);
60
61        // Online quantization: convert F16/BF16 projection matrices to Q8_0 at
62        // load time.  Same rationale as LlamaForward: avoids per-step F16->F32
63        // conversion overhead while using less RAM than F32 expansion.
64        let weights: HashMap<String, Tensor> = weights
65            .into_iter()
66            .map(|(k, v)| {
67                if should_quantize_online(&k, &v) {
68                    (k, quantize_tensor_to_q8_0(v))
69                } else {
70                    (k, v)
71                }
72            })
73            .collect();
74        let embed_key = format!("{prefix}embed_tokens.weight");
75        let tie = tie_word_embeddings_from_config(&info.raw);
76        let lm_head = resolve_lm_head(&weights, &prefix, tie, &embed_key)?.clone();
77        validate_core_shapes(&info, &weights, &embed_key, &lm_head)?;
78        let backend = LlmBackendDispatch::from_kind(backend)?;
79        tracing::debug!(backend = backend.name(), "initialized Phi forward backend");
80
81        let max_seq = info.max_position_embeddings;
82        let n_kv = info.num_key_value_heads;
83        let hd = info.head_dim;
84        let cache_shape = vec![1, n_kv, max_seq, hd];
85
86        // Allocate KV cache as Q8_0 (4× smaller than F32) when head_dim is a multiple
87        // of 32 (the Q8_0 block size).  Fall back to F32 otherwise.
88        let use_q8_cache = hd % 32 == 0;
89
90        let cache = (0..info.num_hidden_layers)
91            .map(|_| {
92                let (keys, values) = if use_q8_cache {
93                    let numel = n_kv * max_seq * hd;
94                    let kv_bytes = numel / 32 * 34;
95                    let k = Tensor::from_quant_bytes(
96                        &vec![0u8; kv_bytes],
97                        cache_shape.clone(),
98                        sapient_core::DType::Q8_0,
99                    )
100                    .unwrap();
101                    let v = Tensor::from_quant_bytes(
102                        &vec![0u8; kv_bytes],
103                        cache_shape.clone(),
104                        sapient_core::DType::Q8_0,
105                    )
106                    .unwrap();
107                    (k, v)
108                } else {
109                    let k = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
110                    let v = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
111                    (k, v)
112                };
113                LayerCache {
114                    keys: Some(keys),
115                    values: Some(values),
116                    seq_len: 0,
117                }
118            })
119            .collect();
120
121        Ok(Self {
122            cache,
123            info,
124            prefix,
125            embed_key,
126            lm_head,
127            weights,
128            backend,
129        })
130    }
131
132    pub fn reset_cache(&mut self) {
133        for layer in &mut self.cache {
134            layer.seq_len = 0;
135        }
136    }
137
138    pub fn forward_logits(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Vec<f32>> {
139        let hidden = self.forward_hidden(input_ids, use_cache)?;
140        let mut logits = self.backend.logits_from_hidden(&hidden, &self.lm_head)?;
141        // Phi's lm_head has a bias term; add it if present.
142        if let Some(bias) = resolve_bias(&self.weights, &self.prefix, "lm_head") {
143            let bias_cow = bias.to_f32_cow();
144            for (l, b) in logits.iter_mut().zip(bias_cow.iter()) {
145                *l += *b;
146            }
147        }
148        Ok(logits)
149    }
150
151    /// Returns logits for ALL positions without updating the KV cache.
152    pub fn forward_all_logits(&mut self, input_ids: &[u32]) -> Result<Vec<Vec<f32>>> {
153        let hidden = self.forward_hidden(input_ids, false)?;
154        let mut all = self.backend.all_logits_from_hidden(&hidden, &self.lm_head)?;
155        // Phi's lm_head has a bias term; add it to every position if present.
156        if let Some(bias) = resolve_bias(&self.weights, &self.prefix, "lm_head") {
157            let bias_cow = bias.to_f32_cow();
158            for logits in &mut all {
159                for (l, b) in logits.iter_mut().zip(bias_cow.iter()) {
160                    *l += *b;
161                }
162            }
163        }
164        Ok(all)
165    }
166
167    pub fn embed(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
168        self.reset_cache();
169        let hidden = self.forward_hidden(input_ids, false)?;
170        mean_pool_hidden(&hidden)
171    }
172
173    fn forward_hidden(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Tensor> {
174        let embed = self
175            .weights
176            .get(&self.embed_key)
177            .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{}'", self.embed_key))?;
178        let mut x = embed_tokens(embed, input_ids)?;
179
180        let start_pos = if use_cache {
181            self.cache.first().map(|l| l.seq_len).unwrap_or(0)
182        } else {
183            self.reset_cache();
184            0
185        };
186        let seq_len = input_ids.len();
187        let positions: Vec<usize> = (start_pos..start_pos + seq_len).collect();
188
189        for layer_idx in 0..self.info.num_hidden_layers {
190            x = self.forward_layer(x, layer_idx, &positions, use_cache)?;
191        }
192
193        // Phi names the final norm `final_layernorm`; fall back to `norm` for other variants.
194        let (norm_w, norm_b) = match resolve_weight(&self.weights, &self.prefix, "final_layernorm")
195        {
196            Ok(w) => (
197                w,
198                resolve_bias(&self.weights, &self.prefix, "final_layernorm"),
199            ),
200            Err(_) => (
201                resolve_weight(&self.weights, &self.prefix, "norm")?,
202                resolve_bias(&self.weights, &self.prefix, "norm"),
203            ),
204        };
205        self.backend
206            .layer_norm(&x, norm_w, norm_b, self.info.rms_norm_eps as f32)
207    }
208
209    fn forward_layer(
210        &mut self,
211        x: Tensor,
212        layer_idx: usize,
213        positions: &[usize],
214        use_cache: bool,
215    ) -> Result<Tensor> {
216        let pfx = format!("layers.{layer_idx}");
217        let eps = self.info.rms_norm_eps as f32;
218        let n_heads = self.info.num_attention_heads;
219        let head_dim = self.info.head_dim;
220
221        // RoPE is applied to only the first `rotary_dim` channels (Phi partial rotary).
222        let rotary_dim = ((self.info.partial_rotary_factor * head_dim as f64).round() as usize)
223            .clamp(2, head_dim);
224        let theta = self.info.rope_theta as f32;
225
226        // Input LayerNorm (Phi uses LayerNorm with a bias term).
227        let in_ln = format!("{pfx}.input_layernorm");
228        let norm_w = resolve_weight(&self.weights, &self.prefix, &in_ln)?;
229        let norm_b = resolve_bias(&self.weights, &self.prefix, &in_ln);
230        let h = self.backend.layer_norm(&x, norm_w, norm_b, eps)?;
231
232        // Q/K/V projections (Phi has bias on each).
233        let q = self.linear_with_bias(&h, &format!("{pfx}.self_attn.q_proj"), None)?;
234        let k = self.linear_with_bias(&h, &format!("{pfx}.self_attn.k_proj"), None)?;
235        let v = self.linear_with_bias(&h, &format!("{pfx}.self_attn.v_proj"), None)?;
236
237        let q = split_heads(&q, n_heads, head_dim)?;
238        let k = split_heads(&k, n_heads, head_dim)?;
239        let mut v = split_heads(&v, n_heads, head_dim)?;
240
241        let q = self
242            .backend
243            .apply_rope_partial(&q, positions, theta, rotary_dim)?;
244        let mut k = self
245            .backend
246            .apply_rope_partial(&k, positions, theta, rotary_dim)?;
247
248        if use_cache {
249            let current_seq = self.cache[layer_idx].seq_len;
250            let cache = &mut self.cache[layer_idx];
251            if let (Some(ck), Some(cv)) = (&mut cache.keys, &mut cache.values) {
252                k = crate::forward::common::update_kv_cache(ck, current_seq, &k)?;
253                v = crate::forward::common::update_kv_cache(cv, current_seq, &v)?;
254            }
255            cache.seq_len = (current_seq + positions.len()).min(self.info.max_position_embeddings);
256        }
257
258        let attn = self.backend.gqa_attention(&q, &k, &v, n_heads, true)?;
259        let attn = merge_heads(&attn)?;
260        // Attention output projection (Phi-2 calls it `dense`, Phi-3 `o_proj`).
261        let o = self.linear_with_bias(
262            &attn,
263            &format!("{pfx}.self_attn.dense"),
264            Some(&format!("{pfx}.self_attn.o_proj")),
265        )?;
266
267        // Phi-1/1.5/2 ("phi") use a parallel block: attention and MLP both read the
268        // same normalized input `h` and are summed onto the residual.
269        if self.info.model_type == "phi" {
270            let ff = self.mlp_phi2(&h, &pfx)?;
271            let parallel_res = self.backend.add(&o, &ff)?;
272            self.backend.add(&x, &parallel_res)
273        } else {
274            // Phi-3 sequential: residual add, post-attention LayerNorm, then MLP.
275            let x = self.backend.add(&x, &o)?;
276            let post_ln = format!("{pfx}.post_attention_layernorm");
277            let pn_w = resolve_weight(&self.weights, &self.prefix, &post_ln)?;
278            let pn_b = resolve_bias(&self.weights, &self.prefix, &post_ln);
279            let hn = self.backend.layer_norm(&x, pn_w, pn_b, eps)?;
280            let ff = self.mlp_phi3(&hn, &pfx)?;
281            self.backend.add(&x, &ff)
282        }
283    }
284
285    /// Linear projection with optional bias, resolving `name` (or `alt` fallback)
286    /// as the weight key and `<name>.bias` as the bias if present.
287    fn linear_with_bias(&self, x: &Tensor, name: &str, alt: Option<&str>) -> Result<Tensor> {
288        let (weight, bias) = match resolve_weight(&self.weights, &self.prefix, name) {
289            Ok(w) => (w, resolve_bias(&self.weights, &self.prefix, name)),
290            Err(e) => match alt {
291                Some(a) => (
292                    resolve_weight(&self.weights, &self.prefix, a)?,
293                    resolve_bias(&self.weights, &self.prefix, a),
294                ),
295                None => return Err(e),
296            },
297        };
298        self.backend.linear_3d_bias(x, weight, bias)
299    }
300
301    /// Phi-1/1.5/2 MLP: fc1 → gelu_new → fc2 (both with bias).
302    fn mlp_phi2(&self, h: &Tensor, pfx: &str) -> Result<Tensor> {
303        let ff1 = self.linear_with_bias(h, &format!("{pfx}.mlp.fc1"), None)?;
304        let ff1 = self.backend.gelu(&ff1)?;
305        self.linear_with_bias(&ff1, &format!("{pfx}.mlp.fc2"), None)
306    }
307
308    /// Phi-3 MLP: fused gate_up_proj → SwiGLU → down_proj.
309    fn mlp_phi3(&self, h: &Tensor, pfx: &str) -> Result<Tensor> {
310        let gate_up = self.linear_with_bias(h, &format!("{pfx}.mlp.gate_up_proj"), None)?;
311        // gate_up is [1, seq, 2*inter] contiguous; split the last dim into the
312        // gate and up halves. We copy into contiguous buffers rather than use a
313        // strided view, since the elementwise kernels read data contiguously.
314        let dims = gate_up.shape().dims().to_vec();
315        let last = *dims.last().unwrap();
316        let inter = last / 2;
317        let rows: usize = dims[..dims.len() - 1].iter().product();
318        let src = gate_up.to_f32_cow();
319        let mut gate_v = vec![0.0f32; rows * inter];
320        let mut up_v = vec![0.0f32; rows * inter];
321        for r in 0..rows {
322            let base = r * last;
323            gate_v[r * inter..(r + 1) * inter].copy_from_slice(&src[base..base + inter]);
324            up_v[r * inter..(r + 1) * inter].copy_from_slice(&src[base + inter..base + last]);
325        }
326        let mut half_dims = dims.clone();
327        *half_dims.last_mut().unwrap() = inter;
328        let gate = Tensor::from_f32(&gate_v, sapient_core::Shape::new(half_dims.clone()))
329            .map_err(|e| anyhow::anyhow!("{e}"))?;
330        let up = Tensor::from_f32(&up_v, sapient_core::Shape::new(half_dims))
331            .map_err(|e| anyhow::anyhow!("{e}"))?;
332        let gate = self.backend.silu(&gate)?;
333        let activated = self.backend.mul(&gate, &up)?;
334        self.linear_with_bias(&activated, &format!("{pfx}.mlp.down_proj"), None)
335    }
336}
337
338fn validate_core_shapes(
339    info: &ModelInfo,
340    weights: &HashMap<String, Tensor>,
341    embed_key: &str,
342    lm_head: &Tensor,
343) -> Result<()> {
344    let embed = weights
345        .get(embed_key)
346        .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{embed_key}'"))?;
347    let embed_dims = embed.shape().dims();
348    if embed_dims.len() != 2 || embed_dims[1] != info.hidden_size {
349        anyhow::bail!(
350            "embedding shape mismatch at '{embed_key}': expected [vocab, {}], got {:?}",
351            info.hidden_size,
352            embed_dims
353        );
354    }
355    if embed_dims[0] < info.vocab_size {
356        anyhow::bail!(
357            "embedding vocab rows {} are smaller than config vocab_size {}",
358            embed_dims[0],
359            info.vocab_size
360        );
361    }
362
363    let head_dims = lm_head.shape().dims();
364    if head_dims.len() != 2 || head_dims[1] != info.hidden_size {
365        anyhow::bail!(
366            "lm_head shape mismatch: expected [vocab, {}], got {:?}",
367            info.hidden_size,
368            head_dims
369        );
370    }
371
372    Ok(())
373}