Skip to main content

zeph_tools/
adversarial_gate.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! `AdversarialPolicyGateExecutor`: wraps an inner `ToolExecutor` and runs an LLM-based
5//! policy check before delegating any structured tool call.
6//!
7//! Wiring order (outermost first):
8//!   `PolicyGateExecutor` → `AdversarialPolicyGateExecutor` → `TrustGateExecutor` → ...
9//!
10//! Per CRIT-04 recommendation: declarative `PolicyGateExecutor` is outermost.
11//! Adversarial gate only fires for calls that pass declarative policy — no duplication.
12//!
13//! Per CRIT-06: ALL `ToolExecutor` trait methods are delegated to `self.inner`.
14//! Per CRIT-01: fail behavior (allow/deny on LLM error) is controlled by `fail_open` config.
15//! Per CRIT-11: params are sanitized and wrapped in code fences before LLM call.
16
17use std::sync::Arc;
18
19use crate::adversarial_policy::{PolicyDecision, PolicyLlmClient, PolicyValidator};
20use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
21use crate::executor::{ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput};
22use crate::registry::ToolDef;
23
24/// Wraps an inner `ToolExecutor`, running an LLM-based adversarial policy check
25/// before delegating structured tool calls.
26///
27/// Only `execute_tool_call` and `execute_tool_call_confirmed` are intercepted.
28/// Legacy `execute` / `execute_confirmed` bypass the check (no structured `tool_id`).
29pub struct AdversarialPolicyGateExecutor<T: ToolExecutor> {
30    inner: T,
31    validator: Arc<PolicyValidator>,
32    llm: Arc<dyn PolicyLlmClient>,
33    audit: Option<Arc<AuditLogger>>,
34}
35
36impl<T: ToolExecutor + std::fmt::Debug> std::fmt::Debug for AdversarialPolicyGateExecutor<T> {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("AdversarialPolicyGateExecutor")
39            .field("inner", &self.inner)
40            .finish_non_exhaustive()
41    }
42}
43
44impl<T: ToolExecutor> AdversarialPolicyGateExecutor<T> {
45    /// Create a new `AdversarialPolicyGateExecutor`.
46    #[must_use]
47    pub fn new(inner: T, validator: Arc<PolicyValidator>, llm: Arc<dyn PolicyLlmClient>) -> Self {
48        Self {
49            inner,
50            validator,
51            llm,
52            audit: None,
53        }
54    }
55
56    /// Attach an audit logger.
57    #[must_use]
58    pub fn with_audit(mut self, audit: Arc<AuditLogger>) -> Self {
59        self.audit = Some(audit);
60        self
61    }
62
63    async fn check_policy(&self, call: &ToolCall) -> Result<(), ToolError> {
64        tracing::info!(
65            tool = %call.tool_id,
66            status_spinner = true,
67            "Validating tool policy\u{2026}"
68        );
69
70        let decision = self
71            .validator
72            .validate(&call.tool_id, &call.params, self.llm.as_ref())
73            .await;
74
75        match decision {
76            PolicyDecision::Allow => {
77                tracing::debug!(tool = %call.tool_id, "adversarial policy: allow");
78                self.write_audit(call, "allow", AuditResult::Success, None)
79                    .await;
80                Ok(())
81            }
82            PolicyDecision::Deny { reason } => {
83                tracing::warn!(
84                    tool = %call.tool_id,
85                    reason = %reason,
86                    "adversarial policy: deny"
87                );
88                self.write_audit(
89                    call,
90                    &format!("deny:{reason}"),
91                    AuditResult::Blocked {
92                        reason: reason.clone(),
93                    },
94                    None,
95                )
96                .await;
97                // MED-03: do NOT surface the LLM reason to the main LLM.
98                Err(ToolError::Blocked {
99                    command: "[adversarial] Tool call denied by policy".to_owned(),
100                })
101            }
102            PolicyDecision::Error { message } => {
103                tracing::warn!(
104                    tool = %call.tool_id,
105                    error = %message,
106                    fail_open = self.validator.fail_open(),
107                    "adversarial policy: LLM error"
108                );
109                if self.validator.fail_open() {
110                    self.write_audit(
111                        call,
112                        &format!("error:{message}"),
113                        AuditResult::Success,
114                        None,
115                    )
116                    .await;
117                    Ok(())
118                } else {
119                    self.write_audit(
120                        call,
121                        &format!("error:{message}"),
122                        AuditResult::Blocked {
123                            reason: "adversarial policy LLM error (fail-closed)".to_owned(),
124                        },
125                        None,
126                    )
127                    .await;
128                    Err(ToolError::Blocked {
129                        command: "[adversarial] Tool call denied: policy check failed".to_owned(),
130                    })
131                }
132            }
133        }
134    }
135
136    async fn write_audit(
137        &self,
138        call: &ToolCall,
139        decision: &str,
140        result: AuditResult,
141        claim_source: Option<ClaimSource>,
142    ) {
143        let Some(audit) = &self.audit else { return };
144        let entry = AuditEntry {
145            timestamp: chrono_now(),
146            tool: call.tool_id.clone(),
147            command: params_summary(&call.params),
148            result,
149            duration_ms: 0,
150            error_category: None,
151            error_domain: None,
152            error_phase: None,
153            claim_source,
154            mcp_server_id: None,
155            injection_flagged: false,
156            embedding_anomalous: false,
157            cross_boundary_mcp_to_acp: false,
158            adversarial_policy_decision: Some(decision.to_owned()),
159            exit_code: None,
160            truncated: false,
161            caller_id: call.caller_id.clone(),
162            policy_match: None,
163        };
164        audit.log(&entry).await;
165    }
166}
167
168impl<T: ToolExecutor> ToolExecutor for AdversarialPolicyGateExecutor<T> {
169    // Legacy dispatch bypasses adversarial check — no structured tool_id available.
170    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
171        self.inner.execute(response).await
172    }
173
174    async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
175        self.inner.execute_confirmed(response).await
176    }
177
178    // CRIT-06: delegate all pass-through methods to inner executor.
179    fn tool_definitions(&self) -> Vec<ToolDef> {
180        self.inner.tool_definitions()
181    }
182
183    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
184        self.check_policy(call).await?;
185        let output = self.inner.execute_tool_call(call).await?;
186        if let Some(ref out) = output {
187            self.write_audit(
188                call,
189                "allow:executed",
190                AuditResult::Success,
191                out.claim_source,
192            )
193            .await;
194        }
195        Ok(output)
196    }
197
198    // MED-04: policy also enforced on confirmed calls.
199    async fn execute_tool_call_confirmed(
200        &self,
201        call: &ToolCall,
202    ) -> Result<Option<ToolOutput>, ToolError> {
203        self.check_policy(call).await?;
204        let output = self.inner.execute_tool_call_confirmed(call).await?;
205        if let Some(ref out) = output {
206            self.write_audit(
207                call,
208                "allow:executed",
209                AuditResult::Success,
210                out.claim_source,
211            )
212            .await;
213        }
214        Ok(output)
215    }
216
217    fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
218        self.inner.set_skill_env(env);
219    }
220
221    fn set_effective_trust(&self, level: crate::SkillTrustLevel) {
222        self.inner.set_effective_trust(level);
223    }
224
225    fn is_tool_retryable(&self, tool_id: &str) -> bool {
226        self.inner.is_tool_retryable(tool_id)
227    }
228}
229
230fn params_summary(params: &serde_json::Map<String, serde_json::Value>) -> String {
231    let s = serde_json::to_string(params).unwrap_or_default();
232    if s.chars().count() > 500 {
233        let truncated: String = s.chars().take(497).collect();
234        format!("{truncated}\u{2026}")
235    } else {
236        s
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use std::future::Future;
243    use std::pin::Pin;
244    use std::sync::Arc;
245    use std::sync::atomic::{AtomicUsize, Ordering};
246    use std::time::Duration;
247
248    use super::*;
249    use crate::adversarial_policy::{PolicyMessage, PolicyValidator};
250    use crate::executor::{ToolCall, ToolOutput};
251
252    // --- Mock LLM client ---
253
254    struct MockLlm {
255        response: String,
256        call_count: Arc<AtomicUsize>,
257    }
258
259    impl MockLlm {
260        fn new(response: impl Into<String>) -> (Arc<AtomicUsize>, Self) {
261            let counter = Arc::new(AtomicUsize::new(0));
262            let client = Self {
263                response: response.into(),
264                call_count: Arc::clone(&counter),
265            };
266            (counter, client)
267        }
268    }
269
270    impl PolicyLlmClient for MockLlm {
271        fn chat<'a>(
272            &'a self,
273            _messages: &'a [PolicyMessage],
274        ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
275            self.call_count.fetch_add(1, Ordering::SeqCst);
276            let resp = self.response.clone();
277            Box::pin(async move { Ok(resp) })
278        }
279    }
280
281    // --- Mock inner executor ---
282
283    #[derive(Debug)]
284    struct MockInner {
285        call_count: Arc<AtomicUsize>,
286    }
287
288    impl MockInner {
289        fn new() -> (Arc<AtomicUsize>, Self) {
290            let counter = Arc::new(AtomicUsize::new(0));
291            let exec = Self {
292                call_count: Arc::clone(&counter),
293            };
294            (counter, exec)
295        }
296    }
297
298    impl ToolExecutor for MockInner {
299        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
300            Ok(None)
301        }
302
303        async fn execute_tool_call(
304            &self,
305            call: &ToolCall,
306        ) -> Result<Option<ToolOutput>, ToolError> {
307            self.call_count.fetch_add(1, Ordering::SeqCst);
308            Ok(Some(ToolOutput {
309                tool_name: call.tool_id.clone(),
310                summary: "ok".into(),
311                blocks_executed: 1,
312                filter_stats: None,
313                diff: None,
314                streamed: false,
315                terminal_id: None,
316                locations: None,
317                raw_response: None,
318                claim_source: None,
319            }))
320        }
321    }
322
323    fn make_call(tool_id: &str) -> ToolCall {
324        ToolCall {
325            tool_id: tool_id.into(),
326            params: serde_json::Map::new(),
327            caller_id: None,
328        }
329    }
330
331    fn make_validator(fail_open: bool) -> Arc<PolicyValidator> {
332        Arc::new(PolicyValidator::new(
333            vec!["test policy".to_owned()],
334            Duration::from_millis(500),
335            fail_open,
336            Vec::new(),
337        ))
338    }
339
340    #[tokio::test]
341    async fn allow_path_delegates_to_inner() {
342        let (llm_count, llm) = MockLlm::new("ALLOW");
343        let (inner_count, inner) = MockInner::new();
344        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
345        let result = gate.execute_tool_call(&make_call("shell")).await;
346        assert!(result.is_ok());
347        assert_eq!(
348            llm_count.load(Ordering::SeqCst),
349            1,
350            "LLM must be called once"
351        );
352        assert_eq!(
353            inner_count.load(Ordering::SeqCst),
354            1,
355            "inner executor must be called on allow"
356        );
357    }
358
359    #[tokio::test]
360    async fn deny_path_blocks_and_does_not_call_inner() {
361        let (llm_count, llm) = MockLlm::new("DENY: unsafe command");
362        let (inner_count, inner) = MockInner::new();
363        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
364        let result = gate.execute_tool_call(&make_call("shell")).await;
365        assert!(matches!(result, Err(ToolError::Blocked { .. })));
366        assert_eq!(llm_count.load(Ordering::SeqCst), 1);
367        assert_eq!(
368            inner_count.load(Ordering::SeqCst),
369            0,
370            "inner must NOT be called on deny"
371        );
372    }
373
374    #[tokio::test]
375    async fn error_message_is_opaque() {
376        // MED-03: error returned to main LLM must not contain the LLM denial reason.
377        let (_, llm) = MockLlm::new("DENY: secret internal policy rule XYZ");
378        let (_, inner) = MockInner::new();
379        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
380        let err = gate
381            .execute_tool_call(&make_call("shell"))
382            .await
383            .unwrap_err();
384        if let ToolError::Blocked { command } = err {
385            assert!(
386                !command.contains("secret internal policy rule XYZ"),
387                "LLM denial reason must not leak to main LLM"
388            );
389        } else {
390            panic!("expected Blocked error");
391        }
392    }
393
394    #[tokio::test]
395    async fn fail_closed_blocks_on_llm_error() {
396        struct FailingLlm;
397        impl PolicyLlmClient for FailingLlm {
398            fn chat<'a>(
399                &'a self,
400                _: &'a [PolicyMessage],
401            ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
402                Box::pin(async { Err("network error".to_owned()) })
403            }
404        }
405
406        let (_, inner) = MockInner::new();
407        let gate = AdversarialPolicyGateExecutor::new(
408            inner,
409            make_validator(false), // fail_open = false
410            Arc::new(FailingLlm),
411        );
412        let result = gate.execute_tool_call(&make_call("shell")).await;
413        assert!(
414            matches!(result, Err(ToolError::Blocked { .. })),
415            "fail-closed must block on LLM error"
416        );
417    }
418
419    #[tokio::test]
420    async fn fail_open_allows_on_llm_error() {
421        struct FailingLlm;
422        impl PolicyLlmClient for FailingLlm {
423            fn chat<'a>(
424                &'a self,
425                _: &'a [PolicyMessage],
426            ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
427                Box::pin(async { Err("network error".to_owned()) })
428            }
429        }
430
431        let (inner_count, inner) = MockInner::new();
432        let gate = AdversarialPolicyGateExecutor::new(
433            inner,
434            make_validator(true), // fail_open = true
435            Arc::new(FailingLlm),
436        );
437        let result = gate.execute_tool_call(&make_call("shell")).await;
438        assert!(result.is_ok(), "fail-open must allow on LLM error");
439        assert_eq!(inner_count.load(Ordering::SeqCst), 1);
440    }
441
442    #[tokio::test]
443    async fn confirmed_also_enforces_policy() {
444        let (_, llm) = MockLlm::new("DENY: blocked");
445        let (_, inner) = MockInner::new();
446        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
447        let result = gate.execute_tool_call_confirmed(&make_call("shell")).await;
448        assert!(
449            matches!(result, Err(ToolError::Blocked { .. })),
450            "confirmed path must also enforce adversarial policy"
451        );
452    }
453
454    #[tokio::test]
455    async fn legacy_execute_bypasses_policy() {
456        let (llm_count, llm) = MockLlm::new("DENY: anything");
457        let (_, inner) = MockInner::new();
458        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
459        let result = gate.execute("```shell\necho hi\n```").await;
460        assert!(
461            result.is_ok(),
462            "legacy execute must bypass adversarial policy"
463        );
464        assert_eq!(
465            llm_count.load(Ordering::SeqCst),
466            0,
467            "LLM must NOT be called for legacy dispatch"
468        );
469    }
470
471    #[tokio::test]
472    async fn delegation_set_skill_env() {
473        // Verify that set_skill_env reaches the inner executor without panic.
474        let (_, llm) = MockLlm::new("ALLOW");
475        let (_, inner) = MockInner::new();
476        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
477        gate.set_skill_env(None);
478    }
479
480    #[tokio::test]
481    async fn delegation_set_effective_trust() {
482        use crate::SkillTrustLevel;
483        let (_, llm) = MockLlm::new("ALLOW");
484        let (_, inner) = MockInner::new();
485        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
486        gate.set_effective_trust(SkillTrustLevel::Trusted);
487    }
488
489    #[tokio::test]
490    async fn delegation_is_tool_retryable() {
491        let (_, llm) = MockLlm::new("ALLOW");
492        let (_, inner) = MockInner::new();
493        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
494        let retryable = gate.is_tool_retryable("shell");
495        assert!(!retryable, "MockInner returns false for is_tool_retryable");
496    }
497
498    #[tokio::test]
499    async fn delegation_tool_definitions() {
500        let (_, llm) = MockLlm::new("ALLOW");
501        let (_, inner) = MockInner::new();
502        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
503        let defs = gate.tool_definitions();
504        assert!(defs.is_empty(), "MockInner returns empty tool definitions");
505    }
506
507    #[tokio::test]
508    async fn audit_entry_contains_adversarial_decision() {
509        use tempfile::TempDir;
510
511        let dir = TempDir::new().unwrap();
512        let log_path = dir.path().join("audit.log");
513        let audit_config = crate::config::AuditConfig {
514            enabled: true,
515            destination: log_path.display().to_string(),
516            ..Default::default()
517        };
518        let audit_logger = Arc::new(
519            crate::audit::AuditLogger::from_config(&audit_config)
520                .await
521                .unwrap(),
522        );
523
524        let (_, llm) = MockLlm::new("ALLOW");
525        let (_, inner) = MockInner::new();
526        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm))
527            .with_audit(Arc::clone(&audit_logger));
528
529        gate.execute_tool_call(&make_call("shell")).await.unwrap();
530
531        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
532        assert!(
533            content.contains("adversarial_policy_decision"),
534            "audit entry must contain adversarial_policy_decision field"
535        );
536        assert!(
537            content.contains("\"allow\""),
538            "allow decision must be recorded"
539        );
540    }
541
542    #[tokio::test]
543    async fn audit_entry_deny_contains_decision() {
544        use tempfile::TempDir;
545
546        let dir = TempDir::new().unwrap();
547        let log_path = dir.path().join("audit.log");
548        let audit_config = crate::config::AuditConfig {
549            enabled: true,
550            destination: log_path.display().to_string(),
551            ..Default::default()
552        };
553        let audit_logger = Arc::new(
554            crate::audit::AuditLogger::from_config(&audit_config)
555                .await
556                .unwrap(),
557        );
558
559        let (_, llm) = MockLlm::new("DENY: test denial");
560        let (_, inner) = MockInner::new();
561        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm))
562            .with_audit(Arc::clone(&audit_logger));
563
564        let _ = gate.execute_tool_call(&make_call("shell")).await;
565
566        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
567        assert!(
568            content.contains("deny:"),
569            "deny decision must be recorded in audit"
570        );
571    }
572
573    #[tokio::test]
574    async fn audit_entry_propagates_claim_source() {
575        use tempfile::TempDir;
576
577        #[derive(Debug)]
578        struct InnerWithClaimSource;
579
580        impl ToolExecutor for InnerWithClaimSource {
581            async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
582                Ok(None)
583            }
584
585            async fn execute_tool_call(
586                &self,
587                call: &ToolCall,
588            ) -> Result<Option<ToolOutput>, ToolError> {
589                Ok(Some(ToolOutput {
590                    tool_name: call.tool_id.clone(),
591                    summary: "ok".into(),
592                    blocks_executed: 1,
593                    filter_stats: None,
594                    diff: None,
595                    streamed: false,
596                    terminal_id: None,
597                    locations: None,
598                    raw_response: None,
599                    claim_source: Some(crate::executor::ClaimSource::Shell),
600                }))
601            }
602        }
603
604        let dir = TempDir::new().unwrap();
605        let log_path = dir.path().join("audit.log");
606        let audit_config = crate::config::AuditConfig {
607            enabled: true,
608            destination: log_path.display().to_string(),
609            ..Default::default()
610        };
611        let audit_logger = Arc::new(
612            crate::audit::AuditLogger::from_config(&audit_config)
613                .await
614                .unwrap(),
615        );
616
617        let (_, llm) = MockLlm::new("ALLOW");
618        let gate = AdversarialPolicyGateExecutor::new(
619            InnerWithClaimSource,
620            make_validator(false),
621            Arc::new(llm),
622        )
623        .with_audit(Arc::clone(&audit_logger));
624
625        gate.execute_tool_call(&make_call("shell")).await.unwrap();
626
627        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
628        assert!(
629            content.contains("\"shell\""),
630            "claim_source must be propagated into the post-execution audit entry"
631        );
632    }
633}