1use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
7use crate::registry::ToolDef;
8
9#[derive(Debug)]
34pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
35 first: A,
36 second: B,
37}
38
39impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
40 #[must_use]
42 pub fn new(first: A, second: B) -> Self {
43 Self { first, second }
44 }
45}
46
47impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
48 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
49 if let Some(output) = self.first.execute(response).await? {
50 return Ok(Some(output));
51 }
52 self.second.execute(response).await
53 }
54
55 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
56 if let Some(output) = self.first.execute_confirmed(response).await? {
57 return Ok(Some(output));
58 }
59 self.second.execute_confirmed(response).await
60 }
61
62 fn tool_definitions(&self) -> Vec<ToolDef> {
63 let mut defs = self.first.tool_definitions();
64 let seen: std::collections::HashSet<String> =
65 defs.iter().map(|d| d.id.to_string()).collect();
66 for def in self.second.tool_definitions() {
67 if !seen.contains(def.id.as_ref()) {
68 defs.push(def);
69 }
70 }
71 defs
72 }
73
74 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
75 if let Some(output) = self.first.execute_tool_call(call).await? {
76 return Ok(Some(output));
77 }
78 self.second.execute_tool_call(call).await
79 }
80
81 fn is_tool_retryable(&self, tool_id: &str) -> bool {
82 self.first.is_tool_retryable(tool_id) || self.second.is_tool_retryable(tool_id)
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use crate::ToolName;
90
91 #[derive(Debug)]
92 struct MatchingExecutor;
93 impl ToolExecutor for MatchingExecutor {
94 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
95 Ok(Some(ToolOutput {
96 tool_name: ToolName::new("test"),
97 summary: "matched".to_owned(),
98 blocks_executed: 1,
99 filter_stats: None,
100 diff: None,
101 streamed: false,
102 terminal_id: None,
103 locations: None,
104 raw_response: None,
105 claim_source: None,
106 }))
107 }
108 }
109
110 #[derive(Debug)]
111 struct NoMatchExecutor;
112 impl ToolExecutor for NoMatchExecutor {
113 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
114 Ok(None)
115 }
116 }
117
118 #[derive(Debug)]
119 struct ErrorExecutor;
120 impl ToolExecutor for ErrorExecutor {
121 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
122 Err(ToolError::Blocked {
123 command: "test".to_owned(),
124 })
125 }
126 }
127
128 #[derive(Debug)]
129 struct SecondExecutor;
130 impl ToolExecutor for SecondExecutor {
131 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
132 Ok(Some(ToolOutput {
133 tool_name: ToolName::new("test"),
134 summary: "second".to_owned(),
135 blocks_executed: 1,
136 filter_stats: None,
137 diff: None,
138 streamed: false,
139 terminal_id: None,
140 locations: None,
141 raw_response: None,
142 claim_source: None,
143 }))
144 }
145 }
146
147 #[tokio::test]
148 async fn first_matches_returns_first() {
149 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
150 let result = composite.execute("anything").await.unwrap();
151 assert_eq!(result.unwrap().summary, "matched");
152 }
153
154 #[tokio::test]
155 async fn first_none_falls_through_to_second() {
156 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
157 let result = composite.execute("anything").await.unwrap();
158 assert_eq!(result.unwrap().summary, "second");
159 }
160
161 #[tokio::test]
162 async fn both_none_returns_none() {
163 let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
164 let result = composite.execute("anything").await.unwrap();
165 assert!(result.is_none());
166 }
167
168 #[tokio::test]
169 async fn first_error_propagates_without_trying_second() {
170 let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
171 let result = composite.execute("anything").await;
172 assert!(matches!(result, Err(ToolError::Blocked { .. })));
173 }
174
175 #[tokio::test]
176 async fn second_error_propagates_when_first_none() {
177 let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
178 let result = composite.execute("anything").await;
179 assert!(matches!(result, Err(ToolError::Blocked { .. })));
180 }
181
182 #[tokio::test]
183 async fn execute_confirmed_first_matches() {
184 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
185 let result = composite.execute_confirmed("anything").await.unwrap();
186 assert_eq!(result.unwrap().summary, "matched");
187 }
188
189 #[tokio::test]
190 async fn execute_confirmed_falls_through() {
191 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
192 let result = composite.execute_confirmed("anything").await.unwrap();
193 assert_eq!(result.unwrap().summary, "second");
194 }
195
196 #[test]
197 fn composite_debug() {
198 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
199 let debug = format!("{composite:?}");
200 assert!(debug.contains("CompositeExecutor"));
201 }
202
203 #[derive(Debug)]
204 struct FileToolExecutor;
205 impl ToolExecutor for FileToolExecutor {
206 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
207 Ok(None)
208 }
209 async fn execute_tool_call(
210 &self,
211 call: &ToolCall,
212 ) -> Result<Option<ToolOutput>, ToolError> {
213 if call.tool_id == "read" || call.tool_id == "write" {
214 Ok(Some(ToolOutput {
215 tool_name: call.tool_id.clone(),
216 summary: "file_handler".to_owned(),
217 blocks_executed: 1,
218 filter_stats: None,
219 diff: None,
220 streamed: false,
221 terminal_id: None,
222 locations: None,
223 raw_response: None,
224 claim_source: None,
225 }))
226 } else {
227 Ok(None)
228 }
229 }
230 }
231
232 #[derive(Debug)]
233 struct ShellToolExecutor;
234 impl ToolExecutor for ShellToolExecutor {
235 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
236 Ok(None)
237 }
238 async fn execute_tool_call(
239 &self,
240 call: &ToolCall,
241 ) -> Result<Option<ToolOutput>, ToolError> {
242 if call.tool_id == "bash" {
243 Ok(Some(ToolOutput {
244 tool_name: ToolName::new("bash"),
245 summary: "shell_handler".to_owned(),
246 blocks_executed: 1,
247 filter_stats: None,
248 diff: None,
249 streamed: false,
250 terminal_id: None,
251 locations: None,
252 raw_response: None,
253 claim_source: None,
254 }))
255 } else {
256 Ok(None)
257 }
258 }
259 }
260
261 #[tokio::test]
262 async fn tool_call_routes_to_file_executor() {
263 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
264 let call = ToolCall {
265 tool_id: ToolName::new("read"),
266 params: serde_json::Map::new(),
267 caller_id: None,
268 };
269 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
270 assert_eq!(result.summary, "file_handler");
271 }
272
273 #[tokio::test]
274 async fn tool_call_routes_to_shell_executor() {
275 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
276 let call = ToolCall {
277 tool_id: ToolName::new("bash"),
278 params: serde_json::Map::new(),
279 caller_id: None,
280 };
281 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
282 assert_eq!(result.summary, "shell_handler");
283 }
284
285 #[tokio::test]
286 async fn tool_call_unhandled_returns_none() {
287 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
288 let call = ToolCall {
289 tool_id: ToolName::new("unknown"),
290 params: serde_json::Map::new(),
291 caller_id: None,
292 };
293 let result = composite.execute_tool_call(&call).await.unwrap();
294 assert!(result.is_none());
295 }
296}