strands_agents/tools/
registry.rs1use std::collections::HashMap;
4use std::sync::Arc;
5use uuid::Uuid;
6
7use crate::types::errors::StrandsError;
8use crate::types::tools::ToolSpec;
9
10use super::mcp::ToolProvider;
11use super::AgentTool;
12
13pub enum ToolInput {
15 Tool(Box<dyn AgentTool>),
17 Provider(Arc<dyn ToolProvider>),
19 Multiple(Vec<ToolInput>),
21}
22
23impl ToolInput {
24 pub fn tool(tool: impl AgentTool + 'static) -> Self {
26 Self::Tool(Box::new(tool))
27 }
28
29 pub fn provider(provider: impl ToolProvider + 'static) -> Self {
31 Self::Provider(Arc::new(provider))
32 }
33
34 pub fn multiple(inputs: impl IntoIterator<Item = ToolInput>) -> Self {
36 Self::Multiple(inputs.into_iter().collect())
37 }
38}
39
40pub struct ToolRegistry {
42 tools: HashMap<String, Arc<dyn AgentTool>>,
43 dynamic_tools: HashMap<String, Arc<dyn AgentTool>>,
44 tool_providers: Vec<Arc<dyn ToolProvider>>,
45 registry_id: String,
46}
47
48impl Default for ToolRegistry {
49 fn default() -> Self { Self::new() }
50}
51
52impl ToolRegistry {
53 pub fn new() -> Self {
54 Self {
55 tools: HashMap::new(),
56 dynamic_tools: HashMap::new(),
57 tool_providers: Vec::new(),
58 registry_id: Uuid::new_v4().to_string(),
59 }
60 }
61
62 pub async fn process_tools(&mut self, inputs: Vec<ToolInput>) -> Result<Vec<String>, StrandsError> {
75 let mut tool_names = Vec::new();
76 self.process_tools_recursive(inputs, &mut tool_names).await?;
77 Ok(tool_names)
78 }
79
80 async fn process_tools_recursive(
82 &mut self,
83 inputs: Vec<ToolInput>,
84 tool_names: &mut Vec<String>,
85 ) -> Result<(), StrandsError> {
86 for input in inputs {
87 match input {
88 ToolInput::Tool(tool) => {
89 let name = tool.tool_name().to_string();
90 if tool.is_dynamic() {
91 self.dynamic_tools.insert(name.clone(), Arc::from(tool));
92 } else {
93 self.tools.insert(name.clone(), Arc::from(tool));
94 }
95 tool_names.push(name);
96 }
97 ToolInput::Provider(provider) => {
98 provider.add_consumer(&self.registry_id);
99 let provider_tools = provider.load_tools().await
100 .map_err(|e| StrandsError::ToolError {
101 tool_name: "provider".to_string(),
102 message: format!("Failed to load tools from provider: {}", e),
103 })?;
104
105 for tool in provider_tools {
106 let name = tool.tool_name().to_string();
107 self.tools.insert(name.clone(), tool);
108 tool_names.push(name);
109 }
110
111 self.tool_providers.push(provider);
112 }
113 ToolInput::Multiple(nested) => {
114 Box::pin(self.process_tools_recursive(nested, tool_names)).await?;
115 }
116 }
117 }
118 Ok(())
119 }
120
121 pub fn process_tools_sync(&mut self, inputs: Vec<ToolInput>) -> Result<Vec<String>, StrandsError> {
125 crate::async_utils::run_async(self.process_tools(inputs))
126 }
127
128 pub fn register(&mut self, tool: Box<dyn AgentTool>) {
130 let name = tool.tool_name().to_string();
131 self.tools.insert(name, Arc::from(tool));
132 }
133
134 pub fn register_typed(&mut self, tool: impl AgentTool + 'static) -> Result<(), StrandsError> {
136 let name = tool.tool_name().to_string();
137
138 if self.tools.contains_key(&name) {
139 return Err(StrandsError::ConfigurationError {
140 message: format!("Tool '{name}' already exists"),
141 });
142 }
143
144 let normalized_name = name.replace('-', "_");
145 for existing_name in self.tools.keys() {
146 if existing_name.replace('-', "_") == normalized_name && *existing_name != name {
147 return Err(StrandsError::ConfigurationError {
148 message: format!(
149 "Tool '{name}' conflicts with existing tool '{existing_name}' (differ only by - vs _)"
150 ),
151 });
152 }
153 }
154
155 self.tools.insert(name, Arc::new(tool));
156 Ok(())
157 }
158
159 pub fn register_all(
161 &mut self,
162 tools: impl IntoIterator<Item = impl AgentTool + 'static>,
163 ) {
164 for tool in tools {
165 self.tools.insert(tool.tool_name().to_string(), Arc::new(tool));
166 }
167 }
168
169 pub fn get(&self, name: &str) -> Option<Arc<dyn AgentTool>> {
171 self.tools.get(name).or_else(|| self.dynamic_tools.get(name)).cloned()
172 }
173
174 pub fn tool_names(&self) -> Vec<&str> {
176 self.tools.keys().chain(self.dynamic_tools.keys()).map(|s| s.as_str()).collect()
177 }
178
179 pub fn get_all_tool_specs(&self) -> Vec<ToolSpec> {
181 self.tools.values().chain(self.dynamic_tools.values()).map(|t| t.tool_spec()).collect()
182 }
183
184 pub fn get_all_tools_config(&self) -> HashMap<String, ToolSpec> {
186 self.tools.iter().chain(self.dynamic_tools.iter()).map(|(n, t)| (n.clone(), t.tool_spec())).collect()
187 }
188
189 pub fn len(&self) -> usize { self.tools.len() + self.dynamic_tools.len() }
190 pub fn is_empty(&self) -> bool { self.tools.is_empty() && self.dynamic_tools.is_empty() }
191
192 pub fn register_dynamic(&mut self, tool: impl AgentTool + 'static) -> Result<(), StrandsError> {
194 let name = tool.tool_name().to_string();
195
196 if self.tools.contains_key(&name) || self.dynamic_tools.contains_key(&name) {
197 return Err(StrandsError::ConfigurationError {
198 message: format!("Tool '{name}' already exists"),
199 });
200 }
201
202 self.dynamic_tools.insert(name, Arc::new(tool));
203 Ok(())
204 }
205
206 pub fn register_spec(&mut self, spec: ToolSpec) -> Result<(), StrandsError> {
208 let tool = super::structured_output::StructuredOutputAgentTool::from_spec(spec);
209 self.register_typed(tool)
210 }
211
212 pub fn remove_dynamic(&mut self, name: &str) -> bool {
214 self.dynamic_tools.remove(name).is_some()
215 }
216
217 pub fn replace(&mut self, tool: impl AgentTool + 'static) -> Result<(), StrandsError> {
219 let name = tool.tool_name().to_string();
220 let tool_arc = Arc::new(tool);
221
222 if let Some(entry) = self.tools.get_mut(&name) {
223 *entry = tool_arc;
224 Ok(())
225 } else if let Some(entry) = self.dynamic_tools.get_mut(&name) {
226 *entry = tool_arc;
227 Ok(())
228 } else {
229 Err(StrandsError::ToolNotFound { tool_name: name })
230 }
231 }
232
233 pub fn clear(&mut self) {
235 self.tools.clear();
236 self.dynamic_tools.clear();
237 }
238
239 pub fn cleanup(&mut self) {
244
245 for provider in &self.tool_providers {
246 provider.remove_consumer(&self.registry_id);
247 tracing::debug!("provider cleanup | removed consumer");
248 }
249 self.tool_providers.clear();
250 self.clear();
251 }
252
253 pub fn registry_id(&self) -> &str {
255 &self.registry_id
256 }
257
258 pub fn reload_tool(&mut self, name: &str) -> Result<(), StrandsError> {
269 if !self.tools.contains_key(name) && !self.dynamic_tools.contains_key(name) {
270 return Err(StrandsError::ToolNotFound {
271 tool_name: name.to_string(),
272 });
273 }
274
275 tracing::info!(
276 "tool_name=<{}> | reload requested (compiled Rust tools do not support hot reload)",
277 name
278 );
279 Ok(())
280 }
281
282 pub fn get_tools_dirs(&self) -> Vec<std::path::PathBuf> {
289 let mut dirs = Vec::new();
290
291 if let Ok(cwd) = std::env::current_dir() {
292 let tools_dir = cwd.join("tools");
293 if tools_dir.exists() && tools_dir.is_dir() {
294 tracing::debug!("tools_dir=<{}> | found tools directory", tools_dir.display());
295 dirs.push(tools_dir);
296 }
297 }
298
299 dirs
300 }
301
302 pub fn discover_tool_modules(&self) -> HashMap<String, std::path::PathBuf> {
309 let mut tool_modules = HashMap::new();
310
311 for tools_dir in self.get_tools_dirs() {
312 tracing::debug!("tools_dir=<{}> | scanning", tools_dir.display());
313
314 let entries = match std::fs::read_dir(&tools_dir) {
315 Ok(e) => e,
316 Err(e) => {
317 tracing::warn!("tools_dir=<{}> | failed to read: {}", tools_dir.display(), e);
318 continue;
319 }
320 };
321
322 let valid_extensions = ["json", "yaml", "yml", "wasm"];
323
324 for entry in entries.flatten() {
325 let path = entry.path();
326 if !path.is_file() {
327 continue;
328 }
329
330 let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("");
331 if !valid_extensions.contains(&extension) {
332 continue;
333 }
334
335 if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
336 if stem.starts_with('_') {
337 continue;
338 }
339
340 tracing::debug!(
341 "tools_dir=<{}>, module_name=<{}> | discovered tool",
342 tools_dir.display(),
343 stem
344 );
345 tool_modules.insert(stem.to_string(), path);
346 }
347 }
348 }
349
350 tracing::debug!("tool_modules=<{:?}> | discovered", tool_modules.keys().collect::<Vec<_>>());
351 tool_modules
352 }
353
354 pub fn validate_spec(spec: &ToolSpec) -> Result<(), StrandsError> {
356 if spec.name.is_empty() {
357 return Err(StrandsError::ToolValidationError {
358 message: "Tool name cannot be empty".to_string(),
359 });
360 }
361
362 if spec.description.is_empty() {
363 return Err(StrandsError::ToolValidationError {
364 message: format!("Tool '{}' has an empty description", spec.name),
365 });
366 }
367
368 Ok(())
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use async_trait::async_trait;
376 use crate::tools::{ToolContext, ToolResult2};
377
378 struct DummyTool { name: String }
379
380 impl DummyTool {
381 fn new(name: &str) -> Self { Self { name: name.to_string() } }
382 }
383
384 #[async_trait]
385 impl AgentTool for DummyTool {
386 fn name(&self) -> &str { &self.name }
387 fn description(&self) -> &str { "A dummy tool" }
388 fn tool_spec(&self) -> ToolSpec { ToolSpec::new(&self.name, "A dummy tool") }
389
390 async fn invoke(
391 &self,
392 _input: serde_json::Value,
393 _context: &ToolContext,
394 ) -> std::result::Result<ToolResult2, String> {
395 Ok(ToolResult2::success("dummy result"))
396 }
397 }
398
399 #[test]
400 fn test_registry_register() {
401 let mut registry = ToolRegistry::new();
402 registry.register_typed(DummyTool::new("test")).unwrap();
403 assert_eq!(registry.len(), 1);
404 assert!(registry.get("test").is_some());
405 }
406
407 #[test]
408 fn test_registry_duplicate() {
409 let mut registry = ToolRegistry::new();
410 registry.register_typed(DummyTool::new("test")).unwrap();
411 let result = registry.register_typed(DummyTool::new("test"));
412 assert!(result.is_err());
413 }
414
415 #[test]
416 fn test_registry_normalized_conflict() {
417 let mut registry = ToolRegistry::new();
418 registry.register_typed(DummyTool::new("my_tool")).unwrap();
419 let result = registry.register_typed(DummyTool::new("my-tool"));
420 assert!(result.is_err());
421 }
422
423 #[test]
424 fn test_registry_get_all_specs() {
425 let mut registry = ToolRegistry::new();
426 registry.register_typed(DummyTool::new("tool1")).unwrap();
427 registry.register_typed(DummyTool::new("tool2")).unwrap();
428 let specs = registry.get_all_tool_specs();
429 assert_eq!(specs.len(), 2);
430 }
431}