ruvector_sparse_inference/predictor/
lowrank.rs1use ndarray::{Array1, Array2, Axis};
4use serde::{Deserialize, Serialize};
5use tracing::{debug, trace};
6
7use crate::config::SparsityConfig;
8use crate::error::{PredictorError, Result};
9use super::{Predictor, PredictorStats};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct LowRankPredictor {
24 p_matrix: Array2<f32>,
26
27 q_matrix: Array2<f32>,
29
30 config: SparsityConfig,
32
33 #[serde(skip)]
35 stats: PredictorStats,
36}
37
38impl LowRankPredictor {
39 pub fn new(
41 input_dim: usize,
42 hidden_dim: usize,
43 rank: usize,
44 config: SparsityConfig,
45 ) -> Result<Self> {
46 if rank == 0 || rank > input_dim.min(hidden_dim) {
47 return Err(PredictorError::InvalidRank(rank).into());
48 }
49
50 config.validate()
51 .map_err(|e| PredictorError::InvalidConfig(e))?;
52
53 use rand::Rng;
55 use rand::distributions::Uniform;
56 use rand::distributions::Distribution;
57
58 let dist = Uniform::new(-0.01f32, 0.01f32);
59 let mut rng = rand::thread_rng();
60
61 let p_data: Vec<f32> = (0..rank * input_dim)
62 .map(|_| dist.sample(&mut rng))
63 .collect();
64 let p_matrix = Array2::from_shape_vec((rank, input_dim), p_data)
65 .map_err(|e| PredictorError::InvalidConfig(e.to_string()))?;
66
67 let q_data: Vec<f32> = (0..hidden_dim * rank)
68 .map(|_| dist.sample(&mut rng))
69 .collect();
70 let q_matrix = Array2::from_shape_vec((hidden_dim, rank), q_data)
71 .map_err(|e| PredictorError::InvalidConfig(e.to_string()))?;
72
73 Ok(Self {
74 p_matrix,
75 q_matrix,
76 config,
77 stats: PredictorStats {
78 is_calibrated: false,
79 ..Default::default()
80 },
81 })
82 }
83
84 pub fn from_matrices(
86 p_matrix: Array2<f32>,
87 q_matrix: Array2<f32>,
88 config: SparsityConfig,
89 ) -> Result<Self> {
90 let (rank, input_dim) = p_matrix.dim();
91 let (hidden_dim, q_rank) = q_matrix.dim();
92
93 if rank != q_rank {
94 return Err(PredictorError::InvalidConfig(
95 format!("Rank mismatch: P has rank {}, Q has rank {}", rank, q_rank)
96 ).into());
97 }
98
99 config.validate()
100 .map_err(|e| PredictorError::InvalidConfig(e))?;
101
102 Ok(Self {
103 p_matrix,
104 q_matrix,
105 config,
106 stats: PredictorStats {
107 is_calibrated: true,
108 ..Default::default()
109 },
110 })
111 }
112
113 pub fn rank(&self) -> usize {
115 self.p_matrix.nrows()
116 }
117
118 pub fn input_dim(&self) -> usize {
120 self.p_matrix.ncols()
121 }
122
123 pub fn hidden_dim(&self) -> usize {
125 self.q_matrix.nrows()
126 }
127
128 fn compute_scores(&self, input: &[f32]) -> Result<Array1<f32>> {
130 if input.len() != self.input_dim() {
131 return Err(PredictorError::DimensionMismatch {
132 expected: self.input_dim(),
133 actual: input.len(),
134 }.into());
135 }
136
137 let input_vec = Array1::from_vec(input.to_vec());
139
140 trace!("Compressing input from {} to {} dimensions", input.len(), self.rank());
142 let compressed = self.p_matrix.dot(&input_vec);
143
144 trace!("Scoring {} neurons", self.hidden_dim());
146 let scores = self.q_matrix.dot(&compressed);
147
148 Ok(scores)
149 }
150
151 fn select_active_neurons(&self, scores: &Array1<f32>) -> Vec<usize> {
153 if let Some(k) = self.config.top_k {
154 self.select_top_k(scores, k)
156 } else if let Some(threshold) = self.config.threshold {
157 self.select_by_threshold(scores, threshold)
159 } else {
160 vec![]
162 }
163 }
164
165 fn select_top_k(&self, scores: &Array1<f32>, k: usize) -> Vec<usize> {
167 let mut indexed_scores: Vec<(usize, f32)> = scores
168 .iter()
169 .enumerate()
170 .map(|(i, &s)| (i, s))
171 .collect();
172
173 let len = indexed_scores.len();
175 if len == 0 {
176 return vec![];
177 }
178
179 indexed_scores.select_nth_unstable_by(
181 k.min(len - 1),
182 |a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
183 );
184
185 indexed_scores.truncate(k);
186 indexed_scores.sort_by_key(|(i, _)| *i);
187 indexed_scores.into_iter().map(|(i, _)| i).collect()
188 }
189
190 fn select_by_threshold(&self, scores: &Array1<f32>, threshold: f32) -> Vec<usize> {
192 scores
193 .iter()
194 .enumerate()
195 .filter(|(_, &s)| s > threshold)
196 .map(|(i, _)| i)
197 .collect()
198 }
199
200 fn update_stats(&mut self, active_count: usize) {
202 self.stats.predictions += 1;
203
204 let n = self.stats.predictions as f32;
205 let prev_avg = self.stats.avg_active_neurons;
206 self.stats.avg_active_neurons =
207 (prev_avg * (n - 1.0) + active_count as f32) / n;
208
209 let sparsity = 1.0 - (active_count as f32 / self.hidden_dim() as f32);
210 let prev_sparsity = self.stats.avg_sparsity;
211 self.stats.avg_sparsity =
212 (prev_sparsity * (n - 1.0) + sparsity) / n;
213 }
214}
215
216impl Predictor for LowRankPredictor {
217 fn predict(&self, input: &[f32]) -> Result<Vec<usize>> {
218 let scores = self.compute_scores(input)?;
219 let active = self.select_active_neurons(&scores);
220
221 trace!("Predicted {} active neurons (sparsity: {:.2}%)",
222 active.len(),
223 100.0 * (1.0 - active.len() as f32 / self.hidden_dim() as f32)
224 );
225
226 Ok(active)
227 }
228
229 fn calibrate(
230 &mut self,
231 samples: &[Vec<f32>],
232 activations: &[Vec<f32>],
233 ) -> Result<()> {
234 if samples.is_empty() || activations.is_empty() {
235 return Err(PredictorError::CalibrationFailed(
236 "Empty samples or activations".to_string()
237 ).into());
238 }
239
240 if samples.len() != activations.len() {
241 return Err(PredictorError::CalibrationFailed(
242 format!("Sample count ({}) != activation count ({})",
243 samples.len(), activations.len())
244 ).into());
245 }
246
247 debug!("Calibrating predictor with {} samples", samples.len());
248
249 let n_samples = samples.len();
251 let input_dim = self.input_dim();
252 let hidden_dim = self.hidden_dim();
253
254 let mut x_data = Vec::with_capacity(n_samples * input_dim);
256 for sample in samples {
257 if sample.len() != input_dim {
258 return Err(PredictorError::DimensionMismatch {
259 expected: input_dim,
260 actual: sample.len(),
261 }.into());
262 }
263 x_data.extend_from_slice(sample);
264 }
265 let x = Array2::from_shape_vec((n_samples, input_dim), x_data)
266 .map_err(|e| PredictorError::CalibrationFailed(e.to_string()))?;
267
268 let mut y_data = Vec::with_capacity(n_samples * hidden_dim);
270 for activation in activations {
271 if activation.len() != hidden_dim {
272 return Err(PredictorError::DimensionMismatch {
273 expected: hidden_dim,
274 actual: activation.len(),
275 }.into());
276 }
277 y_data.extend_from_slice(activation);
278 }
279 let y = Array2::from_shape_vec((n_samples, hidden_dim), y_data)
280 .map_err(|e| PredictorError::CalibrationFailed(e.to_string()))?;
281
282 let c = x.t().dot(&y) / (n_samples as f32);
292
293 self.stats.is_calibrated = true;
297 debug!("Calibration complete");
298
299 Ok(())
300 }
301
302 fn stats(&self) -> PredictorStats {
303 self.stats.clone()
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_predictor_creation() {
313 let config = SparsityConfig::with_top_k(100);
314 let predictor = LowRankPredictor::new(128, 512, 64, config).unwrap();
315
316 assert_eq!(predictor.input_dim(), 128);
317 assert_eq!(predictor.hidden_dim(), 512);
318 assert_eq!(predictor.rank(), 64);
319 }
320
321 #[test]
322 fn test_prediction() {
323 let config = SparsityConfig::with_top_k(50);
324 let predictor = LowRankPredictor::new(128, 512, 64, config).unwrap();
325
326 let input = vec![0.1; 128];
327 let active = predictor.predict(&input).unwrap();
328
329 assert_eq!(active.len(), 50);
330
331 for i in 1..active.len() {
333 assert!(active[i] > active[i-1]);
334 }
335 }
336
337 #[test]
338 fn test_threshold_selection() {
339 let config = SparsityConfig::with_threshold(0.0); let predictor = LowRankPredictor::new(128, 512, 64, config).unwrap();
343
344 let input = vec![100.0; 128];
346 let active = predictor.predict(&input).unwrap();
347
348 assert!(active.len() <= 512); }
354
355 #[test]
356 fn test_dimension_mismatch() {
357 let config = SparsityConfig::with_top_k(50);
358 let predictor = LowRankPredictor::new(128, 512, 64, config).unwrap();
359
360 let input = vec![0.1; 64]; let result = predictor.predict(&input);
362
363 assert!(result.is_err());
364 }
365}