Skip to main content

sapient_models/forward/
phi.rs

1//! Phi-family causal LM forward pass.
2
3use std::collections::HashMap;
4
5use anyhow::Result;
6use sapient_core::Tensor;
7use sapient_hub::model_info::ModelInfo;
8
9use super::backend::{LlmBackend, LlmBackendDispatch, LlmBackendKind};
10use super::common::{embed_tokens, mean_pool_hidden, merge_heads, split_heads};
11use crate::weights::{
12    detect_weight_prefix, load_hf_weights, resolve_bias, resolve_lm_head, resolve_weight,
13    tie_word_embeddings_from_config,
14};
15
16#[derive(Debug, Default, Clone)]
17struct LayerCache {
18    keys: Option<Tensor>,
19    values: Option<Tensor>,
20    seq_len: usize,
21}
22
23pub struct PhiForward {
24    info: ModelInfo,
25    prefix: String,
26    weights: HashMap<String, Tensor>,
27    embed_key: String,
28    lm_head: Tensor,
29    cache: Vec<LayerCache>,
30    backend: LlmBackendDispatch,
31}
32
33impl PhiForward {
34    pub fn from_files(info: ModelInfo, weight_paths: &[std::path::PathBuf]) -> Result<Self> {
35        Self::from_files_with_backend(info, weight_paths, LlmBackendKind::Auto)
36    }
37
38    pub fn from_files_with_backend(
39        info: ModelInfo,
40        weight_paths: &[std::path::PathBuf],
41        backend: LlmBackendKind,
42    ) -> Result<Self> {
43        let weights = load_hf_weights(weight_paths)?;
44        Self::from_weights_with_backend(info, weights, backend)
45    }
46
47    pub fn from_weights(info: ModelInfo, weights: HashMap<String, Tensor>) -> Result<Self> {
48        Self::from_weights_with_backend(info, weights, LlmBackendKind::Auto)
49    }
50
51    pub fn from_weights_with_backend(
52        info: ModelInfo,
53        weights: HashMap<String, Tensor>,
54        backend: LlmBackendKind,
55    ) -> Result<Self> {
56        let prefix = detect_weight_prefix(&weights);
57        let embed_key = format!("{prefix}embed_tokens.weight");
58        let tie = tie_word_embeddings_from_config(&info.raw);
59        let lm_head = resolve_lm_head(&weights, &prefix, tie, &embed_key)?.clone();
60        validate_core_shapes(&info, &weights, &embed_key, &lm_head)?;
61        let backend = LlmBackendDispatch::from_kind(backend)?;
62        tracing::debug!(backend = backend.name(), "initialized Phi forward backend");
63
64        let max_seq = info.max_position_embeddings;
65        let n_kv = info.num_key_value_heads;
66        let hd = info.head_dim;
67        let cache_shape = vec![1, n_kv, max_seq, hd];
68
69        let cache = (0..info.num_hidden_layers)
70            .map(|_| {
71                let keys = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
72                let values = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
73                LayerCache {
74                    keys: Some(keys),
75                    values: Some(values),
76                    seq_len: 0,
77                }
78            })
79            .collect();
80
81        Ok(Self {
82            cache,
83            info,
84            prefix,
85            embed_key,
86            lm_head,
87            weights,
88            backend,
89        })
90    }
91
92    pub fn reset_cache(&mut self) {
93        for layer in &mut self.cache {
94            layer.seq_len = 0;
95        }
96    }
97
98    pub fn forward_logits(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Vec<f32>> {
99        let hidden = self.forward_hidden(input_ids, use_cache)?;
100        let mut logits = self.backend.logits_from_hidden(&hidden, &self.lm_head)?;
101        // Phi's lm_head has a bias term; add it if present.
102        if let Some(bias) = resolve_bias(&self.weights, &self.prefix, "lm_head") {
103            let bias_cow = bias.to_f32_cow();
104            for (l, b) in logits.iter_mut().zip(bias_cow.iter()) {
105                *l += *b;
106            }
107        }
108        Ok(logits)
109    }
110
111    pub fn embed(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
112        self.reset_cache();
113        let hidden = self.forward_hidden(input_ids, false)?;
114        mean_pool_hidden(&hidden)
115    }
116
117    fn forward_hidden(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Tensor> {
118        let embed = self
119            .weights
120            .get(&self.embed_key)
121            .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{}'", self.embed_key))?;
122        let mut x = embed_tokens(embed, input_ids)?;
123
124        let start_pos = if use_cache {
125            self.cache.first().map(|l| l.seq_len).unwrap_or(0)
126        } else {
127            self.reset_cache();
128            0
129        };
130        let seq_len = input_ids.len();
131        let positions: Vec<usize> = (start_pos..start_pos + seq_len).collect();
132
133        for layer_idx in 0..self.info.num_hidden_layers {
134            x = self.forward_layer(x, layer_idx, &positions, use_cache)?;
135        }
136
137        // Phi names the final norm `final_layernorm`; fall back to `norm` for other variants.
138        let (norm_w, norm_b) = match resolve_weight(&self.weights, &self.prefix, "final_layernorm")
139        {
140            Ok(w) => (
141                w,
142                resolve_bias(&self.weights, &self.prefix, "final_layernorm"),
143            ),
144            Err(_) => (
145                resolve_weight(&self.weights, &self.prefix, "norm")?,
146                resolve_bias(&self.weights, &self.prefix, "norm"),
147            ),
148        };
149        self.backend
150            .layer_norm(&x, norm_w, norm_b, self.info.rms_norm_eps as f32)
151    }
152
153    fn forward_layer(
154        &mut self,
155        x: Tensor,
156        layer_idx: usize,
157        positions: &[usize],
158        use_cache: bool,
159    ) -> Result<Tensor> {
160        let pfx = format!("layers.{layer_idx}");
161        let eps = self.info.rms_norm_eps as f32;
162        let n_heads = self.info.num_attention_heads;
163        let head_dim = self.info.head_dim;
164
165        // RoPE is applied to only the first `rotary_dim` channels (Phi partial rotary).
166        let rotary_dim = ((self.info.partial_rotary_factor * head_dim as f64).round() as usize)
167            .clamp(2, head_dim);
168        let theta = self.info.rope_theta as f32;
169
170        // Input LayerNorm (Phi uses LayerNorm with a bias term).
171        let in_ln = format!("{pfx}.input_layernorm");
172        let norm_w = resolve_weight(&self.weights, &self.prefix, &in_ln)?;
173        let norm_b = resolve_bias(&self.weights, &self.prefix, &in_ln);
174        let h = self.backend.layer_norm(&x, norm_w, norm_b, eps)?;
175
176        // Q/K/V projections (Phi has bias on each).
177        let q = self.linear_with_bias(&h, &format!("{pfx}.self_attn.q_proj"), None)?;
178        let k = self.linear_with_bias(&h, &format!("{pfx}.self_attn.k_proj"), None)?;
179        let v = self.linear_with_bias(&h, &format!("{pfx}.self_attn.v_proj"), None)?;
180
181        let q = split_heads(&q, n_heads, head_dim)?;
182        let k = split_heads(&k, n_heads, head_dim)?;
183        let mut v = split_heads(&v, n_heads, head_dim)?;
184
185        let q = self
186            .backend
187            .apply_rope_partial(&q, positions, theta, rotary_dim)?;
188        let mut k = self
189            .backend
190            .apply_rope_partial(&k, positions, theta, rotary_dim)?;
191
192        if use_cache {
193            let current_seq = self.cache[layer_idx].seq_len;
194            let cache = &mut self.cache[layer_idx];
195            if let (Some(ck), Some(cv)) = (&mut cache.keys, &mut cache.values) {
196                k = crate::forward::common::update_kv_cache(ck, current_seq, &k)?;
197                v = crate::forward::common::update_kv_cache(cv, current_seq, &v)?;
198            }
199            cache.seq_len = (current_seq + positions.len()).min(self.info.max_position_embeddings);
200        }
201
202        let attn = self.backend.gqa_attention(&q, &k, &v, n_heads, true)?;
203        let attn = merge_heads(&attn)?;
204        // Attention output projection (Phi-2 calls it `dense`, Phi-3 `o_proj`).
205        let o = self.linear_with_bias(
206            &attn,
207            &format!("{pfx}.self_attn.dense"),
208            Some(&format!("{pfx}.self_attn.o_proj")),
209        )?;
210
211        // Phi-1/1.5/2 ("phi") use a parallel block: attention and MLP both read the
212        // same normalized input `h` and are summed onto the residual.
213        if self.info.model_type == "phi" {
214            let ff = self.mlp_phi2(&h, &pfx)?;
215            let parallel_res = self.backend.add(&o, &ff)?;
216            self.backend.add(&x, &parallel_res)
217        } else {
218            // Phi-3 sequential: residual add, post-attention LayerNorm, then MLP.
219            let x = self.backend.add(&x, &o)?;
220            let post_ln = format!("{pfx}.post_attention_layernorm");
221            let pn_w = resolve_weight(&self.weights, &self.prefix, &post_ln)?;
222            let pn_b = resolve_bias(&self.weights, &self.prefix, &post_ln);
223            let hn = self.backend.layer_norm(&x, pn_w, pn_b, eps)?;
224            let ff = self.mlp_phi3(&hn, &pfx)?;
225            self.backend.add(&x, &ff)
226        }
227    }
228
229    /// Linear projection with optional bias, resolving `name` (or `alt` fallback)
230    /// as the weight key and `<name>.bias` as the bias if present.
231    fn linear_with_bias(&self, x: &Tensor, name: &str, alt: Option<&str>) -> Result<Tensor> {
232        let (weight, bias) = match resolve_weight(&self.weights, &self.prefix, name) {
233            Ok(w) => (w, resolve_bias(&self.weights, &self.prefix, name)),
234            Err(e) => match alt {
235                Some(a) => (
236                    resolve_weight(&self.weights, &self.prefix, a)?,
237                    resolve_bias(&self.weights, &self.prefix, a),
238                ),
239                None => return Err(e),
240            },
241        };
242        self.backend.linear_3d_bias(x, weight, bias)
243    }
244
245    /// Phi-1/1.5/2 MLP: fc1 → gelu_new → fc2 (both with bias).
246    fn mlp_phi2(&self, h: &Tensor, pfx: &str) -> Result<Tensor> {
247        let ff1 = self.linear_with_bias(h, &format!("{pfx}.mlp.fc1"), None)?;
248        let ff1 = self.backend.gelu(&ff1)?;
249        self.linear_with_bias(&ff1, &format!("{pfx}.mlp.fc2"), None)
250    }
251
252    /// Phi-3 MLP: fused gate_up_proj → SwiGLU → down_proj.
253    fn mlp_phi3(&self, h: &Tensor, pfx: &str) -> Result<Tensor> {
254        let gate_up = self.linear_with_bias(h, &format!("{pfx}.mlp.gate_up_proj"), None)?;
255        // gate_up is [1, seq, 2*inter] contiguous; split the last dim into the
256        // gate and up halves. We copy into contiguous buffers rather than use a
257        // strided view, since the elementwise kernels read data contiguously.
258        let dims = gate_up.shape().dims().to_vec();
259        let last = *dims.last().unwrap();
260        let inter = last / 2;
261        let rows: usize = dims[..dims.len() - 1].iter().product();
262        let src = gate_up.to_f32_cow();
263        let mut gate_v = vec![0.0f32; rows * inter];
264        let mut up_v = vec![0.0f32; rows * inter];
265        for r in 0..rows {
266            let base = r * last;
267            gate_v[r * inter..(r + 1) * inter].copy_from_slice(&src[base..base + inter]);
268            up_v[r * inter..(r + 1) * inter].copy_from_slice(&src[base + inter..base + last]);
269        }
270        let mut half_dims = dims.clone();
271        *half_dims.last_mut().unwrap() = inter;
272        let gate = Tensor::from_f32(&gate_v, sapient_core::Shape::new(half_dims.clone()))
273            .map_err(|e| anyhow::anyhow!("{e}"))?;
274        let up = Tensor::from_f32(&up_v, sapient_core::Shape::new(half_dims))
275            .map_err(|e| anyhow::anyhow!("{e}"))?;
276        let gate = self.backend.silu(&gate)?;
277        let activated = self.backend.mul(&gate, &up)?;
278        self.linear_with_bias(&activated, &format!("{pfx}.mlp.down_proj"), None)
279    }
280}
281
282fn validate_core_shapes(
283    info: &ModelInfo,
284    weights: &HashMap<String, Tensor>,
285    embed_key: &str,
286    lm_head: &Tensor,
287) -> Result<()> {
288    let embed = weights
289        .get(embed_key)
290        .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{embed_key}'"))?;
291    let embed_dims = embed.shape().dims();
292    if embed_dims.len() != 2 || embed_dims[1] != info.hidden_size {
293        anyhow::bail!(
294            "embedding shape mismatch at '{embed_key}': expected [vocab, {}], got {:?}",
295            info.hidden_size,
296            embed_dims
297        );
298    }
299    if embed_dims[0] < info.vocab_size {
300        anyhow::bail!(
301            "embedding vocab rows {} are smaller than config vocab_size {}",
302            embed_dims[0],
303            info.vocab_size
304        );
305    }
306
307    let head_dims = lm_head.shape().dims();
308    if head_dims.len() != 2 || head_dims[1] != info.hidden_size {
309        anyhow::bail!(
310            "lm_head shape mismatch: expected [vocab, {}], got {:?}",
311            info.hidden_size,
312            head_dims
313        );
314    }
315
316    Ok(())
317}