Skip to main content

tensorlogic_scirs_backend/
precision.rs

1//! Precision control for tensor computations.
2//!
3//! This module provides abstractions for controlling numerical precision
4//! (f32, f64, mixed precision).
5
6use std::fmt;
7
8/// Numerical precision for tensor computations.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
10pub enum Precision {
11    /// 32-bit floating point (faster, less memory)
12    F32,
13
14    /// 64-bit floating point (more accurate)
15    #[default]
16    F64,
17
18    /// Mixed precision: f16 for storage, f32 for computation
19    Mixed16,
20
21    /// Mixed precision: bf16 for storage, f32 for computation
22    BFloat16,
23}
24
25impl Precision {
26    /// Returns the size in bytes of this precision.
27    pub fn size_bytes(&self) -> usize {
28        match self {
29            Precision::F32 => 4,
30            Precision::F64 => 8,
31            Precision::Mixed16 => 2,  // Storage size
32            Precision::BFloat16 => 2, // Storage size
33        }
34    }
35
36    /// Returns true if this is a mixed precision mode.
37    pub fn is_mixed(&self) -> bool {
38        matches!(self, Precision::Mixed16 | Precision::BFloat16)
39    }
40
41    /// Returns the computation precision (the precision used for actual operations).
42    pub fn compute_precision(&self) -> ComputePrecision {
43        match self {
44            Precision::F32 | Precision::Mixed16 | Precision::BFloat16 => ComputePrecision::F32,
45            Precision::F64 => ComputePrecision::F64,
46        }
47    }
48
49    /// Returns a human-readable description.
50    pub fn description(&self) -> &'static str {
51        match self {
52            Precision::F32 => "32-bit floating point",
53            Precision::F64 => "64-bit floating point",
54            Precision::Mixed16 => "Mixed precision (FP16 storage, FP32 compute)",
55            Precision::BFloat16 => "Mixed precision (BF16 storage, FP32 compute)",
56        }
57    }
58
59    /// Memory savings compared to F64.
60    pub fn memory_savings(&self) -> f64 {
61        1.0 - (self.size_bytes() as f64 / 8.0)
62    }
63}
64
65impl fmt::Display for Precision {
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        match self {
68            Precision::F32 => write!(f, "FP32"),
69            Precision::F64 => write!(f, "FP64"),
70            Precision::Mixed16 => write!(f, "Mixed-FP16"),
71            Precision::BFloat16 => write!(f, "Mixed-BF16"),
72        }
73    }
74}
75
76/// Computation precision (the actual precision used for operations).
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum ComputePrecision {
79    /// 32-bit computation
80    F32,
81
82    /// 64-bit computation
83    F64,
84}
85
86impl fmt::Display for ComputePrecision {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        match self {
89            ComputePrecision::F32 => write!(f, "FP32"),
90            ComputePrecision::F64 => write!(f, "FP64"),
91        }
92    }
93}
94
95/// Precision configuration for an executor.
96#[derive(Debug, Clone)]
97pub struct PrecisionConfig {
98    /// Default precision for tensors
99    pub default_precision: Precision,
100
101    /// Enable automatic mixed precision
102    pub auto_mixed_precision: bool,
103
104    /// Loss scaling for mixed precision training
105    pub loss_scale: Option<f64>,
106
107    /// Dynamic loss scaling (adjust based on gradients)
108    pub dynamic_loss_scaling: bool,
109}
110
111impl Default for PrecisionConfig {
112    fn default() -> Self {
113        Self {
114            default_precision: Precision::F64,
115            auto_mixed_precision: false,
116            loss_scale: None,
117            dynamic_loss_scaling: false,
118        }
119    }
120}
121
122impl PrecisionConfig {
123    /// Create a configuration for FP32 precision.
124    pub fn f32() -> Self {
125        Self {
126            default_precision: Precision::F32,
127            auto_mixed_precision: false,
128            loss_scale: None,
129            dynamic_loss_scaling: false,
130        }
131    }
132
133    /// Create a configuration for FP64 precision.
134    pub fn f64() -> Self {
135        Self {
136            default_precision: Precision::F64,
137            auto_mixed_precision: false,
138            loss_scale: None,
139            dynamic_loss_scaling: false,
140        }
141    }
142
143    /// Create a configuration for mixed precision training.
144    pub fn mixed_precision() -> Self {
145        Self {
146            default_precision: Precision::Mixed16,
147            auto_mixed_precision: true,
148            loss_scale: Some(2048.0), // Common starting value
149            dynamic_loss_scaling: true,
150        }
151    }
152
153    /// Enable automatic mixed precision.
154    pub fn with_auto_mixed_precision(mut self, enable: bool) -> Self {
155        self.auto_mixed_precision = enable;
156        self
157    }
158
159    /// Set the loss scale for mixed precision training.
160    pub fn with_loss_scale(mut self, scale: f64) -> Self {
161        self.loss_scale = Some(scale);
162        self
163    }
164
165    /// Enable dynamic loss scaling.
166    pub fn with_dynamic_loss_scaling(mut self, enable: bool) -> Self {
167        self.dynamic_loss_scaling = enable;
168        self
169    }
170}
171
172/// Trait for scalar types that can be used in tensor computations.
173///
174/// This trait abstracts over f32 and f64 for generic tensor operations.
175pub trait Scalar:
176    Copy
177    + Clone
178    + PartialEq
179    + PartialOrd
180    + std::fmt::Debug
181    + std::fmt::Display
182    + std::ops::Add<Output = Self>
183    + std::ops::Sub<Output = Self>
184    + std::ops::Mul<Output = Self>
185    + std::ops::Div<Output = Self>
186    + std::ops::Neg<Output = Self>
187    + 'static
188{
189    /// Zero value
190    fn zero() -> Self;
191
192    /// One value
193    fn one() -> Self;
194
195    /// Maximum value
196    fn max_value() -> Self;
197
198    /// Minimum value (most negative)
199    fn min_value() -> Self;
200
201    /// Positive infinity
202    fn infinity() -> Self;
203
204    /// Negative infinity
205    fn neg_infinity() -> Self;
206
207    /// Not a number
208    fn nan() -> Self;
209
210    /// Check if value is NaN
211    fn is_nan(self) -> bool;
212
213    /// Check if value is infinite
214    fn is_infinite(self) -> bool;
215
216    /// Check if value is finite
217    fn is_finite(self) -> bool;
218
219    /// Absolute value
220    fn abs(self) -> Self;
221
222    /// Square root
223    fn sqrt(self) -> Self;
224
225    /// Exponential
226    fn exp(self) -> Self;
227
228    /// Natural logarithm
229    fn ln(self) -> Self;
230
231    /// Maximum of two values
232    fn max(self, other: Self) -> Self;
233
234    /// Minimum of two values
235    fn min(self, other: Self) -> Self;
236
237    /// Convert from f64
238    fn from_f64(value: f64) -> Self;
239
240    /// Convert to f64
241    fn to_f64(self) -> f64;
242
243    /// The precision type
244    fn precision() -> Precision;
245}
246
247impl Scalar for f32 {
248    fn zero() -> Self {
249        0.0
250    }
251
252    fn one() -> Self {
253        1.0
254    }
255
256    fn max_value() -> Self {
257        f32::MAX
258    }
259
260    fn min_value() -> Self {
261        f32::MIN
262    }
263
264    fn infinity() -> Self {
265        f32::INFINITY
266    }
267
268    fn neg_infinity() -> Self {
269        f32::NEG_INFINITY
270    }
271
272    fn nan() -> Self {
273        f32::NAN
274    }
275
276    fn is_nan(self) -> bool {
277        f32::is_nan(self)
278    }
279
280    fn is_infinite(self) -> bool {
281        f32::is_infinite(self)
282    }
283
284    fn is_finite(self) -> bool {
285        f32::is_finite(self)
286    }
287
288    fn abs(self) -> Self {
289        f32::abs(self)
290    }
291
292    fn sqrt(self) -> Self {
293        f32::sqrt(self)
294    }
295
296    fn exp(self) -> Self {
297        f32::exp(self)
298    }
299
300    fn ln(self) -> Self {
301        f32::ln(self)
302    }
303
304    fn max(self, other: Self) -> Self {
305        f32::max(self, other)
306    }
307
308    fn min(self, other: Self) -> Self {
309        f32::min(self, other)
310    }
311
312    fn from_f64(value: f64) -> Self {
313        value as f32
314    }
315
316    fn to_f64(self) -> f64 {
317        self as f64
318    }
319
320    fn precision() -> Precision {
321        Precision::F32
322    }
323}
324
325impl Scalar for f64 {
326    fn zero() -> Self {
327        0.0
328    }
329
330    fn one() -> Self {
331        1.0
332    }
333
334    fn max_value() -> Self {
335        f64::MAX
336    }
337
338    fn min_value() -> Self {
339        f64::MIN
340    }
341
342    fn infinity() -> Self {
343        f64::INFINITY
344    }
345
346    fn neg_infinity() -> Self {
347        f64::NEG_INFINITY
348    }
349
350    fn nan() -> Self {
351        f64::NAN
352    }
353
354    fn is_nan(self) -> bool {
355        f64::is_nan(self)
356    }
357
358    fn is_infinite(self) -> bool {
359        f64::is_infinite(self)
360    }
361
362    fn is_finite(self) -> bool {
363        f64::is_finite(self)
364    }
365
366    fn abs(self) -> Self {
367        f64::abs(self)
368    }
369
370    fn sqrt(self) -> Self {
371        f64::sqrt(self)
372    }
373
374    fn exp(self) -> Self {
375        f64::exp(self)
376    }
377
378    fn ln(self) -> Self {
379        f64::ln(self)
380    }
381
382    fn max(self, other: Self) -> Self {
383        f64::max(self, other)
384    }
385
386    fn min(self, other: Self) -> Self {
387        f64::min(self, other)
388    }
389
390    fn from_f64(value: f64) -> Self {
391        value
392    }
393
394    fn to_f64(self) -> f64 {
395        self
396    }
397
398    fn precision() -> Precision {
399        Precision::F64
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_precision_properties() {
409        assert_eq!(Precision::F32.size_bytes(), 4);
410        assert_eq!(Precision::F64.size_bytes(), 8);
411        assert_eq!(Precision::Mixed16.size_bytes(), 2);
412
413        assert!(!Precision::F32.is_mixed());
414        assert!(!Precision::F64.is_mixed());
415        assert!(Precision::Mixed16.is_mixed());
416    }
417
418    #[test]
419    fn test_precision_default() {
420        let precision = Precision::default();
421        assert_eq!(precision, Precision::F64);
422    }
423
424    #[test]
425    fn test_precision_display() {
426        assert_eq!(Precision::F32.to_string(), "FP32");
427        assert_eq!(Precision::F64.to_string(), "FP64");
428        assert_eq!(Precision::Mixed16.to_string(), "Mixed-FP16");
429    }
430
431    #[test]
432    fn test_precision_memory_savings() {
433        assert!((Precision::F32.memory_savings() - 0.5).abs() < 0.01); // 50% savings vs F64
434        assert!((Precision::F64.memory_savings()).abs() < 0.01); // 0% savings
435        assert!((Precision::Mixed16.memory_savings() - 0.75).abs() < 0.01); // 75% savings
436    }
437
438    #[test]
439    fn test_precision_config_default() {
440        let config = PrecisionConfig::default();
441        assert_eq!(config.default_precision, Precision::F64);
442        assert!(!config.auto_mixed_precision);
443    }
444
445    #[test]
446    fn test_precision_config_builders() {
447        let f32_config = PrecisionConfig::f32();
448        assert_eq!(f32_config.default_precision, Precision::F32);
449
450        let f64_config = PrecisionConfig::f64();
451        assert_eq!(f64_config.default_precision, Precision::F64);
452
453        let mixed_config = PrecisionConfig::mixed_precision();
454        assert_eq!(mixed_config.default_precision, Precision::Mixed16);
455        assert!(mixed_config.auto_mixed_precision);
456        assert!(mixed_config.loss_scale.is_some());
457    }
458
459    #[test]
460    fn test_precision_config_builder_methods() {
461        let config = PrecisionConfig::f32()
462            .with_auto_mixed_precision(true)
463            .with_loss_scale(1024.0)
464            .with_dynamic_loss_scaling(true);
465
466        assert!(config.auto_mixed_precision);
467        assert_eq!(config.loss_scale, Some(1024.0));
468        assert!(config.dynamic_loss_scaling);
469    }
470
471    #[test]
472    fn test_scalar_f32() {
473        assert_eq!(f32::zero(), 0.0_f32);
474        assert_eq!(f32::one(), 1.0_f32);
475        assert!(f32::infinity().is_infinite());
476        assert!(f32::nan().is_nan());
477
478        let x = 2.0_f32;
479        assert_eq!(x.abs(), 2.0);
480        assert!((x.sqrt() - std::f32::consts::SQRT_2).abs() < 1e-6);
481        assert_eq!(f32::precision(), Precision::F32);
482    }
483
484    #[test]
485    fn test_scalar_f64() {
486        assert_eq!(f64::zero(), 0.0_f64);
487        assert_eq!(f64::one(), 1.0_f64);
488        assert!(f64::infinity().is_infinite());
489        assert!(f64::nan().is_nan());
490
491        let x = 2.0_f64;
492        assert_eq!(x.abs(), 2.0);
493        assert!((x.sqrt() - std::f64::consts::SQRT_2).abs() < 1e-10);
494        assert_eq!(f64::precision(), Precision::F64);
495    }
496
497    #[test]
498    fn test_scalar_conversions() {
499        let x_f64 = std::f64::consts::PI;
500        let x_f32 = f32::from_f64(x_f64);
501        let back_to_f64 = x_f32.to_f64();
502
503        assert!((x_f32 - std::f32::consts::PI).abs() < 1e-5);
504        assert!((back_to_f64 - x_f64).abs() < 1e-5);
505    }
506}