1use std::collections::HashMap;
35
36use candle_core::{DType, Device, Tensor};
37
38use crate::error::{OptimError, Result};
39
40#[derive(Debug, Clone)]
42pub struct DeterministicPredictionConfig {
43 pub warmup_steps: usize,
45
46 pub history_window: usize,
48
49 pub prediction_horizon: usize,
51
52 pub history_decay: f32,
54
55 pub residual_threshold: f32,
57}
58
59impl Default for DeterministicPredictionConfig {
60 fn default() -> Self {
61 Self {
62 warmup_steps: 10,
63 history_window: 8,
64 prediction_horizon: 4,
65 history_decay: 0.95,
66 residual_threshold: 0.5,
67 }
68 }
69}
70
71impl DeterministicPredictionConfig {
72 #[must_use]
74 pub const fn with_warmup_steps(mut self, steps: usize) -> Self {
75 self.warmup_steps = steps;
76 self
77 }
78
79 #[must_use]
81 pub const fn with_history_window(mut self, window: usize) -> Self {
82 self.history_window = window;
83 self
84 }
85
86 #[must_use]
88 pub const fn with_prediction_horizon(mut self, horizon: usize) -> Self {
89 self.prediction_horizon = horizon;
90 self
91 }
92
93 #[must_use]
95 pub const fn with_history_decay(mut self, decay: f32) -> Self {
96 self.history_decay = decay;
97 self
98 }
99}
100
101#[derive(Clone)]
103struct GradientSnapshot {
104 step: usize,
106 gradient: Tensor,
108}
109
110#[derive(Clone)]
112struct LinearGradientModel {
113 baseline: Tensor,
115 velocity: Tensor,
117 fit_step: usize,
119}
120
121pub struct DeterministicPredictor {
126 config: DeterministicPredictionConfig,
127 device: Device,
128
129 shapes: HashMap<String, Vec<usize>>,
131
132 history: HashMap<String, Vec<GradientSnapshot>>,
134
135 models: HashMap<String, LinearGradientModel>,
137
138 residuals: HashMap<String, Tensor>,
140
141 global_step: usize,
143
144 steps_since_fit: usize,
146
147 warmup_complete: bool,
149
150 stats: PredictorStatistics,
152}
153
154#[derive(Debug, Clone, Default)]
156pub struct PredictorStatistics {
157 pub total_steps: usize,
159 pub full_steps: usize,
161 pub predicted_steps: usize,
163 pub mean_abs_error: f32,
165 pub max_residual: f32,
167 pub early_corrections: usize,
169}
170
171impl DeterministicPredictor {
172 pub fn new(
180 param_shapes: &[(String, Vec<usize>)],
181 config: DeterministicPredictionConfig,
182 device: &Device,
183 ) -> Result<Self> {
184 let mut shapes = HashMap::new();
185 let mut history = HashMap::new();
186 let mut residuals = HashMap::new();
187
188 for (name, shape) in param_shapes {
189 shapes.insert(name.clone(), shape.clone());
190 history.insert(name.clone(), Vec::with_capacity(config.history_window + 4));
191 residuals.insert(
193 name.clone(),
194 Tensor::zeros(shape.as_slice(), DType::F32, device)?,
195 );
196 }
197
198 Ok(Self {
199 config,
200 device: device.clone(),
201 shapes,
202 history,
203 models: HashMap::new(),
204 residuals,
205 global_step: 0,
206 steps_since_fit: 0,
207 warmup_complete: false,
208 stats: PredictorStatistics::default(),
209 })
210 }
211
212 #[must_use]
214 pub fn in_warmup(&self) -> bool {
215 !self.warmup_complete
216 }
217
218 #[must_use]
220 pub fn needs_correction(&self) -> bool {
221 if self.steps_since_fit >= self.config.prediction_horizon {
223 return true;
224 }
225
226 for residual in self.residuals.values() {
228 if let Ok(max_abs) = residual.abs().and_then(|t| t.max(0)).and_then(|t| t.to_scalar::<f32>()) {
229 if max_abs > self.config.residual_threshold {
230 return true;
231 }
232 }
233 }
234
235 false
236 }
237
238 pub fn record_gradient(
247 &mut self,
248 gradients: &HashMap<String, Tensor>,
249 is_correction: bool,
250 ) -> Result<()> {
251 for (name, grad) in gradients {
253 if let Some(hist) = self.history.get_mut(name) {
254 hist.push(GradientSnapshot {
255 step: self.global_step,
256 gradient: grad.clone(),
257 });
258
259 let window = self.config.history_window;
261 if hist.len() > window + 2 {
262 hist.drain(0..hist.len() - window - 2);
263 }
264 }
265 }
266
267 self.stats.total_steps += 1;
269 self.stats.full_steps += 1;
270
271 if !self.warmup_complete {
273 let min_history = self.history.values().map(|h| h.len()).min().unwrap_or(0);
274 if min_history >= self.config.warmup_steps {
275 self.warmup_complete = true;
276 self.fit_models()?;
277 }
278 } else if is_correction {
279 self.update_residuals(gradients)?;
281 self.fit_models()?;
283 } else {
284 self.fit_models()?;
286 }
287
288 self.global_step += 1;
289 self.steps_since_fit = 0;
290
291 Ok(())
292 }
293
294 pub fn predict_gradient(&mut self) -> Result<HashMap<String, Tensor>> {
303 if !self.warmup_complete {
304 return Err(OptimError::Prediction(
305 "Cannot predict during warmup phase".to_string(),
306 ));
307 }
308
309 let mut predicted = HashMap::new();
310
311 for (name, model) in &self.models {
312 let dt = (self.global_step - model.fit_step) as f64;
314
315 let velocity_term = (&model.velocity * dt)?;
317 let mut prediction = model.baseline.add(&velocity_term)?;
318
319 if let Some(residual) = self.residuals.get(name) {
321 let residual_weight = self.config.history_decay.powi(self.steps_since_fit as i32);
323 let scaled_residual = (residual * residual_weight as f64)?;
324 prediction = prediction.add(&scaled_residual)?;
325 }
326
327 predicted.insert(name.clone(), prediction);
328 }
329
330 self.stats.total_steps += 1;
332 self.stats.predicted_steps += 1;
333 self.global_step += 1;
334 self.steps_since_fit += 1;
335
336 Ok(predicted)
337 }
338
339 fn update_residuals(&mut self, actual: &HashMap<String, Tensor>) -> Result<()> {
344 for (name, actual_grad) in actual {
345 if let Some(model) = self.models.get(name) {
346 let dt = (self.global_step - model.fit_step) as f64;
348 let velocity_term = (&model.velocity * dt)?;
349 let predicted = model.baseline.add(&velocity_term)?;
350
351 let error = actual_grad.sub(&predicted)?;
353
354 if let Some(existing) = self.residuals.get(name) {
356 let decay = self.config.history_decay as f64;
357 let decayed_existing = (existing * decay)?;
358 let new_contribution = (&error * (1.0 - decay))?;
359 self.residuals
360 .insert(name.clone(), decayed_existing.add(&new_contribution)?);
361 } else {
362 self.residuals.insert(name.clone(), error);
363 }
364
365 if let Ok(mean_err) = actual_grad
367 .sub(&predicted)
368 .and_then(|t| t.abs())
369 .and_then(|t| t.mean_all())
370 .and_then(|t| t.to_scalar::<f32>())
371 {
372 self.stats.mean_abs_error =
373 0.9 * self.stats.mean_abs_error + 0.1 * mean_err;
374 }
375 }
376 }
377
378 Ok(())
379 }
380
381 fn fit_models(&mut self) -> Result<()> {
386 for (name, hist) in &self.history {
387 if hist.len() < 2 {
388 continue;
389 }
390
391 let shape = self.shapes.get(name).ok_or_else(|| {
392 OptimError::Prediction(format!("Unknown parameter: {name}"))
393 })?;
394
395 let n = hist.len();
400 let mut sum_w = 0.0f64;
401 let mut sum_wt = 0.0f64;
402 let mut sum_wt2 = 0.0f64;
403 let mut sum_wg: Option<Tensor> = None;
404 let mut sum_wtg: Option<Tensor> = None;
405
406 let t_ref = hist.last().map(|s| s.step).unwrap_or(0);
408
409 for (i, snapshot) in hist.iter().enumerate() {
410 let age = (n - 1 - i) as i32;
412 let w = self.config.history_decay.powi(age) as f64;
413
414 let t = (snapshot.step as i64 - t_ref as i64) as f64;
416
417 sum_w += w;
418 sum_wt += w * t;
419 sum_wt2 += w * t * t;
420
421 let wg = (&snapshot.gradient * w)?;
423 let wtg = (&snapshot.gradient * (w * t))?;
424
425 sum_wg = Some(match sum_wg {
426 Some(acc) => acc.add(&wg)?,
427 None => wg,
428 });
429
430 sum_wtg = Some(match sum_wtg {
431 Some(acc) => acc.add(&wtg)?,
432 None => wtg,
433 });
434 }
435
436 let det = sum_w * sum_wt2 - sum_wt * sum_wt;
441 if det.abs() < 1e-10 {
442 let baseline = hist.last().unwrap().gradient.clone();
444 let velocity = Tensor::zeros(shape.as_slice(), DType::F32, &self.device)?;
445 self.models.insert(
446 name.clone(),
447 LinearGradientModel {
448 baseline,
449 velocity,
450 fit_step: self.global_step,
451 },
452 );
453 continue;
454 }
455
456 let sum_wg = sum_wg.ok_or_else(|| {
457 OptimError::Prediction("Empty gradient history".to_string())
458 })?;
459 let sum_wtg = sum_wtg.ok_or_else(|| {
460 OptimError::Prediction("Empty gradient history".to_string())
461 })?;
462
463 let baseline = {
468 let term1 = (&sum_wg * sum_wt2)?;
469 let term2 = (&sum_wtg * sum_wt)?;
470 let numer = term1.sub(&term2)?;
471 (&numer * (1.0 / det))?
472 };
473
474 let velocity = {
475 let term1 = (&sum_wtg * sum_w)?;
476 let term2 = (&sum_wg * sum_wt)?;
477 let numer = term1.sub(&term2)?;
478 (&numer * (1.0 / det))?
479 };
480
481 self.models.insert(
482 name.clone(),
483 LinearGradientModel {
484 baseline,
485 velocity,
486 fit_step: self.global_step,
487 },
488 );
489 }
490
491 Ok(())
492 }
493
494 #[must_use]
496 pub fn get_stats(&self) -> &PredictorStatistics {
497 &self.stats
498 }
499
500 pub fn reset(&mut self) -> Result<()> {
502 for hist in self.history.values_mut() {
503 hist.clear();
504 }
505 self.models.clear();
506
507 for (name, shape) in &self.shapes {
509 self.residuals.insert(
510 name.clone(),
511 Tensor::zeros(shape.as_slice(), DType::F32, &self.device)?,
512 );
513 }
514
515 self.global_step = 0;
516 self.steps_since_fit = 0;
517 self.warmup_complete = false;
518 self.stats = PredictorStatistics::default();
519
520 Ok(())
521 }
522
523 #[must_use]
525 pub const fn global_step(&self) -> usize {
526 self.global_step
527 }
528
529 #[must_use]
531 pub const fn is_ready(&self) -> bool {
532 self.warmup_complete
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 fn create_shapes() -> Vec<(String, Vec<usize>)> {
541 vec![
542 ("layer.weight".to_string(), vec![16, 32]),
543 ("layer.bias".to_string(), vec![16]),
544 ]
545 }
546
547 #[test]
548 fn test_warmup_phase() {
549 let config = DeterministicPredictionConfig::default().with_warmup_steps(5);
550 let mut predictor =
551 DeterministicPredictor::new(&create_shapes(), config, &Device::Cpu).unwrap();
552
553 assert!(predictor.in_warmup());
554 assert!(!predictor.is_ready());
555
556 for i in 0..5 {
558 let mut grads = HashMap::new();
559 grads.insert(
560 "layer.weight".to_string(),
561 Tensor::ones((16, 32), DType::F32, &Device::Cpu)
562 .unwrap()
563 .affine(i as f64, 0.0)
564 .unwrap(),
565 );
566 grads.insert(
567 "layer.bias".to_string(),
568 Tensor::ones(16, DType::F32, &Device::Cpu)
569 .unwrap()
570 .affine(i as f64, 0.0)
571 .unwrap(),
572 );
573 predictor.record_gradient(&grads, false).unwrap();
574 }
575
576 assert!(!predictor.in_warmup());
577 assert!(predictor.is_ready());
578 }
579
580 #[test]
581 fn test_deterministic_prediction() {
582 let config = DeterministicPredictionConfig::default()
583 .with_warmup_steps(3)
584 .with_prediction_horizon(2);
585 let device = Device::Cpu;
586
587 let shapes = create_shapes();
589 let mut pred1 = DeterministicPredictor::new(&shapes, config.clone(), &device).unwrap();
590 let mut pred2 = DeterministicPredictor::new(&shapes, config, &device).unwrap();
591
592 for i in 0..5 {
594 let mut grads = HashMap::new();
595 grads.insert(
596 "layer.weight".to_string(),
597 Tensor::ones((16, 32), DType::F32, &device)
598 .unwrap()
599 .affine(1.0 + i as f64 * 0.1, 0.0)
600 .unwrap(),
601 );
602 grads.insert(
603 "layer.bias".to_string(),
604 Tensor::ones(16, DType::F32, &device)
605 .unwrap()
606 .affine(1.0 + i as f64 * 0.1, 0.0)
607 .unwrap(),
608 );
609 pred1.record_gradient(&grads, false).unwrap();
610 pred2.record_gradient(&grads, false).unwrap();
611 }
612
613 let p1 = pred1.predict_gradient().unwrap();
615 let p2 = pred2.predict_gradient().unwrap();
616
617 for (name, t1) in &p1 {
618 let t2 = p2.get(name).unwrap();
619 let diff: f32 = t1
620 .sub(t2)
621 .unwrap()
622 .abs()
623 .unwrap()
624 .flatten_all()
625 .unwrap()
626 .max(0)
627 .unwrap()
628 .to_scalar()
629 .unwrap();
630 assert!(
631 diff < 1e-6,
632 "Predictions should be deterministic, got diff={diff}"
633 );
634 }
635 }
636
637 #[test]
638 fn test_linear_fit_quality() {
639 let config = DeterministicPredictionConfig::default()
641 .with_warmup_steps(5)
642 .with_prediction_horizon(3);
643 let device = Device::Cpu;
644 let shapes = vec![("param".to_string(), vec![8])];
645
646 let mut predictor = DeterministicPredictor::new(&shapes, config, &device).unwrap();
647
648 for t in 0..5 {
650 let mut grads = HashMap::new();
651 grads.insert(
652 "param".to_string(),
653 Tensor::ones(8, DType::F32, &device)
654 .unwrap()
655 .affine(1.0 + 0.1 * t as f64, 0.0)
656 .unwrap(),
657 );
658 predictor.record_gradient(&grads, false).unwrap();
659 }
660
661 let predicted = predictor.predict_gradient().unwrap();
663 let pred_vals: Vec<f32> = predicted
664 .get("param")
665 .unwrap()
666 .to_vec1()
667 .unwrap();
668
669 for v in &pred_vals {
671 assert!(
672 (*v - 1.5).abs() < 0.1,
673 "Linear prediction should be accurate, got {v}"
674 );
675 }
676 }
677}