1pub mod direct;
11pub mod http;
12pub mod llm;
13pub mod mcp;
14pub mod shell;
15
16use std::collections::HashMap;
17use std::sync::Arc;
18
19use async_trait::async_trait;
20use tokio::sync::mpsc;
21
22use crate::error::{SoulError, SoulResult};
23use crate::tool::ToolOutput;
24use crate::types::ToolDefinition;
25
26#[async_trait]
28pub trait ToolExecutor: Send + Sync {
29 async fn execute(
31 &self,
32 definition: &ToolDefinition,
33 call_id: &str,
34 arguments: serde_json::Value,
35 partial_tx: Option<mpsc::UnboundedSender<String>>,
36 ) -> SoulResult<ToolOutput>;
37
38 fn executor_name(&self) -> &str;
40}
41
42#[derive(Debug, Clone)]
44pub struct ConfigTool {
45 pub definition: ToolDefinition,
46 pub executor_name: String,
47 pub executor_config: serde_json::Value,
48}
49
50pub struct ExecutorRegistry {
54 executors: HashMap<String, Arc<dyn ToolExecutor>>,
55 config_tools: HashMap<String, ConfigTool>,
56 fallback: Option<Arc<dyn ToolExecutor>>,
57}
58
59impl ExecutorRegistry {
60 pub fn new() -> Self {
61 Self {
62 executors: HashMap::new(),
63 config_tools: HashMap::new(),
64 fallback: None,
65 }
66 }
67
68 pub fn register_executor(&mut self, executor: Arc<dyn ToolExecutor>) {
70 self.executors
71 .insert(executor.executor_name().to_string(), executor);
72 }
73
74 pub fn register_config_tool(&mut self, tool: ConfigTool) {
76 self.config_tools.insert(tool.definition.name.clone(), tool);
77 }
78
79 pub fn set_fallback(&mut self, executor: Arc<dyn ToolExecutor>) {
81 self.fallback = Some(executor);
82 }
83
84 pub async fn execute(
86 &self,
87 tool_name: &str,
88 call_id: &str,
89 arguments: serde_json::Value,
90 partial_tx: Option<mpsc::UnboundedSender<String>>,
91 ) -> SoulResult<ToolOutput> {
92 if let Some(config_tool) = self.config_tools.get(tool_name) {
94 if let Some(executor) = self.executors.get(&config_tool.executor_name) {
95 return executor
96 .execute(&config_tool.definition, call_id, arguments, partial_tx)
97 .await;
98 }
99 return Err(SoulError::ExecutorNotFound {
100 name: config_tool.executor_name.clone(),
101 });
102 }
103
104 if let Some(fallback) = &self.fallback {
106 let def = ToolDefinition {
107 name: tool_name.to_string(),
108 description: String::new(),
109 input_schema: serde_json::json!({"type": "object"}),
110 };
111 return fallback.execute(&def, call_id, arguments, partial_tx).await;
112 }
113
114 Err(SoulError::ToolExecution {
115 tool_name: tool_name.to_string(),
116 message: format!("No executor found for tool '{tool_name}'"),
117 })
118 }
119
120 pub fn definitions(&self) -> Vec<ToolDefinition> {
122 self.config_tools
123 .values()
124 .map(|ct| ct.definition.clone())
125 .collect()
126 }
127
128 pub fn has_tool(&self, name: &str) -> bool {
130 self.config_tools.contains_key(name) || self.fallback.is_some()
131 }
132}
133
134impl Default for ExecutorRegistry {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use serde_json::json;
144
145 struct EchoExecutor;
146
147 #[async_trait]
148 impl ToolExecutor for EchoExecutor {
149 async fn execute(
150 &self,
151 definition: &ToolDefinition,
152 _call_id: &str,
153 arguments: serde_json::Value,
154 _partial_tx: Option<mpsc::UnboundedSender<String>>,
155 ) -> SoulResult<ToolOutput> {
156 Ok(ToolOutput::success(format!(
157 "{}({})",
158 definition.name, arguments
159 )))
160 }
161
162 fn executor_name(&self) -> &str {
163 "echo"
164 }
165 }
166
167 struct FailExecutor;
168
169 #[async_trait]
170 impl ToolExecutor for FailExecutor {
171 async fn execute(
172 &self,
173 _definition: &ToolDefinition,
174 _call_id: &str,
175 _arguments: serde_json::Value,
176 _partial_tx: Option<mpsc::UnboundedSender<String>>,
177 ) -> SoulResult<ToolOutput> {
178 Ok(ToolOutput::error("always fails"))
179 }
180
181 fn executor_name(&self) -> &str {
182 "fail"
183 }
184 }
185
186 fn config_tool(name: &str, executor: &str) -> ConfigTool {
187 ConfigTool {
188 definition: ToolDefinition {
189 name: name.into(),
190 description: format!("Tool {name}"),
191 input_schema: json!({"type": "object"}),
192 },
193 executor_name: executor.into(),
194 executor_config: json!({}),
195 }
196 }
197
198 #[tokio::test]
199 async fn routes_to_correct_executor() {
200 let mut registry = ExecutorRegistry::new();
201 registry.register_executor(Arc::new(EchoExecutor));
202 registry.register_config_tool(config_tool("my_tool", "echo"));
203
204 let result = registry
205 .execute("my_tool", "c1", json!({"a": 1}), None)
206 .await
207 .unwrap();
208 assert!(result.content.contains("my_tool"));
209 assert!(!result.is_error);
210 }
211
212 #[tokio::test]
213 async fn missing_executor_errors() {
214 let mut registry = ExecutorRegistry::new();
215 registry.register_config_tool(config_tool("my_tool", "nonexistent"));
216
217 let result = registry.execute("my_tool", "c1", json!({}), None).await;
218 assert!(result.is_err());
219 }
220
221 #[tokio::test]
222 async fn unknown_tool_errors() {
223 let registry = ExecutorRegistry::new();
224 let result = registry.execute("unknown", "c1", json!({}), None).await;
225 assert!(result.is_err());
226 }
227
228 #[tokio::test]
229 async fn fallback_executor() {
230 let mut registry = ExecutorRegistry::new();
231 registry.set_fallback(Arc::new(EchoExecutor));
232
233 let result = registry
234 .execute("anything", "c1", json!({}), None)
235 .await
236 .unwrap();
237 assert!(!result.is_error);
238 }
239
240 #[tokio::test]
241 async fn config_tool_takes_priority_over_fallback() {
242 let mut registry = ExecutorRegistry::new();
243 registry.register_executor(Arc::new(FailExecutor));
244 registry.set_fallback(Arc::new(EchoExecutor));
245 registry.register_config_tool(config_tool("my_tool", "fail"));
246
247 let result = registry
248 .execute("my_tool", "c1", json!({}), None)
249 .await
250 .unwrap();
251 assert!(result.is_error);
253 }
254
255 #[test]
256 fn definitions_returns_config_tools() {
257 let mut registry = ExecutorRegistry::new();
258 registry.register_config_tool(config_tool("tool_a", "echo"));
259 registry.register_config_tool(config_tool("tool_b", "echo"));
260
261 let defs = registry.definitions();
262 assert_eq!(defs.len(), 2);
263 }
264
265 #[test]
266 fn has_tool_checks_config_and_fallback() {
267 let mut registry = ExecutorRegistry::new();
268 assert!(!registry.has_tool("anything"));
269
270 registry.register_config_tool(config_tool("my_tool", "echo"));
271 assert!(registry.has_tool("my_tool"));
272 assert!(!registry.has_tool("other"));
273
274 registry.set_fallback(Arc::new(EchoExecutor));
275 assert!(registry.has_tool("other")); }
277
278 #[tokio::test]
279 async fn multiple_executors() {
280 let mut registry = ExecutorRegistry::new();
281 registry.register_executor(Arc::new(EchoExecutor));
282 registry.register_executor(Arc::new(FailExecutor));
283 registry.register_config_tool(config_tool("echo_tool", "echo"));
284 registry.register_config_tool(config_tool("fail_tool", "fail"));
285
286 let r1 = registry
287 .execute("echo_tool", "c1", json!({}), None)
288 .await
289 .unwrap();
290 assert!(!r1.is_error);
291
292 let r2 = registry
293 .execute("fail_tool", "c2", json!({}), None)
294 .await
295 .unwrap();
296 assert!(r2.is_error);
297 }
298}