1use std::collections::{HashMap, VecDeque};
27
28use candle_core::{Device, Tensor};
29
30use crate::error::{OptimError, Result};
31use crate::prediction::{DeterministicPredictionConfig, DeterministicPredictor};
32
33fn warn_cpu_fallback(device: &Device) {
34 static WARN_ONCE: std::sync::Once = std::sync::Once::new();
35 if matches!(device, Device::Cpu) {
36 WARN_ONCE.call_once(|| {
37 eprintln!(
38 "vsa-optim-rs: CPU device in use. CUDA is the intended default; use Device::cuda_if_available(0) when possible."
39 );
40 });
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct DeterministicPhaseConfig {
47 pub warmup_steps: usize,
49
50 pub full_steps: usize,
52
53 pub predict_steps: usize,
55
56 pub correct_every: usize,
58
59 pub history_window: usize,
61
62 pub adaptive_phases: bool,
64
65 pub loss_threshold: f32,
67
68 pub max_grad_norm: f32,
70}
71
72impl Default for DeterministicPhaseConfig {
73 fn default() -> Self {
74 Self {
75 warmup_steps: 10,
76 full_steps: 5,
77 predict_steps: 20,
78 correct_every: 5,
79 history_window: 8,
80 adaptive_phases: true,
81 loss_threshold: 0.1,
82 max_grad_norm: 1.0,
83 }
84 }
85}
86
87impl DeterministicPhaseConfig {
88 #[must_use]
90 pub const fn with_warmup_steps(mut self, steps: usize) -> Self {
91 self.warmup_steps = steps;
92 self
93 }
94
95 #[must_use]
97 pub const fn with_full_steps(mut self, steps: usize) -> Self {
98 self.full_steps = steps;
99 self
100 }
101
102 #[must_use]
104 pub const fn with_predict_steps(mut self, steps: usize) -> Self {
105 self.predict_steps = steps;
106 self
107 }
108
109 #[must_use]
111 pub const fn with_correct_every(mut self, every: usize) -> Self {
112 self.correct_every = every;
113 self
114 }
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
119pub enum DeterministicPhase {
120 Warmup,
122 Full,
124 Predict,
126 Correct,
128}
129
130impl std::fmt::Display for DeterministicPhase {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 match self {
133 Self::Warmup => write!(f, "WARMUP"),
134 Self::Full => write!(f, "FULL"),
135 Self::Predict => write!(f, "PREDICT"),
136 Self::Correct => write!(f, "CORRECT"),
137 }
138 }
139}
140
141#[derive(Debug, Clone)]
143pub struct DeterministicStepInfo {
144 pub phase: DeterministicPhase,
146 pub phase_step: usize,
148 pub total_step: usize,
150 pub cycle: usize,
152 pub phase_changed: bool,
154 pub needs_backward: bool,
156}
157
158#[derive(Debug, Clone)]
160pub struct DeterministicTrainerStats {
161 pub total_steps: usize,
163 pub warmup_steps: usize,
165 pub full_steps: usize,
167 pub predict_steps: usize,
169 pub correct_steps: usize,
171 pub cycles: usize,
173 pub speedup: f32,
175 pub mean_prediction_error: f32,
177 pub current_loss: f32,
179}
180
181impl std::fmt::Display for DeterministicTrainerStats {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 write!(
184 f,
185 "Steps: {} | Cycles: {} | Speedup: {:.2}x | Warmup: {} | Full: {} | Predict: {} | Correct: {}",
186 self.total_steps,
187 self.cycles,
188 self.speedup,
189 self.warmup_steps,
190 self.full_steps,
191 self.predict_steps,
192 self.correct_steps
193 )
194 }
195}
196
197pub struct DeterministicPhaseTrainer {
202 config: DeterministicPhaseConfig,
203 device: Device,
204
205 predictor: DeterministicPredictor,
207
208 current_phase: DeterministicPhase,
210
211 phase_step: usize,
213
214 total_step: usize,
216
217 cycle_count: usize,
219
220 warmup_steps_taken: usize,
222 full_steps_taken: usize,
223 predict_steps_taken: usize,
224 correct_steps_taken: usize,
225
226 recent_losses: VecDeque<f32>,
228
229 last_loss: f32,
231
232 warmup_complete: bool,
234
235 effective_full_steps: usize,
237
238 effective_predict_steps: usize,
240}
241
242impl DeterministicPhaseTrainer {
243 pub fn new(
255 param_shapes: &[(String, Vec<usize>)],
256 config: DeterministicPhaseConfig,
257 device: &Device,
258 ) -> Result<Self> {
259 warn_cpu_fallback(device);
260 let prediction_config = DeterministicPredictionConfig {
261 warmup_steps: config.warmup_steps,
262 history_window: config.history_window,
263 prediction_horizon: config.predict_steps,
264 history_decay: 0.95,
265 residual_threshold: 0.5,
266 };
267
268 let predictor = DeterministicPredictor::new(param_shapes, prediction_config, device)?;
269
270 Ok(Self {
271 effective_full_steps: config.full_steps,
272 effective_predict_steps: config.predict_steps,
273 config,
274 device: device.clone(),
275 predictor,
276 current_phase: DeterministicPhase::Warmup,
277 phase_step: 0,
278 total_step: 0,
279 cycle_count: 0,
280 warmup_steps_taken: 0,
281 full_steps_taken: 0,
282 predict_steps_taken: 0,
283 correct_steps_taken: 0,
284 recent_losses: VecDeque::with_capacity(100),
285 last_loss: 0.0,
286 warmup_complete: false,
287 })
288 }
289
290 pub fn begin_step(&mut self) -> Result<DeterministicStepInfo> {
295 let (next_phase, phase_changed) = self.compute_next_phase();
297 if phase_changed {
298 self.transition_to(next_phase);
299 }
300
301 let needs_backward = matches!(
303 self.current_phase,
304 DeterministicPhase::Warmup | DeterministicPhase::Full | DeterministicPhase::Correct
305 );
306
307 Ok(DeterministicStepInfo {
308 phase: self.current_phase,
309 phase_step: self.phase_step,
310 total_step: self.total_step,
311 cycle: self.cycle_count,
312 phase_changed,
313 needs_backward,
314 })
315 }
316
317 fn compute_next_phase(&self) -> (DeterministicPhase, bool) {
319 match self.current_phase {
320 DeterministicPhase::Warmup => {
321 if self.predictor.is_ready() {
322 (DeterministicPhase::Full, true)
323 } else {
324 (DeterministicPhase::Warmup, false)
325 }
326 }
327 DeterministicPhase::Full => {
328 if self.phase_step >= self.effective_full_steps {
329 (DeterministicPhase::Predict, true)
330 } else {
331 (DeterministicPhase::Full, false)
332 }
333 }
334 DeterministicPhase::Predict => {
335 if self.phase_step > 0 && self.phase_step % self.config.correct_every == 0 {
337 return (DeterministicPhase::Correct, true);
338 }
339 if self.predictor.needs_correction() {
341 return (DeterministicPhase::Correct, true);
342 }
343 if self.phase_step >= self.effective_predict_steps {
345 return (DeterministicPhase::Full, true);
346 }
347 (DeterministicPhase::Predict, false)
348 }
349 DeterministicPhase::Correct => {
350 let remaining = self.effective_predict_steps.saturating_sub(self.phase_step);
352 if remaining > 0 {
353 (DeterministicPhase::Predict, true)
354 } else {
355 (DeterministicPhase::Full, true)
356 }
357 }
358 }
359 }
360
361 fn transition_to(&mut self, new_phase: DeterministicPhase) {
363 let old_phase = self.current_phase;
364 self.current_phase = new_phase;
365
366 match new_phase {
367 DeterministicPhase::Warmup => {
368 }
370 DeterministicPhase::Full => {
371 if old_phase != DeterministicPhase::Warmup {
373 self.cycle_count += 1;
374 }
375 self.phase_step = 0;
376 self.warmup_complete = true;
377
378 if self.config.adaptive_phases {
380 self.adjust_phase_lengths();
381 }
382 }
383 DeterministicPhase::Predict => {
384 if old_phase == DeterministicPhase::Full {
385 self.phase_step = 0;
386 }
387 }
389 DeterministicPhase::Correct => {
390 }
392 }
393 }
394
395 fn adjust_phase_lengths(&mut self) {
397 if self.recent_losses.len() < 20 {
398 return;
399 }
400
401 let losses: Vec<f32> = self.recent_losses.iter().copied().collect();
402 let early: f32 = losses[..10].iter().sum::<f32>() / 10.0;
403 let late: f32 = losses[losses.len() - 10..].iter().sum::<f32>() / 10.0;
404
405 if late > early * (1.0 + self.config.loss_threshold) {
406 self.effective_full_steps = (self.effective_full_steps + 2).min(30);
408 self.effective_predict_steps = self.effective_predict_steps.saturating_sub(5).max(5);
409 } else if late < early * 0.95 {
410 self.effective_full_steps = self.effective_full_steps.saturating_sub(1).max(3);
412 self.effective_predict_steps = (self.effective_predict_steps + 3).min(50);
413 }
414 }
415
416 #[must_use]
418 pub fn needs_backward(&self) -> bool {
419 matches!(
420 self.current_phase,
421 DeterministicPhase::Warmup | DeterministicPhase::Full | DeterministicPhase::Correct
422 )
423 }
424
425 pub fn record_full_gradients(&mut self, gradients: &HashMap<String, Tensor>) -> Result<()> {
434 let is_correction = self.current_phase == DeterministicPhase::Correct;
435 self.predictor.record_gradient(gradients, is_correction)?;
436 Ok(())
437 }
438
439 pub fn get_predicted_gradients(&mut self) -> Result<HashMap<String, Tensor>> {
447 if !self.warmup_complete {
448 return Err(OptimError::Prediction(
449 "Cannot predict during warmup phase".to_string(),
450 ));
451 }
452 self.predictor.predict_gradient()
453 }
454
455 #[allow(clippy::cast_precision_loss)]
463 pub fn end_step(&mut self, loss: f32) -> Result<()> {
464 if self.recent_losses.len() >= 100 {
466 self.recent_losses.pop_front();
467 }
468 self.recent_losses.push_back(loss);
469 self.last_loss = loss;
470
471 match self.current_phase {
473 DeterministicPhase::Warmup => self.warmup_steps_taken += 1,
474 DeterministicPhase::Full => self.full_steps_taken += 1,
475 DeterministicPhase::Predict => self.predict_steps_taken += 1,
476 DeterministicPhase::Correct => self.correct_steps_taken += 1,
477 }
478
479 self.phase_step += 1;
481 self.total_step += 1;
482
483 Ok(())
484 }
485
486 #[must_use]
488 pub const fn current_phase(&self) -> DeterministicPhase {
489 self.current_phase
490 }
491
492 #[must_use]
494 pub const fn warmup_complete(&self) -> bool {
495 self.warmup_complete
496 }
497
498 #[must_use]
500 #[allow(clippy::cast_precision_loss)]
501 pub fn get_stats(&self) -> DeterministicTrainerStats {
502 let total_forward = self.total_step as f32;
504 let total_backward = (self.warmup_steps_taken
505 + self.full_steps_taken
506 + self.correct_steps_taken)
507 .max(1) as f32;
508 let speedup = total_forward / total_backward;
509
510 DeterministicTrainerStats {
511 total_steps: self.total_step,
512 warmup_steps: self.warmup_steps_taken,
513 full_steps: self.full_steps_taken,
514 predict_steps: self.predict_steps_taken,
515 correct_steps: self.correct_steps_taken,
516 cycles: self.cycle_count,
517 speedup,
518 mean_prediction_error: self.predictor.get_stats().mean_abs_error,
519 current_loss: self.last_loss,
520 }
521 }
522
523 pub fn reset(&mut self) -> Result<()> {
525 self.predictor.reset()?;
526 self.current_phase = DeterministicPhase::Warmup;
527 self.phase_step = 0;
528 self.total_step = 0;
529 self.cycle_count = 0;
530 self.warmup_steps_taken = 0;
531 self.full_steps_taken = 0;
532 self.predict_steps_taken = 0;
533 self.correct_steps_taken = 0;
534 self.recent_losses.clear();
535 self.last_loss = 0.0;
536 self.warmup_complete = false;
537 self.effective_full_steps = self.config.full_steps;
538 self.effective_predict_steps = self.config.predict_steps;
539 Ok(())
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546
547 fn create_shapes() -> Vec<(String, Vec<usize>)> {
548 vec![
549 ("layer.weight".to_string(), vec![16, 32]),
550 ("layer.bias".to_string(), vec![16]),
551 ]
552 }
553
554 fn create_mock_gradients(device: &Device, scale: f32) -> HashMap<String, Tensor> {
555 let mut grads = HashMap::new();
556 grads.insert(
557 "layer.weight".to_string(),
558 Tensor::ones((16, 32), candle_core::DType::F32, device)
559 .unwrap()
560 .affine(scale as f64, 0.0)
561 .unwrap(),
562 );
563 grads.insert(
564 "layer.bias".to_string(),
565 Tensor::ones(16, candle_core::DType::F32, device)
566 .unwrap()
567 .affine(scale as f64, 0.0)
568 .unwrap(),
569 );
570 grads
571 }
572
573 #[test]
574 fn test_warmup_to_full_transition() {
575 let config = DeterministicPhaseConfig::default()
576 .with_warmup_steps(5)
577 .with_full_steps(3);
578
579 let mut trainer =
580 DeterministicPhaseTrainer::new(&create_shapes(), config, &Device::Cpu).unwrap();
581
582 let info = trainer.begin_step().unwrap();
584 assert_eq!(info.phase, DeterministicPhase::Warmup);
585 assert!(info.needs_backward);
586
587 for i in 0..5 {
589 let grads = create_mock_gradients(&Device::Cpu, 1.0 + i as f32 * 0.1);
590 trainer.record_full_gradients(&grads).unwrap();
591 trainer.end_step(1.0 - i as f32 * 0.1).unwrap();
592 trainer.begin_step().unwrap();
593 }
594
595 assert!(trainer.warmup_complete());
597 assert_eq!(trainer.current_phase(), DeterministicPhase::Full);
598 }
599
600 #[test]
601 fn test_full_cycle() {
602 let config = DeterministicPhaseConfig::default()
603 .with_warmup_steps(3)
604 .with_full_steps(2)
605 .with_predict_steps(4)
606 .with_correct_every(2);
607
608 let mut trainer =
609 DeterministicPhaseTrainer::new(&create_shapes(), config, &Device::Cpu).unwrap();
610
611 let mut phases_seen = Vec::new();
612
613 for i in 0..20 {
615 let info = trainer.begin_step().unwrap();
616 phases_seen.push(info.phase);
617
618 if info.needs_backward {
619 let grads = create_mock_gradients(&Device::Cpu, 1.0 + i as f32 * 0.05);
620 trainer.record_full_gradients(&grads).unwrap();
621 } else {
622 let _predicted = trainer.get_predicted_gradients().unwrap();
623 }
624
625 trainer.end_step(1.0 / (i + 1) as f32).unwrap();
626 }
627
628 assert!(phases_seen.contains(&DeterministicPhase::Warmup));
630 assert!(phases_seen.contains(&DeterministicPhase::Full));
631 assert!(phases_seen.contains(&DeterministicPhase::Predict));
632 }
634
635 #[test]
636 fn test_deterministic_stats() {
637 let config = DeterministicPhaseConfig::default()
638 .with_warmup_steps(5)
639 .with_full_steps(2)
640 .with_predict_steps(8);
641
642 let mut trainer =
643 DeterministicPhaseTrainer::new(&create_shapes(), config, &Device::Cpu).unwrap();
644
645 for i in 0..15 {
647 let info = trainer.begin_step().unwrap();
648 if info.needs_backward {
649 let grads = create_mock_gradients(&Device::Cpu, 1.0);
650 trainer.record_full_gradients(&grads).unwrap();
651 } else {
652 let _ = trainer.get_predicted_gradients();
653 }
654 trainer.end_step(0.5).unwrap();
655 }
656
657 let stats = trainer.get_stats();
658 assert_eq!(stats.total_steps, 15);
659 assert!(stats.speedup >= 1.0);
660 assert!(stats.warmup_steps > 0);
661 }
662}