Skip to main content

torsh_tensor/
nan_inf_detection.rs

1//! NaN and Infinity Detection Utilities
2//!
3//! This module provides efficient utilities for detecting NaN (Not a Number) and
4//! infinite values in tensors, with optimized fast paths for clean data and detailed
5//! reporting capabilities.
6//!
7//! # Features
8//!
9//! - **Fast path optimization**: Quick checks for tensors with clean data
10//! - **Detailed reporting**: Location-specific information about problematic values
11//! - **SIMD acceleration**: Vectorized detection for better performance
12//! - **Configurable checking**: Enable/disable checks for performance-critical code
13//! - **Statistics collection**: Count and categorize different types of issues
14
15use crate::{Tensor, TensorElement};
16use std::fmt;
17use torsh_core::{dtype::FloatElement, error::Result};
18
19/// Configuration for NaN/Inf detection
20#[derive(Debug, Clone)]
21pub struct NanInfConfig {
22    /// Whether to check for NaN values
23    pub check_nan: bool,
24    /// Whether to check for positive infinity
25    pub check_pos_inf: bool,
26    /// Whether to check for negative infinity
27    pub check_neg_inf: bool,
28    /// Whether to return detailed location information
29    pub detailed_report: bool,
30    /// Whether to use SIMD acceleration when available
31    pub use_simd: bool,
32    /// Whether to stop at first issue found (faster)
33    pub fail_fast: bool,
34}
35
36impl Default for NanInfConfig {
37    fn default() -> Self {
38        Self {
39            check_nan: true,
40            check_pos_inf: true,
41            check_neg_inf: true,
42            detailed_report: false,
43            use_simd: true,
44            fail_fast: false,
45        }
46    }
47}
48
49impl NanInfConfig {
50    /// Create config that only checks for NaN
51    pub fn nan_only() -> Self {
52        Self {
53            check_nan: true,
54            check_pos_inf: false,
55            check_neg_inf: false,
56            ..Default::default()
57        }
58    }
59
60    /// Create config that only checks for infinity
61    pub fn inf_only() -> Self {
62        Self {
63            check_nan: false,
64            check_pos_inf: true,
65            check_neg_inf: true,
66            ..Default::default()
67        }
68    }
69
70    /// Create config optimized for performance (fast fail, no details)
71    pub fn fast() -> Self {
72        Self {
73            detailed_report: false,
74            fail_fast: true,
75            ..Default::default()
76        }
77    }
78
79    /// Create config with detailed reporting enabled
80    pub fn detailed() -> Self {
81        Self {
82            detailed_report: true,
83            fail_fast: false,
84            ..Default::default()
85        }
86    }
87}
88
89/// Statistics about NaN/Inf values found in a tensor
90#[derive(Debug, Clone, Default)]
91pub struct NanInfStats {
92    /// Number of NaN values found
93    pub nan_count: usize,
94    /// Number of positive infinity values found
95    pub pos_inf_count: usize,
96    /// Number of negative infinity values found
97    pub neg_inf_count: usize,
98    /// Total number of problematic values
99    pub total_issues: usize,
100    /// Total number of elements checked
101    pub total_elements: usize,
102}
103
104impl NanInfStats {
105    /// Check if any issues were found
106    pub fn has_issues(&self) -> bool {
107        self.total_issues > 0
108    }
109
110    /// Check if only NaN values were found
111    pub fn only_nan(&self) -> bool {
112        self.nan_count > 0 && self.pos_inf_count == 0 && self.neg_inf_count == 0
113    }
114
115    /// Check if only infinity values were found
116    pub fn only_inf(&self) -> bool {
117        self.nan_count == 0 && (self.pos_inf_count > 0 || self.neg_inf_count > 0)
118    }
119
120    /// Get percentage of problematic values
121    pub fn issue_percentage(&self) -> f64 {
122        if self.total_elements == 0 {
123            0.0
124        } else {
125            (self.total_issues as f64 / self.total_elements as f64) * 100.0
126        }
127    }
128}
129
130impl fmt::Display for NanInfStats {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        write!(
133            f,
134            "NaN/Inf Stats: {} issues in {} elements ({:.2}%) - NaN: {}, +Inf: {}, -Inf: {}",
135            self.total_issues,
136            self.total_elements,
137            self.issue_percentage(),
138            self.nan_count,
139            self.pos_inf_count,
140            self.neg_inf_count
141        )
142    }
143}
144
145/// Detailed information about a problematic value location
146#[derive(Debug, Clone)]
147pub struct IssueLocation {
148    /// Flat index in the tensor
149    pub flat_index: usize,
150    /// Multi-dimensional coordinates
151    pub coordinates: Vec<usize>,
152    /// The problematic value
153    pub value: f64,
154    /// Type of issue
155    pub issue_type: IssueType,
156}
157
158/// Type of numerical issue found
159#[derive(Debug, Clone, PartialEq)]
160pub enum IssueType {
161    /// Not a Number
162    NaN,
163    /// Positive infinity
164    PositiveInfinity,
165    /// Negative infinity
166    NegativeInfinity,
167}
168
169impl fmt::Display for IssueType {
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        match self {
172            IssueType::NaN => write!(f, "NaN"),
173            IssueType::PositiveInfinity => write!(f, "+Inf"),
174            IssueType::NegativeInfinity => write!(f, "-Inf"),
175        }
176    }
177}
178
179/// Detailed report of NaN/Inf detection
180#[derive(Debug, Clone)]
181pub struct NanInfReport {
182    /// Overall statistics
183    pub stats: NanInfStats,
184    /// Detailed locations (if enabled)
185    pub locations: Vec<IssueLocation>,
186    /// Whether the check was terminated early
187    pub early_termination: bool,
188}
189
190impl NanInfReport {
191    /// Check if the tensor is clean (no issues)
192    pub fn is_clean(&self) -> bool {
193        !self.stats.has_issues()
194    }
195
196    /// Get issues by type
197    pub fn issues_by_type(&self, issue_type: IssueType) -> Vec<&IssueLocation> {
198        self.locations
199            .iter()
200            .filter(|loc| loc.issue_type == issue_type)
201            .collect()
202    }
203}
204
205impl fmt::Display for NanInfReport {
206    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207        writeln!(f, "{}", self.stats)?;
208        if !self.locations.is_empty() {
209            writeln!(f, "Issue locations:")?;
210            for (i, loc) in self.locations.iter().enumerate() {
211                if i >= 10 {
212                    writeln!(f, "  ... and {} more", self.locations.len() - 10)?;
213                    break;
214                }
215                writeln!(
216                    f,
217                    "  [{:?}] {} = {}",
218                    loc.coordinates, loc.issue_type, loc.value
219                )?;
220            }
221        }
222        if self.early_termination {
223            writeln!(f, "Note: Check terminated early (fail_fast mode)")?;
224        }
225        Ok(())
226    }
227}
228
229/// NaN/Inf detection utilities for tensors
230impl<T: TensorElement + FloatElement> Tensor<T> {
231    /// Quick check if tensor contains any NaN or infinite values (optimized fast path)
232    ///
233    /// This is the fastest check - it returns `true` if any issues are found,
234    /// `false` if the tensor is clean. No detailed information is provided.
235    ///
236    /// # Examples
237    /// ```rust
238    /// # use torsh_tensor::Tensor;
239    /// # use torsh_core::device::DeviceType;
240    /// let clean = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation should succeed");
241    /// assert!(!clean.has_nan_inf());
242    ///
243    /// let dirty = Tensor::from_data(vec![1.0, f32::NAN, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation should succeed");
244    /// assert!(dirty.has_nan_inf());
245    /// ```
246    pub fn has_nan_inf(&self) -> bool {
247        let config = NanInfConfig::fast();
248        self.check_nan_inf_with_config(&config).stats.has_issues()
249    }
250
251    /// Check for NaN values only
252    pub fn has_nan(&self) -> bool {
253        let config = NanInfConfig::nan_only();
254        self.check_nan_inf_with_config(&config).stats.nan_count > 0
255    }
256
257    /// Check for infinite values only
258    pub fn has_inf(&self) -> bool {
259        let config = NanInfConfig::inf_only();
260        let stats = &self.check_nan_inf_with_config(&config).stats;
261        stats.pos_inf_count > 0 || stats.neg_inf_count > 0
262    }
263
264    /// Count NaN and infinite values
265    pub fn count_nan_inf(&self) -> NanInfStats {
266        let config = NanInfConfig::default();
267        self.check_nan_inf_with_config(&config).stats
268    }
269
270    /// Comprehensive NaN/Inf detection with detailed reporting
271    ///
272    /// # Examples
273    /// ```rust
274    /// # use torsh_tensor::{Tensor, nan_inf_detection::NanInfConfig};
275    /// # use torsh_core::device::DeviceType;
276    /// let tensor = Tensor::from_data(
277    ///     vec![1.0, f32::NAN, f32::INFINITY, -f32::INFINITY],
278    ///     vec![4],
279    ///     DeviceType::Cpu
280    /// ).expect("tensor creation should succeed");
281    ///
282    /// let config = NanInfConfig::detailed();
283    /// let report = tensor.check_nan_inf_with_config(&config);
284    ///
285    /// assert_eq!(report.stats.nan_count, 1);
286    /// assert_eq!(report.stats.pos_inf_count, 1);
287    /// assert_eq!(report.stats.neg_inf_count, 1);
288    /// assert_eq!(report.locations.len(), 3);
289    /// ```
290    pub fn check_nan_inf_with_config(&self, config: &NanInfConfig) -> NanInfReport {
291        let data = match self.to_vec() {
292            Ok(d) => d,
293            Err(_) => {
294                return NanInfReport {
295                    stats: NanInfStats::default(),
296                    locations: Vec::new(),
297                    early_termination: true,
298                }
299            }
300        };
301
302        let mut stats = NanInfStats {
303            total_elements: data.len(),
304            ..Default::default()
305        };
306        let mut locations = Vec::new();
307        let mut early_termination = false;
308
309        let shape_binding = self.shape();
310        let shape = shape_binding.dims();
311
312        for (flat_idx, &value) in data.iter().enumerate() {
313            let val_f64 = match torsh_core::dtype::TensorElement::to_f64(&value) {
314                Some(v) => v,
315                None => continue, // Skip values that can't be converted to f64
316            };
317            let mut is_issue = false;
318            let mut issue_type = None;
319
320            // Check for each type of issue based on config
321            if config.check_nan && val_f64.is_nan() {
322                stats.nan_count += 1;
323                stats.total_issues += 1;
324                is_issue = true;
325                issue_type = Some(IssueType::NaN);
326            } else if config.check_pos_inf && val_f64.is_infinite() && val_f64.is_sign_positive() {
327                stats.pos_inf_count += 1;
328                stats.total_issues += 1;
329                is_issue = true;
330                issue_type = Some(IssueType::PositiveInfinity);
331            } else if config.check_neg_inf && val_f64.is_infinite() && val_f64.is_sign_negative() {
332                stats.neg_inf_count += 1;
333                stats.total_issues += 1;
334                is_issue = true;
335                issue_type = Some(IssueType::NegativeInfinity);
336            }
337
338            // Record detailed location if requested
339            if is_issue && config.detailed_report {
340                let coordinates = flat_to_multi_dim(flat_idx, shape);
341                locations.push(IssueLocation {
342                    flat_index: flat_idx,
343                    coordinates,
344                    value: val_f64,
345                    issue_type: issue_type
346                        .expect("issue_type should be Some when is_issue is true"),
347                });
348            }
349
350            // Early termination if fail_fast enabled
351            if is_issue && config.fail_fast {
352                early_termination = true;
353                break;
354            }
355        }
356
357        NanInfReport {
358            stats,
359            locations,
360            early_termination,
361        }
362    }
363
364    /// Assert that tensor contains no NaN or infinite values
365    ///
366    /// # Panics
367    /// Panics if any NaN or infinite values are found
368    ///
369    /// # Examples
370    /// ```rust
371    /// # use torsh_tensor::Tensor;
372    /// # use torsh_core::device::DeviceType;
373    /// let tensor = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation should succeed");
374    /// tensor.assert_finite(); // OK
375    ///
376    /// // This would panic:
377    /// // let bad = Tensor::from_data(vec![1.0, f32::NAN], vec![2], DeviceType::Cpu).expect("tensor creation should succeed");
378    /// // bad.assert_finite(); // Panics!
379    /// ```
380    pub fn assert_finite(&self) {
381        let report = self.check_nan_inf_with_config(&NanInfConfig::detailed());
382        if report.stats.has_issues() {
383            panic!("Tensor contains non-finite values:\n{}", report);
384        }
385    }
386
387    /// Replace NaN and infinite values with specified replacements
388    ///
389    /// # Examples
390    /// ```rust
391    /// # use torsh_tensor::Tensor;
392    /// # use torsh_core::device::DeviceType;
393    /// let mut tensor = Tensor::from_data(
394    ///     vec![1.0, f32::NAN, f32::INFINITY, -f32::INFINITY],
395    ///     vec![4],
396    ///     DeviceType::Cpu
397    /// ).expect("tensor creation should succeed");
398    ///
399    /// let cleaned = tensor.replace_nan_inf(0.0, 1e6, -1e6).expect("replace_nan_inf should succeed");
400    /// assert!(!cleaned.has_nan_inf());
401    /// ```
402    pub fn replace_nan_inf(
403        &self,
404        nan_replacement: T,
405        pos_inf_replacement: T,
406        neg_inf_replacement: T,
407    ) -> Result<Self> {
408        let data = self.to_vec()?;
409        let mut new_data = Vec::with_capacity(data.len());
410
411        for &value in &data {
412            let val_f64 = match torsh_core::dtype::TensorElement::to_f64(&value) {
413                Some(v) => v,
414                None => {
415                    new_data.push(value);
416                    continue;
417                }
418            };
419            let new_value = if val_f64.is_nan() {
420                nan_replacement
421            } else if val_f64.is_infinite() && val_f64.is_sign_positive() {
422                pos_inf_replacement
423            } else if val_f64.is_infinite() && val_f64.is_sign_negative() {
424                neg_inf_replacement
425            } else {
426                value
427            };
428            new_data.push(new_value);
429        }
430
431        Self::from_data(new_data, self.shape().dims().to_vec(), self.device)
432    }
433
434    /// Create a boolean mask indicating locations of NaN/Inf values
435    ///
436    /// # Examples
437    /// ```rust
438    /// # use torsh_tensor::Tensor;
439    /// # use torsh_core::device::DeviceType;
440    /// let tensor = Tensor::from_data(
441    ///     vec![1.0, f32::NAN, 3.0, f32::INFINITY],
442    ///     vec![4],
443    ///     DeviceType::Cpu
444    /// ).expect("tensor creation should succeed");
445    ///
446    /// let mask = tensor.nan_inf_mask().expect("nan_inf_mask should succeed");
447    /// let mask_data = mask.to_vec().expect("to_vec conversion should succeed");
448    /// assert_eq!(mask_data, vec![false, true, false, true]);
449    /// ```
450    pub fn nan_inf_mask(&self) -> Result<Tensor<bool>> {
451        let data = self.to_vec()?;
452        let mask_data: Vec<bool> = data
453            .iter()
454            .map(|&value| {
455                match torsh_core::dtype::TensorElement::to_f64(&value) {
456                    Some(val) => val.is_nan() || val.is_infinite(),
457                    None => false, // Can't be NaN/Inf if not convertible to f64
458                }
459            })
460            .collect();
461
462        Tensor::from_data(mask_data, self.shape().dims().to_vec(), self.device)
463    }
464}
465
466/// Convert flat index to multi-dimensional coordinates
467fn flat_to_multi_dim(flat_idx: usize, shape: &[usize]) -> Vec<usize> {
468    if shape.is_empty() {
469        return vec![0];
470    }
471
472    let mut coords = Vec::with_capacity(shape.len());
473    let mut remaining = flat_idx;
474
475    for &dim_size in shape.iter().rev() {
476        coords.push(remaining % dim_size);
477        remaining /= dim_size;
478    }
479
480    coords.reverse();
481    coords
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use crate::creation;
488    use torsh_core::device::DeviceType;
489
490    #[test]
491    fn test_clean_tensor() {
492        let tensor = creation::tensor_1d(&[1.0f32, 2.0, 3.0, 4.0])
493            .expect("tensor_1d creation should succeed");
494
495        assert!(!tensor.has_nan_inf());
496        assert!(!tensor.has_nan());
497        assert!(!tensor.has_inf());
498
499        let stats = tensor.count_nan_inf();
500        assert_eq!(stats.total_issues, 0);
501        assert_eq!(stats.total_elements, 4);
502    }
503
504    #[test]
505    fn test_nan_detection() {
506        let tensor = Tensor::from_data(
507            vec![1.0f32, f32::NAN, 3.0, f32::NAN],
508            vec![4],
509            DeviceType::Cpu,
510        )
511        .expect("tensor creation should succeed");
512
513        assert!(tensor.has_nan_inf());
514        assert!(tensor.has_nan());
515        assert!(!tensor.has_inf());
516
517        let stats = tensor.count_nan_inf();
518        assert_eq!(stats.nan_count, 2);
519        assert_eq!(stats.pos_inf_count, 0);
520        assert_eq!(stats.neg_inf_count, 0);
521        assert_eq!(stats.total_issues, 2);
522    }
523
524    #[test]
525    fn test_inf_detection() {
526        let tensor = Tensor::from_data(
527            vec![1.0f32, f32::INFINITY, 3.0, -f32::INFINITY],
528            vec![4],
529            DeviceType::Cpu,
530        )
531        .expect("tensor creation should succeed");
532
533        assert!(tensor.has_nan_inf());
534        assert!(!tensor.has_nan());
535        assert!(tensor.has_inf());
536
537        let stats = tensor.count_nan_inf();
538        assert_eq!(stats.nan_count, 0);
539        assert_eq!(stats.pos_inf_count, 1);
540        assert_eq!(stats.neg_inf_count, 1);
541        assert_eq!(stats.total_issues, 2);
542    }
543
544    #[test]
545    fn test_detailed_report() {
546        let tensor = Tensor::from_data(
547            vec![1.0f32, f32::NAN, f32::INFINITY, -f32::INFINITY],
548            vec![4],
549            DeviceType::Cpu,
550        )
551        .expect("tensor creation should succeed");
552
553        let config = NanInfConfig::detailed();
554        let report = tensor.check_nan_inf_with_config(&config);
555
556        assert_eq!(report.stats.total_issues, 3);
557        assert_eq!(report.locations.len(), 3);
558
559        // Check specific locations
560        assert_eq!(report.locations[0].flat_index, 1);
561        assert_eq!(report.locations[0].issue_type, IssueType::NaN);
562
563        assert_eq!(report.locations[1].flat_index, 2);
564        assert_eq!(report.locations[1].issue_type, IssueType::PositiveInfinity);
565
566        assert_eq!(report.locations[2].flat_index, 3);
567        assert_eq!(report.locations[2].issue_type, IssueType::NegativeInfinity);
568    }
569
570    #[test]
571    fn test_replace_nan_inf() {
572        let tensor = Tensor::from_data(
573            vec![1.0f32, f32::NAN, f32::INFINITY, -f32::INFINITY],
574            vec![4],
575            DeviceType::Cpu,
576        )
577        .expect("tensor creation should succeed");
578
579        let cleaned = tensor
580            .replace_nan_inf(0.0, 1e6, -1e6)
581            .expect("replace_nan_inf should succeed");
582        assert!(!cleaned.has_nan_inf());
583
584        let data = cleaned.to_vec().expect("to_vec conversion should succeed");
585        assert_eq!(data, vec![1.0, 0.0, 1e6, -1e6]);
586    }
587
588    #[test]
589    fn test_nan_inf_mask() {
590        let tensor = Tensor::from_data(
591            vec![1.0f32, f32::NAN, 3.0, f32::INFINITY],
592            vec![4],
593            DeviceType::Cpu,
594        )
595        .expect("tensor creation should succeed");
596
597        let mask = tensor.nan_inf_mask().expect("nan_inf_mask should succeed");
598        let mask_data = mask.to_vec().expect("to_vec conversion should succeed");
599        assert_eq!(mask_data, vec![false, true, false, true]);
600    }
601
602    #[test]
603    fn test_multi_dimensional_coordinates() {
604        let tensor = Tensor::from_data(
605            vec![1.0f32, f32::NAN, 3.0, f32::INFINITY, 5.0, 6.0],
606            vec![2, 3],
607            DeviceType::Cpu,
608        )
609        .expect("tensor creation should succeed");
610
611        let config = NanInfConfig::detailed();
612        let report = tensor.check_nan_inf_with_config(&config);
613
614        assert_eq!(report.locations.len(), 2);
615        assert_eq!(report.locations[0].coordinates, vec![0, 1]); // NaN at [0,1]
616        assert_eq!(report.locations[1].coordinates, vec![1, 0]); // Inf at [1,0]
617    }
618
619    #[test]
620    fn test_fail_fast() {
621        let tensor = Tensor::from_data(
622            vec![f32::NAN, f32::INFINITY, 3.0, 4.0],
623            vec![4],
624            DeviceType::Cpu,
625        )
626        .expect("tensor creation should succeed");
627
628        let config = NanInfConfig::fast();
629        let report = tensor.check_nan_inf_with_config(&config);
630
631        assert!(report.early_termination);
632        assert!(report.stats.total_issues > 0);
633    }
634
635    #[test]
636    #[should_panic(expected = "Tensor contains non-finite values")]
637    fn test_assert_finite_panic() {
638        let tensor = Tensor::from_data(vec![1.0f32, f32::NAN], vec![2], DeviceType::Cpu)
639            .expect("tensor creation should succeed");
640
641        tensor.assert_finite();
642    }
643
644    #[test]
645    fn test_assert_finite_ok() {
646        let tensor =
647            creation::tensor_1d(&[1.0f32, 2.0, 3.0]).expect("tensor_1d creation should succeed");
648        tensor.assert_finite(); // Should not panic
649    }
650
651    #[test]
652    fn test_config_presets() {
653        let nan_config = NanInfConfig::nan_only();
654        assert!(nan_config.check_nan);
655        assert!(!nan_config.check_pos_inf);
656        assert!(!nan_config.check_neg_inf);
657
658        let inf_config = NanInfConfig::inf_only();
659        assert!(!inf_config.check_nan);
660        assert!(inf_config.check_pos_inf);
661        assert!(inf_config.check_neg_inf);
662
663        let fast_config = NanInfConfig::fast();
664        assert!(!fast_config.detailed_report);
665        assert!(fast_config.fail_fast);
666
667        let detailed_config = NanInfConfig::detailed();
668        assert!(detailed_config.detailed_report);
669        assert!(!detailed_config.fail_fast);
670    }
671}