tenflowers_core/
half_precision.rs1pub use half::{bf16, f16};
7
8pub trait HalfPrecision: Copy + Clone + Send + Sync + 'static {
10 type FullPrecision: scirs2_core::num_traits::Float;
11
12 fn to_f32(self) -> f32;
14
15 fn from_f32(value: f32) -> Self;
17
18 fn dtype() -> crate::DType;
20}
21
22impl HalfPrecision for f16 {
23 type FullPrecision = f32;
24
25 fn to_f32(self) -> f32 {
26 self.to_f32()
27 }
28
29 fn from_f32(value: f32) -> Self {
30 f16::from_f32(value)
31 }
32
33 fn dtype() -> crate::DType {
34 crate::DType::Float16
35 }
36}
37
38impl HalfPrecision for bf16 {
39 type FullPrecision = f32;
40
41 fn to_f32(self) -> f32 {
42 self.to_f32()
43 }
44
45 fn from_f32(value: f32) -> Self {
46 bf16::from_f32(value)
47 }
48
49 fn dtype() -> crate::DType {
50 crate::DType::BFloat16
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct MixedPrecisionConfig {
57 pub enabled: bool,
59 pub loss_scale: f32,
61 pub growth_factor: f32,
63 pub backoff_factor: f32,
65 pub growth_interval: u32,
67 pub(crate) steps_without_overflow: u32,
69 pub use_bfloat16: bool,
71}
72
73impl Default for MixedPrecisionConfig {
74 fn default() -> Self {
75 Self {
76 enabled: false,
77 loss_scale: 65536.0,
78 growth_factor: 2.0,
79 backoff_factor: 0.5,
80 growth_interval: 2000,
81 steps_without_overflow: 0,
82 use_bfloat16: false,
83 }
84 }
85}
86
87impl MixedPrecisionConfig {
88 pub fn new() -> Self {
90 Self::default()
91 }
92
93 pub fn enable(mut self) -> Self {
95 self.enabled = true;
96 self
97 }
98
99 pub fn with_loss_scale(mut self, scale: f32) -> Self {
101 self.loss_scale = scale;
102 self
103 }
104
105 pub fn with_bfloat16(mut self) -> Self {
107 self.use_bfloat16 = true;
108 self
109 }
110
111 pub fn update_loss_scale(&mut self, has_overflow: bool) {
113 if has_overflow {
114 self.loss_scale *= self.backoff_factor;
116 self.steps_without_overflow = 0;
117 } else {
118 self.steps_without_overflow += 1;
120
121 if self.steps_without_overflow >= self.growth_interval {
123 self.loss_scale *= self.growth_factor;
124 self.steps_without_overflow = 0;
125 }
126 }
127
128 self.loss_scale = self.loss_scale.clamp(1.0, f32::MAX / 1000.0);
130 }
131
132 pub fn target_dtype(&self) -> crate::DType {
134 if self.use_bfloat16 {
135 crate::DType::BFloat16
136 } else {
137 crate::DType::Float16
138 }
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn test_f16_conversion() {
148 let value = std::f32::consts::PI;
149 let f16_val = f16::from_f32(value);
150 let converted_back = f16_val.to_f32();
151
152 assert!((converted_back - value).abs() < 0.01);
154 }
155
156 #[test]
157 fn test_bf16_conversion() {
158 let value = std::f32::consts::PI;
159 let bf16_val = bf16::from_f32(value);
160 let converted_back = bf16_val.to_f32();
161
162 assert!((converted_back - value).abs() < 0.001);
164 }
165
166 #[test]
167 fn test_mixed_precision_config() {
168 let mut config = MixedPrecisionConfig::new()
169 .enable()
170 .with_loss_scale(1024.0)
171 .with_bfloat16();
172
173 assert!(config.enabled);
174 assert_eq!(config.loss_scale, 1024.0);
175 assert!(config.use_bfloat16);
176 assert_eq!(config.target_dtype(), crate::DType::BFloat16);
177
178 config.update_loss_scale(true);
180 assert_eq!(config.loss_scale, 512.0); assert_eq!(config.steps_without_overflow, 0);
182
183 for _ in 0..config.growth_interval {
185 config.update_loss_scale(false);
186 }
187 assert_eq!(config.loss_scale, 1024.0); }
189
190 #[test]
191 fn test_dtype_mapping() {
192 assert_eq!(f16::dtype(), crate::DType::Float16);
193 assert_eq!(bf16::dtype(), crate::DType::BFloat16);
194 }
195}