1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ThreatFeatures {
13 pub hour_of_day: f64,
15 pub day_of_week: f64,
16 pub is_weekend: f64,
17 pub is_business_hours: f64,
18
19 pub event_count_1h: f64,
21 pub event_count_24h: f64,
22 pub failed_ratio: f64,
23 pub unique_sources: f64,
24
25 pub velocity_score: f64, pub entropy_score: f64, pub deviation_score: f64, pub anomaly_indicators: f64, pub geo_risk_score: f64,
33 pub asset_criticality: f64,
34 pub user_risk_score: f64,
35 pub network_risk_score: f64,
36}
37
38impl ThreatFeatures {
39 pub fn new() -> Self {
41 Self {
42 hour_of_day: 0.0,
43 day_of_week: 0.0,
44 is_weekend: 0.0,
45 is_business_hours: 0.0,
46 event_count_1h: 0.0,
47 event_count_24h: 0.0,
48 failed_ratio: 0.0,
49 unique_sources: 0.0,
50 velocity_score: 0.0,
51 entropy_score: 0.0,
52 deviation_score: 0.0,
53 anomaly_indicators: 0.0,
54 geo_risk_score: 0.0,
55 asset_criticality: 0.0,
56 user_risk_score: 0.0,
57 network_risk_score: 0.0,
58 }
59 }
60
61 pub fn to_vector(&self) -> Vec<f64> {
63 vec![
64 self.hour_of_day,
65 self.day_of_week,
66 self.is_weekend,
67 self.is_business_hours,
68 self.event_count_1h,
69 self.event_count_24h,
70 self.failed_ratio,
71 self.unique_sources,
72 self.velocity_score,
73 self.entropy_score,
74 self.deviation_score,
75 self.anomaly_indicators,
76 self.geo_risk_score,
77 self.asset_criticality,
78 self.user_risk_score,
79 self.network_risk_score,
80 ]
81 }
82
83 pub fn normalize(&mut self) {
85 self.hour_of_day /= 24.0;
86 self.day_of_week /= 7.0;
87 self.event_count_1h = (self.event_count_1h / 1000.0).min(1.0);
89 self.event_count_24h = (self.event_count_24h / 10000.0).min(1.0);
90 self.unique_sources = (self.unique_sources / 100.0).min(1.0);
92 self.velocity_score = (self.velocity_score / 100.0).min(1.0);
93 self.anomaly_indicators = (self.anomaly_indicators / 10.0).min(1.0);
95 self.geo_risk_score /= 100.0;
97 self.asset_criticality /= 100.0;
98 self.user_risk_score /= 100.0;
99 self.network_risk_score /= 100.0;
100 }
101}
102
103impl Default for ThreatFeatures {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct ModelWeights {
112 pub feature_weights: Vec<f64>,
113 pub bias: f64,
114 pub threshold: f64,
115}
116
117impl ModelWeights {
118 pub fn default_security_model() -> Self {
120 Self {
121 feature_weights: vec![
122 0.05, 0.02, 0.10, -0.05, 0.15, 0.10, 0.25, 0.12, 0.18, 0.20, 0.22, 0.25, 0.15, 0.10, 0.18, 0.12, ],
139 bias: 0.1,
140 threshold: 0.5,
141 }
142 }
143}
144
145impl Default for ModelWeights {
146 fn default() -> Self {
147 Self::default_security_model()
148 }
149}
150
151pub struct MLThreatScorer {
153 weights: ModelWeights,
154 feature_history: HashMap<String, VecDeque<ThreatFeatures>>,
155 baseline_stats: HashMap<String, BaselineStats>,
156 max_history: usize,
157}
158
159#[derive(Debug, Clone)]
161pub struct BaselineStats {
162 pub mean_event_rate: f64,
163 pub std_event_rate: f64,
164 pub mean_failed_ratio: f64,
165 pub typical_hours: Vec<u32>,
166 pub sample_count: usize,
167}
168
169impl BaselineStats {
170 pub fn new() -> Self {
171 Self {
172 mean_event_rate: 10.0,
173 std_event_rate: 5.0,
174 mean_failed_ratio: 0.05,
175 typical_hours: (9..18).collect(),
176 sample_count: 0,
177 }
178 }
179
180 pub fn update(&mut self, event_rate: f64, failed_ratio: f64, hour: u32) {
182 self.sample_count += 1;
183 let n = self.sample_count as f64;
184
185 let old_mean = self.mean_event_rate;
187 self.mean_event_rate += (event_rate - old_mean) / n;
188 self.std_event_rate += (event_rate - old_mean) * (event_rate - self.mean_event_rate);
189
190 self.mean_failed_ratio += (failed_ratio - self.mean_failed_ratio) / n;
191
192 if !self.typical_hours.contains(&hour) && self.sample_count > 10 {
193 self.typical_hours.push(hour);
194 }
195 }
196
197 pub fn calculate_deviation(&self, event_rate: f64) -> f64 {
199 if self.std_event_rate == 0.0 {
200 return 0.0;
201 }
202 let std = (self.std_event_rate / self.sample_count.max(1) as f64).sqrt();
203 ((event_rate - self.mean_event_rate) / std.max(1.0)).abs().min(3.0) / 3.0
204 }
205}
206
207impl Default for BaselineStats {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct ThreatScore {
216 pub score: f64,
217 pub confidence: f64,
218 pub risk_level: RiskLevel,
219 pub contributing_factors: Vec<ContributingFactor>,
220 pub timestamp: DateTime<Utc>,
221}
222
223#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
225pub enum RiskLevel {
226 Minimal,
227 Low,
228 Medium,
229 High,
230 Critical,
231}
232
233impl RiskLevel {
234 pub fn from_score(score: f64) -> Self {
235 match score {
236 s if s >= 0.9 => RiskLevel::Critical,
237 s if s >= 0.7 => RiskLevel::High,
238 s if s >= 0.5 => RiskLevel::Medium,
239 s if s >= 0.3 => RiskLevel::Low,
240 _ => RiskLevel::Minimal,
241 }
242 }
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct ContributingFactor {
248 pub name: String,
249 pub value: f64,
250 pub contribution: f64,
251 pub description: String,
252}
253
254impl MLThreatScorer {
255 pub fn new() -> Self {
257 Self {
258 weights: ModelWeights::default(),
259 feature_history: HashMap::new(),
260 baseline_stats: HashMap::new(),
261 max_history: 1000,
262 }
263 }
264
265 pub fn with_weights(weights: ModelWeights) -> Self {
267 Self {
268 weights,
269 feature_history: HashMap::new(),
270 baseline_stats: HashMap::new(),
271 max_history: 1000,
272 }
273 }
274
275 pub fn extract_features(
277 &mut self,
278 entity_id: &str,
279 timestamp: DateTime<Utc>,
280 event_count_1h: usize,
281 event_count_24h: usize,
282 failed_count: usize,
283 total_count: usize,
284 unique_sources: usize,
285 source_ip: Option<&str>,
286 asset_criticality: f64,
287 ) -> ThreatFeatures {
288 let hour = timestamp.format("%H").to_string().parse::<f64>().unwrap_or(0.0);
289 let day = timestamp.format("%u").to_string().parse::<f64>().unwrap_or(1.0);
290 let is_weekend = if day >= 6.0 { 1.0 } else { 0.0 };
291 let is_business_hours = if hour >= 9.0 && hour <= 17.0 && day < 6.0 { 1.0 } else { 0.0 };
292
293 let failed_ratio = if total_count > 0 {
294 failed_count as f64 / total_count as f64
295 } else {
296 0.0
297 };
298
299 let velocity = event_count_1h as f64 / 60.0;
301
302 let baseline = self.baseline_stats
304 .entry(entity_id.to_string())
305 .or_insert_with(BaselineStats::new);
306
307 let deviation = baseline.calculate_deviation(event_count_1h as f64);
308
309 baseline.update(event_count_1h as f64, failed_ratio, hour as u32);
311
312 let entropy = if unique_sources > 1 {
314 (unique_sources as f64).ln() / 10.0_f64.ln()
315 } else {
316 0.0
317 };
318
319 let geo_risk = match source_ip {
321 Some(ip) if ip.starts_with("10.") || ip.starts_with("192.168.") => 10.0,
322 Some(_) => 50.0, None => 30.0, };
325
326 let user_risk = if failed_ratio > 0.3 { 70.0 } else { 20.0 };
328
329 let network_risk = if unique_sources > 10 { 60.0 } else { 20.0 };
331
332 let mut anomaly_count = 0.0;
334 if is_weekend > 0.0 && event_count_1h > 100 { anomaly_count += 1.0; }
335 if failed_ratio > 0.5 { anomaly_count += 2.0; }
336 if deviation > 0.5 { anomaly_count += 1.0; }
337 if velocity > 10.0 { anomaly_count += 1.0; }
338
339 ThreatFeatures {
340 hour_of_day: hour,
341 day_of_week: day,
342 is_weekend,
343 is_business_hours,
344 event_count_1h: event_count_1h as f64,
345 event_count_24h: event_count_24h as f64,
346 failed_ratio,
347 unique_sources: unique_sources as f64,
348 velocity_score: velocity,
349 entropy_score: entropy,
350 deviation_score: deviation,
351 anomaly_indicators: anomaly_count,
352 geo_risk_score: geo_risk,
353 asset_criticality,
354 user_risk_score: user_risk,
355 network_risk_score: network_risk,
356 }
357 }
358
359 pub fn score(&self, features: &ThreatFeatures) -> ThreatScore {
361 let mut normalized = features.clone();
362 normalized.normalize();
363
364 let feature_vec = normalized.to_vector();
365 let mut raw_score = self.weights.bias;
366 let mut contributing_factors = Vec::new();
367
368 let factor_names = [
369 "Hour of Day", "Day of Week", "Weekend Activity", "Business Hours",
370 "Event Volume (1h)", "Event Volume (24h)", "Failure Rate", "Unique Sources",
371 "Velocity", "Entropy", "Baseline Deviation", "Anomaly Indicators",
372 "Geographic Risk", "Asset Criticality", "User Risk", "Network Risk",
373 ];
374
375 for (i, (&value, &weight)) in feature_vec.iter().zip(self.weights.feature_weights.iter()).enumerate() {
376 let contribution = value * weight;
377 raw_score += contribution;
378
379 if contribution.abs() > 0.01 {
380 contributing_factors.push(ContributingFactor {
381 name: factor_names.get(i).unwrap_or(&"Unknown").to_string(),
382 value,
383 contribution,
384 description: self.describe_contribution(factor_names.get(i).unwrap_or(&""), value),
385 });
386 }
387 }
388
389 let score = 1.0 / (1.0 + (-raw_score).exp());
391
392 contributing_factors.sort_by(|a, b| {
394 b.contribution.abs().partial_cmp(&a.contribution.abs()).unwrap()
395 });
396 contributing_factors.truncate(5);
397
398 let confidence = self.calculate_confidence(features);
400
401 ThreatScore {
402 score,
403 confidence,
404 risk_level: RiskLevel::from_score(score),
405 contributing_factors,
406 timestamp: Utc::now(),
407 }
408 }
409
410 fn describe_contribution(&self, name: &str, value: f64) -> String {
412 match name {
413 "Failure Rate" if value > 0.5 => "High failure rate indicates potential brute force".to_string(),
414 "Failure Rate" => "Normal failure rate".to_string(),
415 "Weekend Activity" if value > 0.0 => "Activity during weekend (unusual)".to_string(),
416 "Velocity" if value > 0.5 => "Rapid event generation (suspicious)".to_string(),
417 "Baseline Deviation" if value > 0.5 => "Significant deviation from normal behavior".to_string(),
418 "Geographic Risk" if value > 0.5 => "External or suspicious source location".to_string(),
419 "Anomaly Indicators" if value > 0.0 => "Multiple anomaly flags detected".to_string(),
420 _ => format!("{} score: {:.2}", name, value),
421 }
422 }
423
424 fn calculate_confidence(&self, features: &ThreatFeatures) -> f64 {
426 let event_factor = (features.event_count_24h / 100.0).min(1.0);
428
429 let clarity_factor = if features.failed_ratio > 0.5 || features.deviation_score > 0.5 {
431 0.9
432 } else if features.failed_ratio < 0.1 && features.deviation_score < 0.2 {
433 0.9
434 } else {
435 0.6
436 };
437
438 (event_factor * 0.4 + clarity_factor * 0.6).min(0.95)
439 }
440
441 pub fn score_batch(&self, features_list: &[ThreatFeatures]) -> Vec<ThreatScore> {
443 features_list.iter().map(|f| self.score(f)).collect()
444 }
445
446 pub fn get_top_threats<'a>(&self, scores: &'a [ThreatScore], min_level: RiskLevel, limit: usize) -> Vec<&'a ThreatScore> {
448 let mut filtered: Vec<&'a ThreatScore> = scores
449 .iter()
450 .filter(|s| s.risk_level as u8 >= min_level as u8)
451 .collect();
452
453 filtered.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
454 filtered.truncate(limit);
455 filtered
456 }
457
458 pub fn clear_old_baselines(&mut self, min_samples: usize) {
460 self.baseline_stats.retain(|_, stats| stats.sample_count >= min_samples);
461 }
462}
463
464impl Default for MLThreatScorer {
465 fn default() -> Self {
466 Self::new()
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_feature_extraction() {
476 let mut scorer = MLThreatScorer::new();
477
478 let features = scorer.extract_features(
479 "user1",
480 Utc::now(),
481 100, 500, 20, 100, 5, Some("192.168.1.100"),
487 50.0, );
489
490 assert_eq!(features.failed_ratio, 0.2);
491 assert_eq!(features.event_count_1h, 100.0);
492 }
493
494 #[test]
495 fn test_threat_scoring() {
496 let scorer = MLThreatScorer::new();
497
498 let mut features = ThreatFeatures::new();
500 features.failed_ratio = 0.95;
501 features.velocity_score = 100.0;
502 features.deviation_score = 1.0;
503 features.anomaly_indicators = 10.0;
504 features.geo_risk_score = 80.0;
505 features.user_risk_score = 90.0;
506
507 let score = scorer.score(&features);
508 assert!(score.score > 0.3);
510 }
511
512 #[test]
513 fn test_low_risk_scoring() {
514 let scorer = MLThreatScorer::new();
515
516 let features = ThreatFeatures::new(); let score = scorer.score(&features);
520 assert!(score.score < 0.8); }
523
524 #[test]
525 fn test_baseline_deviation() {
526 let mut baseline = BaselineStats::new();
527
528 for _ in 0..100 {
530 baseline.update(10.0, 0.05, 10);
531 }
532
533 let normal_deviation = baseline.calculate_deviation(10.0);
535 let abnormal_deviation = baseline.calculate_deviation(100.0);
536
537 assert!(abnormal_deviation >= normal_deviation);
539 }
540
541 #[test]
542 fn test_risk_level_classification() {
543 assert_eq!(RiskLevel::from_score(0.95), RiskLevel::Critical);
544 assert_eq!(RiskLevel::from_score(0.75), RiskLevel::High);
545 assert_eq!(RiskLevel::from_score(0.55), RiskLevel::Medium);
546 assert_eq!(RiskLevel::from_score(0.35), RiskLevel::Low);
547 assert_eq!(RiskLevel::from_score(0.15), RiskLevel::Minimal);
548 }
549
550 #[test]
551 fn test_contributing_factors() {
552 let scorer = MLThreatScorer::new();
553
554 let mut features = ThreatFeatures::new();
555 features.failed_ratio = 0.9;
556 features.deviation_score = 0.8;
557
558 let score = scorer.score(&features);
559 assert!(!score.contributing_factors.is_empty());
560
561 assert!(score.contributing_factors.iter().any(|f| f.name.contains("Failure")));
563 }
564
565 #[test]
566 fn test_batch_scoring() {
567 let scorer = MLThreatScorer::new();
568
569 let features_list: Vec<ThreatFeatures> = (0..5)
570 .map(|i| {
571 let mut f = ThreatFeatures::new();
572 f.failed_ratio = i as f64 * 0.2;
573 f
574 })
575 .collect();
576
577 let scores = scorer.score_batch(&features_list);
578 assert_eq!(scores.len(), 5);
579 }
580
581 #[test]
582 fn test_top_threats() {
583 let scorer = MLThreatScorer::new();
584
585 let scores: Vec<ThreatScore> = (0..10)
586 .map(|i| ThreatScore {
587 score: i as f64 / 10.0,
588 confidence: 0.8,
589 risk_level: RiskLevel::from_score(i as f64 / 10.0),
590 contributing_factors: vec![],
591 timestamp: Utc::now(),
592 })
593 .collect();
594
595 let top = scorer.get_top_threats(&scores, RiskLevel::Medium, 3);
596 assert_eq!(top.len(), 3);
597 assert!(top[0].score >= top[1].score);
598 }
599}