Skip to main content

vsa_optim_rs/prediction/
deterministic.rs

1//! Deterministic gradient prediction for phase-based training.
2//!
3//! This module implements a deterministic prediction system that models
4//! gradient evolution during training. Unlike stochastic extrapolation,
5//! this approach guarantees reproducible predictions given the same history.
6//!
7//! # Algorithm
8//!
9//! The predictor fits a linear regression model to the gradient trajectory:
10//!
11//! ```text
12//! g(t) = g(0) + α * t + residual(t)
13//! ```
14//!
15//! Where:
16//! - `g(t)` is the gradient at step t
17//! - `g(0)` is the baseline gradient (from warmup)
18//! - `α` is the fitted gradient velocity (change per step)
19//! - `residual(t)` accumulates prediction errors for correction
20//!
21//! # Phases
22//!
23//! 1. **Warmup**: `warmup_steps` of full gradient computation to establish baseline
24//! 2. **Predict**: Extrapolate gradients using fitted model
25//! 3. **Correct**: Compare prediction with actual, update model and residuals
26//!
27//! # Determinism
28//!
29//! Predictions are fully deterministic because:
30//! - No random sampling or stochastic operations
31//! - Model fit uses closed-form least squares (no iterative optimization)
32//! - Same history always produces same prediction
33
34use std::collections::HashMap;
35
36use candle_core::{DType, Device, Tensor};
37
38use crate::error::{OptimError, Result};
39
40/// Configuration for deterministic gradient prediction.
41#[derive(Debug, Clone)]
42pub struct DeterministicPredictionConfig {
43    /// Minimum steps of full training before prediction begins.
44    pub warmup_steps: usize,
45
46    /// Number of history steps to use for fitting.
47    pub history_window: usize,
48
49    /// Steps to predict before correction.
50    pub prediction_horizon: usize,
51
52    /// Exponential decay for older history (1.0 = no decay).
53    pub history_decay: f32,
54
55    /// Threshold for residual magnitude to trigger early correction.
56    pub residual_threshold: f32,
57}
58
59impl Default for DeterministicPredictionConfig {
60    fn default() -> Self {
61        Self {
62            warmup_steps: 10,
63            history_window: 8,
64            prediction_horizon: 4,
65            history_decay: 0.95,
66            residual_threshold: 0.5,
67        }
68    }
69}
70
71impl DeterministicPredictionConfig {
72    /// Builder: Set warmup steps.
73    #[must_use]
74    pub const fn with_warmup_steps(mut self, steps: usize) -> Self {
75        self.warmup_steps = steps;
76        self
77    }
78
79    /// Builder: Set history window.
80    #[must_use]
81    pub const fn with_history_window(mut self, window: usize) -> Self {
82        self.history_window = window;
83        self
84    }
85
86    /// Builder: Set prediction horizon.
87    #[must_use]
88    pub const fn with_prediction_horizon(mut self, horizon: usize) -> Self {
89        self.prediction_horizon = horizon;
90        self
91    }
92
93    /// Builder: Set history decay.
94    #[must_use]
95    pub const fn with_history_decay(mut self, decay: f32) -> Self {
96        self.history_decay = decay;
97        self
98    }
99}
100
101/// Gradient history entry with step index.
102#[derive(Clone)]
103struct GradientSnapshot {
104    /// Global step index when this gradient was recorded.
105    step: usize,
106    /// The gradient tensor.
107    gradient: Tensor,
108}
109
110/// Linear model for gradient evolution: g(t) = baseline + velocity * t
111#[derive(Clone)]
112struct LinearGradientModel {
113    /// Baseline gradient (intercept).
114    baseline: Tensor,
115    /// Gradient velocity (slope) per step.
116    velocity: Tensor,
117    /// Step index at which model was fitted.
118    fit_step: usize,
119}
120
121/// Deterministic gradient predictor.
122///
123/// Maintains gradient history and fits a linear model to predict
124/// future gradients deterministically.
125pub struct DeterministicPredictor {
126    config: DeterministicPredictionConfig,
127    device: Device,
128
129    /// Parameter shapes for reconstruction.
130    shapes: HashMap<String, Vec<usize>>,
131
132    /// Gradient history per parameter.
133    history: HashMap<String, Vec<GradientSnapshot>>,
134
135    /// Fitted linear models per parameter.
136    models: HashMap<String, LinearGradientModel>,
137
138    /// Accumulated residuals (prediction errors) per parameter.
139    residuals: HashMap<String, Tensor>,
140
141    /// Current global step.
142    global_step: usize,
143
144    /// Steps since last model fit.
145    steps_since_fit: usize,
146
147    /// Whether warmup is complete.
148    warmup_complete: bool,
149
150    /// Statistics tracking.
151    stats: PredictorStatistics,
152}
153
154/// Statistics for prediction quality monitoring.
155#[derive(Debug, Clone, Default)]
156pub struct PredictorStatistics {
157    /// Total steps processed.
158    pub total_steps: usize,
159    /// Steps with full gradient computation.
160    pub full_steps: usize,
161    /// Steps with predicted gradients.
162    pub predicted_steps: usize,
163    /// Mean absolute prediction error.
164    pub mean_abs_error: f32,
165    /// Maximum observed residual.
166    pub max_residual: f32,
167    /// Number of early corrections triggered.
168    pub early_corrections: usize,
169}
170
171impl DeterministicPredictor {
172    /// Create a new deterministic predictor.
173    ///
174    /// # Arguments
175    ///
176    /// * `param_shapes` - List of (name, shape) tuples for parameters
177    /// * `config` - Prediction configuration
178    /// * `device` - Device for tensor storage
179    pub fn new(
180        param_shapes: &[(String, Vec<usize>)],
181        config: DeterministicPredictionConfig,
182        device: &Device,
183    ) -> Result<Self> {
184        let mut shapes = HashMap::new();
185        let mut history = HashMap::new();
186        let mut residuals = HashMap::new();
187
188        for (name, shape) in param_shapes {
189            shapes.insert(name.clone(), shape.clone());
190            history.insert(name.clone(), Vec::with_capacity(config.history_window + 4));
191            // Initialize residuals to zero
192            residuals.insert(
193                name.clone(),
194                Tensor::zeros(shape.as_slice(), DType::F32, device)?,
195            );
196        }
197
198        Ok(Self {
199            config,
200            device: device.clone(),
201            shapes,
202            history,
203            models: HashMap::new(),
204            residuals,
205            global_step: 0,
206            steps_since_fit: 0,
207            warmup_complete: false,
208            stats: PredictorStatistics::default(),
209        })
210    }
211
212    /// Check if still in warmup phase (must compute full gradients).
213    #[must_use]
214    pub fn in_warmup(&self) -> bool {
215        !self.warmup_complete
216    }
217
218    /// Check if correction is needed based on residual magnitude.
219    #[must_use]
220    pub fn needs_correction(&self) -> bool {
221        // Need correction after prediction horizon
222        if self.steps_since_fit >= self.config.prediction_horizon {
223            return true;
224        }
225
226        // Check residual threshold
227        for residual in self.residuals.values() {
228            if let Ok(max_abs) = residual.abs().and_then(|t| t.max(0)).and_then(|t| t.to_scalar::<f32>()) {
229                if max_abs > self.config.residual_threshold {
230                    return true;
231                }
232            }
233        }
234
235        false
236    }
237
238    /// Record a gradient from full computation.
239    ///
240    /// Updates history and potentially refits the prediction model.
241    ///
242    /// # Arguments
243    ///
244    /// * `gradients` - Map of parameter names to gradient tensors
245    /// * `is_correction` - Whether this is a correction step (vs. regular full step)
246    pub fn record_gradient(
247        &mut self,
248        gradients: &HashMap<String, Tensor>,
249        is_correction: bool,
250    ) -> Result<()> {
251        // Record to history
252        for (name, grad) in gradients {
253            if let Some(hist) = self.history.get_mut(name) {
254                hist.push(GradientSnapshot {
255                    step: self.global_step,
256                    gradient: grad.clone(),
257                });
258
259                // Trim history to window size
260                let window = self.config.history_window;
261                if hist.len() > window + 2 {
262                    hist.drain(0..hist.len() - window - 2);
263                }
264            }
265        }
266
267        // Update statistics
268        self.stats.total_steps += 1;
269        self.stats.full_steps += 1;
270
271        // Check warmup completion
272        if !self.warmup_complete {
273            let min_history = self.history.values().map(|h| h.len()).min().unwrap_or(0);
274            if min_history >= self.config.warmup_steps {
275                self.warmup_complete = true;
276                self.fit_models()?;
277            }
278        } else if is_correction {
279            // Update residuals based on prediction error
280            self.update_residuals(gradients)?;
281            // Refit model with new data
282            self.fit_models()?;
283        } else {
284            // Regular full step - refit model
285            self.fit_models()?;
286        }
287
288        self.global_step += 1;
289        self.steps_since_fit = 0;
290
291        Ok(())
292    }
293
294    /// Predict gradient for current step.
295    ///
296    /// Uses the fitted linear model to extrapolate from history.
297    /// Prediction is fully deterministic.
298    ///
299    /// # Returns
300    ///
301    /// Map of parameter names to predicted gradient tensors.
302    pub fn predict_gradient(&mut self) -> Result<HashMap<String, Tensor>> {
303        if !self.warmup_complete {
304            return Err(OptimError::Prediction(
305                "Cannot predict during warmup phase".to_string(),
306            ));
307        }
308
309        let mut predicted = HashMap::new();
310
311        for (name, model) in &self.models {
312            // Steps since model was fitted
313            let dt = (self.global_step - model.fit_step) as f64;
314
315            // Linear prediction: g(t) = baseline + velocity * dt
316            let velocity_term = (&model.velocity * dt)?;
317            let mut prediction = model.baseline.add(&velocity_term)?;
318
319            // Add accumulated residual correction
320            if let Some(residual) = self.residuals.get(name) {
321                // Apply weighted residual (decays with steps since correction)
322                let residual_weight = self.config.history_decay.powi(self.steps_since_fit as i32);
323                let scaled_residual = (residual * residual_weight as f64)?;
324                prediction = prediction.add(&scaled_residual)?;
325            }
326
327            predicted.insert(name.clone(), prediction);
328        }
329
330        // Update statistics
331        self.stats.total_steps += 1;
332        self.stats.predicted_steps += 1;
333        self.global_step += 1;
334        self.steps_since_fit += 1;
335
336        Ok(predicted)
337    }
338
339    /// Update residuals based on prediction error.
340    ///
341    /// Called during correction step to accumulate the difference
342    /// between predicted and actual gradients.
343    fn update_residuals(&mut self, actual: &HashMap<String, Tensor>) -> Result<()> {
344        for (name, actual_grad) in actual {
345            if let Some(model) = self.models.get(name) {
346                // What we predicted for this step
347                let dt = (self.global_step - model.fit_step) as f64;
348                let velocity_term = (&model.velocity * dt)?;
349                let predicted = model.baseline.add(&velocity_term)?;
350
351                // Residual = actual - predicted
352                let error = actual_grad.sub(&predicted)?;
353
354                // Update accumulated residual with exponential averaging
355                if let Some(existing) = self.residuals.get(name) {
356                    let decay = self.config.history_decay as f64;
357                    let decayed_existing = (existing * decay)?;
358                    let new_contribution = (&error * (1.0 - decay))?;
359                    self.residuals
360                        .insert(name.clone(), decayed_existing.add(&new_contribution)?);
361                } else {
362                    self.residuals.insert(name.clone(), error);
363                }
364
365                // Track statistics
366                if let Ok(mean_err) = actual_grad
367                    .sub(&predicted)
368                    .and_then(|t| t.abs())
369                    .and_then(|t| t.mean_all())
370                    .and_then(|t| t.to_scalar::<f32>())
371                {
372                    self.stats.mean_abs_error =
373                        0.9 * self.stats.mean_abs_error + 0.1 * mean_err;
374                }
375            }
376        }
377
378        Ok(())
379    }
380
381    /// Fit linear models to gradient history.
382    ///
383    /// Uses weighted least squares to fit g(t) = baseline + velocity * t
384    /// for each parameter.
385    fn fit_models(&mut self) -> Result<()> {
386        for (name, hist) in &self.history {
387            if hist.len() < 2 {
388                continue;
389            }
390
391            let shape = self.shapes.get(name).ok_or_else(|| {
392                OptimError::Prediction(format!("Unknown parameter: {name}"))
393            })?;
394
395            // Weighted least squares fitting
396            // g(t) = baseline + velocity * t
397            // Minimize: sum_i w_i * (g_i - baseline - velocity * t_i)^2
398
399            let n = hist.len();
400            let mut sum_w = 0.0f64;
401            let mut sum_wt = 0.0f64;
402            let mut sum_wt2 = 0.0f64;
403            let mut sum_wg: Option<Tensor> = None;
404            let mut sum_wtg: Option<Tensor> = None;
405
406            // Reference step for numerical stability
407            let t_ref = hist.last().map(|s| s.step).unwrap_or(0);
408
409            for (i, snapshot) in hist.iter().enumerate() {
410                // Exponential weight favoring recent gradients
411                let age = (n - 1 - i) as i32;
412                let w = self.config.history_decay.powi(age) as f64;
413
414                // Relative step index
415                let t = (snapshot.step as i64 - t_ref as i64) as f64;
416
417                sum_w += w;
418                sum_wt += w * t;
419                sum_wt2 += w * t * t;
420
421                // Accumulate weighted gradients
422                let wg = (&snapshot.gradient * w)?;
423                let wtg = (&snapshot.gradient * (w * t))?;
424
425                sum_wg = Some(match sum_wg {
426                    Some(acc) => acc.add(&wg)?,
427                    None => wg,
428                });
429
430                sum_wtg = Some(match sum_wtg {
431                    Some(acc) => acc.add(&wtg)?,
432                    None => wtg,
433                });
434            }
435
436            // Solve normal equations for least squares
437            // [sum_w    sum_wt  ] [baseline]   [sum_wg ]
438            // [sum_wt   sum_wt2 ] [velocity] = [sum_wtg]
439
440            let det = sum_w * sum_wt2 - sum_wt * sum_wt;
441            if det.abs() < 1e-10 {
442                // Degenerate case: use latest gradient as baseline, zero velocity
443                let baseline = hist.last().unwrap().gradient.clone();
444                let velocity = Tensor::zeros(shape.as_slice(), DType::F32, &self.device)?;
445                self.models.insert(
446                    name.clone(),
447                    LinearGradientModel {
448                        baseline,
449                        velocity,
450                        fit_step: self.global_step,
451                    },
452                );
453                continue;
454            }
455
456            let sum_wg = sum_wg.ok_or_else(|| {
457                OptimError::Prediction("Empty gradient history".to_string())
458            })?;
459            let sum_wtg = sum_wtg.ok_or_else(|| {
460                OptimError::Prediction("Empty gradient history".to_string())
461            })?;
462
463            // Cramer's rule
464            // baseline = (sum_wt2 * sum_wg - sum_wt * sum_wtg) / det
465            // velocity = (sum_w * sum_wtg - sum_wt * sum_wg) / det
466
467            let baseline = {
468                let term1 = (&sum_wg * sum_wt2)?;
469                let term2 = (&sum_wtg * sum_wt)?;
470                let numer = term1.sub(&term2)?;
471                (&numer * (1.0 / det))?
472            };
473
474            let velocity = {
475                let term1 = (&sum_wtg * sum_w)?;
476                let term2 = (&sum_wg * sum_wt)?;
477                let numer = term1.sub(&term2)?;
478                (&numer * (1.0 / det))?
479            };
480
481            self.models.insert(
482                name.clone(),
483                LinearGradientModel {
484                    baseline,
485                    velocity,
486                    fit_step: self.global_step,
487                },
488            );
489        }
490
491        Ok(())
492    }
493
494    /// Get prediction statistics.
495    #[must_use]
496    pub fn get_stats(&self) -> &PredictorStatistics {
497        &self.stats
498    }
499
500    /// Reset predictor state.
501    pub fn reset(&mut self) -> Result<()> {
502        for hist in self.history.values_mut() {
503            hist.clear();
504        }
505        self.models.clear();
506
507        // Reset residuals to zero
508        for (name, shape) in &self.shapes {
509            self.residuals.insert(
510                name.clone(),
511                Tensor::zeros(shape.as_slice(), DType::F32, &self.device)?,
512            );
513        }
514
515        self.global_step = 0;
516        self.steps_since_fit = 0;
517        self.warmup_complete = false;
518        self.stats = PredictorStatistics::default();
519
520        Ok(())
521    }
522
523    /// Get current global step.
524    #[must_use]
525    pub const fn global_step(&self) -> usize {
526        self.global_step
527    }
528
529    /// Check if warmup is complete.
530    #[must_use]
531    pub const fn is_ready(&self) -> bool {
532        self.warmup_complete
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    fn create_shapes() -> Vec<(String, Vec<usize>)> {
541        vec![
542            ("layer.weight".to_string(), vec![16, 32]),
543            ("layer.bias".to_string(), vec![16]),
544        ]
545    }
546
547    #[test]
548    fn test_warmup_phase() {
549        let config = DeterministicPredictionConfig::default().with_warmup_steps(5);
550        let mut predictor =
551            DeterministicPredictor::new(&create_shapes(), config, &Device::Cpu).unwrap();
552
553        assert!(predictor.in_warmup());
554        assert!(!predictor.is_ready());
555
556        // Record warmup gradients
557        for i in 0..5 {
558            let mut grads = HashMap::new();
559            grads.insert(
560                "layer.weight".to_string(),
561                Tensor::ones((16, 32), DType::F32, &Device::Cpu)
562                    .unwrap()
563                    .affine(i as f64, 0.0)
564                    .unwrap(),
565            );
566            grads.insert(
567                "layer.bias".to_string(),
568                Tensor::ones(16, DType::F32, &Device::Cpu)
569                    .unwrap()
570                    .affine(i as f64, 0.0)
571                    .unwrap(),
572            );
573            predictor.record_gradient(&grads, false).unwrap();
574        }
575
576        assert!(!predictor.in_warmup());
577        assert!(predictor.is_ready());
578    }
579
580    #[test]
581    fn test_deterministic_prediction() {
582        let config = DeterministicPredictionConfig::default()
583            .with_warmup_steps(3)
584            .with_prediction_horizon(2);
585        let device = Device::Cpu;
586
587        // Create two identical predictors
588        let shapes = create_shapes();
589        let mut pred1 = DeterministicPredictor::new(&shapes, config.clone(), &device).unwrap();
590        let mut pred2 = DeterministicPredictor::new(&shapes, config, &device).unwrap();
591
592        // Feed identical history
593        for i in 0..5 {
594            let mut grads = HashMap::new();
595            grads.insert(
596                "layer.weight".to_string(),
597                Tensor::ones((16, 32), DType::F32, &device)
598                    .unwrap()
599                    .affine(1.0 + i as f64 * 0.1, 0.0)
600                    .unwrap(),
601            );
602            grads.insert(
603                "layer.bias".to_string(),
604                Tensor::ones(16, DType::F32, &device)
605                    .unwrap()
606                    .affine(1.0 + i as f64 * 0.1, 0.0)
607                    .unwrap(),
608            );
609            pred1.record_gradient(&grads, false).unwrap();
610            pred2.record_gradient(&grads, false).unwrap();
611        }
612
613        // Predictions should be identical
614        let p1 = pred1.predict_gradient().unwrap();
615        let p2 = pred2.predict_gradient().unwrap();
616
617        for (name, t1) in &p1 {
618            let t2 = p2.get(name).unwrap();
619            let diff: f32 = t1
620                .sub(t2)
621                .unwrap()
622                .abs()
623                .unwrap()
624                .flatten_all()
625                .unwrap()
626                .max(0)
627                .unwrap()
628                .to_scalar()
629                .unwrap();
630            assert!(
631                diff < 1e-6,
632                "Predictions should be deterministic, got diff={diff}"
633            );
634        }
635    }
636
637    #[test]
638    fn test_linear_fit_quality() {
639        // Test that linear model correctly fits linear gradient evolution
640        let config = DeterministicPredictionConfig::default()
641            .with_warmup_steps(5)
642            .with_prediction_horizon(3);
643        let device = Device::Cpu;
644        let shapes = vec![("param".to_string(), vec![8])];
645
646        let mut predictor = DeterministicPredictor::new(&shapes, config, &device).unwrap();
647
648        // Generate perfectly linear gradients: g(t) = 1 + 0.1*t
649        for t in 0..5 {
650            let mut grads = HashMap::new();
651            grads.insert(
652                "param".to_string(),
653                Tensor::ones(8, DType::F32, &device)
654                    .unwrap()
655                    .affine(1.0 + 0.1 * t as f64, 0.0)
656                    .unwrap(),
657            );
658            predictor.record_gradient(&grads, false).unwrap();
659        }
660
661        // Predict next step - should be close to 1 + 0.1*5 = 1.5
662        let predicted = predictor.predict_gradient().unwrap();
663        let pred_vals: Vec<f32> = predicted
664            .get("param")
665            .unwrap()
666            .to_vec1()
667            .unwrap();
668
669        // All values should be close to 1.5
670        for v in &pred_vals {
671            assert!(
672                (*v - 1.5).abs() < 0.1,
673                "Linear prediction should be accurate, got {v}"
674            );
675        }
676    }
677}