1use std::sync::Arc;
22
23use tracing::info_span;
24
25use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
26use crate::registry::ToolDef;
27
28pub trait ProbeGate: Send + Sync {
36 fn probe<'a>(
38 &'a self,
39 qualified_tool_id: &'a str,
40 args: &'a serde_json::Value,
41 turn_number: u64,
42 risk_level: &'a str,
43 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeOutcome> + Send + 'a>>;
44}
45
46#[derive(Debug, Clone, PartialEq, Eq)]
48#[non_exhaustive]
49pub enum ProbeOutcome {
50 Allow,
52 Deny {
54 reason: String,
56 },
57 Skip,
59}
60
61pub struct ShadowProbeExecutor<T: ToolExecutor> {
71 inner: T,
72 probe: Arc<dyn ProbeGate>,
73 turn_number: Arc<std::sync::atomic::AtomicU64>,
76 risk_level: Arc<parking_lot::RwLock<String>>,
78}
79
80impl<T: ToolExecutor + std::fmt::Debug> std::fmt::Debug for ShadowProbeExecutor<T> {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 f.debug_struct("ShadowProbeExecutor")
83 .field("inner", &self.inner)
84 .finish_non_exhaustive()
85 }
86}
87
88impl<T: ToolExecutor> ShadowProbeExecutor<T> {
89 #[must_use]
98 pub fn new(
99 inner: T,
100 probe: Arc<dyn ProbeGate>,
101 turn_number: Arc<std::sync::atomic::AtomicU64>,
102 risk_level: Arc<parking_lot::RwLock<String>>,
103 ) -> Self {
104 Self {
105 inner,
106 probe,
107 turn_number,
108 risk_level,
109 }
110 }
111
112 fn current_turn(&self) -> u64 {
113 self.turn_number.load(std::sync::atomic::Ordering::Acquire)
114 }
115
116 fn current_risk_level(&self) -> String {
117 self.risk_level.read().clone()
118 }
119}
120
121impl<T: ToolExecutor> ToolExecutor for ShadowProbeExecutor<T> {
122 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
124 self.inner.execute(response).await
125 }
126
127 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
129 self.inner.execute_confirmed(response).await
130 }
131
132 fn tool_definitions(&self) -> Vec<ToolDef> {
133 self.inner.tool_definitions()
134 }
135
136 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
141 let span = info_span!(
142 "security.shadow.probe_executor",
143 tool_id = %call.tool_id
144 );
145 let _enter = span.enter();
146
147 let args = serde_json::Value::Object(call.params.clone());
148 let turn = self.current_turn();
149 let risk = self.current_risk_level();
150
151 let outcome = self
152 .probe
153 .probe(call.tool_id.as_str(), &args, turn, &risk)
154 .await;
155
156 match outcome {
157 ProbeOutcome::Allow | ProbeOutcome::Skip => self.inner.execute_tool_call(call).await,
158 ProbeOutcome::Deny { reason } => {
159 tracing::warn!(
160 tool_id = %call.tool_id,
161 reason = %reason,
162 "ShadowProbeExecutor: safety probe denied tool call"
163 );
164 Err(ToolError::SafetyDenied { reason })
165 }
166 }
167 }
168
169 async fn execute_tool_call_confirmed(
173 &self,
174 call: &ToolCall,
175 ) -> Result<Option<ToolOutput>, ToolError> {
176 let span = info_span!(
177 "security.shadow.probe_executor_confirmed",
178 tool_id = %call.tool_id
179 );
180 let _enter = span.enter();
181
182 let args = serde_json::Value::Object(call.params.clone());
183 let turn = self.current_turn();
184 let risk = self.current_risk_level();
185
186 let outcome = self
187 .probe
188 .probe(call.tool_id.as_str(), &args, turn, &risk)
189 .await;
190
191 match outcome {
192 ProbeOutcome::Allow | ProbeOutcome::Skip => {
193 self.inner.execute_tool_call_confirmed(call).await
194 }
195 ProbeOutcome::Deny { reason } => {
196 tracing::warn!(
197 tool_id = %call.tool_id,
198 reason = %reason,
199 "ShadowProbeExecutor: safety probe denied confirmed tool call"
200 );
201 Err(ToolError::SafetyDenied { reason })
202 }
203 }
204 }
205
206 fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
207 self.inner.set_skill_env(env);
208 }
209
210 fn set_effective_trust(&self, level: crate::SkillTrustLevel) {
211 self.inner.set_effective_trust(level);
212 }
213
214 fn is_tool_retryable(&self, tool_id: &str) -> bool {
215 self.inner.is_tool_retryable(tool_id)
216 }
217
218 fn is_tool_speculatable(&self, tool_id: &str) -> bool {
219 let _ = tool_id;
222 false
223 }
224
225 fn requires_confirmation(&self, call: &ToolCall) -> bool {
226 self.inner.requires_confirmation(call)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::executor::{ToolError, ToolOutput};
234 use crate::{ToolCall, ToolExecutor};
235 use zeph_common::ToolName;
236
237 struct AllowProbe;
238 impl ProbeGate for AllowProbe {
239 fn probe<'a>(
240 &'a self,
241 _: &'a str,
242 _: &'a serde_json::Value,
243 _: u64,
244 _: &'a str,
245 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeOutcome> + Send + 'a>>
246 {
247 Box::pin(async { ProbeOutcome::Allow })
248 }
249 }
250
251 struct DenyProbe;
252 impl ProbeGate for DenyProbe {
253 fn probe<'a>(
254 &'a self,
255 _: &'a str,
256 _: &'a serde_json::Value,
257 _: u64,
258 _: &'a str,
259 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeOutcome> + Send + 'a>>
260 {
261 Box::pin(async {
262 ProbeOutcome::Deny {
263 reason: "test denial".to_owned(),
264 }
265 })
266 }
267 }
268
269 struct SkipProbe;
270 impl ProbeGate for SkipProbe {
271 fn probe<'a>(
272 &'a self,
273 _: &'a str,
274 _: &'a serde_json::Value,
275 _: u64,
276 _: &'a str,
277 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeOutcome> + Send + 'a>>
278 {
279 Box::pin(async { ProbeOutcome::Skip })
280 }
281 }
282
283 struct OkInner;
284 impl ToolExecutor for OkInner {
285 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
286 Ok(None)
287 }
288
289 async fn execute_tool_call(
290 &self,
291 call: &ToolCall,
292 ) -> Result<Option<ToolOutput>, ToolError> {
293 Ok(Some(ToolOutput {
294 tool_name: call.tool_id.clone(),
295 summary: "ok".to_owned(),
296 blocks_executed: 1,
297 filter_stats: None,
298 diff: None,
299 streamed: false,
300 terminal_id: None,
301 locations: None,
302 raw_response: None,
303 claim_source: None,
304 }))
305 }
306 }
307
308 fn make_call(tool: &str) -> ToolCall {
309 ToolCall {
310 tool_id: ToolName::new(tool),
311 params: serde_json::Map::new(),
312 caller_id: None,
313 context: None,
314 tool_call_id: String::new(),
315 }
316 }
317
318 fn make_executor<P: ProbeGate + 'static>(probe: P) -> ShadowProbeExecutor<OkInner> {
319 ShadowProbeExecutor::new(
320 OkInner,
321 Arc::new(probe),
322 Arc::new(std::sync::atomic::AtomicU64::new(1)),
323 Arc::new(parking_lot::RwLock::new("calm".to_owned())),
324 )
325 }
326
327 #[tokio::test]
328 async fn allow_probe_delegates_to_inner() {
329 let exec = make_executor(AllowProbe);
330 let result = exec.execute_tool_call(&make_call("builtin:shell")).await;
331 assert!(result.unwrap().is_some());
332 }
333
334 #[tokio::test]
335 async fn deny_probe_returns_safety_denied() {
336 let exec = make_executor(DenyProbe);
337 let result = exec.execute_tool_call(&make_call("builtin:shell")).await;
338 match result {
339 Err(ToolError::SafetyDenied { reason }) => {
340 assert_eq!(reason, "test denial");
341 }
342 other => panic!("expected SafetyDenied, got {other:?}"),
343 }
344 }
345
346 #[tokio::test]
347 async fn skip_probe_delegates_to_inner() {
348 let exec = make_executor(SkipProbe);
349 let result = exec.execute_tool_call(&make_call("builtin:read")).await;
350 assert!(result.unwrap().is_some());
351 }
352
353 #[tokio::test]
354 async fn legacy_execute_bypasses_probe() {
355 let exec = make_executor(DenyProbe);
356 let result = exec.execute("some text").await;
358 assert!(result.unwrap().is_none());
359 }
360
361 #[tokio::test]
362 async fn deny_probe_blocks_confirmed_call() {
363 let exec = make_executor(DenyProbe);
365 let result = exec
366 .execute_tool_call_confirmed(&make_call("builtin:shell"))
367 .await;
368 match result {
369 Err(ToolError::SafetyDenied { reason }) => {
370 assert_eq!(reason, "test denial");
371 }
372 other => panic!("expected SafetyDenied on confirmed call, got {other:?}"),
373 }
374 }
375
376 #[test]
377 fn is_tool_speculatable_always_false() {
378 let exec = make_executor(AllowProbe);
379 assert!(!exec.is_tool_speculatable("builtin:read"));
380 assert!(!exec.is_tool_speculatable("builtin:shell"));
381 }
382}