1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11#[derive(Debug)]
13pub struct StabilityChecker {
14 issues: HashMap<String, Vec<StabilityIssue>>,
16 config: StabilityConfig,
18 issue_counter: usize,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct StabilityConfig {
25 pub check_nan: bool,
27 pub check_inf: bool,
29 pub check_underflow: bool,
31 pub check_overflow: bool,
33 pub underflow_threshold: f64,
35 pub overflow_threshold: f64,
37 pub stop_on_first_issue: bool,
39}
40
41impl Default for StabilityConfig {
42 fn default() -> Self {
43 Self {
44 check_nan: true,
45 check_inf: true,
46 check_underflow: true,
47 check_overflow: true,
48 underflow_threshold: 1e-15,
49 overflow_threshold: 1e15,
50 stop_on_first_issue: false,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
57pub enum IssueKind {
58 NaN,
60 PosInf,
62 NegInf,
64 Underflow,
66 Overflow,
68 PrecisionLoss,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct StabilityIssue {
75 pub id: usize,
77 pub layer_name: String,
79 pub kind: IssueKind,
81 pub count: usize,
83 pub positions: Vec<Vec<usize>>,
85 pub sample_values: Vec<f64>,
87 pub timestamp: u64,
89 pub context: Option<String>,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct StabilitySummary {
96 pub total_issues: usize,
98 pub issues_by_kind: HashMap<IssueKind, usize>,
100 pub issues_by_layer: HashMap<String, usize>,
102 pub problematic_layers: Vec<(String, usize)>,
104}
105
106impl StabilityChecker {
107 pub fn new() -> Self {
117 Self {
118 issues: HashMap::new(),
119 config: StabilityConfig::default(),
120 issue_counter: 0,
121 }
122 }
123
124 pub fn with_config(config: StabilityConfig) -> Self {
126 Self {
127 issues: HashMap::new(),
128 config,
129 issue_counter: 0,
130 }
131 }
132
133 pub fn check_tensor(&mut self, layer_name: &str, values: &[f64]) -> Result<usize> {
150 let mut issues_found = 0;
151
152 if self.config.check_nan {
154 issues_found += self.check_nan(layer_name, values)?;
155 }
156
157 if self.config.check_inf {
159 issues_found += self.check_inf(layer_name, values)?;
160 }
161
162 if self.config.check_underflow {
164 issues_found += self.check_underflow(layer_name, values)?;
165 }
166
167 if self.config.check_overflow {
169 issues_found += self.check_overflow(layer_name, values)?;
170 }
171
172 if self.config.stop_on_first_issue && issues_found > 0 {
173 anyhow::bail!("Stability issues detected in {}", layer_name);
174 }
175
176 Ok(issues_found)
177 }
178
179 fn check_nan(&mut self, layer_name: &str, values: &[f64]) -> Result<usize> {
181 let mut positions = Vec::new();
182 let mut sample_values = Vec::new();
183
184 for (i, &value) in values.iter().enumerate() {
185 if value.is_nan() {
186 positions.push(vec![i]);
187 if sample_values.len() < 10 {
188 sample_values.push(value);
189 }
190 }
191 }
192
193 if !positions.is_empty() {
194 let id = self.next_issue_id();
195 self.add_issue(StabilityIssue {
196 id,
197 layer_name: layer_name.to_string(),
198 kind: IssueKind::NaN,
199 count: positions.len(),
200 positions,
201 sample_values,
202 timestamp: current_timestamp()?,
203 context: None,
204 });
205 Ok(1)
206 } else {
207 Ok(0)
208 }
209 }
210
211 fn check_inf(&mut self, layer_name: &str, values: &[f64]) -> Result<usize> {
213 let mut pos_inf_positions = Vec::new();
214 let mut neg_inf_positions = Vec::new();
215 let mut pos_inf_samples = Vec::new();
216 let mut neg_inf_samples = Vec::new();
217
218 for (i, &value) in values.iter().enumerate() {
219 if value.is_infinite() {
220 if value.is_sign_positive() {
221 pos_inf_positions.push(vec![i]);
222 if pos_inf_samples.len() < 10 {
223 pos_inf_samples.push(value);
224 }
225 } else {
226 neg_inf_positions.push(vec![i]);
227 if neg_inf_samples.len() < 10 {
228 neg_inf_samples.push(value);
229 }
230 }
231 }
232 }
233
234 let mut issues_count = 0;
235
236 if !pos_inf_positions.is_empty() {
237 let id = self.next_issue_id();
238 self.add_issue(StabilityIssue {
239 id,
240 layer_name: layer_name.to_string(),
241 kind: IssueKind::PosInf,
242 count: pos_inf_positions.len(),
243 positions: pos_inf_positions,
244 sample_values: pos_inf_samples,
245 timestamp: current_timestamp()?,
246 context: None,
247 });
248 issues_count += 1;
249 }
250
251 if !neg_inf_positions.is_empty() {
252 let id = self.next_issue_id();
253 self.add_issue(StabilityIssue {
254 id,
255 layer_name: layer_name.to_string(),
256 kind: IssueKind::NegInf,
257 count: neg_inf_positions.len(),
258 positions: neg_inf_positions,
259 sample_values: neg_inf_samples,
260 timestamp: current_timestamp()?,
261 context: None,
262 });
263 issues_count += 1;
264 }
265
266 Ok(issues_count)
267 }
268
269 fn check_underflow(&mut self, layer_name: &str, values: &[f64]) -> Result<usize> {
271 let mut positions = Vec::new();
272 let mut sample_values = Vec::new();
273
274 for (i, &value) in values.iter().enumerate() {
275 if !value.is_nan()
276 && !value.is_infinite()
277 && value != 0.0
278 && value.abs() < self.config.underflow_threshold
279 {
280 positions.push(vec![i]);
281 if sample_values.len() < 10 {
282 sample_values.push(value);
283 }
284 }
285 }
286
287 if !positions.is_empty() {
288 let id = self.next_issue_id();
289 let threshold = self.config.underflow_threshold;
290 self.add_issue(StabilityIssue {
291 id,
292 layer_name: layer_name.to_string(),
293 kind: IssueKind::Underflow,
294 count: positions.len(),
295 positions,
296 sample_values,
297 timestamp: current_timestamp()?,
298 context: Some(format!("threshold: {}", threshold)),
299 });
300 Ok(1)
301 } else {
302 Ok(0)
303 }
304 }
305
306 fn check_overflow(&mut self, layer_name: &str, values: &[f64]) -> Result<usize> {
308 let mut positions = Vec::new();
309 let mut sample_values = Vec::new();
310
311 for (i, &value) in values.iter().enumerate() {
312 if !value.is_nan()
313 && !value.is_infinite()
314 && value.abs() > self.config.overflow_threshold
315 {
316 positions.push(vec![i]);
317 if sample_values.len() < 10 {
318 sample_values.push(value);
319 }
320 }
321 }
322
323 if !positions.is_empty() {
324 let id = self.next_issue_id();
325 let threshold = self.config.overflow_threshold;
326 self.add_issue(StabilityIssue {
327 id,
328 layer_name: layer_name.to_string(),
329 kind: IssueKind::Overflow,
330 count: positions.len(),
331 positions,
332 sample_values,
333 timestamp: current_timestamp()?,
334 context: Some(format!("threshold: {}", threshold)),
335 });
336 Ok(1)
337 } else {
338 Ok(0)
339 }
340 }
341
342 fn add_issue(&mut self, issue: StabilityIssue) {
344 let layer_name = issue.layer_name.clone();
345 self.issues.entry(layer_name).or_default().push(issue);
346 }
347
348 fn next_issue_id(&mut self) -> usize {
350 let id = self.issue_counter;
351 self.issue_counter += 1;
352 id
353 }
354
355 pub fn get_issues(&self, layer_name: &str) -> Option<&Vec<StabilityIssue>> {
357 self.issues.get(layer_name)
358 }
359
360 pub fn get_all_issues(&self) -> Vec<&StabilityIssue> {
362 self.issues.values().flatten().collect()
363 }
364
365 pub fn summary(&self) -> StabilitySummary {
367 let mut issues_by_kind: HashMap<IssueKind, usize> = HashMap::new();
368 let mut issues_by_layer: HashMap<String, usize> = HashMap::new();
369
370 for (layer_name, layer_issues) in &self.issues {
371 issues_by_layer.insert(layer_name.clone(), layer_issues.len());
372
373 for issue in layer_issues {
374 *issues_by_kind.entry(issue.kind).or_insert(0) += 1;
375 }
376 }
377
378 let mut problematic_layers: Vec<_> =
379 issues_by_layer.iter().map(|(k, &v)| (k.clone(), v)).collect();
380 problematic_layers.sort_by_key(|item| std::cmp::Reverse(item.1));
381
382 let total_issues = self.get_all_issues().len();
383
384 StabilitySummary {
385 total_issues,
386 issues_by_kind,
387 issues_by_layer,
388 problematic_layers,
389 }
390 }
391
392 pub fn report(&self) -> String {
394 let mut output = String::new();
395 output.push_str("Numerical Stability Report\n");
396 output.push_str(&"=".repeat(80));
397 output.push('\n');
398
399 let summary = self.summary();
400
401 output.push_str(&format!("\nTotal Issues: {}\n", summary.total_issues));
402
403 output.push_str("\nIssues by Type:\n");
404 for (kind, count) in &summary.issues_by_kind {
405 output.push_str(&format!(" {:?}: {}\n", kind, count));
406 }
407
408 output.push_str("\nMost Problematic Layers:\n");
409 for (layer, count) in summary.problematic_layers.iter().take(10) {
410 output.push_str(&format!(" {}: {} issues\n", layer, count));
411 }
412
413 output.push_str("\nDetailed Issues:\n");
414 for (layer_name, layer_issues) in &self.issues {
415 output.push_str(&format!("\n Layer: {}\n", layer_name));
416 for issue in layer_issues {
417 output.push_str(&format!(
418 " [{:?}] {} occurrences",
419 issue.kind, issue.count
420 ));
421 if let Some(ref context) = issue.context {
422 output.push_str(&format!(" ({})", context));
423 }
424 output.push('\n');
425 }
426 }
427
428 output
429 }
430
431 pub fn export_to_json(&self, output_path: &Path) -> Result<()> {
433 let json = serde_json::to_string_pretty(&self.issues)?;
434 std::fs::write(output_path, json)?;
435 Ok(())
436 }
437
438 pub fn clear(&mut self) {
440 self.issues.clear();
441 self.issue_counter = 0;
442 }
443
444 pub fn has_issues(&self) -> bool {
446 !self.issues.is_empty()
447 }
448
449 pub fn total_issues(&self) -> usize {
451 self.issues.values().map(|v| v.len()).sum()
452 }
453}
454
455impl Default for StabilityChecker {
456 fn default() -> Self {
457 Self::new()
458 }
459}
460
461fn current_timestamp() -> Result<u64> {
463 Ok(std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs())
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_stability_checker_creation() {
472 let checker = StabilityChecker::new();
473 assert_eq!(checker.total_issues(), 0);
474 }
475
476 #[test]
477 fn test_check_nan() {
478 let mut checker = StabilityChecker::new();
479 let values = vec![1.0, f64::NAN, 2.0, f64::NAN];
480
481 let issues = checker.check_tensor("layer1", &values).expect("tensor operation failed");
482 assert!(issues > 0);
483 assert!(checker.has_issues());
484 }
485
486 #[test]
487 fn test_check_inf() {
488 let mut checker = StabilityChecker::new();
489 let values = vec![1.0, f64::INFINITY, 2.0, f64::NEG_INFINITY];
490
491 let issues = checker.check_tensor("layer1", &values).expect("tensor operation failed");
492 assert!(issues > 0);
493 assert!(checker.has_issues());
494 }
495
496 #[test]
497 fn test_check_underflow() {
498 let mut checker = StabilityChecker::new();
499 let values = vec![1.0, 1e-20, 2.0, 1e-18];
500
501 let issues = checker.check_tensor("layer1", &values).expect("tensor operation failed");
502 assert!(issues > 0);
503 }
504
505 #[test]
506 fn test_check_overflow() {
507 let mut config = StabilityConfig::default();
508 config.overflow_threshold = 100.0;
509
510 let mut checker = StabilityChecker::with_config(config);
511 let values = vec![1.0, 200.0, 2.0, 300.0];
512
513 let issues = checker.check_tensor("layer1", &values).expect("tensor operation failed");
514 assert!(issues > 0);
515 }
516
517 #[test]
518 fn test_summary() {
519 let mut checker = StabilityChecker::new();
520
521 checker
522 .check_tensor("layer1", &[f64::NAN, 1.0])
523 .expect("tensor operation failed");
524 checker
525 .check_tensor("layer2", &[f64::INFINITY, 2.0])
526 .expect("tensor operation failed");
527
528 let summary = checker.summary();
529 assert!(summary.total_issues > 0);
530 assert_eq!(summary.issues_by_layer.len(), 2);
531 }
532
533 #[test]
534 fn test_report() {
535 let mut checker = StabilityChecker::new();
536 checker
537 .check_tensor("layer1", &[f64::NAN, 1.0])
538 .expect("tensor operation failed");
539
540 let report = checker.report();
541 assert!(report.contains("Numerical Stability Report"));
542 assert!(report.contains("layer1"));
543 }
544
545 #[test]
546 fn test_export_to_json() {
547 use std::env;
548
549 let temp_dir = env::temp_dir();
550 let output_path = temp_dir.join("stability_issues.json");
551
552 let mut checker = StabilityChecker::new();
553 checker
554 .check_tensor("layer1", &[f64::NAN, 1.0])
555 .expect("tensor operation failed");
556
557 checker.export_to_json(&output_path).expect("operation failed in test");
558 assert!(output_path.exists());
559
560 let _ = std::fs::remove_file(output_path);
562 }
563
564 #[test]
565 fn test_clear() {
566 let mut checker = StabilityChecker::new();
567 checker.check_tensor("layer1", &[f64::NAN]).expect("tensor operation failed");
568
569 assert!(checker.has_issues());
570
571 checker.clear();
572 assert!(!checker.has_issues());
573 assert_eq!(checker.total_issues(), 0);
574 }
575
576 #[test]
577 fn test_no_issues() {
578 let mut checker = StabilityChecker::new();
579 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
580
581 let issues = checker.check_tensor("layer1", &values).expect("tensor operation failed");
582 assert_eq!(issues, 0);
583 assert!(!checker.has_issues());
584 }
585
586 #[test]
587 fn test_custom_config() {
588 let config = StabilityConfig {
589 check_nan: true,
590 check_inf: false,
591 check_underflow: false,
592 check_overflow: false,
593 underflow_threshold: 1e-10,
594 overflow_threshold: 1e10,
595 stop_on_first_issue: false,
596 };
597
598 let mut checker = StabilityChecker::with_config(config);
599 let values = vec![1.0, f64::INFINITY, f64::NAN];
600
601 let issues = checker.check_tensor("layer1", &values).expect("tensor operation failed");
603 assert_eq!(issues, 1);
604 }
605}