1use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
5use crate::registry::ToolDef;
6
7#[derive(Debug)]
12pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
13 first: A,
14 second: B,
15}
16
17impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
18 #[must_use]
19 pub fn new(first: A, second: B) -> Self {
20 Self { first, second }
21 }
22}
23
24impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
25 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
26 if let Some(output) = self.first.execute(response).await? {
27 return Ok(Some(output));
28 }
29 self.second.execute(response).await
30 }
31
32 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
33 if let Some(output) = self.first.execute_confirmed(response).await? {
34 return Ok(Some(output));
35 }
36 self.second.execute_confirmed(response).await
37 }
38
39 fn tool_definitions(&self) -> Vec<ToolDef> {
40 let mut defs = self.first.tool_definitions();
41 let seen: std::collections::HashSet<String> =
42 defs.iter().map(|d| d.id.to_string()).collect();
43 for def in self.second.tool_definitions() {
44 if !seen.contains(def.id.as_ref()) {
45 defs.push(def);
46 }
47 }
48 defs
49 }
50
51 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
52 if let Some(output) = self.first.execute_tool_call(call).await? {
53 return Ok(Some(output));
54 }
55 self.second.execute_tool_call(call).await
56 }
57
58 fn is_tool_retryable(&self, tool_id: &str) -> bool {
59 self.first.is_tool_retryable(tool_id) || self.second.is_tool_retryable(tool_id)
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 #[derive(Debug)]
68 struct MatchingExecutor;
69 impl ToolExecutor for MatchingExecutor {
70 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
71 Ok(Some(ToolOutput {
72 tool_name: "test".to_owned(),
73 summary: "matched".to_owned(),
74 blocks_executed: 1,
75 filter_stats: None,
76 diff: None,
77 streamed: false,
78 terminal_id: None,
79 locations: None,
80 raw_response: None,
81 claim_source: None,
82 }))
83 }
84 }
85
86 #[derive(Debug)]
87 struct NoMatchExecutor;
88 impl ToolExecutor for NoMatchExecutor {
89 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
90 Ok(None)
91 }
92 }
93
94 #[derive(Debug)]
95 struct ErrorExecutor;
96 impl ToolExecutor for ErrorExecutor {
97 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
98 Err(ToolError::Blocked {
99 command: "test".to_owned(),
100 })
101 }
102 }
103
104 #[derive(Debug)]
105 struct SecondExecutor;
106 impl ToolExecutor for SecondExecutor {
107 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
108 Ok(Some(ToolOutput {
109 tool_name: "test".to_owned(),
110 summary: "second".to_owned(),
111 blocks_executed: 1,
112 filter_stats: None,
113 diff: None,
114 streamed: false,
115 terminal_id: None,
116 locations: None,
117 raw_response: None,
118 claim_source: None,
119 }))
120 }
121 }
122
123 #[tokio::test]
124 async fn first_matches_returns_first() {
125 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
126 let result = composite.execute("anything").await.unwrap();
127 assert_eq!(result.unwrap().summary, "matched");
128 }
129
130 #[tokio::test]
131 async fn first_none_falls_through_to_second() {
132 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
133 let result = composite.execute("anything").await.unwrap();
134 assert_eq!(result.unwrap().summary, "second");
135 }
136
137 #[tokio::test]
138 async fn both_none_returns_none() {
139 let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
140 let result = composite.execute("anything").await.unwrap();
141 assert!(result.is_none());
142 }
143
144 #[tokio::test]
145 async fn first_error_propagates_without_trying_second() {
146 let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
147 let result = composite.execute("anything").await;
148 assert!(matches!(result, Err(ToolError::Blocked { .. })));
149 }
150
151 #[tokio::test]
152 async fn second_error_propagates_when_first_none() {
153 let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
154 let result = composite.execute("anything").await;
155 assert!(matches!(result, Err(ToolError::Blocked { .. })));
156 }
157
158 #[tokio::test]
159 async fn execute_confirmed_first_matches() {
160 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
161 let result = composite.execute_confirmed("anything").await.unwrap();
162 assert_eq!(result.unwrap().summary, "matched");
163 }
164
165 #[tokio::test]
166 async fn execute_confirmed_falls_through() {
167 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
168 let result = composite.execute_confirmed("anything").await.unwrap();
169 assert_eq!(result.unwrap().summary, "second");
170 }
171
172 #[test]
173 fn composite_debug() {
174 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
175 let debug = format!("{composite:?}");
176 assert!(debug.contains("CompositeExecutor"));
177 }
178
179 #[derive(Debug)]
180 struct FileToolExecutor;
181 impl ToolExecutor for FileToolExecutor {
182 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
183 Ok(None)
184 }
185 async fn execute_tool_call(
186 &self,
187 call: &ToolCall,
188 ) -> Result<Option<ToolOutput>, ToolError> {
189 if call.tool_id == "read" || call.tool_id == "write" {
190 Ok(Some(ToolOutput {
191 tool_name: call.tool_id.clone(),
192 summary: "file_handler".to_owned(),
193 blocks_executed: 1,
194 filter_stats: None,
195 diff: None,
196 streamed: false,
197 terminal_id: None,
198 locations: None,
199 raw_response: None,
200 claim_source: None,
201 }))
202 } else {
203 Ok(None)
204 }
205 }
206 }
207
208 #[derive(Debug)]
209 struct ShellToolExecutor;
210 impl ToolExecutor for ShellToolExecutor {
211 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
212 Ok(None)
213 }
214 async fn execute_tool_call(
215 &self,
216 call: &ToolCall,
217 ) -> Result<Option<ToolOutput>, ToolError> {
218 if call.tool_id == "bash" {
219 Ok(Some(ToolOutput {
220 tool_name: "bash".to_owned(),
221 summary: "shell_handler".to_owned(),
222 blocks_executed: 1,
223 filter_stats: None,
224 diff: None,
225 streamed: false,
226 terminal_id: None,
227 locations: None,
228 raw_response: None,
229 claim_source: None,
230 }))
231 } else {
232 Ok(None)
233 }
234 }
235 }
236
237 #[tokio::test]
238 async fn tool_call_routes_to_file_executor() {
239 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
240 let call = ToolCall {
241 tool_id: "read".to_owned(),
242 params: serde_json::Map::new(),
243 caller_id: None,
244 };
245 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
246 assert_eq!(result.summary, "file_handler");
247 }
248
249 #[tokio::test]
250 async fn tool_call_routes_to_shell_executor() {
251 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
252 let call = ToolCall {
253 tool_id: "bash".to_owned(),
254 params: serde_json::Map::new(),
255 caller_id: None,
256 };
257 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
258 assert_eq!(result.summary, "shell_handler");
259 }
260
261 #[tokio::test]
262 async fn tool_call_unhandled_returns_none() {
263 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
264 let call = ToolCall {
265 tool_id: "unknown".to_owned(),
266 params: serde_json::Map::new(),
267 caller_id: None,
268 };
269 let result = composite.execute_tool_call(&call).await.unwrap();
270 assert!(result.is_none());
271 }
272}