Skip to main content

zeph_tools/
composite.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Composite executor that chains two [`ToolExecutor`] implementations.
5
6use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
7use crate::registry::ToolDef;
8
9/// Chains two [`ToolExecutor`] implementations with first-match-wins dispatch.
10///
11/// For each method, `first` is tried first. If it returns `Ok(None)` (i.e. it does not
12/// handle the input), `second` is tried. If `first` returns an `Err`, the error propagates
13/// immediately without consulting `second`.
14///
15/// Use this to compose a chain of specialized executors at startup instead of a dynamic
16/// `Vec<Box<dyn ...>>`. Nest multiple `CompositeExecutor`s to handle more than two backends.
17///
18/// Tool definitions from both executors are merged, with `first` taking precedence when
19/// both define a tool with the same ID.
20///
21/// # Example
22///
23/// ```rust
24/// use zeph_tools::{
25///     CompositeExecutor, ShellExecutor, WebScrapeExecutor, ShellConfig, ScrapeConfig,
26/// };
27///
28/// let shell = ShellExecutor::new(&ShellConfig::default());
29/// let scrape = WebScrapeExecutor::new(&ScrapeConfig::default());
30/// let executor = CompositeExecutor::new(shell, scrape);
31/// // executor handles both bash blocks and scrape/fetch tool calls.
32/// ```
33#[derive(Debug)]
34pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
35    first: A,
36    second: B,
37}
38
39impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
40    /// Create a new `CompositeExecutor` wrapping `first` and `second`.
41    #[must_use]
42    pub fn new(first: A, second: B) -> Self {
43        Self { first, second }
44    }
45}
46
47impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
48    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
49        if let Some(output) = self.first.execute(response).await? {
50            return Ok(Some(output));
51        }
52        self.second.execute(response).await
53    }
54
55    async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
56        if let Some(output) = self.first.execute_confirmed(response).await? {
57            return Ok(Some(output));
58        }
59        self.second.execute_confirmed(response).await
60    }
61
62    fn tool_definitions(&self) -> Vec<ToolDef> {
63        let mut defs = self.first.tool_definitions();
64        let seen: std::collections::HashSet<String> =
65            defs.iter().map(|d| d.id.to_string()).collect();
66        for def in self.second.tool_definitions() {
67            if !seen.contains(def.id.as_ref()) {
68                defs.push(def);
69            }
70        }
71        defs
72    }
73
74    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
75        if let Some(output) = self.first.execute_tool_call(call).await? {
76            return Ok(Some(output));
77        }
78        self.second.execute_tool_call(call).await
79    }
80
81    fn is_tool_retryable(&self, tool_id: &str) -> bool {
82        self.first.is_tool_retryable(tool_id) || self.second.is_tool_retryable(tool_id)
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::ToolName;
90
91    #[derive(Debug)]
92    struct MatchingExecutor;
93    impl ToolExecutor for MatchingExecutor {
94        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
95            Ok(Some(ToolOutput {
96                tool_name: ToolName::new("test"),
97                summary: "matched".to_owned(),
98                blocks_executed: 1,
99                filter_stats: None,
100                diff: None,
101                streamed: false,
102                terminal_id: None,
103                locations: None,
104                raw_response: None,
105                claim_source: None,
106            }))
107        }
108    }
109
110    #[derive(Debug)]
111    struct NoMatchExecutor;
112    impl ToolExecutor for NoMatchExecutor {
113        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
114            Ok(None)
115        }
116    }
117
118    #[derive(Debug)]
119    struct ErrorExecutor;
120    impl ToolExecutor for ErrorExecutor {
121        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
122            Err(ToolError::Blocked {
123                command: "test".to_owned(),
124            })
125        }
126    }
127
128    #[derive(Debug)]
129    struct SecondExecutor;
130    impl ToolExecutor for SecondExecutor {
131        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
132            Ok(Some(ToolOutput {
133                tool_name: ToolName::new("test"),
134                summary: "second".to_owned(),
135                blocks_executed: 1,
136                filter_stats: None,
137                diff: None,
138                streamed: false,
139                terminal_id: None,
140                locations: None,
141                raw_response: None,
142                claim_source: None,
143            }))
144        }
145    }
146
147    #[tokio::test]
148    async fn first_matches_returns_first() {
149        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
150        let result = composite.execute("anything").await.unwrap();
151        assert_eq!(result.unwrap().summary, "matched");
152    }
153
154    #[tokio::test]
155    async fn first_none_falls_through_to_second() {
156        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
157        let result = composite.execute("anything").await.unwrap();
158        assert_eq!(result.unwrap().summary, "second");
159    }
160
161    #[tokio::test]
162    async fn both_none_returns_none() {
163        let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
164        let result = composite.execute("anything").await.unwrap();
165        assert!(result.is_none());
166    }
167
168    #[tokio::test]
169    async fn first_error_propagates_without_trying_second() {
170        let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
171        let result = composite.execute("anything").await;
172        assert!(matches!(result, Err(ToolError::Blocked { .. })));
173    }
174
175    #[tokio::test]
176    async fn second_error_propagates_when_first_none() {
177        let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
178        let result = composite.execute("anything").await;
179        assert!(matches!(result, Err(ToolError::Blocked { .. })));
180    }
181
182    #[tokio::test]
183    async fn execute_confirmed_first_matches() {
184        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
185        let result = composite.execute_confirmed("anything").await.unwrap();
186        assert_eq!(result.unwrap().summary, "matched");
187    }
188
189    #[tokio::test]
190    async fn execute_confirmed_falls_through() {
191        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
192        let result = composite.execute_confirmed("anything").await.unwrap();
193        assert_eq!(result.unwrap().summary, "second");
194    }
195
196    #[test]
197    fn composite_debug() {
198        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
199        let debug = format!("{composite:?}");
200        assert!(debug.contains("CompositeExecutor"));
201    }
202
203    #[derive(Debug)]
204    struct FileToolExecutor;
205    impl ToolExecutor for FileToolExecutor {
206        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
207            Ok(None)
208        }
209        async fn execute_tool_call(
210            &self,
211            call: &ToolCall,
212        ) -> Result<Option<ToolOutput>, ToolError> {
213            if call.tool_id == "read" || call.tool_id == "write" {
214                Ok(Some(ToolOutput {
215                    tool_name: call.tool_id.clone(),
216                    summary: "file_handler".to_owned(),
217                    blocks_executed: 1,
218                    filter_stats: None,
219                    diff: None,
220                    streamed: false,
221                    terminal_id: None,
222                    locations: None,
223                    raw_response: None,
224                    claim_source: None,
225                }))
226            } else {
227                Ok(None)
228            }
229        }
230    }
231
232    #[derive(Debug)]
233    struct ShellToolExecutor;
234    impl ToolExecutor for ShellToolExecutor {
235        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
236            Ok(None)
237        }
238        async fn execute_tool_call(
239            &self,
240            call: &ToolCall,
241        ) -> Result<Option<ToolOutput>, ToolError> {
242            if call.tool_id == "bash" {
243                Ok(Some(ToolOutput {
244                    tool_name: ToolName::new("bash"),
245                    summary: "shell_handler".to_owned(),
246                    blocks_executed: 1,
247                    filter_stats: None,
248                    diff: None,
249                    streamed: false,
250                    terminal_id: None,
251                    locations: None,
252                    raw_response: None,
253                    claim_source: None,
254                }))
255            } else {
256                Ok(None)
257            }
258        }
259    }
260
261    #[tokio::test]
262    async fn tool_call_routes_to_file_executor() {
263        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
264        let call = ToolCall {
265            tool_id: ToolName::new("read"),
266            params: serde_json::Map::new(),
267            caller_id: None,
268        };
269        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
270        assert_eq!(result.summary, "file_handler");
271    }
272
273    #[tokio::test]
274    async fn tool_call_routes_to_shell_executor() {
275        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
276        let call = ToolCall {
277            tool_id: ToolName::new("bash"),
278            params: serde_json::Map::new(),
279            caller_id: None,
280        };
281        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
282        assert_eq!(result.summary, "shell_handler");
283    }
284
285    #[tokio::test]
286    async fn tool_call_unhandled_returns_none() {
287        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
288        let call = ToolCall {
289            tool_id: ToolName::new("unknown"),
290            params: serde_json::Map::new(),
291            caller_id: None,
292        };
293        let result = composite.execute_tool_call(&call).await.unwrap();
294        assert!(result.is_none());
295    }
296}