1use anyhow::Result;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10
11use crate::DebugConfig;
12
13#[derive(Debug)]
15pub struct AnomalyDetector {
16 config: AnomalyDetectorConfig,
17 detected_anomalies: Vec<Anomaly>,
18 start_time: DateTime<Utc>,
19 recovery_attempts: Vec<RecoveryAttempt>,
20 monitoring_stats: MonitoringStats,
21 performance_history: VecDeque<f64>,
22 #[allow(dead_code)]
23 gradient_history: HashMap<String, VecDeque<f64>>,
24 loss_history: VecDeque<f64>,
25 weight_baseline: HashMap<String, Vec<f32>>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct AnomalyDetectorConfig {
31 pub enable_nan_detection: bool,
32 pub enable_inf_detection: bool,
33 pub enable_gradient_explosion: bool,
34 pub enable_gradient_vanishing: bool,
35 pub gradient_threshold: f64,
36 pub enable_memory_leak_detection: bool,
37 pub enable_numerical_instability_detection: bool,
38 pub enable_gradient_conflict_detection: bool,
39 pub enable_performance_monitoring: bool,
40 pub enable_weight_divergence_detection: bool,
41 pub enable_activation_dead_detection: bool,
42 pub enable_loss_anomaly_detection: bool,
43 pub enable_auto_recovery: bool,
44 pub numerical_instability_threshold: f64,
45 pub performance_degradation_threshold: f64,
46 pub weight_divergence_threshold: f64,
47 pub loss_spike_threshold: f64,
48 pub monitoring_window_size: usize,
49 pub recovery_attempts_limit: usize,
50}
51
52impl Default for AnomalyDetectorConfig {
53 fn default() -> Self {
54 Self {
55 enable_nan_detection: true,
56 enable_inf_detection: true,
57 enable_gradient_explosion: true,
58 enable_gradient_vanishing: true,
59 gradient_threshold: 1e6,
60 enable_memory_leak_detection: true,
61 enable_numerical_instability_detection: true,
62 enable_gradient_conflict_detection: true,
63 enable_performance_monitoring: true,
64 enable_weight_divergence_detection: true,
65 enable_activation_dead_detection: true,
66 enable_loss_anomaly_detection: true,
67 enable_auto_recovery: false, numerical_instability_threshold: 1e-12,
69 performance_degradation_threshold: 0.5, weight_divergence_threshold: 5.0,
71 loss_spike_threshold: 10.0, monitoring_window_size: 100,
73 recovery_attempts_limit: 3,
74 }
75 }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum AnomalyType {
81 NaN,
82 Infinity,
83 GradientExplosion,
84 GradientVanishing,
85 MemoryLeak,
86 UnusualActivation,
87 NumericalInstability,
88 GradientConflict,
89 PerformanceDegradation,
90 WeightDivergence,
91 ActivationDead,
92 LossAnomalous,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct Anomaly {
98 pub anomaly_type: AnomalyType,
99 pub timestamp: DateTime<Utc>,
100 pub location: String,
101 pub description: String,
102 pub severity: AnomalySeverity,
103 pub metadata: HashMap<String, String>,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub enum AnomalySeverity {
109 Low,
110 Medium,
111 High,
112 Critical,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub enum RecoveryAction {
118 None,
119 ResetGradients,
120 ReduceLearningRate { factor: f64 },
121 ClipGradients { max_norm: f64 },
122 RestartOptimizer,
123 SkipBatch,
124 ResetWeights { layer_name: String },
125 ApplyWeightDecay { rate: f64 },
126 EmergencyStop,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct RecoveryAttempt {
132 pub anomaly_id: String,
133 pub action: RecoveryAction,
134 pub timestamp: DateTime<Utc>,
135 pub success: bool,
136 pub error_message: Option<String>,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct MonitoringStats {
142 pub total_anomalies: usize,
143 pub anomalies_per_type: HashMap<String, usize>,
144 pub recovery_attempts: usize,
145 pub successful_recoveries: usize,
146 pub average_detection_time_ms: f64,
147 pub monitoring_window: Vec<AnomalySnapshot>,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct AnomalySnapshot {
153 pub timestamp: DateTime<Utc>,
154 pub anomaly_count: usize,
155 pub severity_distribution: HashMap<String, usize>,
156 pub performance_metrics: HashMap<String, f64>,
157}
158
159impl AnomalyDetector {
160 pub fn new(_config: &DebugConfig) -> Self {
162 let monitoring_window_size = AnomalyDetectorConfig::default().monitoring_window_size;
163 Self {
164 config: AnomalyDetectorConfig::default(),
165 detected_anomalies: Vec::new(),
166 start_time: Utc::now(),
167 recovery_attempts: Vec::new(),
168 monitoring_stats: MonitoringStats {
169 total_anomalies: 0,
170 anomalies_per_type: HashMap::new(),
171 recovery_attempts: 0,
172 successful_recoveries: 0,
173 average_detection_time_ms: 0.0,
174 monitoring_window: Vec::new(),
175 },
176 performance_history: VecDeque::with_capacity(monitoring_window_size),
177 gradient_history: HashMap::new(),
178 loss_history: VecDeque::with_capacity(monitoring_window_size),
179 weight_baseline: HashMap::new(),
180 }
181 }
182
183 pub async fn start(&mut self) -> Result<()> {
185 self.start_time = Utc::now();
186 self.detected_anomalies.clear();
187 Ok(())
188 }
189
190 pub fn check_nan(&mut self, values: &[f32], location: &str) -> Result<()> {
192 if !self.config.enable_nan_detection {
193 return Ok(());
194 }
195
196 if values.iter().any(|v| v.is_nan()) {
197 self.report_anomaly(Anomaly {
198 anomaly_type: AnomalyType::NaN,
199 timestamp: Utc::now(),
200 location: location.to_string(),
201 description: "NaN values detected in tensor".to_string(),
202 severity: AnomalySeverity::High,
203 metadata: HashMap::new(),
204 });
205 }
206
207 Ok(())
208 }
209
210 pub fn check_inf(&mut self, values: &[f32], location: &str) -> Result<()> {
212 if !self.config.enable_inf_detection {
213 return Ok(());
214 }
215
216 if values.iter().any(|v| v.is_infinite()) {
217 self.report_anomaly(Anomaly {
218 anomaly_type: AnomalyType::Infinity,
219 timestamp: Utc::now(),
220 location: location.to_string(),
221 description: "Infinite values detected in tensor".to_string(),
222 severity: AnomalySeverity::High,
223 metadata: HashMap::new(),
224 });
225 }
226
227 Ok(())
228 }
229
230 pub fn check_gradient_explosion(&mut self, gradient_norm: f64, location: &str) -> Result<()> {
232 if !self.config.enable_gradient_explosion {
233 return Ok(());
234 }
235
236 if gradient_norm > self.config.gradient_threshold {
237 self.report_anomaly(Anomaly {
238 anomaly_type: AnomalyType::GradientExplosion,
239 timestamp: Utc::now(),
240 location: location.to_string(),
241 description: format!("Gradient explosion detected: norm = {}", gradient_norm),
242 severity: AnomalySeverity::Critical,
243 metadata: {
244 let mut meta = HashMap::new();
245 meta.insert("gradient_norm".to_string(), gradient_norm.to_string());
246 meta
247 },
248 });
249 }
250
251 Ok(())
252 }
253
254 pub fn check_gradient_vanishing(&mut self, gradient_norm: f64, location: &str) -> Result<()> {
256 if !self.config.enable_gradient_vanishing {
257 return Ok(());
258 }
259
260 let vanishing_threshold = 1e-8;
261 if gradient_norm < vanishing_threshold {
262 self.report_anomaly(Anomaly {
263 anomaly_type: AnomalyType::GradientVanishing,
264 timestamp: Utc::now(),
265 location: location.to_string(),
266 description: format!("Vanishing gradient detected: norm = {}", gradient_norm),
267 severity: AnomalySeverity::High,
268 metadata: {
269 let mut meta = HashMap::new();
270 meta.insert("gradient_norm".to_string(), gradient_norm.to_string());
271 meta.insert("threshold".to_string(), vanishing_threshold.to_string());
272 meta
273 },
274 });
275 }
276
277 Ok(())
278 }
279
280 pub fn check_numerical_instability(&mut self, values: &[f32], location: &str) -> Result<()> {
282 let mut metadata = HashMap::new();
283
284 let near_zero_count = values.iter().filter(|&&v| v.abs() < 1e-10 && v != 0.0).count();
286 if near_zero_count > values.len() / 10 {
287 metadata.insert("near_zero_count".to_string(), near_zero_count.to_string());
288 metadata.insert("total_values".to_string(), values.len().to_string());
289
290 self.report_anomaly(Anomaly {
291 anomaly_type: AnomalyType::UnusualActivation,
292 timestamp: Utc::now(),
293 location: location.to_string(),
294 description: format!(
295 "Numerical instability: {} values near zero",
296 near_zero_count
297 ),
298 severity: AnomalySeverity::Medium,
299 metadata: metadata.clone(),
300 });
301 }
302
303 let extreme_count = values.iter().filter(|&&v| v.abs() > 1e6).count();
305 if extreme_count > 0 {
306 metadata.insert("extreme_count".to_string(), extreme_count.to_string());
307
308 self.report_anomaly(Anomaly {
309 anomaly_type: AnomalyType::UnusualActivation,
310 timestamp: Utc::now(),
311 location: location.to_string(),
312 description: format!("Numerical instability: {} extreme values", extreme_count),
313 severity: AnomalySeverity::High,
314 metadata,
315 });
316 }
317
318 Ok(())
319 }
320
321 pub fn check_activation_saturation(
323 &mut self,
324 activations: &[f32],
325 activation_type: &str,
326 location: &str,
327 ) -> Result<()> {
328 let saturation_threshold = match activation_type.to_lowercase().as_str() {
329 "sigmoid" | "tanh" => 0.01, "relu" => 0.0, _ => 0.01,
332 };
333
334 let saturated_count = match activation_type.to_lowercase().as_str() {
335 "sigmoid" => activations
336 .iter()
337 .filter(|&&v| v < saturation_threshold || v > 1.0 - saturation_threshold)
338 .count(),
339 "tanh" => activations.iter().filter(|&&v| v.abs() > 1.0 - saturation_threshold).count(),
340 "relu" => activations.iter().filter(|&&v| v == 0.0).count(),
341 _ => activations.iter().filter(|&&v| v.abs() < saturation_threshold).count(),
342 };
343
344 let saturation_ratio = saturated_count as f32 / activations.len() as f32;
345
346 if saturation_ratio > 0.9 {
347 let mut metadata = HashMap::new();
348 metadata.insert("activation_type".to_string(), activation_type.to_string());
349 metadata.insert("saturated_count".to_string(), saturated_count.to_string());
350 metadata.insert("total_count".to_string(), activations.len().to_string());
351 metadata.insert("saturation_ratio".to_string(), saturation_ratio.to_string());
352
353 self.report_anomaly(Anomaly {
354 anomaly_type: AnomalyType::UnusualActivation,
355 timestamp: Utc::now(),
356 location: location.to_string(),
357 description: format!(
358 "Activation saturation detected: {:.1}% of {} activations saturated",
359 saturation_ratio * 100.0,
360 activation_type
361 ),
362 severity: AnomalySeverity::High,
363 metadata,
364 });
365 }
366
367 Ok(())
368 }
369
370 pub fn check_memory_leak(
372 &mut self,
373 current_memory_mb: usize,
374 expected_memory_mb: Option<usize>,
375 location: &str,
376 ) -> Result<()> {
377 if !self.config.enable_memory_leak_detection {
378 return Ok(());
379 }
380
381 let mut should_report = false;
382 let mut description = String::new();
383 let mut metadata = HashMap::new();
384
385 metadata.insert(
386 "current_memory_mb".to_string(),
387 current_memory_mb.to_string(),
388 );
389
390 if let Some(expected) = expected_memory_mb {
391 metadata.insert("expected_memory_mb".to_string(), expected.to_string());
392
393 let growth_ratio = current_memory_mb as f64 / expected as f64;
394 if growth_ratio > 2.0 {
395 should_report = true;
396 description = format!(
397 "Memory usage {}MB is {:.1}x expected {}MB",
398 current_memory_mb, growth_ratio, expected
399 );
400 metadata.insert("growth_ratio".to_string(), growth_ratio.to_string());
401 }
402 } else {
403 if current_memory_mb > 8192 {
405 should_report = true;
407 description = format!("High memory usage detected: {}MB", current_memory_mb);
408 }
409 }
410
411 if should_report {
412 self.report_anomaly(Anomaly {
413 anomaly_type: AnomalyType::MemoryLeak,
414 timestamp: Utc::now(),
415 location: location.to_string(),
416 description,
417 severity: if current_memory_mb > 16384 {
418 AnomalySeverity::Critical
419 } else {
420 AnomalySeverity::High
421 },
422 metadata,
423 });
424 }
425
426 Ok(())
427 }
428
429 pub fn check_weight_explosion(&mut self, weights: &[f32], layer_name: &str) -> Result<()> {
431 let weight_threshold = 10.0;
432 let extreme_weights: Vec<f32> =
433 weights.iter().filter(|&&w| w.abs() > weight_threshold).cloned().collect();
434
435 if !extreme_weights.is_empty() {
436 let mut metadata = HashMap::new();
437 metadata.insert("layer_name".to_string(), layer_name.to_string());
438 metadata.insert(
439 "extreme_weight_count".to_string(),
440 extreme_weights.len().to_string(),
441 );
442 metadata.insert("total_weight_count".to_string(), weights.len().to_string());
443 metadata.insert(
444 "max_weight".to_string(),
445 extreme_weights.iter().map(|w| w.abs()).fold(0.0f32, f32::max).to_string(),
446 );
447
448 self.report_anomaly(Anomaly {
449 anomaly_type: AnomalyType::UnusualActivation,
450 timestamp: Utc::now(),
451 location: layer_name.to_string(),
452 description: format!(
453 "Weight explosion in {}: {} weights > {}",
454 layer_name,
455 extreme_weights.len(),
456 weight_threshold
457 ),
458 severity: AnomalySeverity::High,
459 metadata,
460 });
461 }
462
463 Ok(())
464 }
465
466 fn report_anomaly(&mut self, anomaly: Anomaly) {
468 eprintln!(
469 "🚨 Anomaly detected: {} at {}",
470 anomaly.description, anomaly.location
471 );
472
473 self.monitoring_stats.total_anomalies += 1;
475 let anomaly_type_key = format!("{:?}", anomaly.anomaly_type);
476 *self.monitoring_stats.anomalies_per_type.entry(anomaly_type_key).or_insert(0) += 1;
477
478 self.detected_anomalies.push(anomaly);
479 }
480
481 pub fn get_anomalies(&self) -> &[Anomaly] {
483 &self.detected_anomalies
484 }
485
486 pub fn clear_anomalies(&mut self) {
488 self.detected_anomalies.clear();
489 }
490
491 pub fn check_gradient_conflict(
493 &mut self,
494 layer_gradients: &HashMap<String, Vec<f32>>,
495 ) -> Result<()> {
496 if !self.config.enable_gradient_conflict_detection {
497 return Ok(());
498 }
499
500 let layer_names: Vec<_> = layer_gradients.keys().cloned().collect();
501
502 for i in 0..layer_names.len() {
503 for j in i + 1..layer_names.len() {
504 let layer1 = &layer_names[i];
505 let layer2 = &layer_names[j];
506
507 if let (Some(grad1), Some(grad2)) =
508 (layer_gradients.get(layer1), layer_gradients.get(layer2))
509 {
510 let conflict_score = self.compute_gradient_conflict(grad1, grad2);
511
512 if conflict_score > 0.8 {
513 let mut metadata = HashMap::new();
514 metadata.insert("layer1".to_string(), layer1.clone());
515 metadata.insert("layer2".to_string(), layer2.clone());
516 metadata.insert("conflict_score".to_string(), conflict_score.to_string());
517
518 self.report_anomaly(Anomaly {
519 anomaly_type: AnomalyType::GradientConflict,
520 timestamp: Utc::now(),
521 location: format!("{}↔{}", layer1, layer2),
522 description: format!(
523 "Gradient conflict detected between {} and {} (score: {:.2})",
524 layer1, layer2, conflict_score
525 ),
526 severity: AnomalySeverity::High,
527 metadata,
528 });
529 }
530 }
531 }
532 }
533
534 Ok(())
535 }
536
537 pub fn check_weight_divergence(
539 &mut self,
540 layer_name: &str,
541 current_weights: &[f32],
542 ) -> Result<()> {
543 if !self.config.enable_weight_divergence_detection {
544 return Ok(());
545 }
546
547 if !self.weight_baseline.contains_key(layer_name) {
549 self.weight_baseline.insert(layer_name.to_string(), current_weights.to_vec());
550 return Ok(());
551 }
552
553 let baseline = self
554 .weight_baseline
555 .get(layer_name)
556 .expect("baseline should exist after contains_key check");
557 if baseline.len() != current_weights.len() {
558 return Ok(()); }
560
561 let divergence = self.compute_weight_divergence(baseline, current_weights);
562
563 if divergence > self.config.weight_divergence_threshold {
564 let mut metadata = HashMap::new();
565 metadata.insert("layer_name".to_string(), layer_name.to_string());
566 metadata.insert("divergence_score".to_string(), divergence.to_string());
567 metadata.insert(
568 "threshold".to_string(),
569 self.config.weight_divergence_threshold.to_string(),
570 );
571
572 self.report_anomaly(Anomaly {
573 anomaly_type: AnomalyType::WeightDivergence,
574 timestamp: Utc::now(),
575 location: layer_name.to_string(),
576 description: format!(
577 "Weight divergence in {}: {:.2} (threshold: {:.2})",
578 layer_name, divergence, self.config.weight_divergence_threshold
579 ),
580 severity: if divergence > self.config.weight_divergence_threshold * 2.0 {
581 AnomalySeverity::Critical
582 } else {
583 AnomalySeverity::High
584 },
585 metadata,
586 });
587 }
588
589 Ok(())
590 }
591
592 pub fn check_performance_degradation(
594 &mut self,
595 current_performance: f64,
596 location: &str,
597 ) -> Result<()> {
598 if !self.config.enable_performance_monitoring {
599 return Ok(());
600 }
601
602 if self.performance_history.len() >= self.config.monitoring_window_size {
604 self.performance_history.pop_front();
605 }
606 self.performance_history.push_back(current_performance);
607
608 if self.performance_history.len() >= 10 {
610 let recent_avg = self.performance_history.iter().rev().take(5).sum::<f64>() / 5.0;
611 let baseline_avg = self.performance_history.iter().take(5).sum::<f64>() / 5.0;
612
613 let degradation_ratio = (baseline_avg - recent_avg) / baseline_avg;
614
615 if degradation_ratio > self.config.performance_degradation_threshold {
616 let mut metadata = HashMap::new();
617 metadata.insert("baseline_performance".to_string(), baseline_avg.to_string());
618 metadata.insert("current_performance".to_string(), recent_avg.to_string());
619 metadata.insert(
620 "degradation_ratio".to_string(),
621 degradation_ratio.to_string(),
622 );
623
624 self.report_anomaly(Anomaly {
625 anomaly_type: AnomalyType::PerformanceDegradation,
626 timestamp: Utc::now(),
627 location: location.to_string(),
628 description: format!(
629 "Performance degradation detected: {:.1}% drop from baseline",
630 degradation_ratio * 100.0
631 ),
632 severity: if degradation_ratio > 0.8 {
633 AnomalySeverity::Critical
634 } else {
635 AnomalySeverity::High
636 },
637 metadata,
638 });
639 }
640 }
641
642 Ok(())
643 }
644
645 pub fn check_loss_anomaly(&mut self, current_loss: f64, location: &str) -> Result<()> {
647 if !self.config.enable_loss_anomaly_detection {
648 return Ok(());
649 }
650
651 if self.loss_history.len() >= self.config.monitoring_window_size {
653 self.loss_history.pop_front();
654 }
655 self.loss_history.push_back(current_loss);
656
657 if self.loss_history.len() >= 3 {
659 let prev_loss = self.loss_history[self.loss_history.len() - 2];
660 let loss_ratio = current_loss / prev_loss;
661
662 if loss_ratio > self.config.loss_spike_threshold {
663 let mut metadata = HashMap::new();
664 metadata.insert("previous_loss".to_string(), prev_loss.to_string());
665 metadata.insert("current_loss".to_string(), current_loss.to_string());
666 metadata.insert("spike_ratio".to_string(), loss_ratio.to_string());
667
668 self.report_anomaly(Anomaly {
669 anomaly_type: AnomalyType::LossAnomalous,
670 timestamp: Utc::now(),
671 location: location.to_string(),
672 description: format!(
673 "Loss spike detected: {:.2}x increase (from {:.6} to {:.6})",
674 loss_ratio, prev_loss, current_loss
675 ),
676 severity: if loss_ratio > 100.0 {
677 AnomalySeverity::Critical
678 } else {
679 AnomalySeverity::High
680 },
681 metadata,
682 });
683 }
684 }
685
686 Ok(())
687 }
688
689 pub async fn attempt_recovery(&mut self, anomaly: &Anomaly) -> Result<RecoveryAction> {
691 if !self.config.enable_auto_recovery {
692 return Ok(RecoveryAction::None);
693 }
694
695 let action = self.determine_recovery_action(anomaly);
696 let anomaly_id = format!(
697 "{:?}_{}",
698 anomaly.anomaly_type,
699 anomaly.timestamp.timestamp()
700 );
701
702 let success = self.execute_recovery_action(&action).await?;
703
704 self.recovery_attempts.push(RecoveryAttempt {
705 anomaly_id: anomaly_id.clone(),
706 action: action.clone(),
707 timestamp: Utc::now(),
708 success,
709 error_message: if success { None } else { Some("Recovery failed".to_string()) },
710 });
711
712 self.monitoring_stats.recovery_attempts += 1;
713 if success {
714 self.monitoring_stats.successful_recoveries += 1;
715 }
716
717 Ok(action)
718 }
719
720 pub fn get_monitoring_stats(&self) -> &MonitoringStats {
722 &self.monitoring_stats
723 }
724
725 pub fn get_recovery_attempts(&self) -> &[RecoveryAttempt] {
727 &self.recovery_attempts
728 }
729
730 pub fn update_monitoring_window(&mut self) -> Result<()> {
732 let mut severity_distribution = HashMap::new();
733 for anomaly in &self.detected_anomalies {
734 let key = format!("{:?}", anomaly.severity);
735 *severity_distribution.entry(key).or_insert(0) += 1;
736 }
737
738 let mut performance_metrics = HashMap::new();
739 if let Some(latest_perf) = self.performance_history.back() {
740 performance_metrics.insert("latest_performance".to_string(), *latest_perf);
741 }
742 if let Some(latest_loss) = self.loss_history.back() {
743 performance_metrics.insert("latest_loss".to_string(), *latest_loss);
744 }
745
746 let snapshot = AnomalySnapshot {
747 timestamp: Utc::now(),
748 anomaly_count: self.detected_anomalies.len(),
749 severity_distribution,
750 performance_metrics,
751 };
752
753 self.monitoring_stats.monitoring_window.push(snapshot);
754
755 if self.monitoring_stats.monitoring_window.len() > self.config.monitoring_window_size {
757 self.monitoring_stats.monitoring_window.remove(0);
758 }
759
760 Ok(())
761 }
762
763 fn compute_gradient_conflict(&self, grad1: &[f32], grad2: &[f32]) -> f64 {
766 if grad1.len() != grad2.len() {
767 return 0.0;
768 }
769
770 let dot_product: f64 =
771 grad1.iter().zip(grad2.iter()).map(|(a, b)| (*a as f64) * (*b as f64)).sum();
772
773 let norm1: f64 = grad1.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
774 let norm2: f64 = grad2.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
775
776 if norm1 == 0.0 || norm2 == 0.0 {
777 return 0.0;
778 }
779
780 let cosine_sim = dot_product / (norm1 * norm2);
782
783 (1.0 - cosine_sim) / 2.0
785 }
786
787 fn compute_weight_divergence(&self, baseline: &[f32], current: &[f32]) -> f64 {
788 let mse: f64 = baseline
789 .iter()
790 .zip(current.iter())
791 .map(|(a, b)| (*a as f64 - *b as f64).powi(2))
792 .sum::<f64>()
793 / baseline.len() as f64;
794
795 mse.sqrt()
796 }
797
798 fn determine_recovery_action(&self, anomaly: &Anomaly) -> RecoveryAction {
799 match anomaly.anomaly_type {
800 AnomalyType::GradientExplosion => RecoveryAction::ClipGradients { max_norm: 1.0 },
801 AnomalyType::GradientVanishing => RecoveryAction::ReduceLearningRate { factor: 0.5 },
802 AnomalyType::NaN | AnomalyType::Infinity => RecoveryAction::ResetGradients,
803 AnomalyType::WeightDivergence => RecoveryAction::ApplyWeightDecay { rate: 0.01 },
804 AnomalyType::LossAnomalous => RecoveryAction::SkipBatch,
805 AnomalyType::MemoryLeak => RecoveryAction::RestartOptimizer,
806 AnomalyType::PerformanceDegradation => {
807 RecoveryAction::ReduceLearningRate { factor: 0.8 }
808 },
809 _ => RecoveryAction::None,
810 }
811 }
812
813 async fn execute_recovery_action(&self, action: &RecoveryAction) -> Result<bool> {
814 match action {
817 RecoveryAction::None => Ok(true),
818 RecoveryAction::ResetGradients => {
819 tracing::info!("Executing recovery: Reset gradients");
820 Ok(true)
821 },
822 RecoveryAction::ReduceLearningRate { factor } => {
823 tracing::info!(
824 "Executing recovery: Reduce learning rate by factor {}",
825 factor
826 );
827 Ok(true)
828 },
829 RecoveryAction::ClipGradients { max_norm } => {
830 tracing::info!(
831 "Executing recovery: Clip gradients to max norm {}",
832 max_norm
833 );
834 Ok(true)
835 },
836 RecoveryAction::RestartOptimizer => {
837 tracing::info!("Executing recovery: Restart optimizer");
838 Ok(true)
839 },
840 RecoveryAction::SkipBatch => {
841 tracing::info!("Executing recovery: Skip current batch");
842 Ok(true)
843 },
844 RecoveryAction::ResetWeights { layer_name } => {
845 tracing::info!("Executing recovery: Reset weights for layer {}", layer_name);
846 Ok(true)
847 },
848 RecoveryAction::ApplyWeightDecay { rate } => {
849 tracing::info!("Executing recovery: Apply weight decay with rate {}", rate);
850 Ok(true)
851 },
852 RecoveryAction::EmergencyStop => {
853 tracing::warn!("Executing recovery: Emergency stop");
854 Ok(false) },
856 }
857 }
858
859 pub async fn quick_check(&self) -> Result<crate::QuickAnomalySummary> {
861 let anomaly_count = self.detected_anomalies.len();
862
863 let severity_level = match anomaly_count {
864 0 => "None",
865 1..=3 => "Low",
866 4..=10 => "Medium",
867 11..=20 => "High",
868 _ => "Critical",
869 }
870 .to_string();
871
872 let mut recommendations = Vec::new();
873 if anomaly_count > 0 {
874 recommendations.push("Review recent training metrics for instabilities".to_string());
875 }
876 if anomaly_count > 5 {
877 recommendations.push(
878 "Consider adjusting learning rate or implementing gradient clipping".to_string(),
879 );
880 }
881 if anomaly_count > 15 {
882 recommendations
883 .push("Training may need to be restarted with better configuration".to_string());
884 }
885 if anomaly_count == 0 {
886 recommendations.push("No anomalies detected, training appears stable".to_string());
887 }
888
889 Ok(crate::QuickAnomalySummary {
890 anomaly_count,
891 severity_level,
892 recommendations,
893 })
894 }
895
896 pub async fn generate_report(&self) -> Result<AnomalyDetectorReport> {
898 let mut anomaly_counts = HashMap::new();
899 for anomaly in &self.detected_anomalies {
900 let count = anomaly_counts.entry(format!("{:?}", anomaly.anomaly_type)).or_insert(0);
901 *count += 1;
902 }
903
904 Ok(AnomalyDetectorReport {
905 session_duration: Utc::now().signed_duration_since(self.start_time),
906 total_anomalies: self.detected_anomalies.len(),
907 anomaly_counts,
908 most_recent_anomalies: self.detected_anomalies.iter().rev().take(10).cloned().collect(),
909 config: self.config.clone(),
910 })
911 }
912}
913
914#[derive(Debug, Clone, Serialize, Deserialize)]
916pub struct AnomalyDetectorReport {
917 pub session_duration: chrono::Duration,
918 pub total_anomalies: usize,
919 pub anomaly_counts: HashMap<String, usize>,
920 pub most_recent_anomalies: Vec<Anomaly>,
921 pub config: AnomalyDetectorConfig,
922}
923
924#[cfg(test)]
925mod tests {
926 use super::*;
927
928 #[test]
929 fn test_anomaly_detector_creation() {
930 let config = DebugConfig::default();
931 let detector = AnomalyDetector::new(&config);
932 assert_eq!(detector.get_anomalies().len(), 0);
933 }
934
935 #[test]
936 fn test_nan_detection() {
937 let config = DebugConfig::default();
938 let mut detector = AnomalyDetector::new(&config);
939
940 let values = vec![1.0, 2.0, f32::NAN, 4.0];
941 detector.check_nan(&values, "test_location").expect("operation failed in test");
942
943 assert_eq!(detector.get_anomalies().len(), 1);
944 assert!(matches!(
945 detector.get_anomalies()[0].anomaly_type,
946 AnomalyType::NaN
947 ));
948 }
949
950 #[test]
951 fn test_inf_detection() {
952 let config = DebugConfig::default();
953 let mut detector = AnomalyDetector::new(&config);
954
955 let values = vec![1.0, 2.0, f32::INFINITY, 4.0];
956 detector.check_inf(&values, "test_location").expect("operation failed in test");
957
958 assert_eq!(detector.get_anomalies().len(), 1);
959 assert!(matches!(
960 detector.get_anomalies()[0].anomaly_type,
961 AnomalyType::Infinity
962 ));
963 }
964
965 #[test]
966 fn test_gradient_explosion_detection() {
967 let config = DebugConfig::default();
968 let mut detector = AnomalyDetector::new(&config);
969
970 detector
971 .check_gradient_explosion(1e7, "test_layer")
972 .expect("operation failed in test");
973
974 assert_eq!(detector.get_anomalies().len(), 1);
975 assert!(matches!(
976 detector.get_anomalies()[0].anomaly_type,
977 AnomalyType::GradientExplosion
978 ));
979 }
980
981 #[test]
982 fn test_gradient_vanishing_detection() {
983 let config = DebugConfig::default();
984 let mut detector = AnomalyDetector::new(&config);
985
986 detector
987 .check_gradient_vanishing(1e-10, "test_layer")
988 .expect("operation failed in test");
989
990 assert_eq!(detector.get_anomalies().len(), 1);
991 assert!(matches!(
992 detector.get_anomalies()[0].anomaly_type,
993 AnomalyType::GradientVanishing
994 ));
995 }
996
997 #[test]
998 fn test_numerical_instability_detection() {
999 let config = DebugConfig::default();
1000 let mut detector = AnomalyDetector::new(&config);
1001
1002 let near_zero_values: Vec<f32> =
1004 (0..100).map(|i| if i < 50 { 1e-12 } else { 1.0 }).collect();
1005 detector
1006 .check_numerical_instability(&near_zero_values, "test_location")
1007 .expect("operation failed in test");
1008 assert_eq!(detector.get_anomalies().len(), 1);
1009
1010 detector.clear_anomalies();
1011
1012 let extreme_values = vec![1.0, 2.0, 1e7, 4.0];
1014 detector
1015 .check_numerical_instability(&extreme_values, "test_location")
1016 .expect("operation failed in test");
1017 assert_eq!(detector.get_anomalies().len(), 1);
1018 }
1019
1020 #[test]
1021 fn test_activation_saturation_detection() {
1022 let config = DebugConfig::default();
1023 let mut detector = AnomalyDetector::new(&config);
1024
1025 let relu_saturated: Vec<f32> = vec![0.0; 100];
1027 detector
1028 .check_activation_saturation(&relu_saturated, "relu", "test_layer")
1029 .expect("operation failed in test");
1030 assert_eq!(detector.get_anomalies().len(), 1);
1031
1032 detector.clear_anomalies();
1033
1034 let sigmoid_saturated: Vec<f32> = vec![0.999; 100];
1036 detector
1037 .check_activation_saturation(&sigmoid_saturated, "sigmoid", "test_layer")
1038 .expect("operation failed in test");
1039 assert_eq!(detector.get_anomalies().len(), 1);
1040 }
1041
1042 #[test]
1043 fn test_memory_leak_detection() {
1044 let config = DebugConfig::default();
1045 let mut detector = AnomalyDetector::new(&config);
1046
1047 detector
1049 .check_memory_leak(3072, Some(1024), "test_location")
1050 .expect("operation failed in test");
1051 assert_eq!(detector.get_anomalies().len(), 1);
1052 assert!(matches!(
1053 detector.get_anomalies()[0].anomaly_type,
1054 AnomalyType::MemoryLeak
1055 ));
1056
1057 detector.clear_anomalies();
1058
1059 detector
1061 .check_memory_leak(10240, None, "test_location")
1062 .expect("operation failed in test");
1063 assert_eq!(detector.get_anomalies().len(), 1);
1064 }
1065
1066 #[test]
1067 fn test_weight_explosion_detection() {
1068 let config = DebugConfig::default();
1069 let mut detector = AnomalyDetector::new(&config);
1070
1071 let weights = vec![1.0, 2.0, 15.0, 4.0, -20.0]; detector
1073 .check_weight_explosion(&weights, "test_layer")
1074 .expect("operation failed in test");
1075
1076 assert_eq!(detector.get_anomalies().len(), 1);
1077 assert!(matches!(
1078 detector.get_anomalies()[0].anomaly_type,
1079 AnomalyType::UnusualActivation
1080 ));
1081 }
1082
1083 #[test]
1084 fn test_gradient_conflict_detection() {
1085 let config = DebugConfig::default();
1086 let mut detector = AnomalyDetector::new(&config);
1087
1088 let mut layer_gradients = HashMap::new();
1089 layer_gradients.insert("layer1".to_string(), vec![1.0, 0.0, 0.0]);
1090 layer_gradients.insert("layer2".to_string(), vec![-1.0, 0.0, 0.0]); detector
1093 .check_gradient_conflict(&layer_gradients)
1094 .expect("operation failed in test");
1095
1096 assert_eq!(detector.get_anomalies().len(), 1);
1097 assert!(matches!(
1098 detector.get_anomalies()[0].anomaly_type,
1099 AnomalyType::GradientConflict
1100 ));
1101 }
1102
1103 #[test]
1104 fn test_weight_divergence_detection() {
1105 let config = DebugConfig::default();
1106 let mut detector = AnomalyDetector::new(&config);
1107
1108 let baseline_weights = vec![1.0, 2.0, 3.0, 4.0];
1109 let diverged_weights = vec![10.0, 20.0, 30.0, 40.0]; detector
1113 .check_weight_divergence("test_layer", &baseline_weights)
1114 .expect("operation failed in test");
1115 assert_eq!(detector.get_anomalies().len(), 0);
1116
1117 detector
1119 .check_weight_divergence("test_layer", &diverged_weights)
1120 .expect("operation failed in test");
1121 assert_eq!(detector.get_anomalies().len(), 1);
1122 assert!(matches!(
1123 detector.get_anomalies()[0].anomaly_type,
1124 AnomalyType::WeightDivergence
1125 ));
1126 }
1127
1128 #[test]
1129 fn test_performance_degradation_detection() {
1130 let config = DebugConfig::default();
1131 let mut detector = AnomalyDetector::new(&config);
1132
1133 for _ in 0..10 {
1135 detector
1136 .check_performance_degradation(100.0, "training")
1137 .expect("operation failed in test"); }
1139 assert_eq!(detector.get_anomalies().len(), 0);
1140
1141 for _ in 0..5 {
1143 detector
1144 .check_performance_degradation(20.0, "training")
1145 .expect("operation failed in test"); }
1147
1148 assert!(!detector.get_anomalies().is_empty());
1150 assert!(detector
1151 .get_anomalies()
1152 .iter()
1153 .any(|a| matches!(a.anomaly_type, AnomalyType::PerformanceDegradation)));
1154 }
1155
1156 #[test]
1157 fn test_loss_anomaly_detection() {
1158 let config = DebugConfig::default();
1159 let mut detector = AnomalyDetector::new(&config);
1160
1161 detector.check_loss_anomaly(1.0, "training").expect("operation failed in test");
1163 detector.check_loss_anomaly(0.9, "training").expect("operation failed in test");
1164 assert_eq!(detector.get_anomalies().len(), 0);
1165
1166 detector
1168 .check_loss_anomaly(100.0, "training")
1169 .expect("operation failed in test"); assert_eq!(detector.get_anomalies().len(), 1);
1171 assert!(matches!(
1172 detector.get_anomalies()[0].anomaly_type,
1173 AnomalyType::LossAnomalous
1174 ));
1175 }
1176
1177 #[tokio::test]
1178 async fn test_auto_recovery() {
1179 let config = DebugConfig::default();
1180 let mut detector = AnomalyDetector::new(&config);
1181 detector.config.enable_auto_recovery = true;
1182
1183 let anomaly = Anomaly {
1184 anomaly_type: AnomalyType::GradientExplosion,
1185 timestamp: Utc::now(),
1186 location: "test_layer".to_string(),
1187 description: "Test gradient explosion".to_string(),
1188 severity: AnomalySeverity::High,
1189 metadata: HashMap::new(),
1190 };
1191
1192 let action = detector.attempt_recovery(&anomaly).await.expect("temp file creation failed");
1193 assert!(matches!(action, RecoveryAction::ClipGradients { .. }));
1194 assert_eq!(detector.get_recovery_attempts().len(), 1);
1195 }
1196
1197 #[test]
1198 fn test_monitoring_stats() {
1199 let config = DebugConfig::default();
1200 let mut detector = AnomalyDetector::new(&config);
1201
1202 detector.check_nan(&[f32::NAN], "test").expect("operation failed in test");
1204 detector.check_inf(&[f32::INFINITY], "test").expect("operation failed in test");
1205
1206 let stats = detector.get_monitoring_stats();
1207 assert_eq!(stats.total_anomalies, 2);
1208 assert!(stats.anomalies_per_type.contains_key("NaN"));
1209 assert!(stats.anomalies_per_type.contains_key("Infinity"));
1210 }
1211
1212 #[test]
1213 fn test_monitoring_window_update() {
1214 let config = DebugConfig::default();
1215 let mut detector = AnomalyDetector::new(&config);
1216
1217 detector.check_nan(&[f32::NAN], "test").expect("operation failed in test");
1218 detector.update_monitoring_window().expect("operation failed in test");
1219
1220 let stats = detector.get_monitoring_stats();
1221 assert_eq!(stats.monitoring_window.len(), 1);
1222 assert_eq!(stats.monitoring_window[0].anomaly_count, 1);
1223 }
1224}