1use std::collections::{HashMap, VecDeque};
7
8use candle_core::{Device, Tensor};
9
10use crate::config::PhaseConfig;
11use crate::error::Result;
12use crate::prediction::GradientPredictor;
13use crate::ternary::TernaryGradientAccumulator;
14use crate::vsa::VSAGradientCompressor;
15
16fn warn_cpu_fallback(device: &Device) {
17 static WARN_ONCE: std::sync::Once = std::sync::Once::new();
18 if matches!(device, Device::Cpu) {
19 WARN_ONCE.call_once(|| {
20 eprintln!(
21 "vsa-optim-rs: CPU device in use. CUDA is the intended default; use Device::cuda_if_available(0) when possible."
22 );
23 });
24 }
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum TrainingPhase {
30 Full,
32 Predict,
34 Correct,
36}
37
38impl std::fmt::Display for TrainingPhase {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 match self {
41 Self::Full => write!(f, "FULL"),
42 Self::Predict => write!(f, "PREDICT"),
43 Self::Correct => write!(f, "CORRECT"),
44 }
45 }
46}
47
48pub struct PhaseTrainer {
90 config: PhaseConfig,
91 device: Device,
92
93 predictor: GradientPredictor,
95
96 ternary_accum: TernaryGradientAccumulator,
98
99 vsa_compressor: VSAGradientCompressor,
101
102 current_phase: TrainingPhase,
104
105 phase_step: usize,
107
108 total_step: usize,
110
111 cycle_count: usize,
113
114 phase_losses: HashMap<TrainingPhase, Vec<f32>>,
116
117 recent_losses: VecDeque<f32>,
119
120 speedup_ratio: f32,
122
123 full_steps_taken: usize,
125 predict_steps_taken: usize,
126 correct_steps_taken: usize,
127
128 param_shapes: Vec<(String, Vec<usize>)>,
130}
131
132impl PhaseTrainer {
133 pub fn new(
145 param_shapes: &[(String, Vec<usize>)],
146 config: PhaseConfig,
147 device: &Device,
148 ) -> Result<Self> {
149 warn_cpu_fallback(device);
150 let predictor = GradientPredictor::new(
151 param_shapes,
152 config.prediction_config.clone(),
153 device,
154 )?;
155
156 let ternary_accum = TernaryGradientAccumulator::new(
157 param_shapes,
158 config.ternary_config.clone(),
159 device,
160 )?;
161
162 let param_count: usize = param_shapes.iter().map(|(_, s)| s.iter().product::<usize>()).sum();
163 let vsa_compressor = VSAGradientCompressor::new(param_count, config.vsa_config.clone());
164
165 let mut phase_losses = HashMap::new();
166 phase_losses.insert(TrainingPhase::Full, Vec::new());
167 phase_losses.insert(TrainingPhase::Predict, Vec::new());
168 phase_losses.insert(TrainingPhase::Correct, Vec::new());
169
170 Ok(Self {
171 config,
172 device: device.clone(),
173 predictor,
174 ternary_accum,
175 vsa_compressor,
176 current_phase: TrainingPhase::Full,
177 phase_step: 0,
178 total_step: 0,
179 cycle_count: 0,
180 phase_losses,
181 recent_losses: VecDeque::with_capacity(100),
182 speedup_ratio: 1.0,
183 full_steps_taken: 0,
184 predict_steps_taken: 0,
185 correct_steps_taken: 0,
186 param_shapes: param_shapes.to_vec(),
187 })
188 }
189
190 fn get_next_phase(&self) -> TrainingPhase {
192 match self.current_phase {
193 TrainingPhase::Full => {
194 if self.phase_step >= self.config.full_steps {
195 TrainingPhase::Predict
196 } else {
197 TrainingPhase::Full
198 }
199 }
200 TrainingPhase::Predict => {
201 if self.phase_step > 0 && self.phase_step % self.config.correct_every == 0 {
203 return TrainingPhase::Correct;
204 }
205 if self.phase_step >= self.config.predict_steps {
207 return TrainingPhase::Full;
208 }
209 TrainingPhase::Predict
210 }
211 TrainingPhase::Correct => {
212 let remaining_predict = self.config.predict_steps.saturating_sub(self.phase_step);
214 if remaining_predict > 0 {
215 TrainingPhase::Predict
216 } else {
217 TrainingPhase::Full
218 }
219 }
220 }
221 }
222
223 fn transition_phase(&mut self, new_phase: TrainingPhase) {
225 let old_phase = self.current_phase;
226 self.current_phase = new_phase;
227
228 match new_phase {
229 TrainingPhase::Full => {
230 self.phase_step = 0;
232 self.cycle_count += 1;
233
234 if self.config.adaptive_phases && self.recent_losses.len() >= 10 {
236 self.adjust_phase_lengths();
237 }
238 }
239 TrainingPhase::Predict => {
240 if old_phase == TrainingPhase::Full {
241 self.phase_step = 0;
243 }
244 }
245 TrainingPhase::Correct => {
246 }
248 }
249 }
250
251 fn adjust_phase_lengths(&mut self) {
253 if self.recent_losses.len() < 20 {
254 return;
255 }
256
257 let losses: Vec<f32> = self.recent_losses.iter().copied().collect();
258 let early: f32 = losses[..10].iter().sum::<f32>() / 10.0;
259 let late: f32 = losses[losses.len() - 10..].iter().sum::<f32>() / 10.0;
260
261 if late > early * (1.0 + self.config.loss_threshold) {
262 self.config.full_steps = (self.config.full_steps + 5).min(50);
264 self.config.predict_steps = self.config.predict_steps.saturating_sub(10).max(10);
265 } else if late < early * 0.95 {
266 self.config.full_steps = self.config.full_steps.saturating_sub(2).max(5);
268 self.config.predict_steps = (self.config.predict_steps + 5).min(100);
269 }
270 }
271
272 pub fn begin_step(&mut self) -> Result<StepInfo> {
282 let next_phase = self.get_next_phase();
284 let phase_changed = next_phase != self.current_phase;
285 if phase_changed {
286 self.transition_phase(next_phase);
287 }
288
289 Ok(StepInfo {
290 phase: self.current_phase,
291 phase_step: self.phase_step,
292 total_step: self.total_step,
293 cycle: self.cycle_count,
294 phase_changed,
295 })
296 }
297
298 pub fn record_full_gradients(&mut self, gradients: &HashMap<String, Tensor>) -> Result<()> {
308 self.predictor.record_gradient(gradients)?;
310
311 if self.current_phase == TrainingPhase::Correct {
313 self.predictor.compute_correction(gradients)?;
314 }
315
316 Ok(())
317 }
318
319 pub fn get_predicted_gradients(&mut self) -> Result<HashMap<String, Tensor>> {
329 self.predictor.predict_gradient()
330 }
331
332 pub fn apply_correction(&mut self, gradients: &mut HashMap<String, Tensor>) -> Result<()> {
342 self.predictor.apply_correction(gradients)
343 }
344
345 #[allow(clippy::cast_precision_loss)]
355 pub fn end_step(&mut self, loss: f32) -> Result<()> {
356 if self.recent_losses.len() >= 100 {
358 self.recent_losses.pop_front();
359 }
360 self.recent_losses.push_back(loss);
361
362 if let Some(phase_losses) = self.phase_losses.get_mut(&self.current_phase) {
363 phase_losses.push(loss);
364 }
365
366 match self.current_phase {
368 TrainingPhase::Full => self.full_steps_taken += 1,
369 TrainingPhase::Predict => self.predict_steps_taken += 1,
370 TrainingPhase::Correct => self.correct_steps_taken += 1,
371 }
372
373 self.phase_step += 1;
375 self.total_step += 1;
376
377 let total_forward = (self.full_steps_taken + self.predict_steps_taken + self.correct_steps_taken) as f32;
379 let total_backward = (self.full_steps_taken + self.correct_steps_taken).max(1) as f32;
380 self.speedup_ratio = total_forward / total_backward;
381
382 Ok(())
383 }
384
385 #[must_use]
387 pub const fn current_phase(&self) -> TrainingPhase {
388 self.current_phase
389 }
390
391 #[must_use]
393 pub const fn total_step(&self) -> usize {
394 self.total_step
395 }
396
397 #[must_use]
399 pub const fn cycle_count(&self) -> usize {
400 self.cycle_count
401 }
402
403 #[must_use]
405 pub const fn speedup_ratio(&self) -> f32 {
406 self.speedup_ratio
407 }
408
409 #[must_use]
411 #[allow(clippy::cast_precision_loss)]
412 pub fn get_stats(&self) -> TrainerStats {
413 let mut phase_avg_losses = HashMap::new();
414
415 for (phase, losses) in &self.phase_losses {
416 if !losses.is_empty() {
417 let recent: Vec<&f32> = losses.iter().rev().take(100).collect();
418 let avg: f32 = recent.iter().copied().sum::<f32>() / recent.len() as f32;
419 phase_avg_losses.insert(*phase, avg);
420 }
421 }
422
423 TrainerStats {
424 total_steps: self.total_step,
425 cycles: self.cycle_count,
426 speedup: self.speedup_ratio,
427 full_steps: self.full_steps_taken,
428 predict_steps: self.predict_steps_taken,
429 correct_steps: self.correct_steps_taken,
430 current_full_steps: self.config.full_steps,
431 current_predict_steps: self.config.predict_steps,
432 phase_avg_losses,
433 }
434 }
435
436 pub fn reset(&mut self) -> Result<()> {
438 self.predictor.reset();
439 self.ternary_accum.reset()?;
440 self.current_phase = TrainingPhase::Full;
441 self.phase_step = 0;
442 self.total_step = 0;
443 self.cycle_count = 0;
444 self.recent_losses.clear();
445 self.speedup_ratio = 1.0;
446 self.full_steps_taken = 0;
447 self.predict_steps_taken = 0;
448 self.correct_steps_taken = 0;
449
450 for losses in self.phase_losses.values_mut() {
451 losses.clear();
452 }
453
454 Ok(())
455 }
456
457 pub fn vsa_compressor_mut(&mut self) -> &mut VSAGradientCompressor {
459 &mut self.vsa_compressor
460 }
461
462 pub fn ternary_accumulator_mut(&mut self) -> &mut TernaryGradientAccumulator {
464 &mut self.ternary_accum
465 }
466
467 #[must_use]
469 pub fn should_compute_full(&self) -> bool {
470 matches!(self.current_phase, TrainingPhase::Full | TrainingPhase::Correct)
471 }
472}
473
474#[derive(Debug, Clone)]
476pub struct StepInfo {
477 pub phase: TrainingPhase,
479 pub phase_step: usize,
481 pub total_step: usize,
483 pub cycle: usize,
485 pub phase_changed: bool,
487}
488
489#[derive(Debug, Clone)]
491pub struct TrainerStats {
492 pub total_steps: usize,
494 pub cycles: usize,
496 pub speedup: f32,
498 pub full_steps: usize,
500 pub predict_steps: usize,
502 pub correct_steps: usize,
504 pub current_full_steps: usize,
506 pub current_predict_steps: usize,
508 pub phase_avg_losses: HashMap<TrainingPhase, f32>,
510}
511
512impl std::fmt::Display for TrainerStats {
513 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
514 write!(
515 f,
516 "Steps: {} | Cycles: {} | Speedup: {:.2}x | Full: {} | Predict: {} | Correct: {}",
517 self.total_steps,
518 self.cycles,
519 self.speedup,
520 self.full_steps,
521 self.predict_steps,
522 self.correct_steps
523 )
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530
531 fn create_param_shapes() -> Vec<(String, Vec<usize>)> {
532 vec![
533 ("layer1.weight".to_string(), vec![64, 128]),
534 ("layer1.bias".to_string(), vec![64]),
535 ]
536 }
537
538 fn create_mock_gradients(device: &Device) -> HashMap<String, Tensor> {
539 let mut gradients = HashMap::new();
540 gradients.insert(
541 "layer1.weight".to_string(),
542 Tensor::randn(0.0f32, 0.1, (64, 128), device).unwrap(),
543 );
544 gradients.insert(
545 "layer1.bias".to_string(),
546 Tensor::randn(0.0f32, 0.1, 64, device).unwrap(),
547 );
548 gradients
549 }
550
551 #[test]
552 fn test_trainer_creation() {
553 let shapes = create_param_shapes();
554 let device = Device::Cpu;
555 let config = PhaseConfig::default();
556
557 let trainer = PhaseTrainer::new(&shapes, config, &device).unwrap();
558 assert_eq!(trainer.current_phase(), TrainingPhase::Full);
559 assert_eq!(trainer.total_step(), 0);
560 }
561
562 #[test]
563 fn test_phase_transitions() {
564 let shapes = create_param_shapes();
565 let device = Device::Cpu;
566 let config = PhaseConfig::default()
567 .with_full_steps(2)
568 .with_predict_steps(4)
569 .with_correct_every(2);
570
571 let mut trainer = PhaseTrainer::new(&shapes, config, &device).unwrap();
572 let gradients = create_mock_gradients(&device);
573
574 assert_eq!(trainer.current_phase(), TrainingPhase::Full);
576
577 let info = trainer.begin_step().unwrap();
579 assert_eq!(info.phase, TrainingPhase::Full);
580 trainer.record_full_gradients(&gradients).unwrap();
581 trainer.end_step(1.0).unwrap();
582
583 let info = trainer.begin_step().unwrap();
585 assert_eq!(info.phase, TrainingPhase::Full);
586 trainer.record_full_gradients(&gradients).unwrap();
587 trainer.end_step(0.9).unwrap();
588
589 let info = trainer.begin_step().unwrap();
591 assert!(info.phase_changed);
592 assert_eq!(info.phase, TrainingPhase::Predict);
593 }
594
595 #[test]
596 fn test_speedup_calculation() {
597 let shapes = create_param_shapes();
598 let device = Device::Cpu;
599 let config = PhaseConfig::default()
600 .with_full_steps(1)
601 .with_predict_steps(3)
602 .with_correct_every(10); let mut trainer = PhaseTrainer::new(&shapes, config, &device).unwrap();
605 let gradients = create_mock_gradients(&device);
606
607 trainer.begin_step().unwrap();
609 trainer.record_full_gradients(&gradients).unwrap();
610 trainer.end_step(1.0).unwrap();
611
612 for _ in 0..3 {
614 trainer.begin_step().unwrap();
615 let _ = trainer.get_predicted_gradients().unwrap();
616 trainer.end_step(0.9).unwrap();
617 }
618
619 assert!((trainer.speedup_ratio() - 4.0).abs() < 0.1);
621 }
622
623 #[test]
624 fn test_stats() {
625 let shapes = create_param_shapes();
626 let device = Device::Cpu;
627 let config = PhaseConfig::default();
628
629 let mut trainer = PhaseTrainer::new(&shapes, config, &device).unwrap();
630 let gradients = create_mock_gradients(&device);
631
632 for i in 0..5 {
634 trainer.begin_step().unwrap();
635 if trainer.should_compute_full() {
636 trainer.record_full_gradients(&gradients).unwrap();
637 } else {
638 let _ = trainer.get_predicted_gradients().unwrap();
639 }
640 trainer.end_step(1.0 - i as f32 * 0.1).unwrap();
641 }
642
643 let stats = trainer.get_stats();
644 assert_eq!(stats.total_steps, 5);
645 }
646
647 #[test]
648 fn test_reset() {
649 let shapes = create_param_shapes();
650 let device = Device::Cpu;
651 let config = PhaseConfig::default();
652
653 let mut trainer = PhaseTrainer::new(&shapes, config, &device).unwrap();
654 let gradients = create_mock_gradients(&device);
655
656 trainer.begin_step().unwrap();
658 trainer.record_full_gradients(&gradients).unwrap();
659 trainer.end_step(1.0).unwrap();
660
661 assert_eq!(trainer.total_step(), 1);
662
663 trainer.reset().unwrap();
665
666 assert_eq!(trainer.total_step(), 0);
667 assert_eq!(trainer.current_phase(), TrainingPhase::Full);
668 }
669}