1use crate::error::SparseInferenceError;
4use crate::model::loader::{ModelLoader, ModelMetadata};
5use crate::model::types::{CalibrationStats, InferenceConfig, ModelInput, ModelOutput, Tensor};
6use crate::ops::{Linear, Embedding, RMSNorm, LayerNorm, silu};
7use std::collections::HashMap;
8
9type Result<T> = std::result::Result<T, SparseInferenceError>;
10
11pub trait ModelRunner {
13 fn forward(&self, input: &ModelInput, config: &InferenceConfig) -> Result<ModelOutput>;
15
16 fn get_predictor(&self, layer_idx: usize) -> Option<&LowRankPredictor>;
18
19 fn calibrate(&mut self, samples: &[ModelInput]) -> Result<CalibrationStats>;
21
22 fn metadata(&self) -> &ModelMetadata;
24}
25
26#[derive(Debug, Clone)]
28pub struct LowRankPredictor {
29 pub u: Vec<Vec<f32>>, pub v: Vec<Vec<f32>>, pub rank: usize,
32}
33
34impl LowRankPredictor {
35 pub fn new(input_dim: usize, output_dim: usize, rank: usize) -> Self {
36 Self {
37 u: vec![vec![0.0; rank]; input_dim],
38 v: vec![vec![0.0; output_dim]; rank],
39 rank,
40 }
41 }
42
43 pub fn predict_active(&self, input: &[f32], k: usize) -> Vec<usize> {
45 let scores = self.forward(input);
46 let mut indices: Vec<usize> = (0..scores.len()).collect();
47 indices.sort_by(|&a, &b| scores[b].partial_cmp(&scores[a]).unwrap());
48 indices.truncate(k);
49 indices
50 }
51
52 fn forward(&self, input: &[f32]) -> Vec<f32> {
53 let mut hidden = vec![0.0; self.rank];
56 for i in 0..self.rank {
57 for (j, u_ji) in self.u.iter().enumerate() {
58 if j < input.len() && i < u_ji.len() {
59 hidden[i] += u_ji[i] * input[j];
60 }
61 }
62 }
63
64 let output_dim = self.v.first().map(|v| v.len()).unwrap_or(0);
66 let mut output = vec![0.0; output_dim];
67 for i in 0..output_dim {
68 for (j, &h) in hidden.iter().enumerate() {
69 if j < self.v.len() && i < self.v[j].len() {
70 output[i] += self.v[j][i] * h;
71 }
72 }
73 }
74
75 output
76 }
77}
78
79pub struct LlamaModel {
85 pub metadata: ModelMetadata,
86 pub layers: Vec<LlamaLayer>,
87 pub embed_tokens: Embedding,
88 pub norm: RMSNorm,
89 pub lm_head: Option<Linear>,
90}
91
92pub struct LlamaLayer {
93 pub input_layernorm: RMSNorm,
94 pub self_attn: LlamaAttention,
95 pub post_attention_layernorm: RMSNorm,
96 pub mlp: LlamaMLP,
97 pub predictor: Option<LowRankPredictor>,
98}
99
100pub struct LlamaAttention {
101 pub q_proj: Linear,
102 pub k_proj: Linear,
103 pub v_proj: Linear,
104 pub o_proj: Linear,
105 pub num_heads: usize,
106 pub head_dim: usize,
107}
108
109pub struct LlamaMLP {
110 pub gate_proj: Linear, pub up_proj: Linear, pub down_proj: Linear, }
114
115impl LlamaMLP {
116 pub fn forward(&self, x: &[f32]) -> Vec<f32> {
118 let gate = self.gate_proj.forward(x);
119 let up = self.up_proj.forward(x);
120
121 let hidden: Vec<f32> = gate
123 .iter()
124 .zip(up.iter())
125 .map(|(&g, &u)| silu(g) * u)
126 .collect();
127
128 self.down_proj.forward(&hidden)
129 }
130
131 pub fn forward_sparse(
133 &self,
134 x: &[f32],
135 active_neurons: &[usize],
136 ) -> Vec<f32> {
137 let gate = sparse_matmul(&self.gate_proj, x, active_neurons);
139 let up = sparse_matmul(&self.up_proj, x, active_neurons);
140
141 let hidden: Vec<f32> = gate
143 .iter()
144 .zip(up.iter())
145 .map(|(&g, &u)| silu(g) * u)
146 .collect();
147
148 sparse_matmul_full(&self.down_proj, &hidden, active_neurons)
150 }
151}
152
153impl ModelRunner for LlamaModel {
154 fn forward(&self, input: &ModelInput, config: &InferenceConfig) -> Result<ModelOutput> {
155 let mut hidden_states = self.embed_tokens.forward(&input.input_ids);
157
158 let mut all_hidden_states = if config.output_hidden_states {
159 Some(Vec::new())
160 } else {
161 None
162 };
163
164 for (idx, layer) in self.layers.iter().enumerate() {
166 if let Some(ref mut states) = all_hidden_states {
167 states.push(hidden_states.clone());
168 }
169
170 let normed = layer.input_layernorm.forward(&hidden_states);
172
173 let attn_output = layer.self_attn.forward(&normed);
175
176 hidden_states = add_vectors(&hidden_states, &attn_output);
178
179 let normed = layer.post_attention_layernorm.forward(&hidden_states);
181
182 let mlp_output = if config.use_sparse_ffn {
184 if let Some(ref predictor) = layer.predictor {
185 let k = config.active_neurons_per_layer.unwrap_or(
186 (self.metadata.intermediate_size as f32 * (1.0 - config.sparsity)) as usize,
187 );
188 let active = predictor.predict_active(&normed, k);
189 layer.mlp.forward_sparse(&normed, &active)
190 } else {
191 layer.mlp.forward(&normed)
192 }
193 } else {
194 layer.mlp.forward(&normed)
195 };
196
197 hidden_states = add_vectors(&hidden_states, &mlp_output);
199 }
200
201 hidden_states = self.norm.forward(&hidden_states);
203
204 let logits = if let Some(ref lm_head) = self.lm_head {
206 lm_head.forward(&hidden_states)
207 } else {
208 hidden_states
209 };
210
211 Ok(ModelOutput::new(logits).with_hidden_states(all_hidden_states.unwrap_or_default()))
212 }
213
214 fn get_predictor(&self, layer_idx: usize) -> Option<&LowRankPredictor> {
215 self.layers.get(layer_idx)?.predictor.as_ref()
216 }
217
218 fn calibrate(&mut self, samples: &[ModelInput]) -> Result<CalibrationStats> {
219 Ok(CalibrationStats {
221 num_samples: samples.len(),
222 average_sparsity: 0.9,
223 layer_stats: HashMap::new(),
224 })
225 }
226
227 fn metadata(&self) -> &ModelMetadata {
228 &self.metadata
229 }
230}
231
232impl LlamaAttention {
233 pub fn forward(&self, hidden_states: &[f32]) -> Vec<f32> {
234 let q = self.q_proj.forward(hidden_states);
236 let k = self.k_proj.forward(hidden_states);
237 let v = self.v_proj.forward(hidden_states);
238
239 self.o_proj.forward(&q)
241 }
242}
243
244pub struct LFM2Model {
249 pub metadata: ModelMetadata,
250 pub embedding: Embedding,
251 pub layers: Vec<LFM2Layer>,
252 pub pooler: Option<Pooler>,
253}
254
255pub struct LFM2Layer {
256 pub gated_conv: GatedConv1d,
257 pub attention: GroupedQueryAttention,
258 pub ffn: SparseFfn,
259 pub norm: LayerNorm,
260}
261
262pub struct GatedConv1d {
263 pub weight: Vec<Vec<f32>>,
264 pub gate: Linear,
265}
266
267pub struct GroupedQueryAttention {
268 pub q_proj: Linear,
269 pub k_proj: Linear,
270 pub v_proj: Linear,
271 pub o_proj: Linear,
272 pub num_groups: usize,
273}
274
275pub struct SparseFfn {
276 pub w1: Linear,
277 pub w2: Linear,
278 pub predictor: Option<LowRankPredictor>,
279}
280
281impl ModelRunner for LFM2Model {
282 fn forward(&self, input: &ModelInput, config: &InferenceConfig) -> Result<ModelOutput> {
283 let mut hidden = self.embedding.forward(&input.input_ids);
284
285 for layer in &self.layers {
286 hidden = layer.gated_conv.forward(&hidden);
288
289 let attn_out = layer.attention.forward(&hidden);
291 hidden = add_vectors(&hidden, &attn_out);
292
293 let ffn_out = layer.ffn.forward(&hidden, config);
295 hidden = add_vectors(&hidden, &ffn_out);
296
297 hidden = layer.norm.forward(&hidden);
298 }
299
300 Ok(ModelOutput::new(hidden))
301 }
302
303 fn get_predictor(&self, layer_idx: usize) -> Option<&LowRankPredictor> {
304 self.layers.get(layer_idx)?.ffn.predictor.as_ref()
305 }
306
307 fn calibrate(&mut self, _samples: &[ModelInput]) -> Result<CalibrationStats> {
308 Ok(CalibrationStats {
309 num_samples: 0,
310 average_sparsity: 0.9,
311 layer_stats: HashMap::new(),
312 })
313 }
314
315 fn metadata(&self) -> &ModelMetadata {
316 &self.metadata
317 }
318}
319
320impl GatedConv1d {
321 pub fn forward(&self, x: &[f32]) -> Vec<f32> {
322 x.to_vec()
324 }
325}
326
327impl GroupedQueryAttention {
328 pub fn forward(&self, x: &[f32]) -> Vec<f32> {
329 self.o_proj.forward(x)
330 }
331}
332
333impl SparseFfn {
334 pub fn forward(&self, x: &[f32], config: &InferenceConfig) -> Vec<f32> {
335 if config.use_sparse_ffn {
336 if let Some(ref predictor) = self.predictor {
337 let k = (self.w1.out_features as f32 * (1.0 - config.sparsity)) as usize;
338 let active = predictor.predict_active(x, k);
339 return sparse_matmul_full(&self.w2, &self.w1.forward(x), &active);
340 }
341 }
342 self.w2.forward(&self.w1.forward(x))
343 }
344}
345
346pub struct BertModel {
351 pub metadata: ModelMetadata,
352 pub embeddings: BertEmbeddings,
353 pub encoder: Vec<BertLayer>,
354 pub pooler: Option<Pooler>,
355}
356
357pub struct BertEmbeddings {
358 pub word_embeddings: Embedding,
359 pub position_embeddings: Embedding,
360 pub token_type_embeddings: Embedding,
361 pub layer_norm: LayerNorm,
362}
363
364pub struct BertLayer {
365 pub attention: MultiHeadAttention,
366 pub intermediate: Linear,
367 pub output: Linear,
368 pub layer_norm1: LayerNorm,
369 pub layer_norm2: LayerNorm,
370}
371
372pub struct MultiHeadAttention {
373 pub q_proj: Linear,
374 pub k_proj: Linear,
375 pub v_proj: Linear,
376 pub o_proj: Linear,
377 pub num_heads: usize,
378}
379
380pub struct Pooler {
381 pub dense: Linear,
382}
383
384impl ModelRunner for BertModel {
385 fn forward(&self, input: &ModelInput, config: &InferenceConfig) -> Result<ModelOutput> {
386 let mut hidden = self.embeddings.forward(&input.input_ids);
387
388 for layer in &self.encoder {
389 let attn_out = layer.attention.forward(&hidden);
390 hidden = layer.layer_norm1.forward(&add_vectors(&hidden, &attn_out));
391
392 let intermediate = layer.intermediate.forward(&hidden);
393 let output = layer.output.forward(&intermediate);
394 hidden = layer.layer_norm2.forward(&add_vectors(&hidden, &output));
395 }
396
397 Ok(ModelOutput::new(hidden))
398 }
399
400 fn get_predictor(&self, _layer_idx: usize) -> Option<&LowRankPredictor> {
401 None
402 }
403
404 fn calibrate(&mut self, _samples: &[ModelInput]) -> Result<CalibrationStats> {
405 Ok(CalibrationStats {
406 num_samples: 0,
407 average_sparsity: 0.0,
408 layer_stats: HashMap::new(),
409 })
410 }
411
412 fn metadata(&self) -> &ModelMetadata {
413 &self.metadata
414 }
415}
416
417impl BertEmbeddings {
418 pub fn forward(&self, input_ids: &[u64]) -> Vec<f32> {
419 self.word_embeddings.forward(input_ids)
420 }
421}
422
423impl MultiHeadAttention {
424 pub fn forward(&self, x: &[f32]) -> Vec<f32> {
425 self.o_proj.forward(x)
426 }
427}
428
429pub enum SparseModel {
434 Llama(LlamaModel),
435 LFM2(LFM2Model),
436 Bert(BertModel),
437}
438
439impl ModelRunner for SparseModel {
440 fn forward(&self, input: &ModelInput, config: &InferenceConfig) -> Result<ModelOutput> {
441 match self {
442 Self::Llama(m) => m.forward(input, config),
443 Self::LFM2(m) => m.forward(input, config),
444 Self::Bert(m) => m.forward(input, config),
445 }
446 }
447
448 fn get_predictor(&self, layer_idx: usize) -> Option<&LowRankPredictor> {
449 match self {
450 Self::Llama(m) => m.get_predictor(layer_idx),
451 Self::LFM2(m) => m.get_predictor(layer_idx),
452 Self::Bert(m) => m.get_predictor(layer_idx),
453 }
454 }
455
456 fn calibrate(&mut self, samples: &[ModelInput]) -> Result<CalibrationStats> {
457 match self {
458 Self::Llama(m) => m.calibrate(samples),
459 Self::LFM2(m) => m.calibrate(samples),
460 Self::Bert(m) => m.calibrate(samples),
461 }
462 }
463
464 fn metadata(&self) -> &ModelMetadata {
465 match self {
466 Self::Llama(m) => m.metadata(),
467 Self::LFM2(m) => m.metadata(),
468 Self::Bert(m) => m.metadata(),
469 }
470 }
471}
472
473fn sparse_matmul(linear: &Linear, input: &[f32], active_cols: &[usize]) -> Vec<f32> {
478 let mut output = vec![0.0; active_cols.len()];
479
480 for (out_idx, &col_idx) in active_cols.iter().enumerate() {
481 if col_idx < linear.out_features {
482 for (in_idx, &x) in input.iter().enumerate() {
483 if in_idx < linear.in_features {
484 output[out_idx] += linear.weight[col_idx][in_idx] * x;
485 }
486 }
487 if let Some(ref bias) = linear.bias {
488 output[out_idx] += bias[col_idx];
489 }
490 }
491 }
492
493 output
494}
495
496fn sparse_matmul_full(linear: &Linear, input: &[f32], active_input_cols: &[usize]) -> Vec<f32> {
497 let mut output = vec![0.0; linear.out_features];
498
499 for out_idx in 0..linear.out_features {
500 for &in_idx in active_input_cols {
501 if in_idx < input.len() && in_idx < linear.in_features {
502 output[out_idx] += linear.weight[out_idx][in_idx] * input[in_idx];
503 }
504 }
505 if let Some(ref bias) = linear.bias {
506 output[out_idx] += bias[out_idx];
507 }
508 }
509
510 output
511}
512
513fn add_vectors(a: &[f32], b: &[f32]) -> Vec<f32> {
514 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn test_low_rank_predictor() {
523 let predictor = LowRankPredictor::new(128, 512, 16);
524 let input = vec![1.0; 128];
525 let active = predictor.predict_active(&input, 10);
526 assert_eq!(active.len(), 10);
527 }
528
529 #[test]
530 fn test_add_vectors() {
531 let a = vec![1.0, 2.0, 3.0];
532 let b = vec![4.0, 5.0, 6.0];
533 let result = add_vectors(&a, &b);
534 assert_eq!(result, vec![5.0, 7.0, 9.0]);
535 }
536}