sklears_ensemble/
mixed_precision.rs

1//! Mixed-precision training support for ensemble methods
2//!
3//! This module provides mixed-precision training capabilities using FP16 and FP32
4//! to reduce memory usage and improve training speed while maintaining numerical stability.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::error::{Result, SklearsError};
8use sklears_core::types::{Float, Int};
9use std::collections::HashMap;
10
11/// Half-precision floating point type (FP16)
12pub type Half = half::f16;
13
14/// Mixed precision configuration
15#[derive(Debug, Clone)]
16pub struct MixedPrecisionConfig {
17    /// Enable mixed precision training
18    pub enabled: bool,
19    /// Loss scaling factor to prevent gradient underflow
20    pub loss_scale: Float,
21    /// Dynamic loss scaling
22    pub dynamic_loss_scaling: bool,
23    /// Initial loss scale for dynamic scaling
24    pub initial_loss_scale: Float,
25    /// Growth factor for loss scaling
26    pub growth_factor: Float,
27    /// Backoff factor when overflow is detected
28    pub backoff_factor: Float,
29    /// Number of steps without overflow before increasing scale
30    pub growth_interval: usize,
31    /// Operations to keep in FP32 (for numerical stability)
32    pub fp32_operations: Vec<String>,
33    /// Use automatic mixed precision (AMP)
34    pub use_amp: bool,
35    /// Gradient clipping threshold
36    pub gradient_clip_threshold: Option<Float>,
37}
38
39impl Default for MixedPrecisionConfig {
40    fn default() -> Self {
41        Self {
42            enabled: false,
43            loss_scale: 65536.0, // 2^16
44            dynamic_loss_scaling: true,
45            initial_loss_scale: 65536.0,
46            growth_factor: 2.0,
47            backoff_factor: 0.5,
48            growth_interval: 2000,
49            fp32_operations: vec![
50                "loss_computation".to_string(),
51                "batch_norm".to_string(),
52                "layer_norm".to_string(),
53                "softmax".to_string(),
54            ],
55            use_amp: true,
56            gradient_clip_threshold: Some(1.0),
57        }
58    }
59}
60
61/// Mixed precision trainer
62pub struct MixedPrecisionTrainer {
63    config: MixedPrecisionConfig,
64    current_loss_scale: Float,
65    overflow_count: usize,
66    successful_steps: usize,
67    scaler_state: ScalerState,
68}
69
70/// Loss scaler state
71#[derive(Debug, Clone)]
72pub struct ScalerState {
73    pub scale: Float,
74    pub growth_tracker: usize,
75    pub overflow_detected: bool,
76    pub should_skip_step: bool,
77}
78
79/// Mixed precision data types
80#[derive(Debug, Clone)]
81pub enum MixedPrecisionArray {
82    /// Full precision (FP32)
83    Full(Array2<Float>),
84    /// Half precision (FP16)
85    Half(Array2<Half>),
86    /// Mixed arrays with different precisions per operation
87    Mixed {
88        fp32_data: Array2<Float>,
89        fp16_data: Array2<Half>,
90        precision_mask: Array2<bool>, // true = FP32, false = FP16
91    },
92}
93
94/// Gradient accumulator with mixed precision
95pub struct MixedPrecisionGradientAccumulator {
96    fp32_gradients: HashMap<String, Array2<Float>>,
97    fp16_gradients: HashMap<String, Array2<Half>>,
98    accumulation_count: usize,
99}
100
101/// Automatic Mixed Precision (AMP) context
102pub struct AMPContext {
103    config: MixedPrecisionConfig,
104    scaler: GradientScaler,
105    autocast_enabled: bool,
106}
107
108/// Gradient scaler for mixed precision training
109pub struct GradientScaler {
110    scale: Float,
111    growth_tracker: usize,
112    growth_interval: usize,
113    backoff_factor: Float,
114    growth_factor: Float,
115}
116
117impl MixedPrecisionTrainer {
118    /// Create new mixed precision trainer
119    pub fn new(config: MixedPrecisionConfig) -> Self {
120        let current_loss_scale = if config.dynamic_loss_scaling {
121            config.initial_loss_scale
122        } else {
123            config.loss_scale
124        };
125
126        Self {
127            config,
128            current_loss_scale,
129            overflow_count: 0,
130            successful_steps: 0,
131            scaler_state: ScalerState {
132                scale: current_loss_scale,
133                growth_tracker: 0,
134                overflow_detected: false,
135                should_skip_step: false,
136            },
137        }
138    }
139
140    /// Enable mixed precision training
141    pub fn enable() -> Self {
142        Self::new(MixedPrecisionConfig {
143            enabled: true,
144            ..Default::default()
145        })
146    }
147
148    /// Convert array to mixed precision format
149    pub fn to_mixed_precision(
150        &self,
151        array: &Array2<Float>,
152        operation_name: &str,
153    ) -> MixedPrecisionArray {
154        if !self.config.enabled {
155            return MixedPrecisionArray::Full(array.clone());
156        }
157
158        if self
159            .config
160            .fp32_operations
161            .contains(&operation_name.to_string())
162        {
163            // Keep in FP32 for numerical stability
164            MixedPrecisionArray::Full(array.clone())
165        } else {
166            // Convert to FP16
167            let half_array = array.map(|&x| Half::from_f32(x as f32));
168            MixedPrecisionArray::Half(half_array)
169        }
170    }
171
172    /// Convert mixed precision array back to FP32
173    pub fn to_full_precision(&self, array: &MixedPrecisionArray) -> Array2<Float> {
174        match array {
175            MixedPrecisionArray::Full(arr) => arr.clone(),
176            MixedPrecisionArray::Half(arr) => arr.map(|&x| x.to_f32() as Float),
177            MixedPrecisionArray::Mixed {
178                fp32_data,
179                fp16_data,
180                precision_mask,
181            } => {
182                let mut result = Array2::zeros(fp32_data.dim());
183                for ((i, j), &use_fp32) in precision_mask.indexed_iter() {
184                    result[[i, j]] = if use_fp32 {
185                        fp32_data[[i, j]]
186                    } else {
187                        fp16_data[[i, j]].to_f32() as Float
188                    };
189                }
190                result
191            }
192        }
193    }
194
195    /// Scale gradients to prevent underflow
196    pub fn scale_gradients(&self, gradients: &mut Array2<Float>) {
197        if self.config.enabled {
198            *gradients *= self.current_loss_scale;
199        }
200    }
201
202    /// Unscale gradients after backward pass
203    pub fn unscale_gradients(&self, gradients: &mut Array2<Float>) -> bool {
204        if !self.config.enabled {
205            return false;
206        }
207
208        // Check for overflow/infinities
209        let has_overflow = gradients.iter().any(|&x| !x.is_finite());
210
211        if !has_overflow {
212            *gradients /= self.current_loss_scale;
213        }
214
215        has_overflow
216    }
217
218    /// Update loss scale based on overflow detection
219    pub fn update_scale(&mut self, overflow_detected: bool) {
220        if !self.config.dynamic_loss_scaling {
221            return;
222        }
223
224        self.scaler_state.overflow_detected = overflow_detected;
225
226        if overflow_detected {
227            // Reduce scale on overflow
228            self.current_loss_scale *= self.config.backoff_factor;
229            self.current_loss_scale = self.current_loss_scale.max(1.0);
230            self.overflow_count += 1;
231            self.successful_steps = 0;
232            self.scaler_state.should_skip_step = true;
233        } else {
234            // Increase scale after successful steps
235            self.successful_steps += 1;
236            self.scaler_state.should_skip_step = false;
237
238            if self.successful_steps >= self.config.growth_interval {
239                self.current_loss_scale *= self.config.growth_factor;
240                self.successful_steps = 0;
241            }
242        }
243
244        self.scaler_state.scale = self.current_loss_scale;
245    }
246
247    /// Check if current step should be skipped due to overflow
248    pub fn should_skip_step(&self) -> bool {
249        self.scaler_state.should_skip_step
250    }
251
252    /// Get current loss scale
253    pub fn get_loss_scale(&self) -> Float {
254        self.current_loss_scale
255    }
256
257    /// Train ensemble with mixed precision
258    pub fn train_ensemble_mixed_precision<F>(
259        &mut self,
260        x: &Array2<Float>,
261        y: &Array1<Int>,
262        n_estimators: usize,
263        mut train_fn: F,
264    ) -> Result<Vec<Array1<Float>>>
265    where
266        F: FnMut(&MixedPrecisionArray, &Array1<Int>) -> Result<Array1<Float>>,
267    {
268        let mut models = Vec::new();
269
270        for i in 0..n_estimators {
271            // Convert input to mixed precision
272            let x_mixed = self.to_mixed_precision(x, "forward_pass");
273
274            // Train single model
275            let model = train_fn(&x_mixed, y)?;
276
277            // Apply gradient scaling if needed
278            // In a real implementation, this would involve the actual gradient computation
279
280            models.push(model);
281
282            // Update loss scale based on training stability
283            let overflow_detected = false; // Would be detected during actual training
284            self.update_scale(overflow_detected);
285        }
286
287        Ok(models)
288    }
289
290    /// Get scaler state
291    pub fn scaler_state(&self) -> &ScalerState {
292        &self.scaler_state
293    }
294
295    /// Reset scaler state
296    pub fn reset_scaler(&mut self) {
297        self.current_loss_scale = self.config.initial_loss_scale;
298        self.overflow_count = 0;
299        self.successful_steps = 0;
300        self.scaler_state = ScalerState {
301            scale: self.current_loss_scale,
302            growth_tracker: 0,
303            overflow_detected: false,
304            should_skip_step: false,
305        };
306    }
307}
308
309impl MixedPrecisionArray {
310    /// Get the shape of the array
311    pub fn shape(&self) -> (usize, usize) {
312        match self {
313            MixedPrecisionArray::Full(arr) => arr.dim(),
314            MixedPrecisionArray::Half(arr) => arr.dim(),
315            MixedPrecisionArray::Mixed { fp32_data, .. } => fp32_data.dim(),
316        }
317    }
318
319    /// Check if array uses mixed precision
320    pub fn is_mixed_precision(&self) -> bool {
321        matches!(
322            self,
323            MixedPrecisionArray::Half(_) | MixedPrecisionArray::Mixed { .. }
324        )
325    }
326
327    /// Get memory usage in bytes
328    pub fn memory_usage_bytes(&self) -> usize {
329        match self {
330            MixedPrecisionArray::Full(arr) => arr.len() * std::mem::size_of::<Float>(),
331            MixedPrecisionArray::Half(arr) => arr.len() * std::mem::size_of::<Half>(),
332            MixedPrecisionArray::Mixed {
333                fp32_data,
334                fp16_data,
335                precision_mask,
336            } => {
337                let fp32_count = precision_mask.iter().filter(|&&x| x).count();
338                let fp16_count = precision_mask.len() - fp32_count;
339                fp32_count * std::mem::size_of::<Float>()
340                    + fp16_count * std::mem::size_of::<Half>()
341                    + precision_mask.len() * std::mem::size_of::<bool>()
342            }
343        }
344    }
345
346    /// Element-wise addition with automatic precision handling
347    pub fn add(&self, other: &Self) -> Result<Self> {
348        match (self, other) {
349            (MixedPrecisionArray::Full(a), MixedPrecisionArray::Full(b)) => {
350                Ok(MixedPrecisionArray::Full(a + b))
351            }
352            (MixedPrecisionArray::Half(a), MixedPrecisionArray::Half(b)) => {
353                Ok(MixedPrecisionArray::Half(a + b))
354            }
355            _ => {
356                // Convert both to full precision for mixed operations
357                let a_full = match self {
358                    MixedPrecisionArray::Full(arr) => arr.clone(),
359                    MixedPrecisionArray::Half(arr) => arr.map(|&x| x.to_f32() as Float),
360                    MixedPrecisionArray::Mixed {
361                        fp32_data,
362                        fp16_data,
363                        precision_mask,
364                    } => {
365                        let mut result = Array2::zeros(fp32_data.dim());
366                        for ((i, j), &use_fp32) in precision_mask.indexed_iter() {
367                            result[[i, j]] = if use_fp32 {
368                                fp32_data[[i, j]]
369                            } else {
370                                fp16_data[[i, j]].to_f32() as Float
371                            };
372                        }
373                        result
374                    }
375                };
376
377                let b_full = match other {
378                    MixedPrecisionArray::Full(arr) => arr.clone(),
379                    MixedPrecisionArray::Half(arr) => arr.map(|&x| x.to_f32() as Float),
380                    MixedPrecisionArray::Mixed {
381                        fp32_data,
382                        fp16_data,
383                        precision_mask,
384                    } => {
385                        let mut result = Array2::zeros(fp32_data.dim());
386                        for ((i, j), &use_fp32) in precision_mask.indexed_iter() {
387                            result[[i, j]] = if use_fp32 {
388                                fp32_data[[i, j]]
389                            } else {
390                                fp16_data[[i, j]].to_f32() as Float
391                            };
392                        }
393                        result
394                    }
395                };
396
397                Ok(MixedPrecisionArray::Full(&a_full + &b_full))
398            }
399        }
400    }
401}
402
403impl Default for MixedPrecisionGradientAccumulator {
404    fn default() -> Self {
405        Self::new()
406    }
407}
408
409impl MixedPrecisionGradientAccumulator {
410    /// Create new gradient accumulator
411    pub fn new() -> Self {
412        Self {
413            fp32_gradients: HashMap::new(),
414            fp16_gradients: HashMap::new(),
415            accumulation_count: 0,
416        }
417    }
418
419    /// Accumulate gradients with mixed precision
420    pub fn accumulate(&mut self, name: &str, gradients: &MixedPrecisionArray) -> Result<()> {
421        match gradients {
422            MixedPrecisionArray::Full(grads) => {
423                let entry = self
424                    .fp32_gradients
425                    .entry(name.to_string())
426                    .or_insert_with(|| Array2::zeros(grads.dim()));
427                *entry = entry.clone() + grads;
428            }
429            MixedPrecisionArray::Half(grads) => {
430                let entry = self
431                    .fp16_gradients
432                    .entry(name.to_string())
433                    .or_insert_with(|| Array2::zeros(grads.dim()));
434                *entry = entry.clone() + grads;
435            }
436            MixedPrecisionArray::Mixed { .. } => {
437                // Convert to full precision for accumulation
438                let full_grads = match gradients {
439                    MixedPrecisionArray::Mixed {
440                        fp32_data,
441                        fp16_data,
442                        precision_mask,
443                    } => {
444                        let mut result = Array2::zeros(fp32_data.dim());
445                        for ((i, j), &use_fp32) in precision_mask.indexed_iter() {
446                            result[[i, j]] = if use_fp32 {
447                                fp32_data[[i, j]]
448                            } else {
449                                fp16_data[[i, j]].to_f32() as Float
450                            };
451                        }
452                        result
453                    }
454                    _ => unreachable!(),
455                };
456
457                let entry = self
458                    .fp32_gradients
459                    .entry(name.to_string())
460                    .or_insert_with(|| Array2::zeros(full_grads.dim()));
461                *entry = entry.clone() + &full_grads;
462            }
463        }
464
465        self.accumulation_count += 1;
466        Ok(())
467    }
468
469    /// Get averaged gradients
470    pub fn get_averaged_gradients(&self) -> HashMap<String, Array2<Float>> {
471        let mut result = HashMap::new();
472
473        // Convert FP32 gradients
474        for (name, grads) in &self.fp32_gradients {
475            result.insert(
476                name.clone(),
477                grads.clone() / self.accumulation_count as Float,
478            );
479        }
480
481        // Convert FP16 gradients to FP32
482        for (name, grads) in &self.fp16_gradients {
483            let fp32_grads = grads.map(|&x| x.to_f32() as Float);
484            result.insert(name.clone(), fp32_grads / self.accumulation_count as Float);
485        }
486
487        result
488    }
489
490    /// Clear accumulated gradients
491    pub fn clear(&mut self) {
492        self.fp32_gradients.clear();
493        self.fp16_gradients.clear();
494        self.accumulation_count = 0;
495    }
496}
497
498impl AMPContext {
499    /// Create new AMP context
500    pub fn new(config: MixedPrecisionConfig) -> Self {
501        let scaler = GradientScaler::new(
502            config.initial_loss_scale,
503            config.growth_interval,
504            config.backoff_factor,
505            config.growth_factor,
506        );
507
508        Self {
509            config,
510            scaler,
511            autocast_enabled: false,
512        }
513    }
514
515    /// Enable autocast for current scope
516    pub fn autocast<F, R>(&mut self, f: F) -> R
517    where
518        F: FnOnce(&mut Self) -> R,
519    {
520        let old_state = self.autocast_enabled;
521        self.autocast_enabled = true;
522        let result = f(self);
523        self.autocast_enabled = old_state;
524        result
525    }
526
527    /// Check if autocast is enabled
528    pub fn is_autocast_enabled(&self) -> bool {
529        self.autocast_enabled
530    }
531
532    /// Scale loss for backward pass
533    pub fn scale_loss(&mut self, loss: Float) -> Float {
534        self.scaler.scale(loss)
535    }
536
537    /// Step optimizer with gradient scaling
538    pub fn step<F>(&mut self, optimizer_step: F) -> bool
539    where
540        F: FnOnce(),
541    {
542        if !self.scaler.should_skip_step() {
543            optimizer_step();
544            self.scaler.update(false); // No overflow
545            true
546        } else {
547            self.scaler.update(true); // Overflow detected
548            false
549        }
550    }
551}
552
553impl GradientScaler {
554    /// Create new gradient scaler
555    pub fn new(
556        initial_scale: Float,
557        growth_interval: usize,
558        backoff_factor: Float,
559        growth_factor: Float,
560    ) -> Self {
561        Self {
562            scale: initial_scale,
563            growth_tracker: 0,
564            growth_interval,
565            backoff_factor,
566            growth_factor,
567        }
568    }
569
570    /// Scale value
571    pub fn scale(&self, value: Float) -> Float {
572        value * self.scale
573    }
574
575    /// Unscale value
576    pub fn unscale(&self, value: Float) -> Float {
577        value / self.scale
578    }
579
580    /// Update scale based on overflow detection
581    pub fn update(&mut self, overflow_detected: bool) {
582        if overflow_detected {
583            self.scale *= self.backoff_factor;
584            self.scale = self.scale.max(1.0);
585            self.growth_tracker = 0;
586        } else {
587            self.growth_tracker += 1;
588            if self.growth_tracker >= self.growth_interval {
589                self.scale *= self.growth_factor;
590                self.growth_tracker = 0;
591            }
592        }
593    }
594
595    /// Check if step should be skipped
596    pub fn should_skip_step(&self) -> bool {
597        self.scale < 1.0
598    }
599
600    /// Get current scale
601    pub fn get_scale(&self) -> Float {
602        self.scale
603    }
604}
605
606/// Utility functions for mixed precision operations
607pub mod utils {
608    use super::*;
609
610    /// Check if value is in FP16 range
611    pub fn is_fp16_representable(value: Float) -> bool {
612        let abs_val = value.abs();
613        abs_val <= Half::MAX.to_f32() as Float && abs_val >= Half::MIN_POSITIVE.to_f32() as Float
614    }
615
616    /// Estimate memory savings from mixed precision
617    pub fn estimate_memory_savings(
618        fp32_arrays: &[Array2<Float>],
619        mixed_precision_ratio: Float,
620    ) -> (usize, usize, Float) {
621        let fp32_memory = fp32_arrays
622            .iter()
623            .map(|arr| arr.len() * std::mem::size_of::<Float>())
624            .sum::<usize>();
625
626        let fp16_elements = (fp32_arrays.iter().map(|arr| arr.len()).sum::<usize>() as Float
627            * mixed_precision_ratio) as usize;
628        let fp32_elements = fp32_arrays.iter().map(|arr| arr.len()).sum::<usize>() - fp16_elements;
629
630        let mixed_memory = fp32_elements * std::mem::size_of::<Float>()
631            + fp16_elements * std::mem::size_of::<Half>();
632
633        let savings_ratio = 1.0 - (mixed_memory as Float / fp32_memory as Float);
634
635        (fp32_memory, mixed_memory, savings_ratio)
636    }
637
638    /// Convert Float to Half with overflow checking
639    pub fn safe_float_to_half(value: Float) -> Result<Half> {
640        if value.is_finite() && is_fp16_representable(value) {
641            Ok(Half::from_f32(value as f32))
642        } else {
643            Err(SklearsError::InvalidInput(format!(
644                "Value {} cannot be represented in FP16",
645                value
646            )))
647        }
648    }
649}
650
651#[allow(non_snake_case)]
652#[cfg(test)]
653mod tests {
654    use super::*;
655    use scirs2_core::ndarray::array;
656
657    #[test]
658    fn test_mixed_precision_config() {
659        let config = MixedPrecisionConfig::default();
660        assert!(!config.enabled);
661        assert_eq!(config.loss_scale, 65536.0);
662        assert!(config.dynamic_loss_scaling);
663    }
664
665    #[test]
666    fn test_mixed_precision_trainer() {
667        let config = MixedPrecisionConfig::default();
668        let trainer = MixedPrecisionTrainer::new(config);
669        assert_eq!(trainer.get_loss_scale(), 65536.0);
670        assert!(!trainer.should_skip_step());
671    }
672
673    #[test]
674    fn test_mixed_precision_array() {
675        let full_array = array![[1.0, 2.0], [3.0, 4.0]];
676        let mixed_array = MixedPrecisionArray::Full(full_array.clone());
677
678        assert_eq!(mixed_array.shape(), (2, 2));
679        assert!(!mixed_array.is_mixed_precision());
680        assert_eq!(
681            mixed_array.memory_usage_bytes(),
682            4 * std::mem::size_of::<Float>()
683        );
684    }
685
686    #[test]
687    fn test_mixed_precision_array_addition() {
688        let a = MixedPrecisionArray::Full(array![[1.0, 2.0], [3.0, 4.0]]);
689        let b = MixedPrecisionArray::Full(array![[5.0, 6.0], [7.0, 8.0]]);
690
691        let result = a.add(&b).unwrap();
692        match result {
693            MixedPrecisionArray::Full(arr) => {
694                assert_eq!(arr, array![[6.0, 8.0], [10.0, 12.0]]);
695            }
696            _ => panic!("Expected full precision result"),
697        }
698    }
699
700    #[test]
701    fn test_gradient_accumulator() {
702        let mut accumulator = MixedPrecisionGradientAccumulator::new();
703
704        let grad1 = MixedPrecisionArray::Full(array![[1.0, 2.0], [3.0, 4.0]]);
705        let grad2 = MixedPrecisionArray::Full(array![[2.0, 3.0], [4.0, 5.0]]);
706
707        accumulator.accumulate("layer1", &grad1).unwrap();
708        accumulator.accumulate("layer1", &grad2).unwrap();
709
710        let averaged = accumulator.get_averaged_gradients();
711        let layer1_grads = &averaged["layer1"];
712        assert_eq!(*layer1_grads, array![[1.5, 2.5], [3.5, 4.5]]);
713    }
714
715    #[test]
716    fn test_gradient_scaler() {
717        let mut scaler = GradientScaler::new(1024.0, 2000, 0.5, 2.0);
718
719        assert_eq!(scaler.scale(1.0), 1024.0);
720        assert_eq!(scaler.unscale(1024.0), 1.0);
721
722        // Test overflow handling
723        scaler.update(true);
724        assert_eq!(scaler.get_scale(), 512.0);
725
726        // Test growth
727        for _ in 0..2000 {
728            scaler.update(false);
729        }
730        assert_eq!(scaler.get_scale(), 1024.0);
731    }
732
733    #[test]
734    fn test_amp_context() {
735        let config = MixedPrecisionConfig::default();
736        let mut amp = AMPContext::new(config);
737
738        assert!(!amp.is_autocast_enabled());
739
740        amp.autocast(|ctx| {
741            assert!(ctx.is_autocast_enabled());
742        });
743
744        assert!(!amp.is_autocast_enabled());
745    }
746
747    #[test]
748    fn test_memory_savings_estimation() {
749        let arrays = vec![
750            array![[1.0, 2.0], [3.0, 4.0]],
751            array![[5.0, 6.0], [7.0, 8.0]],
752        ];
753
754        let (fp32_mem, mixed_mem, savings) = utils::estimate_memory_savings(&arrays, 0.5);
755
756        assert!(fp32_mem > mixed_mem);
757        assert!(savings > 0.0);
758    }
759
760    #[test]
761    fn test_fp16_range_check() {
762        assert!(utils::is_fp16_representable(1.0));
763        assert!(utils::is_fp16_representable(-1.0));
764        assert!(!utils::is_fp16_representable(Float::INFINITY));
765        assert!(!utils::is_fp16_representable(Float::NAN));
766    }
767}