1use std::collections::HashMap;
4
5use anyhow::Result;
6use sapient_core::Tensor;
7use sapient_hub::model_info::ModelInfo;
8
9use super::backend::{LlmBackend, LlmBackendDispatch, LlmBackendKind};
10use super::common::{embed_tokens, mean_pool_hidden, merge_heads, split_heads};
11use crate::weights::{
12 detect_weight_prefix, load_hf_weights, resolve_bias, resolve_lm_head, resolve_weight,
13 tie_word_embeddings_from_config,
14};
15
16#[derive(Debug, Default, Clone)]
18struct LayerCache {
19 keys: Option<Tensor>,
20 values: Option<Tensor>,
21 seq_len: usize,
22}
23
24pub struct LlamaForward {
26 info: ModelInfo,
27 prefix: String,
28 weights: HashMap<String, Tensor>,
29 embed_key: String,
30 lm_head: Tensor,
31 cache: Vec<LayerCache>,
32 backend: LlmBackendDispatch,
33}
34
35impl LlamaForward {
36 pub fn from_files(info: ModelInfo, weight_paths: &[std::path::PathBuf]) -> Result<Self> {
37 Self::from_files_with_backend(info, weight_paths, LlmBackendKind::Auto)
38 }
39
40 pub fn from_files_with_backend(
41 info: ModelInfo,
42 weight_paths: &[std::path::PathBuf],
43 backend: LlmBackendKind,
44 ) -> Result<Self> {
45 let weights = load_hf_weights(weight_paths)?;
46 Self::from_weights_with_backend(info, weights, backend)
47 }
48
49 pub fn from_weights(info: ModelInfo, weights: HashMap<String, Tensor>) -> Result<Self> {
50 Self::from_weights_with_backend(info, weights, LlmBackendKind::Auto)
51 }
52
53 pub fn from_weights_with_backend(
54 info: ModelInfo,
55 weights: HashMap<String, Tensor>,
56 backend: LlmBackendKind,
57 ) -> Result<Self> {
58 let prefix = detect_weight_prefix(&weights);
59 let embed_key = format!("{prefix}embed_tokens.weight");
60 let tie = tie_word_embeddings_from_config(&info.raw);
61 let lm_head = resolve_lm_head(&weights, &prefix, tie, &embed_key)?.clone();
62 validate_core_shapes(&info, &weights, &embed_key, &lm_head)?;
63 let backend = LlmBackendDispatch::from_kind(backend)?;
64 tracing::debug!(
65 backend = backend.name(),
66 "initialized Llama forward backend"
67 );
68
69 let max_seq = info.max_position_embeddings;
70 let n_kv = info.num_key_value_heads;
71 let hd = info.head_dim;
72 let cache_shape = vec![1, n_kv, max_seq, hd];
73
74 let use_q8_cache = hd % 32 == 0;
77
78 let cache = (0..info.num_hidden_layers)
79 .map(|_| {
80 let (keys, values) = if use_q8_cache {
81 let numel = n_kv * max_seq * hd;
83 let kv_bytes = numel / 32 * 34;
84 let k = Tensor::from_quant_bytes(
85 &vec![0u8; kv_bytes],
86 cache_shape.clone(),
87 sapient_core::DType::Q8_0,
88 )
89 .unwrap();
90 let v = Tensor::from_quant_bytes(
91 &vec![0u8; kv_bytes],
92 cache_shape.clone(),
93 sapient_core::DType::Q8_0,
94 )
95 .unwrap();
96 (k, v)
97 } else {
98 let k = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
99 let v = Tensor::zeros(cache_shape.clone(), sapient_core::DType::F32).unwrap();
100 (k, v)
101 };
102 LayerCache {
103 keys: Some(keys),
104 values: Some(values),
105 seq_len: 0,
106 }
107 })
108 .collect();
109
110 Ok(Self {
111 cache,
112 info,
113 prefix,
114 embed_key,
115 lm_head,
116 weights,
117 backend,
118 })
119 }
120
121 pub fn reset_cache(&mut self) {
122 for layer in &mut self.cache {
123 layer.seq_len = 0;
124 }
125 }
126
127 pub fn forward_logits(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Vec<f32>> {
129 let hidden = self.forward_hidden(input_ids, use_cache)?;
130 self.backend.logits_from_hidden(&hidden, &self.lm_head)
131 }
132
133 pub fn forward_all_logits(&mut self, input_ids: &[u32]) -> Result<Vec<Vec<f32>>> {
136 let hidden = self.forward_hidden(input_ids, false)?;
137 self.backend.all_logits_from_hidden(&hidden, &self.lm_head)
138 }
139
140 pub fn embed(&mut self, input_ids: &[u32]) -> Result<Vec<f32>> {
142 self.reset_cache();
143 let hidden = self.forward_hidden(input_ids, false)?;
144 mean_pool_hidden(&hidden)
145 }
146
147 fn forward_hidden(&mut self, input_ids: &[u32], use_cache: bool) -> Result<Tensor> {
148 let embed = self
149 .weights
150 .get(&self.embed_key)
151 .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{}'", self.embed_key))?;
152 let mut x = embed_tokens(embed, input_ids)?;
153
154 let start_pos = if use_cache {
155 self.cache.first().map(|l| l.seq_len).unwrap_or(0)
156 } else {
157 self.reset_cache();
158 0
159 };
160
161 let seq_len = input_ids.len();
162 let positions: Vec<usize> = (start_pos..start_pos + seq_len).collect();
163
164 for layer_idx in 0..self.info.num_hidden_layers {
165 x = self.forward_layer(x, layer_idx, &positions, use_cache)?;
166 }
167
168 let norm_w = resolve_weight(&self.weights, &self.prefix, "norm")?;
169 self.backend
170 .rms_norm(&x, norm_w, self.info.rms_norm_eps as f32)
171 }
172
173 fn forward_layer(
174 &mut self,
175 x: Tensor,
176 layer_idx: usize,
177 positions: &[usize],
178 use_cache: bool,
179 ) -> Result<Tensor> {
180 let pfx = format!("layers.{layer_idx}");
181 let eps = self.info.rms_norm_eps as f32;
182 let n_heads = self.info.num_attention_heads;
183 let n_kv = self.info.num_key_value_heads;
184 let head_dim = self.info.head_dim;
185
186 let attn_norm_w = resolve_weight(
187 &self.weights,
188 &self.prefix,
189 &format!("{pfx}.input_layernorm"),
190 )?;
191 let h = self.backend.rms_norm(&x, attn_norm_w, eps)?;
192
193 let q = self.linear(&h, &format!("{pfx}.self_attn.q_proj"))?;
196 let k = self.linear(&h, &format!("{pfx}.self_attn.k_proj"))?;
197 let v = self.linear(&h, &format!("{pfx}.self_attn.v_proj"))?;
198
199 let mut q = split_heads(&q, n_heads, head_dim)?;
200 let mut k = split_heads(&k, n_kv, head_dim)?;
201 let mut v = split_heads(&v, n_kv, head_dim)?;
202
203 q = self
204 .backend
205 .apply_rope_positions(&q, positions, self.info.rope_theta as f32)?;
206 k = self
207 .backend
208 .apply_rope_positions(&k, positions, self.info.rope_theta as f32)?;
209
210 let cache = &mut self.cache[layer_idx];
211 if use_cache {
212 let current_seq = cache.seq_len;
213 if let (Some(ck), Some(cv)) = (&mut cache.keys, &mut cache.values) {
214 k = crate::forward::common::update_kv_cache(ck, current_seq, &k)?;
215 v = crate::forward::common::update_kv_cache(cv, current_seq, &v)?;
216 }
217 cache.seq_len = current_seq + positions.len();
218 }
219
220 let attn = self.backend.gqa_attention(&q, &k, &v, n_kv, true)?;
221 let attn = merge_heads(&attn)?;
222 let o = self.linear(&attn, &format!("{pfx}.self_attn.o_proj"))?;
223 let x = self.backend.add(&x, &o)?;
224
225 let ffn_norm_w = resolve_weight(
226 &self.weights,
227 &self.prefix,
228 &format!("{pfx}.post_attention_layernorm"),
229 )?;
230 let h = self.backend.rms_norm(&x, ffn_norm_w, eps)?;
231
232 let gate = self.backend.linear_3d(
233 &h,
234 resolve_weight(&self.weights, &self.prefix, &format!("{pfx}.mlp.gate_proj"))?,
235 )?;
236 let up = self.backend.linear_3d(
237 &h,
238 resolve_weight(&self.weights, &self.prefix, &format!("{pfx}.mlp.up_proj"))?,
239 )?;
240 let gate = self.backend.silu(&gate)?;
241 let mid = self.backend.mul(&gate, &up)?;
242 let down = self.backend.linear_3d(
243 &mid,
244 resolve_weight(&self.weights, &self.prefix, &format!("{pfx}.mlp.down_proj"))?,
245 )?;
246 self.backend.add(&x, &down)
247 }
248
249 fn linear(&self, x: &Tensor, name: &str) -> Result<Tensor> {
252 let weight = resolve_weight(&self.weights, &self.prefix, name)?;
253 let bias = resolve_bias(&self.weights, &self.prefix, name);
254 self.backend.linear_3d_bias(x, weight, bias)
255 }
256}
257
258fn validate_core_shapes(
259 info: &ModelInfo,
260 weights: &HashMap<String, Tensor>,
261 embed_key: &str,
262 lm_head: &Tensor,
263) -> Result<()> {
264 let embed = weights
265 .get(embed_key)
266 .ok_or_else(|| anyhow::anyhow!("missing embedding weights at '{embed_key}'"))?;
267 let embed_dims = embed.shape().dims();
268 if embed_dims.len() != 2 || embed_dims[1] != info.hidden_size {
269 anyhow::bail!(
270 "embedding shape mismatch at '{embed_key}': expected [vocab, {}], got {:?}",
271 info.hidden_size,
272 embed_dims
273 );
274 }
275 if embed_dims[0] < info.vocab_size {
276 anyhow::bail!(
277 "embedding vocab rows {} are smaller than config vocab_size {}",
278 embed_dims[0],
279 info.vocab_size
280 );
281 }
282
283 let head_dims = lm_head.shape().dims();
284 if head_dims.len() != 2 || head_dims[1] != info.hidden_size {
285 anyhow::bail!(
286 "lm_head shape mismatch: expected [vocab, {}], got {:?}",
287 info.hidden_size,
288 head_dims
289 );
290 }
291
292 Ok(())
293}