1use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::error::{Result, SklearsError};
8use sklears_core::types::{Float, Int};
9use std::collections::HashMap;
10
11pub type Half = half::f16;
13
14#[derive(Debug, Clone)]
16pub struct MixedPrecisionConfig {
17 pub enabled: bool,
19 pub loss_scale: Float,
21 pub dynamic_loss_scaling: bool,
23 pub initial_loss_scale: Float,
25 pub growth_factor: Float,
27 pub backoff_factor: Float,
29 pub growth_interval: usize,
31 pub fp32_operations: Vec<String>,
33 pub use_amp: bool,
35 pub gradient_clip_threshold: Option<Float>,
37}
38
39impl Default for MixedPrecisionConfig {
40 fn default() -> Self {
41 Self {
42 enabled: false,
43 loss_scale: 65536.0, dynamic_loss_scaling: true,
45 initial_loss_scale: 65536.0,
46 growth_factor: 2.0,
47 backoff_factor: 0.5,
48 growth_interval: 2000,
49 fp32_operations: vec![
50 "loss_computation".to_string(),
51 "batch_norm".to_string(),
52 "layer_norm".to_string(),
53 "softmax".to_string(),
54 ],
55 use_amp: true,
56 gradient_clip_threshold: Some(1.0),
57 }
58 }
59}
60
61pub struct MixedPrecisionTrainer {
63 config: MixedPrecisionConfig,
64 current_loss_scale: Float,
65 overflow_count: usize,
66 successful_steps: usize,
67 scaler_state: ScalerState,
68}
69
70#[derive(Debug, Clone)]
72pub struct ScalerState {
73 pub scale: Float,
74 pub growth_tracker: usize,
75 pub overflow_detected: bool,
76 pub should_skip_step: bool,
77}
78
79#[derive(Debug, Clone)]
81pub enum MixedPrecisionArray {
82 Full(Array2<Float>),
84 Half(Array2<Half>),
86 Mixed {
88 fp32_data: Array2<Float>,
89 fp16_data: Array2<Half>,
90 precision_mask: Array2<bool>, },
92}
93
94pub struct MixedPrecisionGradientAccumulator {
96 fp32_gradients: HashMap<String, Array2<Float>>,
97 fp16_gradients: HashMap<String, Array2<Half>>,
98 accumulation_count: usize,
99}
100
101pub struct AMPContext {
103 config: MixedPrecisionConfig,
104 scaler: GradientScaler,
105 autocast_enabled: bool,
106}
107
108pub struct GradientScaler {
110 scale: Float,
111 growth_tracker: usize,
112 growth_interval: usize,
113 backoff_factor: Float,
114 growth_factor: Float,
115}
116
117impl MixedPrecisionTrainer {
118 pub fn new(config: MixedPrecisionConfig) -> Self {
120 let current_loss_scale = if config.dynamic_loss_scaling {
121 config.initial_loss_scale
122 } else {
123 config.loss_scale
124 };
125
126 Self {
127 config,
128 current_loss_scale,
129 overflow_count: 0,
130 successful_steps: 0,
131 scaler_state: ScalerState {
132 scale: current_loss_scale,
133 growth_tracker: 0,
134 overflow_detected: false,
135 should_skip_step: false,
136 },
137 }
138 }
139
140 pub fn enable() -> Self {
142 Self::new(MixedPrecisionConfig {
143 enabled: true,
144 ..Default::default()
145 })
146 }
147
148 pub fn to_mixed_precision(
150 &self,
151 array: &Array2<Float>,
152 operation_name: &str,
153 ) -> MixedPrecisionArray {
154 if !self.config.enabled {
155 return MixedPrecisionArray::Full(array.clone());
156 }
157
158 if self
159 .config
160 .fp32_operations
161 .contains(&operation_name.to_string())
162 {
163 MixedPrecisionArray::Full(array.clone())
165 } else {
166 let half_array = array.map(|&x| Half::from_f32(x as f32));
168 MixedPrecisionArray::Half(half_array)
169 }
170 }
171
172 pub fn to_full_precision(&self, array: &MixedPrecisionArray) -> Array2<Float> {
174 match array {
175 MixedPrecisionArray::Full(arr) => arr.clone(),
176 MixedPrecisionArray::Half(arr) => arr.map(|&x| x.to_f32() as Float),
177 MixedPrecisionArray::Mixed {
178 fp32_data,
179 fp16_data,
180 precision_mask,
181 } => {
182 let mut result = Array2::zeros(fp32_data.dim());
183 for ((i, j), &use_fp32) in precision_mask.indexed_iter() {
184 result[[i, j]] = if use_fp32 {
185 fp32_data[[i, j]]
186 } else {
187 fp16_data[[i, j]].to_f32() as Float
188 };
189 }
190 result
191 }
192 }
193 }
194
195 pub fn scale_gradients(&self, gradients: &mut Array2<Float>) {
197 if self.config.enabled {
198 *gradients *= self.current_loss_scale;
199 }
200 }
201
202 pub fn unscale_gradients(&self, gradients: &mut Array2<Float>) -> bool {
204 if !self.config.enabled {
205 return false;
206 }
207
208 let has_overflow = gradients.iter().any(|&x| !x.is_finite());
210
211 if !has_overflow {
212 *gradients /= self.current_loss_scale;
213 }
214
215 has_overflow
216 }
217
218 pub fn update_scale(&mut self, overflow_detected: bool) {
220 if !self.config.dynamic_loss_scaling {
221 return;
222 }
223
224 self.scaler_state.overflow_detected = overflow_detected;
225
226 if overflow_detected {
227 self.current_loss_scale *= self.config.backoff_factor;
229 self.current_loss_scale = self.current_loss_scale.max(1.0);
230 self.overflow_count += 1;
231 self.successful_steps = 0;
232 self.scaler_state.should_skip_step = true;
233 } else {
234 self.successful_steps += 1;
236 self.scaler_state.should_skip_step = false;
237
238 if self.successful_steps >= self.config.growth_interval {
239 self.current_loss_scale *= self.config.growth_factor;
240 self.successful_steps = 0;
241 }
242 }
243
244 self.scaler_state.scale = self.current_loss_scale;
245 }
246
247 pub fn should_skip_step(&self) -> bool {
249 self.scaler_state.should_skip_step
250 }
251
252 pub fn get_loss_scale(&self) -> Float {
254 self.current_loss_scale
255 }
256
257 pub fn train_ensemble_mixed_precision<F>(
259 &mut self,
260 x: &Array2<Float>,
261 y: &Array1<Int>,
262 n_estimators: usize,
263 mut train_fn: F,
264 ) -> Result<Vec<Array1<Float>>>
265 where
266 F: FnMut(&MixedPrecisionArray, &Array1<Int>) -> Result<Array1<Float>>,
267 {
268 let mut models = Vec::new();
269
270 for i in 0..n_estimators {
271 let x_mixed = self.to_mixed_precision(x, "forward_pass");
273
274 let model = train_fn(&x_mixed, y)?;
276
277 models.push(model);
281
282 let overflow_detected = false; self.update_scale(overflow_detected);
285 }
286
287 Ok(models)
288 }
289
290 pub fn scaler_state(&self) -> &ScalerState {
292 &self.scaler_state
293 }
294
295 pub fn reset_scaler(&mut self) {
297 self.current_loss_scale = self.config.initial_loss_scale;
298 self.overflow_count = 0;
299 self.successful_steps = 0;
300 self.scaler_state = ScalerState {
301 scale: self.current_loss_scale,
302 growth_tracker: 0,
303 overflow_detected: false,
304 should_skip_step: false,
305 };
306 }
307}
308
309impl MixedPrecisionArray {
310 pub fn shape(&self) -> (usize, usize) {
312 match self {
313 MixedPrecisionArray::Full(arr) => arr.dim(),
314 MixedPrecisionArray::Half(arr) => arr.dim(),
315 MixedPrecisionArray::Mixed { fp32_data, .. } => fp32_data.dim(),
316 }
317 }
318
319 pub fn is_mixed_precision(&self) -> bool {
321 matches!(
322 self,
323 MixedPrecisionArray::Half(_) | MixedPrecisionArray::Mixed { .. }
324 )
325 }
326
327 pub fn memory_usage_bytes(&self) -> usize {
329 match self {
330 MixedPrecisionArray::Full(arr) => arr.len() * std::mem::size_of::<Float>(),
331 MixedPrecisionArray::Half(arr) => arr.len() * std::mem::size_of::<Half>(),
332 MixedPrecisionArray::Mixed {
333 fp32_data,
334 fp16_data,
335 precision_mask,
336 } => {
337 let fp32_count = precision_mask.iter().filter(|&&x| x).count();
338 let fp16_count = precision_mask.len() - fp32_count;
339 fp32_count * std::mem::size_of::<Float>()
340 + fp16_count * std::mem::size_of::<Half>()
341 + precision_mask.len() * std::mem::size_of::<bool>()
342 }
343 }
344 }
345
346 pub fn add(&self, other: &Self) -> Result<Self> {
348 match (self, other) {
349 (MixedPrecisionArray::Full(a), MixedPrecisionArray::Full(b)) => {
350 Ok(MixedPrecisionArray::Full(a + b))
351 }
352 (MixedPrecisionArray::Half(a), MixedPrecisionArray::Half(b)) => {
353 Ok(MixedPrecisionArray::Half(a + b))
354 }
355 _ => {
356 let a_full = match self {
358 MixedPrecisionArray::Full(arr) => arr.clone(),
359 MixedPrecisionArray::Half(arr) => arr.map(|&x| x.to_f32() as Float),
360 MixedPrecisionArray::Mixed {
361 fp32_data,
362 fp16_data,
363 precision_mask,
364 } => {
365 let mut result = Array2::zeros(fp32_data.dim());
366 for ((i, j), &use_fp32) in precision_mask.indexed_iter() {
367 result[[i, j]] = if use_fp32 {
368 fp32_data[[i, j]]
369 } else {
370 fp16_data[[i, j]].to_f32() as Float
371 };
372 }
373 result
374 }
375 };
376
377 let b_full = match other {
378 MixedPrecisionArray::Full(arr) => arr.clone(),
379 MixedPrecisionArray::Half(arr) => arr.map(|&x| x.to_f32() as Float),
380 MixedPrecisionArray::Mixed {
381 fp32_data,
382 fp16_data,
383 precision_mask,
384 } => {
385 let mut result = Array2::zeros(fp32_data.dim());
386 for ((i, j), &use_fp32) in precision_mask.indexed_iter() {
387 result[[i, j]] = if use_fp32 {
388 fp32_data[[i, j]]
389 } else {
390 fp16_data[[i, j]].to_f32() as Float
391 };
392 }
393 result
394 }
395 };
396
397 Ok(MixedPrecisionArray::Full(&a_full + &b_full))
398 }
399 }
400 }
401}
402
403impl Default for MixedPrecisionGradientAccumulator {
404 fn default() -> Self {
405 Self::new()
406 }
407}
408
409impl MixedPrecisionGradientAccumulator {
410 pub fn new() -> Self {
412 Self {
413 fp32_gradients: HashMap::new(),
414 fp16_gradients: HashMap::new(),
415 accumulation_count: 0,
416 }
417 }
418
419 pub fn accumulate(&mut self, name: &str, gradients: &MixedPrecisionArray) -> Result<()> {
421 match gradients {
422 MixedPrecisionArray::Full(grads) => {
423 let entry = self
424 .fp32_gradients
425 .entry(name.to_string())
426 .or_insert_with(|| Array2::zeros(grads.dim()));
427 *entry = entry.clone() + grads;
428 }
429 MixedPrecisionArray::Half(grads) => {
430 let entry = self
431 .fp16_gradients
432 .entry(name.to_string())
433 .or_insert_with(|| Array2::zeros(grads.dim()));
434 *entry = entry.clone() + grads;
435 }
436 MixedPrecisionArray::Mixed { .. } => {
437 let full_grads = match gradients {
439 MixedPrecisionArray::Mixed {
440 fp32_data,
441 fp16_data,
442 precision_mask,
443 } => {
444 let mut result = Array2::zeros(fp32_data.dim());
445 for ((i, j), &use_fp32) in precision_mask.indexed_iter() {
446 result[[i, j]] = if use_fp32 {
447 fp32_data[[i, j]]
448 } else {
449 fp16_data[[i, j]].to_f32() as Float
450 };
451 }
452 result
453 }
454 _ => unreachable!(),
455 };
456
457 let entry = self
458 .fp32_gradients
459 .entry(name.to_string())
460 .or_insert_with(|| Array2::zeros(full_grads.dim()));
461 *entry = entry.clone() + &full_grads;
462 }
463 }
464
465 self.accumulation_count += 1;
466 Ok(())
467 }
468
469 pub fn get_averaged_gradients(&self) -> HashMap<String, Array2<Float>> {
471 let mut result = HashMap::new();
472
473 for (name, grads) in &self.fp32_gradients {
475 result.insert(
476 name.clone(),
477 grads.clone() / self.accumulation_count as Float,
478 );
479 }
480
481 for (name, grads) in &self.fp16_gradients {
483 let fp32_grads = grads.map(|&x| x.to_f32() as Float);
484 result.insert(name.clone(), fp32_grads / self.accumulation_count as Float);
485 }
486
487 result
488 }
489
490 pub fn clear(&mut self) {
492 self.fp32_gradients.clear();
493 self.fp16_gradients.clear();
494 self.accumulation_count = 0;
495 }
496}
497
498impl AMPContext {
499 pub fn new(config: MixedPrecisionConfig) -> Self {
501 let scaler = GradientScaler::new(
502 config.initial_loss_scale,
503 config.growth_interval,
504 config.backoff_factor,
505 config.growth_factor,
506 );
507
508 Self {
509 config,
510 scaler,
511 autocast_enabled: false,
512 }
513 }
514
515 pub fn autocast<F, R>(&mut self, f: F) -> R
517 where
518 F: FnOnce(&mut Self) -> R,
519 {
520 let old_state = self.autocast_enabled;
521 self.autocast_enabled = true;
522 let result = f(self);
523 self.autocast_enabled = old_state;
524 result
525 }
526
527 pub fn is_autocast_enabled(&self) -> bool {
529 self.autocast_enabled
530 }
531
532 pub fn scale_loss(&mut self, loss: Float) -> Float {
534 self.scaler.scale(loss)
535 }
536
537 pub fn step<F>(&mut self, optimizer_step: F) -> bool
539 where
540 F: FnOnce(),
541 {
542 if !self.scaler.should_skip_step() {
543 optimizer_step();
544 self.scaler.update(false); true
546 } else {
547 self.scaler.update(true); false
549 }
550 }
551}
552
553impl GradientScaler {
554 pub fn new(
556 initial_scale: Float,
557 growth_interval: usize,
558 backoff_factor: Float,
559 growth_factor: Float,
560 ) -> Self {
561 Self {
562 scale: initial_scale,
563 growth_tracker: 0,
564 growth_interval,
565 backoff_factor,
566 growth_factor,
567 }
568 }
569
570 pub fn scale(&self, value: Float) -> Float {
572 value * self.scale
573 }
574
575 pub fn unscale(&self, value: Float) -> Float {
577 value / self.scale
578 }
579
580 pub fn update(&mut self, overflow_detected: bool) {
582 if overflow_detected {
583 self.scale *= self.backoff_factor;
584 self.scale = self.scale.max(1.0);
585 self.growth_tracker = 0;
586 } else {
587 self.growth_tracker += 1;
588 if self.growth_tracker >= self.growth_interval {
589 self.scale *= self.growth_factor;
590 self.growth_tracker = 0;
591 }
592 }
593 }
594
595 pub fn should_skip_step(&self) -> bool {
597 self.scale < 1.0
598 }
599
600 pub fn get_scale(&self) -> Float {
602 self.scale
603 }
604}
605
606pub mod utils {
608 use super::*;
609
610 pub fn is_fp16_representable(value: Float) -> bool {
612 let abs_val = value.abs();
613 abs_val <= Half::MAX.to_f32() as Float && abs_val >= Half::MIN_POSITIVE.to_f32() as Float
614 }
615
616 pub fn estimate_memory_savings(
618 fp32_arrays: &[Array2<Float>],
619 mixed_precision_ratio: Float,
620 ) -> (usize, usize, Float) {
621 let fp32_memory = fp32_arrays
622 .iter()
623 .map(|arr| arr.len() * std::mem::size_of::<Float>())
624 .sum::<usize>();
625
626 let fp16_elements = (fp32_arrays.iter().map(|arr| arr.len()).sum::<usize>() as Float
627 * mixed_precision_ratio) as usize;
628 let fp32_elements = fp32_arrays.iter().map(|arr| arr.len()).sum::<usize>() - fp16_elements;
629
630 let mixed_memory = fp32_elements * std::mem::size_of::<Float>()
631 + fp16_elements * std::mem::size_of::<Half>();
632
633 let savings_ratio = 1.0 - (mixed_memory as Float / fp32_memory as Float);
634
635 (fp32_memory, mixed_memory, savings_ratio)
636 }
637
638 pub fn safe_float_to_half(value: Float) -> Result<Half> {
640 if value.is_finite() && is_fp16_representable(value) {
641 Ok(Half::from_f32(value as f32))
642 } else {
643 Err(SklearsError::InvalidInput(format!(
644 "Value {} cannot be represented in FP16",
645 value
646 )))
647 }
648 }
649}
650
651#[allow(non_snake_case)]
652#[cfg(test)]
653mod tests {
654 use super::*;
655 use scirs2_core::ndarray::array;
656
657 #[test]
658 fn test_mixed_precision_config() {
659 let config = MixedPrecisionConfig::default();
660 assert!(!config.enabled);
661 assert_eq!(config.loss_scale, 65536.0);
662 assert!(config.dynamic_loss_scaling);
663 }
664
665 #[test]
666 fn test_mixed_precision_trainer() {
667 let config = MixedPrecisionConfig::default();
668 let trainer = MixedPrecisionTrainer::new(config);
669 assert_eq!(trainer.get_loss_scale(), 65536.0);
670 assert!(!trainer.should_skip_step());
671 }
672
673 #[test]
674 fn test_mixed_precision_array() {
675 let full_array = array![[1.0, 2.0], [3.0, 4.0]];
676 let mixed_array = MixedPrecisionArray::Full(full_array.clone());
677
678 assert_eq!(mixed_array.shape(), (2, 2));
679 assert!(!mixed_array.is_mixed_precision());
680 assert_eq!(
681 mixed_array.memory_usage_bytes(),
682 4 * std::mem::size_of::<Float>()
683 );
684 }
685
686 #[test]
687 fn test_mixed_precision_array_addition() {
688 let a = MixedPrecisionArray::Full(array![[1.0, 2.0], [3.0, 4.0]]);
689 let b = MixedPrecisionArray::Full(array![[5.0, 6.0], [7.0, 8.0]]);
690
691 let result = a.add(&b).unwrap();
692 match result {
693 MixedPrecisionArray::Full(arr) => {
694 assert_eq!(arr, array![[6.0, 8.0], [10.0, 12.0]]);
695 }
696 _ => panic!("Expected full precision result"),
697 }
698 }
699
700 #[test]
701 fn test_gradient_accumulator() {
702 let mut accumulator = MixedPrecisionGradientAccumulator::new();
703
704 let grad1 = MixedPrecisionArray::Full(array![[1.0, 2.0], [3.0, 4.0]]);
705 let grad2 = MixedPrecisionArray::Full(array![[2.0, 3.0], [4.0, 5.0]]);
706
707 accumulator.accumulate("layer1", &grad1).unwrap();
708 accumulator.accumulate("layer1", &grad2).unwrap();
709
710 let averaged = accumulator.get_averaged_gradients();
711 let layer1_grads = &averaged["layer1"];
712 assert_eq!(*layer1_grads, array![[1.5, 2.5], [3.5, 4.5]]);
713 }
714
715 #[test]
716 fn test_gradient_scaler() {
717 let mut scaler = GradientScaler::new(1024.0, 2000, 0.5, 2.0);
718
719 assert_eq!(scaler.scale(1.0), 1024.0);
720 assert_eq!(scaler.unscale(1024.0), 1.0);
721
722 scaler.update(true);
724 assert_eq!(scaler.get_scale(), 512.0);
725
726 for _ in 0..2000 {
728 scaler.update(false);
729 }
730 assert_eq!(scaler.get_scale(), 1024.0);
731 }
732
733 #[test]
734 fn test_amp_context() {
735 let config = MixedPrecisionConfig::default();
736 let mut amp = AMPContext::new(config);
737
738 assert!(!amp.is_autocast_enabled());
739
740 amp.autocast(|ctx| {
741 assert!(ctx.is_autocast_enabled());
742 });
743
744 assert!(!amp.is_autocast_enabled());
745 }
746
747 #[test]
748 fn test_memory_savings_estimation() {
749 let arrays = vec![
750 array![[1.0, 2.0], [3.0, 4.0]],
751 array![[5.0, 6.0], [7.0, 8.0]],
752 ];
753
754 let (fp32_mem, mixed_mem, savings) = utils::estimate_memory_savings(&arrays, 0.5);
755
756 assert!(fp32_mem > mixed_mem);
757 assert!(savings > 0.0);
758 }
759
760 #[test]
761 fn test_fp16_range_check() {
762 assert!(utils::is_fp16_representable(1.0));
763 assert!(utils::is_fp16_representable(-1.0));
764 assert!(!utils::is_fp16_representable(Float::INFINITY));
765 assert!(!utils::is_fp16_representable(Float::NAN));
766 }
767}