1use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::SystemTime;
9
10use crate::types::*;
11
12#[async_trait]
14pub trait PolicyEngine: Send + Sync {
15 async fn evaluate_policy(&self, request: PolicyRequest) -> Result<PolicyDecision, PolicyError>;
17
18 async fn register_policy(&self, policy: Policy) -> Result<PolicyId, PolicyError>;
20
21 async fn update_policy(&self, policy_id: PolicyId, policy: Policy) -> Result<(), PolicyError>;
23
24 async fn delete_policy(&self, policy_id: PolicyId) -> Result<(), PolicyError>;
26
27 async fn list_policies(&self) -> Result<Vec<PolicyInfo>, PolicyError>;
29
30 async fn get_policy(&self, policy_id: PolicyId) -> Result<Policy, PolicyError>;
32
33 async fn validate_policy(&self, policy: &Policy) -> Result<ValidationResult, PolicyError>;
35
36 async fn get_policy_stats(&self) -> Result<PolicyStatistics, PolicyError>;
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct PolicyRequest {
43 pub agent_id: AgentId,
44 pub action: AgentAction,
45 pub context: PolicyContext,
46 pub timestamp: SystemTime,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub enum AgentAction {
52 Execute {
53 command: String,
54 args: Vec<String>,
55 },
56 NetworkAccess {
57 destination: String,
58 port: u16,
59 protocol: NetworkProtocol,
60 },
61 FileAccess {
62 path: String,
63 operation: FileOperation,
64 },
65 ResourceAllocation {
66 resource_type: String, amount: u64,
68 },
69 Communication {
70 target: AgentId,
71 message_type: String,
72 },
73 StateTransition {
74 from_state: AgentState,
75 to_state: AgentState,
76 },
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub enum NetworkProtocol {
82 Tcp,
83 Udp,
84 Http,
85 Https,
86 WebSocket,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub enum FileOperation {
92 Read,
93 Write,
94 Execute,
95 Delete,
96 Create,
97 Modify,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct PolicyContext {
103 pub agent_metadata: AgentMetadata,
104 pub resource_usage: ResourceUsage,
105 pub security_level: SecurityTier,
106 pub environment: HashMap<String, String>,
107 pub previous_actions: Vec<AgentAction>,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct PolicyDecision {
113 pub decision: Decision,
114 pub reason: String,
115 pub conditions: Vec<PolicyCondition>,
116 pub metadata: HashMap<String, String>,
117 pub expires_at: Option<SystemTime>,
118}
119
120#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
122pub enum Decision {
123 Allow,
124 Deny,
125 Conditional,
126 Defer,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct PolicyCondition {
132 pub condition_type: ConditionType,
133 pub parameters: HashMap<String, String>,
134 pub timeout: Option<Duration>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub enum ConditionType {
140 ResourceLimit,
141 TimeWindow,
142 ApprovalRequired,
143 AuditRequired,
144 SecurityScan,
145 RateLimited,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct Policy {
151 pub id: Option<PolicyId>,
152 pub name: String,
153 pub description: String,
154 pub version: String,
155 pub rules: Vec<PolicyRule>,
156 pub priority: u32,
157 pub enabled: bool,
158 pub created_at: SystemTime,
159 pub updated_at: SystemTime,
160 pub tags: Vec<String>,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct PolicyRule {
166 pub id: String,
167 pub condition: RuleCondition,
168 pub action: RuleAction,
169 pub metadata: HashMap<String, String>,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
174pub enum RuleCondition {
175 AgentMatch { patterns: Vec<String> },
176 ActionMatch { action_types: Vec<String> },
177 ResourceMatch { resource_patterns: Vec<String> },
178 TimeMatch { time_windows: Vec<TimeWindow> },
179 SecurityLevelMatch { levels: Vec<SecurityTier> },
180 And { conditions: Vec<RuleCondition> },
181 Or { conditions: Vec<RuleCondition> },
182 Not { condition: Box<RuleCondition> },
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct TimeWindow {
188 pub start_time: String, pub end_time: String, pub days: Vec<Weekday>,
191 pub timezone: String,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub enum Weekday {
197 Monday,
198 Tuesday,
199 Wednesday,
200 Thursday,
201 Friday,
202 Saturday,
203 Sunday,
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
208pub enum RuleAction {
209 Allow { conditions: Vec<PolicyCondition> },
210 Deny { reason: String },
211 Require { requirements: Vec<String> },
212 Limit { limits: HashMap<String, u64> },
213 Audit { level: AuditLevel },
214 Escalate { to: String },
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub enum AuditLevel {
220 Info,
221 Warning,
222 Critical,
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct PolicyInfo {
228 pub id: PolicyId,
229 pub name: String,
230 pub description: String,
231 pub version: String,
232 pub priority: u32,
233 pub enabled: bool,
234 pub rule_count: u32,
235 pub created_at: SystemTime,
236 pub updated_at: SystemTime,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct ValidationResult {
242 pub valid: bool,
243 pub errors: Vec<ValidationError>,
244 pub warnings: Vec<ValidationWarning>,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct ValidationError {
250 pub rule_id: Option<String>,
251 pub error_type: String,
252 pub message: String,
253 pub line: Option<u32>,
254 pub column: Option<u32>,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
259pub struct ValidationWarning {
260 pub rule_id: Option<String>,
261 pub warning_type: String,
262 pub message: String,
263 pub suggestion: Option<String>,
264}
265
266#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct PolicyStatistics {
269 pub total_evaluations: u64,
270 pub decisions: HashMap<Decision, u64>,
271 pub policy_usage: HashMap<PolicyId, u64>,
272 pub average_evaluation_time: Duration,
273 pub error_rate: f64,
274 pub last_updated: SystemTime,
275}
276
277pub use crate::types::PolicyId;
279
280pub type Duration = std::time::Duration;
282
283pub struct MockPolicyEngine {
285 policies: std::sync::RwLock<HashMap<PolicyId, Policy>>,
286 stats: std::sync::RwLock<PolicyStatistics>,
287}
288
289impl MockPolicyEngine {
290 pub fn new() -> Self {
291 let default_policy = Self::create_default_policy();
292 let mut policies = HashMap::new();
293 if let Some(id) = default_policy.id {
294 policies.insert(id, default_policy);
295 }
296
297 Self {
298 policies: std::sync::RwLock::new(policies),
299 stats: std::sync::RwLock::new(PolicyStatistics {
300 total_evaluations: 0,
301 decisions: HashMap::new(),
302 policy_usage: HashMap::new(),
303 average_evaluation_time: Duration::from_millis(10),
304 error_rate: 0.0,
305 last_updated: SystemTime::now(),
306 }),
307 }
308 }
309
310 fn create_default_policy() -> Policy {
311 Policy {
312 id: Some(PolicyId::new()),
313 name: "Default Allow Policy".to_string(),
314 description: "Default policy that allows most actions".to_string(),
315 version: "1.0.0".to_string(),
316 rules: vec![PolicyRule {
317 id: "default-allow".to_string(),
318 condition: RuleCondition::AgentMatch {
319 patterns: vec!["*".to_string()],
320 },
321 action: RuleAction::Allow { conditions: vec![] },
322 metadata: HashMap::new(),
323 }],
324 priority: 1000,
325 enabled: true,
326 created_at: SystemTime::now(),
327 updated_at: SystemTime::now(),
328 tags: vec!["default".to_string()],
329 }
330 }
331}
332
333impl Default for MockPolicyEngine {
334 fn default() -> Self {
335 Self::new()
336 }
337}
338
339#[async_trait]
340impl PolicyEngine for MockPolicyEngine {
341 async fn evaluate_policy(&self, request: PolicyRequest) -> Result<PolicyDecision, PolicyError> {
342 {
344 let mut stats = self.stats.write().unwrap();
345 stats.total_evaluations += 1;
346 *stats.decisions.entry(Decision::Allow).or_insert(0) += 1;
347 stats.last_updated = SystemTime::now();
348 }
349
350 let decision = match &request.action {
352 AgentAction::Execute { command, .. } => {
353 if command.contains("rm") || command.contains("delete") {
354 Decision::Conditional
355 } else {
356 Decision::Allow
357 }
358 }
359 AgentAction::NetworkAccess { destination, .. } => {
360 if destination.contains("malicious") {
361 Decision::Deny
362 } else {
363 Decision::Allow
364 }
365 }
366 AgentAction::FileAccess {
367 operation: FileOperation::Delete,
368 ..
369 } => Decision::Conditional,
370 AgentAction::FileAccess { .. } => Decision::Allow,
371 _ => Decision::Allow,
372 };
373
374 let conditions = if decision == Decision::Conditional {
375 vec![PolicyCondition {
376 condition_type: ConditionType::ApprovalRequired,
377 parameters: HashMap::new(),
378 timeout: Some(Duration::from_secs(300)),
379 }]
380 } else {
381 vec![]
382 };
383
384 Ok(PolicyDecision {
385 decision,
386 reason: "Mock policy evaluation".to_string(),
387 conditions,
388 metadata: HashMap::new(),
389 expires_at: None,
390 })
391 }
392
393 async fn register_policy(&self, mut policy: Policy) -> Result<PolicyId, PolicyError> {
394 let policy_id = PolicyId::new();
395 policy.id = Some(policy_id);
396 policy.created_at = SystemTime::now();
397 policy.updated_at = SystemTime::now();
398
399 self.policies.write().unwrap().insert(policy_id, policy);
400 Ok(policy_id)
401 }
402
403 async fn update_policy(
404 &self,
405 policy_id: PolicyId,
406 mut policy: Policy,
407 ) -> Result<(), PolicyError> {
408 policy.id = Some(policy_id);
409 policy.updated_at = SystemTime::now();
410
411 let mut policies = self.policies.write().unwrap();
412 if let std::collections::hash_map::Entry::Occupied(mut e) = policies.entry(policy_id) {
413 e.insert(policy);
414 Ok(())
415 } else {
416 Err(PolicyError::PolicyNotFound { id: policy_id })
417 }
418 }
419
420 async fn delete_policy(&self, policy_id: PolicyId) -> Result<(), PolicyError> {
421 let mut policies = self.policies.write().unwrap();
422 if policies.remove(&policy_id).is_some() {
423 Ok(())
424 } else {
425 Err(PolicyError::PolicyNotFound { id: policy_id })
426 }
427 }
428
429 async fn list_policies(&self) -> Result<Vec<PolicyInfo>, PolicyError> {
430 let policies = self.policies.read().unwrap();
431 let policy_infos = policies
432 .values()
433 .map(|policy| PolicyInfo {
434 id: policy.id.unwrap(),
435 name: policy.name.clone(),
436 description: policy.description.clone(),
437 version: policy.version.clone(),
438 priority: policy.priority,
439 enabled: policy.enabled,
440 rule_count: policy.rules.len() as u32,
441 created_at: policy.created_at,
442 updated_at: policy.updated_at,
443 })
444 .collect();
445
446 Ok(policy_infos)
447 }
448
449 async fn get_policy(&self, policy_id: PolicyId) -> Result<Policy, PolicyError> {
450 let policies = self.policies.read().unwrap();
451 policies
452 .get(&policy_id)
453 .cloned()
454 .ok_or(PolicyError::PolicyNotFound { id: policy_id })
455 }
456
457 async fn validate_policy(&self, _policy: &Policy) -> Result<ValidationResult, PolicyError> {
458 Ok(ValidationResult {
460 valid: true,
461 errors: vec![],
462 warnings: vec![],
463 })
464 }
465
466 async fn get_policy_stats(&self) -> Result<PolicyStatistics, PolicyError> {
467 Ok(self.stats.read().unwrap().clone())
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[tokio::test]
476 async fn test_mock_policy_engine() {
477 let engine = MockPolicyEngine::new();
478
479 let policy = MockPolicyEngine::create_default_policy();
481 let policy_id = engine.register_policy(policy).await.unwrap();
482
483 let retrieved_policy = engine.get_policy(policy_id).await.unwrap();
485 assert_eq!(retrieved_policy.name, "Default Allow Policy");
486
487 let request = PolicyRequest {
489 agent_id: AgentId::new(),
490 action: AgentAction::Execute {
491 command: "ls".to_string(),
492 args: vec!["-la".to_string()],
493 },
494 context: PolicyContext {
495 agent_metadata: AgentMetadata {
496 version: "1.0.0".to_string(),
497 author: "test".to_string(),
498 description: "Test agent".to_string(),
499 capabilities: vec![],
500 dependencies: vec![],
501 resource_requirements: crate::types::agent::ResourceRequirements::default(),
502 security_requirements: crate::types::agent::SecurityRequirements::default(),
503 custom_fields: std::collections::HashMap::new(),
504 },
505 resource_usage: ResourceUsage {
506 memory_used: 1024 * 1024,
507 cpu_utilization: 1.0,
508 disk_io_rate: 0,
509 network_io_rate: 0,
510 uptime: std::time::Duration::from_secs(60),
511 },
512 security_level: SecurityTier::Tier2,
513 environment: HashMap::new(),
514 previous_actions: vec![],
515 },
516 timestamp: SystemTime::now(),
517 };
518
519 let decision = engine.evaluate_policy(request).await.unwrap();
520 assert_eq!(decision.decision, Decision::Allow);
521 }
522
523 #[tokio::test]
524 async fn test_policy_validation() {
525 let engine = MockPolicyEngine::new();
526 let policy = MockPolicyEngine::create_default_policy();
527
528 let result = engine.validate_policy(&policy).await.unwrap();
529 assert!(result.valid);
530 assert!(result.errors.is_empty());
531 }
532}