Skip to main content

trustformers_debug/
stability_checker.rs

1//! Numerical stability checking for model debugging
2//!
3//! This module provides tools to detect and analyze numerical stability issues in
4//! neural network computations, including NaN, Inf, underflow, and overflow detection.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11/// Numerical stability checker for detecting computational issues
12#[derive(Debug)]
13pub struct StabilityChecker {
14    /// Detected issues by layer name
15    issues: HashMap<String, Vec<StabilityIssue>>,
16    /// Configuration
17    config: StabilityConfig,
18    /// Issue counter for tracking
19    issue_counter: usize,
20}
21
22/// Configuration for stability checking
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct StabilityConfig {
25    /// Check for NaN values
26    pub check_nan: bool,
27    /// Check for Inf values
28    pub check_inf: bool,
29    /// Check for underflow (values too close to zero)
30    pub check_underflow: bool,
31    /// Check for overflow (values too large)
32    pub check_overflow: bool,
33    /// Underflow threshold
34    pub underflow_threshold: f64,
35    /// Overflow threshold
36    pub overflow_threshold: f64,
37    /// Whether to stop on first issue
38    pub stop_on_first_issue: bool,
39}
40
41impl Default for StabilityConfig {
42    fn default() -> Self {
43        Self {
44            check_nan: true,
45            check_inf: true,
46            check_underflow: true,
47            check_overflow: true,
48            underflow_threshold: 1e-15,
49            overflow_threshold: 1e15,
50            stop_on_first_issue: false,
51        }
52    }
53}
54
55/// Type of stability issue
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
57pub enum IssueKind {
58    /// Not a Number (NaN)
59    NaN,
60    /// Positive infinity
61    PosInf,
62    /// Negative infinity
63    NegInf,
64    /// Underflow (value too close to zero)
65    Underflow,
66    /// Overflow (value too large)
67    Overflow,
68    /// Precision loss
69    PrecisionLoss,
70}
71
72/// Detected stability issue
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct StabilityIssue {
75    /// Unique issue ID
76    pub id: usize,
77    /// Layer or operation name
78    pub layer_name: String,
79    /// Type of issue
80    pub kind: IssueKind,
81    /// Number of affected values
82    pub count: usize,
83    /// Position in tensor (if applicable)
84    pub positions: Vec<Vec<usize>>,
85    /// Sample values that triggered the issue
86    pub sample_values: Vec<f64>,
87    /// Timestamp when detected
88    pub timestamp: u64,
89    /// Additional context
90    pub context: Option<String>,
91}
92
93/// Summary of all detected issues
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct StabilitySummary {
96    /// Total number of issues
97    pub total_issues: usize,
98    /// Issues grouped by kind
99    pub issues_by_kind: HashMap<IssueKind, usize>,
100    /// Issues grouped by layer
101    pub issues_by_layer: HashMap<String, usize>,
102    /// Most problematic layers
103    pub problematic_layers: Vec<(String, usize)>,
104}
105
106impl StabilityChecker {
107    /// Create a new stability checker
108    ///
109    /// # Example
110    ///
111    /// ```
112    /// use trustformers_debug::StabilityChecker;
113    ///
114    /// let checker = StabilityChecker::new();
115    /// ```
116    pub fn new() -> Self {
117        Self {
118            issues: HashMap::new(),
119            config: StabilityConfig::default(),
120            issue_counter: 0,
121        }
122    }
123
124    /// Create a stability checker with custom configuration
125    pub fn with_config(config: StabilityConfig) -> Self {
126        Self {
127            issues: HashMap::new(),
128            config,
129            issue_counter: 0,
130        }
131    }
132
133    /// Check a tensor for stability issues
134    ///
135    /// # Arguments
136    ///
137    /// * `layer_name` - Name of the layer or operation
138    /// * `values` - Tensor values to check
139    ///
140    /// # Example
141    ///
142    /// ```
143    /// # use trustformers_debug::StabilityChecker;
144    /// # let mut checker = StabilityChecker::new();
145    /// let values = vec![1.0, f64::NAN, 2.0, f64::INFINITY];
146    /// let issues = checker.check_tensor("layer1", &values).unwrap();
147    /// assert!(issues > 0);
148    /// ```
149    pub fn check_tensor(&mut self, layer_name: &str, values: &[f64]) -> Result<usize> {
150        let mut issues_found = 0;
151
152        // Check for NaN
153        if self.config.check_nan {
154            issues_found += self.check_nan(layer_name, values)?;
155        }
156
157        // Check for Inf
158        if self.config.check_inf {
159            issues_found += self.check_inf(layer_name, values)?;
160        }
161
162        // Check for underflow
163        if self.config.check_underflow {
164            issues_found += self.check_underflow(layer_name, values)?;
165        }
166
167        // Check for overflow
168        if self.config.check_overflow {
169            issues_found += self.check_overflow(layer_name, values)?;
170        }
171
172        if self.config.stop_on_first_issue && issues_found > 0 {
173            anyhow::bail!("Stability issues detected in {}", layer_name);
174        }
175
176        Ok(issues_found)
177    }
178
179    /// Check for NaN values
180    fn check_nan(&mut self, layer_name: &str, values: &[f64]) -> Result<usize> {
181        let mut positions = Vec::new();
182        let mut sample_values = Vec::new();
183
184        for (i, &value) in values.iter().enumerate() {
185            if value.is_nan() {
186                positions.push(vec![i]);
187                if sample_values.len() < 10 {
188                    sample_values.push(value);
189                }
190            }
191        }
192
193        if !positions.is_empty() {
194            let id = self.next_issue_id();
195            self.add_issue(StabilityIssue {
196                id,
197                layer_name: layer_name.to_string(),
198                kind: IssueKind::NaN,
199                count: positions.len(),
200                positions,
201                sample_values,
202                timestamp: current_timestamp()?,
203                context: None,
204            });
205            Ok(1)
206        } else {
207            Ok(0)
208        }
209    }
210
211    /// Check for Inf values
212    fn check_inf(&mut self, layer_name: &str, values: &[f64]) -> Result<usize> {
213        let mut pos_inf_positions = Vec::new();
214        let mut neg_inf_positions = Vec::new();
215        let mut pos_inf_samples = Vec::new();
216        let mut neg_inf_samples = Vec::new();
217
218        for (i, &value) in values.iter().enumerate() {
219            if value.is_infinite() {
220                if value.is_sign_positive() {
221                    pos_inf_positions.push(vec![i]);
222                    if pos_inf_samples.len() < 10 {
223                        pos_inf_samples.push(value);
224                    }
225                } else {
226                    neg_inf_positions.push(vec![i]);
227                    if neg_inf_samples.len() < 10 {
228                        neg_inf_samples.push(value);
229                    }
230                }
231            }
232        }
233
234        let mut issues_count = 0;
235
236        if !pos_inf_positions.is_empty() {
237            let id = self.next_issue_id();
238            self.add_issue(StabilityIssue {
239                id,
240                layer_name: layer_name.to_string(),
241                kind: IssueKind::PosInf,
242                count: pos_inf_positions.len(),
243                positions: pos_inf_positions,
244                sample_values: pos_inf_samples,
245                timestamp: current_timestamp()?,
246                context: None,
247            });
248            issues_count += 1;
249        }
250
251        if !neg_inf_positions.is_empty() {
252            let id = self.next_issue_id();
253            self.add_issue(StabilityIssue {
254                id,
255                layer_name: layer_name.to_string(),
256                kind: IssueKind::NegInf,
257                count: neg_inf_positions.len(),
258                positions: neg_inf_positions,
259                sample_values: neg_inf_samples,
260                timestamp: current_timestamp()?,
261                context: None,
262            });
263            issues_count += 1;
264        }
265
266        Ok(issues_count)
267    }
268
269    /// Check for underflow
270    fn check_underflow(&mut self, layer_name: &str, values: &[f64]) -> Result<usize> {
271        let mut positions = Vec::new();
272        let mut sample_values = Vec::new();
273
274        for (i, &value) in values.iter().enumerate() {
275            if !value.is_nan()
276                && !value.is_infinite()
277                && value != 0.0
278                && value.abs() < self.config.underflow_threshold
279            {
280                positions.push(vec![i]);
281                if sample_values.len() < 10 {
282                    sample_values.push(value);
283                }
284            }
285        }
286
287        if !positions.is_empty() {
288            let id = self.next_issue_id();
289            let threshold = self.config.underflow_threshold;
290            self.add_issue(StabilityIssue {
291                id,
292                layer_name: layer_name.to_string(),
293                kind: IssueKind::Underflow,
294                count: positions.len(),
295                positions,
296                sample_values,
297                timestamp: current_timestamp()?,
298                context: Some(format!("threshold: {}", threshold)),
299            });
300            Ok(1)
301        } else {
302            Ok(0)
303        }
304    }
305
306    /// Check for overflow
307    fn check_overflow(&mut self, layer_name: &str, values: &[f64]) -> Result<usize> {
308        let mut positions = Vec::new();
309        let mut sample_values = Vec::new();
310
311        for (i, &value) in values.iter().enumerate() {
312            if !value.is_nan()
313                && !value.is_infinite()
314                && value.abs() > self.config.overflow_threshold
315            {
316                positions.push(vec![i]);
317                if sample_values.len() < 10 {
318                    sample_values.push(value);
319                }
320            }
321        }
322
323        if !positions.is_empty() {
324            let id = self.next_issue_id();
325            let threshold = self.config.overflow_threshold;
326            self.add_issue(StabilityIssue {
327                id,
328                layer_name: layer_name.to_string(),
329                kind: IssueKind::Overflow,
330                count: positions.len(),
331                positions,
332                sample_values,
333                timestamp: current_timestamp()?,
334                context: Some(format!("threshold: {}", threshold)),
335            });
336            Ok(1)
337        } else {
338            Ok(0)
339        }
340    }
341
342    /// Add an issue to the tracker
343    fn add_issue(&mut self, issue: StabilityIssue) {
344        let layer_name = issue.layer_name.clone();
345        self.issues.entry(layer_name).or_default().push(issue);
346    }
347
348    /// Get the next issue ID
349    fn next_issue_id(&mut self) -> usize {
350        let id = self.issue_counter;
351        self.issue_counter += 1;
352        id
353    }
354
355    /// Get all issues for a specific layer
356    pub fn get_issues(&self, layer_name: &str) -> Option<&Vec<StabilityIssue>> {
357        self.issues.get(layer_name)
358    }
359
360    /// Get all issues
361    pub fn get_all_issues(&self) -> Vec<&StabilityIssue> {
362        self.issues.values().flatten().collect()
363    }
364
365    /// Get summary of all issues
366    pub fn summary(&self) -> StabilitySummary {
367        let mut issues_by_kind: HashMap<IssueKind, usize> = HashMap::new();
368        let mut issues_by_layer: HashMap<String, usize> = HashMap::new();
369
370        for (layer_name, layer_issues) in &self.issues {
371            issues_by_layer.insert(layer_name.clone(), layer_issues.len());
372
373            for issue in layer_issues {
374                *issues_by_kind.entry(issue.kind).or_insert(0) += 1;
375            }
376        }
377
378        let mut problematic_layers: Vec<_> =
379            issues_by_layer.iter().map(|(k, &v)| (k.clone(), v)).collect();
380        problematic_layers.sort_by(|a, b| b.1.cmp(&a.1));
381
382        let total_issues = self.get_all_issues().len();
383
384        StabilitySummary {
385            total_issues,
386            issues_by_kind,
387            issues_by_layer,
388            problematic_layers,
389        }
390    }
391
392    /// Print a detailed report
393    pub fn report(&self) -> String {
394        let mut output = String::new();
395        output.push_str("Numerical Stability Report\n");
396        output.push_str(&"=".repeat(80));
397        output.push('\n');
398
399        let summary = self.summary();
400
401        output.push_str(&format!("\nTotal Issues: {}\n", summary.total_issues));
402
403        output.push_str("\nIssues by Type:\n");
404        for (kind, count) in &summary.issues_by_kind {
405            output.push_str(&format!("  {:?}: {}\n", kind, count));
406        }
407
408        output.push_str("\nMost Problematic Layers:\n");
409        for (layer, count) in summary.problematic_layers.iter().take(10) {
410            output.push_str(&format!("  {}: {} issues\n", layer, count));
411        }
412
413        output.push_str("\nDetailed Issues:\n");
414        for (layer_name, layer_issues) in &self.issues {
415            output.push_str(&format!("\n  Layer: {}\n", layer_name));
416            for issue in layer_issues {
417                output.push_str(&format!(
418                    "    [{:?}] {} occurrences",
419                    issue.kind, issue.count
420                ));
421                if let Some(ref context) = issue.context {
422                    output.push_str(&format!(" ({})", context));
423                }
424                output.push('\n');
425            }
426        }
427
428        output
429    }
430
431    /// Export issues to JSON
432    pub fn export_to_json(&self, output_path: &Path) -> Result<()> {
433        let json = serde_json::to_string_pretty(&self.issues)?;
434        std::fs::write(output_path, json)?;
435        Ok(())
436    }
437
438    /// Clear all recorded issues
439    pub fn clear(&mut self) {
440        self.issues.clear();
441        self.issue_counter = 0;
442    }
443
444    /// Check if any issues were detected
445    pub fn has_issues(&self) -> bool {
446        !self.issues.is_empty()
447    }
448
449    /// Get total number of issues
450    pub fn total_issues(&self) -> usize {
451        self.issues.values().map(|v| v.len()).sum()
452    }
453}
454
455impl Default for StabilityChecker {
456    fn default() -> Self {
457        Self::new()
458    }
459}
460
461/// Helper function to get current timestamp
462fn current_timestamp() -> Result<u64> {
463    Ok(std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs())
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_stability_checker_creation() {
472        let checker = StabilityChecker::new();
473        assert_eq!(checker.total_issues(), 0);
474    }
475
476    #[test]
477    fn test_check_nan() {
478        let mut checker = StabilityChecker::new();
479        let values = vec![1.0, f64::NAN, 2.0, f64::NAN];
480
481        let issues = checker.check_tensor("layer1", &values).unwrap();
482        assert!(issues > 0);
483        assert!(checker.has_issues());
484    }
485
486    #[test]
487    fn test_check_inf() {
488        let mut checker = StabilityChecker::new();
489        let values = vec![1.0, f64::INFINITY, 2.0, f64::NEG_INFINITY];
490
491        let issues = checker.check_tensor("layer1", &values).unwrap();
492        assert!(issues > 0);
493        assert!(checker.has_issues());
494    }
495
496    #[test]
497    fn test_check_underflow() {
498        let mut checker = StabilityChecker::new();
499        let values = vec![1.0, 1e-20, 2.0, 1e-18];
500
501        let issues = checker.check_tensor("layer1", &values).unwrap();
502        assert!(issues > 0);
503    }
504
505    #[test]
506    fn test_check_overflow() {
507        let mut config = StabilityConfig::default();
508        config.overflow_threshold = 100.0;
509
510        let mut checker = StabilityChecker::with_config(config);
511        let values = vec![1.0, 200.0, 2.0, 300.0];
512
513        let issues = checker.check_tensor("layer1", &values).unwrap();
514        assert!(issues > 0);
515    }
516
517    #[test]
518    fn test_summary() {
519        let mut checker = StabilityChecker::new();
520
521        checker.check_tensor("layer1", &[f64::NAN, 1.0]).unwrap();
522        checker.check_tensor("layer2", &[f64::INFINITY, 2.0]).unwrap();
523
524        let summary = checker.summary();
525        assert!(summary.total_issues > 0);
526        assert_eq!(summary.issues_by_layer.len(), 2);
527    }
528
529    #[test]
530    fn test_report() {
531        let mut checker = StabilityChecker::new();
532        checker.check_tensor("layer1", &[f64::NAN, 1.0]).unwrap();
533
534        let report = checker.report();
535        assert!(report.contains("Numerical Stability Report"));
536        assert!(report.contains("layer1"));
537    }
538
539    #[test]
540    fn test_export_to_json() {
541        use std::env;
542
543        let temp_dir = env::temp_dir();
544        let output_path = temp_dir.join("stability_issues.json");
545
546        let mut checker = StabilityChecker::new();
547        checker.check_tensor("layer1", &[f64::NAN, 1.0]).unwrap();
548
549        checker.export_to_json(&output_path).unwrap();
550        assert!(output_path.exists());
551
552        // Clean up
553        let _ = std::fs::remove_file(output_path);
554    }
555
556    #[test]
557    fn test_clear() {
558        let mut checker = StabilityChecker::new();
559        checker.check_tensor("layer1", &[f64::NAN]).unwrap();
560
561        assert!(checker.has_issues());
562
563        checker.clear();
564        assert!(!checker.has_issues());
565        assert_eq!(checker.total_issues(), 0);
566    }
567
568    #[test]
569    fn test_no_issues() {
570        let mut checker = StabilityChecker::new();
571        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
572
573        let issues = checker.check_tensor("layer1", &values).unwrap();
574        assert_eq!(issues, 0);
575        assert!(!checker.has_issues());
576    }
577
578    #[test]
579    fn test_custom_config() {
580        let config = StabilityConfig {
581            check_nan: true,
582            check_inf: false,
583            check_underflow: false,
584            check_overflow: false,
585            underflow_threshold: 1e-10,
586            overflow_threshold: 1e10,
587            stop_on_first_issue: false,
588        };
589
590        let mut checker = StabilityChecker::with_config(config);
591        let values = vec![1.0, f64::INFINITY, f64::NAN];
592
593        // Should only detect NaN, not Inf
594        let issues = checker.check_tensor("layer1", &values).unwrap();
595        assert_eq!(issues, 1);
596    }
597}