Skip to main content

trustformers_debug/
anomaly_detector.rs

1//! Anomaly Detection for Model Debugging
2//!
3//! Detects unusual patterns in model execution, tensor values, and gradients
4//! to help identify potential issues during training and inference.
5
6use anyhow::Result;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10
11use crate::DebugConfig;
12
13/// Anomaly detector for model execution
14#[derive(Debug)]
15pub struct AnomalyDetector {
16    config: AnomalyDetectorConfig,
17    detected_anomalies: Vec<Anomaly>,
18    start_time: DateTime<Utc>,
19    recovery_attempts: Vec<RecoveryAttempt>,
20    monitoring_stats: MonitoringStats,
21    performance_history: VecDeque<f64>,
22    #[allow(dead_code)]
23    gradient_history: HashMap<String, VecDeque<f64>>,
24    loss_history: VecDeque<f64>,
25    weight_baseline: HashMap<String, Vec<f32>>,
26}
27
28/// Configuration for anomaly detection
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct AnomalyDetectorConfig {
31    pub enable_nan_detection: bool,
32    pub enable_inf_detection: bool,
33    pub enable_gradient_explosion: bool,
34    pub enable_gradient_vanishing: bool,
35    pub gradient_threshold: f64,
36    pub enable_memory_leak_detection: bool,
37    pub enable_numerical_instability_detection: bool,
38    pub enable_gradient_conflict_detection: bool,
39    pub enable_performance_monitoring: bool,
40    pub enable_weight_divergence_detection: bool,
41    pub enable_activation_dead_detection: bool,
42    pub enable_loss_anomaly_detection: bool,
43    pub enable_auto_recovery: bool,
44    pub numerical_instability_threshold: f64,
45    pub performance_degradation_threshold: f64,
46    pub weight_divergence_threshold: f64,
47    pub loss_spike_threshold: f64,
48    pub monitoring_window_size: usize,
49    pub recovery_attempts_limit: usize,
50}
51
52impl Default for AnomalyDetectorConfig {
53    fn default() -> Self {
54        Self {
55            enable_nan_detection: true,
56            enable_inf_detection: true,
57            enable_gradient_explosion: true,
58            enable_gradient_vanishing: true,
59            gradient_threshold: 1e6,
60            enable_memory_leak_detection: true,
61            enable_numerical_instability_detection: true,
62            enable_gradient_conflict_detection: true,
63            enable_performance_monitoring: true,
64            enable_weight_divergence_detection: true,
65            enable_activation_dead_detection: true,
66            enable_loss_anomaly_detection: true,
67            enable_auto_recovery: false, // Conservative default
68            numerical_instability_threshold: 1e-12,
69            performance_degradation_threshold: 0.5, // 50% degradation
70            weight_divergence_threshold: 5.0,
71            loss_spike_threshold: 10.0, // 10x loss increase
72            monitoring_window_size: 100,
73            recovery_attempts_limit: 3,
74        }
75    }
76}
77
78/// Types of anomalies that can be detected
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum AnomalyType {
81    NaN,
82    Infinity,
83    GradientExplosion,
84    GradientVanishing,
85    MemoryLeak,
86    UnusualActivation,
87    NumericalInstability,
88    GradientConflict,
89    PerformanceDegradation,
90    WeightDivergence,
91    ActivationDead,
92    LossAnomalous,
93}
94
95/// An detected anomaly
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct Anomaly {
98    pub anomaly_type: AnomalyType,
99    pub timestamp: DateTime<Utc>,
100    pub location: String,
101    pub description: String,
102    pub severity: AnomalySeverity,
103    pub metadata: HashMap<String, String>,
104}
105
106/// Severity level of an anomaly
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub enum AnomalySeverity {
109    Low,
110    Medium,
111    High,
112    Critical,
113}
114
115/// Auto-recovery action that can be taken
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub enum RecoveryAction {
118    None,
119    ResetGradients,
120    ReduceLearningRate { factor: f64 },
121    ClipGradients { max_norm: f64 },
122    RestartOptimizer,
123    SkipBatch,
124    ResetWeights { layer_name: String },
125    ApplyWeightDecay { rate: f64 },
126    EmergencyStop,
127}
128
129/// Recovery attempt record
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct RecoveryAttempt {
132    pub anomaly_id: String,
133    pub action: RecoveryAction,
134    pub timestamp: DateTime<Utc>,
135    pub success: bool,
136    pub error_message: Option<String>,
137}
138
139/// Real-time monitoring statistics
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct MonitoringStats {
142    pub total_anomalies: usize,
143    pub anomalies_per_type: HashMap<String, usize>,
144    pub recovery_attempts: usize,
145    pub successful_recoveries: usize,
146    pub average_detection_time_ms: f64,
147    pub monitoring_window: Vec<AnomalySnapshot>,
148}
149
150/// Snapshot of anomaly state for monitoring window
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct AnomalySnapshot {
153    pub timestamp: DateTime<Utc>,
154    pub anomaly_count: usize,
155    pub severity_distribution: HashMap<String, usize>,
156    pub performance_metrics: HashMap<String, f64>,
157}
158
159impl AnomalyDetector {
160    /// Create a new anomaly detector
161    pub fn new(_config: &DebugConfig) -> Self {
162        let monitoring_window_size = AnomalyDetectorConfig::default().monitoring_window_size;
163        Self {
164            config: AnomalyDetectorConfig::default(),
165            detected_anomalies: Vec::new(),
166            start_time: Utc::now(),
167            recovery_attempts: Vec::new(),
168            monitoring_stats: MonitoringStats {
169                total_anomalies: 0,
170                anomalies_per_type: HashMap::new(),
171                recovery_attempts: 0,
172                successful_recoveries: 0,
173                average_detection_time_ms: 0.0,
174                monitoring_window: Vec::new(),
175            },
176            performance_history: VecDeque::with_capacity(monitoring_window_size),
177            gradient_history: HashMap::new(),
178            loss_history: VecDeque::with_capacity(monitoring_window_size),
179            weight_baseline: HashMap::new(),
180        }
181    }
182
183    /// Start the anomaly detector
184    pub async fn start(&mut self) -> Result<()> {
185        self.start_time = Utc::now();
186        self.detected_anomalies.clear();
187        Ok(())
188    }
189
190    /// Check for NaN values in tensors
191    pub fn check_nan(&mut self, values: &[f32], location: &str) -> Result<()> {
192        if !self.config.enable_nan_detection {
193            return Ok(());
194        }
195
196        if values.iter().any(|v| v.is_nan()) {
197            self.report_anomaly(Anomaly {
198                anomaly_type: AnomalyType::NaN,
199                timestamp: Utc::now(),
200                location: location.to_string(),
201                description: "NaN values detected in tensor".to_string(),
202                severity: AnomalySeverity::High,
203                metadata: HashMap::new(),
204            });
205        }
206
207        Ok(())
208    }
209
210    /// Check for infinite values in tensors
211    pub fn check_inf(&mut self, values: &[f32], location: &str) -> Result<()> {
212        if !self.config.enable_inf_detection {
213            return Ok(());
214        }
215
216        if values.iter().any(|v| v.is_infinite()) {
217            self.report_anomaly(Anomaly {
218                anomaly_type: AnomalyType::Infinity,
219                timestamp: Utc::now(),
220                location: location.to_string(),
221                description: "Infinite values detected in tensor".to_string(),
222                severity: AnomalySeverity::High,
223                metadata: HashMap::new(),
224            });
225        }
226
227        Ok(())
228    }
229
230    /// Check for gradient explosion
231    pub fn check_gradient_explosion(&mut self, gradient_norm: f64, location: &str) -> Result<()> {
232        if !self.config.enable_gradient_explosion {
233            return Ok(());
234        }
235
236        if gradient_norm > self.config.gradient_threshold {
237            self.report_anomaly(Anomaly {
238                anomaly_type: AnomalyType::GradientExplosion,
239                timestamp: Utc::now(),
240                location: location.to_string(),
241                description: format!("Gradient explosion detected: norm = {}", gradient_norm),
242                severity: AnomalySeverity::Critical,
243                metadata: {
244                    let mut meta = HashMap::new();
245                    meta.insert("gradient_norm".to_string(), gradient_norm.to_string());
246                    meta
247                },
248            });
249        }
250
251        Ok(())
252    }
253
254    /// Check for vanishing gradients
255    pub fn check_gradient_vanishing(&mut self, gradient_norm: f64, location: &str) -> Result<()> {
256        if !self.config.enable_gradient_vanishing {
257            return Ok(());
258        }
259
260        let vanishing_threshold = 1e-8;
261        if gradient_norm < vanishing_threshold {
262            self.report_anomaly(Anomaly {
263                anomaly_type: AnomalyType::GradientVanishing,
264                timestamp: Utc::now(),
265                location: location.to_string(),
266                description: format!("Vanishing gradient detected: norm = {}", gradient_norm),
267                severity: AnomalySeverity::High,
268                metadata: {
269                    let mut meta = HashMap::new();
270                    meta.insert("gradient_norm".to_string(), gradient_norm.to_string());
271                    meta.insert("threshold".to_string(), vanishing_threshold.to_string());
272                    meta
273                },
274            });
275        }
276
277        Ok(())
278    }
279
280    /// Check for numerical instability
281    pub fn check_numerical_instability(&mut self, values: &[f32], location: &str) -> Result<()> {
282        let mut metadata = HashMap::new();
283
284        // Check for values close to zero that might cause division problems
285        let near_zero_count = values.iter().filter(|&&v| v.abs() < 1e-10 && v != 0.0).count();
286        if near_zero_count > values.len() / 10 {
287            metadata.insert("near_zero_count".to_string(), near_zero_count.to_string());
288            metadata.insert("total_values".to_string(), values.len().to_string());
289
290            self.report_anomaly(Anomaly {
291                anomaly_type: AnomalyType::UnusualActivation,
292                timestamp: Utc::now(),
293                location: location.to_string(),
294                description: format!(
295                    "Numerical instability: {} values near zero",
296                    near_zero_count
297                ),
298                severity: AnomalySeverity::Medium,
299                metadata: metadata.clone(),
300            });
301        }
302
303        // Check for extreme values that might cause overflow
304        let extreme_count = values.iter().filter(|&&v| v.abs() > 1e6).count();
305        if extreme_count > 0 {
306            metadata.insert("extreme_count".to_string(), extreme_count.to_string());
307
308            self.report_anomaly(Anomaly {
309                anomaly_type: AnomalyType::UnusualActivation,
310                timestamp: Utc::now(),
311                location: location.to_string(),
312                description: format!("Numerical instability: {} extreme values", extreme_count),
313                severity: AnomalySeverity::High,
314                metadata,
315            });
316        }
317
318        Ok(())
319    }
320
321    /// Check for activation saturation
322    pub fn check_activation_saturation(
323        &mut self,
324        activations: &[f32],
325        activation_type: &str,
326        location: &str,
327    ) -> Result<()> {
328        let saturation_threshold = match activation_type.to_lowercase().as_str() {
329            "sigmoid" | "tanh" => 0.01, // Close to 0 or 1 for sigmoid, -1 or 1 for tanh
330            "relu" => 0.0,              // Zero values for ReLU
331            _ => 0.01,
332        };
333
334        let saturated_count = match activation_type.to_lowercase().as_str() {
335            "sigmoid" => activations
336                .iter()
337                .filter(|&&v| v < saturation_threshold || v > 1.0 - saturation_threshold)
338                .count(),
339            "tanh" => activations.iter().filter(|&&v| v.abs() > 1.0 - saturation_threshold).count(),
340            "relu" => activations.iter().filter(|&&v| v == 0.0).count(),
341            _ => activations.iter().filter(|&&v| v.abs() < saturation_threshold).count(),
342        };
343
344        let saturation_ratio = saturated_count as f32 / activations.len() as f32;
345
346        if saturation_ratio > 0.9 {
347            let mut metadata = HashMap::new();
348            metadata.insert("activation_type".to_string(), activation_type.to_string());
349            metadata.insert("saturated_count".to_string(), saturated_count.to_string());
350            metadata.insert("total_count".to_string(), activations.len().to_string());
351            metadata.insert("saturation_ratio".to_string(), saturation_ratio.to_string());
352
353            self.report_anomaly(Anomaly {
354                anomaly_type: AnomalyType::UnusualActivation,
355                timestamp: Utc::now(),
356                location: location.to_string(),
357                description: format!(
358                    "Activation saturation detected: {:.1}% of {} activations saturated",
359                    saturation_ratio * 100.0,
360                    activation_type
361                ),
362                severity: AnomalySeverity::High,
363                metadata,
364            });
365        }
366
367        Ok(())
368    }
369
370    /// Check for memory leaks by tracking memory usage patterns
371    pub fn check_memory_leak(
372        &mut self,
373        current_memory_mb: usize,
374        expected_memory_mb: Option<usize>,
375        location: &str,
376    ) -> Result<()> {
377        if !self.config.enable_memory_leak_detection {
378            return Ok(());
379        }
380
381        let mut should_report = false;
382        let mut description = String::new();
383        let mut metadata = HashMap::new();
384
385        metadata.insert(
386            "current_memory_mb".to_string(),
387            current_memory_mb.to_string(),
388        );
389
390        if let Some(expected) = expected_memory_mb {
391            metadata.insert("expected_memory_mb".to_string(), expected.to_string());
392
393            let growth_ratio = current_memory_mb as f64 / expected as f64;
394            if growth_ratio > 2.0 {
395                should_report = true;
396                description = format!(
397                    "Memory usage {}MB is {:.1}x expected {}MB",
398                    current_memory_mb, growth_ratio, expected
399                );
400                metadata.insert("growth_ratio".to_string(), growth_ratio.to_string());
401            }
402        } else {
403            // Check for absolute high memory usage
404            if current_memory_mb > 8192 {
405                // 8GB threshold
406                should_report = true;
407                description = format!("High memory usage detected: {}MB", current_memory_mb);
408            }
409        }
410
411        if should_report {
412            self.report_anomaly(Anomaly {
413                anomaly_type: AnomalyType::MemoryLeak,
414                timestamp: Utc::now(),
415                location: location.to_string(),
416                description,
417                severity: if current_memory_mb > 16384 {
418                    AnomalySeverity::Critical
419                } else {
420                    AnomalySeverity::High
421                },
422                metadata,
423            });
424        }
425
426        Ok(())
427    }
428
429    /// Check for weight explosion in model parameters
430    pub fn check_weight_explosion(&mut self, weights: &[f32], layer_name: &str) -> Result<()> {
431        let weight_threshold = 10.0;
432        let extreme_weights: Vec<f32> =
433            weights.iter().filter(|&&w| w.abs() > weight_threshold).cloned().collect();
434
435        if !extreme_weights.is_empty() {
436            let mut metadata = HashMap::new();
437            metadata.insert("layer_name".to_string(), layer_name.to_string());
438            metadata.insert(
439                "extreme_weight_count".to_string(),
440                extreme_weights.len().to_string(),
441            );
442            metadata.insert("total_weight_count".to_string(), weights.len().to_string());
443            metadata.insert(
444                "max_weight".to_string(),
445                extreme_weights.iter().map(|w| w.abs()).fold(0.0f32, f32::max).to_string(),
446            );
447
448            self.report_anomaly(Anomaly {
449                anomaly_type: AnomalyType::UnusualActivation,
450                timestamp: Utc::now(),
451                location: layer_name.to_string(),
452                description: format!(
453                    "Weight explosion in {}: {} weights > {}",
454                    layer_name,
455                    extreme_weights.len(),
456                    weight_threshold
457                ),
458                severity: AnomalySeverity::High,
459                metadata,
460            });
461        }
462
463        Ok(())
464    }
465
466    /// Report an anomaly
467    fn report_anomaly(&mut self, anomaly: Anomaly) {
468        eprintln!(
469            "🚨 Anomaly detected: {} at {}",
470            anomaly.description, anomaly.location
471        );
472
473        // Update monitoring stats
474        self.monitoring_stats.total_anomalies += 1;
475        let anomaly_type_key = format!("{:?}", anomaly.anomaly_type);
476        *self.monitoring_stats.anomalies_per_type.entry(anomaly_type_key).or_insert(0) += 1;
477
478        self.detected_anomalies.push(anomaly);
479    }
480
481    /// Get all detected anomalies
482    pub fn get_anomalies(&self) -> &[Anomaly] {
483        &self.detected_anomalies
484    }
485
486    /// Clear all detected anomalies
487    pub fn clear_anomalies(&mut self) {
488        self.detected_anomalies.clear();
489    }
490
491    /// Check for gradient conflicts between layers
492    pub fn check_gradient_conflict(
493        &mut self,
494        layer_gradients: &HashMap<String, Vec<f32>>,
495    ) -> Result<()> {
496        if !self.config.enable_gradient_conflict_detection {
497            return Ok(());
498        }
499
500        let layer_names: Vec<_> = layer_gradients.keys().cloned().collect();
501
502        for i in 0..layer_names.len() {
503            for j in i + 1..layer_names.len() {
504                let layer1 = &layer_names[i];
505                let layer2 = &layer_names[j];
506
507                if let (Some(grad1), Some(grad2)) =
508                    (layer_gradients.get(layer1), layer_gradients.get(layer2))
509                {
510                    let conflict_score = self.compute_gradient_conflict(grad1, grad2);
511
512                    if conflict_score > 0.8 {
513                        let mut metadata = HashMap::new();
514                        metadata.insert("layer1".to_string(), layer1.clone());
515                        metadata.insert("layer2".to_string(), layer2.clone());
516                        metadata.insert("conflict_score".to_string(), conflict_score.to_string());
517
518                        self.report_anomaly(Anomaly {
519                            anomaly_type: AnomalyType::GradientConflict,
520                            timestamp: Utc::now(),
521                            location: format!("{}↔{}", layer1, layer2),
522                            description: format!(
523                                "Gradient conflict detected between {} and {} (score: {:.2})",
524                                layer1, layer2, conflict_score
525                            ),
526                            severity: AnomalySeverity::High,
527                            metadata,
528                        });
529                    }
530                }
531            }
532        }
533
534        Ok(())
535    }
536
537    /// Check for weight divergence from baseline
538    pub fn check_weight_divergence(
539        &mut self,
540        layer_name: &str,
541        current_weights: &[f32],
542    ) -> Result<()> {
543        if !self.config.enable_weight_divergence_detection {
544            return Ok(());
545        }
546
547        // Initialize baseline if not exists
548        if !self.weight_baseline.contains_key(layer_name) {
549            self.weight_baseline.insert(layer_name.to_string(), current_weights.to_vec());
550            return Ok(());
551        }
552
553        let baseline = self
554            .weight_baseline
555            .get(layer_name)
556            .expect("baseline should exist after contains_key check");
557        if baseline.len() != current_weights.len() {
558            return Ok(()); // Skip if dimensions don't match
559        }
560
561        let divergence = self.compute_weight_divergence(baseline, current_weights);
562
563        if divergence > self.config.weight_divergence_threshold {
564            let mut metadata = HashMap::new();
565            metadata.insert("layer_name".to_string(), layer_name.to_string());
566            metadata.insert("divergence_score".to_string(), divergence.to_string());
567            metadata.insert(
568                "threshold".to_string(),
569                self.config.weight_divergence_threshold.to_string(),
570            );
571
572            self.report_anomaly(Anomaly {
573                anomaly_type: AnomalyType::WeightDivergence,
574                timestamp: Utc::now(),
575                location: layer_name.to_string(),
576                description: format!(
577                    "Weight divergence in {}: {:.2} (threshold: {:.2})",
578                    layer_name, divergence, self.config.weight_divergence_threshold
579                ),
580                severity: if divergence > self.config.weight_divergence_threshold * 2.0 {
581                    AnomalySeverity::Critical
582                } else {
583                    AnomalySeverity::High
584                },
585                metadata,
586            });
587        }
588
589        Ok(())
590    }
591
592    /// Check for performance degradation
593    pub fn check_performance_degradation(
594        &mut self,
595        current_performance: f64,
596        location: &str,
597    ) -> Result<()> {
598        if !self.config.enable_performance_monitoring {
599            return Ok(());
600        }
601
602        // Add to history
603        if self.performance_history.len() >= self.config.monitoring_window_size {
604            self.performance_history.pop_front();
605        }
606        self.performance_history.push_back(current_performance);
607
608        // Check for degradation if we have enough history
609        if self.performance_history.len() >= 10 {
610            let recent_avg = self.performance_history.iter().rev().take(5).sum::<f64>() / 5.0;
611            let baseline_avg = self.performance_history.iter().take(5).sum::<f64>() / 5.0;
612
613            let degradation_ratio = (baseline_avg - recent_avg) / baseline_avg;
614
615            if degradation_ratio > self.config.performance_degradation_threshold {
616                let mut metadata = HashMap::new();
617                metadata.insert("baseline_performance".to_string(), baseline_avg.to_string());
618                metadata.insert("current_performance".to_string(), recent_avg.to_string());
619                metadata.insert(
620                    "degradation_ratio".to_string(),
621                    degradation_ratio.to_string(),
622                );
623
624                self.report_anomaly(Anomaly {
625                    anomaly_type: AnomalyType::PerformanceDegradation,
626                    timestamp: Utc::now(),
627                    location: location.to_string(),
628                    description: format!(
629                        "Performance degradation detected: {:.1}% drop from baseline",
630                        degradation_ratio * 100.0
631                    ),
632                    severity: if degradation_ratio > 0.8 {
633                        AnomalySeverity::Critical
634                    } else {
635                        AnomalySeverity::High
636                    },
637                    metadata,
638                });
639            }
640        }
641
642        Ok(())
643    }
644
645    /// Check for loss anomalies
646    pub fn check_loss_anomaly(&mut self, current_loss: f64, location: &str) -> Result<()> {
647        if !self.config.enable_loss_anomaly_detection {
648            return Ok(());
649        }
650
651        // Add to history
652        if self.loss_history.len() >= self.config.monitoring_window_size {
653            self.loss_history.pop_front();
654        }
655        self.loss_history.push_back(current_loss);
656
657        // Check for loss spikes
658        if self.loss_history.len() >= 3 {
659            let prev_loss = self.loss_history[self.loss_history.len() - 2];
660            let loss_ratio = current_loss / prev_loss;
661
662            if loss_ratio > self.config.loss_spike_threshold {
663                let mut metadata = HashMap::new();
664                metadata.insert("previous_loss".to_string(), prev_loss.to_string());
665                metadata.insert("current_loss".to_string(), current_loss.to_string());
666                metadata.insert("spike_ratio".to_string(), loss_ratio.to_string());
667
668                self.report_anomaly(Anomaly {
669                    anomaly_type: AnomalyType::LossAnomalous,
670                    timestamp: Utc::now(),
671                    location: location.to_string(),
672                    description: format!(
673                        "Loss spike detected: {:.2}x increase (from {:.6} to {:.6})",
674                        loss_ratio, prev_loss, current_loss
675                    ),
676                    severity: if loss_ratio > 100.0 {
677                        AnomalySeverity::Critical
678                    } else {
679                        AnomalySeverity::High
680                    },
681                    metadata,
682                });
683            }
684        }
685
686        Ok(())
687    }
688
689    /// Attempt automatic recovery from an anomaly
690    pub async fn attempt_recovery(&mut self, anomaly: &Anomaly) -> Result<RecoveryAction> {
691        if !self.config.enable_auto_recovery {
692            return Ok(RecoveryAction::None);
693        }
694
695        let action = self.determine_recovery_action(anomaly);
696        let anomaly_id = format!(
697            "{:?}_{}",
698            anomaly.anomaly_type,
699            anomaly.timestamp.timestamp()
700        );
701
702        let success = self.execute_recovery_action(&action).await?;
703
704        self.recovery_attempts.push(RecoveryAttempt {
705            anomaly_id: anomaly_id.clone(),
706            action: action.clone(),
707            timestamp: Utc::now(),
708            success,
709            error_message: if success { None } else { Some("Recovery failed".to_string()) },
710        });
711
712        self.monitoring_stats.recovery_attempts += 1;
713        if success {
714            self.monitoring_stats.successful_recoveries += 1;
715        }
716
717        Ok(action)
718    }
719
720    /// Get monitoring statistics
721    pub fn get_monitoring_stats(&self) -> &MonitoringStats {
722        &self.monitoring_stats
723    }
724
725    /// Get recovery attempts history
726    pub fn get_recovery_attempts(&self) -> &[RecoveryAttempt] {
727        &self.recovery_attempts
728    }
729
730    /// Update monitoring window with current state
731    pub fn update_monitoring_window(&mut self) -> Result<()> {
732        let mut severity_distribution = HashMap::new();
733        for anomaly in &self.detected_anomalies {
734            let key = format!("{:?}", anomaly.severity);
735            *severity_distribution.entry(key).or_insert(0) += 1;
736        }
737
738        let mut performance_metrics = HashMap::new();
739        if let Some(latest_perf) = self.performance_history.back() {
740            performance_metrics.insert("latest_performance".to_string(), *latest_perf);
741        }
742        if let Some(latest_loss) = self.loss_history.back() {
743            performance_metrics.insert("latest_loss".to_string(), *latest_loss);
744        }
745
746        let snapshot = AnomalySnapshot {
747            timestamp: Utc::now(),
748            anomaly_count: self.detected_anomalies.len(),
749            severity_distribution,
750            performance_metrics,
751        };
752
753        self.monitoring_stats.monitoring_window.push(snapshot);
754
755        // Keep only recent snapshots
756        if self.monitoring_stats.monitoring_window.len() > self.config.monitoring_window_size {
757            self.monitoring_stats.monitoring_window.remove(0);
758        }
759
760        Ok(())
761    }
762
763    // Private helper methods for new functionality
764
765    fn compute_gradient_conflict(&self, grad1: &[f32], grad2: &[f32]) -> f64 {
766        if grad1.len() != grad2.len() {
767            return 0.0;
768        }
769
770        let dot_product: f64 =
771            grad1.iter().zip(grad2.iter()).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
772
773        let norm1: f64 = grad1.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
774        let norm2: f64 = grad2.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
775
776        if norm1 == 0.0 || norm2 == 0.0 {
777            return 0.0;
778        }
779
780        // Cosine similarity - conflicts are indicated by negative correlation
781        let cosine_sim = dot_product / (norm1 * norm2);
782
783        // Convert to conflict score (0 = no conflict, 1 = maximum conflict)
784        (1.0 - cosine_sim) / 2.0
785    }
786
787    fn compute_weight_divergence(&self, baseline: &[f32], current: &[f32]) -> f64 {
788        let mse: f64 = baseline
789            .iter()
790            .zip(current.iter())
791            .map(|(a, b)| (*a as f64 - *b as f64).powi(2))
792            .sum::<f64>()
793            / baseline.len() as f64;
794
795        mse.sqrt()
796    }
797
798    fn determine_recovery_action(&self, anomaly: &Anomaly) -> RecoveryAction {
799        match anomaly.anomaly_type {
800            AnomalyType::GradientExplosion => RecoveryAction::ClipGradients { max_norm: 1.0 },
801            AnomalyType::GradientVanishing => RecoveryAction::ReduceLearningRate { factor: 0.5 },
802            AnomalyType::NaN | AnomalyType::Infinity => RecoveryAction::ResetGradients,
803            AnomalyType::WeightDivergence => RecoveryAction::ApplyWeightDecay { rate: 0.01 },
804            AnomalyType::LossAnomalous => RecoveryAction::SkipBatch,
805            AnomalyType::MemoryLeak => RecoveryAction::RestartOptimizer,
806            AnomalyType::PerformanceDegradation => {
807                RecoveryAction::ReduceLearningRate { factor: 0.8 }
808            },
809            _ => RecoveryAction::None,
810        }
811    }
812
813    async fn execute_recovery_action(&self, action: &RecoveryAction) -> Result<bool> {
814        // In a real implementation, this would interface with the training system
815        // For now, we'll simulate the actions
816        match action {
817            RecoveryAction::None => Ok(true),
818            RecoveryAction::ResetGradients => {
819                tracing::info!("Executing recovery: Reset gradients");
820                Ok(true)
821            },
822            RecoveryAction::ReduceLearningRate { factor } => {
823                tracing::info!(
824                    "Executing recovery: Reduce learning rate by factor {}",
825                    factor
826                );
827                Ok(true)
828            },
829            RecoveryAction::ClipGradients { max_norm } => {
830                tracing::info!(
831                    "Executing recovery: Clip gradients to max norm {}",
832                    max_norm
833                );
834                Ok(true)
835            },
836            RecoveryAction::RestartOptimizer => {
837                tracing::info!("Executing recovery: Restart optimizer");
838                Ok(true)
839            },
840            RecoveryAction::SkipBatch => {
841                tracing::info!("Executing recovery: Skip current batch");
842                Ok(true)
843            },
844            RecoveryAction::ResetWeights { layer_name } => {
845                tracing::info!("Executing recovery: Reset weights for layer {}", layer_name);
846                Ok(true)
847            },
848            RecoveryAction::ApplyWeightDecay { rate } => {
849                tracing::info!("Executing recovery: Apply weight decay with rate {}", rate);
850                Ok(true)
851            },
852            RecoveryAction::EmergencyStop => {
853                tracing::warn!("Executing recovery: Emergency stop");
854                Ok(false) // This would actually stop training
855            },
856        }
857    }
858
859    /// Quick anomaly check for simplified interface
860    pub async fn quick_check(&self) -> Result<crate::QuickAnomalySummary> {
861        let anomaly_count = self.detected_anomalies.len();
862
863        let severity_level = match anomaly_count {
864            0 => "None",
865            1..=3 => "Low",
866            4..=10 => "Medium",
867            11..=20 => "High",
868            _ => "Critical",
869        }
870        .to_string();
871
872        let mut recommendations = Vec::new();
873        if anomaly_count > 0 {
874            recommendations.push("Review recent training metrics for instabilities".to_string());
875        }
876        if anomaly_count > 5 {
877            recommendations.push(
878                "Consider adjusting learning rate or implementing gradient clipping".to_string(),
879            );
880        }
881        if anomaly_count > 15 {
882            recommendations
883                .push("Training may need to be restarted with better configuration".to_string());
884        }
885        if anomaly_count == 0 {
886            recommendations.push("No anomalies detected, training appears stable".to_string());
887        }
888
889        Ok(crate::QuickAnomalySummary {
890            anomaly_count,
891            severity_level,
892            recommendations,
893        })
894    }
895
896    /// Generate anomaly detection report
897    pub async fn generate_report(&self) -> Result<AnomalyDetectorReport> {
898        let mut anomaly_counts = HashMap::new();
899        for anomaly in &self.detected_anomalies {
900            let count = anomaly_counts.entry(format!("{:?}", anomaly.anomaly_type)).or_insert(0);
901            *count += 1;
902        }
903
904        Ok(AnomalyDetectorReport {
905            session_duration: Utc::now().signed_duration_since(self.start_time),
906            total_anomalies: self.detected_anomalies.len(),
907            anomaly_counts,
908            most_recent_anomalies: self.detected_anomalies.iter().rev().take(10).cloned().collect(),
909            config: self.config.clone(),
910        })
911    }
912}
913
914/// Report generated by the anomaly detector
915#[derive(Debug, Clone, Serialize, Deserialize)]
916pub struct AnomalyDetectorReport {
917    pub session_duration: chrono::Duration,
918    pub total_anomalies: usize,
919    pub anomaly_counts: HashMap<String, usize>,
920    pub most_recent_anomalies: Vec<Anomaly>,
921    pub config: AnomalyDetectorConfig,
922}
923
924#[cfg(test)]
925mod tests {
926    use super::*;
927
928    #[test]
929    fn test_anomaly_detector_creation() {
930        let config = DebugConfig::default();
931        let detector = AnomalyDetector::new(&config);
932        assert_eq!(detector.get_anomalies().len(), 0);
933    }
934
935    #[test]
936    fn test_nan_detection() {
937        let config = DebugConfig::default();
938        let mut detector = AnomalyDetector::new(&config);
939
940        let values = vec![1.0, 2.0, f32::NAN, 4.0];
941        detector.check_nan(&values, "test_location").expect("operation failed in test");
942
943        assert_eq!(detector.get_anomalies().len(), 1);
944        assert!(matches!(
945            detector.get_anomalies()[0].anomaly_type,
946            AnomalyType::NaN
947        ));
948    }
949
950    #[test]
951    fn test_inf_detection() {
952        let config = DebugConfig::default();
953        let mut detector = AnomalyDetector::new(&config);
954
955        let values = vec![1.0, 2.0, f32::INFINITY, 4.0];
956        detector.check_inf(&values, "test_location").expect("operation failed in test");
957
958        assert_eq!(detector.get_anomalies().len(), 1);
959        assert!(matches!(
960            detector.get_anomalies()[0].anomaly_type,
961            AnomalyType::Infinity
962        ));
963    }
964
965    #[test]
966    fn test_gradient_explosion_detection() {
967        let config = DebugConfig::default();
968        let mut detector = AnomalyDetector::new(&config);
969
970        detector
971            .check_gradient_explosion(1e7, "test_layer")
972            .expect("operation failed in test");
973
974        assert_eq!(detector.get_anomalies().len(), 1);
975        assert!(matches!(
976            detector.get_anomalies()[0].anomaly_type,
977            AnomalyType::GradientExplosion
978        ));
979    }
980
981    #[test]
982    fn test_gradient_vanishing_detection() {
983        let config = DebugConfig::default();
984        let mut detector = AnomalyDetector::new(&config);
985
986        detector
987            .check_gradient_vanishing(1e-10, "test_layer")
988            .expect("operation failed in test");
989
990        assert_eq!(detector.get_anomalies().len(), 1);
991        assert!(matches!(
992            detector.get_anomalies()[0].anomaly_type,
993            AnomalyType::GradientVanishing
994        ));
995    }
996
997    #[test]
998    fn test_numerical_instability_detection() {
999        let config = DebugConfig::default();
1000        let mut detector = AnomalyDetector::new(&config);
1001
1002        // Test near-zero values
1003        let near_zero_values: Vec<f32> =
1004            (0..100).map(|i| if i < 50 { 1e-12 } else { 1.0 }).collect();
1005        detector
1006            .check_numerical_instability(&near_zero_values, "test_location")
1007            .expect("operation failed in test");
1008        assert_eq!(detector.get_anomalies().len(), 1);
1009
1010        detector.clear_anomalies();
1011
1012        // Test extreme values
1013        let extreme_values = vec![1.0, 2.0, 1e7, 4.0];
1014        detector
1015            .check_numerical_instability(&extreme_values, "test_location")
1016            .expect("operation failed in test");
1017        assert_eq!(detector.get_anomalies().len(), 1);
1018    }
1019
1020    #[test]
1021    fn test_activation_saturation_detection() {
1022        let config = DebugConfig::default();
1023        let mut detector = AnomalyDetector::new(&config);
1024
1025        // Test ReLU saturation (all zeros)
1026        let relu_saturated: Vec<f32> = vec![0.0; 100];
1027        detector
1028            .check_activation_saturation(&relu_saturated, "relu", "test_layer")
1029            .expect("operation failed in test");
1030        assert_eq!(detector.get_anomalies().len(), 1);
1031
1032        detector.clear_anomalies();
1033
1034        // Test sigmoid saturation (all ones)
1035        let sigmoid_saturated: Vec<f32> = vec![0.999; 100];
1036        detector
1037            .check_activation_saturation(&sigmoid_saturated, "sigmoid", "test_layer")
1038            .expect("operation failed in test");
1039        assert_eq!(detector.get_anomalies().len(), 1);
1040    }
1041
1042    #[test]
1043    fn test_memory_leak_detection() {
1044        let config = DebugConfig::default();
1045        let mut detector = AnomalyDetector::new(&config);
1046
1047        // Test memory growth detection (3x growth should trigger)
1048        detector
1049            .check_memory_leak(3072, Some(1024), "test_location")
1050            .expect("operation failed in test");
1051        assert_eq!(detector.get_anomalies().len(), 1);
1052        assert!(matches!(
1053            detector.get_anomalies()[0].anomaly_type,
1054            AnomalyType::MemoryLeak
1055        ));
1056
1057        detector.clear_anomalies();
1058
1059        // Test absolute high memory
1060        detector
1061            .check_memory_leak(10240, None, "test_location")
1062            .expect("operation failed in test");
1063        assert_eq!(detector.get_anomalies().len(), 1);
1064    }
1065
1066    #[test]
1067    fn test_weight_explosion_detection() {
1068        let config = DebugConfig::default();
1069        let mut detector = AnomalyDetector::new(&config);
1070
1071        let weights = vec![1.0, 2.0, 15.0, 4.0, -20.0]; // Two weights exceed threshold of 10.0
1072        detector
1073            .check_weight_explosion(&weights, "test_layer")
1074            .expect("operation failed in test");
1075
1076        assert_eq!(detector.get_anomalies().len(), 1);
1077        assert!(matches!(
1078            detector.get_anomalies()[0].anomaly_type,
1079            AnomalyType::UnusualActivation
1080        ));
1081    }
1082
1083    #[test]
1084    fn test_gradient_conflict_detection() {
1085        let config = DebugConfig::default();
1086        let mut detector = AnomalyDetector::new(&config);
1087
1088        let mut layer_gradients = HashMap::new();
1089        layer_gradients.insert("layer1".to_string(), vec![1.0, 0.0, 0.0]);
1090        layer_gradients.insert("layer2".to_string(), vec![-1.0, 0.0, 0.0]); // Opposing gradients
1091
1092        detector
1093            .check_gradient_conflict(&layer_gradients)
1094            .expect("operation failed in test");
1095
1096        assert_eq!(detector.get_anomalies().len(), 1);
1097        assert!(matches!(
1098            detector.get_anomalies()[0].anomaly_type,
1099            AnomalyType::GradientConflict
1100        ));
1101    }
1102
1103    #[test]
1104    fn test_weight_divergence_detection() {
1105        let config = DebugConfig::default();
1106        let mut detector = AnomalyDetector::new(&config);
1107
1108        let baseline_weights = vec![1.0, 2.0, 3.0, 4.0];
1109        let diverged_weights = vec![10.0, 20.0, 30.0, 40.0]; // Significant divergence
1110
1111        // First call establishes baseline
1112        detector
1113            .check_weight_divergence("test_layer", &baseline_weights)
1114            .expect("operation failed in test");
1115        assert_eq!(detector.get_anomalies().len(), 0);
1116
1117        // Second call detects divergence
1118        detector
1119            .check_weight_divergence("test_layer", &diverged_weights)
1120            .expect("operation failed in test");
1121        assert_eq!(detector.get_anomalies().len(), 1);
1122        assert!(matches!(
1123            detector.get_anomalies()[0].anomaly_type,
1124            AnomalyType::WeightDivergence
1125        ));
1126    }
1127
1128    #[test]
1129    fn test_performance_degradation_detection() {
1130        let config = DebugConfig::default();
1131        let mut detector = AnomalyDetector::new(&config);
1132
1133        // Add baseline performance metrics
1134        for _ in 0..10 {
1135            detector
1136                .check_performance_degradation(100.0, "training")
1137                .expect("operation failed in test"); // Good performance
1138        }
1139        assert_eq!(detector.get_anomalies().len(), 0);
1140
1141        // Add degraded performance metrics - just enough to trigger once
1142        for _ in 0..5 {
1143            detector
1144                .check_performance_degradation(20.0, "training")
1145                .expect("operation failed in test"); // Poor performance
1146        }
1147
1148        // Should have at least one degradation anomaly
1149        assert!(!detector.get_anomalies().is_empty());
1150        assert!(detector
1151            .get_anomalies()
1152            .iter()
1153            .any(|a| matches!(a.anomaly_type, AnomalyType::PerformanceDegradation)));
1154    }
1155
1156    #[test]
1157    fn test_loss_anomaly_detection() {
1158        let config = DebugConfig::default();
1159        let mut detector = AnomalyDetector::new(&config);
1160
1161        // Add normal loss values
1162        detector.check_loss_anomaly(1.0, "training").expect("operation failed in test");
1163        detector.check_loss_anomaly(0.9, "training").expect("operation failed in test");
1164        assert_eq!(detector.get_anomalies().len(), 0);
1165
1166        // Add loss spike
1167        detector
1168            .check_loss_anomaly(100.0, "training")
1169            .expect("operation failed in test"); // 100x spike
1170        assert_eq!(detector.get_anomalies().len(), 1);
1171        assert!(matches!(
1172            detector.get_anomalies()[0].anomaly_type,
1173            AnomalyType::LossAnomalous
1174        ));
1175    }
1176
1177    #[tokio::test]
1178    async fn test_auto_recovery() {
1179        let config = DebugConfig::default();
1180        let mut detector = AnomalyDetector::new(&config);
1181        detector.config.enable_auto_recovery = true;
1182
1183        let anomaly = Anomaly {
1184            anomaly_type: AnomalyType::GradientExplosion,
1185            timestamp: Utc::now(),
1186            location: "test_layer".to_string(),
1187            description: "Test gradient explosion".to_string(),
1188            severity: AnomalySeverity::High,
1189            metadata: HashMap::new(),
1190        };
1191
1192        let action = detector.attempt_recovery(&anomaly).await.expect("temp file creation failed");
1193        assert!(matches!(action, RecoveryAction::ClipGradients { .. }));
1194        assert_eq!(detector.get_recovery_attempts().len(), 1);
1195    }
1196
1197    #[test]
1198    fn test_monitoring_stats() {
1199        let config = DebugConfig::default();
1200        let mut detector = AnomalyDetector::new(&config);
1201
1202        // Create some anomalies to generate stats
1203        detector.check_nan(&[f32::NAN], "test").expect("operation failed in test");
1204        detector.check_inf(&[f32::INFINITY], "test").expect("operation failed in test");
1205
1206        let stats = detector.get_monitoring_stats();
1207        assert_eq!(stats.total_anomalies, 2);
1208        assert!(stats.anomalies_per_type.contains_key("NaN"));
1209        assert!(stats.anomalies_per_type.contains_key("Infinity"));
1210    }
1211
1212    #[test]
1213    fn test_monitoring_window_update() {
1214        let config = DebugConfig::default();
1215        let mut detector = AnomalyDetector::new(&config);
1216
1217        detector.check_nan(&[f32::NAN], "test").expect("operation failed in test");
1218        detector.update_monitoring_window().expect("operation failed in test");
1219
1220        let stats = detector.get_monitoring_stats();
1221        assert_eq!(stats.monitoring_window.len(), 1);
1222        assert_eq!(stats.monitoring_window[0].anomaly_count, 1);
1223    }
1224}