Skip to main content

tensorlogic_train/
mixed_precision.rs

1//! Mixed precision training infrastructure for memory efficiency and speed.
2//!
3//! This module provides utilities for training with reduced precision (FP16/BF16)
4//! while maintaining numerical stability through loss scaling and gradient management.
5//!
6//! # Features
7//! - FP16 and BF16 precision modes
8//! - Dynamic and static loss scaling
9//! - Gradient overflow detection
10//! - Master weight management
11//! - Automatic precision casting
12//!
13//! # Example
14//!
15//! ```rust,ignore
16//! use tensorlogic_train::{MixedPrecisionTrainer, PrecisionMode, LossScaler};
17//!
18//! // Create FP16 trainer with dynamic loss scaling
19//! let mut trainer = MixedPrecisionTrainer::new(
20//!     PrecisionMode::FP16,
21//!     LossScaler::dynamic(2.0_f32.powi(15), 2.0, 2000),
22//! );
23//!
24//! // Train with automatic precision management
25//! trainer.scale_loss(loss);
26//! ```
27
28use scirs2_core::ndarray::Array2;
29use serde::{Deserialize, Serialize};
30use std::collections::HashMap;
31
32use crate::error::TrainResult;
33
34/// Precision mode for mixed precision training.
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum PrecisionMode {
37    /// Full precision (FP32) - baseline
38    FP32,
39    /// Half precision (FP16) - 2x memory reduction
40    FP16,
41    /// Brain floating point (BF16) - better for training
42    BF16,
43}
44
45impl PrecisionMode {
46    /// Returns the number of bytes per element for this precision.
47    pub fn bytes_per_element(&self) -> usize {
48        match self {
49            PrecisionMode::FP32 => 4,
50            PrecisionMode::FP16 => 2,
51            PrecisionMode::BF16 => 2,
52        }
53    }
54
55    /// Returns the memory reduction factor compared to FP32.
56    pub fn memory_reduction(&self) -> f32 {
57        match self {
58            PrecisionMode::FP32 => 1.0,
59            PrecisionMode::FP16 => 2.0,
60            PrecisionMode::BF16 => 2.0,
61        }
62    }
63
64    /// Returns the typical numerical range for this precision.
65    pub fn numerical_range(&self) -> (f32, f32) {
66        match self {
67            PrecisionMode::FP32 => (-3.4e38, 3.4e38),
68            PrecisionMode::FP16 => (-6.55e4, 6.55e4),
69            PrecisionMode::BF16 => (-3.39e38, 3.39e38), // Same exponent range as FP32
70        }
71    }
72}
73
74/// Loss scaling strategy for mixed precision training.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum LossScaler {
77    /// No loss scaling (not recommended for FP16)
78    None,
79    /// Static loss scaling with fixed scale factor
80    Static { scale: f32 },
81    /// Dynamic loss scaling that adjusts based on gradient overflow
82    Dynamic {
83        /// Current scale factor
84        scale: f32,
85        /// Growth factor when no overflow (typically 2.0)
86        growth_factor: f32,
87        /// Backoff factor when overflow detected (typically 0.5)
88        backoff_factor: f32,
89        /// Number of successful steps before growing scale
90        growth_interval: usize,
91        /// Current step counter
92        steps_since_overflow: usize,
93    },
94}
95
96impl LossScaler {
97    /// Creates a static loss scaler.
98    pub fn static_scale(scale: f32) -> Self {
99        Self::Static { scale }
100    }
101
102    /// Creates a dynamic loss scaler with typical defaults.
103    ///
104    /// # Arguments
105    /// * `initial_scale` - Starting scale (typically 2^15 = 32768)
106    /// * `growth_factor` - How much to grow scale (typically 2.0)
107    /// * `growth_interval` - Steps before growing (typically 2000)
108    pub fn dynamic(initial_scale: f32, growth_factor: f32, growth_interval: usize) -> Self {
109        Self::Dynamic {
110            scale: initial_scale,
111            growth_factor,
112            backoff_factor: 0.5,
113            growth_interval,
114            steps_since_overflow: 0,
115        }
116    }
117
118    /// Gets the current scale factor.
119    pub fn get_scale(&self) -> f32 {
120        match self {
121            Self::None => 1.0,
122            Self::Static { scale } => *scale,
123            Self::Dynamic { scale, .. } => *scale,
124        }
125    }
126
127    /// Scales a loss value.
128    pub fn scale_loss(&self, loss: f32) -> f32 {
129        loss * self.get_scale()
130    }
131
132    /// Unscales gradients (divides by scale).
133    pub fn unscale_gradients(&self, gradients: &mut Array2<f32>) {
134        let scale = self.get_scale();
135        if scale != 1.0 {
136            *gradients /= scale;
137        }
138    }
139
140    /// Updates the dynamic scaler based on overflow detection.
141    ///
142    /// # Arguments
143    /// * `overflow_detected` - Whether gradient overflow was detected
144    ///
145    /// # Returns
146    /// True if the optimizer step should proceed
147    pub fn update(&mut self, overflow_detected: bool) -> bool {
148        if let Self::Dynamic {
149            scale,
150            growth_factor,
151            backoff_factor,
152            growth_interval,
153            steps_since_overflow,
154        } = self
155        {
156            if overflow_detected {
157                // Backoff: reduce scale and reset counter
158                *scale *= *backoff_factor;
159                *steps_since_overflow = 0;
160                false // Skip optimizer step
161            } else {
162                // Increment counter
163                *steps_since_overflow += 1;
164
165                // Grow scale if interval reached
166                if *steps_since_overflow >= *growth_interval {
167                    *scale *= *growth_factor;
168                    *steps_since_overflow = 0;
169                }
170                true // Proceed with optimizer step
171            }
172        } else {
173            // Static or None scaling always proceeds
174            !overflow_detected
175        }
176    }
177}
178
179/// Mixed precision training manager.
180pub struct MixedPrecisionTrainer {
181    /// Precision mode
182    mode: PrecisionMode,
183    /// Loss scaler
184    scaler: LossScaler,
185    /// Master weights (FP32) - keeps full precision copy
186    master_weights: HashMap<String, Array2<f32>>,
187    /// Training statistics
188    stats: MixedPrecisionStats,
189}
190
191impl MixedPrecisionTrainer {
192    /// Creates a new mixed precision trainer.
193    pub fn new(mode: PrecisionMode, scaler: LossScaler) -> Self {
194        Self {
195            mode,
196            scaler,
197            master_weights: HashMap::new(),
198            stats: MixedPrecisionStats::default(),
199        }
200    }
201
202    /// Registers weights to maintain master copy.
203    pub fn register_weights(&mut self, name: String, weights: Array2<f32>) {
204        self.master_weights.insert(name, weights);
205    }
206
207    /// Converts FP32 weights to working precision.
208    pub fn cast_to_working_precision(&self, weights: &Array2<f32>) -> Array2<f32> {
209        match self.mode {
210            PrecisionMode::FP32 => weights.clone(),
211            PrecisionMode::FP16 => self.simulate_fp16(weights),
212            PrecisionMode::BF16 => self.simulate_bf16(weights),
213        }
214    }
215
216    /// Simulates FP16 precision (in FP32 container for compatibility).
217    fn simulate_fp16(&self, weights: &Array2<f32>) -> Array2<f32> {
218        weights.mapv(|x| {
219            // Clamp to FP16 range
220            let clamped = x.clamp(-65504.0, 65504.0);
221            // Simulate reduced mantissa precision (10 bits vs 23 bits)
222            let scale = 2.0_f32.powi(10);
223            (clamped * scale).round() / scale
224        })
225    }
226
227    /// Simulates BF16 precision (in FP32 container for compatibility).
228    fn simulate_bf16(&self, weights: &Array2<f32>) -> Array2<f32> {
229        weights.mapv(|x| {
230            // BF16 has same exponent range as FP32, reduced mantissa (7 bits vs 23 bits)
231            let scale = 2.0_f32.powi(7);
232            (x * scale).round() / scale
233        })
234    }
235
236    /// Scales loss for backward pass.
237    pub fn scale_loss(&mut self, loss: f32) -> f32 {
238        self.stats.total_steps += 1;
239        self.scaler.scale_loss(loss)
240    }
241
242    /// Unscales and checks gradients for overflow.
243    ///
244    /// # Returns
245    /// (should_step, overflow_detected)
246    pub fn unscale_and_check_gradients(
247        &mut self,
248        gradients: &mut HashMap<String, Array2<f32>>,
249    ) -> TrainResult<(bool, bool)> {
250        // Check for overflow before unscaling
251        let mut overflow = false;
252        for (_name, grad) in gradients.iter() {
253            if grad.iter().any(|&x| !x.is_finite()) {
254                overflow = true;
255                break;
256            }
257        }
258
259        if overflow {
260            self.stats.overflow_steps += 1;
261        }
262
263        // Unscale gradients
264        for (_name, grad) in gradients.iter_mut() {
265            self.scaler.unscale_gradients(grad);
266        }
267
268        // Update scaler and determine if we should step
269        let should_step = self.scaler.update(overflow);
270
271        Ok((should_step, overflow))
272    }
273
274    /// Updates master weights from working precision weights.
275    pub fn update_master_weights(&mut self, updates: &HashMap<String, Array2<f32>>) {
276        for (name, update) in updates {
277            if let Some(master) = self.master_weights.get_mut(name) {
278                *master = master.clone() + update;
279            }
280        }
281    }
282
283    /// Gets the current precision mode.
284    pub fn mode(&self) -> PrecisionMode {
285        self.mode
286    }
287
288    /// Gets the current loss scale.
289    pub fn current_scale(&self) -> f32 {
290        self.scaler.get_scale()
291    }
292
293    /// Gets training statistics.
294    pub fn stats(&self) -> &MixedPrecisionStats {
295        &self.stats
296    }
297
298    /// Resets statistics.
299    pub fn reset_stats(&mut self) {
300        self.stats = MixedPrecisionStats::default();
301    }
302}
303
304/// Statistics for mixed precision training.
305#[derive(Debug, Clone, Default, Serialize, Deserialize)]
306pub struct MixedPrecisionStats {
307    /// Total training steps attempted
308    pub total_steps: usize,
309    /// Steps with gradient overflow
310    pub overflow_steps: usize,
311    /// Successful optimizer steps
312    pub successful_steps: usize,
313}
314
315impl MixedPrecisionStats {
316    /// Calculates overflow rate.
317    pub fn overflow_rate(&self) -> f32 {
318        if self.total_steps == 0 {
319            0.0
320        } else {
321            self.overflow_steps as f32 / self.total_steps as f32
322        }
323    }
324
325    /// Calculates success rate.
326    pub fn success_rate(&self) -> f32 {
327        if self.total_steps == 0 {
328            0.0
329        } else {
330            self.successful_steps as f32 / self.total_steps as f32
331        }
332    }
333}
334
335/// Gradient scaler for automatic mixed precision.
336pub struct GradientScaler {
337    scaler: LossScaler,
338    enabled: bool,
339}
340
341impl GradientScaler {
342    /// Creates a new gradient scaler.
343    pub fn new(enabled: bool) -> Self {
344        let scaler = if enabled {
345            LossScaler::dynamic(2.0_f32.powi(15), 2.0, 2000)
346        } else {
347            LossScaler::None
348        };
349
350        Self { scaler, enabled }
351    }
352
353    /// Creates a gradient scaler with custom settings.
354    pub fn with_scaler(scaler: LossScaler, enabled: bool) -> Self {
355        Self { scaler, enabled }
356    }
357
358    /// Scales a loss tensor.
359    pub fn scale(&self, loss: f32) -> f32 {
360        if self.enabled {
361            self.scaler.scale_loss(loss)
362        } else {
363            loss
364        }
365    }
366
367    /// Unscales gradients.
368    pub fn unscale(&self, gradients: &mut Array2<f32>) {
369        if self.enabled {
370            self.scaler.unscale_gradients(gradients);
371        }
372    }
373
374    /// Steps with overflow check.
375    pub fn step(&mut self, overflow_detected: bool) -> bool {
376        if self.enabled {
377            self.scaler.update(overflow_detected)
378        } else {
379            !overflow_detected
380        }
381    }
382
383    /// Gets the current scale.
384    pub fn get_scale(&self) -> f32 {
385        self.scaler.get_scale()
386    }
387}
388
389/// Automatic Mixed Precision (AMP) context manager.
390pub struct AutocastContext {
391    enabled: bool,
392    mode: PrecisionMode,
393}
394
395impl AutocastContext {
396    /// Creates a new autocast context.
397    pub fn new(enabled: bool, mode: PrecisionMode) -> Self {
398        Self { enabled, mode }
399    }
400
401    /// Checks if autocast is enabled.
402    pub fn is_enabled(&self) -> bool {
403        self.enabled
404    }
405
406    /// Gets the target precision mode.
407    pub fn mode(&self) -> PrecisionMode {
408        self.mode
409    }
410
411    /// Casts tensor to working precision if enabled.
412    pub fn cast(&self, tensor: &Array2<f32>) -> Array2<f32> {
413        if !self.enabled || self.mode == PrecisionMode::FP32 {
414            return tensor.clone();
415        }
416
417        match self.mode {
418            PrecisionMode::FP16 => self.simulate_fp16(tensor),
419            PrecisionMode::BF16 => self.simulate_bf16(tensor),
420            PrecisionMode::FP32 => tensor.clone(),
421        }
422    }
423
424    fn simulate_fp16(&self, tensor: &Array2<f32>) -> Array2<f32> {
425        tensor.mapv(|x| {
426            let clamped = x.clamp(-65504.0, 65504.0);
427            let scale = 2.0_f32.powi(10);
428            (clamped * scale).round() / scale
429        })
430    }
431
432    fn simulate_bf16(&self, tensor: &Array2<f32>) -> Array2<f32> {
433        tensor.mapv(|x| {
434            let scale = 2.0_f32.powi(7);
435            (x * scale).round() / scale
436        })
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    use approx::assert_relative_eq;
444
445    #[test]
446    fn test_precision_mode_properties() {
447        assert_eq!(PrecisionMode::FP32.bytes_per_element(), 4);
448        assert_eq!(PrecisionMode::FP16.bytes_per_element(), 2);
449        assert_eq!(PrecisionMode::BF16.bytes_per_element(), 2);
450
451        assert_eq!(PrecisionMode::FP16.memory_reduction(), 2.0);
452        assert_eq!(PrecisionMode::BF16.memory_reduction(), 2.0);
453    }
454
455    #[test]
456    fn test_static_loss_scaler() {
457        let scaler = LossScaler::static_scale(1024.0);
458        assert_eq!(scaler.get_scale(), 1024.0);
459
460        let loss = 0.5;
461        let scaled = scaler.scale_loss(loss);
462        assert_eq!(scaled, 512.0);
463    }
464
465    #[test]
466    fn test_dynamic_loss_scaler() {
467        let mut scaler = LossScaler::dynamic(1000.0, 2.0, 3);
468        assert_eq!(scaler.get_scale(), 1000.0);
469
470        // No overflow, should grow after 3 steps
471        assert!(scaler.update(false));
472        assert!(scaler.update(false));
473        assert!(scaler.update(false));
474        assert_eq!(scaler.get_scale(), 2000.0); // Grew
475
476        // Overflow, should backoff
477        assert!(!scaler.update(true));
478        assert_eq!(scaler.get_scale(), 1000.0); // Backoff
479    }
480
481    #[test]
482    fn test_gradient_unscaling() {
483        let mut gradients =
484            Array2::from_shape_vec((2, 2), vec![100.0, 200.0, 300.0, 400.0]).unwrap();
485        let scaler = LossScaler::static_scale(10.0);
486
487        scaler.unscale_gradients(&mut gradients);
488
489        assert_eq!(gradients[[0, 0]], 10.0);
490        assert_eq!(gradients[[0, 1]], 20.0);
491        assert_eq!(gradients[[1, 0]], 30.0);
492        assert_eq!(gradients[[1, 1]], 40.0);
493    }
494
495    #[test]
496    fn test_mixed_precision_trainer() {
497        let mut trainer =
498            MixedPrecisionTrainer::new(PrecisionMode::FP16, LossScaler::static_scale(100.0));
499
500        let loss = 0.5;
501        let scaled_loss = trainer.scale_loss(loss);
502        assert_eq!(scaled_loss, 50.0);
503        assert_eq!(trainer.stats().total_steps, 1);
504    }
505
506    #[test]
507    fn test_fp16_simulation() {
508        let trainer = MixedPrecisionTrainer::new(PrecisionMode::FP16, LossScaler::None);
509
510        let weights =
511            Array2::from_shape_vec((2, 2), vec![1.234_567, 100000.0, -100000.0, 0.0001]).unwrap();
512        let fp16_weights = trainer.cast_to_working_precision(&weights);
513
514        // Should be quantized
515        assert_ne!(fp16_weights[[0, 0]], 1.234_567); // Reduced precision
516        assert!(fp16_weights[[0, 0]] > 1.0 && fp16_weights[[0, 0]] < 2.0);
517
518        // Large values should be clamped to FP16 range
519        assert!(fp16_weights[[0, 1]] <= 65504.0);
520        assert!(fp16_weights[[1, 0]] >= -65504.0);
521    }
522
523    #[test]
524    fn test_bf16_simulation() {
525        let trainer = MixedPrecisionTrainer::new(PrecisionMode::BF16, LossScaler::None);
526
527        let weights =
528            Array2::from_shape_vec((2, 2), vec![1.234_567, 100.5, -50.25, 0.125]).unwrap();
529        let bf16_weights = trainer.cast_to_working_precision(&weights);
530
531        // Should have reduced mantissa precision
532        assert_ne!(bf16_weights[[0, 0]], 1.234_567);
533    }
534
535    #[test]
536    fn test_overflow_detection() {
537        let mut trainer =
538            MixedPrecisionTrainer::new(PrecisionMode::FP16, LossScaler::dynamic(1000.0, 2.0, 100));
539
540        let mut gradients = HashMap::new();
541        gradients.insert(
542            "layer1".to_string(),
543            Array2::from_shape_vec((2, 2), vec![f32::INFINITY, 1.0, 2.0, 3.0]).unwrap(),
544        );
545
546        let (should_step, overflow) = trainer.unscale_and_check_gradients(&mut gradients).unwrap();
547
548        assert!(!should_step);
549        assert!(overflow);
550        assert_eq!(trainer.stats().overflow_steps, 1);
551    }
552
553    #[test]
554    fn test_gradient_scaler() {
555        let scaler = GradientScaler::new(true);
556
557        let loss = 1.0;
558        let scaled = scaler.scale(loss);
559        assert!(scaled > loss); // Should be scaled
560
561        let mut grads = Array2::from_shape_vec((2, 2), vec![1000.0; 4]).unwrap();
562        scaler.unscale(&mut grads);
563        assert!(grads[[0, 0]] < 1000.0); // Should be unscaled
564    }
565
566    #[test]
567    fn test_autocast_context() {
568        let ctx = AutocastContext::new(true, PrecisionMode::FP16);
569        assert!(ctx.is_enabled());
570        assert_eq!(ctx.mode(), PrecisionMode::FP16);
571
572        let tensor = Array2::from_shape_vec((2, 2), vec![1.234_567; 4]).unwrap();
573        let casted = ctx.cast(&tensor);
574
575        // Should have reduced precision
576        assert_ne!(casted[[0, 0]], 1.234_567);
577    }
578
579    #[test]
580    fn test_autocast_disabled() {
581        let ctx = AutocastContext::new(false, PrecisionMode::FP16);
582        assert!(!ctx.is_enabled());
583
584        let tensor = Array2::from_shape_vec((2, 2), vec![1.234_567; 4]).unwrap();
585        let casted = ctx.cast(&tensor);
586
587        // Should be unchanged
588        assert_eq!(casted, tensor);
589    }
590
591    #[test]
592    fn test_master_weights_update() {
593        let mut trainer = MixedPrecisionTrainer::new(PrecisionMode::FP16, LossScaler::None);
594
595        let weights = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
596        trainer.register_weights("layer1".to_string(), weights.clone());
597
598        let mut updates = HashMap::new();
599        updates.insert(
600            "layer1".to_string(),
601            Array2::from_shape_vec((2, 2), vec![0.1, 0.1, 0.1, 0.1]).unwrap(),
602        );
603
604        trainer.update_master_weights(&updates);
605
606        let master = &trainer.master_weights["layer1"];
607        assert_relative_eq!(master[[0, 0]], 1.1, epsilon = 1e-6);
608    }
609
610    #[test]
611    fn test_mixed_precision_stats() {
612        let stats = MixedPrecisionStats {
613            total_steps: 100,
614            overflow_steps: 5,
615            successful_steps: 95,
616        };
617
618        assert_eq!(stats.overflow_rate(), 0.05);
619        assert_eq!(stats.success_rate(), 0.95);
620    }
621
622    #[test]
623    fn test_loss_scaler_growth() {
624        let mut scaler = LossScaler::dynamic(1000.0, 2.0, 2);
625
626        // First successful step
627        assert!(scaler.update(false));
628        assert_eq!(scaler.get_scale(), 1000.0);
629
630        // Second successful step - should trigger growth
631        assert!(scaler.update(false));
632        assert_eq!(scaler.get_scale(), 2000.0);
633    }
634}