Skip to main content

zeph_tools/
composite.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
5use crate::registry::ToolDef;
6
7/// Chains two `ToolExecutor` implementations with first-match-wins dispatch.
8///
9/// Tries `first`, falls through to `second` if it returns `Ok(None)`.
10/// Errors from `first` propagate immediately without trying `second`.
11#[derive(Debug)]
12pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
13    first: A,
14    second: B,
15}
16
17impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
18    #[must_use]
19    pub fn new(first: A, second: B) -> Self {
20        Self { first, second }
21    }
22}
23
24impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
25    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
26        if let Some(output) = self.first.execute(response).await? {
27            return Ok(Some(output));
28        }
29        self.second.execute(response).await
30    }
31
32    async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
33        if let Some(output) = self.first.execute_confirmed(response).await? {
34            return Ok(Some(output));
35        }
36        self.second.execute_confirmed(response).await
37    }
38
39    fn tool_definitions(&self) -> Vec<ToolDef> {
40        let mut defs = self.first.tool_definitions();
41        let seen: std::collections::HashSet<String> =
42            defs.iter().map(|d| d.id.to_string()).collect();
43        for def in self.second.tool_definitions() {
44            if !seen.contains(def.id.as_ref()) {
45                defs.push(def);
46            }
47        }
48        defs
49    }
50
51    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
52        if let Some(output) = self.first.execute_tool_call(call).await? {
53            return Ok(Some(output));
54        }
55        self.second.execute_tool_call(call).await
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62
63    #[derive(Debug)]
64    struct MatchingExecutor;
65    impl ToolExecutor for MatchingExecutor {
66        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
67            Ok(Some(ToolOutput {
68                tool_name: "test".to_owned(),
69                summary: "matched".to_owned(),
70                blocks_executed: 1,
71                filter_stats: None,
72                diff: None,
73                streamed: false,
74                terminal_id: None,
75                locations: None,
76            }))
77        }
78    }
79
80    #[derive(Debug)]
81    struct NoMatchExecutor;
82    impl ToolExecutor for NoMatchExecutor {
83        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
84            Ok(None)
85        }
86    }
87
88    #[derive(Debug)]
89    struct ErrorExecutor;
90    impl ToolExecutor for ErrorExecutor {
91        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
92            Err(ToolError::Blocked {
93                command: "test".to_owned(),
94            })
95        }
96    }
97
98    #[derive(Debug)]
99    struct SecondExecutor;
100    impl ToolExecutor for SecondExecutor {
101        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
102            Ok(Some(ToolOutput {
103                tool_name: "test".to_owned(),
104                summary: "second".to_owned(),
105                blocks_executed: 1,
106                filter_stats: None,
107                diff: None,
108                streamed: false,
109                terminal_id: None,
110                locations: None,
111            }))
112        }
113    }
114
115    #[tokio::test]
116    async fn first_matches_returns_first() {
117        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
118        let result = composite.execute("anything").await.unwrap();
119        assert_eq!(result.unwrap().summary, "matched");
120    }
121
122    #[tokio::test]
123    async fn first_none_falls_through_to_second() {
124        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
125        let result = composite.execute("anything").await.unwrap();
126        assert_eq!(result.unwrap().summary, "second");
127    }
128
129    #[tokio::test]
130    async fn both_none_returns_none() {
131        let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
132        let result = composite.execute("anything").await.unwrap();
133        assert!(result.is_none());
134    }
135
136    #[tokio::test]
137    async fn first_error_propagates_without_trying_second() {
138        let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
139        let result = composite.execute("anything").await;
140        assert!(matches!(result, Err(ToolError::Blocked { .. })));
141    }
142
143    #[tokio::test]
144    async fn second_error_propagates_when_first_none() {
145        let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
146        let result = composite.execute("anything").await;
147        assert!(matches!(result, Err(ToolError::Blocked { .. })));
148    }
149
150    #[tokio::test]
151    async fn execute_confirmed_first_matches() {
152        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
153        let result = composite.execute_confirmed("anything").await.unwrap();
154        assert_eq!(result.unwrap().summary, "matched");
155    }
156
157    #[tokio::test]
158    async fn execute_confirmed_falls_through() {
159        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
160        let result = composite.execute_confirmed("anything").await.unwrap();
161        assert_eq!(result.unwrap().summary, "second");
162    }
163
164    #[test]
165    fn composite_debug() {
166        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
167        let debug = format!("{composite:?}");
168        assert!(debug.contains("CompositeExecutor"));
169    }
170
171    #[derive(Debug)]
172    struct FileToolExecutor;
173    impl ToolExecutor for FileToolExecutor {
174        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
175            Ok(None)
176        }
177        async fn execute_tool_call(
178            &self,
179            call: &ToolCall,
180        ) -> Result<Option<ToolOutput>, ToolError> {
181            if call.tool_id == "read" || call.tool_id == "write" {
182                Ok(Some(ToolOutput {
183                    tool_name: call.tool_id.clone(),
184                    summary: "file_handler".to_owned(),
185                    blocks_executed: 1,
186                    filter_stats: None,
187                    diff: None,
188                    streamed: false,
189                    terminal_id: None,
190                    locations: None,
191                }))
192            } else {
193                Ok(None)
194            }
195        }
196    }
197
198    #[derive(Debug)]
199    struct ShellToolExecutor;
200    impl ToolExecutor for ShellToolExecutor {
201        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
202            Ok(None)
203        }
204        async fn execute_tool_call(
205            &self,
206            call: &ToolCall,
207        ) -> Result<Option<ToolOutput>, ToolError> {
208            if call.tool_id == "bash" {
209                Ok(Some(ToolOutput {
210                    tool_name: "bash".to_owned(),
211                    summary: "shell_handler".to_owned(),
212                    blocks_executed: 1,
213                    filter_stats: None,
214                    diff: None,
215                    streamed: false,
216                    terminal_id: None,
217                    locations: None,
218                }))
219            } else {
220                Ok(None)
221            }
222        }
223    }
224
225    #[tokio::test]
226    async fn tool_call_routes_to_file_executor() {
227        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
228        let call = ToolCall {
229            tool_id: "read".to_owned(),
230            params: serde_json::Map::new(),
231        };
232        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
233        assert_eq!(result.summary, "file_handler");
234    }
235
236    #[tokio::test]
237    async fn tool_call_routes_to_shell_executor() {
238        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
239        let call = ToolCall {
240            tool_id: "bash".to_owned(),
241            params: serde_json::Map::new(),
242        };
243        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
244        assert_eq!(result.summary, "shell_handler");
245    }
246
247    #[tokio::test]
248    async fn tool_call_unhandled_returns_none() {
249        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
250        let call = ToolCall {
251            tool_id: "unknown".to_owned(),
252            params: serde_json::Map::new(),
253        };
254        let result = composite.execute_tool_call(&call).await.unwrap();
255        assert!(result.is_none());
256    }
257}