1use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use thiserror::Error;
40
41#[derive(Error, Debug, Clone, PartialEq)]
43pub enum MixedPrecisionError {
44 #[error("Loss scale overflow: scale={0}")]
45 LossScaleOverflow(f64),
46
47 #[error("Loss scale underflow: scale={0}")]
48 LossScaleUnderflow(f64),
49
50 #[error("Gradient overflow detected in {0} gradients")]
51 GradientOverflow(usize),
52
53 #[error("NaN detected in gradients")]
54 GradientNaN,
55
56 #[error("Unsupported precision mode: {0:?}")]
57 UnsupportedPrecisionMode(PrecisionMode),
58
59 #[error("Mixed precision not supported by backend")]
60 NotSupported,
61
62 #[error("Numerical instability detected: {0}")]
63 NumericalInstability(String),
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
68pub enum PrecisionMode {
69 FP32,
71 FP16,
73 BF16,
75 FP64,
77 FP8,
79}
80
81impl PrecisionMode {
82 pub fn bytes_per_element(&self) -> usize {
84 match self {
85 PrecisionMode::FP64 => 8,
86 PrecisionMode::FP32 => 4,
87 PrecisionMode::FP16 | PrecisionMode::BF16 => 2,
88 PrecisionMode::FP8 => 1,
89 }
90 }
91
92 pub fn is_mixed_precision(&self) -> bool {
94 matches!(
95 self,
96 PrecisionMode::FP16 | PrecisionMode::BF16 | PrecisionMode::FP8
97 )
98 }
99
100 pub fn name(&self) -> &'static str {
102 match self {
103 PrecisionMode::FP32 => "float32",
104 PrecisionMode::FP16 => "float16",
105 PrecisionMode::BF16 => "bfloat16",
106 PrecisionMode::FP64 => "float64",
107 PrecisionMode::FP8 => "float8",
108 }
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
114pub enum LossScalingStrategy {
115 None,
117
118 Static { scale: f64 },
120
121 Dynamic {
123 init_scale: f64,
125 growth_factor: f64,
127 backoff_factor: f64,
129 growth_interval: usize,
131 },
132}
133
134impl Default for LossScalingStrategy {
135 fn default() -> Self {
136 LossScalingStrategy::Dynamic {
137 init_scale: 65536.0,
138 growth_factor: 2.0,
139 backoff_factor: 0.5,
140 growth_interval: 2000,
141 }
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct LossScaler {
148 strategy: LossScalingStrategy,
149 current_scale: f64,
150 growth_tracker: usize,
151 overflow_count: usize,
152 total_steps: usize,
153}
154
155impl LossScaler {
156 pub fn new(strategy: LossScalingStrategy) -> Self {
158 let current_scale = match &strategy {
159 LossScalingStrategy::None => 1.0,
160 LossScalingStrategy::Static { scale } => *scale,
161 LossScalingStrategy::Dynamic { init_scale, .. } => *init_scale,
162 };
163
164 Self {
165 strategy,
166 current_scale,
167 growth_tracker: 0,
168 overflow_count: 0,
169 total_steps: 0,
170 }
171 }
172
173 pub fn scale(&self) -> f64 {
175 self.current_scale
176 }
177
178 pub fn scale_loss(&self, loss: f64) -> f64 {
180 loss * self.current_scale
181 }
182
183 pub fn unscale_gradients(&self, grads: &mut HashMap<String, f64>) {
185 let inv_scale = 1.0 / self.current_scale;
186 for grad in grads.values_mut() {
187 *grad *= inv_scale;
188 }
189 }
190
191 pub fn check_overflow(&self, grads: &HashMap<String, f64>) -> Result<(), MixedPrecisionError> {
193 let mut has_nan = false;
194 let mut has_inf = false;
195
196 for grad in grads.values() {
197 if grad.is_nan() {
198 has_nan = true;
199 }
200 if grad.is_infinite() {
201 has_inf = true;
202 }
203 }
204
205 if has_nan {
206 return Err(MixedPrecisionError::GradientNaN);
207 }
208 if has_inf {
209 return Err(MixedPrecisionError::GradientOverflow(
210 grads.values().filter(|g| g.is_infinite()).count(),
211 ));
212 }
213
214 Ok(())
215 }
216
217 pub fn update(&mut self, found_overflow: bool) -> Result<(), MixedPrecisionError> {
219 self.total_steps += 1;
220
221 match &self.strategy {
222 LossScalingStrategy::None | LossScalingStrategy::Static { .. } => {
223 }
225 LossScalingStrategy::Dynamic {
226 growth_factor,
227 backoff_factor,
228 growth_interval,
229 ..
230 } => {
231 if found_overflow {
232 self.current_scale *= backoff_factor;
234 self.growth_tracker = 0;
235 self.overflow_count += 1;
236
237 if self.current_scale < 1.0 {
239 return Err(MixedPrecisionError::LossScaleUnderflow(self.current_scale));
240 }
241 } else {
242 self.growth_tracker += 1;
244
245 if self.growth_tracker >= *growth_interval {
247 self.current_scale *= growth_factor;
248 self.growth_tracker = 0;
249
250 if self.current_scale > 1e10 {
252 return Err(MixedPrecisionError::LossScaleOverflow(self.current_scale));
253 }
254 }
255 }
256 }
257 }
258
259 Ok(())
260 }
261
262 pub fn stats(&self) -> LossScalerStats {
264 LossScalerStats {
265 current_scale: self.current_scale,
266 overflow_count: self.overflow_count,
267 total_steps: self.total_steps,
268 overflow_rate: self.overflow_count as f64 / self.total_steps.max(1) as f64,
269 growth_tracker: self.growth_tracker,
270 }
271 }
272
273 pub fn reset(&mut self) {
275 let init_scale = match &self.strategy {
276 LossScalingStrategy::None => 1.0,
277 LossScalingStrategy::Static { scale } => *scale,
278 LossScalingStrategy::Dynamic { init_scale, .. } => *init_scale,
279 };
280
281 self.current_scale = init_scale;
282 self.growth_tracker = 0;
283 self.overflow_count = 0;
284 self.total_steps = 0;
285 }
286}
287
288#[derive(Debug, Clone, PartialEq)]
290pub struct LossScalerStats {
291 pub current_scale: f64,
293 pub overflow_count: usize,
295 pub total_steps: usize,
297 pub overflow_rate: f64,
299 pub growth_tracker: usize,
301}
302
303#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
305pub struct MixedPrecisionConfig {
306 pub compute_dtype: PrecisionMode,
308
309 pub param_dtype: PrecisionMode,
311
312 pub loss_scaling: LossScalingStrategy,
314
315 pub gradient_checkpointing: bool,
317
318 pub gradient_clipping: bool,
320
321 pub max_gradient_norm: f64,
323
324 pub stability_checks: bool,
326
327 pub skip_on_overflow: bool,
329
330 pub use_master_weights: bool,
332}
333
334impl Default for MixedPrecisionConfig {
335 fn default() -> Self {
336 Self {
337 compute_dtype: PrecisionMode::FP16,
338 param_dtype: PrecisionMode::FP32,
339 loss_scaling: LossScalingStrategy::default(),
340 gradient_checkpointing: false,
341 gradient_clipping: true,
342 max_gradient_norm: 1.0,
343 stability_checks: true,
344 skip_on_overflow: true,
345 use_master_weights: true,
346 }
347 }
348}
349
350impl MixedPrecisionConfig {
351 pub fn new(compute_dtype: PrecisionMode, param_dtype: PrecisionMode) -> Self {
353 Self {
354 compute_dtype,
355 param_dtype,
356 ..Default::default()
357 }
358 }
359
360 pub fn with_compute_dtype(mut self, dtype: PrecisionMode) -> Self {
362 self.compute_dtype = dtype;
363 self
364 }
365
366 pub fn with_param_dtype(mut self, dtype: PrecisionMode) -> Self {
368 self.param_dtype = dtype;
369 self
370 }
371
372 pub fn with_loss_scaling(mut self, strategy: LossScalingStrategy) -> Self {
374 self.loss_scaling = strategy;
375 self
376 }
377
378 pub fn with_gradient_checkpointing(mut self, enabled: bool) -> Self {
380 self.gradient_checkpointing = enabled;
381 self
382 }
383
384 pub fn with_gradient_clipping(mut self, enabled: bool, max_norm: f64) -> Self {
386 self.gradient_clipping = enabled;
387 self.max_gradient_norm = max_norm;
388 self
389 }
390
391 pub fn with_stability_checks(mut self, enabled: bool) -> Self {
393 self.stability_checks = enabled;
394 self
395 }
396
397 pub fn with_master_weights(mut self, enabled: bool) -> Self {
399 self.use_master_weights = enabled;
400 self
401 }
402
403 pub fn fp16() -> Self {
405 Self::new(PrecisionMode::FP16, PrecisionMode::FP32)
406 }
407
408 pub fn bf16() -> Self {
410 Self::new(PrecisionMode::BF16, PrecisionMode::FP32)
411 }
412
413 pub fn fp8() -> Self {
415 Self::new(PrecisionMode::FP8, PrecisionMode::FP32)
416 }
417
418 pub fn validate(&self) -> Result<(), MixedPrecisionError> {
420 let compute_bytes = self.compute_dtype.bytes_per_element();
422 let param_bytes = self.param_dtype.bytes_per_element();
423
424 if param_bytes < compute_bytes {
425 return Err(MixedPrecisionError::NumericalInstability(format!(
426 "Parameter dtype ({:?}) should be at least as precise as compute dtype ({:?})",
427 self.param_dtype, self.compute_dtype
428 )));
429 }
430
431 if let LossScalingStrategy::Dynamic {
433 init_scale,
434 growth_factor,
435 backoff_factor,
436 ..
437 } = &self.loss_scaling
438 {
439 if *init_scale <= 0.0 {
440 return Err(MixedPrecisionError::LossScaleUnderflow(*init_scale));
441 }
442 if *growth_factor <= 1.0 {
443 return Err(MixedPrecisionError::NumericalInstability(format!(
444 "Growth factor must be > 1.0, got {}",
445 growth_factor
446 )));
447 }
448 if *backoff_factor >= 1.0 || *backoff_factor <= 0.0 {
449 return Err(MixedPrecisionError::NumericalInstability(format!(
450 "Backoff factor must be in (0, 1), got {}",
451 backoff_factor
452 )));
453 }
454 }
455
456 Ok(())
457 }
458}
459
460#[derive(Debug, Clone)]
462pub struct MixedPrecisionState {
463 pub config: MixedPrecisionConfig,
465
466 pub scaler: LossScaler,
468
469 pub master_weights: HashMap<String, Vec<f64>>,
471
472 pub successful_steps: usize,
474
475 pub skipped_steps: usize,
477
478 pub step: usize,
480}
481
482impl MixedPrecisionState {
483 pub fn new(config: MixedPrecisionConfig) -> Result<Self, MixedPrecisionError> {
485 config.validate()?;
486
487 Ok(Self {
488 scaler: LossScaler::new(config.loss_scaling.clone()),
489 config,
490 master_weights: HashMap::new(),
491 successful_steps: 0,
492 skipped_steps: 0,
493 step: 0,
494 })
495 }
496
497 pub fn init_master_weights(&mut self, params: &HashMap<String, Vec<f64>>) {
499 if self.config.use_master_weights {
500 self.master_weights = params.clone();
501 }
502 }
503
504 pub fn current_loss_scale(&self) -> f64 {
506 self.scaler.scale()
507 }
508
509 pub fn stats(&self) -> MixedPrecisionStats {
511 let scaler_stats = self.scaler.stats();
512
513 MixedPrecisionStats {
514 compute_dtype: self.config.compute_dtype,
515 param_dtype: self.config.param_dtype,
516 current_scale: scaler_stats.current_scale,
517 total_steps: self.step,
518 successful_steps: self.successful_steps,
519 skipped_steps: self.skipped_steps,
520 overflow_count: scaler_stats.overflow_count,
521 overflow_rate: scaler_stats.overflow_rate,
522 success_rate: self.successful_steps as f64 / self.step.max(1) as f64,
523 }
524 }
525
526 pub fn process_step(
528 &mut self,
529 loss: f64,
530 gradients: &mut HashMap<String, f64>,
531 ) -> Result<bool, MixedPrecisionError> {
532 self.step += 1;
533
534 let _scaled_loss = self.scaler.scale_loss(loss);
536
537 for grad in gradients.values_mut() {
539 *grad *= self.scaler.scale();
540 }
541
542 self.scaler.unscale_gradients(gradients);
544
545 let found_overflow = self.scaler.check_overflow(gradients).is_err();
547
548 self.scaler.update(found_overflow)?;
550
551 if found_overflow {
552 self.skipped_steps += 1;
553 Ok(false) } else {
555 self.successful_steps += 1;
556 Ok(true) }
558 }
559}
560
561#[derive(Debug, Clone, PartialEq)]
563pub struct MixedPrecisionStats {
564 pub compute_dtype: PrecisionMode,
566 pub param_dtype: PrecisionMode,
568 pub current_scale: f64,
570 pub total_steps: usize,
572 pub successful_steps: usize,
574 pub skipped_steps: usize,
576 pub overflow_count: usize,
578 pub overflow_rate: f64,
580 pub success_rate: f64,
582}
583
584impl std::fmt::Display for MixedPrecisionStats {
585 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
586 writeln!(f, "Mixed Precision Training Statistics")?;
587 writeln!(f, "====================================")?;
588 writeln!(f, "Compute dtype: {:?}", self.compute_dtype)?;
589 writeln!(f, "Parameter dtype: {:?}", self.param_dtype)?;
590 writeln!(f, "Current scale: {:.0}", self.current_scale)?;
591 writeln!(f, "Total steps: {}", self.total_steps)?;
592 writeln!(f, "Successful steps: {}", self.successful_steps)?;
593 writeln!(f, "Skipped steps: {}", self.skipped_steps)?;
594 writeln!(f, "Overflow count: {}", self.overflow_count)?;
595 writeln!(f, "Overflow rate: {:.2}%", self.overflow_rate * 100.0)?;
596 writeln!(f, "Success rate: {:.2}%", self.success_rate * 100.0)?;
597 Ok(())
598 }
599}
600
601#[derive(Debug, Clone)]
603pub struct GradientCheckpoint {
604 pub id: String,
606
607 pub saved_tensors: HashMap<String, Vec<f64>>,
609
610 pub memory_saved: usize,
612}
613
614impl GradientCheckpoint {
615 pub fn new(id: String) -> Self {
617 Self {
618 id,
619 saved_tensors: HashMap::new(),
620 memory_saved: 0,
621 }
622 }
623
624 pub fn save_tensor(&mut self, name: String, data: Vec<f64>) {
626 let bytes = data.len() * std::mem::size_of::<f64>();
627 self.memory_saved += bytes;
628 self.saved_tensors.insert(name, data);
629 }
630
631 pub fn memory_saved_mb(&self) -> f64 {
633 self.memory_saved as f64 / (1024.0 * 1024.0)
634 }
635}
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640
641 #[test]
642 fn test_precision_mode_bytes() {
643 assert_eq!(PrecisionMode::FP64.bytes_per_element(), 8);
644 assert_eq!(PrecisionMode::FP32.bytes_per_element(), 4);
645 assert_eq!(PrecisionMode::FP16.bytes_per_element(), 2);
646 assert_eq!(PrecisionMode::BF16.bytes_per_element(), 2);
647 assert_eq!(PrecisionMode::FP8.bytes_per_element(), 1);
648 }
649
650 #[test]
651 fn test_precision_mode_is_mixed() {
652 assert!(!PrecisionMode::FP32.is_mixed_precision());
653 assert!(PrecisionMode::FP16.is_mixed_precision());
654 assert!(PrecisionMode::BF16.is_mixed_precision());
655 assert!(PrecisionMode::FP8.is_mixed_precision());
656 }
657
658 #[test]
659 fn test_loss_scaler_static() {
660 let scaler = LossScaler::new(LossScalingStrategy::Static { scale: 1024.0 });
661 assert_eq!(scaler.scale(), 1024.0);
662
663 let loss = 0.5;
664 let scaled = scaler.scale_loss(loss);
665 assert_eq!(scaled, 512.0);
666 }
667
668 #[test]
669 fn test_loss_scaler_dynamic_no_overflow() {
670 let mut scaler = LossScaler::new(LossScalingStrategy::Dynamic {
671 init_scale: 1024.0,
672 growth_factor: 2.0,
673 backoff_factor: 0.5,
674 growth_interval: 2,
675 });
676
677 assert_eq!(scaler.scale(), 1024.0);
678
679 scaler.update(false).unwrap();
681 scaler.update(false).unwrap();
682
683 assert_eq!(scaler.scale(), 2048.0);
685 }
686
687 #[test]
688 fn test_loss_scaler_dynamic_with_overflow() {
689 let mut scaler = LossScaler::new(LossScalingStrategy::Dynamic {
690 init_scale: 1024.0,
691 growth_factor: 2.0,
692 backoff_factor: 0.5,
693 growth_interval: 2,
694 });
695
696 assert_eq!(scaler.scale(), 1024.0);
697
698 scaler.update(true).unwrap();
700
701 assert_eq!(scaler.scale(), 512.0);
703 }
704
705 #[test]
706 fn test_loss_scaler_overflow_detection() {
707 let scaler = LossScaler::new(LossScalingStrategy::None);
708
709 let mut grads = HashMap::new();
710 grads.insert("w1".to_string(), 1.0);
711 grads.insert("w2".to_string(), 2.0);
712
713 assert!(scaler.check_overflow(&grads).is_ok());
715
716 grads.insert("w3".to_string(), f64::NAN);
718 assert!(matches!(
719 scaler.check_overflow(&grads),
720 Err(MixedPrecisionError::GradientNaN)
721 ));
722
723 grads.remove("w3");
725 grads.insert("w4".to_string(), f64::INFINITY);
726 assert!(matches!(
727 scaler.check_overflow(&grads),
728 Err(MixedPrecisionError::GradientOverflow(_))
729 ));
730 }
731
732 #[test]
733 fn test_mixed_precision_config_default() {
734 let config = MixedPrecisionConfig::default();
735 assert_eq!(config.compute_dtype, PrecisionMode::FP16);
736 assert_eq!(config.param_dtype, PrecisionMode::FP32);
737 assert!(config.use_master_weights);
738 assert!(config.stability_checks);
739 }
740
741 #[test]
742 fn test_mixed_precision_config_builders() {
743 let config = MixedPrecisionConfig::fp16();
744 assert_eq!(config.compute_dtype, PrecisionMode::FP16);
745
746 let config = MixedPrecisionConfig::bf16();
747 assert_eq!(config.compute_dtype, PrecisionMode::BF16);
748
749 let config = MixedPrecisionConfig::fp8();
750 assert_eq!(config.compute_dtype, PrecisionMode::FP8);
751 }
752
753 #[test]
754 fn test_mixed_precision_config_validation() {
755 let config = MixedPrecisionConfig::new(PrecisionMode::FP16, PrecisionMode::FP32);
757 assert!(config.validate().is_ok());
758
759 let config = MixedPrecisionConfig::new(PrecisionMode::FP32, PrecisionMode::FP16);
761 assert!(config.validate().is_err());
762 }
763
764 #[test]
765 fn test_mixed_precision_state() {
766 let config = MixedPrecisionConfig::fp16();
767 let mut state = MixedPrecisionState::new(config).unwrap();
768
769 assert_eq!(state.step, 0);
770 assert_eq!(state.successful_steps, 0);
771 assert_eq!(state.skipped_steps, 0);
772
773 let mut grads = HashMap::new();
775 grads.insert("w1".to_string(), 0.1);
776 grads.insert("w2".to_string(), 0.2);
777
778 let should_update = state.process_step(0.5, &mut grads).unwrap();
779 assert!(should_update);
780 assert_eq!(state.step, 1);
781 assert_eq!(state.successful_steps, 1);
782 assert_eq!(state.skipped_steps, 0);
783 }
784
785 #[test]
786 fn test_mixed_precision_stats_display() {
787 let stats = MixedPrecisionStats {
788 compute_dtype: PrecisionMode::FP16,
789 param_dtype: PrecisionMode::FP32,
790 current_scale: 1024.0,
791 total_steps: 100,
792 successful_steps: 95,
793 skipped_steps: 5,
794 overflow_count: 5,
795 overflow_rate: 0.05,
796 success_rate: 0.95,
797 };
798
799 let display = format!("{}", stats);
800 assert!(display.contains("FP16"));
801 assert!(display.contains("1024"));
802 assert!(display.contains("95"));
803 }
804
805 #[test]
806 fn test_gradient_checkpoint() {
807 let mut checkpoint = GradientCheckpoint::new("layer1".to_string());
808 assert_eq!(checkpoint.memory_saved, 0);
809
810 checkpoint.save_tensor("activations".to_string(), vec![1.0, 2.0, 3.0]);
811 assert!(checkpoint.memory_saved > 0);
812 assert!(checkpoint.memory_saved_mb() > 0.0);
813 }
814
815 #[test]
816 fn test_loss_scaler_stats() {
817 let mut scaler = LossScaler::new(LossScalingStrategy::Dynamic {
818 init_scale: 1024.0,
819 growth_factor: 2.0,
820 backoff_factor: 0.5,
821 growth_interval: 2,
822 });
823
824 scaler.update(false).unwrap();
825 scaler.update(true).unwrap();
826 scaler.update(false).unwrap();
827
828 let stats = scaler.stats();
829 assert_eq!(stats.total_steps, 3);
830 assert_eq!(stats.overflow_count, 1);
831 assert!((stats.overflow_rate - 0.333).abs() < 0.01);
832 }
833
834 #[test]
835 fn test_loss_scaler_reset() {
836 let mut scaler = LossScaler::new(LossScalingStrategy::Dynamic {
837 init_scale: 1024.0,
838 growth_factor: 2.0,
839 backoff_factor: 0.5,
840 growth_interval: 2,
841 });
842
843 scaler.update(false).unwrap();
844 scaler.update(false).unwrap();
845 assert_eq!(scaler.scale(), 2048.0);
846
847 scaler.reset();
848 assert_eq!(scaler.scale(), 1024.0);
849 assert_eq!(scaler.stats().total_steps, 0);
850 }
851
852 #[test]
853 fn test_unscale_gradients() {
854 let scaler = LossScaler::new(LossScalingStrategy::Static { scale: 1024.0 });
855
856 let mut grads = HashMap::new();
857 grads.insert("w1".to_string(), 1024.0);
858 grads.insert("w2".to_string(), 2048.0);
859
860 scaler.unscale_gradients(&mut grads);
861
862 assert_eq!(grads.get("w1").unwrap(), &1.0);
863 assert_eq!(grads.get("w2").unwrap(), &2.0);
864 }
865}