1use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
5use crate::registry::ToolDef;
6
7#[derive(Debug)]
12pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
13 first: A,
14 second: B,
15}
16
17impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
18 #[must_use]
19 pub fn new(first: A, second: B) -> Self {
20 Self { first, second }
21 }
22}
23
24impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
25 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
26 if let Some(output) = self.first.execute(response).await? {
27 return Ok(Some(output));
28 }
29 self.second.execute(response).await
30 }
31
32 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
33 if let Some(output) = self.first.execute_confirmed(response).await? {
34 return Ok(Some(output));
35 }
36 self.second.execute_confirmed(response).await
37 }
38
39 fn tool_definitions(&self) -> Vec<ToolDef> {
40 let mut defs = self.first.tool_definitions();
41 let seen: std::collections::HashSet<&str> = defs.iter().map(|d| d.id).collect();
42 for def in self.second.tool_definitions() {
43 if !seen.contains(def.id) {
44 defs.push(def);
45 }
46 }
47 defs
48 }
49
50 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
51 if let Some(output) = self.first.execute_tool_call(call).await? {
52 return Ok(Some(output));
53 }
54 self.second.execute_tool_call(call).await
55 }
56}
57
58#[cfg(test)]
59mod tests {
60 use super::*;
61
62 #[derive(Debug)]
63 struct MatchingExecutor;
64 impl ToolExecutor for MatchingExecutor {
65 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
66 Ok(Some(ToolOutput {
67 tool_name: "test".to_owned(),
68 summary: "matched".to_owned(),
69 blocks_executed: 1,
70 filter_stats: None,
71 diff: None,
72 streamed: false,
73 }))
74 }
75 }
76
77 #[derive(Debug)]
78 struct NoMatchExecutor;
79 impl ToolExecutor for NoMatchExecutor {
80 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
81 Ok(None)
82 }
83 }
84
85 #[derive(Debug)]
86 struct ErrorExecutor;
87 impl ToolExecutor for ErrorExecutor {
88 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
89 Err(ToolError::Blocked {
90 command: "test".to_owned(),
91 })
92 }
93 }
94
95 #[derive(Debug)]
96 struct SecondExecutor;
97 impl ToolExecutor for SecondExecutor {
98 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
99 Ok(Some(ToolOutput {
100 tool_name: "test".to_owned(),
101 summary: "second".to_owned(),
102 blocks_executed: 1,
103 filter_stats: None,
104 diff: None,
105 streamed: false,
106 }))
107 }
108 }
109
110 #[tokio::test]
111 async fn first_matches_returns_first() {
112 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
113 let result = composite.execute("anything").await.unwrap();
114 assert_eq!(result.unwrap().summary, "matched");
115 }
116
117 #[tokio::test]
118 async fn first_none_falls_through_to_second() {
119 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
120 let result = composite.execute("anything").await.unwrap();
121 assert_eq!(result.unwrap().summary, "second");
122 }
123
124 #[tokio::test]
125 async fn both_none_returns_none() {
126 let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
127 let result = composite.execute("anything").await.unwrap();
128 assert!(result.is_none());
129 }
130
131 #[tokio::test]
132 async fn first_error_propagates_without_trying_second() {
133 let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
134 let result = composite.execute("anything").await;
135 assert!(matches!(result, Err(ToolError::Blocked { .. })));
136 }
137
138 #[tokio::test]
139 async fn second_error_propagates_when_first_none() {
140 let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
141 let result = composite.execute("anything").await;
142 assert!(matches!(result, Err(ToolError::Blocked { .. })));
143 }
144
145 #[tokio::test]
146 async fn execute_confirmed_first_matches() {
147 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
148 let result = composite.execute_confirmed("anything").await.unwrap();
149 assert_eq!(result.unwrap().summary, "matched");
150 }
151
152 #[tokio::test]
153 async fn execute_confirmed_falls_through() {
154 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
155 let result = composite.execute_confirmed("anything").await.unwrap();
156 assert_eq!(result.unwrap().summary, "second");
157 }
158
159 #[test]
160 fn composite_debug() {
161 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
162 let debug = format!("{composite:?}");
163 assert!(debug.contains("CompositeExecutor"));
164 }
165
166 #[derive(Debug)]
167 struct FileToolExecutor;
168 impl ToolExecutor for FileToolExecutor {
169 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
170 Ok(None)
171 }
172 async fn execute_tool_call(
173 &self,
174 call: &ToolCall,
175 ) -> Result<Option<ToolOutput>, ToolError> {
176 if call.tool_id == "read" || call.tool_id == "write" {
177 Ok(Some(ToolOutput {
178 tool_name: call.tool_id.clone(),
179 summary: "file_handler".to_owned(),
180 blocks_executed: 1,
181 filter_stats: None,
182 diff: None,
183 streamed: false,
184 }))
185 } else {
186 Ok(None)
187 }
188 }
189 }
190
191 #[derive(Debug)]
192 struct ShellToolExecutor;
193 impl ToolExecutor for ShellToolExecutor {
194 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
195 Ok(None)
196 }
197 async fn execute_tool_call(
198 &self,
199 call: &ToolCall,
200 ) -> Result<Option<ToolOutput>, ToolError> {
201 if call.tool_id == "bash" {
202 Ok(Some(ToolOutput {
203 tool_name: "bash".to_owned(),
204 summary: "shell_handler".to_owned(),
205 blocks_executed: 1,
206 filter_stats: None,
207 diff: None,
208 streamed: false,
209 }))
210 } else {
211 Ok(None)
212 }
213 }
214 }
215
216 #[tokio::test]
217 async fn tool_call_routes_to_file_executor() {
218 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
219 let call = ToolCall {
220 tool_id: "read".to_owned(),
221 params: serde_json::Map::new(),
222 };
223 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
224 assert_eq!(result.summary, "file_handler");
225 }
226
227 #[tokio::test]
228 async fn tool_call_routes_to_shell_executor() {
229 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
230 let call = ToolCall {
231 tool_id: "bash".to_owned(),
232 params: serde_json::Map::new(),
233 };
234 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
235 assert_eq!(result.summary, "shell_handler");
236 }
237
238 #[tokio::test]
239 async fn tool_call_unhandled_returns_none() {
240 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
241 let call = ToolCall {
242 tool_id: "unknown".to_owned(),
243 params: serde_json::Map::new(),
244 };
245 let result = composite.execute_tool_call(&call).await.unwrap();
246 assert!(result.is_none());
247 }
248}