Skip to main content

serdes_ai_toolsets/
dynamic.rs

1//! Dynamic toolset implementation.
2//!
3//! This module provides `DynamicToolset`, which allows tools to be
4//! added and removed at runtime.
5
6use async_trait::async_trait;
7use parking_lot::RwLock;
8use serde_json::Value as JsonValue;
9use serdes_ai_tools::{RunContext, Tool, ToolError, ToolReturn};
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use crate::{AbstractToolset, ToolsetTool};
14
15/// Toolset that can have tools added/removed at runtime.
16///
17/// This is useful for scenarios where the available tools change
18/// during the agent's lifetime.
19///
20/// # Thread Safety
21///
22/// All operations are thread-safe and can be called concurrently.
23///
24/// # Example
25///
26/// ```ignore
27/// use serdes_ai_toolsets::DynamicToolset;
28///
29/// let toolset = DynamicToolset::new();
30///
31/// // Add tools at runtime
32/// toolset.add_tool(my_tool);
33///
34/// // Remove tools
35/// toolset.remove_tool("my_tool");
36/// ```
37pub struct DynamicToolset<Deps = ()>
38where
39    Deps: Send + Sync + 'static,
40{
41    id: Option<String>,
42    tools: RwLock<HashMap<String, Arc<dyn Tool<Deps>>>>,
43    max_retries: u32,
44}
45
46impl<Deps: Send + Sync + 'static> DynamicToolset<Deps> {
47    /// Create a new empty dynamic toolset.
48    #[must_use]
49    pub fn new() -> Self {
50        Self {
51            id: None,
52            tools: RwLock::new(HashMap::new()),
53            max_retries: 3,
54        }
55    }
56
57    /// Create with an ID.
58    #[must_use]
59    pub fn with_id(id: impl Into<String>) -> Self {
60        Self {
61            id: Some(id.into()),
62            tools: RwLock::new(HashMap::new()),
63            max_retries: 3,
64        }
65    }
66
67    /// Set max retries.
68    #[must_use]
69    pub fn with_max_retries(mut self, retries: u32) -> Self {
70        self.max_retries = retries;
71        self
72    }
73
74    /// Add a tool.
75    ///
76    /// If a tool with the same name exists, it will be replaced.
77    pub fn add_tool<T: Tool<Deps> + 'static>(&self, tool: T) {
78        let name = tool.definition().name.clone();
79        self.tools.write().insert(name, Arc::new(tool));
80    }
81
82    /// Add a boxed tool.
83    pub fn add_boxed(&self, tool: Arc<dyn Tool<Deps>>) {
84        let name = tool.definition().name.clone();
85        self.tools.write().insert(name, tool);
86    }
87
88    /// Remove a tool by name.
89    ///
90    /// Returns `true` if the tool was removed, `false` if it didn't exist.
91    pub fn remove_tool(&self, name: &str) -> bool {
92        self.tools.write().remove(name).is_some()
93    }
94
95    /// Clear all tools.
96    pub fn clear(&self) {
97        self.tools.write().clear();
98    }
99
100    /// Get the number of tools.
101    #[must_use]
102    pub fn len(&self) -> usize {
103        self.tools.read().len()
104    }
105
106    /// Check if empty.
107    #[must_use]
108    pub fn is_empty(&self) -> bool {
109        self.tools.read().is_empty()
110    }
111
112    /// Check if a tool exists.
113    #[must_use]
114    pub fn contains(&self, name: &str) -> bool {
115        self.tools.read().contains_key(name)
116    }
117
118    /// Get tool names.
119    #[must_use]
120    pub fn tool_names(&self) -> Vec<String> {
121        self.tools.read().keys().cloned().collect()
122    }
123}
124
125impl<Deps: Send + Sync + 'static> Default for DynamicToolset<Deps> {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131#[async_trait]
132impl<Deps: Send + Sync + 'static> AbstractToolset<Deps> for DynamicToolset<Deps> {
133    fn id(&self) -> Option<&str> {
134        self.id.as_deref()
135    }
136
137    fn type_name(&self) -> &'static str {
138        "DynamicToolset"
139    }
140
141    async fn get_tools(
142        &self,
143        ctx: &RunContext<Deps>,
144    ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
145        // Clone the tools under the lock to avoid holding it across await
146        let tools_snapshot: Vec<(String, Arc<dyn Tool<Deps>>)> = {
147            let tools = self.tools.read();
148            tools
149                .iter()
150                .map(|(k, v)| (k.clone(), Arc::clone(v)))
151                .collect()
152        };
153
154        let mut result = HashMap::with_capacity(tools_snapshot.len());
155
156        for (name, tool) in tools_snapshot {
157            let def = tool.definition();
158
159            // Apply prepare if available
160            let prepared_def = tool.prepare(ctx, def.clone()).await;
161
162            if let Some(final_def) = prepared_def {
163                let max_retries = tool.max_retries().unwrap_or(self.max_retries);
164                result.insert(
165                    name,
166                    ToolsetTool {
167                        toolset_id: self.id.clone(),
168                        tool_def: final_def,
169                        max_retries,
170                    },
171                );
172            }
173        }
174
175        Ok(result)
176    }
177
178    async fn call_tool(
179        &self,
180        name: &str,
181        args: JsonValue,
182        ctx: &RunContext<Deps>,
183        _tool: &ToolsetTool,
184    ) -> Result<ToolReturn, ToolError> {
185        let tool = {
186            let tools = self.tools.read();
187            tools
188                .get(name)
189                .cloned()
190                .ok_or_else(|| ToolError::not_found(name))?
191        };
192
193        tool.call(ctx, args).await
194    }
195}
196
197impl<Deps: Send + Sync + 'static> std::fmt::Debug for DynamicToolset<Deps> {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        f.debug_struct("DynamicToolset")
200            .field("id", &self.id)
201            .field("tool_count", &self.len())
202            .field("max_retries", &self.max_retries)
203            .finish()
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use serdes_ai_tools::ToolDefinition;
211
212    struct EchoTool {
213        prefix: String,
214    }
215
216    impl EchoTool {
217        fn new(prefix: impl Into<String>) -> Self {
218            Self {
219                prefix: prefix.into(),
220            }
221        }
222    }
223
224    #[async_trait]
225    impl Tool<()> for EchoTool {
226        fn definition(&self) -> ToolDefinition {
227            ToolDefinition::new("echo", "Echo with prefix")
228        }
229
230        async fn call(
231            &self,
232            _ctx: &RunContext<()>,
233            args: JsonValue,
234        ) -> Result<ToolReturn, ToolError> {
235            let msg = args["msg"].as_str().unwrap_or("<none>");
236            Ok(ToolReturn::text(format!("{}{}", self.prefix, msg)))
237        }
238    }
239
240    struct AddTool;
241
242    #[async_trait]
243    impl Tool<()> for AddTool {
244        fn definition(&self) -> ToolDefinition {
245            ToolDefinition::new("add", "Add numbers")
246        }
247
248        async fn call(
249            &self,
250            _ctx: &RunContext<()>,
251            args: JsonValue,
252        ) -> Result<ToolReturn, ToolError> {
253            let a = args["a"].as_i64().unwrap_or(0);
254            let b = args["b"].as_i64().unwrap_or(0);
255            Ok(ToolReturn::text(format!("{}", a + b)))
256        }
257    }
258
259    #[test]
260    fn test_dynamic_toolset_new() {
261        let toolset = DynamicToolset::<()>::new();
262        assert!(toolset.is_empty());
263    }
264
265    #[test]
266    fn test_dynamic_toolset_add_remove() {
267        let toolset = DynamicToolset::<()>::new();
268
269        toolset.add_tool(EchoTool::new(">>> "));
270        assert_eq!(toolset.len(), 1);
271        assert!(toolset.contains("echo"));
272
273        toolset.add_tool(AddTool);
274        assert_eq!(toolset.len(), 2);
275
276        assert!(toolset.remove_tool("echo"));
277        assert_eq!(toolset.len(), 1);
278        assert!(!toolset.contains("echo"));
279
280        assert!(!toolset.remove_tool("nonexistent"));
281    }
282
283    #[test]
284    fn test_dynamic_toolset_clear() {
285        let toolset = DynamicToolset::<()>::new();
286        toolset.add_tool(EchoTool::new(""));
287        toolset.add_tool(AddTool);
288
289        toolset.clear();
290        assert!(toolset.is_empty());
291    }
292
293    #[test]
294    fn test_dynamic_toolset_tool_names() {
295        let toolset = DynamicToolset::<()>::new();
296        toolset.add_tool(EchoTool::new(""));
297        toolset.add_tool(AddTool);
298
299        let names = toolset.tool_names();
300        assert_eq!(names.len(), 2);
301        assert!(names.contains(&"echo".to_string()));
302        assert!(names.contains(&"add".to_string()));
303    }
304
305    #[tokio::test]
306    async fn test_dynamic_toolset_get_tools() {
307        let toolset = DynamicToolset::<()>::new();
308        toolset.add_tool(EchoTool::new(""));
309
310        let ctx = RunContext::minimal("test");
311        let tools = toolset.get_tools(&ctx).await.unwrap();
312
313        assert_eq!(tools.len(), 1);
314        assert!(tools.contains_key("echo"));
315    }
316
317    #[tokio::test]
318    async fn test_dynamic_toolset_call_tool() {
319        let toolset = DynamicToolset::<()>::new();
320        toolset.add_tool(EchoTool::new("[PREFIX] "));
321
322        let ctx = RunContext::minimal("test");
323        let tools = toolset.get_tools(&ctx).await.unwrap();
324        let tool = tools.get("echo").unwrap();
325
326        let result = toolset
327            .call_tool("echo", serde_json::json!({"msg": "hello"}), &ctx, tool)
328            .await
329            .unwrap();
330
331        assert_eq!(result.as_text(), Some("[PREFIX] hello"));
332    }
333
334    #[tokio::test]
335    async fn test_dynamic_toolset_replace_tool() {
336        let toolset = DynamicToolset::<()>::new();
337        toolset.add_tool(EchoTool::new("v1: "));
338
339        let ctx = RunContext::minimal("test");
340        let tools = toolset.get_tools(&ctx).await.unwrap();
341        let tool = tools.get("echo").unwrap();
342
343        let result1 = toolset
344            .call_tool("echo", serde_json::json!({"msg": "test"}), &ctx, tool)
345            .await
346            .unwrap();
347        assert_eq!(result1.as_text(), Some("v1: test"));
348
349        // Replace with v2
350        toolset.add_tool(EchoTool::new("v2: "));
351
352        let result2 = toolset
353            .call_tool("echo", serde_json::json!({"msg": "test"}), &ctx, tool)
354            .await
355            .unwrap();
356        assert_eq!(result2.as_text(), Some("v2: test"));
357    }
358
359    #[tokio::test]
360    async fn test_dynamic_toolset_concurrent_access() {
361        use std::sync::Arc;
362        use tokio::task::JoinSet;
363
364        let toolset = Arc::new(DynamicToolset::<()>::new());
365
366        let mut tasks = JoinSet::new();
367
368        // Spawn multiple tasks that add tools
369        for i in 0..10 {
370            let ts = toolset.clone();
371            tasks.spawn(async move {
372                ts.add_tool(EchoTool::new(format!("task{}: ", i)));
373            });
374        }
375
376        // Wait for all to complete
377        while tasks.join_next().await.is_some() {}
378
379        // All tools should be there (but with possible overwrites for "echo")
380        assert!(!toolset.is_empty());
381    }
382}