Skip to main content

zeph_tools/
adversarial_gate.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! `AdversarialPolicyGateExecutor`: wraps an inner `ToolExecutor` and runs an LLM-based
5//! policy check before delegating any structured tool call.
6//!
7//! Wiring order (outermost first):
8//!   `PolicyGateExecutor` → `AdversarialPolicyGateExecutor` → `TrustGateExecutor` → ...
9//!
10//! Per CRIT-04 recommendation: declarative `PolicyGateExecutor` is outermost.
11//! Adversarial gate only fires for calls that pass declarative policy — no duplication.
12//!
13//! Per CRIT-06: ALL `ToolExecutor` trait methods are delegated to `self.inner`.
14//! Per CRIT-01: fail behavior (allow/deny on LLM error) is controlled by `fail_open` config.
15//! Per CRIT-11: params are sanitized and wrapped in code fences before LLM call.
16
17use std::sync::Arc;
18
19use crate::adversarial_policy::{PolicyDecision, PolicyLlmClient, PolicyValidator};
20use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
21use crate::executor::{ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput};
22use crate::registry::ToolDef;
23
24/// Wraps an inner `ToolExecutor`, running an LLM-based adversarial policy check
25/// before delegating structured tool calls.
26///
27/// Only `execute_tool_call` and `execute_tool_call_confirmed` are intercepted.
28/// Legacy `execute` / `execute_confirmed` bypass the check (no structured `tool_id`).
29pub struct AdversarialPolicyGateExecutor<T: ToolExecutor> {
30    inner: T,
31    validator: Arc<PolicyValidator>,
32    llm: Arc<dyn PolicyLlmClient>,
33    audit: Option<Arc<AuditLogger>>,
34}
35
36impl<T: ToolExecutor + std::fmt::Debug> std::fmt::Debug for AdversarialPolicyGateExecutor<T> {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("AdversarialPolicyGateExecutor")
39            .field("inner", &self.inner)
40            .finish_non_exhaustive()
41    }
42}
43
44impl<T: ToolExecutor> AdversarialPolicyGateExecutor<T> {
45    /// Create a new `AdversarialPolicyGateExecutor`.
46    #[must_use]
47    pub fn new(inner: T, validator: Arc<PolicyValidator>, llm: Arc<dyn PolicyLlmClient>) -> Self {
48        Self {
49            inner,
50            validator,
51            llm,
52            audit: None,
53        }
54    }
55
56    /// Attach an audit logger.
57    #[must_use]
58    pub fn with_audit(mut self, audit: Arc<AuditLogger>) -> Self {
59        self.audit = Some(audit);
60        self
61    }
62
63    async fn check_policy(&self, call: &ToolCall) -> Result<(), ToolError> {
64        tracing::info!(
65            tool = %call.tool_id,
66            status_spinner = true,
67            "Validating tool policy\u{2026}"
68        );
69
70        let decision = self
71            .validator
72            .validate(call.tool_id.as_str(), &call.params, self.llm.as_ref())
73            .await;
74
75        match decision {
76            PolicyDecision::Allow => {
77                tracing::debug!(tool = %call.tool_id, "adversarial policy: allow");
78                self.write_audit(call, "allow", AuditResult::Success, None)
79                    .await;
80                Ok(())
81            }
82            PolicyDecision::Deny { reason } => {
83                tracing::warn!(
84                    tool = %call.tool_id,
85                    reason = %reason,
86                    "adversarial policy: deny"
87                );
88                self.write_audit(
89                    call,
90                    &format!("deny:{reason}"),
91                    AuditResult::Blocked {
92                        reason: reason.clone(),
93                    },
94                    None,
95                )
96                .await;
97                // MED-03: do NOT surface the LLM reason to the main LLM.
98                Err(ToolError::Blocked {
99                    command: "[adversarial] Tool call denied by policy".to_owned(),
100                })
101            }
102            PolicyDecision::Error { message } => {
103                tracing::warn!(
104                    tool = %call.tool_id,
105                    error = %message,
106                    fail_open = self.validator.fail_open(),
107                    "adversarial policy: LLM error"
108                );
109                if self.validator.fail_open() {
110                    self.write_audit(
111                        call,
112                        &format!("error:{message}"),
113                        AuditResult::Success,
114                        None,
115                    )
116                    .await;
117                    Ok(())
118                } else {
119                    self.write_audit(
120                        call,
121                        &format!("error:{message}"),
122                        AuditResult::Blocked {
123                            reason: "adversarial policy LLM error (fail-closed)".to_owned(),
124                        },
125                        None,
126                    )
127                    .await;
128                    Err(ToolError::Blocked {
129                        command: "[adversarial] Tool call denied: policy check failed".to_owned(),
130                    })
131                }
132            }
133        }
134    }
135
136    async fn write_audit(
137        &self,
138        call: &ToolCall,
139        decision: &str,
140        result: AuditResult,
141        claim_source: Option<ClaimSource>,
142    ) {
143        let Some(audit) = &self.audit else { return };
144        let entry = AuditEntry {
145            timestamp: chrono_now(),
146            tool: call.tool_id.clone(),
147            command: params_summary(&call.params),
148            result,
149            duration_ms: 0,
150            error_category: None,
151            error_domain: None,
152            error_phase: None,
153            claim_source,
154            mcp_server_id: None,
155            injection_flagged: false,
156            embedding_anomalous: false,
157            cross_boundary_mcp_to_acp: false,
158            adversarial_policy_decision: Some(decision.to_owned()),
159            exit_code: None,
160            truncated: false,
161            caller_id: call.caller_id.clone(),
162            skill_name: call.skill_name.clone(),
163            policy_match: None,
164            correlation_id: None,
165            vigil_risk: None,
166            execution_env: None,
167            resolved_cwd: None,
168            scope_at_definition: None,
169            scope_at_dispatch: None,
170        };
171        audit.log(&entry).await;
172    }
173}
174
175impl<T: ToolExecutor> ToolExecutor for AdversarialPolicyGateExecutor<T> {
176    // Legacy dispatch bypasses adversarial check — no structured tool_id available.
177    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
178        self.inner.execute(response).await
179    }
180
181    async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
182        self.inner.execute_confirmed(response).await
183    }
184
185    // CRIT-06: delegate all pass-through methods to inner executor.
186    fn tool_definitions(&self) -> Vec<ToolDef> {
187        self.inner.tool_definitions()
188    }
189
190    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
191        self.check_policy(call).await?;
192        let output = self.inner.execute_tool_call(call).await?;
193        if let Some(ref out) = output {
194            self.write_audit(
195                call,
196                "allow:executed",
197                AuditResult::Success,
198                out.claim_source,
199            )
200            .await;
201        }
202        Ok(output)
203    }
204
205    // MED-04: policy also enforced on confirmed calls.
206    async fn execute_tool_call_confirmed(
207        &self,
208        call: &ToolCall,
209    ) -> Result<Option<ToolOutput>, ToolError> {
210        self.check_policy(call).await?;
211        let output = self.inner.execute_tool_call_confirmed(call).await?;
212        if let Some(ref out) = output {
213            self.write_audit(
214                call,
215                "allow:executed",
216                AuditResult::Success,
217                out.claim_source,
218            )
219            .await;
220        }
221        Ok(output)
222    }
223
224    fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
225        self.inner.set_skill_env(env);
226    }
227
228    fn set_effective_trust(&self, level: crate::SkillTrustLevel) {
229        self.inner.set_effective_trust(level);
230    }
231
232    fn is_tool_retryable(&self, tool_id: &str) -> bool {
233        self.inner.is_tool_retryable(tool_id)
234    }
235}
236
237fn params_summary(params: &serde_json::Map<String, serde_json::Value>) -> String {
238    let s = serde_json::to_string(params).unwrap_or_default();
239    if s.chars().count() > 500 {
240        let truncated: String = s.chars().take(497).collect();
241        format!("{truncated}\u{2026}")
242    } else {
243        s
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use std::future::Future;
250    use std::pin::Pin;
251    use std::sync::Arc;
252    use std::sync::atomic::{AtomicUsize, Ordering};
253    use std::time::Duration;
254
255    use super::*;
256    use crate::adversarial_policy::{PolicyMessage, PolicyValidator};
257    use crate::executor::{ToolCall, ToolOutput};
258
259    // --- Mock LLM client ---
260
261    struct MockLlm {
262        response: String,
263        call_count: Arc<AtomicUsize>,
264    }
265
266    impl MockLlm {
267        fn new(response: impl Into<String>) -> (Arc<AtomicUsize>, Self) {
268            let counter = Arc::new(AtomicUsize::new(0));
269            let client = Self {
270                response: response.into(),
271                call_count: Arc::clone(&counter),
272            };
273            (counter, client)
274        }
275    }
276
277    impl PolicyLlmClient for MockLlm {
278        fn chat<'a>(
279            &'a self,
280            _messages: &'a [PolicyMessage],
281        ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
282            self.call_count.fetch_add(1, Ordering::SeqCst);
283            let resp = self.response.clone();
284            Box::pin(async move { Ok(resp) })
285        }
286    }
287
288    // --- Mock inner executor ---
289
290    #[derive(Debug)]
291    struct MockInner {
292        call_count: Arc<AtomicUsize>,
293    }
294
295    impl MockInner {
296        fn new() -> (Arc<AtomicUsize>, Self) {
297            let counter = Arc::new(AtomicUsize::new(0));
298            let exec = Self {
299                call_count: Arc::clone(&counter),
300            };
301            (counter, exec)
302        }
303    }
304
305    impl ToolExecutor for MockInner {
306        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
307            Ok(None)
308        }
309
310        async fn execute_tool_call(
311            &self,
312            call: &ToolCall,
313        ) -> Result<Option<ToolOutput>, ToolError> {
314            self.call_count.fetch_add(1, Ordering::SeqCst);
315            Ok(Some(ToolOutput {
316                tool_name: call.tool_id.clone(),
317                summary: "ok".into(),
318                blocks_executed: 1,
319                filter_stats: None,
320                diff: None,
321                streamed: false,
322                terminal_id: None,
323                locations: None,
324                raw_response: None,
325                claim_source: None,
326            }))
327        }
328    }
329
330    fn make_call(tool_id: &str) -> ToolCall {
331        ToolCall {
332            tool_id: tool_id.into(),
333            params: serde_json::Map::new(),
334            caller_id: None,
335            context: None,
336
337            tool_call_id: String::new(),
338            skill_name: None,
339        }
340    }
341
342    fn make_validator(fail_open: bool) -> Arc<PolicyValidator> {
343        Arc::new(PolicyValidator::new(
344            vec!["test policy".to_owned()],
345            Duration::from_millis(500),
346            fail_open,
347            Vec::new(),
348        ))
349    }
350
351    #[tokio::test]
352    async fn allow_path_delegates_to_inner() {
353        let (llm_count, llm) = MockLlm::new("ALLOW");
354        let (inner_count, inner) = MockInner::new();
355        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
356        let result = gate.execute_tool_call(&make_call("shell")).await;
357        assert!(result.is_ok());
358        assert_eq!(
359            llm_count.load(Ordering::SeqCst),
360            1,
361            "LLM must be called once"
362        );
363        assert_eq!(
364            inner_count.load(Ordering::SeqCst),
365            1,
366            "inner executor must be called on allow"
367        );
368    }
369
370    #[tokio::test]
371    async fn deny_path_blocks_and_does_not_call_inner() {
372        let (llm_count, llm) = MockLlm::new("DENY: unsafe command");
373        let (inner_count, inner) = MockInner::new();
374        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
375        let result = gate.execute_tool_call(&make_call("shell")).await;
376        assert!(matches!(result, Err(ToolError::Blocked { .. })));
377        assert_eq!(llm_count.load(Ordering::SeqCst), 1);
378        assert_eq!(
379            inner_count.load(Ordering::SeqCst),
380            0,
381            "inner must NOT be called on deny"
382        );
383    }
384
385    #[tokio::test]
386    async fn error_message_is_opaque() {
387        // MED-03: error returned to main LLM must not contain the LLM denial reason.
388        let (_, llm) = MockLlm::new("DENY: secret internal policy rule XYZ");
389        let (_, inner) = MockInner::new();
390        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
391        let err = gate
392            .execute_tool_call(&make_call("shell"))
393            .await
394            .unwrap_err();
395        if let ToolError::Blocked { command } = err {
396            assert!(
397                !command.contains("secret internal policy rule XYZ"),
398                "LLM denial reason must not leak to main LLM"
399            );
400        } else {
401            panic!("expected Blocked error");
402        }
403    }
404
405    #[tokio::test]
406    async fn fail_closed_blocks_on_llm_error() {
407        struct FailingLlm;
408        impl PolicyLlmClient for FailingLlm {
409            fn chat<'a>(
410                &'a self,
411                _: &'a [PolicyMessage],
412            ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
413                Box::pin(async { Err("network error".to_owned()) })
414            }
415        }
416
417        let (_, inner) = MockInner::new();
418        let gate = AdversarialPolicyGateExecutor::new(
419            inner,
420            make_validator(false), // fail_open = false
421            Arc::new(FailingLlm),
422        );
423        let result = gate.execute_tool_call(&make_call("shell")).await;
424        assert!(
425            matches!(result, Err(ToolError::Blocked { .. })),
426            "fail-closed must block on LLM error"
427        );
428    }
429
430    #[tokio::test]
431    async fn fail_open_allows_on_llm_error() {
432        struct FailingLlm;
433        impl PolicyLlmClient for FailingLlm {
434            fn chat<'a>(
435                &'a self,
436                _: &'a [PolicyMessage],
437            ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + 'a>> {
438                Box::pin(async { Err("network error".to_owned()) })
439            }
440        }
441
442        let (inner_count, inner) = MockInner::new();
443        let gate = AdversarialPolicyGateExecutor::new(
444            inner,
445            make_validator(true), // fail_open = true
446            Arc::new(FailingLlm),
447        );
448        let result = gate.execute_tool_call(&make_call("shell")).await;
449        assert!(result.is_ok(), "fail-open must allow on LLM error");
450        assert_eq!(inner_count.load(Ordering::SeqCst), 1);
451    }
452
453    #[tokio::test]
454    async fn confirmed_also_enforces_policy() {
455        let (_, llm) = MockLlm::new("DENY: blocked");
456        let (_, inner) = MockInner::new();
457        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
458        let result = gate.execute_tool_call_confirmed(&make_call("shell")).await;
459        assert!(
460            matches!(result, Err(ToolError::Blocked { .. })),
461            "confirmed path must also enforce adversarial policy"
462        );
463    }
464
465    #[tokio::test]
466    async fn legacy_execute_bypasses_policy() {
467        let (llm_count, llm) = MockLlm::new("DENY: anything");
468        let (_, inner) = MockInner::new();
469        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
470        let result = gate.execute("```shell\necho hi\n```").await;
471        assert!(
472            result.is_ok(),
473            "legacy execute must bypass adversarial policy"
474        );
475        assert_eq!(
476            llm_count.load(Ordering::SeqCst),
477            0,
478            "LLM must NOT be called for legacy dispatch"
479        );
480    }
481
482    #[tokio::test]
483    async fn delegation_set_skill_env() {
484        // Verify that set_skill_env reaches the inner executor without panic.
485        let (_, llm) = MockLlm::new("ALLOW");
486        let (_, inner) = MockInner::new();
487        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
488        gate.set_skill_env(None);
489    }
490
491    #[tokio::test]
492    async fn delegation_set_effective_trust() {
493        use crate::SkillTrustLevel;
494        let (_, llm) = MockLlm::new("ALLOW");
495        let (_, inner) = MockInner::new();
496        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
497        gate.set_effective_trust(SkillTrustLevel::Trusted);
498    }
499
500    #[tokio::test]
501    async fn delegation_is_tool_retryable() {
502        let (_, llm) = MockLlm::new("ALLOW");
503        let (_, inner) = MockInner::new();
504        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
505        let retryable = gate.is_tool_retryable("shell");
506        assert!(!retryable, "MockInner returns false for is_tool_retryable");
507    }
508
509    #[tokio::test]
510    async fn delegation_tool_definitions() {
511        let (_, llm) = MockLlm::new("ALLOW");
512        let (_, inner) = MockInner::new();
513        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm));
514        let defs = gate.tool_definitions();
515        assert!(defs.is_empty(), "MockInner returns empty tool definitions");
516    }
517
518    #[tokio::test]
519    async fn audit_entry_contains_adversarial_decision() {
520        use tempfile::TempDir;
521
522        let dir = TempDir::new().unwrap();
523        let log_path = dir.path().join("audit.log");
524        let audit_config = crate::config::AuditConfig {
525            enabled: true,
526            destination: crate::config::AuditDestination::File(log_path.clone()),
527            ..Default::default()
528        };
529        let audit_logger = Arc::new(
530            crate::audit::AuditLogger::from_config(&audit_config, false)
531                .await
532                .unwrap(),
533        );
534
535        let (_, llm) = MockLlm::new("ALLOW");
536        let (_, inner) = MockInner::new();
537        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm))
538            .with_audit(Arc::clone(&audit_logger));
539
540        gate.execute_tool_call(&make_call("shell")).await.unwrap();
541
542        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
543        assert!(
544            content.contains("adversarial_policy_decision"),
545            "audit entry must contain adversarial_policy_decision field"
546        );
547        assert!(
548            content.contains("\"allow\""),
549            "allow decision must be recorded"
550        );
551    }
552
553    #[tokio::test]
554    async fn audit_entry_deny_contains_decision() {
555        use tempfile::TempDir;
556
557        let dir = TempDir::new().unwrap();
558        let log_path = dir.path().join("audit.log");
559        let audit_config = crate::config::AuditConfig {
560            enabled: true,
561            destination: crate::config::AuditDestination::File(log_path.clone()),
562            ..Default::default()
563        };
564        let audit_logger = Arc::new(
565            crate::audit::AuditLogger::from_config(&audit_config, false)
566                .await
567                .unwrap(),
568        );
569
570        let (_, llm) = MockLlm::new("DENY: test denial");
571        let (_, inner) = MockInner::new();
572        let gate = AdversarialPolicyGateExecutor::new(inner, make_validator(false), Arc::new(llm))
573            .with_audit(Arc::clone(&audit_logger));
574
575        let _ = gate.execute_tool_call(&make_call("shell")).await;
576
577        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
578        assert!(
579            content.contains("deny:"),
580            "deny decision must be recorded in audit"
581        );
582    }
583
584    #[tokio::test]
585    async fn audit_entry_propagates_claim_source() {
586        use tempfile::TempDir;
587
588        #[derive(Debug)]
589        struct InnerWithClaimSource;
590
591        impl ToolExecutor for InnerWithClaimSource {
592            async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
593                Ok(None)
594            }
595
596            async fn execute_tool_call(
597                &self,
598                call: &ToolCall,
599            ) -> Result<Option<ToolOutput>, ToolError> {
600                Ok(Some(ToolOutput {
601                    tool_name: call.tool_id.clone(),
602                    summary: "ok".into(),
603                    blocks_executed: 1,
604                    filter_stats: None,
605                    diff: None,
606                    streamed: false,
607                    terminal_id: None,
608                    locations: None,
609                    raw_response: None,
610                    claim_source: Some(crate::executor::ClaimSource::Shell),
611                }))
612            }
613        }
614
615        let dir = TempDir::new().unwrap();
616        let log_path = dir.path().join("audit.log");
617        let audit_config = crate::config::AuditConfig {
618            enabled: true,
619            destination: crate::config::AuditDestination::File(log_path.clone()),
620            ..Default::default()
621        };
622        let audit_logger = Arc::new(
623            crate::audit::AuditLogger::from_config(&audit_config, false)
624                .await
625                .unwrap(),
626        );
627
628        let (_, llm) = MockLlm::new("ALLOW");
629        let gate = AdversarialPolicyGateExecutor::new(
630            InnerWithClaimSource,
631            make_validator(false),
632            Arc::new(llm),
633        )
634        .with_audit(Arc::clone(&audit_logger));
635
636        gate.execute_tool_call(&make_call("shell")).await.unwrap();
637
638        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
639        assert!(
640            content.contains("\"shell\""),
641            "claim_source must be propagated into the post-execution audit entry"
642        );
643    }
644}