1use std::sync::Arc;
18
19use crate::adversarial_policy::{PolicyDecision, PolicyLlmClient, PolicyValidator};
20use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
21use crate::executor::{ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput};
22use crate::registry::ToolDef;
23
24pub struct AdversarialPolicyGateExecutor<T: ToolExecutor> {
30 inner: T,
31 validator: Arc<PolicyValidator>,
32 llm: Arc<dyn PolicyLlmClient>,
33 audit: Option<Arc<AuditLogger>>,
34}
35
36impl<T: ToolExecutor + std::fmt::Debug> std::fmt::Debug for AdversarialPolicyGateExecutor<T> {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("AdversarialPolicyGateExecutor")
39 .field("inner", &self.inner)
40 .finish_non_exhaustive()
41 }
42}
43
44impl<T: ToolExecutor> AdversarialPolicyGateExecutor<T> {
45 #[must_use]
47 pub fn new(inner: T, validator: Arc<PolicyValidator>, llm: Arc<dyn PolicyLlmClient>) -> Self {
48 Self {
49 inner,
50 validator,
51 llm,
52 audit: None,
53 }
54 }
55
56 #[must_use]
58 pub fn with_audit(mut self, audit: Arc<AuditLogger>) -> Self {
59 self.audit = Some(audit);
60 self
61 }
62
63 async fn check_policy(&self, call: &ToolCall) -> Result<(), ToolError> {
64 tracing::info!(
65 tool = %call.tool_id,
66 status_spinner = true,
67 "Validating tool policy\u{2026}"
68 );
69
70 let decision = self
71 .validator
72 .validate(&call.tool_id, &call.params, self.llm.as_ref())
73 .await;
74
75 match decision {
76 PolicyDecision::Allow => {
77 tracing::debug!(tool = %call.tool_id, "adversarial policy: allow");
78 self.write_audit(call, "allow", AuditResult::Success, None)
79 .await;
80 Ok(())
81 }
82 PolicyDecision::Deny { reason } => {
83 tracing::warn!(
84 tool = %call.tool_id,
85 reason = %reason,
86 "adversarial policy: deny"
87 );
88 self.write_audit(
89 call,
90 &format!("deny:{reason}"),
91 AuditResult::Blocked {
92 reason: reason.clone(),
93 },
94 None,
95 )
96 .await;
97 Err(ToolError::Blocked {
99 command: "[adversarial] Tool call denied by policy".to_owned(),
100 })
101 }
102 PolicyDecision::Error { message } => {
103 tracing::warn!(
104 tool = %call.tool_id,
105 error = %message,
106 fail_open = self.validator.fail_open(),
107 "adversarial policy: LLM error"
108 );
109 if self.validator.fail_open() {
110 self.write_audit(
111 call,
112 &format!("error:{message}"),
113 AuditResult::Success,
114 None,
115 )
116 .await;
117 Ok(())
118 } else {
119 self.write_audit(
120 call,
121 &format!("error:{message}"),
122 AuditResult::Blocked {
123 reason: "adversarial policy LLM error (fail-closed)".to_owned(),
124 },
125 None,
126 )
127 .await;
128 Err(ToolError::Blocked {
129 command: "[adversarial] Tool call denied: policy check failed".to_owned(),
130 })
131 }
132 }
133 }
134 }
135
136 async fn write_audit(
137 &self,
138 call: &ToolCall,
139 decision: &str,
140 result: AuditResult,
141 claim_source: Option<ClaimSource>,
142 ) {
143 let Some(audit) = &self.audit else { return };
144 let entry = AuditEntry {
145 timestamp: chrono_now(),
146 tool: call.tool_id.clone(),
147 command: params_summary(&call.params),
148 result,
149 duration_ms: 0,
150 error_category: None,
151 error_domain: None,
152 error_phase: None,
153 claim_source,
154 mcp_server_id: None,
155 injection_flagged: false,
156 embedding_anomalous: false,
157 cross_boundary_mcp_to_acp: false,
158 adversarial_policy_decision: Some(decision.to_owned()),
159 exit_code: None,
160 truncated: false,
161 caller_id: call.caller_id.clone(),
162 policy_match: None,
163 };
164 audit.log(&entry).await;
165 }
166}
167
168impl<T: ToolExecutor> ToolExecutor for AdversarialPolicyGateExecutor<T> {
169 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
171 self.inner.execute(response).await
172 }
173
174 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
175 self.inner.execute_confirmed(response).await
176 }
177
178 fn tool_definitions(&self) -> Vec<ToolDef> {
180 self.inner.tool_definitions()
181 }
182
183 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
184 self.check_policy(call).await?;
185 let output = self.inner.execute_tool_call(call).await?;
186 if let Some(ref out) = output {
187 self.write_audit(
188 call,
189 "allow:executed",
190 AuditResult::Success,
191 out.claim_source,
192 )
193 .await;
194 }
195 Ok(output)
196 }
197
198 async fn execute_tool_call_confirmed(
200 &self,
201 call: &ToolCall,
202 ) -> Result<Option<ToolOutput>, ToolError> {
203 self.check_policy(call).await?;
204 let output = self.inner.execute_tool_call_confirmed(call).await?;
205 if let Some(ref out) = output {
206 self.write_audit(
207 call,
208 "allow:executed",
209 AuditResult::Success,
210 out.claim_source,
211 )
212 .await;
213 }
214 Ok(output)
215 }
216
217 fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
218 self.inner.set_skill_env(env);
219 }
220
221 fn set_effective_trust(&self, level: crate::SkillTrustLevel) {
222 self.inner.set_effective_trust(level);
223 }
224
225 fn is_tool_retryable(&self, tool_id: &str) -> bool {
226 self.inner.is_tool_retryable(tool_id)
227 }
228}
229
230fn params_summary(params: &serde_json::Map<String, serde_json::Value>) -> String {
231 let s = serde_json::to_string(params).unwrap_or_default();
232 if s.chars().count() > 500 {
233 let truncated: String = s.chars().take(497).collect();
234 format!("{truncated}\u{2026}")
235 } else {
236 s
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use std::future::Future;
243 use std::pin::Pin;
244 use std::sync::Arc;
245 use std::sync::atomic::{AtomicUsize, Ordering};
246 use std::time::Duration;
247
248 use super::*;
249 use crate::adversarial_policy::{PolicyMessage, PolicyValidator};
250 use crate::executor::{ToolCall, ToolOutput};
251
252 struct MockLlm {
255 response: String,
256 call_count: Arc<AtomicUsize>,
257 }
258
259 impl MockLlm {
260 fn new(response: impl Into<String>) -> (Arc<AtomicUsize>, Self) {
261 let counter = Arc::new(AtomicUsize::new(0));
262 let client = Self {
263 response: response.into(),
264 call_count: Arc::clone(&counter),
265 };
266 (counter, client)
267 }
268 }
269
270 impl PolicyLlmClient for MockLlm {
271 fn chat<'a>(
272 &'a self,
273 _messages: &'a [PolicyMessage],
274 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
275 self.call_count.fetch_add(1, Ordering::SeqCst);
276 let resp = self.response.clone();
277 Box::pin(async move { Ok(resp) })
278 }
279 }
280
281 #[derive(Debug)]
284 struct MockInner {
285 call_count: Arc<AtomicUsize>,
286 }
287
288 impl MockInner {
289 fn new() -> (Arc<AtomicUsize>, Self) {
290 let counter = Arc::new(AtomicUsize::new(0));
291 let exec = Self {
292 call_count: Arc::clone(&counter),
293 };
294 (counter, exec)
295 }
296 }
297
298 impl ToolExecutor for MockInner {
299 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
300 Ok(None)
301 }
302
303 async fn execute_tool_call(
304 &self,
305 call: &ToolCall,
306 ) -> Result<Option<ToolOutput>, ToolError> {
307 self.call_count.fetch_add(1, Ordering::SeqCst);
308 Ok(Some(ToolOutput {
309 tool_name: call.tool_id.clone(),
310 summary: "ok".into(),
311 blocks_executed: 1,
312 filter_stats: None,
313 diff: None,
314 streamed: false,
315 terminal_id: None,
316 locations: None,
317 raw_response: None,
318 claim_source: None,
319 }))
320 }
321 }
322
323 fn make_call(tool_id: &str) -> ToolCall {
324 ToolCall {
325 tool_id: tool_id.into(),
326 params: serde_json::Map::new(),
327 caller_id: None,
328 }
329 }
330
331 fn make_validator(fail_open: bool) -> Arc<PolicyValidator> {
332 Arc::new(PolicyValidator::new(
333 vec!["test policy".to_owned()],
334 Duration::from_millis(500),
335 fail_open,
336 Vec::new(),
337 ))
338 }
339
340 #[tokio::test]
341 async fn allow_path_delegates_to_inner() {
342 let (llm_count, llm) = MockLlm::new("ALLOW");
343 let (inner_count, inner) = MockInner::new();
344 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
345 let result = gate.execute_tool_call(&make_call("shell")).await;
346 assert!(result.is_ok());
347 assert_eq!(
348 llm_count.load(Ordering::SeqCst),
349 1,
350 "LLM must be called once"
351 );
352 assert_eq!(
353 inner_count.load(Ordering::SeqCst),
354 1,
355 "inner executor must be called on allow"
356 );
357 }
358
359 #[tokio::test]
360 async fn deny_path_blocks_and_does_not_call_inner() {
361 let (llm_count, llm) = MockLlm::new("DENY: unsafe command");
362 let (inner_count, inner) = MockInner::new();
363 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
364 let result = gate.execute_tool_call(&make_call("shell")).await;
365 assert!(matches!(result, Err(ToolError::Blocked { .. })));
366 assert_eq!(llm_count.load(Ordering::SeqCst), 1);
367 assert_eq!(
368 inner_count.load(Ordering::SeqCst),
369 0,
370 "inner must NOT be called on deny"
371 );
372 }
373
374 #[tokio::test]
375 async fn error_message_is_opaque() {
376 let (_, llm) = MockLlm::new("DENY: secret internal policy rule XYZ");
378 let (_, inner) = MockInner::new();
379 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
380 let err = gate
381 .execute_tool_call(&make_call("shell"))
382 .await
383 .unwrap_err();
384 if let ToolError::Blocked { command } = err {
385 assert!(
386 !command.contains("secret internal policy rule XYZ"),
387 "LLM denial reason must not leak to main LLM"
388 );
389 } else {
390 panic!("expected Blocked error");
391 }
392 }
393
394 #[tokio::test]
395 async fn fail_closed_blocks_on_llm_error() {
396 struct FailingLlm;
397 impl PolicyLlmClient for FailingLlm {
398 fn chat<'a>(
399 &'a self,
400 _: &'a [PolicyMessage],
401 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
402 Box::pin(async { Err("network error".to_owned()) })
403 }
404 }
405
406 let (_, inner) = MockInner::new();
407 let gate = AdversarialPolicyGateExecutor::new(
408 inner,
409 make_validator(false), Arc::new(FailingLlm),
411 );
412 let result = gate.execute_tool_call(&make_call("shell")).await;
413 assert!(
414 matches!(result, Err(ToolError::Blocked { .. })),
415 "fail-closed must block on LLM error"
416 );
417 }
418
419 #[tokio::test]
420 async fn fail_open_allows_on_llm_error() {
421 struct FailingLlm;
422 impl PolicyLlmClient for FailingLlm {
423 fn chat<'a>(
424 &'a self,
425 _: &'a [PolicyMessage],
426 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
427 Box::pin(async { Err("network error".to_owned()) })
428 }
429 }
430
431 let (inner_count, inner) = MockInner::new();
432 let gate = AdversarialPolicyGateExecutor::new(
433 inner,
434 make_validator(true), Arc::new(FailingLlm),
436 );
437 let result = gate.execute_tool_call(&make_call("shell")).await;
438 assert!(result.is_ok(), "fail-open must allow on LLM error");
439 assert_eq!(inner_count.load(Ordering::SeqCst), 1);
440 }
441
442 #[tokio::test]
443 async fn confirmed_also_enforces_policy() {
444 let (_, llm) = MockLlm::new("DENY: blocked");
445 let (_, inner) = MockInner::new();
446 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
447 let result = gate.execute_tool_call_confirmed(&make_call("shell")).await;
448 assert!(
449 matches!(result, Err(ToolError::Blocked { .. })),
450 "confirmed path must also enforce adversarial policy"
451 );
452 }
453
454 #[tokio::test]
455 async fn legacy_execute_bypasses_policy() {
456 let (llm_count, llm) = MockLlm::new("DENY: anything");
457 let (_, inner) = MockInner::new();
458 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
459 let result = gate.execute("```shell\necho hi\n```").await;
460 assert!(
461 result.is_ok(),
462 "legacy execute must bypass adversarial policy"
463 );
464 assert_eq!(
465 llm_count.load(Ordering::SeqCst),
466 0,
467 "LLM must NOT be called for legacy dispatch"
468 );
469 }
470
471 #[tokio::test]
472 async fn delegation_set_skill_env() {
473 let (_, llm) = MockLlm::new("ALLOW");
475 let (_, inner) = MockInner::new();
476 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
477 gate.set_skill_env(None);
478 }
479
480 #[tokio::test]
481 async fn delegation_set_effective_trust() {
482 use crate::SkillTrustLevel;
483 let (_, llm) = MockLlm::new("ALLOW");
484 let (_, inner) = MockInner::new();
485 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
486 gate.set_effective_trust(SkillTrustLevel::Trusted);
487 }
488
489 #[tokio::test]
490 async fn delegation_is_tool_retryable() {
491 let (_, llm) = MockLlm::new("ALLOW");
492 let (_, inner) = MockInner::new();
493 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
494 let retryable = gate.is_tool_retryable("shell");
495 assert!(!retryable, "MockInner returns false for is_tool_retryable");
496 }
497
498 #[tokio::test]
499 async fn delegation_tool_definitions() {
500 let (_, llm) = MockLlm::new("ALLOW");
501 let (_, inner) = MockInner::new();
502 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
503 let defs = gate.tool_definitions();
504 assert!(defs.is_empty(), "MockInner returns empty tool definitions");
505 }
506
507 #[tokio::test]
508 async fn audit_entry_contains_adversarial_decision() {
509 use tempfile::TempDir;
510
511 let dir = TempDir::new().unwrap();
512 let log_path = dir.path().join("audit.log");
513 let audit_config = crate::config::AuditConfig {
514 enabled: true,
515 destination: log_path.display().to_string(),
516 ..Default::default()
517 };
518 let audit_logger = Arc::new(
519 crate::audit::AuditLogger::from_config(&audit_config)
520 .await
521 .unwrap(),
522 );
523
524 let (_, llm) = MockLlm::new("ALLOW");
525 let (_, inner) = MockInner::new();
526 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm))
527 .with_audit(Arc::clone(&audit_logger));
528
529 gate.execute_tool_call(&make_call("shell")).await.unwrap();
530
531 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
532 assert!(
533 content.contains("adversarial_policy_decision"),
534 "audit entry must contain adversarial_policy_decision field"
535 );
536 assert!(
537 content.contains("\"allow\""),
538 "allow decision must be recorded"
539 );
540 }
541
542 #[tokio::test]
543 async fn audit_entry_deny_contains_decision() {
544 use tempfile::TempDir;
545
546 let dir = TempDir::new().unwrap();
547 let log_path = dir.path().join("audit.log");
548 let audit_config = crate::config::AuditConfig {
549 enabled: true,
550 destination: log_path.display().to_string(),
551 ..Default::default()
552 };
553 let audit_logger = Arc::new(
554 crate::audit::AuditLogger::from_config(&audit_config)
555 .await
556 .unwrap(),
557 );
558
559 let (_, llm) = MockLlm::new("DENY: test denial");
560 let (_, inner) = MockInner::new();
561 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm))
562 .with_audit(Arc::clone(&audit_logger));
563
564 let _ = gate.execute_tool_call(&make_call("shell")).await;
565
566 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
567 assert!(
568 content.contains("deny:"),
569 "deny decision must be recorded in audit"
570 );
571 }
572
573 #[tokio::test]
574 async fn audit_entry_propagates_claim_source() {
575 use tempfile::TempDir;
576
577 #[derive(Debug)]
578 struct InnerWithClaimSource;
579
580 impl ToolExecutor for InnerWithClaimSource {
581 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
582 Ok(None)
583 }
584
585 async fn execute_tool_call(
586 &self,
587 call: &ToolCall,
588 ) -> Result<Option<ToolOutput>, ToolError> {
589 Ok(Some(ToolOutput {
590 tool_name: call.tool_id.clone(),
591 summary: "ok".into(),
592 blocks_executed: 1,
593 filter_stats: None,
594 diff: None,
595 streamed: false,
596 terminal_id: None,
597 locations: None,
598 raw_response: None,
599 claim_source: Some(crate::executor::ClaimSource::Shell),
600 }))
601 }
602 }
603
604 let dir = TempDir::new().unwrap();
605 let log_path = dir.path().join("audit.log");
606 let audit_config = crate::config::AuditConfig {
607 enabled: true,
608 destination: log_path.display().to_string(),
609 ..Default::default()
610 };
611 let audit_logger = Arc::new(
612 crate::audit::AuditLogger::from_config(&audit_config)
613 .await
614 .unwrap(),
615 );
616
617 let (_, llm) = MockLlm::new("ALLOW");
618 let gate = AdversarialPolicyGateExecutor::new(
619 InnerWithClaimSource,
620 make_validator(false),
621 Arc::new(llm),
622 )
623 .with_audit(Arc::clone(&audit_logger));
624
625 gate.execute_tool_call(&make_call("shell")).await.unwrap();
626
627 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
628 assert!(
629 content.contains("\"shell\""),
630 "claim_source must be propagated into the post-execution audit entry"
631 );
632 }
633}