1use std::collections::{HashMap, VecDeque};
6
7use candle_core::{DType, Device, Tensor};
8
9use crate::config::PredictionConfig;
10use crate::error::Result;
11
12pub struct GradientPredictor {
46 config: PredictionConfig,
47 device: Device,
48
49 gradient_history: HashMap<String, VecDeque<Tensor>>,
51
52 shapes: HashMap<String, Vec<usize>>,
54
55 steps_since_full: usize,
57
58 total_steps: usize,
60
61 last_prediction: HashMap<String, Tensor>,
63
64 correction_accumulator: HashMap<String, Tensor>,
66
67 prediction_errors: VecDeque<f32>,
69}
70
71impl GradientPredictor {
72 pub fn new(
84 param_shapes: &[(String, Vec<usize>)],
85 config: PredictionConfig,
86 device: &Device,
87 ) -> Result<Self> {
88 let mut gradient_history = HashMap::new();
89 let mut shapes = HashMap::new();
90
91 for (name, shape) in param_shapes {
92 gradient_history.insert(name.clone(), VecDeque::with_capacity(config.history_size));
93 shapes.insert(name.clone(), shape.clone());
94 }
95
96 Ok(Self {
97 config,
98 device: device.clone(),
99 gradient_history,
100 shapes,
101 steps_since_full: 0,
102 total_steps: 0,
103 last_prediction: HashMap::new(),
104 correction_accumulator: HashMap::new(),
105 prediction_errors: VecDeque::with_capacity(100),
106 })
107 }
108
109 #[must_use]
116 pub fn should_compute_full(&self) -> bool {
117 let any_history = self.gradient_history.values().next();
119 if let Some(history) = any_history {
120 if history.len() < 2 {
121 return true;
122 }
123 } else {
124 return true;
125 }
126
127 if self.steps_since_full >= self.config.prediction_steps {
129 return true;
130 }
131
132 if self.prediction_errors.len() >= 10 {
134 let recent: f32 = self.prediction_errors.iter().rev().take(10).sum::<f32>() / 10.0;
135 if recent > 0.5 {
136 return true;
137 }
138 }
139
140 false
141 }
142
143 pub fn record_gradient(&mut self, gradients: &HashMap<String, Tensor>) -> Result<()> {
155 for (name, grad) in gradients {
156 if let Some(history) = self.gradient_history.get_mut(name) {
157 if history.len() >= self.config.history_size {
159 history.pop_front();
160 }
161 history.push_back(grad.clone());
162 }
163 }
164
165 self.steps_since_full = 0;
166 self.total_steps += 1;
167 Ok(())
168 }
169
170 pub fn predict_gradient(&mut self) -> Result<HashMap<String, Tensor>> {
185 let mut predicted = HashMap::new();
186 let momentum = self.config.momentum;
187
188 for (name, history) in &self.gradient_history {
189 let prediction = match history.len() {
190 0 => {
191 if let Some(shape) = self.shapes.get(name) {
193 Tensor::zeros(shape.as_slice(), DType::F32, &self.device)?
194 } else {
195 continue;
196 }
197 }
198 1 => {
199 history.back().unwrap().clone()
201 }
202 _ => {
203 let g_prev = &history[history.len() - 2];
205 let g_curr = history.back().unwrap();
206
207 let delta = g_curr.sub(g_prev)?;
209
210 let scaled_delta = (&delta * momentum as f64)?;
212 g_curr.add(&scaled_delta)?
213 }
214 };
215
216 predicted.insert(name.clone(), prediction);
217 }
218
219 self.last_prediction = predicted.clone();
220 self.steps_since_full += 1;
221 self.total_steps += 1;
222
223 Ok(predicted)
224 }
225
226 pub fn compute_correction(
243 &mut self,
244 actual_gradients: &HashMap<String, Tensor>,
245 ) -> Result<HashMap<String, Tensor>> {
246 let mut corrections = HashMap::new();
247
248 for (name, actual) in actual_gradients {
249 if let Some(predicted) = self.last_prediction.get(name) {
250 let correction = actual.sub(predicted)?;
252
253 let error = correction
255 .abs()?
256 .mean_all()?
257 .to_scalar::<f32>()?;
258
259 if self.prediction_errors.len() >= 100 {
260 self.prediction_errors.pop_front();
261 }
262 self.prediction_errors.push_back(error);
263
264 if let Some(existing) = self.correction_accumulator.get(name) {
266 self.correction_accumulator
267 .insert(name.clone(), existing.add(&correction)?);
268 } else {
269 self.correction_accumulator
270 .insert(name.clone(), correction.clone());
271 }
272
273 corrections.insert(name.clone(), correction);
274 }
275 }
276
277 Ok(corrections)
278 }
279
280 pub fn apply_correction(
294 &mut self,
295 gradients: &mut HashMap<String, Tensor>,
296 ) -> Result<()> {
297 let weight = self.config.correction_weight;
298
299 for (name, grad) in gradients.iter_mut() {
300 if let Some(correction) = self.correction_accumulator.get(name) {
301 let scaled = (correction * weight as f64)?;
303 *grad = grad.add(&scaled)?;
304 }
305 }
306
307 self.correction_accumulator.clear();
309 Ok(())
310 }
311
312 #[must_use]
314 #[allow(clippy::cast_precision_loss)]
315 pub fn get_stats(&self) -> PredictorStats {
316 let mean_error = if !self.prediction_errors.is_empty() {
317 self.prediction_errors.iter().sum::<f32>() / self.prediction_errors.len() as f32
318 } else {
319 0.0
320 };
321
322 let recent_error = if self.prediction_errors.len() >= 10 {
323 self.prediction_errors.iter().rev().take(10).sum::<f32>() / 10.0
324 } else if !self.prediction_errors.is_empty() {
325 self.prediction_errors.iter().sum::<f32>() / self.prediction_errors.len() as f32
326 } else {
327 0.0
328 };
329
330 let prediction_ratio = 1.0 - (1.0 / (self.config.prediction_steps + 1) as f32);
331
332 PredictorStats {
333 total_steps: self.total_steps,
334 prediction_ratio,
335 mean_error,
336 recent_error,
337 history_size: self.gradient_history.values().next().map_or(0, |h| h.len()),
338 }
339 }
340
341 #[must_use]
343 pub const fn total_steps(&self) -> usize {
344 self.total_steps
345 }
346
347 pub fn reset(&mut self) {
349 for history in self.gradient_history.values_mut() {
350 history.clear();
351 }
352 self.steps_since_full = 0;
353 self.total_steps = 0;
354 self.last_prediction.clear();
355 self.correction_accumulator.clear();
356 self.prediction_errors.clear();
357 }
358}
359
360#[derive(Debug, Clone)]
362pub struct PredictorStats {
363 pub total_steps: usize,
365 pub prediction_ratio: f32,
367 pub mean_error: f32,
369 pub recent_error: f32,
371 pub history_size: usize,
373}
374
375impl std::fmt::Display for PredictorStats {
376 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377 write!(
378 f,
379 "Steps: {} | Prediction ratio: {:.1}% | Mean error: {:.4} | Recent error: {:.4}",
380 self.total_steps,
381 self.prediction_ratio * 100.0,
382 self.mean_error,
383 self.recent_error
384 )
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 fn create_param_shapes() -> Vec<(String, Vec<usize>)> {
393 vec![
394 ("layer1.weight".to_string(), vec![64, 128]),
395 ("layer1.bias".to_string(), vec![64]),
396 ]
397 }
398
399 fn create_mock_gradients(device: &Device) -> HashMap<String, Tensor> {
400 let mut gradients = HashMap::new();
401 gradients.insert(
402 "layer1.weight".to_string(),
403 Tensor::randn(0.0f32, 0.1, (64, 128), device).unwrap(),
404 );
405 gradients.insert(
406 "layer1.bias".to_string(),
407 Tensor::randn(0.0f32, 0.1, 64, device).unwrap(),
408 );
409 gradients
410 }
411
412 #[test]
413 fn test_predictor_creation() {
414 let shapes = create_param_shapes();
415 let device = Device::Cpu;
416 let config = PredictionConfig::default();
417
418 let predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
419 assert_eq!(predictor.total_steps(), 0);
420 }
421
422 #[test]
423 fn test_should_compute_full_initially() {
424 let shapes = create_param_shapes();
425 let device = Device::Cpu;
426 let config = PredictionConfig::default();
427
428 let predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
429
430 assert!(predictor.should_compute_full());
432 }
433
434 #[test]
435 fn test_record_gradient() {
436 let shapes = create_param_shapes();
437 let device = Device::Cpu;
438 let config = PredictionConfig::default();
439
440 let mut predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
441 let gradients = create_mock_gradients(&device);
442
443 predictor.record_gradient(&gradients).unwrap();
444 assert_eq!(predictor.total_steps(), 1);
445
446 assert!(predictor.should_compute_full());
448
449 predictor.record_gradient(&gradients).unwrap();
450 assert_eq!(predictor.total_steps(), 2);
451
452 assert!(!predictor.should_compute_full());
454 }
455
456 #[test]
457 fn test_predict_gradient() {
458 let shapes = create_param_shapes();
459 let device = Device::Cpu;
460 let config = PredictionConfig::default().with_prediction_steps(4);
461
462 let mut predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
463 let gradients = create_mock_gradients(&device);
464
465 predictor.record_gradient(&gradients).unwrap();
467 predictor.record_gradient(&gradients).unwrap();
468
469 let predicted = predictor.predict_gradient().unwrap();
471 assert_eq!(predicted.len(), 2);
472
473 for (name, _shape) in &shapes {
475 assert!(predicted.contains_key(name));
476 }
477 }
478
479 #[test]
480 fn test_correction_cycle() {
481 let shapes = create_param_shapes();
482 let device = Device::Cpu;
483 let config = PredictionConfig::default().with_prediction_steps(2);
484
485 let mut predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
486 let gradients = create_mock_gradients(&device);
487
488 predictor.record_gradient(&gradients).unwrap();
490 predictor.record_gradient(&gradients).unwrap();
491
492 predictor.predict_gradient().unwrap();
494 predictor.predict_gradient().unwrap();
495
496 assert!(predictor.should_compute_full());
498 }
499
500 #[test]
501 fn test_compute_correction() {
502 let shapes = create_param_shapes();
503 let device = Device::Cpu;
504 let config = PredictionConfig::default();
505
506 let mut predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
507 let gradients = create_mock_gradients(&device);
508
509 predictor.record_gradient(&gradients).unwrap();
511 predictor.record_gradient(&gradients).unwrap();
512 let _predicted = predictor.predict_gradient().unwrap();
513
514 let actual = create_mock_gradients(&device);
516 let corrections = predictor.compute_correction(&actual).unwrap();
517
518 assert_eq!(corrections.len(), 2);
519 }
520
521 #[test]
522 fn test_apply_correction() {
523 let shapes = create_param_shapes();
524 let device = Device::Cpu;
525 let config = PredictionConfig::default().with_correction_weight(0.5);
526
527 let mut predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
528 let gradients = create_mock_gradients(&device);
529
530 predictor.record_gradient(&gradients).unwrap();
532 predictor.record_gradient(&gradients).unwrap();
533 let _predicted = predictor.predict_gradient().unwrap();
534
535 let actual = create_mock_gradients(&device);
537 predictor.compute_correction(&actual).unwrap();
538
539 let mut grads_to_modify = create_mock_gradients(&device);
541 predictor.apply_correction(&mut grads_to_modify).unwrap();
542
543 assert!(predictor.correction_accumulator.is_empty());
545 }
546
547 #[test]
548 fn test_stats() {
549 let shapes = create_param_shapes();
550 let device = Device::Cpu;
551 let config = PredictionConfig::default().with_prediction_steps(4);
552
553 let mut predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
554 let gradients = create_mock_gradients(&device);
555
556 predictor.record_gradient(&gradients).unwrap();
557 predictor.record_gradient(&gradients).unwrap();
558 predictor.predict_gradient().unwrap();
559
560 let stats = predictor.get_stats();
561 assert_eq!(stats.total_steps, 3);
562 assert!(stats.prediction_ratio > 0.7); }
564
565 #[test]
566 fn test_reset() {
567 let shapes = create_param_shapes();
568 let device = Device::Cpu;
569 let config = PredictionConfig::default();
570
571 let mut predictor = GradientPredictor::new(&shapes, config, &device).unwrap();
572 let gradients = create_mock_gradients(&device);
573
574 predictor.record_gradient(&gradients).unwrap();
575 predictor.record_gradient(&gradients).unwrap();
576
577 assert_eq!(predictor.total_steps(), 2);
578
579 predictor.reset();
580
581 assert_eq!(predictor.total_steps(), 0);
582 assert!(predictor.should_compute_full());
583 }
584}