Skip to main content

sparse_ir/sve/
types.rs

1//! Type definitions for SVE computation
2
3use simba::scalar::ComplexField;
4
5/// Working precision type for SVE computations
6///
7/// Values match the C-API constants defined in sparseir.h
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum TworkType {
10    /// Use double precision (64-bit)
11    Float64 = 0, // SPIR_TWORK_FLOAT64
12    /// Use extended precision (128-bit double-double)
13    Float64X2 = 1, // SPIR_TWORK_FLOAT64X2
14    /// Automatically choose precision based on epsilon
15    Auto = -1, // SPIR_TWORK_AUTO
16}
17
18/// SVD computation strategy
19///
20/// Values match the C-API constants defined in sparseir.h
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum SVDStrategy {
23    /// Fast computation
24    Fast = 0, // SPIR_SVDSTRAT_FAST
25    /// Accurate computation
26    Accurate = 1, // SPIR_SVDSTRAT_ACCURATE
27    /// Automatically choose strategy
28    Auto = -1, // SPIR_SVDSTRAT_AUTO
29}
30
31/// Determine safe epsilon and working precision
32///
33/// This function determines the safe epsilon value based on the working precision,
34/// and automatically selects the working precision if TworkType::Auto is specified.
35///
36/// # Arguments
37///
38/// * `epsilon` - Required accuracy (must be non-negative)
39/// * `twork` - Working precision type (Auto for automatic selection)
40/// * `svd_strategy` - SVD computation strategy (Auto for automatic selection)
41///
42/// # Returns
43///
44/// Tuple of (safe_epsilon, actual_twork, actual_svd_strategy)
45///
46/// # Panics
47///
48/// Panics if epsilon is negative
49pub fn safe_epsilon(
50    epsilon: f64,
51    twork: TworkType,
52    svd_strategy: SVDStrategy,
53) -> (f64, TworkType, SVDStrategy) {
54    // Check for negative epsilon (following C++ implementation)
55    if epsilon < 0.0 {
56        panic!("eps_required must be non-negative");
57    }
58
59    // First, choose the working dtype based on the eps required
60    let twork_actual = match twork {
61        TworkType::Auto => {
62            if epsilon.is_nan() || epsilon < 1e-8 {
63                TworkType::Float64X2 // MAX_DTYPE equivalent
64            } else {
65                TworkType::Float64
66            }
67        }
68        other => other,
69    };
70
71    // Next, work out the actual epsilon.
72    // The precision floor is the smallest epsilon achievable with the chosen
73    // working type.  The returned epsilon is the *larger* of the user's
74    // request and this floor so that (a) we never promise more accuracy than
75    // the arithmetic can deliver and (b) a user who asks for *less* accuracy
76    // actually gets what they asked for.
77    let precision_floor = match twork_actual {
78        TworkType::Float64 => {
79            // This is technically a bit too low (the true value is about 1.5e-8),
80            // but it's not too far off and easier to remember for the user.
81            1e-8
82        }
83        TworkType::Float64X2 => {
84            // sqrt(Df64 epsilon) ≈ sqrt(2.465e-32) ≈ 1.57e-16
85            use crate::numeric::CustomNumeric;
86            crate::Df64::epsilon().sqrt().to_f64()
87        }
88        _ => 1e-8,
89    };
90    let safe_eps = if epsilon.is_nan() {
91        precision_floor
92    } else {
93        epsilon.max(precision_floor)
94    };
95
96    // Work out the SVD strategy to be used
97    let svd_strategy_actual = match svd_strategy {
98        SVDStrategy::Auto => {
99            if !epsilon.is_nan() && epsilon < safe_eps {
100                // TODO: Add warning output like C++
101                SVDStrategy::Accurate
102            } else {
103                SVDStrategy::Fast
104            }
105        }
106        other => other,
107    };
108
109    (safe_eps, twork_actual, svd_strategy_actual)
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn test_safe_epsilon_auto_float64() {
118        // epsilon=1e-7 > floor=1e-8 → safe_eps should honour the user's request
119        let (safe_eps, twork, _) = safe_epsilon(1e-7, TworkType::Auto, SVDStrategy::Auto);
120        assert_eq!(twork, TworkType::Float64);
121        assert_eq!(safe_eps, 1e-7);
122    }
123
124    #[test]
125    fn test_safe_epsilon_auto_float64x2() {
126        // epsilon=1e-10 > floor≈1.57e-16 → safe_eps should be the user's epsilon
127        let (safe_eps, twork, _) = safe_epsilon(1e-10, TworkType::Auto, SVDStrategy::Auto);
128        assert_eq!(twork, TworkType::Float64X2);
129        assert_eq!(safe_eps, 1e-10);
130    }
131
132    #[test]
133    fn test_safe_epsilon_explicit_precision() {
134        // epsilon=1e-7 > floor≈1.57e-16 → safe_eps should honour the user's epsilon
135        let (safe_eps, twork, _) = safe_epsilon(1e-7, TworkType::Float64X2, SVDStrategy::Auto);
136        assert_eq!(twork, TworkType::Float64X2);
137        assert_eq!(safe_eps, 1e-7);
138    }
139
140    #[test]
141    fn test_svd_strategy_auto_accurate() {
142        // epsilon = 1e-20 < 1.57e-16 (safe_eps for Float64X2) → Accurate
143        let (_, _, strategy) = safe_epsilon(1e-20, TworkType::Auto, SVDStrategy::Auto);
144        assert_eq!(strategy, SVDStrategy::Accurate);
145    }
146
147    #[test]
148    fn test_svd_strategy_auto_fast() {
149        let (_, _, strategy) = safe_epsilon(1e-7, TworkType::Auto, SVDStrategy::Auto);
150        assert_eq!(strategy, SVDStrategy::Fast);
151    }
152
153    #[test]
154    #[should_panic(expected = "eps_required must be non-negative")]
155    fn test_negative_epsilon_panics() {
156        safe_epsilon(-1.0, TworkType::Auto, SVDStrategy::Auto);
157    }
158}