Skip to main content

vsa_optim_rs/phase/
deterministic_trainer.rs

1//! Deterministic phase trainer implementation.
2//!
3//! Orchestrates phase-based training using deterministic gradient prediction.
4//! This trainer guarantees reproducible training outcomes by:
5//!
6//! 1. Using deterministic least-squares gradient model fitting
7//! 2. Tracking residuals for drift correction
8//! 3. Requiring warmup before prediction begins
9//!
10//! # Training Phases
11//!
12//! ```text
13//! WARMUP ──► FULL ──► PREDICT ──► CORRECT ──► FULL ──► ...
14//!   │                    │            │
15//!   │                    │            └─► Extract residual, refit model
16//!   │                    └─► Use predicted gradients
17//!   └─► Build gradient history for model fitting
18//! ```
19//!
20//! # Determinism Guarantees
21//!
22//! - Same random seed + same data order = identical training trajectory
23//! - No stochastic operations in prediction
24//! - Residuals ensure predictions converge to actual gradients over time
25
26use std::collections::{HashMap, VecDeque};
27
28use candle_core::{Device, Tensor};
29
30use crate::error::{OptimError, Result};
31use crate::prediction::{DeterministicPredictionConfig, DeterministicPredictor};
32
33fn warn_cpu_fallback(device: &Device) {
34    static WARN_ONCE: std::sync::Once = std::sync::Once::new();
35    if matches!(device, Device::Cpu) {
36        WARN_ONCE.call_once(|| {
37            eprintln!(
38                "vsa-optim-rs: CPU device in use. CUDA is the intended default; use Device::cuda_if_available(0) when possible."
39            );
40        });
41    }
42}
43
44/// Configuration for deterministic phase training.
45#[derive(Debug, Clone)]
46pub struct DeterministicPhaseConfig {
47    /// Warmup steps before prediction begins.
48    pub warmup_steps: usize,
49
50    /// Full gradient steps per cycle (after warmup).
51    pub full_steps: usize,
52
53    /// Prediction steps per cycle.
54    pub predict_steps: usize,
55
56    /// Correction frequency during prediction phase.
57    pub correct_every: usize,
58
59    /// History window for model fitting.
60    pub history_window: usize,
61
62    /// Whether to adaptively adjust phase lengths.
63    pub adaptive_phases: bool,
64
65    /// Loss threshold for triggering more full steps.
66    pub loss_threshold: f32,
67
68    /// Maximum gradient norm for clipping.
69    pub max_grad_norm: f32,
70}
71
72impl Default for DeterministicPhaseConfig {
73    fn default() -> Self {
74        Self {
75            warmup_steps: 10,
76            full_steps: 5,
77            predict_steps: 20,
78            correct_every: 5,
79            history_window: 8,
80            adaptive_phases: true,
81            loss_threshold: 0.1,
82            max_grad_norm: 1.0,
83        }
84    }
85}
86
87impl DeterministicPhaseConfig {
88    /// Builder: Set warmup steps.
89    #[must_use]
90    pub const fn with_warmup_steps(mut self, steps: usize) -> Self {
91        self.warmup_steps = steps;
92        self
93    }
94
95    /// Builder: Set full steps per cycle.
96    #[must_use]
97    pub const fn with_full_steps(mut self, steps: usize) -> Self {
98        self.full_steps = steps;
99        self
100    }
101
102    /// Builder: Set prediction steps per cycle.
103    #[must_use]
104    pub const fn with_predict_steps(mut self, steps: usize) -> Self {
105        self.predict_steps = steps;
106        self
107    }
108
109    /// Builder: Set correction frequency.
110    #[must_use]
111    pub const fn with_correct_every(mut self, every: usize) -> Self {
112        self.correct_every = every;
113        self
114    }
115}
116
117/// Training phase for deterministic phase trainer.
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
119pub enum DeterministicPhase {
120    /// Initial warmup phase - always compute full gradients.
121    Warmup,
122    /// Full gradient computation phase.
123    Full,
124    /// Prediction phase - use predicted gradients.
125    Predict,
126    /// Correction phase - compute full gradients and update residuals.
127    Correct,
128}
129
130impl std::fmt::Display for DeterministicPhase {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        match self {
133            Self::Warmup => write!(f, "WARMUP"),
134            Self::Full => write!(f, "FULL"),
135            Self::Predict => write!(f, "PREDICT"),
136            Self::Correct => write!(f, "CORRECT"),
137        }
138    }
139}
140
141/// Step information from the phase trainer.
142#[derive(Debug, Clone)]
143pub struct DeterministicStepInfo {
144    /// Current phase.
145    pub phase: DeterministicPhase,
146    /// Step within current phase.
147    pub phase_step: usize,
148    /// Total training steps.
149    pub total_step: usize,
150    /// Training cycle count (after warmup).
151    pub cycle: usize,
152    /// Whether phase changed this step.
153    pub phase_changed: bool,
154    /// Whether backward pass is needed.
155    pub needs_backward: bool,
156}
157
158/// Training statistics.
159#[derive(Debug, Clone)]
160pub struct DeterministicTrainerStats {
161    /// Total steps taken.
162    pub total_steps: usize,
163    /// Warmup steps taken.
164    pub warmup_steps: usize,
165    /// Full gradient steps taken.
166    pub full_steps: usize,
167    /// Prediction steps taken.
168    pub predict_steps: usize,
169    /// Correction steps taken.
170    pub correct_steps: usize,
171    /// Training cycles completed.
172    pub cycles: usize,
173    /// Effective speedup ratio.
174    pub speedup: f32,
175    /// Mean prediction error.
176    pub mean_prediction_error: f32,
177    /// Current loss (most recent).
178    pub current_loss: f32,
179}
180
181impl std::fmt::Display for DeterministicTrainerStats {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        write!(
184            f,
185            "Steps: {} | Cycles: {} | Speedup: {:.2}x | Warmup: {} | Full: {} | Predict: {} | Correct: {}",
186            self.total_steps,
187            self.cycles,
188            self.speedup,
189            self.warmup_steps,
190            self.full_steps,
191            self.predict_steps,
192            self.correct_steps
193        )
194    }
195}
196
197/// Deterministic phase-based trainer.
198///
199/// Orchestrates training with guaranteed deterministic outcomes.
200/// Uses warmup → full → predict → correct cycle with residual tracking.
201pub struct DeterministicPhaseTrainer {
202    config: DeterministicPhaseConfig,
203    device: Device,
204
205    /// Deterministic gradient predictor.
206    predictor: DeterministicPredictor,
207
208    /// Current phase.
209    current_phase: DeterministicPhase,
210
211    /// Step within current phase.
212    phase_step: usize,
213
214    /// Total training steps.
215    total_step: usize,
216
217    /// Training cycles (full → predict → correct sequences).
218    cycle_count: usize,
219
220    /// Steps taken per phase.
221    warmup_steps_taken: usize,
222    full_steps_taken: usize,
223    predict_steps_taken: usize,
224    correct_steps_taken: usize,
225
226    /// Recent losses for adaptive scheduling.
227    recent_losses: VecDeque<f32>,
228
229    /// Last recorded loss.
230    last_loss: f32,
231
232    /// Whether warmup is complete.
233    warmup_complete: bool,
234
235    /// Effective full steps per cycle (may adapt).
236    effective_full_steps: usize,
237
238    /// Effective predict steps per cycle (may adapt).
239    effective_predict_steps: usize,
240}
241
242impl DeterministicPhaseTrainer {
243    /// Create a new deterministic phase trainer.
244    ///
245    /// # Arguments
246    ///
247    /// * `param_shapes` - List of (name, shape) tuples for parameters
248    /// * `config` - Phase training configuration
249    /// * `device` - Device for tensor storage
250    ///
251    /// # Errors
252    ///
253    /// Returns error if predictor initialization fails.
254    pub fn new(
255        param_shapes: &[(String, Vec<usize>)],
256        config: DeterministicPhaseConfig,
257        device: &Device,
258    ) -> Result<Self> {
259        warn_cpu_fallback(device);
260        let prediction_config = DeterministicPredictionConfig {
261            warmup_steps: config.warmup_steps,
262            history_window: config.history_window,
263            prediction_horizon: config.predict_steps,
264            history_decay: 0.95,
265            residual_threshold: 0.5,
266        };
267
268        let predictor = DeterministicPredictor::new(param_shapes, prediction_config, device)?;
269
270        Ok(Self {
271            effective_full_steps: config.full_steps,
272            effective_predict_steps: config.predict_steps,
273            config,
274            device: device.clone(),
275            predictor,
276            current_phase: DeterministicPhase::Warmup,
277            phase_step: 0,
278            total_step: 0,
279            cycle_count: 0,
280            warmup_steps_taken: 0,
281            full_steps_taken: 0,
282            predict_steps_taken: 0,
283            correct_steps_taken: 0,
284            recent_losses: VecDeque::with_capacity(100),
285            last_loss: 0.0,
286            warmup_complete: false,
287        })
288    }
289
290    /// Begin a training step.
291    ///
292    /// Returns information about the current phase and whether
293    /// backward pass (full gradient computation) is needed.
294    pub fn begin_step(&mut self) -> Result<DeterministicStepInfo> {
295        // Check for phase transitions
296        let (next_phase, phase_changed) = self.compute_next_phase();
297        if phase_changed {
298            self.transition_to(next_phase);
299        }
300
301        // Determine if backward is needed
302        let needs_backward = matches!(
303            self.current_phase,
304            DeterministicPhase::Warmup | DeterministicPhase::Full | DeterministicPhase::Correct
305        );
306
307        Ok(DeterministicStepInfo {
308            phase: self.current_phase,
309            phase_step: self.phase_step,
310            total_step: self.total_step,
311            cycle: self.cycle_count,
312            phase_changed,
313            needs_backward,
314        })
315    }
316
317    /// Compute the next phase based on current state.
318    fn compute_next_phase(&self) -> (DeterministicPhase, bool) {
319        match self.current_phase {
320            DeterministicPhase::Warmup => {
321                if self.predictor.is_ready() {
322                    (DeterministicPhase::Full, true)
323                } else {
324                    (DeterministicPhase::Warmup, false)
325                }
326            }
327            DeterministicPhase::Full => {
328                if self.phase_step >= self.effective_full_steps {
329                    (DeterministicPhase::Predict, true)
330                } else {
331                    (DeterministicPhase::Full, false)
332                }
333            }
334            DeterministicPhase::Predict => {
335                // Check for correction
336                if self.phase_step > 0 && self.phase_step % self.config.correct_every == 0 {
337                    return (DeterministicPhase::Correct, true);
338                }
339                // Check for residual-triggered correction
340                if self.predictor.needs_correction() {
341                    return (DeterministicPhase::Correct, true);
342                }
343                // Check for cycle completion
344                if self.phase_step >= self.effective_predict_steps {
345                    return (DeterministicPhase::Full, true);
346                }
347                (DeterministicPhase::Predict, false)
348            }
349            DeterministicPhase::Correct => {
350                // After correction, continue predict or start new cycle
351                let remaining = self.effective_predict_steps.saturating_sub(self.phase_step);
352                if remaining > 0 {
353                    (DeterministicPhase::Predict, true)
354                } else {
355                    (DeterministicPhase::Full, true)
356                }
357            }
358        }
359    }
360
361    /// Handle phase transition.
362    fn transition_to(&mut self, new_phase: DeterministicPhase) {
363        let old_phase = self.current_phase;
364        self.current_phase = new_phase;
365
366        match new_phase {
367            DeterministicPhase::Warmup => {
368                // Shouldn't happen - warmup only at start
369            }
370            DeterministicPhase::Full => {
371                // Starting new cycle
372                if old_phase != DeterministicPhase::Warmup {
373                    self.cycle_count += 1;
374                }
375                self.phase_step = 0;
376                self.warmup_complete = true;
377
378                // Adaptive phase adjustment
379                if self.config.adaptive_phases {
380                    self.adjust_phase_lengths();
381                }
382            }
383            DeterministicPhase::Predict => {
384                if old_phase == DeterministicPhase::Full {
385                    self.phase_step = 0;
386                }
387                // Don't reset phase_step when returning from correction
388            }
389            DeterministicPhase::Correct => {
390                // Don't reset phase_step - we continue prediction count
391            }
392        }
393    }
394
395    /// Adjust phase lengths based on training dynamics.
396    fn adjust_phase_lengths(&mut self) {
397        if self.recent_losses.len() < 20 {
398            return;
399        }
400
401        let losses: Vec<f32> = self.recent_losses.iter().copied().collect();
402        let early: f32 = losses[..10].iter().sum::<f32>() / 10.0;
403        let late: f32 = losses[losses.len() - 10..].iter().sum::<f32>() / 10.0;
404
405        if late > early * (1.0 + self.config.loss_threshold) {
406            // Loss increasing: more full training, less prediction
407            self.effective_full_steps = (self.effective_full_steps + 2).min(30);
408            self.effective_predict_steps = self.effective_predict_steps.saturating_sub(5).max(5);
409        } else if late < early * 0.95 {
410            // Loss decreasing well: can use more prediction
411            self.effective_full_steps = self.effective_full_steps.saturating_sub(1).max(3);
412            self.effective_predict_steps = (self.effective_predict_steps + 3).min(50);
413        }
414    }
415
416    /// Check if backward pass is needed for current step.
417    #[must_use]
418    pub fn needs_backward(&self) -> bool {
419        matches!(
420            self.current_phase,
421            DeterministicPhase::Warmup | DeterministicPhase::Full | DeterministicPhase::Correct
422        )
423    }
424
425    /// Record full gradients after backward pass.
426    ///
427    /// Called during WARMUP, FULL, or CORRECT phases after computing
428    /// gradients via backpropagation.
429    ///
430    /// # Arguments
431    ///
432    /// * `gradients` - Map of parameter names to gradient tensors
433    pub fn record_full_gradients(&mut self, gradients: &HashMap<String, Tensor>) -> Result<()> {
434        let is_correction = self.current_phase == DeterministicPhase::Correct;
435        self.predictor.record_gradient(gradients, is_correction)?;
436        Ok(())
437    }
438
439    /// Get predicted gradients for current step.
440    ///
441    /// Called during PREDICT phase to get deterministic gradient predictions.
442    ///
443    /// # Returns
444    ///
445    /// Map of parameter names to predicted gradient tensors.
446    pub fn get_predicted_gradients(&mut self) -> Result<HashMap<String, Tensor>> {
447        if !self.warmup_complete {
448            return Err(OptimError::Prediction(
449                "Cannot predict during warmup phase".to_string(),
450            ));
451        }
452        self.predictor.predict_gradient()
453    }
454
455    /// End the current training step.
456    ///
457    /// Updates internal state and statistics.
458    ///
459    /// # Arguments
460    ///
461    /// * `loss` - Loss value for this step
462    #[allow(clippy::cast_precision_loss)]
463    pub fn end_step(&mut self, loss: f32) -> Result<()> {
464        // Track loss
465        if self.recent_losses.len() >= 100 {
466            self.recent_losses.pop_front();
467        }
468        self.recent_losses.push_back(loss);
469        self.last_loss = loss;
470
471        // Update phase-specific counters
472        match self.current_phase {
473            DeterministicPhase::Warmup => self.warmup_steps_taken += 1,
474            DeterministicPhase::Full => self.full_steps_taken += 1,
475            DeterministicPhase::Predict => self.predict_steps_taken += 1,
476            DeterministicPhase::Correct => self.correct_steps_taken += 1,
477        }
478
479        // Update step counters
480        self.phase_step += 1;
481        self.total_step += 1;
482
483        Ok(())
484    }
485
486    /// Get current training phase.
487    #[must_use]
488    pub const fn current_phase(&self) -> DeterministicPhase {
489        self.current_phase
490    }
491
492    /// Check if warmup is complete.
493    #[must_use]
494    pub const fn warmup_complete(&self) -> bool {
495        self.warmup_complete
496    }
497
498    /// Get training statistics.
499    #[must_use]
500    #[allow(clippy::cast_precision_loss)]
501    pub fn get_stats(&self) -> DeterministicTrainerStats {
502        // Calculate speedup: total forward steps / backward steps
503        let total_forward = self.total_step as f32;
504        let total_backward = (self.warmup_steps_taken
505            + self.full_steps_taken
506            + self.correct_steps_taken)
507            .max(1) as f32;
508        let speedup = total_forward / total_backward;
509
510        DeterministicTrainerStats {
511            total_steps: self.total_step,
512            warmup_steps: self.warmup_steps_taken,
513            full_steps: self.full_steps_taken,
514            predict_steps: self.predict_steps_taken,
515            correct_steps: self.correct_steps_taken,
516            cycles: self.cycle_count,
517            speedup,
518            mean_prediction_error: self.predictor.get_stats().mean_abs_error,
519            current_loss: self.last_loss,
520        }
521    }
522
523    /// Reset trainer state.
524    pub fn reset(&mut self) -> Result<()> {
525        self.predictor.reset()?;
526        self.current_phase = DeterministicPhase::Warmup;
527        self.phase_step = 0;
528        self.total_step = 0;
529        self.cycle_count = 0;
530        self.warmup_steps_taken = 0;
531        self.full_steps_taken = 0;
532        self.predict_steps_taken = 0;
533        self.correct_steps_taken = 0;
534        self.recent_losses.clear();
535        self.last_loss = 0.0;
536        self.warmup_complete = false;
537        self.effective_full_steps = self.config.full_steps;
538        self.effective_predict_steps = self.config.predict_steps;
539        Ok(())
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546
547    fn create_shapes() -> Vec<(String, Vec<usize>)> {
548        vec![
549            ("layer.weight".to_string(), vec![16, 32]),
550            ("layer.bias".to_string(), vec![16]),
551        ]
552    }
553
554    fn create_mock_gradients(device: &Device, scale: f32) -> HashMap<String, Tensor> {
555        let mut grads = HashMap::new();
556        grads.insert(
557            "layer.weight".to_string(),
558            Tensor::ones((16, 32), candle_core::DType::F32, device)
559                .unwrap()
560                .affine(scale as f64, 0.0)
561                .unwrap(),
562        );
563        grads.insert(
564            "layer.bias".to_string(),
565            Tensor::ones(16, candle_core::DType::F32, device)
566                .unwrap()
567                .affine(scale as f64, 0.0)
568                .unwrap(),
569        );
570        grads
571    }
572
573    #[test]
574    fn test_warmup_to_full_transition() {
575        let config = DeterministicPhaseConfig::default()
576            .with_warmup_steps(5)
577            .with_full_steps(3);
578
579        let mut trainer =
580            DeterministicPhaseTrainer::new(&create_shapes(), config, &Device::Cpu).unwrap();
581
582        // Should start in warmup
583        let info = trainer.begin_step().unwrap();
584        assert_eq!(info.phase, DeterministicPhase::Warmup);
585        assert!(info.needs_backward);
586
587        // Run through warmup
588        for i in 0..5 {
589            let grads = create_mock_gradients(&Device::Cpu, 1.0 + i as f32 * 0.1);
590            trainer.record_full_gradients(&grads).unwrap();
591            trainer.end_step(1.0 - i as f32 * 0.1).unwrap();
592            trainer.begin_step().unwrap();
593        }
594
595        // Should now be in FULL phase
596        assert!(trainer.warmup_complete());
597        assert_eq!(trainer.current_phase(), DeterministicPhase::Full);
598    }
599
600    #[test]
601    fn test_full_cycle() {
602        let config = DeterministicPhaseConfig::default()
603            .with_warmup_steps(3)
604            .with_full_steps(2)
605            .with_predict_steps(4)
606            .with_correct_every(2);
607
608        let mut trainer =
609            DeterministicPhaseTrainer::new(&create_shapes(), config, &Device::Cpu).unwrap();
610
611        let mut phases_seen = Vec::new();
612
613        // Run 20 steps
614        for i in 0..20 {
615            let info = trainer.begin_step().unwrap();
616            phases_seen.push(info.phase);
617
618            if info.needs_backward {
619                let grads = create_mock_gradients(&Device::Cpu, 1.0 + i as f32 * 0.05);
620                trainer.record_full_gradients(&grads).unwrap();
621            } else {
622                let _predicted = trainer.get_predicted_gradients().unwrap();
623            }
624
625            trainer.end_step(1.0 / (i + 1) as f32).unwrap();
626        }
627
628        // Should have seen all phase types
629        assert!(phases_seen.contains(&DeterministicPhase::Warmup));
630        assert!(phases_seen.contains(&DeterministicPhase::Full));
631        assert!(phases_seen.contains(&DeterministicPhase::Predict));
632        // Correction may or may not trigger depending on residuals
633    }
634
635    #[test]
636    fn test_deterministic_stats() {
637        let config = DeterministicPhaseConfig::default()
638            .with_warmup_steps(5)
639            .with_full_steps(2)
640            .with_predict_steps(8);
641
642        let mut trainer =
643            DeterministicPhaseTrainer::new(&create_shapes(), config, &Device::Cpu).unwrap();
644
645        // Run some steps
646        for i in 0..15 {
647            let info = trainer.begin_step().unwrap();
648            if info.needs_backward {
649                let grads = create_mock_gradients(&Device::Cpu, 1.0);
650                trainer.record_full_gradients(&grads).unwrap();
651            } else {
652                let _ = trainer.get_predicted_gradients();
653            }
654            trainer.end_step(0.5).unwrap();
655        }
656
657        let stats = trainer.get_stats();
658        assert_eq!(stats.total_steps, 15);
659        assert!(stats.speedup >= 1.0);
660        assert!(stats.warmup_steps > 0);
661    }
662}