Skip to main content

soul_core/tool/
mod.rs

1#[cfg(test)]
2use async_trait::async_trait;
3use std::sync::Arc;
4use tokio::sync::mpsc;
5
6use crate::error::SoulResult;
7use crate::types::ToolDefinition;
8
9/// A tool that can be executed by the agent
10#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
11#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
12pub trait Tool: Send + Sync {
13    /// Tool name (must match the definition name)
14    fn name(&self) -> &str;
15
16    /// Tool definition for sending to the LLM
17    fn definition(&self) -> ToolDefinition;
18
19    /// Execute the tool with the given arguments
20    async fn execute(
21        &self,
22        call_id: &str,
23        arguments: serde_json::Value,
24        partial_tx: Option<mpsc::UnboundedSender<String>>,
25    ) -> SoulResult<ToolOutput>;
26}
27
28/// Output from a tool execution
29#[derive(Debug, Clone)]
30pub struct ToolOutput {
31    pub content: String,
32    pub is_error: bool,
33    pub metadata: serde_json::Value,
34}
35
36impl ToolOutput {
37    pub fn success(content: impl Into<String>) -> Self {
38        Self {
39            content: content.into(),
40            is_error: false,
41            metadata: serde_json::Value::Null,
42        }
43    }
44
45    pub fn error(content: impl Into<String>) -> Self {
46        Self {
47            content: content.into(),
48            is_error: true,
49            metadata: serde_json::Value::Null,
50        }
51    }
52
53    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
54        self.metadata = metadata;
55        self
56    }
57}
58
59/// Handle for adding tools at runtime (from within a running agent).
60///
61/// Clone this handle and give it to tools that need to register new tools
62/// dynamically (e.g., `install_skill`). Tools added via the handle become
63/// visible to the LLM on the next turn.
64#[derive(Clone)]
65pub struct DynamicToolHandle {
66    inner: Arc<std::sync::RwLock<Vec<Arc<dyn Tool>>>>,
67}
68
69impl DynamicToolHandle {
70    /// Register a tool at runtime. Visible to the agent on the next turn.
71    pub fn register(&self, tool: Arc<dyn Tool>) {
72        let mut tools = self.inner.write().unwrap();
73        tools.push(tool);
74    }
75
76    /// Number of dynamically-added tools.
77    pub fn len(&self) -> usize {
78        self.inner.read().unwrap().len()
79    }
80
81    /// Check if any dynamic tools have been added.
82    pub fn is_empty(&self) -> bool {
83        self.inner.read().unwrap().is_empty()
84    }
85}
86
87impl std::fmt::Debug for DynamicToolHandle {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        let count = self.inner.read().map(|v| v.len()).unwrap_or(0);
90        f.debug_struct("DynamicToolHandle")
91            .field("count", &count)
92            .finish()
93    }
94}
95
96/// Registry of tools available to the agent.
97///
98/// Supports both static tools (registered at startup via `register()`) and
99/// dynamic tools (registered at runtime via `dynamic_handle()`).
100pub struct ToolRegistry {
101    tools: Vec<Box<dyn Tool>>,
102    dynamic: Arc<std::sync::RwLock<Vec<Arc<dyn Tool>>>>,
103}
104
105impl ToolRegistry {
106    pub fn new() -> Self {
107        Self {
108            tools: Vec::new(),
109            dynamic: Arc::new(std::sync::RwLock::new(Vec::new())),
110        }
111    }
112
113    /// Register a tool at build time (before the agent starts).
114    pub fn register(&mut self, tool: Box<dyn Tool>) {
115        self.tools.push(tool);
116    }
117
118    /// Get a handle for runtime tool registration.
119    ///
120    /// Tools added via this handle become visible to the LLM on the next turn
121    /// (when `definitions()` is called again).
122    pub fn dynamic_handle(&self) -> DynamicToolHandle {
123        DynamicToolHandle {
124            inner: self.dynamic.clone(),
125        }
126    }
127
128    /// Look up a static tool by name.
129    pub fn get(&self, name: &str) -> Option<&dyn Tool> {
130        self.tools
131            .iter()
132            .find(|t| t.name() == name)
133            .map(|t| t.as_ref())
134    }
135
136    /// Look up a dynamic tool by name (returns Arc clone, safe across await).
137    pub fn get_dynamic(&self, name: &str) -> Option<Arc<dyn Tool>> {
138        let dynamic = self.dynamic.read().unwrap();
139        dynamic
140            .iter()
141            .find(|t| t.name() == name)
142            .cloned()
143    }
144
145    /// Get all tool definitions (static + dynamic).
146    pub fn definitions(&self) -> Vec<ToolDefinition> {
147        let mut defs: Vec<ToolDefinition> =
148            self.tools.iter().map(|t| t.definition()).collect();
149        let dynamic = self.dynamic.read().unwrap();
150        defs.extend(dynamic.iter().map(|t| t.definition()));
151        defs
152    }
153
154    pub fn names(&self) -> Vec<&str> {
155        self.tools.iter().map(|t| t.name()).collect()
156    }
157
158    /// Get all tool names including dynamic (owned strings).
159    pub fn all_names(&self) -> Vec<String> {
160        let mut names: Vec<String> = self.tools.iter().map(|t| t.name().to_string()).collect();
161        let dynamic = self.dynamic.read().unwrap();
162        names.extend(dynamic.iter().map(|t| t.name().to_string()));
163        names
164    }
165
166    pub fn len(&self) -> usize {
167        self.tools.len() + self.dynamic.read().unwrap().len()
168    }
169
170    pub fn is_empty(&self) -> bool {
171        self.tools.is_empty() && self.dynamic.read().unwrap().is_empty()
172    }
173}
174
175impl Default for ToolRegistry {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use serde_json::json;
185
186    struct EchoTool;
187
188    #[async_trait]
189    impl Tool for EchoTool {
190        fn name(&self) -> &str {
191            "echo"
192        }
193
194        fn definition(&self) -> ToolDefinition {
195            ToolDefinition {
196                name: "echo".into(),
197                description: "Echo back the input".into(),
198                input_schema: json!({
199                    "type": "object",
200                    "properties": {"message": {"type": "string"}},
201                    "required": ["message"]
202                }),
203            }
204        }
205
206        async fn execute(
207            &self,
208            _call_id: &str,
209            arguments: serde_json::Value,
210            _partial_tx: Option<mpsc::UnboundedSender<String>>,
211        ) -> SoulResult<ToolOutput> {
212            let message = arguments
213                .get("message")
214                .and_then(|v| v.as_str())
215                .unwrap_or("no message");
216            Ok(ToolOutput::success(message))
217        }
218    }
219
220    #[test]
221    fn tool_output_success() {
222        let output = ToolOutput::success("result");
223        assert_eq!(output.content, "result");
224        assert!(!output.is_error);
225    }
226
227    #[test]
228    fn tool_output_error() {
229        let output = ToolOutput::error("failed");
230        assert_eq!(output.content, "failed");
231        assert!(output.is_error);
232    }
233
234    #[test]
235    fn tool_output_with_metadata() {
236        let output = ToolOutput::success("ok").with_metadata(json!({"duration_ms": 42}));
237        assert_eq!(output.metadata["duration_ms"], 42);
238    }
239
240    #[test]
241    fn registry_register_and_lookup() {
242        let mut registry = ToolRegistry::new();
243        assert!(registry.is_empty());
244
245        registry.register(Box::new(EchoTool));
246        assert_eq!(registry.len(), 1);
247        assert!(!registry.is_empty());
248
249        let tool = registry.get("echo");
250        assert!(tool.is_some());
251        assert_eq!(tool.unwrap().name(), "echo");
252
253        assert!(registry.get("nonexistent").is_none());
254    }
255
256    #[test]
257    fn registry_definitions() {
258        let mut registry = ToolRegistry::new();
259        registry.register(Box::new(EchoTool));
260
261        let defs = registry.definitions();
262        assert_eq!(defs.len(), 1);
263        assert_eq!(defs[0].name, "echo");
264    }
265
266    #[test]
267    fn registry_names() {
268        let mut registry = ToolRegistry::new();
269        registry.register(Box::new(EchoTool));
270
271        let names = registry.names();
272        assert_eq!(names, vec!["echo"]);
273    }
274
275    #[tokio::test]
276    async fn tool_execute() {
277        let tool = EchoTool;
278        let result = tool
279            .execute("call_1", json!({"message": "hello world"}), None)
280            .await
281            .unwrap();
282        assert_eq!(result.content, "hello world");
283        assert!(!result.is_error);
284    }
285
286    // Trait object safety
287    #[test]
288    fn tool_is_object_safe() {
289        fn _assert_object_safe(_: &dyn Tool) {}
290    }
291}