1use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
5use crate::registry::ToolDef;
6
7#[derive(Debug)]
12pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
13 first: A,
14 second: B,
15}
16
17impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
18 #[must_use]
19 pub fn new(first: A, second: B) -> Self {
20 Self { first, second }
21 }
22}
23
24impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
25 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
26 if let Some(output) = self.first.execute(response).await? {
27 return Ok(Some(output));
28 }
29 self.second.execute(response).await
30 }
31
32 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
33 if let Some(output) = self.first.execute_confirmed(response).await? {
34 return Ok(Some(output));
35 }
36 self.second.execute_confirmed(response).await
37 }
38
39 fn tool_definitions(&self) -> Vec<ToolDef> {
40 let mut defs = self.first.tool_definitions();
41 let seen: std::collections::HashSet<String> =
42 defs.iter().map(|d| d.id.to_string()).collect();
43 for def in self.second.tool_definitions() {
44 if !seen.contains(def.id.as_ref()) {
45 defs.push(def);
46 }
47 }
48 defs
49 }
50
51 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
52 if let Some(output) = self.first.execute_tool_call(call).await? {
53 return Ok(Some(output));
54 }
55 self.second.execute_tool_call(call).await
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62
63 #[derive(Debug)]
64 struct MatchingExecutor;
65 impl ToolExecutor for MatchingExecutor {
66 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
67 Ok(Some(ToolOutput {
68 tool_name: "test".to_owned(),
69 summary: "matched".to_owned(),
70 blocks_executed: 1,
71 filter_stats: None,
72 diff: None,
73 streamed: false,
74 terminal_id: None,
75 locations: None,
76 raw_response: None,
77 }))
78 }
79 }
80
81 #[derive(Debug)]
82 struct NoMatchExecutor;
83 impl ToolExecutor for NoMatchExecutor {
84 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
85 Ok(None)
86 }
87 }
88
89 #[derive(Debug)]
90 struct ErrorExecutor;
91 impl ToolExecutor for ErrorExecutor {
92 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
93 Err(ToolError::Blocked {
94 command: "test".to_owned(),
95 })
96 }
97 }
98
99 #[derive(Debug)]
100 struct SecondExecutor;
101 impl ToolExecutor for SecondExecutor {
102 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
103 Ok(Some(ToolOutput {
104 tool_name: "test".to_owned(),
105 summary: "second".to_owned(),
106 blocks_executed: 1,
107 filter_stats: None,
108 diff: None,
109 streamed: false,
110 terminal_id: None,
111 locations: None,
112 raw_response: None,
113 }))
114 }
115 }
116
117 #[tokio::test]
118 async fn first_matches_returns_first() {
119 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
120 let result = composite.execute("anything").await.unwrap();
121 assert_eq!(result.unwrap().summary, "matched");
122 }
123
124 #[tokio::test]
125 async fn first_none_falls_through_to_second() {
126 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
127 let result = composite.execute("anything").await.unwrap();
128 assert_eq!(result.unwrap().summary, "second");
129 }
130
131 #[tokio::test]
132 async fn both_none_returns_none() {
133 let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
134 let result = composite.execute("anything").await.unwrap();
135 assert!(result.is_none());
136 }
137
138 #[tokio::test]
139 async fn first_error_propagates_without_trying_second() {
140 let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
141 let result = composite.execute("anything").await;
142 assert!(matches!(result, Err(ToolError::Blocked { .. })));
143 }
144
145 #[tokio::test]
146 async fn second_error_propagates_when_first_none() {
147 let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
148 let result = composite.execute("anything").await;
149 assert!(matches!(result, Err(ToolError::Blocked { .. })));
150 }
151
152 #[tokio::test]
153 async fn execute_confirmed_first_matches() {
154 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
155 let result = composite.execute_confirmed("anything").await.unwrap();
156 assert_eq!(result.unwrap().summary, "matched");
157 }
158
159 #[tokio::test]
160 async fn execute_confirmed_falls_through() {
161 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
162 let result = composite.execute_confirmed("anything").await.unwrap();
163 assert_eq!(result.unwrap().summary, "second");
164 }
165
166 #[test]
167 fn composite_debug() {
168 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
169 let debug = format!("{composite:?}");
170 assert!(debug.contains("CompositeExecutor"));
171 }
172
173 #[derive(Debug)]
174 struct FileToolExecutor;
175 impl ToolExecutor for FileToolExecutor {
176 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
177 Ok(None)
178 }
179 async fn execute_tool_call(
180 &self,
181 call: &ToolCall,
182 ) -> Result<Option<ToolOutput>, ToolError> {
183 if call.tool_id == "read" || call.tool_id == "write" {
184 Ok(Some(ToolOutput {
185 tool_name: call.tool_id.clone(),
186 summary: "file_handler".to_owned(),
187 blocks_executed: 1,
188 filter_stats: None,
189 diff: None,
190 streamed: false,
191 terminal_id: None,
192 locations: None,
193 raw_response: None,
194 }))
195 } else {
196 Ok(None)
197 }
198 }
199 }
200
201 #[derive(Debug)]
202 struct ShellToolExecutor;
203 impl ToolExecutor for ShellToolExecutor {
204 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
205 Ok(None)
206 }
207 async fn execute_tool_call(
208 &self,
209 call: &ToolCall,
210 ) -> Result<Option<ToolOutput>, ToolError> {
211 if call.tool_id == "bash" {
212 Ok(Some(ToolOutput {
213 tool_name: "bash".to_owned(),
214 summary: "shell_handler".to_owned(),
215 blocks_executed: 1,
216 filter_stats: None,
217 diff: None,
218 streamed: false,
219 terminal_id: None,
220 locations: None,
221 raw_response: None,
222 }))
223 } else {
224 Ok(None)
225 }
226 }
227 }
228
229 #[tokio::test]
230 async fn tool_call_routes_to_file_executor() {
231 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
232 let call = ToolCall {
233 tool_id: "read".to_owned(),
234 params: serde_json::Map::new(),
235 };
236 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
237 assert_eq!(result.summary, "file_handler");
238 }
239
240 #[tokio::test]
241 async fn tool_call_routes_to_shell_executor() {
242 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
243 let call = ToolCall {
244 tool_id: "bash".to_owned(),
245 params: serde_json::Map::new(),
246 };
247 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
248 assert_eq!(result.summary, "shell_handler");
249 }
250
251 #[tokio::test]
252 async fn tool_call_unhandled_returns_none() {
253 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
254 let call = ToolCall {
255 tool_id: "unknown".to_owned(),
256 params: serde_json::Map::new(),
257 };
258 let result = composite.execute_tool_call(&call).await.unwrap();
259 assert!(result.is_none());
260 }
261}