1use anyhow::Result;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10
11use crate::DebugConfig;
12
13#[derive(Debug)]
15pub struct AnomalyDetector {
16 config: AnomalyDetectorConfig,
17 detected_anomalies: Vec<Anomaly>,
18 start_time: DateTime<Utc>,
19 recovery_attempts: Vec<RecoveryAttempt>,
20 monitoring_stats: MonitoringStats,
21 performance_history: VecDeque<f64>,
22 #[allow(dead_code)]
23 gradient_history: HashMap<String, VecDeque<f64>>,
24 loss_history: VecDeque<f64>,
25 weight_baseline: HashMap<String, Vec<f32>>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct AnomalyDetectorConfig {
31 pub enable_nan_detection: bool,
32 pub enable_inf_detection: bool,
33 pub enable_gradient_explosion: bool,
34 pub enable_gradient_vanishing: bool,
35 pub gradient_threshold: f64,
36 pub enable_memory_leak_detection: bool,
37 pub enable_numerical_instability_detection: bool,
38 pub enable_gradient_conflict_detection: bool,
39 pub enable_performance_monitoring: bool,
40 pub enable_weight_divergence_detection: bool,
41 pub enable_activation_dead_detection: bool,
42 pub enable_loss_anomaly_detection: bool,
43 pub enable_auto_recovery: bool,
44 pub numerical_instability_threshold: f64,
45 pub performance_degradation_threshold: f64,
46 pub weight_divergence_threshold: f64,
47 pub loss_spike_threshold: f64,
48 pub monitoring_window_size: usize,
49 pub recovery_attempts_limit: usize,
50}
51
52impl Default for AnomalyDetectorConfig {
53 fn default() -> Self {
54 Self {
55 enable_nan_detection: true,
56 enable_inf_detection: true,
57 enable_gradient_explosion: true,
58 enable_gradient_vanishing: true,
59 gradient_threshold: 1e6,
60 enable_memory_leak_detection: true,
61 enable_numerical_instability_detection: true,
62 enable_gradient_conflict_detection: true,
63 enable_performance_monitoring: true,
64 enable_weight_divergence_detection: true,
65 enable_activation_dead_detection: true,
66 enable_loss_anomaly_detection: true,
67 enable_auto_recovery: false, numerical_instability_threshold: 1e-12,
69 performance_degradation_threshold: 0.5, weight_divergence_threshold: 5.0,
71 loss_spike_threshold: 10.0, monitoring_window_size: 100,
73 recovery_attempts_limit: 3,
74 }
75 }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum AnomalyType {
81 NaN,
82 Infinity,
83 GradientExplosion,
84 GradientVanishing,
85 MemoryLeak,
86 UnusualActivation,
87 NumericalInstability,
88 GradientConflict,
89 PerformanceDegradation,
90 WeightDivergence,
91 ActivationDead,
92 LossAnomalous,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct Anomaly {
98 pub anomaly_type: AnomalyType,
99 pub timestamp: DateTime<Utc>,
100 pub location: String,
101 pub description: String,
102 pub severity: AnomalySeverity,
103 pub metadata: HashMap<String, String>,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub enum AnomalySeverity {
109 Low,
110 Medium,
111 High,
112 Critical,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub enum RecoveryAction {
118 None,
119 ResetGradients,
120 ReduceLearningRate { factor: f64 },
121 ClipGradients { max_norm: f64 },
122 RestartOptimizer,
123 SkipBatch,
124 ResetWeights { layer_name: String },
125 ApplyWeightDecay { rate: f64 },
126 EmergencyStop,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct RecoveryAttempt {
132 pub anomaly_id: String,
133 pub action: RecoveryAction,
134 pub timestamp: DateTime<Utc>,
135 pub success: bool,
136 pub error_message: Option<String>,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct MonitoringStats {
142 pub total_anomalies: usize,
143 pub anomalies_per_type: HashMap<String, usize>,
144 pub recovery_attempts: usize,
145 pub successful_recoveries: usize,
146 pub average_detection_time_ms: f64,
147 pub monitoring_window: Vec<AnomalySnapshot>,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct AnomalySnapshot {
153 pub timestamp: DateTime<Utc>,
154 pub anomaly_count: usize,
155 pub severity_distribution: HashMap<String, usize>,
156 pub performance_metrics: HashMap<String, f64>,
157}
158
159impl AnomalyDetector {
160 pub fn new(_config: &DebugConfig) -> Self {
162 let monitoring_window_size = AnomalyDetectorConfig::default().monitoring_window_size;
163 Self {
164 config: AnomalyDetectorConfig::default(),
165 detected_anomalies: Vec::new(),
166 start_time: Utc::now(),
167 recovery_attempts: Vec::new(),
168 monitoring_stats: MonitoringStats {
169 total_anomalies: 0,
170 anomalies_per_type: HashMap::new(),
171 recovery_attempts: 0,
172 successful_recoveries: 0,
173 average_detection_time_ms: 0.0,
174 monitoring_window: Vec::new(),
175 },
176 performance_history: VecDeque::with_capacity(monitoring_window_size),
177 gradient_history: HashMap::new(),
178 loss_history: VecDeque::with_capacity(monitoring_window_size),
179 weight_baseline: HashMap::new(),
180 }
181 }
182
183 pub async fn start(&mut self) -> Result<()> {
185 self.start_time = Utc::now();
186 self.detected_anomalies.clear();
187 Ok(())
188 }
189
190 pub fn check_nan(&mut self, values: &[f32], location: &str) -> Result<()> {
192 if !self.config.enable_nan_detection {
193 return Ok(());
194 }
195
196 if values.iter().any(|v| v.is_nan()) {
197 self.report_anomaly(Anomaly {
198 anomaly_type: AnomalyType::NaN,
199 timestamp: Utc::now(),
200 location: location.to_string(),
201 description: "NaN values detected in tensor".to_string(),
202 severity: AnomalySeverity::High,
203 metadata: HashMap::new(),
204 });
205 }
206
207 Ok(())
208 }
209
210 pub fn check_inf(&mut self, values: &[f32], location: &str) -> Result<()> {
212 if !self.config.enable_inf_detection {
213 return Ok(());
214 }
215
216 if values.iter().any(|v| v.is_infinite()) {
217 self.report_anomaly(Anomaly {
218 anomaly_type: AnomalyType::Infinity,
219 timestamp: Utc::now(),
220 location: location.to_string(),
221 description: "Infinite values detected in tensor".to_string(),
222 severity: AnomalySeverity::High,
223 metadata: HashMap::new(),
224 });
225 }
226
227 Ok(())
228 }
229
230 pub fn check_gradient_explosion(&mut self, gradient_norm: f64, location: &str) -> Result<()> {
232 if !self.config.enable_gradient_explosion {
233 return Ok(());
234 }
235
236 if gradient_norm > self.config.gradient_threshold {
237 self.report_anomaly(Anomaly {
238 anomaly_type: AnomalyType::GradientExplosion,
239 timestamp: Utc::now(),
240 location: location.to_string(),
241 description: format!("Gradient explosion detected: norm = {}", gradient_norm),
242 severity: AnomalySeverity::Critical,
243 metadata: {
244 let mut meta = HashMap::new();
245 meta.insert("gradient_norm".to_string(), gradient_norm.to_string());
246 meta
247 },
248 });
249 }
250
251 Ok(())
252 }
253
254 pub fn check_gradient_vanishing(&mut self, gradient_norm: f64, location: &str) -> Result<()> {
256 if !self.config.enable_gradient_vanishing {
257 return Ok(());
258 }
259
260 let vanishing_threshold = 1e-8;
261 if gradient_norm < vanishing_threshold {
262 self.report_anomaly(Anomaly {
263 anomaly_type: AnomalyType::GradientVanishing,
264 timestamp: Utc::now(),
265 location: location.to_string(),
266 description: format!("Vanishing gradient detected: norm = {}", gradient_norm),
267 severity: AnomalySeverity::High,
268 metadata: {
269 let mut meta = HashMap::new();
270 meta.insert("gradient_norm".to_string(), gradient_norm.to_string());
271 meta.insert("threshold".to_string(), vanishing_threshold.to_string());
272 meta
273 },
274 });
275 }
276
277 Ok(())
278 }
279
280 pub fn check_numerical_instability(&mut self, values: &[f32], location: &str) -> Result<()> {
282 let mut metadata = HashMap::new();
283
284 let near_zero_count = values.iter().filter(|&&v| v.abs() < 1e-10 && v != 0.0).count();
286 if near_zero_count > values.len() / 10 {
287 metadata.insert("near_zero_count".to_string(), near_zero_count.to_string());
288 metadata.insert("total_values".to_string(), values.len().to_string());
289
290 self.report_anomaly(Anomaly {
291 anomaly_type: AnomalyType::UnusualActivation,
292 timestamp: Utc::now(),
293 location: location.to_string(),
294 description: format!(
295 "Numerical instability: {} values near zero",
296 near_zero_count
297 ),
298 severity: AnomalySeverity::Medium,
299 metadata: metadata.clone(),
300 });
301 }
302
303 let extreme_count = values.iter().filter(|&&v| v.abs() > 1e6).count();
305 if extreme_count > 0 {
306 metadata.insert("extreme_count".to_string(), extreme_count.to_string());
307
308 self.report_anomaly(Anomaly {
309 anomaly_type: AnomalyType::UnusualActivation,
310 timestamp: Utc::now(),
311 location: location.to_string(),
312 description: format!("Numerical instability: {} extreme values", extreme_count),
313 severity: AnomalySeverity::High,
314 metadata,
315 });
316 }
317
318 Ok(())
319 }
320
321 pub fn check_activation_saturation(
323 &mut self,
324 activations: &[f32],
325 activation_type: &str,
326 location: &str,
327 ) -> Result<()> {
328 let saturation_threshold = match activation_type.to_lowercase().as_str() {
329 "sigmoid" | "tanh" => 0.01, "relu" => 0.0, _ => 0.01,
332 };
333
334 let saturated_count = match activation_type.to_lowercase().as_str() {
335 "sigmoid" => activations
336 .iter()
337 .filter(|&&v| v < saturation_threshold || v > 1.0 - saturation_threshold)
338 .count(),
339 "tanh" => activations.iter().filter(|&&v| v.abs() > 1.0 - saturation_threshold).count(),
340 "relu" => activations.iter().filter(|&&v| v == 0.0).count(),
341 _ => activations.iter().filter(|&&v| v.abs() < saturation_threshold).count(),
342 };
343
344 let saturation_ratio = saturated_count as f32 / activations.len() as f32;
345
346 if saturation_ratio > 0.9 {
347 let mut metadata = HashMap::new();
348 metadata.insert("activation_type".to_string(), activation_type.to_string());
349 metadata.insert("saturated_count".to_string(), saturated_count.to_string());
350 metadata.insert("total_count".to_string(), activations.len().to_string());
351 metadata.insert("saturation_ratio".to_string(), saturation_ratio.to_string());
352
353 self.report_anomaly(Anomaly {
354 anomaly_type: AnomalyType::UnusualActivation,
355 timestamp: Utc::now(),
356 location: location.to_string(),
357 description: format!(
358 "Activation saturation detected: {:.1}% of {} activations saturated",
359 saturation_ratio * 100.0,
360 activation_type
361 ),
362 severity: AnomalySeverity::High,
363 metadata,
364 });
365 }
366
367 Ok(())
368 }
369
370 pub fn check_memory_leak(
372 &mut self,
373 current_memory_mb: usize,
374 expected_memory_mb: Option<usize>,
375 location: &str,
376 ) -> Result<()> {
377 if !self.config.enable_memory_leak_detection {
378 return Ok(());
379 }
380
381 let mut should_report = false;
382 let mut description = String::new();
383 let mut metadata = HashMap::new();
384
385 metadata.insert(
386 "current_memory_mb".to_string(),
387 current_memory_mb.to_string(),
388 );
389
390 if let Some(expected) = expected_memory_mb {
391 metadata.insert("expected_memory_mb".to_string(), expected.to_string());
392
393 let growth_ratio = current_memory_mb as f64 / expected as f64;
394 if growth_ratio > 2.0 {
395 should_report = true;
396 description = format!(
397 "Memory usage {}MB is {:.1}x expected {}MB",
398 current_memory_mb, growth_ratio, expected
399 );
400 metadata.insert("growth_ratio".to_string(), growth_ratio.to_string());
401 }
402 } else {
403 if current_memory_mb > 8192 {
405 should_report = true;
407 description = format!("High memory usage detected: {}MB", current_memory_mb);
408 }
409 }
410
411 if should_report {
412 self.report_anomaly(Anomaly {
413 anomaly_type: AnomalyType::MemoryLeak,
414 timestamp: Utc::now(),
415 location: location.to_string(),
416 description,
417 severity: if current_memory_mb > 16384 {
418 AnomalySeverity::Critical
419 } else {
420 AnomalySeverity::High
421 },
422 metadata,
423 });
424 }
425
426 Ok(())
427 }
428
429 pub fn check_weight_explosion(&mut self, weights: &[f32], layer_name: &str) -> Result<()> {
431 let weight_threshold = 10.0;
432 let extreme_weights: Vec<f32> =
433 weights.iter().filter(|&&w| w.abs() > weight_threshold).cloned().collect();
434
435 if !extreme_weights.is_empty() {
436 let mut metadata = HashMap::new();
437 metadata.insert("layer_name".to_string(), layer_name.to_string());
438 metadata.insert(
439 "extreme_weight_count".to_string(),
440 extreme_weights.len().to_string(),
441 );
442 metadata.insert("total_weight_count".to_string(), weights.len().to_string());
443 metadata.insert(
444 "max_weight".to_string(),
445 extreme_weights.iter().map(|w| w.abs()).fold(0.0f32, f32::max).to_string(),
446 );
447
448 self.report_anomaly(Anomaly {
449 anomaly_type: AnomalyType::UnusualActivation,
450 timestamp: Utc::now(),
451 location: layer_name.to_string(),
452 description: format!(
453 "Weight explosion in {}: {} weights > {}",
454 layer_name,
455 extreme_weights.len(),
456 weight_threshold
457 ),
458 severity: AnomalySeverity::High,
459 metadata,
460 });
461 }
462
463 Ok(())
464 }
465
466 fn report_anomaly(&mut self, anomaly: Anomaly) {
468 eprintln!(
469 "🚨 Anomaly detected: {} at {}",
470 anomaly.description, anomaly.location
471 );
472
473 self.monitoring_stats.total_anomalies += 1;
475 let anomaly_type_key = format!("{:?}", anomaly.anomaly_type);
476 *self.monitoring_stats.anomalies_per_type.entry(anomaly_type_key).or_insert(0) += 1;
477
478 self.detected_anomalies.push(anomaly);
479 }
480
481 pub fn get_anomalies(&self) -> &[Anomaly] {
483 &self.detected_anomalies
484 }
485
486 pub fn clear_anomalies(&mut self) {
488 self.detected_anomalies.clear();
489 }
490
491 pub fn check_gradient_conflict(
493 &mut self,
494 layer_gradients: &HashMap<String, Vec<f32>>,
495 ) -> Result<()> {
496 if !self.config.enable_gradient_conflict_detection {
497 return Ok(());
498 }
499
500 let layer_names: Vec<_> = layer_gradients.keys().cloned().collect();
501
502 for i in 0..layer_names.len() {
503 for j in i + 1..layer_names.len() {
504 let layer1 = &layer_names[i];
505 let layer2 = &layer_names[j];
506
507 if let (Some(grad1), Some(grad2)) =
508 (layer_gradients.get(layer1), layer_gradients.get(layer2))
509 {
510 let conflict_score = self.compute_gradient_conflict(grad1, grad2);
511
512 if conflict_score > 0.8 {
513 let mut metadata = HashMap::new();
514 metadata.insert("layer1".to_string(), layer1.clone());
515 metadata.insert("layer2".to_string(), layer2.clone());
516 metadata.insert("conflict_score".to_string(), conflict_score.to_string());
517
518 self.report_anomaly(Anomaly {
519 anomaly_type: AnomalyType::GradientConflict,
520 timestamp: Utc::now(),
521 location: format!("{}↔{}", layer1, layer2),
522 description: format!(
523 "Gradient conflict detected between {} and {} (score: {:.2})",
524 layer1, layer2, conflict_score
525 ),
526 severity: AnomalySeverity::High,
527 metadata,
528 });
529 }
530 }
531 }
532 }
533
534 Ok(())
535 }
536
537 pub fn check_weight_divergence(
539 &mut self,
540 layer_name: &str,
541 current_weights: &[f32],
542 ) -> Result<()> {
543 if !self.config.enable_weight_divergence_detection {
544 return Ok(());
545 }
546
547 if !self.weight_baseline.contains_key(layer_name) {
549 self.weight_baseline.insert(layer_name.to_string(), current_weights.to_vec());
550 return Ok(());
551 }
552
553 let baseline = self.weight_baseline.get(layer_name).unwrap();
554 if baseline.len() != current_weights.len() {
555 return Ok(()); }
557
558 let divergence = self.compute_weight_divergence(baseline, current_weights);
559
560 if divergence > self.config.weight_divergence_threshold {
561 let mut metadata = HashMap::new();
562 metadata.insert("layer_name".to_string(), layer_name.to_string());
563 metadata.insert("divergence_score".to_string(), divergence.to_string());
564 metadata.insert(
565 "threshold".to_string(),
566 self.config.weight_divergence_threshold.to_string(),
567 );
568
569 self.report_anomaly(Anomaly {
570 anomaly_type: AnomalyType::WeightDivergence,
571 timestamp: Utc::now(),
572 location: layer_name.to_string(),
573 description: format!(
574 "Weight divergence in {}: {:.2} (threshold: {:.2})",
575 layer_name, divergence, self.config.weight_divergence_threshold
576 ),
577 severity: if divergence > self.config.weight_divergence_threshold * 2.0 {
578 AnomalySeverity::Critical
579 } else {
580 AnomalySeverity::High
581 },
582 metadata,
583 });
584 }
585
586 Ok(())
587 }
588
589 pub fn check_performance_degradation(
591 &mut self,
592 current_performance: f64,
593 location: &str,
594 ) -> Result<()> {
595 if !self.config.enable_performance_monitoring {
596 return Ok(());
597 }
598
599 if self.performance_history.len() >= self.config.monitoring_window_size {
601 self.performance_history.pop_front();
602 }
603 self.performance_history.push_back(current_performance);
604
605 if self.performance_history.len() >= 10 {
607 let recent_avg = self.performance_history.iter().rev().take(5).sum::<f64>() / 5.0;
608 let baseline_avg = self.performance_history.iter().take(5).sum::<f64>() / 5.0;
609
610 let degradation_ratio = (baseline_avg - recent_avg) / baseline_avg;
611
612 if degradation_ratio > self.config.performance_degradation_threshold {
613 let mut metadata = HashMap::new();
614 metadata.insert("baseline_performance".to_string(), baseline_avg.to_string());
615 metadata.insert("current_performance".to_string(), recent_avg.to_string());
616 metadata.insert(
617 "degradation_ratio".to_string(),
618 degradation_ratio.to_string(),
619 );
620
621 self.report_anomaly(Anomaly {
622 anomaly_type: AnomalyType::PerformanceDegradation,
623 timestamp: Utc::now(),
624 location: location.to_string(),
625 description: format!(
626 "Performance degradation detected: {:.1}% drop from baseline",
627 degradation_ratio * 100.0
628 ),
629 severity: if degradation_ratio > 0.8 {
630 AnomalySeverity::Critical
631 } else {
632 AnomalySeverity::High
633 },
634 metadata,
635 });
636 }
637 }
638
639 Ok(())
640 }
641
642 pub fn check_loss_anomaly(&mut self, current_loss: f64, location: &str) -> Result<()> {
644 if !self.config.enable_loss_anomaly_detection {
645 return Ok(());
646 }
647
648 if self.loss_history.len() >= self.config.monitoring_window_size {
650 self.loss_history.pop_front();
651 }
652 self.loss_history.push_back(current_loss);
653
654 if self.loss_history.len() >= 3 {
656 let prev_loss = self.loss_history[self.loss_history.len() - 2];
657 let loss_ratio = current_loss / prev_loss;
658
659 if loss_ratio > self.config.loss_spike_threshold {
660 let mut metadata = HashMap::new();
661 metadata.insert("previous_loss".to_string(), prev_loss.to_string());
662 metadata.insert("current_loss".to_string(), current_loss.to_string());
663 metadata.insert("spike_ratio".to_string(), loss_ratio.to_string());
664
665 self.report_anomaly(Anomaly {
666 anomaly_type: AnomalyType::LossAnomalous,
667 timestamp: Utc::now(),
668 location: location.to_string(),
669 description: format!(
670 "Loss spike detected: {:.2}x increase (from {:.6} to {:.6})",
671 loss_ratio, prev_loss, current_loss
672 ),
673 severity: if loss_ratio > 100.0 {
674 AnomalySeverity::Critical
675 } else {
676 AnomalySeverity::High
677 },
678 metadata,
679 });
680 }
681 }
682
683 Ok(())
684 }
685
686 pub async fn attempt_recovery(&mut self, anomaly: &Anomaly) -> Result<RecoveryAction> {
688 if !self.config.enable_auto_recovery {
689 return Ok(RecoveryAction::None);
690 }
691
692 let action = self.determine_recovery_action(anomaly);
693 let anomaly_id = format!(
694 "{:?}_{}",
695 anomaly.anomaly_type,
696 anomaly.timestamp.timestamp()
697 );
698
699 let success = self.execute_recovery_action(&action).await?;
700
701 self.recovery_attempts.push(RecoveryAttempt {
702 anomaly_id: anomaly_id.clone(),
703 action: action.clone(),
704 timestamp: Utc::now(),
705 success,
706 error_message: if success { None } else { Some("Recovery failed".to_string()) },
707 });
708
709 self.monitoring_stats.recovery_attempts += 1;
710 if success {
711 self.monitoring_stats.successful_recoveries += 1;
712 }
713
714 Ok(action)
715 }
716
717 pub fn get_monitoring_stats(&self) -> &MonitoringStats {
719 &self.monitoring_stats
720 }
721
722 pub fn get_recovery_attempts(&self) -> &[RecoveryAttempt] {
724 &self.recovery_attempts
725 }
726
727 pub fn update_monitoring_window(&mut self) -> Result<()> {
729 let mut severity_distribution = HashMap::new();
730 for anomaly in &self.detected_anomalies {
731 let key = format!("{:?}", anomaly.severity);
732 *severity_distribution.entry(key).or_insert(0) += 1;
733 }
734
735 let mut performance_metrics = HashMap::new();
736 if let Some(latest_perf) = self.performance_history.back() {
737 performance_metrics.insert("latest_performance".to_string(), *latest_perf);
738 }
739 if let Some(latest_loss) = self.loss_history.back() {
740 performance_metrics.insert("latest_loss".to_string(), *latest_loss);
741 }
742
743 let snapshot = AnomalySnapshot {
744 timestamp: Utc::now(),
745 anomaly_count: self.detected_anomalies.len(),
746 severity_distribution,
747 performance_metrics,
748 };
749
750 self.monitoring_stats.monitoring_window.push(snapshot);
751
752 if self.monitoring_stats.monitoring_window.len() > self.config.monitoring_window_size {
754 self.monitoring_stats.monitoring_window.remove(0);
755 }
756
757 Ok(())
758 }
759
760 fn compute_gradient_conflict(&self, grad1: &[f32], grad2: &[f32]) -> f64 {
763 if grad1.len() != grad2.len() {
764 return 0.0;
765 }
766
767 let dot_product: f64 =
768 grad1.iter().zip(grad2.iter()).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
769
770 let norm1: f64 = grad1.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
771 let norm2: f64 = grad2.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
772
773 if norm1 == 0.0 || norm2 == 0.0 {
774 return 0.0;
775 }
776
777 let cosine_sim = dot_product / (norm1 * norm2);
779
780 (1.0 - cosine_sim) / 2.0
782 }
783
784 fn compute_weight_divergence(&self, baseline: &[f32], current: &[f32]) -> f64 {
785 let mse: f64 = baseline
786 .iter()
787 .zip(current.iter())
788 .map(|(a, b)| (*a as f64 - *b as f64).powi(2))
789 .sum::<f64>()
790 / baseline.len() as f64;
791
792 mse.sqrt()
793 }
794
795 fn determine_recovery_action(&self, anomaly: &Anomaly) -> RecoveryAction {
796 match anomaly.anomaly_type {
797 AnomalyType::GradientExplosion => RecoveryAction::ClipGradients { max_norm: 1.0 },
798 AnomalyType::GradientVanishing => RecoveryAction::ReduceLearningRate { factor: 0.5 },
799 AnomalyType::NaN | AnomalyType::Infinity => RecoveryAction::ResetGradients,
800 AnomalyType::WeightDivergence => RecoveryAction::ApplyWeightDecay { rate: 0.01 },
801 AnomalyType::LossAnomalous => RecoveryAction::SkipBatch,
802 AnomalyType::MemoryLeak => RecoveryAction::RestartOptimizer,
803 AnomalyType::PerformanceDegradation => {
804 RecoveryAction::ReduceLearningRate { factor: 0.8 }
805 },
806 _ => RecoveryAction::None,
807 }
808 }
809
810 async fn execute_recovery_action(&self, action: &RecoveryAction) -> Result<bool> {
811 match action {
814 RecoveryAction::None => Ok(true),
815 RecoveryAction::ResetGradients => {
816 tracing::info!("Executing recovery: Reset gradients");
817 Ok(true)
818 },
819 RecoveryAction::ReduceLearningRate { factor } => {
820 tracing::info!(
821 "Executing recovery: Reduce learning rate by factor {}",
822 factor
823 );
824 Ok(true)
825 },
826 RecoveryAction::ClipGradients { max_norm } => {
827 tracing::info!(
828 "Executing recovery: Clip gradients to max norm {}",
829 max_norm
830 );
831 Ok(true)
832 },
833 RecoveryAction::RestartOptimizer => {
834 tracing::info!("Executing recovery: Restart optimizer");
835 Ok(true)
836 },
837 RecoveryAction::SkipBatch => {
838 tracing::info!("Executing recovery: Skip current batch");
839 Ok(true)
840 },
841 RecoveryAction::ResetWeights { layer_name } => {
842 tracing::info!("Executing recovery: Reset weights for layer {}", layer_name);
843 Ok(true)
844 },
845 RecoveryAction::ApplyWeightDecay { rate } => {
846 tracing::info!("Executing recovery: Apply weight decay with rate {}", rate);
847 Ok(true)
848 },
849 RecoveryAction::EmergencyStop => {
850 tracing::warn!("Executing recovery: Emergency stop");
851 Ok(false) },
853 }
854 }
855
856 pub async fn quick_check(&self) -> Result<crate::QuickAnomalySummary> {
858 let anomaly_count = self.detected_anomalies.len();
859
860 let severity_level = match anomaly_count {
861 0 => "None",
862 1..=3 => "Low",
863 4..=10 => "Medium",
864 11..=20 => "High",
865 _ => "Critical",
866 }
867 .to_string();
868
869 let mut recommendations = Vec::new();
870 if anomaly_count > 0 {
871 recommendations.push("Review recent training metrics for instabilities".to_string());
872 }
873 if anomaly_count > 5 {
874 recommendations.push(
875 "Consider adjusting learning rate or implementing gradient clipping".to_string(),
876 );
877 }
878 if anomaly_count > 15 {
879 recommendations
880 .push("Training may need to be restarted with better configuration".to_string());
881 }
882 if anomaly_count == 0 {
883 recommendations.push("No anomalies detected, training appears stable".to_string());
884 }
885
886 Ok(crate::QuickAnomalySummary {
887 anomaly_count,
888 severity_level,
889 recommendations,
890 })
891 }
892
893 pub async fn generate_report(&self) -> Result<AnomalyDetectorReport> {
895 let mut anomaly_counts = HashMap::new();
896 for anomaly in &self.detected_anomalies {
897 let count = anomaly_counts.entry(format!("{:?}", anomaly.anomaly_type)).or_insert(0);
898 *count += 1;
899 }
900
901 Ok(AnomalyDetectorReport {
902 session_duration: Utc::now().signed_duration_since(self.start_time),
903 total_anomalies: self.detected_anomalies.len(),
904 anomaly_counts,
905 most_recent_anomalies: self.detected_anomalies.iter().rev().take(10).cloned().collect(),
906 config: self.config.clone(),
907 })
908 }
909}
910
911#[derive(Debug, Clone, Serialize, Deserialize)]
913pub struct AnomalyDetectorReport {
914 pub session_duration: chrono::Duration,
915 pub total_anomalies: usize,
916 pub anomaly_counts: HashMap<String, usize>,
917 pub most_recent_anomalies: Vec<Anomaly>,
918 pub config: AnomalyDetectorConfig,
919}
920
921#[cfg(test)]
922mod tests {
923 use super::*;
924
925 #[test]
926 fn test_anomaly_detector_creation() {
927 let config = DebugConfig::default();
928 let detector = AnomalyDetector::new(&config);
929 assert_eq!(detector.get_anomalies().len(), 0);
930 }
931
932 #[test]
933 fn test_nan_detection() {
934 let config = DebugConfig::default();
935 let mut detector = AnomalyDetector::new(&config);
936
937 let values = vec![1.0, 2.0, f32::NAN, 4.0];
938 detector.check_nan(&values, "test_location").unwrap();
939
940 assert_eq!(detector.get_anomalies().len(), 1);
941 assert!(matches!(
942 detector.get_anomalies()[0].anomaly_type,
943 AnomalyType::NaN
944 ));
945 }
946
947 #[test]
948 fn test_inf_detection() {
949 let config = DebugConfig::default();
950 let mut detector = AnomalyDetector::new(&config);
951
952 let values = vec![1.0, 2.0, f32::INFINITY, 4.0];
953 detector.check_inf(&values, "test_location").unwrap();
954
955 assert_eq!(detector.get_anomalies().len(), 1);
956 assert!(matches!(
957 detector.get_anomalies()[0].anomaly_type,
958 AnomalyType::Infinity
959 ));
960 }
961
962 #[test]
963 fn test_gradient_explosion_detection() {
964 let config = DebugConfig::default();
965 let mut detector = AnomalyDetector::new(&config);
966
967 detector.check_gradient_explosion(1e7, "test_layer").unwrap();
968
969 assert_eq!(detector.get_anomalies().len(), 1);
970 assert!(matches!(
971 detector.get_anomalies()[0].anomaly_type,
972 AnomalyType::GradientExplosion
973 ));
974 }
975
976 #[test]
977 fn test_gradient_vanishing_detection() {
978 let config = DebugConfig::default();
979 let mut detector = AnomalyDetector::new(&config);
980
981 detector.check_gradient_vanishing(1e-10, "test_layer").unwrap();
982
983 assert_eq!(detector.get_anomalies().len(), 1);
984 assert!(matches!(
985 detector.get_anomalies()[0].anomaly_type,
986 AnomalyType::GradientVanishing
987 ));
988 }
989
990 #[test]
991 fn test_numerical_instability_detection() {
992 let config = DebugConfig::default();
993 let mut detector = AnomalyDetector::new(&config);
994
995 let near_zero_values: Vec<f32> =
997 (0..100).map(|i| if i < 50 { 1e-12 } else { 1.0 }).collect();
998 detector
999 .check_numerical_instability(&near_zero_values, "test_location")
1000 .unwrap();
1001 assert_eq!(detector.get_anomalies().len(), 1);
1002
1003 detector.clear_anomalies();
1004
1005 let extreme_values = vec![1.0, 2.0, 1e7, 4.0];
1007 detector.check_numerical_instability(&extreme_values, "test_location").unwrap();
1008 assert_eq!(detector.get_anomalies().len(), 1);
1009 }
1010
1011 #[test]
1012 fn test_activation_saturation_detection() {
1013 let config = DebugConfig::default();
1014 let mut detector = AnomalyDetector::new(&config);
1015
1016 let relu_saturated: Vec<f32> = vec![0.0; 100];
1018 detector
1019 .check_activation_saturation(&relu_saturated, "relu", "test_layer")
1020 .unwrap();
1021 assert_eq!(detector.get_anomalies().len(), 1);
1022
1023 detector.clear_anomalies();
1024
1025 let sigmoid_saturated: Vec<f32> = vec![0.999; 100];
1027 detector
1028 .check_activation_saturation(&sigmoid_saturated, "sigmoid", "test_layer")
1029 .unwrap();
1030 assert_eq!(detector.get_anomalies().len(), 1);
1031 }
1032
1033 #[test]
1034 fn test_memory_leak_detection() {
1035 let config = DebugConfig::default();
1036 let mut detector = AnomalyDetector::new(&config);
1037
1038 detector.check_memory_leak(3072, Some(1024), "test_location").unwrap();
1040 assert_eq!(detector.get_anomalies().len(), 1);
1041 assert!(matches!(
1042 detector.get_anomalies()[0].anomaly_type,
1043 AnomalyType::MemoryLeak
1044 ));
1045
1046 detector.clear_anomalies();
1047
1048 detector.check_memory_leak(10240, None, "test_location").unwrap();
1050 assert_eq!(detector.get_anomalies().len(), 1);
1051 }
1052
1053 #[test]
1054 fn test_weight_explosion_detection() {
1055 let config = DebugConfig::default();
1056 let mut detector = AnomalyDetector::new(&config);
1057
1058 let weights = vec![1.0, 2.0, 15.0, 4.0, -20.0]; detector.check_weight_explosion(&weights, "test_layer").unwrap();
1060
1061 assert_eq!(detector.get_anomalies().len(), 1);
1062 assert!(matches!(
1063 detector.get_anomalies()[0].anomaly_type,
1064 AnomalyType::UnusualActivation
1065 ));
1066 }
1067
1068 #[test]
1069 fn test_gradient_conflict_detection() {
1070 let config = DebugConfig::default();
1071 let mut detector = AnomalyDetector::new(&config);
1072
1073 let mut layer_gradients = HashMap::new();
1074 layer_gradients.insert("layer1".to_string(), vec![1.0, 0.0, 0.0]);
1075 layer_gradients.insert("layer2".to_string(), vec![-1.0, 0.0, 0.0]); detector.check_gradient_conflict(&layer_gradients).unwrap();
1078
1079 assert_eq!(detector.get_anomalies().len(), 1);
1080 assert!(matches!(
1081 detector.get_anomalies()[0].anomaly_type,
1082 AnomalyType::GradientConflict
1083 ));
1084 }
1085
1086 #[test]
1087 fn test_weight_divergence_detection() {
1088 let config = DebugConfig::default();
1089 let mut detector = AnomalyDetector::new(&config);
1090
1091 let baseline_weights = vec![1.0, 2.0, 3.0, 4.0];
1092 let diverged_weights = vec![10.0, 20.0, 30.0, 40.0]; detector.check_weight_divergence("test_layer", &baseline_weights).unwrap();
1096 assert_eq!(detector.get_anomalies().len(), 0);
1097
1098 detector.check_weight_divergence("test_layer", &diverged_weights).unwrap();
1100 assert_eq!(detector.get_anomalies().len(), 1);
1101 assert!(matches!(
1102 detector.get_anomalies()[0].anomaly_type,
1103 AnomalyType::WeightDivergence
1104 ));
1105 }
1106
1107 #[test]
1108 fn test_performance_degradation_detection() {
1109 let config = DebugConfig::default();
1110 let mut detector = AnomalyDetector::new(&config);
1111
1112 for _ in 0..10 {
1114 detector.check_performance_degradation(100.0, "training").unwrap(); }
1116 assert_eq!(detector.get_anomalies().len(), 0);
1117
1118 for _ in 0..5 {
1120 detector.check_performance_degradation(20.0, "training").unwrap(); }
1122
1123 assert!(!detector.get_anomalies().is_empty());
1125 assert!(detector
1126 .get_anomalies()
1127 .iter()
1128 .any(|a| matches!(a.anomaly_type, AnomalyType::PerformanceDegradation)));
1129 }
1130
1131 #[test]
1132 fn test_loss_anomaly_detection() {
1133 let config = DebugConfig::default();
1134 let mut detector = AnomalyDetector::new(&config);
1135
1136 detector.check_loss_anomaly(1.0, "training").unwrap();
1138 detector.check_loss_anomaly(0.9, "training").unwrap();
1139 assert_eq!(detector.get_anomalies().len(), 0);
1140
1141 detector.check_loss_anomaly(100.0, "training").unwrap(); assert_eq!(detector.get_anomalies().len(), 1);
1144 assert!(matches!(
1145 detector.get_anomalies()[0].anomaly_type,
1146 AnomalyType::LossAnomalous
1147 ));
1148 }
1149
1150 #[tokio::test]
1151 async fn test_auto_recovery() {
1152 let config = DebugConfig::default();
1153 let mut detector = AnomalyDetector::new(&config);
1154 detector.config.enable_auto_recovery = true;
1155
1156 let anomaly = Anomaly {
1157 anomaly_type: AnomalyType::GradientExplosion,
1158 timestamp: Utc::now(),
1159 location: "test_layer".to_string(),
1160 description: "Test gradient explosion".to_string(),
1161 severity: AnomalySeverity::High,
1162 metadata: HashMap::new(),
1163 };
1164
1165 let action = detector.attempt_recovery(&anomaly).await.unwrap();
1166 assert!(matches!(action, RecoveryAction::ClipGradients { .. }));
1167 assert_eq!(detector.get_recovery_attempts().len(), 1);
1168 }
1169
1170 #[test]
1171 fn test_monitoring_stats() {
1172 let config = DebugConfig::default();
1173 let mut detector = AnomalyDetector::new(&config);
1174
1175 detector.check_nan(&[f32::NAN], "test").unwrap();
1177 detector.check_inf(&[f32::INFINITY], "test").unwrap();
1178
1179 let stats = detector.get_monitoring_stats();
1180 assert_eq!(stats.total_anomalies, 2);
1181 assert!(stats.anomalies_per_type.contains_key("NaN"));
1182 assert!(stats.anomalies_per_type.contains_key("Infinity"));
1183 }
1184
1185 #[test]
1186 fn test_monitoring_window_update() {
1187 let config = DebugConfig::default();
1188 let mut detector = AnomalyDetector::new(&config);
1189
1190 detector.check_nan(&[f32::NAN], "test").unwrap();
1191 detector.update_monitoring_window().unwrap();
1192
1193 let stats = detector.get_monitoring_stats();
1194 assert_eq!(stats.monitoring_window.len(), 1);
1195 assert_eq!(stats.monitoring_window[0].anomaly_count, 1);
1196 }
1197}