1use std::sync::Arc;
18
19use crate::adversarial_policy::{PolicyDecision, PolicyLlmClient, PolicyValidator};
20use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
21use crate::executor::{ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput};
22use crate::registry::ToolDef;
23
24pub struct AdversarialPolicyGateExecutor<T: ToolExecutor> {
30 inner: T,
31 validator: Arc<PolicyValidator>,
32 llm: Arc<dyn PolicyLlmClient>,
33 audit: Option<Arc<AuditLogger>>,
34}
35
36impl<T: ToolExecutor + std::fmt::Debug> std::fmt::Debug for AdversarialPolicyGateExecutor<T> {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("AdversarialPolicyGateExecutor")
39 .field("inner", &self.inner)
40 .finish_non_exhaustive()
41 }
42}
43
44impl<T: ToolExecutor> AdversarialPolicyGateExecutor<T> {
45 #[must_use]
47 pub fn new(inner: T, validator: Arc<PolicyValidator>, llm: Arc<dyn PolicyLlmClient>) -> Self {
48 Self {
49 inner,
50 validator,
51 llm,
52 audit: None,
53 }
54 }
55
56 #[must_use]
58 pub fn with_audit(mut self, audit: Arc<AuditLogger>) -> Self {
59 self.audit = Some(audit);
60 self
61 }
62
63 async fn check_policy(&self, call: &ToolCall) -> Result<(), ToolError> {
64 tracing::info!(
65 tool = %call.tool_id,
66 status_spinner = true,
67 "Validating tool policy\u{2026}"
68 );
69
70 let decision = self
71 .validator
72 .validate(call.tool_id.as_str(), &call.params, self.llm.as_ref())
73 .await;
74
75 match decision {
76 PolicyDecision::Allow => {
77 tracing::debug!(tool = %call.tool_id, "adversarial policy: allow");
78 self.write_audit(call, "allow", AuditResult::Success, None)
79 .await;
80 Ok(())
81 }
82 PolicyDecision::Deny { reason } => {
83 tracing::warn!(
84 tool = %call.tool_id,
85 reason = %reason,
86 "adversarial policy: deny"
87 );
88 self.write_audit(
89 call,
90 &format!("deny:{reason}"),
91 AuditResult::Blocked {
92 reason: reason.clone(),
93 },
94 None,
95 )
96 .await;
97 Err(ToolError::Blocked {
99 command: "[adversarial] Tool call denied by policy".to_owned(),
100 })
101 }
102 PolicyDecision::Error { message } => {
103 tracing::warn!(
104 tool = %call.tool_id,
105 error = %message,
106 fail_open = self.validator.fail_open(),
107 "adversarial policy: LLM error"
108 );
109 if self.validator.fail_open() {
110 self.write_audit(
111 call,
112 &format!("error:{message}"),
113 AuditResult::Success,
114 None,
115 )
116 .await;
117 Ok(())
118 } else {
119 self.write_audit(
120 call,
121 &format!("error:{message}"),
122 AuditResult::Blocked {
123 reason: "adversarial policy LLM error (fail-closed)".to_owned(),
124 },
125 None,
126 )
127 .await;
128 Err(ToolError::Blocked {
129 command: "[adversarial] Tool call denied: policy check failed".to_owned(),
130 })
131 }
132 }
133 }
134 }
135
136 async fn write_audit(
137 &self,
138 call: &ToolCall,
139 decision: &str,
140 result: AuditResult,
141 claim_source: Option<ClaimSource>,
142 ) {
143 let Some(audit) = &self.audit else { return };
144 let entry = AuditEntry {
145 timestamp: chrono_now(),
146 tool: call.tool_id.clone(),
147 command: params_summary(&call.params),
148 result,
149 duration_ms: 0,
150 error_category: None,
151 error_domain: None,
152 error_phase: None,
153 claim_source,
154 mcp_server_id: None,
155 injection_flagged: false,
156 embedding_anomalous: false,
157 cross_boundary_mcp_to_acp: false,
158 adversarial_policy_decision: Some(decision.to_owned()),
159 exit_code: None,
160 truncated: false,
161 caller_id: call.caller_id.clone(),
162 policy_match: None,
163 correlation_id: None,
164 vigil_risk: None,
165 };
166 audit.log(&entry).await;
167 }
168}
169
170impl<T: ToolExecutor> ToolExecutor for AdversarialPolicyGateExecutor<T> {
171 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
173 self.inner.execute(response).await
174 }
175
176 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
177 self.inner.execute_confirmed(response).await
178 }
179
180 fn tool_definitions(&self) -> Vec<ToolDef> {
182 self.inner.tool_definitions()
183 }
184
185 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
186 self.check_policy(call).await?;
187 let output = self.inner.execute_tool_call(call).await?;
188 if let Some(ref out) = output {
189 self.write_audit(
190 call,
191 "allow:executed",
192 AuditResult::Success,
193 out.claim_source,
194 )
195 .await;
196 }
197 Ok(output)
198 }
199
200 async fn execute_tool_call_confirmed(
202 &self,
203 call: &ToolCall,
204 ) -> Result<Option<ToolOutput>, ToolError> {
205 self.check_policy(call).await?;
206 let output = self.inner.execute_tool_call_confirmed(call).await?;
207 if let Some(ref out) = output {
208 self.write_audit(
209 call,
210 "allow:executed",
211 AuditResult::Success,
212 out.claim_source,
213 )
214 .await;
215 }
216 Ok(output)
217 }
218
219 fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
220 self.inner.set_skill_env(env);
221 }
222
223 fn set_effective_trust(&self, level: crate::SkillTrustLevel) {
224 self.inner.set_effective_trust(level);
225 }
226
227 fn is_tool_retryable(&self, tool_id: &str) -> bool {
228 self.inner.is_tool_retryable(tool_id)
229 }
230}
231
232fn params_summary(params: &serde_json::Map<String, serde_json::Value>) -> String {
233 let s = serde_json::to_string(params).unwrap_or_default();
234 if s.chars().count() > 500 {
235 let truncated: String = s.chars().take(497).collect();
236 format!("{truncated}\u{2026}")
237 } else {
238 s
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use std::future::Future;
245 use std::pin::Pin;
246 use std::sync::Arc;
247 use std::sync::atomic::{AtomicUsize, Ordering};
248 use std::time::Duration;
249
250 use super::*;
251 use crate::adversarial_policy::{PolicyMessage, PolicyValidator};
252 use crate::executor::{ToolCall, ToolOutput};
253
254 struct MockLlm {
257 response: String,
258 call_count: Arc<AtomicUsize>,
259 }
260
261 impl MockLlm {
262 fn new(response: impl Into<String>) -> (Arc<AtomicUsize>, Self) {
263 let counter = Arc::new(AtomicUsize::new(0));
264 let client = Self {
265 response: response.into(),
266 call_count: Arc::clone(&counter),
267 };
268 (counter, client)
269 }
270 }
271
272 impl PolicyLlmClient for MockLlm {
273 fn chat<'a>(
274 &'a self,
275 _messages: &'a [PolicyMessage],
276 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
277 self.call_count.fetch_add(1, Ordering::SeqCst);
278 let resp = self.response.clone();
279 Box::pin(async move { Ok(resp) })
280 }
281 }
282
283 #[derive(Debug)]
286 struct MockInner {
287 call_count: Arc<AtomicUsize>,
288 }
289
290 impl MockInner {
291 fn new() -> (Arc<AtomicUsize>, Self) {
292 let counter = Arc::new(AtomicUsize::new(0));
293 let exec = Self {
294 call_count: Arc::clone(&counter),
295 };
296 (counter, exec)
297 }
298 }
299
300 impl ToolExecutor for MockInner {
301 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
302 Ok(None)
303 }
304
305 async fn execute_tool_call(
306 &self,
307 call: &ToolCall,
308 ) -> Result<Option<ToolOutput>, ToolError> {
309 self.call_count.fetch_add(1, Ordering::SeqCst);
310 Ok(Some(ToolOutput {
311 tool_name: call.tool_id.clone(),
312 summary: "ok".into(),
313 blocks_executed: 1,
314 filter_stats: None,
315 diff: None,
316 streamed: false,
317 terminal_id: None,
318 locations: None,
319 raw_response: None,
320 claim_source: None,
321 }))
322 }
323 }
324
325 fn make_call(tool_id: &str) -> ToolCall {
326 ToolCall {
327 tool_id: tool_id.into(),
328 params: serde_json::Map::new(),
329 caller_id: None,
330 }
331 }
332
333 fn make_validator(fail_open: bool) -> Arc<PolicyValidator> {
334 Arc::new(PolicyValidator::new(
335 vec!["test policy".to_owned()],
336 Duration::from_millis(500),
337 fail_open,
338 Vec::new(),
339 ))
340 }
341
342 #[tokio::test]
343 async fn allow_path_delegates_to_inner() {
344 let (llm_count, llm) = MockLlm::new("ALLOW");
345 let (inner_count, inner) = MockInner::new();
346 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
347 let result = gate.execute_tool_call(&make_call("shell")).await;
348 assert!(result.is_ok());
349 assert_eq!(
350 llm_count.load(Ordering::SeqCst),
351 1,
352 "LLM must be called once"
353 );
354 assert_eq!(
355 inner_count.load(Ordering::SeqCst),
356 1,
357 "inner executor must be called on allow"
358 );
359 }
360
361 #[tokio::test]
362 async fn deny_path_blocks_and_does_not_call_inner() {
363 let (llm_count, llm) = MockLlm::new("DENY: unsafe command");
364 let (inner_count, inner) = MockInner::new();
365 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
366 let result = gate.execute_tool_call(&make_call("shell")).await;
367 assert!(matches!(result, Err(ToolError::Blocked { .. })));
368 assert_eq!(llm_count.load(Ordering::SeqCst), 1);
369 assert_eq!(
370 inner_count.load(Ordering::SeqCst),
371 0,
372 "inner must NOT be called on deny"
373 );
374 }
375
376 #[tokio::test]
377 async fn error_message_is_opaque() {
378 let (_, llm) = MockLlm::new("DENY: secret internal policy rule XYZ");
380 let (_, inner) = MockInner::new();
381 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
382 let err = gate
383 .execute_tool_call(&make_call("shell"))
384 .await
385 .unwrap_err();
386 if let ToolError::Blocked { command } = err {
387 assert!(
388 !command.contains("secret internal policy rule XYZ"),
389 "LLM denial reason must not leak to main LLM"
390 );
391 } else {
392 panic!("expected Blocked error");
393 }
394 }
395
396 #[tokio::test]
397 async fn fail_closed_blocks_on_llm_error() {
398 struct FailingLlm;
399 impl PolicyLlmClient for FailingLlm {
400 fn chat<'a>(
401 &'a self,
402 _: &'a [PolicyMessage],
403 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
404 Box::pin(async { Err("network error".to_owned()) })
405 }
406 }
407
408 let (_, inner) = MockInner::new();
409 let gate = AdversarialPolicyGateExecutor::new(
410 inner,
411 make_validator(false), Arc::new(FailingLlm),
413 );
414 let result = gate.execute_tool_call(&make_call("shell")).await;
415 assert!(
416 matches!(result, Err(ToolError::Blocked { .. })),
417 "fail-closed must block on LLM error"
418 );
419 }
420
421 #[tokio::test]
422 async fn fail_open_allows_on_llm_error() {
423 struct FailingLlm;
424 impl PolicyLlmClient for FailingLlm {
425 fn chat<'a>(
426 &'a self,
427 _: &'a [PolicyMessage],
428 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
429 Box::pin(async { Err("network error".to_owned()) })
430 }
431 }
432
433 let (inner_count, inner) = MockInner::new();
434 let gate = AdversarialPolicyGateExecutor::new(
435 inner,
436 make_validator(true), Arc::new(FailingLlm),
438 );
439 let result = gate.execute_tool_call(&make_call("shell")).await;
440 assert!(result.is_ok(), "fail-open must allow on LLM error");
441 assert_eq!(inner_count.load(Ordering::SeqCst), 1);
442 }
443
444 #[tokio::test]
445 async fn confirmed_also_enforces_policy() {
446 let (_, llm) = MockLlm::new("DENY: blocked");
447 let (_, inner) = MockInner::new();
448 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
449 let result = gate.execute_tool_call_confirmed(&make_call("shell")).await;
450 assert!(
451 matches!(result, Err(ToolError::Blocked { .. })),
452 "confirmed path must also enforce adversarial policy"
453 );
454 }
455
456 #[tokio::test]
457 async fn legacy_execute_bypasses_policy() {
458 let (llm_count, llm) = MockLlm::new("DENY: anything");
459 let (_, inner) = MockInner::new();
460 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
461 let result = gate.execute("```shell\necho hi\n```").await;
462 assert!(
463 result.is_ok(),
464 "legacy execute must bypass adversarial policy"
465 );
466 assert_eq!(
467 llm_count.load(Ordering::SeqCst),
468 0,
469 "LLM must NOT be called for legacy dispatch"
470 );
471 }
472
473 #[tokio::test]
474 async fn delegation_set_skill_env() {
475 let (_, llm) = MockLlm::new("ALLOW");
477 let (_, inner) = MockInner::new();
478 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
479 gate.set_skill_env(None);
480 }
481
482 #[tokio::test]
483 async fn delegation_set_effective_trust() {
484 use crate::SkillTrustLevel;
485 let (_, llm) = MockLlm::new("ALLOW");
486 let (_, inner) = MockInner::new();
487 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
488 gate.set_effective_trust(SkillTrustLevel::Trusted);
489 }
490
491 #[tokio::test]
492 async fn delegation_is_tool_retryable() {
493 let (_, llm) = MockLlm::new("ALLOW");
494 let (_, inner) = MockInner::new();
495 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
496 let retryable = gate.is_tool_retryable("shell");
497 assert!(!retryable, "MockInner returns false for is_tool_retryable");
498 }
499
500 #[tokio::test]
501 async fn delegation_tool_definitions() {
502 let (_, llm) = MockLlm::new("ALLOW");
503 let (_, inner) = MockInner::new();
504 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
505 let defs = gate.tool_definitions();
506 assert!(defs.is_empty(), "MockInner returns empty tool definitions");
507 }
508
509 #[tokio::test]
510 async fn audit_entry_contains_adversarial_decision() {
511 use tempfile::TempDir;
512
513 let dir = TempDir::new().unwrap();
514 let log_path = dir.path().join("audit.log");
515 let audit_config = crate::config::AuditConfig {
516 enabled: true,
517 destination: log_path.display().to_string(),
518 ..Default::default()
519 };
520 let audit_logger = Arc::new(
521 crate::audit::AuditLogger::from_config(&audit_config, false)
522 .await
523 .unwrap(),
524 );
525
526 let (_, llm) = MockLlm::new("ALLOW");
527 let (_, inner) = MockInner::new();
528 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm))
529 .with_audit(Arc::clone(&audit_logger));
530
531 gate.execute_tool_call(&make_call("shell")).await.unwrap();
532
533 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
534 assert!(
535 content.contains("adversarial_policy_decision"),
536 "audit entry must contain adversarial_policy_decision field"
537 );
538 assert!(
539 content.contains("\"allow\""),
540 "allow decision must be recorded"
541 );
542 }
543
544 #[tokio::test]
545 async fn audit_entry_deny_contains_decision() {
546 use tempfile::TempDir;
547
548 let dir = TempDir::new().unwrap();
549 let log_path = dir.path().join("audit.log");
550 let audit_config = crate::config::AuditConfig {
551 enabled: true,
552 destination: log_path.display().to_string(),
553 ..Default::default()
554 };
555 let audit_logger = Arc::new(
556 crate::audit::AuditLogger::from_config(&audit_config, false)
557 .await
558 .unwrap(),
559 );
560
561 let (_, llm) = MockLlm::new("DENY: test denial");
562 let (_, inner) = MockInner::new();
563 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm))
564 .with_audit(Arc::clone(&audit_logger));
565
566 let _ = gate.execute_tool_call(&make_call("shell")).await;
567
568 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
569 assert!(
570 content.contains("deny:"),
571 "deny decision must be recorded in audit"
572 );
573 }
574
575 #[tokio::test]
576 async fn audit_entry_propagates_claim_source() {
577 use tempfile::TempDir;
578
579 #[derive(Debug)]
580 struct InnerWithClaimSource;
581
582 impl ToolExecutor for InnerWithClaimSource {
583 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
584 Ok(None)
585 }
586
587 async fn execute_tool_call(
588 &self,
589 call: &ToolCall,
590 ) -> Result<Option<ToolOutput>, ToolError> {
591 Ok(Some(ToolOutput {
592 tool_name: call.tool_id.clone(),
593 summary: "ok".into(),
594 blocks_executed: 1,
595 filter_stats: None,
596 diff: None,
597 streamed: false,
598 terminal_id: None,
599 locations: None,
600 raw_response: None,
601 claim_source: Some(crate::executor::ClaimSource::Shell),
602 }))
603 }
604 }
605
606 let dir = TempDir::new().unwrap();
607 let log_path = dir.path().join("audit.log");
608 let audit_config = crate::config::AuditConfig {
609 enabled: true,
610 destination: log_path.display().to_string(),
611 ..Default::default()
612 };
613 let audit_logger = Arc::new(
614 crate::audit::AuditLogger::from_config(&audit_config, false)
615 .await
616 .unwrap(),
617 );
618
619 let (_, llm) = MockLlm::new("ALLOW");
620 let gate = AdversarialPolicyGateExecutor::new(
621 InnerWithClaimSource,
622 make_validator(false),
623 Arc::new(llm),
624 )
625 .with_audit(Arc::clone(&audit_logger));
626
627 gate.execute_tool_call(&make_call("shell")).await.unwrap();
628
629 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
630 assert!(
631 content.contains("\"shell\""),
632 "claim_source must be propagated into the post-execution audit entry"
633 );
634 }
635}