1#[cfg(test)]
2use async_trait::async_trait;
3use std::sync::Arc;
4use tokio::sync::mpsc;
5
6use crate::error::SoulResult;
7use crate::types::ToolDefinition;
8
9#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
11#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
12pub trait Tool: Send + Sync {
13 fn name(&self) -> &str;
15
16 fn definition(&self) -> ToolDefinition;
18
19 async fn execute(
21 &self,
22 call_id: &str,
23 arguments: serde_json::Value,
24 partial_tx: Option<mpsc::UnboundedSender<String>>,
25 ) -> SoulResult<ToolOutput>;
26}
27
28#[derive(Debug, Clone)]
30pub struct ToolOutput {
31 pub content: String,
32 pub is_error: bool,
33 pub metadata: serde_json::Value,
34}
35
36impl ToolOutput {
37 pub fn success(content: impl Into<String>) -> Self {
38 Self {
39 content: content.into(),
40 is_error: false,
41 metadata: serde_json::Value::Null,
42 }
43 }
44
45 pub fn error(content: impl Into<String>) -> Self {
46 Self {
47 content: content.into(),
48 is_error: true,
49 metadata: serde_json::Value::Null,
50 }
51 }
52
53 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
54 self.metadata = metadata;
55 self
56 }
57}
58
59#[derive(Clone)]
65pub struct DynamicToolHandle {
66 inner: Arc<std::sync::RwLock<Vec<Arc<dyn Tool>>>>,
67}
68
69impl DynamicToolHandle {
70 pub fn register(&self, tool: Arc<dyn Tool>) {
72 let mut tools = self.inner.write().unwrap();
73 tools.push(tool);
74 }
75
76 pub fn len(&self) -> usize {
78 self.inner.read().unwrap().len()
79 }
80
81 pub fn is_empty(&self) -> bool {
83 self.inner.read().unwrap().is_empty()
84 }
85}
86
87impl std::fmt::Debug for DynamicToolHandle {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 let count = self.inner.read().map(|v| v.len()).unwrap_or(0);
90 f.debug_struct("DynamicToolHandle")
91 .field("count", &count)
92 .finish()
93 }
94}
95
96pub struct ToolRegistry {
101 tools: Vec<Box<dyn Tool>>,
102 dynamic: Arc<std::sync::RwLock<Vec<Arc<dyn Tool>>>>,
103}
104
105impl ToolRegistry {
106 pub fn new() -> Self {
107 Self {
108 tools: Vec::new(),
109 dynamic: Arc::new(std::sync::RwLock::new(Vec::new())),
110 }
111 }
112
113 pub fn register(&mut self, tool: Box<dyn Tool>) {
115 self.tools.push(tool);
116 }
117
118 pub fn dynamic_handle(&self) -> DynamicToolHandle {
123 DynamicToolHandle {
124 inner: self.dynamic.clone(),
125 }
126 }
127
128 pub fn get(&self, name: &str) -> Option<&dyn Tool> {
130 self.tools
131 .iter()
132 .find(|t| t.name() == name)
133 .map(|t| t.as_ref())
134 }
135
136 pub fn get_dynamic(&self, name: &str) -> Option<Arc<dyn Tool>> {
138 let dynamic = self.dynamic.read().unwrap();
139 dynamic
140 .iter()
141 .find(|t| t.name() == name)
142 .cloned()
143 }
144
145 pub fn definitions(&self) -> Vec<ToolDefinition> {
147 let mut defs: Vec<ToolDefinition> =
148 self.tools.iter().map(|t| t.definition()).collect();
149 let dynamic = self.dynamic.read().unwrap();
150 defs.extend(dynamic.iter().map(|t| t.definition()));
151 defs
152 }
153
154 pub fn names(&self) -> Vec<&str> {
155 self.tools.iter().map(|t| t.name()).collect()
156 }
157
158 pub fn all_names(&self) -> Vec<String> {
160 let mut names: Vec<String> = self.tools.iter().map(|t| t.name().to_string()).collect();
161 let dynamic = self.dynamic.read().unwrap();
162 names.extend(dynamic.iter().map(|t| t.name().to_string()));
163 names
164 }
165
166 pub fn len(&self) -> usize {
167 self.tools.len() + self.dynamic.read().unwrap().len()
168 }
169
170 pub fn is_empty(&self) -> bool {
171 self.tools.is_empty() && self.dynamic.read().unwrap().is_empty()
172 }
173}
174
175impl Default for ToolRegistry {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use serde_json::json;
185
186 struct EchoTool;
187
188 #[async_trait]
189 impl Tool for EchoTool {
190 fn name(&self) -> &str {
191 "echo"
192 }
193
194 fn definition(&self) -> ToolDefinition {
195 ToolDefinition {
196 name: "echo".into(),
197 description: "Echo back the input".into(),
198 input_schema: json!({
199 "type": "object",
200 "properties": {"message": {"type": "string"}},
201 "required": ["message"]
202 }),
203 }
204 }
205
206 async fn execute(
207 &self,
208 _call_id: &str,
209 arguments: serde_json::Value,
210 _partial_tx: Option<mpsc::UnboundedSender<String>>,
211 ) -> SoulResult<ToolOutput> {
212 let message = arguments
213 .get("message")
214 .and_then(|v| v.as_str())
215 .unwrap_or("no message");
216 Ok(ToolOutput::success(message))
217 }
218 }
219
220 #[test]
221 fn tool_output_success() {
222 let output = ToolOutput::success("result");
223 assert_eq!(output.content, "result");
224 assert!(!output.is_error);
225 }
226
227 #[test]
228 fn tool_output_error() {
229 let output = ToolOutput::error("failed");
230 assert_eq!(output.content, "failed");
231 assert!(output.is_error);
232 }
233
234 #[test]
235 fn tool_output_with_metadata() {
236 let output = ToolOutput::success("ok").with_metadata(json!({"duration_ms": 42}));
237 assert_eq!(output.metadata["duration_ms"], 42);
238 }
239
240 #[test]
241 fn registry_register_and_lookup() {
242 let mut registry = ToolRegistry::new();
243 assert!(registry.is_empty());
244
245 registry.register(Box::new(EchoTool));
246 assert_eq!(registry.len(), 1);
247 assert!(!registry.is_empty());
248
249 let tool = registry.get("echo");
250 assert!(tool.is_some());
251 assert_eq!(tool.unwrap().name(), "echo");
252
253 assert!(registry.get("nonexistent").is_none());
254 }
255
256 #[test]
257 fn registry_definitions() {
258 let mut registry = ToolRegistry::new();
259 registry.register(Box::new(EchoTool));
260
261 let defs = registry.definitions();
262 assert_eq!(defs.len(), 1);
263 assert_eq!(defs[0].name, "echo");
264 }
265
266 #[test]
267 fn registry_names() {
268 let mut registry = ToolRegistry::new();
269 registry.register(Box::new(EchoTool));
270
271 let names = registry.names();
272 assert_eq!(names, vec!["echo"]);
273 }
274
275 #[tokio::test]
276 async fn tool_execute() {
277 let tool = EchoTool;
278 let result = tool
279 .execute("call_1", json!({"message": "hello world"}), None)
280 .await
281 .unwrap();
282 assert_eq!(result.content, "hello world");
283 assert!(!result.is_error);
284 }
285
286 #[test]
288 fn tool_is_object_safe() {
289 fn _assert_object_safe(_: &dyn Tool) {}
290 }
291}