Skip to main content

sh_layer4/compliance_checker/
checker.rs

1//! # Compliance Checker
2//!
3//! 合规检查器核心实现。
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use super::report::{CheckResult, ComplianceReport, ComplianceStatus, ReportFormat, Violation};
10use super::rules::{ComplianceRule, ComplianceStandard, RuleCategory, RuleSeverity};
11
12/// 合规检查配置
13#[derive(Debug, Clone)]
14pub struct ComplianceConfig {
15    /// 启用的标准
16    pub enabled_standards: Vec<ComplianceStandard>,
17    /// 自定义规则
18    pub custom_rules: Vec<ComplianceRule>,
19    /// 检查间隔(秒)
20    pub check_interval_secs: u64,
21    /// 是否自动修复
22    pub auto_remediation: bool,
23    /// 严重性阈值(低于此级别不报警)
24    pub severity_threshold: RuleSeverity,
25    /// 排除的规则 ID
26    pub excluded_rules: Vec<String>,
27}
28
29impl Default for ComplianceConfig {
30    fn default() -> Self {
31        Self {
32            enabled_standards: vec![ComplianceStandard::SOC2],
33            custom_rules: Vec::new(),
34            check_interval_secs: 3600,
35            auto_remediation: false,
36            severity_threshold: RuleSeverity::Low,
37            excluded_rules: Vec::new(),
38        }
39    }
40}
41
42impl ComplianceConfig {
43    pub fn new(standards: Vec<ComplianceStandard>) -> Self {
44        Self {
45            enabled_standards: standards,
46            ..Default::default()
47        }
48    }
49
50    pub fn with_custom_rules(mut self, rules: Vec<ComplianceRule>) -> Self {
51        self.custom_rules = rules;
52        self
53    }
54
55    pub fn with_check_interval(mut self, secs: u64) -> Self {
56        self.check_interval_secs = secs;
57        self
58    }
59
60    pub fn with_auto_remediation(mut self, enabled: bool) -> Self {
61        self.auto_remediation = enabled;
62        self
63    }
64
65    pub fn with_severity_threshold(mut self, threshold: RuleSeverity) -> Self {
66        self.severity_threshold = threshold;
67        self
68    }
69
70    pub fn exclude_rule(mut self, rule_id: impl Into<String>) -> Self {
71        self.excluded_rules.push(rule_id.into());
72        self
73    }
74}
75
76/// 检查上下文
77#[derive(Debug, Clone)]
78pub struct CheckContext {
79    /// 系统信息
80    pub system_info: HashMap<String, String>,
81    /// 资源清单
82    pub resources: Vec<ResourceInfo>,
83    /// 配置快照
84    pub config_snapshot: serde_json::Value,
85    /// 时间戳
86    pub timestamp: chrono::DateTime<chrono::Utc>,
87}
88
89impl Default for CheckContext {
90    fn default() -> Self {
91        Self {
92            system_info: HashMap::new(),
93            resources: Vec::new(),
94            config_snapshot: serde_json::json!({}),
95            timestamp: chrono::Utc::now(),
96        }
97    }
98}
99
100impl CheckContext {
101    pub fn new() -> Self {
102        Self::default()
103    }
104
105    pub fn with_system_info(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
106        self.system_info.insert(key.into(), value.into());
107        self
108    }
109
110    pub fn with_resources(mut self, resources: Vec<ResourceInfo>) -> Self {
111        self.resources = resources;
112        self
113    }
114
115    pub fn with_config(mut self, config: serde_json::Value) -> Self {
116        self.config_snapshot = config;
117        self
118    }
119}
120
121/// 资源信息
122#[derive(Debug, Clone)]
123pub struct ResourceInfo {
124    pub resource_type: String,
125    pub resource_id: String,
126    pub metadata: HashMap<String, String>,
127}
128
129impl ResourceInfo {
130    pub fn new(resource_type: impl Into<String>, resource_id: impl Into<String>) -> Self {
131        Self {
132            resource_type: resource_type.into(),
133            resource_id: resource_id.into(),
134            metadata: HashMap::new(),
135        }
136    }
137
138    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
139        self.metadata.insert(key.into(), value.into());
140        self
141    }
142}
143
144/// 规则检查器 Trait
145pub trait RuleChecker: Send + Sync {
146    fn check(&self, context: &CheckContext) -> CheckResult;
147    fn rule_id(&self) -> &str;
148}
149
150/// 规则检查函数类型
151type RuleCheckFn = Arc<dyn Fn(&CheckContext) -> Option<Violation> + Send + Sync>;
152
153/// 内置规则检查器
154pub struct BuiltinRuleChecker {
155    rule: ComplianceRule,
156    check_fn: RuleCheckFn,
157}
158
159impl BuiltinRuleChecker {
160    pub fn new<F>(rule: ComplianceRule, check_fn: F) -> Self
161    where
162        F: Fn(&CheckContext) -> Option<Violation> + Send + Sync + 'static,
163    {
164        Self {
165            rule,
166            check_fn: Arc::new(check_fn),
167        }
168    }
169}
170
171impl RuleChecker for BuiltinRuleChecker {
172    fn check(&self, context: &CheckContext) -> CheckResult {
173        match (self.check_fn)(context) {
174            Some(violation) => CheckResult::non_compliant(&self.rule, vec![violation]),
175            None => CheckResult::compliant(&self.rule),
176        }
177    }
178
179    fn rule_id(&self) -> &str {
180        &self.rule.id
181    }
182}
183
184/// 合规检查器
185pub struct ComplianceChecker {
186    config: ComplianceConfig,
187    rules: Vec<ComplianceRule>,
188    checkers: HashMap<String, Arc<dyn RuleChecker>>,
189    last_report: Arc<RwLock<Option<ComplianceReport>>>,
190}
191
192impl ComplianceChecker {
193    pub fn new(config: ComplianceConfig) -> Self {
194        let mut checker = Self {
195            config,
196            rules: Vec::new(),
197            checkers: HashMap::new(),
198            last_report: Arc::new(RwLock::new(None)),
199        };
200        checker.load_rules();
201        checker
202    }
203
204    fn load_rules(&mut self) {
205        // 加载启用标准的默认规则
206        for standard in &self.config.enabled_standards {
207            for rule in standard.default_rules() {
208                if !self.config.excluded_rules.contains(&rule.id) {
209                    self.rules.push(rule);
210                }
211            }
212        }
213
214        // 加载自定义规则
215        for rule in &self.config.custom_rules {
216            if !self.config.excluded_rules.contains(&rule.id) {
217                self.rules.push(rule.clone());
218            }
219        }
220
221        // 为每个规则创建默认检查器
222        for rule in &self.rules {
223            let checker = self.create_default_checker(rule);
224            self.checkers.insert(rule.id.clone(), Arc::new(checker));
225        }
226    }
227
228    fn create_default_checker(&self, rule: &ComplianceRule) -> BuiltinRuleChecker {
229        let rule_clone = rule.clone();
230        BuiltinRuleChecker::new(rule_clone, move |_context| {
231            // 默认检查逻辑 - 需要根据规则类型实现具体检查
232            // 这里返回 None 表示合规(占位实现)
233            None
234        })
235    }
236
237    /// 注册自定义检查器
238    pub fn register_checker(&mut self, checker: Arc<dyn RuleChecker>) {
239        self.checkers.insert(checker.rule_id().to_string(), checker);
240    }
241
242    /// 执行检查
243    pub async fn check(&self, context: &CheckContext) -> ComplianceReport {
244        let mut report = ComplianceReport::new(
245            format!(
246                "Compliance Check - {}",
247                chrono::Utc::now().format("%Y-%m-%d %H:%M")
248            ),
249            self.config.enabled_standards.clone(),
250        );
251
252        for rule in &self.rules {
253            if let Some(checker) = self.checkers.get(&rule.id) {
254                let result = checker.check(context);
255
256                // 过滤低于严重性阈值的结果
257                let should_include = match result.status {
258                    ComplianceStatus::NonCompliant => result
259                        .violations
260                        .iter()
261                        .any(|v| v.severity >= self.config.severity_threshold),
262                    _ => true,
263                };
264
265                if should_include {
266                    report.add_result(result);
267                }
268            }
269        }
270
271        report.calculate_score();
272
273        // 保存报告
274        {
275            let mut last_report = self.last_report.write().await;
276            *last_report = Some(report.clone());
277        }
278
279        report
280    }
281
282    /// 获取上次报告
283    pub async fn last_report(&self) -> Option<ComplianceReport> {
284        self.last_report.read().await.clone()
285    }
286
287    /// 导出报告
288    pub async fn export_report(&self, format: ReportFormat) -> Option<Vec<u8>> {
289        let report = self.last_report.read().await;
290        report.as_ref().map(|r| r.export(format))
291    }
292
293    /// 获取所有规则
294    pub fn rules(&self) -> &[ComplianceRule] {
295        &self.rules
296    }
297
298    /// 按类别获取规则
299    pub fn rules_by_category(&self, category: RuleCategory) -> Vec<&ComplianceRule> {
300        self.rules
301            .iter()
302            .filter(|r| r.category == category)
303            .collect()
304    }
305
306    /// 按严重性获取规则
307    pub fn rules_by_severity(&self, severity: RuleSeverity) -> Vec<&ComplianceRule> {
308        self.rules
309            .iter()
310            .filter(|r| r.severity == severity)
311            .collect()
312    }
313
314    /// 快速检查(不生成完整报告)
315    pub fn quick_check(&self, context: &CheckContext) -> QuickCheckResult {
316        let mut violations = Vec::new();
317        let mut checked_count = 0;
318
319        for rule in &self.rules {
320            if let Some(checker) = self.checkers.get(&rule.id) {
321                let result = checker.check(context);
322                checked_count += 1;
323
324                if !result.violations.is_empty() {
325                    violations.extend(result.violations);
326                }
327            }
328        }
329
330        let status = if violations.is_empty() {
331            ComplianceStatus::Compliant
332        } else if violations
333            .iter()
334            .any(|v| matches!(v.severity, RuleSeverity::Critical | RuleSeverity::High))
335        {
336            ComplianceStatus::NonCompliant
337        } else {
338            ComplianceStatus::PartiallyCompliant
339        };
340
341        QuickCheckResult {
342            status,
343            checked_count,
344            violation_count: violations.len(),
345            critical_count: violations
346                .iter()
347                .filter(|v| v.severity == RuleSeverity::Critical)
348                .count(),
349            high_count: violations
350                .iter()
351                .filter(|v| v.severity == RuleSeverity::High)
352                .count(),
353        }
354    }
355
356    /// 生成合规摘要
357    pub fn generate_summary(&self, report: &ComplianceReport) -> ComplianceSummary {
358        let critical_issues: Vec<_> = report
359            .violations
360            .iter()
361            .filter(|v| v.severity == RuleSeverity::Critical)
362            .collect();
363
364        let high_issues: Vec<_> = report
365            .violations
366            .iter()
367            .filter(|v| v.severity == RuleSeverity::High)
368            .collect();
369
370        let recommendations: Vec<String> = critical_issues
371            .iter()
372            .filter_map(|v| v.remediation.clone())
373            .chain(high_issues.iter().filter_map(|v| v.remediation.clone()))
374            .take(5)
375            .collect();
376
377        ComplianceSummary {
378            score: report.compliance_score,
379            status: report.overall_status,
380            total_rules: report.summary.total_rules,
381            compliant_rules: report.summary.compliant_rules,
382            total_violations: report.summary.total_violations,
383            critical_violations: report.summary.critical_violations,
384            high_violations: report.summary.high_violations,
385            recommendations,
386        }
387    }
388}
389
390/// 快速检查结果
391#[derive(Debug, Clone)]
392pub struct QuickCheckResult {
393    pub status: ComplianceStatus,
394    pub checked_count: usize,
395    pub violation_count: usize,
396    pub critical_count: usize,
397    pub high_count: usize,
398}
399
400/// 合规摘要
401#[derive(Debug, Clone)]
402pub struct ComplianceSummary {
403    pub score: f32,
404    pub status: ComplianceStatus,
405    pub total_rules: usize,
406    pub compliant_rules: usize,
407    pub total_violations: usize,
408    pub critical_violations: usize,
409    pub high_violations: usize,
410    pub recommendations: Vec<String>,
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416
417    #[test]
418    fn test_compliance_config() {
419        let config =
420            ComplianceConfig::new(vec![ComplianceStandard::SOC2, ComplianceStandard::HIPAA])
421                .with_check_interval(1800)
422                .with_auto_remediation(true);
423
424        assert_eq!(config.enabled_standards.len(), 2);
425        assert_eq!(config.check_interval_secs, 1800);
426        assert!(config.auto_remediation);
427    }
428
429    #[test]
430    fn test_check_context() {
431        let context = CheckContext::new()
432            .with_system_info("version", "1.0.0")
433            .with_system_info("environment", "production");
434
435        assert_eq!(
436            context.system_info.get("version"),
437            Some(&"1.0.0".to_string())
438        );
439    }
440
441    #[test]
442    fn test_resource_info() {
443        let resource = ResourceInfo::new("server", "srv-001").with_metadata("region", "us-east-1");
444
445        assert_eq!(resource.resource_type, "server");
446        assert_eq!(
447            resource.metadata.get("region"),
448            Some(&"us-east-1".to_string())
449        );
450    }
451
452    #[tokio::test]
453    async fn test_compliance_checker_creation() {
454        let config = ComplianceConfig::new(vec![ComplianceStandard::SOC2]);
455        let checker = ComplianceChecker::new(config);
456
457        // SOC2 默认规则数量
458        assert!(!checker.rules().is_empty());
459    }
460
461    #[tokio::test]
462    async fn test_quick_check() {
463        let config = ComplianceConfig::new(vec![ComplianceStandard::SOC2]);
464        let checker = ComplianceChecker::new(config);
465        let context = CheckContext::new();
466
467        let result = checker.quick_check(&context);
468        assert_eq!(result.checked_count, checker.rules().len());
469    }
470
471    #[tokio::test]
472    async fn test_check_generates_report() {
473        let config = ComplianceConfig::new(vec![ComplianceStandard::SOC2]);
474        let checker = ComplianceChecker::new(config);
475        let context = CheckContext::new();
476
477        let report = checker.check(&context).await;
478        assert!(!report.results.is_empty());
479        assert!(report.compliance_score >= 0.0 && report.compliance_score <= 100.0);
480    }
481
482    #[test]
483    fn test_rules_filtering() {
484        let config = ComplianceConfig::new(vec![ComplianceStandard::SOC2]).exclude_rule("SOC2-001");
485
486        let checker = ComplianceChecker::new(config);
487
488        // 被排除的规则不应在规则列表中
489        assert!(!checker.rules().iter().any(|r| r.id == "SOC2-001"));
490    }
491
492    #[test]
493    fn test_rules_by_category() {
494        let config = ComplianceConfig::new(vec![ComplianceStandard::SOC2]);
495        let checker = ComplianceChecker::new(config);
496
497        let security_rules = checker.rules_by_category(RuleCategory::Security);
498        assert!(!security_rules.is_empty());
499    }
500
501    #[test]
502    fn test_rules_by_severity() {
503        let config = ComplianceConfig::new(vec![ComplianceStandard::SOC2]);
504        let checker = ComplianceChecker::new(config);
505
506        let critical_rules = checker.rules_by_severity(RuleSeverity::Critical);
507        // 取决于默认规则定义
508        assert!(critical_rules.len() <= checker.rules().len());
509    }
510}