Skip to main content

trustformers_debug/gradient_debugger/
anomaly_detection.rs

1//! Advanced Gradient Anomaly Detection System
2//!
3//! This module provides sophisticated anomaly detection capabilities for gradient
4//! analysis, including baseline establishment, pattern recognition, and contextual
5//! anomaly classification.
6
7use crate::anomaly_detector::{Anomaly, AnomalySeverity};
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, VecDeque};
11
12/// Advanced gradient anomaly detection
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct GradientAnomalyDetector {
15    pub enabled: bool,
16    pub sensitivity: f64,
17    pub detection_window: usize,
18    pub anomaly_history: VecDeque<GradientAnomaly>,
19    pub baseline_statistics: HashMap<String, BaselineGradientStats>,
20}
21
22impl Default for GradientAnomalyDetector {
23    fn default() -> Self {
24        Self {
25            enabled: true,
26            sensitivity: 0.8,
27            detection_window: 50,
28            anomaly_history: VecDeque::with_capacity(1000),
29            baseline_statistics: HashMap::new(),
30        }
31    }
32}
33
34impl GradientAnomalyDetector {
35    pub fn new(sensitivity: f64, window_size: usize) -> Self {
36        Self {
37            enabled: true,
38            sensitivity,
39            detection_window: window_size,
40            anomaly_history: VecDeque::with_capacity(1000),
41            baseline_statistics: HashMap::new(),
42        }
43    }
44
45    pub fn establish_baseline(&mut self, layer_name: &str, gradient_history: &[f64]) {
46        if gradient_history.len() < 10 {
47            return;
48        }
49
50        let mean = gradient_history.iter().sum::<f64>() / gradient_history.len() as f64;
51        let variance = gradient_history.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
52            / gradient_history.len() as f64;
53        let std = variance.sqrt();
54
55        let mut sorted_values = gradient_history.to_vec();
56        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
57
58        let median_idx = sorted_values.len() / 2;
59        let median = if sorted_values.len() % 2 == 0 {
60            (sorted_values[median_idx - 1] + sorted_values[median_idx]) / 2.0
61        } else {
62            sorted_values[median_idx]
63        };
64
65        let percentile_5_idx = (sorted_values.len() as f64 * 0.05) as usize;
66        let percentile_95_idx = (sorted_values.len() as f64 * 0.95) as usize;
67
68        let baseline = BaselineGradientStats {
69            mean,
70            std,
71            median,
72            percentile_95: sorted_values[percentile_95_idx.min(sorted_values.len() - 1)],
73            percentile_5: sorted_values[percentile_5_idx],
74            samples: gradient_history.len(),
75        };
76
77        self.baseline_statistics.insert(layer_name.to_string(), baseline);
78    }
79
80    pub fn detect_anomalies(
81        &mut self,
82        layer_name: &str,
83        gradient_norm: f64,
84        step: usize,
85    ) -> Vec<GradientAnomaly> {
86        if !self.enabled {
87            return Vec::new();
88        }
89
90        let baseline = match self.baseline_statistics.get(layer_name) {
91            Some(baseline) => baseline,
92            None => return Vec::new(), // No baseline established yet
93        };
94
95        let mut anomalies = Vec::new();
96
97        // Statistical anomaly detection
98        if let Some(anomaly) =
99            self.detect_statistical_anomaly(layer_name, gradient_norm, step, baseline)
100        {
101            anomalies.push(anomaly);
102        }
103
104        // Pattern-based anomaly detection
105        if let Some(anomaly) = self.detect_pattern_anomaly(layer_name, gradient_norm, step) {
106            anomalies.push(anomaly);
107        }
108
109        // Add to history
110        for anomaly in &anomalies {
111            if self.anomaly_history.len() >= 1000 {
112                self.anomaly_history.pop_front();
113            }
114            self.anomaly_history.push_back(anomaly.clone());
115        }
116
117        anomalies
118    }
119
120    fn detect_statistical_anomaly(
121        &self,
122        layer_name: &str,
123        gradient_norm: f64,
124        step: usize,
125        baseline: &BaselineGradientStats,
126    ) -> Option<GradientAnomaly> {
127        let z_score = (gradient_norm - baseline.mean) / baseline.std;
128        let threshold = 2.0 + (1.0 - self.sensitivity) * 2.0; // Threshold between 2-4 based on sensitivity
129
130        if z_score.abs() > threshold {
131            let anomaly_type = if z_score > 0.0 {
132                if z_score > threshold * 1.5 {
133                    AnomalyType::SuddenSpike
134                } else {
135                    AnomalyType::SuddenSpike
136                }
137            } else {
138                AnomalyType::SuddenDrop
139            };
140
141            let severity = (z_score.abs() / threshold).min(1.0);
142
143            Some(GradientAnomaly {
144                layer_name: layer_name.to_string(),
145                anomaly_type,
146                severity,
147                timestamp: Utc::now(),
148                context: AnomalyContext {
149                    step,
150                    gradient_norm,
151                    expected_range: (baseline.percentile_5, baseline.percentile_95),
152                    deviation_magnitude: z_score.abs(),
153                },
154            })
155        } else {
156            None
157        }
158    }
159
160    fn detect_pattern_anomaly(
161        &self,
162        layer_name: &str,
163        gradient_norm: f64,
164        step: usize,
165    ) -> Option<GradientAnomaly> {
166        // Look for patterns in recent anomaly history for this layer
167        let recent_anomalies: Vec<&GradientAnomaly> = self
168            .anomaly_history
169            .iter()
170            .filter(|a| a.layer_name == layer_name)
171            .rev()
172            .take(10)
173            .collect();
174
175        if recent_anomalies.len() >= 3 {
176            // Check for oscillation pattern
177            let oscillation_count = recent_anomalies
178                .windows(2)
179                .filter(|pair| {
180                    matches!(
181                        (&pair[0].anomaly_type, &pair[1].anomaly_type),
182                        (AnomalyType::SuddenSpike, AnomalyType::SuddenDrop)
183                            | (AnomalyType::SuddenDrop, AnomalyType::SuddenSpike)
184                    )
185                })
186                .count();
187
188            if oscillation_count >= 2 {
189                return Some(GradientAnomaly {
190                    layer_name: layer_name.to_string(),
191                    anomaly_type: AnomalyType::Oscillation,
192                    severity: 0.7,
193                    timestamp: Utc::now(),
194                    context: AnomalyContext {
195                        step,
196                        gradient_norm,
197                        expected_range: (0.0, 1.0), // Placeholder
198                        deviation_magnitude: oscillation_count as f64,
199                    },
200                });
201            }
202        }
203
204        // Check for stagnation
205        if recent_anomalies.len() >= 5 {
206            let all_similar = recent_anomalies.windows(2).all(|pair| {
207                (pair[0].context.gradient_norm - pair[1].context.gradient_norm).abs() < 1e-6
208            });
209
210            if all_similar {
211                return Some(GradientAnomaly {
212                    layer_name: layer_name.to_string(),
213                    anomaly_type: AnomalyType::Stagnation,
214                    severity: 0.8,
215                    timestamp: Utc::now(),
216                    context: AnomalyContext {
217                        step,
218                        gradient_norm,
219                        expected_range: (0.0, 1.0), // Placeholder
220                        deviation_magnitude: 0.0,
221                    },
222                });
223            }
224        }
225
226        None
227    }
228
229    pub fn get_anomaly_summary(&self, layer_name: Option<&str>) -> AnomalySummary {
230        let filtered_anomalies: Vec<&GradientAnomaly> = match layer_name {
231            Some(name) => self.anomaly_history.iter().filter(|a| a.layer_name == name).collect(),
232            None => self.anomaly_history.iter().collect(),
233        };
234
235        let total_anomalies = filtered_anomalies.len();
236        let mut anomaly_type_counts = HashMap::new();
237        let mut severity_sum = 0.0;
238
239        for anomaly in &filtered_anomalies {
240            *anomaly_type_counts.entry(anomaly.anomaly_type.clone()).or_insert(0) += 1;
241            severity_sum += anomaly.severity;
242        }
243
244        let average_severity =
245            if total_anomalies > 0 { severity_sum / total_anomalies as f64 } else { 0.0 };
246
247        // Convert GradientAnomaly to Anomaly objects
248        let anomalies: Vec<Anomaly> = filtered_anomalies
249            .iter()
250            .map(|gradient_anomaly| {
251                let severity = if gradient_anomaly.severity >= 0.8 {
252                    AnomalySeverity::Critical
253                } else if gradient_anomaly.severity >= 0.6 {
254                    AnomalySeverity::High
255                } else if gradient_anomaly.severity >= 0.3 {
256                    AnomalySeverity::Medium
257                } else {
258                    AnomalySeverity::Low
259                };
260
261                // Convert gradient-specific anomaly type to general anomaly type
262                let general_anomaly_type = match gradient_anomaly.anomaly_type {
263                    AnomalyType::SuddenSpike => {
264                        crate::anomaly_detector::AnomalyType::GradientExplosion
265                    },
266                    AnomalyType::SuddenDrop => {
267                        crate::anomaly_detector::AnomalyType::GradientVanishing
268                    },
269                    AnomalyType::Oscillation => {
270                        crate::anomaly_detector::AnomalyType::NumericalInstability
271                    },
272                    AnomalyType::Stagnation => {
273                        crate::anomaly_detector::AnomalyType::GradientVanishing
274                    },
275                    AnomalyType::Chaos => {
276                        crate::anomaly_detector::AnomalyType::NumericalInstability
277                    },
278                };
279
280                let description = format!(
281                    "Gradient anomaly of type {:?} detected with severity {:.2}",
282                    gradient_anomaly.anomaly_type, gradient_anomaly.severity
283                );
284
285                let mut metadata = HashMap::new();
286                metadata.insert(
287                    "step".to_string(),
288                    gradient_anomaly.context.step.to_string(),
289                );
290                metadata.insert(
291                    "gradient_norm".to_string(),
292                    gradient_anomaly.context.gradient_norm.to_string(),
293                );
294                metadata.insert(
295                    "expected_range_min".to_string(),
296                    gradient_anomaly.context.expected_range.0.to_string(),
297                );
298                metadata.insert(
299                    "expected_range_max".to_string(),
300                    gradient_anomaly.context.expected_range.1.to_string(),
301                );
302                metadata.insert(
303                    "deviation_magnitude".to_string(),
304                    gradient_anomaly.context.deviation_magnitude.to_string(),
305                );
306                metadata.insert(
307                    "original_anomaly_type".to_string(),
308                    format!("{:?}", gradient_anomaly.anomaly_type),
309                );
310
311                Anomaly {
312                    anomaly_type: general_anomaly_type,
313                    timestamp: gradient_anomaly.timestamp,
314                    location: gradient_anomaly.layer_name.clone(),
315                    description,
316                    severity,
317                    metadata,
318                }
319            })
320            .collect();
321
322        AnomalySummary {
323            layer_name: layer_name.map(|s| s.to_string()),
324            total_anomalies,
325            anomaly_type_counts,
326            average_severity,
327            recent_trend: self.analyze_recent_trend(&filtered_anomalies),
328            recommendations: self.generate_anomaly_recommendations(&filtered_anomalies),
329            anomalies,
330        }
331    }
332
333    fn analyze_recent_trend(&self, anomalies: &[&GradientAnomaly]) -> AnomalyTrend {
334        if anomalies.len() < 5 {
335            return AnomalyTrend::Stable;
336        }
337
338        let recent_anomalies: Vec<&GradientAnomaly> =
339            anomalies.iter().rev().take(10).cloned().collect();
340        let older_anomalies: Vec<&GradientAnomaly> =
341            anomalies.iter().rev().skip(10).take(10).cloned().collect();
342
343        if older_anomalies.is_empty() {
344            return AnomalyTrend::Stable;
345        }
346
347        let recent_avg_severity: f64 = recent_anomalies.iter().map(|a| a.severity).sum::<f64>()
348            / recent_anomalies.len() as f64;
349        let older_avg_severity: f64 =
350            older_anomalies.iter().map(|a| a.severity).sum::<f64>() / older_anomalies.len() as f64;
351
352        let trend_threshold = 0.1;
353        if recent_avg_severity > older_avg_severity + trend_threshold {
354            AnomalyTrend::Increasing
355        } else if recent_avg_severity < older_avg_severity - trend_threshold {
356            AnomalyTrend::Decreasing
357        } else {
358            AnomalyTrend::Stable
359        }
360    }
361
362    fn generate_anomaly_recommendations(&self, anomalies: &[&GradientAnomaly]) -> Vec<String> {
363        let mut recommendations = Vec::new();
364
365        let spike_count = anomalies
366            .iter()
367            .filter(|a| matches!(a.anomaly_type, AnomalyType::SuddenSpike))
368            .count();
369        let drop_count = anomalies
370            .iter()
371            .filter(|a| matches!(a.anomaly_type, AnomalyType::SuddenDrop))
372            .count();
373        let oscillation_count = anomalies
374            .iter()
375            .filter(|a| matches!(a.anomaly_type, AnomalyType::Oscillation))
376            .count();
377        let stagnation_count = anomalies
378            .iter()
379            .filter(|a| matches!(a.anomaly_type, AnomalyType::Stagnation))
380            .count();
381
382        if spike_count > 3 {
383            recommendations
384                .push("Consider reducing learning rate to prevent gradient explosion".to_string());
385            recommendations.push("Add gradient clipping to stabilize training".to_string());
386        }
387
388        if drop_count > 3 {
389            recommendations.push("Check for vanishing gradient issues".to_string());
390            recommendations
391                .push("Consider using residual connections or better initialization".to_string());
392        }
393
394        if oscillation_count > 2 {
395            recommendations.push("Reduce learning rate to dampen oscillations".to_string());
396            recommendations
397                .push("Consider using momentum or adaptive learning rate methods".to_string());
398        }
399
400        if stagnation_count > 2 {
401            recommendations.push(
402                "Learning may have plateaued - consider learning rate scheduling".to_string(),
403            );
404            recommendations
405                .push("Check for potential convergence or training data issues".to_string());
406        }
407
408        if recommendations.is_empty() {
409            recommendations.push("Gradient behavior appears normal".to_string());
410        }
411
412        recommendations
413    }
414}
415
416/// Gradient anomaly event
417#[derive(Debug, Clone, Serialize, Deserialize)]
418pub struct GradientAnomaly {
419    pub layer_name: String,
420    pub anomaly_type: AnomalyType,
421    pub severity: f64,
422    pub timestamp: DateTime<Utc>,
423    pub context: AnomalyContext,
424}
425
426/// Types of gradient anomalies
427#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
428pub enum AnomalyType {
429    SuddenSpike,
430    SuddenDrop,
431    Oscillation,
432    Stagnation,
433    Chaos,
434}
435
436/// Context information for anomalies
437#[derive(Debug, Clone, Serialize, Deserialize)]
438pub struct AnomalyContext {
439    pub step: usize,
440    pub gradient_norm: f64,
441    pub expected_range: (f64, f64),
442    pub deviation_magnitude: f64,
443}
444
445/// Baseline statistics for anomaly detection
446#[derive(Debug, Clone, Serialize, Deserialize)]
447pub struct BaselineGradientStats {
448    pub mean: f64,
449    pub std: f64,
450    pub median: f64,
451    pub percentile_95: f64,
452    pub percentile_5: f64,
453    pub samples: usize,
454}
455
456/// Summary of anomaly detection results
457#[derive(Debug, Clone, Serialize, Deserialize)]
458pub struct AnomalySummary {
459    pub layer_name: Option<String>,
460    pub total_anomalies: usize,
461    pub anomaly_type_counts: HashMap<AnomalyType, usize>,
462    pub average_severity: f64,
463    pub recent_trend: AnomalyTrend,
464    pub recommendations: Vec<String>,
465    pub anomalies: Vec<Anomaly>,
466}
467
468/// Trend in anomaly occurrence
469#[derive(Debug, Clone, Serialize, Deserialize)]
470pub enum AnomalyTrend {
471    Increasing,
472    Stable,
473    Decreasing,
474}