Skip to main content

trustformers_debug/
anomaly_detector.rs

1//! Anomaly Detection for Model Debugging
2//!
3//! Detects unusual patterns in model execution, tensor values, and gradients
4//! to help identify potential issues during training and inference.
5
6use anyhow::Result;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10
11use crate::DebugConfig;
12
13/// Anomaly detector for model execution
14#[derive(Debug)]
15pub struct AnomalyDetector {
16    config: AnomalyDetectorConfig,
17    detected_anomalies: Vec<Anomaly>,
18    start_time: DateTime<Utc>,
19    recovery_attempts: Vec<RecoveryAttempt>,
20    monitoring_stats: MonitoringStats,
21    performance_history: VecDeque<f64>,
22    #[allow(dead_code)]
23    gradient_history: HashMap<String, VecDeque<f64>>,
24    loss_history: VecDeque<f64>,
25    weight_baseline: HashMap<String, Vec<f32>>,
26}
27
28/// Configuration for anomaly detection
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct AnomalyDetectorConfig {
31    pub enable_nan_detection: bool,
32    pub enable_inf_detection: bool,
33    pub enable_gradient_explosion: bool,
34    pub enable_gradient_vanishing: bool,
35    pub gradient_threshold: f64,
36    pub enable_memory_leak_detection: bool,
37    pub enable_numerical_instability_detection: bool,
38    pub enable_gradient_conflict_detection: bool,
39    pub enable_performance_monitoring: bool,
40    pub enable_weight_divergence_detection: bool,
41    pub enable_activation_dead_detection: bool,
42    pub enable_loss_anomaly_detection: bool,
43    pub enable_auto_recovery: bool,
44    pub numerical_instability_threshold: f64,
45    pub performance_degradation_threshold: f64,
46    pub weight_divergence_threshold: f64,
47    pub loss_spike_threshold: f64,
48    pub monitoring_window_size: usize,
49    pub recovery_attempts_limit: usize,
50}
51
52impl Default for AnomalyDetectorConfig {
53    fn default() -> Self {
54        Self {
55            enable_nan_detection: true,
56            enable_inf_detection: true,
57            enable_gradient_explosion: true,
58            enable_gradient_vanishing: true,
59            gradient_threshold: 1e6,
60            enable_memory_leak_detection: true,
61            enable_numerical_instability_detection: true,
62            enable_gradient_conflict_detection: true,
63            enable_performance_monitoring: true,
64            enable_weight_divergence_detection: true,
65            enable_activation_dead_detection: true,
66            enable_loss_anomaly_detection: true,
67            enable_auto_recovery: false, // Conservative default
68            numerical_instability_threshold: 1e-12,
69            performance_degradation_threshold: 0.5, // 50% degradation
70            weight_divergence_threshold: 5.0,
71            loss_spike_threshold: 10.0, // 10x loss increase
72            monitoring_window_size: 100,
73            recovery_attempts_limit: 3,
74        }
75    }
76}
77
78/// Types of anomalies that can be detected
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum AnomalyType {
81    NaN,
82    Infinity,
83    GradientExplosion,
84    GradientVanishing,
85    MemoryLeak,
86    UnusualActivation,
87    NumericalInstability,
88    GradientConflict,
89    PerformanceDegradation,
90    WeightDivergence,
91    ActivationDead,
92    LossAnomalous,
93}
94
95/// An detected anomaly
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct Anomaly {
98    pub anomaly_type: AnomalyType,
99    pub timestamp: DateTime<Utc>,
100    pub location: String,
101    pub description: String,
102    pub severity: AnomalySeverity,
103    pub metadata: HashMap<String, String>,
104}
105
106/// Severity level of an anomaly
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub enum AnomalySeverity {
109    Low,
110    Medium,
111    High,
112    Critical,
113}
114
115/// Auto-recovery action that can be taken
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub enum RecoveryAction {
118    None,
119    ResetGradients,
120    ReduceLearningRate { factor: f64 },
121    ClipGradients { max_norm: f64 },
122    RestartOptimizer,
123    SkipBatch,
124    ResetWeights { layer_name: String },
125    ApplyWeightDecay { rate: f64 },
126    EmergencyStop,
127}
128
129/// Recovery attempt record
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct RecoveryAttempt {
132    pub anomaly_id: String,
133    pub action: RecoveryAction,
134    pub timestamp: DateTime<Utc>,
135    pub success: bool,
136    pub error_message: Option<String>,
137}
138
139/// Real-time monitoring statistics
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct MonitoringStats {
142    pub total_anomalies: usize,
143    pub anomalies_per_type: HashMap<String, usize>,
144    pub recovery_attempts: usize,
145    pub successful_recoveries: usize,
146    pub average_detection_time_ms: f64,
147    pub monitoring_window: Vec<AnomalySnapshot>,
148}
149
150/// Snapshot of anomaly state for monitoring window
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct AnomalySnapshot {
153    pub timestamp: DateTime<Utc>,
154    pub anomaly_count: usize,
155    pub severity_distribution: HashMap<String, usize>,
156    pub performance_metrics: HashMap<String, f64>,
157}
158
159impl AnomalyDetector {
160    /// Create a new anomaly detector
161    pub fn new(_config: &DebugConfig) -> Self {
162        let monitoring_window_size = AnomalyDetectorConfig::default().monitoring_window_size;
163        Self {
164            config: AnomalyDetectorConfig::default(),
165            detected_anomalies: Vec::new(),
166            start_time: Utc::now(),
167            recovery_attempts: Vec::new(),
168            monitoring_stats: MonitoringStats {
169                total_anomalies: 0,
170                anomalies_per_type: HashMap::new(),
171                recovery_attempts: 0,
172                successful_recoveries: 0,
173                average_detection_time_ms: 0.0,
174                monitoring_window: Vec::new(),
175            },
176            performance_history: VecDeque::with_capacity(monitoring_window_size),
177            gradient_history: HashMap::new(),
178            loss_history: VecDeque::with_capacity(monitoring_window_size),
179            weight_baseline: HashMap::new(),
180        }
181    }
182
183    /// Start the anomaly detector
184    pub async fn start(&mut self) -> Result<()> {
185        self.start_time = Utc::now();
186        self.detected_anomalies.clear();
187        Ok(())
188    }
189
190    /// Check for NaN values in tensors
191    pub fn check_nan(&mut self, values: &[f32], location: &str) -> Result<()> {
192        if !self.config.enable_nan_detection {
193            return Ok(());
194        }
195
196        if values.iter().any(|v| v.is_nan()) {
197            self.report_anomaly(Anomaly {
198                anomaly_type: AnomalyType::NaN,
199                timestamp: Utc::now(),
200                location: location.to_string(),
201                description: "NaN values detected in tensor".to_string(),
202                severity: AnomalySeverity::High,
203                metadata: HashMap::new(),
204            });
205        }
206
207        Ok(())
208    }
209
210    /// Check for infinite values in tensors
211    pub fn check_inf(&mut self, values: &[f32], location: &str) -> Result<()> {
212        if !self.config.enable_inf_detection {
213            return Ok(());
214        }
215
216        if values.iter().any(|v| v.is_infinite()) {
217            self.report_anomaly(Anomaly {
218                anomaly_type: AnomalyType::Infinity,
219                timestamp: Utc::now(),
220                location: location.to_string(),
221                description: "Infinite values detected in tensor".to_string(),
222                severity: AnomalySeverity::High,
223                metadata: HashMap::new(),
224            });
225        }
226
227        Ok(())
228    }
229
230    /// Check for gradient explosion
231    pub fn check_gradient_explosion(&mut self, gradient_norm: f64, location: &str) -> Result<()> {
232        if !self.config.enable_gradient_explosion {
233            return Ok(());
234        }
235
236        if gradient_norm > self.config.gradient_threshold {
237            self.report_anomaly(Anomaly {
238                anomaly_type: AnomalyType::GradientExplosion,
239                timestamp: Utc::now(),
240                location: location.to_string(),
241                description: format!("Gradient explosion detected: norm = {}", gradient_norm),
242                severity: AnomalySeverity::Critical,
243                metadata: {
244                    let mut meta = HashMap::new();
245                    meta.insert("gradient_norm".to_string(), gradient_norm.to_string());
246                    meta
247                },
248            });
249        }
250
251        Ok(())
252    }
253
254    /// Check for vanishing gradients
255    pub fn check_gradient_vanishing(&mut self, gradient_norm: f64, location: &str) -> Result<()> {
256        if !self.config.enable_gradient_vanishing {
257            return Ok(());
258        }
259
260        let vanishing_threshold = 1e-8;
261        if gradient_norm < vanishing_threshold {
262            self.report_anomaly(Anomaly {
263                anomaly_type: AnomalyType::GradientVanishing,
264                timestamp: Utc::now(),
265                location: location.to_string(),
266                description: format!("Vanishing gradient detected: norm = {}", gradient_norm),
267                severity: AnomalySeverity::High,
268                metadata: {
269                    let mut meta = HashMap::new();
270                    meta.insert("gradient_norm".to_string(), gradient_norm.to_string());
271                    meta.insert("threshold".to_string(), vanishing_threshold.to_string());
272                    meta
273                },
274            });
275        }
276
277        Ok(())
278    }
279
280    /// Check for numerical instability
281    pub fn check_numerical_instability(&mut self, values: &[f32], location: &str) -> Result<()> {
282        let mut metadata = HashMap::new();
283
284        // Check for values close to zero that might cause division problems
285        let near_zero_count = values.iter().filter(|&&v| v.abs() < 1e-10 && v != 0.0).count();
286        if near_zero_count > values.len() / 10 {
287            metadata.insert("near_zero_count".to_string(), near_zero_count.to_string());
288            metadata.insert("total_values".to_string(), values.len().to_string());
289
290            self.report_anomaly(Anomaly {
291                anomaly_type: AnomalyType::UnusualActivation,
292                timestamp: Utc::now(),
293                location: location.to_string(),
294                description: format!(
295                    "Numerical instability: {} values near zero",
296                    near_zero_count
297                ),
298                severity: AnomalySeverity::Medium,
299                metadata: metadata.clone(),
300            });
301        }
302
303        // Check for extreme values that might cause overflow
304        let extreme_count = values.iter().filter(|&&v| v.abs() > 1e6).count();
305        if extreme_count > 0 {
306            metadata.insert("extreme_count".to_string(), extreme_count.to_string());
307
308            self.report_anomaly(Anomaly {
309                anomaly_type: AnomalyType::UnusualActivation,
310                timestamp: Utc::now(),
311                location: location.to_string(),
312                description: format!("Numerical instability: {} extreme values", extreme_count),
313                severity: AnomalySeverity::High,
314                metadata,
315            });
316        }
317
318        Ok(())
319    }
320
321    /// Check for activation saturation
322    pub fn check_activation_saturation(
323        &mut self,
324        activations: &[f32],
325        activation_type: &str,
326        location: &str,
327    ) -> Result<()> {
328        let saturation_threshold = match activation_type.to_lowercase().as_str() {
329            "sigmoid" | "tanh" => 0.01, // Close to 0 or 1 for sigmoid, -1 or 1 for tanh
330            "relu" => 0.0,              // Zero values for ReLU
331            _ => 0.01,
332        };
333
334        let saturated_count = match activation_type.to_lowercase().as_str() {
335            "sigmoid" => activations
336                .iter()
337                .filter(|&&v| v < saturation_threshold || v > 1.0 - saturation_threshold)
338                .count(),
339            "tanh" => activations.iter().filter(|&&v| v.abs() > 1.0 - saturation_threshold).count(),
340            "relu" => activations.iter().filter(|&&v| v == 0.0).count(),
341            _ => activations.iter().filter(|&&v| v.abs() < saturation_threshold).count(),
342        };
343
344        let saturation_ratio = saturated_count as f32 / activations.len() as f32;
345
346        if saturation_ratio > 0.9 {
347            let mut metadata = HashMap::new();
348            metadata.insert("activation_type".to_string(), activation_type.to_string());
349            metadata.insert("saturated_count".to_string(), saturated_count.to_string());
350            metadata.insert("total_count".to_string(), activations.len().to_string());
351            metadata.insert("saturation_ratio".to_string(), saturation_ratio.to_string());
352
353            self.report_anomaly(Anomaly {
354                anomaly_type: AnomalyType::UnusualActivation,
355                timestamp: Utc::now(),
356                location: location.to_string(),
357                description: format!(
358                    "Activation saturation detected: {:.1}% of {} activations saturated",
359                    saturation_ratio * 100.0,
360                    activation_type
361                ),
362                severity: AnomalySeverity::High,
363                metadata,
364            });
365        }
366
367        Ok(())
368    }
369
370    /// Check for memory leaks by tracking memory usage patterns
371    pub fn check_memory_leak(
372        &mut self,
373        current_memory_mb: usize,
374        expected_memory_mb: Option<usize>,
375        location: &str,
376    ) -> Result<()> {
377        if !self.config.enable_memory_leak_detection {
378            return Ok(());
379        }
380
381        let mut should_report = false;
382        let mut description = String::new();
383        let mut metadata = HashMap::new();
384
385        metadata.insert(
386            "current_memory_mb".to_string(),
387            current_memory_mb.to_string(),
388        );
389
390        if let Some(expected) = expected_memory_mb {
391            metadata.insert("expected_memory_mb".to_string(), expected.to_string());
392
393            let growth_ratio = current_memory_mb as f64 / expected as f64;
394            if growth_ratio > 2.0 {
395                should_report = true;
396                description = format!(
397                    "Memory usage {}MB is {:.1}x expected {}MB",
398                    current_memory_mb, growth_ratio, expected
399                );
400                metadata.insert("growth_ratio".to_string(), growth_ratio.to_string());
401            }
402        } else {
403            // Check for absolute high memory usage
404            if current_memory_mb > 8192 {
405                // 8GB threshold
406                should_report = true;
407                description = format!("High memory usage detected: {}MB", current_memory_mb);
408            }
409        }
410
411        if should_report {
412            self.report_anomaly(Anomaly {
413                anomaly_type: AnomalyType::MemoryLeak,
414                timestamp: Utc::now(),
415                location: location.to_string(),
416                description,
417                severity: if current_memory_mb > 16384 {
418                    AnomalySeverity::Critical
419                } else {
420                    AnomalySeverity::High
421                },
422                metadata,
423            });
424        }
425
426        Ok(())
427    }
428
429    /// Check for weight explosion in model parameters
430    pub fn check_weight_explosion(&mut self, weights: &[f32], layer_name: &str) -> Result<()> {
431        let weight_threshold = 10.0;
432        let extreme_weights: Vec<f32> =
433            weights.iter().filter(|&&w| w.abs() > weight_threshold).cloned().collect();
434
435        if !extreme_weights.is_empty() {
436            let mut metadata = HashMap::new();
437            metadata.insert("layer_name".to_string(), layer_name.to_string());
438            metadata.insert(
439                "extreme_weight_count".to_string(),
440                extreme_weights.len().to_string(),
441            );
442            metadata.insert("total_weight_count".to_string(), weights.len().to_string());
443            metadata.insert(
444                "max_weight".to_string(),
445                extreme_weights.iter().map(|w| w.abs()).fold(0.0f32, f32::max).to_string(),
446            );
447
448            self.report_anomaly(Anomaly {
449                anomaly_type: AnomalyType::UnusualActivation,
450                timestamp: Utc::now(),
451                location: layer_name.to_string(),
452                description: format!(
453                    "Weight explosion in {}: {} weights > {}",
454                    layer_name,
455                    extreme_weights.len(),
456                    weight_threshold
457                ),
458                severity: AnomalySeverity::High,
459                metadata,
460            });
461        }
462
463        Ok(())
464    }
465
466    /// Report an anomaly
467    fn report_anomaly(&mut self, anomaly: Anomaly) {
468        eprintln!(
469            "🚨 Anomaly detected: {} at {}",
470            anomaly.description, anomaly.location
471        );
472
473        // Update monitoring stats
474        self.monitoring_stats.total_anomalies += 1;
475        let anomaly_type_key = format!("{:?}", anomaly.anomaly_type);
476        *self.monitoring_stats.anomalies_per_type.entry(anomaly_type_key).or_insert(0) += 1;
477
478        self.detected_anomalies.push(anomaly);
479    }
480
481    /// Get all detected anomalies
482    pub fn get_anomalies(&self) -> &[Anomaly] {
483        &self.detected_anomalies
484    }
485
486    /// Clear all detected anomalies
487    pub fn clear_anomalies(&mut self) {
488        self.detected_anomalies.clear();
489    }
490
491    /// Check for gradient conflicts between layers
492    pub fn check_gradient_conflict(
493        &mut self,
494        layer_gradients: &HashMap<String, Vec<f32>>,
495    ) -> Result<()> {
496        if !self.config.enable_gradient_conflict_detection {
497            return Ok(());
498        }
499
500        let layer_names: Vec<_> = layer_gradients.keys().cloned().collect();
501
502        for i in 0..layer_names.len() {
503            for j in i + 1..layer_names.len() {
504                let layer1 = &layer_names[i];
505                let layer2 = &layer_names[j];
506
507                if let (Some(grad1), Some(grad2)) =
508                    (layer_gradients.get(layer1), layer_gradients.get(layer2))
509                {
510                    let conflict_score = self.compute_gradient_conflict(grad1, grad2);
511
512                    if conflict_score > 0.8 {
513                        let mut metadata = HashMap::new();
514                        metadata.insert("layer1".to_string(), layer1.clone());
515                        metadata.insert("layer2".to_string(), layer2.clone());
516                        metadata.insert("conflict_score".to_string(), conflict_score.to_string());
517
518                        self.report_anomaly(Anomaly {
519                            anomaly_type: AnomalyType::GradientConflict,
520                            timestamp: Utc::now(),
521                            location: format!("{}↔{}", layer1, layer2),
522                            description: format!(
523                                "Gradient conflict detected between {} and {} (score: {:.2})",
524                                layer1, layer2, conflict_score
525                            ),
526                            severity: AnomalySeverity::High,
527                            metadata,
528                        });
529                    }
530                }
531            }
532        }
533
534        Ok(())
535    }
536
537    /// Check for weight divergence from baseline
538    pub fn check_weight_divergence(
539        &mut self,
540        layer_name: &str,
541        current_weights: &[f32],
542    ) -> Result<()> {
543        if !self.config.enable_weight_divergence_detection {
544            return Ok(());
545        }
546
547        // Initialize baseline if not exists
548        if !self.weight_baseline.contains_key(layer_name) {
549            self.weight_baseline.insert(layer_name.to_string(), current_weights.to_vec());
550            return Ok(());
551        }
552
553        let baseline = self.weight_baseline.get(layer_name).unwrap();
554        if baseline.len() != current_weights.len() {
555            return Ok(()); // Skip if dimensions don't match
556        }
557
558        let divergence = self.compute_weight_divergence(baseline, current_weights);
559
560        if divergence > self.config.weight_divergence_threshold {
561            let mut metadata = HashMap::new();
562            metadata.insert("layer_name".to_string(), layer_name.to_string());
563            metadata.insert("divergence_score".to_string(), divergence.to_string());
564            metadata.insert(
565                "threshold".to_string(),
566                self.config.weight_divergence_threshold.to_string(),
567            );
568
569            self.report_anomaly(Anomaly {
570                anomaly_type: AnomalyType::WeightDivergence,
571                timestamp: Utc::now(),
572                location: layer_name.to_string(),
573                description: format!(
574                    "Weight divergence in {}: {:.2} (threshold: {:.2})",
575                    layer_name, divergence, self.config.weight_divergence_threshold
576                ),
577                severity: if divergence > self.config.weight_divergence_threshold * 2.0 {
578                    AnomalySeverity::Critical
579                } else {
580                    AnomalySeverity::High
581                },
582                metadata,
583            });
584        }
585
586        Ok(())
587    }
588
589    /// Check for performance degradation
590    pub fn check_performance_degradation(
591        &mut self,
592        current_performance: f64,
593        location: &str,
594    ) -> Result<()> {
595        if !self.config.enable_performance_monitoring {
596            return Ok(());
597        }
598
599        // Add to history
600        if self.performance_history.len() >= self.config.monitoring_window_size {
601            self.performance_history.pop_front();
602        }
603        self.performance_history.push_back(current_performance);
604
605        // Check for degradation if we have enough history
606        if self.performance_history.len() >= 10 {
607            let recent_avg = self.performance_history.iter().rev().take(5).sum::<f64>() / 5.0;
608            let baseline_avg = self.performance_history.iter().take(5).sum::<f64>() / 5.0;
609
610            let degradation_ratio = (baseline_avg - recent_avg) / baseline_avg;
611
612            if degradation_ratio > self.config.performance_degradation_threshold {
613                let mut metadata = HashMap::new();
614                metadata.insert("baseline_performance".to_string(), baseline_avg.to_string());
615                metadata.insert("current_performance".to_string(), recent_avg.to_string());
616                metadata.insert(
617                    "degradation_ratio".to_string(),
618                    degradation_ratio.to_string(),
619                );
620
621                self.report_anomaly(Anomaly {
622                    anomaly_type: AnomalyType::PerformanceDegradation,
623                    timestamp: Utc::now(),
624                    location: location.to_string(),
625                    description: format!(
626                        "Performance degradation detected: {:.1}% drop from baseline",
627                        degradation_ratio * 100.0
628                    ),
629                    severity: if degradation_ratio > 0.8 {
630                        AnomalySeverity::Critical
631                    } else {
632                        AnomalySeverity::High
633                    },
634                    metadata,
635                });
636            }
637        }
638
639        Ok(())
640    }
641
642    /// Check for loss anomalies
643    pub fn check_loss_anomaly(&mut self, current_loss: f64, location: &str) -> Result<()> {
644        if !self.config.enable_loss_anomaly_detection {
645            return Ok(());
646        }
647
648        // Add to history
649        if self.loss_history.len() >= self.config.monitoring_window_size {
650            self.loss_history.pop_front();
651        }
652        self.loss_history.push_back(current_loss);
653
654        // Check for loss spikes
655        if self.loss_history.len() >= 3 {
656            let prev_loss = self.loss_history[self.loss_history.len() - 2];
657            let loss_ratio = current_loss / prev_loss;
658
659            if loss_ratio > self.config.loss_spike_threshold {
660                let mut metadata = HashMap::new();
661                metadata.insert("previous_loss".to_string(), prev_loss.to_string());
662                metadata.insert("current_loss".to_string(), current_loss.to_string());
663                metadata.insert("spike_ratio".to_string(), loss_ratio.to_string());
664
665                self.report_anomaly(Anomaly {
666                    anomaly_type: AnomalyType::LossAnomalous,
667                    timestamp: Utc::now(),
668                    location: location.to_string(),
669                    description: format!(
670                        "Loss spike detected: {:.2}x increase (from {:.6} to {:.6})",
671                        loss_ratio, prev_loss, current_loss
672                    ),
673                    severity: if loss_ratio > 100.0 {
674                        AnomalySeverity::Critical
675                    } else {
676                        AnomalySeverity::High
677                    },
678                    metadata,
679                });
680            }
681        }
682
683        Ok(())
684    }
685
686    /// Attempt automatic recovery from an anomaly
687    pub async fn attempt_recovery(&mut self, anomaly: &Anomaly) -> Result<RecoveryAction> {
688        if !self.config.enable_auto_recovery {
689            return Ok(RecoveryAction::None);
690        }
691
692        let action = self.determine_recovery_action(anomaly);
693        let anomaly_id = format!(
694            "{:?}_{}",
695            anomaly.anomaly_type,
696            anomaly.timestamp.timestamp()
697        );
698
699        let success = self.execute_recovery_action(&action).await?;
700
701        self.recovery_attempts.push(RecoveryAttempt {
702            anomaly_id: anomaly_id.clone(),
703            action: action.clone(),
704            timestamp: Utc::now(),
705            success,
706            error_message: if success { None } else { Some("Recovery failed".to_string()) },
707        });
708
709        self.monitoring_stats.recovery_attempts += 1;
710        if success {
711            self.monitoring_stats.successful_recoveries += 1;
712        }
713
714        Ok(action)
715    }
716
717    /// Get monitoring statistics
718    pub fn get_monitoring_stats(&self) -> &MonitoringStats {
719        &self.monitoring_stats
720    }
721
722    /// Get recovery attempts history
723    pub fn get_recovery_attempts(&self) -> &[RecoveryAttempt] {
724        &self.recovery_attempts
725    }
726
727    /// Update monitoring window with current state
728    pub fn update_monitoring_window(&mut self) -> Result<()> {
729        let mut severity_distribution = HashMap::new();
730        for anomaly in &self.detected_anomalies {
731            let key = format!("{:?}", anomaly.severity);
732            *severity_distribution.entry(key).or_insert(0) += 1;
733        }
734
735        let mut performance_metrics = HashMap::new();
736        if let Some(latest_perf) = self.performance_history.back() {
737            performance_metrics.insert("latest_performance".to_string(), *latest_perf);
738        }
739        if let Some(latest_loss) = self.loss_history.back() {
740            performance_metrics.insert("latest_loss".to_string(), *latest_loss);
741        }
742
743        let snapshot = AnomalySnapshot {
744            timestamp: Utc::now(),
745            anomaly_count: self.detected_anomalies.len(),
746            severity_distribution,
747            performance_metrics,
748        };
749
750        self.monitoring_stats.monitoring_window.push(snapshot);
751
752        // Keep only recent snapshots
753        if self.monitoring_stats.monitoring_window.len() > self.config.monitoring_window_size {
754            self.monitoring_stats.monitoring_window.remove(0);
755        }
756
757        Ok(())
758    }
759
760    // Private helper methods for new functionality
761
762    fn compute_gradient_conflict(&self, grad1: &[f32], grad2: &[f32]) -> f64 {
763        if grad1.len() != grad2.len() {
764            return 0.0;
765        }
766
767        let dot_product: f64 =
768            grad1.iter().zip(grad2.iter()).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
769
770        let norm1: f64 = grad1.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
771        let norm2: f64 = grad2.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
772
773        if norm1 == 0.0 || norm2 == 0.0 {
774            return 0.0;
775        }
776
777        // Cosine similarity - conflicts are indicated by negative correlation
778        let cosine_sim = dot_product / (norm1 * norm2);
779
780        // Convert to conflict score (0 = no conflict, 1 = maximum conflict)
781        (1.0 - cosine_sim) / 2.0
782    }
783
784    fn compute_weight_divergence(&self, baseline: &[f32], current: &[f32]) -> f64 {
785        let mse: f64 = baseline
786            .iter()
787            .zip(current.iter())
788            .map(|(a, b)| (*a as f64 - *b as f64).powi(2))
789            .sum::<f64>()
790            / baseline.len() as f64;
791
792        mse.sqrt()
793    }
794
795    fn determine_recovery_action(&self, anomaly: &Anomaly) -> RecoveryAction {
796        match anomaly.anomaly_type {
797            AnomalyType::GradientExplosion => RecoveryAction::ClipGradients { max_norm: 1.0 },
798            AnomalyType::GradientVanishing => RecoveryAction::ReduceLearningRate { factor: 0.5 },
799            AnomalyType::NaN | AnomalyType::Infinity => RecoveryAction::ResetGradients,
800            AnomalyType::WeightDivergence => RecoveryAction::ApplyWeightDecay { rate: 0.01 },
801            AnomalyType::LossAnomalous => RecoveryAction::SkipBatch,
802            AnomalyType::MemoryLeak => RecoveryAction::RestartOptimizer,
803            AnomalyType::PerformanceDegradation => {
804                RecoveryAction::ReduceLearningRate { factor: 0.8 }
805            },
806            _ => RecoveryAction::None,
807        }
808    }
809
810    async fn execute_recovery_action(&self, action: &RecoveryAction) -> Result<bool> {
811        // In a real implementation, this would interface with the training system
812        // For now, we'll simulate the actions
813        match action {
814            RecoveryAction::None => Ok(true),
815            RecoveryAction::ResetGradients => {
816                tracing::info!("Executing recovery: Reset gradients");
817                Ok(true)
818            },
819            RecoveryAction::ReduceLearningRate { factor } => {
820                tracing::info!(
821                    "Executing recovery: Reduce learning rate by factor {}",
822                    factor
823                );
824                Ok(true)
825            },
826            RecoveryAction::ClipGradients { max_norm } => {
827                tracing::info!(
828                    "Executing recovery: Clip gradients to max norm {}",
829                    max_norm
830                );
831                Ok(true)
832            },
833            RecoveryAction::RestartOptimizer => {
834                tracing::info!("Executing recovery: Restart optimizer");
835                Ok(true)
836            },
837            RecoveryAction::SkipBatch => {
838                tracing::info!("Executing recovery: Skip current batch");
839                Ok(true)
840            },
841            RecoveryAction::ResetWeights { layer_name } => {
842                tracing::info!("Executing recovery: Reset weights for layer {}", layer_name);
843                Ok(true)
844            },
845            RecoveryAction::ApplyWeightDecay { rate } => {
846                tracing::info!("Executing recovery: Apply weight decay with rate {}", rate);
847                Ok(true)
848            },
849            RecoveryAction::EmergencyStop => {
850                tracing::warn!("Executing recovery: Emergency stop");
851                Ok(false) // This would actually stop training
852            },
853        }
854    }
855
856    /// Quick anomaly check for simplified interface
857    pub async fn quick_check(&self) -> Result<crate::QuickAnomalySummary> {
858        let anomaly_count = self.detected_anomalies.len();
859
860        let severity_level = match anomaly_count {
861            0 => "None",
862            1..=3 => "Low",
863            4..=10 => "Medium",
864            11..=20 => "High",
865            _ => "Critical",
866        }
867        .to_string();
868
869        let mut recommendations = Vec::new();
870        if anomaly_count > 0 {
871            recommendations.push("Review recent training metrics for instabilities".to_string());
872        }
873        if anomaly_count > 5 {
874            recommendations.push(
875                "Consider adjusting learning rate or implementing gradient clipping".to_string(),
876            );
877        }
878        if anomaly_count > 15 {
879            recommendations
880                .push("Training may need to be restarted with better configuration".to_string());
881        }
882        if anomaly_count == 0 {
883            recommendations.push("No anomalies detected, training appears stable".to_string());
884        }
885
886        Ok(crate::QuickAnomalySummary {
887            anomaly_count,
888            severity_level,
889            recommendations,
890        })
891    }
892
893    /// Generate anomaly detection report
894    pub async fn generate_report(&self) -> Result<AnomalyDetectorReport> {
895        let mut anomaly_counts = HashMap::new();
896        for anomaly in &self.detected_anomalies {
897            let count = anomaly_counts.entry(format!("{:?}", anomaly.anomaly_type)).or_insert(0);
898            *count += 1;
899        }
900
901        Ok(AnomalyDetectorReport {
902            session_duration: Utc::now().signed_duration_since(self.start_time),
903            total_anomalies: self.detected_anomalies.len(),
904            anomaly_counts,
905            most_recent_anomalies: self.detected_anomalies.iter().rev().take(10).cloned().collect(),
906            config: self.config.clone(),
907        })
908    }
909}
910
911/// Report generated by the anomaly detector
912#[derive(Debug, Clone, Serialize, Deserialize)]
913pub struct AnomalyDetectorReport {
914    pub session_duration: chrono::Duration,
915    pub total_anomalies: usize,
916    pub anomaly_counts: HashMap<String, usize>,
917    pub most_recent_anomalies: Vec<Anomaly>,
918    pub config: AnomalyDetectorConfig,
919}
920
921#[cfg(test)]
922mod tests {
923    use super::*;
924
925    #[test]
926    fn test_anomaly_detector_creation() {
927        let config = DebugConfig::default();
928        let detector = AnomalyDetector::new(&config);
929        assert_eq!(detector.get_anomalies().len(), 0);
930    }
931
932    #[test]
933    fn test_nan_detection() {
934        let config = DebugConfig::default();
935        let mut detector = AnomalyDetector::new(&config);
936
937        let values = vec![1.0, 2.0, f32::NAN, 4.0];
938        detector.check_nan(&values, "test_location").unwrap();
939
940        assert_eq!(detector.get_anomalies().len(), 1);
941        assert!(matches!(
942            detector.get_anomalies()[0].anomaly_type,
943            AnomalyType::NaN
944        ));
945    }
946
947    #[test]
948    fn test_inf_detection() {
949        let config = DebugConfig::default();
950        let mut detector = AnomalyDetector::new(&config);
951
952        let values = vec![1.0, 2.0, f32::INFINITY, 4.0];
953        detector.check_inf(&values, "test_location").unwrap();
954
955        assert_eq!(detector.get_anomalies().len(), 1);
956        assert!(matches!(
957            detector.get_anomalies()[0].anomaly_type,
958            AnomalyType::Infinity
959        ));
960    }
961
962    #[test]
963    fn test_gradient_explosion_detection() {
964        let config = DebugConfig::default();
965        let mut detector = AnomalyDetector::new(&config);
966
967        detector.check_gradient_explosion(1e7, "test_layer").unwrap();
968
969        assert_eq!(detector.get_anomalies().len(), 1);
970        assert!(matches!(
971            detector.get_anomalies()[0].anomaly_type,
972            AnomalyType::GradientExplosion
973        ));
974    }
975
976    #[test]
977    fn test_gradient_vanishing_detection() {
978        let config = DebugConfig::default();
979        let mut detector = AnomalyDetector::new(&config);
980
981        detector.check_gradient_vanishing(1e-10, "test_layer").unwrap();
982
983        assert_eq!(detector.get_anomalies().len(), 1);
984        assert!(matches!(
985            detector.get_anomalies()[0].anomaly_type,
986            AnomalyType::GradientVanishing
987        ));
988    }
989
990    #[test]
991    fn test_numerical_instability_detection() {
992        let config = DebugConfig::default();
993        let mut detector = AnomalyDetector::new(&config);
994
995        // Test near-zero values
996        let near_zero_values: Vec<f32> =
997            (0..100).map(|i| if i < 50 { 1e-12 } else { 1.0 }).collect();
998        detector
999            .check_numerical_instability(&near_zero_values, "test_location")
1000            .unwrap();
1001        assert_eq!(detector.get_anomalies().len(), 1);
1002
1003        detector.clear_anomalies();
1004
1005        // Test extreme values
1006        let extreme_values = vec![1.0, 2.0, 1e7, 4.0];
1007        detector.check_numerical_instability(&extreme_values, "test_location").unwrap();
1008        assert_eq!(detector.get_anomalies().len(), 1);
1009    }
1010
1011    #[test]
1012    fn test_activation_saturation_detection() {
1013        let config = DebugConfig::default();
1014        let mut detector = AnomalyDetector::new(&config);
1015
1016        // Test ReLU saturation (all zeros)
1017        let relu_saturated: Vec<f32> = vec![0.0; 100];
1018        detector
1019            .check_activation_saturation(&relu_saturated, "relu", "test_layer")
1020            .unwrap();
1021        assert_eq!(detector.get_anomalies().len(), 1);
1022
1023        detector.clear_anomalies();
1024
1025        // Test sigmoid saturation (all ones)
1026        let sigmoid_saturated: Vec<f32> = vec![0.999; 100];
1027        detector
1028            .check_activation_saturation(&sigmoid_saturated, "sigmoid", "test_layer")
1029            .unwrap();
1030        assert_eq!(detector.get_anomalies().len(), 1);
1031    }
1032
1033    #[test]
1034    fn test_memory_leak_detection() {
1035        let config = DebugConfig::default();
1036        let mut detector = AnomalyDetector::new(&config);
1037
1038        // Test memory growth detection (3x growth should trigger)
1039        detector.check_memory_leak(3072, Some(1024), "test_location").unwrap();
1040        assert_eq!(detector.get_anomalies().len(), 1);
1041        assert!(matches!(
1042            detector.get_anomalies()[0].anomaly_type,
1043            AnomalyType::MemoryLeak
1044        ));
1045
1046        detector.clear_anomalies();
1047
1048        // Test absolute high memory
1049        detector.check_memory_leak(10240, None, "test_location").unwrap();
1050        assert_eq!(detector.get_anomalies().len(), 1);
1051    }
1052
1053    #[test]
1054    fn test_weight_explosion_detection() {
1055        let config = DebugConfig::default();
1056        let mut detector = AnomalyDetector::new(&config);
1057
1058        let weights = vec![1.0, 2.0, 15.0, 4.0, -20.0]; // Two weights exceed threshold of 10.0
1059        detector.check_weight_explosion(&weights, "test_layer").unwrap();
1060
1061        assert_eq!(detector.get_anomalies().len(), 1);
1062        assert!(matches!(
1063            detector.get_anomalies()[0].anomaly_type,
1064            AnomalyType::UnusualActivation
1065        ));
1066    }
1067
1068    #[test]
1069    fn test_gradient_conflict_detection() {
1070        let config = DebugConfig::default();
1071        let mut detector = AnomalyDetector::new(&config);
1072
1073        let mut layer_gradients = HashMap::new();
1074        layer_gradients.insert("layer1".to_string(), vec![1.0, 0.0, 0.0]);
1075        layer_gradients.insert("layer2".to_string(), vec![-1.0, 0.0, 0.0]); // Opposing gradients
1076
1077        detector.check_gradient_conflict(&layer_gradients).unwrap();
1078
1079        assert_eq!(detector.get_anomalies().len(), 1);
1080        assert!(matches!(
1081            detector.get_anomalies()[0].anomaly_type,
1082            AnomalyType::GradientConflict
1083        ));
1084    }
1085
1086    #[test]
1087    fn test_weight_divergence_detection() {
1088        let config = DebugConfig::default();
1089        let mut detector = AnomalyDetector::new(&config);
1090
1091        let baseline_weights = vec![1.0, 2.0, 3.0, 4.0];
1092        let diverged_weights = vec![10.0, 20.0, 30.0, 40.0]; // Significant divergence
1093
1094        // First call establishes baseline
1095        detector.check_weight_divergence("test_layer", &baseline_weights).unwrap();
1096        assert_eq!(detector.get_anomalies().len(), 0);
1097
1098        // Second call detects divergence
1099        detector.check_weight_divergence("test_layer", &diverged_weights).unwrap();
1100        assert_eq!(detector.get_anomalies().len(), 1);
1101        assert!(matches!(
1102            detector.get_anomalies()[0].anomaly_type,
1103            AnomalyType::WeightDivergence
1104        ));
1105    }
1106
1107    #[test]
1108    fn test_performance_degradation_detection() {
1109        let config = DebugConfig::default();
1110        let mut detector = AnomalyDetector::new(&config);
1111
1112        // Add baseline performance metrics
1113        for _ in 0..10 {
1114            detector.check_performance_degradation(100.0, "training").unwrap(); // Good performance
1115        }
1116        assert_eq!(detector.get_anomalies().len(), 0);
1117
1118        // Add degraded performance metrics - just enough to trigger once
1119        for _ in 0..5 {
1120            detector.check_performance_degradation(20.0, "training").unwrap(); // Poor performance
1121        }
1122
1123        // Should have at least one degradation anomaly
1124        assert!(!detector.get_anomalies().is_empty());
1125        assert!(detector
1126            .get_anomalies()
1127            .iter()
1128            .any(|a| matches!(a.anomaly_type, AnomalyType::PerformanceDegradation)));
1129    }
1130
1131    #[test]
1132    fn test_loss_anomaly_detection() {
1133        let config = DebugConfig::default();
1134        let mut detector = AnomalyDetector::new(&config);
1135
1136        // Add normal loss values
1137        detector.check_loss_anomaly(1.0, "training").unwrap();
1138        detector.check_loss_anomaly(0.9, "training").unwrap();
1139        assert_eq!(detector.get_anomalies().len(), 0);
1140
1141        // Add loss spike
1142        detector.check_loss_anomaly(100.0, "training").unwrap(); // 100x spike
1143        assert_eq!(detector.get_anomalies().len(), 1);
1144        assert!(matches!(
1145            detector.get_anomalies()[0].anomaly_type,
1146            AnomalyType::LossAnomalous
1147        ));
1148    }
1149
1150    #[tokio::test]
1151    async fn test_auto_recovery() {
1152        let config = DebugConfig::default();
1153        let mut detector = AnomalyDetector::new(&config);
1154        detector.config.enable_auto_recovery = true;
1155
1156        let anomaly = Anomaly {
1157            anomaly_type: AnomalyType::GradientExplosion,
1158            timestamp: Utc::now(),
1159            location: "test_layer".to_string(),
1160            description: "Test gradient explosion".to_string(),
1161            severity: AnomalySeverity::High,
1162            metadata: HashMap::new(),
1163        };
1164
1165        let action = detector.attempt_recovery(&anomaly).await.unwrap();
1166        assert!(matches!(action, RecoveryAction::ClipGradients { .. }));
1167        assert_eq!(detector.get_recovery_attempts().len(), 1);
1168    }
1169
1170    #[test]
1171    fn test_monitoring_stats() {
1172        let config = DebugConfig::default();
1173        let mut detector = AnomalyDetector::new(&config);
1174
1175        // Create some anomalies to generate stats
1176        detector.check_nan(&[f32::NAN], "test").unwrap();
1177        detector.check_inf(&[f32::INFINITY], "test").unwrap();
1178
1179        let stats = detector.get_monitoring_stats();
1180        assert_eq!(stats.total_anomalies, 2);
1181        assert!(stats.anomalies_per_type.contains_key("NaN"));
1182        assert!(stats.anomalies_per_type.contains_key("Infinity"));
1183    }
1184
1185    #[test]
1186    fn test_monitoring_window_update() {
1187        let config = DebugConfig::default();
1188        let mut detector = AnomalyDetector::new(&config);
1189
1190        detector.check_nan(&[f32::NAN], "test").unwrap();
1191        detector.update_monitoring_window().unwrap();
1192
1193        let stats = detector.get_monitoring_stats();
1194        assert_eq!(stats.monitoring_window.len(), 1);
1195        assert_eq!(stats.monitoring_window[0].anomaly_count, 1);
1196    }
1197}