sparse_ir/
numeric.rs

1//! Custom numeric traits for high-precision computation
2//!
3//! This module provides custom numeric traits that work with both f64 and the xprec Df64 backend
4//! for high-precision numerical computation in gauss quadrature and matrix operations.
5
6use crate::Df64;
7use libm::expm1;
8use simba::scalar::ComplexField;
9use std::fmt::Debug;
10
11/// Custom numeric trait for high-precision numerical computation
12///
13/// This trait provides the essential numeric operations needed for gauss module
14/// and matrix_from_gauss functions. Supports both f64 and Df64 types.
15///
16/// Uses standard traits for common operations:
17/// - num_traits::Zero for zero()
18/// - simba::scalar::ComplexField for mathematical functions (abs, sqrt, cos, sin, exp, etc.)
19/// - ComplexField::is_finite() for is_finite()
20pub trait CustomNumeric:
21    Copy
22    + Debug
23    + PartialOrd
24    + std::fmt::Display
25    + std::ops::Add<Output = Self>
26    + std::ops::Sub<Output = Self>
27    + std::ops::Mul<Output = Self>
28    + std::ops::Div<Output = Self>
29    + std::ops::Neg<Output = Self>
30    + num_traits::Zero                    // zero()
31    + simba::scalar::ComplexField         // abs, cos, sin, exp, sqrt, etc.
32{
33    /// Convert from f64 to Self (direct conversion, no Option)
34    ///
35    /// This is a static method that can be called as T::from_f64_unchecked(x)
36    /// Renamed to avoid conflict with num_traits::FromPrimitive::from_f64
37    fn from_f64_unchecked(x: f64) -> Self;
38
39    /// Convert from any CustomNumeric type to Self (generic conversion)
40    ///
41    /// This allows conversion between different numeric types while preserving precision.
42    /// Can be called as T::convert_from(other_numeric_value)
43    fn convert_from<U: CustomNumeric + 'static>(value: U) -> Self;
44
45
46    /// Convert to f64
47    fn to_f64(self) -> f64;
48
49    /// Get machine epsilon
50    fn epsilon() -> Self;
51
52    /// Get high-precision PI constant
53    fn pi() -> Self;
54
55    /// Maximum of two values (not provided by ComplexField)
56    fn max(self, other: Self) -> Self;
57
58    /// Minimum of two values (not provided by ComplexField)
59    fn min(self, other: Self) -> Self;
60
61    /// Check if the value is valid (not NaN/infinite)
62    fn is_valid(&self) -> bool;
63
64    /// Check if the value is finite (not NaN or infinite)
65
66    /// Get absolute value as the same type (convenience method)
67    ///
68    /// This method wraps ComplexField::abs() and converts the result back to Self
69    /// to avoid type conversion issues in generic code.
70    /// Note: Only abs() has this problem - other math functions (exp, cos, sin, sqrt) 
71    /// already return Self directly from ComplexField.
72    fn abs_as_same_type(self) -> Self;
73
74    /// Compute exp(self) - 1 with higher precision for small values
75    ///
76    /// This is more accurate than `self.exp() - 1` when self is close to zero,
77    /// avoiding catastrophic cancellation.
78    fn exp_m1(self) -> Self;
79}
80
81/// f64 implementation of CustomNumeric
82impl CustomNumeric for f64 {
83    fn from_f64_unchecked(x: f64) -> Self {
84        x
85    }
86
87    fn convert_from<U: CustomNumeric + 'static>(value: U) -> Self {
88        // Use match to optimize conversion based on the source type
89        // Note: Using TypeId for compile-time optimization, but falling back to safe conversion
90        match std::any::TypeId::of::<U>() {
91            // For f64 to f64, this is just a copy (no conversion needed)
92            id if id == std::any::TypeId::of::<f64>() => {
93                // Safe: f64 to f64 conversion
94                // This is a no-op for f64
95                value.to_f64()
96            }
97            // For Df64 to f64, use the conversion method
98            id if id == std::any::TypeId::of::<Df64>() => {
99                // Safe: Df64 to f64 conversion
100                value.to_f64()
101            }
102            // Fallback: convert via f64 for unknown types
103            _ => value.to_f64(),
104        }
105    }
106
107    fn to_f64(self) -> f64 {
108        self
109    }
110
111    fn epsilon() -> Self {
112        f64::EPSILON
113    }
114
115    fn pi() -> Self {
116        std::f64::consts::PI
117    }
118
119    fn max(self, other: Self) -> Self {
120        self.max(other)
121    }
122
123    fn min(self, other: Self) -> Self {
124        self.min(other)
125    }
126
127    fn is_valid(&self) -> bool {
128        num_traits::Float::is_finite(*self)
129    }
130
131    fn abs_as_same_type(self) -> Self {
132        Self::convert_from(self.abs())
133    }
134
135    fn exp_m1(self) -> Self {
136        expm1(self)
137    }
138}
139
140/// Df64 implementation of CustomNumeric
141impl CustomNumeric for Df64 {
142    fn from_f64_unchecked(x: f64) -> Self {
143        Df64::from(x)
144    }
145
146    fn convert_from<U: CustomNumeric + 'static>(value: U) -> Self {
147        // Use match to optimize conversion based on the source type
148        // Note: Using TypeId for compile-time optimization, but falling back to safe conversion
149        match std::any::TypeId::of::<U>() {
150            // For f64 to Df64, use the conversion method
151            id if id == std::any::TypeId::of::<f64>() => {
152                // Safe: f64 to Df64 conversion
153                let f64_value = value.to_f64();
154                Self::from_f64_unchecked(f64_value)
155            }
156            // For Df64 to Df64, this is just a copy (no conversion needed)
157            id if id == std::any::TypeId::of::<Df64>() => {
158                // Safe: Df64 to Df64 conversion (copy)
159                // We know U is Df64 from TypeId check, and Df64 implements Copy
160                // so we can safely copy the value without losing precision
161                // Note: We can't just return `value` directly because Rust's type system
162                // doesn't know that U == Df64 at compile time, even though we've verified
163                // it at runtime with TypeId. So we need unsafe transmute_copy.
164                unsafe {
165                    // Safety: We've verified via TypeId that U == Df64, and Df64 is Copy
166                    // so this is a no-op copy that preserves all precision (hi and lo parts)
167                    std::mem::transmute_copy(&value)
168                }
169            }
170            // Fallback: convert via f64 for unknown types
171            _ => Self::from_f64_unchecked(value.to_f64()),
172        }
173    }
174
175    fn to_f64(self) -> f64 {
176        // Use hi() and lo() methods for better precision
177        let hi = self.hi();
178        let lo = self.lo();
179        hi + lo
180    }
181
182    fn epsilon() -> Self {
183        // Df64::EPSILON = f64::EPSILON * f64::EPSILON / 2.0
184        let epsilon_value = f64::EPSILON * f64::EPSILON / 2.0;
185        Df64::from(epsilon_value)
186    }
187
188    fn pi() -> Self {
189        Df64::from(std::f64::consts::PI)
190    }
191
192    fn max(self, other: Self) -> Self {
193        if self > other { self } else { other }
194    }
195
196    fn min(self, other: Self) -> Self {
197        if self < other { self } else { other }
198    }
199
200    fn is_valid(&self) -> bool {
201        let value = *self;
202        f64::from(value).is_finite() && !f64::from(value).is_nan()
203    }
204
205    fn abs_as_same_type(self) -> Self {
206        Self::convert_from(self.abs())
207    }
208
209    fn exp_m1(self) -> Self {
210        // Use ComplexField::exp_m1() for Df64
211        ComplexField::exp_m1(self)
212    }
213}
214
215// Note: ScalarOperand implementations for f64 and Df64 are provided by ndarray
216// We cannot implement them here due to Orphan Rules, but they are already implemented
217// in the ndarray crate for standard numeric types.
218
219// Note: Df64ArrayOps trait and impl removed after ndarray migration
220// Array operations should now be done using mdarray Tensor methods directly
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_f64_custom_numeric() {
228        let x = 1.5_f64;
229        let y = -2.0_f64;
230
231        // Test basic operations
232        assert_eq!(x.abs(), 1.5);
233        assert_eq!(y.abs(), 2.0);
234
235        // Test mathematical functions
236        let cos_x = x.cos();
237        assert!(f64::from(cos_x).is_finite());
238
239        let sqrt_x = x.sqrt();
240        assert!(f64::from(sqrt_x).is_finite());
241
242        // Test conversion
243        assert_eq!(x.to_f64(), 1.5);
244        assert_eq!(<f64 as CustomNumeric>::epsilon(), f64::EPSILON);
245    }
246
247    #[test]
248    fn test_twofloat_custom_numeric() {
249        let x = Df64::from_f64_unchecked(1.5);
250        let y = Df64::from_f64_unchecked(-2.0);
251
252        // Test basic operations
253        assert_eq!(x.abs(), Df64::from_f64_unchecked(1.5));
254        assert_eq!(y.abs(), Df64::from_f64_unchecked(2.0));
255
256        // Test mathematical functions
257        let cos_x = x.cos();
258        assert!(f64::from(cos_x).is_finite());
259
260        let sqrt_x = x.sqrt();
261        assert!(f64::from(sqrt_x).is_finite());
262
263        // Test conversion
264        let x_f64 = x.to_f64();
265        assert!((x_f64 - 1.5).abs() < 1e-15);
266
267        // Test epsilon
268        let eps = Df64::epsilon();
269        assert!(eps > Df64::from_f64_unchecked(0.0));
270        assert!(eps < Df64::from_f64_unchecked(1.0));
271    }
272
273    // Note: Df64ArrayOps tests removed after ndarray migration
274
275    #[test]
276    fn test_precision_comparison() {
277        // Test that Df64 provides higher precision than f64
278        let pi_f64 = std::f64::consts::PI;
279        let pi_tf = Df64::from_f64_unchecked(pi_f64);
280
281        // Both should be finite
282        assert!(pi_f64.is_finite());
283        assert!(f64::from(pi_tf).is_finite());
284
285        // Df64 should convert back to f64 with minimal loss
286        let pi_back = pi_tf.to_f64();
287        assert!((pi_back - pi_f64).abs() < f64::EPSILON);
288    }
289}