1use std::sync::Arc;
18
19use crate::adversarial_policy::{PolicyDecision, PolicyLlmClient, PolicyValidator};
20use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
21use crate::executor::{ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput};
22use crate::registry::ToolDef;
23
24pub struct AdversarialPolicyGateExecutor<T: ToolExecutor> {
30 inner: T,
31 validator: Arc<PolicyValidator>,
32 llm: Arc<dyn PolicyLlmClient>,
33 audit: Option<Arc<AuditLogger>>,
34}
35
36impl<T: ToolExecutor + std::fmt::Debug> std::fmt::Debug for AdversarialPolicyGateExecutor<T> {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("AdversarialPolicyGateExecutor")
39 .field("inner", &self.inner)
40 .finish_non_exhaustive()
41 }
42}
43
44impl<T: ToolExecutor> AdversarialPolicyGateExecutor<T> {
45 #[must_use]
47 pub fn new(inner: T, validator: Arc<PolicyValidator>, llm: Arc<dyn PolicyLlmClient>) -> Self {
48 Self {
49 inner,
50 validator,
51 llm,
52 audit: None,
53 }
54 }
55
56 #[must_use]
58 pub fn with_audit(mut self, audit: Arc<AuditLogger>) -> Self {
59 self.audit = Some(audit);
60 self
61 }
62
63 async fn check_policy(&self, call: &ToolCall) -> Result<(), ToolError> {
64 tracing::info!(
65 tool = %call.tool_id,
66 status_spinner = true,
67 "Validating tool policy\u{2026}"
68 );
69
70 let decision = self
71 .validator
72 .validate(&call.tool_id, &call.params, self.llm.as_ref())
73 .await;
74
75 match decision {
76 PolicyDecision::Allow => {
77 tracing::debug!(tool = %call.tool_id, "adversarial policy: allow");
78 self.write_audit(call, "allow", AuditResult::Success, None)
79 .await;
80 Ok(())
81 }
82 PolicyDecision::Deny { reason } => {
83 tracing::warn!(
84 tool = %call.tool_id,
85 reason = %reason,
86 "adversarial policy: deny"
87 );
88 self.write_audit(
89 call,
90 &format!("deny:{reason}"),
91 AuditResult::Blocked {
92 reason: reason.clone(),
93 },
94 None,
95 )
96 .await;
97 Err(ToolError::Blocked {
99 command: "[adversarial] Tool call denied by policy".to_owned(),
100 })
101 }
102 PolicyDecision::Error { message } => {
103 tracing::warn!(
104 tool = %call.tool_id,
105 error = %message,
106 fail_open = self.validator.fail_open(),
107 "adversarial policy: LLM error"
108 );
109 if self.validator.fail_open() {
110 self.write_audit(
111 call,
112 &format!("error:{message}"),
113 AuditResult::Success,
114 None,
115 )
116 .await;
117 Ok(())
118 } else {
119 self.write_audit(
120 call,
121 &format!("error:{message}"),
122 AuditResult::Blocked {
123 reason: "adversarial policy LLM error (fail-closed)".to_owned(),
124 },
125 None,
126 )
127 .await;
128 Err(ToolError::Blocked {
129 command: "[adversarial] Tool call denied: policy check failed".to_owned(),
130 })
131 }
132 }
133 }
134 }
135
136 async fn write_audit(
137 &self,
138 call: &ToolCall,
139 decision: &str,
140 result: AuditResult,
141 claim_source: Option<ClaimSource>,
142 ) {
143 let Some(audit) = &self.audit else { return };
144 let entry = AuditEntry {
145 timestamp: chrono_now(),
146 tool: call.tool_id.clone(),
147 command: params_summary(&call.params),
148 result,
149 duration_ms: 0,
150 error_category: None,
151 error_domain: None,
152 error_phase: None,
153 claim_source,
154 mcp_server_id: None,
155 injection_flagged: false,
156 embedding_anomalous: false,
157 cross_boundary_mcp_to_acp: false,
158 adversarial_policy_decision: Some(decision.to_owned()),
159 exit_code: None,
160 truncated: false,
161 };
162 audit.log(&entry).await;
163 }
164}
165
166impl<T: ToolExecutor> ToolExecutor for AdversarialPolicyGateExecutor<T> {
167 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
169 self.inner.execute(response).await
170 }
171
172 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
173 self.inner.execute_confirmed(response).await
174 }
175
176 fn tool_definitions(&self) -> Vec<ToolDef> {
178 self.inner.tool_definitions()
179 }
180
181 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
182 self.check_policy(call).await?;
183 let output = self.inner.execute_tool_call(call).await?;
184 if let Some(ref out) = output {
185 self.write_audit(
186 call,
187 "allow:executed",
188 AuditResult::Success,
189 out.claim_source,
190 )
191 .await;
192 }
193 Ok(output)
194 }
195
196 async fn execute_tool_call_confirmed(
198 &self,
199 call: &ToolCall,
200 ) -> Result<Option<ToolOutput>, ToolError> {
201 self.check_policy(call).await?;
202 let output = self.inner.execute_tool_call_confirmed(call).await?;
203 if let Some(ref out) = output {
204 self.write_audit(
205 call,
206 "allow:executed",
207 AuditResult::Success,
208 out.claim_source,
209 )
210 .await;
211 }
212 Ok(output)
213 }
214
215 fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
216 self.inner.set_skill_env(env);
217 }
218
219 fn set_effective_trust(&self, level: crate::SkillTrustLevel) {
220 self.inner.set_effective_trust(level);
221 }
222
223 fn is_tool_retryable(&self, tool_id: &str) -> bool {
224 self.inner.is_tool_retryable(tool_id)
225 }
226}
227
228fn params_summary(params: &serde_json::Map<String, serde_json::Value>) -> String {
229 let s = serde_json::to_string(params).unwrap_or_default();
230 if s.chars().count() > 500 {
231 let truncated: String = s.chars().take(497).collect();
232 format!("{truncated}\u{2026}")
233 } else {
234 s
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use std::future::Future;
241 use std::pin::Pin;
242 use std::sync::Arc;
243 use std::sync::atomic::{AtomicUsize, Ordering};
244 use std::time::Duration;
245
246 use super::*;
247 use crate::adversarial_policy::{PolicyMessage, PolicyValidator};
248 use crate::executor::{ToolCall, ToolOutput};
249
250 struct MockLlm {
253 response: String,
254 call_count: Arc<AtomicUsize>,
255 }
256
257 impl MockLlm {
258 fn new(response: impl Into<String>) -> (Arc<AtomicUsize>, Self) {
259 let counter = Arc::new(AtomicUsize::new(0));
260 let client = Self {
261 response: response.into(),
262 call_count: Arc::clone(&counter),
263 };
264 (counter, client)
265 }
266 }
267
268 impl PolicyLlmClient for MockLlm {
269 fn chat<'a>(
270 &'a self,
271 _messages: &'a [PolicyMessage],
272 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
273 self.call_count.fetch_add(1, Ordering::SeqCst);
274 let resp = self.response.clone();
275 Box::pin(async move { Ok(resp) })
276 }
277 }
278
279 #[derive(Debug)]
282 struct MockInner {
283 call_count: Arc<AtomicUsize>,
284 }
285
286 impl MockInner {
287 fn new() -> (Arc<AtomicUsize>, Self) {
288 let counter = Arc::new(AtomicUsize::new(0));
289 let exec = Self {
290 call_count: Arc::clone(&counter),
291 };
292 (counter, exec)
293 }
294 }
295
296 impl ToolExecutor for MockInner {
297 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
298 Ok(None)
299 }
300
301 async fn execute_tool_call(
302 &self,
303 call: &ToolCall,
304 ) -> Result<Option<ToolOutput>, ToolError> {
305 self.call_count.fetch_add(1, Ordering::SeqCst);
306 Ok(Some(ToolOutput {
307 tool_name: call.tool_id.clone(),
308 summary: "ok".into(),
309 blocks_executed: 1,
310 filter_stats: None,
311 diff: None,
312 streamed: false,
313 terminal_id: None,
314 locations: None,
315 raw_response: None,
316 claim_source: None,
317 }))
318 }
319 }
320
321 fn make_call(tool_id: &str) -> ToolCall {
322 ToolCall {
323 tool_id: tool_id.into(),
324 params: serde_json::Map::new(),
325 }
326 }
327
328 fn make_validator(fail_open: bool) -> Arc<PolicyValidator> {
329 Arc::new(PolicyValidator::new(
330 vec!["test policy".to_owned()],
331 Duration::from_millis(500),
332 fail_open,
333 Vec::new(),
334 ))
335 }
336
337 #[tokio::test]
338 async fn allow_path_delegates_to_inner() {
339 let (llm_count, llm) = MockLlm::new("ALLOW");
340 let (inner_count, inner) = MockInner::new();
341 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
342 let result = gate.execute_tool_call(&make_call("shell")).await;
343 assert!(result.is_ok());
344 assert_eq!(
345 llm_count.load(Ordering::SeqCst),
346 1,
347 "LLM must be called once"
348 );
349 assert_eq!(
350 inner_count.load(Ordering::SeqCst),
351 1,
352 "inner executor must be called on allow"
353 );
354 }
355
356 #[tokio::test]
357 async fn deny_path_blocks_and_does_not_call_inner() {
358 let (llm_count, llm) = MockLlm::new("DENY: unsafe command");
359 let (inner_count, inner) = MockInner::new();
360 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
361 let result = gate.execute_tool_call(&make_call("shell")).await;
362 assert!(matches!(result, Err(ToolError::Blocked { .. })));
363 assert_eq!(llm_count.load(Ordering::SeqCst), 1);
364 assert_eq!(
365 inner_count.load(Ordering::SeqCst),
366 0,
367 "inner must NOT be called on deny"
368 );
369 }
370
371 #[tokio::test]
372 async fn error_message_is_opaque() {
373 let (_, llm) = MockLlm::new("DENY: secret internal policy rule XYZ");
375 let (_, inner) = MockInner::new();
376 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
377 let err = gate
378 .execute_tool_call(&make_call("shell"))
379 .await
380 .unwrap_err();
381 if let ToolError::Blocked { command } = err {
382 assert!(
383 !command.contains("secret internal policy rule XYZ"),
384 "LLM denial reason must not leak to main LLM"
385 );
386 } else {
387 panic!("expected Blocked error");
388 }
389 }
390
391 #[tokio::test]
392 async fn fail_closed_blocks_on_llm_error() {
393 struct FailingLlm;
394 impl PolicyLlmClient for FailingLlm {
395 fn chat<'a>(
396 &'a self,
397 _: &'a [PolicyMessage],
398 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
399 Box::pin(async { Err("network error".to_owned()) })
400 }
401 }
402
403 let (_, inner) = MockInner::new();
404 let gate = AdversarialPolicyGateExecutor::new(
405 inner,
406 make_validator(false), Arc::new(FailingLlm),
408 );
409 let result = gate.execute_tool_call(&make_call("shell")).await;
410 assert!(
411 matches!(result, Err(ToolError::Blocked { .. })),
412 "fail-closed must block on LLM error"
413 );
414 }
415
416 #[tokio::test]
417 async fn fail_open_allows_on_llm_error() {
418 struct FailingLlm;
419 impl PolicyLlmClient for FailingLlm {
420 fn chat<'a>(
421 &'a self,
422 _: &'a [PolicyMessage],
423 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
424 Box::pin(async { Err("network error".to_owned()) })
425 }
426 }
427
428 let (inner_count, inner) = MockInner::new();
429 let gate = AdversarialPolicyGateExecutor::new(
430 inner,
431 make_validator(true), Arc::new(FailingLlm),
433 );
434 let result = gate.execute_tool_call(&make_call("shell")).await;
435 assert!(result.is_ok(), "fail-open must allow on LLM error");
436 assert_eq!(inner_count.load(Ordering::SeqCst), 1);
437 }
438
439 #[tokio::test]
440 async fn confirmed_also_enforces_policy() {
441 let (_, llm) = MockLlm::new("DENY: blocked");
442 let (_, inner) = MockInner::new();
443 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
444 let result = gate.execute_tool_call_confirmed(&make_call("shell")).await;
445 assert!(
446 matches!(result, Err(ToolError::Blocked { .. })),
447 "confirmed path must also enforce adversarial policy"
448 );
449 }
450
451 #[tokio::test]
452 async fn legacy_execute_bypasses_policy() {
453 let (llm_count, llm) = MockLlm::new("DENY: anything");
454 let (_, inner) = MockInner::new();
455 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
456 let result = gate.execute("```shell\necho hi\n```").await;
457 assert!(
458 result.is_ok(),
459 "legacy execute must bypass adversarial policy"
460 );
461 assert_eq!(
462 llm_count.load(Ordering::SeqCst),
463 0,
464 "LLM must NOT be called for legacy dispatch"
465 );
466 }
467
468 #[tokio::test]
469 async fn delegation_set_skill_env() {
470 let (_, llm) = MockLlm::new("ALLOW");
472 let (_, inner) = MockInner::new();
473 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
474 gate.set_skill_env(None);
475 }
476
477 #[tokio::test]
478 async fn delegation_set_effective_trust() {
479 use crate::SkillTrustLevel;
480 let (_, llm) = MockLlm::new("ALLOW");
481 let (_, inner) = MockInner::new();
482 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
483 gate.set_effective_trust(SkillTrustLevel::Trusted);
484 }
485
486 #[tokio::test]
487 async fn delegation_is_tool_retryable() {
488 let (_, llm) = MockLlm::new("ALLOW");
489 let (_, inner) = MockInner::new();
490 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
491 let retryable = gate.is_tool_retryable("shell");
492 assert!(!retryable, "MockInner returns false for is_tool_retryable");
493 }
494
495 #[tokio::test]
496 async fn delegation_tool_definitions() {
497 let (_, llm) = MockLlm::new("ALLOW");
498 let (_, inner) = MockInner::new();
499 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
500 let defs = gate.tool_definitions();
501 assert!(defs.is_empty(), "MockInner returns empty tool definitions");
502 }
503
504 #[tokio::test]
505 async fn audit_entry_contains_adversarial_decision() {
506 use tempfile::TempDir;
507
508 let dir = TempDir::new().unwrap();
509 let log_path = dir.path().join("audit.log");
510 let audit_config = crate::config::AuditConfig {
511 enabled: true,
512 destination: log_path.display().to_string(),
513 ..Default::default()
514 };
515 let audit_logger = Arc::new(
516 crate::audit::AuditLogger::from_config(&audit_config)
517 .await
518 .unwrap(),
519 );
520
521 let (_, llm) = MockLlm::new("ALLOW");
522 let (_, inner) = MockInner::new();
523 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm))
524 .with_audit(Arc::clone(&audit_logger));
525
526 gate.execute_tool_call(&make_call("shell")).await.unwrap();
527
528 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
529 assert!(
530 content.contains("adversarial_policy_decision"),
531 "audit entry must contain adversarial_policy_decision field"
532 );
533 assert!(
534 content.contains("\"allow\""),
535 "allow decision must be recorded"
536 );
537 }
538
539 #[tokio::test]
540 async fn audit_entry_deny_contains_decision() {
541 use tempfile::TempDir;
542
543 let dir = TempDir::new().unwrap();
544 let log_path = dir.path().join("audit.log");
545 let audit_config = crate::config::AuditConfig {
546 enabled: true,
547 destination: log_path.display().to_string(),
548 ..Default::default()
549 };
550 let audit_logger = Arc::new(
551 crate::audit::AuditLogger::from_config(&audit_config)
552 .await
553 .unwrap(),
554 );
555
556 let (_, llm) = MockLlm::new("DENY: test denial");
557 let (_, inner) = MockInner::new();
558 let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm))
559 .with_audit(Arc::clone(&audit_logger));
560
561 let _ = gate.execute_tool_call(&make_call("shell")).await;
562
563 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
564 assert!(
565 content.contains("deny:"),
566 "deny decision must be recorded in audit"
567 );
568 }
569
570 #[tokio::test]
571 async fn audit_entry_propagates_claim_source() {
572 use tempfile::TempDir;
573
574 #[derive(Debug)]
575 struct InnerWithClaimSource;
576
577 impl ToolExecutor for InnerWithClaimSource {
578 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
579 Ok(None)
580 }
581
582 async fn execute_tool_call(
583 &self,
584 call: &ToolCall,
585 ) -> Result<Option<ToolOutput>, ToolError> {
586 Ok(Some(ToolOutput {
587 tool_name: call.tool_id.clone(),
588 summary: "ok".into(),
589 blocks_executed: 1,
590 filter_stats: None,
591 diff: None,
592 streamed: false,
593 terminal_id: None,
594 locations: None,
595 raw_response: None,
596 claim_source: Some(crate::executor::ClaimSource::Shell),
597 }))
598 }
599 }
600
601 let dir = TempDir::new().unwrap();
602 let log_path = dir.path().join("audit.log");
603 let audit_config = crate::config::AuditConfig {
604 enabled: true,
605 destination: log_path.display().to_string(),
606 ..Default::default()
607 };
608 let audit_logger = Arc::new(
609 crate::audit::AuditLogger::from_config(&audit_config)
610 .await
611 .unwrap(),
612 );
613
614 let (_, llm) = MockLlm::new("ALLOW");
615 let gate = AdversarialPolicyGateExecutor::new(
616 InnerWithClaimSource,
617 make_validator(false),
618 Arc::new(llm),
619 )
620 .with_audit(Arc::clone(&audit_logger));
621
622 gate.execute_tool_call(&make_call("shell")).await.unwrap();
623
624 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
625 assert!(
626 content.contains("\"shell\""),
627 "claim_source must be propagated into the post-execution audit entry"
628 );
629 }
630}