tensorlogic_scirs_backend/
fallback.rs1use crate::error::{NumericalErrorKind, TlBackendError, TlBackendResult};
7use scirs2_core::ndarray::ArrayD;
8
9#[derive(Debug, Clone)]
11pub struct FallbackConfig {
12 pub nan_replacement: f64,
14 pub pos_inf_replacement: f64,
16 pub neg_inf_replacement: f64,
18 pub fail_on_nan: bool,
20 pub fail_on_inf: bool,
22 pub min_value: Option<f64>,
24 pub max_value: Option<f64>,
26}
27
28impl Default for FallbackConfig {
29 fn default() -> Self {
30 Self {
31 nan_replacement: 0.0,
32 pos_inf_replacement: 1e10,
33 neg_inf_replacement: -1e10,
34 fail_on_nan: false,
35 fail_on_inf: false,
36 min_value: None,
37 max_value: None,
38 }
39 }
40}
41
42impl FallbackConfig {
43 pub fn strict() -> Self {
45 Self {
46 fail_on_nan: true,
47 fail_on_inf: true,
48 ..Default::default()
49 }
50 }
51
52 pub fn permissive() -> Self {
54 Self {
55 fail_on_nan: false,
56 fail_on_inf: false,
57 ..Default::default()
58 }
59 }
60
61 pub fn with_nan_replacement(mut self, value: f64) -> Self {
63 self.nan_replacement = value;
64 self
65 }
66
67 pub fn with_inf_replacement(mut self, pos: f64, neg: f64) -> Self {
69 self.pos_inf_replacement = pos;
70 self.neg_inf_replacement = neg;
71 self
72 }
73
74 pub fn with_clamp(mut self, min: f64, max: f64) -> Self {
76 self.min_value = Some(min);
77 self.max_value = Some(max);
78 self
79 }
80}
81
82pub fn sanitize_tensor(
84 tensor: &ArrayD<f64>,
85 config: &FallbackConfig,
86 location: &str,
87) -> TlBackendResult<ArrayD<f64>> {
88 let mut result = tensor.clone();
89
90 for value in result.iter_mut() {
92 if value.is_nan() {
93 if config.fail_on_nan {
94 return Err(TlBackendError::numerical(NumericalErrorKind::NaN, location));
95 }
96 *value = config.nan_replacement;
97 } else if value.is_infinite() {
98 if config.fail_on_inf {
99 return Err(TlBackendError::numerical(
100 NumericalErrorKind::Infinity,
101 location,
102 ));
103 }
104 *value = if *value > 0.0 {
105 config.pos_inf_replacement
106 } else {
107 config.neg_inf_replacement
108 };
109 }
110
111 if let Some(min) = config.min_value {
113 if *value < min {
114 *value = min;
115 }
116 }
117 if let Some(max) = config.max_value {
118 if *value > max {
119 *value = max;
120 }
121 }
122 }
123
124 Ok(result)
125}
126
127pub fn contains_nan(tensor: &ArrayD<f64>) -> bool {
129 tensor.iter().any(|v| v.is_nan())
130}
131
132pub fn contains_inf(tensor: &ArrayD<f64>) -> bool {
134 tensor.iter().any(|v| v.is_infinite())
135}
136
137pub fn is_valid(tensor: &ArrayD<f64>) -> bool {
139 !contains_nan(tensor) && !contains_inf(tensor)
140}
141
142pub fn replace_nan(tensor: &ArrayD<f64>, replacement: f64) -> ArrayD<f64> {
144 tensor.mapv(|v| if v.is_nan() { replacement } else { v })
145}
146
147pub fn replace_inf(
149 tensor: &ArrayD<f64>,
150 pos_replacement: f64,
151 neg_replacement: f64,
152) -> ArrayD<f64> {
153 tensor.mapv(|v| {
154 if v.is_infinite() {
155 if v > 0.0 {
156 pos_replacement
157 } else {
158 neg_replacement
159 }
160 } else {
161 v
162 }
163 })
164}
165
166pub fn clamp(tensor: &ArrayD<f64>, min: f64, max: f64) -> ArrayD<f64> {
168 tensor.mapv(|v| v.max(min).min(max))
169}
170
171pub fn safe_div(a: &ArrayD<f64>, b: &ArrayD<f64>, epsilon: f64) -> ArrayD<f64> {
173 let b_safe = b.mapv(|v| {
174 if v.abs() < epsilon {
175 epsilon * v.signum()
176 } else {
177 v
178 }
179 });
180 a / &b_safe
181}
182
183pub fn safe_log(tensor: &ArrayD<f64>, epsilon: f64) -> ArrayD<f64> {
185 tensor.mapv(|v| (v.max(epsilon)).ln())
186}
187
188pub fn safe_sqrt(tensor: &ArrayD<f64>) -> ArrayD<f64> {
190 tensor.mapv(|v| v.max(0.0).sqrt())
191}
192
193pub fn detect_issues(tensor: &ArrayD<f64>) -> Vec<NumericalIssue> {
195 let mut issues = Vec::new();
196
197 let nan_count = tensor.iter().filter(|v| v.is_nan()).count();
198 if nan_count > 0 {
199 issues.push(NumericalIssue {
200 kind: NumericalErrorKind::NaN,
201 count: nan_count,
202 percentage: (nan_count as f64 / tensor.len() as f64) * 100.0,
203 });
204 }
205
206 let inf_count = tensor.iter().filter(|v| v.is_infinite()).count();
207 if inf_count > 0 {
208 issues.push(NumericalIssue {
209 kind: NumericalErrorKind::Infinity,
210 count: inf_count,
211 percentage: (inf_count as f64 / tensor.len() as f64) * 100.0,
212 });
213 }
214
215 let large_count = tensor
217 .iter()
218 .filter(|v| v.abs() > 1e100 && v.is_finite())
219 .count();
220 if large_count > 0 {
221 issues.push(NumericalIssue {
222 kind: NumericalErrorKind::Overflow,
223 count: large_count,
224 percentage: (large_count as f64 / tensor.len() as f64) * 100.0,
225 });
226 }
227
228 let small_count = tensor
230 .iter()
231 .filter(|v| v.abs() < 1e-100 && **v != 0.0)
232 .count();
233 if small_count > 0 {
234 issues.push(NumericalIssue {
235 kind: NumericalErrorKind::Underflow,
236 count: small_count,
237 percentage: (small_count as f64 / tensor.len() as f64) * 100.0,
238 });
239 }
240
241 issues
242}
243
244#[derive(Debug, Clone)]
246pub struct NumericalIssue {
247 pub kind: NumericalErrorKind,
249 pub count: usize,
251 pub percentage: f64,
253}
254
255impl std::fmt::Display for NumericalIssue {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 write!(
258 f,
259 "{:?}: {} values ({:.2}%)",
260 self.kind, self.count, self.percentage
261 )
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use scirs2_core::ndarray::array;
269
270 #[test]
271 fn test_contains_nan() {
272 let valid = array![1.0, 2.0, 3.0].into_dyn();
273 assert!(!contains_nan(&valid));
274
275 let invalid = array![1.0, f64::NAN, 3.0].into_dyn();
276 assert!(contains_nan(&invalid));
277 }
278
279 #[test]
280 fn test_contains_inf() {
281 let valid = array![1.0, 2.0, 3.0].into_dyn();
282 assert!(!contains_inf(&valid));
283
284 let invalid = array![1.0, f64::INFINITY, 3.0].into_dyn();
285 assert!(contains_inf(&invalid));
286 }
287
288 #[test]
289 fn test_is_valid() {
290 let valid = array![1.0, 2.0, 3.0].into_dyn();
291 assert!(is_valid(&valid));
292
293 let nan_tensor = array![1.0, f64::NAN, 3.0].into_dyn();
294 assert!(!is_valid(&nan_tensor));
295
296 let inf_tensor = array![1.0, f64::INFINITY, 3.0].into_dyn();
297 assert!(!is_valid(&inf_tensor));
298 }
299
300 #[test]
301 fn test_replace_nan() {
302 let tensor = array![1.0, f64::NAN, 3.0, f64::NAN].into_dyn();
303 let result = replace_nan(&tensor, 0.0);
304
305 assert_eq!(result[[0]], 1.0);
306 assert_eq!(result[[1]], 0.0);
307 assert_eq!(result[[2]], 3.0);
308 assert_eq!(result[[3]], 0.0);
309 }
310
311 #[test]
312 fn test_replace_inf() {
313 let tensor = array![1.0, f64::INFINITY, -3.0, f64::NEG_INFINITY].into_dyn();
314 let result = replace_inf(&tensor, 100.0, -100.0);
315
316 assert_eq!(result[[0]], 1.0);
317 assert_eq!(result[[1]], 100.0);
318 assert_eq!(result[[2]], -3.0);
319 assert_eq!(result[[3]], -100.0);
320 }
321
322 #[test]
323 fn test_clamp() {
324 let tensor = array![-5.0, 0.0, 5.0, 10.0].into_dyn();
325 let result = clamp(&tensor, -2.0, 7.0);
326
327 assert_eq!(result[[0]], -2.0);
328 assert_eq!(result[[1]], 0.0);
329 assert_eq!(result[[2]], 5.0);
330 assert_eq!(result[[3]], 7.0);
331 }
332
333 #[test]
334 fn test_sanitize_tensor_permissive() {
335 let tensor = array![1.0, f64::NAN, f64::INFINITY, -3.0].into_dyn();
336 let config = FallbackConfig::permissive();
337 let result = sanitize_tensor(&tensor, &config, "test").unwrap();
338
339 assert_eq!(result[[0]], 1.0);
340 assert_eq!(result[[1]], 0.0); assert_eq!(result[[2]], 1e10); assert_eq!(result[[3]], -3.0);
343 }
344
345 #[test]
346 fn test_sanitize_tensor_strict() {
347 let tensor = array![1.0, f64::NAN, 3.0].into_dyn();
348 let config = FallbackConfig::strict();
349 let result = sanitize_tensor(&tensor, &config, "test");
350
351 assert!(result.is_err());
352 }
353
354 #[test]
355 fn test_safe_div() {
356 let a = array![1.0, 2.0, 3.0].into_dyn();
357 let b = array![2.0, 0.0, 4.0].into_dyn();
358 let result = safe_div(&a, &b, 1e-10);
359
360 assert_eq!(result[[0]], 0.5);
361 assert!(result[[1]].is_finite()); assert_eq!(result[[2]], 0.75);
363 }
364
365 #[test]
366 fn test_safe_log() {
367 let tensor = array![1.0, 0.0, 10.0].into_dyn();
368 let result = safe_log(&tensor, 1e-10);
369
370 assert_eq!(result[[0]], 0.0);
371 assert!(result[[1]].is_finite()); assert!((result[[2]] - 10.0_f64.ln()).abs() < 1e-10);
373 }
374
375 #[test]
376 fn test_safe_sqrt() {
377 let tensor = array![4.0, -1.0, 9.0].into_dyn();
378 let result = safe_sqrt(&tensor);
379
380 assert_eq!(result[[0]], 2.0);
381 assert_eq!(result[[1]], 0.0); assert_eq!(result[[2]], 3.0);
383 }
384
385 #[test]
386 fn test_detect_issues() {
387 let tensor = array![1.0, f64::NAN, 3.0, f64::INFINITY, 5.0, f64::NAN, 7.0, 8.0].into_dyn();
388
389 let issues = detect_issues(&tensor);
390
391 assert!(issues
392 .iter()
393 .any(|i| matches!(i.kind, NumericalErrorKind::NaN)));
394 assert!(issues
395 .iter()
396 .any(|i| matches!(i.kind, NumericalErrorKind::Infinity)));
397
398 let nan_issue = issues
399 .iter()
400 .find(|i| matches!(i.kind, NumericalErrorKind::NaN))
401 .unwrap();
402 assert_eq!(nan_issue.count, 2);
403 }
404
405 #[test]
406 fn test_fallback_config_builder() {
407 let config = FallbackConfig::default()
408 .with_nan_replacement(1.0)
409 .with_inf_replacement(1e5, -1e5)
410 .with_clamp(-100.0, 100.0);
411
412 assert_eq!(config.nan_replacement, 1.0);
413 assert_eq!(config.pos_inf_replacement, 1e5);
414 assert_eq!(config.neg_inf_replacement, -1e5);
415 assert_eq!(config.min_value, Some(-100.0));
416 assert_eq!(config.max_value, Some(100.0));
417 }
418}