sh_layer4/compliance_checker/
checker.rs1use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use super::report::{CheckResult, ComplianceReport, ComplianceStatus, ReportFormat, Violation};
10use super::rules::{ComplianceRule, ComplianceStandard, RuleCategory, RuleSeverity};
11
12#[derive(Debug, Clone)]
14pub struct ComplianceConfig {
15 pub enabled_standards: Vec<ComplianceStandard>,
17 pub custom_rules: Vec<ComplianceRule>,
19 pub check_interval_secs: u64,
21 pub auto_remediation: bool,
23 pub severity_threshold: RuleSeverity,
25 pub excluded_rules: Vec<String>,
27}
28
29impl Default for ComplianceConfig {
30 fn default() -> Self {
31 Self {
32 enabled_standards: vec![ComplianceStandard::SOC2],
33 custom_rules: Vec::new(),
34 check_interval_secs: 3600,
35 auto_remediation: false,
36 severity_threshold: RuleSeverity::Low,
37 excluded_rules: Vec::new(),
38 }
39 }
40}
41
42impl ComplianceConfig {
43 pub fn new(standards: Vec<ComplianceStandard>) -> Self {
44 Self {
45 enabled_standards: standards,
46 ..Default::default()
47 }
48 }
49
50 pub fn with_custom_rules(mut self, rules: Vec<ComplianceRule>) -> Self {
51 self.custom_rules = rules;
52 self
53 }
54
55 pub fn with_check_interval(mut self, secs: u64) -> Self {
56 self.check_interval_secs = secs;
57 self
58 }
59
60 pub fn with_auto_remediation(mut self, enabled: bool) -> Self {
61 self.auto_remediation = enabled;
62 self
63 }
64
65 pub fn with_severity_threshold(mut self, threshold: RuleSeverity) -> Self {
66 self.severity_threshold = threshold;
67 self
68 }
69
70 pub fn exclude_rule(mut self, rule_id: impl Into<String>) -> Self {
71 self.excluded_rules.push(rule_id.into());
72 self
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct CheckContext {
79 pub system_info: HashMap<String, String>,
81 pub resources: Vec<ResourceInfo>,
83 pub config_snapshot: serde_json::Value,
85 pub timestamp: chrono::DateTime<chrono::Utc>,
87}
88
89impl Default for CheckContext {
90 fn default() -> Self {
91 Self {
92 system_info: HashMap::new(),
93 resources: Vec::new(),
94 config_snapshot: serde_json::json!({}),
95 timestamp: chrono::Utc::now(),
96 }
97 }
98}
99
100impl CheckContext {
101 pub fn new() -> Self {
102 Self::default()
103 }
104
105 pub fn with_system_info(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
106 self.system_info.insert(key.into(), value.into());
107 self
108 }
109
110 pub fn with_resources(mut self, resources: Vec<ResourceInfo>) -> Self {
111 self.resources = resources;
112 self
113 }
114
115 pub fn with_config(mut self, config: serde_json::Value) -> Self {
116 self.config_snapshot = config;
117 self
118 }
119}
120
121#[derive(Debug, Clone)]
123pub struct ResourceInfo {
124 pub resource_type: String,
125 pub resource_id: String,
126 pub metadata: HashMap<String, String>,
127}
128
129impl ResourceInfo {
130 pub fn new(resource_type: impl Into<String>, resource_id: impl Into<String>) -> Self {
131 Self {
132 resource_type: resource_type.into(),
133 resource_id: resource_id.into(),
134 metadata: HashMap::new(),
135 }
136 }
137
138 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
139 self.metadata.insert(key.into(), value.into());
140 self
141 }
142}
143
144pub trait RuleChecker: Send + Sync {
146 fn check(&self, context: &CheckContext) -> CheckResult;
147 fn rule_id(&self) -> &str;
148}
149
150type RuleCheckFn = Arc<dyn Fn(&CheckContext) -> Option<Violation> + Send + Sync>;
152
153pub struct BuiltinRuleChecker {
155 rule: ComplianceRule,
156 check_fn: RuleCheckFn,
157}
158
159impl BuiltinRuleChecker {
160 pub fn new<F>(rule: ComplianceRule, check_fn: F) -> Self
161 where
162 F: Fn(&CheckContext) -> Option<Violation> + Send + Sync + 'static,
163 {
164 Self {
165 rule,
166 check_fn: Arc::new(check_fn),
167 }
168 }
169}
170
171impl RuleChecker for BuiltinRuleChecker {
172 fn check(&self, context: &CheckContext) -> CheckResult {
173 match (self.check_fn)(context) {
174 Some(violation) => CheckResult::non_compliant(&self.rule, vec![violation]),
175 None => CheckResult::compliant(&self.rule),
176 }
177 }
178
179 fn rule_id(&self) -> &str {
180 &self.rule.id
181 }
182}
183
184pub struct ComplianceChecker {
186 config: ComplianceConfig,
187 rules: Vec<ComplianceRule>,
188 checkers: HashMap<String, Arc<dyn RuleChecker>>,
189 last_report: Arc<RwLock<Option<ComplianceReport>>>,
190}
191
192impl ComplianceChecker {
193 pub fn new(config: ComplianceConfig) -> Self {
194 let mut checker = Self {
195 config,
196 rules: Vec::new(),
197 checkers: HashMap::new(),
198 last_report: Arc::new(RwLock::new(None)),
199 };
200 checker.load_rules();
201 checker
202 }
203
204 fn load_rules(&mut self) {
205 for standard in &self.config.enabled_standards {
207 for rule in standard.default_rules() {
208 if !self.config.excluded_rules.contains(&rule.id) {
209 self.rules.push(rule);
210 }
211 }
212 }
213
214 for rule in &self.config.custom_rules {
216 if !self.config.excluded_rules.contains(&rule.id) {
217 self.rules.push(rule.clone());
218 }
219 }
220
221 for rule in &self.rules {
223 let checker = self.create_default_checker(rule);
224 self.checkers.insert(rule.id.clone(), Arc::new(checker));
225 }
226 }
227
228 fn create_default_checker(&self, rule: &ComplianceRule) -> BuiltinRuleChecker {
229 let rule_clone = rule.clone();
230 BuiltinRuleChecker::new(rule_clone, move |_context| {
231 None
234 })
235 }
236
237 pub fn register_checker(&mut self, checker: Arc<dyn RuleChecker>) {
239 self.checkers.insert(checker.rule_id().to_string(), checker);
240 }
241
242 pub async fn check(&self, context: &CheckContext) -> ComplianceReport {
244 let mut report = ComplianceReport::new(
245 format!(
246 "Compliance Check - {}",
247 chrono::Utc::now().format("%Y-%m-%d %H:%M")
248 ),
249 self.config.enabled_standards.clone(),
250 );
251
252 for rule in &self.rules {
253 if let Some(checker) = self.checkers.get(&rule.id) {
254 let result = checker.check(context);
255
256 let should_include = match result.status {
258 ComplianceStatus::NonCompliant => result
259 .violations
260 .iter()
261 .any(|v| v.severity >= self.config.severity_threshold),
262 _ => true,
263 };
264
265 if should_include {
266 report.add_result(result);
267 }
268 }
269 }
270
271 report.calculate_score();
272
273 {
275 let mut last_report = self.last_report.write().await;
276 *last_report = Some(report.clone());
277 }
278
279 report
280 }
281
282 pub async fn last_report(&self) -> Option<ComplianceReport> {
284 self.last_report.read().await.clone()
285 }
286
287 pub async fn export_report(&self, format: ReportFormat) -> Option<Vec<u8>> {
289 let report = self.last_report.read().await;
290 report.as_ref().map(|r| r.export(format))
291 }
292
293 pub fn rules(&self) -> &[ComplianceRule] {
295 &self.rules
296 }
297
298 pub fn rules_by_category(&self, category: RuleCategory) -> Vec<&ComplianceRule> {
300 self.rules
301 .iter()
302 .filter(|r| r.category == category)
303 .collect()
304 }
305
306 pub fn rules_by_severity(&self, severity: RuleSeverity) -> Vec<&ComplianceRule> {
308 self.rules
309 .iter()
310 .filter(|r| r.severity == severity)
311 .collect()
312 }
313
314 pub fn quick_check(&self, context: &CheckContext) -> QuickCheckResult {
316 let mut violations = Vec::new();
317 let mut checked_count = 0;
318
319 for rule in &self.rules {
320 if let Some(checker) = self.checkers.get(&rule.id) {
321 let result = checker.check(context);
322 checked_count += 1;
323
324 if !result.violations.is_empty() {
325 violations.extend(result.violations);
326 }
327 }
328 }
329
330 let status = if violations.is_empty() {
331 ComplianceStatus::Compliant
332 } else if violations
333 .iter()
334 .any(|v| matches!(v.severity, RuleSeverity::Critical | RuleSeverity::High))
335 {
336 ComplianceStatus::NonCompliant
337 } else {
338 ComplianceStatus::PartiallyCompliant
339 };
340
341 QuickCheckResult {
342 status,
343 checked_count,
344 violation_count: violations.len(),
345 critical_count: violations
346 .iter()
347 .filter(|v| v.severity == RuleSeverity::Critical)
348 .count(),
349 high_count: violations
350 .iter()
351 .filter(|v| v.severity == RuleSeverity::High)
352 .count(),
353 }
354 }
355
356 pub fn generate_summary(&self, report: &ComplianceReport) -> ComplianceSummary {
358 let critical_issues: Vec<_> = report
359 .violations
360 .iter()
361 .filter(|v| v.severity == RuleSeverity::Critical)
362 .collect();
363
364 let high_issues: Vec<_> = report
365 .violations
366 .iter()
367 .filter(|v| v.severity == RuleSeverity::High)
368 .collect();
369
370 let recommendations: Vec<String> = critical_issues
371 .iter()
372 .filter_map(|v| v.remediation.clone())
373 .chain(high_issues.iter().filter_map(|v| v.remediation.clone()))
374 .take(5)
375 .collect();
376
377 ComplianceSummary {
378 score: report.compliance_score,
379 status: report.overall_status,
380 total_rules: report.summary.total_rules,
381 compliant_rules: report.summary.compliant_rules,
382 total_violations: report.summary.total_violations,
383 critical_violations: report.summary.critical_violations,
384 high_violations: report.summary.high_violations,
385 recommendations,
386 }
387 }
388}
389
390#[derive(Debug, Clone)]
392pub struct QuickCheckResult {
393 pub status: ComplianceStatus,
394 pub checked_count: usize,
395 pub violation_count: usize,
396 pub critical_count: usize,
397 pub high_count: usize,
398}
399
400#[derive(Debug, Clone)]
402pub struct ComplianceSummary {
403 pub score: f32,
404 pub status: ComplianceStatus,
405 pub total_rules: usize,
406 pub compliant_rules: usize,
407 pub total_violations: usize,
408 pub critical_violations: usize,
409 pub high_violations: usize,
410 pub recommendations: Vec<String>,
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416
417 #[test]
418 fn test_compliance_config() {
419 let config =
420 ComplianceConfig::new(vec![ComplianceStandard::SOC2, ComplianceStandard::HIPAA])
421 .with_check_interval(1800)
422 .with_auto_remediation(true);
423
424 assert_eq!(config.enabled_standards.len(), 2);
425 assert_eq!(config.check_interval_secs, 1800);
426 assert!(config.auto_remediation);
427 }
428
429 #[test]
430 fn test_check_context() {
431 let context = CheckContext::new()
432 .with_system_info("version", "1.0.0")
433 .with_system_info("environment", "production");
434
435 assert_eq!(
436 context.system_info.get("version"),
437 Some(&"1.0.0".to_string())
438 );
439 }
440
441 #[test]
442 fn test_resource_info() {
443 let resource = ResourceInfo::new("server", "srv-001").with_metadata("region", "us-east-1");
444
445 assert_eq!(resource.resource_type, "server");
446 assert_eq!(
447 resource.metadata.get("region"),
448 Some(&"us-east-1".to_string())
449 );
450 }
451
452 #[tokio::test]
453 async fn test_compliance_checker_creation() {
454 let config = ComplianceConfig::new(vec![ComplianceStandard::SOC2]);
455 let checker = ComplianceChecker::new(config);
456
457 assert!(!checker.rules().is_empty());
459 }
460
461 #[tokio::test]
462 async fn test_quick_check() {
463 let config = ComplianceConfig::new(vec![ComplianceStandard::SOC2]);
464 let checker = ComplianceChecker::new(config);
465 let context = CheckContext::new();
466
467 let result = checker.quick_check(&context);
468 assert_eq!(result.checked_count, checker.rules().len());
469 }
470
471 #[tokio::test]
472 async fn test_check_generates_report() {
473 let config = ComplianceConfig::new(vec![ComplianceStandard::SOC2]);
474 let checker = ComplianceChecker::new(config);
475 let context = CheckContext::new();
476
477 let report = checker.check(&context).await;
478 assert!(!report.results.is_empty());
479 assert!(report.compliance_score >= 0.0 && report.compliance_score <= 100.0);
480 }
481
482 #[test]
483 fn test_rules_filtering() {
484 let config = ComplianceConfig::new(vec![ComplianceStandard::SOC2]).exclude_rule("SOC2-001");
485
486 let checker = ComplianceChecker::new(config);
487
488 assert!(!checker.rules().iter().any(|r| r.id == "SOC2-001"));
490 }
491
492 #[test]
493 fn test_rules_by_category() {
494 let config = ComplianceConfig::new(vec![ComplianceStandard::SOC2]);
495 let checker = ComplianceChecker::new(config);
496
497 let security_rules = checker.rules_by_category(RuleCategory::Security);
498 assert!(!security_rules.is_empty());
499 }
500
501 #[test]
502 fn test_rules_by_severity() {
503 let config = ComplianceConfig::new(vec![ComplianceStandard::SOC2]);
504 let checker = ComplianceChecker::new(config);
505
506 let critical_rules = checker.rules_by_severity(RuleSeverity::Critical);
507 assert!(critical_rules.len() <= checker.rules().len());
509 }
510}