1use std::sync::Arc;
18
19use crate::adversarial_policy::{PolicyDecision, PolicyLlmClient, PolicyValidator};
20use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
21use crate::executor::{ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput};
22use crate::registry::ToolDef;
23
24pub struct AdversarialPolicyGateExecutor<T: ToolExecutor> {
30 inner: T,
31 validator: Arc<PolicyValidator>,
32 llm: Arc<dyn PolicyLlmClient>,
33 audit: Option<Arc<AuditLogger>>,
34}
35
36impl<T: ToolExecutor + std::fmt::Debug> std::fmt::Debug for AdversarialPolicyGateExecutor<T> {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("AdversarialPolicyGateExecutor")
39 .field("inner", &self.inner)
40 .finish_non_exhaustive()
41 }
42}
43
44impl<T: ToolExecutor> AdversarialPolicyGateExecutor<T> {
45 #[must_use]
47 pub fn new(inner: T, validator: Arc<PolicyValidator>, llm: Arc<dyn PolicyLlmClient>) -> Self {
48 Self {
49 inner,
50 validator,
51 llm,
52 audit: None,
53 }
54 }
55
56 #[must_use]
58 pub fn with_audit(mut self, audit: Arc<AuditLogger>) -> Self {
59 self.audit = Some(audit);
60 self
61 }
62
63 async fn check_policy(&self, call: &ToolCall) -> Result<(), ToolError> {
64 tracing::info!(
65 tool = %call.tool_id,
66 status_spinner = true,
67 "Validating tool policy\u{2026}"
68 );
69
70 let decision = self
71 .validator
72 .validate(call.tool_id.as_str(), &call.params, self.llm.as_ref())
73 .await;
74
75 match decision {
76 PolicyDecision::Allow => {
77 tracing::debug!(tool = %call.tool_id, "adversarial policy: allow");
78 self.write_audit(call, "allow", AuditResult::Success, None)
79 .await;
80 Ok(())
81 }
82 PolicyDecision::Deny { reason } => {
83 tracing::warn!(
84 tool = %call.tool_id,
85 reason = %reason,
86 "adversarial policy: deny"
87 );
88 self.write_audit(
89 call,
90 &format!("deny:{reason}"),
91 AuditResult::Blocked {
92 reason: reason.clone(),
93 },
94 None,
95 )
96 .await;
97 Err(ToolError::Blocked {
99 command: "[adversarial] Tool call denied by policy".to_owned(),
100 })
101 }
102 PolicyDecision::Error { message } => {
103 tracing::warn!(
104 tool = %call.tool_id,
105 error = %message,
106 fail_open = self.validator.fail_open(),
107 "adversarial policy: LLM error"
108 );
109 if self.validator.fail_open() {
110 self.write_audit(
111 call,
112 &format!("error:{message}"),
113 AuditResult::Success,
114 None,
115 )
116 .await;
117 Ok(())
118 } else {
119 self.write_audit(
120 call,
121 &format!("error:{message}"),
122 AuditResult::Blocked {
123 reason: "adversarial policy LLM error (fail-closed)".to_owned(),
124 },
125 None,
126 )
127 .await;
128 Err(ToolError::Blocked {
129 command: "[adversarial] Tool call denied: policy check failed".to_owned(),
130 })
131 }
132 }
133 }
134 }
135
136 async fn write_audit(
137 &self,
138 call: &ToolCall,
139 decision: &str,
140 result: AuditResult,
141 claim_source: Option<ClaimSource>,
142 ) {
143 let Some(audit) = &self.audit else { return };
144 let entry = AuditEntry {
145 timestamp: chrono_now(),
146 tool: call.tool_id.clone(),
147 command: params_summary(&call.params),
148 result,
149 duration_ms: 0,
150 error_category: None,
151 error_domain: None,
152 error_phase: None,
153 claim_source,
154 mcp_server_id: None,
155 injection_flagged: false,
156 embedding_anomalous: false,
157 cross_boundary_mcp_to_acp: false,
158 adversarial_policy_decision: Some(decision.to_owned()),
159 exit_code: None,
160 truncated: false,
161 caller_id: call.caller_id.clone(),
162 policy_match: None,
163 correlation_id: None,
164 vigil_risk: None,
165 execution_env: None,
166 resolved_cwd: None,
167 scope_at_definition: None,
168 scope_at_dispatch: None,
169 };
170 audit.log(&entry).await;
171 }
172}
173
174impl<T: ToolExecutor> ToolExecutor for AdversarialPolicyGateExecutor<T> {
175 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
177 self.inner.execute(response).await
178 }
179
180 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
181 self.inner.execute_confirmed(response).await
182 }
183
184 fn tool_definitions(&self) -> Vec<ToolDef> {
186 self.inner.tool_definitions()
187 }
188
189 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
190 self.check_policy(call).await?;
191 let output = self.inner.execute_tool_call(call).await?;
192 if let Some(ref out) = output {
193 self.write_audit(
194 call,
195 "allow:executed",
196 AuditResult::Success,
197 out.claim_source,
198 )
199 .await;
200 }
201 Ok(output)
202 }
203
204 async fn execute_tool_call_confirmed(
206 &self,
207 call: &ToolCall,
208 ) -> Result<Option<ToolOutput>, ToolError> {
209 self.check_policy(call).await?;
210 let output = self.inner.execute_tool_call_confirmed(call).await?;
211 if let Some(ref out) = output {
212 self.write_audit(
213 call,
214 "allow:executed",
215 AuditResult::Success,
216 out.claim_source,
217 )
218 .await;
219 }
220 Ok(output)
221 }
222
223 fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
224 self.inner.set_skill_env(env);
225 }
226
227 fn set_effective_trust(&self, level: crate::SkillTrustLevel) {
228 self.inner.set_effective_trust(level);
229 }
230
231 fn is_tool_retryable(&self, tool_id: &str) -> bool {
232 self.inner.is_tool_retryable(tool_id)
233 }
234}
235
236fn params_summary(params: &serde_json::Map<String, serde_json::Value>) -> String {
237 let s = serde_json::to_string(params).unwrap_or_default();
238 if s.chars().count() > 500 {
239 let truncated: String = s.chars().take(497).collect();
240 format!("{truncated}\u{2026}")
241 } else {
242 s
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use std::future::Future;
249 use std::pin::Pin;
250 use std::sync::Arc;
251 use std::sync::atomic::{AtomicUsize, Ordering};
252 use std::time::Duration;
253
254 use super::*;
255 use crate::adversarial_policy::{PolicyMessage, PolicyValidator};
256 use crate::executor::{ToolCall, ToolOutput};
257
258 struct MockLlm {
261 response: String,
262 call_count: Arc<AtomicUsize>,
263 }
264
265 impl MockLlm {
266 fn new(response: impl Into<String>) -> (Arc<AtomicUsize>, Self) {
267 let counter = Arc::new(AtomicUsize::new(0));
268 let client = Self {
269 response: response.into(),
270 call_count: Arc::clone(&counter),
271 };
272 (counter, client)
273 }
274 }
275
276 impl PolicyLlmClient for MockLlm {
277 fn chat<'a>(
278 &'a self,
279 _messages: &'a [PolicyMessage],
280 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
281 self.call_count.fetch_add(1, Ordering::SeqCst);
282 let resp = self.response.clone();
283 Box::pin(async move { Ok(resp) })
284 }
285 }
286
287 #[derive(Debug)]
290 struct MockInner {
291 call_count: Arc<AtomicUsize>,
292 }
293
294 impl MockInner {
295 fn new() -> (Arc<AtomicUsize>, Self) {
296 let counter = Arc::new(AtomicUsize::new(0));
297 let exec = Self {
298 call_count: Arc::clone(&counter),
299 };
300 (counter, exec)
301 }
302 }
303
304 impl ToolExecutor for MockInner {
305 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
306 Ok(None)
307 }
308
309 async fn execute_tool_call(
310 &self,
311 call: &ToolCall,
312 ) -> Result<Option<ToolOutput>, ToolError> {
313 self.call_count.fetch_add(1, Ordering::SeqCst);
314 Ok(Some(ToolOutput {
315 tool_name: call.tool_id.clone(),
316 summary: "ok".into(),
317 blocks_executed: 1,
318 filter_stats: None,
319 diff: None,
320 streamed: false,
321 terminal_id: None,
322 locations: None,
323 raw_response: None,
324 claim_source: None,
325 }))
326 }
327 }
328
329 fn make_call(tool_id: &str) -> ToolCall {
330 ToolCall {
331 tool_id: tool_id.into(),
332 params: serde_json::Map::new(),
333 caller_id: None,
334 context: None,
335
336 tool_call_id: String::new(),
337 }
338 }
339
340 fn make_validator(fail_open: bool) -> Arc<PolicyValidator> {
341 Arc::new(PolicyValidator::new(
342 vec!["test policy".to_owned()],
343 Duration::from_millis(500),
344 fail_open,
345 Vec::new(),
346 ))
347 }
348
349 #[tokio::test]
350 async fn allow_path_delegates_to_inner() {
351 let (llm_count, llm) = MockLlm::new("ALLOW");
352 let (inner_count, inner) = MockInner::new();
353 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
354 let result = gate.execute_tool_call(&make_call("shell")).await;
355 assert!(result.is_ok());
356 assert_eq!(
357 llm_count.load(Ordering::SeqCst),
358 1,
359 "LLM must be called once"
360 );
361 assert_eq!(
362 inner_count.load(Ordering::SeqCst),
363 1,
364 "inner executor must be called on allow"
365 );
366 }
367
368 #[tokio::test]
369 async fn deny_path_blocks_and_does_not_call_inner() {
370 let (llm_count, llm) = MockLlm::new("DENY: unsafe command");
371 let (inner_count, inner) = MockInner::new();
372 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
373 let result = gate.execute_tool_call(&make_call("shell")).await;
374 assert!(matches!(result, Err(ToolError::Blocked { .. })));
375 assert_eq!(llm_count.load(Ordering::SeqCst), 1);
376 assert_eq!(
377 inner_count.load(Ordering::SeqCst),
378 0,
379 "inner must NOT be called on deny"
380 );
381 }
382
383 #[tokio::test]
384 async fn error_message_is_opaque() {
385 let (_, llm) = MockLlm::new("DENY: secret internal policy rule XYZ");
387 let (_, inner) = MockInner::new();
388 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
389 let err = gate
390 .execute_tool_call(&make_call("shell"))
391 .await
392 .unwrap_err();
393 if let ToolError::Blocked { command } = err {
394 assert!(
395 !command.contains("secret internal policy rule XYZ"),
396 "LLM denial reason must not leak to main LLM"
397 );
398 } else {
399 panic!("expected Blocked error");
400 }
401 }
402
403 #[tokio::test]
404 async fn fail_closed_blocks_on_llm_error() {
405 struct FailingLlm;
406 impl PolicyLlmClient for FailingLlm {
407 fn chat<'a>(
408 &'a self,
409 _: &'a [PolicyMessage],
410 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
411 Box::pin(async { Err("network error".to_owned()) })
412 }
413 }
414
415 let (_, inner) = MockInner::new();
416 let gate = AdversarialPolicyGateExecutor::new(
417 inner,
418 make_validator(false), Arc::new(FailingLlm),
420 );
421 let result = gate.execute_tool_call(&make_call("shell")).await;
422 assert!(
423 matches!(result, Err(ToolError::Blocked { .. })),
424 "fail-closed must block on LLM error"
425 );
426 }
427
428 #[tokio::test]
429 async fn fail_open_allows_on_llm_error() {
430 struct FailingLlm;
431 impl PolicyLlmClient for FailingLlm {
432 fn chat<'a>(
433 &'a self,
434 _: &'a [PolicyMessage],
435 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
436 Box::pin(async { Err("network error".to_owned()) })
437 }
438 }
439
440 let (inner_count, inner) = MockInner::new();
441 let gate = AdversarialPolicyGateExecutor::new(
442 inner,
443 make_validator(true), Arc::new(FailingLlm),
445 );
446 let result = gate.execute_tool_call(&make_call("shell")).await;
447 assert!(result.is_ok(), "fail-open must allow on LLM error");
448 assert_eq!(inner_count.load(Ordering::SeqCst), 1);
449 }
450
451 #[tokio::test]
452 async fn confirmed_also_enforces_policy() {
453 let (_, llm) = MockLlm::new("DENY: blocked");
454 let (_, inner) = MockInner::new();
455 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
456 let result = gate.execute_tool_call_confirmed(&make_call("shell")).await;
457 assert!(
458 matches!(result, Err(ToolError::Blocked { .. })),
459 "confirmed path must also enforce adversarial policy"
460 );
461 }
462
463 #[tokio::test]
464 async fn legacy_execute_bypasses_policy() {
465 let (llm_count, llm) = MockLlm::new("DENY: anything");
466 let (_, inner) = MockInner::new();
467 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
468 let result = gate.execute("```shell\necho hi\n```").await;
469 assert!(
470 result.is_ok(),
471 "legacy execute must bypass adversarial policy"
472 );
473 assert_eq!(
474 llm_count.load(Ordering::SeqCst),
475 0,
476 "LLM must NOT be called for legacy dispatch"
477 );
478 }
479
480 #[tokio::test]
481 async fn delegation_set_skill_env() {
482 let (_, llm) = MockLlm::new("ALLOW");
484 let (_, inner) = MockInner::new();
485 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
486 gate.set_skill_env(None);
487 }
488
489 #[tokio::test]
490 async fn delegation_set_effective_trust() {
491 use crate::SkillTrustLevel;
492 let (_, llm) = MockLlm::new("ALLOW");
493 let (_, inner) = MockInner::new();
494 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
495 gate.set_effective_trust(SkillTrustLevel::Trusted);
496 }
497
498 #[tokio::test]
499 async fn delegation_is_tool_retryable() {
500 let (_, llm) = MockLlm::new("ALLOW");
501 let (_, inner) = MockInner::new();
502 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
503 let retryable = gate.is_tool_retryable("shell");
504 assert!(!retryable, "MockInner returns false for is_tool_retryable");
505 }
506
507 #[tokio::test]
508 async fn delegation_tool_definitions() {
509 let (_, llm) = MockLlm::new("ALLOW");
510 let (_, inner) = MockInner::new();
511 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
512 let defs = gate.tool_definitions();
513 assert!(defs.is_empty(), "MockInner returns empty tool definitions");
514 }
515
516 #[tokio::test]
517 async fn audit_entry_contains_adversarial_decision() {
518 use tempfile::TempDir;
519
520 let dir = TempDir::new().unwrap();
521 let log_path = dir.path().join("audit.log");
522 let audit_config = crate::config::AuditConfig {
523 enabled: true,
524 destination: crate::config::AuditDestination::File(log_path.clone()),
525 ..Default::default()
526 };
527 let audit_logger = Arc::new(
528 crate::audit::AuditLogger::from_config(&audit_config, false)
529 .await
530 .unwrap(),
531 );
532
533 let (_, llm) = MockLlm::new("ALLOW");
534 let (_, inner) = MockInner::new();
535 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm))
536 .with_audit(Arc::clone(&audit_logger));
537
538 gate.execute_tool_call(&make_call("shell")).await.unwrap();
539
540 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
541 assert!(
542 content.contains("adversarial_policy_decision"),
543 "audit entry must contain adversarial_policy_decision field"
544 );
545 assert!(
546 content.contains("\"allow\""),
547 "allow decision must be recorded"
548 );
549 }
550
551 #[tokio::test]
552 async fn audit_entry_deny_contains_decision() {
553 use tempfile::TempDir;
554
555 let dir = TempDir::new().unwrap();
556 let log_path = dir.path().join("audit.log");
557 let audit_config = crate::config::AuditConfig {
558 enabled: true,
559 destination: crate::config::AuditDestination::File(log_path.clone()),
560 ..Default::default()
561 };
562 let audit_logger = Arc::new(
563 crate::audit::AuditLogger::from_config(&audit_config, false)
564 .await
565 .unwrap(),
566 );
567
568 let (_, llm) = MockLlm::new("DENY: test denial");
569 let (_, inner) = MockInner::new();
570 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm))
571 .with_audit(Arc::clone(&audit_logger));
572
573 let _ = gate.execute_tool_call(&make_call("shell")).await;
574
575 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
576 assert!(
577 content.contains("deny:"),
578 "deny decision must be recorded in audit"
579 );
580 }
581
582 #[tokio::test]
583 async fn audit_entry_propagates_claim_source() {
584 use tempfile::TempDir;
585
586 #[derive(Debug)]
587 struct InnerWithClaimSource;
588
589 impl ToolExecutor for InnerWithClaimSource {
590 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
591 Ok(None)
592 }
593
594 async fn execute_tool_call(
595 &self,
596 call: &ToolCall,
597 ) -> Result<Option<ToolOutput>, ToolError> {
598 Ok(Some(ToolOutput {
599 tool_name: call.tool_id.clone(),
600 summary: "ok".into(),
601 blocks_executed: 1,
602 filter_stats: None,
603 diff: None,
604 streamed: false,
605 terminal_id: None,
606 locations: None,
607 raw_response: None,
608 claim_source: Some(crate::executor::ClaimSource::Shell),
609 }))
610 }
611 }
612
613 let dir = TempDir::new().unwrap();
614 let log_path = dir.path().join("audit.log");
615 let audit_config = crate::config::AuditConfig {
616 enabled: true,
617 destination: crate::config::AuditDestination::File(log_path.clone()),
618 ..Default::default()
619 };
620 let audit_logger = Arc::new(
621 crate::audit::AuditLogger::from_config(&audit_config, false)
622 .await
623 .unwrap(),
624 );
625
626 let (_, llm) = MockLlm::new("ALLOW");
627 let gate = AdversarialPolicyGateExecutor::new(
628 InnerWithClaimSource,
629 make_validator(false),
630 Arc::new(llm),
631 )
632 .with_audit(Arc::clone(&audit_logger));
633
634 gate.execute_tool_call(&make_call("shell")).await.unwrap();
635
636 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
637 assert!(
638 content.contains("\"shell\""),
639 "claim_source must be propagated into the post-execution audit entry"
640 );
641 }
642}