1use anyhow::Result;
2use scirs2_core::Complex; use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use trustformers_core::tensor::Tensor;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct MixedPrecisionConfig {
10 pub enabled: bool,
12 pub init_scale: f32,
14 pub scale_factor: f32,
16 pub backoff_factor: f32,
18 pub scale_window: usize,
20 pub min_scale: f32,
22 pub max_scale: f32,
24 pub skip_inf_nan: bool,
26}
27
28impl Default for MixedPrecisionConfig {
29 fn default() -> Self {
30 Self {
31 enabled: false,
32 init_scale: 2f32.powf(16.0), scale_factor: 2.0,
34 backoff_factor: 0.5,
35 scale_window: 2000,
36 min_scale: 1.0,
37 max_scale: 2f32.powf(24.0), skip_inf_nan: true,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct LossScaler {
46 config: MixedPrecisionConfig,
47 current_scale: f32,
48 steps_since_overflow: usize,
49 overflow_detected: bool,
50}
51
52impl LossScaler {
53 pub fn new(config: MixedPrecisionConfig) -> Self {
54 Self {
55 current_scale: config.init_scale,
56 steps_since_overflow: 0,
57 overflow_detected: false,
58 config,
59 }
60 }
61
62 pub fn get_scale(&self) -> f32 {
64 if self.config.enabled {
65 self.current_scale
66 } else {
67 1.0
68 }
69 }
70
71 pub fn scale_loss(&self, loss: &Tensor) -> Result<Tensor> {
73 if !self.config.enabled {
74 return Ok(loss.clone());
75 }
76
77 loss.scalar_mul(self.current_scale).map_err(|e| anyhow::anyhow!(e))
78 }
79
80 pub fn unscale_gradients(&self, gradients: &mut HashMap<String, Tensor>) -> Result<bool> {
82 if !self.config.enabled {
83 return Ok(false);
84 }
85
86 let scale = self.current_scale;
87 let mut overflow_detected = false;
88
89 for (_, gradient) in gradients.iter_mut() {
90 if self.has_inf_nan(gradient)? {
92 overflow_detected = true;
93 if self.config.skip_inf_nan {
94 break;
95 }
96 }
97
98 *gradient = gradient.scalar_mul(1.0 / scale).map_err(|e| anyhow::anyhow!(e))?;
100 }
101
102 Ok(overflow_detected)
103 }
104
105 pub fn update_scale(&mut self, overflow_detected: bool) -> Result<()> {
107 if !self.config.enabled {
108 return Ok(());
109 }
110
111 if overflow_detected {
112 self.current_scale =
114 (self.current_scale * self.config.backoff_factor).max(self.config.min_scale);
115 self.steps_since_overflow = 0;
116 self.overflow_detected = true;
117 } else {
118 self.steps_since_overflow += 1;
120 self.overflow_detected = false;
121
122 if self.steps_since_overflow >= self.config.scale_window {
123 self.current_scale =
124 (self.current_scale * self.config.scale_factor).min(self.config.max_scale);
125 self.steps_since_overflow = 0;
126 }
127 }
128
129 Ok(())
130 }
131
132 pub fn overflow_detected(&self) -> bool {
134 self.overflow_detected
135 }
136
137 fn has_inf_nan(&self, tensor: &Tensor) -> Result<bool> {
139 match tensor {
140 Tensor::F32(arr) => {
141 for &value in arr.iter() {
142 if !value.is_finite() {
143 return Ok(true);
144 }
145 }
146 Ok(false)
147 },
148 Tensor::F64(arr) => {
149 for &value in arr.iter() {
150 if !value.is_finite() {
151 return Ok(true);
152 }
153 }
154 Ok(false)
155 },
156 Tensor::F16(arr) => {
157 for &value in arr.iter() {
158 if !value.to_f32().is_finite() {
159 return Ok(true);
160 }
161 }
162 Ok(false)
163 },
164 Tensor::BF16(arr) => {
165 for &value in arr.iter() {
166 if !value.to_f32().is_finite() {
167 return Ok(true);
168 }
169 }
170 Ok(false)
171 },
172 Tensor::I64(_) => Ok(false), Tensor::C32(arr) => {
174 for &value in arr.iter() {
175 if !value.re.is_finite() || !value.im.is_finite() {
176 return Ok(true);
177 }
178 }
179 Ok(false)
180 },
181 Tensor::C64(arr) => {
182 for &value in arr.iter() {
183 if !value.re.is_finite() || !value.im.is_finite() {
184 return Ok(true);
185 }
186 }
187 Ok(false)
188 },
189 Tensor::CF16(arr) => {
190 for &value in arr.iter() {
191 if !value.re.to_f32().is_finite() || !value.im.to_f32().is_finite() {
192 return Ok(true);
193 }
194 }
195 Ok(false)
196 },
197 Tensor::CBF16(arr) => {
198 for &value in arr.iter() {
199 if !value.re.to_f32().is_finite() || !value.im.to_f32().is_finite() {
200 return Ok(true);
201 }
202 }
203 Ok(false)
204 },
205 _ => Ok(false), }
207 }
208}
209
210#[derive(Debug)]
212pub struct AMPManager {
213 pub loss_scaler: LossScaler,
214 pub config: MixedPrecisionConfig,
215}
216
217impl AMPManager {
218 pub fn new(config: MixedPrecisionConfig) -> Self {
219 let loss_scaler = LossScaler::new(config.clone());
220 Self {
221 loss_scaler,
222 config,
223 }
224 }
225
226 pub fn to_half_precision(&self, tensor: &Tensor) -> Result<Tensor> {
229 if !self.config.enabled {
230 return Ok(tensor.clone());
231 }
232
233 match tensor {
235 Tensor::F32(arr) => {
236 let quantized = arr.mapv(|x| {
237 let clamped = x.clamp(-65504.0, 65504.0); (clamped * 1024.0).round() / 1024.0
244 });
245 Ok(Tensor::F32(quantized))
246 },
247 Tensor::F64(_) => Ok(tensor.clone()),
248 Tensor::F16(_) => Ok(tensor.clone()), Tensor::BF16(_) => Ok(tensor.clone()), Tensor::I64(_) => Ok(tensor.clone()),
251 Tensor::C32(arr) => {
252 let quantized = arr.mapv(|x| {
253 let re_clamped = x.re.clamp(-65504.0, 65504.0);
254 let im_clamped = x.im.clamp(-65504.0, 65504.0);
255 let re_scaled = (re_clamped * 1024.0).round() / 1024.0;
256 let im_scaled = (im_clamped * 1024.0).round() / 1024.0;
257 Complex::new(re_scaled, im_scaled)
258 });
259 Ok(Tensor::C32(quantized))
260 },
261 Tensor::C64(_) => Ok(tensor.clone()),
262 Tensor::CF16(_) => Ok(tensor.clone()), Tensor::CBF16(_) => Ok(tensor.clone()), _ => Ok(tensor.clone()), }
266 }
267
268 pub fn to_full_precision(&self, tensor: &Tensor) -> Result<Tensor> {
270 Ok(tensor.clone())
272 }
273
274 pub fn forward_with_amp<F>(&self, forward_fn: F) -> Result<Tensor>
276 where
277 F: FnOnce() -> Result<Tensor>,
278 {
279 if !self.config.enabled {
280 return forward_fn();
281 }
282
283 let output = forward_fn()?;
289 self.to_full_precision(&output)
290 }
291
292 pub fn backward_with_amp(
294 &mut self,
295 loss: &Tensor,
296 gradients: &mut HashMap<String, Tensor>,
297 ) -> Result<bool> {
298 let _scaled_loss = self.loss_scaler.scale_loss(loss)?;
300
301 let overflow = self.loss_scaler.unscale_gradients(gradients)?;
306
307 self.loss_scaler.update_scale(overflow)?;
309
310 Ok(overflow)
311 }
312
313 pub fn get_loss_scale(&self) -> f32 {
315 self.loss_scaler.get_scale()
316 }
317
318 pub fn is_enabled(&self) -> bool {
320 self.config.enabled
321 }
322}
323
324pub mod utils {
326 use super::*;
327
328 pub fn default_fp16_config() -> MixedPrecisionConfig {
330 MixedPrecisionConfig {
331 enabled: true,
332 init_scale: 2f32.powf(16.0),
333 scale_factor: 2.0,
334 backoff_factor: 0.5,
335 scale_window: 2000,
336 min_scale: 1.0,
337 max_scale: 2f32.powf(24.0),
338 skip_inf_nan: true,
339 }
340 }
341
342 pub fn default_bf16_config() -> MixedPrecisionConfig {
344 MixedPrecisionConfig {
345 enabled: true,
346 init_scale: 1.0, scale_factor: 1.0,
348 backoff_factor: 1.0,
349 scale_window: usize::MAX,
350 min_scale: 1.0,
351 max_scale: 1.0,
352 skip_inf_nan: true,
353 }
354 }
355
356 pub fn is_fp16_safe(tensor: &Tensor) -> Result<bool> {
358 match tensor {
359 Tensor::F32(arr) => {
360 for &value in arr.iter() {
361 if value.abs() > 65504.0 || (!value.is_finite() && value != 0.0) {
362 return Ok(false);
363 }
364 }
365 Ok(true)
366 },
367 Tensor::F64(arr) => {
368 for &value in arr.iter() {
369 if value.abs() > 65504.0 || (!value.is_finite() && value != 0.0) {
370 return Ok(false);
371 }
372 }
373 Ok(true)
374 },
375 Tensor::F16(_) => Ok(true), Tensor::BF16(_) => Ok(true), Tensor::I64(_) => Ok(true),
378 Tensor::C32(arr) => {
379 for &value in arr.iter() {
380 if value.re.abs() > 65504.0
381 || value.im.abs() > 65504.0
382 || (!value.re.is_finite() && value.re != 0.0)
383 || (!value.im.is_finite() && value.im != 0.0)
384 {
385 return Ok(false);
386 }
387 }
388 Ok(true)
389 },
390 Tensor::C64(arr) => {
391 for &value in arr.iter() {
392 if value.re.abs() > 65504.0
393 || value.im.abs() > 65504.0
394 || (!value.re.is_finite() && value.re != 0.0)
395 || (!value.im.is_finite() && value.im != 0.0)
396 {
397 return Ok(false);
398 }
399 }
400 Ok(true)
401 },
402 Tensor::CF16(_) => Ok(true), Tensor::CBF16(_) => Ok(true), _ => Ok(true), }
406 }
407
408 pub fn calculate_dynamic_range(tensor: &Tensor) -> Result<(f32, f32)> {
410 match tensor {
411 Tensor::F32(arr) => {
412 let mut min_val = f32::INFINITY;
413 let mut max_val = f32::NEG_INFINITY;
414
415 for &value in arr.iter() {
416 if value.is_finite() {
417 min_val = min_val.min(value);
418 max_val = max_val.max(value);
419 }
420 }
421
422 Ok((min_val, max_val))
423 },
424 Tensor::F64(arr) => {
425 let min_val = arr
426 .iter()
427 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
428 .copied()
429 .unwrap_or(0.0) as f32;
430 let max_val = arr
431 .iter()
432 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
433 .copied()
434 .unwrap_or(0.0) as f32;
435 Ok((min_val, max_val))
436 },
437 Tensor::F16(arr) => {
438 let mut min_val = f32::INFINITY;
439 let mut max_val = f32::NEG_INFINITY;
440
441 for &value in arr.iter() {
442 let f32_val = value.to_f32();
443 if f32_val.is_finite() {
444 min_val = min_val.min(f32_val);
445 max_val = max_val.max(f32_val);
446 }
447 }
448
449 Ok((min_val, max_val))
450 },
451 Tensor::BF16(arr) => {
452 let mut min_val = f32::INFINITY;
453 let mut max_val = f32::NEG_INFINITY;
454
455 for &value in arr.iter() {
456 let f32_val = value.to_f32();
457 if f32_val.is_finite() {
458 min_val = min_val.min(f32_val);
459 max_val = max_val.max(f32_val);
460 }
461 }
462
463 Ok((min_val, max_val))
464 },
465 Tensor::I64(arr) => {
466 let min_val = arr.iter().min().copied().unwrap_or(0) as f32;
467 let max_val = arr.iter().max().copied().unwrap_or(0) as f32;
468 Ok((min_val, max_val))
469 },
470 Tensor::C32(arr) => {
471 let mut min_val = f32::INFINITY;
472 let mut max_val = f32::NEG_INFINITY;
473
474 for &value in arr.iter() {
475 let magnitude = value.norm();
476 if magnitude.is_finite() {
477 min_val = min_val.min(magnitude);
478 max_val = max_val.max(magnitude);
479 }
480 }
481
482 Ok((min_val, max_val))
483 },
484 Tensor::C64(arr) => {
485 let mut min_val = f32::INFINITY;
486 let mut max_val = f32::NEG_INFINITY;
487
488 for &value in arr.iter() {
489 let magnitude = value.norm() as f32;
490 if magnitude.is_finite() {
491 min_val = min_val.min(magnitude);
492 max_val = max_val.max(magnitude);
493 }
494 }
495
496 Ok((min_val, max_val))
497 },
498 Tensor::CF16(arr) => {
499 let mut min_val = f32::INFINITY;
500 let mut max_val = f32::NEG_INFINITY;
501
502 for &value in arr.iter() {
503 let magnitude = (value.re.to_f32().powi(2) + value.im.to_f32().powi(2)).sqrt();
504 if magnitude.is_finite() {
505 min_val = min_val.min(magnitude);
506 max_val = max_val.max(magnitude);
507 }
508 }
509
510 Ok((min_val, max_val))
511 },
512 Tensor::CBF16(arr) => {
513 let mut min_val = f32::INFINITY;
514 let mut max_val = f32::NEG_INFINITY;
515
516 for &value in arr.iter() {
517 let magnitude = (value.re.to_f32().powi(2) + value.im.to_f32().powi(2)).sqrt();
518 if magnitude.is_finite() {
519 min_val = min_val.min(magnitude);
520 max_val = max_val.max(magnitude);
521 }
522 }
523
524 Ok((min_val, max_val))
525 },
526 _ => Ok((0.0, 1.0)), }
528 }
529}
530
531#[derive(Debug, Clone, Serialize, Deserialize)]
533pub struct AdvancedMixedPrecisionConfig {
534 pub base_config: MixedPrecisionConfig,
536 pub enable_dynamic_scaling: bool,
538 pub enable_per_layer_scaling: bool,
540 pub enable_auto_precision: bool,
542 pub min_precision: String,
544 pub max_precision: String,
546 pub precision_adaptation_rate: f32,
548 pub memory_threshold: f32,
550 pub performance_threshold: f32,
552}
553
554impl Default for AdvancedMixedPrecisionConfig {
555 fn default() -> Self {
556 Self {
557 base_config: MixedPrecisionConfig::default(),
558 enable_dynamic_scaling: true,
559 enable_per_layer_scaling: true,
560 enable_auto_precision: true,
561 min_precision: "fp16".to_string(),
562 max_precision: "fp32".to_string(),
563 precision_adaptation_rate: 0.1,
564 memory_threshold: 0.8,
565 performance_threshold: 0.9,
566 }
567 }
568}
569
570#[derive(Debug, Clone)]
572pub struct LayerScalingConfig {
573 pub layer_name: String,
575 pub scale_factor: f32,
577 pub gradient_norm_history: Vec<f32>,
579 pub loss_history: Vec<f32>,
581 pub overflow_count: usize,
583 pub underflow_count: usize,
585}
586
587impl LayerScalingConfig {
588 pub fn new(layer_name: String) -> Self {
589 Self {
590 layer_name,
591 scale_factor: 1.0,
592 gradient_norm_history: Vec::new(),
593 loss_history: Vec::new(),
594 overflow_count: 0,
595 underflow_count: 0,
596 }
597 }
598}
599
600#[derive(Debug)]
602pub struct AdvancedMixedPrecisionManager {
603 config: AdvancedMixedPrecisionConfig,
604 base_manager: AMPManager,
605 layer_configs: HashMap<String, LayerScalingConfig>,
606 current_precision: String,
607 precision_history: Vec<(usize, String)>,
608 memory_usage_history: Vec<f32>,
609 performance_history: Vec<f32>,
610 step_count: usize,
611}
612
613impl AdvancedMixedPrecisionManager {
614 pub fn new(config: AdvancedMixedPrecisionConfig) -> Self {
615 let base_manager = AMPManager::new(config.base_config.clone());
616 Self {
617 config,
618 base_manager,
619 layer_configs: HashMap::new(),
620 current_precision: "fp32".to_string(),
621 precision_history: Vec::new(),
622 memory_usage_history: Vec::new(),
623 performance_history: Vec::new(),
624 step_count: 0,
625 }
626 }
627
628 pub fn update_step(&mut self, memory_usage: f32, performance_score: f32) {
630 self.step_count += 1;
631 self.memory_usage_history.push(memory_usage);
632 self.performance_history.push(performance_score);
633
634 if self.memory_usage_history.len() > 100 {
636 self.memory_usage_history.remove(0);
637 self.performance_history.remove(0);
638 }
639
640 if self.config.enable_auto_precision {
642 self.adapt_precision();
643 }
644 }
645
646 fn adapt_precision(&mut self) {
648 let avg_memory =
649 self.memory_usage_history.iter().sum::<f32>() / self.memory_usage_history.len() as f32;
650 let avg_performance =
651 self.performance_history.iter().sum::<f32>() / self.performance_history.len() as f32;
652
653 let target_precision = if avg_memory > self.config.memory_threshold {
654 match self.current_precision.as_str() {
656 "fp32" => "fp16",
657 "bf16" => "fp16",
658 _ => "fp16",
659 }
660 } else if avg_performance < self.config.performance_threshold {
661 match self.current_precision.as_str() {
663 "fp16" => "bf16",
664 "bf16" => "fp32",
665 _ => "fp32",
666 }
667 } else {
668 &self.current_precision
669 };
670
671 if target_precision != self.current_precision {
672 self.switch_precision(target_precision.to_string());
673 }
674 }
675
676 fn switch_precision(&mut self, new_precision: String) {
678 self.current_precision = new_precision.clone();
679 self.precision_history.push((self.step_count, new_precision));
680
681 match self.current_precision.as_str() {
683 "fp16" => {
684 self.base_manager.config = utils::default_fp16_config();
685 },
686 "bf16" => {
687 self.base_manager.config = utils::default_bf16_config();
688 },
689 "fp32" => {
690 self.base_manager.config = MixedPrecisionConfig {
691 enabled: false,
692 ..Default::default()
693 };
694 },
695 _ => {
696 self.base_manager.config = utils::default_fp16_config();
697 },
698 }
699 }
700
701 pub fn scale_gradients_per_layer(
703 &mut self,
704 gradients: &mut HashMap<String, Tensor>,
705 ) -> Result<bool> {
706 let mut global_overflow = false;
707
708 for (layer_name, gradient) in gradients.iter_mut() {
709 if !self.layer_configs.contains_key(layer_name) {
711 self.layer_configs.insert(
712 layer_name.clone(),
713 LayerScalingConfig::new(layer_name.clone()),
714 );
715 }
716
717 let grad_norm = self.compute_gradient_norm(gradient)?;
719
720 let enable_per_layer_scaling = self.config.enable_per_layer_scaling;
722
723 let layer_config = self
724 .layer_configs
725 .get_mut(layer_name)
726 .expect("layer config should exist after insertion at line 710-715");
727 layer_config.gradient_norm_history.push(grad_norm);
728
729 if layer_config.gradient_norm_history.len() > 50 {
731 layer_config.gradient_norm_history.remove(0);
732 }
733
734 if enable_per_layer_scaling {
736 if layer_config.gradient_norm_history.len() >= 5 {
738 let recent_norms: Vec<f32> =
739 layer_config.gradient_norm_history.iter().rev().take(5).cloned().collect();
740
741 let avg_norm = recent_norms.iter().sum::<f32>() / recent_norms.len() as f32;
742
743 if avg_norm > 10.0 {
745 layer_config.scale_factor *= 1.1;
747 } else if avg_norm < 0.01 {
748 layer_config.scale_factor *= 0.9;
750 }
751
752 layer_config.scale_factor = layer_config.scale_factor.clamp(0.01, 1000.0);
754 }
755 }
756
757 let scale_factor = layer_config.scale_factor;
758
759 let _ = layer_config;
761
762 *gradient = self.scale_tensor(gradient, scale_factor)?;
764
765 let has_overflow = self.has_overflow(gradient)?;
767
768 let layer_config = self
770 .layer_configs
771 .get_mut(layer_name)
772 .expect("layer config should exist after insertion at line 710-715");
773 if has_overflow {
774 layer_config.overflow_count += 1;
775 global_overflow = true;
776 } else {
777 layer_config.underflow_count += 1;
778 }
779 }
780
781 Ok(global_overflow)
782 }
783
784 #[allow(dead_code)]
786 fn adapt_layer_scaling(&mut self, layer_config: &mut LayerScalingConfig) {
787 if layer_config.gradient_norm_history.len() < 5 {
788 return;
789 }
790
791 let recent_norms: Vec<f32> =
792 layer_config.gradient_norm_history.iter().rev().take(5).cloned().collect();
793
794 let avg_norm = recent_norms.iter().sum::<f32>() / recent_norms.len() as f32;
795
796 if avg_norm > 10.0 {
798 layer_config.scale_factor *= 1.1;
800 } else if avg_norm < 0.01 {
801 layer_config.scale_factor *= 0.9;
803 }
804
805 layer_config.scale_factor = layer_config.scale_factor.clamp(0.1, 10.0);
807 }
808
809 fn compute_gradient_norm(&self, gradient: &Tensor) -> Result<f32> {
811 match gradient {
812 Tensor::F32(arr) => {
813 let norm = arr.iter().map(|&x| x * x).sum::<f32>().sqrt();
814 Ok(norm)
815 },
816 Tensor::F64(arr) => {
817 let norm = arr.iter().map(|&x| x * x).sum::<f64>().sqrt() as f32;
818 Ok(norm)
819 },
820 Tensor::F16(arr) => {
821 let norm = arr
822 .iter()
823 .map(|&x| {
824 let f32_val = x.to_f32();
825 f32_val * f32_val
826 })
827 .sum::<f32>()
828 .sqrt();
829 Ok(norm)
830 },
831 Tensor::BF16(arr) => {
832 let norm = arr
833 .iter()
834 .map(|&x| {
835 let f32_val = x.to_f32();
836 f32_val * f32_val
837 })
838 .sum::<f32>()
839 .sqrt();
840 Ok(norm)
841 },
842 Tensor::I64(_) => Ok(0.0),
843 Tensor::C32(arr) => {
844 let norm = arr.iter().map(|&x| x.norm_sqr()).sum::<f32>().sqrt();
845 Ok(norm)
846 },
847 Tensor::C64(arr) => {
848 let norm = arr.iter().map(|&x| x.norm_sqr() as f32).sum::<f32>().sqrt();
849 Ok(norm)
850 },
851 Tensor::CF16(arr) => {
852 let norm = arr
853 .iter()
854 .map(|&x| {
855 let re = x.re.to_f32();
856 let im = x.im.to_f32();
857 re * re + im * im
858 })
859 .sum::<f32>()
860 .sqrt();
861 Ok(norm)
862 },
863 Tensor::CBF16(arr) => {
864 let norm = arr
865 .iter()
866 .map(|&x| {
867 let re = x.re.to_f32();
868 let im = x.im.to_f32();
869 re * re + im * im
870 })
871 .sum::<f32>()
872 .sqrt();
873 Ok(norm)
874 },
875 _ => Ok(1.0), }
877 }
878
879 fn scale_tensor(&self, tensor: &Tensor, factor: f32) -> Result<Tensor> {
881 match tensor {
882 Tensor::F32(arr) => {
883 let scaled = arr.mapv(|x| x * factor);
884 Ok(Tensor::F32(scaled))
885 },
886 Tensor::F64(arr) => {
887 let scaled = arr.mapv(|x| x * factor as f64);
888 Ok(Tensor::F64(scaled))
889 },
890 Tensor::F16(arr) => {
891 let factor_f16 = half::f16::from_f32(factor);
892 let scaled = arr.mapv(|x| x * factor_f16);
893 Ok(Tensor::F16(scaled))
894 },
895 Tensor::BF16(arr) => {
896 let factor_bf16 = half::bf16::from_f32(factor);
897 let scaled = arr.mapv(|x| x * factor_bf16);
898 Ok(Tensor::BF16(scaled))
899 },
900 Tensor::I64(arr) => Ok(Tensor::I64(arr.clone())),
901 Tensor::C32(arr) => {
902 let scaled = arr.mapv(|x| x * factor);
903 Ok(Tensor::C32(scaled))
904 },
905 Tensor::C64(arr) => {
906 let scaled = arr.mapv(|x| x * factor as f64);
907 Ok(Tensor::C64(scaled))
908 },
909 Tensor::CF16(arr) => {
910 let factor_f16 = half::f16::from_f32(factor);
911 let scaled = arr.mapv(|x| Complex::new(x.re * factor_f16, x.im * factor_f16));
912 Ok(Tensor::CF16(scaled))
913 },
914 Tensor::CBF16(arr) => {
915 let factor_bf16 = half::bf16::from_f32(factor);
916 let scaled = arr.mapv(|x| Complex::new(x.re * factor_bf16, x.im * factor_bf16));
917 Ok(Tensor::CBF16(scaled))
918 },
919 _ => Ok(tensor.clone()), }
921 }
922
923 fn has_overflow(&self, tensor: &Tensor) -> Result<bool> {
925 match tensor {
926 Tensor::F32(arr) => {
927 for &value in arr.iter() {
928 if !value.is_finite() {
929 return Ok(true);
930 }
931 }
932 Ok(false)
933 },
934 Tensor::F64(arr) => {
935 for &value in arr.iter() {
936 if !value.is_finite() {
937 return Ok(true);
938 }
939 }
940 Ok(false)
941 },
942 Tensor::F16(arr) => {
943 for &value in arr.iter() {
944 if !value.to_f32().is_finite() {
945 return Ok(true);
946 }
947 }
948 Ok(false)
949 },
950 Tensor::BF16(arr) => {
951 for &value in arr.iter() {
952 if !value.to_f32().is_finite() {
953 return Ok(true);
954 }
955 }
956 Ok(false)
957 },
958 Tensor::I64(_) => Ok(false),
959 Tensor::C32(arr) => {
960 for &value in arr.iter() {
961 if !value.re.is_finite() || !value.im.is_finite() {
962 return Ok(true);
963 }
964 }
965 Ok(false)
966 },
967 Tensor::C64(arr) => {
968 for &value in arr.iter() {
969 if !value.re.is_finite() || !value.im.is_finite() {
970 return Ok(true);
971 }
972 }
973 Ok(false)
974 },
975 Tensor::CF16(arr) => {
976 for &value in arr.iter() {
977 if !value.re.to_f32().is_finite() || !value.im.to_f32().is_finite() {
978 return Ok(true);
979 }
980 }
981 Ok(false)
982 },
983 Tensor::CBF16(arr) => {
984 for &value in arr.iter() {
985 if !value.re.to_f32().is_finite() || !value.im.to_f32().is_finite() {
986 return Ok(true);
987 }
988 }
989 Ok(false)
990 },
991 _ => Ok(false), }
993 }
994
995 pub fn get_current_precision(&self) -> &str {
997 &self.current_precision
998 }
999
1000 pub fn get_precision_history(&self) -> &[(usize, String)] {
1002 &self.precision_history
1003 }
1004
1005 pub fn get_layer_configs(&self) -> &HashMap<String, LayerScalingConfig> {
1007 &self.layer_configs
1008 }
1009
1010 pub fn forward_with_advanced_amp<F>(&mut self, forward_fn: F) -> Result<Tensor>
1012 where
1013 F: FnOnce() -> Result<Tensor>,
1014 {
1015 let output = forward_fn()?;
1016
1017 match self.current_precision.as_str() {
1019 "fp16" => self.optimize_for_fp16(&output),
1020 "bf16" => self.optimize_for_bf16(&output),
1021 "fp32" => Ok(output),
1022 _ => Ok(output),
1023 }
1024 }
1025
1026 fn optimize_for_fp16(&self, tensor: &Tensor) -> Result<Tensor> {
1028 match tensor {
1029 Tensor::F32(arr) => {
1030 let optimized = arr.mapv(|x| {
1031 let clamped = x.clamp(-65504.0, 65504.0);
1033
1034 (clamped * 1024.0).round() / 1024.0
1035 });
1036 Ok(Tensor::F32(optimized))
1037 },
1038 _ => Ok(tensor.clone()),
1039 }
1040 }
1041
1042 fn optimize_for_bf16(&self, tensor: &Tensor) -> Result<Tensor> {
1044 match tensor {
1045 Tensor::F32(arr) => {
1046 let optimized = arr.mapv(|x| {
1047 (x * 128.0).round() / 128.0
1050 });
1051 Ok(Tensor::F32(optimized))
1052 },
1053 _ => Ok(tensor.clone()),
1054 }
1055 }
1056
1057 pub fn generate_report(&self) -> MixedPrecisionReport {
1059 let total_overflows = self.layer_configs.values().map(|config| config.overflow_count).sum();
1060
1061 let total_underflows =
1062 self.layer_configs.values().map(|config| config.underflow_count).sum();
1063
1064 let avg_memory_usage = if !self.memory_usage_history.is_empty() {
1065 self.memory_usage_history.iter().sum::<f32>() / self.memory_usage_history.len() as f32
1066 } else {
1067 0.0
1068 };
1069
1070 let avg_performance = if !self.performance_history.is_empty() {
1071 self.performance_history.iter().sum::<f32>() / self.performance_history.len() as f32
1072 } else {
1073 0.0
1074 };
1075
1076 MixedPrecisionReport {
1077 current_precision: self.current_precision.clone(),
1078 step_count: self.step_count,
1079 total_overflows,
1080 total_underflows,
1081 avg_memory_usage,
1082 avg_performance,
1083 precision_switches: self.precision_history.len(),
1084 layer_count: self.layer_configs.len(),
1085 recommendations: self.generate_recommendations(),
1086 }
1087 }
1088
1089 fn generate_recommendations(&self) -> Vec<String> {
1091 let mut recommendations = Vec::new();
1092
1093 let total_overflows =
1095 self.layer_configs.values().map(|config| config.overflow_count).sum::<usize>();
1096
1097 if total_overflows > self.step_count / 10 {
1098 recommendations
1099 .push("High overflow rate detected - consider reducing learning rate".to_string());
1100 }
1101
1102 let avg_memory = if !self.memory_usage_history.is_empty() {
1104 self.memory_usage_history.iter().sum::<f32>() / self.memory_usage_history.len() as f32
1105 } else {
1106 0.0
1107 };
1108
1109 if avg_memory > 0.9 {
1110 recommendations
1111 .push("High memory usage - consider using fp16 or reducing batch size".to_string());
1112 }
1113
1114 let avg_performance = if !self.performance_history.is_empty() {
1116 self.performance_history.iter().sum::<f32>() / self.performance_history.len() as f32
1117 } else {
1118 0.0
1119 };
1120
1121 if avg_performance < 0.5 {
1122 recommendations.push(
1123 "Low performance - consider using higher precision or adjusting hyperparameters"
1124 .to_string(),
1125 );
1126 }
1127
1128 if self.precision_history.len() > 10 {
1130 recommendations.push(
1131 "Frequent precision switches - consider adjusting adaptation thresholds"
1132 .to_string(),
1133 );
1134 }
1135
1136 if recommendations.is_empty() {
1137 recommendations.push("Mixed precision training is working well".to_string());
1138 }
1139
1140 recommendations
1141 }
1142}
1143
1144#[derive(Debug, Clone, Serialize, Deserialize)]
1146pub struct MixedPrecisionReport {
1147 pub current_precision: String,
1148 pub step_count: usize,
1149 pub total_overflows: usize,
1150 pub total_underflows: usize,
1151 pub avg_memory_usage: f32,
1152 pub avg_performance: f32,
1153 pub precision_switches: usize,
1154 pub layer_count: usize,
1155 pub recommendations: Vec<String>,
1156}
1157
1158#[derive(Debug, Clone, Serialize, Deserialize)]
1160pub struct DynamicBatchingConfig {
1161 pub initial_batch_size: usize,
1163 pub max_batch_size: usize,
1165 pub min_batch_size: usize,
1167 pub adaptation_rate: f32,
1169 pub memory_threshold: f32,
1171 pub performance_threshold: f32,
1173}
1174
1175impl Default for DynamicBatchingConfig {
1176 fn default() -> Self {
1177 Self {
1178 initial_batch_size: 32,
1179 max_batch_size: 128,
1180 min_batch_size: 8,
1181 adaptation_rate: 0.1,
1182 memory_threshold: 0.85,
1183 performance_threshold: 0.9,
1184 }
1185 }
1186}
1187
1188#[derive(Debug)]
1190pub struct DynamicBatchingManager {
1191 config: DynamicBatchingConfig,
1192 current_batch_size: usize,
1193 batch_size_history: Vec<(usize, usize)>,
1194 memory_usage_history: Vec<f32>,
1195 performance_history: Vec<f32>,
1196 step_count: usize,
1197}
1198
1199impl DynamicBatchingManager {
1200 pub fn new(config: DynamicBatchingConfig) -> Self {
1201 Self {
1202 current_batch_size: config.initial_batch_size,
1203 config,
1204 batch_size_history: Vec::new(),
1205 memory_usage_history: Vec::new(),
1206 performance_history: Vec::new(),
1207 step_count: 0,
1208 }
1209 }
1210
1211 pub fn update_step(&mut self, memory_usage: f32, performance_score: f32) {
1213 self.step_count += 1;
1214 self.memory_usage_history.push(memory_usage);
1215 self.performance_history.push(performance_score);
1216
1217 if self.memory_usage_history.len() > 50 {
1219 self.memory_usage_history.remove(0);
1220 self.performance_history.remove(0);
1221 }
1222
1223 self.adapt_batch_size();
1225 }
1226
1227 fn adapt_batch_size(&mut self) {
1229 let avg_memory =
1230 self.memory_usage_history.iter().sum::<f32>() / self.memory_usage_history.len() as f32;
1231 let avg_performance =
1232 self.performance_history.iter().sum::<f32>() / self.performance_history.len() as f32;
1233
1234 let old_batch_size = self.current_batch_size;
1235
1236 if avg_memory > self.config.memory_threshold {
1237 let reduction = (self.current_batch_size as f32 * self.config.adaptation_rate) as usize;
1239 self.current_batch_size =
1240 (self.current_batch_size - reduction).max(self.config.min_batch_size);
1241 } else if avg_performance > self.config.performance_threshold {
1242 let increase = (self.current_batch_size as f32 * self.config.adaptation_rate) as usize;
1244 self.current_batch_size =
1245 (self.current_batch_size + increase).min(self.config.max_batch_size);
1246 }
1247
1248 if self.current_batch_size != old_batch_size {
1249 self.batch_size_history.push((self.step_count, self.current_batch_size));
1250 }
1251 }
1252
1253 pub fn get_current_batch_size(&self) -> usize {
1255 self.current_batch_size
1256 }
1257
1258 pub fn get_batch_size_history(&self) -> &[(usize, usize)] {
1260 &self.batch_size_history
1261 }
1262
1263 pub fn generate_report(&self) -> DynamicBatchingReport {
1265 let avg_memory = if !self.memory_usage_history.is_empty() {
1266 self.memory_usage_history.iter().sum::<f32>() / self.memory_usage_history.len() as f32
1267 } else {
1268 0.0
1269 };
1270
1271 let avg_performance = if !self.performance_history.is_empty() {
1272 self.performance_history.iter().sum::<f32>() / self.performance_history.len() as f32
1273 } else {
1274 0.0
1275 };
1276
1277 DynamicBatchingReport {
1278 current_batch_size: self.current_batch_size,
1279 step_count: self.step_count,
1280 avg_memory_usage: avg_memory,
1281 avg_performance,
1282 batch_size_changes: self.batch_size_history.len(),
1283 memory_efficiency: 1.0 - avg_memory,
1284 performance_score: avg_performance,
1285 }
1286 }
1287}
1288
1289#[derive(Debug, Clone, Serialize, Deserialize)]
1291pub struct DynamicBatchingReport {
1292 pub current_batch_size: usize,
1293 pub step_count: usize,
1294 pub avg_memory_usage: f32,
1295 pub avg_performance: f32,
1296 pub batch_size_changes: usize,
1297 pub memory_efficiency: f32,
1298 pub performance_score: f32,
1299}
1300
1301#[derive(Debug)]
1303pub struct ComputeOptimizationManager {
1304 mixed_precision_manager: AdvancedMixedPrecisionManager,
1305 dynamic_batching_manager: DynamicBatchingManager,
1306 kernel_fusion_enabled: bool,
1307 pipeline_optimization_enabled: bool,
1308}
1309
1310impl ComputeOptimizationManager {
1311 pub fn new(
1312 mixed_precision_config: AdvancedMixedPrecisionConfig,
1313 dynamic_batching_config: DynamicBatchingConfig,
1314 ) -> Self {
1315 Self {
1316 mixed_precision_manager: AdvancedMixedPrecisionManager::new(mixed_precision_config),
1317 dynamic_batching_manager: DynamicBatchingManager::new(dynamic_batching_config),
1318 kernel_fusion_enabled: true,
1319 pipeline_optimization_enabled: true,
1320 }
1321 }
1322
1323 pub fn update_step(&mut self, memory_usage: f32, performance_score: f32) {
1325 self.mixed_precision_manager.update_step(memory_usage, performance_score);
1326 self.dynamic_batching_manager.update_step(memory_usage, performance_score);
1327 }
1328
1329 pub fn get_current_batch_size(&self) -> usize {
1331 self.dynamic_batching_manager.get_current_batch_size()
1332 }
1333
1334 pub fn get_current_precision(&self) -> &str {
1336 self.mixed_precision_manager.get_current_precision()
1337 }
1338
1339 pub fn generate_report(&self) -> ComputeOptimizationReport {
1341 ComputeOptimizationReport {
1342 mixed_precision_report: self.mixed_precision_manager.generate_report(),
1343 dynamic_batching_report: self.dynamic_batching_manager.generate_report(),
1344 kernel_fusion_enabled: self.kernel_fusion_enabled,
1345 pipeline_optimization_enabled: self.pipeline_optimization_enabled,
1346 }
1347 }
1348}
1349
1350#[derive(Debug, Clone, Serialize, Deserialize)]
1352pub struct ComputeOptimizationReport {
1353 pub mixed_precision_report: MixedPrecisionReport,
1354 pub dynamic_batching_report: DynamicBatchingReport,
1355 pub kernel_fusion_enabled: bool,
1356 pub pipeline_optimization_enabled: bool,
1357}
1358
1359#[cfg(test)]
1360mod tests {
1361 use super::*;
1362 use std::collections::HashMap;
1363
1364 #[test]
1365 fn test_mixed_precision_config_default() {
1366 let config = MixedPrecisionConfig::default();
1367 assert!(!config.enabled);
1368 assert_eq!(config.init_scale, 65536.0);
1369 assert_eq!(config.scale_factor, 2.0);
1370 assert_eq!(config.backoff_factor, 0.5);
1371 assert_eq!(config.scale_window, 2000);
1372 }
1373
1374 #[test]
1375 fn test_loss_scaler_creation() {
1376 let config = MixedPrecisionConfig::default();
1377 let scaler = LossScaler::new(config);
1378 assert_eq!(scaler.get_scale(), 1.0); }
1380
1381 #[test]
1382 fn test_loss_scaler_enabled() {
1383 let config = MixedPrecisionConfig {
1384 enabled: true,
1385 ..MixedPrecisionConfig::default()
1386 };
1387 let scaler = LossScaler::new(config);
1388 assert_eq!(scaler.get_scale(), 65536.0);
1389 }
1390
1391 #[test]
1392 fn test_loss_scaling() {
1393 let config = MixedPrecisionConfig {
1394 enabled: true,
1395 ..MixedPrecisionConfig::default()
1396 };
1397 let scaler = LossScaler::new(config);
1398
1399 let loss = Tensor::ones(&[2, 2]).expect("tensor operation failed");
1400 let scaled_loss = scaler.scale_loss(&loss).expect("operation failed in test");
1401
1402 match (&loss, &scaled_loss) {
1403 (Tensor::F32(orig), Tensor::F32(scaled)) => {
1404 assert!((scaled[[0, 0]] / orig[[0, 0]] - 65536.0).abs() < 1e-6);
1406 },
1407 _ => panic!("Unexpected tensor types"),
1408 }
1409 }
1410
1411 #[test]
1412 fn test_gradient_unscaling() {
1413 let config = MixedPrecisionConfig {
1414 enabled: true,
1415 ..MixedPrecisionConfig::default()
1416 };
1417 let scaler = LossScaler::new(config);
1418
1419 let mut gradients = HashMap::new();
1420 gradients.insert(
1421 "param1".to_string(),
1422 Tensor::ones(&[2, 2]).expect("tensor operation failed"),
1423 );
1424
1425 let overflow = scaler.unscale_gradients(&mut gradients).expect("operation failed in test");
1426 assert!(!overflow);
1427
1428 let gradient = gradients.get("param1").expect("expected value not found");
1430 match gradient {
1431 Tensor::F32(arr) => {
1432 assert!((arr[[0, 0]] - 1.0 / 65536.0).abs() < 1e-6);
1433 },
1434 _ => panic!("Unexpected tensor type"),
1435 }
1436 }
1437
1438 #[test]
1439 fn test_amp_manager_creation() {
1440 let config = MixedPrecisionConfig::default();
1441 let manager = AMPManager::new(config);
1442 assert!(!manager.is_enabled());
1443 }
1444
1445 #[test]
1446 fn test_amp_manager_enabled() {
1447 let config = MixedPrecisionConfig {
1448 enabled: true,
1449 ..MixedPrecisionConfig::default()
1450 };
1451 let manager = AMPManager::new(config);
1452 assert!(manager.is_enabled());
1453 assert_eq!(manager.get_loss_scale(), 65536.0);
1454 }
1455
1456 #[test]
1457 fn test_half_precision_conversion() {
1458 let config = utils::default_fp16_config();
1459 let manager = AMPManager::new(config);
1460
1461 let tensor = Tensor::from_vec(vec![1.0, 2.5, -3.7, 1000.0], &[2, 2])
1462 .expect("tensor operation failed");
1463 let half_precision = manager.to_half_precision(&tensor).expect("tensor operation failed");
1464 let full_precision =
1465 manager.to_full_precision(&half_precision).expect("operation failed in test");
1466
1467 match (&tensor, &full_precision) {
1469 (Tensor::F32(orig), Tensor::F32(converted)) => {
1470 for (o, c) in orig.iter().zip(converted.iter()) {
1471 assert!((o - c).abs() < 0.1); }
1473 },
1474 _ => panic!("Unexpected tensor types"),
1475 }
1476 }
1477
1478 #[test]
1479 fn test_fp16_safety_check() {
1480 let safe_tensor =
1481 Tensor::from_vec(vec![1.0, 2.0, -3.0], &[3]).expect("tensor operation failed");
1482 assert!(utils::is_fp16_safe(&safe_tensor).expect("tensor operation failed"));
1483
1484 let unsafe_tensor =
1485 Tensor::from_vec(vec![1.0, 70000.0, -3.0], &[3]).expect("tensor operation failed");
1486 assert!(!utils::is_fp16_safe(&unsafe_tensor).expect("tensor operation failed"));
1487 }
1488
1489 #[test]
1490 fn test_dynamic_range_calculation() {
1491 let tensor =
1492 Tensor::from_vec(vec![1.0, 5.0, -2.0, 3.0], &[2, 2]).expect("tensor operation failed");
1493 let (min_val, max_val) =
1494 utils::calculate_dynamic_range(&tensor).expect("tensor operation failed");
1495 assert_eq!(min_val, -2.0);
1496 assert_eq!(max_val, 5.0);
1497 }
1498
1499 #[test]
1500 fn test_overflow_detection_and_scale_update() {
1501 let config = MixedPrecisionConfig {
1502 enabled: true,
1503 backoff_factor: 0.5,
1504 ..MixedPrecisionConfig::default()
1505 };
1506 let mut scaler = LossScaler::new(config);
1507
1508 let initial_scale = scaler.get_scale();
1509
1510 scaler.update_scale(true).expect("operation failed in test");
1512 assert_eq!(scaler.get_scale(), initial_scale * 0.5);
1513 assert!(scaler.overflow_detected());
1514
1515 scaler.update_scale(false).expect("operation failed in test");
1517 assert!(!scaler.overflow_detected());
1518 }
1519}