rust_threat_detector/
anomaly_detection.rs

1//! # Anomaly Detection Engine
2//!
3//! Statistical and time-series based anomaly detection for identifying
4//! unusual patterns in security logs and metrics.
5
6use crate::{LogEntry, ThreatAlert, ThreatCategory, ThreatSeverity};
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10
11/// Anomaly detection method
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum DetectionMethod {
14    /// Z-score based detection (standard deviations from mean)
15    ZScore,
16    /// Moving average with threshold
17    MovingAverage,
18    /// Exponential smoothing
19    ExponentialSmoothing,
20    /// Inter-Quartile Range (IQR)
21    IQR,
22}
23
24/// Time series metric for tracking
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TimeSeries {
27    pub name: String,
28    pub values: VecDeque<f64>,
29    pub timestamps: VecDeque<DateTime<Utc>>,
30    pub max_size: usize,
31}
32
33impl TimeSeries {
34    /// Create new time series with maximum history
35    pub fn new(name: String, max_size: usize) -> Self {
36        Self {
37            name,
38            values: VecDeque::with_capacity(max_size),
39            timestamps: VecDeque::with_capacity(max_size),
40            max_size,
41        }
42    }
43
44    /// Add value to time series
45    pub fn add(&mut self, value: f64, timestamp: DateTime<Utc>) {
46        if self.values.len() >= self.max_size {
47            self.values.pop_front();
48            self.timestamps.pop_front();
49        }
50        self.values.push_back(value);
51        self.timestamps.push_back(timestamp);
52    }
53
54    /// Calculate mean
55    pub fn mean(&self) -> f64 {
56        if self.values.is_empty() {
57            return 0.0;
58        }
59        self.values.iter().sum::<f64>() / self.values.len() as f64
60    }
61
62    /// Calculate standard deviation
63    pub fn std_dev(&self) -> f64 {
64        if self.values.len() < 2 {
65            return 0.0;
66        }
67        let mean = self.mean();
68        let variance = self.values.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
69            / (self.values.len() - 1) as f64;
70        variance.sqrt()
71    }
72
73    /// Calculate moving average over window
74    pub fn moving_average(&self, window_size: usize) -> f64 {
75        if self.values.is_empty() {
76            return 0.0;
77        }
78        let window = window_size.min(self.values.len());
79        let start = self.values.len().saturating_sub(window);
80        self.values.iter().skip(start).sum::<f64>() / window as f64
81    }
82
83    /// Get percentile value
84    pub fn percentile(&self, p: f64) -> f64 {
85        if self.values.is_empty() {
86            return 0.0;
87        }
88        let mut sorted: Vec<f64> = self.values.iter().copied().collect();
89        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
90        let index = ((p / 100.0) * (sorted.len() - 1) as f64).round() as usize;
91        sorted[index]
92    }
93
94    /// Calculate IQR (Inter-Quartile Range)
95    pub fn iqr(&self) -> (f64, f64, f64) {
96        let q1 = self.percentile(25.0);
97        let q3 = self.percentile(75.0);
98        let iqr = q3 - q1;
99        (q1, q3, iqr)
100    }
101}
102
103/// Anomaly detection engine
104pub struct AnomalyDetector {
105    metrics: HashMap<String, TimeSeries>,
106    z_score_threshold: f64,
107    iqr_multiplier: f64,
108    moving_avg_window: usize,
109    smoothing_alpha: f64,
110}
111
112impl AnomalyDetector {
113    /// Create new anomaly detector with default parameters
114    pub fn new() -> Self {
115        Self {
116            metrics: HashMap::new(),
117            z_score_threshold: 3.0, // 3 sigma
118            iqr_multiplier: 1.5,    // Standard IQR multiplier
119            moving_avg_window: 10,
120            smoothing_alpha: 0.3, // Exponential smoothing factor
121        }
122    }
123
124    /// Create with custom parameters
125    pub fn with_params(
126        z_score_threshold: f64,
127        iqr_multiplier: f64,
128        moving_avg_window: usize,
129        smoothing_alpha: f64,
130    ) -> Self {
131        Self {
132            metrics: HashMap::new(),
133            z_score_threshold,
134            iqr_multiplier,
135            moving_avg_window,
136            smoothing_alpha,
137        }
138    }
139
140    /// Track metric value
141    pub fn track_metric(&mut self, name: &str, value: f64, timestamp: DateTime<Utc>) {
142        let metric = self
143            .metrics
144            .entry(name.to_string())
145            .or_insert_with(|| TimeSeries::new(name.to_string(), 1000));
146        metric.add(value, timestamp);
147    }
148
149    /// Detect anomaly using specified method
150    pub fn detect(
151        &self,
152        metric_name: &str,
153        current_value: f64,
154        method: DetectionMethod,
155    ) -> Option<AnomalyResult> {
156        let metric = self.metrics.get(metric_name)?;
157
158        if metric.values.is_empty() {
159            return None;
160        }
161
162        match method {
163            DetectionMethod::ZScore => self.detect_zscore(metric, current_value),
164            DetectionMethod::MovingAverage => self.detect_moving_avg(metric, current_value),
165            DetectionMethod::ExponentialSmoothing => self.detect_exponential(metric, current_value),
166            DetectionMethod::IQR => self.detect_iqr(metric, current_value),
167        }
168    }
169
170    /// Detect using Z-score (standard deviations from mean)
171    fn detect_zscore(&self, metric: &TimeSeries, value: f64) -> Option<AnomalyResult> {
172        if metric.values.len() < 10 {
173            return None; // Need sufficient history
174        }
175
176        let mean = metric.mean();
177        let std_dev = metric.std_dev();
178
179        if std_dev == 0.0 {
180            return None; // No variation
181        }
182
183        let z_score = (value - mean).abs() / std_dev;
184
185        if z_score > self.z_score_threshold {
186            Some(AnomalyResult {
187                metric_name: metric.name.clone(),
188                current_value: value,
189                expected_value: mean,
190                deviation: z_score,
191                method: DetectionMethod::ZScore,
192                severity: self.calculate_severity(z_score, self.z_score_threshold),
193                description: format!(
194                    "Value {:.2} deviates {:.2} standard deviations from mean {:.2}",
195                    value, z_score, mean
196                ),
197            })
198        } else {
199            None
200        }
201    }
202
203    /// Detect using moving average
204    fn detect_moving_avg(&self, metric: &TimeSeries, value: f64) -> Option<AnomalyResult> {
205        if metric.values.len() < self.moving_avg_window {
206            return None;
207        }
208
209        let moving_avg = metric.moving_average(self.moving_avg_window);
210        let std_dev = metric.std_dev();
211
212        if std_dev == 0.0 {
213            return None;
214        }
215
216        let deviation = (value - moving_avg).abs() / std_dev;
217
218        if deviation > self.z_score_threshold {
219            Some(AnomalyResult {
220                metric_name: metric.name.clone(),
221                current_value: value,
222                expected_value: moving_avg,
223                deviation,
224                method: DetectionMethod::MovingAverage,
225                severity: self.calculate_severity(deviation, self.z_score_threshold),
226                description: format!(
227                    "Value {:.2} deviates from moving average {:.2} by {:.2} std devs",
228                    value, moving_avg, deviation
229                ),
230            })
231        } else {
232            None
233        }
234    }
235
236    /// Detect using exponential smoothing
237    fn detect_exponential(&self, metric: &TimeSeries, value: f64) -> Option<AnomalyResult> {
238        if metric.values.is_empty() {
239            return None;
240        }
241
242        // Calculate exponentially weighted moving average
243        let mut ewma = metric.values[0];
244        for &v in metric.values.iter().skip(1) {
245            ewma = self.smoothing_alpha * v + (1.0 - self.smoothing_alpha) * ewma;
246        }
247
248        let std_dev = metric.std_dev();
249        if std_dev == 0.0 {
250            return None;
251        }
252
253        let deviation = (value - ewma).abs() / std_dev;
254
255        if deviation > self.z_score_threshold {
256            Some(AnomalyResult {
257                metric_name: metric.name.clone(),
258                current_value: value,
259                expected_value: ewma,
260                deviation,
261                method: DetectionMethod::ExponentialSmoothing,
262                severity: self.calculate_severity(deviation, self.z_score_threshold),
263                description: format!(
264                    "Value {:.2} deviates from exponential moving average {:.2}",
265                    value, ewma
266                ),
267            })
268        } else {
269            None
270        }
271    }
272
273    /// Detect using IQR (Inter-Quartile Range)
274    fn detect_iqr(&self, metric: &TimeSeries, value: f64) -> Option<AnomalyResult> {
275        if metric.values.len() < 10 {
276            return None;
277        }
278
279        let (q1, q3, iqr) = metric.iqr();
280        let lower_bound = q1 - self.iqr_multiplier * iqr;
281        let upper_bound = q3 + self.iqr_multiplier * iqr;
282
283        if value < lower_bound || value > upper_bound {
284            let deviation = if value < lower_bound {
285                (lower_bound - value) / iqr
286            } else {
287                (value - upper_bound) / iqr
288            };
289
290            Some(AnomalyResult {
291                metric_name: metric.name.clone(),
292                current_value: value,
293                expected_value: (q1 + q3) / 2.0,
294                deviation,
295                method: DetectionMethod::IQR,
296                severity: self.calculate_severity(deviation, 1.0),
297                description: format!(
298                    "Value {:.2} outside IQR bounds [{:.2}, {:.2}]",
299                    value, lower_bound, upper_bound
300                ),
301            })
302        } else {
303            None
304        }
305    }
306
307    /// Calculate severity based on deviation
308    fn calculate_severity(&self, deviation: f64, threshold: f64) -> ThreatSeverity {
309        let ratio = deviation / threshold;
310        if ratio > 3.0 {
311            ThreatSeverity::Critical
312        } else if ratio > 2.0 {
313            ThreatSeverity::High
314        } else if ratio > 1.5 {
315            ThreatSeverity::Medium
316        } else {
317            ThreatSeverity::Low
318        }
319    }
320
321    /// Analyze log for metric anomalies
322    pub fn analyze_log(&mut self, log: &LogEntry) -> Vec<ThreatAlert> {
323        let mut alerts = Vec::new();
324
325        // Extract metrics from log metadata
326        for (key, value_str) in &log.metadata {
327            if let Ok(value) = value_str.parse::<f64>() {
328                let metric_name = format!("log.{}", key);
329                self.track_metric(&metric_name, value, log.timestamp);
330
331                // Try all detection methods
332                for method in &[
333                    DetectionMethod::ZScore,
334                    DetectionMethod::MovingAverage,
335                    DetectionMethod::IQR,
336                ] {
337                    if let Some(anomaly) = self.detect(&metric_name, value, *method) {
338                        alerts.push(anomaly.to_threat_alert(log));
339                        break; // Only generate one alert per metric
340                    }
341                }
342            }
343        }
344
345        alerts
346    }
347
348    /// Get metric statistics
349    pub fn get_metric(&self, name: &str) -> Option<&TimeSeries> {
350        self.metrics.get(name)
351    }
352
353    /// Get all tracked metrics
354    pub fn get_all_metrics(&self) -> Vec<&str> {
355        self.metrics.keys().map(|s| s.as_str()).collect()
356    }
357
358    /// Clear old data from metrics
359    pub fn clear_old_data(&mut self, before: DateTime<Utc>) {
360        for metric in self.metrics.values_mut() {
361            while let Some(&timestamp) = metric.timestamps.front() {
362                if timestamp < before {
363                    metric.timestamps.pop_front();
364                    metric.values.pop_front();
365                } else {
366                    break;
367                }
368            }
369        }
370    }
371}
372
373impl Default for AnomalyDetector {
374    fn default() -> Self {
375        Self::new()
376    }
377}
378
379/// Result of anomaly detection
380#[derive(Debug, Clone)]
381pub struct AnomalyResult {
382    pub metric_name: String,
383    pub current_value: f64,
384    pub expected_value: f64,
385    pub deviation: f64,
386    pub method: DetectionMethod,
387    pub severity: ThreatSeverity,
388    pub description: String,
389}
390
391impl AnomalyResult {
392    /// Convert to ThreatAlert
393    pub fn to_threat_alert(&self, source_log: &LogEntry) -> ThreatAlert {
394        ThreatAlert {
395            alert_id: format!("ANOMALY-{}", chrono::Utc::now().timestamp()),
396            timestamp: Utc::now(),
397            severity: self.severity,
398            category: ThreatCategory::AnomalousActivity,
399            description: format!(
400                "Statistical anomaly in {}: {}",
401                self.metric_name, self.description
402            ),
403            source_log: format!("{} - {}", source_log.timestamp, source_log.message),
404            indicators: vec![
405                format!("Current: {:.2}", self.current_value),
406                format!("Expected: {:.2}", self.expected_value),
407                format!("Deviation: {:.2}", self.deviation),
408                format!("Method: {:?}", self.method),
409            ],
410            recommended_action:
411                "Investigate metric anomaly, review related logs, check for system issues"
412                    .to_string(),
413            threat_score: self.calculate_threat_score(),
414            correlated_alerts: vec![],
415        }
416    }
417
418    fn calculate_threat_score(&self) -> u32 {
419        let base_score = match self.severity {
420            ThreatSeverity::Info => 10,
421            ThreatSeverity::Low => 25,
422            ThreatSeverity::Medium => 50,
423            ThreatSeverity::High => 75,
424            ThreatSeverity::Critical => 95,
425        };
426
427        // Adjust based on deviation magnitude
428        let deviation_bonus = (self.deviation * 2.0).min(20.0) as u32;
429        (base_score + deviation_bonus).min(100)
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use chrono::Duration;
437    use std::collections::HashMap;
438
439    #[test]
440    fn test_time_series_mean() {
441        let mut ts = TimeSeries::new("test".to_string(), 100);
442        ts.add(10.0, Utc::now());
443        ts.add(20.0, Utc::now());
444        ts.add(30.0, Utc::now());
445
446        assert_eq!(ts.mean(), 20.0);
447    }
448
449    #[test]
450    fn test_time_series_std_dev() {
451        let mut ts = TimeSeries::new("test".to_string(), 100);
452        for i in 1..=10 {
453            ts.add(i as f64, Utc::now());
454        }
455
456        let std_dev = ts.std_dev();
457        assert!(std_dev > 0.0);
458        assert!(std_dev < 4.0); // Approximate std dev for 1-10
459    }
460
461    #[test]
462    fn test_time_series_moving_average() {
463        let mut ts = TimeSeries::new("test".to_string(), 100);
464        ts.add(10.0, Utc::now());
465        ts.add(20.0, Utc::now());
466        ts.add(30.0, Utc::now());
467        ts.add(40.0, Utc::now());
468
469        let ma = ts.moving_average(2);
470        assert_eq!(ma, 35.0); // (30 + 40) / 2
471    }
472
473    #[test]
474    fn test_time_series_percentile() {
475        let mut ts = TimeSeries::new("test".to_string(), 100);
476        for i in 1..=100 {
477            ts.add(i as f64, Utc::now());
478        }
479
480        assert_eq!(ts.percentile(0.0), 1.0);
481        assert_eq!(ts.percentile(100.0), 100.0);
482        let median = ts.percentile(50.0);
483        assert!((49.0..=52.0).contains(&median));
484    }
485
486    #[test]
487    fn test_time_series_iqr() {
488        let mut ts = TimeSeries::new("test".to_string(), 100);
489        for i in 1..=100 {
490            ts.add(i as f64, Utc::now());
491        }
492
493        let (q1, q3, iqr) = ts.iqr();
494        assert!((24.0..=27.0).contains(&q1));
495        assert!((73.0..=77.0).contains(&q3));
496        assert!((48.0..=52.0).contains(&iqr));
497    }
498
499    #[test]
500    fn test_zscore_detection() {
501        let mut detector = AnomalyDetector::new();
502
503        // Build baseline
504        for i in 0..20 {
505            detector.track_metric("test_metric", 100.0 + (i as f64), Utc::now());
506        }
507
508        // Test normal value - should not detect
509        let result = detector.detect("test_metric", 110.0, DetectionMethod::ZScore);
510        assert!(result.is_none());
511
512        // Test anomalous value - should detect
513        let result = detector.detect("test_metric", 500.0, DetectionMethod::ZScore);
514        assert!(result.is_some());
515        let anomaly = result.unwrap();
516        assert_eq!(anomaly.metric_name, "test_metric");
517        assert_eq!(anomaly.current_value, 500.0);
518    }
519
520    #[test]
521    fn test_moving_average_detection() {
522        let mut detector = AnomalyDetector::new();
523
524        for i in 0..15 {
525            detector.track_metric("test_metric", 50.0 + i as f64, Utc::now());
526        }
527
528        let result = detector.detect("test_metric", 200.0, DetectionMethod::MovingAverage);
529        assert!(result.is_some());
530    }
531
532    #[test]
533    fn test_iqr_detection() {
534        let mut detector = AnomalyDetector::new();
535
536        // Normal distribution
537        for i in 1..=20 {
538            detector.track_metric("test_metric", i as f64 * 10.0, Utc::now());
539        }
540
541        // Outlier
542        let result = detector.detect("test_metric", 1000.0, DetectionMethod::IQR);
543        assert!(result.is_some());
544
545        // Normal value
546        let result = detector.detect("test_metric", 105.0, DetectionMethod::IQR);
547        assert!(result.is_none());
548    }
549
550    #[test]
551    fn test_exponential_smoothing() {
552        let mut detector = AnomalyDetector::with_params(3.0, 1.5, 10, 0.3);
553
554        for i in 0..20 {
555            detector.track_metric("test_metric", 100.0 + (i as f64), Utc::now());
556        }
557
558        let result = detector.detect("test_metric", 500.0, DetectionMethod::ExponentialSmoothing);
559        assert!(result.is_some());
560    }
561
562    #[test]
563    fn test_severity_calculation() {
564        let detector = AnomalyDetector::new();
565
566        assert_eq!(
567            detector.calculate_severity(10.0, 3.0),
568            ThreatSeverity::Critical
569        );
570        assert_eq!(detector.calculate_severity(6.5, 3.0), ThreatSeverity::High);
571        assert_eq!(
572            detector.calculate_severity(4.8, 3.0),
573            ThreatSeverity::Medium
574        );
575        assert_eq!(detector.calculate_severity(3.2, 3.0), ThreatSeverity::Low);
576    }
577
578    #[test]
579    fn test_analyze_log() {
580        let mut detector = AnomalyDetector::new();
581
582        // Build baseline
583        for _ in 0..20 {
584            let mut metadata = HashMap::new();
585            metadata.insert("request_count".to_string(), "100".to_string());
586            let log = LogEntry {
587                timestamp: Utc::now(),
588                source_ip: Some("192.168.1.1".to_string()),
589                user: Some("test".to_string()),
590                event_type: "metric".to_string(),
591                message: "Normal traffic".to_string(),
592                metadata,
593            };
594            detector.analyze_log(&log);
595        }
596
597        // Anomalous log
598        let mut metadata = HashMap::new();
599        metadata.insert("request_count".to_string(), "10000".to_string());
600        let log = LogEntry {
601            timestamp: Utc::now(),
602            source_ip: Some("192.168.1.1".to_string()),
603            user: Some("test".to_string()),
604            event_type: "metric".to_string(),
605            message: "Spike in traffic".to_string(),
606            metadata,
607        };
608
609        let alerts = detector.analyze_log(&log);
610        assert!(!alerts.is_empty());
611        assert_eq!(alerts[0].category, ThreatCategory::AnomalousActivity);
612    }
613
614    #[test]
615    fn test_clear_old_data() {
616        let mut detector = AnomalyDetector::new();
617
618        let old_time = Utc::now() - Duration::hours(2);
619        let new_time = Utc::now();
620
621        detector.track_metric("test", 10.0, old_time);
622        detector.track_metric("test", 20.0, new_time);
623
624        let metric = detector.get_metric("test").unwrap();
625        assert_eq!(metric.values.len(), 2);
626
627        let cutoff = Utc::now() - Duration::hours(1);
628        detector.clear_old_data(cutoff);
629
630        let metric = detector.get_metric("test").unwrap();
631        assert_eq!(metric.values.len(), 1);
632        assert_eq!(metric.values[0], 20.0);
633    }
634
635    #[test]
636    fn test_get_all_metrics() {
637        let mut detector = AnomalyDetector::new();
638
639        detector.track_metric("metric1", 10.0, Utc::now());
640        detector.track_metric("metric2", 20.0, Utc::now());
641        detector.track_metric("metric3", 30.0, Utc::now());
642
643        let metrics = detector.get_all_metrics();
644        assert_eq!(metrics.len(), 3);
645        assert!(metrics.contains(&"metric1"));
646        assert!(metrics.contains(&"metric2"));
647        assert!(metrics.contains(&"metric3"));
648    }
649
650    #[test]
651    fn test_anomaly_to_threat_alert() {
652        let anomaly = AnomalyResult {
653            metric_name: "test_metric".to_string(),
654            current_value: 500.0,
655            expected_value: 100.0,
656            deviation: 5.0,
657            method: DetectionMethod::ZScore,
658            severity: ThreatSeverity::High,
659            description: "Test anomaly".to_string(),
660        };
661
662        let log = LogEntry {
663            timestamp: Utc::now(),
664            source_ip: Some("192.168.1.1".to_string()),
665            user: Some("test".to_string()),
666            event_type: "test".to_string(),
667            message: "test message".to_string(),
668            metadata: HashMap::new(),
669        };
670
671        let alert = anomaly.to_threat_alert(&log);
672        assert_eq!(alert.severity, ThreatSeverity::High);
673        assert_eq!(alert.category, ThreatCategory::AnomalousActivity);
674        assert!(alert.threat_score > 0);
675    }
676}