1use std::sync::Arc;
22
23use tracing::{Instrument as _, info_span};
24
25use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
26use crate::registry::ToolDef;
27
28pub trait ProbeGate: Send + Sync {
36 fn probe<'a>(
38 &'a self,
39 qualified_tool_id: &'a str,
40 args: &'a serde_json::Value,
41 turn_number: u64,
42 risk_level: &'a str,
43 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeOutcome> + Send + 'a>>;
44}
45
46#[derive(Debug, Clone, PartialEq, Eq)]
48#[non_exhaustive]
49pub enum ProbeOutcome {
50 Allow,
52 Deny {
54 reason: String,
56 },
57 Skip,
59}
60
61pub struct ShadowProbeExecutor<T: ToolExecutor> {
71 inner: T,
72 probe: Arc<dyn ProbeGate>,
73 turn_number: Arc<std::sync::atomic::AtomicU64>,
76 risk_level: Arc<parking_lot::RwLock<String>>,
78}
79
80impl<T: ToolExecutor + std::fmt::Debug> std::fmt::Debug for ShadowProbeExecutor<T> {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 f.debug_struct("ShadowProbeExecutor")
83 .field("inner", &self.inner)
84 .finish_non_exhaustive()
85 }
86}
87
88impl<T: ToolExecutor> ShadowProbeExecutor<T> {
89 #[must_use]
98 pub fn new(
99 inner: T,
100 probe: Arc<dyn ProbeGate>,
101 turn_number: Arc<std::sync::atomic::AtomicU64>,
102 risk_level: Arc<parking_lot::RwLock<String>>,
103 ) -> Self {
104 Self {
105 inner,
106 probe,
107 turn_number,
108 risk_level,
109 }
110 }
111
112 fn current_turn(&self) -> u64 {
113 self.turn_number.load(std::sync::atomic::Ordering::Acquire)
114 }
115
116 fn current_risk_level(&self) -> String {
117 self.risk_level.read().clone()
118 }
119}
120
121impl<T: ToolExecutor> ToolExecutor for ShadowProbeExecutor<T> {
122 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
124 self.inner.execute(response).await
125 }
126
127 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
129 self.inner.execute_confirmed(response).await
130 }
131
132 fn tool_definitions(&self) -> Vec<ToolDef> {
133 self.inner.tool_definitions()
134 }
135
136 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
141 let span = info_span!(
142 "security.shadow.probe_executor",
143 tool_id = %call.tool_id
144 );
145
146 let args = serde_json::Value::Object(call.params.clone());
147 let turn = self.current_turn();
148 let risk = self.current_risk_level();
149
150 let outcome = self
151 .probe
152 .probe(call.tool_id.as_str(), &args, turn, &risk)
153 .instrument(span)
154 .await;
155
156 match outcome {
157 ProbeOutcome::Allow | ProbeOutcome::Skip => self.inner.execute_tool_call(call).await,
158 ProbeOutcome::Deny { reason } => {
159 tracing::warn!(
160 tool_id = %call.tool_id,
161 reason = %reason,
162 "ShadowProbeExecutor: safety probe denied tool call"
163 );
164 Err(ToolError::SafetyDenied { reason })
165 }
166 }
167 }
168
169 async fn execute_tool_call_confirmed(
173 &self,
174 call: &ToolCall,
175 ) -> Result<Option<ToolOutput>, ToolError> {
176 let span = info_span!(
177 "security.shadow.probe_executor_confirmed",
178 tool_id = %call.tool_id
179 );
180
181 let args = serde_json::Value::Object(call.params.clone());
182 let turn = self.current_turn();
183 let risk = self.current_risk_level();
184
185 let outcome = self
186 .probe
187 .probe(call.tool_id.as_str(), &args, turn, &risk)
188 .instrument(span)
189 .await;
190
191 match outcome {
192 ProbeOutcome::Allow | ProbeOutcome::Skip => {
193 self.inner.execute_tool_call_confirmed(call).await
194 }
195 ProbeOutcome::Deny { reason } => {
196 tracing::warn!(
197 tool_id = %call.tool_id,
198 reason = %reason,
199 "ShadowProbeExecutor: safety probe denied confirmed tool call"
200 );
201 Err(ToolError::SafetyDenied { reason })
202 }
203 }
204 }
205
206 fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
207 self.inner.set_skill_env(env);
208 }
209
210 fn set_effective_trust(&self, level: crate::SkillTrustLevel) {
211 self.inner.set_effective_trust(level);
212 }
213
214 fn is_tool_retryable(&self, tool_id: &str) -> bool {
215 self.inner.is_tool_retryable(tool_id)
216 }
217
218 fn is_tool_speculatable(&self, tool_id: &str) -> bool {
219 let _ = tool_id;
222 false
223 }
224
225 fn requires_confirmation(&self, call: &ToolCall) -> bool {
226 self.inner.requires_confirmation(call)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::executor::{ToolError, ToolOutput};
234 use crate::{ToolCall, ToolExecutor};
235 use zeph_common::ToolName;
236
237 struct AllowProbe;
238 impl ProbeGate for AllowProbe {
239 fn probe<'a>(
240 &'a self,
241 _: &'a str,
242 _: &'a serde_json::Value,
243 _: u64,
244 _: &'a str,
245 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeOutcome> + Send + 'a>>
246 {
247 Box::pin(async { ProbeOutcome::Allow })
248 }
249 }
250
251 struct DenyProbe;
252 impl ProbeGate for DenyProbe {
253 fn probe<'a>(
254 &'a self,
255 _: &'a str,
256 _: &'a serde_json::Value,
257 _: u64,
258 _: &'a str,
259 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeOutcome> + Send + 'a>>
260 {
261 Box::pin(async {
262 ProbeOutcome::Deny {
263 reason: "test denial".to_owned(),
264 }
265 })
266 }
267 }
268
269 struct SkipProbe;
270 impl ProbeGate for SkipProbe {
271 fn probe<'a>(
272 &'a self,
273 _: &'a str,
274 _: &'a serde_json::Value,
275 _: u64,
276 _: &'a str,
277 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeOutcome> + Send + 'a>>
278 {
279 Box::pin(async { ProbeOutcome::Skip })
280 }
281 }
282
283 struct OkInner;
284 impl ToolExecutor for OkInner {
285 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
286 Ok(None)
287 }
288
289 async fn execute_tool_call(
290 &self,
291 call: &ToolCall,
292 ) -> Result<Option<ToolOutput>, ToolError> {
293 Ok(Some(ToolOutput {
294 tool_name: call.tool_id.clone(),
295 summary: "ok".to_owned(),
296 blocks_executed: 1,
297 filter_stats: None,
298 diff: None,
299 streamed: false,
300 terminal_id: None,
301 locations: None,
302 raw_response: None,
303 claim_source: None,
304 }))
305 }
306 }
307
308 fn make_call(tool: &str) -> ToolCall {
309 ToolCall {
310 tool_id: ToolName::new(tool),
311 params: serde_json::Map::new(),
312 caller_id: None,
313 context: None,
314 tool_call_id: String::new(),
315 skill_name: None,
316 }
317 }
318
319 fn make_executor<P: ProbeGate + 'static>(probe: P) -> ShadowProbeExecutor<OkInner> {
320 ShadowProbeExecutor::new(
321 OkInner,
322 Arc::new(probe),
323 Arc::new(std::sync::atomic::AtomicU64::new(1)),
324 Arc::new(parking_lot::RwLock::new("calm".to_owned())),
325 )
326 }
327
328 #[tokio::test]
329 async fn allow_probe_delegates_to_inner() {
330 let exec = make_executor(AllowProbe);
331 let result = exec.execute_tool_call(&make_call("builtin:shell")).await;
332 assert!(result.unwrap().is_some());
333 }
334
335 #[tokio::test]
336 async fn deny_probe_returns_safety_denied() {
337 let exec = make_executor(DenyProbe);
338 let result = exec.execute_tool_call(&make_call("builtin:shell")).await;
339 match result {
340 Err(ToolError::SafetyDenied { reason }) => {
341 assert_eq!(reason, "test denial");
342 }
343 other => panic!("expected SafetyDenied, got {other:?}"),
344 }
345 }
346
347 #[tokio::test]
348 async fn skip_probe_delegates_to_inner() {
349 let exec = make_executor(SkipProbe);
350 let result = exec.execute_tool_call(&make_call("builtin:read")).await;
351 assert!(result.unwrap().is_some());
352 }
353
354 #[tokio::test]
355 async fn legacy_execute_bypasses_probe() {
356 let exec = make_executor(DenyProbe);
357 let result = exec.execute("some text").await;
359 assert!(result.unwrap().is_none());
360 }
361
362 #[tokio::test]
363 async fn deny_probe_blocks_confirmed_call() {
364 let exec = make_executor(DenyProbe);
366 let result = exec
367 .execute_tool_call_confirmed(&make_call("builtin:shell"))
368 .await;
369 match result {
370 Err(ToolError::SafetyDenied { reason }) => {
371 assert_eq!(reason, "test denial");
372 }
373 other => panic!("expected SafetyDenied on confirmed call, got {other:?}"),
374 }
375 }
376
377 #[test]
378 fn is_tool_speculatable_always_false() {
379 let exec = make_executor(AllowProbe);
380 assert!(!exec.is_tool_speculatable("builtin:read"));
381 assert!(!exec.is_tool_speculatable("builtin:shell"));
382 }
383}