1use crate::{Tensor, TensorElement};
16use std::fmt;
17use torsh_core::{dtype::FloatElement, error::Result};
18
19#[derive(Debug, Clone)]
21pub struct NanInfConfig {
22 pub check_nan: bool,
24 pub check_pos_inf: bool,
26 pub check_neg_inf: bool,
28 pub detailed_report: bool,
30 pub use_simd: bool,
32 pub fail_fast: bool,
34}
35
36impl Default for NanInfConfig {
37 fn default() -> Self {
38 Self {
39 check_nan: true,
40 check_pos_inf: true,
41 check_neg_inf: true,
42 detailed_report: false,
43 use_simd: true,
44 fail_fast: false,
45 }
46 }
47}
48
49impl NanInfConfig {
50 pub fn nan_only() -> Self {
52 Self {
53 check_nan: true,
54 check_pos_inf: false,
55 check_neg_inf: false,
56 ..Default::default()
57 }
58 }
59
60 pub fn inf_only() -> Self {
62 Self {
63 check_nan: false,
64 check_pos_inf: true,
65 check_neg_inf: true,
66 ..Default::default()
67 }
68 }
69
70 pub fn fast() -> Self {
72 Self {
73 detailed_report: false,
74 fail_fast: true,
75 ..Default::default()
76 }
77 }
78
79 pub fn detailed() -> Self {
81 Self {
82 detailed_report: true,
83 fail_fast: false,
84 ..Default::default()
85 }
86 }
87}
88
89#[derive(Debug, Clone, Default)]
91pub struct NanInfStats {
92 pub nan_count: usize,
94 pub pos_inf_count: usize,
96 pub neg_inf_count: usize,
98 pub total_issues: usize,
100 pub total_elements: usize,
102}
103
104impl NanInfStats {
105 pub fn has_issues(&self) -> bool {
107 self.total_issues > 0
108 }
109
110 pub fn only_nan(&self) -> bool {
112 self.nan_count > 0 && self.pos_inf_count == 0 && self.neg_inf_count == 0
113 }
114
115 pub fn only_inf(&self) -> bool {
117 self.nan_count == 0 && (self.pos_inf_count > 0 || self.neg_inf_count > 0)
118 }
119
120 pub fn issue_percentage(&self) -> f64 {
122 if self.total_elements == 0 {
123 0.0
124 } else {
125 (self.total_issues as f64 / self.total_elements as f64) * 100.0
126 }
127 }
128}
129
130impl fmt::Display for NanInfStats {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 write!(
133 f,
134 "NaN/Inf Stats: {} issues in {} elements ({:.2}%) - NaN: {}, +Inf: {}, -Inf: {}",
135 self.total_issues,
136 self.total_elements,
137 self.issue_percentage(),
138 self.nan_count,
139 self.pos_inf_count,
140 self.neg_inf_count
141 )
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct IssueLocation {
148 pub flat_index: usize,
150 pub coordinates: Vec<usize>,
152 pub value: f64,
154 pub issue_type: IssueType,
156}
157
158#[derive(Debug, Clone, PartialEq)]
160pub enum IssueType {
161 NaN,
163 PositiveInfinity,
165 NegativeInfinity,
167}
168
169impl fmt::Display for IssueType {
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 match self {
172 IssueType::NaN => write!(f, "NaN"),
173 IssueType::PositiveInfinity => write!(f, "+Inf"),
174 IssueType::NegativeInfinity => write!(f, "-Inf"),
175 }
176 }
177}
178
179#[derive(Debug, Clone)]
181pub struct NanInfReport {
182 pub stats: NanInfStats,
184 pub locations: Vec<IssueLocation>,
186 pub early_termination: bool,
188}
189
190impl NanInfReport {
191 pub fn is_clean(&self) -> bool {
193 !self.stats.has_issues()
194 }
195
196 pub fn issues_by_type(&self, issue_type: IssueType) -> Vec<&IssueLocation> {
198 self.locations
199 .iter()
200 .filter(|loc| loc.issue_type == issue_type)
201 .collect()
202 }
203}
204
205impl fmt::Display for NanInfReport {
206 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207 writeln!(f, "{}", self.stats)?;
208 if !self.locations.is_empty() {
209 writeln!(f, "Issue locations:")?;
210 for (i, loc) in self.locations.iter().enumerate() {
211 if i >= 10 {
212 writeln!(f, " ... and {} more", self.locations.len() - 10)?;
213 break;
214 }
215 writeln!(
216 f,
217 " [{:?}] {} = {}",
218 loc.coordinates, loc.issue_type, loc.value
219 )?;
220 }
221 }
222 if self.early_termination {
223 writeln!(f, "Note: Check terminated early (fail_fast mode)")?;
224 }
225 Ok(())
226 }
227}
228
229impl<T: TensorElement + FloatElement> Tensor<T> {
231 pub fn has_nan_inf(&self) -> bool {
247 let config = NanInfConfig::fast();
248 self.check_nan_inf_with_config(&config).stats.has_issues()
249 }
250
251 pub fn has_nan(&self) -> bool {
253 let config = NanInfConfig::nan_only();
254 self.check_nan_inf_with_config(&config).stats.nan_count > 0
255 }
256
257 pub fn has_inf(&self) -> bool {
259 let config = NanInfConfig::inf_only();
260 let stats = &self.check_nan_inf_with_config(&config).stats;
261 stats.pos_inf_count > 0 || stats.neg_inf_count > 0
262 }
263
264 pub fn count_nan_inf(&self) -> NanInfStats {
266 let config = NanInfConfig::default();
267 self.check_nan_inf_with_config(&config).stats
268 }
269
270 pub fn check_nan_inf_with_config(&self, config: &NanInfConfig) -> NanInfReport {
291 let data = match self.to_vec() {
292 Ok(d) => d,
293 Err(_) => {
294 return NanInfReport {
295 stats: NanInfStats::default(),
296 locations: Vec::new(),
297 early_termination: true,
298 }
299 }
300 };
301
302 let mut stats = NanInfStats {
303 total_elements: data.len(),
304 ..Default::default()
305 };
306 let mut locations = Vec::new();
307 let mut early_termination = false;
308
309 let shape_binding = self.shape();
310 let shape = shape_binding.dims();
311
312 for (flat_idx, &value) in data.iter().enumerate() {
313 let val_f64 = match torsh_core::dtype::TensorElement::to_f64(&value) {
314 Some(v) => v,
315 None => continue, };
317 let mut is_issue = false;
318 let mut issue_type = None;
319
320 if config.check_nan && val_f64.is_nan() {
322 stats.nan_count += 1;
323 stats.total_issues += 1;
324 is_issue = true;
325 issue_type = Some(IssueType::NaN);
326 } else if config.check_pos_inf && val_f64.is_infinite() && val_f64.is_sign_positive() {
327 stats.pos_inf_count += 1;
328 stats.total_issues += 1;
329 is_issue = true;
330 issue_type = Some(IssueType::PositiveInfinity);
331 } else if config.check_neg_inf && val_f64.is_infinite() && val_f64.is_sign_negative() {
332 stats.neg_inf_count += 1;
333 stats.total_issues += 1;
334 is_issue = true;
335 issue_type = Some(IssueType::NegativeInfinity);
336 }
337
338 if is_issue && config.detailed_report {
340 let coordinates = flat_to_multi_dim(flat_idx, shape);
341 locations.push(IssueLocation {
342 flat_index: flat_idx,
343 coordinates,
344 value: val_f64,
345 issue_type: issue_type
346 .expect("issue_type should be Some when is_issue is true"),
347 });
348 }
349
350 if is_issue && config.fail_fast {
352 early_termination = true;
353 break;
354 }
355 }
356
357 NanInfReport {
358 stats,
359 locations,
360 early_termination,
361 }
362 }
363
364 pub fn assert_finite(&self) {
381 let report = self.check_nan_inf_with_config(&NanInfConfig::detailed());
382 if report.stats.has_issues() {
383 panic!("Tensor contains non-finite values:\n{}", report);
384 }
385 }
386
387 pub fn replace_nan_inf(
403 &self,
404 nan_replacement: T,
405 pos_inf_replacement: T,
406 neg_inf_replacement: T,
407 ) -> Result<Self> {
408 let data = self.to_vec()?;
409 let mut new_data = Vec::with_capacity(data.len());
410
411 for &value in &data {
412 let val_f64 = match torsh_core::dtype::TensorElement::to_f64(&value) {
413 Some(v) => v,
414 None => {
415 new_data.push(value);
416 continue;
417 }
418 };
419 let new_value = if val_f64.is_nan() {
420 nan_replacement
421 } else if val_f64.is_infinite() && val_f64.is_sign_positive() {
422 pos_inf_replacement
423 } else if val_f64.is_infinite() && val_f64.is_sign_negative() {
424 neg_inf_replacement
425 } else {
426 value
427 };
428 new_data.push(new_value);
429 }
430
431 Self::from_data(new_data, self.shape().dims().to_vec(), self.device)
432 }
433
434 pub fn nan_inf_mask(&self) -> Result<Tensor<bool>> {
451 let data = self.to_vec()?;
452 let mask_data: Vec<bool> = data
453 .iter()
454 .map(|&value| {
455 match torsh_core::dtype::TensorElement::to_f64(&value) {
456 Some(val) => val.is_nan() || val.is_infinite(),
457 None => false, }
459 })
460 .collect();
461
462 Tensor::from_data(mask_data, self.shape().dims().to_vec(), self.device)
463 }
464}
465
466fn flat_to_multi_dim(flat_idx: usize, shape: &[usize]) -> Vec<usize> {
468 if shape.is_empty() {
469 return vec![0];
470 }
471
472 let mut coords = Vec::with_capacity(shape.len());
473 let mut remaining = flat_idx;
474
475 for &dim_size in shape.iter().rev() {
476 coords.push(remaining % dim_size);
477 remaining /= dim_size;
478 }
479
480 coords.reverse();
481 coords
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use crate::creation;
488 use torsh_core::device::DeviceType;
489
490 #[test]
491 fn test_clean_tensor() {
492 let tensor = creation::tensor_1d(&[1.0f32, 2.0, 3.0, 4.0])
493 .expect("tensor_1d creation should succeed");
494
495 assert!(!tensor.has_nan_inf());
496 assert!(!tensor.has_nan());
497 assert!(!tensor.has_inf());
498
499 let stats = tensor.count_nan_inf();
500 assert_eq!(stats.total_issues, 0);
501 assert_eq!(stats.total_elements, 4);
502 }
503
504 #[test]
505 fn test_nan_detection() {
506 let tensor = Tensor::from_data(
507 vec![1.0f32, f32::NAN, 3.0, f32::NAN],
508 vec![4],
509 DeviceType::Cpu,
510 )
511 .expect("tensor creation should succeed");
512
513 assert!(tensor.has_nan_inf());
514 assert!(tensor.has_nan());
515 assert!(!tensor.has_inf());
516
517 let stats = tensor.count_nan_inf();
518 assert_eq!(stats.nan_count, 2);
519 assert_eq!(stats.pos_inf_count, 0);
520 assert_eq!(stats.neg_inf_count, 0);
521 assert_eq!(stats.total_issues, 2);
522 }
523
524 #[test]
525 fn test_inf_detection() {
526 let tensor = Tensor::from_data(
527 vec![1.0f32, f32::INFINITY, 3.0, -f32::INFINITY],
528 vec![4],
529 DeviceType::Cpu,
530 )
531 .expect("tensor creation should succeed");
532
533 assert!(tensor.has_nan_inf());
534 assert!(!tensor.has_nan());
535 assert!(tensor.has_inf());
536
537 let stats = tensor.count_nan_inf();
538 assert_eq!(stats.nan_count, 0);
539 assert_eq!(stats.pos_inf_count, 1);
540 assert_eq!(stats.neg_inf_count, 1);
541 assert_eq!(stats.total_issues, 2);
542 }
543
544 #[test]
545 fn test_detailed_report() {
546 let tensor = Tensor::from_data(
547 vec![1.0f32, f32::NAN, f32::INFINITY, -f32::INFINITY],
548 vec![4],
549 DeviceType::Cpu,
550 )
551 .expect("tensor creation should succeed");
552
553 let config = NanInfConfig::detailed();
554 let report = tensor.check_nan_inf_with_config(&config);
555
556 assert_eq!(report.stats.total_issues, 3);
557 assert_eq!(report.locations.len(), 3);
558
559 assert_eq!(report.locations[0].flat_index, 1);
561 assert_eq!(report.locations[0].issue_type, IssueType::NaN);
562
563 assert_eq!(report.locations[1].flat_index, 2);
564 assert_eq!(report.locations[1].issue_type, IssueType::PositiveInfinity);
565
566 assert_eq!(report.locations[2].flat_index, 3);
567 assert_eq!(report.locations[2].issue_type, IssueType::NegativeInfinity);
568 }
569
570 #[test]
571 fn test_replace_nan_inf() {
572 let tensor = Tensor::from_data(
573 vec![1.0f32, f32::NAN, f32::INFINITY, -f32::INFINITY],
574 vec![4],
575 DeviceType::Cpu,
576 )
577 .expect("tensor creation should succeed");
578
579 let cleaned = tensor
580 .replace_nan_inf(0.0, 1e6, -1e6)
581 .expect("replace_nan_inf should succeed");
582 assert!(!cleaned.has_nan_inf());
583
584 let data = cleaned.to_vec().expect("to_vec conversion should succeed");
585 assert_eq!(data, vec![1.0, 0.0, 1e6, -1e6]);
586 }
587
588 #[test]
589 fn test_nan_inf_mask() {
590 let tensor = Tensor::from_data(
591 vec![1.0f32, f32::NAN, 3.0, f32::INFINITY],
592 vec![4],
593 DeviceType::Cpu,
594 )
595 .expect("tensor creation should succeed");
596
597 let mask = tensor.nan_inf_mask().expect("nan_inf_mask should succeed");
598 let mask_data = mask.to_vec().expect("to_vec conversion should succeed");
599 assert_eq!(mask_data, vec![false, true, false, true]);
600 }
601
602 #[test]
603 fn test_multi_dimensional_coordinates() {
604 let tensor = Tensor::from_data(
605 vec![1.0f32, f32::NAN, 3.0, f32::INFINITY, 5.0, 6.0],
606 vec![2, 3],
607 DeviceType::Cpu,
608 )
609 .expect("tensor creation should succeed");
610
611 let config = NanInfConfig::detailed();
612 let report = tensor.check_nan_inf_with_config(&config);
613
614 assert_eq!(report.locations.len(), 2);
615 assert_eq!(report.locations[0].coordinates, vec![0, 1]); assert_eq!(report.locations[1].coordinates, vec![1, 0]); }
618
619 #[test]
620 fn test_fail_fast() {
621 let tensor = Tensor::from_data(
622 vec![f32::NAN, f32::INFINITY, 3.0, 4.0],
623 vec![4],
624 DeviceType::Cpu,
625 )
626 .expect("tensor creation should succeed");
627
628 let config = NanInfConfig::fast();
629 let report = tensor.check_nan_inf_with_config(&config);
630
631 assert!(report.early_termination);
632 assert!(report.stats.total_issues > 0);
633 }
634
635 #[test]
636 #[should_panic(expected = "Tensor contains non-finite values")]
637 fn test_assert_finite_panic() {
638 let tensor = Tensor::from_data(vec![1.0f32, f32::NAN], vec![2], DeviceType::Cpu)
639 .expect("tensor creation should succeed");
640
641 tensor.assert_finite();
642 }
643
644 #[test]
645 fn test_assert_finite_ok() {
646 let tensor =
647 creation::tensor_1d(&[1.0f32, 2.0, 3.0]).expect("tensor_1d creation should succeed");
648 tensor.assert_finite(); }
650
651 #[test]
652 fn test_config_presets() {
653 let nan_config = NanInfConfig::nan_only();
654 assert!(nan_config.check_nan);
655 assert!(!nan_config.check_pos_inf);
656 assert!(!nan_config.check_neg_inf);
657
658 let inf_config = NanInfConfig::inf_only();
659 assert!(!inf_config.check_nan);
660 assert!(inf_config.check_pos_inf);
661 assert!(inf_config.check_neg_inf);
662
663 let fast_config = NanInfConfig::fast();
664 assert!(!fast_config.detailed_report);
665 assert!(fast_config.fail_fast);
666
667 let detailed_config = NanInfConfig::detailed();
668 assert!(detailed_config.detailed_report);
669 assert!(!detailed_config.fail_fast);
670 }
671}