Skip to main content

vsa_optim_rs/prediction/
predictor.rs

1//! Gradient predictor implementation.
2//!
3//! Predicts future gradients from history using momentum-based extrapolation.
4
5use std::collections::{HashMap, VecDeque};
6
7use candle_core::{DType, Device, Tensor};
8
9use crate::config::PredictionConfig;
10use crate::error::Result;
11
12/// Predict future gradients from history.
13///
14/// Gradient prediction reduces compute by ~80% (4 predicted steps
15/// per 1 computed step) while maintaining convergence quality through
16/// periodic correction cycles.
17///
18/// The predictor maintains a history of recent gradients and uses a
19/// momentum-based extrapolation to predict future gradients. Corrections
20/// are computed as the difference between predicted and actual gradients.
21///
22/// # Example
23///
24/// ```ignore
25/// use vsa_optim_rs::prediction::GradientPredictor;
26/// use vsa_optim_rs::PredictionConfig;
27///
28/// let shapes = vec![
29///     ("layer1.weight".to_string(), vec![64, 128]),
30/// ];
31/// let mut predictor = GradientPredictor::new(&shapes, PredictionConfig::default(), &Device::Cpu)?;
32///
33/// // Training loop
34/// for step in 0..total_steps {
35///     if predictor.should_compute_full() {
36///         // loss.backward() - compute full gradients
37///         predictor.record_gradient(&gradients)?;
38///         predictor.apply_correction(&mut gradients);
39///     } else {
40///         let predicted = predictor.predict_gradient()?;
41///         // Use predicted gradients for optimizer step
42///     }
43/// }
44/// ```
45pub struct GradientPredictor {
46    config: PredictionConfig,
47    device: Device,
48
49    /// Gradient history per parameter (circular buffer).
50    gradient_history: HashMap<String, VecDeque<Tensor>>,
51
52    /// Original shapes for reconstruction.
53    shapes: HashMap<String, Vec<usize>>,
54
55    /// Steps since last full gradient computation.
56    steps_since_full: usize,
57
58    /// Total training steps.
59    total_steps: usize,
60
61    /// Last predicted gradients.
62    last_prediction: HashMap<String, Tensor>,
63
64    /// Accumulated corrections.
65    correction_accumulator: HashMap<String, Tensor>,
66
67    /// Recent prediction errors for adaptive prediction.
68    prediction_errors: VecDeque<f32>,
69}
70
71impl GradientPredictor {
72    /// Create a new gradient predictor.
73    ///
74    /// # Arguments
75    ///
76    /// * `param_shapes` - List of (name, shape) tuples for parameters
77    /// * `config` - Prediction configuration
78    /// * `device` - Device for tensor storage
79    ///
80    /// # Errors
81    ///
82    /// Returns error if initialization fails.
83    pub fn new(
84        param_shapes: &[(String, Vec<usize>)],
85        config: PredictionConfig,
86        device: &Device,
87    ) -> Result<Self> {
88        let mut gradient_history = HashMap::new();
89        let mut shapes = HashMap::new();
90
91        for (name, shape) in param_shapes {
92            gradient_history.insert(name.clone(), VecDeque::with_capacity(config.history_size));
93            shapes.insert(name.clone(), shape.clone());
94        }
95
96        Ok(Self {
97            config,
98            device: device.clone(),
99            gradient_history,
100            shapes,
101            steps_since_full: 0,
102            total_steps: 0,
103            last_prediction: HashMap::new(),
104            correction_accumulator: HashMap::new(),
105            prediction_errors: VecDeque::with_capacity(100),
106        })
107    }
108
109    /// Check if full gradient computation is needed.
110    ///
111    /// Full computation is needed:
112    /// 1. At the start (insufficient history)
113    /// 2. After `prediction_steps` predicted steps (correction cycle)
114    /// 3. When prediction quality degrades below threshold
115    #[must_use]
116    pub fn should_compute_full(&self) -> bool {
117        // Need full gradient at start for history
118        let any_history = self.gradient_history.values().next();
119        if let Some(history) = any_history {
120            if history.len() < 2 {
121                return true;
122            }
123        } else {
124            return true;
125        }
126
127        // Need correction after prediction_steps
128        if self.steps_since_full >= self.config.prediction_steps {
129            return true;
130        }
131
132        // Check if prediction quality is poor
133        if self.prediction_errors.len() >= 10 {
134            let recent: f32 = self.prediction_errors.iter().rev().take(10).sum::<f32>() / 10.0;
135            if recent > 0.5 {
136                return true;
137            }
138        }
139
140        false
141    }
142
143    /// Record current gradients to history.
144    ///
145    /// Called after full gradient computation to update history.
146    ///
147    /// # Arguments
148    ///
149    /// * `gradients` - Map of parameter names to gradient tensors
150    ///
151    /// # Errors
152    ///
153    /// Returns error if tensor cloning fails.
154    pub fn record_gradient(&mut self, gradients: &HashMap<String, Tensor>) -> Result<()> {
155        for (name, grad) in gradients {
156            if let Some(history) = self.gradient_history.get_mut(name) {
157                // Maintain max history size
158                if history.len() >= self.config.history_size {
159                    history.pop_front();
160                }
161                history.push_back(grad.clone());
162            }
163        }
164
165        self.steps_since_full = 0;
166        self.total_steps += 1;
167        Ok(())
168    }
169
170    /// Predict gradients based on history.
171    ///
172    /// Uses momentum-based extrapolation from gradient history:
173    /// ```text
174    /// g_pred = g[-1] + momentum * (g[-1] - g[-2])
175    /// ```
176    ///
177    /// # Returns
178    ///
179    /// Dictionary mapping parameter names to predicted gradients.
180    ///
181    /// # Errors
182    ///
183    /// Returns error if tensor operations fail.
184    pub fn predict_gradient(&mut self) -> Result<HashMap<String, Tensor>> {
185        let mut predicted = HashMap::new();
186        let momentum = self.config.momentum;
187
188        for (name, history) in &self.gradient_history {
189            let prediction = match history.len() {
190                0 => {
191                    // No history, create zeros
192                    if let Some(shape) = self.shapes.get(name) {
193                        Tensor::zeros(shape.as_slice(), DType::F32, &self.device)?
194                    } else {
195                        continue;
196                    }
197                }
198                1 => {
199                    // Single history entry, use as-is
200                    history.back().unwrap().clone()
201                }
202                _ => {
203                    // Momentum-based extrapolation
204                    let g_prev = &history[history.len() - 2];
205                    let g_curr = history.back().unwrap();
206
207                    // delta = g_curr - g_prev
208                    let delta = g_curr.sub(g_prev)?;
209
210                    // g_pred = g_curr + momentum * delta
211                    let scaled_delta = (&delta * momentum as f64)?;
212                    g_curr.add(&scaled_delta)?
213                }
214            };
215
216            predicted.insert(name.clone(), prediction);
217        }
218
219        self.last_prediction = predicted.clone();
220        self.steps_since_full += 1;
221        self.total_steps += 1;
222
223        Ok(predicted)
224    }
225
226    /// Compute correction between predicted and actual gradients.
227    ///
228    /// The correction term captures the prediction error and is
229    /// accumulated to apply a "catch-up" adjustment.
230    ///
231    /// # Arguments
232    ///
233    /// * `actual_gradients` - The actual computed gradients
234    ///
235    /// # Returns
236    ///
237    /// Dictionary of correction terms.
238    ///
239    /// # Errors
240    ///
241    /// Returns error if tensor operations fail.
242    pub fn compute_correction(
243        &mut self,
244        actual_gradients: &HashMap<String, Tensor>,
245    ) -> Result<HashMap<String, Tensor>> {
246        let mut corrections = HashMap::new();
247
248        for (name, actual) in actual_gradients {
249            if let Some(predicted) = self.last_prediction.get(name) {
250                // Correction = actual - predicted
251                let correction = actual.sub(predicted)?;
252
253                // Track prediction error
254                let error = correction
255                    .abs()?
256                    .mean_all()?
257                    .to_scalar::<f32>()?;
258
259                if self.prediction_errors.len() >= 100 {
260                    self.prediction_errors.pop_front();
261                }
262                self.prediction_errors.push_back(error);
263
264                // Accumulate for later application
265                if let Some(existing) = self.correction_accumulator.get(name) {
266                    self.correction_accumulator
267                        .insert(name.clone(), existing.add(&correction)?);
268                } else {
269                    self.correction_accumulator
270                        .insert(name.clone(), correction.clone());
271                }
272
273                corrections.insert(name.clone(), correction);
274            }
275        }
276
277        Ok(corrections)
278    }
279
280    /// Apply accumulated corrections to gradients.
281    ///
282    /// After computing full gradients, adds the accumulated
283    /// correction to account for prediction errors from previous
284    /// predicted steps.
285    ///
286    /// # Arguments
287    ///
288    /// * `gradients` - Mutable map of gradients to modify in-place
289    ///
290    /// # Errors
291    ///
292    /// Returns error if tensor operations fail.
293    pub fn apply_correction(
294        &mut self,
295        gradients: &mut HashMap<String, Tensor>,
296    ) -> Result<()> {
297        let weight = self.config.correction_weight;
298
299        for (name, grad) in gradients.iter_mut() {
300            if let Some(correction) = self.correction_accumulator.get(name) {
301                // Add weighted correction: grad += weight * correction
302                let scaled = (correction * weight as f64)?;
303                *grad = grad.add(&scaled)?;
304            }
305        }
306
307        // Clear accumulator
308        self.correction_accumulator.clear();
309        Ok(())
310    }
311
312    /// Get prediction statistics.
313    #[must_use]
314    #[allow(clippy::cast_precision_loss)]
315    pub fn get_stats(&self) -> PredictorStats {
316        let mean_error = if !self.prediction_errors.is_empty() {
317            self.prediction_errors.iter().sum::<f32>() / self.prediction_errors.len() as f32
318        } else {
319            0.0
320        };
321
322        let recent_error = if self.prediction_errors.len() >= 10 {
323            self.prediction_errors.iter().rev().take(10).sum::<f32>() / 10.0
324        } else if !self.prediction_errors.is_empty() {
325            self.prediction_errors.iter().sum::<f32>() / self.prediction_errors.len() as f32
326        } else {
327            0.0
328        };
329
330        let prediction_ratio = 1.0 - (1.0 / (self.config.prediction_steps + 1) as f32);
331
332        PredictorStats {
333            total_steps: self.total_steps,
334            prediction_ratio,
335            mean_error,
336            recent_error,
337            history_size: self.gradient_history.values().next().map_or(0, |h| h.len()),
338        }
339    }
340
341    /// Get total steps.
342    #[must_use]
343    pub const fn total_steps(&self) -> usize {
344        self.total_steps
345    }
346
347    /// Reset predictor state.
348    pub fn reset(&mut self) {
349        for history in self.gradient_history.values_mut() {
350            history.clear();
351        }
352        self.steps_since_full = 0;
353        self.total_steps = 0;
354        self.last_prediction.clear();
355        self.correction_accumulator.clear();
356        self.prediction_errors.clear();
357    }
358}
359
360/// Prediction statistics.
361#[derive(Debug, Clone)]
362pub struct PredictorStats {
363    /// Total training steps.
364    pub total_steps: usize,
365    /// Fraction of steps using prediction (0 to 1).
366    pub prediction_ratio: f32,
367    /// Mean prediction error across all history.
368    pub mean_error: f32,
369    /// Recent prediction error (last 10 steps).
370    pub recent_error: f32,
371    /// Current history size.
372    pub history_size: usize,
373}
374
375impl std::fmt::Display for PredictorStats {
376    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377        write!(
378            f,
379            "Steps: {} | Prediction ratio: {:.1}% | Mean error: {:.4} | Recent error: {:.4}",
380            self.total_steps,
381            self.prediction_ratio * 100.0,
382            self.mean_error,
383            self.recent_error
384        )
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    fn create_param_shapes() -> Vec<(String, Vec<usize>)> {
393        vec![
394            ("layer1.weight".to_string(), vec![64, 128]),
395            ("layer1.bias".to_string(), vec![64]),
396        ]
397    }
398
399    fn create_mock_gradients(device: &Device) -> HashMap<String, Tensor> {
400        let mut gradients = HashMap::new();
401        gradients.insert(
402            "layer1.weight".to_string(),
403            Tensor::randn(0.0f32, 0.1, (64, 128), device).unwrap(),
404        );
405        gradients.insert(
406            "layer1.bias".to_string(),
407            Tensor::randn(0.0f32, 0.1, 64, device).unwrap(),
408        );
409        gradients
410    }
411
412    #[test]
413    fn test_predictor_creation() {
414        let shapes = create_param_shapes();
415        let device = Device::Cpu;
416        let config = PredictionConfig::default();
417
418        let predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
419        assert_eq!(predictor.total_steps(), 0);
420    }
421
422    #[test]
423    fn test_should_compute_full_initially() {
424        let shapes = create_param_shapes();
425        let device = Device::Cpu;
426        let config = PredictionConfig::default();
427
428        let predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
429
430        // Should compute full at start (no history)
431        assert!(predictor.should_compute_full());
432    }
433
434    #[test]
435    fn test_record_gradient() {
436        let shapes = create_param_shapes();
437        let device = Device::Cpu;
438        let config = PredictionConfig::default();
439
440        let mut predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
441        let gradients = create_mock_gradients(&device);
442
443        predictor.record_gradient(&gradients).unwrap();
444        assert_eq!(predictor.total_steps(), 1);
445
446        // Still should compute full (need 2+ entries in history)
447        assert!(predictor.should_compute_full());
448
449        predictor.record_gradient(&gradients).unwrap();
450        assert_eq!(predictor.total_steps(), 2);
451
452        // Now should not require full computation (has history)
453        assert!(!predictor.should_compute_full());
454    }
455
456    #[test]
457    fn test_predict_gradient() {
458        let shapes = create_param_shapes();
459        let device = Device::Cpu;
460        let config = PredictionConfig::default().with_prediction_steps(4);
461
462        let mut predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
463        let gradients = create_mock_gradients(&device);
464
465        // Build history
466        predictor.record_gradient(&gradients).unwrap();
467        predictor.record_gradient(&gradients).unwrap();
468
469        // Now we can predict
470        let predicted = predictor.predict_gradient().unwrap();
471        assert_eq!(predicted.len(), 2);
472
473        // Check shapes match
474        for (name, _shape) in &shapes {
475            assert!(predicted.contains_key(name));
476        }
477    }
478
479    #[test]
480    fn test_correction_cycle() {
481        let shapes = create_param_shapes();
482        let device = Device::Cpu;
483        let config = PredictionConfig::default().with_prediction_steps(2);
484
485        let mut predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
486        let gradients = create_mock_gradients(&device);
487
488        // Build history
489        predictor.record_gradient(&gradients).unwrap();
490        predictor.record_gradient(&gradients).unwrap();
491
492        // Predict for 2 steps
493        predictor.predict_gradient().unwrap();
494        predictor.predict_gradient().unwrap();
495
496        // Now should require full computation (correction cycle)
497        assert!(predictor.should_compute_full());
498    }
499
500    #[test]
501    fn test_compute_correction() {
502        let shapes = create_param_shapes();
503        let device = Device::Cpu;
504        let config = PredictionConfig::default();
505
506        let mut predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
507        let gradients = create_mock_gradients(&device);
508
509        // Build history and predict
510        predictor.record_gradient(&gradients).unwrap();
511        predictor.record_gradient(&gradients).unwrap();
512        let _predicted = predictor.predict_gradient().unwrap();
513
514        // Compute correction with actual gradients
515        let actual = create_mock_gradients(&device);
516        let corrections = predictor.compute_correction(&actual).unwrap();
517
518        assert_eq!(corrections.len(), 2);
519    }
520
521    #[test]
522    fn test_apply_correction() {
523        let shapes = create_param_shapes();
524        let device = Device::Cpu;
525        let config = PredictionConfig::default().with_correction_weight(0.5);
526
527        let mut predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
528        let gradients = create_mock_gradients(&device);
529
530        // Build history and predict
531        predictor.record_gradient(&gradients).unwrap();
532        predictor.record_gradient(&gradients).unwrap();
533        let _predicted = predictor.predict_gradient().unwrap();
534
535        // Compute correction
536        let actual = create_mock_gradients(&device);
537        predictor.compute_correction(&actual).unwrap();
538
539        // Apply correction
540        let mut grads_to_modify = create_mock_gradients(&device);
541        predictor.apply_correction(&mut grads_to_modify).unwrap();
542
543        // Correction should be cleared
544        assert!(predictor.correction_accumulator.is_empty());
545    }
546
547    #[test]
548    fn test_stats() {
549        let shapes = create_param_shapes();
550        let device = Device::Cpu;
551        let config = PredictionConfig::default().with_prediction_steps(4);
552
553        let mut predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
554        let gradients = create_mock_gradients(&device);
555
556        predictor.record_gradient(&gradients).unwrap();
557        predictor.record_gradient(&gradients).unwrap();
558        predictor.predict_gradient().unwrap();
559
560        let stats = predictor.get_stats();
561        assert_eq!(stats.total_steps, 3);
562        assert!(stats.prediction_ratio > 0.7); // 4/(4+1) = 0.8
563    }
564
565    #[test]
566    fn test_reset() {
567        let shapes = create_param_shapes();
568        let device = Device::Cpu;
569        let config = PredictionConfig::default();
570
571        let mut predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
572        let gradients = create_mock_gradients(&device);
573
574        predictor.record_gradient(&gradients).unwrap();
575        predictor.record_gradient(&gradients).unwrap();
576
577        assert_eq!(predictor.total_steps(), 2);
578
579        predictor.reset();
580
581        assert_eq!(predictor.total_steps(), 0);
582        assert!(predictor.should_compute_full());
583    }
584}