1pub mod mitre_attack;
21pub use mitre_attack::{AttackTactic, AttackTechnique, MitreAttackDetector, ThreatDetection};
22
23use chrono::{DateTime, Duration, Utc};
24use regex::Regex;
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use thiserror::Error;
28
29#[derive(Error, Debug)]
31pub enum DetectionError {
32 #[error("Invalid log format: {0}")]
33 InvalidLogFormat(String),
34
35 #[error("Pattern compilation failed: {0}")]
36 PatternError(String),
37}
38
39#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
41pub enum ThreatSeverity {
42 Info,
43 Low,
44 Medium,
45 High,
46 Critical,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
51pub enum ThreatCategory {
52 BruteForce,
53 MalwareDetection,
54 DataExfiltration,
55 UnauthorizedAccess,
56 AnomalousActivity,
57 PolicyViolation,
58 SystemCompromise,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct LogEntry {
64 pub timestamp: DateTime<Utc>,
65 pub source_ip: Option<String>,
66 pub user: Option<String>,
67 pub event_type: String,
68 pub message: String,
69 pub metadata: HashMap<String, String>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ThreatAlert {
75 pub alert_id: String,
76 pub timestamp: DateTime<Utc>,
77 pub severity: ThreatSeverity,
78 pub category: ThreatCategory,
79 pub description: String,
80 pub source_log: String,
81 pub indicators: Vec<String>,
82 pub recommended_action: String,
83 pub threat_score: u32,
84 pub correlated_alerts: Vec<String>,
85}
86
87impl ThreatAlert {
88 pub fn to_json(&self) -> Result<String, serde_json::Error> {
90 serde_json::to_string_pretty(self)
91 }
92
93 pub fn risk_assessment(&self) -> &str {
95 match self.threat_score {
96 0..=20 => "Low Risk",
97 21..=50 => "Medium Risk",
98 51..=80 => "High Risk",
99 _ => "Critical Risk",
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
106struct AlertAggregation {
107 category: ThreatCategory,
108 first_seen: DateTime<Utc>,
109 last_seen: DateTime<Utc>,
110 count: usize,
111 sources: Vec<String>,
112}
113
114#[derive(Debug, Clone)]
116pub struct ThreatPattern {
117 pub name: String,
118 pub category: ThreatCategory,
119 pub severity: ThreatSeverity,
120 pub pattern: Regex,
121 pub description: String,
122 pub recommended_action: String,
123}
124
125pub struct ThreatDetector {
127 patterns: Vec<ThreatPattern>,
128 alert_count: usize,
129 alert_history: Vec<ThreatAlert>,
130 aggregations: HashMap<String, AlertAggregation>,
131}
132
133impl ThreatDetector {
134 pub fn new() -> Self {
136 let mut detector = Self {
137 patterns: Vec::new(),
138 alert_count: 0,
139 alert_history: Vec::new(),
140 aggregations: HashMap::new(),
141 };
142 detector.load_default_patterns();
143 detector
144 }
145
146 fn load_default_patterns(&mut self) {
148 self.add_pattern(ThreatPattern {
150 name: "Failed Login Attempts".to_string(),
151 category: ThreatCategory::BruteForce,
152 severity: ThreatSeverity::High,
153 pattern: Regex::new(r"(?i)(failed.*login|authentication.*failed|invalid.*password)").unwrap(),
154 description: "Multiple failed login attempts detected".to_string(),
155 recommended_action: "Block source IP, enable MFA, review user account".to_string(),
156 });
157
158 self.add_pattern(ThreatPattern {
160 name: "Malware Signature".to_string(),
161 category: ThreatCategory::MalwareDetection,
162 severity: ThreatSeverity::Critical,
163 pattern: Regex::new(r"(?i)(malware|virus|trojan|ransomware|backdoor)").unwrap(),
164 description: "Malware signature detected in logs".to_string(),
165 recommended_action: "Isolate system, run full scan, investigate infection vector".to_string(),
166 });
167
168 self.add_pattern(ThreatPattern {
170 name: "Large Data Transfer".to_string(),
171 category: ThreatCategory::DataExfiltration,
172 severity: ThreatSeverity::High,
173 pattern: Regex::new(r"(?i)(large.*transfer|exfiltration|unusual.*download)").unwrap(),
174 description: "Potential data exfiltration detected".to_string(),
175 recommended_action: "Block transfer, investigate user activity, review DLP policies".to_string(),
176 });
177
178 self.add_pattern(ThreatPattern {
180 name: "Privilege Escalation".to_string(),
181 category: ThreatCategory::UnauthorizedAccess,
182 severity: ThreatSeverity::Critical,
183 pattern: Regex::new(r"(?i)(privilege.*escalation|unauthorized.*access|sudo|admin.*access)").unwrap(),
184 description: "Unauthorized privilege escalation attempt".to_string(),
185 recommended_action: "Revoke privileges, investigate account, review access logs".to_string(),
186 });
187
188 self.add_pattern(ThreatPattern {
190 name: "SQL Injection Attempt".to_string(),
191 category: ThreatCategory::SystemCompromise,
192 severity: ThreatSeverity::Critical,
193 pattern: Regex::new(r"(?i)(union.*select|' or '1'='1|drop.*table|;--|exec\()").unwrap(),
194 description: "SQL injection attack detected".to_string(),
195 recommended_action: "Block source IP, patch application, review WAF rules".to_string(),
196 });
197
198 self.add_pattern(ThreatPattern {
200 name: "Suspicious IP Address".to_string(),
201 category: ThreatCategory::AnomalousActivity,
202 severity: ThreatSeverity::Medium,
203 pattern: Regex::new(r"(^0\.|^10\.|^127\.|^169\.254\.|^172\.(1[6-9]|2[0-9]|3[0-1])\.|^192\.168\.|^224\.)").unwrap(),
204 description: "Access from suspicious IP range".to_string(),
205 recommended_action: "Verify IP legitimacy, check geo-location, review firewall rules".to_string(),
206 });
207 }
208
209 pub fn add_pattern(&mut self, pattern: ThreatPattern) {
211 self.patterns.push(pattern);
212 }
213
214 pub fn analyze(&mut self, log: &LogEntry) -> Vec<ThreatAlert> {
216 let mut alerts = Vec::new();
217 let mut new_alerts = Vec::new();
218
219 for pattern in &self.patterns {
220 if pattern.pattern.is_match(&log.message) {
221 self.alert_count += 1;
222
223 let threat_score = self.calculate_threat_score(pattern.severity, log);
225
226 let correlated = self.find_correlated_alerts(&pattern.category, log);
228
229 let alert = ThreatAlert {
230 alert_id: format!("ALERT-{:08}", self.alert_count),
231 timestamp: Utc::now(),
232 severity: pattern.severity,
233 category: pattern.category.clone(),
234 description: format!("{}: {}", pattern.name, pattern.description),
235 source_log: format!("{} - {}", log.timestamp, log.message),
236 indicators: self.extract_indicators(&log.message, &pattern.pattern),
237 recommended_action: pattern.recommended_action.clone(),
238 threat_score,
239 correlated_alerts: correlated,
240 };
241
242 new_alerts.push((alert.clone(), log.clone()));
243 alerts.push(alert);
244 }
245 }
246
247 for (alert, log) in new_alerts {
249 self.update_aggregation(&alert, &log);
250 self.alert_history.push(alert);
251 }
252
253 alerts
254 }
255
256 fn calculate_threat_score(&self, severity: ThreatSeverity, log: &LogEntry) -> u32 {
258 let mut score: u32 = match severity {
259 ThreatSeverity::Info => 5,
260 ThreatSeverity::Low => 15,
261 ThreatSeverity::Medium => 40,
262 ThreatSeverity::High => 70,
263 ThreatSeverity::Critical => 95,
264 };
265
266 if let Some(ref ip) = log.source_ip {
268 if !ip.starts_with("192.168.") && !ip.starts_with("10.") {
269 score += 10;
270 }
271 }
272
273 if let Some(ref user) = log.user {
275 if user.contains("admin") || user.contains("root") {
276 score += 15;
277 }
278 }
279
280 score.min(100)
281 }
282
283 fn find_correlated_alerts(&self, category: &ThreatCategory, log: &LogEntry) -> Vec<String> {
285 let window_start = log.timestamp - Duration::hours(1);
286 let mut correlated = Vec::new();
287
288 for alert in &self.alert_history {
289 if alert.timestamp >= window_start && alert.category == *category {
290 if let Some(ref ip) = log.source_ip {
292 if alert.source_log.contains(ip) {
293 correlated.push(alert.alert_id.clone());
294 }
295 }
296 }
297 }
298
299 correlated
300 }
301
302 fn update_aggregation(&mut self, alert: &ThreatAlert, log: &LogEntry) {
304 let key = format!("{:?}", alert.category);
305 let source = log.source_ip.clone().unwrap_or_else(|| "unknown".to_string());
306
307 self.aggregations
308 .entry(key.clone())
309 .and_modify(|agg| {
310 agg.last_seen = alert.timestamp;
311 agg.count += 1;
312 if !agg.sources.contains(&source) {
313 agg.sources.push(source.clone());
314 }
315 })
316 .or_insert(AlertAggregation {
317 category: alert.category.clone(),
318 first_seen: alert.timestamp,
319 last_seen: alert.timestamp,
320 count: 1,
321 sources: vec![source],
322 });
323 }
324
325 fn extract_indicators(&self, message: &str, pattern: &Regex) -> Vec<String> {
327 let mut indicators = Vec::new();
328
329 if let Some(captures) = pattern.captures(message) {
330 for i in 1..captures.len() {
331 if let Some(matched) = captures.get(i) {
332 indicators.push(matched.as_str().to_string());
333 }
334 }
335 }
336
337 if indicators.is_empty() {
338 indicators.push("Pattern match".to_string());
339 }
340
341 indicators
342 }
343
344 pub fn analyze_batch(&mut self, logs: &[LogEntry]) -> Vec<ThreatAlert> {
346 let mut all_alerts = Vec::new();
347 for log in logs {
348 all_alerts.extend(self.analyze(log));
349 }
350 all_alerts
351 }
352
353 pub fn get_alert_history(&self, since: DateTime<Utc>) -> Vec<&ThreatAlert> {
355 self.alert_history
356 .iter()
357 .filter(|alert| alert.timestamp >= since)
358 .collect()
359 }
360
361 pub fn get_aggregations(&self) -> HashMap<String, (usize, usize)> {
363 self.aggregations
364 .iter()
365 .map(|(k, v)| (k.clone(), (v.count, v.sources.len())))
366 .collect()
367 }
368
369 pub fn deduplicate_alerts(&mut self, window_minutes: i64) -> usize {
371 let mut to_remove = Vec::new();
372 let mut seen = HashMap::new();
373
374 for (i, alert) in self.alert_history.iter().enumerate() {
375 let key = format!("{:?}-{}", alert.category, alert.source_log);
376 if let Some(&prev_idx) = seen.get(&key) {
377 let prev_alert: &ThreatAlert = &self.alert_history[prev_idx];
378 if alert.timestamp.signed_duration_since(prev_alert.timestamp).num_minutes() <= window_minutes {
379 to_remove.push(i);
380 continue;
381 }
382 }
383 seen.insert(key, i);
384 }
385
386 let removed_count = to_remove.len();
387 for &idx in to_remove.iter().rev() {
389 self.alert_history.remove(idx);
390 }
391
392 removed_count
393 }
394
395 pub fn get_top_sources(&self, limit: usize) -> Vec<(String, usize)> {
397 let mut source_counts: HashMap<String, usize> = HashMap::new();
398
399 for agg in self.aggregations.values() {
400 for source in &agg.sources {
401 *source_counts.entry(source.clone()).or_insert(0) += 1;
402 }
403 }
404
405 let mut sorted: Vec<(String, usize)> = source_counts.into_iter().collect();
406 sorted.sort_by(|a, b| b.1.cmp(&a.1));
407 sorted.truncate(limit);
408 sorted
409 }
410
411 pub fn clear_old_alerts(&mut self, before: DateTime<Utc>) {
413 self.alert_history.retain(|alert| alert.timestamp >= before);
414 }
415
416 pub fn get_stats(&self) -> HashMap<String, usize> {
418 let mut stats = HashMap::new();
419 stats.insert("total_patterns".to_string(), self.patterns.len());
420 stats.insert("total_alerts".to_string(), self.alert_count);
421 stats.insert("alerts_in_history".to_string(), self.alert_history.len());
422 stats.insert("active_aggregations".to_string(), self.aggregations.len());
423 stats
424 }
425
426 pub fn filter_by_severity(
428 &self,
429 alerts: &[ThreatAlert],
430 min_severity: ThreatSeverity,
431 ) -> Vec<ThreatAlert> {
432 alerts
433 .iter()
434 .filter(|alert| alert.severity >= min_severity)
435 .cloned()
436 .collect()
437 }
438
439 pub fn filter_by_category(
441 &self,
442 alerts: &[ThreatAlert],
443 category: &ThreatCategory,
444 ) -> Vec<ThreatAlert> {
445 alerts
446 .iter()
447 .filter(|alert| alert.category == *category)
448 .cloned()
449 .collect()
450 }
451}
452
453impl Default for ThreatDetector {
454 fn default() -> Self {
455 Self::new()
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 fn create_log_entry(message: &str) -> LogEntry {
464 LogEntry {
465 timestamp: Utc::now(),
466 source_ip: Some("192.168.1.100".to_string()),
467 user: Some("test_user".to_string()),
468 event_type: "security_event".to_string(),
469 message: message.to_string(),
470 metadata: HashMap::new(),
471 }
472 }
473
474 #[test]
475 fn test_brute_force_detection() {
476 let mut detector = ThreatDetector::new();
477 let log = create_log_entry("Failed login attempt for user admin");
478
479 let alerts = detector.analyze(&log);
480 assert!(!alerts.is_empty());
481 assert_eq!(alerts[0].category, ThreatCategory::BruteForce);
482 }
483
484 #[test]
485 fn test_malware_detection() {
486 let mut detector = ThreatDetector::new();
487 let log = create_log_entry("Malware detected in file system");
488
489 let alerts = detector.analyze(&log);
490 assert!(!alerts.is_empty());
491 assert_eq!(alerts[0].category, ThreatCategory::MalwareDetection);
492 assert_eq!(alerts[0].severity, ThreatSeverity::Critical);
493 }
494
495 #[test]
496 fn test_sql_injection_detection() {
497 let mut detector = ThreatDetector::new();
498 let log = create_log_entry("Query: SELECT * FROM users WHERE id='1' OR '1'='1'");
499
500 let alerts = detector.analyze(&log);
501 assert!(!alerts.is_empty());
502 assert_eq!(alerts[0].severity, ThreatSeverity::Critical);
503 }
504
505 #[test]
506 fn test_no_threat_detected() {
507 let mut detector = ThreatDetector::new();
508 let log = create_log_entry("User successfully logged in");
509
510 let alerts = detector.analyze(&log);
511 assert!(alerts.is_empty());
512 }
513
514 #[test]
515 fn test_severity_filtering() {
516 let mut detector = ThreatDetector::new();
517 let log1 = create_log_entry("Failed login attempt");
518 let log2 = create_log_entry("Malware detected");
519
520 let mut all_alerts = Vec::new();
521 all_alerts.extend(detector.analyze(&log1));
522 all_alerts.extend(detector.analyze(&log2));
523
524 let critical_alerts = detector.filter_by_severity(&all_alerts, ThreatSeverity::Critical);
525 assert_eq!(critical_alerts.len(), 1);
526 assert_eq!(critical_alerts[0].category, ThreatCategory::MalwareDetection);
527 }
528
529 #[test]
530 fn test_threat_scoring() {
531 let mut detector = ThreatDetector::new();
532
533 let log = LogEntry {
535 timestamp: Utc::now(),
536 source_ip: Some("1.2.3.4".to_string()), user: Some("admin".to_string()), event_type: "security_event".to_string(),
539 message: "Malware detected".to_string(),
540 metadata: HashMap::new(),
541 };
542
543 let alerts = detector.analyze(&log);
544 assert!(!alerts.is_empty());
545 assert!(alerts[0].threat_score > 95); }
547
548 #[test]
549 fn test_alert_correlation() {
550 let mut detector = ThreatDetector::new();
551 let ip = "192.168.1.100".to_string();
552
553 let log1 = LogEntry {
555 timestamp: Utc::now(),
556 source_ip: Some(ip.clone()),
557 user: Some("user1".to_string()),
558 event_type: "security_event".to_string(),
559 message: "Failed login attempt".to_string(),
560 metadata: HashMap::new(),
561 };
562
563 let alerts1 = detector.analyze(&log1);
564 assert_eq!(alerts1[0].correlated_alerts.len(), 0);
565
566 let log2 = LogEntry {
568 timestamp: Utc::now(),
569 source_ip: Some(ip),
570 user: Some("user2".to_string()),
571 event_type: "security_event".to_string(),
572 message: "Failed login attempt again".to_string(),
573 metadata: HashMap::new(),
574 };
575
576 let alerts2 = detector.analyze(&log2);
577 assert!(alerts2[0].correlated_alerts.len() > 0); }
579
580 #[test]
581 fn test_batch_analysis() {
582 let mut detector = ThreatDetector::new();
583
584 let logs = vec![
585 create_log_entry("Failed login"),
586 create_log_entry("Malware detected"),
587 create_log_entry("Normal activity"),
588 ];
589
590 let alerts = detector.analyze_batch(&logs);
591 assert_eq!(alerts.len(), 2); }
593
594 #[test]
595 fn test_alert_history() {
596 let mut detector = ThreatDetector::new();
597
598 let log1 = create_log_entry("Failed login");
599 let log2 = create_log_entry("Malware detected");
600
601 detector.analyze(&log1);
602 detector.analyze(&log2);
603
604 let since = Utc::now() - Duration::hours(1);
605 let history = detector.get_alert_history(since);
606 assert_eq!(history.len(), 2);
607 }
608
609 #[test]
610 fn test_aggregations() {
611 let mut detector = ThreatDetector::new();
612
613 for _ in 0..5 {
614 let log = create_log_entry("Failed login attempt");
615 detector.analyze(&log);
616 }
617
618 let aggregations = detector.get_aggregations();
619 assert!(!aggregations.is_empty());
620
621 let brute_force_key = format!("{:?}", ThreatCategory::BruteForce);
623 if let Some(&(count, _sources)) = aggregations.get(&brute_force_key) {
624 assert_eq!(count, 5);
625 }
626 }
627
628 #[test]
629 fn test_deduplication() {
630 let mut detector = ThreatDetector::new();
631
632 for _ in 0..3 {
634 let log = create_log_entry("Failed login attempt");
635 detector.analyze(&log);
636 }
637
638 let initial_count = detector.alert_history.len();
639 assert_eq!(initial_count, 3);
640
641 let removed = detector.deduplicate_alerts(60); assert!(removed > 0); assert!(detector.alert_history.len() < initial_count);
644 }
645
646 #[test]
647 fn test_top_sources() {
648 let mut detector = ThreatDetector::new();
649
650 for i in 0..5 {
652 let log = LogEntry {
653 timestamp: Utc::now(),
654 source_ip: Some(format!("192.168.1.{}", i)),
655 user: Some("user1".to_string()),
656 event_type: "security_event".to_string(),
657 message: "Failed login attempt".to_string(),
658 metadata: HashMap::new(),
659 };
660 detector.analyze(&log);
661 }
662
663 let top_sources = detector.get_top_sources(3);
664 assert!(top_sources.len() <= 3);
665 }
666
667 #[test]
668 fn test_clear_old_alerts() {
669 let mut detector = ThreatDetector::new();
670
671 let log = create_log_entry("Failed login");
672 detector.analyze(&log);
673
674 assert_eq!(detector.alert_history.len(), 1);
675
676 let cutoff = Utc::now() + Duration::hours(1); detector.clear_old_alerts(cutoff);
678
679 assert_eq!(detector.alert_history.len(), 0);
680 }
681
682 #[test]
683 fn test_risk_assessment() {
684 let alert = ThreatAlert {
685 alert_id: "TEST-001".to_string(),
686 timestamp: Utc::now(),
687 severity: ThreatSeverity::Critical,
688 category: ThreatCategory::MalwareDetection,
689 description: "Test alert".to_string(),
690 source_log: "Test log".to_string(),
691 indicators: vec![],
692 recommended_action: "Test action".to_string(),
693 threat_score: 95,
694 correlated_alerts: vec![],
695 };
696
697 assert_eq!(alert.risk_assessment(), "Critical Risk");
698
699 let low_alert = ThreatAlert {
700 threat_score: 15,
701 ..alert
702 };
703 assert_eq!(low_alert.risk_assessment(), "Low Risk");
704 }
705
706 #[test]
707 fn test_category_filtering() {
708 let mut detector = ThreatDetector::new();
709 let log1 = create_log_entry("Failed login attempt");
710 let log2 = create_log_entry("Malware detected");
711
712 let mut all_alerts = Vec::new();
713 all_alerts.extend(detector.analyze(&log1));
714 all_alerts.extend(detector.analyze(&log2));
715
716 let brute_force_alerts =
717 detector.filter_by_category(&all_alerts, &ThreatCategory::BruteForce);
718 assert_eq!(brute_force_alerts.len(), 1);
719 assert_eq!(brute_force_alerts[0].category, ThreatCategory::BruteForce);
720 }
721
722 #[test]
723 fn test_json_export() {
724 let mut detector = ThreatDetector::new();
725 let log = create_log_entry("Failed login attempt");
726 let alerts = detector.analyze(&log);
727
728 let json = alerts[0].to_json();
729 assert!(json.is_ok());
730 let json_str = json.unwrap();
731 assert!(json_str.contains("threat_score"));
732 assert!(json_str.contains("correlated_alerts"));
733 }
734}