1use scirs2_core::ndarray::Array2;
29use serde::{Deserialize, Serialize};
30use std::collections::HashMap;
31
32use crate::error::TrainResult;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum PrecisionMode {
37 FP32,
39 FP16,
41 BF16,
43}
44
45impl PrecisionMode {
46 pub fn bytes_per_element(&self) -> usize {
48 match self {
49 PrecisionMode::FP32 => 4,
50 PrecisionMode::FP16 => 2,
51 PrecisionMode::BF16 => 2,
52 }
53 }
54
55 pub fn memory_reduction(&self) -> f32 {
57 match self {
58 PrecisionMode::FP32 => 1.0,
59 PrecisionMode::FP16 => 2.0,
60 PrecisionMode::BF16 => 2.0,
61 }
62 }
63
64 pub fn numerical_range(&self) -> (f32, f32) {
66 match self {
67 PrecisionMode::FP32 => (-3.4e38, 3.4e38),
68 PrecisionMode::FP16 => (-6.55e4, 6.55e4),
69 PrecisionMode::BF16 => (-3.39e38, 3.39e38), }
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum LossScaler {
77 None,
79 Static { scale: f32 },
81 Dynamic {
83 scale: f32,
85 growth_factor: f32,
87 backoff_factor: f32,
89 growth_interval: usize,
91 steps_since_overflow: usize,
93 },
94}
95
96impl LossScaler {
97 pub fn static_scale(scale: f32) -> Self {
99 Self::Static { scale }
100 }
101
102 pub fn dynamic(initial_scale: f32, growth_factor: f32, growth_interval: usize) -> Self {
109 Self::Dynamic {
110 scale: initial_scale,
111 growth_factor,
112 backoff_factor: 0.5,
113 growth_interval,
114 steps_since_overflow: 0,
115 }
116 }
117
118 pub fn get_scale(&self) -> f32 {
120 match self {
121 Self::None => 1.0,
122 Self::Static { scale } => *scale,
123 Self::Dynamic { scale, .. } => *scale,
124 }
125 }
126
127 pub fn scale_loss(&self, loss: f32) -> f32 {
129 loss * self.get_scale()
130 }
131
132 pub fn unscale_gradients(&self, gradients: &mut Array2<f32>) {
134 let scale = self.get_scale();
135 if scale != 1.0 {
136 *gradients /= scale;
137 }
138 }
139
140 pub fn update(&mut self, overflow_detected: bool) -> bool {
148 if let Self::Dynamic {
149 scale,
150 growth_factor,
151 backoff_factor,
152 growth_interval,
153 steps_since_overflow,
154 } = self
155 {
156 if overflow_detected {
157 *scale *= *backoff_factor;
159 *steps_since_overflow = 0;
160 false } else {
162 *steps_since_overflow += 1;
164
165 if *steps_since_overflow >= *growth_interval {
167 *scale *= *growth_factor;
168 *steps_since_overflow = 0;
169 }
170 true }
172 } else {
173 !overflow_detected
175 }
176 }
177}
178
179pub struct MixedPrecisionTrainer {
181 mode: PrecisionMode,
183 scaler: LossScaler,
185 master_weights: HashMap<String, Array2<f32>>,
187 stats: MixedPrecisionStats,
189}
190
191impl MixedPrecisionTrainer {
192 pub fn new(mode: PrecisionMode, scaler: LossScaler) -> Self {
194 Self {
195 mode,
196 scaler,
197 master_weights: HashMap::new(),
198 stats: MixedPrecisionStats::default(),
199 }
200 }
201
202 pub fn register_weights(&mut self, name: String, weights: Array2<f32>) {
204 self.master_weights.insert(name, weights);
205 }
206
207 pub fn cast_to_working_precision(&self, weights: &Array2<f32>) -> Array2<f32> {
209 match self.mode {
210 PrecisionMode::FP32 => weights.clone(),
211 PrecisionMode::FP16 => self.simulate_fp16(weights),
212 PrecisionMode::BF16 => self.simulate_bf16(weights),
213 }
214 }
215
216 fn simulate_fp16(&self, weights: &Array2<f32>) -> Array2<f32> {
218 weights.mapv(|x| {
219 let clamped = x.clamp(-65504.0, 65504.0);
221 let scale = 2.0_f32.powi(10);
223 (clamped * scale).round() / scale
224 })
225 }
226
227 fn simulate_bf16(&self, weights: &Array2<f32>) -> Array2<f32> {
229 weights.mapv(|x| {
230 let scale = 2.0_f32.powi(7);
232 (x * scale).round() / scale
233 })
234 }
235
236 pub fn scale_loss(&mut self, loss: f32) -> f32 {
238 self.stats.total_steps += 1;
239 self.scaler.scale_loss(loss)
240 }
241
242 pub fn unscale_and_check_gradients(
247 &mut self,
248 gradients: &mut HashMap<String, Array2<f32>>,
249 ) -> TrainResult<(bool, bool)> {
250 let mut overflow = false;
252 for (_name, grad) in gradients.iter() {
253 if grad.iter().any(|&x| !x.is_finite()) {
254 overflow = true;
255 break;
256 }
257 }
258
259 if overflow {
260 self.stats.overflow_steps += 1;
261 }
262
263 for (_name, grad) in gradients.iter_mut() {
265 self.scaler.unscale_gradients(grad);
266 }
267
268 let should_step = self.scaler.update(overflow);
270
271 Ok((should_step, overflow))
272 }
273
274 pub fn update_master_weights(&mut self, updates: &HashMap<String, Array2<f32>>) {
276 for (name, update) in updates {
277 if let Some(master) = self.master_weights.get_mut(name) {
278 *master = master.clone() + update;
279 }
280 }
281 }
282
283 pub fn mode(&self) -> PrecisionMode {
285 self.mode
286 }
287
288 pub fn current_scale(&self) -> f32 {
290 self.scaler.get_scale()
291 }
292
293 pub fn stats(&self) -> &MixedPrecisionStats {
295 &self.stats
296 }
297
298 pub fn reset_stats(&mut self) {
300 self.stats = MixedPrecisionStats::default();
301 }
302}
303
304#[derive(Debug, Clone, Default, Serialize, Deserialize)]
306pub struct MixedPrecisionStats {
307 pub total_steps: usize,
309 pub overflow_steps: usize,
311 pub successful_steps: usize,
313}
314
315impl MixedPrecisionStats {
316 pub fn overflow_rate(&self) -> f32 {
318 if self.total_steps == 0 {
319 0.0
320 } else {
321 self.overflow_steps as f32 / self.total_steps as f32
322 }
323 }
324
325 pub fn success_rate(&self) -> f32 {
327 if self.total_steps == 0 {
328 0.0
329 } else {
330 self.successful_steps as f32 / self.total_steps as f32
331 }
332 }
333}
334
335pub struct GradientScaler {
337 scaler: LossScaler,
338 enabled: bool,
339}
340
341impl GradientScaler {
342 pub fn new(enabled: bool) -> Self {
344 let scaler = if enabled {
345 LossScaler::dynamic(2.0_f32.powi(15), 2.0, 2000)
346 } else {
347 LossScaler::None
348 };
349
350 Self { scaler, enabled }
351 }
352
353 pub fn with_scaler(scaler: LossScaler, enabled: bool) -> Self {
355 Self { scaler, enabled }
356 }
357
358 pub fn scale(&self, loss: f32) -> f32 {
360 if self.enabled {
361 self.scaler.scale_loss(loss)
362 } else {
363 loss
364 }
365 }
366
367 pub fn unscale(&self, gradients: &mut Array2<f32>) {
369 if self.enabled {
370 self.scaler.unscale_gradients(gradients);
371 }
372 }
373
374 pub fn step(&mut self, overflow_detected: bool) -> bool {
376 if self.enabled {
377 self.scaler.update(overflow_detected)
378 } else {
379 !overflow_detected
380 }
381 }
382
383 pub fn get_scale(&self) -> f32 {
385 self.scaler.get_scale()
386 }
387}
388
389pub struct AutocastContext {
391 enabled: bool,
392 mode: PrecisionMode,
393}
394
395impl AutocastContext {
396 pub fn new(enabled: bool, mode: PrecisionMode) -> Self {
398 Self { enabled, mode }
399 }
400
401 pub fn is_enabled(&self) -> bool {
403 self.enabled
404 }
405
406 pub fn mode(&self) -> PrecisionMode {
408 self.mode
409 }
410
411 pub fn cast(&self, tensor: &Array2<f32>) -> Array2<f32> {
413 if !self.enabled || self.mode == PrecisionMode::FP32 {
414 return tensor.clone();
415 }
416
417 match self.mode {
418 PrecisionMode::FP16 => self.simulate_fp16(tensor),
419 PrecisionMode::BF16 => self.simulate_bf16(tensor),
420 PrecisionMode::FP32 => tensor.clone(),
421 }
422 }
423
424 fn simulate_fp16(&self, tensor: &Array2<f32>) -> Array2<f32> {
425 tensor.mapv(|x| {
426 let clamped = x.clamp(-65504.0, 65504.0);
427 let scale = 2.0_f32.powi(10);
428 (clamped * scale).round() / scale
429 })
430 }
431
432 fn simulate_bf16(&self, tensor: &Array2<f32>) -> Array2<f32> {
433 tensor.mapv(|x| {
434 let scale = 2.0_f32.powi(7);
435 (x * scale).round() / scale
436 })
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use approx::assert_relative_eq;
444
445 #[test]
446 fn test_precision_mode_properties() {
447 assert_eq!(PrecisionMode::FP32.bytes_per_element(), 4);
448 assert_eq!(PrecisionMode::FP16.bytes_per_element(), 2);
449 assert_eq!(PrecisionMode::BF16.bytes_per_element(), 2);
450
451 assert_eq!(PrecisionMode::FP16.memory_reduction(), 2.0);
452 assert_eq!(PrecisionMode::BF16.memory_reduction(), 2.0);
453 }
454
455 #[test]
456 fn test_static_loss_scaler() {
457 let scaler = LossScaler::static_scale(1024.0);
458 assert_eq!(scaler.get_scale(), 1024.0);
459
460 let loss = 0.5;
461 let scaled = scaler.scale_loss(loss);
462 assert_eq!(scaled, 512.0);
463 }
464
465 #[test]
466 fn test_dynamic_loss_scaler() {
467 let mut scaler = LossScaler::dynamic(1000.0, 2.0, 3);
468 assert_eq!(scaler.get_scale(), 1000.0);
469
470 assert!(scaler.update(false));
472 assert!(scaler.update(false));
473 assert!(scaler.update(false));
474 assert_eq!(scaler.get_scale(), 2000.0); assert!(!scaler.update(true));
478 assert_eq!(scaler.get_scale(), 1000.0); }
480
481 #[test]
482 fn test_gradient_unscaling() {
483 let mut gradients =
484 Array2::from_shape_vec((2, 2), vec![100.0, 200.0, 300.0, 400.0]).unwrap();
485 let scaler = LossScaler::static_scale(10.0);
486
487 scaler.unscale_gradients(&mut gradients);
488
489 assert_eq!(gradients[[0, 0]], 10.0);
490 assert_eq!(gradients[[0, 1]], 20.0);
491 assert_eq!(gradients[[1, 0]], 30.0);
492 assert_eq!(gradients[[1, 1]], 40.0);
493 }
494
495 #[test]
496 fn test_mixed_precision_trainer() {
497 let mut trainer =
498 MixedPrecisionTrainer::new(PrecisionMode::FP16, LossScaler::static_scale(100.0));
499
500 let loss = 0.5;
501 let scaled_loss = trainer.scale_loss(loss);
502 assert_eq!(scaled_loss, 50.0);
503 assert_eq!(trainer.stats().total_steps, 1);
504 }
505
506 #[test]
507 fn test_fp16_simulation() {
508 let trainer = MixedPrecisionTrainer::new(PrecisionMode::FP16, LossScaler::None);
509
510 let weights =
511 Array2::from_shape_vec((2, 2), vec![1.234_567, 100000.0, -100000.0, 0.0001]).unwrap();
512 let fp16_weights = trainer.cast_to_working_precision(&weights);
513
514 assert_ne!(fp16_weights[[0, 0]], 1.234_567); assert!(fp16_weights[[0, 0]] > 1.0 && fp16_weights[[0, 0]] < 2.0);
517
518 assert!(fp16_weights[[0, 1]] <= 65504.0);
520 assert!(fp16_weights[[1, 0]] >= -65504.0);
521 }
522
523 #[test]
524 fn test_bf16_simulation() {
525 let trainer = MixedPrecisionTrainer::new(PrecisionMode::BF16, LossScaler::None);
526
527 let weights =
528 Array2::from_shape_vec((2, 2), vec![1.234_567, 100.5, -50.25, 0.125]).unwrap();
529 let bf16_weights = trainer.cast_to_working_precision(&weights);
530
531 assert_ne!(bf16_weights[[0, 0]], 1.234_567);
533 }
534
535 #[test]
536 fn test_overflow_detection() {
537 let mut trainer =
538 MixedPrecisionTrainer::new(PrecisionMode::FP16, LossScaler::dynamic(1000.0, 2.0, 100));
539
540 let mut gradients = HashMap::new();
541 gradients.insert(
542 "layer1".to_string(),
543 Array2::from_shape_vec((2, 2), vec![f32::INFINITY, 1.0, 2.0, 3.0]).unwrap(),
544 );
545
546 let (should_step, overflow) = trainer.unscale_and_check_gradients(&mut gradients).unwrap();
547
548 assert!(!should_step);
549 assert!(overflow);
550 assert_eq!(trainer.stats().overflow_steps, 1);
551 }
552
553 #[test]
554 fn test_gradient_scaler() {
555 let scaler = GradientScaler::new(true);
556
557 let loss = 1.0;
558 let scaled = scaler.scale(loss);
559 assert!(scaled > loss); let mut grads = Array2::from_shape_vec((2, 2), vec![1000.0; 4]).unwrap();
562 scaler.unscale(&mut grads);
563 assert!(grads[[0, 0]] < 1000.0); }
565
566 #[test]
567 fn test_autocast_context() {
568 let ctx = AutocastContext::new(true, PrecisionMode::FP16);
569 assert!(ctx.is_enabled());
570 assert_eq!(ctx.mode(), PrecisionMode::FP16);
571
572 let tensor = Array2::from_shape_vec((2, 2), vec![1.234_567; 4]).unwrap();
573 let casted = ctx.cast(&tensor);
574
575 assert_ne!(casted[[0, 0]], 1.234_567);
577 }
578
579 #[test]
580 fn test_autocast_disabled() {
581 let ctx = AutocastContext::new(false, PrecisionMode::FP16);
582 assert!(!ctx.is_enabled());
583
584 let tensor = Array2::from_shape_vec((2, 2), vec![1.234_567; 4]).unwrap();
585 let casted = ctx.cast(&tensor);
586
587 assert_eq!(casted, tensor);
589 }
590
591 #[test]
592 fn test_master_weights_update() {
593 let mut trainer = MixedPrecisionTrainer::new(PrecisionMode::FP16, LossScaler::None);
594
595 let weights = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
596 trainer.register_weights("layer1".to_string(), weights.clone());
597
598 let mut updates = HashMap::new();
599 updates.insert(
600 "layer1".to_string(),
601 Array2::from_shape_vec((2, 2), vec![0.1, 0.1, 0.1, 0.1]).unwrap(),
602 );
603
604 trainer.update_master_weights(&updates);
605
606 let master = &trainer.master_weights["layer1"];
607 assert_relative_eq!(master[[0, 0]], 1.1, epsilon = 1e-6);
608 }
609
610 #[test]
611 fn test_mixed_precision_stats() {
612 let stats = MixedPrecisionStats {
613 total_steps: 100,
614 overflow_steps: 5,
615 successful_steps: 95,
616 };
617
618 assert_eq!(stats.overflow_rate(), 0.05);
619 assert_eq!(stats.success_rate(), 0.95);
620 }
621
622 #[test]
623 fn test_loss_scaler_growth() {
624 let mut scaler = LossScaler::dynamic(1000.0, 2.0, 2);
625
626 assert!(scaler.update(false));
628 assert_eq!(scaler.get_scale(), 1000.0);
629
630 assert!(scaler.update(false));
632 assert_eq!(scaler.get_scale(), 2000.0);
633 }
634}