1use std::collections::HashMap;
4
5use anyhow::Result;
6use sapient_core::Tensor;
7use sapient_hub::model_info::ModelInfo;
8
9use super::backend::{LlmBackend, LlmBackendDispatch, LlmBackendKind};
10use super::common::{
11 embed_tokens, mean_pool_hidden, merge_heads, quantize_tensor_to_q8_0,
12 should_quantize_online, split_heads,
13};
14use crate::weights::{
15 detect_weight_prefix, load_hf_weights, resolve_bias, resolve_lm_head, resolve_weight,
16 tie_word_embeddings_from_config,
17};
18
19#[derive(Debug, Default, Clone)]
21struct LayerCache {
22 keys: Option<Tensor>,
23 values: Option<Tensor>,
24 seq_len: usize,
25}
26
27pub struct LlamaForward {
29 info: ModelInfo,
30 prefix: String,
31 weights: HashMap<String, Tensor>,
32 embed_key: String,
33 lm_head: Tensor,
34 cache: Vec<LayerCache>,
35 backend: LlmBackendDispatch,
36}
37
38impl LlamaForward {
39 pub fn from_files(info: ModelInfo, weight_paths: &[std::path::PathBuf]) -> Result<Self> {
40 Self::from_files_with_backend(info, weight_paths, LlmBackendKind::Auto)
41 }
42
43 pub fn from_files_with_backend(
44 info: ModelInfo,
45 weight_paths: &[std::path::PathBuf],
46 backend: LlmBackendKind,
47 ) -> Result<Self> {
48 let weights = load_hf_weights(weight_paths)?;
49 Self::from_weights_with_backend(info, weights, backend)
50 }
51
52 pub fn from_weights(info: ModelInfo, weights: HashMap<String, Tensor>) -> Result<Self> {
53 Self::from_weights_with_backend(info, weights, LlmBackendKind::Auto)
54 }
55
56 pub fn from_weights_with_backend(
57 info: ModelInfo,
58 weights: HashMap<String, Tensor>,
59 backend: LlmBackendKind,
60 ) -> Result<Self> {
61 let prefix = detect_weight_prefix(&weights);
62
63 let weights: HashMap<String, Tensor> = weights
73 .into_iter()
74 .map(|(k, v)| {
75 if should_quantize_online(&k, &v) {
76 (k, quantize_tensor_to_q8_0(v))
77 } else {
78 (k, v)
79 }
80 })
81 .collect();
82 let embed_key = format!("{prefix}embed_tokens.weight");
83 let tie = tie_word_embeddings_from_config(&info.raw);
84 let lm_head = resolve_lm_head(&weights, &prefix, tie, &embed_key)?.clone();
85 validate_core_shapes(&info, &weights, &embed_key, &lm_head)?;
86 let backend = LlmBackendDispatch::from_kind(backend)?;
87 tracing::debug!(
88 backend = backend.name(),
89 "initialized Llama forward backend"
90 );
91
92 let max_seq = info.max_position_embeddings;
93 let n_kv = info.num_key_value_heads;
94 let hd = info.head_dim;
95 let cache_shape = vec![1, n_kv, max_seq, hd];
96
97 let use_q8_cache = hd % 32 == 0;
100
101 let cache = (0..info.num_hidden_layers)
102 .map(|_| {
103 let (keys, values) = if use_q8_cache {
104 let numel = n_kv * max_seq * hd;
106 let kv_bytes = numel / 32 * 34;
107 let k = Tensor::from_quant_bytes(
108 &vec![0u8; kv_bytes],
109 cache_shape.clone(),
110 sapient_core::DType::Q8_0,
111 )
112 .unwrap();
113 let v = Tensor::from_quant_bytes(
114 &vec![0u8; kv_bytes],
115 cache_shape.clone(),
116 sapient_core::DType::Q8_0,
117 )
118 .unwrap();
119 (k, v)
120 } else {
121 let k = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
122 let v = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
123 (k, v)
124 };
125 LayerCache {
126 keys: Some(keys),
127 values: Some(values),
128 seq_len: 0,
129 }
130 })
131 .collect();
132
133 Ok(Self {
134 cache,
135 info,
136 prefix,
137 embed_key,
138 lm_head,
139 weights,
140 backend,
141 })
142 }
143
144 pub fn reset_cache(&mut self) {
145 for layer in &mut self.cache {
146 layer.seq_len = 0;
147 }
148 }
149
150 pub fn forward_logits(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Vec<f32>> {
152 let hidden = self.forward_hidden(input_ids, use_cache)?;
153 self.backend.logits_from_hidden(&hidden, &self.lm_head)
154 }
155
156 pub fn forward_all_logits(&mut self, input_ids: &[u32]) -> Result<Vec<Vec<f32>>> {
159 let hidden = self.forward_hidden(input_ids, false)?;
160 self.backend.all_logits_from_hidden(&hidden, &self.lm_head)
161 }
162
163 pub fn embed(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
165 self.reset_cache();
166 let hidden = self.forward_hidden(input_ids, false)?;
167 mean_pool_hidden(&hidden)
168 }
169
170 fn forward_hidden(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Tensor> {
171 let embed = self
172 .weights
173 .get(&self.embed_key)
174 .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{}'", self.embed_key))?;
175 let mut x = embed_tokens(embed, input_ids)?;
176
177 let start_pos = if use_cache {
178 self.cache.first().map(|l| l.seq_len).unwrap_or(0)
179 } else {
180 self.reset_cache();
181 0
182 };
183
184 let seq_len = input_ids.len();
185 let positions: Vec<usize> = (start_pos..start_pos + seq_len).collect();
186
187 for layer_idx in 0..self.info.num_hidden_layers {
188 x = self.forward_layer(x, layer_idx, &positions, use_cache)?;
189 }
190
191 let norm_w = resolve_weight(&self.weights, &self.prefix, "norm")?;
192 self.backend
193 .rms_norm(&x, norm_w, self.info.rms_norm_eps as f32)
194 }
195
196 fn forward_layer(
197 &mut self,
198 x: Tensor,
199 layer_idx: usize,
200 positions: &[usize],
201 use_cache: bool,
202 ) -> Result<Tensor> {
203 let pfx = format!("layers.{layer_idx}");
204 let eps = self.info.rms_norm_eps as f32;
205 let n_heads = self.info.num_attention_heads;
206 let n_kv = self.info.num_key_value_heads;
207 let head_dim = self.info.head_dim;
208
209 let attn_norm_w = resolve_weight(
210 &self.weights,
211 &self.prefix,
212 &format!("{pfx}.input_layernorm"),
213 )?;
214 let h = self.backend.rms_norm(&x, attn_norm_w, eps)?;
215
216 let q_name = format!("{pfx}.self_attn.q_proj");
220 let k_name = format!("{pfx}.self_attn.k_proj");
221 let v_name = format!("{pfx}.self_attn.v_proj");
222 let (q, k, v) = if self.backend.is_cpu() {
223 let ((q_res, k_res), v_res) = rayon::join(
224 || rayon::join(
225 || self.linear(&h, &q_name),
226 || self.linear(&h, &k_name),
227 ),
228 || self.linear(&h, &v_name),
229 );
230 (q_res?, k_res?, v_res?)
231 } else {
232 let q = self.linear(&h, &q_name)?;
233 let k = self.linear(&h, &k_name)?;
234 let v = self.linear(&h, &v_name)?;
235 (q, k, v)
236 };
237
238 let mut q = split_heads(&q, n_heads, head_dim)?;
239 let mut k = split_heads(&k, n_kv, head_dim)?;
240 let mut v = split_heads(&v, n_kv, head_dim)?;
241
242 q = self
243 .backend
244 .apply_rope_positions(&q, positions, self.info.rope_theta as f32)?;
245 k = self
246 .backend
247 .apply_rope_positions(&k, positions, self.info.rope_theta as f32)?;
248
249 let cache = &mut self.cache[layer_idx];
250 if use_cache {
251 let current_seq = cache.seq_len;
252 if let (Some(ck), Some(cv)) = (&mut cache.keys, &mut cache.values) {
253 k = crate::forward::common::update_kv_cache(ck, current_seq, &k)?;
254 v = crate::forward::common::update_kv_cache(cv, current_seq, &v)?;
255 }
256 cache.seq_len = current_seq + positions.len();
257 }
258
259 let attn = self.backend.gqa_attention(&q, &k, &v, n_kv, true)?;
260 let attn = merge_heads(&attn)?;
261 let o = self.linear(&attn, &format!("{pfx}.self_attn.o_proj"))?;
262 let x = self.backend.add(&x, &o)?;
263
264 let ffn_norm_w = resolve_weight(
265 &self.weights,
266 &self.prefix,
267 &format!("{pfx}.post_attention_layernorm"),
268 )?;
269 let h = self.backend.rms_norm(&x, ffn_norm_w, eps)?;
270
271 let gate_w = resolve_weight(&self.weights, &self.prefix, &format!("{pfx}.mlp.gate_proj"))?;
273 let up_w = resolve_weight(&self.weights, &self.prefix, &format!("{pfx}.mlp.up_proj"))?;
274 let (gate, up) = if self.backend.is_cpu() {
275 let backend = &self.backend;
276 let (gr, ur) = rayon::join(
277 || backend.linear_3d(&h, gate_w),
278 || backend.linear_3d(&h, up_w),
279 );
280 (gr?, ur?)
281 } else {
282 (
283 self.backend.linear_3d(&h, gate_w)?,
284 self.backend.linear_3d(&h, up_w)?,
285 )
286 };
287 let gate = self.backend.silu(&gate)?;
288 let mid = self.backend.mul(&gate, &up)?;
289 let down = self.backend.linear_3d(
290 &mid,
291 resolve_weight(&self.weights, &self.prefix, &format!("{pfx}.mlp.down_proj"))?,
292 )?;
293 self.backend.add(&x, &down)
294 }
295
296 fn linear(&self, x: &Tensor, name: &str) -> Result<Tensor> {
299 let weight = resolve_weight(&self.weights, &self.prefix, name)?;
300 let bias = resolve_bias(&self.weights, &self.prefix, name);
301 self.backend.linear_3d_bias(x, weight, bias)
302 }
303}
304
305fn validate_core_shapes(
306 info: &ModelInfo,
307 weights: &HashMap<String, Tensor>,
308 embed_key: &str,
309 lm_head: &Tensor,
310) -> Result<()> {
311 let embed = weights
312 .get(embed_key)
313 .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{embed_key}'"))?;
314 let embed_dims = embed.shape().dims();
315 if embed_dims.len() != 2 || embed_dims[1] != info.hidden_size {
316 anyhow::bail!(
317 "embedding shape mismatch at '{embed_key}': expected [vocab, {}], got {:?}",
318 info.hidden_size,
319 embed_dims
320 );
321 }
322 if embed_dims[0] < info.vocab_size {
323 anyhow::bail!(
324 "embedding vocab rows {} are smaller than config vocab_size {}",
325 embed_dims[0],
326 info.vocab_size
327 );
328 }
329
330 let head_dims = lm_head.shape().dims();
331 if head_dims.len() != 2 || head_dims[1] != info.hidden_size {
332 anyhow::bail!(
333 "lm_head shape mismatch: expected [vocab, {}], got {:?}",
334 info.hidden_size,
335 head_dims
336 );
337 }
338
339 Ok(())
340}