Skip to main content

zeph_tools/
composite.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
5use crate::registry::ToolDef;
6
7/// Chains two `ToolExecutor` implementations with first-match-wins dispatch.
8///
9/// Tries `first`, falls through to `second` if it returns `Ok(None)`.
10/// Errors from `first` propagate immediately without trying `second`.
11#[derive(Debug)]
12pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
13    first: A,
14    second: B,
15}
16
17impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
18    #[must_use]
19    pub fn new(first: A, second: B) -> Self {
20        Self { first, second }
21    }
22}
23
24impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
25    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
26        if let Some(output) = self.first.execute(response).await? {
27            return Ok(Some(output));
28        }
29        self.second.execute(response).await
30    }
31
32    async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
33        if let Some(output) = self.first.execute_confirmed(response).await? {
34            return Ok(Some(output));
35        }
36        self.second.execute_confirmed(response).await
37    }
38
39    fn tool_definitions(&self) -> Vec<ToolDef> {
40        let mut defs = self.first.tool_definitions();
41        let seen: std::collections::HashSet<String> =
42            defs.iter().map(|d| d.id.to_string()).collect();
43        for def in self.second.tool_definitions() {
44            if !seen.contains(def.id.as_ref()) {
45                defs.push(def);
46            }
47        }
48        defs
49    }
50
51    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
52        if let Some(output) = self.first.execute_tool_call(call).await? {
53            return Ok(Some(output));
54        }
55        self.second.execute_tool_call(call).await
56    }
57
58    fn is_tool_retryable(&self, tool_id: &str) -> bool {
59        self.first.is_tool_retryable(tool_id) || self.second.is_tool_retryable(tool_id)
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    #[derive(Debug)]
68    struct MatchingExecutor;
69    impl ToolExecutor for MatchingExecutor {
70        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
71            Ok(Some(ToolOutput {
72                tool_name: "test".to_owned(),
73                summary: "matched".to_owned(),
74                blocks_executed: 1,
75                filter_stats: None,
76                diff: None,
77                streamed: false,
78                terminal_id: None,
79                locations: None,
80                raw_response: None,
81            }))
82        }
83    }
84
85    #[derive(Debug)]
86    struct NoMatchExecutor;
87    impl ToolExecutor for NoMatchExecutor {
88        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
89            Ok(None)
90        }
91    }
92
93    #[derive(Debug)]
94    struct ErrorExecutor;
95    impl ToolExecutor for ErrorExecutor {
96        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
97            Err(ToolError::Blocked {
98                command: "test".to_owned(),
99            })
100        }
101    }
102
103    #[derive(Debug)]
104    struct SecondExecutor;
105    impl ToolExecutor for SecondExecutor {
106        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
107            Ok(Some(ToolOutput {
108                tool_name: "test".to_owned(),
109                summary: "second".to_owned(),
110                blocks_executed: 1,
111                filter_stats: None,
112                diff: None,
113                streamed: false,
114                terminal_id: None,
115                locations: None,
116                raw_response: None,
117            }))
118        }
119    }
120
121    #[tokio::test]
122    async fn first_matches_returns_first() {
123        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
124        let result = composite.execute("anything").await.unwrap();
125        assert_eq!(result.unwrap().summary, "matched");
126    }
127
128    #[tokio::test]
129    async fn first_none_falls_through_to_second() {
130        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
131        let result = composite.execute("anything").await.unwrap();
132        assert_eq!(result.unwrap().summary, "second");
133    }
134
135    #[tokio::test]
136    async fn both_none_returns_none() {
137        let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
138        let result = composite.execute("anything").await.unwrap();
139        assert!(result.is_none());
140    }
141
142    #[tokio::test]
143    async fn first_error_propagates_without_trying_second() {
144        let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
145        let result = composite.execute("anything").await;
146        assert!(matches!(result, Err(ToolError::Blocked { .. })));
147    }
148
149    #[tokio::test]
150    async fn second_error_propagates_when_first_none() {
151        let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
152        let result = composite.execute("anything").await;
153        assert!(matches!(result, Err(ToolError::Blocked { .. })));
154    }
155
156    #[tokio::test]
157    async fn execute_confirmed_first_matches() {
158        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
159        let result = composite.execute_confirmed("anything").await.unwrap();
160        assert_eq!(result.unwrap().summary, "matched");
161    }
162
163    #[tokio::test]
164    async fn execute_confirmed_falls_through() {
165        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
166        let result = composite.execute_confirmed("anything").await.unwrap();
167        assert_eq!(result.unwrap().summary, "second");
168    }
169
170    #[test]
171    fn composite_debug() {
172        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
173        let debug = format!("{composite:?}");
174        assert!(debug.contains("CompositeExecutor"));
175    }
176
177    #[derive(Debug)]
178    struct FileToolExecutor;
179    impl ToolExecutor for FileToolExecutor {
180        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
181            Ok(None)
182        }
183        async fn execute_tool_call(
184            &self,
185            call: &ToolCall,
186        ) -> Result<Option<ToolOutput>, ToolError> {
187            if call.tool_id == "read" || call.tool_id == "write" {
188                Ok(Some(ToolOutput {
189                    tool_name: call.tool_id.clone(),
190                    summary: "file_handler".to_owned(),
191                    blocks_executed: 1,
192                    filter_stats: None,
193                    diff: None,
194                    streamed: false,
195                    terminal_id: None,
196                    locations: None,
197                    raw_response: None,
198                }))
199            } else {
200                Ok(None)
201            }
202        }
203    }
204
205    #[derive(Debug)]
206    struct ShellToolExecutor;
207    impl ToolExecutor for ShellToolExecutor {
208        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
209            Ok(None)
210        }
211        async fn execute_tool_call(
212            &self,
213            call: &ToolCall,
214        ) -> Result<Option<ToolOutput>, ToolError> {
215            if call.tool_id == "bash" {
216                Ok(Some(ToolOutput {
217                    tool_name: "bash".to_owned(),
218                    summary: "shell_handler".to_owned(),
219                    blocks_executed: 1,
220                    filter_stats: None,
221                    diff: None,
222                    streamed: false,
223                    terminal_id: None,
224                    locations: None,
225                    raw_response: None,
226                }))
227            } else {
228                Ok(None)
229            }
230        }
231    }
232
233    #[tokio::test]
234    async fn tool_call_routes_to_file_executor() {
235        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
236        let call = ToolCall {
237            tool_id: "read".to_owned(),
238            params: serde_json::Map::new(),
239        };
240        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
241        assert_eq!(result.summary, "file_handler");
242    }
243
244    #[tokio::test]
245    async fn tool_call_routes_to_shell_executor() {
246        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
247        let call = ToolCall {
248            tool_id: "bash".to_owned(),
249            params: serde_json::Map::new(),
250        };
251        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
252        assert_eq!(result.summary, "shell_handler");
253    }
254
255    #[tokio::test]
256    async fn tool_call_unhandled_returns_none() {
257        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
258        let call = ToolCall {
259            tool_id: "unknown".to_owned(),
260            params: serde_json::Map::new(),
261        };
262        let result = composite.execute_tool_call(&call).await.unwrap();
263        assert!(result.is_none());
264    }
265}