1use std::collections::HashMap;
4
5use anyhow::Result;
6use sapient_core::Tensor;
7use sapient_hub::model_info::ModelInfo;
8
9use super::backend::{LlmBackend, LlmBackendDispatch, LlmBackendKind};
10use super::common::{embed_tokens, mean_pool_hidden, merge_heads, split_heads};
11use crate::weights::{
12 detect_weight_prefix, load_hf_weights, resolve_bias, resolve_lm_head, resolve_weight,
13 tie_word_embeddings_from_config,
14};
15
16#[derive(Debug, Default, Clone)]
17struct LayerCache {
18 keys: Option<Tensor>,
19 values: Option<Tensor>,
20 seq_len: usize,
21}
22
23pub struct PhiForward {
24 info: ModelInfo,
25 prefix: String,
26 weights: HashMap<String, Tensor>,
27 embed_key: String,
28 lm_head: Tensor,
29 cache: Vec<LayerCache>,
30 backend: LlmBackendDispatch,
31}
32
33impl PhiForward {
34 pub fn from_files(info: ModelInfo, weight_paths: &[std::path::PathBuf]) -> Result<Self> {
35 Self::from_files_with_backend(info, weight_paths, LlmBackendKind::Auto)
36 }
37
38 pub fn from_files_with_backend(
39 info: ModelInfo,
40 weight_paths: &[std::path::PathBuf],
41 backend: LlmBackendKind,
42 ) -> Result<Self> {
43 let weights = load_hf_weights(weight_paths)?;
44 Self::from_weights_with_backend(info, weights, backend)
45 }
46
47 pub fn from_weights(info: ModelInfo, weights: HashMap<String, Tensor>) -> Result<Self> {
48 Self::from_weights_with_backend(info, weights, LlmBackendKind::Auto)
49 }
50
51 pub fn from_weights_with_backend(
52 info: ModelInfo,
53 weights: HashMap<String, Tensor>,
54 backend: LlmBackendKind,
55 ) -> Result<Self> {
56 let prefix = detect_weight_prefix(&weights);
57 let embed_key = format!("{prefix}embed_tokens.weight");
58 let tie = tie_word_embeddings_from_config(&info.raw);
59 let lm_head = resolve_lm_head(&weights, &prefix, tie, &embed_key)?.clone();
60 validate_core_shapes(&info, &weights, &embed_key, &lm_head)?;
61 let backend = LlmBackendDispatch::from_kind(backend)?;
62 tracing::debug!(backend = backend.name(), "initialized Phi forward backend");
63
64 let max_seq = info.max_position_embeddings;
65 let n_kv = info.num_key_value_heads;
66 let hd = info.head_dim;
67 let cache_shape = vec![1, n_kv, max_seq, hd];
68
69 let use_q8_cache = hd % 32 == 0;
72
73 let cache = (0..info.num_hidden_layers)
74 .map(|_| {
75 let (keys, values) = if use_q8_cache {
76 let numel = n_kv * max_seq * hd;
77 let kv_bytes = numel / 32 * 34;
78 let k = Tensor::from_quant_bytes(
79 &vec![0u8; kv_bytes],
80 cache_shape.clone(),
81 sapient_core::DType::Q8_0,
82 )
83 .unwrap();
84 let v = Tensor::from_quant_bytes(
85 &vec![0u8; kv_bytes],
86 cache_shape.clone(),
87 sapient_core::DType::Q8_0,
88 )
89 .unwrap();
90 (k, v)
91 } else {
92 let k = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
93 let v = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
94 (k, v)
95 };
96 LayerCache {
97 keys: Some(keys),
98 values: Some(values),
99 seq_len: 0,
100 }
101 })
102 .collect();
103
104 Ok(Self {
105 cache,
106 info,
107 prefix,
108 embed_key,
109 lm_head,
110 weights,
111 backend,
112 })
113 }
114
115 pub fn reset_cache(&mut self) {
116 for layer in &mut self.cache {
117 layer.seq_len = 0;
118 }
119 }
120
121 pub fn forward_logits(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Vec<f32>> {
122 let hidden = self.forward_hidden(input_ids, use_cache)?;
123 let mut logits = self.backend.logits_from_hidden(&hidden, &self.lm_head)?;
124 if let Some(bias) = resolve_bias(&self.weights, &self.prefix, "lm_head") {
126 let bias_cow = bias.to_f32_cow();
127 for (l, b) in logits.iter_mut().zip(bias_cow.iter()) {
128 *l += *b;
129 }
130 }
131 Ok(logits)
132 }
133
134 pub fn forward_all_logits(&mut self, input_ids: &[u32]) -> Result<Vec<Vec<f32>>> {
136 let hidden = self.forward_hidden(input_ids, false)?;
137 let mut all = self.backend.all_logits_from_hidden(&hidden, &self.lm_head)?;
138 if let Some(bias) = resolve_bias(&self.weights, &self.prefix, "lm_head") {
140 let bias_cow = bias.to_f32_cow();
141 for logits in &mut all {
142 for (l, b) in logits.iter_mut().zip(bias_cow.iter()) {
143 *l += *b;
144 }
145 }
146 }
147 Ok(all)
148 }
149
150 pub fn embed(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
151 self.reset_cache();
152 let hidden = self.forward_hidden(input_ids, false)?;
153 mean_pool_hidden(&hidden)
154 }
155
156 fn forward_hidden(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Tensor> {
157 let embed = self
158 .weights
159 .get(&self.embed_key)
160 .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{}'", self.embed_key))?;
161 let mut x = embed_tokens(embed, input_ids)?;
162
163 let start_pos = if use_cache {
164 self.cache.first().map(|l| l.seq_len).unwrap_or(0)
165 } else {
166 self.reset_cache();
167 0
168 };
169 let seq_len = input_ids.len();
170 let positions: Vec<usize> = (start_pos..start_pos + seq_len).collect();
171
172 for layer_idx in 0..self.info.num_hidden_layers {
173 x = self.forward_layer(x, layer_idx, &positions, use_cache)?;
174 }
175
176 let (norm_w, norm_b) = match resolve_weight(&self.weights, &self.prefix, "final_layernorm")
178 {
179 Ok(w) => (
180 w,
181 resolve_bias(&self.weights, &self.prefix, "final_layernorm"),
182 ),
183 Err(_) => (
184 resolve_weight(&self.weights, &self.prefix, "norm")?,
185 resolve_bias(&self.weights, &self.prefix, "norm"),
186 ),
187 };
188 self.backend
189 .layer_norm(&x, norm_w, norm_b, self.info.rms_norm_eps as f32)
190 }
191
192 fn forward_layer(
193 &mut self,
194 x: Tensor,
195 layer_idx: usize,
196 positions: &[usize],
197 use_cache: bool,
198 ) -> Result<Tensor> {
199 let pfx = format!("layers.{layer_idx}");
200 let eps = self.info.rms_norm_eps as f32;
201 let n_heads = self.info.num_attention_heads;
202 let head_dim = self.info.head_dim;
203
204 let rotary_dim = ((self.info.partial_rotary_factor * head_dim as f64).round() as usize)
206 .clamp(2, head_dim);
207 let theta = self.info.rope_theta as f32;
208
209 let in_ln = format!("{pfx}.input_layernorm");
211 let norm_w = resolve_weight(&self.weights, &self.prefix, &in_ln)?;
212 let norm_b = resolve_bias(&self.weights, &self.prefix, &in_ln);
213 let h = self.backend.layer_norm(&x, norm_w, norm_b, eps)?;
214
215 let q = self.linear_with_bias(&h, &format!("{pfx}.self_attn.q_proj"), None)?;
217 let k = self.linear_with_bias(&h, &format!("{pfx}.self_attn.k_proj"), None)?;
218 let v = self.linear_with_bias(&h, &format!("{pfx}.self_attn.v_proj"), None)?;
219
220 let q = split_heads(&q, n_heads, head_dim)?;
221 let k = split_heads(&k, n_heads, head_dim)?;
222 let mut v = split_heads(&v, n_heads, head_dim)?;
223
224 let q = self
225 .backend
226 .apply_rope_partial(&q, positions, theta, rotary_dim)?;
227 let mut k = self
228 .backend
229 .apply_rope_partial(&k, positions, theta, rotary_dim)?;
230
231 if use_cache {
232 let current_seq = self.cache[layer_idx].seq_len;
233 let cache = &mut self.cache[layer_idx];
234 if let (Some(ck), Some(cv)) = (&mut cache.keys, &mut cache.values) {
235 k = crate::forward::common::update_kv_cache(ck, current_seq, &k)?;
236 v = crate::forward::common::update_kv_cache(cv, current_seq, &v)?;
237 }
238 cache.seq_len = (current_seq + positions.len()).min(self.info.max_position_embeddings);
239 }
240
241 let attn = self.backend.gqa_attention(&q, &k, &v, n_heads, true)?;
242 let attn = merge_heads(&attn)?;
243 let o = self.linear_with_bias(
245 &attn,
246 &format!("{pfx}.self_attn.dense"),
247 Some(&format!("{pfx}.self_attn.o_proj")),
248 )?;
249
250 if self.info.model_type == "phi" {
253 let ff = self.mlp_phi2(&h, &pfx)?;
254 let parallel_res = self.backend.add(&o, &ff)?;
255 self.backend.add(&x, ¶llel_res)
256 } else {
257 let x = self.backend.add(&x, &o)?;
259 let post_ln = format!("{pfx}.post_attention_layernorm");
260 let pn_w = resolve_weight(&self.weights, &self.prefix, &post_ln)?;
261 let pn_b = resolve_bias(&self.weights, &self.prefix, &post_ln);
262 let hn = self.backend.layer_norm(&x, pn_w, pn_b, eps)?;
263 let ff = self.mlp_phi3(&hn, &pfx)?;
264 self.backend.add(&x, &ff)
265 }
266 }
267
268 fn linear_with_bias(&self, x: &Tensor, name: &str, alt: Option<&str>) -> Result<Tensor> {
271 let (weight, bias) = match resolve_weight(&self.weights, &self.prefix, name) {
272 Ok(w) => (w, resolve_bias(&self.weights, &self.prefix, name)),
273 Err(e) => match alt {
274 Some(a) => (
275 resolve_weight(&self.weights, &self.prefix, a)?,
276 resolve_bias(&self.weights, &self.prefix, a),
277 ),
278 None => return Err(e),
279 },
280 };
281 self.backend.linear_3d_bias(x, weight, bias)
282 }
283
284 fn mlp_phi2(&self, h: &Tensor, pfx: &str) -> Result<Tensor> {
286 let ff1 = self.linear_with_bias(h, &format!("{pfx}.mlp.fc1"), None)?;
287 let ff1 = self.backend.gelu(&ff1)?;
288 self.linear_with_bias(&ff1, &format!("{pfx}.mlp.fc2"), None)
289 }
290
291 fn mlp_phi3(&self, h: &Tensor, pfx: &str) -> Result<Tensor> {
293 let gate_up = self.linear_with_bias(h, &format!("{pfx}.mlp.gate_up_proj"), None)?;
294 let dims = gate_up.shape().dims().to_vec();
298 let last = *dims.last().unwrap();
299 let inter = last / 2;
300 let rows: usize = dims[..dims.len() - 1].iter().product();
301 let src = gate_up.to_f32_cow();
302 let mut gate_v = vec![0.0f32; rows * inter];
303 let mut up_v = vec![0.0f32; rows * inter];
304 for r in 0..rows {
305 let base = r * last;
306 gate_v[r * inter..(r + 1) * inter].copy_from_slice(&src[base..base + inter]);
307 up_v[r * inter..(r + 1) * inter].copy_from_slice(&src[base + inter..base + last]);
308 }
309 let mut half_dims = dims.clone();
310 *half_dims.last_mut().unwrap() = inter;
311 let gate = Tensor::from_f32(&gate_v, sapient_core::Shape::new(half_dims.clone()))
312 .map_err(|e| anyhow::anyhow!("{e}"))?;
313 let up = Tensor::from_f32(&up_v, sapient_core::Shape::new(half_dims))
314 .map_err(|e| anyhow::anyhow!("{e}"))?;
315 let gate = self.backend.silu(&gate)?;
316 let activated = self.backend.mul(&gate, &up)?;
317 self.linear_with_bias(&activated, &format!("{pfx}.mlp.down_proj"), None)
318 }
319}
320
321fn validate_core_shapes(
322 info: &ModelInfo,
323 weights: &HashMap<String, Tensor>,
324 embed_key: &str,
325 lm_head: &Tensor,
326) -> Result<()> {
327 let embed = weights
328 .get(embed_key)
329 .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{embed_key}'"))?;
330 let embed_dims = embed.shape().dims();
331 if embed_dims.len() != 2 || embed_dims[1] != info.hidden_size {
332 anyhow::bail!(
333 "embedding shape mismatch at '{embed_key}': expected [vocab, {}], got {:?}",
334 info.hidden_size,
335 embed_dims
336 );
337 }
338 if embed_dims[0] < info.vocab_size {
339 anyhow::bail!(
340 "embedding vocab rows {} are smaller than config vocab_size {}",
341 embed_dims[0],
342 info.vocab_size
343 );
344 }
345
346 let head_dims = lm_head.shape().dims();
347 if head_dims.len() != 2 || head_dims[1] != info.hidden_size {
348 anyhow::bail!(
349 "lm_head shape mismatch: expected [vocab, {}], got {:?}",
350 info.hidden_size,
351 head_dims
352 );
353 }
354
355 Ok(())
356}