Skip to main content

zeph_tools/
composite.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
5use crate::registry::ToolDef;
6
7/// Chains two `ToolExecutor` implementations with first-match-wins dispatch.
8///
9/// Tries `first`, falls through to `second` if it returns `Ok(None)`.
10/// Errors from `first` propagate immediately without trying `second`.
11#[derive(Debug)]
12pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
13    first: A,
14    second: B,
15}
16
17impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
18    #[must_use]
19    pub fn new(first: A, second: B) -> Self {
20        Self { first, second }
21    }
22}
23
24impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
25    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
26        if let Some(output) = self.first.execute(response).await? {
27            return Ok(Some(output));
28        }
29        self.second.execute(response).await
30    }
31
32    async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
33        if let Some(output) = self.first.execute_confirmed(response).await? {
34            return Ok(Some(output));
35        }
36        self.second.execute_confirmed(response).await
37    }
38
39    fn tool_definitions(&self) -> Vec<ToolDef> {
40        let mut defs = self.first.tool_definitions();
41        let seen: std::collections::HashSet<String> =
42            defs.iter().map(|d| d.id.to_string()).collect();
43        for def in self.second.tool_definitions() {
44            if !seen.contains(def.id.as_ref()) {
45                defs.push(def);
46            }
47        }
48        defs
49    }
50
51    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
52        if let Some(output) = self.first.execute_tool_call(call).await? {
53            return Ok(Some(output));
54        }
55        self.second.execute_tool_call(call).await
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62
63    #[derive(Debug)]
64    struct MatchingExecutor;
65    impl ToolExecutor for MatchingExecutor {
66        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
67            Ok(Some(ToolOutput {
68                tool_name: "test".to_owned(),
69                summary: "matched".to_owned(),
70                blocks_executed: 1,
71                filter_stats: None,
72                diff: None,
73                streamed: false,
74                terminal_id: None,
75            }))
76        }
77    }
78
79    #[derive(Debug)]
80    struct NoMatchExecutor;
81    impl ToolExecutor for NoMatchExecutor {
82        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
83            Ok(None)
84        }
85    }
86
87    #[derive(Debug)]
88    struct ErrorExecutor;
89    impl ToolExecutor for ErrorExecutor {
90        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
91            Err(ToolError::Blocked {
92                command: "test".to_owned(),
93            })
94        }
95    }
96
97    #[derive(Debug)]
98    struct SecondExecutor;
99    impl ToolExecutor for SecondExecutor {
100        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
101            Ok(Some(ToolOutput {
102                tool_name: "test".to_owned(),
103                summary: "second".to_owned(),
104                blocks_executed: 1,
105                filter_stats: None,
106                diff: None,
107                streamed: false,
108                terminal_id: None,
109            }))
110        }
111    }
112
113    #[tokio::test]
114    async fn first_matches_returns_first() {
115        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
116        let result = composite.execute("anything").await.unwrap();
117        assert_eq!(result.unwrap().summary, "matched");
118    }
119
120    #[tokio::test]
121    async fn first_none_falls_through_to_second() {
122        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
123        let result = composite.execute("anything").await.unwrap();
124        assert_eq!(result.unwrap().summary, "second");
125    }
126
127    #[tokio::test]
128    async fn both_none_returns_none() {
129        let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
130        let result = composite.execute("anything").await.unwrap();
131        assert!(result.is_none());
132    }
133
134    #[tokio::test]
135    async fn first_error_propagates_without_trying_second() {
136        let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
137        let result = composite.execute("anything").await;
138        assert!(matches!(result, Err(ToolError::Blocked { .. })));
139    }
140
141    #[tokio::test]
142    async fn second_error_propagates_when_first_none() {
143        let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
144        let result = composite.execute("anything").await;
145        assert!(matches!(result, Err(ToolError::Blocked { .. })));
146    }
147
148    #[tokio::test]
149    async fn execute_confirmed_first_matches() {
150        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
151        let result = composite.execute_confirmed("anything").await.unwrap();
152        assert_eq!(result.unwrap().summary, "matched");
153    }
154
155    #[tokio::test]
156    async fn execute_confirmed_falls_through() {
157        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
158        let result = composite.execute_confirmed("anything").await.unwrap();
159        assert_eq!(result.unwrap().summary, "second");
160    }
161
162    #[test]
163    fn composite_debug() {
164        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
165        let debug = format!("{composite:?}");
166        assert!(debug.contains("CompositeExecutor"));
167    }
168
169    #[derive(Debug)]
170    struct FileToolExecutor;
171    impl ToolExecutor for FileToolExecutor {
172        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
173            Ok(None)
174        }
175        async fn execute_tool_call(
176            &self,
177            call: &ToolCall,
178        ) -> Result<Option<ToolOutput>, ToolError> {
179            if call.tool_id == "read" || call.tool_id == "write" {
180                Ok(Some(ToolOutput {
181                    tool_name: call.tool_id.clone(),
182                    summary: "file_handler".to_owned(),
183                    blocks_executed: 1,
184                    filter_stats: None,
185                    diff: None,
186                    streamed: false,
187                    terminal_id: None,
188                }))
189            } else {
190                Ok(None)
191            }
192        }
193    }
194
195    #[derive(Debug)]
196    struct ShellToolExecutor;
197    impl ToolExecutor for ShellToolExecutor {
198        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
199            Ok(None)
200        }
201        async fn execute_tool_call(
202            &self,
203            call: &ToolCall,
204        ) -> Result<Option<ToolOutput>, ToolError> {
205            if call.tool_id == "bash" {
206                Ok(Some(ToolOutput {
207                    tool_name: "bash".to_owned(),
208                    summary: "shell_handler".to_owned(),
209                    blocks_executed: 1,
210                    filter_stats: None,
211                    diff: None,
212                    streamed: false,
213                    terminal_id: None,
214                }))
215            } else {
216                Ok(None)
217            }
218        }
219    }
220
221    #[tokio::test]
222    async fn tool_call_routes_to_file_executor() {
223        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
224        let call = ToolCall {
225            tool_id: "read".to_owned(),
226            params: serde_json::Map::new(),
227        };
228        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
229        assert_eq!(result.summary, "file_handler");
230    }
231
232    #[tokio::test]
233    async fn tool_call_routes_to_shell_executor() {
234        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
235        let call = ToolCall {
236            tool_id: "bash".to_owned(),
237            params: serde_json::Map::new(),
238        };
239        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
240        assert_eq!(result.summary, "shell_handler");
241    }
242
243    #[tokio::test]
244    async fn tool_call_unhandled_returns_none() {
245        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
246        let call = ToolCall {
247            tool_id: "unknown".to_owned(),
248            params: serde_json::Map::new(),
249        };
250        let result = composite.execute_tool_call(&call).await.unwrap();
251        assert!(result.is_none());
252    }
253}