Skip to main content

zeph_tools/
composite.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
5use crate::registry::ToolDef;
6
7/// Chains two `ToolExecutor` implementations with first-match-wins dispatch.
8///
9/// Tries `first`, falls through to `second` if it returns `Ok(None)`.
10/// Errors from `first` propagate immediately without trying `second`.
11#[derive(Debug)]
12pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
13    first: A,
14    second: B,
15}
16
17impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
18    #[must_use]
19    pub fn new(first: A, second: B) -> Self {
20        Self { first, second }
21    }
22}
23
24impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
25    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
26        if let Some(output) = self.first.execute(response).await? {
27            return Ok(Some(output));
28        }
29        self.second.execute(response).await
30    }
31
32    async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
33        if let Some(output) = self.first.execute_confirmed(response).await? {
34            return Ok(Some(output));
35        }
36        self.second.execute_confirmed(response).await
37    }
38
39    fn tool_definitions(&self) -> Vec<ToolDef> {
40        let mut defs = self.first.tool_definitions();
41        let seen: std::collections::HashSet<String> =
42            defs.iter().map(|d| d.id.to_string()).collect();
43        for def in self.second.tool_definitions() {
44            if !seen.contains(def.id.as_ref()) {
45                defs.push(def);
46            }
47        }
48        defs
49    }
50
51    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
52        if let Some(output) = self.first.execute_tool_call(call).await? {
53            return Ok(Some(output));
54        }
55        self.second.execute_tool_call(call).await
56    }
57
58    fn is_tool_retryable(&self, tool_id: &str) -> bool {
59        self.first.is_tool_retryable(tool_id) || self.second.is_tool_retryable(tool_id)
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    #[derive(Debug)]
68    struct MatchingExecutor;
69    impl ToolExecutor for MatchingExecutor {
70        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
71            Ok(Some(ToolOutput {
72                tool_name: "test".to_owned(),
73                summary: "matched".to_owned(),
74                blocks_executed: 1,
75                filter_stats: None,
76                diff: None,
77                streamed: false,
78                terminal_id: None,
79                locations: None,
80                raw_response: None,
81                claim_source: None,
82            }))
83        }
84    }
85
86    #[derive(Debug)]
87    struct NoMatchExecutor;
88    impl ToolExecutor for NoMatchExecutor {
89        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
90            Ok(None)
91        }
92    }
93
94    #[derive(Debug)]
95    struct ErrorExecutor;
96    impl ToolExecutor for ErrorExecutor {
97        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
98            Err(ToolError::Blocked {
99                command: "test".to_owned(),
100            })
101        }
102    }
103
104    #[derive(Debug)]
105    struct SecondExecutor;
106    impl ToolExecutor for SecondExecutor {
107        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
108            Ok(Some(ToolOutput {
109                tool_name: "test".to_owned(),
110                summary: "second".to_owned(),
111                blocks_executed: 1,
112                filter_stats: None,
113                diff: None,
114                streamed: false,
115                terminal_id: None,
116                locations: None,
117                raw_response: None,
118                claim_source: None,
119            }))
120        }
121    }
122
123    #[tokio::test]
124    async fn first_matches_returns_first() {
125        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
126        let result = composite.execute("anything").await.unwrap();
127        assert_eq!(result.unwrap().summary, "matched");
128    }
129
130    #[tokio::test]
131    async fn first_none_falls_through_to_second() {
132        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
133        let result = composite.execute("anything").await.unwrap();
134        assert_eq!(result.unwrap().summary, "second");
135    }
136
137    #[tokio::test]
138    async fn both_none_returns_none() {
139        let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
140        let result = composite.execute("anything").await.unwrap();
141        assert!(result.is_none());
142    }
143
144    #[tokio::test]
145    async fn first_error_propagates_without_trying_second() {
146        let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
147        let result = composite.execute("anything").await;
148        assert!(matches!(result, Err(ToolError::Blocked { .. })));
149    }
150
151    #[tokio::test]
152    async fn second_error_propagates_when_first_none() {
153        let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
154        let result = composite.execute("anything").await;
155        assert!(matches!(result, Err(ToolError::Blocked { .. })));
156    }
157
158    #[tokio::test]
159    async fn execute_confirmed_first_matches() {
160        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
161        let result = composite.execute_confirmed("anything").await.unwrap();
162        assert_eq!(result.unwrap().summary, "matched");
163    }
164
165    #[tokio::test]
166    async fn execute_confirmed_falls_through() {
167        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
168        let result = composite.execute_confirmed("anything").await.unwrap();
169        assert_eq!(result.unwrap().summary, "second");
170    }
171
172    #[test]
173    fn composite_debug() {
174        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
175        let debug = format!("{composite:?}");
176        assert!(debug.contains("CompositeExecutor"));
177    }
178
179    #[derive(Debug)]
180    struct FileToolExecutor;
181    impl ToolExecutor for FileToolExecutor {
182        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
183            Ok(None)
184        }
185        async fn execute_tool_call(
186            &self,
187            call: &ToolCall,
188        ) -> Result<Option<ToolOutput>, ToolError> {
189            if call.tool_id == "read" || call.tool_id == "write" {
190                Ok(Some(ToolOutput {
191                    tool_name: call.tool_id.clone(),
192                    summary: "file_handler".to_owned(),
193                    blocks_executed: 1,
194                    filter_stats: None,
195                    diff: None,
196                    streamed: false,
197                    terminal_id: None,
198                    locations: None,
199                    raw_response: None,
200                    claim_source: None,
201                }))
202            } else {
203                Ok(None)
204            }
205        }
206    }
207
208    #[derive(Debug)]
209    struct ShellToolExecutor;
210    impl ToolExecutor for ShellToolExecutor {
211        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
212            Ok(None)
213        }
214        async fn execute_tool_call(
215            &self,
216            call: &ToolCall,
217        ) -> Result<Option<ToolOutput>, ToolError> {
218            if call.tool_id == "bash" {
219                Ok(Some(ToolOutput {
220                    tool_name: "bash".to_owned(),
221                    summary: "shell_handler".to_owned(),
222                    blocks_executed: 1,
223                    filter_stats: None,
224                    diff: None,
225                    streamed: false,
226                    terminal_id: None,
227                    locations: None,
228                    raw_response: None,
229                    claim_source: None,
230                }))
231            } else {
232                Ok(None)
233            }
234        }
235    }
236
237    #[tokio::test]
238    async fn tool_call_routes_to_file_executor() {
239        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
240        let call = ToolCall {
241            tool_id: "read".to_owned(),
242            params: serde_json::Map::new(),
243        };
244        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
245        assert_eq!(result.summary, "file_handler");
246    }
247
248    #[tokio::test]
249    async fn tool_call_routes_to_shell_executor() {
250        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
251        let call = ToolCall {
252            tool_id: "bash".to_owned(),
253            params: serde_json::Map::new(),
254        };
255        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
256        assert_eq!(result.summary, "shell_handler");
257    }
258
259    #[tokio::test]
260    async fn tool_call_unhandled_returns_none() {
261        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
262        let call = ToolCall {
263            tool_id: "unknown".to_owned(),
264            params: serde_json::Map::new(),
265        };
266        let result = composite.execute_tool_call(&call).await.unwrap();
267        assert!(result.is_none());
268    }
269}