1use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
5use crate::registry::ToolDef;
6
7#[derive(Debug)]
12pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
13 first: A,
14 second: B,
15}
16
17impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
18 #[must_use]
19 pub fn new(first: A, second: B) -> Self {
20 Self { first, second }
21 }
22}
23
24impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
25 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
26 if let Some(output) = self.first.execute(response).await? {
27 return Ok(Some(output));
28 }
29 self.second.execute(response).await
30 }
31
32 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
33 if let Some(output) = self.first.execute_confirmed(response).await? {
34 return Ok(Some(output));
35 }
36 self.second.execute_confirmed(response).await
37 }
38
39 fn tool_definitions(&self) -> Vec<ToolDef> {
40 let mut defs = self.first.tool_definitions();
41 let seen: std::collections::HashSet<String> =
42 defs.iter().map(|d| d.id.to_string()).collect();
43 for def in self.second.tool_definitions() {
44 if !seen.contains(def.id.as_ref()) {
45 defs.push(def);
46 }
47 }
48 defs
49 }
50
51 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
52 if let Some(output) = self.first.execute_tool_call(call).await? {
53 return Ok(Some(output));
54 }
55 self.second.execute_tool_call(call).await
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62
63 #[derive(Debug)]
64 struct MatchingExecutor;
65 impl ToolExecutor for MatchingExecutor {
66 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
67 Ok(Some(ToolOutput {
68 tool_name: "test".to_owned(),
69 summary: "matched".to_owned(),
70 blocks_executed: 1,
71 filter_stats: None,
72 diff: None,
73 streamed: false,
74 terminal_id: None,
75 }))
76 }
77 }
78
79 #[derive(Debug)]
80 struct NoMatchExecutor;
81 impl ToolExecutor for NoMatchExecutor {
82 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
83 Ok(None)
84 }
85 }
86
87 #[derive(Debug)]
88 struct ErrorExecutor;
89 impl ToolExecutor for ErrorExecutor {
90 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
91 Err(ToolError::Blocked {
92 command: "test".to_owned(),
93 })
94 }
95 }
96
97 #[derive(Debug)]
98 struct SecondExecutor;
99 impl ToolExecutor for SecondExecutor {
100 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
101 Ok(Some(ToolOutput {
102 tool_name: "test".to_owned(),
103 summary: "second".to_owned(),
104 blocks_executed: 1,
105 filter_stats: None,
106 diff: None,
107 streamed: false,
108 terminal_id: None,
109 }))
110 }
111 }
112
113 #[tokio::test]
114 async fn first_matches_returns_first() {
115 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
116 let result = composite.execute("anything").await.unwrap();
117 assert_eq!(result.unwrap().summary, "matched");
118 }
119
120 #[tokio::test]
121 async fn first_none_falls_through_to_second() {
122 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
123 let result = composite.execute("anything").await.unwrap();
124 assert_eq!(result.unwrap().summary, "second");
125 }
126
127 #[tokio::test]
128 async fn both_none_returns_none() {
129 let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
130 let result = composite.execute("anything").await.unwrap();
131 assert!(result.is_none());
132 }
133
134 #[tokio::test]
135 async fn first_error_propagates_without_trying_second() {
136 let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
137 let result = composite.execute("anything").await;
138 assert!(matches!(result, Err(ToolError::Blocked { .. })));
139 }
140
141 #[tokio::test]
142 async fn second_error_propagates_when_first_none() {
143 let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
144 let result = composite.execute("anything").await;
145 assert!(matches!(result, Err(ToolError::Blocked { .. })));
146 }
147
148 #[tokio::test]
149 async fn execute_confirmed_first_matches() {
150 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
151 let result = composite.execute_confirmed("anything").await.unwrap();
152 assert_eq!(result.unwrap().summary, "matched");
153 }
154
155 #[tokio::test]
156 async fn execute_confirmed_falls_through() {
157 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
158 let result = composite.execute_confirmed("anything").await.unwrap();
159 assert_eq!(result.unwrap().summary, "second");
160 }
161
162 #[test]
163 fn composite_debug() {
164 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
165 let debug = format!("{composite:?}");
166 assert!(debug.contains("CompositeExecutor"));
167 }
168
169 #[derive(Debug)]
170 struct FileToolExecutor;
171 impl ToolExecutor for FileToolExecutor {
172 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
173 Ok(None)
174 }
175 async fn execute_tool_call(
176 &self,
177 call: &ToolCall,
178 ) -> Result<Option<ToolOutput>, ToolError> {
179 if call.tool_id == "read" || call.tool_id == "write" {
180 Ok(Some(ToolOutput {
181 tool_name: call.tool_id.clone(),
182 summary: "file_handler".to_owned(),
183 blocks_executed: 1,
184 filter_stats: None,
185 diff: None,
186 streamed: false,
187 terminal_id: None,
188 }))
189 } else {
190 Ok(None)
191 }
192 }
193 }
194
195 #[derive(Debug)]
196 struct ShellToolExecutor;
197 impl ToolExecutor for ShellToolExecutor {
198 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
199 Ok(None)
200 }
201 async fn execute_tool_call(
202 &self,
203 call: &ToolCall,
204 ) -> Result<Option<ToolOutput>, ToolError> {
205 if call.tool_id == "bash" {
206 Ok(Some(ToolOutput {
207 tool_name: "bash".to_owned(),
208 summary: "shell_handler".to_owned(),
209 blocks_executed: 1,
210 filter_stats: None,
211 diff: None,
212 streamed: false,
213 terminal_id: None,
214 }))
215 } else {
216 Ok(None)
217 }
218 }
219 }
220
221 #[tokio::test]
222 async fn tool_call_routes_to_file_executor() {
223 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
224 let call = ToolCall {
225 tool_id: "read".to_owned(),
226 params: serde_json::Map::new(),
227 };
228 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
229 assert_eq!(result.summary, "file_handler");
230 }
231
232 #[tokio::test]
233 async fn tool_call_routes_to_shell_executor() {
234 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
235 let call = ToolCall {
236 tool_id: "bash".to_owned(),
237 params: serde_json::Map::new(),
238 };
239 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
240 assert_eq!(result.summary, "shell_handler");
241 }
242
243 #[tokio::test]
244 async fn tool_call_unhandled_returns_none() {
245 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
246 let call = ToolCall {
247 tool_id: "unknown".to_owned(),
248 params: serde_json::Map::new(),
249 };
250 let result = composite.execute_tool_call(&call).await.unwrap();
251 assert!(result.is_none());
252 }
253}