strands_agents/tools/
loader.rs1use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6
7use crate::tools::AgentTool;
8use crate::types::errors::{Result, StrandsError};
9
10#[derive(Debug, Clone)]
12pub struct ToolLoaderConfig {
13 pub tool_dirs: Vec<PathBuf>,
14 pub recursive: bool,
15 pub file_patterns: Vec<String>,
16}
17
18impl Default for ToolLoaderConfig {
19 fn default() -> Self {
20 Self {
21 tool_dirs: Vec::new(),
22 recursive: false,
23 file_patterns: vec!["*.rs".to_string()],
24 }
25 }
26}
27
28impl ToolLoaderConfig {
29 pub fn new() -> Self {
30 Self::default()
31 }
32
33 pub fn add_dir(mut self, dir: impl Into<PathBuf>) -> Self {
34 self.tool_dirs.push(dir.into());
35 self
36 }
37
38 pub fn recursive(mut self, recursive: bool) -> Self {
39 self.recursive = recursive;
40 self
41 }
42}
43
44pub struct ToolLoader {
46 config: ToolLoaderConfig,
47 loaded_tools: HashMap<String, Arc<dyn AgentTool>>,
48 tool_paths: HashMap<String, PathBuf>,
49}
50
51impl ToolLoader {
52 pub fn new(config: ToolLoaderConfig) -> Self {
53 Self {
54 config,
55 loaded_tools: HashMap::new(),
56 tool_paths: HashMap::new(),
57 }
58 }
59
60 pub fn tool_dirs(&self) -> &[PathBuf] {
62 &self.config.tool_dirs
63 }
64
65 pub fn tools(&self) -> Vec<Arc<dyn AgentTool>> {
67 self.loaded_tools.values().cloned().collect()
68 }
69
70 pub fn get_tool(&self, name: &str) -> Option<Arc<dyn AgentTool>> {
72 self.loaded_tools.get(name).cloned()
73 }
74
75 pub fn has_tool(&self, name: &str) -> bool {
77 self.loaded_tools.contains_key(name)
78 }
79
80 pub fn register_tool(&mut self, tool: Arc<dyn AgentTool>, path: Option<PathBuf>) {
82 let name = tool.tool_name().to_string();
83 self.loaded_tools.insert(name.clone(), tool);
84 if let Some(p) = path {
85 self.tool_paths.insert(name, p);
86 }
87 }
88
89 pub fn unregister_tool(&mut self, name: &str) -> Option<Arc<dyn AgentTool>> {
91 self.tool_paths.remove(name);
92 self.loaded_tools.remove(name)
93 }
94
95 pub fn tool_path(&self, name: &str) -> Option<&PathBuf> {
97 self.tool_paths.get(name)
98 }
99
100 pub fn scan_directories(&self) -> Result<Vec<PathBuf>> {
102 let mut files = Vec::new();
103
104 for dir in &self.config.tool_dirs {
105 if !dir.exists() {
106 continue;
107 }
108
109 self.scan_directory(dir, &mut files)?;
110 }
111
112 Ok(files)
113 }
114
115 fn scan_directory(&self, dir: &Path, files: &mut Vec<PathBuf>) -> Result<()> {
116 let entries = std::fs::read_dir(dir).map_err(|e| StrandsError::InternalError {
117 message: format!("Failed to read directory {}: {}", dir.display(), e),
118 })?;
119
120 for entry in entries.flatten() {
121 let path = entry.path();
122
123 if path.is_dir() && self.config.recursive {
124 self.scan_directory(&path, files)?;
125 } else if path.is_file() {
126 if let Some(ext) = path.extension() {
127 if ext == "rs" {
128 files.push(path);
129 }
130 }
131 }
132 }
133
134 Ok(())
135 }
136}
137
138pub type ReloadCallback = Arc<dyn Fn(&str) + Send + Sync>;
140
141pub struct ToolWatcher {
155 loader: ToolLoader,
156 on_reload: Option<ReloadCallback>,
157}
158
159impl ToolWatcher {
160 pub fn new(loader: ToolLoader) -> Self {
161 Self {
162 loader,
163 on_reload: None,
164 }
165 }
166
167 pub fn on_reload(mut self, callback: ReloadCallback) -> Self {
168 self.on_reload = Some(callback);
169 self
170 }
171
172 pub fn loader(&self) -> &ToolLoader {
174 &self.loader
175 }
176
177 pub fn loader_mut(&mut self) -> &mut ToolLoader {
179 &mut self.loader
180 }
181
182 pub fn notify_modified(&self, tool_name: &str) {
184 if let Some(ref callback) = self.on_reload {
185 callback(tool_name);
186 }
187 }
188
189 pub fn watched_dirs(&self) -> &[PathBuf] {
191 self.loader.tool_dirs()
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn test_tool_loader_config() {
201 let config = ToolLoaderConfig::new()
202 .add_dir("/tmp/tools")
203 .recursive(true);
204
205 assert_eq!(config.tool_dirs.len(), 1);
206 assert!(config.recursive);
207 }
208
209 #[test]
210 fn test_tool_loader_creation() {
211 let config = ToolLoaderConfig::new();
212 let loader = ToolLoader::new(config);
213
214 assert!(loader.tools().is_empty());
215 assert!(loader.tool_dirs().is_empty());
216 }
217}
218