Skip to main content

tenflowers_core/
half_precision.rs

1//! Half precision floating point support
2//!
3//! This module provides support for IEEE 754-2008 half precision (f16) and
4//! Google's brain floating point (bf16) data types for mixed precision training.
5
6pub use half::{bf16, f16};
7
8/// Trait for half precision floating point types
9pub trait HalfPrecision: Copy + Clone + Send + Sync + 'static {
10    type FullPrecision: scirs2_core::num_traits::Float;
11
12    /// Convert to full precision (f32)
13    fn to_f32(self) -> f32;
14
15    /// Convert from full precision (f32)
16    fn from_f32(value: f32) -> Self;
17
18    /// Get the data type
19    fn dtype() -> crate::DType;
20}
21
22impl HalfPrecision for f16 {
23    type FullPrecision = f32;
24
25    fn to_f32(self) -> f32 {
26        self.to_f32()
27    }
28
29    fn from_f32(value: f32) -> Self {
30        f16::from_f32(value)
31    }
32
33    fn dtype() -> crate::DType {
34        crate::DType::Float16
35    }
36}
37
38impl HalfPrecision for bf16 {
39    type FullPrecision = f32;
40
41    fn to_f32(self) -> f32 {
42        self.to_f32()
43    }
44
45    fn from_f32(value: f32) -> Self {
46        bf16::from_f32(value)
47    }
48
49    fn dtype() -> crate::DType {
50        crate::DType::BFloat16
51    }
52}
53
54/// Mixed precision configuration for automatic mixed precision (AMP) training
55#[derive(Debug, Clone)]
56pub struct MixedPrecisionConfig {
57    /// Whether to enable automatic mixed precision
58    pub enabled: bool,
59    /// Loss scaling factor to prevent gradient underflow
60    pub loss_scale: f32,
61    /// Growth factor for loss scaling
62    pub growth_factor: f32,
63    /// Backoff factor for loss scaling
64    pub backoff_factor: f32,
65    /// Growth interval (number of steps without overflow)
66    pub growth_interval: u32,
67    /// Counter for steps without overflow
68    pub(crate) steps_without_overflow: u32,
69    /// Whether to use bf16 instead of f16
70    pub use_bfloat16: bool,
71}
72
73impl Default for MixedPrecisionConfig {
74    fn default() -> Self {
75        Self {
76            enabled: false,
77            loss_scale: 65536.0,
78            growth_factor: 2.0,
79            backoff_factor: 0.5,
80            growth_interval: 2000,
81            steps_without_overflow: 0,
82            use_bfloat16: false,
83        }
84    }
85}
86
87impl MixedPrecisionConfig {
88    /// Create a new mixed precision configuration with default settings
89    pub fn new() -> Self {
90        Self::default()
91    }
92
93    /// Enable mixed precision training
94    pub fn enable(mut self) -> Self {
95        self.enabled = true;
96        self
97    }
98
99    /// Set the initial loss scale
100    pub fn with_loss_scale(mut self, scale: f32) -> Self {
101        self.loss_scale = scale;
102        self
103    }
104
105    /// Use bfloat16 instead of float16
106    pub fn with_bfloat16(mut self) -> Self {
107        self.use_bfloat16 = true;
108        self
109    }
110
111    /// Check for gradient overflow and update loss scaling
112    pub fn update_loss_scale(&mut self, has_overflow: bool) {
113        if has_overflow {
114            // Decrease loss scale and reset counter
115            self.loss_scale *= self.backoff_factor;
116            self.steps_without_overflow = 0;
117        } else {
118            // Increment counter
119            self.steps_without_overflow += 1;
120
121            // Increase loss scale if no overflow for growth_interval steps
122            if self.steps_without_overflow >= self.growth_interval {
123                self.loss_scale *= self.growth_factor;
124                self.steps_without_overflow = 0;
125            }
126        }
127
128        // Ensure loss scale doesn't become too small or too large
129        self.loss_scale = self.loss_scale.clamp(1.0, f32::MAX / 1000.0);
130    }
131
132    /// Get the target half precision type
133    pub fn target_dtype(&self) -> crate::DType {
134        if self.use_bfloat16 {
135            crate::DType::BFloat16
136        } else {
137            crate::DType::Float16
138        }
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn test_f16_conversion() {
148        let value = std::f32::consts::PI;
149        let f16_val = f16::from_f32(value);
150        let converted_back = f16_val.to_f32();
151
152        // f16 has limited precision, so allow some error
153        assert!((converted_back - value).abs() < 0.01);
154    }
155
156    #[test]
157    fn test_bf16_conversion() {
158        let value = std::f32::consts::PI;
159        let bf16_val = bf16::from_f32(value);
160        let converted_back = bf16_val.to_f32();
161
162        // bf16 has better precision than f16 for this range
163        assert!((converted_back - value).abs() < 0.001);
164    }
165
166    #[test]
167    fn test_mixed_precision_config() {
168        let mut config = MixedPrecisionConfig::new()
169            .enable()
170            .with_loss_scale(1024.0)
171            .with_bfloat16();
172
173        assert!(config.enabled);
174        assert_eq!(config.loss_scale, 1024.0);
175        assert!(config.use_bfloat16);
176        assert_eq!(config.target_dtype(), crate::DType::BFloat16);
177
178        // Test overflow handling
179        config.update_loss_scale(true);
180        assert_eq!(config.loss_scale, 512.0); // 1024 * 0.5
181        assert_eq!(config.steps_without_overflow, 0);
182
183        // Test growth
184        for _ in 0..config.growth_interval {
185            config.update_loss_scale(false);
186        }
187        assert_eq!(config.loss_scale, 1024.0); // 512 * 2.0
188    }
189
190    #[test]
191    fn test_dtype_mapping() {
192        assert_eq!(f16::dtype(), crate::DType::Float16);
193        assert_eq!(bf16::dtype(), crate::DType::BFloat16);
194    }
195}