Skip to main content

zeph_tools/
tool_filter.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
5use crate::registry::ToolDef;
6
7/// Wraps a `ToolExecutor` and suppresses specified tool ids from both
8/// `tool_definitions` and `execute_tool_call`.
9///
10/// Used to hide `FileExecutor` tools (e.g. `read`, `glob`) when
11/// `AcpFileExecutor` provides equivalent IDE-proxied alternatives.
12#[derive(Debug)]
13pub struct ToolFilter<E: ToolExecutor> {
14    inner: E,
15    suppressed: &'static [&'static str],
16}
17
18impl<E: ToolExecutor> ToolFilter<E> {
19    #[must_use]
20    pub fn new(inner: E, suppressed: &'static [&'static str]) -> Self {
21        Self { inner, suppressed }
22    }
23}
24
25impl<E: ToolExecutor> ToolExecutor for ToolFilter<E> {
26    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
27        self.inner.execute(response).await
28    }
29
30    async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
31        self.inner.execute_confirmed(response).await
32    }
33
34    fn tool_definitions(&self) -> Vec<ToolDef> {
35        self.inner
36            .tool_definitions()
37            .into_iter()
38            .filter(|d| !self.suppressed.contains(&d.id.as_ref()))
39            .collect()
40    }
41
42    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
43        if self.suppressed.contains(&call.tool_id.as_str()) {
44            return Ok(None);
45        }
46        self.inner.execute_tool_call(call).await
47    }
48}
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53    use crate::ToolName;
54
55    #[derive(Debug)]
56    struct StubExecutor;
57    impl ToolExecutor for StubExecutor {
58        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
59            Ok(None)
60        }
61        fn tool_definitions(&self) -> Vec<ToolDef> {
62            vec![
63                ToolDef {
64                    id: "read".into(),
65                    description: "read a file".into(),
66                    schema: schemars::schema_for!(String),
67                    invocation: crate::registry::InvocationHint::ToolCall,
68                },
69                ToolDef {
70                    id: "glob".into(),
71                    description: "find files".into(),
72                    schema: schemars::schema_for!(String),
73                    invocation: crate::registry::InvocationHint::ToolCall,
74                },
75                ToolDef {
76                    id: "edit".into(),
77                    description: "edit a file".into(),
78                    schema: schemars::schema_for!(String),
79                    invocation: crate::registry::InvocationHint::ToolCall,
80                },
81            ]
82        }
83        async fn execute_tool_call(
84            &self,
85            call: &ToolCall,
86        ) -> Result<Option<ToolOutput>, ToolError> {
87            Ok(Some(ToolOutput {
88                tool_name: call.tool_id.clone(),
89                summary: "stub".to_owned(),
90                blocks_executed: 1,
91                filter_stats: None,
92                diff: None,
93                streamed: false,
94                terminal_id: None,
95                locations: None,
96                raw_response: None,
97                claim_source: None,
98            }))
99        }
100    }
101
102    #[test]
103    fn suppressed_tools_hidden_from_definitions() {
104        let filter = ToolFilter::new(StubExecutor, &["read", "glob"]);
105        let defs = filter.tool_definitions();
106        let ids: Vec<&str> = defs.iter().map(|d| d.id.as_ref()).collect();
107        assert!(!ids.contains(&"read"));
108        assert!(!ids.contains(&"glob"));
109        assert!(ids.contains(&"edit"));
110    }
111
112    #[tokio::test]
113    async fn suppressed_tool_call_returns_none() {
114        let filter = ToolFilter::new(StubExecutor, &["read", "glob"]);
115        let call = ToolCall {
116            tool_id: ToolName::new("read"),
117            params: serde_json::Map::new(),
118            caller_id: None,
119        };
120        let result = filter.execute_tool_call(&call).await.unwrap();
121        assert!(result.is_none());
122    }
123
124    #[tokio::test]
125    async fn allowed_tool_call_passes_through() {
126        let filter = ToolFilter::new(StubExecutor, &["read", "glob"]);
127        let call = ToolCall {
128            tool_id: ToolName::new("edit"),
129            params: serde_json::Map::new(),
130            caller_id: None,
131        };
132        let result = filter.execute_tool_call(&call).await.unwrap();
133        assert!(result.is_some());
134    }
135}