rust_threat_detector/
lib.rs

1//! # Rust Threat Detector v2.0
2//!
3//! An advanced memory-safe SIEM threat detection component with ML-based scoring,
4//! automated incident response, and proactive threat hunting capabilities.
5//!
6//! ## What's New in v2.0
7//!
8//! - **ML-Based Scoring**: Feature-engineered threat scoring with statistical models
9//! - **Automated Incident Response**: Playbook-driven response workflows
10//! - **Threat Hunting**: Hypothesis-driven hunting with IOC sweeps
11//! - **Enhanced Detection**: Improved pattern matching and behavioral analysis
12//!
13//! ## Features
14//!
15//! - **Memory Safety**: Built with Rust to prevent vulnerabilities in security tools
16//! - **Real-time Analysis**: Fast pattern matching and threat detection
17//! - **MITRE ATT&CK Framework**: 10+ technique detection patterns
18//! - **Pattern Library**: Pre-configured threat patterns
19//! - **Behavioral Analytics**: User and Entity Behavior Analytics (UEBA) for anomaly detection
20//! - **Threat Intelligence**: IOC matching against known malicious indicators
21//! - **Anomaly Detection**: Statistical and machine learning-based anomaly detection
22//! - **Alert Generation**: Structured alert output for SIEM integration
23//! - **SIEM Export**: Multiple export formats (CEF, LEEF, JSON, Syslog)
24//! - **Incident Response**: Automated playbooks and response actions
25//! - **Threat Hunting**: Proactive hunting capabilities with templates
26//!
27//! ## Quick Start
28//!
29//! ```rust,no_run
30//! use rust_threat_detector::{ThreatDetector, LogEntry};
31//! use chrono::Utc;
32//! use std::collections::HashMap;
33//!
34//! let mut detector = ThreatDetector::new();
35//!
36//! let log = LogEntry {
37//!     timestamp: Utc::now(),
38//!     source_ip: Some("192.168.1.100".to_string()),
39//!     user: Some("admin".to_string()),
40//!     event_type: "auth".to_string(),
41//!     message: "Failed login attempt for admin".to_string(),
42//!     metadata: HashMap::new(),
43//! };
44//!
45//! let alerts = detector.analyze(&log);
46//! for alert in alerts {
47//!     println!("Alert: {} (Score: {})", alert.description, alert.threat_score);
48//! }
49//! ```
50//!
51//! ## Alignment with Federal Guidance
52//!
53//! Implements memory-safe security monitoring tools, aligning with 2024-2025 CISA/FBI/NSA
54//! guidance for critical infrastructure protection.
55
56pub mod mitre_attack;
57pub use mitre_attack::{AttackTactic, AttackTechnique, MitreAttackDetector, ThreatDetection};
58
59pub mod behavioral_analytics;
60pub use behavioral_analytics::{BehavioralAnalytics, EntityProfile, UserProfile};
61
62pub mod threat_intelligence;
63pub use threat_intelligence::{IOCType, ThreatIntelligence, IOC};
64
65pub mod siem_formats;
66pub use siem_formats::{BatchExporter, SIEMExporter, SIEMFormat};
67
68pub mod anomaly_detection;
69pub use anomaly_detection::{AnomalyDetector, AnomalyResult, DetectionMethod, TimeSeries};
70
71// v2.0 modules
72pub mod ml_scoring;
73pub use ml_scoring::{
74    MLThreatScorer, ThreatFeatures, ThreatScore, RiskLevel,
75    ContributingFactor, ModelWeights, BaselineStats,
76};
77
78pub mod incident_response;
79pub use incident_response::{
80    IncidentResponseManager, Incident, IncidentStatus, Playbook,
81    PlaybookAction, ResponseAction, ActionResult, IncidentStatistics,
82};
83
84pub mod threat_hunting;
85pub use threat_hunting::{
86    ThreatHuntingEngine, ThreatHunt, HuntStatus, HuntQuery,
87    HuntFinding, HuntTemplate, HuntIOC, IOCType as HuntIOCType,
88    HuntStatistics, QueryMatch, IOCSweepResult,
89};
90
91use chrono::{DateTime, Duration, Utc};
92use regex::Regex;
93use serde::{Deserialize, Serialize};
94use std::collections::HashMap;
95use thiserror::Error;
96
97/// Threat detection errors
98#[derive(Error, Debug)]
99pub enum DetectionError {
100    #[error("Invalid log format: {0}")]
101    InvalidLogFormat(String),
102
103    #[error("Pattern compilation failed: {0}")]
104    PatternError(String),
105}
106
107/// Threat severity levels
108#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
109pub enum ThreatSeverity {
110    Info,
111    Low,
112    Medium,
113    High,
114    Critical,
115}
116
117/// Threat categories
118#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
119pub enum ThreatCategory {
120    BruteForce,
121    MalwareDetection,
122    DataExfiltration,
123    UnauthorizedAccess,
124    AnomalousActivity,
125    PolicyViolation,
126    SystemCompromise,
127}
128
129/// Log entry for analysis
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct LogEntry {
132    pub timestamp: DateTime<Utc>,
133    pub source_ip: Option<String>,
134    pub user: Option<String>,
135    pub event_type: String,
136    pub message: String,
137    pub metadata: HashMap<String, String>,
138}
139
140/// Detected threat
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct ThreatAlert {
143    pub alert_id: String,
144    pub timestamp: DateTime<Utc>,
145    pub severity: ThreatSeverity,
146    pub category: ThreatCategory,
147    pub description: String,
148    pub source_log: String,
149    pub indicators: Vec<String>,
150    pub recommended_action: String,
151    pub threat_score: u32,
152    pub correlated_alerts: Vec<String>,
153}
154
155impl ThreatAlert {
156    /// Export alert as JSON
157    pub fn to_json(&self) -> Result<String, serde_json::Error> {
158        serde_json::to_string_pretty(self)
159    }
160
161    /// Get risk assessment based on threat score
162    pub fn risk_assessment(&self) -> &str {
163        match self.threat_score {
164            0..=20 => "Low Risk",
165            21..=50 => "Medium Risk",
166            51..=80 => "High Risk",
167            _ => "Critical Risk",
168        }
169    }
170}
171
172/// Alert aggregation for pattern analysis
173#[derive(Debug, Clone)]
174#[allow(dead_code)]
175struct AlertAggregation {
176    category: ThreatCategory,
177    first_seen: DateTime<Utc>,
178    last_seen: DateTime<Utc>,
179    count: usize,
180    sources: Vec<String>,
181}
182
183/// Threat detection pattern
184#[derive(Debug, Clone)]
185pub struct ThreatPattern {
186    pub name: String,
187    pub category: ThreatCategory,
188    pub severity: ThreatSeverity,
189    pub pattern: Regex,
190    pub description: String,
191    pub recommended_action: String,
192}
193
194/// Threat detector
195pub struct ThreatDetector {
196    patterns: Vec<ThreatPattern>,
197    alert_count: usize,
198    alert_history: Vec<ThreatAlert>,
199    aggregations: HashMap<String, AlertAggregation>,
200}
201
202impl ThreatDetector {
203    /// Create a new threat detector with default patterns
204    pub fn new() -> Self {
205        let mut detector = Self {
206            patterns: Vec::new(),
207            alert_count: 0,
208            alert_history: Vec::new(),
209            aggregations: HashMap::new(),
210        };
211        detector.load_default_patterns();
212        detector
213    }
214
215    /// Load default threat detection patterns
216    fn load_default_patterns(&mut self) {
217        // Brute force detection
218        self.add_pattern(ThreatPattern {
219            name: "Failed Login Attempts".to_string(),
220            category: ThreatCategory::BruteForce,
221            severity: ThreatSeverity::High,
222            pattern: Regex::new(r"(?i)(failed.*login|authentication.*failed|invalid.*password)")
223                .unwrap(),
224            description: "Multiple failed login attempts detected".to_string(),
225            recommended_action: "Block source IP, enable MFA, review user account".to_string(),
226        });
227
228        // Malware indicators
229        self.add_pattern(ThreatPattern {
230            name: "Malware Signature".to_string(),
231            category: ThreatCategory::MalwareDetection,
232            severity: ThreatSeverity::Critical,
233            pattern: Regex::new(r"(?i)(malware|virus|trojan|ransomware|backdoor)").unwrap(),
234            description: "Malware signature detected in logs".to_string(),
235            recommended_action: "Isolate system, run full scan, investigate infection vector"
236                .to_string(),
237        });
238
239        // Data exfiltration
240        self.add_pattern(ThreatPattern {
241            name: "Large Data Transfer".to_string(),
242            category: ThreatCategory::DataExfiltration,
243            severity: ThreatSeverity::High,
244            pattern: Regex::new(r"(?i)(large.*transfer|exfiltration|unusual.*download)").unwrap(),
245            description: "Potential data exfiltration detected".to_string(),
246            recommended_action: "Block transfer, investigate user activity, review DLP policies"
247                .to_string(),
248        });
249
250        // Unauthorized access
251        self.add_pattern(ThreatPattern {
252            name: "Privilege Escalation".to_string(),
253            category: ThreatCategory::UnauthorizedAccess,
254            severity: ThreatSeverity::Critical,
255            pattern: Regex::new(
256                r"(?i)(privilege.*escalation|unauthorized.*access|sudo|admin.*access)",
257            )
258            .unwrap(),
259            description: "Unauthorized privilege escalation attempt".to_string(),
260            recommended_action: "Revoke privileges, investigate account, review access logs"
261                .to_string(),
262        });
263
264        // SQL Injection
265        self.add_pattern(ThreatPattern {
266            name: "SQL Injection Attempt".to_string(),
267            category: ThreatCategory::SystemCompromise,
268            severity: ThreatSeverity::Critical,
269            pattern: Regex::new(r"(?i)(union.*select|' or '1'='1|drop.*table|;--|exec\()").unwrap(),
270            description: "SQL injection attack detected".to_string(),
271            recommended_action: "Block source IP, patch application, review WAF rules".to_string(),
272        });
273
274        // Suspicious IP access
275        self.add_pattern(ThreatPattern {
276            name: "Suspicious IP Address".to_string(),
277            category: ThreatCategory::AnomalousActivity,
278            severity: ThreatSeverity::Medium,
279            pattern: Regex::new(r"(^0\.|^10\.|^127\.|^169\.254\.|^172\.(1[6-9]|2[0-9]|3[0-1])\.|^192\.168\.|^224\.)").unwrap(),
280            description: "Access from suspicious IP range".to_string(),
281            recommended_action: "Verify IP legitimacy, check geo-location, review firewall rules".to_string(),
282        });
283    }
284
285    /// Add a custom threat pattern
286    pub fn add_pattern(&mut self, pattern: ThreatPattern) {
287        self.patterns.push(pattern);
288    }
289
290    /// Analyze a log entry for threats
291    pub fn analyze(&mut self, log: &LogEntry) -> Vec<ThreatAlert> {
292        let mut alerts = Vec::new();
293        let mut new_alerts = Vec::new();
294
295        for pattern in &self.patterns {
296            if pattern.pattern.is_match(&log.message) {
297                self.alert_count += 1;
298
299                // Calculate threat score
300                let threat_score = self.calculate_threat_score(pattern.severity, log);
301
302                // Find correlated alerts
303                let correlated = self.find_correlated_alerts(&pattern.category, log);
304
305                let alert = ThreatAlert {
306                    alert_id: format!("ALERT-{:08}", self.alert_count),
307                    timestamp: Utc::now(),
308                    severity: pattern.severity,
309                    category: pattern.category.clone(),
310                    description: format!("{}: {}", pattern.name, pattern.description),
311                    source_log: format!("{} - {}", log.timestamp, log.message),
312                    indicators: self.extract_indicators(&log.message, &pattern.pattern),
313                    recommended_action: pattern.recommended_action.clone(),
314                    threat_score,
315                    correlated_alerts: correlated,
316                };
317
318                new_alerts.push((alert.clone(), log.clone()));
319                alerts.push(alert);
320            }
321        }
322
323        // Update aggregation and store in history after loop
324        for (alert, log) in new_alerts {
325            self.update_aggregation(&alert, &log);
326            self.alert_history.push(alert);
327        }
328
329        alerts
330    }
331
332    /// Calculate threat score based on multiple factors
333    fn calculate_threat_score(&self, severity: ThreatSeverity, log: &LogEntry) -> u32 {
334        let mut score: u32 = match severity {
335            ThreatSeverity::Info => 5,
336            ThreatSeverity::Low => 15,
337            ThreatSeverity::Medium => 40,
338            ThreatSeverity::High => 70,
339            ThreatSeverity::Critical => 95,
340        };
341
342        // Increase score if from external IP (simplified check)
343        if let Some(ref ip) = log.source_ip {
344            if !ip.starts_with("192.168.") && !ip.starts_with("10.") {
345                score += 10;
346            }
347        }
348
349        // Increase score for root/admin users
350        if let Some(ref user) = log.user {
351            if user.contains("admin") || user.contains("root") {
352                score += 15;
353            }
354        }
355
356        score.min(100)
357    }
358
359    /// Find correlated alerts in recent history
360    fn find_correlated_alerts(&self, category: &ThreatCategory, log: &LogEntry) -> Vec<String> {
361        let window_start = log.timestamp - Duration::hours(1);
362        let mut correlated = Vec::new();
363
364        for alert in &self.alert_history {
365            if alert.timestamp >= window_start && alert.category == *category {
366                // Check for same source
367                if let Some(ref ip) = log.source_ip {
368                    if alert.source_log.contains(ip) {
369                        correlated.push(alert.alert_id.clone());
370                    }
371                }
372            }
373        }
374
375        correlated
376    }
377
378    /// Update alert aggregation for pattern analysis
379    fn update_aggregation(&mut self, alert: &ThreatAlert, log: &LogEntry) {
380        let key = format!("{:?}", alert.category);
381        let source = log
382            .source_ip
383            .clone()
384            .unwrap_or_else(|| "unknown".to_string());
385
386        self.aggregations
387            .entry(key.clone())
388            .and_modify(|agg| {
389                agg.last_seen = alert.timestamp;
390                agg.count += 1;
391                if !agg.sources.contains(&source) {
392                    agg.sources.push(source.clone());
393                }
394            })
395            .or_insert(AlertAggregation {
396                category: alert.category.clone(),
397                first_seen: alert.timestamp,
398                last_seen: alert.timestamp,
399                count: 1,
400                sources: vec![source],
401            });
402    }
403
404    /// Extract threat indicators from log message
405    fn extract_indicators(&self, message: &str, pattern: &Regex) -> Vec<String> {
406        let mut indicators = Vec::new();
407
408        if let Some(captures) = pattern.captures(message) {
409            for i in 1..captures.len() {
410                if let Some(matched) = captures.get(i) {
411                    indicators.push(matched.as_str().to_string());
412                }
413            }
414        }
415
416        if indicators.is_empty() {
417            indicators.push("Pattern match".to_string());
418        }
419
420        indicators
421    }
422
423    /// Analyze multiple log entries in batch
424    pub fn analyze_batch(&mut self, logs: &[LogEntry]) -> Vec<ThreatAlert> {
425        let mut all_alerts = Vec::new();
426        for log in logs {
427            all_alerts.extend(self.analyze(log));
428        }
429        all_alerts
430    }
431
432    /// Get alert history for a time window
433    pub fn get_alert_history(&self, since: DateTime<Utc>) -> Vec<&ThreatAlert> {
434        self.alert_history
435            .iter()
436            .filter(|alert| alert.timestamp >= since)
437            .collect()
438    }
439
440    /// Get aggregated patterns
441    pub fn get_aggregations(&self) -> HashMap<String, (usize, usize)> {
442        self.aggregations
443            .iter()
444            .map(|(k, v)| (k.clone(), (v.count, v.sources.len())))
445            .collect()
446    }
447
448    /// Deduplicate alerts by removing similar alerts within time window
449    pub fn deduplicate_alerts(&mut self, window_minutes: i64) -> usize {
450        let mut to_remove = Vec::new();
451        let mut seen = HashMap::new();
452
453        for (i, alert) in self.alert_history.iter().enumerate() {
454            let key = format!("{:?}-{}", alert.category, alert.source_log);
455            if let Some(&prev_idx) = seen.get(&key) {
456                let prev_alert: &ThreatAlert = &self.alert_history[prev_idx];
457                if alert
458                    .timestamp
459                    .signed_duration_since(prev_alert.timestamp)
460                    .num_minutes()
461                    <= window_minutes
462                {
463                    to_remove.push(i);
464                    continue;
465                }
466            }
467            seen.insert(key, i);
468        }
469
470        let removed_count = to_remove.len();
471        // Remove in reverse order to maintain indices
472        for &idx in to_remove.iter().rev() {
473            self.alert_history.remove(idx);
474        }
475
476        removed_count
477    }
478
479    /// Get top threat sources
480    pub fn get_top_sources(&self, limit: usize) -> Vec<(String, usize)> {
481        let mut source_counts: HashMap<String, usize> = HashMap::new();
482
483        for agg in self.aggregations.values() {
484            for source in &agg.sources {
485                *source_counts.entry(source.clone()).or_insert(0) += 1;
486            }
487        }
488
489        let mut sorted: Vec<(String, usize)> = source_counts.into_iter().collect();
490        sorted.sort_by(|a, b| b.1.cmp(&a.1));
491        sorted.truncate(limit);
492        sorted
493    }
494
495    /// Clear old alerts from history (memory management)
496    pub fn clear_old_alerts(&mut self, before: DateTime<Utc>) {
497        self.alert_history.retain(|alert| alert.timestamp >= before);
498    }
499
500    /// Get statistics
501    pub fn get_stats(&self) -> HashMap<String, usize> {
502        let mut stats = HashMap::new();
503        stats.insert("total_patterns".to_string(), self.patterns.len());
504        stats.insert("total_alerts".to_string(), self.alert_count);
505        stats.insert("alerts_in_history".to_string(), self.alert_history.len());
506        stats.insert("active_aggregations".to_string(), self.aggregations.len());
507        stats
508    }
509
510    /// Get alerts by severity
511    pub fn filter_by_severity(
512        &self,
513        alerts: &[ThreatAlert],
514        min_severity: ThreatSeverity,
515    ) -> Vec<ThreatAlert> {
516        alerts
517            .iter()
518            .filter(|alert| alert.severity >= min_severity)
519            .cloned()
520            .collect()
521    }
522
523    /// Get alerts by category
524    pub fn filter_by_category(
525        &self,
526        alerts: &[ThreatAlert],
527        category: &ThreatCategory,
528    ) -> Vec<ThreatAlert> {
529        alerts
530            .iter()
531            .filter(|alert| alert.category == *category)
532            .cloned()
533            .collect()
534    }
535}
536
537impl Default for ThreatDetector {
538    fn default() -> Self {
539        Self::new()
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546
547    fn create_log_entry(message: &str) -> LogEntry {
548        LogEntry {
549            timestamp: Utc::now(),
550            source_ip: Some("192.168.1.100".to_string()),
551            user: Some("test_user".to_string()),
552            event_type: "security_event".to_string(),
553            message: message.to_string(),
554            metadata: HashMap::new(),
555        }
556    }
557
558    #[test]
559    fn test_brute_force_detection() {
560        let mut detector = ThreatDetector::new();
561        let log = create_log_entry("Failed login attempt for user admin");
562
563        let alerts = detector.analyze(&log);
564        assert!(!alerts.is_empty());
565        assert_eq!(alerts[0].category, ThreatCategory::BruteForce);
566    }
567
568    #[test]
569    fn test_malware_detection() {
570        let mut detector = ThreatDetector::new();
571        let log = create_log_entry("Malware detected in file system");
572
573        let alerts = detector.analyze(&log);
574        assert!(!alerts.is_empty());
575        assert_eq!(alerts[0].category, ThreatCategory::MalwareDetection);
576        assert_eq!(alerts[0].severity, ThreatSeverity::Critical);
577    }
578
579    #[test]
580    fn test_sql_injection_detection() {
581        let mut detector = ThreatDetector::new();
582        let log = create_log_entry("Query: SELECT * FROM users WHERE id='1' OR '1'='1'");
583
584        let alerts = detector.analyze(&log);
585        assert!(!alerts.is_empty());
586        assert_eq!(alerts[0].severity, ThreatSeverity::Critical);
587    }
588
589    #[test]
590    fn test_no_threat_detected() {
591        let mut detector = ThreatDetector::new();
592        let log = create_log_entry("User successfully logged in");
593
594        let alerts = detector.analyze(&log);
595        assert!(alerts.is_empty());
596    }
597
598    #[test]
599    fn test_severity_filtering() {
600        let mut detector = ThreatDetector::new();
601        let log1 = create_log_entry("Failed login attempt");
602        let log2 = create_log_entry("Malware detected");
603
604        let mut all_alerts = Vec::new();
605        all_alerts.extend(detector.analyze(&log1));
606        all_alerts.extend(detector.analyze(&log2));
607
608        let critical_alerts = detector.filter_by_severity(&all_alerts, ThreatSeverity::Critical);
609        assert_eq!(critical_alerts.len(), 1);
610        assert_eq!(
611            critical_alerts[0].category,
612            ThreatCategory::MalwareDetection
613        );
614    }
615
616    #[test]
617    fn test_threat_scoring() {
618        let mut detector = ThreatDetector::new();
619
620        // High severity with external IP
621        let log = LogEntry {
622            timestamp: Utc::now(),
623            source_ip: Some("1.2.3.4".to_string()), // External IP
624            user: Some("admin".to_string()),        // Admin user
625            event_type: "security_event".to_string(),
626            message: "Malware detected".to_string(),
627            metadata: HashMap::new(),
628        };
629
630        let alerts = detector.analyze(&log);
631        assert!(!alerts.is_empty());
632        assert!(alerts[0].threat_score > 95); // Critical + external + admin
633    }
634
635    #[test]
636    fn test_alert_correlation() {
637        let mut detector = ThreatDetector::new();
638        let ip = "192.168.1.100".to_string();
639
640        // First alert - include IP in message for correlation
641        let log1 = LogEntry {
642            timestamp: Utc::now(),
643            source_ip: Some(ip.clone()),
644            user: Some("user1".to_string()),
645            event_type: "security_event".to_string(),
646            message: format!("Failed login attempt from {}", ip),
647            metadata: HashMap::new(),
648        };
649
650        let alerts1 = detector.analyze(&log1);
651        assert_eq!(alerts1[0].correlated_alerts.len(), 0);
652
653        // Second alert from same IP - include IP in message for correlation
654        let log2 = LogEntry {
655            timestamp: Utc::now(),
656            source_ip: Some(ip.clone()),
657            user: Some("user2".to_string()),
658            event_type: "security_event".to_string(),
659            message: format!("Failed login attempt from {}", ip),
660            metadata: HashMap::new(),
661        };
662
663        let alerts2 = detector.analyze(&log2);
664        assert!(!alerts2[0].correlated_alerts.is_empty()); // Should correlate with first alert
665    }
666
667    #[test]
668    fn test_batch_analysis() {
669        let mut detector = ThreatDetector::new();
670
671        let logs = vec![
672            create_log_entry("Failed login"),
673            create_log_entry("Malware detected"),
674            create_log_entry("Normal activity"),
675        ];
676
677        let alerts = detector.analyze_batch(&logs);
678        assert_eq!(alerts.len(), 2); // Only failed login and malware
679    }
680
681    #[test]
682    fn test_alert_history() {
683        let mut detector = ThreatDetector::new();
684
685        let log1 = create_log_entry("Failed login");
686        let log2 = create_log_entry("Malware detected");
687
688        detector.analyze(&log1);
689        detector.analyze(&log2);
690
691        let since = Utc::now() - Duration::hours(1);
692        let history = detector.get_alert_history(since);
693        assert_eq!(history.len(), 2);
694    }
695
696    #[test]
697    fn test_aggregations() {
698        let mut detector = ThreatDetector::new();
699
700        for _ in 0..5 {
701            let log = create_log_entry("Failed login attempt");
702            detector.analyze(&log);
703        }
704
705        let aggregations = detector.get_aggregations();
706        assert!(!aggregations.is_empty());
707
708        // Should have BruteForce aggregation with count of 5
709        let brute_force_key = format!("{:?}", ThreatCategory::BruteForce);
710        if let Some(&(count, _sources)) = aggregations.get(&brute_force_key) {
711            assert_eq!(count, 5);
712        }
713    }
714
715    #[test]
716    fn test_deduplication() {
717        let mut detector = ThreatDetector::new();
718
719        // Create duplicate alerts with same timestamp to ensure deduplication works
720        let timestamp = Utc::now();
721        for _ in 0..3 {
722            let mut log = create_log_entry("Failed login attempt");
723            log.timestamp = timestamp; // Use same timestamp for all
724            detector.analyze(&log);
725        }
726
727        let initial_count = detector.alert_history.len();
728        assert_eq!(initial_count, 3);
729
730        let removed = detector.deduplicate_alerts(60); // 60 minute window
731        assert!(removed > 0); // Should remove duplicates
732        assert!(detector.alert_history.len() < initial_count);
733    }
734
735    #[test]
736    fn test_top_sources() {
737        let mut detector = ThreatDetector::new();
738
739        // Generate alerts from different sources
740        for i in 0..5 {
741            let log = LogEntry {
742                timestamp: Utc::now(),
743                source_ip: Some(format!("192.168.1.{}", i)),
744                user: Some("user1".to_string()),
745                event_type: "security_event".to_string(),
746                message: "Failed login attempt".to_string(),
747                metadata: HashMap::new(),
748            };
749            detector.analyze(&log);
750        }
751
752        let top_sources = detector.get_top_sources(3);
753        assert!(top_sources.len() <= 3);
754    }
755
756    #[test]
757    fn test_clear_old_alerts() {
758        let mut detector = ThreatDetector::new();
759
760        let log = create_log_entry("Failed login");
761        detector.analyze(&log);
762
763        assert_eq!(detector.alert_history.len(), 1);
764
765        let cutoff = Utc::now() + Duration::hours(1); // Future time
766        detector.clear_old_alerts(cutoff);
767
768        assert_eq!(detector.alert_history.len(), 0);
769    }
770
771    #[test]
772    fn test_risk_assessment() {
773        let alert = ThreatAlert {
774            alert_id: "TEST-001".to_string(),
775            timestamp: Utc::now(),
776            severity: ThreatSeverity::Critical,
777            category: ThreatCategory::MalwareDetection,
778            description: "Test alert".to_string(),
779            source_log: "Test log".to_string(),
780            indicators: vec![],
781            recommended_action: "Test action".to_string(),
782            threat_score: 95,
783            correlated_alerts: vec![],
784        };
785
786        assert_eq!(alert.risk_assessment(), "Critical Risk");
787
788        let low_alert = ThreatAlert {
789            threat_score: 15,
790            ..alert
791        };
792        assert_eq!(low_alert.risk_assessment(), "Low Risk");
793    }
794
795    #[test]
796    fn test_category_filtering() {
797        let mut detector = ThreatDetector::new();
798        let log1 = create_log_entry("Failed login attempt");
799        let log2 = create_log_entry("Malware detected");
800
801        let mut all_alerts = Vec::new();
802        all_alerts.extend(detector.analyze(&log1));
803        all_alerts.extend(detector.analyze(&log2));
804
805        let brute_force_alerts =
806            detector.filter_by_category(&all_alerts, &ThreatCategory::BruteForce);
807        assert_eq!(brute_force_alerts.len(), 1);
808        assert_eq!(brute_force_alerts[0].category, ThreatCategory::BruteForce);
809    }
810
811    #[test]
812    fn test_json_export() {
813        let mut detector = ThreatDetector::new();
814        let log = create_log_entry("Failed login attempt");
815        let alerts = detector.analyze(&log);
816
817        let json = alerts[0].to_json();
818        assert!(json.is_ok());
819        let json_str = json.unwrap();
820        assert!(json_str.contains("threat_score"));
821        assert!(json_str.contains("correlated_alerts"));
822    }
823}