1use std::sync::Arc;
18
19use crate::adversarial_policy::{PolicyDecision, PolicyLlmClient, PolicyValidator};
20use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
21use crate::executor::{ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput};
22use crate::registry::ToolDef;
23
24pub struct AdversarialPolicyGateExecutor<T: ToolExecutor> {
30 inner: T,
31 validator: Arc<PolicyValidator>,
32 llm: Arc<dyn PolicyLlmClient>,
33 audit: Option<Arc<AuditLogger>>,
34}
35
36impl<T: ToolExecutor + std::fmt::Debug> std::fmt::Debug for AdversarialPolicyGateExecutor<T> {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("AdversarialPolicyGateExecutor")
39 .field("inner", &self.inner)
40 .finish_non_exhaustive()
41 }
42}
43
44impl<T: ToolExecutor> AdversarialPolicyGateExecutor<T> {
45 #[must_use]
47 pub fn new(inner: T, validator: Arc<PolicyValidator>, llm: Arc<dyn PolicyLlmClient>) -> Self {
48 Self {
49 inner,
50 validator,
51 llm,
52 audit: None,
53 }
54 }
55
56 #[must_use]
58 pub fn with_audit(mut self, audit: Arc<AuditLogger>) -> Self {
59 self.audit = Some(audit);
60 self
61 }
62
63 async fn check_policy(&self, call: &ToolCall) -> Result<(), ToolError> {
64 tracing::info!(
65 tool = %call.tool_id,
66 status_spinner = true,
67 "Validating tool policy\u{2026}"
68 );
69
70 let decision = self
71 .validator
72 .validate(call.tool_id.as_str(), &call.params, self.llm.as_ref())
73 .await;
74
75 match decision {
76 PolicyDecision::Allow => {
77 tracing::debug!(tool = %call.tool_id, "adversarial policy: allow");
78 self.write_audit(call, "allow", AuditResult::Success, None)
79 .await;
80 Ok(())
81 }
82 PolicyDecision::Deny { reason } => {
83 tracing::warn!(
84 tool = %call.tool_id,
85 reason = %reason,
86 "adversarial policy: deny"
87 );
88 self.write_audit(
89 call,
90 &format!("deny:{reason}"),
91 AuditResult::Blocked {
92 reason: reason.clone(),
93 },
94 None,
95 )
96 .await;
97 Err(ToolError::Blocked {
99 command: "[adversarial] Tool call denied by policy".to_owned(),
100 })
101 }
102 PolicyDecision::Error { message } => {
103 tracing::warn!(
104 tool = %call.tool_id,
105 error = %message,
106 fail_open = self.validator.fail_open(),
107 "adversarial policy: LLM error"
108 );
109 if self.validator.fail_open() {
110 self.write_audit(
111 call,
112 &format!("error:{message}"),
113 AuditResult::Success,
114 None,
115 )
116 .await;
117 Ok(())
118 } else {
119 self.write_audit(
120 call,
121 &format!("error:{message}"),
122 AuditResult::Blocked {
123 reason: "adversarial policy LLM error (fail-closed)".to_owned(),
124 },
125 None,
126 )
127 .await;
128 Err(ToolError::Blocked {
129 command: "[adversarial] Tool call denied: policy check failed".to_owned(),
130 })
131 }
132 }
133 }
134 }
135
136 async fn write_audit(
137 &self,
138 call: &ToolCall,
139 decision: &str,
140 result: AuditResult,
141 claim_source: Option<ClaimSource>,
142 ) {
143 let Some(audit) = &self.audit else { return };
144 let entry = AuditEntry {
145 timestamp: chrono_now(),
146 tool: call.tool_id.clone(),
147 command: params_summary(&call.params),
148 result,
149 duration_ms: 0,
150 error_category: None,
151 error_domain: None,
152 error_phase: None,
153 claim_source,
154 mcp_server_id: None,
155 injection_flagged: false,
156 embedding_anomalous: false,
157 cross_boundary_mcp_to_acp: false,
158 adversarial_policy_decision: Some(decision.to_owned()),
159 exit_code: None,
160 truncated: false,
161 caller_id: call.caller_id.clone(),
162 skill_name: call.skill_name.clone(),
163 policy_match: None,
164 correlation_id: None,
165 vigil_risk: None,
166 execution_env: None,
167 resolved_cwd: None,
168 scope_at_definition: None,
169 scope_at_dispatch: None,
170 };
171 audit.log(&entry).await;
172 }
173}
174
175impl<T: ToolExecutor> ToolExecutor for AdversarialPolicyGateExecutor<T> {
176 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
178 self.inner.execute(response).await
179 }
180
181 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
182 self.inner.execute_confirmed(response).await
183 }
184
185 fn tool_definitions(&self) -> Vec<ToolDef> {
187 self.inner.tool_definitions()
188 }
189
190 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
191 self.check_policy(call).await?;
192 let output = self.inner.execute_tool_call(call).await?;
193 if let Some(ref out) = output {
194 self.write_audit(
195 call,
196 "allow:executed",
197 AuditResult::Success,
198 out.claim_source,
199 )
200 .await;
201 }
202 Ok(output)
203 }
204
205 async fn execute_tool_call_confirmed(
207 &self,
208 call: &ToolCall,
209 ) -> Result<Option<ToolOutput>, ToolError> {
210 self.check_policy(call).await?;
211 let output = self.inner.execute_tool_call_confirmed(call).await?;
212 if let Some(ref out) = output {
213 self.write_audit(
214 call,
215 "allow:executed",
216 AuditResult::Success,
217 out.claim_source,
218 )
219 .await;
220 }
221 Ok(output)
222 }
223
224 fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
225 self.inner.set_skill_env(env);
226 }
227
228 fn set_effective_trust(&self, level: crate::SkillTrustLevel) {
229 self.inner.set_effective_trust(level);
230 }
231
232 fn is_tool_retryable(&self, tool_id: &str) -> bool {
233 self.inner.is_tool_retryable(tool_id)
234 }
235}
236
237fn params_summary(params: &serde_json::Map<String, serde_json::Value>) -> String {
238 let s = serde_json::to_string(params).unwrap_or_default();
239 if s.chars().count() > 500 {
240 let truncated: String = s.chars().take(497).collect();
241 format!("{truncated}\u{2026}")
242 } else {
243 s
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use std::future::Future;
250 use std::pin::Pin;
251 use std::sync::Arc;
252 use std::sync::atomic::{AtomicUsize, Ordering};
253 use std::time::Duration;
254
255 use super::*;
256 use crate::adversarial_policy::{PolicyMessage, PolicyValidator};
257 use crate::executor::{ToolCall, ToolOutput};
258
259 struct MockLlm {
262 response: String,
263 call_count: Arc<AtomicUsize>,
264 }
265
266 impl MockLlm {
267 fn new(response: impl Into<String>) -> (Arc<AtomicUsize>, Self) {
268 let counter = Arc::new(AtomicUsize::new(0));
269 let client = Self {
270 response: response.into(),
271 call_count: Arc::clone(&counter),
272 };
273 (counter, client)
274 }
275 }
276
277 impl PolicyLlmClient for MockLlm {
278 fn chat<'a>(
279 &'a self,
280 _messages: &'a [PolicyMessage],
281 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
282 self.call_count.fetch_add(1, Ordering::SeqCst);
283 let resp = self.response.clone();
284 Box::pin(async move { Ok(resp) })
285 }
286 }
287
288 #[derive(Debug)]
291 struct MockInner {
292 call_count: Arc<AtomicUsize>,
293 }
294
295 impl MockInner {
296 fn new() -> (Arc<AtomicUsize>, Self) {
297 let counter = Arc::new(AtomicUsize::new(0));
298 let exec = Self {
299 call_count: Arc::clone(&counter),
300 };
301 (counter, exec)
302 }
303 }
304
305 impl ToolExecutor for MockInner {
306 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
307 Ok(None)
308 }
309
310 async fn execute_tool_call(
311 &self,
312 call: &ToolCall,
313 ) -> Result<Option<ToolOutput>, ToolError> {
314 self.call_count.fetch_add(1, Ordering::SeqCst);
315 Ok(Some(ToolOutput {
316 tool_name: call.tool_id.clone(),
317 summary: "ok".into(),
318 blocks_executed: 1,
319 filter_stats: None,
320 diff: None,
321 streamed: false,
322 terminal_id: None,
323 locations: None,
324 raw_response: None,
325 claim_source: None,
326 }))
327 }
328 }
329
330 fn make_call(tool_id: &str) -> ToolCall {
331 ToolCall {
332 tool_id: tool_id.into(),
333 params: serde_json::Map::new(),
334 caller_id: None,
335 context: None,
336
337 tool_call_id: String::new(),
338 skill_name: None,
339 }
340 }
341
342 fn make_validator(fail_open: bool) -> Arc<PolicyValidator> {
343 Arc::new(PolicyValidator::new(
344 vec!["test policy".to_owned()],
345 Duration::from_millis(500),
346 fail_open,
347 Vec::new(),
348 ))
349 }
350
351 #[tokio::test]
352 async fn allow_path_delegates_to_inner() {
353 let (llm_count, llm) = MockLlm::new("ALLOW");
354 let (inner_count, inner) = MockInner::new();
355 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
356 let result = gate.execute_tool_call(&make_call("shell")).await;
357 assert!(result.is_ok());
358 assert_eq!(
359 llm_count.load(Ordering::SeqCst),
360 1,
361 "LLM must be called once"
362 );
363 assert_eq!(
364 inner_count.load(Ordering::SeqCst),
365 1,
366 "inner executor must be called on allow"
367 );
368 }
369
370 #[tokio::test]
371 async fn deny_path_blocks_and_does_not_call_inner() {
372 let (llm_count, llm) = MockLlm::new("DENY: unsafe command");
373 let (inner_count, inner) = MockInner::new();
374 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
375 let result = gate.execute_tool_call(&make_call("shell")).await;
376 assert!(matches!(result, Err(ToolError::Blocked { .. })));
377 assert_eq!(llm_count.load(Ordering::SeqCst), 1);
378 assert_eq!(
379 inner_count.load(Ordering::SeqCst),
380 0,
381 "inner must NOT be called on deny"
382 );
383 }
384
385 #[tokio::test]
386 async fn error_message_is_opaque() {
387 let (_, llm) = MockLlm::new("DENY: secret internal policy rule XYZ");
389 let (_, inner) = MockInner::new();
390 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
391 let err = gate
392 .execute_tool_call(&make_call("shell"))
393 .await
394 .unwrap_err();
395 if let ToolError::Blocked { command } = err {
396 assert!(
397 !command.contains("secret internal policy rule XYZ"),
398 "LLM denial reason must not leak to main LLM"
399 );
400 } else {
401 panic!("expected Blocked error");
402 }
403 }
404
405 #[tokio::test]
406 async fn fail_closed_blocks_on_llm_error() {
407 struct FailingLlm;
408 impl PolicyLlmClient for FailingLlm {
409 fn chat<'a>(
410 &'a self,
411 _: &'a [PolicyMessage],
412 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
413 Box::pin(async { Err("network error".to_owned()) })
414 }
415 }
416
417 let (_, inner) = MockInner::new();
418 let gate = AdversarialPolicyGateExecutor::new(
419 inner,
420 make_validator(false), Arc::new(FailingLlm),
422 );
423 let result = gate.execute_tool_call(&make_call("shell")).await;
424 assert!(
425 matches!(result, Err(ToolError::Blocked { .. })),
426 "fail-closed must block on LLM error"
427 );
428 }
429
430 #[tokio::test]
431 async fn fail_open_allows_on_llm_error() {
432 struct FailingLlm;
433 impl PolicyLlmClient for FailingLlm {
434 fn chat<'a>(
435 &'a self,
436 _: &'a [PolicyMessage],
437 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
438 Box::pin(async { Err("network error".to_owned()) })
439 }
440 }
441
442 let (inner_count, inner) = MockInner::new();
443 let gate = AdversarialPolicyGateExecutor::new(
444 inner,
445 make_validator(true), Arc::new(FailingLlm),
447 );
448 let result = gate.execute_tool_call(&make_call("shell")).await;
449 assert!(result.is_ok(), "fail-open must allow on LLM error");
450 assert_eq!(inner_count.load(Ordering::SeqCst), 1);
451 }
452
453 #[tokio::test]
454 async fn confirmed_also_enforces_policy() {
455 let (_, llm) = MockLlm::new("DENY: blocked");
456 let (_, inner) = MockInner::new();
457 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
458 let result = gate.execute_tool_call_confirmed(&make_call("shell")).await;
459 assert!(
460 matches!(result, Err(ToolError::Blocked { .. })),
461 "confirmed path must also enforce adversarial policy"
462 );
463 }
464
465 #[tokio::test]
466 async fn legacy_execute_bypasses_policy() {
467 let (llm_count, llm) = MockLlm::new("DENY: anything");
468 let (_, inner) = MockInner::new();
469 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
470 let result = gate.execute("```shell\necho hi\n```").await;
471 assert!(
472 result.is_ok(),
473 "legacy execute must bypass adversarial policy"
474 );
475 assert_eq!(
476 llm_count.load(Ordering::SeqCst),
477 0,
478 "LLM must NOT be called for legacy dispatch"
479 );
480 }
481
482 #[tokio::test]
483 async fn delegation_set_skill_env() {
484 let (_, llm) = MockLlm::new("ALLOW");
486 let (_, inner) = MockInner::new();
487 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
488 gate.set_skill_env(None);
489 }
490
491 #[tokio::test]
492 async fn delegation_set_effective_trust() {
493 use crate::SkillTrustLevel;
494 let (_, llm) = MockLlm::new("ALLOW");
495 let (_, inner) = MockInner::new();
496 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
497 gate.set_effective_trust(SkillTrustLevel::Trusted);
498 }
499
500 #[tokio::test]
501 async fn delegation_is_tool_retryable() {
502 let (_, llm) = MockLlm::new("ALLOW");
503 let (_, inner) = MockInner::new();
504 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
505 let retryable = gate.is_tool_retryable("shell");
506 assert!(!retryable, "MockInner returns false for is_tool_retryable");
507 }
508
509 #[tokio::test]
510 async fn delegation_tool_definitions() {
511 let (_, llm) = MockLlm::new("ALLOW");
512 let (_, inner) = MockInner::new();
513 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
514 let defs = gate.tool_definitions();
515 assert!(defs.is_empty(), "MockInner returns empty tool definitions");
516 }
517
518 #[tokio::test]
519 async fn audit_entry_contains_adversarial_decision() {
520 use tempfile::TempDir;
521
522 let dir = TempDir::new().unwrap();
523 let log_path = dir.path().join("audit.log");
524 let audit_config = crate::config::AuditConfig {
525 enabled: true,
526 destination: crate::config::AuditDestination::File(log_path.clone()),
527 ..Default::default()
528 };
529 let audit_logger = Arc::new(
530 crate::audit::AuditLogger::from_config(&audit_config, false)
531 .await
532 .unwrap(),
533 );
534
535 let (_, llm) = MockLlm::new("ALLOW");
536 let (_, inner) = MockInner::new();
537 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm))
538 .with_audit(Arc::clone(&audit_logger));
539
540 gate.execute_tool_call(&make_call("shell")).await.unwrap();
541
542 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
543 assert!(
544 content.contains("adversarial_policy_decision"),
545 "audit entry must contain adversarial_policy_decision field"
546 );
547 assert!(
548 content.contains("\"allow\""),
549 "allow decision must be recorded"
550 );
551 }
552
553 #[tokio::test]
554 async fn audit_entry_deny_contains_decision() {
555 use tempfile::TempDir;
556
557 let dir = TempDir::new().unwrap();
558 let log_path = dir.path().join("audit.log");
559 let audit_config = crate::config::AuditConfig {
560 enabled: true,
561 destination: crate::config::AuditDestination::File(log_path.clone()),
562 ..Default::default()
563 };
564 let audit_logger = Arc::new(
565 crate::audit::AuditLogger::from_config(&audit_config, false)
566 .await
567 .unwrap(),
568 );
569
570 let (_, llm) = MockLlm::new("DENY: test denial");
571 let (_, inner) = MockInner::new();
572 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm))
573 .with_audit(Arc::clone(&audit_logger));
574
575 let _ = gate.execute_tool_call(&make_call("shell")).await;
576
577 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
578 assert!(
579 content.contains("deny:"),
580 "deny decision must be recorded in audit"
581 );
582 }
583
584 #[tokio::test]
585 async fn audit_entry_propagates_claim_source() {
586 use tempfile::TempDir;
587
588 #[derive(Debug)]
589 struct InnerWithClaimSource;
590
591 impl ToolExecutor for InnerWithClaimSource {
592 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
593 Ok(None)
594 }
595
596 async fn execute_tool_call(
597 &self,
598 call: &ToolCall,
599 ) -> Result<Option<ToolOutput>, ToolError> {
600 Ok(Some(ToolOutput {
601 tool_name: call.tool_id.clone(),
602 summary: "ok".into(),
603 blocks_executed: 1,
604 filter_stats: None,
605 diff: None,
606 streamed: false,
607 terminal_id: None,
608 locations: None,
609 raw_response: None,
610 claim_source: Some(crate::executor::ClaimSource::Shell),
611 }))
612 }
613 }
614
615 let dir = TempDir::new().unwrap();
616 let log_path = dir.path().join("audit.log");
617 let audit_config = crate::config::AuditConfig {
618 enabled: true,
619 destination: crate::config::AuditDestination::File(log_path.clone()),
620 ..Default::default()
621 };
622 let audit_logger = Arc::new(
623 crate::audit::AuditLogger::from_config(&audit_config, false)
624 .await
625 .unwrap(),
626 );
627
628 let (_, llm) = MockLlm::new("ALLOW");
629 let gate = AdversarialPolicyGateExecutor::new(
630 InnerWithClaimSource,
631 make_validator(false),
632 Arc::new(llm),
633 )
634 .with_audit(Arc::clone(&audit_logger));
635
636 gate.execute_tool_call(&make_call("shell")).await.unwrap();
637
638 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
639 assert!(
640 content.contains("\"shell\""),
641 "claim_source must be propagated into the post-execution audit entry"
642 );
643 }
644}