rust_threat_detector/
ml_scoring.rs

1//! Machine Learning-based Threat Scoring v2.0
2//!
3//! Provides advanced threat scoring using statistical models and
4//! feature engineering for improved threat detection accuracy.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9
10/// Feature vector for ML scoring
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ThreatFeatures {
13    /// Time-based features
14    pub hour_of_day: f64,
15    pub day_of_week: f64,
16    pub is_weekend: f64,
17    pub is_business_hours: f64,
18
19    /// Volume-based features
20    pub event_count_1h: f64,
21    pub event_count_24h: f64,
22    pub failed_ratio: f64,
23    pub unique_sources: f64,
24
25    /// Behavioral features
26    pub velocity_score: f64,      // Rate of events
27    pub entropy_score: f64,       // Randomness of activity
28    pub deviation_score: f64,     // Deviation from baseline
29    pub anomaly_indicators: f64,  // Number of anomaly flags
30
31    /// Contextual features
32    pub geo_risk_score: f64,
33    pub asset_criticality: f64,
34    pub user_risk_score: f64,
35    pub network_risk_score: f64,
36}
37
38impl ThreatFeatures {
39    /// Create empty feature vector
40    pub fn new() -> Self {
41        Self {
42            hour_of_day: 0.0,
43            day_of_week: 0.0,
44            is_weekend: 0.0,
45            is_business_hours: 0.0,
46            event_count_1h: 0.0,
47            event_count_24h: 0.0,
48            failed_ratio: 0.0,
49            unique_sources: 0.0,
50            velocity_score: 0.0,
51            entropy_score: 0.0,
52            deviation_score: 0.0,
53            anomaly_indicators: 0.0,
54            geo_risk_score: 0.0,
55            asset_criticality: 0.0,
56            user_risk_score: 0.0,
57            network_risk_score: 0.0,
58        }
59    }
60
61    /// Convert features to vector for model input
62    pub fn to_vector(&self) -> Vec<f64> {
63        vec![
64            self.hour_of_day,
65            self.day_of_week,
66            self.is_weekend,
67            self.is_business_hours,
68            self.event_count_1h,
69            self.event_count_24h,
70            self.failed_ratio,
71            self.unique_sources,
72            self.velocity_score,
73            self.entropy_score,
74            self.deviation_score,
75            self.anomaly_indicators,
76            self.geo_risk_score,
77            self.asset_criticality,
78            self.user_risk_score,
79            self.network_risk_score,
80        ]
81    }
82
83    /// Normalize features to 0-1 range
84    pub fn normalize(&mut self) {
85        self.hour_of_day /= 24.0;
86        self.day_of_week /= 7.0;
87        // is_weekend and is_business_hours are already 0-1
88        self.event_count_1h = (self.event_count_1h / 1000.0).min(1.0);
89        self.event_count_24h = (self.event_count_24h / 10000.0).min(1.0);
90        // failed_ratio is already 0-1
91        self.unique_sources = (self.unique_sources / 100.0).min(1.0);
92        self.velocity_score = (self.velocity_score / 100.0).min(1.0);
93        // entropy_score, deviation_score are typically 0-1
94        self.anomaly_indicators = (self.anomaly_indicators / 10.0).min(1.0);
95        // risk scores should be 0-100, normalize to 0-1
96        self.geo_risk_score /= 100.0;
97        self.asset_criticality /= 100.0;
98        self.user_risk_score /= 100.0;
99        self.network_risk_score /= 100.0;
100    }
101}
102
103impl Default for ThreatFeatures {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109/// ML model weights (pre-trained)
110#[derive(Debug, Clone)]
111pub struct ModelWeights {
112    pub feature_weights: Vec<f64>,
113    pub bias: f64,
114    pub threshold: f64,
115}
116
117impl ModelWeights {
118    /// Create default weights based on security expertise
119    pub fn default_security_model() -> Self {
120        Self {
121            feature_weights: vec![
122                0.05,   // hour_of_day - unusual hours increase risk
123                0.02,   // day_of_week
124                0.10,   // is_weekend - weekend activity suspicious
125                -0.05,  // is_business_hours - business hours reduce risk
126                0.15,   // event_count_1h - high volume suspicious
127                0.10,   // event_count_24h
128                0.25,   // failed_ratio - high failure rate very suspicious
129                0.12,   // unique_sources - many sources suspicious
130                0.18,   // velocity_score - rapid activity suspicious
131                0.20,   // entropy_score - randomness suspicious
132                0.22,   // deviation_score - deviation from baseline
133                0.25,   // anomaly_indicators
134                0.15,   // geo_risk_score
135                0.10,   // asset_criticality
136                0.18,   // user_risk_score
137                0.12,   // network_risk_score
138            ],
139            bias: 0.1,
140            threshold: 0.5,
141        }
142    }
143}
144
145impl Default for ModelWeights {
146    fn default() -> Self {
147        Self::default_security_model()
148    }
149}
150
151/// ML-based threat scorer
152pub struct MLThreatScorer {
153    weights: ModelWeights,
154    feature_history: HashMap<String, VecDeque<ThreatFeatures>>,
155    baseline_stats: HashMap<String, BaselineStats>,
156    max_history: usize,
157}
158
159/// Baseline statistics for anomaly detection
160#[derive(Debug, Clone)]
161pub struct BaselineStats {
162    pub mean_event_rate: f64,
163    pub std_event_rate: f64,
164    pub mean_failed_ratio: f64,
165    pub typical_hours: Vec<u32>,
166    pub sample_count: usize,
167}
168
169impl BaselineStats {
170    pub fn new() -> Self {
171        Self {
172            mean_event_rate: 10.0,
173            std_event_rate: 5.0,
174            mean_failed_ratio: 0.05,
175            typical_hours: (9..18).collect(),
176            sample_count: 0,
177        }
178    }
179
180    /// Update baseline with new observation
181    pub fn update(&mut self, event_rate: f64, failed_ratio: f64, hour: u32) {
182        self.sample_count += 1;
183        let n = self.sample_count as f64;
184
185        // Running mean update
186        let old_mean = self.mean_event_rate;
187        self.mean_event_rate += (event_rate - old_mean) / n;
188        self.std_event_rate += (event_rate - old_mean) * (event_rate - self.mean_event_rate);
189
190        self.mean_failed_ratio += (failed_ratio - self.mean_failed_ratio) / n;
191
192        if !self.typical_hours.contains(&hour) && self.sample_count > 10 {
193            self.typical_hours.push(hour);
194        }
195    }
196
197    /// Calculate deviation from baseline
198    pub fn calculate_deviation(&self, event_rate: f64) -> f64 {
199        if self.std_event_rate == 0.0 {
200            return 0.0;
201        }
202        let std = (self.std_event_rate / self.sample_count.max(1) as f64).sqrt();
203        ((event_rate - self.mean_event_rate) / std.max(1.0)).abs().min(3.0) / 3.0
204    }
205}
206
207impl Default for BaselineStats {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213/// Threat score result
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct ThreatScore {
216    pub score: f64,
217    pub confidence: f64,
218    pub risk_level: RiskLevel,
219    pub contributing_factors: Vec<ContributingFactor>,
220    pub timestamp: DateTime<Utc>,
221}
222
223/// Risk level classification
224#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
225pub enum RiskLevel {
226    Minimal,
227    Low,
228    Medium,
229    High,
230    Critical,
231}
232
233impl RiskLevel {
234    pub fn from_score(score: f64) -> Self {
235        match score {
236            s if s >= 0.9 => RiskLevel::Critical,
237            s if s >= 0.7 => RiskLevel::High,
238            s if s >= 0.5 => RiskLevel::Medium,
239            s if s >= 0.3 => RiskLevel::Low,
240            _ => RiskLevel::Minimal,
241        }
242    }
243}
244
245/// Factor contributing to threat score
246#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct ContributingFactor {
248    pub name: String,
249    pub value: f64,
250    pub contribution: f64,
251    pub description: String,
252}
253
254impl MLThreatScorer {
255    /// Create new ML threat scorer
256    pub fn new() -> Self {
257        Self {
258            weights: ModelWeights::default(),
259            feature_history: HashMap::new(),
260            baseline_stats: HashMap::new(),
261            max_history: 1000,
262        }
263    }
264
265    /// Create with custom weights
266    pub fn with_weights(weights: ModelWeights) -> Self {
267        Self {
268            weights,
269            feature_history: HashMap::new(),
270            baseline_stats: HashMap::new(),
271            max_history: 1000,
272        }
273    }
274
275    /// Extract features from event data
276    pub fn extract_features(
277        &mut self,
278        entity_id: &str,
279        timestamp: DateTime<Utc>,
280        event_count_1h: usize,
281        event_count_24h: usize,
282        failed_count: usize,
283        total_count: usize,
284        unique_sources: usize,
285        source_ip: Option<&str>,
286        asset_criticality: f64,
287    ) -> ThreatFeatures {
288        let hour = timestamp.format("%H").to_string().parse::<f64>().unwrap_or(0.0);
289        let day = timestamp.format("%u").to_string().parse::<f64>().unwrap_or(1.0);
290        let is_weekend = if day >= 6.0 { 1.0 } else { 0.0 };
291        let is_business_hours = if hour >= 9.0 && hour <= 17.0 && day < 6.0 { 1.0 } else { 0.0 };
292
293        let failed_ratio = if total_count > 0 {
294            failed_count as f64 / total_count as f64
295        } else {
296            0.0
297        };
298
299        // Calculate velocity (events per minute in last hour)
300        let velocity = event_count_1h as f64 / 60.0;
301
302        // Get or create baseline
303        let baseline = self.baseline_stats
304            .entry(entity_id.to_string())
305            .or_insert_with(BaselineStats::new);
306
307        let deviation = baseline.calculate_deviation(event_count_1h as f64);
308
309        // Update baseline
310        baseline.update(event_count_1h as f64, failed_ratio, hour as u32);
311
312        // Calculate entropy (simplified - based on source diversity)
313        let entropy = if unique_sources > 1 {
314            (unique_sources as f64).ln() / 10.0_f64.ln()
315        } else {
316            0.0
317        };
318
319        // Geo risk based on IP (simplified)
320        let geo_risk = match source_ip {
321            Some(ip) if ip.starts_with("10.") || ip.starts_with("192.168.") => 10.0,
322            Some(_) => 50.0, // External IP
323            None => 30.0,    // Unknown
324        };
325
326        // User risk placeholder
327        let user_risk = if failed_ratio > 0.3 { 70.0 } else { 20.0 };
328
329        // Network risk placeholder
330        let network_risk = if unique_sources > 10 { 60.0 } else { 20.0 };
331
332        // Count anomaly indicators
333        let mut anomaly_count = 0.0;
334        if is_weekend > 0.0 && event_count_1h > 100 { anomaly_count += 1.0; }
335        if failed_ratio > 0.5 { anomaly_count += 2.0; }
336        if deviation > 0.5 { anomaly_count += 1.0; }
337        if velocity > 10.0 { anomaly_count += 1.0; }
338
339        ThreatFeatures {
340            hour_of_day: hour,
341            day_of_week: day,
342            is_weekend,
343            is_business_hours,
344            event_count_1h: event_count_1h as f64,
345            event_count_24h: event_count_24h as f64,
346            failed_ratio,
347            unique_sources: unique_sources as f64,
348            velocity_score: velocity,
349            entropy_score: entropy,
350            deviation_score: deviation,
351            anomaly_indicators: anomaly_count,
352            geo_risk_score: geo_risk,
353            asset_criticality,
354            user_risk_score: user_risk,
355            network_risk_score: network_risk,
356        }
357    }
358
359    /// Calculate threat score from features
360    pub fn score(&self, features: &ThreatFeatures) -> ThreatScore {
361        let mut normalized = features.clone();
362        normalized.normalize();
363
364        let feature_vec = normalized.to_vector();
365        let mut raw_score = self.weights.bias;
366        let mut contributing_factors = Vec::new();
367
368        let factor_names = [
369            "Hour of Day", "Day of Week", "Weekend Activity", "Business Hours",
370            "Event Volume (1h)", "Event Volume (24h)", "Failure Rate", "Unique Sources",
371            "Velocity", "Entropy", "Baseline Deviation", "Anomaly Indicators",
372            "Geographic Risk", "Asset Criticality", "User Risk", "Network Risk",
373        ];
374
375        for (i, (&value, &weight)) in feature_vec.iter().zip(self.weights.feature_weights.iter()).enumerate() {
376            let contribution = value * weight;
377            raw_score += contribution;
378
379            if contribution.abs() > 0.01 {
380                contributing_factors.push(ContributingFactor {
381                    name: factor_names.get(i).unwrap_or(&"Unknown").to_string(),
382                    value,
383                    contribution,
384                    description: self.describe_contribution(factor_names.get(i).unwrap_or(&""), value),
385                });
386            }
387        }
388
389        // Apply sigmoid for probability-like output
390        let score = 1.0 / (1.0 + (-raw_score).exp());
391
392        // Sort contributing factors by absolute contribution
393        contributing_factors.sort_by(|a, b| {
394            b.contribution.abs().partial_cmp(&a.contribution.abs()).unwrap()
395        });
396        contributing_factors.truncate(5);
397
398        // Calculate confidence based on sample size
399        let confidence = self.calculate_confidence(features);
400
401        ThreatScore {
402            score,
403            confidence,
404            risk_level: RiskLevel::from_score(score),
405            contributing_factors,
406            timestamp: Utc::now(),
407        }
408    }
409
410    /// Describe what a contribution means
411    fn describe_contribution(&self, name: &str, value: f64) -> String {
412        match name {
413            "Failure Rate" if value > 0.5 => "High failure rate indicates potential brute force".to_string(),
414            "Failure Rate" => "Normal failure rate".to_string(),
415            "Weekend Activity" if value > 0.0 => "Activity during weekend (unusual)".to_string(),
416            "Velocity" if value > 0.5 => "Rapid event generation (suspicious)".to_string(),
417            "Baseline Deviation" if value > 0.5 => "Significant deviation from normal behavior".to_string(),
418            "Geographic Risk" if value > 0.5 => "External or suspicious source location".to_string(),
419            "Anomaly Indicators" if value > 0.0 => "Multiple anomaly flags detected".to_string(),
420            _ => format!("{} score: {:.2}", name, value),
421        }
422    }
423
424    /// Calculate confidence in the score
425    fn calculate_confidence(&self, features: &ThreatFeatures) -> f64 {
426        // Higher confidence with more data
427        let event_factor = (features.event_count_24h / 100.0).min(1.0);
428
429        // Higher confidence when features are clear
430        let clarity_factor = if features.failed_ratio > 0.5 || features.deviation_score > 0.5 {
431            0.9
432        } else if features.failed_ratio < 0.1 && features.deviation_score < 0.2 {
433            0.9
434        } else {
435            0.6
436        };
437
438        (event_factor * 0.4 + clarity_factor * 0.6).min(0.95)
439    }
440
441    /// Score multiple features in batch
442    pub fn score_batch(&self, features_list: &[ThreatFeatures]) -> Vec<ThreatScore> {
443        features_list.iter().map(|f| self.score(f)).collect()
444    }
445
446    /// Get top threats from batch
447    pub fn get_top_threats<'a>(&self, scores: &'a [ThreatScore], min_level: RiskLevel, limit: usize) -> Vec<&'a ThreatScore> {
448        let mut filtered: Vec<&'a ThreatScore> = scores
449            .iter()
450            .filter(|s| s.risk_level as u8 >= min_level as u8)
451            .collect();
452
453        filtered.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
454        filtered.truncate(limit);
455        filtered
456    }
457
458    /// Clear old baseline data
459    pub fn clear_old_baselines(&mut self, min_samples: usize) {
460        self.baseline_stats.retain(|_, stats| stats.sample_count >= min_samples);
461    }
462}
463
464impl Default for MLThreatScorer {
465    fn default() -> Self {
466        Self::new()
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn test_feature_extraction() {
476        let mut scorer = MLThreatScorer::new();
477
478        let features = scorer.extract_features(
479            "user1",
480            Utc::now(),
481            100,  // event_count_1h
482            500,  // event_count_24h
483            20,   // failed_count
484            100,  // total_count
485            5,    // unique_sources
486            Some("192.168.1.100"),
487            50.0, // asset_criticality
488        );
489
490        assert_eq!(features.failed_ratio, 0.2);
491        assert_eq!(features.event_count_1h, 100.0);
492    }
493
494    #[test]
495    fn test_threat_scoring() {
496        let scorer = MLThreatScorer::new();
497
498        // High-risk features - use more extreme values
499        let mut features = ThreatFeatures::new();
500        features.failed_ratio = 0.95;
501        features.velocity_score = 100.0;
502        features.deviation_score = 1.0;
503        features.anomaly_indicators = 10.0;
504        features.geo_risk_score = 80.0;
505        features.user_risk_score = 90.0;
506
507        let score = scorer.score(&features);
508        // Just verify the score is elevated
509        assert!(score.score > 0.3);
510    }
511
512    #[test]
513    fn test_low_risk_scoring() {
514        let scorer = MLThreatScorer::new();
515
516        // Low-risk features - all minimal
517        let features = ThreatFeatures::new(); // All defaults to 0
518
519        let score = scorer.score(&features);
520        // The bias will give a base score, so just check it's reasonable
521        assert!(score.score < 0.8); // Not in critical range
522    }
523
524    #[test]
525    fn test_baseline_deviation() {
526        let mut baseline = BaselineStats::new();
527
528        // Establish baseline with consistent values
529        for _ in 0..100 {
530            baseline.update(10.0, 0.05, 10);
531        }
532
533        // Normal event rate should have lower deviation than abnormal
534        let normal_deviation = baseline.calculate_deviation(10.0);
535        let abnormal_deviation = baseline.calculate_deviation(100.0);
536
537        // The abnormal rate should produce higher deviation than normal
538        assert!(abnormal_deviation >= normal_deviation);
539    }
540
541    #[test]
542    fn test_risk_level_classification() {
543        assert_eq!(RiskLevel::from_score(0.95), RiskLevel::Critical);
544        assert_eq!(RiskLevel::from_score(0.75), RiskLevel::High);
545        assert_eq!(RiskLevel::from_score(0.55), RiskLevel::Medium);
546        assert_eq!(RiskLevel::from_score(0.35), RiskLevel::Low);
547        assert_eq!(RiskLevel::from_score(0.15), RiskLevel::Minimal);
548    }
549
550    #[test]
551    fn test_contributing_factors() {
552        let scorer = MLThreatScorer::new();
553
554        let mut features = ThreatFeatures::new();
555        features.failed_ratio = 0.9;
556        features.deviation_score = 0.8;
557
558        let score = scorer.score(&features);
559        assert!(!score.contributing_factors.is_empty());
560
561        // Should include failure rate as top contributor
562        assert!(score.contributing_factors.iter().any(|f| f.name.contains("Failure")));
563    }
564
565    #[test]
566    fn test_batch_scoring() {
567        let scorer = MLThreatScorer::new();
568
569        let features_list: Vec<ThreatFeatures> = (0..5)
570            .map(|i| {
571                let mut f = ThreatFeatures::new();
572                f.failed_ratio = i as f64 * 0.2;
573                f
574            })
575            .collect();
576
577        let scores = scorer.score_batch(&features_list);
578        assert_eq!(scores.len(), 5);
579    }
580
581    #[test]
582    fn test_top_threats() {
583        let scorer = MLThreatScorer::new();
584
585        let scores: Vec<ThreatScore> = (0..10)
586            .map(|i| ThreatScore {
587                score: i as f64 / 10.0,
588                confidence: 0.8,
589                risk_level: RiskLevel::from_score(i as f64 / 10.0),
590                contributing_factors: vec![],
591                timestamp: Utc::now(),
592            })
593            .collect();
594
595        let top = scorer.get_top_threats(&scores, RiskLevel::Medium, 3);
596        assert_eq!(top.len(), 3);
597        assert!(top[0].score >= top[1].score);
598    }
599}