1pub mod direct;
11#[cfg(feature = "native")]
12pub mod http;
13pub mod llm;
14pub mod mcp;
15pub mod shell;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use async_trait::async_trait;
21use tokio::sync::mpsc;
22
23use crate::error::{SoulError, SoulResult};
24use crate::tool::ToolOutput;
25use crate::types::ToolDefinition;
26
27#[async_trait]
29pub trait ToolExecutor: Send + Sync {
30 async fn execute(
32 &self,
33 definition: &ToolDefinition,
34 call_id: &str,
35 arguments: serde_json::Value,
36 partial_tx: Option<mpsc::UnboundedSender<String>>,
37 ) -> SoulResult<ToolOutput>;
38
39 fn executor_name(&self) -> &str;
41}
42
43#[derive(Debug, Clone)]
45pub struct ConfigTool {
46 pub definition: ToolDefinition,
47 pub executor_name: String,
48 pub executor_config: serde_json::Value,
49}
50
51pub struct ExecutorRegistry {
55 executors: HashMap<String, Arc<dyn ToolExecutor>>,
56 config_tools: HashMap<String, ConfigTool>,
57 fallback: Option<Arc<dyn ToolExecutor>>,
58}
59
60impl ExecutorRegistry {
61 pub fn new() -> Self {
62 Self {
63 executors: HashMap::new(),
64 config_tools: HashMap::new(),
65 fallback: None,
66 }
67 }
68
69 pub fn register_executor(&mut self, executor: Arc<dyn ToolExecutor>) {
71 self.executors
72 .insert(executor.executor_name().to_string(), executor);
73 }
74
75 pub fn register_config_tool(&mut self, tool: ConfigTool) {
77 self.config_tools.insert(tool.definition.name.clone(), tool);
78 }
79
80 pub fn set_fallback(&mut self, executor: Arc<dyn ToolExecutor>) {
82 self.fallback = Some(executor);
83 }
84
85 pub async fn execute(
87 &self,
88 tool_name: &str,
89 call_id: &str,
90 arguments: serde_json::Value,
91 partial_tx: Option<mpsc::UnboundedSender<String>>,
92 ) -> SoulResult<ToolOutput> {
93 if let Some(config_tool) = self.config_tools.get(tool_name) {
95 if let Some(executor) = self.executors.get(&config_tool.executor_name) {
96 return executor
97 .execute(&config_tool.definition, call_id, arguments, partial_tx)
98 .await;
99 }
100 return Err(SoulError::ExecutorNotFound {
101 name: config_tool.executor_name.clone(),
102 });
103 }
104
105 if let Some(fallback) = &self.fallback {
107 let def = ToolDefinition {
108 name: tool_name.to_string(),
109 description: String::new(),
110 input_schema: serde_json::json!({"type": "object"}),
111 };
112 return fallback.execute(&def, call_id, arguments, partial_tx).await;
113 }
114
115 Err(SoulError::ToolExecution {
116 tool_name: tool_name.to_string(),
117 message: format!("No executor found for tool '{tool_name}'"),
118 })
119 }
120
121 pub fn definitions(&self) -> Vec<ToolDefinition> {
123 self.config_tools
124 .values()
125 .map(|ct| ct.definition.clone())
126 .collect()
127 }
128
129 pub fn has_tool(&self, name: &str) -> bool {
131 self.config_tools.contains_key(name) || self.fallback.is_some()
132 }
133}
134
135impl Default for ExecutorRegistry {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use serde_json::json;
145
146 struct EchoExecutor;
147
148 #[async_trait]
149 impl ToolExecutor for EchoExecutor {
150 async fn execute(
151 &self,
152 definition: &ToolDefinition,
153 _call_id: &str,
154 arguments: serde_json::Value,
155 _partial_tx: Option<mpsc::UnboundedSender<String>>,
156 ) -> SoulResult<ToolOutput> {
157 Ok(ToolOutput::success(format!(
158 "{}({})",
159 definition.name, arguments
160 )))
161 }
162
163 fn executor_name(&self) -> &str {
164 "echo"
165 }
166 }
167
168 struct FailExecutor;
169
170 #[async_trait]
171 impl ToolExecutor for FailExecutor {
172 async fn execute(
173 &self,
174 _definition: &ToolDefinition,
175 _call_id: &str,
176 _arguments: serde_json::Value,
177 _partial_tx: Option<mpsc::UnboundedSender<String>>,
178 ) -> SoulResult<ToolOutput> {
179 Ok(ToolOutput::error("always fails"))
180 }
181
182 fn executor_name(&self) -> &str {
183 "fail"
184 }
185 }
186
187 fn config_tool(name: &str, executor: &str) -> ConfigTool {
188 ConfigTool {
189 definition: ToolDefinition {
190 name: name.into(),
191 description: format!("Tool {name}"),
192 input_schema: json!({"type": "object"}),
193 },
194 executor_name: executor.into(),
195 executor_config: json!({}),
196 }
197 }
198
199 #[tokio::test]
200 async fn routes_to_correct_executor() {
201 let mut registry = ExecutorRegistry::new();
202 registry.register_executor(Arc::new(EchoExecutor));
203 registry.register_config_tool(config_tool("my_tool", "echo"));
204
205 let result = registry
206 .execute("my_tool", "c1", json!({"a": 1}), None)
207 .await
208 .unwrap();
209 assert!(result.content.contains("my_tool"));
210 assert!(!result.is_error);
211 }
212
213 #[tokio::test]
214 async fn missing_executor_errors() {
215 let mut registry = ExecutorRegistry::new();
216 registry.register_config_tool(config_tool("my_tool", "nonexistent"));
217
218 let result = registry.execute("my_tool", "c1", json!({}), None).await;
219 assert!(result.is_err());
220 }
221
222 #[tokio::test]
223 async fn unknown_tool_errors() {
224 let registry = ExecutorRegistry::new();
225 let result = registry.execute("unknown", "c1", json!({}), None).await;
226 assert!(result.is_err());
227 }
228
229 #[tokio::test]
230 async fn fallback_executor() {
231 let mut registry = ExecutorRegistry::new();
232 registry.set_fallback(Arc::new(EchoExecutor));
233
234 let result = registry
235 .execute("anything", "c1", json!({}), None)
236 .await
237 .unwrap();
238 assert!(!result.is_error);
239 }
240
241 #[tokio::test]
242 async fn config_tool_takes_priority_over_fallback() {
243 let mut registry = ExecutorRegistry::new();
244 registry.register_executor(Arc::new(FailExecutor));
245 registry.set_fallback(Arc::new(EchoExecutor));
246 registry.register_config_tool(config_tool("my_tool", "fail"));
247
248 let result = registry
249 .execute("my_tool", "c1", json!({}), None)
250 .await
251 .unwrap();
252 assert!(result.is_error);
254 }
255
256 #[test]
257 fn definitions_returns_config_tools() {
258 let mut registry = ExecutorRegistry::new();
259 registry.register_config_tool(config_tool("tool_a", "echo"));
260 registry.register_config_tool(config_tool("tool_b", "echo"));
261
262 let defs = registry.definitions();
263 assert_eq!(defs.len(), 2);
264 }
265
266 #[test]
267 fn has_tool_checks_config_and_fallback() {
268 let mut registry = ExecutorRegistry::new();
269 assert!(!registry.has_tool("anything"));
270
271 registry.register_config_tool(config_tool("my_tool", "echo"));
272 assert!(registry.has_tool("my_tool"));
273 assert!(!registry.has_tool("other"));
274
275 registry.set_fallback(Arc::new(EchoExecutor));
276 assert!(registry.has_tool("other")); }
278
279 #[tokio::test]
280 async fn multiple_executors() {
281 let mut registry = ExecutorRegistry::new();
282 registry.register_executor(Arc::new(EchoExecutor));
283 registry.register_executor(Arc::new(FailExecutor));
284 registry.register_config_tool(config_tool("echo_tool", "echo"));
285 registry.register_config_tool(config_tool("fail_tool", "fail"));
286
287 let r1 = registry
288 .execute("echo_tool", "c1", json!({}), None)
289 .await
290 .unwrap();
291 assert!(!r1.is_error);
292
293 let r2 = registry
294 .execute("fail_tool", "c2", json!({}), None)
295 .await
296 .unwrap();
297 assert!(r2.is_error);
298 }
299}