Skip to main content

trustformers_training/
mixed_precision.rs

1use anyhow::Result;
2use scirs2_core::Complex; // SciRS2 Integration Policy
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use trustformers_core::tensor::Tensor;
6
7/// Configuration for mixed precision training
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct MixedPrecisionConfig {
10    /// Enable automatic mixed precision
11    pub enabled: bool,
12    /// Initial loss scale value
13    pub init_scale: f32,
14    /// Factor to scale loss by when no overflow is detected
15    pub scale_factor: f32,
16    /// Factor to scale loss by when overflow is detected
17    pub backoff_factor: f32,
18    /// Number of consecutive steps without overflow before increasing scale
19    pub scale_window: usize,
20    /// Minimum loss scale value
21    pub min_scale: f32,
22    /// Maximum loss scale value
23    pub max_scale: f32,
24    /// Skip optimizer update if loss is inf/nan
25    pub skip_inf_nan: bool,
26}
27
28impl Default for MixedPrecisionConfig {
29    fn default() -> Self {
30        Self {
31            enabled: false,
32            init_scale: 2f32.powf(16.0), // 65536
33            scale_factor: 2.0,
34            backoff_factor: 0.5,
35            scale_window: 2000,
36            min_scale: 1.0,
37            max_scale: 2f32.powf(24.0), // 16M
38            skip_inf_nan: true,
39        }
40    }
41}
42
43/// Loss scaler for automatic mixed precision training
44#[derive(Debug, Clone)]
45pub struct LossScaler {
46    config: MixedPrecisionConfig,
47    current_scale: f32,
48    steps_since_overflow: usize,
49    overflow_detected: bool,
50}
51
52impl LossScaler {
53    pub fn new(config: MixedPrecisionConfig) -> Self {
54        Self {
55            current_scale: config.init_scale,
56            steps_since_overflow: 0,
57            overflow_detected: false,
58            config,
59        }
60    }
61
62    /// Get current loss scale
63    pub fn get_scale(&self) -> f32 {
64        if self.config.enabled {
65            self.current_scale
66        } else {
67            1.0
68        }
69    }
70
71    /// Scale the loss tensor for backward pass
72    pub fn scale_loss(&self, loss: &Tensor) -> Result<Tensor> {
73        if !self.config.enabled {
74            return Ok(loss.clone());
75        }
76
77        loss.scalar_mul(self.current_scale).map_err(|e| anyhow::anyhow!(e))
78    }
79
80    /// Unscale gradients after backward pass
81    pub fn unscale_gradients(&self, gradients: &mut HashMap<String, Tensor>) -> Result<bool> {
82        if !self.config.enabled {
83            return Ok(false);
84        }
85
86        let scale = self.current_scale;
87        let mut overflow_detected = false;
88
89        for (_, gradient) in gradients.iter_mut() {
90            // Check for inf/nan values
91            if self.has_inf_nan(gradient)? {
92                overflow_detected = true;
93                if self.config.skip_inf_nan {
94                    break;
95                }
96            }
97
98            // Unscale gradient
99            *gradient = gradient.scalar_mul(1.0 / scale).map_err(|e| anyhow::anyhow!(e))?;
100        }
101
102        Ok(overflow_detected)
103    }
104
105    /// Update loss scale based on overflow detection
106    pub fn update_scale(&mut self, overflow_detected: bool) -> Result<()> {
107        if !self.config.enabled {
108            return Ok(());
109        }
110
111        if overflow_detected {
112            // Decrease scale and reset counter
113            self.current_scale =
114                (self.current_scale * self.config.backoff_factor).max(self.config.min_scale);
115            self.steps_since_overflow = 0;
116            self.overflow_detected = true;
117        } else {
118            // Increase counter and potentially scale up
119            self.steps_since_overflow += 1;
120            self.overflow_detected = false;
121
122            if self.steps_since_overflow >= self.config.scale_window {
123                self.current_scale =
124                    (self.current_scale * self.config.scale_factor).min(self.config.max_scale);
125                self.steps_since_overflow = 0;
126            }
127        }
128
129        Ok(())
130    }
131
132    /// Check if overflow was detected in the last step
133    pub fn overflow_detected(&self) -> bool {
134        self.overflow_detected
135    }
136
137    /// Check for inf/nan values in tensor
138    fn has_inf_nan(&self, tensor: &Tensor) -> Result<bool> {
139        match tensor {
140            Tensor::F32(arr) => {
141                for &value in arr.iter() {
142                    if !value.is_finite() {
143                        return Ok(true);
144                    }
145                }
146                Ok(false)
147            },
148            Tensor::F64(arr) => {
149                for &value in arr.iter() {
150                    if !value.is_finite() {
151                        return Ok(true);
152                    }
153                }
154                Ok(false)
155            },
156            Tensor::F16(arr) => {
157                for &value in arr.iter() {
158                    if !value.to_f32().is_finite() {
159                        return Ok(true);
160                    }
161                }
162                Ok(false)
163            },
164            Tensor::BF16(arr) => {
165                for &value in arr.iter() {
166                    if !value.to_f32().is_finite() {
167                        return Ok(true);
168                    }
169                }
170                Ok(false)
171            },
172            Tensor::I64(_) => Ok(false), // Integer tensors can't have inf/nan
173            Tensor::C32(arr) => {
174                for &value in arr.iter() {
175                    if !value.re.is_finite() || !value.im.is_finite() {
176                        return Ok(true);
177                    }
178                }
179                Ok(false)
180            },
181            Tensor::C64(arr) => {
182                for &value in arr.iter() {
183                    if !value.re.is_finite() || !value.im.is_finite() {
184                        return Ok(true);
185                    }
186                }
187                Ok(false)
188            },
189            Tensor::CF16(arr) => {
190                for &value in arr.iter() {
191                    if !value.re.to_f32().is_finite() || !value.im.to_f32().is_finite() {
192                        return Ok(true);
193                    }
194                }
195                Ok(false)
196            },
197            Tensor::CBF16(arr) => {
198                for &value in arr.iter() {
199                    if !value.re.to_f32().is_finite() || !value.im.to_f32().is_finite() {
200                        return Ok(true);
201                    }
202                }
203                Ok(false)
204            },
205            _ => Ok(false), // Assume sparse and other tensor types are validated
206        }
207    }
208}
209
210/// Automatic Mixed Precision (AMP) manager
211#[derive(Debug)]
212pub struct AMPManager {
213    pub loss_scaler: LossScaler,
214    pub config: MixedPrecisionConfig,
215}
216
217impl AMPManager {
218    pub fn new(config: MixedPrecisionConfig) -> Self {
219        let loss_scaler = LossScaler::new(config.clone());
220        Self {
221            loss_scaler,
222            config,
223        }
224    }
225
226    /// Convert tensor to half precision (fp16 simulation using f32)
227    /// In a real implementation, this would convert to actual fp16
228    pub fn to_half_precision(&self, tensor: &Tensor) -> Result<Tensor> {
229        if !self.config.enabled {
230            return Ok(tensor.clone());
231        }
232
233        // Simulate fp16 precision by quantizing to fp16 range
234        match tensor {
235            Tensor::F32(arr) => {
236                let quantized = arr.mapv(|x| {
237                    // Simulate fp16 precision limitations
238                    let clamped = x.clamp(-65504.0, 65504.0); // fp16 range
239
240                    // Simulate fp16 precision by reducing mantissa bits
241                    // This is a simplified simulation
242
243                    (clamped * 1024.0).round() / 1024.0
244                });
245                Ok(Tensor::F32(quantized))
246            },
247            Tensor::F64(_) => Ok(tensor.clone()),
248            Tensor::F16(_) => Ok(tensor.clone()), // Already fp16 precision
249            Tensor::BF16(_) => Ok(tensor.clone()), // Already reduced precision
250            Tensor::I64(_) => Ok(tensor.clone()),
251            Tensor::C32(arr) => {
252                let quantized = arr.mapv(|x| {
253                    let re_clamped = x.re.clamp(-65504.0, 65504.0);
254                    let im_clamped = x.im.clamp(-65504.0, 65504.0);
255                    let re_scaled = (re_clamped * 1024.0).round() / 1024.0;
256                    let im_scaled = (im_clamped * 1024.0).round() / 1024.0;
257                    Complex::new(re_scaled, im_scaled)
258                });
259                Ok(Tensor::C32(quantized))
260            },
261            Tensor::C64(_) => Ok(tensor.clone()),
262            Tensor::CF16(_) => Ok(tensor.clone()), // Already fp16 precision
263            Tensor::CBF16(_) => Ok(tensor.clone()), // Already reduced precision
264            _ => Ok(tensor.clone()),               // Sparse and other tensor types unchanged
265        }
266    }
267
268    /// Convert tensor back to full precision
269    pub fn to_full_precision(&self, tensor: &Tensor) -> Result<Tensor> {
270        // In fp16 simulation, this is a no-op since we're still using f32
271        Ok(tensor.clone())
272    }
273
274    /// Perform forward pass with automatic mixed precision
275    pub fn forward_with_amp<F>(&self, forward_fn: F) -> Result<Tensor>
276    where
277        F: FnOnce() -> Result<Tensor>,
278    {
279        if !self.config.enabled {
280            return forward_fn();
281        }
282
283        // In a real implementation, this would:
284        // 1. Cast model weights to fp16
285        // 2. Perform forward pass in fp16
286        // 3. Cast output back to fp32 for loss computation
287
288        let output = forward_fn()?;
289        self.to_full_precision(&output)
290    }
291
292    /// Perform backward pass with loss scaling
293    pub fn backward_with_amp(
294        &mut self,
295        loss: &Tensor,
296        gradients: &mut HashMap<String, Tensor>,
297    ) -> Result<bool> {
298        // Scale loss
299        let _scaled_loss = self.loss_scaler.scale_loss(loss)?;
300
301        // Simulate backward pass (in real implementation, this would compute gradients)
302        // For simulation, we assume gradients are already computed and scaled
303
304        // Unscale gradients and check for overflow
305        let overflow = self.loss_scaler.unscale_gradients(gradients)?;
306
307        // Update loss scale
308        self.loss_scaler.update_scale(overflow)?;
309
310        Ok(overflow)
311    }
312
313    /// Get current loss scale
314    pub fn get_loss_scale(&self) -> f32 {
315        self.loss_scaler.get_scale()
316    }
317
318    /// Check if AMP is enabled
319    pub fn is_enabled(&self) -> bool {
320        self.config.enabled
321    }
322}
323
324/// Mixed precision training utilities
325pub mod utils {
326    use super::*;
327
328    /// Create a default AMP configuration for fp16 training
329    pub fn default_fp16_config() -> MixedPrecisionConfig {
330        MixedPrecisionConfig {
331            enabled: true,
332            init_scale: 2f32.powf(16.0),
333            scale_factor: 2.0,
334            backoff_factor: 0.5,
335            scale_window: 2000,
336            min_scale: 1.0,
337            max_scale: 2f32.powf(24.0),
338            skip_inf_nan: true,
339        }
340    }
341
342    /// Create a default AMP configuration for bfloat16 training
343    pub fn default_bf16_config() -> MixedPrecisionConfig {
344        MixedPrecisionConfig {
345            enabled: true,
346            init_scale: 1.0, // bfloat16 doesn't typically need loss scaling
347            scale_factor: 1.0,
348            backoff_factor: 1.0,
349            scale_window: usize::MAX,
350            min_scale: 1.0,
351            max_scale: 1.0,
352            skip_inf_nan: true,
353        }
354    }
355
356    /// Check if tensor values are within fp16 range
357    pub fn is_fp16_safe(tensor: &Tensor) -> Result<bool> {
358        match tensor {
359            Tensor::F32(arr) => {
360                for &value in arr.iter() {
361                    if value.abs() > 65504.0 || (!value.is_finite() && value != 0.0) {
362                        return Ok(false);
363                    }
364                }
365                Ok(true)
366            },
367            Tensor::F64(arr) => {
368                for &value in arr.iter() {
369                    if value.abs() > 65504.0 || (!value.is_finite() && value != 0.0) {
370                        return Ok(false);
371                    }
372                }
373                Ok(true)
374            },
375            Tensor::F16(_) => Ok(true),  // Already fp16, so always safe
376            Tensor::BF16(_) => Ok(true), // BF16 has similar range to fp16
377            Tensor::I64(_) => Ok(true),
378            Tensor::C32(arr) => {
379                for &value in arr.iter() {
380                    if value.re.abs() > 65504.0
381                        || value.im.abs() > 65504.0
382                        || (!value.re.is_finite() && value.re != 0.0)
383                        || (!value.im.is_finite() && value.im != 0.0)
384                    {
385                        return Ok(false);
386                    }
387                }
388                Ok(true)
389            },
390            Tensor::C64(arr) => {
391                for &value in arr.iter() {
392                    if value.re.abs() > 65504.0
393                        || value.im.abs() > 65504.0
394                        || (!value.re.is_finite() && value.re != 0.0)
395                        || (!value.im.is_finite() && value.im != 0.0)
396                    {
397                        return Ok(false);
398                    }
399                }
400                Ok(true)
401            },
402            Tensor::CF16(_) => Ok(true),  // Already fp16, so always safe
403            Tensor::CBF16(_) => Ok(true), // BF16 has similar range to fp16
404            _ => Ok(true),                // Assume sparse and other tensor types are safe
405        }
406    }
407
408    /// Calculate the dynamic range of a tensor
409    pub fn calculate_dynamic_range(tensor: &Tensor) -> Result<(f32, f32)> {
410        match tensor {
411            Tensor::F32(arr) => {
412                let mut min_val = f32::INFINITY;
413                let mut max_val = f32::NEG_INFINITY;
414
415                for &value in arr.iter() {
416                    if value.is_finite() {
417                        min_val = min_val.min(value);
418                        max_val = max_val.max(value);
419                    }
420                }
421
422                Ok((min_val, max_val))
423            },
424            Tensor::F64(arr) => {
425                let min_val = arr
426                    .iter()
427                    .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
428                    .copied()
429                    .unwrap_or(0.0) as f32;
430                let max_val = arr
431                    .iter()
432                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
433                    .copied()
434                    .unwrap_or(0.0) as f32;
435                Ok((min_val, max_val))
436            },
437            Tensor::F16(arr) => {
438                let mut min_val = f32::INFINITY;
439                let mut max_val = f32::NEG_INFINITY;
440
441                for &value in arr.iter() {
442                    let f32_val = value.to_f32();
443                    if f32_val.is_finite() {
444                        min_val = min_val.min(f32_val);
445                        max_val = max_val.max(f32_val);
446                    }
447                }
448
449                Ok((min_val, max_val))
450            },
451            Tensor::BF16(arr) => {
452                let mut min_val = f32::INFINITY;
453                let mut max_val = f32::NEG_INFINITY;
454
455                for &value in arr.iter() {
456                    let f32_val = value.to_f32();
457                    if f32_val.is_finite() {
458                        min_val = min_val.min(f32_val);
459                        max_val = max_val.max(f32_val);
460                    }
461                }
462
463                Ok((min_val, max_val))
464            },
465            Tensor::I64(arr) => {
466                let min_val = arr.iter().min().copied().unwrap_or(0) as f32;
467                let max_val = arr.iter().max().copied().unwrap_or(0) as f32;
468                Ok((min_val, max_val))
469            },
470            Tensor::C32(arr) => {
471                let mut min_val = f32::INFINITY;
472                let mut max_val = f32::NEG_INFINITY;
473
474                for &value in arr.iter() {
475                    let magnitude = value.norm();
476                    if magnitude.is_finite() {
477                        min_val = min_val.min(magnitude);
478                        max_val = max_val.max(magnitude);
479                    }
480                }
481
482                Ok((min_val, max_val))
483            },
484            Tensor::C64(arr) => {
485                let mut min_val = f32::INFINITY;
486                let mut max_val = f32::NEG_INFINITY;
487
488                for &value in arr.iter() {
489                    let magnitude = value.norm() as f32;
490                    if magnitude.is_finite() {
491                        min_val = min_val.min(magnitude);
492                        max_val = max_val.max(magnitude);
493                    }
494                }
495
496                Ok((min_val, max_val))
497            },
498            Tensor::CF16(arr) => {
499                let mut min_val = f32::INFINITY;
500                let mut max_val = f32::NEG_INFINITY;
501
502                for &value in arr.iter() {
503                    let magnitude = (value.re.to_f32().powi(2) + value.im.to_f32().powi(2)).sqrt();
504                    if magnitude.is_finite() {
505                        min_val = min_val.min(magnitude);
506                        max_val = max_val.max(magnitude);
507                    }
508                }
509
510                Ok((min_val, max_val))
511            },
512            Tensor::CBF16(arr) => {
513                let mut min_val = f32::INFINITY;
514                let mut max_val = f32::NEG_INFINITY;
515
516                for &value in arr.iter() {
517                    let magnitude = (value.re.to_f32().powi(2) + value.im.to_f32().powi(2)).sqrt();
518                    if magnitude.is_finite() {
519                        min_val = min_val.min(magnitude);
520                        max_val = max_val.max(magnitude);
521                    }
522                }
523
524                Ok((min_val, max_val))
525            },
526            _ => Ok((0.0, 1.0)), // Default range for sparse and other tensor types
527        }
528    }
529}
530
531/// Advanced mixed precision training enhancements
532#[derive(Debug, Clone, Serialize, Deserialize)]
533pub struct AdvancedMixedPrecisionConfig {
534    /// Base mixed precision config
535    pub base_config: MixedPrecisionConfig,
536    /// Enable dynamic loss scaling
537    pub enable_dynamic_scaling: bool,
538    /// Enable gradient scaling per layer
539    pub enable_per_layer_scaling: bool,
540    /// Enable automatic precision selection
541    pub enable_auto_precision: bool,
542    /// Minimum precision (fp16, bf16, fp32)
543    pub min_precision: String,
544    /// Maximum precision (fp16, bf16, fp32)
545    pub max_precision: String,
546    /// Precision adaptation rate
547    pub precision_adaptation_rate: f32,
548    /// Memory usage threshold for precision switching
549    pub memory_threshold: f32,
550    /// Performance threshold for precision switching
551    pub performance_threshold: f32,
552}
553
554impl Default for AdvancedMixedPrecisionConfig {
555    fn default() -> Self {
556        Self {
557            base_config: MixedPrecisionConfig::default(),
558            enable_dynamic_scaling: true,
559            enable_per_layer_scaling: true,
560            enable_auto_precision: true,
561            min_precision: "fp16".to_string(),
562            max_precision: "fp32".to_string(),
563            precision_adaptation_rate: 0.1,
564            memory_threshold: 0.8,
565            performance_threshold: 0.9,
566        }
567    }
568}
569
570/// Per-layer scaling configuration
571#[derive(Debug, Clone)]
572pub struct LayerScalingConfig {
573    /// Layer name
574    pub layer_name: String,
575    /// Current scale factor
576    pub scale_factor: f32,
577    /// Gradient norm history
578    pub gradient_norm_history: Vec<f32>,
579    /// Loss history for this layer
580    pub loss_history: Vec<f32>,
581    /// Overflow count
582    pub overflow_count: usize,
583    /// Underflow count
584    pub underflow_count: usize,
585}
586
587impl LayerScalingConfig {
588    pub fn new(layer_name: String) -> Self {
589        Self {
590            layer_name,
591            scale_factor: 1.0,
592            gradient_norm_history: Vec::new(),
593            loss_history: Vec::new(),
594            overflow_count: 0,
595            underflow_count: 0,
596        }
597    }
598}
599
600/// Advanced mixed precision manager
601#[derive(Debug)]
602pub struct AdvancedMixedPrecisionManager {
603    config: AdvancedMixedPrecisionConfig,
604    base_manager: AMPManager,
605    layer_configs: HashMap<String, LayerScalingConfig>,
606    current_precision: String,
607    precision_history: Vec<(usize, String)>,
608    memory_usage_history: Vec<f32>,
609    performance_history: Vec<f32>,
610    step_count: usize,
611}
612
613impl AdvancedMixedPrecisionManager {
614    pub fn new(config: AdvancedMixedPrecisionConfig) -> Self {
615        let base_manager = AMPManager::new(config.base_config.clone());
616        Self {
617            config,
618            base_manager,
619            layer_configs: HashMap::new(),
620            current_precision: "fp32".to_string(),
621            precision_history: Vec::new(),
622            memory_usage_history: Vec::new(),
623            performance_history: Vec::new(),
624            step_count: 0,
625        }
626    }
627
628    /// Update with training step information
629    pub fn update_step(&mut self, memory_usage: f32, performance_score: f32) {
630        self.step_count += 1;
631        self.memory_usage_history.push(memory_usage);
632        self.performance_history.push(performance_score);
633
634        // Keep only recent history
635        if self.memory_usage_history.len() > 100 {
636            self.memory_usage_history.remove(0);
637            self.performance_history.remove(0);
638        }
639
640        // Adapt precision if enabled
641        if self.config.enable_auto_precision {
642            self.adapt_precision();
643        }
644    }
645
646    /// Adapt precision based on memory usage and performance
647    fn adapt_precision(&mut self) {
648        let avg_memory =
649            self.memory_usage_history.iter().sum::<f32>() / self.memory_usage_history.len() as f32;
650        let avg_performance =
651            self.performance_history.iter().sum::<f32>() / self.performance_history.len() as f32;
652
653        let target_precision = if avg_memory > self.config.memory_threshold {
654            // High memory usage, switch to lower precision
655            match self.current_precision.as_str() {
656                "fp32" => "fp16",
657                "bf16" => "fp16",
658                _ => "fp16",
659            }
660        } else if avg_performance < self.config.performance_threshold {
661            // Low performance, switch to higher precision
662            match self.current_precision.as_str() {
663                "fp16" => "bf16",
664                "bf16" => "fp32",
665                _ => "fp32",
666            }
667        } else {
668            &self.current_precision
669        };
670
671        if target_precision != self.current_precision {
672            self.switch_precision(target_precision.to_string());
673        }
674    }
675
676    /// Switch to a different precision
677    fn switch_precision(&mut self, new_precision: String) {
678        self.current_precision = new_precision.clone();
679        self.precision_history.push((self.step_count, new_precision));
680
681        // Update base manager configuration
682        match self.current_precision.as_str() {
683            "fp16" => {
684                self.base_manager.config = utils::default_fp16_config();
685            },
686            "bf16" => {
687                self.base_manager.config = utils::default_bf16_config();
688            },
689            "fp32" => {
690                self.base_manager.config = MixedPrecisionConfig {
691                    enabled: false,
692                    ..Default::default()
693                };
694            },
695            _ => {
696                self.base_manager.config = utils::default_fp16_config();
697            },
698        }
699    }
700
701    /// Scale gradients with per-layer scaling
702    pub fn scale_gradients_per_layer(
703        &mut self,
704        gradients: &mut HashMap<String, Tensor>,
705    ) -> Result<bool> {
706        let mut global_overflow = false;
707
708        for (layer_name, gradient) in gradients.iter_mut() {
709            // Get or create layer config
710            if !self.layer_configs.contains_key(layer_name) {
711                self.layer_configs.insert(
712                    layer_name.clone(),
713                    LayerScalingConfig::new(layer_name.clone()),
714                );
715            }
716
717            // Compute gradient norm first
718            let grad_norm = self.compute_gradient_norm(gradient)?;
719
720            // Get scale factor before mutably borrowing
721            let enable_per_layer_scaling = self.config.enable_per_layer_scaling;
722
723            let layer_config = self
724                .layer_configs
725                .get_mut(layer_name)
726                .expect("layer config should exist after insertion at line 710-715");
727            layer_config.gradient_norm_history.push(grad_norm);
728
729            // Keep only recent history
730            if layer_config.gradient_norm_history.len() > 50 {
731                layer_config.gradient_norm_history.remove(0);
732            }
733
734            // Adapt layer-specific scaling
735            if enable_per_layer_scaling {
736                // Inline the adaptation logic to avoid borrow issues
737                if layer_config.gradient_norm_history.len() >= 5 {
738                    let recent_norms: Vec<f32> =
739                        layer_config.gradient_norm_history.iter().rev().take(5).cloned().collect();
740
741                    let avg_norm = recent_norms.iter().sum::<f32>() / recent_norms.len() as f32;
742
743                    // Adjust scaling based on gradient norm
744                    if avg_norm > 10.0 {
745                        // Large gradients, increase scaling
746                        layer_config.scale_factor *= 1.1;
747                    } else if avg_norm < 0.01 {
748                        // Small gradients, decrease scaling
749                        layer_config.scale_factor *= 0.9;
750                    }
751
752                    // Clamp scale factor to reasonable range
753                    layer_config.scale_factor = layer_config.scale_factor.clamp(0.01, 1000.0);
754                }
755            }
756
757            let scale_factor = layer_config.scale_factor;
758
759            // Drop the mutable borrow before calling scale_tensor
760            let _ = layer_config;
761
762            // Scale gradient using extracted scale factor
763            *gradient = self.scale_tensor(gradient, scale_factor)?;
764
765            // Check for overflow and update counters
766            let has_overflow = self.has_overflow(gradient)?;
767
768            // Re-acquire mutable borrow to update counters
769            let layer_config = self
770                .layer_configs
771                .get_mut(layer_name)
772                .expect("layer config should exist after insertion at line 710-715");
773            if has_overflow {
774                layer_config.overflow_count += 1;
775                global_overflow = true;
776            } else {
777                layer_config.underflow_count += 1;
778            }
779        }
780
781        Ok(global_overflow)
782    }
783
784    /// Adapt scaling for a specific layer
785    #[allow(dead_code)]
786    fn adapt_layer_scaling(&mut self, layer_config: &mut LayerScalingConfig) {
787        if layer_config.gradient_norm_history.len() < 5 {
788            return;
789        }
790
791        let recent_norms: Vec<f32> =
792            layer_config.gradient_norm_history.iter().rev().take(5).cloned().collect();
793
794        let avg_norm = recent_norms.iter().sum::<f32>() / recent_norms.len() as f32;
795
796        // Adjust scaling based on gradient norm
797        if avg_norm > 10.0 {
798            // Large gradients, increase scaling
799            layer_config.scale_factor *= 1.1;
800        } else if avg_norm < 0.01 {
801            // Small gradients, decrease scaling
802            layer_config.scale_factor *= 0.9;
803        }
804
805        // Clamp scaling factor
806        layer_config.scale_factor = layer_config.scale_factor.clamp(0.1, 10.0);
807    }
808
809    /// Compute gradient norm
810    fn compute_gradient_norm(&self, gradient: &Tensor) -> Result<f32> {
811        match gradient {
812            Tensor::F32(arr) => {
813                let norm = arr.iter().map(|&x| x * x).sum::<f32>().sqrt();
814                Ok(norm)
815            },
816            Tensor::F64(arr) => {
817                let norm = arr.iter().map(|&x| x * x).sum::<f64>().sqrt() as f32;
818                Ok(norm)
819            },
820            Tensor::F16(arr) => {
821                let norm = arr
822                    .iter()
823                    .map(|&x| {
824                        let f32_val = x.to_f32();
825                        f32_val * f32_val
826                    })
827                    .sum::<f32>()
828                    .sqrt();
829                Ok(norm)
830            },
831            Tensor::BF16(arr) => {
832                let norm = arr
833                    .iter()
834                    .map(|&x| {
835                        let f32_val = x.to_f32();
836                        f32_val * f32_val
837                    })
838                    .sum::<f32>()
839                    .sqrt();
840                Ok(norm)
841            },
842            Tensor::I64(_) => Ok(0.0),
843            Tensor::C32(arr) => {
844                let norm = arr.iter().map(|&x| x.norm_sqr()).sum::<f32>().sqrt();
845                Ok(norm)
846            },
847            Tensor::C64(arr) => {
848                let norm = arr.iter().map(|&x| x.norm_sqr() as f32).sum::<f32>().sqrt();
849                Ok(norm)
850            },
851            Tensor::CF16(arr) => {
852                let norm = arr
853                    .iter()
854                    .map(|&x| {
855                        let re = x.re.to_f32();
856                        let im = x.im.to_f32();
857                        re * re + im * im
858                    })
859                    .sum::<f32>()
860                    .sqrt();
861                Ok(norm)
862            },
863            Tensor::CBF16(arr) => {
864                let norm = arr
865                    .iter()
866                    .map(|&x| {
867                        let re = x.re.to_f32();
868                        let im = x.im.to_f32();
869                        re * re + im * im
870                    })
871                    .sum::<f32>()
872                    .sqrt();
873                Ok(norm)
874            },
875            _ => Ok(1.0), // Default norm for sparse and other tensor types
876        }
877    }
878
879    /// Scale tensor by factor
880    fn scale_tensor(&self, tensor: &Tensor, factor: f32) -> Result<Tensor> {
881        match tensor {
882            Tensor::F32(arr) => {
883                let scaled = arr.mapv(|x| x * factor);
884                Ok(Tensor::F32(scaled))
885            },
886            Tensor::F64(arr) => {
887                let scaled = arr.mapv(|x| x * factor as f64);
888                Ok(Tensor::F64(scaled))
889            },
890            Tensor::F16(arr) => {
891                let factor_f16 = half::f16::from_f32(factor);
892                let scaled = arr.mapv(|x| x * factor_f16);
893                Ok(Tensor::F16(scaled))
894            },
895            Tensor::BF16(arr) => {
896                let factor_bf16 = half::bf16::from_f32(factor);
897                let scaled = arr.mapv(|x| x * factor_bf16);
898                Ok(Tensor::BF16(scaled))
899            },
900            Tensor::I64(arr) => Ok(Tensor::I64(arr.clone())),
901            Tensor::C32(arr) => {
902                let scaled = arr.mapv(|x| x * factor);
903                Ok(Tensor::C32(scaled))
904            },
905            Tensor::C64(arr) => {
906                let scaled = arr.mapv(|x| x * factor as f64);
907                Ok(Tensor::C64(scaled))
908            },
909            Tensor::CF16(arr) => {
910                let factor_f16 = half::f16::from_f32(factor);
911                let scaled = arr.mapv(|x| Complex::new(x.re * factor_f16, x.im * factor_f16));
912                Ok(Tensor::CF16(scaled))
913            },
914            Tensor::CBF16(arr) => {
915                let factor_bf16 = half::bf16::from_f32(factor);
916                let scaled = arr.mapv(|x| Complex::new(x.re * factor_bf16, x.im * factor_bf16));
917                Ok(Tensor::CBF16(scaled))
918            },
919            _ => Ok(tensor.clone()), // Don't scale sparse and other tensor types
920        }
921    }
922
923    /// Check for overflow in tensor
924    fn has_overflow(&self, tensor: &Tensor) -> Result<bool> {
925        match tensor {
926            Tensor::F32(arr) => {
927                for &value in arr.iter() {
928                    if !value.is_finite() {
929                        return Ok(true);
930                    }
931                }
932                Ok(false)
933            },
934            Tensor::F64(arr) => {
935                for &value in arr.iter() {
936                    if !value.is_finite() {
937                        return Ok(true);
938                    }
939                }
940                Ok(false)
941            },
942            Tensor::F16(arr) => {
943                for &value in arr.iter() {
944                    if !value.to_f32().is_finite() {
945                        return Ok(true);
946                    }
947                }
948                Ok(false)
949            },
950            Tensor::BF16(arr) => {
951                for &value in arr.iter() {
952                    if !value.to_f32().is_finite() {
953                        return Ok(true);
954                    }
955                }
956                Ok(false)
957            },
958            Tensor::I64(_) => Ok(false),
959            Tensor::C32(arr) => {
960                for &value in arr.iter() {
961                    if !value.re.is_finite() || !value.im.is_finite() {
962                        return Ok(true);
963                    }
964                }
965                Ok(false)
966            },
967            Tensor::C64(arr) => {
968                for &value in arr.iter() {
969                    if !value.re.is_finite() || !value.im.is_finite() {
970                        return Ok(true);
971                    }
972                }
973                Ok(false)
974            },
975            Tensor::CF16(arr) => {
976                for &value in arr.iter() {
977                    if !value.re.to_f32().is_finite() || !value.im.to_f32().is_finite() {
978                        return Ok(true);
979                    }
980                }
981                Ok(false)
982            },
983            Tensor::CBF16(arr) => {
984                for &value in arr.iter() {
985                    if !value.re.to_f32().is_finite() || !value.im.to_f32().is_finite() {
986                        return Ok(true);
987                    }
988                }
989                Ok(false)
990            },
991            _ => Ok(false), // Assume sparse and other tensor types don't overflow
992        }
993    }
994
995    /// Get current precision
996    pub fn get_current_precision(&self) -> &str {
997        &self.current_precision
998    }
999
1000    /// Get precision history
1001    pub fn get_precision_history(&self) -> &[(usize, String)] {
1002        &self.precision_history
1003    }
1004
1005    /// Get layer configurations
1006    pub fn get_layer_configs(&self) -> &HashMap<String, LayerScalingConfig> {
1007        &self.layer_configs
1008    }
1009
1010    /// Forward pass with advanced mixed precision
1011    pub fn forward_with_advanced_amp<F>(&mut self, forward_fn: F) -> Result<Tensor>
1012    where
1013        F: FnOnce() -> Result<Tensor>,
1014    {
1015        let output = forward_fn()?;
1016
1017        // Apply precision-specific optimizations
1018        match self.current_precision.as_str() {
1019            "fp16" => self.optimize_for_fp16(&output),
1020            "bf16" => self.optimize_for_bf16(&output),
1021            "fp32" => Ok(output),
1022            _ => Ok(output),
1023        }
1024    }
1025
1026    /// Optimize tensor for fp16
1027    fn optimize_for_fp16(&self, tensor: &Tensor) -> Result<Tensor> {
1028        match tensor {
1029            Tensor::F32(arr) => {
1030                let optimized = arr.mapv(|x| {
1031                    // Apply fp16 optimizations
1032                    let clamped = x.clamp(-65504.0, 65504.0);
1033
1034                    (clamped * 1024.0).round() / 1024.0
1035                });
1036                Ok(Tensor::F32(optimized))
1037            },
1038            _ => Ok(tensor.clone()),
1039        }
1040    }
1041
1042    /// Optimize tensor for bf16
1043    fn optimize_for_bf16(&self, tensor: &Tensor) -> Result<Tensor> {
1044        match tensor {
1045            Tensor::F32(arr) => {
1046                let optimized = arr.mapv(|x| {
1047                    // Apply bf16 optimizations (wider range, lower precision)
1048
1049                    (x * 128.0).round() / 128.0
1050                });
1051                Ok(Tensor::F32(optimized))
1052            },
1053            _ => Ok(tensor.clone()),
1054        }
1055    }
1056
1057    /// Generate mixed precision report
1058    pub fn generate_report(&self) -> MixedPrecisionReport {
1059        let total_overflows = self.layer_configs.values().map(|config| config.overflow_count).sum();
1060
1061        let total_underflows =
1062            self.layer_configs.values().map(|config| config.underflow_count).sum();
1063
1064        let avg_memory_usage = if !self.memory_usage_history.is_empty() {
1065            self.memory_usage_history.iter().sum::<f32>() / self.memory_usage_history.len() as f32
1066        } else {
1067            0.0
1068        };
1069
1070        let avg_performance = if !self.performance_history.is_empty() {
1071            self.performance_history.iter().sum::<f32>() / self.performance_history.len() as f32
1072        } else {
1073            0.0
1074        };
1075
1076        MixedPrecisionReport {
1077            current_precision: self.current_precision.clone(),
1078            step_count: self.step_count,
1079            total_overflows,
1080            total_underflows,
1081            avg_memory_usage,
1082            avg_performance,
1083            precision_switches: self.precision_history.len(),
1084            layer_count: self.layer_configs.len(),
1085            recommendations: self.generate_recommendations(),
1086        }
1087    }
1088
1089    /// Generate recommendations for mixed precision training
1090    fn generate_recommendations(&self) -> Vec<String> {
1091        let mut recommendations = Vec::new();
1092
1093        // Check overflow rates
1094        let total_overflows =
1095            self.layer_configs.values().map(|config| config.overflow_count).sum::<usize>();
1096
1097        if total_overflows > self.step_count / 10 {
1098            recommendations
1099                .push("High overflow rate detected - consider reducing learning rate".to_string());
1100        }
1101
1102        // Check memory usage
1103        let avg_memory = if !self.memory_usage_history.is_empty() {
1104            self.memory_usage_history.iter().sum::<f32>() / self.memory_usage_history.len() as f32
1105        } else {
1106            0.0
1107        };
1108
1109        if avg_memory > 0.9 {
1110            recommendations
1111                .push("High memory usage - consider using fp16 or reducing batch size".to_string());
1112        }
1113
1114        // Check performance
1115        let avg_performance = if !self.performance_history.is_empty() {
1116            self.performance_history.iter().sum::<f32>() / self.performance_history.len() as f32
1117        } else {
1118            0.0
1119        };
1120
1121        if avg_performance < 0.5 {
1122            recommendations.push(
1123                "Low performance - consider using higher precision or adjusting hyperparameters"
1124                    .to_string(),
1125            );
1126        }
1127
1128        // Check precision switches
1129        if self.precision_history.len() > 10 {
1130            recommendations.push(
1131                "Frequent precision switches - consider adjusting adaptation thresholds"
1132                    .to_string(),
1133            );
1134        }
1135
1136        if recommendations.is_empty() {
1137            recommendations.push("Mixed precision training is working well".to_string());
1138        }
1139
1140        recommendations
1141    }
1142}
1143
1144/// Mixed precision training report
1145#[derive(Debug, Clone, Serialize, Deserialize)]
1146pub struct MixedPrecisionReport {
1147    pub current_precision: String,
1148    pub step_count: usize,
1149    pub total_overflows: usize,
1150    pub total_underflows: usize,
1151    pub avg_memory_usage: f32,
1152    pub avg_performance: f32,
1153    pub precision_switches: usize,
1154    pub layer_count: usize,
1155    pub recommendations: Vec<String>,
1156}
1157
1158/// Dynamic batching strategy
1159#[derive(Debug, Clone, Serialize, Deserialize)]
1160pub struct DynamicBatchingConfig {
1161    /// Initial batch size
1162    pub initial_batch_size: usize,
1163    /// Maximum batch size
1164    pub max_batch_size: usize,
1165    /// Minimum batch size
1166    pub min_batch_size: usize,
1167    /// Batch size adaptation rate
1168    pub adaptation_rate: f32,
1169    /// Memory threshold for batch size reduction
1170    pub memory_threshold: f32,
1171    /// Performance threshold for batch size increase
1172    pub performance_threshold: f32,
1173}
1174
1175impl Default for DynamicBatchingConfig {
1176    fn default() -> Self {
1177        Self {
1178            initial_batch_size: 32,
1179            max_batch_size: 128,
1180            min_batch_size: 8,
1181            adaptation_rate: 0.1,
1182            memory_threshold: 0.85,
1183            performance_threshold: 0.9,
1184        }
1185    }
1186}
1187
1188/// Dynamic batching manager
1189#[derive(Debug)]
1190pub struct DynamicBatchingManager {
1191    config: DynamicBatchingConfig,
1192    current_batch_size: usize,
1193    batch_size_history: Vec<(usize, usize)>,
1194    memory_usage_history: Vec<f32>,
1195    performance_history: Vec<f32>,
1196    step_count: usize,
1197}
1198
1199impl DynamicBatchingManager {
1200    pub fn new(config: DynamicBatchingConfig) -> Self {
1201        Self {
1202            current_batch_size: config.initial_batch_size,
1203            config,
1204            batch_size_history: Vec::new(),
1205            memory_usage_history: Vec::new(),
1206            performance_history: Vec::new(),
1207            step_count: 0,
1208        }
1209    }
1210
1211    /// Update with training step information
1212    pub fn update_step(&mut self, memory_usage: f32, performance_score: f32) {
1213        self.step_count += 1;
1214        self.memory_usage_history.push(memory_usage);
1215        self.performance_history.push(performance_score);
1216
1217        // Keep only recent history
1218        if self.memory_usage_history.len() > 50 {
1219            self.memory_usage_history.remove(0);
1220            self.performance_history.remove(0);
1221        }
1222
1223        // Adapt batch size
1224        self.adapt_batch_size();
1225    }
1226
1227    /// Adapt batch size based on memory usage and performance
1228    fn adapt_batch_size(&mut self) {
1229        let avg_memory =
1230            self.memory_usage_history.iter().sum::<f32>() / self.memory_usage_history.len() as f32;
1231        let avg_performance =
1232            self.performance_history.iter().sum::<f32>() / self.performance_history.len() as f32;
1233
1234        let old_batch_size = self.current_batch_size;
1235
1236        if avg_memory > self.config.memory_threshold {
1237            // Reduce batch size
1238            let reduction = (self.current_batch_size as f32 * self.config.adaptation_rate) as usize;
1239            self.current_batch_size =
1240                (self.current_batch_size - reduction).max(self.config.min_batch_size);
1241        } else if avg_performance > self.config.performance_threshold {
1242            // Increase batch size
1243            let increase = (self.current_batch_size as f32 * self.config.adaptation_rate) as usize;
1244            self.current_batch_size =
1245                (self.current_batch_size + increase).min(self.config.max_batch_size);
1246        }
1247
1248        if self.current_batch_size != old_batch_size {
1249            self.batch_size_history.push((self.step_count, self.current_batch_size));
1250        }
1251    }
1252
1253    /// Get current batch size
1254    pub fn get_current_batch_size(&self) -> usize {
1255        self.current_batch_size
1256    }
1257
1258    /// Get batch size history
1259    pub fn get_batch_size_history(&self) -> &[(usize, usize)] {
1260        &self.batch_size_history
1261    }
1262
1263    /// Generate batching report
1264    pub fn generate_report(&self) -> DynamicBatchingReport {
1265        let avg_memory = if !self.memory_usage_history.is_empty() {
1266            self.memory_usage_history.iter().sum::<f32>() / self.memory_usage_history.len() as f32
1267        } else {
1268            0.0
1269        };
1270
1271        let avg_performance = if !self.performance_history.is_empty() {
1272            self.performance_history.iter().sum::<f32>() / self.performance_history.len() as f32
1273        } else {
1274            0.0
1275        };
1276
1277        DynamicBatchingReport {
1278            current_batch_size: self.current_batch_size,
1279            step_count: self.step_count,
1280            avg_memory_usage: avg_memory,
1281            avg_performance,
1282            batch_size_changes: self.batch_size_history.len(),
1283            memory_efficiency: 1.0 - avg_memory,
1284            performance_score: avg_performance,
1285        }
1286    }
1287}
1288
1289/// Dynamic batching report
1290#[derive(Debug, Clone, Serialize, Deserialize)]
1291pub struct DynamicBatchingReport {
1292    pub current_batch_size: usize,
1293    pub step_count: usize,
1294    pub avg_memory_usage: f32,
1295    pub avg_performance: f32,
1296    pub batch_size_changes: usize,
1297    pub memory_efficiency: f32,
1298    pub performance_score: f32,
1299}
1300
1301/// Compute optimization manager
1302#[derive(Debug)]
1303pub struct ComputeOptimizationManager {
1304    mixed_precision_manager: AdvancedMixedPrecisionManager,
1305    dynamic_batching_manager: DynamicBatchingManager,
1306    kernel_fusion_enabled: bool,
1307    pipeline_optimization_enabled: bool,
1308}
1309
1310impl ComputeOptimizationManager {
1311    pub fn new(
1312        mixed_precision_config: AdvancedMixedPrecisionConfig,
1313        dynamic_batching_config: DynamicBatchingConfig,
1314    ) -> Self {
1315        Self {
1316            mixed_precision_manager: AdvancedMixedPrecisionManager::new(mixed_precision_config),
1317            dynamic_batching_manager: DynamicBatchingManager::new(dynamic_batching_config),
1318            kernel_fusion_enabled: true,
1319            pipeline_optimization_enabled: true,
1320        }
1321    }
1322
1323    /// Update with training step information
1324    pub fn update_step(&mut self, memory_usage: f32, performance_score: f32) {
1325        self.mixed_precision_manager.update_step(memory_usage, performance_score);
1326        self.dynamic_batching_manager.update_step(memory_usage, performance_score);
1327    }
1328
1329    /// Get current batch size
1330    pub fn get_current_batch_size(&self) -> usize {
1331        self.dynamic_batching_manager.get_current_batch_size()
1332    }
1333
1334    /// Get current precision
1335    pub fn get_current_precision(&self) -> &str {
1336        self.mixed_precision_manager.get_current_precision()
1337    }
1338
1339    /// Generate comprehensive optimization report
1340    pub fn generate_report(&self) -> ComputeOptimizationReport {
1341        ComputeOptimizationReport {
1342            mixed_precision_report: self.mixed_precision_manager.generate_report(),
1343            dynamic_batching_report: self.dynamic_batching_manager.generate_report(),
1344            kernel_fusion_enabled: self.kernel_fusion_enabled,
1345            pipeline_optimization_enabled: self.pipeline_optimization_enabled,
1346        }
1347    }
1348}
1349
1350/// Comprehensive compute optimization report
1351#[derive(Debug, Clone, Serialize, Deserialize)]
1352pub struct ComputeOptimizationReport {
1353    pub mixed_precision_report: MixedPrecisionReport,
1354    pub dynamic_batching_report: DynamicBatchingReport,
1355    pub kernel_fusion_enabled: bool,
1356    pub pipeline_optimization_enabled: bool,
1357}
1358
1359#[cfg(test)]
1360mod tests {
1361    use super::*;
1362    use std::collections::HashMap;
1363
1364    #[test]
1365    fn test_mixed_precision_config_default() {
1366        let config = MixedPrecisionConfig::default();
1367        assert!(!config.enabled);
1368        assert_eq!(config.init_scale, 65536.0);
1369        assert_eq!(config.scale_factor, 2.0);
1370        assert_eq!(config.backoff_factor, 0.5);
1371        assert_eq!(config.scale_window, 2000);
1372    }
1373
1374    #[test]
1375    fn test_loss_scaler_creation() {
1376        let config = MixedPrecisionConfig::default();
1377        let scaler = LossScaler::new(config);
1378        assert_eq!(scaler.get_scale(), 1.0); // Disabled by default
1379    }
1380
1381    #[test]
1382    fn test_loss_scaler_enabled() {
1383        let config = MixedPrecisionConfig {
1384            enabled: true,
1385            ..MixedPrecisionConfig::default()
1386        };
1387        let scaler = LossScaler::new(config);
1388        assert_eq!(scaler.get_scale(), 65536.0);
1389    }
1390
1391    #[test]
1392    fn test_loss_scaling() {
1393        let config = MixedPrecisionConfig {
1394            enabled: true,
1395            ..MixedPrecisionConfig::default()
1396        };
1397        let scaler = LossScaler::new(config);
1398
1399        let loss = Tensor::ones(&[2, 2]).expect("tensor operation failed");
1400        let scaled_loss = scaler.scale_loss(&loss).expect("operation failed in test");
1401
1402        match (&loss, &scaled_loss) {
1403            (Tensor::F32(orig), Tensor::F32(scaled)) => {
1404                // Values should be scaled by the loss scale factor
1405                assert!((scaled[[0, 0]] / orig[[0, 0]] - 65536.0).abs() < 1e-6);
1406            },
1407            _ => panic!("Unexpected tensor types"),
1408        }
1409    }
1410
1411    #[test]
1412    fn test_gradient_unscaling() {
1413        let config = MixedPrecisionConfig {
1414            enabled: true,
1415            ..MixedPrecisionConfig::default()
1416        };
1417        let scaler = LossScaler::new(config);
1418
1419        let mut gradients = HashMap::new();
1420        gradients.insert(
1421            "param1".to_string(),
1422            Tensor::ones(&[2, 2]).expect("tensor operation failed"),
1423        );
1424
1425        let overflow = scaler.unscale_gradients(&mut gradients).expect("operation failed in test");
1426        assert!(!overflow);
1427
1428        // Check that gradients are unscaled
1429        let gradient = gradients.get("param1").expect("expected value not found");
1430        match gradient {
1431            Tensor::F32(arr) => {
1432                assert!((arr[[0, 0]] - 1.0 / 65536.0).abs() < 1e-6);
1433            },
1434            _ => panic!("Unexpected tensor type"),
1435        }
1436    }
1437
1438    #[test]
1439    fn test_amp_manager_creation() {
1440        let config = MixedPrecisionConfig::default();
1441        let manager = AMPManager::new(config);
1442        assert!(!manager.is_enabled());
1443    }
1444
1445    #[test]
1446    fn test_amp_manager_enabled() {
1447        let config = MixedPrecisionConfig {
1448            enabled: true,
1449            ..MixedPrecisionConfig::default()
1450        };
1451        let manager = AMPManager::new(config);
1452        assert!(manager.is_enabled());
1453        assert_eq!(manager.get_loss_scale(), 65536.0);
1454    }
1455
1456    #[test]
1457    fn test_half_precision_conversion() {
1458        let config = utils::default_fp16_config();
1459        let manager = AMPManager::new(config);
1460
1461        let tensor = Tensor::from_vec(vec![1.0, 2.5, -3.7, 1000.0], &[2, 2])
1462            .expect("tensor operation failed");
1463        let half_precision = manager.to_half_precision(&tensor).expect("tensor operation failed");
1464        let full_precision =
1465            manager.to_full_precision(&half_precision).expect("operation failed in test");
1466
1467        // Values should be quantized but still reasonable
1468        match (&tensor, &full_precision) {
1469            (Tensor::F32(orig), Tensor::F32(converted)) => {
1470                for (o, c) in orig.iter().zip(converted.iter()) {
1471                    assert!((o - c).abs() < 0.1); // Some precision loss expected
1472                }
1473            },
1474            _ => panic!("Unexpected tensor types"),
1475        }
1476    }
1477
1478    #[test]
1479    fn test_fp16_safety_check() {
1480        let safe_tensor =
1481            Tensor::from_vec(vec![1.0, 2.0, -3.0], &[3]).expect("tensor operation failed");
1482        assert!(utils::is_fp16_safe(&safe_tensor).expect("tensor operation failed"));
1483
1484        let unsafe_tensor =
1485            Tensor::from_vec(vec![1.0, 70000.0, -3.0], &[3]).expect("tensor operation failed");
1486        assert!(!utils::is_fp16_safe(&unsafe_tensor).expect("tensor operation failed"));
1487    }
1488
1489    #[test]
1490    fn test_dynamic_range_calculation() {
1491        let tensor =
1492            Tensor::from_vec(vec![1.0, 5.0, -2.0, 3.0], &[2, 2]).expect("tensor operation failed");
1493        let (min_val, max_val) =
1494            utils::calculate_dynamic_range(&tensor).expect("tensor operation failed");
1495        assert_eq!(min_val, -2.0);
1496        assert_eq!(max_val, 5.0);
1497    }
1498
1499    #[test]
1500    fn test_overflow_detection_and_scale_update() {
1501        let config = MixedPrecisionConfig {
1502            enabled: true,
1503            backoff_factor: 0.5,
1504            ..MixedPrecisionConfig::default()
1505        };
1506        let mut scaler = LossScaler::new(config);
1507
1508        let initial_scale = scaler.get_scale();
1509
1510        // Simulate overflow
1511        scaler.update_scale(true).expect("operation failed in test");
1512        assert_eq!(scaler.get_scale(), initial_scale * 0.5);
1513        assert!(scaler.overflow_detected());
1514
1515        // Simulate no overflow
1516        scaler.update_scale(false).expect("operation failed in test");
1517        assert!(!scaler.overflow_detected());
1518    }
1519}