Skip to main content

vsa_optim_rs/phase/
trainer.rs

1//! Phase trainer implementation.
2//!
3//! Orchestrates phase-based training for acceleration by combining
4//! gradient prediction, VSA compression, and ternary optimization.
5
6use std::collections::{HashMap, VecDeque};
7
8use candle_core::{Device, Tensor};
9
10use crate::config::PhaseConfig;
11use crate::error::Result;
12use crate::prediction::GradientPredictor;
13use crate::ternary::TernaryGradientAccumulator;
14use crate::vsa::VSAGradientCompressor;
15
16fn warn_cpu_fallback(device: &Device) {
17    static WARN_ONCE: std::sync::Once = std::sync::Once::new();
18    if matches!(device, Device::Cpu) {
19        WARN_ONCE.call_once(|| {
20            eprintln!(
21                "vsa-optim-rs: CPU device in use. CUDA is the intended default; use Device::cuda_if_available(0) when possible."
22            );
23        });
24    }
25}
26
27/// Training phase types.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum TrainingPhase {
30    /// Full gradient computation.
31    Full,
32    /// Predicted gradients.
33    Predict,
34    /// Correction phase.
35    Correct,
36}
37
38impl std::fmt::Display for TrainingPhase {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        match self {
41            Self::Full => write!(f, "FULL"),
42            Self::Predict => write!(f, "PREDICT"),
43            Self::Correct => write!(f, "CORRECT"),
44        }
45    }
46}
47
48/// Orchestrates phase-based training for acceleration.
49///
50/// This is the main training orchestrator that combines all optimization
51/// techniques. It manages the phase transitions and ensures convergence
52/// while maximizing training speed.
53///
54/// The trainer automatically:
55/// 1. Tracks which phase we're in
56/// 2. Manages gradient prediction during PREDICT phase
57/// 3. Applies corrections to prevent drift
58/// 4. Uses ternary accumulation for memory efficiency
59/// 5. Optionally uses VSA compression for gradient storage
60///
61/// # Example
62///
63/// ```ignore
64/// use vsa_optim_rs::phase::PhaseTrainer;
65/// use vsa_optim_rs::PhaseConfig;
66///
67/// let shapes = vec![("layer.weight".to_string(), vec![64, 128])];
68/// let config = PhaseConfig::default();
69/// let mut trainer = PhaseTrainer::new(&shapes, config, &Device::Cpu)?;
70///
71/// // Training loop
72/// for step in 0..total_steps {
73///     let step_info = trainer.begin_step()?;
74///
75///     match step_info.phase {
76///         TrainingPhase::Full | TrainingPhase::Correct => {
77///             // Compute full gradients via backprop
78///             trainer.record_full_gradients(&gradients)?;
79///         }
80///         TrainingPhase::Predict => {
81///             // Use predicted gradients
82///             let predicted = trainer.get_predicted_gradients()?;
83///         }
84///     }
85///
86///     trainer.end_step(loss_value)?;
87/// }
88/// ```
89pub struct PhaseTrainer {
90    config: PhaseConfig,
91    device: Device,
92
93    /// Gradient predictor.
94    predictor: GradientPredictor,
95
96    /// Ternary gradient accumulator.
97    ternary_accum: TernaryGradientAccumulator,
98
99    /// VSA gradient compressor.
100    vsa_compressor: VSAGradientCompressor,
101
102    /// Current training phase.
103    current_phase: TrainingPhase,
104
105    /// Steps in current phase.
106    phase_step: usize,
107
108    /// Total training steps.
109    total_step: usize,
110
111    /// Cycle count (full phase completions).
112    cycle_count: usize,
113
114    /// Per-phase loss tracking.
115    phase_losses: HashMap<TrainingPhase, Vec<f32>>,
116
117    /// Recent losses for adaptive scheduling.
118    recent_losses: VecDeque<f32>,
119
120    /// Speedup ratio.
121    speedup_ratio: f32,
122
123    /// Steps taken per phase type.
124    full_steps_taken: usize,
125    predict_steps_taken: usize,
126    correct_steps_taken: usize,
127
128    /// Parameter shapes for reference.
129    param_shapes: Vec<(String, Vec<usize>)>,
130}
131
132impl PhaseTrainer {
133    /// Create a new phase trainer.
134    ///
135    /// # Arguments
136    ///
137    /// * `param_shapes` - List of (name, shape) tuples for parameters
138    /// * `config` - Phase training configuration
139    /// * `device` - Device for tensor storage
140    ///
141    /// # Errors
142    ///
143    /// Returns error if component initialization fails.
144    pub fn new(
145        param_shapes: &[(String, Vec<usize>)],
146        config: PhaseConfig,
147        device: &Device,
148    ) -> Result<Self> {
149        warn_cpu_fallback(device);
150        let predictor = GradientPredictor::new(
151            param_shapes,
152            config.prediction_config.clone(),
153            device,
154        )?;
155
156        let ternary_accum = TernaryGradientAccumulator::new(
157            param_shapes,
158            config.ternary_config.clone(),
159            device,
160        )?;
161
162        let param_count: usize = param_shapes.iter().map(|(_, s)| s.iter().product::<usize>()).sum();
163        let vsa_compressor = VSAGradientCompressor::new(param_count, config.vsa_config.clone());
164
165        let mut phase_losses = HashMap::new();
166        phase_losses.insert(TrainingPhase::Full, Vec::new());
167        phase_losses.insert(TrainingPhase::Predict, Vec::new());
168        phase_losses.insert(TrainingPhase::Correct, Vec::new());
169
170        Ok(Self {
171            config,
172            device: device.clone(),
173            predictor,
174            ternary_accum,
175            vsa_compressor,
176            current_phase: TrainingPhase::Full,
177            phase_step: 0,
178            total_step: 0,
179            cycle_count: 0,
180            phase_losses,
181            recent_losses: VecDeque::with_capacity(100),
182            speedup_ratio: 1.0,
183            full_steps_taken: 0,
184            predict_steps_taken: 0,
185            correct_steps_taken: 0,
186            param_shapes: param_shapes.to_vec(),
187        })
188    }
189
190    /// Determine the next training phase.
191    fn get_next_phase(&self) -> TrainingPhase {
192        match self.current_phase {
193            TrainingPhase::Full => {
194                if self.phase_step >= self.config.full_steps {
195                    TrainingPhase::Predict
196                } else {
197                    TrainingPhase::Full
198                }
199            }
200            TrainingPhase::Predict => {
201                // Check for correction
202                if self.phase_step > 0 && self.phase_step % self.config.correct_every == 0 {
203                    return TrainingPhase::Correct;
204                }
205                // Check for cycle completion
206                if self.phase_step >= self.config.predict_steps {
207                    return TrainingPhase::Full;
208                }
209                TrainingPhase::Predict
210            }
211            TrainingPhase::Correct => {
212                // After correction, back to predict or full
213                let remaining_predict = self.config.predict_steps.saturating_sub(self.phase_step);
214                if remaining_predict > 0 {
215                    TrainingPhase::Predict
216                } else {
217                    TrainingPhase::Full
218                }
219            }
220        }
221    }
222
223    /// Handle phase transition.
224    fn transition_phase(&mut self, new_phase: TrainingPhase) {
225        let old_phase = self.current_phase;
226        self.current_phase = new_phase;
227
228        match new_phase {
229            TrainingPhase::Full => {
230                // Starting new cycle
231                self.phase_step = 0;
232                self.cycle_count += 1;
233
234                // Apply adaptive scheduling if enabled
235                if self.config.adaptive_phases && self.recent_losses.len() >= 10 {
236                    self.adjust_phase_lengths();
237                }
238            }
239            TrainingPhase::Predict => {
240                if old_phase == TrainingPhase::Full {
241                    // Entering predict from full
242                    self.phase_step = 0;
243                }
244            }
245            TrainingPhase::Correct => {
246                // Correction is a single step, don't reset phase_step
247            }
248        }
249    }
250
251    /// Adjust phase lengths based on training dynamics.
252    fn adjust_phase_lengths(&mut self) {
253        if self.recent_losses.len() < 20 {
254            return;
255        }
256
257        let losses: Vec<f32> = self.recent_losses.iter().copied().collect();
258        let early: f32 = losses[..10].iter().sum::<f32>() / 10.0;
259        let late: f32 = losses[losses.len() - 10..].iter().sum::<f32>() / 10.0;
260
261        if late > early * (1.0 + self.config.loss_threshold) {
262            // Loss increasing: more full training
263            self.config.full_steps = (self.config.full_steps + 5).min(50);
264            self.config.predict_steps = self.config.predict_steps.saturating_sub(10).max(10);
265        } else if late < early * 0.95 {
266            // Loss decreasing well: can use more prediction
267            self.config.full_steps = self.config.full_steps.saturating_sub(2).max(5);
268            self.config.predict_steps = (self.config.predict_steps + 5).min(100);
269        }
270    }
271
272    /// Begin a training step. Returns info about current phase.
273    ///
274    /// # Returns
275    ///
276    /// Step information including phase and whether phase changed.
277    ///
278    /// # Errors
279    ///
280    /// Returns error if phase transition fails.
281    pub fn begin_step(&mut self) -> Result<StepInfo> {
282        // Check for phase transition
283        let next_phase = self.get_next_phase();
284        let phase_changed = next_phase != self.current_phase;
285        if phase_changed {
286            self.transition_phase(next_phase);
287        }
288
289        Ok(StepInfo {
290            phase: self.current_phase,
291            phase_step: self.phase_step,
292            total_step: self.total_step,
293            cycle: self.cycle_count,
294            phase_changed,
295        })
296    }
297
298    /// Record full gradients after backprop (for FULL or CORRECT phase).
299    ///
300    /// # Arguments
301    ///
302    /// * `gradients` - Map of parameter names to gradient tensors
303    ///
304    /// # Errors
305    ///
306    /// Returns error if recording fails.
307    pub fn record_full_gradients(&mut self, gradients: &HashMap<String, Tensor>) -> Result<()> {
308        // Record for prediction
309        self.predictor.record_gradient(gradients)?;
310
311        // If in correction phase, compute and apply correction
312        if self.current_phase == TrainingPhase::Correct {
313            self.predictor.compute_correction(gradients)?;
314        }
315
316        Ok(())
317    }
318
319    /// Get predicted gradients (for PREDICT phase).
320    ///
321    /// # Returns
322    ///
323    /// Map of parameter names to predicted gradient tensors.
324    ///
325    /// # Errors
326    ///
327    /// Returns error if prediction fails.
328    pub fn get_predicted_gradients(&mut self) -> Result<HashMap<String, Tensor>> {
329        self.predictor.predict_gradient()
330    }
331
332    /// Apply correction to gradients.
333    ///
334    /// # Arguments
335    ///
336    /// * `gradients` - Mutable map of gradients to modify in-place
337    ///
338    /// # Errors
339    ///
340    /// Returns error if correction fails.
341    pub fn apply_correction(&mut self, gradients: &mut HashMap<String, Tensor>) -> Result<()> {
342        self.predictor.apply_correction(gradients)
343    }
344
345    /// End the training step.
346    ///
347    /// # Arguments
348    ///
349    /// * `loss` - Loss value for this step
350    ///
351    /// # Errors
352    ///
353    /// Returns error if tracking fails.
354    #[allow(clippy::cast_precision_loss)]
355    pub fn end_step(&mut self, loss: f32) -> Result<()> {
356        // Track loss
357        if self.recent_losses.len() >= 100 {
358            self.recent_losses.pop_front();
359        }
360        self.recent_losses.push_back(loss);
361
362        if let Some(phase_losses) = self.phase_losses.get_mut(&self.current_phase) {
363            phase_losses.push(loss);
364        }
365
366        // Update phase step count
367        match self.current_phase {
368            TrainingPhase::Full => self.full_steps_taken += 1,
369            TrainingPhase::Predict => self.predict_steps_taken += 1,
370            TrainingPhase::Correct => self.correct_steps_taken += 1,
371        }
372
373        // Update counters
374        self.phase_step += 1;
375        self.total_step += 1;
376
377        // Calculate speedup
378        let total_forward = (self.full_steps_taken + self.predict_steps_taken + self.correct_steps_taken) as f32;
379        let total_backward = (self.full_steps_taken + self.correct_steps_taken).max(1) as f32;
380        self.speedup_ratio = total_forward / total_backward;
381
382        Ok(())
383    }
384
385    /// Get current training phase.
386    #[must_use]
387    pub const fn current_phase(&self) -> TrainingPhase {
388        self.current_phase
389    }
390
391    /// Get total step count.
392    #[must_use]
393    pub const fn total_step(&self) -> usize {
394        self.total_step
395    }
396
397    /// Get cycle count.
398    #[must_use]
399    pub const fn cycle_count(&self) -> usize {
400        self.cycle_count
401    }
402
403    /// Get speedup ratio.
404    #[must_use]
405    pub const fn speedup_ratio(&self) -> f32 {
406        self.speedup_ratio
407    }
408
409    /// Get training statistics.
410    #[must_use]
411    #[allow(clippy::cast_precision_loss)]
412    pub fn get_stats(&self) -> TrainerStats {
413        let mut phase_avg_losses = HashMap::new();
414
415        for (phase, losses) in &self.phase_losses {
416            if !losses.is_empty() {
417                let recent: Vec<&f32> = losses.iter().rev().take(100).collect();
418                let avg: f32 = recent.iter().copied().sum::<f32>() / recent.len() as f32;
419                phase_avg_losses.insert(*phase, avg);
420            }
421        }
422
423        TrainerStats {
424            total_steps: self.total_step,
425            cycles: self.cycle_count,
426            speedup: self.speedup_ratio,
427            full_steps: self.full_steps_taken,
428            predict_steps: self.predict_steps_taken,
429            correct_steps: self.correct_steps_taken,
430            current_full_steps: self.config.full_steps,
431            current_predict_steps: self.config.predict_steps,
432            phase_avg_losses,
433        }
434    }
435
436    /// Reset trainer state.
437    pub fn reset(&mut self) -> Result<()> {
438        self.predictor.reset();
439        self.ternary_accum.reset()?;
440        self.current_phase = TrainingPhase::Full;
441        self.phase_step = 0;
442        self.total_step = 0;
443        self.cycle_count = 0;
444        self.recent_losses.clear();
445        self.speedup_ratio = 1.0;
446        self.full_steps_taken = 0;
447        self.predict_steps_taken = 0;
448        self.correct_steps_taken = 0;
449
450        for losses in self.phase_losses.values_mut() {
451            losses.clear();
452        }
453
454        Ok(())
455    }
456
457    /// Get mutable access to VSA compressor.
458    pub fn vsa_compressor_mut(&mut self) -> &mut VSAGradientCompressor {
459        &mut self.vsa_compressor
460    }
461
462    /// Get mutable access to ternary accumulator.
463    pub fn ternary_accumulator_mut(&mut self) -> &mut TernaryGradientAccumulator {
464        &mut self.ternary_accum
465    }
466
467    /// Check if should compute full gradients.
468    #[must_use]
469    pub fn should_compute_full(&self) -> bool {
470        matches!(self.current_phase, TrainingPhase::Full | TrainingPhase::Correct)
471    }
472}
473
474/// Information about current training step.
475#[derive(Debug, Clone)]
476pub struct StepInfo {
477    /// Current training phase.
478    pub phase: TrainingPhase,
479    /// Step within current phase.
480    pub phase_step: usize,
481    /// Total training step.
482    pub total_step: usize,
483    /// Cycle count.
484    pub cycle: usize,
485    /// Whether phase changed this step.
486    pub phase_changed: bool,
487}
488
489/// Training statistics.
490#[derive(Debug, Clone)]
491pub struct TrainerStats {
492    /// Total training steps.
493    pub total_steps: usize,
494    /// Cycle count.
495    pub cycles: usize,
496    /// Speedup ratio (total steps / backward steps).
497    pub speedup: f32,
498    /// Full phase steps taken.
499    pub full_steps: usize,
500    /// Predict phase steps taken.
501    pub predict_steps: usize,
502    /// Correct phase steps taken.
503    pub correct_steps: usize,
504    /// Current full steps per cycle.
505    pub current_full_steps: usize,
506    /// Current predict steps per cycle.
507    pub current_predict_steps: usize,
508    /// Average loss per phase.
509    pub phase_avg_losses: HashMap<TrainingPhase, f32>,
510}
511
512impl std::fmt::Display for TrainerStats {
513    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
514        write!(
515            f,
516            "Steps: {} | Cycles: {} | Speedup: {:.2}x | Full: {} | Predict: {} | Correct: {}",
517            self.total_steps,
518            self.cycles,
519            self.speedup,
520            self.full_steps,
521            self.predict_steps,
522            self.correct_steps
523        )
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    fn create_param_shapes() -> Vec<(String, Vec<usize>)> {
532        vec![
533            ("layer1.weight".to_string(), vec![64, 128]),
534            ("layer1.bias".to_string(), vec![64]),
535        ]
536    }
537
538    fn create_mock_gradients(device: &Device) -> HashMap<String, Tensor> {
539        let mut gradients = HashMap::new();
540        gradients.insert(
541            "layer1.weight".to_string(),
542            Tensor::randn(0.0f32, 0.1, (64, 128), device).unwrap(),
543        );
544        gradients.insert(
545            "layer1.bias".to_string(),
546            Tensor::randn(0.0f32, 0.1, 64, device).unwrap(),
547        );
548        gradients
549    }
550
551    #[test]
552    fn test_trainer_creation() {
553        let shapes = create_param_shapes();
554        let device = Device::Cpu;
555        let config = PhaseConfig::default();
556
557        let trainer = PhaseTrainer::new(&shapes, config, &device).unwrap();
558        assert_eq!(trainer.current_phase(), TrainingPhase::Full);
559        assert_eq!(trainer.total_step(), 0);
560    }
561
562    #[test]
563    fn test_phase_transitions() {
564        let shapes = create_param_shapes();
565        let device = Device::Cpu;
566        let config = PhaseConfig::default()
567            .with_full_steps(2)
568            .with_predict_steps(4)
569            .with_correct_every(2);
570
571        let mut trainer = PhaseTrainer::new(&shapes, config, &device).unwrap();
572        let gradients = create_mock_gradients(&device);
573
574        // Start in FULL phase
575        assert_eq!(trainer.current_phase(), TrainingPhase::Full);
576
577        // Step 1: FULL
578        let info = trainer.begin_step().unwrap();
579        assert_eq!(info.phase, TrainingPhase::Full);
580        trainer.record_full_gradients(&gradients).unwrap();
581        trainer.end_step(1.0).unwrap();
582
583        // Step 2: FULL (still, phase_step was 0)
584        let info = trainer.begin_step().unwrap();
585        assert_eq!(info.phase, TrainingPhase::Full);
586        trainer.record_full_gradients(&gradients).unwrap();
587        trainer.end_step(0.9).unwrap();
588
589        // Step 3: Should transition to PREDICT
590        let info = trainer.begin_step().unwrap();
591        assert!(info.phase_changed);
592        assert_eq!(info.phase, TrainingPhase::Predict);
593    }
594
595    #[test]
596    fn test_speedup_calculation() {
597        let shapes = create_param_shapes();
598        let device = Device::Cpu;
599        let config = PhaseConfig::default()
600            .with_full_steps(1)
601            .with_predict_steps(3)
602            .with_correct_every(10); // No correction in this short test
603
604        let mut trainer = PhaseTrainer::new(&shapes, config, &device).unwrap();
605        let gradients = create_mock_gradients(&device);
606
607        // 1 full step
608        trainer.begin_step().unwrap();
609        trainer.record_full_gradients(&gradients).unwrap();
610        trainer.end_step(1.0).unwrap();
611
612        // 3 predict steps
613        for _ in 0..3 {
614            trainer.begin_step().unwrap();
615            let _ = trainer.get_predicted_gradients().unwrap();
616            trainer.end_step(0.9).unwrap();
617        }
618
619        // Speedup should be 4/1 = 4.0 (4 total steps, 1 backward step)
620        assert!((trainer.speedup_ratio() - 4.0).abs() < 0.1);
621    }
622
623    #[test]
624    fn test_stats() {
625        let shapes = create_param_shapes();
626        let device = Device::Cpu;
627        let config = PhaseConfig::default();
628
629        let mut trainer = PhaseTrainer::new(&shapes, config, &device).unwrap();
630        let gradients = create_mock_gradients(&device);
631
632        // Run a few steps
633        for i in 0..5 {
634            trainer.begin_step().unwrap();
635            if trainer.should_compute_full() {
636                trainer.record_full_gradients(&gradients).unwrap();
637            } else {
638                let _ = trainer.get_predicted_gradients().unwrap();
639            }
640            trainer.end_step(1.0 - i as f32 * 0.1).unwrap();
641        }
642
643        let stats = trainer.get_stats();
644        assert_eq!(stats.total_steps, 5);
645    }
646
647    #[test]
648    fn test_reset() {
649        let shapes = create_param_shapes();
650        let device = Device::Cpu;
651        let config = PhaseConfig::default();
652
653        let mut trainer = PhaseTrainer::new(&shapes, config, &device).unwrap();
654        let gradients = create_mock_gradients(&device);
655
656        // Run some steps
657        trainer.begin_step().unwrap();
658        trainer.record_full_gradients(&gradients).unwrap();
659        trainer.end_step(1.0).unwrap();
660
661        assert_eq!(trainer.total_step(), 1);
662
663        // Reset
664        trainer.reset().unwrap();
665
666        assert_eq!(trainer.total_step(), 0);
667        assert_eq!(trainer.current_phase(), TrainingPhase::Full);
668    }
669}