1use crate::{LogEntry, ThreatAlert, ThreatCategory, ThreatSeverity};
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum DetectionMethod {
14 ZScore,
16 MovingAverage,
18 ExponentialSmoothing,
20 IQR,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TimeSeries {
27 pub name: String,
28 pub values: VecDeque<f64>,
29 pub timestamps: VecDeque<DateTime<Utc>>,
30 pub max_size: usize,
31}
32
33impl TimeSeries {
34 pub fn new(name: String, max_size: usize) -> Self {
36 Self {
37 name,
38 values: VecDeque::with_capacity(max_size),
39 timestamps: VecDeque::with_capacity(max_size),
40 max_size,
41 }
42 }
43
44 pub fn add(&mut self, value: f64, timestamp: DateTime<Utc>) {
46 if self.values.len() >= self.max_size {
47 self.values.pop_front();
48 self.timestamps.pop_front();
49 }
50 self.values.push_back(value);
51 self.timestamps.push_back(timestamp);
52 }
53
54 pub fn mean(&self) -> f64 {
56 if self.values.is_empty() {
57 return 0.0;
58 }
59 self.values.iter().sum::<f64>() / self.values.len() as f64
60 }
61
62 pub fn std_dev(&self) -> f64 {
64 if self.values.len() < 2 {
65 return 0.0;
66 }
67 let mean = self.mean();
68 let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
69 / (self.values.len() - 1) as f64;
70 variance.sqrt()
71 }
72
73 pub fn moving_average(&self, window_size: usize) -> f64 {
75 if self.values.is_empty() {
76 return 0.0;
77 }
78 let window = window_size.min(self.values.len());
79 let start = self.values.len().saturating_sub(window);
80 self.values.iter().skip(start).sum::<f64>() / window as f64
81 }
82
83 pub fn percentile(&self, p: f64) -> f64 {
85 if self.values.is_empty() {
86 return 0.0;
87 }
88 let mut sorted: Vec<f64> = self.values.iter().copied().collect();
89 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
90 let index = ((p / 100.0) * (sorted.len() - 1) as f64).round() as usize;
91 sorted[index]
92 }
93
94 pub fn iqr(&self) -> (f64, f64, f64) {
96 let q1 = self.percentile(25.0);
97 let q3 = self.percentile(75.0);
98 let iqr = q3 - q1;
99 (q1, q3, iqr)
100 }
101}
102
103pub struct AnomalyDetector {
105 metrics: HashMap<String, TimeSeries>,
106 z_score_threshold: f64,
107 iqr_multiplier: f64,
108 moving_avg_window: usize,
109 smoothing_alpha: f64,
110}
111
112impl AnomalyDetector {
113 pub fn new() -> Self {
115 Self {
116 metrics: HashMap::new(),
117 z_score_threshold: 3.0, iqr_multiplier: 1.5, moving_avg_window: 10,
120 smoothing_alpha: 0.3, }
122 }
123
124 pub fn with_params(
126 z_score_threshold: f64,
127 iqr_multiplier: f64,
128 moving_avg_window: usize,
129 smoothing_alpha: f64,
130 ) -> Self {
131 Self {
132 metrics: HashMap::new(),
133 z_score_threshold,
134 iqr_multiplier,
135 moving_avg_window,
136 smoothing_alpha,
137 }
138 }
139
140 pub fn track_metric(&mut self, name: &str, value: f64, timestamp: DateTime<Utc>) {
142 let metric = self
143 .metrics
144 .entry(name.to_string())
145 .or_insert_with(|| TimeSeries::new(name.to_string(), 1000));
146 metric.add(value, timestamp);
147 }
148
149 pub fn detect(
151 &self,
152 metric_name: &str,
153 current_value: f64,
154 method: DetectionMethod,
155 ) -> Option<AnomalyResult> {
156 let metric = self.metrics.get(metric_name)?;
157
158 if metric.values.is_empty() {
159 return None;
160 }
161
162 match method {
163 DetectionMethod::ZScore => self.detect_zscore(metric, current_value),
164 DetectionMethod::MovingAverage => self.detect_moving_avg(metric, current_value),
165 DetectionMethod::ExponentialSmoothing => self.detect_exponential(metric, current_value),
166 DetectionMethod::IQR => self.detect_iqr(metric, current_value),
167 }
168 }
169
170 fn detect_zscore(&self, metric: &TimeSeries, value: f64) -> Option<AnomalyResult> {
172 if metric.values.len() < 10 {
173 return None; }
175
176 let mean = metric.mean();
177 let std_dev = metric.std_dev();
178
179 if std_dev == 0.0 {
180 return None; }
182
183 let z_score = (value - mean).abs() / std_dev;
184
185 if z_score > self.z_score_threshold {
186 Some(AnomalyResult {
187 metric_name: metric.name.clone(),
188 current_value: value,
189 expected_value: mean,
190 deviation: z_score,
191 method: DetectionMethod::ZScore,
192 severity: self.calculate_severity(z_score, self.z_score_threshold),
193 description: format!(
194 "Value {:.2} deviates {:.2} standard deviations from mean {:.2}",
195 value, z_score, mean
196 ),
197 })
198 } else {
199 None
200 }
201 }
202
203 fn detect_moving_avg(&self, metric: &TimeSeries, value: f64) -> Option<AnomalyResult> {
205 if metric.values.len() < self.moving_avg_window {
206 return None;
207 }
208
209 let moving_avg = metric.moving_average(self.moving_avg_window);
210 let std_dev = metric.std_dev();
211
212 if std_dev == 0.0 {
213 return None;
214 }
215
216 let deviation = (value - moving_avg).abs() / std_dev;
217
218 if deviation > self.z_score_threshold {
219 Some(AnomalyResult {
220 metric_name: metric.name.clone(),
221 current_value: value,
222 expected_value: moving_avg,
223 deviation,
224 method: DetectionMethod::MovingAverage,
225 severity: self.calculate_severity(deviation, self.z_score_threshold),
226 description: format!(
227 "Value {:.2} deviates from moving average {:.2} by {:.2} std devs",
228 value, moving_avg, deviation
229 ),
230 })
231 } else {
232 None
233 }
234 }
235
236 fn detect_exponential(&self, metric: &TimeSeries, value: f64) -> Option<AnomalyResult> {
238 if metric.values.is_empty() {
239 return None;
240 }
241
242 let mut ewma = metric.values[0];
244 for &v in metric.values.iter().skip(1) {
245 ewma = self.smoothing_alpha * v + (1.0 - self.smoothing_alpha) * ewma;
246 }
247
248 let std_dev = metric.std_dev();
249 if std_dev == 0.0 {
250 return None;
251 }
252
253 let deviation = (value - ewma).abs() / std_dev;
254
255 if deviation > self.z_score_threshold {
256 Some(AnomalyResult {
257 metric_name: metric.name.clone(),
258 current_value: value,
259 expected_value: ewma,
260 deviation,
261 method: DetectionMethod::ExponentialSmoothing,
262 severity: self.calculate_severity(deviation, self.z_score_threshold),
263 description: format!(
264 "Value {:.2} deviates from exponential moving average {:.2}",
265 value, ewma
266 ),
267 })
268 } else {
269 None
270 }
271 }
272
273 fn detect_iqr(&self, metric: &TimeSeries, value: f64) -> Option<AnomalyResult> {
275 if metric.values.len() < 10 {
276 return None;
277 }
278
279 let (q1, q3, iqr) = metric.iqr();
280 let lower_bound = q1 - self.iqr_multiplier * iqr;
281 let upper_bound = q3 + self.iqr_multiplier * iqr;
282
283 if value < lower_bound || value > upper_bound {
284 let deviation = if value < lower_bound {
285 (lower_bound - value) / iqr
286 } else {
287 (value - upper_bound) / iqr
288 };
289
290 Some(AnomalyResult {
291 metric_name: metric.name.clone(),
292 current_value: value,
293 expected_value: (q1 + q3) / 2.0,
294 deviation,
295 method: DetectionMethod::IQR,
296 severity: self.calculate_severity(deviation, 1.0),
297 description: format!(
298 "Value {:.2} outside IQR bounds [{:.2}, {:.2}]",
299 value, lower_bound, upper_bound
300 ),
301 })
302 } else {
303 None
304 }
305 }
306
307 fn calculate_severity(&self, deviation: f64, threshold: f64) -> ThreatSeverity {
309 let ratio = deviation / threshold;
310 if ratio > 3.0 {
311 ThreatSeverity::Critical
312 } else if ratio > 2.0 {
313 ThreatSeverity::High
314 } else if ratio > 1.5 {
315 ThreatSeverity::Medium
316 } else {
317 ThreatSeverity::Low
318 }
319 }
320
321 pub fn analyze_log(&mut self, log: &LogEntry) -> Vec<ThreatAlert> {
323 let mut alerts = Vec::new();
324
325 for (key, value_str) in &log.metadata {
327 if let Ok(value) = value_str.parse::<f64>() {
328 let metric_name = format!("log.{}", key);
329 self.track_metric(&metric_name, value, log.timestamp);
330
331 for method in &[
333 DetectionMethod::ZScore,
334 DetectionMethod::MovingAverage,
335 DetectionMethod::IQR,
336 ] {
337 if let Some(anomaly) = self.detect(&metric_name, value, *method) {
338 alerts.push(anomaly.to_threat_alert(log));
339 break; }
341 }
342 }
343 }
344
345 alerts
346 }
347
348 pub fn get_metric(&self, name: &str) -> Option<&TimeSeries> {
350 self.metrics.get(name)
351 }
352
353 pub fn get_all_metrics(&self) -> Vec<&str> {
355 self.metrics.keys().map(|s| s.as_str()).collect()
356 }
357
358 pub fn clear_old_data(&mut self, before: DateTime<Utc>) {
360 for metric in self.metrics.values_mut() {
361 while let Some(×tamp) = metric.timestamps.front() {
362 if timestamp < before {
363 metric.timestamps.pop_front();
364 metric.values.pop_front();
365 } else {
366 break;
367 }
368 }
369 }
370 }
371}
372
373impl Default for AnomalyDetector {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379#[derive(Debug, Clone)]
381pub struct AnomalyResult {
382 pub metric_name: String,
383 pub current_value: f64,
384 pub expected_value: f64,
385 pub deviation: f64,
386 pub method: DetectionMethod,
387 pub severity: ThreatSeverity,
388 pub description: String,
389}
390
391impl AnomalyResult {
392 pub fn to_threat_alert(&self, source_log: &LogEntry) -> ThreatAlert {
394 ThreatAlert {
395 alert_id: format!("ANOMALY-{}", chrono::Utc::now().timestamp()),
396 timestamp: Utc::now(),
397 severity: self.severity,
398 category: ThreatCategory::AnomalousActivity,
399 description: format!(
400 "Statistical anomaly in {}: {}",
401 self.metric_name, self.description
402 ),
403 source_log: format!("{} - {}", source_log.timestamp, source_log.message),
404 indicators: vec![
405 format!("Current: {:.2}", self.current_value),
406 format!("Expected: {:.2}", self.expected_value),
407 format!("Deviation: {:.2}", self.deviation),
408 format!("Method: {:?}", self.method),
409 ],
410 recommended_action:
411 "Investigate metric anomaly, review related logs, check for system issues"
412 .to_string(),
413 threat_score: self.calculate_threat_score(),
414 correlated_alerts: vec![],
415 }
416 }
417
418 fn calculate_threat_score(&self) -> u32 {
419 let base_score = match self.severity {
420 ThreatSeverity::Info => 10,
421 ThreatSeverity::Low => 25,
422 ThreatSeverity::Medium => 50,
423 ThreatSeverity::High => 75,
424 ThreatSeverity::Critical => 95,
425 };
426
427 let deviation_bonus = (self.deviation * 2.0).min(20.0) as u32;
429 (base_score + deviation_bonus).min(100)
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use chrono::Duration;
437 use std::collections::HashMap;
438
439 #[test]
440 fn test_time_series_mean() {
441 let mut ts = TimeSeries::new("test".to_string(), 100);
442 ts.add(10.0, Utc::now());
443 ts.add(20.0, Utc::now());
444 ts.add(30.0, Utc::now());
445
446 assert_eq!(ts.mean(), 20.0);
447 }
448
449 #[test]
450 fn test_time_series_std_dev() {
451 let mut ts = TimeSeries::new("test".to_string(), 100);
452 for i in 1..=10 {
453 ts.add(i as f64, Utc::now());
454 }
455
456 let std_dev = ts.std_dev();
457 assert!(std_dev > 0.0);
458 assert!(std_dev < 4.0); }
460
461 #[test]
462 fn test_time_series_moving_average() {
463 let mut ts = TimeSeries::new("test".to_string(), 100);
464 ts.add(10.0, Utc::now());
465 ts.add(20.0, Utc::now());
466 ts.add(30.0, Utc::now());
467 ts.add(40.0, Utc::now());
468
469 let ma = ts.moving_average(2);
470 assert_eq!(ma, 35.0); }
472
473 #[test]
474 fn test_time_series_percentile() {
475 let mut ts = TimeSeries::new("test".to_string(), 100);
476 for i in 1..=100 {
477 ts.add(i as f64, Utc::now());
478 }
479
480 assert_eq!(ts.percentile(0.0), 1.0);
481 assert_eq!(ts.percentile(100.0), 100.0);
482 let median = ts.percentile(50.0);
483 assert!((49.0..=52.0).contains(&median));
484 }
485
486 #[test]
487 fn test_time_series_iqr() {
488 let mut ts = TimeSeries::new("test".to_string(), 100);
489 for i in 1..=100 {
490 ts.add(i as f64, Utc::now());
491 }
492
493 let (q1, q3, iqr) = ts.iqr();
494 assert!((24.0..=27.0).contains(&q1));
495 assert!((73.0..=77.0).contains(&q3));
496 assert!((48.0..=52.0).contains(&iqr));
497 }
498
499 #[test]
500 fn test_zscore_detection() {
501 let mut detector = AnomalyDetector::new();
502
503 for i in 0..20 {
505 detector.track_metric("test_metric", 100.0 + (i as f64), Utc::now());
506 }
507
508 let result = detector.detect("test_metric", 110.0, DetectionMethod::ZScore);
510 assert!(result.is_none());
511
512 let result = detector.detect("test_metric", 500.0, DetectionMethod::ZScore);
514 assert!(result.is_some());
515 let anomaly = result.unwrap();
516 assert_eq!(anomaly.metric_name, "test_metric");
517 assert_eq!(anomaly.current_value, 500.0);
518 }
519
520 #[test]
521 fn test_moving_average_detection() {
522 let mut detector = AnomalyDetector::new();
523
524 for i in 0..15 {
525 detector.track_metric("test_metric", 50.0 + i as f64, Utc::now());
526 }
527
528 let result = detector.detect("test_metric", 200.0, DetectionMethod::MovingAverage);
529 assert!(result.is_some());
530 }
531
532 #[test]
533 fn test_iqr_detection() {
534 let mut detector = AnomalyDetector::new();
535
536 for i in 1..=20 {
538 detector.track_metric("test_metric", i as f64 * 10.0, Utc::now());
539 }
540
541 let result = detector.detect("test_metric", 1000.0, DetectionMethod::IQR);
543 assert!(result.is_some());
544
545 let result = detector.detect("test_metric", 105.0, DetectionMethod::IQR);
547 assert!(result.is_none());
548 }
549
550 #[test]
551 fn test_exponential_smoothing() {
552 let mut detector = AnomalyDetector::with_params(3.0, 1.5, 10, 0.3);
553
554 for i in 0..20 {
555 detector.track_metric("test_metric", 100.0 + (i as f64), Utc::now());
556 }
557
558 let result = detector.detect("test_metric", 500.0, DetectionMethod::ExponentialSmoothing);
559 assert!(result.is_some());
560 }
561
562 #[test]
563 fn test_severity_calculation() {
564 let detector = AnomalyDetector::new();
565
566 assert_eq!(
567 detector.calculate_severity(10.0, 3.0),
568 ThreatSeverity::Critical
569 );
570 assert_eq!(detector.calculate_severity(6.5, 3.0), ThreatSeverity::High);
571 assert_eq!(
572 detector.calculate_severity(4.8, 3.0),
573 ThreatSeverity::Medium
574 );
575 assert_eq!(detector.calculate_severity(3.2, 3.0), ThreatSeverity::Low);
576 }
577
578 #[test]
579 fn test_analyze_log() {
580 let mut detector = AnomalyDetector::new();
581
582 for _ in 0..20 {
584 let mut metadata = HashMap::new();
585 metadata.insert("request_count".to_string(), "100".to_string());
586 let log = LogEntry {
587 timestamp: Utc::now(),
588 source_ip: Some("192.168.1.1".to_string()),
589 user: Some("test".to_string()),
590 event_type: "metric".to_string(),
591 message: "Normal traffic".to_string(),
592 metadata,
593 };
594 detector.analyze_log(&log);
595 }
596
597 let mut metadata = HashMap::new();
599 metadata.insert("request_count".to_string(), "10000".to_string());
600 let log = LogEntry {
601 timestamp: Utc::now(),
602 source_ip: Some("192.168.1.1".to_string()),
603 user: Some("test".to_string()),
604 event_type: "metric".to_string(),
605 message: "Spike in traffic".to_string(),
606 metadata,
607 };
608
609 let alerts = detector.analyze_log(&log);
610 assert!(!alerts.is_empty());
611 assert_eq!(alerts[0].category, ThreatCategory::AnomalousActivity);
612 }
613
614 #[test]
615 fn test_clear_old_data() {
616 let mut detector = AnomalyDetector::new();
617
618 let old_time = Utc::now() - Duration::hours(2);
619 let new_time = Utc::now();
620
621 detector.track_metric("test", 10.0, old_time);
622 detector.track_metric("test", 20.0, new_time);
623
624 let metric = detector.get_metric("test").unwrap();
625 assert_eq!(metric.values.len(), 2);
626
627 let cutoff = Utc::now() - Duration::hours(1);
628 detector.clear_old_data(cutoff);
629
630 let metric = detector.get_metric("test").unwrap();
631 assert_eq!(metric.values.len(), 1);
632 assert_eq!(metric.values[0], 20.0);
633 }
634
635 #[test]
636 fn test_get_all_metrics() {
637 let mut detector = AnomalyDetector::new();
638
639 detector.track_metric("metric1", 10.0, Utc::now());
640 detector.track_metric("metric2", 20.0, Utc::now());
641 detector.track_metric("metric3", 30.0, Utc::now());
642
643 let metrics = detector.get_all_metrics();
644 assert_eq!(metrics.len(), 3);
645 assert!(metrics.contains(&"metric1"));
646 assert!(metrics.contains(&"metric2"));
647 assert!(metrics.contains(&"metric3"));
648 }
649
650 #[test]
651 fn test_anomaly_to_threat_alert() {
652 let anomaly = AnomalyResult {
653 metric_name: "test_metric".to_string(),
654 current_value: 500.0,
655 expected_value: 100.0,
656 deviation: 5.0,
657 method: DetectionMethod::ZScore,
658 severity: ThreatSeverity::High,
659 description: "Test anomaly".to_string(),
660 };
661
662 let log = LogEntry {
663 timestamp: Utc::now(),
664 source_ip: Some("192.168.1.1".to_string()),
665 user: Some("test".to_string()),
666 event_type: "test".to_string(),
667 message: "test message".to_string(),
668 metadata: HashMap::new(),
669 };
670
671 let alert = anomaly.to_threat_alert(&log);
672 assert_eq!(alert.severity, ThreatSeverity::High);
673 assert_eq!(alert.category, ThreatCategory::AnomalousActivity);
674 assert!(alert.threat_score > 0);
675 }
676}