ruvector_sparse_inference/predictor/
mod.rs1mod lowrank;
7
8pub use lowrank::LowRankPredictor;
9
10use crate::error::Result;
11
12pub trait Predictor: Send + Sync {
14 fn predict(&self, input: &[f32]) -> Result<Vec<usize>>;
18
19 fn calibrate(
25 &mut self,
26 samples: &[Vec<f32>],
27 activations: &[Vec<f32>],
28 ) -> Result<()>;
29
30 fn stats(&self) -> PredictorStats;
32}
33
34pub trait NeuronPredictor: Predictor {}
36
37impl<T: Predictor> NeuronPredictor for T {}
38
39pub struct DensePredictor {
41 neuron_count: usize,
42}
43
44impl DensePredictor {
45 pub fn new(neuron_count: usize) -> Self {
47 Self { neuron_count }
48 }
49}
50
51impl Predictor for DensePredictor {
52 fn predict(&self, _input: &[f32]) -> Result<Vec<usize>> {
53 Ok((0..self.neuron_count).collect())
54 }
55
56 fn calibrate(
57 &mut self,
58 _samples: &[Vec<f32>],
59 _activations: &[Vec<f32>],
60 ) -> Result<()> {
61 Ok(())
62 }
63
64 fn stats(&self) -> PredictorStats {
65 PredictorStats {
66 predictions: 0,
67 avg_active_neurons: self.neuron_count as f32,
68 avg_sparsity: 0.0,
69 is_calibrated: true,
70 }
71 }
72}
73
74#[derive(Debug, Clone, Default)]
76pub struct PredictorStats {
77 pub predictions: usize,
79
80 pub avg_active_neurons: f32,
82
83 pub avg_sparsity: f32,
85
86 pub is_calibrated: bool,
88}