ruvector_sparse_inference/model/
runners.rs

1//! Model runners for different architectures with sparse inference support
2
3use crate::error::SparseInferenceError;
4use crate::model::loader::{ModelLoader, ModelMetadata};
5use crate::model::types::{CalibrationStats, InferenceConfig, ModelInput, ModelOutput, Tensor};
6use crate::ops::{Linear, Embedding, RMSNorm, LayerNorm, silu};
7use std::collections::HashMap;
8
9type Result<T> = std::result::Result<T, SparseInferenceError>;
10
11/// Trait for running inference on models
12pub trait ModelRunner {
13    /// Forward pass with optional sparse computation
14    fn forward(&self, input: &ModelInput, config: &InferenceConfig) -> Result<ModelOutput>;
15
16    /// Get predictor for a specific layer (if available)
17    fn get_predictor(&self, layer_idx: usize) -> Option<&LowRankPredictor>;
18
19    /// Calibrate predictors with sample data
20    fn calibrate(&mut self, samples: &[ModelInput]) -> Result<CalibrationStats>;
21
22    /// Get model metadata
23    fn metadata(&self) -> &ModelMetadata;
24}
25
26/// Low-rank predictor for neuron activation prediction
27#[derive(Debug, Clone)]
28pub struct LowRankPredictor {
29    pub u: Vec<Vec<f32>>, // U matrix (d x r)
30    pub v: Vec<Vec<f32>>, // V matrix (r x m)
31    pub rank: usize,
32}
33
34impl LowRankPredictor {
35    pub fn new(input_dim: usize, output_dim: usize, rank: usize) -> Self {
36        Self {
37            u: vec![vec![0.0; rank]; input_dim],
38            v: vec![vec![0.0; output_dim]; rank],
39            rank,
40        }
41    }
42
43    /// Predict top-k active neurons
44    pub fn predict_active(&self, input: &[f32], k: usize) -> Vec<usize> {
45        let scores = self.forward(input);
46        let mut indices: Vec<usize> = (0..scores.len()).collect();
47        indices.sort_by(|&a, &b| scores[b].partial_cmp(&scores[a]).unwrap());
48        indices.truncate(k);
49        indices
50    }
51
52    fn forward(&self, input: &[f32]) -> Vec<f32> {
53        // Compute UV^T · input in two steps
54        // First: U^T · input (r-dimensional)
55        let mut hidden = vec![0.0; self.rank];
56        for i in 0..self.rank {
57            for (j, u_ji) in self.u.iter().enumerate() {
58                if j < input.len() && i < u_ji.len() {
59                    hidden[i] += u_ji[i] * input[j];
60                }
61            }
62        }
63
64        // Second: V · hidden (m-dimensional)
65        let output_dim = self.v.first().map(|v| v.len()).unwrap_or(0);
66        let mut output = vec![0.0; output_dim];
67        for i in 0..output_dim {
68            for (j, &h) in hidden.iter().enumerate() {
69                if j < self.v.len() && i < self.v[j].len() {
70                    output[i] += self.v[j][i] * h;
71                }
72            }
73        }
74
75        output
76    }
77}
78
79// ============================================================================
80// Llama Model
81// ============================================================================
82
83/// Llama model for sparse inference
84pub struct LlamaModel {
85    pub metadata: ModelMetadata,
86    pub layers: Vec<LlamaLayer>,
87    pub embed_tokens: Embedding,
88    pub norm: RMSNorm,
89    pub lm_head: Option<Linear>,
90}
91
92pub struct LlamaLayer {
93    pub input_layernorm: RMSNorm,
94    pub self_attn: LlamaAttention,
95    pub post_attention_layernorm: RMSNorm,
96    pub mlp: LlamaMLP,
97    pub predictor: Option<LowRankPredictor>,
98}
99
100pub struct LlamaAttention {
101    pub q_proj: Linear,
102    pub k_proj: Linear,
103    pub v_proj: Linear,
104    pub o_proj: Linear,
105    pub num_heads: usize,
106    pub head_dim: usize,
107}
108
109pub struct LlamaMLP {
110    pub gate_proj: Linear,  // W1 for SwiGLU gate
111    pub up_proj: Linear,    // W3 for SwiGLU up
112    pub down_proj: Linear,  // W2 for down projection
113}
114
115impl LlamaMLP {
116    /// Standard forward pass (dense)
117    pub fn forward(&self, x: &[f32]) -> Vec<f32> {
118        let gate = self.gate_proj.forward(x);
119        let up = self.up_proj.forward(x);
120
121        // SwiGLU: silu(gate) ⊙ up
122        let hidden: Vec<f32> = gate
123            .iter()
124            .zip(up.iter())
125            .map(|(&g, &u)| silu(g) * u)
126            .collect();
127
128        self.down_proj.forward(&hidden)
129    }
130
131    /// Sparse forward pass using predictor
132    pub fn forward_sparse(
133        &self,
134        x: &[f32],
135        active_neurons: &[usize],
136    ) -> Vec<f32> {
137        // Only compute for active neurons in intermediate layer
138        let gate = sparse_matmul(&self.gate_proj, x, active_neurons);
139        let up = sparse_matmul(&self.up_proj, x, active_neurons);
140
141        // SwiGLU on active neurons only
142        let hidden: Vec<f32> = gate
143            .iter()
144            .zip(up.iter())
145            .map(|(&g, &u)| silu(g) * u)
146            .collect();
147
148        // Sparse down projection
149        sparse_matmul_full(&self.down_proj, &hidden, active_neurons)
150    }
151}
152
153impl ModelRunner for LlamaModel {
154    fn forward(&self, input: &ModelInput, config: &InferenceConfig) -> Result<ModelOutput> {
155        // Embed tokens
156        let mut hidden_states = self.embed_tokens.forward(&input.input_ids);
157
158        let mut all_hidden_states = if config.output_hidden_states {
159            Some(Vec::new())
160        } else {
161            None
162        };
163
164        // Process each layer
165        for (idx, layer) in self.layers.iter().enumerate() {
166            if let Some(ref mut states) = all_hidden_states {
167                states.push(hidden_states.clone());
168            }
169
170            // Layer norm
171            let normed = layer.input_layernorm.forward(&hidden_states);
172
173            // Self-attention (simplified, no KV cache)
174            let attn_output = layer.self_attn.forward(&normed);
175
176            // Residual
177            hidden_states = add_vectors(&hidden_states, &attn_output);
178
179            // Post-attention norm
180            let normed = layer.post_attention_layernorm.forward(&hidden_states);
181
182            // MLP with optional sparsity
183            let mlp_output = if config.use_sparse_ffn {
184                if let Some(ref predictor) = layer.predictor {
185                    let k = config.active_neurons_per_layer.unwrap_or(
186                        (self.metadata.intermediate_size as f32 * (1.0 - config.sparsity)) as usize,
187                    );
188                    let active = predictor.predict_active(&normed, k);
189                    layer.mlp.forward_sparse(&normed, &active)
190                } else {
191                    layer.mlp.forward(&normed)
192                }
193            } else {
194                layer.mlp.forward(&normed)
195            };
196
197            // Residual
198            hidden_states = add_vectors(&hidden_states, &mlp_output);
199        }
200
201        // Final norm
202        hidden_states = self.norm.forward(&hidden_states);
203
204        // LM head
205        let logits = if let Some(ref lm_head) = self.lm_head {
206            lm_head.forward(&hidden_states)
207        } else {
208            hidden_states
209        };
210
211        Ok(ModelOutput::new(logits).with_hidden_states(all_hidden_states.unwrap_or_default()))
212    }
213
214    fn get_predictor(&self, layer_idx: usize) -> Option<&LowRankPredictor> {
215        self.layers.get(layer_idx)?.predictor.as_ref()
216    }
217
218    fn calibrate(&mut self, samples: &[ModelInput]) -> Result<CalibrationStats> {
219        // Placeholder: would collect activation statistics
220        Ok(CalibrationStats {
221            num_samples: samples.len(),
222            average_sparsity: 0.9,
223            layer_stats: HashMap::new(),
224        })
225    }
226
227    fn metadata(&self) -> &ModelMetadata {
228        &self.metadata
229    }
230}
231
232impl LlamaAttention {
233    pub fn forward(&self, hidden_states: &[f32]) -> Vec<f32> {
234        // Simplified: full attention without KV cache
235        let q = self.q_proj.forward(hidden_states);
236        let k = self.k_proj.forward(hidden_states);
237        let v = self.v_proj.forward(hidden_states);
238
239        // Placeholder: would do scaled dot-product attention
240        self.o_proj.forward(&q)
241    }
242}
243
244// ============================================================================
245// LFM2 Model (Liquid AI)
246// ============================================================================
247
248pub struct LFM2Model {
249    pub metadata: ModelMetadata,
250    pub embedding: Embedding,
251    pub layers: Vec<LFM2Layer>,
252    pub pooler: Option<Pooler>,
253}
254
255pub struct LFM2Layer {
256    pub gated_conv: GatedConv1d,
257    pub attention: GroupedQueryAttention,
258    pub ffn: SparseFfn,
259    pub norm: LayerNorm,
260}
261
262pub struct GatedConv1d {
263    pub weight: Vec<Vec<f32>>,
264    pub gate: Linear,
265}
266
267pub struct GroupedQueryAttention {
268    pub q_proj: Linear,
269    pub k_proj: Linear,
270    pub v_proj: Linear,
271    pub o_proj: Linear,
272    pub num_groups: usize,
273}
274
275pub struct SparseFfn {
276    pub w1: Linear,
277    pub w2: Linear,
278    pub predictor: Option<LowRankPredictor>,
279}
280
281impl ModelRunner for LFM2Model {
282    fn forward(&self, input: &ModelInput, config: &InferenceConfig) -> Result<ModelOutput> {
283        let mut hidden = self.embedding.forward(&input.input_ids);
284
285        for layer in &self.layers {
286            // Gated convolution for local context
287            hidden = layer.gated_conv.forward(&hidden);
288
289            // Grouped query attention
290            let attn_out = layer.attention.forward(&hidden);
291            hidden = add_vectors(&hidden, &attn_out);
292
293            // Sparse FFN
294            let ffn_out = layer.ffn.forward(&hidden, config);
295            hidden = add_vectors(&hidden, &ffn_out);
296
297            hidden = layer.norm.forward(&hidden);
298        }
299
300        Ok(ModelOutput::new(hidden))
301    }
302
303    fn get_predictor(&self, layer_idx: usize) -> Option<&LowRankPredictor> {
304        self.layers.get(layer_idx)?.ffn.predictor.as_ref()
305    }
306
307    fn calibrate(&mut self, _samples: &[ModelInput]) -> Result<CalibrationStats> {
308        Ok(CalibrationStats {
309            num_samples: 0,
310            average_sparsity: 0.9,
311            layer_stats: HashMap::new(),
312        })
313    }
314
315    fn metadata(&self) -> &ModelMetadata {
316        &self.metadata
317    }
318}
319
320impl GatedConv1d {
321    pub fn forward(&self, x: &[f32]) -> Vec<f32> {
322        // Simplified convolution
323        x.to_vec()
324    }
325}
326
327impl GroupedQueryAttention {
328    pub fn forward(&self, x: &[f32]) -> Vec<f32> {
329        self.o_proj.forward(x)
330    }
331}
332
333impl SparseFfn {
334    pub fn forward(&self, x: &[f32], config: &InferenceConfig) -> Vec<f32> {
335        if config.use_sparse_ffn {
336            if let Some(ref predictor) = self.predictor {
337                let k = (self.w1.out_features as f32 * (1.0 - config.sparsity)) as usize;
338                let active = predictor.predict_active(x, k);
339                return sparse_matmul_full(&self.w2, &self.w1.forward(x), &active);
340            }
341        }
342        self.w2.forward(&self.w1.forward(x))
343    }
344}
345
346// ============================================================================
347// BERT Model
348// ============================================================================
349
350pub struct BertModel {
351    pub metadata: ModelMetadata,
352    pub embeddings: BertEmbeddings,
353    pub encoder: Vec<BertLayer>,
354    pub pooler: Option<Pooler>,
355}
356
357pub struct BertEmbeddings {
358    pub word_embeddings: Embedding,
359    pub position_embeddings: Embedding,
360    pub token_type_embeddings: Embedding,
361    pub layer_norm: LayerNorm,
362}
363
364pub struct BertLayer {
365    pub attention: MultiHeadAttention,
366    pub intermediate: Linear,
367    pub output: Linear,
368    pub layer_norm1: LayerNorm,
369    pub layer_norm2: LayerNorm,
370}
371
372pub struct MultiHeadAttention {
373    pub q_proj: Linear,
374    pub k_proj: Linear,
375    pub v_proj: Linear,
376    pub o_proj: Linear,
377    pub num_heads: usize,
378}
379
380pub struct Pooler {
381    pub dense: Linear,
382}
383
384impl ModelRunner for BertModel {
385    fn forward(&self, input: &ModelInput, config: &InferenceConfig) -> Result<ModelOutput> {
386        let mut hidden = self.embeddings.forward(&input.input_ids);
387
388        for layer in &self.encoder {
389            let attn_out = layer.attention.forward(&hidden);
390            hidden = layer.layer_norm1.forward(&add_vectors(&hidden, &attn_out));
391
392            let intermediate = layer.intermediate.forward(&hidden);
393            let output = layer.output.forward(&intermediate);
394            hidden = layer.layer_norm2.forward(&add_vectors(&hidden, &output));
395        }
396
397        Ok(ModelOutput::new(hidden))
398    }
399
400    fn get_predictor(&self, _layer_idx: usize) -> Option<&LowRankPredictor> {
401        None
402    }
403
404    fn calibrate(&mut self, _samples: &[ModelInput]) -> Result<CalibrationStats> {
405        Ok(CalibrationStats {
406            num_samples: 0,
407            average_sparsity: 0.0,
408            layer_stats: HashMap::new(),
409        })
410    }
411
412    fn metadata(&self) -> &ModelMetadata {
413        &self.metadata
414    }
415}
416
417impl BertEmbeddings {
418    pub fn forward(&self, input_ids: &[u64]) -> Vec<f32> {
419        self.word_embeddings.forward(input_ids)
420    }
421}
422
423impl MultiHeadAttention {
424    pub fn forward(&self, x: &[f32]) -> Vec<f32> {
425        self.o_proj.forward(x)
426    }
427}
428
429// ============================================================================
430// Unified Model Wrapper
431// ============================================================================
432
433pub enum SparseModel {
434    Llama(LlamaModel),
435    LFM2(LFM2Model),
436    Bert(BertModel),
437}
438
439impl ModelRunner for SparseModel {
440    fn forward(&self, input: &ModelInput, config: &InferenceConfig) -> Result<ModelOutput> {
441        match self {
442            Self::Llama(m) => m.forward(input, config),
443            Self::LFM2(m) => m.forward(input, config),
444            Self::Bert(m) => m.forward(input, config),
445        }
446    }
447
448    fn get_predictor(&self, layer_idx: usize) -> Option<&LowRankPredictor> {
449        match self {
450            Self::Llama(m) => m.get_predictor(layer_idx),
451            Self::LFM2(m) => m.get_predictor(layer_idx),
452            Self::Bert(m) => m.get_predictor(layer_idx),
453        }
454    }
455
456    fn calibrate(&mut self, samples: &[ModelInput]) -> Result<CalibrationStats> {
457        match self {
458            Self::Llama(m) => m.calibrate(samples),
459            Self::LFM2(m) => m.calibrate(samples),
460            Self::Bert(m) => m.calibrate(samples),
461        }
462    }
463
464    fn metadata(&self) -> &ModelMetadata {
465        match self {
466            Self::Llama(m) => m.metadata(),
467            Self::LFM2(m) => m.metadata(),
468            Self::Bert(m) => m.metadata(),
469        }
470    }
471}
472
473// ============================================================================
474// Helper Functions
475// ============================================================================
476
477fn sparse_matmul(linear: &Linear, input: &[f32], active_cols: &[usize]) -> Vec<f32> {
478    let mut output = vec![0.0; active_cols.len()];
479
480    for (out_idx, &col_idx) in active_cols.iter().enumerate() {
481        if col_idx < linear.out_features {
482            for (in_idx, &x) in input.iter().enumerate() {
483                if in_idx < linear.in_features {
484                    output[out_idx] += linear.weight[col_idx][in_idx] * x;
485                }
486            }
487            if let Some(ref bias) = linear.bias {
488                output[out_idx] += bias[col_idx];
489            }
490        }
491    }
492
493    output
494}
495
496fn sparse_matmul_full(linear: &Linear, input: &[f32], active_input_cols: &[usize]) -> Vec<f32> {
497    let mut output = vec![0.0; linear.out_features];
498
499    for out_idx in 0..linear.out_features {
500        for &in_idx in active_input_cols {
501            if in_idx < input.len() && in_idx < linear.in_features {
502                output[out_idx] += linear.weight[out_idx][in_idx] * input[in_idx];
503            }
504        }
505        if let Some(ref bias) = linear.bias {
506            output[out_idx] += bias[out_idx];
507        }
508    }
509
510    output
511}
512
513fn add_vectors(a: &[f32], b: &[f32]) -> Vec<f32> {
514    a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[test]
522    fn test_low_rank_predictor() {
523        let predictor = LowRankPredictor::new(128, 512, 16);
524        let input = vec![1.0; 128];
525        let active = predictor.predict_active(&input, 10);
526        assert_eq!(active.len(), 10);
527    }
528
529    #[test]
530    fn test_add_vectors() {
531        let a = vec![1.0, 2.0, 3.0];
532        let b = vec![4.0, 5.0, 6.0];
533        let result = add_vectors(&a, &b);
534        assert_eq!(result, vec![5.0, 7.0, 9.0]);
535    }
536}