Skip to main content

zeph_tools/
composite.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Composite executor that chains two [`ToolExecutor`] implementations.
5
6use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
7use crate::registry::ToolDef;
8
9/// Chains two [`ToolExecutor`] implementations with first-match-wins dispatch.
10///
11/// For each method, `first` is tried first. If it returns `Ok(None)` (i.e. it does not
12/// handle the input), `second` is tried. If `first` returns an `Err`, the error propagates
13/// immediately without consulting `second`.
14///
15/// Use this to compose a chain of specialized executors at startup instead of a dynamic
16/// `Vec<Box<dyn ...>>`. Nest multiple `CompositeExecutor`s to handle more than two backends.
17///
18/// Tool definitions from both executors are merged, with `first` taking precedence when
19/// both define a tool with the same ID.
20///
21/// # Example
22///
23/// ```rust
24/// use zeph_tools::{
25///     CompositeExecutor, ShellExecutor, WebScrapeExecutor,
26///     config::{ShellConfig, ScrapeConfig},
27/// };
28///
29/// let shell = ShellExecutor::new(&ShellConfig::default());
30/// let scrape = WebScrapeExecutor::new(&ScrapeConfig::default());
31/// let executor = CompositeExecutor::new(shell, scrape);
32/// // executor handles both bash blocks and scrape/fetch tool calls.
33/// ```
34#[derive(Debug)]
35pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
36    first: A,
37    second: B,
38}
39
40impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
41    /// Create a new `CompositeExecutor` wrapping `first` and `second`.
42    #[must_use]
43    pub fn new(first: A, second: B) -> Self {
44        Self { first, second }
45    }
46}
47
48impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
49    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
50        if let Some(output) = self.first.execute(response).await? {
51            return Ok(Some(output));
52        }
53        self.second.execute(response).await
54    }
55
56    async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
57        if let Some(output) = self.first.execute_confirmed(response).await? {
58            return Ok(Some(output));
59        }
60        self.second.execute_confirmed(response).await
61    }
62
63    fn tool_definitions(&self) -> Vec<ToolDef> {
64        let mut defs = self.first.tool_definitions();
65        let seen: std::collections::HashSet<String> =
66            defs.iter().map(|d| d.id.to_string()).collect();
67        for def in self.second.tool_definitions() {
68            if !seen.contains(def.id.as_ref()) {
69                defs.push(def);
70            }
71        }
72        defs
73    }
74
75    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
76        if let Some(output) = self.first.execute_tool_call(call).await? {
77            return Ok(Some(output));
78        }
79        self.second.execute_tool_call(call).await
80    }
81
82    fn is_tool_retryable(&self, tool_id: &str) -> bool {
83        self.first.is_tool_retryable(tool_id) || self.second.is_tool_retryable(tool_id)
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use crate::ToolName;
91
92    #[derive(Debug)]
93    struct MatchingExecutor;
94    impl ToolExecutor for MatchingExecutor {
95        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
96            Ok(Some(ToolOutput {
97                tool_name: ToolName::new("test"),
98                summary: "matched".to_owned(),
99                blocks_executed: 1,
100                filter_stats: None,
101                diff: None,
102                streamed: false,
103                terminal_id: None,
104                locations: None,
105                raw_response: None,
106                claim_source: None,
107            }))
108        }
109    }
110
111    #[derive(Debug)]
112    struct NoMatchExecutor;
113    impl ToolExecutor for NoMatchExecutor {
114        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
115            Ok(None)
116        }
117    }
118
119    #[derive(Debug)]
120    struct ErrorExecutor;
121    impl ToolExecutor for ErrorExecutor {
122        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
123            Err(ToolError::Blocked {
124                command: "test".to_owned(),
125            })
126        }
127    }
128
129    #[derive(Debug)]
130    struct SecondExecutor;
131    impl ToolExecutor for SecondExecutor {
132        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
133            Ok(Some(ToolOutput {
134                tool_name: ToolName::new("test"),
135                summary: "second".to_owned(),
136                blocks_executed: 1,
137                filter_stats: None,
138                diff: None,
139                streamed: false,
140                terminal_id: None,
141                locations: None,
142                raw_response: None,
143                claim_source: None,
144            }))
145        }
146    }
147
148    #[tokio::test]
149    async fn first_matches_returns_first() {
150        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
151        let result = composite.execute("anything").await.unwrap();
152        assert_eq!(result.unwrap().summary, "matched");
153    }
154
155    #[tokio::test]
156    async fn first_none_falls_through_to_second() {
157        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
158        let result = composite.execute("anything").await.unwrap();
159        assert_eq!(result.unwrap().summary, "second");
160    }
161
162    #[tokio::test]
163    async fn both_none_returns_none() {
164        let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
165        let result = composite.execute("anything").await.unwrap();
166        assert!(result.is_none());
167    }
168
169    #[tokio::test]
170    async fn first_error_propagates_without_trying_second() {
171        let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
172        let result = composite.execute("anything").await;
173        assert!(matches!(result, Err(ToolError::Blocked { .. })));
174    }
175
176    #[tokio::test]
177    async fn second_error_propagates_when_first_none() {
178        let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
179        let result = composite.execute("anything").await;
180        assert!(matches!(result, Err(ToolError::Blocked { .. })));
181    }
182
183    #[tokio::test]
184    async fn execute_confirmed_first_matches() {
185        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
186        let result = composite.execute_confirmed("anything").await.unwrap();
187        assert_eq!(result.unwrap().summary, "matched");
188    }
189
190    #[tokio::test]
191    async fn execute_confirmed_falls_through() {
192        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
193        let result = composite.execute_confirmed("anything").await.unwrap();
194        assert_eq!(result.unwrap().summary, "second");
195    }
196
197    #[test]
198    fn composite_debug() {
199        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
200        let debug = format!("{composite:?}");
201        assert!(debug.contains("CompositeExecutor"));
202    }
203
204    #[derive(Debug)]
205    struct FileToolExecutor;
206    impl ToolExecutor for FileToolExecutor {
207        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
208            Ok(None)
209        }
210        async fn execute_tool_call(
211            &self,
212            call: &ToolCall,
213        ) -> Result<Option<ToolOutput>, ToolError> {
214            if call.tool_id == "read" || call.tool_id == "write" {
215                Ok(Some(ToolOutput {
216                    tool_name: call.tool_id.clone(),
217                    summary: "file_handler".to_owned(),
218                    blocks_executed: 1,
219                    filter_stats: None,
220                    diff: None,
221                    streamed: false,
222                    terminal_id: None,
223                    locations: None,
224                    raw_response: None,
225                    claim_source: None,
226                }))
227            } else {
228                Ok(None)
229            }
230        }
231    }
232
233    #[derive(Debug)]
234    struct ShellToolExecutor;
235    impl ToolExecutor for ShellToolExecutor {
236        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
237            Ok(None)
238        }
239        async fn execute_tool_call(
240            &self,
241            call: &ToolCall,
242        ) -> Result<Option<ToolOutput>, ToolError> {
243            if call.tool_id == "bash" {
244                Ok(Some(ToolOutput {
245                    tool_name: ToolName::new("bash"),
246                    summary: "shell_handler".to_owned(),
247                    blocks_executed: 1,
248                    filter_stats: None,
249                    diff: None,
250                    streamed: false,
251                    terminal_id: None,
252                    locations: None,
253                    raw_response: None,
254                    claim_source: None,
255                }))
256            } else {
257                Ok(None)
258            }
259        }
260    }
261
262    #[tokio::test]
263    async fn tool_call_routes_to_file_executor() {
264        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
265        let call = ToolCall {
266            tool_id: ToolName::new("read"),
267            params: serde_json::Map::new(),
268            caller_id: None,
269        };
270        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
271        assert_eq!(result.summary, "file_handler");
272    }
273
274    #[tokio::test]
275    async fn tool_call_routes_to_shell_executor() {
276        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
277        let call = ToolCall {
278            tool_id: ToolName::new("bash"),
279            params: serde_json::Map::new(),
280            caller_id: None,
281        };
282        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
283        assert_eq!(result.summary, "shell_handler");
284    }
285
286    #[tokio::test]
287    async fn tool_call_unhandled_returns_none() {
288        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
289        let call = ToolCall {
290            tool_id: ToolName::new("unknown"),
291            params: serde_json::Map::new(),
292            caller_id: None,
293        };
294        let result = composite.execute_tool_call(&call).await.unwrap();
295        assert!(result.is_none());
296    }
297}