sparse_ir/sve/types.rs
1//! Type definitions for SVE computation
2
3use simba::scalar::ComplexField;
4
5/// Working precision type for SVE computations
6///
7/// Values match the C-API constants defined in sparseir.h
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum TworkType {
10 /// Use double precision (64-bit)
11 Float64 = 0, // SPIR_TWORK_FLOAT64
12 /// Use extended precision (128-bit double-double)
13 Float64X2 = 1, // SPIR_TWORK_FLOAT64X2
14 /// Automatically choose precision based on epsilon
15 Auto = -1, // SPIR_TWORK_AUTO
16}
17
18/// SVD computation strategy
19///
20/// Values match the C-API constants defined in sparseir.h
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum SVDStrategy {
23 /// Fast computation
24 Fast = 0, // SPIR_SVDSTRAT_FAST
25 /// Accurate computation
26 Accurate = 1, // SPIR_SVDSTRAT_ACCURATE
27 /// Automatically choose strategy
28 Auto = -1, // SPIR_SVDSTRAT_AUTO
29}
30
31/// Determine safe epsilon and working precision
32///
33/// This function determines the safe epsilon value based on the working precision,
34/// and automatically selects the working precision if TworkType::Auto is specified.
35///
36/// # Arguments
37///
38/// * `epsilon` - Required accuracy (must be non-negative)
39/// * `twork` - Working precision type (Auto for automatic selection)
40/// * `svd_strategy` - SVD computation strategy (Auto for automatic selection)
41///
42/// # Returns
43///
44/// Tuple of (safe_epsilon, actual_twork, actual_svd_strategy)
45///
46/// # Panics
47///
48/// Panics if epsilon is negative
49pub fn safe_epsilon(
50 epsilon: f64,
51 twork: TworkType,
52 svd_strategy: SVDStrategy,
53) -> (f64, TworkType, SVDStrategy) {
54 // Check for negative epsilon (following C++ implementation)
55 if epsilon < 0.0 {
56 panic!("eps_required must be non-negative");
57 }
58
59 // First, choose the working dtype based on the eps required
60 let twork_actual = match twork {
61 TworkType::Auto => {
62 if epsilon.is_nan() || epsilon < 1e-8 {
63 TworkType::Float64X2 // MAX_DTYPE equivalent
64 } else {
65 TworkType::Float64
66 }
67 }
68 other => other,
69 };
70
71 // Next, work out the actual epsilon.
72 // The precision floor is the smallest epsilon achievable with the chosen
73 // working type. The returned epsilon is the *larger* of the user's
74 // request and this floor so that (a) we never promise more accuracy than
75 // the arithmetic can deliver and (b) a user who asks for *less* accuracy
76 // actually gets what they asked for.
77 let precision_floor = match twork_actual {
78 TworkType::Float64 => {
79 // This is technically a bit too low (the true value is about 1.5e-8),
80 // but it's not too far off and easier to remember for the user.
81 1e-8
82 }
83 TworkType::Float64X2 => {
84 // sqrt(Df64 epsilon) ≈ sqrt(2.465e-32) ≈ 1.57e-16
85 use crate::numeric::CustomNumeric;
86 crate::Df64::epsilon().sqrt().to_f64()
87 }
88 _ => 1e-8,
89 };
90 let safe_eps = if epsilon.is_nan() {
91 precision_floor
92 } else {
93 epsilon.max(precision_floor)
94 };
95
96 // Work out the SVD strategy to be used
97 let svd_strategy_actual = match svd_strategy {
98 SVDStrategy::Auto => {
99 if !epsilon.is_nan() && epsilon < safe_eps {
100 // TODO: Add warning output like C++
101 SVDStrategy::Accurate
102 } else {
103 SVDStrategy::Fast
104 }
105 }
106 other => other,
107 };
108
109 (safe_eps, twork_actual, svd_strategy_actual)
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn test_safe_epsilon_auto_float64() {
118 // epsilon=1e-7 > floor=1e-8 → safe_eps should honour the user's request
119 let (safe_eps, twork, _) = safe_epsilon(1e-7, TworkType::Auto, SVDStrategy::Auto);
120 assert_eq!(twork, TworkType::Float64);
121 assert_eq!(safe_eps, 1e-7);
122 }
123
124 #[test]
125 fn test_safe_epsilon_auto_float64x2() {
126 // epsilon=1e-10 > floor≈1.57e-16 → safe_eps should be the user's epsilon
127 let (safe_eps, twork, _) = safe_epsilon(1e-10, TworkType::Auto, SVDStrategy::Auto);
128 assert_eq!(twork, TworkType::Float64X2);
129 assert_eq!(safe_eps, 1e-10);
130 }
131
132 #[test]
133 fn test_safe_epsilon_explicit_precision() {
134 // epsilon=1e-7 > floor≈1.57e-16 → safe_eps should honour the user's epsilon
135 let (safe_eps, twork, _) = safe_epsilon(1e-7, TworkType::Float64X2, SVDStrategy::Auto);
136 assert_eq!(twork, TworkType::Float64X2);
137 assert_eq!(safe_eps, 1e-7);
138 }
139
140 #[test]
141 fn test_svd_strategy_auto_accurate() {
142 // epsilon = 1e-20 < 1.57e-16 (safe_eps for Float64X2) → Accurate
143 let (_, _, strategy) = safe_epsilon(1e-20, TworkType::Auto, SVDStrategy::Auto);
144 assert_eq!(strategy, SVDStrategy::Accurate);
145 }
146
147 #[test]
148 fn test_svd_strategy_auto_fast() {
149 let (_, _, strategy) = safe_epsilon(1e-7, TworkType::Auto, SVDStrategy::Auto);
150 assert_eq!(strategy, SVDStrategy::Fast);
151 }
152
153 #[test]
154 #[should_panic(expected = "eps_required must be non-negative")]
155 fn test_negative_epsilon_panics() {
156 safe_epsilon(-1.0, TworkType::Auto, SVDStrategy::Auto);
157 }
158}