sparse_ir/sve/
types.rs

1//! Type definitions for SVE computation
2
3use simba::scalar::ComplexField;
4
5/// Working precision type for SVE computations
6///
7/// Values match the C-API constants defined in sparseir.h
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum TworkType {
10    /// Use double precision (64-bit)
11    Float64 = 0, // SPIR_TWORK_FLOAT64
12    /// Use extended precision (128-bit double-double)
13    Float64X2 = 1, // SPIR_TWORK_FLOAT64X2
14    /// Automatically choose precision based on epsilon
15    Auto = -1, // SPIR_TWORK_AUTO
16}
17
18/// SVD computation strategy
19///
20/// Values match the C-API constants defined in sparseir.h
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum SVDStrategy {
23    /// Fast computation
24    Fast = 0, // SPIR_SVDSTRAT_FAST
25    /// Accurate computation
26    Accurate = 1, // SPIR_SVDSTRAT_ACCURATE
27    /// Automatically choose strategy
28    Auto = -1, // SPIR_SVDSTRAT_AUTO
29}
30
31/// Determine safe epsilon and working precision
32///
33/// This function determines the safe epsilon value based on the working precision,
34/// and automatically selects the working precision if TworkType::Auto is specified.
35///
36/// # Arguments
37///
38/// * `epsilon` - Required accuracy (must be non-negative)
39/// * `twork` - Working precision type (Auto for automatic selection)
40/// * `svd_strategy` - SVD computation strategy (Auto for automatic selection)
41///
42/// # Returns
43///
44/// Tuple of (safe_epsilon, actual_twork, actual_svd_strategy)
45///
46/// # Panics
47///
48/// Panics if epsilon is negative
49pub fn safe_epsilon(
50    epsilon: f64,
51    twork: TworkType,
52    svd_strategy: SVDStrategy,
53) -> (f64, TworkType, SVDStrategy) {
54    // Check for negative epsilon (following C++ implementation)
55    if epsilon < 0.0 {
56        panic!("eps_required must be non-negative");
57    }
58
59    // First, choose the working dtype based on the eps required
60    let twork_actual = match twork {
61        TworkType::Auto => {
62            if epsilon.is_nan() || epsilon < 1e-8 {
63                TworkType::Float64X2 // MAX_DTYPE equivalent
64            } else {
65                TworkType::Float64
66            }
67        }
68        other => other,
69    };
70
71    // Next, work out the actual epsilon
72    let safe_eps = match twork_actual {
73        TworkType::Float64 => {
74            // This is technically a bit too low (the true value is about 1.5e-8),
75            // but it's not too far off and easier to remember for the user.
76            1e-8
77        }
78        TworkType::Float64X2 => {
79            // sqrt(Df64 epsilon) ≈ sqrt(2.465e-32) ≈ 1.57e-16
80            use crate::numeric::CustomNumeric;
81            crate::Df64::epsilon().sqrt().to_f64()
82        }
83        _ => 1e-8,
84    };
85
86    // Work out the SVD strategy to be used
87    let svd_strategy_actual = match svd_strategy {
88        SVDStrategy::Auto => {
89            if !epsilon.is_nan() && epsilon < safe_eps {
90                // TODO: Add warning output like C++
91                SVDStrategy::Accurate
92            } else {
93                SVDStrategy::Fast
94            }
95        }
96        other => other,
97    };
98
99    (safe_eps, twork_actual, svd_strategy_actual)
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn test_safe_epsilon_auto_float64() {
108        let (safe_eps, twork, _) = safe_epsilon(1e-7, TworkType::Auto, SVDStrategy::Auto);
109        assert_eq!(twork, TworkType::Float64);
110        assert_eq!(safe_eps, 1e-8);
111    }
112
113    #[test]
114    fn test_safe_epsilon_auto_float64x2() {
115        let (safe_eps, twork, _) = safe_epsilon(1e-10, TworkType::Auto, SVDStrategy::Auto);
116        assert_eq!(twork, TworkType::Float64X2);
117        // sqrt(Df64 epsilon) ≈ 1.57e-16
118        assert!((safe_eps - 1.5700924586837752e-16).abs() < 1e-20);
119    }
120
121    #[test]
122    fn test_safe_epsilon_explicit_precision() {
123        let (safe_eps, twork, _) = safe_epsilon(1e-7, TworkType::Float64X2, SVDStrategy::Auto);
124        assert_eq!(twork, TworkType::Float64X2);
125        // sqrt(Df64 epsilon) ≈ 1.57e-16
126        assert!((safe_eps - 1.5700924586837752e-16).abs() < 1e-20);
127    }
128
129    #[test]
130    fn test_svd_strategy_auto_accurate() {
131        // epsilon = 1e-20 < 1.57e-16 (safe_eps for Float64X2) → Accurate
132        let (_, _, strategy) = safe_epsilon(1e-20, TworkType::Auto, SVDStrategy::Auto);
133        assert_eq!(strategy, SVDStrategy::Accurate);
134    }
135
136    #[test]
137    fn test_svd_strategy_auto_fast() {
138        let (_, _, strategy) = safe_epsilon(1e-7, TworkType::Auto, SVDStrategy::Auto);
139        assert_eq!(strategy, SVDStrategy::Fast);
140    }
141
142    #[test]
143    #[should_panic(expected = "eps_required must be non-negative")]
144    fn test_negative_epsilon_panics() {
145        safe_epsilon(-1.0, TworkType::Auto, SVDStrategy::Auto);
146    }
147}