1pub mod direct;
11#[cfg(feature = "native")]
12pub mod http;
13pub mod llm;
14pub mod mcp;
15pub mod shell;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20#[cfg(test)]
21use async_trait::async_trait;
22use tokio::sync::mpsc;
23
24use crate::error::{SoulError, SoulResult};
25use crate::tool::ToolOutput;
26use crate::types::ToolDefinition;
27
28#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
30#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
31pub trait ToolExecutor: Send + Sync {
32 async fn execute(
34 &self,
35 definition: &ToolDefinition,
36 call_id: &str,
37 arguments: serde_json::Value,
38 partial_tx: Option<mpsc::UnboundedSender<String>>,
39 ) -> SoulResult<ToolOutput>;
40
41 fn executor_name(&self) -> &str;
43}
44
45#[derive(Debug, Clone)]
47pub struct ConfigTool {
48 pub definition: ToolDefinition,
49 pub executor_name: String,
50 pub executor_config: serde_json::Value,
51}
52
53pub struct ExecutorRegistry {
57 executors: HashMap<String, Arc<dyn ToolExecutor>>,
58 config_tools: HashMap<String, ConfigTool>,
59 fallback: Option<Arc<dyn ToolExecutor>>,
60}
61
62impl ExecutorRegistry {
63 pub fn new() -> Self {
64 Self {
65 executors: HashMap::new(),
66 config_tools: HashMap::new(),
67 fallback: None,
68 }
69 }
70
71 pub fn register_executor(&mut self, executor: Arc<dyn ToolExecutor>) {
73 self.executors
74 .insert(executor.executor_name().to_string(), executor);
75 }
76
77 pub fn register_config_tool(&mut self, tool: ConfigTool) {
79 self.config_tools.insert(tool.definition.name.clone(), tool);
80 }
81
82 pub fn set_fallback(&mut self, executor: Arc<dyn ToolExecutor>) {
84 self.fallback = Some(executor);
85 }
86
87 pub async fn execute(
89 &self,
90 tool_name: &str,
91 call_id: &str,
92 arguments: serde_json::Value,
93 partial_tx: Option<mpsc::UnboundedSender<String>>,
94 ) -> SoulResult<ToolOutput> {
95 if let Some(config_tool) = self.config_tools.get(tool_name) {
97 if let Some(executor) = self.executors.get(&config_tool.executor_name) {
98 return executor
99 .execute(&config_tool.definition, call_id, arguments, partial_tx)
100 .await;
101 }
102 return Err(SoulError::ExecutorNotFound {
103 name: config_tool.executor_name.clone(),
104 });
105 }
106
107 if let Some(fallback) = &self.fallback {
109 let def = ToolDefinition {
110 name: tool_name.to_string(),
111 description: String::new(),
112 input_schema: serde_json::json!({"type": "object"}),
113 };
114 return fallback.execute(&def, call_id, arguments, partial_tx).await;
115 }
116
117 Err(SoulError::ToolExecution {
118 tool_name: tool_name.to_string(),
119 message: format!("No executor found for tool '{tool_name}'"),
120 })
121 }
122
123 pub fn definitions(&self) -> Vec<ToolDefinition> {
125 self.config_tools
126 .values()
127 .map(|ct| ct.definition.clone())
128 .collect()
129 }
130
131 pub fn has_tool(&self, name: &str) -> bool {
133 self.config_tools.contains_key(name) || self.fallback.is_some()
134 }
135}
136
137impl Default for ExecutorRegistry {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use serde_json::json;
147
148 struct EchoExecutor;
149
150 #[async_trait]
151 impl ToolExecutor for EchoExecutor {
152 async fn execute(
153 &self,
154 definition: &ToolDefinition,
155 _call_id: &str,
156 arguments: serde_json::Value,
157 _partial_tx: Option<mpsc::UnboundedSender<String>>,
158 ) -> SoulResult<ToolOutput> {
159 Ok(ToolOutput::success(format!(
160 "{}({})",
161 definition.name, arguments
162 )))
163 }
164
165 fn executor_name(&self) -> &str {
166 "echo"
167 }
168 }
169
170 struct FailExecutor;
171
172 #[async_trait]
173 impl ToolExecutor for FailExecutor {
174 async fn execute(
175 &self,
176 _definition: &ToolDefinition,
177 _call_id: &str,
178 _arguments: serde_json::Value,
179 _partial_tx: Option<mpsc::UnboundedSender<String>>,
180 ) -> SoulResult<ToolOutput> {
181 Ok(ToolOutput::error("always fails"))
182 }
183
184 fn executor_name(&self) -> &str {
185 "fail"
186 }
187 }
188
189 fn config_tool(name: &str, executor: &str) -> ConfigTool {
190 ConfigTool {
191 definition: ToolDefinition {
192 name: name.into(),
193 description: format!("Tool {name}"),
194 input_schema: json!({"type": "object"}),
195 },
196 executor_name: executor.into(),
197 executor_config: json!({}),
198 }
199 }
200
201 #[tokio::test]
202 async fn routes_to_correct_executor() {
203 let mut registry = ExecutorRegistry::new();
204 registry.register_executor(Arc::new(EchoExecutor));
205 registry.register_config_tool(config_tool("my_tool", "echo"));
206
207 let result = registry
208 .execute("my_tool", "c1", json!({"a": 1}), None)
209 .await
210 .unwrap();
211 assert!(result.content.contains("my_tool"));
212 assert!(!result.is_error);
213 }
214
215 #[tokio::test]
216 async fn missing_executor_errors() {
217 let mut registry = ExecutorRegistry::new();
218 registry.register_config_tool(config_tool("my_tool", "nonexistent"));
219
220 let result = registry.execute("my_tool", "c1", json!({}), None).await;
221 assert!(result.is_err());
222 }
223
224 #[tokio::test]
225 async fn unknown_tool_errors() {
226 let registry = ExecutorRegistry::new();
227 let result = registry.execute("unknown", "c1", json!({}), None).await;
228 assert!(result.is_err());
229 }
230
231 #[tokio::test]
232 async fn fallback_executor() {
233 let mut registry = ExecutorRegistry::new();
234 registry.set_fallback(Arc::new(EchoExecutor));
235
236 let result = registry
237 .execute("anything", "c1", json!({}), None)
238 .await
239 .unwrap();
240 assert!(!result.is_error);
241 }
242
243 #[tokio::test]
244 async fn config_tool_takes_priority_over_fallback() {
245 let mut registry = ExecutorRegistry::new();
246 registry.register_executor(Arc::new(FailExecutor));
247 registry.set_fallback(Arc::new(EchoExecutor));
248 registry.register_config_tool(config_tool("my_tool", "fail"));
249
250 let result = registry
251 .execute("my_tool", "c1", json!({}), None)
252 .await
253 .unwrap();
254 assert!(result.is_error);
256 }
257
258 #[test]
259 fn definitions_returns_config_tools() {
260 let mut registry = ExecutorRegistry::new();
261 registry.register_config_tool(config_tool("tool_a", "echo"));
262 registry.register_config_tool(config_tool("tool_b", "echo"));
263
264 let defs = registry.definitions();
265 assert_eq!(defs.len(), 2);
266 }
267
268 #[test]
269 fn has_tool_checks_config_and_fallback() {
270 let mut registry = ExecutorRegistry::new();
271 assert!(!registry.has_tool("anything"));
272
273 registry.register_config_tool(config_tool("my_tool", "echo"));
274 assert!(registry.has_tool("my_tool"));
275 assert!(!registry.has_tool("other"));
276
277 registry.set_fallback(Arc::new(EchoExecutor));
278 assert!(registry.has_tool("other")); }
280
281 #[tokio::test]
282 async fn multiple_executors() {
283 let mut registry = ExecutorRegistry::new();
284 registry.register_executor(Arc::new(EchoExecutor));
285 registry.register_executor(Arc::new(FailExecutor));
286 registry.register_config_tool(config_tool("echo_tool", "echo"));
287 registry.register_config_tool(config_tool("fail_tool", "fail"));
288
289 let r1 = registry
290 .execute("echo_tool", "c1", json!({}), None)
291 .await
292 .unwrap();
293 assert!(!r1.is_error);
294
295 let r2 = registry
296 .execute("fail_tool", "c2", json!({}), None)
297 .await
298 .unwrap();
299 assert!(r2.is_error);
300 }
301}