Skip to main content

tensorlogic_infer/
mixed_precision.rs

1//! Mixed precision training utilities.
2//!
3//! This module provides comprehensive mixed precision training support:
4//! - FP16 (half precision) and BF16 (bfloat16) computation modes
5//! - Automatic loss scaling with dynamic adjustment
6//! - Gradient checkpointing integration
7//! - Mixed precision optimizer wrappers
8//! - Numerical stability monitoring
9//! - Performance profiling for mixed precision operations
10//!
11//! ## Example
12//!
13//! ```rust,ignore
14//! use tensorlogic_infer::{MixedPrecisionConfig, PrecisionMode, LossScaler, MixedPrecisionTrainer};
15//!
16//! // Configure mixed precision training
17//! let config = MixedPrecisionConfig::default()
18//!     .with_compute_dtype(PrecisionMode::FP16)
19//!     .with_param_dtype(PrecisionMode::FP32)
20//!     .with_loss_scaling(LossScalingStrategy::Dynamic {
21//!         init_scale: 65536.0,
22//!         growth_factor: 2.0,
23//!         backoff_factor: 0.5,
24//!         growth_interval: 2000,
25//!     });
26//!
27//! // Create mixed precision trainer
28//! let mut trainer = MixedPrecisionTrainer::new(executor, config);
29//!
30//! // Training loop with automatic loss scaling
31//! for batch in dataset {
32//!     let loss = trainer.train_step(&batch)?;
33//!     println!("Loss: {:.4}, Scale: {:.0}", loss, trainer.current_loss_scale());
34//! }
35//! ```
36
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use thiserror::Error;
40
41/// Mixed precision training errors.
42#[derive(Error, Debug, Clone, PartialEq)]
43pub enum MixedPrecisionError {
44    #[error("Loss scale overflow: scale={0}")]
45    LossScaleOverflow(f64),
46
47    #[error("Loss scale underflow: scale={0}")]
48    LossScaleUnderflow(f64),
49
50    #[error("Gradient overflow detected in {0} gradients")]
51    GradientOverflow(usize),
52
53    #[error("NaN detected in gradients")]
54    GradientNaN,
55
56    #[error("Unsupported precision mode: {0:?}")]
57    UnsupportedPrecisionMode(PrecisionMode),
58
59    #[error("Mixed precision not supported by backend")]
60    NotSupported,
61
62    #[error("Numerical instability detected: {0}")]
63    NumericalInstability(String),
64}
65
66/// Precision mode for computation.
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
68pub enum PrecisionMode {
69    /// 32-bit floating point (full precision)
70    FP32,
71    /// 16-bit floating point (half precision)
72    FP16,
73    /// BFloat16 (brain floating point 16)
74    BF16,
75    /// 64-bit floating point (double precision)
76    FP64,
77    /// 8-bit floating point (experimental)
78    FP8,
79}
80
81impl PrecisionMode {
82    /// Get the number of bytes per element.
83    pub fn bytes_per_element(&self) -> usize {
84        match self {
85            PrecisionMode::FP64 => 8,
86            PrecisionMode::FP32 => 4,
87            PrecisionMode::FP16 | PrecisionMode::BF16 => 2,
88            PrecisionMode::FP8 => 1,
89        }
90    }
91
92    /// Check if this precision mode is mixed (lower than FP32).
93    pub fn is_mixed_precision(&self) -> bool {
94        matches!(
95            self,
96            PrecisionMode::FP16 | PrecisionMode::BF16 | PrecisionMode::FP8
97        )
98    }
99
100    /// Get the precision name.
101    pub fn name(&self) -> &'static str {
102        match self {
103            PrecisionMode::FP32 => "float32",
104            PrecisionMode::FP16 => "float16",
105            PrecisionMode::BF16 => "bfloat16",
106            PrecisionMode::FP64 => "float64",
107            PrecisionMode::FP8 => "float8",
108        }
109    }
110}
111
112/// Loss scaling strategy for mixed precision training.
113#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
114pub enum LossScalingStrategy {
115    /// No loss scaling
116    None,
117
118    /// Static loss scaling with a fixed scale factor
119    Static { scale: f64 },
120
121    /// Dynamic loss scaling with automatic adjustment
122    Dynamic {
123        /// Initial scale factor
124        init_scale: f64,
125        /// Growth factor when no overflow detected
126        growth_factor: f64,
127        /// Backoff factor when overflow detected
128        backoff_factor: f64,
129        /// Number of successful steps before growing scale
130        growth_interval: usize,
131    },
132}
133
134impl Default for LossScalingStrategy {
135    fn default() -> Self {
136        LossScalingStrategy::Dynamic {
137            init_scale: 65536.0,
138            growth_factor: 2.0,
139            backoff_factor: 0.5,
140            growth_interval: 2000,
141        }
142    }
143}
144
145/// Loss scaler for automatic mixed precision training.
146#[derive(Debug, Clone)]
147pub struct LossScaler {
148    strategy: LossScalingStrategy,
149    current_scale: f64,
150    growth_tracker: usize,
151    overflow_count: usize,
152    total_steps: usize,
153}
154
155impl LossScaler {
156    /// Create a new loss scaler with the given strategy.
157    pub fn new(strategy: LossScalingStrategy) -> Self {
158        let current_scale = match &strategy {
159            LossScalingStrategy::None => 1.0,
160            LossScalingStrategy::Static { scale } => *scale,
161            LossScalingStrategy::Dynamic { init_scale, .. } => *init_scale,
162        };
163
164        Self {
165            strategy,
166            current_scale,
167            growth_tracker: 0,
168            overflow_count: 0,
169            total_steps: 0,
170        }
171    }
172
173    /// Get the current loss scale.
174    pub fn scale(&self) -> f64 {
175        self.current_scale
176    }
177
178    /// Scale the loss value.
179    pub fn scale_loss(&self, loss: f64) -> f64 {
180        loss * self.current_scale
181    }
182
183    /// Unscale gradients.
184    pub fn unscale_gradients(&self, grads: &mut HashMap<String, f64>) {
185        let inv_scale = 1.0 / self.current_scale;
186        for grad in grads.values_mut() {
187            *grad *= inv_scale;
188        }
189    }
190
191    /// Check for gradient overflow or NaN values.
192    pub fn check_overflow(&self, grads: &HashMap<String, f64>) -> Result<(), MixedPrecisionError> {
193        let mut has_nan = false;
194        let mut has_inf = false;
195
196        for grad in grads.values() {
197            if grad.is_nan() {
198                has_nan = true;
199            }
200            if grad.is_infinite() {
201                has_inf = true;
202            }
203        }
204
205        if has_nan {
206            return Err(MixedPrecisionError::GradientNaN);
207        }
208        if has_inf {
209            return Err(MixedPrecisionError::GradientOverflow(
210                grads.values().filter(|g| g.is_infinite()).count(),
211            ));
212        }
213
214        Ok(())
215    }
216
217    /// Update the loss scale based on overflow detection.
218    pub fn update(&mut self, found_overflow: bool) -> Result<(), MixedPrecisionError> {
219        self.total_steps += 1;
220
221        match &self.strategy {
222            LossScalingStrategy::None | LossScalingStrategy::Static { .. } => {
223                // No update for static scaling
224            }
225            LossScalingStrategy::Dynamic {
226                growth_factor,
227                backoff_factor,
228                growth_interval,
229                ..
230            } => {
231                if found_overflow {
232                    // Reduce scale
233                    self.current_scale *= backoff_factor;
234                    self.growth_tracker = 0;
235                    self.overflow_count += 1;
236
237                    // Check for underflow
238                    if self.current_scale < 1.0 {
239                        return Err(MixedPrecisionError::LossScaleUnderflow(self.current_scale));
240                    }
241                } else {
242                    // Increase growth tracker
243                    self.growth_tracker += 1;
244
245                    // Grow scale if interval reached
246                    if self.growth_tracker >= *growth_interval {
247                        self.current_scale *= growth_factor;
248                        self.growth_tracker = 0;
249
250                        // Check for overflow
251                        if self.current_scale > 1e10 {
252                            return Err(MixedPrecisionError::LossScaleOverflow(self.current_scale));
253                        }
254                    }
255                }
256            }
257        }
258
259        Ok(())
260    }
261
262    /// Get statistics about loss scaling.
263    pub fn stats(&self) -> LossScalerStats {
264        LossScalerStats {
265            current_scale: self.current_scale,
266            overflow_count: self.overflow_count,
267            total_steps: self.total_steps,
268            overflow_rate: self.overflow_count as f64 / self.total_steps.max(1) as f64,
269            growth_tracker: self.growth_tracker,
270        }
271    }
272
273    /// Reset the loss scaler to initial state.
274    pub fn reset(&mut self) {
275        let init_scale = match &self.strategy {
276            LossScalingStrategy::None => 1.0,
277            LossScalingStrategy::Static { scale } => *scale,
278            LossScalingStrategy::Dynamic { init_scale, .. } => *init_scale,
279        };
280
281        self.current_scale = init_scale;
282        self.growth_tracker = 0;
283        self.overflow_count = 0;
284        self.total_steps = 0;
285    }
286}
287
288/// Statistics about loss scaling.
289#[derive(Debug, Clone, PartialEq)]
290pub struct LossScalerStats {
291    /// Current scale factor
292    pub current_scale: f64,
293    /// Number of overflow events
294    pub overflow_count: usize,
295    /// Total training steps
296    pub total_steps: usize,
297    /// Overflow rate (overflows / total steps)
298    pub overflow_rate: f64,
299    /// Current growth tracker value
300    pub growth_tracker: usize,
301}
302
303/// Mixed precision training configuration.
304#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
305pub struct MixedPrecisionConfig {
306    /// Precision mode for computation
307    pub compute_dtype: PrecisionMode,
308
309    /// Precision mode for parameters (usually FP32)
310    pub param_dtype: PrecisionMode,
311
312    /// Loss scaling strategy
313    pub loss_scaling: LossScalingStrategy,
314
315    /// Enable gradient checkpointing to save memory
316    pub gradient_checkpointing: bool,
317
318    /// Enable gradient clipping before unscaling
319    pub gradient_clipping: bool,
320
321    /// Maximum gradient norm for clipping
322    pub max_gradient_norm: f64,
323
324    /// Enable numerical stability checks
325    pub stability_checks: bool,
326
327    /// Skip optimizer step if overflow detected
328    pub skip_on_overflow: bool,
329
330    /// Enable master weights in FP32
331    pub use_master_weights: bool,
332}
333
334impl Default for MixedPrecisionConfig {
335    fn default() -> Self {
336        Self {
337            compute_dtype: PrecisionMode::FP16,
338            param_dtype: PrecisionMode::FP32,
339            loss_scaling: LossScalingStrategy::default(),
340            gradient_checkpointing: false,
341            gradient_clipping: true,
342            max_gradient_norm: 1.0,
343            stability_checks: true,
344            skip_on_overflow: true,
345            use_master_weights: true,
346        }
347    }
348}
349
350impl MixedPrecisionConfig {
351    /// Create a new mixed precision config.
352    pub fn new(compute_dtype: PrecisionMode, param_dtype: PrecisionMode) -> Self {
353        Self {
354            compute_dtype,
355            param_dtype,
356            ..Default::default()
357        }
358    }
359
360    /// Set the compute dtype.
361    pub fn with_compute_dtype(mut self, dtype: PrecisionMode) -> Self {
362        self.compute_dtype = dtype;
363        self
364    }
365
366    /// Set the parameter dtype.
367    pub fn with_param_dtype(mut self, dtype: PrecisionMode) -> Self {
368        self.param_dtype = dtype;
369        self
370    }
371
372    /// Set the loss scaling strategy.
373    pub fn with_loss_scaling(mut self, strategy: LossScalingStrategy) -> Self {
374        self.loss_scaling = strategy;
375        self
376    }
377
378    /// Enable or disable gradient checkpointing.
379    pub fn with_gradient_checkpointing(mut self, enabled: bool) -> Self {
380        self.gradient_checkpointing = enabled;
381        self
382    }
383
384    /// Enable or disable gradient clipping.
385    pub fn with_gradient_clipping(mut self, enabled: bool, max_norm: f64) -> Self {
386        self.gradient_clipping = enabled;
387        self.max_gradient_norm = max_norm;
388        self
389    }
390
391    /// Enable or disable stability checks.
392    pub fn with_stability_checks(mut self, enabled: bool) -> Self {
393        self.stability_checks = enabled;
394        self
395    }
396
397    /// Enable or disable master weights.
398    pub fn with_master_weights(mut self, enabled: bool) -> Self {
399        self.use_master_weights = enabled;
400        self
401    }
402
403    /// Create FP16 mixed precision config.
404    pub fn fp16() -> Self {
405        Self::new(PrecisionMode::FP16, PrecisionMode::FP32)
406    }
407
408    /// Create BF16 mixed precision config.
409    pub fn bf16() -> Self {
410        Self::new(PrecisionMode::BF16, PrecisionMode::FP32)
411    }
412
413    /// Create FP8 mixed precision config (experimental).
414    pub fn fp8() -> Self {
415        Self::new(PrecisionMode::FP8, PrecisionMode::FP32)
416    }
417
418    /// Validate the configuration.
419    pub fn validate(&self) -> Result<(), MixedPrecisionError> {
420        // Check that param dtype is at least as precise as compute dtype
421        let compute_bytes = self.compute_dtype.bytes_per_element();
422        let param_bytes = self.param_dtype.bytes_per_element();
423
424        if param_bytes < compute_bytes {
425            return Err(MixedPrecisionError::NumericalInstability(format!(
426                "Parameter dtype ({:?}) should be at least as precise as compute dtype ({:?})",
427                self.param_dtype, self.compute_dtype
428            )));
429        }
430
431        // Validate loss scaling parameters
432        if let LossScalingStrategy::Dynamic {
433            init_scale,
434            growth_factor,
435            backoff_factor,
436            ..
437        } = &self.loss_scaling
438        {
439            if *init_scale <= 0.0 {
440                return Err(MixedPrecisionError::LossScaleUnderflow(*init_scale));
441            }
442            if *growth_factor <= 1.0 {
443                return Err(MixedPrecisionError::NumericalInstability(format!(
444                    "Growth factor must be > 1.0, got {}",
445                    growth_factor
446                )));
447            }
448            if *backoff_factor >= 1.0 || *backoff_factor <= 0.0 {
449                return Err(MixedPrecisionError::NumericalInstability(format!(
450                    "Backoff factor must be in (0, 1), got {}",
451                    backoff_factor
452                )));
453            }
454        }
455
456        Ok(())
457    }
458}
459
460/// Mixed precision training state.
461#[derive(Debug, Clone)]
462pub struct MixedPrecisionState {
463    /// Configuration
464    pub config: MixedPrecisionConfig,
465
466    /// Loss scaler
467    pub scaler: LossScaler,
468
469    /// Master weights (FP32 copies of parameters)
470    pub master_weights: HashMap<String, Vec<f64>>,
471
472    /// Number of successful steps
473    pub successful_steps: usize,
474
475    /// Number of skipped steps due to overflow
476    pub skipped_steps: usize,
477
478    /// Training step counter
479    pub step: usize,
480}
481
482impl MixedPrecisionState {
483    /// Create a new mixed precision training state.
484    pub fn new(config: MixedPrecisionConfig) -> Result<Self, MixedPrecisionError> {
485        config.validate()?;
486
487        Ok(Self {
488            scaler: LossScaler::new(config.loss_scaling.clone()),
489            config,
490            master_weights: HashMap::new(),
491            successful_steps: 0,
492            skipped_steps: 0,
493            step: 0,
494        })
495    }
496
497    /// Initialize master weights.
498    pub fn init_master_weights(&mut self, params: &HashMap<String, Vec<f64>>) {
499        if self.config.use_master_weights {
500            self.master_weights = params.clone();
501        }
502    }
503
504    /// Get current loss scale.
505    pub fn current_loss_scale(&self) -> f64 {
506        self.scaler.scale()
507    }
508
509    /// Get training statistics.
510    pub fn stats(&self) -> MixedPrecisionStats {
511        let scaler_stats = self.scaler.stats();
512
513        MixedPrecisionStats {
514            compute_dtype: self.config.compute_dtype,
515            param_dtype: self.config.param_dtype,
516            current_scale: scaler_stats.current_scale,
517            total_steps: self.step,
518            successful_steps: self.successful_steps,
519            skipped_steps: self.skipped_steps,
520            overflow_count: scaler_stats.overflow_count,
521            overflow_rate: scaler_stats.overflow_rate,
522            success_rate: self.successful_steps as f64 / self.step.max(1) as f64,
523        }
524    }
525
526    /// Process training step with automatic loss scaling.
527    pub fn process_step(
528        &mut self,
529        loss: f64,
530        gradients: &mut HashMap<String, f64>,
531    ) -> Result<bool, MixedPrecisionError> {
532        self.step += 1;
533
534        // Scale loss (used implicitly in gradient scaling)
535        let _scaled_loss = self.scaler.scale_loss(loss);
536
537        // Scale gradients by loss scale
538        for grad in gradients.values_mut() {
539            *grad *= self.scaler.scale();
540        }
541
542        // Unscale gradients
543        self.scaler.unscale_gradients(gradients);
544
545        // Check for overflow
546        let found_overflow = self.scaler.check_overflow(gradients).is_err();
547
548        // Update loss scale
549        self.scaler.update(found_overflow)?;
550
551        if found_overflow {
552            self.skipped_steps += 1;
553            Ok(false) // Skip optimizer step
554        } else {
555            self.successful_steps += 1;
556            Ok(true) // Proceed with optimizer step
557        }
558    }
559}
560
561/// Statistics about mixed precision training.
562#[derive(Debug, Clone, PartialEq)]
563pub struct MixedPrecisionStats {
564    /// Compute dtype
565    pub compute_dtype: PrecisionMode,
566    /// Parameter dtype
567    pub param_dtype: PrecisionMode,
568    /// Current loss scale
569    pub current_scale: f64,
570    /// Total training steps
571    pub total_steps: usize,
572    /// Successful steps (no overflow)
573    pub successful_steps: usize,
574    /// Skipped steps (due to overflow)
575    pub skipped_steps: usize,
576    /// Total overflow count
577    pub overflow_count: usize,
578    /// Overflow rate
579    pub overflow_rate: f64,
580    /// Success rate
581    pub success_rate: f64,
582}
583
584impl std::fmt::Display for MixedPrecisionStats {
585    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
586        writeln!(f, "Mixed Precision Training Statistics")?;
587        writeln!(f, "====================================")?;
588        writeln!(f, "Compute dtype:     {:?}", self.compute_dtype)?;
589        writeln!(f, "Parameter dtype:   {:?}", self.param_dtype)?;
590        writeln!(f, "Current scale:     {:.0}", self.current_scale)?;
591        writeln!(f, "Total steps:       {}", self.total_steps)?;
592        writeln!(f, "Successful steps:  {}", self.successful_steps)?;
593        writeln!(f, "Skipped steps:     {}", self.skipped_steps)?;
594        writeln!(f, "Overflow count:    {}", self.overflow_count)?;
595        writeln!(f, "Overflow rate:     {:.2}%", self.overflow_rate * 100.0)?;
596        writeln!(f, "Success rate:      {:.2}%", self.success_rate * 100.0)?;
597        Ok(())
598    }
599}
600
601/// Gradient checkpointing utility.
602#[derive(Debug, Clone)]
603pub struct GradientCheckpoint {
604    /// Checkpoint identifier
605    pub id: String,
606
607    /// Saved tensors for recomputation
608    pub saved_tensors: HashMap<String, Vec<f64>>,
609
610    /// Memory saved by checkpointing (bytes)
611    pub memory_saved: usize,
612}
613
614impl GradientCheckpoint {
615    /// Create a new gradient checkpoint.
616    pub fn new(id: String) -> Self {
617        Self {
618            id,
619            saved_tensors: HashMap::new(),
620            memory_saved: 0,
621        }
622    }
623
624    /// Save a tensor for recomputation.
625    pub fn save_tensor(&mut self, name: String, data: Vec<f64>) {
626        let bytes = data.len() * std::mem::size_of::<f64>();
627        self.memory_saved += bytes;
628        self.saved_tensors.insert(name, data);
629    }
630
631    /// Get memory saved by this checkpoint.
632    pub fn memory_saved_mb(&self) -> f64 {
633        self.memory_saved as f64 / (1024.0 * 1024.0)
634    }
635}
636
637#[cfg(test)]
638mod tests {
639    use super::*;
640
641    #[test]
642    fn test_precision_mode_bytes() {
643        assert_eq!(PrecisionMode::FP64.bytes_per_element(), 8);
644        assert_eq!(PrecisionMode::FP32.bytes_per_element(), 4);
645        assert_eq!(PrecisionMode::FP16.bytes_per_element(), 2);
646        assert_eq!(PrecisionMode::BF16.bytes_per_element(), 2);
647        assert_eq!(PrecisionMode::FP8.bytes_per_element(), 1);
648    }
649
650    #[test]
651    fn test_precision_mode_is_mixed() {
652        assert!(!PrecisionMode::FP32.is_mixed_precision());
653        assert!(PrecisionMode::FP16.is_mixed_precision());
654        assert!(PrecisionMode::BF16.is_mixed_precision());
655        assert!(PrecisionMode::FP8.is_mixed_precision());
656    }
657
658    #[test]
659    fn test_loss_scaler_static() {
660        let scaler = LossScaler::new(LossScalingStrategy::Static { scale: 1024.0 });
661        assert_eq!(scaler.scale(), 1024.0);
662
663        let loss = 0.5;
664        let scaled = scaler.scale_loss(loss);
665        assert_eq!(scaled, 512.0);
666    }
667
668    #[test]
669    fn test_loss_scaler_dynamic_no_overflow() {
670        let mut scaler = LossScaler::new(LossScalingStrategy::Dynamic {
671            init_scale: 1024.0,
672            growth_factor: 2.0,
673            backoff_factor: 0.5,
674            growth_interval: 2,
675        });
676
677        assert_eq!(scaler.scale(), 1024.0);
678
679        // No overflow for 2 steps
680        scaler.update(false).unwrap();
681        scaler.update(false).unwrap();
682
683        // Scale should have grown
684        assert_eq!(scaler.scale(), 2048.0);
685    }
686
687    #[test]
688    fn test_loss_scaler_dynamic_with_overflow() {
689        let mut scaler = LossScaler::new(LossScalingStrategy::Dynamic {
690            init_scale: 1024.0,
691            growth_factor: 2.0,
692            backoff_factor: 0.5,
693            growth_interval: 2,
694        });
695
696        assert_eq!(scaler.scale(), 1024.0);
697
698        // Overflow detected
699        scaler.update(true).unwrap();
700
701        // Scale should have reduced
702        assert_eq!(scaler.scale(), 512.0);
703    }
704
705    #[test]
706    fn test_loss_scaler_overflow_detection() {
707        let scaler = LossScaler::new(LossScalingStrategy::None);
708
709        let mut grads = HashMap::new();
710        grads.insert("w1".to_string(), 1.0);
711        grads.insert("w2".to_string(), 2.0);
712
713        // No overflow
714        assert!(scaler.check_overflow(&grads).is_ok());
715
716        // Add NaN
717        grads.insert("w3".to_string(), f64::NAN);
718        assert!(matches!(
719            scaler.check_overflow(&grads),
720            Err(MixedPrecisionError::GradientNaN)
721        ));
722
723        // Remove NaN and add Inf
724        grads.remove("w3");
725        grads.insert("w4".to_string(), f64::INFINITY);
726        assert!(matches!(
727            scaler.check_overflow(&grads),
728            Err(MixedPrecisionError::GradientOverflow(_))
729        ));
730    }
731
732    #[test]
733    fn test_mixed_precision_config_default() {
734        let config = MixedPrecisionConfig::default();
735        assert_eq!(config.compute_dtype, PrecisionMode::FP16);
736        assert_eq!(config.param_dtype, PrecisionMode::FP32);
737        assert!(config.use_master_weights);
738        assert!(config.stability_checks);
739    }
740
741    #[test]
742    fn test_mixed_precision_config_builders() {
743        let config = MixedPrecisionConfig::fp16();
744        assert_eq!(config.compute_dtype, PrecisionMode::FP16);
745
746        let config = MixedPrecisionConfig::bf16();
747        assert_eq!(config.compute_dtype, PrecisionMode::BF16);
748
749        let config = MixedPrecisionConfig::fp8();
750        assert_eq!(config.compute_dtype, PrecisionMode::FP8);
751    }
752
753    #[test]
754    fn test_mixed_precision_config_validation() {
755        // Valid config
756        let config = MixedPrecisionConfig::new(PrecisionMode::FP16, PrecisionMode::FP32);
757        assert!(config.validate().is_ok());
758
759        // Invalid: param dtype less precise than compute dtype
760        let config = MixedPrecisionConfig::new(PrecisionMode::FP32, PrecisionMode::FP16);
761        assert!(config.validate().is_err());
762    }
763
764    #[test]
765    fn test_mixed_precision_state() {
766        let config = MixedPrecisionConfig::fp16();
767        let mut state = MixedPrecisionState::new(config).unwrap();
768
769        assert_eq!(state.step, 0);
770        assert_eq!(state.successful_steps, 0);
771        assert_eq!(state.skipped_steps, 0);
772
773        // Process step without overflow
774        let mut grads = HashMap::new();
775        grads.insert("w1".to_string(), 0.1);
776        grads.insert("w2".to_string(), 0.2);
777
778        let should_update = state.process_step(0.5, &mut grads).unwrap();
779        assert!(should_update);
780        assert_eq!(state.step, 1);
781        assert_eq!(state.successful_steps, 1);
782        assert_eq!(state.skipped_steps, 0);
783    }
784
785    #[test]
786    fn test_mixed_precision_stats_display() {
787        let stats = MixedPrecisionStats {
788            compute_dtype: PrecisionMode::FP16,
789            param_dtype: PrecisionMode::FP32,
790            current_scale: 1024.0,
791            total_steps: 100,
792            successful_steps: 95,
793            skipped_steps: 5,
794            overflow_count: 5,
795            overflow_rate: 0.05,
796            success_rate: 0.95,
797        };
798
799        let display = format!("{}", stats);
800        assert!(display.contains("FP16"));
801        assert!(display.contains("1024"));
802        assert!(display.contains("95"));
803    }
804
805    #[test]
806    fn test_gradient_checkpoint() {
807        let mut checkpoint = GradientCheckpoint::new("layer1".to_string());
808        assert_eq!(checkpoint.memory_saved, 0);
809
810        checkpoint.save_tensor("activations".to_string(), vec![1.0, 2.0, 3.0]);
811        assert!(checkpoint.memory_saved > 0);
812        assert!(checkpoint.memory_saved_mb() > 0.0);
813    }
814
815    #[test]
816    fn test_loss_scaler_stats() {
817        let mut scaler = LossScaler::new(LossScalingStrategy::Dynamic {
818            init_scale: 1024.0,
819            growth_factor: 2.0,
820            backoff_factor: 0.5,
821            growth_interval: 2,
822        });
823
824        scaler.update(false).unwrap();
825        scaler.update(true).unwrap();
826        scaler.update(false).unwrap();
827
828        let stats = scaler.stats();
829        assert_eq!(stats.total_steps, 3);
830        assert_eq!(stats.overflow_count, 1);
831        assert!((stats.overflow_rate - 0.333).abs() < 0.01);
832    }
833
834    #[test]
835    fn test_loss_scaler_reset() {
836        let mut scaler = LossScaler::new(LossScalingStrategy::Dynamic {
837            init_scale: 1024.0,
838            growth_factor: 2.0,
839            backoff_factor: 0.5,
840            growth_interval: 2,
841        });
842
843        scaler.update(false).unwrap();
844        scaler.update(false).unwrap();
845        assert_eq!(scaler.scale(), 2048.0);
846
847        scaler.reset();
848        assert_eq!(scaler.scale(), 1024.0);
849        assert_eq!(scaler.stats().total_steps, 0);
850    }
851
852    #[test]
853    fn test_unscale_gradients() {
854        let scaler = LossScaler::new(LossScalingStrategy::Static { scale: 1024.0 });
855
856        let mut grads = HashMap::new();
857        grads.insert("w1".to_string(), 1024.0);
858        grads.insert("w2".to_string(), 2048.0);
859
860        scaler.unscale_gradients(&mut grads);
861
862        assert_eq!(grads.get("w1").unwrap(), &1.0);
863        assert_eq!(grads.get("w2").unwrap(), &2.0);
864    }
865}