strands_agents/tools/
loader.rs

1//! Tool loading utilities for dynamic tool discovery.
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6
7use crate::tools::AgentTool;
8use crate::types::errors::{Result, StrandsError};
9
10/// Configuration for loading tools from a directory.
11#[derive(Debug, Clone)]
12pub struct ToolLoaderConfig {
13    pub tool_dirs: Vec<PathBuf>,
14    pub recursive: bool,
15    pub file_patterns: Vec<String>,
16}
17
18impl Default for ToolLoaderConfig {
19    fn default() -> Self {
20        Self {
21            tool_dirs: Vec::new(),
22            recursive: false,
23            file_patterns: vec!["*.rs".to_string()],
24        }
25    }
26}
27
28impl ToolLoaderConfig {
29    pub fn new() -> Self {
30        Self::default()
31    }
32
33    pub fn add_dir(mut self, dir: impl Into<PathBuf>) -> Self {
34        self.tool_dirs.push(dir.into());
35        self
36    }
37
38    pub fn recursive(mut self, recursive: bool) -> Self {
39        self.recursive = recursive;
40        self
41    }
42}
43
44/// Tool loader for discovering and loading tools from directories.
45pub struct ToolLoader {
46    config: ToolLoaderConfig,
47    loaded_tools: HashMap<String, Arc<dyn AgentTool>>,
48    tool_paths: HashMap<String, PathBuf>,
49}
50
51impl ToolLoader {
52    pub fn new(config: ToolLoaderConfig) -> Self {
53        Self {
54            config,
55            loaded_tools: HashMap::new(),
56            tool_paths: HashMap::new(),
57        }
58    }
59
60    /// Get the list of tool directories being watched.
61    pub fn tool_dirs(&self) -> &[PathBuf] {
62        &self.config.tool_dirs
63    }
64
65    /// Get all loaded tools.
66    pub fn tools(&self) -> Vec<Arc<dyn AgentTool>> {
67        self.loaded_tools.values().cloned().collect()
68    }
69
70    /// Get a specific tool by name.
71    pub fn get_tool(&self, name: &str) -> Option<Arc<dyn AgentTool>> {
72        self.loaded_tools.get(name).cloned()
73    }
74
75    /// Check if a tool is loaded.
76    pub fn has_tool(&self, name: &str) -> bool {
77        self.loaded_tools.contains_key(name)
78    }
79
80    /// Register a tool.
81    pub fn register_tool(&mut self, tool: Arc<dyn AgentTool>, path: Option<PathBuf>) {
82        let name = tool.tool_name().to_string();
83        self.loaded_tools.insert(name.clone(), tool);
84        if let Some(p) = path {
85            self.tool_paths.insert(name, p);
86        }
87    }
88
89    /// Unregister a tool.
90    pub fn unregister_tool(&mut self, name: &str) -> Option<Arc<dyn AgentTool>> {
91        self.tool_paths.remove(name);
92        self.loaded_tools.remove(name)
93    }
94
95    /// Get the path for a tool.
96    pub fn tool_path(&self, name: &str) -> Option<&PathBuf> {
97        self.tool_paths.get(name)
98    }
99
100    /// Scan directories for tool files.
101    pub fn scan_directories(&self) -> Result<Vec<PathBuf>> {
102        let mut files = Vec::new();
103
104        for dir in &self.config.tool_dirs {
105            if !dir.exists() {
106                continue;
107            }
108
109            self.scan_directory(dir, &mut files)?;
110        }
111
112        Ok(files)
113    }
114
115    fn scan_directory(&self, dir: &Path, files: &mut Vec<PathBuf>) -> Result<()> {
116        let entries = std::fs::read_dir(dir).map_err(|e| StrandsError::InternalError {
117            message: format!("Failed to read directory {}: {}", dir.display(), e),
118        })?;
119
120        for entry in entries.flatten() {
121            let path = entry.path();
122
123            if path.is_dir() && self.config.recursive {
124                self.scan_directory(&path, files)?;
125            } else if path.is_file() {
126                if let Some(ext) = path.extension() {
127                    if ext == "rs" {
128                        files.push(path);
129                    }
130                }
131            }
132        }
133
134        Ok(())
135    }
136}
137
138/// Callback type for tool reload events.
139pub type ReloadCallback = Arc<dyn Fn(&str) + Send + Sync>;
140
141/// Tool watcher for monitoring tool changes during development.
142///
143/// This implementation provides a callback-based notification system for tool changes.
144/// Unlike the Python SDK which uses `watchdog` for automatic file system monitoring,
145/// Rust tools are typically compiled and don't support runtime hot-reloading.
146///
147/// This watcher is useful for:
148/// - Notifying when tools are programmatically reloaded
149/// - Integration with external file watchers (e.g., `notify` crate)
150/// - Development-time tooling that manages tool lifecycle
151///
152/// For automatic file watching, integrate with the `notify` crate and call
153/// `notify_modified` when file changes are detected.
154pub struct ToolWatcher {
155    loader: ToolLoader,
156    on_reload: Option<ReloadCallback>,
157}
158
159impl ToolWatcher {
160    pub fn new(loader: ToolLoader) -> Self {
161        Self {
162            loader,
163            on_reload: None,
164        }
165    }
166
167    pub fn on_reload(mut self, callback: ReloadCallback) -> Self {
168        self.on_reload = Some(callback);
169        self
170    }
171
172    /// Get the tool loader.
173    pub fn loader(&self) -> &ToolLoader {
174        &self.loader
175    }
176
177    /// Get mutable reference to the tool loader.
178    pub fn loader_mut(&mut self) -> &mut ToolLoader {
179        &mut self.loader
180    }
181
182    /// Notify that a tool has been modified.
183    pub fn notify_modified(&self, tool_name: &str) {
184        if let Some(ref callback) = self.on_reload {
185            callback(tool_name);
186        }
187    }
188
189    /// Get the directories being watched.
190    pub fn watched_dirs(&self) -> &[PathBuf] {
191        self.loader.tool_dirs()
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn test_tool_loader_config() {
201        let config = ToolLoaderConfig::new()
202            .add_dir("/tmp/tools")
203            .recursive(true);
204
205        assert_eq!(config.tool_dirs.len(), 1);
206        assert!(config.recursive);
207    }
208
209    #[test]
210    fn test_tool_loader_creation() {
211        let config = ToolLoaderConfig::new();
212        let loader = ToolLoader::new(config);
213
214        assert!(loader.tools().is_empty());
215        assert!(loader.tool_dirs().is_empty());
216    }
217}
218