ruvector_sparse_inference/predictor/
mod.rs

1//! Activation predictor module.
2//!
3//! This module provides predictors for determining which neurons will be active
4//! before performing the full computation.
5
6mod lowrank;
7
8pub use lowrank::LowRankPredictor;
9
10use crate::error::Result;
11
12/// Trait for activation predictors.
13pub trait Predictor: Send + Sync {
14    /// Predict active neurons for the given input.
15    ///
16    /// Returns a vector of neuron indices that are predicted to be active.
17    fn predict(&self, input: &[f32]) -> Result<Vec<usize>>;
18
19    /// Calibrate the predictor using sample data.
20    ///
21    /// # Arguments
22    /// * `samples` - Input samples
23    /// * `activations` - Corresponding activation patterns
24    fn calibrate(
25        &mut self,
26        samples: &[Vec<f32>],
27        activations: &[Vec<f32>],
28    ) -> Result<()>;
29
30    /// Get predictor statistics.
31    fn stats(&self) -> PredictorStats;
32}
33
34/// Alias for backward compatibility.
35pub trait NeuronPredictor: Predictor {}
36
37impl<T: Predictor> NeuronPredictor for T {}
38
39/// Dense predictor that returns all neurons (for baseline comparison).
40pub struct DensePredictor {
41    neuron_count: usize,
42}
43
44impl DensePredictor {
45    /// Create a new dense predictor.
46    pub fn new(neuron_count: usize) -> Self {
47        Self { neuron_count }
48    }
49}
50
51impl Predictor for DensePredictor {
52    fn predict(&self, _input: &[f32]) -> Result<Vec<usize>> {
53        Ok((0..self.neuron_count).collect())
54    }
55
56    fn calibrate(
57        &mut self,
58        _samples: &[Vec<f32>],
59        _activations: &[Vec<f32>],
60    ) -> Result<()> {
61        Ok(())
62    }
63
64    fn stats(&self) -> PredictorStats {
65        PredictorStats {
66            predictions: 0,
67            avg_active_neurons: self.neuron_count as f32,
68            avg_sparsity: 0.0,
69            is_calibrated: true,
70        }
71    }
72}
73
74/// Statistics about predictor performance.
75#[derive(Debug, Clone, Default)]
76pub struct PredictorStats {
77    /// Number of predictions made.
78    pub predictions: usize,
79
80    /// Average number of neurons predicted as active.
81    pub avg_active_neurons: f32,
82
83    /// Average sparsity ratio (1 - active/total).
84    pub avg_sparsity: f32,
85
86    /// Whether the predictor is calibrated.
87    pub is_calibrated: bool,
88}