1use std::sync::Arc;
16
17use tracing::debug;
18
19use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
20use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
21use crate::policy::{PolicyContext, PolicyDecision, PolicyEnforcer};
22use crate::registry::ToolDef;
23
24pub struct PolicyGateExecutor<T: ToolExecutor> {
29 inner: T,
30 enforcer: Arc<PolicyEnforcer>,
31 context: Arc<std::sync::RwLock<PolicyContext>>,
32 audit: Option<Arc<AuditLogger>>,
33}
34
35impl<T: ToolExecutor + std::fmt::Debug> std::fmt::Debug for PolicyGateExecutor<T> {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 f.debug_struct("PolicyGateExecutor")
38 .field("inner", &self.inner)
39 .finish_non_exhaustive()
40 }
41}
42
43impl<T: ToolExecutor> PolicyGateExecutor<T> {
44 #[must_use]
46 pub fn new(
47 inner: T,
48 enforcer: Arc<PolicyEnforcer>,
49 context: Arc<std::sync::RwLock<PolicyContext>>,
50 ) -> Self {
51 Self {
52 inner,
53 enforcer,
54 context,
55 audit: None,
56 }
57 }
58
59 #[must_use]
61 pub fn with_audit(mut self, audit: Arc<AuditLogger>) -> Self {
62 self.audit = Some(audit);
63 self
64 }
65
66 fn read_context(&self) -> PolicyContext {
67 match self.context.read() {
70 Ok(ctx) => ctx.clone(),
71 Err(poisoned) => {
72 tracing::warn!("PolicyContext RwLock poisoned; using poisoned value");
73 poisoned.into_inner().clone()
74 }
75 }
76 }
77
78 pub fn update_context(&self, new_ctx: PolicyContext) {
80 match self.context.write() {
81 Ok(mut ctx) => *ctx = new_ctx,
82 Err(poisoned) => {
83 tracing::warn!("PolicyContext RwLock poisoned on write; overwriting");
84 *poisoned.into_inner() = new_ctx;
85 }
86 }
87 }
88
89 async fn check_policy(&self, call: &ToolCall) -> Result<(), ToolError> {
90 let ctx = self.read_context();
91 let decision = self.enforcer.evaluate(&call.tool_id, &call.params, &ctx);
92
93 match &decision {
94 PolicyDecision::Allow { trace } => {
95 debug!(tool = %call.tool_id, trace = %trace, "policy: allow");
96 if let Some(audit) = &self.audit {
97 let entry = AuditEntry {
98 timestamp: chrono_now(),
99 tool: call.tool_id.clone(),
100 command: truncate_params(&call.params),
101 result: AuditResult::Success,
102 duration_ms: 0,
103 error_category: None,
104 error_domain: None,
105 error_phase: None,
106 claim_source: None,
107 mcp_server_id: None,
108 injection_flagged: false,
109 embedding_anomalous: false,
110 cross_boundary_mcp_to_acp: false,
111 adversarial_policy_decision: None,
112 exit_code: None,
113 truncated: false,
114 caller_id: call.caller_id.clone(),
115 policy_match: Some(trace.clone()),
117 };
118 audit.log(&entry).await;
119 }
120 Ok(())
121 }
122 PolicyDecision::Deny { trace } => {
123 debug!(tool = %call.tool_id, trace = %trace, "policy: deny");
124 if let Some(audit) = &self.audit {
125 let entry = AuditEntry {
126 timestamp: chrono_now(),
127 tool: call.tool_id.clone(),
128 command: truncate_params(&call.params),
129 result: AuditResult::Blocked {
130 reason: trace.clone(),
131 },
132 duration_ms: 0,
133 error_category: Some("policy_blocked".to_owned()),
134 error_domain: Some("action".to_owned()),
135 error_phase: None,
136 claim_source: None,
137 mcp_server_id: None,
138 injection_flagged: false,
139 embedding_anomalous: false,
140 cross_boundary_mcp_to_acp: false,
141 adversarial_policy_decision: None,
142 exit_code: None,
143 truncated: false,
144 caller_id: call.caller_id.clone(),
145 policy_match: Some(trace.clone()),
147 };
148 audit.log(&entry).await;
149 }
150 Err(ToolError::Blocked {
152 command: "Tool call denied by policy".to_owned(),
153 })
154 }
155 }
156 }
157}
158
159impl<T: ToolExecutor> ToolExecutor for PolicyGateExecutor<T> {
160 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
163 Err(ToolError::Blocked {
164 command:
165 "legacy unstructured dispatch is not supported when policy enforcement is enabled"
166 .into(),
167 })
168 }
169
170 async fn execute_confirmed(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
171 Err(ToolError::Blocked {
172 command:
173 "legacy unstructured dispatch is not supported when policy enforcement is enabled"
174 .into(),
175 })
176 }
177
178 fn tool_definitions(&self) -> Vec<ToolDef> {
179 self.inner.tool_definitions()
180 }
181
182 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
183 self.check_policy(call).await?;
184 let result = self.inner.execute_tool_call(call).await;
185 if let Ok(Some(ref output)) = result
188 && let Some(colon) = output.tool_name.find(':')
189 {
190 let server_id = output.tool_name[..colon].to_owned();
191 if let Some(audit) = &self.audit {
192 let entry = AuditEntry {
193 timestamp: chrono_now(),
194 tool: call.tool_id.clone(),
195 command: truncate_params(&call.params),
196 result: AuditResult::Success,
197 duration_ms: 0,
198 error_category: None,
199 error_domain: None,
200 error_phase: None,
201 claim_source: None,
202 mcp_server_id: Some(server_id),
203 injection_flagged: false,
204 embedding_anomalous: false,
205 cross_boundary_mcp_to_acp: false,
206 adversarial_policy_decision: None,
207 exit_code: None,
208 truncated: false,
209 caller_id: call.caller_id.clone(),
210 policy_match: None,
211 };
212 audit.log(&entry).await;
213 }
214 }
215 result
216 }
217
218 async fn execute_tool_call_confirmed(
221 &self,
222 call: &ToolCall,
223 ) -> Result<Option<ToolOutput>, ToolError> {
224 self.check_policy(call).await?;
225 self.inner.execute_tool_call_confirmed(call).await
226 }
227
228 fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
229 self.inner.set_skill_env(env);
230 }
231
232 fn set_effective_trust(&self, level: crate::SkillTrustLevel) {
233 match self.context.write() {
234 Ok(mut ctx) => ctx.trust_level = level,
235 Err(poisoned) => {
236 tracing::warn!("PolicyContext RwLock poisoned on trust update; overwriting");
237 poisoned.into_inner().trust_level = level;
238 }
239 }
240 self.inner.set_effective_trust(level);
241 }
242
243 fn is_tool_retryable(&self, tool_id: &str) -> bool {
244 self.inner.is_tool_retryable(tool_id)
245 }
246}
247
248fn truncate_params(params: &serde_json::Map<String, serde_json::Value>) -> String {
249 let s = serde_json::to_string(params).unwrap_or_default();
250 if s.chars().count() > 500 {
251 let truncated: String = s.chars().take(497).collect();
252 format!("{truncated}…")
253 } else {
254 s
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use std::collections::HashMap;
261 use std::sync::Arc;
262
263 use super::*;
264 use crate::SkillTrustLevel;
265 use crate::policy::{
266 DefaultEffect, PolicyConfig, PolicyEffect, PolicyEnforcer, PolicyRuleConfig,
267 };
268
269 #[derive(Debug)]
270 struct MockExecutor;
271
272 impl ToolExecutor for MockExecutor {
273 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
274 Ok(None)
275 }
276 async fn execute_tool_call(
277 &self,
278 call: &ToolCall,
279 ) -> Result<Option<ToolOutput>, ToolError> {
280 Ok(Some(ToolOutput {
281 tool_name: call.tool_id.clone(),
282 summary: "ok".into(),
283 blocks_executed: 1,
284 filter_stats: None,
285 diff: None,
286 streamed: false,
287 terminal_id: None,
288 locations: None,
289 raw_response: None,
290 claim_source: None,
291 }))
292 }
293 }
294
295 fn make_gate(config: &PolicyConfig) -> PolicyGateExecutor<MockExecutor> {
296 let enforcer = Arc::new(PolicyEnforcer::compile(config).unwrap());
297 let context = Arc::new(std::sync::RwLock::new(PolicyContext {
298 trust_level: SkillTrustLevel::Trusted,
299 env: HashMap::new(),
300 }));
301 PolicyGateExecutor::new(MockExecutor, enforcer, context)
302 }
303
304 fn make_call(tool_id: &str) -> ToolCall {
305 ToolCall {
306 tool_id: tool_id.into(),
307 params: serde_json::Map::new(),
308 caller_id: None,
309 }
310 }
311
312 fn make_call_with_path(tool_id: &str, path: &str) -> ToolCall {
313 let mut params = serde_json::Map::new();
314 params.insert("file_path".into(), serde_json::Value::String(path.into()));
315 ToolCall {
316 tool_id: tool_id.into(),
317 params,
318 caller_id: None,
319 }
320 }
321
322 #[tokio::test]
323 async fn allow_by_default_when_default_allow() {
324 let config = PolicyConfig {
325 enabled: true,
326 default_effect: DefaultEffect::Allow,
327 rules: vec![],
328 policy_file: None,
329 };
330 let gate = make_gate(&config);
331 let result = gate.execute_tool_call(&make_call("bash")).await;
332 assert!(result.is_ok());
333 }
334
335 #[tokio::test]
336 async fn deny_by_default_when_default_deny() {
337 let config = PolicyConfig {
338 enabled: true,
339 default_effect: DefaultEffect::Deny,
340 rules: vec![],
341 policy_file: None,
342 };
343 let gate = make_gate(&config);
344 let result = gate.execute_tool_call(&make_call("bash")).await;
345 assert!(matches!(result, Err(ToolError::Blocked { .. })));
346 }
347
348 #[tokio::test]
349 async fn deny_rule_blocks_tool() {
350 let config = PolicyConfig {
351 enabled: true,
352 default_effect: DefaultEffect::Allow,
353 rules: vec![PolicyRuleConfig {
354 effect: PolicyEffect::Deny,
355 tool: "shell".to_owned(),
356 paths: vec!["/etc/*".to_owned()],
357 env: vec![],
358 trust_level: None,
359 args_match: None,
360 capabilities: vec![],
361 }],
362 policy_file: None,
363 };
364 let gate = make_gate(&config);
365 let result = gate
366 .execute_tool_call(&make_call_with_path("shell", "/etc/passwd"))
367 .await;
368 assert!(matches!(result, Err(ToolError::Blocked { .. })));
369 }
370
371 #[tokio::test]
372 async fn allow_rule_permits_tool() {
373 let config = PolicyConfig {
374 enabled: true,
375 default_effect: DefaultEffect::Deny,
376 rules: vec![PolicyRuleConfig {
377 effect: PolicyEffect::Allow,
378 tool: "shell".to_owned(),
379 paths: vec!["/tmp/*".to_owned()],
380 env: vec![],
381 trust_level: None,
382 args_match: None,
383 capabilities: vec![],
384 }],
385 policy_file: None,
386 };
387 let gate = make_gate(&config);
388 let result = gate
389 .execute_tool_call(&make_call_with_path("shell", "/tmp/foo.sh"))
390 .await;
391 assert!(result.is_ok());
392 }
393
394 #[tokio::test]
395 async fn error_message_is_generic() {
396 let config = PolicyConfig {
398 enabled: true,
399 default_effect: DefaultEffect::Deny,
400 rules: vec![],
401 policy_file: None,
402 };
403 let gate = make_gate(&config);
404 let err = gate
405 .execute_tool_call(&make_call("bash"))
406 .await
407 .unwrap_err();
408 if let ToolError::Blocked { command } = err {
409 assert!(!command.contains("rule["), "must not leak rule index");
410 assert!(!command.contains("/etc/"), "must not leak path pattern");
411 } else {
412 panic!("expected Blocked error");
413 }
414 }
415
416 #[tokio::test]
417 async fn confirmed_also_enforces_policy() {
418 let config = PolicyConfig {
420 enabled: true,
421 default_effect: DefaultEffect::Deny,
422 rules: vec![],
423 policy_file: None,
424 };
425 let gate = make_gate(&config);
426 let result = gate.execute_tool_call_confirmed(&make_call("bash")).await;
427 assert!(matches!(result, Err(ToolError::Blocked { .. })));
428 }
429
430 #[tokio::test]
432 async fn confirmed_allow_delegates_to_inner() {
433 let config = PolicyConfig {
434 enabled: true,
435 default_effect: DefaultEffect::Allow,
436 rules: vec![],
437 policy_file: None,
438 };
439 let gate = make_gate(&config);
440 let call = make_call("shell");
441 let result = gate.execute_tool_call_confirmed(&call).await;
442 assert!(result.is_ok(), "allow path must not return an error");
443 let output = result.unwrap();
444 assert!(
445 output.is_some(),
446 "inner executor must be invoked and return output on allow"
447 );
448 assert_eq!(
449 output.unwrap().tool_name,
450 "shell",
451 "output tool_name must match the confirmed call"
452 );
453 }
454
455 #[tokio::test]
456 async fn legacy_execute_blocked_when_policy_enabled() {
457 let config = PolicyConfig {
460 enabled: true,
461 default_effect: DefaultEffect::Deny,
462 rules: vec![],
463 policy_file: None,
464 };
465 let gate = make_gate(&config);
466 let result = gate.execute("```bash\necho hi\n```").await;
467 assert!(matches!(result, Err(ToolError::Blocked { .. })));
468 let result_confirmed = gate.execute_confirmed("```bash\necho hi\n```").await;
469 assert!(matches!(result_confirmed, Err(ToolError::Blocked { .. })));
470 }
471
472 #[tokio::test]
475 async fn set_effective_trust_quarantined_blocks_verified_threshold_rule() {
476 let config = PolicyConfig {
480 enabled: true,
481 default_effect: DefaultEffect::Deny,
482 rules: vec![PolicyRuleConfig {
483 effect: PolicyEffect::Allow,
484 tool: "shell".to_owned(),
485 paths: vec![],
486 env: vec![],
487 trust_level: Some(SkillTrustLevel::Verified),
488 args_match: None,
489 capabilities: vec![],
490 }],
491 policy_file: None,
492 };
493 let gate = make_gate(&config);
494 gate.set_effective_trust(SkillTrustLevel::Quarantined);
495 let result = gate.execute_tool_call(&make_call("shell")).await;
496 assert!(
497 matches!(result, Err(ToolError::Blocked { .. })),
498 "Quarantined context must not satisfy a Verified trust threshold allow rule"
499 );
500 }
501
502 #[tokio::test]
503 async fn set_effective_trust_trusted_satisfies_verified_threshold_rule() {
504 let config = PolicyConfig {
508 enabled: true,
509 default_effect: DefaultEffect::Deny,
510 rules: vec![PolicyRuleConfig {
511 effect: PolicyEffect::Allow,
512 tool: "shell".to_owned(),
513 paths: vec![],
514 env: vec![],
515 trust_level: Some(SkillTrustLevel::Verified),
516 args_match: None,
517 capabilities: vec![],
518 }],
519 policy_file: None,
520 };
521 let gate = make_gate(&config);
522 gate.set_effective_trust(SkillTrustLevel::Trusted);
523 let result = gate.execute_tool_call(&make_call("shell")).await;
524 assert!(
525 result.is_ok(),
526 "Trusted context must satisfy a Verified trust threshold allow rule"
527 );
528 }
529}