serdes_ai_toolsets/
dynamic.rs1use async_trait::async_trait;
7use parking_lot::RwLock;
8use serde_json::Value as JsonValue;
9use serdes_ai_tools::{RunContext, Tool, ToolError, ToolReturn};
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use crate::{AbstractToolset, ToolsetTool};
14
15pub struct DynamicToolset<Deps = ()>
38where
39 Deps: Send + Sync + 'static,
40{
41 id: Option<String>,
42 tools: RwLock<HashMap<String, Arc<dyn Tool<Deps>>>>,
43 max_retries: u32,
44}
45
46impl<Deps: Send + Sync + 'static> DynamicToolset<Deps> {
47 #[must_use]
49 pub fn new() -> Self {
50 Self {
51 id: None,
52 tools: RwLock::new(HashMap::new()),
53 max_retries: 3,
54 }
55 }
56
57 #[must_use]
59 pub fn with_id(id: impl Into<String>) -> Self {
60 Self {
61 id: Some(id.into()),
62 tools: RwLock::new(HashMap::new()),
63 max_retries: 3,
64 }
65 }
66
67 #[must_use]
69 pub fn with_max_retries(mut self, retries: u32) -> Self {
70 self.max_retries = retries;
71 self
72 }
73
74 pub fn add_tool<T: Tool<Deps> + 'static>(&self, tool: T) {
78 let name = tool.definition().name.clone();
79 self.tools.write().insert(name, Arc::new(tool));
80 }
81
82 pub fn add_boxed(&self, tool: Arc<dyn Tool<Deps>>) {
84 let name = tool.definition().name.clone();
85 self.tools.write().insert(name, tool);
86 }
87
88 pub fn remove_tool(&self, name: &str) -> bool {
92 self.tools.write().remove(name).is_some()
93 }
94
95 pub fn clear(&self) {
97 self.tools.write().clear();
98 }
99
100 #[must_use]
102 pub fn len(&self) -> usize {
103 self.tools.read().len()
104 }
105
106 #[must_use]
108 pub fn is_empty(&self) -> bool {
109 self.tools.read().is_empty()
110 }
111
112 #[must_use]
114 pub fn contains(&self, name: &str) -> bool {
115 self.tools.read().contains_key(name)
116 }
117
118 #[must_use]
120 pub fn tool_names(&self) -> Vec<String> {
121 self.tools.read().keys().cloned().collect()
122 }
123}
124
125impl<Deps: Send + Sync + 'static> Default for DynamicToolset<Deps> {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131#[async_trait]
132impl<Deps: Send + Sync + 'static> AbstractToolset<Deps> for DynamicToolset<Deps> {
133 fn id(&self) -> Option<&str> {
134 self.id.as_deref()
135 }
136
137 fn type_name(&self) -> &'static str {
138 "DynamicToolset"
139 }
140
141 async fn get_tools(
142 &self,
143 ctx: &RunContext<Deps>,
144 ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
145 let tools_snapshot: Vec<(String, Arc<dyn Tool<Deps>>)> = {
147 let tools = self.tools.read();
148 tools
149 .iter()
150 .map(|(k, v)| (k.clone(), Arc::clone(v)))
151 .collect()
152 };
153
154 let mut result = HashMap::with_capacity(tools_snapshot.len());
155
156 for (name, tool) in tools_snapshot {
157 let def = tool.definition();
158
159 let prepared_def = tool.prepare(ctx, def.clone()).await;
161
162 if let Some(final_def) = prepared_def {
163 let max_retries = tool.max_retries().unwrap_or(self.max_retries);
164 result.insert(
165 name,
166 ToolsetTool {
167 toolset_id: self.id.clone(),
168 tool_def: final_def,
169 max_retries,
170 },
171 );
172 }
173 }
174
175 Ok(result)
176 }
177
178 async fn call_tool(
179 &self,
180 name: &str,
181 args: JsonValue,
182 ctx: &RunContext<Deps>,
183 _tool: &ToolsetTool,
184 ) -> Result<ToolReturn, ToolError> {
185 let tool = {
186 let tools = self.tools.read();
187 tools
188 .get(name)
189 .cloned()
190 .ok_or_else(|| ToolError::not_found(name))?
191 };
192
193 tool.call(ctx, args).await
194 }
195}
196
197impl<Deps: Send + Sync + 'static> std::fmt::Debug for DynamicToolset<Deps> {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 f.debug_struct("DynamicToolset")
200 .field("id", &self.id)
201 .field("tool_count", &self.len())
202 .field("max_retries", &self.max_retries)
203 .finish()
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use serdes_ai_tools::ToolDefinition;
211
212 struct EchoTool {
213 prefix: String,
214 }
215
216 impl EchoTool {
217 fn new(prefix: impl Into<String>) -> Self {
218 Self {
219 prefix: prefix.into(),
220 }
221 }
222 }
223
224 #[async_trait]
225 impl Tool<()> for EchoTool {
226 fn definition(&self) -> ToolDefinition {
227 ToolDefinition::new("echo", "Echo with prefix")
228 }
229
230 async fn call(
231 &self,
232 _ctx: &RunContext<()>,
233 args: JsonValue,
234 ) -> Result<ToolReturn, ToolError> {
235 let msg = args["msg"].as_str().unwrap_or("<none>");
236 Ok(ToolReturn::text(format!("{}{}", self.prefix, msg)))
237 }
238 }
239
240 struct AddTool;
241
242 #[async_trait]
243 impl Tool<()> for AddTool {
244 fn definition(&self) -> ToolDefinition {
245 ToolDefinition::new("add", "Add numbers")
246 }
247
248 async fn call(
249 &self,
250 _ctx: &RunContext<()>,
251 args: JsonValue,
252 ) -> Result<ToolReturn, ToolError> {
253 let a = args["a"].as_i64().unwrap_or(0);
254 let b = args["b"].as_i64().unwrap_or(0);
255 Ok(ToolReturn::text(format!("{}", a + b)))
256 }
257 }
258
259 #[test]
260 fn test_dynamic_toolset_new() {
261 let toolset = DynamicToolset::<()>::new();
262 assert!(toolset.is_empty());
263 }
264
265 #[test]
266 fn test_dynamic_toolset_add_remove() {
267 let toolset = DynamicToolset::<()>::new();
268
269 toolset.add_tool(EchoTool::new(">>> "));
270 assert_eq!(toolset.len(), 1);
271 assert!(toolset.contains("echo"));
272
273 toolset.add_tool(AddTool);
274 assert_eq!(toolset.len(), 2);
275
276 assert!(toolset.remove_tool("echo"));
277 assert_eq!(toolset.len(), 1);
278 assert!(!toolset.contains("echo"));
279
280 assert!(!toolset.remove_tool("nonexistent"));
281 }
282
283 #[test]
284 fn test_dynamic_toolset_clear() {
285 let toolset = DynamicToolset::<()>::new();
286 toolset.add_tool(EchoTool::new(""));
287 toolset.add_tool(AddTool);
288
289 toolset.clear();
290 assert!(toolset.is_empty());
291 }
292
293 #[test]
294 fn test_dynamic_toolset_tool_names() {
295 let toolset = DynamicToolset::<()>::new();
296 toolset.add_tool(EchoTool::new(""));
297 toolset.add_tool(AddTool);
298
299 let names = toolset.tool_names();
300 assert_eq!(names.len(), 2);
301 assert!(names.contains(&"echo".to_string()));
302 assert!(names.contains(&"add".to_string()));
303 }
304
305 #[tokio::test]
306 async fn test_dynamic_toolset_get_tools() {
307 let toolset = DynamicToolset::<()>::new();
308 toolset.add_tool(EchoTool::new(""));
309
310 let ctx = RunContext::minimal("test");
311 let tools = toolset.get_tools(&ctx).await.unwrap();
312
313 assert_eq!(tools.len(), 1);
314 assert!(tools.contains_key("echo"));
315 }
316
317 #[tokio::test]
318 async fn test_dynamic_toolset_call_tool() {
319 let toolset = DynamicToolset::<()>::new();
320 toolset.add_tool(EchoTool::new("[PREFIX] "));
321
322 let ctx = RunContext::minimal("test");
323 let tools = toolset.get_tools(&ctx).await.unwrap();
324 let tool = tools.get("echo").unwrap();
325
326 let result = toolset
327 .call_tool("echo", serde_json::json!({"msg": "hello"}), &ctx, tool)
328 .await
329 .unwrap();
330
331 assert_eq!(result.as_text(), Some("[PREFIX] hello"));
332 }
333
334 #[tokio::test]
335 async fn test_dynamic_toolset_replace_tool() {
336 let toolset = DynamicToolset::<()>::new();
337 toolset.add_tool(EchoTool::new("v1: "));
338
339 let ctx = RunContext::minimal("test");
340 let tools = toolset.get_tools(&ctx).await.unwrap();
341 let tool = tools.get("echo").unwrap();
342
343 let result1 = toolset
344 .call_tool("echo", serde_json::json!({"msg": "test"}), &ctx, tool)
345 .await
346 .unwrap();
347 assert_eq!(result1.as_text(), Some("v1: test"));
348
349 toolset.add_tool(EchoTool::new("v2: "));
351
352 let result2 = toolset
353 .call_tool("echo", serde_json::json!({"msg": "test"}), &ctx, tool)
354 .await
355 .unwrap();
356 assert_eq!(result2.as_text(), Some("v2: test"));
357 }
358
359 #[tokio::test]
360 async fn test_dynamic_toolset_concurrent_access() {
361 use std::sync::Arc;
362 use tokio::task::JoinSet;
363
364 let toolset = Arc::new(DynamicToolset::<()>::new());
365
366 let mut tasks = JoinSet::new();
367
368 for i in 0..10 {
370 let ts = toolset.clone();
371 tasks.spawn(async move {
372 ts.add_tool(EchoTool::new(format!("task{}: ", i)));
373 });
374 }
375
376 while tasks.join_next().await.is_some() {}
378
379 assert!(!toolset.is_empty());
381 }
382}