1use std::collections::HashMap;
4
5use anyhow::Result;
6use sapient_core::Tensor;
7use sapient_hub::model_info::ModelInfo;
8
9use super::backend::{LlmBackend, LlmBackendDispatch, LlmBackendKind};
10use super::common::{
11 embed_tokens, mean_pool_hidden, merge_heads, quantize_tensor_to_q8_0,
12 should_quantize_online, split_heads,
13};
14use crate::weights::{
15 detect_weight_prefix, load_hf_weights, resolve_bias, resolve_lm_head, resolve_weight,
16 tie_word_embeddings_from_config,
17};
18
19#[derive(Debug, Default, Clone)]
20struct LayerCache {
21 keys: Option<Tensor>,
22 values: Option<Tensor>,
23 seq_len: usize,
24}
25
26pub struct PhiForward {
27 info: ModelInfo,
28 prefix: String,
29 weights: HashMap<String, Tensor>,
30 embed_key: String,
31 lm_head: Tensor,
32 cache: Vec<LayerCache>,
33 backend: LlmBackendDispatch,
34}
35
36impl PhiForward {
37 pub fn from_files(info: ModelInfo, weight_paths: &[std::path::PathBuf]) -> Result<Self> {
38 Self::from_files_with_backend(info, weight_paths, LlmBackendKind::Auto)
39 }
40
41 pub fn from_files_with_backend(
42 info: ModelInfo,
43 weight_paths: &[std::path::PathBuf],
44 backend: LlmBackendKind,
45 ) -> Result<Self> {
46 let weights = load_hf_weights(weight_paths)?;
47 Self::from_weights_with_backend(info, weights, backend)
48 }
49
50 pub fn from_weights(info: ModelInfo, weights: HashMap<String, Tensor>) -> Result<Self> {
51 Self::from_weights_with_backend(info, weights, LlmBackendKind::Auto)
52 }
53
54 pub fn from_weights_with_backend(
55 info: ModelInfo,
56 weights: HashMap<String, Tensor>,
57 backend: LlmBackendKind,
58 ) -> Result<Self> {
59 let prefix = detect_weight_prefix(&weights);
60
61 let weights: HashMap<String, Tensor> = weights
65 .into_iter()
66 .map(|(k, v)| {
67 if should_quantize_online(&k, &v) {
68 (k, quantize_tensor_to_q8_0(v))
69 } else {
70 (k, v)
71 }
72 })
73 .collect();
74 let embed_key = format!("{prefix}embed_tokens.weight");
75 let tie = tie_word_embeddings_from_config(&info.raw);
76 let lm_head = resolve_lm_head(&weights, &prefix, tie, &embed_key)?.clone();
77 validate_core_shapes(&info, &weights, &embed_key, &lm_head)?;
78 let backend = LlmBackendDispatch::from_kind(backend)?;
79 tracing::debug!(backend = backend.name(), "initialized Phi forward backend");
80
81 let max_seq = info.max_position_embeddings;
82 let n_kv = info.num_key_value_heads;
83 let hd = info.head_dim;
84 let cache_shape = vec![1, n_kv, max_seq, hd];
85
86 let use_q8_cache = hd % 32 == 0;
89
90 let cache = (0..info.num_hidden_layers)
91 .map(|_| {
92 let (keys, values) = if use_q8_cache {
93 let numel = n_kv * max_seq * hd;
94 let kv_bytes = numel / 32 * 34;
95 let k = Tensor::from_quant_bytes(
96 &vec![0u8; kv_bytes],
97 cache_shape.clone(),
98 sapient_core::DType::Q8_0,
99 )
100 .unwrap();
101 let v = Tensor::from_quant_bytes(
102 &vec![0u8; kv_bytes],
103 cache_shape.clone(),
104 sapient_core::DType::Q8_0,
105 )
106 .unwrap();
107 (k, v)
108 } else {
109 let k = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
110 let v = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
111 (k, v)
112 };
113 LayerCache {
114 keys: Some(keys),
115 values: Some(values),
116 seq_len: 0,
117 }
118 })
119 .collect();
120
121 Ok(Self {
122 cache,
123 info,
124 prefix,
125 embed_key,
126 lm_head,
127 weights,
128 backend,
129 })
130 }
131
132 pub fn reset_cache(&mut self) {
133 for layer in &mut self.cache {
134 layer.seq_len = 0;
135 }
136 }
137
138 pub fn forward_logits(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Vec<f32>> {
139 let hidden = self.forward_hidden(input_ids, use_cache)?;
140 let mut logits = self.backend.logits_from_hidden(&hidden, &self.lm_head)?;
141 if let Some(bias) = resolve_bias(&self.weights, &self.prefix, "lm_head") {
143 let bias_cow = bias.to_f32_cow();
144 for (l, b) in logits.iter_mut().zip(bias_cow.iter()) {
145 *l += *b;
146 }
147 }
148 Ok(logits)
149 }
150
151 pub fn forward_all_logits(&mut self, input_ids: &[u32]) -> Result<Vec<Vec<f32>>> {
153 let hidden = self.forward_hidden(input_ids, false)?;
154 let mut all = self.backend.all_logits_from_hidden(&hidden, &self.lm_head)?;
155 if let Some(bias) = resolve_bias(&self.weights, &self.prefix, "lm_head") {
157 let bias_cow = bias.to_f32_cow();
158 for logits in &mut all {
159 for (l, b) in logits.iter_mut().zip(bias_cow.iter()) {
160 *l += *b;
161 }
162 }
163 }
164 Ok(all)
165 }
166
167 pub fn embed(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
168 self.reset_cache();
169 let hidden = self.forward_hidden(input_ids, false)?;
170 mean_pool_hidden(&hidden)
171 }
172
173 fn forward_hidden(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Tensor> {
174 let embed = self
175 .weights
176 .get(&self.embed_key)
177 .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{}'", self.embed_key))?;
178 let mut x = embed_tokens(embed, input_ids)?;
179
180 let start_pos = if use_cache {
181 self.cache.first().map(|l| l.seq_len).unwrap_or(0)
182 } else {
183 self.reset_cache();
184 0
185 };
186 let seq_len = input_ids.len();
187 let positions: Vec<usize> = (start_pos..start_pos + seq_len).collect();
188
189 for layer_idx in 0..self.info.num_hidden_layers {
190 x = self.forward_layer(x, layer_idx, &positions, use_cache)?;
191 }
192
193 let (norm_w, norm_b) = match resolve_weight(&self.weights, &self.prefix, "final_layernorm")
195 {
196 Ok(w) => (
197 w,
198 resolve_bias(&self.weights, &self.prefix, "final_layernorm"),
199 ),
200 Err(_) => (
201 resolve_weight(&self.weights, &self.prefix, "norm")?,
202 resolve_bias(&self.weights, &self.prefix, "norm"),
203 ),
204 };
205 self.backend
206 .layer_norm(&x, norm_w, norm_b, self.info.rms_norm_eps as f32)
207 }
208
209 fn forward_layer(
210 &mut self,
211 x: Tensor,
212 layer_idx: usize,
213 positions: &[usize],
214 use_cache: bool,
215 ) -> Result<Tensor> {
216 let pfx = format!("layers.{layer_idx}");
217 let eps = self.info.rms_norm_eps as f32;
218 let n_heads = self.info.num_attention_heads;
219 let head_dim = self.info.head_dim;
220
221 let rotary_dim = ((self.info.partial_rotary_factor * head_dim as f64).round() as usize)
223 .clamp(2, head_dim);
224 let theta = self.info.rope_theta as f32;
225
226 let in_ln = format!("{pfx}.input_layernorm");
228 let norm_w = resolve_weight(&self.weights, &self.prefix, &in_ln)?;
229 let norm_b = resolve_bias(&self.weights, &self.prefix, &in_ln);
230 let h = self.backend.layer_norm(&x, norm_w, norm_b, eps)?;
231
232 let q = self.linear_with_bias(&h, &format!("{pfx}.self_attn.q_proj"), None)?;
234 let k = self.linear_with_bias(&h, &format!("{pfx}.self_attn.k_proj"), None)?;
235 let v = self.linear_with_bias(&h, &format!("{pfx}.self_attn.v_proj"), None)?;
236
237 let q = split_heads(&q, n_heads, head_dim)?;
238 let k = split_heads(&k, n_heads, head_dim)?;
239 let mut v = split_heads(&v, n_heads, head_dim)?;
240
241 let q = self
242 .backend
243 .apply_rope_partial(&q, positions, theta, rotary_dim)?;
244 let mut k = self
245 .backend
246 .apply_rope_partial(&k, positions, theta, rotary_dim)?;
247
248 if use_cache {
249 let current_seq = self.cache[layer_idx].seq_len;
250 let cache = &mut self.cache[layer_idx];
251 if let (Some(ck), Some(cv)) = (&mut cache.keys, &mut cache.values) {
252 k = crate::forward::common::update_kv_cache(ck, current_seq, &k)?;
253 v = crate::forward::common::update_kv_cache(cv, current_seq, &v)?;
254 }
255 cache.seq_len = (current_seq + positions.len()).min(self.info.max_position_embeddings);
256 }
257
258 let attn = self.backend.gqa_attention(&q, &k, &v, n_heads, true)?;
259 let attn = merge_heads(&attn)?;
260 let o = self.linear_with_bias(
262 &attn,
263 &format!("{pfx}.self_attn.dense"),
264 Some(&format!("{pfx}.self_attn.o_proj")),
265 )?;
266
267 if self.info.model_type == "phi" {
270 let ff = self.mlp_phi2(&h, &pfx)?;
271 let parallel_res = self.backend.add(&o, &ff)?;
272 self.backend.add(&x, ¶llel_res)
273 } else {
274 let x = self.backend.add(&x, &o)?;
276 let post_ln = format!("{pfx}.post_attention_layernorm");
277 let pn_w = resolve_weight(&self.weights, &self.prefix, &post_ln)?;
278 let pn_b = resolve_bias(&self.weights, &self.prefix, &post_ln);
279 let hn = self.backend.layer_norm(&x, pn_w, pn_b, eps)?;
280 let ff = self.mlp_phi3(&hn, &pfx)?;
281 self.backend.add(&x, &ff)
282 }
283 }
284
285 fn linear_with_bias(&self, x: &Tensor, name: &str, alt: Option<&str>) -> Result<Tensor> {
288 let (weight, bias) = match resolve_weight(&self.weights, &self.prefix, name) {
289 Ok(w) => (w, resolve_bias(&self.weights, &self.prefix, name)),
290 Err(e) => match alt {
291 Some(a) => (
292 resolve_weight(&self.weights, &self.prefix, a)?,
293 resolve_bias(&self.weights, &self.prefix, a),
294 ),
295 None => return Err(e),
296 },
297 };
298 self.backend.linear_3d_bias(x, weight, bias)
299 }
300
301 fn mlp_phi2(&self, h: &Tensor, pfx: &str) -> Result<Tensor> {
303 let ff1 = self.linear_with_bias(h, &format!("{pfx}.mlp.fc1"), None)?;
304 let ff1 = self.backend.gelu(&ff1)?;
305 self.linear_with_bias(&ff1, &format!("{pfx}.mlp.fc2"), None)
306 }
307
308 fn mlp_phi3(&self, h: &Tensor, pfx: &str) -> Result<Tensor> {
310 let gate_up = self.linear_with_bias(h, &format!("{pfx}.mlp.gate_up_proj"), None)?;
311 let dims = gate_up.shape().dims().to_vec();
315 let last = *dims.last().unwrap();
316 let inter = last / 2;
317 let rows: usize = dims[..dims.len() - 1].iter().product();
318 let src = gate_up.to_f32_cow();
319 let mut gate_v = vec![0.0f32; rows * inter];
320 let mut up_v = vec![0.0f32; rows * inter];
321 for r in 0..rows {
322 let base = r * last;
323 gate_v[r * inter..(r + 1) * inter].copy_from_slice(&src[base..base + inter]);
324 up_v[r * inter..(r + 1) * inter].copy_from_slice(&src[base + inter..base + last]);
325 }
326 let mut half_dims = dims.clone();
327 *half_dims.last_mut().unwrap() = inter;
328 let gate = Tensor::from_f32(&gate_v, sapient_core::Shape::new(half_dims.clone()))
329 .map_err(|e| anyhow::anyhow!("{e}"))?;
330 let up = Tensor::from_f32(&up_v, sapient_core::Shape::new(half_dims))
331 .map_err(|e| anyhow::anyhow!("{e}"))?;
332 let gate = self.backend.silu(&gate)?;
333 let activated = self.backend.mul(&gate, &up)?;
334 self.linear_with_bias(&activated, &format!("{pfx}.mlp.down_proj"), None)
335 }
336}
337
338fn validate_core_shapes(
339 info: &ModelInfo,
340 weights: &HashMap<String, Tensor>,
341 embed_key: &str,
342 lm_head: &Tensor,
343) -> Result<()> {
344 let embed = weights
345 .get(embed_key)
346 .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{embed_key}'"))?;
347 let embed_dims = embed.shape().dims();
348 if embed_dims.len() != 2 || embed_dims[1] != info.hidden_size {
349 anyhow::bail!(
350 "embedding shape mismatch at '{embed_key}': expected [vocab, {}], got {:?}",
351 info.hidden_size,
352 embed_dims
353 );
354 }
355 if embed_dims[0] < info.vocab_size {
356 anyhow::bail!(
357 "embedding vocab rows {} are smaller than config vocab_size {}",
358 embed_dims[0],
359 info.vocab_size
360 );
361 }
362
363 let head_dims = lm_head.shape().dims();
364 if head_dims.len() != 2 || head_dims[1] != info.hidden_size {
365 anyhow::bail!(
366 "lm_head shape mismatch: expected [vocab, {}], got {:?}",
367 info.hidden_size,
368 head_dims
369 );
370 }
371
372 Ok(())
373}