strands_agents/tools/
watcher.rs

1//! Tool watcher for hot reloading tools during development.
2
3use std::collections::{HashMap, HashSet};
4use std::path::PathBuf;
5use std::sync::{Arc, Mutex, RwLock};
6use std::time::Duration;
7
8use crate::tools::registry::ToolRegistry;
9use crate::types::errors::Result;
10
11/// Watches tool directories for changes and reloads tools when modified.
12pub struct ToolWatcher {
13    tool_registry: Arc<RwLock<ToolRegistry>>,
14    change_handler: Arc<ToolChangeHandler>,
15    watched_dirs: HashSet<PathBuf>,
16    running: Arc<Mutex<bool>>,
17}
18
19impl ToolWatcher {
20    /// Create a new tool watcher for the given tool registry.
21    pub fn new(tool_registry: Arc<RwLock<ToolRegistry>>) -> Self {
22        let change_handler = Arc::new(ToolChangeHandler::new(tool_registry.clone()));
23        Self {
24            tool_registry,
25            change_handler,
26            watched_dirs: HashSet::new(),
27            running: Arc::new(Mutex::new(false)),
28        }
29    }
30
31    /// Add a directory to watch.
32    pub fn watch_dir(&mut self, dir: PathBuf) {
33        self.watched_dirs.insert(dir);
34    }
35
36    /// Get the tool registry.
37    pub fn tool_registry(&self) -> &Arc<RwLock<ToolRegistry>> {
38        &self.tool_registry
39    }
40
41    /// Get the change handler.
42    pub fn change_handler(&self) -> &Arc<ToolChangeHandler> {
43        &self.change_handler
44    }
45
46    /// Start watching for changes using polling.
47    pub fn start(&self) -> Result<()> {
48        let mut running = self.running.lock().unwrap();
49        if *running {
50            return Ok(());
51        }
52        *running = true;
53
54        let running_flag = self.running.clone();
55        let handler = self.change_handler.clone();
56        let dirs: Vec<PathBuf> = self.watched_dirs.iter().cloned().collect();
57
58        tokio::spawn(async move {
59            while *running_flag.lock().unwrap() {
60                for dir in &dirs {
61                    let changed = handler.poll_changes(dir);
62                    for path in changed {
63                        handler.on_modified(&path);
64                    }
65                }
66                tokio::time::sleep(Duration::from_secs(1)).await;
67            }
68        });
69
70        tracing::debug!("Tool watcher started for {} directories", self.watched_dirs.len());
71
72        Ok(())
73    }
74
75    /// Stop watching for changes.
76    pub fn stop(&self) {
77        let mut running = self.running.lock().unwrap();
78        *running = false;
79        tracing::debug!("Tool watcher stopped");
80    }
81
82    /// Check if the watcher is running.
83    pub fn is_running(&self) -> bool {
84        *self.running.lock().unwrap()
85    }
86
87    /// Get the watched directories.
88    pub fn watched_dirs(&self) -> &HashSet<PathBuf> {
89        &self.watched_dirs
90    }
91}
92
93/// Change handler for tool file modifications.
94pub struct ToolChangeHandler {
95    tool_registry: Arc<RwLock<ToolRegistry>>,
96    file_timestamps: Mutex<HashMap<PathBuf, std::time::SystemTime>>,
97}
98
99impl ToolChangeHandler {
100    pub fn new(tool_registry: Arc<RwLock<ToolRegistry>>) -> Self {
101        Self {
102            tool_registry,
103            file_timestamps: Mutex::new(HashMap::new()),
104        }
105    }
106
107    /// Handle a file modification event.
108    pub fn on_modified(&self, path: &PathBuf) {
109        if let Some(ext) = path.extension() {
110            if ext != "rs" {
111                return;
112            }
113        } else {
114            return;
115        }
116
117        if let Some(stem) = path.file_stem() {
118            let stem_str = stem.to_string_lossy();
119            if stem_str == "mod" || stem_str.starts_with("_") {
120                return;
121            }
122
123            tracing::debug!("Tool change detected: {}", stem_str);
124
125            if let Ok(mut registry) = self.tool_registry.write() {
126                if let Err(e) = registry.reload_tool(&stem_str) {
127                    tracing::error!("Failed to reload tool {}: {}", stem_str, e);
128                }
129            }
130        }
131    }
132
133    /// Check for modifications using polling.
134    pub fn poll_changes(&self, dir: &PathBuf) -> Vec<PathBuf> {
135        let mut changed = Vec::new();
136        let mut timestamps = self.file_timestamps.lock().unwrap();
137
138        if let Ok(entries) = std::fs::read_dir(dir) {
139            for entry in entries.flatten() {
140                let path = entry.path();
141                if path.extension().map(|e| e == "rs").unwrap_or(false) {
142                    if let Ok(metadata) = std::fs::metadata(&path) {
143                        if let Ok(modified) = metadata.modified() {
144                            let prev = timestamps.get(&path).copied();
145                            if prev.map(|p| modified > p).unwrap_or(true) {
146                                timestamps.insert(path.clone(), modified);
147                                if prev.is_some() {
148                                    changed.push(path);
149                                }
150                            }
151                        }
152                    }
153                }
154            }
155        }
156
157        changed
158    }
159}
160
161/// Master handler that delegates to all registered handlers.
162pub struct MasterChangeHandler {
163    dir_path: PathBuf,
164    handlers: Arc<RwLock<HashMap<String, Arc<ToolChangeHandler>>>>,
165}
166
167impl MasterChangeHandler {
168    pub fn new(dir_path: PathBuf) -> Self {
169        Self {
170            dir_path,
171            handlers: Arc::new(RwLock::new(HashMap::new())),
172        }
173    }
174
175    /// Register a handler for a registry.
176    pub fn add_handler(&self, registry_id: String, handler: Arc<ToolChangeHandler>) {
177        if let Ok(mut handlers) = self.handlers.write() {
178            handlers.insert(registry_id, handler);
179        }
180    }
181
182    /// Remove a handler.
183    pub fn remove_handler(&self, registry_id: &str) {
184        if let Ok(mut handlers) = self.handlers.write() {
185            handlers.remove(registry_id);
186        }
187    }
188
189    /// Handle a file modification.
190    pub fn on_modified(&self, path: &PathBuf) {
191        if let Ok(handlers) = self.handlers.read() {
192            for handler in handlers.values() {
193                handler.on_modified(path);
194            }
195        }
196    }
197
198    /// Get the directory path.
199    pub fn dir_path(&self) -> &PathBuf {
200        &self.dir_path
201    }
202}
203
204/// Polling-based watcher that periodically checks for file changes.
205pub struct PollingWatcher {
206    watched_dirs: Vec<PathBuf>,
207    interval: Duration,
208    running: Arc<Mutex<bool>>,
209    handler: Arc<ToolChangeHandler>,
210}
211
212impl PollingWatcher {
213    /// Create a new polling watcher.
214    pub fn new(tool_registry: Arc<RwLock<ToolRegistry>>, interval: Duration) -> Self {
215        let handler = Arc::new(ToolChangeHandler::new(tool_registry));
216        Self {
217            watched_dirs: Vec::new(),
218            interval,
219            running: Arc::new(Mutex::new(false)),
220            handler,
221        }
222    }
223
224    /// Add a directory to watch.
225    pub fn watch_dir(&mut self, dir: PathBuf) {
226        self.watched_dirs.push(dir);
227    }
228
229    /// Start the polling loop in a background task.
230    pub fn start(&self) {
231        let mut running = self.running.lock().unwrap();
232        if *running {
233            return;
234        }
235        *running = true;
236
237        let running_flag = self.running.clone();
238        let handler = self.handler.clone();
239        let dirs = self.watched_dirs.clone();
240        let interval = self.interval;
241
242        tokio::spawn(async move {
243            while *running_flag.lock().unwrap() {
244                for dir in &dirs {
245                    let changed = handler.poll_changes(dir);
246                    for path in changed {
247                        handler.on_modified(&path);
248                    }
249                }
250                tokio::time::sleep(interval).await;
251            }
252        });
253
254        tracing::info!("Polling watcher started with {:?} interval", self.interval);
255    }
256
257    /// Stop the polling loop.
258    pub fn stop(&self) {
259        let mut running = self.running.lock().unwrap();
260        *running = false;
261        tracing::info!("Polling watcher stopped");
262    }
263
264    /// Check if running.
265    pub fn is_running(&self) -> bool {
266        *self.running.lock().unwrap()
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_tool_watcher_creation() {
276        let registry = Arc::new(RwLock::new(ToolRegistry::new()));
277        let mut watcher = ToolWatcher::new(registry);
278        
279        watcher.watch_dir(PathBuf::from("/tmp/tools"));
280        assert_eq!(watcher.watched_dirs().len(), 1);
281    }
282
283    #[test]
284    fn test_tool_change_handler() {
285        let registry = Arc::new(RwLock::new(ToolRegistry::new()));
286        let handler = ToolChangeHandler::new(registry);
287        
288        handler.on_modified(&PathBuf::from("/tmp/test.txt"));
289    }
290
291    #[test]
292    fn test_master_change_handler() {
293        let handler = MasterChangeHandler::new(PathBuf::from("/tmp/tools"));
294        assert_eq!(handler.dir_path(), &PathBuf::from("/tmp/tools"));
295    }
296}
297