Skip to main content

tensorlogic_scirs_backend/
fallback.rs

1//! Fallback mechanisms for numerical stability.
2//!
3//! This module provides utilities to handle NaN, Inf, and other numerical
4//! issues gracefully during tensor operations.
5
6use crate::error::{NumericalErrorKind, TlBackendError, TlBackendResult};
7use scirs2_core::ndarray::ArrayD;
8
9/// Configuration for fallback behavior
10#[derive(Debug, Clone)]
11pub struct FallbackConfig {
12    /// Replace NaN with this value
13    pub nan_replacement: f64,
14    /// Replace positive infinity with this value
15    pub pos_inf_replacement: f64,
16    /// Replace negative infinity with this value
17    pub neg_inf_replacement: f64,
18    /// Whether to fail on NaN (if false, replace)
19    pub fail_on_nan: bool,
20    /// Whether to fail on Inf (if false, replace)
21    pub fail_on_inf: bool,
22    /// Minimum value (clamp below this)
23    pub min_value: Option<f64>,
24    /// Maximum value (clamp above this)
25    pub max_value: Option<f64>,
26}
27
28impl Default for FallbackConfig {
29    fn default() -> Self {
30        Self {
31            nan_replacement: 0.0,
32            pos_inf_replacement: 1e10,
33            neg_inf_replacement: -1e10,
34            fail_on_nan: false,
35            fail_on_inf: false,
36            min_value: None,
37            max_value: None,
38        }
39    }
40}
41
42impl FallbackConfig {
43    /// Create a strict config that fails on any numerical issue
44    pub fn strict() -> Self {
45        Self {
46            fail_on_nan: true,
47            fail_on_inf: true,
48            ..Default::default()
49        }
50    }
51
52    /// Create a permissive config that replaces all invalid values
53    pub fn permissive() -> Self {
54        Self {
55            fail_on_nan: false,
56            fail_on_inf: false,
57            ..Default::default()
58        }
59    }
60
61    /// Set NaN replacement value
62    pub fn with_nan_replacement(mut self, value: f64) -> Self {
63        self.nan_replacement = value;
64        self
65    }
66
67    /// Set infinity replacement values
68    pub fn with_inf_replacement(mut self, pos: f64, neg: f64) -> Self {
69        self.pos_inf_replacement = pos;
70        self.neg_inf_replacement = neg;
71        self
72    }
73
74    /// Set value clamping range
75    pub fn with_clamp(mut self, min: f64, max: f64) -> Self {
76        self.min_value = Some(min);
77        self.max_value = Some(max);
78        self
79    }
80}
81
82/// Check and potentially fix numerical issues in a tensor
83pub fn sanitize_tensor(
84    tensor: &ArrayD<f64>,
85    config: &FallbackConfig,
86    location: &str,
87) -> TlBackendResult<ArrayD<f64>> {
88    let mut result = tensor.clone();
89
90    // Check for NaN and Inf
91    for value in result.iter_mut() {
92        if value.is_nan() {
93            if config.fail_on_nan {
94                return Err(TlBackendError::numerical(NumericalErrorKind::NaN, location));
95            }
96            *value = config.nan_replacement;
97        } else if value.is_infinite() {
98            if config.fail_on_inf {
99                return Err(TlBackendError::numerical(
100                    NumericalErrorKind::Infinity,
101                    location,
102                ));
103            }
104            *value = if *value > 0.0 {
105                config.pos_inf_replacement
106            } else {
107                config.neg_inf_replacement
108            };
109        }
110
111        // Apply clamping if configured
112        if let Some(min) = config.min_value {
113            if *value < min {
114                *value = min;
115            }
116        }
117        if let Some(max) = config.max_value {
118            if *value > max {
119                *value = max;
120            }
121        }
122    }
123
124    Ok(result)
125}
126
127/// Check if a tensor contains any NaN values
128pub fn contains_nan(tensor: &ArrayD<f64>) -> bool {
129    tensor.iter().any(|v| v.is_nan())
130}
131
132/// Check if a tensor contains any infinite values
133pub fn contains_inf(tensor: &ArrayD<f64>) -> bool {
134    tensor.iter().any(|v| v.is_infinite())
135}
136
137/// Check if a tensor is numerically valid (no NaN or Inf)
138pub fn is_valid(tensor: &ArrayD<f64>) -> bool {
139    !contains_nan(tensor) && !contains_inf(tensor)
140}
141
142/// Replace NaN values with a specific value
143pub fn replace_nan(tensor: &ArrayD<f64>, replacement: f64) -> ArrayD<f64> {
144    tensor.mapv(|v| if v.is_nan() { replacement } else { v })
145}
146
147/// Replace infinite values with finite values
148pub fn replace_inf(
149    tensor: &ArrayD<f64>,
150    pos_replacement: f64,
151    neg_replacement: f64,
152) -> ArrayD<f64> {
153    tensor.mapv(|v| {
154        if v.is_infinite() {
155            if v > 0.0 {
156                pos_replacement
157            } else {
158                neg_replacement
159            }
160        } else {
161            v
162        }
163    })
164}
165
166/// Clamp tensor values to a range
167pub fn clamp(tensor: &ArrayD<f64>, min: f64, max: f64) -> ArrayD<f64> {
168    tensor.mapv(|v| v.max(min).min(max))
169}
170
171/// Safe division that avoids division by zero
172pub fn safe_div(a: &ArrayD<f64>, b: &ArrayD<f64>, epsilon: f64) -> ArrayD<f64> {
173    let b_safe = b.mapv(|v| {
174        if v.abs() < epsilon {
175            epsilon * v.signum()
176        } else {
177            v
178        }
179    });
180    a / &b_safe
181}
182
183/// Safe logarithm that avoids log(0)
184pub fn safe_log(tensor: &ArrayD<f64>, epsilon: f64) -> ArrayD<f64> {
185    tensor.mapv(|v| (v.max(epsilon)).ln())
186}
187
188/// Safe square root that avoids sqrt(negative)
189pub fn safe_sqrt(tensor: &ArrayD<f64>) -> ArrayD<f64> {
190    tensor.mapv(|v| v.max(0.0).sqrt())
191}
192
193/// Detect and report numerical issues in a tensor
194pub fn detect_issues(tensor: &ArrayD<f64>) -> Vec<NumericalIssue> {
195    let mut issues = Vec::new();
196
197    let nan_count = tensor.iter().filter(|v| v.is_nan()).count();
198    if nan_count > 0 {
199        issues.push(NumericalIssue {
200            kind: NumericalErrorKind::NaN,
201            count: nan_count,
202            percentage: (nan_count as f64 / tensor.len() as f64) * 100.0,
203        });
204    }
205
206    let inf_count = tensor.iter().filter(|v| v.is_infinite()).count();
207    if inf_count > 0 {
208        issues.push(NumericalIssue {
209            kind: NumericalErrorKind::Infinity,
210            count: inf_count,
211            percentage: (inf_count as f64 / tensor.len() as f64) * 100.0,
212        });
213    }
214
215    // Check for potential overflow (very large values)
216    let large_count = tensor
217        .iter()
218        .filter(|v| v.abs() > 1e100 && v.is_finite())
219        .count();
220    if large_count > 0 {
221        issues.push(NumericalIssue {
222            kind: NumericalErrorKind::Overflow,
223            count: large_count,
224            percentage: (large_count as f64 / tensor.len() as f64) * 100.0,
225        });
226    }
227
228    // Check for potential underflow (very small values)
229    let small_count = tensor
230        .iter()
231        .filter(|v| v.abs() < 1e-100 && **v != 0.0)
232        .count();
233    if small_count > 0 {
234        issues.push(NumericalIssue {
235            kind: NumericalErrorKind::Underflow,
236            count: small_count,
237            percentage: (small_count as f64 / tensor.len() as f64) * 100.0,
238        });
239    }
240
241    issues
242}
243
244/// Description of a numerical issue found in a tensor
245#[derive(Debug, Clone)]
246pub struct NumericalIssue {
247    /// Type of issue
248    pub kind: NumericalErrorKind,
249    /// Number of affected values
250    pub count: usize,
251    /// Percentage of tensor affected
252    pub percentage: f64,
253}
254
255impl std::fmt::Display for NumericalIssue {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        write!(
258            f,
259            "{:?}: {} values ({:.2}%)",
260            self.kind, self.count, self.percentage
261        )
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use scirs2_core::ndarray::array;
269
270    #[test]
271    fn test_contains_nan() {
272        let valid = array![1.0, 2.0, 3.0].into_dyn();
273        assert!(!contains_nan(&valid));
274
275        let invalid = array![1.0, f64::NAN, 3.0].into_dyn();
276        assert!(contains_nan(&invalid));
277    }
278
279    #[test]
280    fn test_contains_inf() {
281        let valid = array![1.0, 2.0, 3.0].into_dyn();
282        assert!(!contains_inf(&valid));
283
284        let invalid = array![1.0, f64::INFINITY, 3.0].into_dyn();
285        assert!(contains_inf(&invalid));
286    }
287
288    #[test]
289    fn test_is_valid() {
290        let valid = array![1.0, 2.0, 3.0].into_dyn();
291        assert!(is_valid(&valid));
292
293        let nan_tensor = array![1.0, f64::NAN, 3.0].into_dyn();
294        assert!(!is_valid(&nan_tensor));
295
296        let inf_tensor = array![1.0, f64::INFINITY, 3.0].into_dyn();
297        assert!(!is_valid(&inf_tensor));
298    }
299
300    #[test]
301    fn test_replace_nan() {
302        let tensor = array![1.0, f64::NAN, 3.0, f64::NAN].into_dyn();
303        let result = replace_nan(&tensor, 0.0);
304
305        assert_eq!(result[[0]], 1.0);
306        assert_eq!(result[[1]], 0.0);
307        assert_eq!(result[[2]], 3.0);
308        assert_eq!(result[[3]], 0.0);
309    }
310
311    #[test]
312    fn test_replace_inf() {
313        let tensor = array![1.0, f64::INFINITY, -3.0, f64::NEG_INFINITY].into_dyn();
314        let result = replace_inf(&tensor, 100.0, -100.0);
315
316        assert_eq!(result[[0]], 1.0);
317        assert_eq!(result[[1]], 100.0);
318        assert_eq!(result[[2]], -3.0);
319        assert_eq!(result[[3]], -100.0);
320    }
321
322    #[test]
323    fn test_clamp() {
324        let tensor = array![-5.0, 0.0, 5.0, 10.0].into_dyn();
325        let result = clamp(&tensor, -2.0, 7.0);
326
327        assert_eq!(result[[0]], -2.0);
328        assert_eq!(result[[1]], 0.0);
329        assert_eq!(result[[2]], 5.0);
330        assert_eq!(result[[3]], 7.0);
331    }
332
333    #[test]
334    fn test_sanitize_tensor_permissive() {
335        let tensor = array![1.0, f64::NAN, f64::INFINITY, -3.0].into_dyn();
336        let config = FallbackConfig::permissive();
337        let result = sanitize_tensor(&tensor, &config, "test").unwrap();
338
339        assert_eq!(result[[0]], 1.0);
340        assert_eq!(result[[1]], 0.0); // NaN replaced with 0.0
341        assert_eq!(result[[2]], 1e10); // Inf replaced
342        assert_eq!(result[[3]], -3.0);
343    }
344
345    #[test]
346    fn test_sanitize_tensor_strict() {
347        let tensor = array![1.0, f64::NAN, 3.0].into_dyn();
348        let config = FallbackConfig::strict();
349        let result = sanitize_tensor(&tensor, &config, "test");
350
351        assert!(result.is_err());
352    }
353
354    #[test]
355    fn test_safe_div() {
356        let a = array![1.0, 2.0, 3.0].into_dyn();
357        let b = array![2.0, 0.0, 4.0].into_dyn();
358        let result = safe_div(&a, &b, 1e-10);
359
360        assert_eq!(result[[0]], 0.5);
361        assert!(result[[1]].is_finite()); // Should not be Inf
362        assert_eq!(result[[2]], 0.75);
363    }
364
365    #[test]
366    fn test_safe_log() {
367        let tensor = array![1.0, 0.0, 10.0].into_dyn();
368        let result = safe_log(&tensor, 1e-10);
369
370        assert_eq!(result[[0]], 0.0);
371        assert!(result[[1]].is_finite()); // Should not be -Inf
372        assert!((result[[2]] - 10.0_f64.ln()).abs() < 1e-10);
373    }
374
375    #[test]
376    fn test_safe_sqrt() {
377        let tensor = array![4.0, -1.0, 9.0].into_dyn();
378        let result = safe_sqrt(&tensor);
379
380        assert_eq!(result[[0]], 2.0);
381        assert_eq!(result[[1]], 0.0); // Negative treated as 0
382        assert_eq!(result[[2]], 3.0);
383    }
384
385    #[test]
386    fn test_detect_issues() {
387        let tensor = array![1.0, f64::NAN, 3.0, f64::INFINITY, 5.0, f64::NAN, 7.0, 8.0].into_dyn();
388
389        let issues = detect_issues(&tensor);
390
391        assert!(issues
392            .iter()
393            .any(|i| matches!(i.kind, NumericalErrorKind::NaN)));
394        assert!(issues
395            .iter()
396            .any(|i| matches!(i.kind, NumericalErrorKind::Infinity)));
397
398        let nan_issue = issues
399            .iter()
400            .find(|i| matches!(i.kind, NumericalErrorKind::NaN))
401            .unwrap();
402        assert_eq!(nan_issue.count, 2);
403    }
404
405    #[test]
406    fn test_fallback_config_builder() {
407        let config = FallbackConfig::default()
408            .with_nan_replacement(1.0)
409            .with_inf_replacement(1e5, -1e5)
410            .with_clamp(-100.0, 100.0);
411
412        assert_eq!(config.nan_replacement, 1.0);
413        assert_eq!(config.pos_inf_replacement, 1e5);
414        assert_eq!(config.neg_inf_replacement, -1e5);
415        assert_eq!(config.min_value, Some(-100.0));
416        assert_eq!(config.max_value, Some(100.0));
417    }
418}