1use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
7use crate::registry::ToolDef;
8
9#[derive(Debug)]
35pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
36 first: A,
37 second: B,
38}
39
40impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
41 #[must_use]
43 pub fn new(first: A, second: B) -> Self {
44 Self { first, second }
45 }
46}
47
48impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
49 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
50 if let Some(output) = self.first.execute(response).await? {
51 return Ok(Some(output));
52 }
53 self.second.execute(response).await
54 }
55
56 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
57 if let Some(output) = self.first.execute_confirmed(response).await? {
58 return Ok(Some(output));
59 }
60 self.second.execute_confirmed(response).await
61 }
62
63 fn tool_definitions(&self) -> Vec<ToolDef> {
64 let mut defs = self.first.tool_definitions();
65 let seen: std::collections::HashSet<String> =
66 defs.iter().map(|d| d.id.to_string()).collect();
67 for def in self.second.tool_definitions() {
68 if !seen.contains(def.id.as_ref()) {
69 defs.push(def);
70 }
71 }
72 defs
73 }
74
75 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
76 if let Some(output) = self.first.execute_tool_call(call).await? {
77 return Ok(Some(output));
78 }
79 self.second.execute_tool_call(call).await
80 }
81
82 fn is_tool_retryable(&self, tool_id: &str) -> bool {
83 self.first.is_tool_retryable(tool_id) || self.second.is_tool_retryable(tool_id)
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use crate::ToolName;
91
92 #[derive(Debug)]
93 struct MatchingExecutor;
94 impl ToolExecutor for MatchingExecutor {
95 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
96 Ok(Some(ToolOutput {
97 tool_name: ToolName::new("test"),
98 summary: "matched".to_owned(),
99 blocks_executed: 1,
100 filter_stats: None,
101 diff: None,
102 streamed: false,
103 terminal_id: None,
104 locations: None,
105 raw_response: None,
106 claim_source: None,
107 }))
108 }
109 }
110
111 #[derive(Debug)]
112 struct NoMatchExecutor;
113 impl ToolExecutor for NoMatchExecutor {
114 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
115 Ok(None)
116 }
117 }
118
119 #[derive(Debug)]
120 struct ErrorExecutor;
121 impl ToolExecutor for ErrorExecutor {
122 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
123 Err(ToolError::Blocked {
124 command: "test".to_owned(),
125 })
126 }
127 }
128
129 #[derive(Debug)]
130 struct SecondExecutor;
131 impl ToolExecutor for SecondExecutor {
132 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
133 Ok(Some(ToolOutput {
134 tool_name: ToolName::new("test"),
135 summary: "second".to_owned(),
136 blocks_executed: 1,
137 filter_stats: None,
138 diff: None,
139 streamed: false,
140 terminal_id: None,
141 locations: None,
142 raw_response: None,
143 claim_source: None,
144 }))
145 }
146 }
147
148 #[tokio::test]
149 async fn first_matches_returns_first() {
150 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
151 let result = composite.execute("anything").await.unwrap();
152 assert_eq!(result.unwrap().summary, "matched");
153 }
154
155 #[tokio::test]
156 async fn first_none_falls_through_to_second() {
157 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
158 let result = composite.execute("anything").await.unwrap();
159 assert_eq!(result.unwrap().summary, "second");
160 }
161
162 #[tokio::test]
163 async fn both_none_returns_none() {
164 let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
165 let result = composite.execute("anything").await.unwrap();
166 assert!(result.is_none());
167 }
168
169 #[tokio::test]
170 async fn first_error_propagates_without_trying_second() {
171 let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
172 let result = composite.execute("anything").await;
173 assert!(matches!(result, Err(ToolError::Blocked { .. })));
174 }
175
176 #[tokio::test]
177 async fn second_error_propagates_when_first_none() {
178 let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
179 let result = composite.execute("anything").await;
180 assert!(matches!(result, Err(ToolError::Blocked { .. })));
181 }
182
183 #[tokio::test]
184 async fn execute_confirmed_first_matches() {
185 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
186 let result = composite.execute_confirmed("anything").await.unwrap();
187 assert_eq!(result.unwrap().summary, "matched");
188 }
189
190 #[tokio::test]
191 async fn execute_confirmed_falls_through() {
192 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
193 let result = composite.execute_confirmed("anything").await.unwrap();
194 assert_eq!(result.unwrap().summary, "second");
195 }
196
197 #[test]
198 fn composite_debug() {
199 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
200 let debug = format!("{composite:?}");
201 assert!(debug.contains("CompositeExecutor"));
202 }
203
204 #[derive(Debug)]
205 struct FileToolExecutor;
206 impl ToolExecutor for FileToolExecutor {
207 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
208 Ok(None)
209 }
210 async fn execute_tool_call(
211 &self,
212 call: &ToolCall,
213 ) -> Result<Option<ToolOutput>, ToolError> {
214 if call.tool_id == "read" || call.tool_id == "write" {
215 Ok(Some(ToolOutput {
216 tool_name: call.tool_id.clone(),
217 summary: "file_handler".to_owned(),
218 blocks_executed: 1,
219 filter_stats: None,
220 diff: None,
221 streamed: false,
222 terminal_id: None,
223 locations: None,
224 raw_response: None,
225 claim_source: None,
226 }))
227 } else {
228 Ok(None)
229 }
230 }
231 }
232
233 #[derive(Debug)]
234 struct ShellToolExecutor;
235 impl ToolExecutor for ShellToolExecutor {
236 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
237 Ok(None)
238 }
239 async fn execute_tool_call(
240 &self,
241 call: &ToolCall,
242 ) -> Result<Option<ToolOutput>, ToolError> {
243 if call.tool_id == "bash" {
244 Ok(Some(ToolOutput {
245 tool_name: ToolName::new("bash"),
246 summary: "shell_handler".to_owned(),
247 blocks_executed: 1,
248 filter_stats: None,
249 diff: None,
250 streamed: false,
251 terminal_id: None,
252 locations: None,
253 raw_response: None,
254 claim_source: None,
255 }))
256 } else {
257 Ok(None)
258 }
259 }
260 }
261
262 #[tokio::test]
263 async fn tool_call_routes_to_file_executor() {
264 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
265 let call = ToolCall {
266 tool_id: ToolName::new("read"),
267 params: serde_json::Map::new(),
268 caller_id: None,
269 };
270 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
271 assert_eq!(result.summary, "file_handler");
272 }
273
274 #[tokio::test]
275 async fn tool_call_routes_to_shell_executor() {
276 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
277 let call = ToolCall {
278 tool_id: ToolName::new("bash"),
279 params: serde_json::Map::new(),
280 caller_id: None,
281 };
282 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
283 assert_eq!(result.summary, "shell_handler");
284 }
285
286 #[tokio::test]
287 async fn tool_call_unhandled_returns_none() {
288 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
289 let call = ToolCall {
290 tool_id: ToolName::new("unknown"),
291 params: serde_json::Map::new(),
292 caller_id: None,
293 };
294 let result = composite.execute_tool_call(&call).await.unwrap();
295 assert!(result.is_none());
296 }
297}