1use crate::anomaly_detector::{Anomaly, AnomalySeverity};
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, VecDeque};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct GradientAnomalyDetector {
15 pub enabled: bool,
16 pub sensitivity: f64,
17 pub detection_window: usize,
18 pub anomaly_history: VecDeque<GradientAnomaly>,
19 pub baseline_statistics: HashMap<String, BaselineGradientStats>,
20}
21
22impl Default for GradientAnomalyDetector {
23 fn default() -> Self {
24 Self {
25 enabled: true,
26 sensitivity: 0.8,
27 detection_window: 50,
28 anomaly_history: VecDeque::with_capacity(1000),
29 baseline_statistics: HashMap::new(),
30 }
31 }
32}
33
34impl GradientAnomalyDetector {
35 pub fn new(sensitivity: f64, window_size: usize) -> Self {
36 Self {
37 enabled: true,
38 sensitivity,
39 detection_window: window_size,
40 anomaly_history: VecDeque::with_capacity(1000),
41 baseline_statistics: HashMap::new(),
42 }
43 }
44
45 pub fn establish_baseline(&mut self, layer_name: &str, gradient_history: &[f64]) {
46 if gradient_history.len() < 10 {
47 return;
48 }
49
50 let mean = gradient_history.iter().sum::<f64>() / gradient_history.len() as f64;
51 let variance = gradient_history.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
52 / gradient_history.len() as f64;
53 let std = variance.sqrt();
54
55 let mut sorted_values = gradient_history.to_vec();
56 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
57
58 let median_idx = sorted_values.len() / 2;
59 let median = if sorted_values.len() % 2 == 0 {
60 (sorted_values[median_idx - 1] + sorted_values[median_idx]) / 2.0
61 } else {
62 sorted_values[median_idx]
63 };
64
65 let percentile_5_idx = (sorted_values.len() as f64 * 0.05) as usize;
66 let percentile_95_idx = (sorted_values.len() as f64 * 0.95) as usize;
67
68 let baseline = BaselineGradientStats {
69 mean,
70 std,
71 median,
72 percentile_95: sorted_values[percentile_95_idx.min(sorted_values.len() - 1)],
73 percentile_5: sorted_values[percentile_5_idx],
74 samples: gradient_history.len(),
75 };
76
77 self.baseline_statistics.insert(layer_name.to_string(), baseline);
78 }
79
80 pub fn detect_anomalies(
81 &mut self,
82 layer_name: &str,
83 gradient_norm: f64,
84 step: usize,
85 ) -> Vec<GradientAnomaly> {
86 if !self.enabled {
87 return Vec::new();
88 }
89
90 let baseline = match self.baseline_statistics.get(layer_name) {
91 Some(baseline) => baseline,
92 None => return Vec::new(), };
94
95 let mut anomalies = Vec::new();
96
97 if let Some(anomaly) =
99 self.detect_statistical_anomaly(layer_name, gradient_norm, step, baseline)
100 {
101 anomalies.push(anomaly);
102 }
103
104 if let Some(anomaly) = self.detect_pattern_anomaly(layer_name, gradient_norm, step) {
106 anomalies.push(anomaly);
107 }
108
109 for anomaly in &anomalies {
111 if self.anomaly_history.len() >= 1000 {
112 self.anomaly_history.pop_front();
113 }
114 self.anomaly_history.push_back(anomaly.clone());
115 }
116
117 anomalies
118 }
119
120 fn detect_statistical_anomaly(
121 &self,
122 layer_name: &str,
123 gradient_norm: f64,
124 step: usize,
125 baseline: &BaselineGradientStats,
126 ) -> Option<GradientAnomaly> {
127 let z_score = (gradient_norm - baseline.mean) / baseline.std;
128 let threshold = 2.0 + (1.0 - self.sensitivity) * 2.0; if z_score.abs() > threshold {
131 let anomaly_type = if z_score > 0.0 {
132 if z_score > threshold * 1.5 {
133 AnomalyType::SuddenSpike
134 } else {
135 AnomalyType::SuddenSpike
136 }
137 } else {
138 AnomalyType::SuddenDrop
139 };
140
141 let severity = (z_score.abs() / threshold).min(1.0);
142
143 Some(GradientAnomaly {
144 layer_name: layer_name.to_string(),
145 anomaly_type,
146 severity,
147 timestamp: Utc::now(),
148 context: AnomalyContext {
149 step,
150 gradient_norm,
151 expected_range: (baseline.percentile_5, baseline.percentile_95),
152 deviation_magnitude: z_score.abs(),
153 },
154 })
155 } else {
156 None
157 }
158 }
159
160 fn detect_pattern_anomaly(
161 &self,
162 layer_name: &str,
163 gradient_norm: f64,
164 step: usize,
165 ) -> Option<GradientAnomaly> {
166 let recent_anomalies: Vec<&GradientAnomaly> = self
168 .anomaly_history
169 .iter()
170 .filter(|a| a.layer_name == layer_name)
171 .rev()
172 .take(10)
173 .collect();
174
175 if recent_anomalies.len() >= 3 {
176 let oscillation_count = recent_anomalies
178 .windows(2)
179 .filter(|pair| {
180 matches!(
181 (&pair[0].anomaly_type, &pair[1].anomaly_type),
182 (AnomalyType::SuddenSpike, AnomalyType::SuddenDrop)
183 | (AnomalyType::SuddenDrop, AnomalyType::SuddenSpike)
184 )
185 })
186 .count();
187
188 if oscillation_count >= 2 {
189 return Some(GradientAnomaly {
190 layer_name: layer_name.to_string(),
191 anomaly_type: AnomalyType::Oscillation,
192 severity: 0.7,
193 timestamp: Utc::now(),
194 context: AnomalyContext {
195 step,
196 gradient_norm,
197 expected_range: (0.0, 1.0), deviation_magnitude: oscillation_count as f64,
199 },
200 });
201 }
202 }
203
204 if recent_anomalies.len() >= 5 {
206 let all_similar = recent_anomalies.windows(2).all(|pair| {
207 (pair[0].context.gradient_norm - pair[1].context.gradient_norm).abs() < 1e-6
208 });
209
210 if all_similar {
211 return Some(GradientAnomaly {
212 layer_name: layer_name.to_string(),
213 anomaly_type: AnomalyType::Stagnation,
214 severity: 0.8,
215 timestamp: Utc::now(),
216 context: AnomalyContext {
217 step,
218 gradient_norm,
219 expected_range: (0.0, 1.0), deviation_magnitude: 0.0,
221 },
222 });
223 }
224 }
225
226 None
227 }
228
229 pub fn get_anomaly_summary(&self, layer_name: Option<&str>) -> AnomalySummary {
230 let filtered_anomalies: Vec<&GradientAnomaly> = match layer_name {
231 Some(name) => self.anomaly_history.iter().filter(|a| a.layer_name == name).collect(),
232 None => self.anomaly_history.iter().collect(),
233 };
234
235 let total_anomalies = filtered_anomalies.len();
236 let mut anomaly_type_counts = HashMap::new();
237 let mut severity_sum = 0.0;
238
239 for anomaly in &filtered_anomalies {
240 *anomaly_type_counts.entry(anomaly.anomaly_type.clone()).or_insert(0) += 1;
241 severity_sum += anomaly.severity;
242 }
243
244 let average_severity =
245 if total_anomalies > 0 { severity_sum / total_anomalies as f64 } else { 0.0 };
246
247 let anomalies: Vec<Anomaly> = filtered_anomalies
249 .iter()
250 .map(|gradient_anomaly| {
251 let severity = if gradient_anomaly.severity >= 0.8 {
252 AnomalySeverity::Critical
253 } else if gradient_anomaly.severity >= 0.6 {
254 AnomalySeverity::High
255 } else if gradient_anomaly.severity >= 0.3 {
256 AnomalySeverity::Medium
257 } else {
258 AnomalySeverity::Low
259 };
260
261 let general_anomaly_type = match gradient_anomaly.anomaly_type {
263 AnomalyType::SuddenSpike => {
264 crate::anomaly_detector::AnomalyType::GradientExplosion
265 },
266 AnomalyType::SuddenDrop => {
267 crate::anomaly_detector::AnomalyType::GradientVanishing
268 },
269 AnomalyType::Oscillation => {
270 crate::anomaly_detector::AnomalyType::NumericalInstability
271 },
272 AnomalyType::Stagnation => {
273 crate::anomaly_detector::AnomalyType::GradientVanishing
274 },
275 AnomalyType::Chaos => {
276 crate::anomaly_detector::AnomalyType::NumericalInstability
277 },
278 };
279
280 let description = format!(
281 "Gradient anomaly of type {:?} detected with severity {:.2}",
282 gradient_anomaly.anomaly_type, gradient_anomaly.severity
283 );
284
285 let mut metadata = HashMap::new();
286 metadata.insert(
287 "step".to_string(),
288 gradient_anomaly.context.step.to_string(),
289 );
290 metadata.insert(
291 "gradient_norm".to_string(),
292 gradient_anomaly.context.gradient_norm.to_string(),
293 );
294 metadata.insert(
295 "expected_range_min".to_string(),
296 gradient_anomaly.context.expected_range.0.to_string(),
297 );
298 metadata.insert(
299 "expected_range_max".to_string(),
300 gradient_anomaly.context.expected_range.1.to_string(),
301 );
302 metadata.insert(
303 "deviation_magnitude".to_string(),
304 gradient_anomaly.context.deviation_magnitude.to_string(),
305 );
306 metadata.insert(
307 "original_anomaly_type".to_string(),
308 format!("{:?}", gradient_anomaly.anomaly_type),
309 );
310
311 Anomaly {
312 anomaly_type: general_anomaly_type,
313 timestamp: gradient_anomaly.timestamp,
314 location: gradient_anomaly.layer_name.clone(),
315 description,
316 severity,
317 metadata,
318 }
319 })
320 .collect();
321
322 AnomalySummary {
323 layer_name: layer_name.map(|s| s.to_string()),
324 total_anomalies,
325 anomaly_type_counts,
326 average_severity,
327 recent_trend: self.analyze_recent_trend(&filtered_anomalies),
328 recommendations: self.generate_anomaly_recommendations(&filtered_anomalies),
329 anomalies,
330 }
331 }
332
333 fn analyze_recent_trend(&self, anomalies: &[&GradientAnomaly]) -> AnomalyTrend {
334 if anomalies.len() < 5 {
335 return AnomalyTrend::Stable;
336 }
337
338 let recent_anomalies: Vec<&GradientAnomaly> =
339 anomalies.iter().rev().take(10).cloned().collect();
340 let older_anomalies: Vec<&GradientAnomaly> =
341 anomalies.iter().rev().skip(10).take(10).cloned().collect();
342
343 if older_anomalies.is_empty() {
344 return AnomalyTrend::Stable;
345 }
346
347 let recent_avg_severity: f64 = recent_anomalies.iter().map(|a| a.severity).sum::<f64>()
348 / recent_anomalies.len() as f64;
349 let older_avg_severity: f64 =
350 older_anomalies.iter().map(|a| a.severity).sum::<f64>() / older_anomalies.len() as f64;
351
352 let trend_threshold = 0.1;
353 if recent_avg_severity > older_avg_severity + trend_threshold {
354 AnomalyTrend::Increasing
355 } else if recent_avg_severity < older_avg_severity - trend_threshold {
356 AnomalyTrend::Decreasing
357 } else {
358 AnomalyTrend::Stable
359 }
360 }
361
362 fn generate_anomaly_recommendations(&self, anomalies: &[&GradientAnomaly]) -> Vec<String> {
363 let mut recommendations = Vec::new();
364
365 let spike_count = anomalies
366 .iter()
367 .filter(|a| matches!(a.anomaly_type, AnomalyType::SuddenSpike))
368 .count();
369 let drop_count = anomalies
370 .iter()
371 .filter(|a| matches!(a.anomaly_type, AnomalyType::SuddenDrop))
372 .count();
373 let oscillation_count = anomalies
374 .iter()
375 .filter(|a| matches!(a.anomaly_type, AnomalyType::Oscillation))
376 .count();
377 let stagnation_count = anomalies
378 .iter()
379 .filter(|a| matches!(a.anomaly_type, AnomalyType::Stagnation))
380 .count();
381
382 if spike_count > 3 {
383 recommendations
384 .push("Consider reducing learning rate to prevent gradient explosion".to_string());
385 recommendations.push("Add gradient clipping to stabilize training".to_string());
386 }
387
388 if drop_count > 3 {
389 recommendations.push("Check for vanishing gradient issues".to_string());
390 recommendations
391 .push("Consider using residual connections or better initialization".to_string());
392 }
393
394 if oscillation_count > 2 {
395 recommendations.push("Reduce learning rate to dampen oscillations".to_string());
396 recommendations
397 .push("Consider using momentum or adaptive learning rate methods".to_string());
398 }
399
400 if stagnation_count > 2 {
401 recommendations.push(
402 "Learning may have plateaued - consider learning rate scheduling".to_string(),
403 );
404 recommendations
405 .push("Check for potential convergence or training data issues".to_string());
406 }
407
408 if recommendations.is_empty() {
409 recommendations.push("Gradient behavior appears normal".to_string());
410 }
411
412 recommendations
413 }
414}
415
416#[derive(Debug, Clone, Serialize, Deserialize)]
418pub struct GradientAnomaly {
419 pub layer_name: String,
420 pub anomaly_type: AnomalyType,
421 pub severity: f64,
422 pub timestamp: DateTime<Utc>,
423 pub context: AnomalyContext,
424}
425
426#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
428pub enum AnomalyType {
429 SuddenSpike,
430 SuddenDrop,
431 Oscillation,
432 Stagnation,
433 Chaos,
434}
435
436#[derive(Debug, Clone, Serialize, Deserialize)]
438pub struct AnomalyContext {
439 pub step: usize,
440 pub gradient_norm: f64,
441 pub expected_range: (f64, f64),
442 pub deviation_magnitude: f64,
443}
444
445#[derive(Debug, Clone, Serialize, Deserialize)]
447pub struct BaselineGradientStats {
448 pub mean: f64,
449 pub std: f64,
450 pub median: f64,
451 pub percentile_95: f64,
452 pub percentile_5: f64,
453 pub samples: usize,
454}
455
456#[derive(Debug, Clone, Serialize, Deserialize)]
458pub struct AnomalySummary {
459 pub layer_name: Option<String>,
460 pub total_anomalies: usize,
461 pub anomaly_type_counts: HashMap<AnomalyType, usize>,
462 pub average_severity: f64,
463 pub recent_trend: AnomalyTrend,
464 pub recommendations: Vec<String>,
465 pub anomalies: Vec<Anomaly>,
466}
467
468#[derive(Debug, Clone, Serialize, Deserialize)]
470pub enum AnomalyTrend {
471 Increasing,
472 Stable,
473 Decreasing,
474}