Skip to main content

zeph_tools/
composite.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
5use crate::registry::ToolDef;
6
7/// Chains two `ToolExecutor` implementations with first-match-wins dispatch.
8///
9/// Tries `first`, falls through to `second` if it returns `Ok(None)`.
10/// Errors from `first` propagate immediately without trying `second`.
11#[derive(Debug)]
12pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
13    first: A,
14    second: B,
15}
16
17impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
18    #[must_use]
19    pub fn new(first: A, second: B) -> Self {
20        Self { first, second }
21    }
22}
23
24impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
25    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
26        if let Some(output) = self.first.execute(response).await? {
27            return Ok(Some(output));
28        }
29        self.second.execute(response).await
30    }
31
32    async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
33        if let Some(output) = self.first.execute_confirmed(response).await? {
34            return Ok(Some(output));
35        }
36        self.second.execute_confirmed(response).await
37    }
38
39    fn tool_definitions(&self) -> Vec<ToolDef> {
40        let mut defs = self.first.tool_definitions();
41        let seen: std::collections::HashSet<&str> = defs.iter().map(|d| d.id).collect();
42        for def in self.second.tool_definitions() {
43            if !seen.contains(def.id) {
44                defs.push(def);
45            }
46        }
47        defs
48    }
49
50    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
51        if let Some(output) = self.first.execute_tool_call(call).await? {
52            return Ok(Some(output));
53        }
54        self.second.execute_tool_call(call).await
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61
62    #[derive(Debug)]
63    struct MatchingExecutor;
64    impl ToolExecutor for MatchingExecutor {
65        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
66            Ok(Some(ToolOutput {
67                tool_name: "test".to_owned(),
68                summary: "matched".to_owned(),
69                blocks_executed: 1,
70                filter_stats: None,
71                diff: None,
72                streamed: false,
73            }))
74        }
75    }
76
77    #[derive(Debug)]
78    struct NoMatchExecutor;
79    impl ToolExecutor for NoMatchExecutor {
80        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
81            Ok(None)
82        }
83    }
84
85    #[derive(Debug)]
86    struct ErrorExecutor;
87    impl ToolExecutor for ErrorExecutor {
88        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
89            Err(ToolError::Blocked {
90                command: "test".to_owned(),
91            })
92        }
93    }
94
95    #[derive(Debug)]
96    struct SecondExecutor;
97    impl ToolExecutor for SecondExecutor {
98        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
99            Ok(Some(ToolOutput {
100                tool_name: "test".to_owned(),
101                summary: "second".to_owned(),
102                blocks_executed: 1,
103                filter_stats: None,
104                diff: None,
105                streamed: false,
106            }))
107        }
108    }
109
110    #[tokio::test]
111    async fn first_matches_returns_first() {
112        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
113        let result = composite.execute("anything").await.unwrap();
114        assert_eq!(result.unwrap().summary, "matched");
115    }
116
117    #[tokio::test]
118    async fn first_none_falls_through_to_second() {
119        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
120        let result = composite.execute("anything").await.unwrap();
121        assert_eq!(result.unwrap().summary, "second");
122    }
123
124    #[tokio::test]
125    async fn both_none_returns_none() {
126        let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
127        let result = composite.execute("anything").await.unwrap();
128        assert!(result.is_none());
129    }
130
131    #[tokio::test]
132    async fn first_error_propagates_without_trying_second() {
133        let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
134        let result = composite.execute("anything").await;
135        assert!(matches!(result, Err(ToolError::Blocked { .. })));
136    }
137
138    #[tokio::test]
139    async fn second_error_propagates_when_first_none() {
140        let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
141        let result = composite.execute("anything").await;
142        assert!(matches!(result, Err(ToolError::Blocked { .. })));
143    }
144
145    #[tokio::test]
146    async fn execute_confirmed_first_matches() {
147        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
148        let result = composite.execute_confirmed("anything").await.unwrap();
149        assert_eq!(result.unwrap().summary, "matched");
150    }
151
152    #[tokio::test]
153    async fn execute_confirmed_falls_through() {
154        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
155        let result = composite.execute_confirmed("anything").await.unwrap();
156        assert_eq!(result.unwrap().summary, "second");
157    }
158
159    #[test]
160    fn composite_debug() {
161        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
162        let debug = format!("{composite:?}");
163        assert!(debug.contains("CompositeExecutor"));
164    }
165
166    #[derive(Debug)]
167    struct FileToolExecutor;
168    impl ToolExecutor for FileToolExecutor {
169        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
170            Ok(None)
171        }
172        async fn execute_tool_call(
173            &self,
174            call: &ToolCall,
175        ) -> Result<Option<ToolOutput>, ToolError> {
176            if call.tool_id == "read" || call.tool_id == "write" {
177                Ok(Some(ToolOutput {
178                    tool_name: call.tool_id.clone(),
179                    summary: "file_handler".to_owned(),
180                    blocks_executed: 1,
181                    filter_stats: None,
182                    diff: None,
183                    streamed: false,
184                }))
185            } else {
186                Ok(None)
187            }
188        }
189    }
190
191    #[derive(Debug)]
192    struct ShellToolExecutor;
193    impl ToolExecutor for ShellToolExecutor {
194        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
195            Ok(None)
196        }
197        async fn execute_tool_call(
198            &self,
199            call: &ToolCall,
200        ) -> Result<Option<ToolOutput>, ToolError> {
201            if call.tool_id == "bash" {
202                Ok(Some(ToolOutput {
203                    tool_name: "bash".to_owned(),
204                    summary: "shell_handler".to_owned(),
205                    blocks_executed: 1,
206                    filter_stats: None,
207                    diff: None,
208                    streamed: false,
209                }))
210            } else {
211                Ok(None)
212            }
213        }
214    }
215
216    #[tokio::test]
217    async fn tool_call_routes_to_file_executor() {
218        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
219        let call = ToolCall {
220            tool_id: "read".to_owned(),
221            params: serde_json::Map::new(),
222        };
223        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
224        assert_eq!(result.summary, "file_handler");
225    }
226
227    #[tokio::test]
228    async fn tool_call_routes_to_shell_executor() {
229        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
230        let call = ToolCall {
231            tool_id: "bash".to_owned(),
232            params: serde_json::Map::new(),
233        };
234        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
235        assert_eq!(result.summary, "shell_handler");
236    }
237
238    #[tokio::test]
239    async fn tool_call_unhandled_returns_none() {
240        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
241        let call = ToolCall {
242            tool_id: "unknown".to_owned(),
243            params: serde_json::Map::new(),
244        };
245        let result = composite.execute_tool_call(&call).await.unwrap();
246        assert!(result.is_none());
247    }
248}