Skip to main content

sen_plugin_host/
watcher.rs

1//! Hot reload file watcher for plugins
2//!
3//! Watches plugin directories for changes and automatically
4//! loads, reloads, or unloads plugins.
5
6use crate::{LoaderError, PluginRegistry};
7use notify::{Event, RecommendedWatcher, RecursiveMode, Watcher};
8use std::path::{Path, PathBuf};
9use std::time::Duration;
10use tokio::sync::mpsc;
11
12/// Configuration for the hot reload watcher
13#[derive(Debug, Clone)]
14pub struct WatcherConfig {
15    /// Debounce duration for file events
16    pub debounce: Duration,
17    /// Whether to load existing plugins on start
18    pub load_existing: bool,
19}
20
21impl Default for WatcherConfig {
22    fn default() -> Self {
23        Self {
24            debounce: Duration::from_millis(500),
25            load_existing: true,
26        }
27    }
28}
29
30/// Hot reload watcher for plugin directories
31pub struct HotReloadWatcher {
32    registry: PluginRegistry,
33    _watcher: RecommendedWatcher,
34    shutdown_tx: mpsc::Sender<()>,
35}
36
37impl HotReloadWatcher {
38    /// Create a new hot reload watcher for the given directories
39    pub async fn new(
40        registry: PluginRegistry,
41        directories: impl IntoIterator<Item = impl AsRef<Path>>,
42        config: WatcherConfig,
43    ) -> Result<Self, WatcherError> {
44        let directories: Vec<PathBuf> = directories
45            .into_iter()
46            .map(|p| p.as_ref().to_path_buf())
47            .collect();
48
49        // Load existing plugins if configured
50        if config.load_existing {
51            for dir in &directories {
52                if dir.exists() && dir.is_dir() {
53                    Self::load_directory(&registry, dir).await?;
54                }
55            }
56        }
57
58        // Create async channel for file events
59        let (event_tx, mut event_rx) = mpsc::channel::<WatchEvent>(100);
60        let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
61
62        // Create the file watcher
63        let tx_clone = event_tx.clone();
64        let mut watcher = notify::recommended_watcher(move |res: Result<Event, notify::Error>| {
65            if let Ok(event) = res {
66                let _ = tx_clone.blocking_send(WatchEvent::FileEvent(event));
67            }
68        })
69        .map_err(WatcherError::WatcherInit)?;
70
71        // Watch directories
72        for dir in &directories {
73            if dir.exists() {
74                watcher
75                    .watch(dir, RecursiveMode::NonRecursive)
76                    .map_err(WatcherError::WatcherInit)?;
77                tracing::info!(dir = %dir.display(), "Watching directory for plugins");
78            } else {
79                tracing::warn!(dir = %dir.display(), "Directory does not exist, skipping");
80            }
81        }
82
83        // Spawn event processing task
84        let registry_clone = registry.clone();
85        let debounce = config.debounce;
86        tokio::spawn(async move {
87            let mut pending_events: Vec<PathBuf> = Vec::new();
88            let mut debounce_timer: Option<tokio::time::Instant> = None;
89
90            loop {
91                tokio::select! {
92                    // Check for shutdown
93                    _ = shutdown_rx.recv() => {
94                        tracing::info!("Hot reload watcher shutting down");
95                        break;
96                    }
97
98                    // Process file events
99                    Some(WatchEvent::FileEvent(event)) = event_rx.recv() => {
100                        for path in event.paths {
101                            if Self::is_wasm_file(&path) {
102                                if !pending_events.contains(&path) {
103                                    pending_events.push(path);
104                                }
105                                debounce_timer = Some(tokio::time::Instant::now() + debounce);
106                            }
107                        }
108                    }
109
110                    // Check debounce timer
111                    _ = async {
112                        if let Some(deadline) = debounce_timer {
113                            tokio::time::sleep_until(deadline).await;
114                        } else {
115                            std::future::pending::<()>().await;
116                        }
117                    } => {
118                        // Process pending events
119                        for path in pending_events.drain(..) {
120                            Self::handle_file_change(&registry_clone, &path).await;
121                        }
122                        debounce_timer = None;
123                    }
124                }
125            }
126        });
127
128        Ok(Self {
129            registry,
130            _watcher: watcher,
131            shutdown_tx,
132        })
133    }
134
135    /// Load all plugins from a directory
136    async fn load_directory(registry: &PluginRegistry, dir: &Path) -> Result<(), WatcherError> {
137        let entries = std::fs::read_dir(dir).map_err(|e| {
138            WatcherError::Io(format!("Failed to read directory {}: {}", dir.display(), e))
139        })?;
140
141        for entry in entries.flatten() {
142            let path = entry.path();
143            if Self::is_wasm_file(&path) {
144                match registry.load_plugin(&path).await {
145                    Ok(cmd) => {
146                        tracing::info!(command = %cmd, path = %path.display(), "Loaded plugin");
147                    }
148                    Err(e) => {
149                        tracing::warn!(path = %path.display(), error = %e, "Failed to load plugin");
150                    }
151                }
152            }
153        }
154
155        Ok(())
156    }
157
158    /// Handle a file change event
159    async fn handle_file_change(registry: &PluginRegistry, path: &Path) {
160        if path.exists() {
161            // File created or modified - load/reload
162            match registry.reload_by_path(path).await {
163                Ok(cmd) => {
164                    tracing::info!(command = %cmd, path = %path.display(), "Plugin reloaded");
165                }
166                Err(e) => {
167                    tracing::warn!(path = %path.display(), error = %e, "Failed to reload plugin");
168                }
169            }
170        } else {
171            // File deleted - unload
172            if let Some(cmd) = registry.unload_by_path(path).await {
173                tracing::info!(command = %cmd, path = %path.display(), "Plugin unloaded (file deleted)");
174            }
175        }
176    }
177
178    /// Check if a path is a wasm file
179    fn is_wasm_file(path: &Path) -> bool {
180        path.extension().map(|ext| ext == "wasm").unwrap_or(false)
181    }
182
183    /// Get a reference to the plugin registry
184    pub fn registry(&self) -> &PluginRegistry {
185        &self.registry
186    }
187
188    /// Shutdown the watcher
189    pub async fn shutdown(self) {
190        let _ = self.shutdown_tx.send(()).await;
191    }
192}
193
194enum WatchEvent {
195    FileEvent(Event),
196}
197
198/// Errors that can occur during watching
199#[derive(Debug, thiserror::Error)]
200pub enum WatcherError {
201    #[error("Failed to initialize watcher: {0}")]
202    WatcherInit(#[source] notify::Error),
203
204    #[error("IO error: {0}")]
205    Io(String),
206
207    #[error("Loader error: {0}")]
208    Loader(#[from] LoaderError),
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use std::fs;
215    use tempfile::TempDir;
216
217    const HELLO_PLUGIN_WASM: &[u8] = include_bytes!(
218        "../../examples/hello-plugin/target/wasm32-unknown-unknown/release/hello_plugin.wasm"
219    );
220
221    #[tokio::test]
222    async fn test_watcher_loads_existing() {
223        let temp = TempDir::new().unwrap();
224        let plugin_path = temp.path().join("hello.wasm");
225        fs::write(&plugin_path, HELLO_PLUGIN_WASM).unwrap();
226
227        let registry = PluginRegistry::new().unwrap();
228        let _watcher = HotReloadWatcher::new(
229            registry.clone(),
230            vec![temp.path()],
231            WatcherConfig::default(),
232        )
233        .await
234        .unwrap();
235
236        // Should have loaded the existing plugin
237        assert!(registry.has_command("hello").await);
238    }
239
240    #[tokio::test]
241    async fn test_watcher_hot_reload() {
242        let temp = TempDir::new().unwrap();
243
244        let registry = PluginRegistry::new().unwrap();
245        let _watcher = HotReloadWatcher::new(
246            registry.clone(),
247            vec![temp.path()],
248            WatcherConfig {
249                debounce: Duration::from_millis(100),
250                load_existing: true,
251            },
252        )
253        .await
254        .unwrap();
255
256        // Initially empty
257        assert!(!registry.has_command("hello").await);
258
259        // Add a plugin file
260        let plugin_path = temp.path().join("hello.wasm");
261        fs::write(&plugin_path, HELLO_PLUGIN_WASM).unwrap();
262
263        // Wait for debounce + processing
264        tokio::time::sleep(Duration::from_millis(300)).await;
265
266        // Should have loaded the plugin
267        assert!(registry.has_command("hello").await);
268
269        // Delete the plugin file
270        fs::remove_file(&plugin_path).unwrap();
271
272        // Wait for debounce + processing
273        tokio::time::sleep(Duration::from_millis(300)).await;
274
275        // Should have unloaded the plugin
276        assert!(!registry.has_command("hello").await);
277    }
278}