Skip to main content

sh_layer0/
threat_detector.rs

1//! # Threat Detector
2//!
3//! 威胁检测器:检测和响应潜在安全威胁。
4//!
5//! ## 功能
6//! - 异常行为检测
7//! - 攻击模式识别
8//! - 威胁评分
9//! - 自动响应
10
11use parking_lot::RwLock;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::time::{Duration, Instant};
15use thiserror::Error;
16
17/// 威胁检测错误
18#[derive(Debug, Error)]
19pub enum ThreatError {
20    #[error("Invalid threshold: {0}")]
21    InvalidThreshold(String),
22
23    #[error("Detection failed: {0}")]
24    DetectionFailed(String),
25
26    #[error("Rule not found: {0}")]
27    RuleNotFound(String),
28}
29
30/// 威胁级别
31#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
32pub enum ThreatLevel {
33    /// 信息级别 - 无威胁
34    Info,
35    /// 低危 - 轻微异常
36    Low,
37    /// 中危 - 需要关注
38    Medium,
39    /// 高危 - 需要立即处理
40    High,
41    /// 严重 - 紧急响应
42    Critical,
43}
44
45impl Default for ThreatLevel {
46    fn default() -> Self {
47        Self::Info
48    }
49}
50
51impl ThreatLevel {
52    pub fn as_str(&self) -> &'static str {
53        match self {
54            Self::Info => "info",
55            Self::Low => "low",
56            Self::Medium => "medium",
57            Self::High => "high",
58            Self::Critical => "critical",
59        }
60    }
61
62    pub fn from_str(s: &str) -> Option<Self> {
63        match s.to_lowercase().as_str() {
64            "info" => Some(Self::Info),
65            "low" => Some(Self::Low),
66            "medium" => Some(Self::Medium),
67            "high" => Some(Self::High),
68            "critical" => Some(Self::Critical),
69            _ => None,
70        }
71    }
72}
73
74/// 威胁类型
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
76pub enum ThreatType {
77    /// 暴力破解
78    BruteForce,
79    /// DDoS 攻击
80    DDoS,
81    /// SQL 注入
82    SqlInjection,
83    /// XSS 攻击
84    XSS,
85    /// 路径遍历
86    PathTraversal,
87    /// 权限提升
88    PrivilegeEscalation,
89    /// 数据泄露
90    DataExfiltration,
91    /// 异常访问
92    AnomalousAccess,
93    /// 速率异常
94    RateAnomaly,
95    /// 自定义威胁
96    Custom,
97}
98
99impl ThreatType {
100    pub fn as_str(&self) -> &'static str {
101        match self {
102            Self::BruteForce => "brute_force",
103            Self::DDoS => "ddos",
104            Self::SqlInjection => "sql_injection",
105            Self::XSS => "xss",
106            Self::PathTraversal => "path_traversal",
107            Self::PrivilegeEscalation => "privilege_escalation",
108            Self::DataExfiltration => "data_exfiltration",
109            Self::AnomalousAccess => "anomalous_access",
110            Self::RateAnomaly => "rate_anomaly",
111            Self::Custom => "custom",
112        }
113    }
114}
115
116/// 检测到的威胁
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct Threat {
119    /// 威胁 ID
120    pub id: String,
121    /// 威胁类型
122    pub threat_type: ThreatType,
123    /// 威胁级别
124    pub level: ThreatLevel,
125    /// 描述
126    pub description: String,
127    /// 来源 IP 或标识
128    pub source: String,
129    /// 目标资源
130    pub target: Option<String>,
131    /// 检测时间
132    pub detected_at: String,
133    /// 证据数据
134    pub evidence: serde_json::Value,
135    /// 置信度 (0.0-1.0)
136    pub confidence: f32,
137    /// 是否已处理
138    pub handled: bool,
139}
140
141impl Threat {
142    pub fn new(
143        threat_type: ThreatType,
144        level: ThreatLevel,
145        source: impl Into<String>,
146        description: impl Into<String>,
147    ) -> Self {
148        Self {
149            id: format!("THR-{}", uuid::Uuid::new_v4()),
150            threat_type,
151            level,
152            description: description.into(),
153            source: source.into(),
154            target: None,
155            detected_at: chrono::Utc::now().to_rfc3339(),
156            evidence: serde_json::json!({}),
157            confidence: 0.5,
158            handled: false,
159        }
160    }
161
162    pub fn with_target(mut self, target: impl Into<String>) -> Self {
163        self.target = Some(target.into());
164        self
165    }
166
167    pub fn with_evidence(mut self, evidence: serde_json::Value) -> Self {
168        self.evidence = evidence;
169        self
170    }
171
172    pub fn with_confidence(mut self, confidence: f32) -> Self {
173        self.confidence = confidence.clamp(0.0, 1.0);
174        self
175    }
176
177    pub fn mark_handled(&mut self) {
178        self.handled = true;
179    }
180}
181
182/// 检测规则
183#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct DetectionRule {
185    /// 规则 ID
186    pub id: String,
187    /// 规则名称
188    pub name: String,
189    /// 威胁类型
190    pub threat_type: ThreatType,
191    /// 基础威胁级别
192    pub base_level: ThreatLevel,
193    /// 触发阈值
194    pub threshold: f32,
195    /// 时间窗口(秒)
196    pub time_window_secs: u64,
197    /// 是否启用
198    pub enabled: bool,
199    /// 描述
200    pub description: String,
201}
202
203impl DetectionRule {
204    pub fn new(name: impl Into<String>, threat_type: ThreatType, base_level: ThreatLevel) -> Self {
205        Self {
206            id: format!("RULE-{}", uuid::Uuid::new_v4()),
207            name: name.into(),
208            threat_type,
209            base_level,
210            threshold: 0.5,
211            time_window_secs: 300,
212            enabled: true,
213            description: String::new(),
214        }
215    }
216
217    pub fn with_threshold(mut self, threshold: f32) -> Self {
218        self.threshold = threshold;
219        self
220    }
221
222    pub fn with_time_window(mut self, secs: u64) -> Self {
223        self.time_window_secs = secs;
224        self
225    }
226
227    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
228        self.description = desc.into();
229        self
230    }
231}
232
233/// 响应动作
234#[derive(Debug, Clone, Serialize, Deserialize)]
235pub enum ResponseAction {
236    /// 仅记录
237    Log,
238    /// 发送告警
239    Alert,
240    /// 临时封禁
241    TempBan { duration_secs: u64 },
242    /// 永久封禁
243    PermanentBan,
244    /// 限流
245    RateLimit { requests_per_sec: u32 },
246    /// 自定义动作
247    Custom { action: String },
248}
249
250/// 响应规则
251#[derive(Debug, Clone)]
252pub struct ResponseRule {
253    /// 触发的威胁级别
254    pub min_level: ThreatLevel,
255    /// 响应动作
256    pub action: ResponseAction,
257    /// 是否启用
258    pub enabled: bool,
259}
260
261impl ResponseRule {
262    pub fn new(min_level: ThreatLevel, action: ResponseAction) -> Self {
263        Self {
264            min_level,
265            action,
266            enabled: true,
267        }
268    }
269}
270
271/// 威胁统计
272#[derive(Debug, Clone, Default)]
273pub struct ThreatStats {
274    /// 总检测次数
275    pub total_detections: u64,
276    /// 各级别威胁数量
277    pub by_level: HashMap<ThreatLevel, u64>,
278    /// 各类型威胁数量
279    pub by_type: HashMap<ThreatType, u64>,
280    /// 活跃威胁数
281    pub active_threats: u64,
282    /// 已处理威胁数
283    pub handled_threats: u64,
284}
285
286/// 活动记录
287#[derive(Debug, Clone)]
288struct ActivityRecord {
289    /// 活动类型
290    activity_type: String,
291    /// 来源
292    source: String,
293    /// 时间戳
294    timestamp: Instant,
295    /// 相关数据
296    data: serde_json::Value,
297}
298
299/// 威胁检测器配置
300#[derive(Debug, Clone)]
301pub struct ThreatDetectorConfig {
302    /// 是否启用检测
303    pub enabled: bool,
304    /// 检测间隔(秒)
305    pub detection_interval_secs: u64,
306    /// 历史记录保留时间(秒)
307    pub history_retention_secs: u64,
308    /// 自动响应
309    pub auto_response: bool,
310    /// 告警阈值
311    pub alert_threshold: ThreatLevel,
312}
313
314impl Default for ThreatDetectorConfig {
315    fn default() -> Self {
316        Self {
317            enabled: true,
318            detection_interval_secs: 60,
319            history_retention_secs: 3600 * 24,
320            auto_response: true,
321            alert_threshold: ThreatLevel::High,
322        }
323    }
324}
325
326/// 威胁检测器
327pub struct ThreatDetector {
328    /// 配置
329    config: ThreatDetectorConfig,
330    /// 检测规则
331    rules: RwLock<Vec<DetectionRule>>,
332    /// 响应规则
333    response_rules: RwLock<Vec<ResponseRule>>,
334    /// 活动历史
335    activity_history: RwLock<Vec<ActivityRecord>>,
336    /// 检测到的威胁
337    threats: RwLock<Vec<Threat>>,
338    /// 统计信息
339    stats: RwLock<ThreatStats>,
340    /// 封禁列表
341    ban_list: RwLock<HashMap<String, Instant>>,
342}
343
344impl ThreatDetector {
345    /// 创建新的威胁检测器
346    pub fn new() -> Self {
347        Self::with_config(ThreatDetectorConfig::default())
348    }
349
350    /// 使用自定义配置创建
351    pub fn with_config(config: ThreatDetectorConfig) -> Self {
352        let detector = Self {
353            config,
354            rules: RwLock::new(Vec::new()),
355            response_rules: RwLock::new(Vec::new()),
356            activity_history: RwLock::new(Vec::new()),
357            threats: RwLock::new(Vec::new()),
358            stats: RwLock::new(ThreatStats::default()),
359            ban_list: RwLock::new(HashMap::new()),
360        };
361
362        // 添加默认规则
363        detector.add_default_rules();
364        detector.add_default_response_rules();
365        detector
366    }
367
368    fn add_default_rules(&self) {
369        let default_rules = vec![
370            DetectionRule::new(
371                "Brute Force Detection",
372                ThreatType::BruteForce,
373                ThreatLevel::High,
374            )
375            .with_threshold(0.3)
376            .with_time_window(300)
377            .with_description("检测短时间内多次失败登录尝试"),
378            DetectionRule::new(
379                "Rate Anomaly Detection",
380                ThreatType::RateAnomaly,
381                ThreatLevel::Medium,
382            )
383            .with_threshold(0.5)
384            .with_time_window(60)
385            .with_description("检测异常请求速率"),
386            DetectionRule::new(
387                "SQL Injection Detection",
388                ThreatType::SqlInjection,
389                ThreatLevel::Critical,
390            )
391            .with_threshold(0.8)
392            .with_time_window(1)
393            .with_description("检测 SQL 注入模式"),
394            DetectionRule::new(
395                "Path Traversal Detection",
396                ThreatType::PathTraversal,
397                ThreatLevel::High,
398            )
399            .with_threshold(0.7)
400            .with_time_window(1)
401            .with_description("检测路径遍历攻击"),
402            DetectionRule::new("XSS Detection", ThreatType::XSS, ThreatLevel::High)
403                .with_threshold(0.7)
404                .with_time_window(1)
405                .with_description("检测跨站脚本攻击"),
406        ];
407
408        let mut rules = self.rules.write();
409        for rule in default_rules {
410            rules.push(rule);
411        }
412    }
413
414    fn add_default_response_rules(&self) {
415        let default_responses = vec![
416            ResponseRule::new(ThreatLevel::Critical, ResponseAction::PermanentBan),
417            ResponseRule::new(
418                ThreatLevel::High,
419                ResponseAction::TempBan {
420                    duration_secs: 3600,
421                },
422            ),
423            ResponseRule::new(
424                ThreatLevel::Medium,
425                ResponseAction::RateLimit {
426                    requests_per_sec: 10,
427                },
428            ),
429            ResponseRule::new(ThreatLevel::Low, ResponseAction::Alert),
430            ResponseRule::new(ThreatLevel::Info, ResponseAction::Log),
431        ];
432
433        let mut response_rules = self.response_rules.write();
434        for rule in default_responses {
435            response_rules.push(rule);
436        }
437    }
438
439    /// 记录活动
440    pub fn record_activity(
441        &self,
442        activity_type: impl Into<String>,
443        source: impl Into<String>,
444        data: serde_json::Value,
445    ) {
446        let record = ActivityRecord {
447            activity_type: activity_type.into(),
448            source: source.into(),
449            timestamp: Instant::now(),
450            data,
451        };
452
453        self.activity_history.write().push(record);
454        self.cleanup_old_activities();
455    }
456
457    /// 检测威胁
458    pub fn detect(&self) -> Vec<Threat> {
459        if !self.config.enabled {
460            return Vec::new();
461        }
462
463        let mut detected_threats = Vec::new();
464        let rules = self.rules.read();
465        let activity = self.activity_history.read();
466
467        for rule in rules.iter().filter(|r| r.enabled) {
468            let threats = self.detect_with_rule(rule, &activity);
469            detected_threats.extend(threats);
470        }
471
472        // 更新统计
473        self.update_stats(&detected_threats);
474
475        // 存储威胁
476        let mut threats = self.threats.write();
477        for threat in &detected_threats {
478            threats.push(threat.clone());
479        }
480
481        detected_threats
482    }
483
484    fn detect_with_rule(&self, rule: &DetectionRule, activity: &[ActivityRecord]) -> Vec<Threat> {
485        let now = Instant::now();
486        let window = Duration::from_secs(rule.time_window_secs);
487
488        let relevant_activities: Vec<_> = activity
489            .iter()
490            .filter(|a| {
491                a.timestamp.elapsed() < window && self.activity_matches_rule(&a.activity_type, rule)
492            })
493            .collect();
494
495        if relevant_activities.is_empty() {
496            return Vec::new();
497        }
498
499        // 计算威胁分数
500        let score = self.calculate_threat_score(rule, &relevant_activities);
501
502        if score >= rule.threshold {
503            // 按来源分组
504            let mut by_source: HashMap<String, Vec<&ActivityRecord>> = HashMap::new();
505            for act in relevant_activities {
506                by_source.entry(act.source.clone()).or_default().push(act);
507            }
508
509            by_source
510                .into_iter()
511                .map(|(source, activities)| {
512                    let level = self.calculate_level(rule.base_level, score);
513                    Threat::new(rule.threat_type, level, source, &rule.description)
514                        .with_confidence(score)
515                        .with_evidence(serde_json::json!({
516                            "rule_id": rule.id,
517                            "rule_name": rule.name,
518                            "activity_count": activities.len(),
519                            "score": score,
520                        }))
521                })
522                .collect()
523        } else {
524            Vec::new()
525        }
526    }
527
528    fn activity_matches_rule(&self, activity_type: &str, rule: &DetectionRule) -> bool {
529        match rule.threat_type {
530            ThreatType::BruteForce => {
531                activity_type.contains("login") || activity_type.contains("auth")
532            }
533            ThreatType::RateAnomaly => activity_type.contains("request"),
534            ThreatType::SqlInjection => {
535                activity_type.contains("query") || activity_type.contains("sql")
536            }
537            ThreatType::XSS => activity_type.contains("input") || activity_type.contains("html"),
538            ThreatType::PathTraversal => {
539                activity_type.contains("file") || activity_type.contains("path")
540            }
541            ThreatType::DDoS => {
542                activity_type.contains("request") || activity_type.contains("connection")
543            }
544            ThreatType::PrivilegeEscalation => {
545                activity_type.contains("permission") || activity_type.contains("admin")
546            }
547            ThreatType::DataExfiltration => {
548                activity_type.contains("download") || activity_type.contains("export")
549            }
550            ThreatType::AnomalousAccess => activity_type.contains("access"),
551            ThreatType::Custom => true,
552        }
553    }
554
555    fn calculate_threat_score(&self, rule: &DetectionRule, activities: &[&ActivityRecord]) -> f32 {
556        if activities.is_empty() {
557            return 0.0;
558        }
559
560        let count = activities.len() as f32;
561        let window = rule.time_window_secs as f32;
562        let rate = count / window.max(1.0);
563
564        // 基于活动频率计算分数
565        match rule.threat_type {
566            ThreatType::BruteForce => (rate * 60.0).min(1.0),
567            ThreatType::RateAnomaly => (rate / 10.0).min(1.0),
568            ThreatType::SqlInjection | ThreatType::XSS | ThreatType::PathTraversal => {
569                // 检测恶意模式
570                activities.iter().any(|a| {
571                    let data = a.data.to_string().to_lowercase();
572                    data.contains("select") || data.contains("script") || data.contains("../")
573                }) as usize as f32
574            }
575            _ => (count / 5.0).min(1.0),
576        }
577    }
578
579    fn calculate_level(&self, base_level: ThreatLevel, score: f32) -> ThreatLevel {
580        if score >= 0.9 {
581            ThreatLevel::Critical
582        } else if score >= 0.7 {
583            ThreatLevel::High
584        } else if score >= 0.5 {
585            base_level
586        } else if score >= 0.3 {
587            ThreatLevel::Low
588        } else {
589            ThreatLevel::Info
590        }
591    }
592
593    fn update_stats(&self, new_threats: &[Threat]) {
594        let mut stats = self.stats.write();
595        stats.total_detections += new_threats.len() as u64;
596
597        for threat in new_threats {
598            *stats.by_level.entry(threat.level).or_default() += 1;
599            *stats.by_type.entry(threat.threat_type).or_default() += 1;
600            if threat.handled {
601                stats.handled_threats += 1;
602            } else {
603                stats.active_threats += 1;
604            }
605        }
606    }
607
608    /// 响应威胁
609    pub fn respond(&self, threat: &Threat) -> Option<ResponseAction> {
610        if !self.config.auto_response {
611            return None;
612        }
613
614        let response_rules = self.response_rules.read();
615        response_rules
616            .iter()
617            .filter(|r| r.enabled && threat.level >= r.min_level)
618            .max_by_key(|r| r.min_level.clone() as i32)
619            .map(|r| r.action.clone())
620    }
621
622    /// 添加检测规则
623    pub fn add_rule(&self, rule: DetectionRule) {
624        self.rules.write().push(rule);
625    }
626
627    /// 添加响应规则
628    pub fn add_response_rule(&self, rule: ResponseRule) {
629        self.response_rules.write().push(rule);
630    }
631
632    /// 获取威胁列表
633    pub fn get_threats(&self) -> Vec<Threat> {
634        self.threats.read().clone()
635    }
636
637    /// 获取活跃威胁
638    pub fn get_active_threats(&self) -> Vec<Threat> {
639        self.threats
640            .read()
641            .iter()
642            .filter(|t| !t.handled)
643            .cloned()
644            .collect()
645    }
646
647    /// 标记威胁已处理
648    pub fn handle_threat(&self, threat_id: &str) -> Result<(), ThreatError> {
649        let mut threats = self.threats.write();
650        if let Some(threat) = threats.iter_mut().find(|t| t.id == threat_id) {
651            threat.mark_handled();
652            let mut stats = self.stats.write();
653            stats.active_threats = stats.active_threats.saturating_sub(1);
654            stats.handled_threats += 1;
655            Ok(())
656        } else {
657            Err(ThreatError::RuleNotFound(threat_id.to_string()))
658        }
659    }
660
661    /// 检查是否被封禁
662    pub fn is_banned(&self, source: &str) -> bool {
663        let ban_list = self.ban_list.read();
664        if let Some(&ban_time) = ban_list.get(source) {
665            if ban_time.elapsed() < Duration::from_secs(3600) {
666                return true;
667            }
668        }
669        false
670    }
671
672    /// 封禁来源
673    pub fn ban(&self, source: &str, duration_secs: u64) {
674        let expiry = Instant::now() + Duration::from_secs(duration_secs);
675        self.ban_list.write().insert(source.to_string(), expiry);
676    }
677
678    /// 解封来源
679    pub fn unban(&self, source: &str) {
680        self.ban_list.write().remove(source);
681    }
682
683    /// 获取统计信息
684    pub fn get_stats(&self) -> ThreatStats {
685        self.stats.read().clone()
686    }
687
688    /// 清理旧活动记录
689    fn cleanup_old_activities(&self) {
690        let retention = Duration::from_secs(self.config.history_retention_secs);
691        self.activity_history
692            .write()
693            .retain(|a| a.timestamp.elapsed() < retention);
694    }
695
696    /// 重置统计
697    pub fn reset_stats(&self) {
698        *self.stats.write() = ThreatStats::default();
699    }
700
701    /// 清空威胁记录
702    pub fn clear_threats(&self) {
703        self.threats.write().clear();
704        self.stats.write().active_threats = 0;
705    }
706}
707
708impl Default for ThreatDetector {
709    fn default() -> Self {
710        Self::new()
711    }
712}
713
714#[cfg(test)]
715mod tests {
716    use super::*;
717
718    #[test]
719    fn test_threat_creation() {
720        let threat = Threat::new(
721            ThreatType::BruteForce,
722            ThreatLevel::High,
723            "192.168.1.1",
724            "Multiple failed login attempts",
725        );
726
727        assert_eq!(threat.threat_type, ThreatType::BruteForce);
728        assert_eq!(threat.level, ThreatLevel::High);
729        assert!(!threat.handled);
730    }
731
732    #[test]
733    fn test_threat_level_conversion() {
734        assert_eq!(ThreatLevel::from_str("high"), Some(ThreatLevel::High));
735        assert_eq!(
736            ThreatLevel::from_str("critical"),
737            Some(ThreatLevel::Critical)
738        );
739        assert_eq!(ThreatLevel::from_str("invalid"), None);
740    }
741
742    #[test]
743    fn test_detection_rule() {
744        let rule = DetectionRule::new("Test Rule", ThreatType::BruteForce, ThreatLevel::Medium)
745            .with_threshold(0.8)
746            .with_time_window(60);
747
748        assert_eq!(rule.threshold, 0.8);
749        assert_eq!(rule.time_window_secs, 60);
750        assert!(rule.enabled);
751    }
752
753    #[test]
754    fn test_threat_detector_creation() {
755        let detector = ThreatDetector::new();
756        assert!(detector.config.enabled);
757        assert!(!detector.rules.read().is_empty());
758    }
759
760    #[test]
761    fn test_record_activity() {
762        let detector = ThreatDetector::new();
763        detector.record_activity(
764            "login_attempt",
765            "192.168.1.1",
766            serde_json::json!({"success": false}),
767        );
768
769        let history = detector.activity_history.read();
770        assert_eq!(history.len(), 1);
771    }
772
773    #[test]
774    fn test_detect_no_threats() {
775        let detector = ThreatDetector::new();
776        let threats = detector.detect();
777        // 无活动,无威胁
778        assert!(threats.is_empty());
779    }
780
781    #[test]
782    fn test_detect_brute_force() {
783        let detector = ThreatDetector::new();
784
785        // 模拟多次失败登录
786        for _ in 0..15 {
787            detector.record_activity(
788                "login_attempt",
789                "192.168.1.100",
790                serde_json::json!({"success": false}),
791            );
792        }
793
794        let threats = detector.detect();
795        assert!(!threats.is_empty());
796
797        let bf_threats: Vec<_> = threats
798            .iter()
799            .filter(|t| t.threat_type == ThreatType::BruteForce)
800            .collect();
801        assert!(!bf_threats.is_empty());
802    }
803
804    #[test]
805    fn test_respond_to_threat() {
806        let detector = ThreatDetector::new();
807
808        let critical_threat = Threat::new(
809            ThreatType::SqlInjection,
810            ThreatLevel::Critical,
811            "192.168.1.1",
812            "SQL injection detected",
813        );
814
815        let response = detector.respond(&critical_threat);
816        assert!(matches!(response, Some(ResponseAction::PermanentBan)));
817
818        let low_threat = Threat::new(
819            ThreatType::RateAnomaly,
820            ThreatLevel::Low,
821            "192.168.1.2",
822            "Slightly elevated rate",
823        );
824
825        let response = detector.respond(&low_threat);
826        assert!(matches!(response, Some(ResponseAction::Alert)));
827    }
828
829    #[test]
830    fn test_ban_functionality() {
831        let detector = ThreatDetector::new();
832
833        detector.ban("192.168.1.1", 3600);
834        assert!(detector.is_banned("192.168.1.1"));
835
836        detector.unban("192.168.1.1");
837        assert!(!detector.is_banned("192.168.1.1"));
838    }
839
840    #[test]
841    fn test_handle_threat() {
842        let detector = ThreatDetector::new();
843
844        // 创建并添加威胁
845        let mut threat = Threat::new(
846            ThreatType::BruteForce,
847            ThreatLevel::High,
848            "192.168.1.1",
849            "Test threat",
850        );
851        threat.id = "THR-TEST-1".to_string();
852        detector.threats.write().push(threat);
853
854        let result = detector.handle_threat("THR-TEST-1");
855        assert!(result.is_ok());
856
857        let threats = detector.threats.read();
858        assert!(
859            threats
860                .iter()
861                .find(|t| t.id == "THR-TEST-1")
862                .unwrap()
863                .handled
864        );
865    }
866
867    #[test]
868    fn test_get_stats() {
869        let detector = ThreatDetector::new();
870
871        detector.record_activity("login", "192.168.1.1", serde_json::json!({}));
872        detector.record_activity("login", "192.168.1.1", serde_json::json!({}));
873
874        let _threats = detector.detect();
875        let stats = detector.get_stats();
876
877        // 应该有统计记录
878        assert!(stats.total_detections >= 0);
879    }
880
881    #[test]
882    fn test_add_custom_rule() {
883        let detector = ThreatDetector::new();
884        let initial_count = detector.rules.read().len();
885
886        let custom_rule =
887            DetectionRule::new("Custom Rule", ThreatType::Custom, ThreatLevel::Medium);
888        detector.add_rule(custom_rule);
889
890        assert_eq!(detector.rules.read().len(), initial_count + 1);
891    }
892
893    #[test]
894    fn test_clear_threats() {
895        let detector = ThreatDetector::new();
896
897        // 添加一些威胁
898        for _ in 0..5 {
899            detector.record_activity("login", "192.168.1.1", serde_json::json!({}));
900        }
901        detector.detect();
902
903        detector.clear_threats();
904        assert!(detector.threats.read().is_empty());
905    }
906
907    #[test]
908    fn test_disabled_detector() {
909        let config = ThreatDetectorConfig {
910            enabled: false,
911            ..Default::default()
912        };
913        let detector = ThreatDetector::with_config(config);
914
915        detector.record_activity("login", "192.168.1.1", serde_json::json!({}));
916        let threats = detector.detect();
917
918        // 禁用检测器不应检测到威胁
919        assert!(threats.is_empty());
920    }
921
922    #[test]
923    fn test_auto_response_disabled() {
924        let config = ThreatDetectorConfig {
925            auto_response: false,
926            ..Default::default()
927        };
928        let detector = ThreatDetector::with_config(config);
929
930        let threat = Threat::new(
931            ThreatType::BruteForce,
932            ThreatLevel::High,
933            "192.168.1.1",
934            "Test",
935        );
936        let response = detector.respond(&threat);
937
938        assert!(response.is_none());
939    }
940
941    #[test]
942    fn test_threat_with_target() {
943        let threat = Threat::new(
944            ThreatType::SqlInjection,
945            ThreatLevel::High,
946            "attacker",
947            "Attack",
948        )
949        .with_target("users_table");
950
951        assert_eq!(threat.target, Some("users_table".to_string()));
952    }
953
954    #[test]
955    fn test_threat_confidence_clamping() {
956        let threat = Threat::new(
957            ThreatType::BruteForce,
958            ThreatLevel::Medium,
959            "source",
960            "desc",
961        )
962        .with_confidence(1.5);
963        assert_eq!(threat.confidence, 1.0);
964
965        let threat = Threat::new(
966            ThreatType::BruteForce,
967            ThreatLevel::Medium,
968            "source",
969            "desc",
970        )
971        .with_confidence(-0.5);
972        assert_eq!(threat.confidence, 0.0);
973    }
974
975    #[test]
976    fn test_multiple_sources() {
977        let detector = ThreatDetector::new();
978
979        // 不同来源的活动
980        detector.record_activity("login", "192.168.1.1", serde_json::json!({}));
981        detector.record_activity("login", "192.168.1.2", serde_json::json!({}));
982        detector.record_activity("login", "192.168.1.3", serde_json::json!({}));
983
984        let threats = detector.detect();
985        // 可能不会触发威胁(取决于阈值)
986        assert!(threats.len() <= 3);
987    }
988}