Skip to main content

sh_layer0/
threat_detector.rs

1//! # Threat Detector
2//!
3//! 威胁检测器:检测和响应潜在安全威胁。
4//!
5//! ## 功能
6//! - 异常行为检测
7//! - 攻击模式识别
8//! - 威胁评分
9//! - 自动响应
10
11use parking_lot::RwLock;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::time::{Duration, Instant};
15use thiserror::Error;
16
17/// 威胁检测错误
18#[derive(Debug, Error)]
19pub enum ThreatError {
20    #[error("Invalid threshold: {0}")]
21    InvalidThreshold(String),
22
23    #[error("Detection failed: {0}")]
24    DetectionFailed(String),
25
26    #[error("Rule not found: {0}")]
27    RuleNotFound(String),
28}
29
30/// 威胁级别
31#[derive(
32    Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize,
33)]
34pub enum ThreatLevel {
35    /// 信息级别 - 无威胁
36    #[default]
37    Info,
38    /// 低危 - 轻微异常
39    Low,
40    /// 中危 - 需要关注
41    Medium,
42    /// 高危 - 需要立即处理
43    High,
44    /// 严重 - 紧急响应
45    Critical,
46}
47
48impl ThreatLevel {
49    pub fn as_str(&self) -> &'static str {
50        match self {
51            Self::Info => "info",
52            Self::Low => "low",
53            Self::Medium => "medium",
54            Self::High => "high",
55            Self::Critical => "critical",
56        }
57    }
58
59    /// Parse from string (non-standard name to avoid confusion with FromStr trait)
60    pub fn parse(s: &str) -> Option<Self> {
61        match s.to_lowercase().as_str() {
62            "info" => Some(Self::Info),
63            "low" => Some(Self::Low),
64            "medium" => Some(Self::Medium),
65            "high" => Some(Self::High),
66            "critical" => Some(Self::Critical),
67            _ => None,
68        }
69    }
70}
71
72/// 威胁类型
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
74pub enum ThreatType {
75    /// 暴力破解
76    BruteForce,
77    /// DDoS 攻击
78    DDoS,
79    /// SQL 注入
80    SqlInjection,
81    /// XSS 攻击
82    #[allow(clippy::upper_case_acronyms)]
83    XSS,
84    /// 路径遍历
85    PathTraversal,
86    /// 权限提升
87    PrivilegeEscalation,
88    /// 数据泄露
89    DataExfiltration,
90    /// 异常访问
91    AnomalousAccess,
92    /// 速率异常
93    RateAnomaly,
94    /// 自定义威胁
95    Custom,
96}
97
98impl ThreatType {
99    pub fn as_str(&self) -> &'static str {
100        match self {
101            Self::BruteForce => "brute_force",
102            Self::DDoS => "ddos",
103            Self::SqlInjection => "sql_injection",
104            Self::XSS => "xss",
105            Self::PathTraversal => "path_traversal",
106            Self::PrivilegeEscalation => "privilege_escalation",
107            Self::DataExfiltration => "data_exfiltration",
108            Self::AnomalousAccess => "anomalous_access",
109            Self::RateAnomaly => "rate_anomaly",
110            Self::Custom => "custom",
111        }
112    }
113}
114
115/// 检测到的威胁
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct Threat {
118    /// 威胁 ID
119    pub id: String,
120    /// 威胁类型
121    pub threat_type: ThreatType,
122    /// 威胁级别
123    pub level: ThreatLevel,
124    /// 描述
125    pub description: String,
126    /// 来源 IP 或标识
127    pub source: String,
128    /// 目标资源
129    pub target: Option<String>,
130    /// 检测时间
131    pub detected_at: String,
132    /// 证据数据
133    pub evidence: serde_json::Value,
134    /// 置信度 (0.0-1.0)
135    pub confidence: f32,
136    /// 是否已处理
137    pub handled: bool,
138}
139
140impl Threat {
141    pub fn new(
142        threat_type: ThreatType,
143        level: ThreatLevel,
144        source: impl Into<String>,
145        description: impl Into<String>,
146    ) -> Self {
147        Self {
148            id: format!("THR-{}", uuid::Uuid::new_v4()),
149            threat_type,
150            level,
151            description: description.into(),
152            source: source.into(),
153            target: None,
154            detected_at: chrono::Utc::now().to_rfc3339(),
155            evidence: serde_json::json!({}),
156            confidence: 0.5,
157            handled: false,
158        }
159    }
160
161    pub fn with_target(mut self, target: impl Into<String>) -> Self {
162        self.target = Some(target.into());
163        self
164    }
165
166    pub fn with_evidence(mut self, evidence: serde_json::Value) -> Self {
167        self.evidence = evidence;
168        self
169    }
170
171    pub fn with_confidence(mut self, confidence: f32) -> Self {
172        self.confidence = confidence.clamp(0.0, 1.0);
173        self
174    }
175
176    pub fn mark_handled(&mut self) {
177        self.handled = true;
178    }
179}
180
181/// 检测规则
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct DetectionRule {
184    /// 规则 ID
185    pub id: String,
186    /// 规则名称
187    pub name: String,
188    /// 威胁类型
189    pub threat_type: ThreatType,
190    /// 基础威胁级别
191    pub base_level: ThreatLevel,
192    /// 触发阈值
193    pub threshold: f32,
194    /// 时间窗口(秒)
195    pub time_window_secs: u64,
196    /// 是否启用
197    pub enabled: bool,
198    /// 描述
199    pub description: String,
200}
201
202impl DetectionRule {
203    pub fn new(name: impl Into<String>, threat_type: ThreatType, base_level: ThreatLevel) -> Self {
204        Self {
205            id: format!("RULE-{}", uuid::Uuid::new_v4()),
206            name: name.into(),
207            threat_type,
208            base_level,
209            threshold: 0.5,
210            time_window_secs: 300,
211            enabled: true,
212            description: String::new(),
213        }
214    }
215
216    pub fn with_threshold(mut self, threshold: f32) -> Self {
217        self.threshold = threshold;
218        self
219    }
220
221    pub fn with_time_window(mut self, secs: u64) -> Self {
222        self.time_window_secs = secs;
223        self
224    }
225
226    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
227        self.description = desc.into();
228        self
229    }
230}
231
232/// 响应动作
233#[derive(Debug, Clone, Serialize, Deserialize)]
234pub enum ResponseAction {
235    /// 仅记录
236    Log,
237    /// 发送告警
238    Alert,
239    /// 临时封禁
240    TempBan { duration_secs: u64 },
241    /// 永久封禁
242    PermanentBan,
243    /// 限流
244    RateLimit { requests_per_sec: u32 },
245    /// 自定义动作
246    Custom { action: String },
247}
248
249/// 响应规则
250#[derive(Debug, Clone)]
251pub struct ResponseRule {
252    /// 触发的威胁级别
253    pub min_level: ThreatLevel,
254    /// 响应动作
255    pub action: ResponseAction,
256    /// 是否启用
257    pub enabled: bool,
258}
259
260impl ResponseRule {
261    pub fn new(min_level: ThreatLevel, action: ResponseAction) -> Self {
262        Self {
263            min_level,
264            action,
265            enabled: true,
266        }
267    }
268}
269
270/// 威胁统计
271#[derive(Debug, Clone, Default)]
272pub struct ThreatStats {
273    /// 总检测次数
274    pub total_detections: u64,
275    /// 各级别威胁数量
276    pub by_level: HashMap<ThreatLevel, u64>,
277    /// 各类型威胁数量
278    pub by_type: HashMap<ThreatType, u64>,
279    /// 活跃威胁数
280    pub active_threats: u64,
281    /// 已处理威胁数
282    pub handled_threats: u64,
283}
284
285/// 活动记录
286#[derive(Debug, Clone)]
287struct ActivityRecord {
288    /// 活动类型
289    activity_type: String,
290    /// 来源
291    source: String,
292    /// 时间戳
293    timestamp: Instant,
294    /// 相关数据
295    data: serde_json::Value,
296}
297
298/// 威胁检测器配置
299#[derive(Debug, Clone)]
300pub struct ThreatDetectorConfig {
301    /// 是否启用检测
302    pub enabled: bool,
303    /// 检测间隔(秒)
304    pub detection_interval_secs: u64,
305    /// 历史记录保留时间(秒)
306    pub history_retention_secs: u64,
307    /// 自动响应
308    pub auto_response: bool,
309    /// 告警阈值
310    pub alert_threshold: ThreatLevel,
311}
312
313impl Default for ThreatDetectorConfig {
314    fn default() -> Self {
315        Self {
316            enabled: true,
317            detection_interval_secs: 60,
318            history_retention_secs: 3600 * 24,
319            auto_response: true,
320            alert_threshold: ThreatLevel::High,
321        }
322    }
323}
324
325/// 威胁检测器
326pub struct ThreatDetector {
327    /// 配置
328    config: ThreatDetectorConfig,
329    /// 检测规则
330    rules: RwLock<Vec<DetectionRule>>,
331    /// 响应规则
332    response_rules: RwLock<Vec<ResponseRule>>,
333    /// 活动历史
334    activity_history: RwLock<Vec<ActivityRecord>>,
335    /// 检测到的威胁
336    threats: RwLock<Vec<Threat>>,
337    /// 统计信息
338    stats: RwLock<ThreatStats>,
339    /// 封禁列表
340    ban_list: RwLock<HashMap<String, Instant>>,
341}
342
343impl ThreatDetector {
344    /// 创建新的威胁检测器
345    pub fn new() -> Self {
346        Self::with_config(ThreatDetectorConfig::default())
347    }
348
349    /// 使用自定义配置创建
350    pub fn with_config(config: ThreatDetectorConfig) -> Self {
351        let detector = Self {
352            config,
353            rules: RwLock::new(Vec::new()),
354            response_rules: RwLock::new(Vec::new()),
355            activity_history: RwLock::new(Vec::new()),
356            threats: RwLock::new(Vec::new()),
357            stats: RwLock::new(ThreatStats::default()),
358            ban_list: RwLock::new(HashMap::new()),
359        };
360
361        // 添加默认规则
362        detector.add_default_rules();
363        detector.add_default_response_rules();
364        detector
365    }
366
367    fn add_default_rules(&self) {
368        let default_rules = vec![
369            DetectionRule::new(
370                "Brute Force Detection",
371                ThreatType::BruteForce,
372                ThreatLevel::High,
373            )
374            .with_threshold(0.3)
375            .with_time_window(300)
376            .with_description("检测短时间内多次失败登录尝试"),
377            DetectionRule::new(
378                "Rate Anomaly Detection",
379                ThreatType::RateAnomaly,
380                ThreatLevel::Medium,
381            )
382            .with_threshold(0.5)
383            .with_time_window(60)
384            .with_description("检测异常请求速率"),
385            DetectionRule::new(
386                "SQL Injection Detection",
387                ThreatType::SqlInjection,
388                ThreatLevel::Critical,
389            )
390            .with_threshold(0.8)
391            .with_time_window(1)
392            .with_description("检测 SQL 注入模式"),
393            DetectionRule::new(
394                "Path Traversal Detection",
395                ThreatType::PathTraversal,
396                ThreatLevel::High,
397            )
398            .with_threshold(0.7)
399            .with_time_window(1)
400            .with_description("检测路径遍历攻击"),
401            DetectionRule::new("XSS Detection", ThreatType::XSS, ThreatLevel::High)
402                .with_threshold(0.7)
403                .with_time_window(1)
404                .with_description("检测跨站脚本攻击"),
405        ];
406
407        let mut rules = self.rules.write();
408        for rule in default_rules {
409            rules.push(rule);
410        }
411    }
412
413    fn add_default_response_rules(&self) {
414        let default_responses = vec![
415            ResponseRule::new(ThreatLevel::Critical, ResponseAction::PermanentBan),
416            ResponseRule::new(
417                ThreatLevel::High,
418                ResponseAction::TempBan {
419                    duration_secs: 3600,
420                },
421            ),
422            ResponseRule::new(
423                ThreatLevel::Medium,
424                ResponseAction::RateLimit {
425                    requests_per_sec: 10,
426                },
427            ),
428            ResponseRule::new(ThreatLevel::Low, ResponseAction::Alert),
429            ResponseRule::new(ThreatLevel::Info, ResponseAction::Log),
430        ];
431
432        let mut response_rules = self.response_rules.write();
433        for rule in default_responses {
434            response_rules.push(rule);
435        }
436    }
437
438    /// 记录活动
439    pub fn record_activity(
440        &self,
441        activity_type: impl Into<String>,
442        source: impl Into<String>,
443        data: serde_json::Value,
444    ) {
445        let record = ActivityRecord {
446            activity_type: activity_type.into(),
447            source: source.into(),
448            timestamp: Instant::now(),
449            data,
450        };
451
452        self.activity_history.write().push(record);
453        self.cleanup_old_activities();
454    }
455
456    /// 检测威胁
457    pub fn detect(&self) -> Vec<Threat> {
458        if !self.config.enabled {
459            return Vec::new();
460        }
461
462        let mut detected_threats = Vec::new();
463        let rules = self.rules.read();
464        let activity = self.activity_history.read();
465
466        for rule in rules.iter().filter(|r| r.enabled) {
467            let threats = self.detect_with_rule(rule, &activity);
468            detected_threats.extend(threats);
469        }
470
471        // 更新统计
472        self.update_stats(&detected_threats);
473
474        // 存储威胁
475        let mut threats = self.threats.write();
476        for threat in &detected_threats {
477            threats.push(threat.clone());
478        }
479
480        detected_threats
481    }
482
483    fn detect_with_rule(&self, rule: &DetectionRule, activity: &[ActivityRecord]) -> Vec<Threat> {
484        let window = Duration::from_secs(rule.time_window_secs);
485
486        let relevant_activities: Vec<_> = activity
487            .iter()
488            .filter(|a| {
489                a.timestamp.elapsed() < window && self.activity_matches_rule(&a.activity_type, rule)
490            })
491            .collect();
492
493        if relevant_activities.is_empty() {
494            return Vec::new();
495        }
496
497        // 计算威胁分数
498        let score = self.calculate_threat_score(rule, &relevant_activities);
499
500        if score >= rule.threshold {
501            // 按来源分组
502            let mut by_source: HashMap<String, Vec<&ActivityRecord>> = HashMap::new();
503            for act in relevant_activities {
504                by_source.entry(act.source.clone()).or_default().push(act);
505            }
506
507            by_source
508                .into_iter()
509                .map(|(source, activities)| {
510                    let level = self.calculate_level(rule.base_level, score);
511                    Threat::new(rule.threat_type, level, source, &rule.description)
512                        .with_confidence(score)
513                        .with_evidence(serde_json::json!({
514                            "rule_id": rule.id,
515                            "rule_name": rule.name,
516                            "activity_count": activities.len(),
517                            "score": score,
518                        }))
519                })
520                .collect()
521        } else {
522            Vec::new()
523        }
524    }
525
526    fn activity_matches_rule(&self, activity_type: &str, rule: &DetectionRule) -> bool {
527        match rule.threat_type {
528            ThreatType::BruteForce => {
529                activity_type.contains("login") || activity_type.contains("auth")
530            }
531            ThreatType::RateAnomaly => activity_type.contains("request"),
532            ThreatType::SqlInjection => {
533                activity_type.contains("query") || activity_type.contains("sql")
534            }
535            ThreatType::XSS => activity_type.contains("input") || activity_type.contains("html"),
536            ThreatType::PathTraversal => {
537                activity_type.contains("file") || activity_type.contains("path")
538            }
539            ThreatType::DDoS => {
540                activity_type.contains("request") || activity_type.contains("connection")
541            }
542            ThreatType::PrivilegeEscalation => {
543                activity_type.contains("permission") || activity_type.contains("admin")
544            }
545            ThreatType::DataExfiltration => {
546                activity_type.contains("download") || activity_type.contains("export")
547            }
548            ThreatType::AnomalousAccess => activity_type.contains("access"),
549            ThreatType::Custom => true,
550        }
551    }
552
553    fn calculate_threat_score(&self, rule: &DetectionRule, activities: &[&ActivityRecord]) -> f32 {
554        if activities.is_empty() {
555            return 0.0;
556        }
557
558        let count = activities.len() as f32;
559        let window = rule.time_window_secs as f32;
560        let rate = count / window.max(1.0);
561
562        // 基于活动频率计算分数
563        match rule.threat_type {
564            ThreatType::BruteForce => (rate * 60.0).min(1.0),
565            ThreatType::RateAnomaly => (rate / 10.0).min(1.0),
566            ThreatType::SqlInjection | ThreatType::XSS | ThreatType::PathTraversal => {
567                // 检测恶意模式
568                activities.iter().any(|a| {
569                    let data = a.data.to_string().to_lowercase();
570                    data.contains("select") || data.contains("script") || data.contains("../")
571                }) as usize as f32
572            }
573            _ => (count / 5.0).min(1.0),
574        }
575    }
576
577    fn calculate_level(&self, base_level: ThreatLevel, score: f32) -> ThreatLevel {
578        if score >= 0.9 {
579            ThreatLevel::Critical
580        } else if score >= 0.7 {
581            ThreatLevel::High
582        } else if score >= 0.5 {
583            base_level
584        } else if score >= 0.3 {
585            ThreatLevel::Low
586        } else {
587            ThreatLevel::Info
588        }
589    }
590
591    fn update_stats(&self, new_threats: &[Threat]) {
592        let mut stats = self.stats.write();
593        stats.total_detections += new_threats.len() as u64;
594
595        for threat in new_threats {
596            *stats.by_level.entry(threat.level).or_default() += 1;
597            *stats.by_type.entry(threat.threat_type).or_default() += 1;
598            if threat.handled {
599                stats.handled_threats += 1;
600            } else {
601                stats.active_threats += 1;
602            }
603        }
604    }
605
606    /// 响应威胁
607    pub fn respond(&self, threat: &Threat) -> Option<ResponseAction> {
608        if !self.config.auto_response {
609            return None;
610        }
611
612        let response_rules = self.response_rules.read();
613        response_rules
614            .iter()
615            .filter(|r| r.enabled && threat.level >= r.min_level)
616            .max_by_key(|r| r.min_level as i32)
617            .map(|r| r.action.clone())
618    }
619
620    /// 添加检测规则
621    pub fn add_rule(&self, rule: DetectionRule) {
622        self.rules.write().push(rule);
623    }
624
625    /// 添加响应规则
626    pub fn add_response_rule(&self, rule: ResponseRule) {
627        self.response_rules.write().push(rule);
628    }
629
630    /// 获取威胁列表
631    pub fn get_threats(&self) -> Vec<Threat> {
632        self.threats.read().clone()
633    }
634
635    /// 获取活跃威胁
636    pub fn get_active_threats(&self) -> Vec<Threat> {
637        self.threats
638            .read()
639            .iter()
640            .filter(|t| !t.handled)
641            .cloned()
642            .collect()
643    }
644
645    /// 标记威胁已处理
646    pub fn handle_threat(&self, threat_id: &str) -> Result<(), ThreatError> {
647        let mut threats = self.threats.write();
648        if let Some(threat) = threats.iter_mut().find(|t| t.id == threat_id) {
649            threat.mark_handled();
650            let mut stats = self.stats.write();
651            stats.active_threats = stats.active_threats.saturating_sub(1);
652            stats.handled_threats += 1;
653            Ok(())
654        } else {
655            Err(ThreatError::RuleNotFound(threat_id.to_string()))
656        }
657    }
658
659    /// 检查是否被封禁
660    pub fn is_banned(&self, source: &str) -> bool {
661        let ban_list = self.ban_list.read();
662        if let Some(&ban_time) = ban_list.get(source) {
663            if ban_time.elapsed() < Duration::from_secs(3600) {
664                return true;
665            }
666        }
667        false
668    }
669
670    /// 封禁来源
671    pub fn ban(&self, source: &str, duration_secs: u64) {
672        let expiry = Instant::now() + Duration::from_secs(duration_secs);
673        self.ban_list.write().insert(source.to_string(), expiry);
674    }
675
676    /// 解封来源
677    pub fn unban(&self, source: &str) {
678        self.ban_list.write().remove(source);
679    }
680
681    /// 获取统计信息
682    pub fn get_stats(&self) -> ThreatStats {
683        self.stats.read().clone()
684    }
685
686    /// 清理旧活动记录
687    fn cleanup_old_activities(&self) {
688        let retention = Duration::from_secs(self.config.history_retention_secs);
689        self.activity_history
690            .write()
691            .retain(|a| a.timestamp.elapsed() < retention);
692    }
693
694    /// 重置统计
695    pub fn reset_stats(&self) {
696        *self.stats.write() = ThreatStats::default();
697    }
698
699    /// 清空威胁记录
700    pub fn clear_threats(&self) {
701        self.threats.write().clear();
702        self.stats.write().active_threats = 0;
703    }
704}
705
706impl Default for ThreatDetector {
707    fn default() -> Self {
708        Self::new()
709    }
710}
711
712#[cfg(test)]
713mod tests {
714    use super::*;
715
716    #[test]
717    fn test_threat_creation() {
718        let threat = Threat::new(
719            ThreatType::BruteForce,
720            ThreatLevel::High,
721            "192.168.1.1",
722            "Multiple failed login attempts",
723        );
724
725        assert_eq!(threat.threat_type, ThreatType::BruteForce);
726        assert_eq!(threat.level, ThreatLevel::High);
727        assert!(!threat.handled);
728    }
729
730    #[test]
731    fn test_threat_level_conversion() {
732        assert_eq!(ThreatLevel::parse("high"), Some(ThreatLevel::High));
733        assert_eq!(ThreatLevel::parse("critical"), Some(ThreatLevel::Critical));
734        assert_eq!(ThreatLevel::parse("invalid"), None);
735    }
736
737    #[test]
738    fn test_detection_rule() {
739        let rule = DetectionRule::new("Test Rule", ThreatType::BruteForce, ThreatLevel::Medium)
740            .with_threshold(0.8)
741            .with_time_window(60);
742
743        assert_eq!(rule.threshold, 0.8);
744        assert_eq!(rule.time_window_secs, 60);
745        assert!(rule.enabled);
746    }
747
748    #[test]
749    fn test_threat_detector_creation() {
750        let detector = ThreatDetector::new();
751        assert!(detector.config.enabled);
752        assert!(!detector.rules.read().is_empty());
753    }
754
755    #[test]
756    fn test_record_activity() {
757        let detector = ThreatDetector::new();
758        detector.record_activity(
759            "login_attempt",
760            "192.168.1.1",
761            serde_json::json!({"success": false}),
762        );
763
764        let history = detector.activity_history.read();
765        assert_eq!(history.len(), 1);
766    }
767
768    #[test]
769    fn test_detect_no_threats() {
770        let detector = ThreatDetector::new();
771        let threats = detector.detect();
772        // 无活动,无威胁
773        assert!(threats.is_empty());
774    }
775
776    #[test]
777    fn test_detect_brute_force() {
778        let detector = ThreatDetector::new();
779
780        // 模拟多次失败登录
781        for _ in 0..15 {
782            detector.record_activity(
783                "login_attempt",
784                "192.168.1.100",
785                serde_json::json!({"success": false}),
786            );
787        }
788
789        let threats = detector.detect();
790        assert!(!threats.is_empty());
791
792        let bf_threats: Vec<_> = threats
793            .iter()
794            .filter(|t| t.threat_type == ThreatType::BruteForce)
795            .collect();
796        assert!(!bf_threats.is_empty());
797    }
798
799    #[test]
800    fn test_respond_to_threat() {
801        let detector = ThreatDetector::new();
802
803        let critical_threat = Threat::new(
804            ThreatType::SqlInjection,
805            ThreatLevel::Critical,
806            "192.168.1.1",
807            "SQL injection detected",
808        );
809
810        let response = detector.respond(&critical_threat);
811        assert!(matches!(response, Some(ResponseAction::PermanentBan)));
812
813        let low_threat = Threat::new(
814            ThreatType::RateAnomaly,
815            ThreatLevel::Low,
816            "192.168.1.2",
817            "Slightly elevated rate",
818        );
819
820        let response = detector.respond(&low_threat);
821        assert!(matches!(response, Some(ResponseAction::Alert)));
822    }
823
824    #[test]
825    fn test_ban_functionality() {
826        let detector = ThreatDetector::new();
827
828        detector.ban("192.168.1.1", 3600);
829        assert!(detector.is_banned("192.168.1.1"));
830
831        detector.unban("192.168.1.1");
832        assert!(!detector.is_banned("192.168.1.1"));
833    }
834
835    #[test]
836    fn test_handle_threat() {
837        let detector = ThreatDetector::new();
838
839        // 创建并添加威胁
840        let mut threat = Threat::new(
841            ThreatType::BruteForce,
842            ThreatLevel::High,
843            "192.168.1.1",
844            "Test threat",
845        );
846        threat.id = "THR-TEST-1".to_string();
847        detector.threats.write().push(threat);
848
849        let result = detector.handle_threat("THR-TEST-1");
850        assert!(result.is_ok());
851
852        let threats = detector.threats.read();
853        assert!(
854            threats
855                .iter()
856                .find(|t| t.id == "THR-TEST-1")
857                .unwrap()
858                .handled
859        );
860    }
861
862    #[test]
863    fn test_get_stats() {
864        let detector = ThreatDetector::new();
865
866        detector.record_activity("login", "192.168.1.1", serde_json::json!({}));
867        detector.record_activity("login", "192.168.1.1", serde_json::json!({}));
868
869        let _threats = detector.detect();
870        let _stats = detector.get_stats();
871    }
872
873    #[test]
874    fn test_add_custom_rule() {
875        let detector = ThreatDetector::new();
876        let initial_count = detector.rules.read().len();
877
878        let custom_rule =
879            DetectionRule::new("Custom Rule", ThreatType::Custom, ThreatLevel::Medium);
880        detector.add_rule(custom_rule);
881
882        assert_eq!(detector.rules.read().len(), initial_count + 1);
883    }
884
885    #[test]
886    fn test_clear_threats() {
887        let detector = ThreatDetector::new();
888
889        // 添加一些威胁
890        for _ in 0..5 {
891            detector.record_activity("login", "192.168.1.1", serde_json::json!({}));
892        }
893        detector.detect();
894
895        detector.clear_threats();
896        assert!(detector.threats.read().is_empty());
897    }
898
899    #[test]
900    fn test_disabled_detector() {
901        let config = ThreatDetectorConfig {
902            enabled: false,
903            ..Default::default()
904        };
905        let detector = ThreatDetector::with_config(config);
906
907        detector.record_activity("login", "192.168.1.1", serde_json::json!({}));
908        let threats = detector.detect();
909
910        // 禁用检测器不应检测到威胁
911        assert!(threats.is_empty());
912    }
913
914    #[test]
915    fn test_auto_response_disabled() {
916        let config = ThreatDetectorConfig {
917            auto_response: false,
918            ..Default::default()
919        };
920        let detector = ThreatDetector::with_config(config);
921
922        let threat = Threat::new(
923            ThreatType::BruteForce,
924            ThreatLevel::High,
925            "192.168.1.1",
926            "Test",
927        );
928        let response = detector.respond(&threat);
929
930        assert!(response.is_none());
931    }
932
933    #[test]
934    fn test_threat_with_target() {
935        let threat = Threat::new(
936            ThreatType::SqlInjection,
937            ThreatLevel::High,
938            "attacker",
939            "Attack",
940        )
941        .with_target("users_table");
942
943        assert_eq!(threat.target, Some("users_table".to_string()));
944    }
945
946    #[test]
947    fn test_threat_confidence_clamping() {
948        let threat = Threat::new(
949            ThreatType::BruteForce,
950            ThreatLevel::Medium,
951            "source",
952            "desc",
953        )
954        .with_confidence(1.5);
955        assert_eq!(threat.confidence, 1.0);
956
957        let threat = Threat::new(
958            ThreatType::BruteForce,
959            ThreatLevel::Medium,
960            "source",
961            "desc",
962        )
963        .with_confidence(-0.5);
964        assert_eq!(threat.confidence, 0.0);
965    }
966
967    #[test]
968    fn test_multiple_sources() {
969        let detector = ThreatDetector::new();
970
971        // 不同来源的活动
972        detector.record_activity("login", "192.168.1.1", serde_json::json!({}));
973        detector.record_activity("login", "192.168.1.2", serde_json::json!({}));
974        detector.record_activity("login", "192.168.1.3", serde_json::json!({}));
975
976        let threats = detector.detect();
977        // 可能不会触发威胁(取决于阈值)
978        assert!(threats.len() <= 3);
979    }
980}