1use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
5use crate::registry::ToolDef;
6
7#[derive(Debug)]
12pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
13 first: A,
14 second: B,
15}
16
17impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
18 #[must_use]
19 pub fn new(first: A, second: B) -> Self {
20 Self { first, second }
21 }
22}
23
24impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
25 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
26 if let Some(output) = self.first.execute(response).await? {
27 return Ok(Some(output));
28 }
29 self.second.execute(response).await
30 }
31
32 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
33 if let Some(output) = self.first.execute_confirmed(response).await? {
34 return Ok(Some(output));
35 }
36 self.second.execute_confirmed(response).await
37 }
38
39 fn tool_definitions(&self) -> Vec<ToolDef> {
40 let mut defs = self.first.tool_definitions();
41 let seen: std::collections::HashSet<String> =
42 defs.iter().map(|d| d.id.to_string()).collect();
43 for def in self.second.tool_definitions() {
44 if !seen.contains(def.id.as_ref()) {
45 defs.push(def);
46 }
47 }
48 defs
49 }
50
51 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
52 if let Some(output) = self.first.execute_tool_call(call).await? {
53 return Ok(Some(output));
54 }
55 self.second.execute_tool_call(call).await
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62
63 #[derive(Debug)]
64 struct MatchingExecutor;
65 impl ToolExecutor for MatchingExecutor {
66 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
67 Ok(Some(ToolOutput {
68 tool_name: "test".to_owned(),
69 summary: "matched".to_owned(),
70 blocks_executed: 1,
71 filter_stats: None,
72 diff: None,
73 streamed: false,
74 terminal_id: None,
75 locations: None,
76 }))
77 }
78 }
79
80 #[derive(Debug)]
81 struct NoMatchExecutor;
82 impl ToolExecutor for NoMatchExecutor {
83 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
84 Ok(None)
85 }
86 }
87
88 #[derive(Debug)]
89 struct ErrorExecutor;
90 impl ToolExecutor for ErrorExecutor {
91 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
92 Err(ToolError::Blocked {
93 command: "test".to_owned(),
94 })
95 }
96 }
97
98 #[derive(Debug)]
99 struct SecondExecutor;
100 impl ToolExecutor for SecondExecutor {
101 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
102 Ok(Some(ToolOutput {
103 tool_name: "test".to_owned(),
104 summary: "second".to_owned(),
105 blocks_executed: 1,
106 filter_stats: None,
107 diff: None,
108 streamed: false,
109 terminal_id: None,
110 locations: None,
111 }))
112 }
113 }
114
115 #[tokio::test]
116 async fn first_matches_returns_first() {
117 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
118 let result = composite.execute("anything").await.unwrap();
119 assert_eq!(result.unwrap().summary, "matched");
120 }
121
122 #[tokio::test]
123 async fn first_none_falls_through_to_second() {
124 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
125 let result = composite.execute("anything").await.unwrap();
126 assert_eq!(result.unwrap().summary, "second");
127 }
128
129 #[tokio::test]
130 async fn both_none_returns_none() {
131 let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
132 let result = composite.execute("anything").await.unwrap();
133 assert!(result.is_none());
134 }
135
136 #[tokio::test]
137 async fn first_error_propagates_without_trying_second() {
138 let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
139 let result = composite.execute("anything").await;
140 assert!(matches!(result, Err(ToolError::Blocked { .. })));
141 }
142
143 #[tokio::test]
144 async fn second_error_propagates_when_first_none() {
145 let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
146 let result = composite.execute("anything").await;
147 assert!(matches!(result, Err(ToolError::Blocked { .. })));
148 }
149
150 #[tokio::test]
151 async fn execute_confirmed_first_matches() {
152 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
153 let result = composite.execute_confirmed("anything").await.unwrap();
154 assert_eq!(result.unwrap().summary, "matched");
155 }
156
157 #[tokio::test]
158 async fn execute_confirmed_falls_through() {
159 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
160 let result = composite.execute_confirmed("anything").await.unwrap();
161 assert_eq!(result.unwrap().summary, "second");
162 }
163
164 #[test]
165 fn composite_debug() {
166 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
167 let debug = format!("{composite:?}");
168 assert!(debug.contains("CompositeExecutor"));
169 }
170
171 #[derive(Debug)]
172 struct FileToolExecutor;
173 impl ToolExecutor for FileToolExecutor {
174 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
175 Ok(None)
176 }
177 async fn execute_tool_call(
178 &self,
179 call: &ToolCall,
180 ) -> Result<Option<ToolOutput>, ToolError> {
181 if call.tool_id == "read" || call.tool_id == "write" {
182 Ok(Some(ToolOutput {
183 tool_name: call.tool_id.clone(),
184 summary: "file_handler".to_owned(),
185 blocks_executed: 1,
186 filter_stats: None,
187 diff: None,
188 streamed: false,
189 terminal_id: None,
190 locations: None,
191 }))
192 } else {
193 Ok(None)
194 }
195 }
196 }
197
198 #[derive(Debug)]
199 struct ShellToolExecutor;
200 impl ToolExecutor for ShellToolExecutor {
201 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
202 Ok(None)
203 }
204 async fn execute_tool_call(
205 &self,
206 call: &ToolCall,
207 ) -> Result<Option<ToolOutput>, ToolError> {
208 if call.tool_id == "bash" {
209 Ok(Some(ToolOutput {
210 tool_name: "bash".to_owned(),
211 summary: "shell_handler".to_owned(),
212 blocks_executed: 1,
213 filter_stats: None,
214 diff: None,
215 streamed: false,
216 terminal_id: None,
217 locations: None,
218 }))
219 } else {
220 Ok(None)
221 }
222 }
223 }
224
225 #[tokio::test]
226 async fn tool_call_routes_to_file_executor() {
227 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
228 let call = ToolCall {
229 tool_id: "read".to_owned(),
230 params: serde_json::Map::new(),
231 };
232 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
233 assert_eq!(result.summary, "file_handler");
234 }
235
236 #[tokio::test]
237 async fn tool_call_routes_to_shell_executor() {
238 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
239 let call = ToolCall {
240 tool_id: "bash".to_owned(),
241 params: serde_json::Map::new(),
242 };
243 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
244 assert_eq!(result.summary, "shell_handler");
245 }
246
247 #[tokio::test]
248 async fn tool_call_unhandled_returns_none() {
249 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
250 let call = ToolCall {
251 tool_id: "unknown".to_owned(),
252 params: serde_json::Map::new(),
253 };
254 let result = composite.execute_tool_call(&call).await.unwrap();
255 assert!(result.is_none());
256 }
257}