tensorlogic_scirs_backend/
precision.rs1use std::fmt;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
10pub enum Precision {
11 F32,
13
14 #[default]
16 F64,
17
18 Mixed16,
20
21 BFloat16,
23}
24
25impl Precision {
26 pub fn size_bytes(&self) -> usize {
28 match self {
29 Precision::F32 => 4,
30 Precision::F64 => 8,
31 Precision::Mixed16 => 2, Precision::BFloat16 => 2, }
34 }
35
36 pub fn is_mixed(&self) -> bool {
38 matches!(self, Precision::Mixed16 | Precision::BFloat16)
39 }
40
41 pub fn compute_precision(&self) -> ComputePrecision {
43 match self {
44 Precision::F32 | Precision::Mixed16 | Precision::BFloat16 => ComputePrecision::F32,
45 Precision::F64 => ComputePrecision::F64,
46 }
47 }
48
49 pub fn description(&self) -> &'static str {
51 match self {
52 Precision::F32 => "32-bit floating point",
53 Precision::F64 => "64-bit floating point",
54 Precision::Mixed16 => "Mixed precision (FP16 storage, FP32 compute)",
55 Precision::BFloat16 => "Mixed precision (BF16 storage, FP32 compute)",
56 }
57 }
58
59 pub fn memory_savings(&self) -> f64 {
61 1.0 - (self.size_bytes() as f64 / 8.0)
62 }
63}
64
65impl fmt::Display for Precision {
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 match self {
68 Precision::F32 => write!(f, "FP32"),
69 Precision::F64 => write!(f, "FP64"),
70 Precision::Mixed16 => write!(f, "Mixed-FP16"),
71 Precision::BFloat16 => write!(f, "Mixed-BF16"),
72 }
73 }
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum ComputePrecision {
79 F32,
81
82 F64,
84}
85
86impl fmt::Display for ComputePrecision {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 match self {
89 ComputePrecision::F32 => write!(f, "FP32"),
90 ComputePrecision::F64 => write!(f, "FP64"),
91 }
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct PrecisionConfig {
98 pub default_precision: Precision,
100
101 pub auto_mixed_precision: bool,
103
104 pub loss_scale: Option<f64>,
106
107 pub dynamic_loss_scaling: bool,
109}
110
111impl Default for PrecisionConfig {
112 fn default() -> Self {
113 Self {
114 default_precision: Precision::F64,
115 auto_mixed_precision: false,
116 loss_scale: None,
117 dynamic_loss_scaling: false,
118 }
119 }
120}
121
122impl PrecisionConfig {
123 pub fn f32() -> Self {
125 Self {
126 default_precision: Precision::F32,
127 auto_mixed_precision: false,
128 loss_scale: None,
129 dynamic_loss_scaling: false,
130 }
131 }
132
133 pub fn f64() -> Self {
135 Self {
136 default_precision: Precision::F64,
137 auto_mixed_precision: false,
138 loss_scale: None,
139 dynamic_loss_scaling: false,
140 }
141 }
142
143 pub fn mixed_precision() -> Self {
145 Self {
146 default_precision: Precision::Mixed16,
147 auto_mixed_precision: true,
148 loss_scale: Some(2048.0), dynamic_loss_scaling: true,
150 }
151 }
152
153 pub fn with_auto_mixed_precision(mut self, enable: bool) -> Self {
155 self.auto_mixed_precision = enable;
156 self
157 }
158
159 pub fn with_loss_scale(mut self, scale: f64) -> Self {
161 self.loss_scale = Some(scale);
162 self
163 }
164
165 pub fn with_dynamic_loss_scaling(mut self, enable: bool) -> Self {
167 self.dynamic_loss_scaling = enable;
168 self
169 }
170}
171
172pub trait Scalar:
176 Copy
177 + Clone
178 + PartialEq
179 + PartialOrd
180 + std::fmt::Debug
181 + std::fmt::Display
182 + std::ops::Add<Output = Self>
183 + std::ops::Sub<Output = Self>
184 + std::ops::Mul<Output = Self>
185 + std::ops::Div<Output = Self>
186 + std::ops::Neg<Output = Self>
187 + 'static
188{
189 fn zero() -> Self;
191
192 fn one() -> Self;
194
195 fn max_value() -> Self;
197
198 fn min_value() -> Self;
200
201 fn infinity() -> Self;
203
204 fn neg_infinity() -> Self;
206
207 fn nan() -> Self;
209
210 fn is_nan(self) -> bool;
212
213 fn is_infinite(self) -> bool;
215
216 fn is_finite(self) -> bool;
218
219 fn abs(self) -> Self;
221
222 fn sqrt(self) -> Self;
224
225 fn exp(self) -> Self;
227
228 fn ln(self) -> Self;
230
231 fn max(self, other: Self) -> Self;
233
234 fn min(self, other: Self) -> Self;
236
237 fn from_f64(value: f64) -> Self;
239
240 fn to_f64(self) -> f64;
242
243 fn precision() -> Precision;
245}
246
247impl Scalar for f32 {
248 fn zero() -> Self {
249 0.0
250 }
251
252 fn one() -> Self {
253 1.0
254 }
255
256 fn max_value() -> Self {
257 f32::MAX
258 }
259
260 fn min_value() -> Self {
261 f32::MIN
262 }
263
264 fn infinity() -> Self {
265 f32::INFINITY
266 }
267
268 fn neg_infinity() -> Self {
269 f32::NEG_INFINITY
270 }
271
272 fn nan() -> Self {
273 f32::NAN
274 }
275
276 fn is_nan(self) -> bool {
277 f32::is_nan(self)
278 }
279
280 fn is_infinite(self) -> bool {
281 f32::is_infinite(self)
282 }
283
284 fn is_finite(self) -> bool {
285 f32::is_finite(self)
286 }
287
288 fn abs(self) -> Self {
289 f32::abs(self)
290 }
291
292 fn sqrt(self) -> Self {
293 f32::sqrt(self)
294 }
295
296 fn exp(self) -> Self {
297 f32::exp(self)
298 }
299
300 fn ln(self) -> Self {
301 f32::ln(self)
302 }
303
304 fn max(self, other: Self) -> Self {
305 f32::max(self, other)
306 }
307
308 fn min(self, other: Self) -> Self {
309 f32::min(self, other)
310 }
311
312 fn from_f64(value: f64) -> Self {
313 value as f32
314 }
315
316 fn to_f64(self) -> f64 {
317 self as f64
318 }
319
320 fn precision() -> Precision {
321 Precision::F32
322 }
323}
324
325impl Scalar for f64 {
326 fn zero() -> Self {
327 0.0
328 }
329
330 fn one() -> Self {
331 1.0
332 }
333
334 fn max_value() -> Self {
335 f64::MAX
336 }
337
338 fn min_value() -> Self {
339 f64::MIN
340 }
341
342 fn infinity() -> Self {
343 f64::INFINITY
344 }
345
346 fn neg_infinity() -> Self {
347 f64::NEG_INFINITY
348 }
349
350 fn nan() -> Self {
351 f64::NAN
352 }
353
354 fn is_nan(self) -> bool {
355 f64::is_nan(self)
356 }
357
358 fn is_infinite(self) -> bool {
359 f64::is_infinite(self)
360 }
361
362 fn is_finite(self) -> bool {
363 f64::is_finite(self)
364 }
365
366 fn abs(self) -> Self {
367 f64::abs(self)
368 }
369
370 fn sqrt(self) -> Self {
371 f64::sqrt(self)
372 }
373
374 fn exp(self) -> Self {
375 f64::exp(self)
376 }
377
378 fn ln(self) -> Self {
379 f64::ln(self)
380 }
381
382 fn max(self, other: Self) -> Self {
383 f64::max(self, other)
384 }
385
386 fn min(self, other: Self) -> Self {
387 f64::min(self, other)
388 }
389
390 fn from_f64(value: f64) -> Self {
391 value
392 }
393
394 fn to_f64(self) -> f64 {
395 self
396 }
397
398 fn precision() -> Precision {
399 Precision::F64
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_precision_properties() {
409 assert_eq!(Precision::F32.size_bytes(), 4);
410 assert_eq!(Precision::F64.size_bytes(), 8);
411 assert_eq!(Precision::Mixed16.size_bytes(), 2);
412
413 assert!(!Precision::F32.is_mixed());
414 assert!(!Precision::F64.is_mixed());
415 assert!(Precision::Mixed16.is_mixed());
416 }
417
418 #[test]
419 fn test_precision_default() {
420 let precision = Precision::default();
421 assert_eq!(precision, Precision::F64);
422 }
423
424 #[test]
425 fn test_precision_display() {
426 assert_eq!(Precision::F32.to_string(), "FP32");
427 assert_eq!(Precision::F64.to_string(), "FP64");
428 assert_eq!(Precision::Mixed16.to_string(), "Mixed-FP16");
429 }
430
431 #[test]
432 fn test_precision_memory_savings() {
433 assert!((Precision::F32.memory_savings() - 0.5).abs() < 0.01); assert!((Precision::F64.memory_savings()).abs() < 0.01); assert!((Precision::Mixed16.memory_savings() - 0.75).abs() < 0.01); }
437
438 #[test]
439 fn test_precision_config_default() {
440 let config = PrecisionConfig::default();
441 assert_eq!(config.default_precision, Precision::F64);
442 assert!(!config.auto_mixed_precision);
443 }
444
445 #[test]
446 fn test_precision_config_builders() {
447 let f32_config = PrecisionConfig::f32();
448 assert_eq!(f32_config.default_precision, Precision::F32);
449
450 let f64_config = PrecisionConfig::f64();
451 assert_eq!(f64_config.default_precision, Precision::F64);
452
453 let mixed_config = PrecisionConfig::mixed_precision();
454 assert_eq!(mixed_config.default_precision, Precision::Mixed16);
455 assert!(mixed_config.auto_mixed_precision);
456 assert!(mixed_config.loss_scale.is_some());
457 }
458
459 #[test]
460 fn test_precision_config_builder_methods() {
461 let config = PrecisionConfig::f32()
462 .with_auto_mixed_precision(true)
463 .with_loss_scale(1024.0)
464 .with_dynamic_loss_scaling(true);
465
466 assert!(config.auto_mixed_precision);
467 assert_eq!(config.loss_scale, Some(1024.0));
468 assert!(config.dynamic_loss_scaling);
469 }
470
471 #[test]
472 fn test_scalar_f32() {
473 assert_eq!(f32::zero(), 0.0_f32);
474 assert_eq!(f32::one(), 1.0_f32);
475 assert!(f32::infinity().is_infinite());
476 assert!(f32::nan().is_nan());
477
478 let x = 2.0_f32;
479 assert_eq!(x.abs(), 2.0);
480 assert!((x.sqrt() - std::f32::consts::SQRT_2).abs() < 1e-6);
481 assert_eq!(f32::precision(), Precision::F32);
482 }
483
484 #[test]
485 fn test_scalar_f64() {
486 assert_eq!(f64::zero(), 0.0_f64);
487 assert_eq!(f64::one(), 1.0_f64);
488 assert!(f64::infinity().is_infinite());
489 assert!(f64::nan().is_nan());
490
491 let x = 2.0_f64;
492 assert_eq!(x.abs(), 2.0);
493 assert!((x.sqrt() - std::f64::consts::SQRT_2).abs() < 1e-10);
494 assert_eq!(f64::precision(), Precision::F64);
495 }
496
497 #[test]
498 fn test_scalar_conversions() {
499 let x_f64 = std::f64::consts::PI;
500 let x_f32 = f32::from_f64(x_f64);
501 let back_to_f64 = x_f32.to_f64();
502
503 assert!((x_f32 - std::f32::consts::PI).abs() < 1e-5);
504 assert!((back_to_f64 - x_f64).abs() < 1e-5);
505 }
506}