strands_agents/tools/
registry.rs

1//! Tool registry for managing agent tools.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use uuid::Uuid;
6
7use crate::types::errors::StrandsError;
8use crate::types::tools::ToolSpec;
9
10use super::mcp::ToolProvider;
11use super::AgentTool;
12
13/// A tool input that can be processed by the registry.
14pub enum ToolInput {
15    /// A single tool instance.
16    Tool(Box<dyn AgentTool>),
17    /// A tool provider.
18    Provider(Arc<dyn ToolProvider>),
19    /// Multiple tool inputs (nested).
20    Multiple(Vec<ToolInput>),
21}
22
23impl ToolInput {
24    /// Create a tool input from an AgentTool.
25    pub fn tool(tool: impl AgentTool + 'static) -> Self {
26        Self::Tool(Box::new(tool))
27    }
28
29    /// Create a tool input from a ToolProvider.
30    pub fn provider(provider: impl ToolProvider + 'static) -> Self {
31        Self::Provider(Arc::new(provider))
32    }
33
34    /// Create a nested collection of tool inputs.
35    pub fn multiple(inputs: impl IntoIterator<Item = ToolInput>) -> Self {
36        Self::Multiple(inputs.into_iter().collect())
37    }
38}
39
40/// Registry for managing agent tools.
41pub struct ToolRegistry {
42    tools: HashMap<String, Arc<dyn AgentTool>>,
43    dynamic_tools: HashMap<String, Arc<dyn AgentTool>>,
44    tool_providers: Vec<Arc<dyn ToolProvider>>,
45    registry_id: String,
46}
47
48impl Default for ToolRegistry {
49    fn default() -> Self { Self::new() }
50}
51
52impl ToolRegistry {
53    pub fn new() -> Self {
54        Self {
55            tools: HashMap::new(),
56            dynamic_tools: HashMap::new(),
57            tool_providers: Vec::new(),
58            registry_id: Uuid::new_v4().to_string(),
59        }
60    }
61
62    /// Process a list of tools.
63    ///
64    /// This is the Rust equivalent of Python's `process_tools` method.
65    /// It handles:
66    /// - AgentTool instances
67    /// - ToolProvider instances
68    /// - Nested collections of tools
69    ///
70    /// Note: Unlike Python, Rust cannot dynamically load modules from file paths.
71    /// For file-based tools, use MCP servers or pre-compiled tool modules.
72    ///
73    /// Returns the list of tool names that were registered.
74    pub async fn process_tools(&mut self, inputs: Vec<ToolInput>) -> Result<Vec<String>, StrandsError> {
75        let mut tool_names = Vec::new();
76        self.process_tools_recursive(inputs, &mut tool_names).await?;
77        Ok(tool_names)
78    }
79
80    /// Internal recursive helper for process_tools.
81    async fn process_tools_recursive(
82        &mut self,
83        inputs: Vec<ToolInput>,
84        tool_names: &mut Vec<String>,
85    ) -> Result<(), StrandsError> {
86        for input in inputs {
87            match input {
88                ToolInput::Tool(tool) => {
89                    let name = tool.tool_name().to_string();
90                    if tool.is_dynamic() {
91                        self.dynamic_tools.insert(name.clone(), Arc::from(tool));
92                    } else {
93                        self.tools.insert(name.clone(), Arc::from(tool));
94                    }
95                    tool_names.push(name);
96                }
97                ToolInput::Provider(provider) => {
98                    provider.add_consumer(&self.registry_id);
99                    let provider_tools = provider.load_tools().await
100                        .map_err(|e| StrandsError::ToolError {
101                            tool_name: "provider".to_string(),
102                            message: format!("Failed to load tools from provider: {}", e),
103                        })?;
104
105                    for tool in provider_tools {
106                        let name = tool.tool_name().to_string();
107                        self.tools.insert(name.clone(), tool);
108                        tool_names.push(name);
109                    }
110
111                    self.tool_providers.push(provider);
112                }
113                ToolInput::Multiple(nested) => {
114                    Box::pin(self.process_tools_recursive(nested, tool_names)).await?;
115                }
116            }
117        }
118        Ok(())
119    }
120
121    /// Process tools synchronously (blocking version).
122    ///
123    /// This is useful when you don't have an async runtime.
124    pub fn process_tools_sync(&mut self, inputs: Vec<ToolInput>) -> Result<Vec<String>, StrandsError> {
125        crate::async_utils::run_async(self.process_tools(inputs))
126    }
127
128    /// Registers a tool with the registry.
129    pub fn register(&mut self, tool: Box<dyn AgentTool>) {
130        let name = tool.tool_name().to_string();
131        self.tools.insert(name, Arc::from(tool));
132    }
133
134    /// Registers a tool with the registry (typed version).
135    pub fn register_typed(&mut self, tool: impl AgentTool + 'static) -> Result<(), StrandsError> {
136        let name = tool.tool_name().to_string();
137
138        if self.tools.contains_key(&name) {
139            return Err(StrandsError::ConfigurationError {
140                message: format!("Tool '{name}' already exists"),
141            });
142        }
143
144        let normalized_name = name.replace('-', "_");
145        for existing_name in self.tools.keys() {
146            if existing_name.replace('-', "_") == normalized_name && *existing_name != name {
147                return Err(StrandsError::ConfigurationError {
148                    message: format!(
149                        "Tool '{name}' conflicts with existing tool '{existing_name}' (differ only by - vs _)"
150                    ),
151                });
152            }
153        }
154
155        self.tools.insert(name, Arc::new(tool));
156        Ok(())
157    }
158
159    /// Registers multiple tools with the registry.
160    pub fn register_all(
161        &mut self,
162        tools: impl IntoIterator<Item = impl AgentTool + 'static>,
163    ) {
164        for tool in tools {
165            self.tools.insert(tool.tool_name().to_string(), Arc::new(tool));
166        }
167    }
168
169    /// Gets a tool by name.
170    pub fn get(&self, name: &str) -> Option<Arc<dyn AgentTool>> {
171        self.tools.get(name).or_else(|| self.dynamic_tools.get(name)).cloned()
172    }
173
174    /// Returns a list of all tool names.
175    pub fn tool_names(&self) -> Vec<&str> {
176        self.tools.keys().chain(self.dynamic_tools.keys()).map(|s| s.as_str()).collect()
177    }
178
179    /// Returns all tool specifications.
180    pub fn get_all_tool_specs(&self) -> Vec<ToolSpec> {
181        self.tools.values().chain(self.dynamic_tools.values()).map(|t| t.tool_spec()).collect()
182    }
183
184    /// Returns all tools as a configuration map.
185    pub fn get_all_tools_config(&self) -> HashMap<String, ToolSpec> {
186        self.tools.iter().chain(self.dynamic_tools.iter()).map(|(n, t)| (n.clone(), t.tool_spec())).collect()
187    }
188
189    pub fn len(&self) -> usize { self.tools.len() + self.dynamic_tools.len() }
190    pub fn is_empty(&self) -> bool { self.tools.is_empty() && self.dynamic_tools.is_empty() }
191
192    /// Registers a dynamic tool.
193    pub fn register_dynamic(&mut self, tool: impl AgentTool + 'static) -> Result<(), StrandsError> {
194        let name = tool.tool_name().to_string();
195
196        if self.tools.contains_key(&name) || self.dynamic_tools.contains_key(&name) {
197            return Err(StrandsError::ConfigurationError {
198                message: format!("Tool '{name}' already exists"),
199            });
200        }
201
202        self.dynamic_tools.insert(name, Arc::new(tool));
203        Ok(())
204    }
205
206    /// Registers a tool from a ToolSpec using StructuredOutputAgentTool.
207    pub fn register_spec(&mut self, spec: ToolSpec) -> Result<(), StrandsError> {
208        let tool = super::structured_output::StructuredOutputAgentTool::from_spec(spec);
209        self.register_typed(tool)
210    }
211
212    /// Removes a dynamic tool by name.
213    pub fn remove_dynamic(&mut self, name: &str) -> bool {
214        self.dynamic_tools.remove(name).is_some()
215    }
216
217    /// Replaces an existing tool.
218    pub fn replace(&mut self, tool: impl AgentTool + 'static) -> Result<(), StrandsError> {
219        let name = tool.tool_name().to_string();
220        let tool_arc = Arc::new(tool);
221
222        if let Some(entry) = self.tools.get_mut(&name) {
223            *entry = tool_arc;
224            Ok(())
225        } else if let Some(entry) = self.dynamic_tools.get_mut(&name) {
226            *entry = tool_arc;
227            Ok(())
228        } else {
229            Err(StrandsError::ToolNotFound { tool_name: name })
230        }
231    }
232
233    /// Clears all tools from the registry.
234    pub fn clear(&mut self) {
235        self.tools.clear();
236        self.dynamic_tools.clear();
237    }
238
239    /// Cleans up the registry and all tool providers.
240    ///
241    /// This removes all consumers from tool providers and clears the registry.
242    /// Errors from individual providers are logged but don't stop cleanup.
243    pub fn cleanup(&mut self) {
244
245        for provider in &self.tool_providers {
246            provider.remove_consumer(&self.registry_id);
247            tracing::debug!("provider cleanup | removed consumer");
248        }
249        self.tool_providers.clear();
250        self.clear();
251    }
252
253    /// Get the registry ID.
254    pub fn registry_id(&self) -> &str {
255        &self.registry_id
256    }
257
258    /// Reloads a tool by name.
259    ///
260    /// In Python, this reloads a tool module from disk for hot-reloading during development.
261    /// In Rust, compiled tools cannot be dynamically reloaded. This method is provided for
262    /// API compatibility and will:
263    /// 1. Log that a reload was requested
264    /// 2. Return success without any action for compiled tools
265    ///
266    /// For dynamic tool implementations (e.g., MCP tools, external process tools),
267    /// subclasses or custom implementations may override this behavior.
268    pub fn reload_tool(&mut self, name: &str) -> Result<(), StrandsError> {
269        if !self.tools.contains_key(name) && !self.dynamic_tools.contains_key(name) {
270            return Err(StrandsError::ToolNotFound {
271                tool_name: name.to_string(),
272            });
273        }
274
275        tracing::info!(
276            "tool_name=<{}> | reload requested (compiled Rust tools do not support hot reload)",
277            name
278        );
279        Ok(())
280    }
281
282    /// Gets the directories being watched for tools.
283    ///
284    /// Returns the standard tool directories that would be scanned for tools.
285    /// In Python, tools can be loaded from these directories at runtime.
286    /// In Rust, this is provided for API compatibility and returns the current
287    /// working directory's `./tools/` path if it exists.
288    pub fn get_tools_dirs(&self) -> Vec<std::path::PathBuf> {
289        let mut dirs = Vec::new();
290
291        if let Ok(cwd) = std::env::current_dir() {
292            let tools_dir = cwd.join("tools");
293            if tools_dir.exists() && tools_dir.is_dir() {
294                tracing::debug!("tools_dir=<{}> | found tools directory", tools_dir.display());
295                dirs.push(tools_dir);
296            }
297        }
298
299        dirs
300    }
301
302    /// Discovers available tool modules in all tools directories.
303    ///
304    /// Returns a map of tool names to their full paths.
305    /// In Python, this scans for `.py` files in the tools directories.
306    /// In Rust, this is provided for API compatibility and scans for common
307    /// tool configuration files (JSON, YAML, WASM).
308    pub fn discover_tool_modules(&self) -> HashMap<String, std::path::PathBuf> {
309        let mut tool_modules = HashMap::new();
310
311        for tools_dir in self.get_tools_dirs() {
312            tracing::debug!("tools_dir=<{}> | scanning", tools_dir.display());
313
314            let entries = match std::fs::read_dir(&tools_dir) {
315                Ok(e) => e,
316                Err(e) => {
317                    tracing::warn!("tools_dir=<{}> | failed to read: {}", tools_dir.display(), e);
318                    continue;
319                }
320            };
321
322            let valid_extensions = ["json", "yaml", "yml", "wasm"];
323
324            for entry in entries.flatten() {
325                let path = entry.path();
326                if !path.is_file() {
327                    continue;
328                }
329
330                let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("");
331                if !valid_extensions.contains(&extension) {
332                    continue;
333                }
334
335                if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
336                    if stem.starts_with('_') {
337                        continue;
338                    }
339
340                    tracing::debug!(
341                        "tools_dir=<{}>, module_name=<{}> | discovered tool",
342                        tools_dir.display(),
343                        stem
344                    );
345                    tool_modules.insert(stem.to_string(), path);
346                }
347            }
348        }
349
350        tracing::debug!("tool_modules=<{:?}> | discovered", tool_modules.keys().collect::<Vec<_>>());
351        tool_modules
352    }
353
354    /// Validates a tool specification.
355    pub fn validate_spec(spec: &ToolSpec) -> Result<(), StrandsError> {
356        if spec.name.is_empty() {
357            return Err(StrandsError::ToolValidationError {
358                message: "Tool name cannot be empty".to_string(),
359            });
360        }
361
362        if spec.description.is_empty() {
363            return Err(StrandsError::ToolValidationError {
364                message: format!("Tool '{}' has an empty description", spec.name),
365            });
366        }
367
368        Ok(())
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use async_trait::async_trait;
376    use crate::tools::{ToolContext, ToolResult2};
377
378    struct DummyTool { name: String }
379
380    impl DummyTool {
381        fn new(name: &str) -> Self { Self { name: name.to_string() } }
382    }
383
384    #[async_trait]
385    impl AgentTool for DummyTool {
386        fn name(&self) -> &str { &self.name }
387        fn description(&self) -> &str { "A dummy tool" }
388        fn tool_spec(&self) -> ToolSpec { ToolSpec::new(&self.name, "A dummy tool") }
389
390        async fn invoke(
391            &self,
392            _input: serde_json::Value,
393            _context: &ToolContext,
394        ) -> std::result::Result<ToolResult2, String> {
395            Ok(ToolResult2::success("dummy result"))
396        }
397    }
398
399    #[test]
400    fn test_registry_register() {
401        let mut registry = ToolRegistry::new();
402        registry.register_typed(DummyTool::new("test")).unwrap();
403        assert_eq!(registry.len(), 1);
404        assert!(registry.get("test").is_some());
405    }
406
407    #[test]
408    fn test_registry_duplicate() {
409        let mut registry = ToolRegistry::new();
410        registry.register_typed(DummyTool::new("test")).unwrap();
411        let result = registry.register_typed(DummyTool::new("test"));
412        assert!(result.is_err());
413    }
414
415    #[test]
416    fn test_registry_normalized_conflict() {
417        let mut registry = ToolRegistry::new();
418        registry.register_typed(DummyTool::new("my_tool")).unwrap();
419        let result = registry.register_typed(DummyTool::new("my-tool"));
420        assert!(result.is_err());
421    }
422
423    #[test]
424    fn test_registry_get_all_specs() {
425        let mut registry = ToolRegistry::new();
426        registry.register_typed(DummyTool::new("tool1")).unwrap();
427        registry.register_typed(DummyTool::new("tool2")).unwrap();
428        let specs = registry.get_all_tool_specs();
429        assert_eq!(specs.len(), 2);
430    }
431}