sqry_core/session/
watcher.rs1use std::collections::HashMap;
21use std::ffi::OsStr;
22use std::path::{Path, PathBuf};
23use std::sync::mpsc::{Receiver, TryRecvError};
24use std::sync::{Arc, Mutex, MutexGuard};
25
26use notify::{
27 Config as NotifyConfig, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher,
28};
29
30use super::error::{SessionError, SessionResult};
31use crate::config::buffers::watch_event_queue_capacity;
32
33const MANIFEST_FILE_NAME: &str = "manifest.json";
39
40const GRAPH_DIR_SEGMENT: &str = "graph";
42
43const SQRY_DIR_SEGMENT: &str = ".sqry";
46
47#[cfg(test)]
54fn manifest_path(workspace: &Path) -> PathBuf {
55 workspace
56 .join(SQRY_DIR_SEGMENT)
57 .join(GRAPH_DIR_SEGMENT)
58 .join(MANIFEST_FILE_NAME)
59}
60
61type Callback = Arc<dyn Fn() + Send + Sync + 'static>;
62
63struct WatcherState {
64 watcher: RecommendedWatcher,
65 rx: Receiver<notify::Result<Event>>,
66 callbacks: Arc<Mutex<HashMap<PathBuf, Callback>>>,
67}
68
69impl WatcherState {
70 fn lock_callbacks(&self) -> MutexGuard<'_, HashMap<PathBuf, Callback>> {
71 self.callbacks
72 .lock()
73 .unwrap_or_else(std::sync::PoisonError::into_inner)
74 }
75}
76
77pub struct FileWatcher {
80 state: Option<WatcherState>,
81}
82
83impl FileWatcher {
84 pub fn new() -> SessionResult<Self> {
90 let capacity = watch_event_queue_capacity();
93 let (tx, rx) = std::sync::mpsc::sync_channel(capacity);
94
95 let watcher = RecommendedWatcher::new(
96 move |event| {
97 let _ = tx.send(event);
98 },
99 NotifyConfig::default(),
100 )
101 .map_err(SessionError::WatcherInit)?;
102
103 Ok(Self {
104 state: Some(WatcherState {
105 watcher,
106 rx,
107 callbacks: Arc::new(Mutex::new(HashMap::new())),
108 }),
109 })
110 }
111
112 #[must_use]
114 pub fn disabled() -> Self {
115 Self { state: None }
116 }
117
118 pub fn watch<F>(&mut self, path: PathBuf, on_change: F) -> SessionResult<()>
130 where
131 F: Fn() + Send + Sync + 'static,
132 {
133 let Some(state) = &mut self.state else {
134 return Ok(());
136 };
137
138 if state.lock_callbacks().contains_key(&path) {
140 return Ok(());
141 }
142
143 state
144 .watcher
145 .watch(&path, RecursiveMode::NonRecursive)
146 .map_err(|source| SessionError::WatchIndex {
147 path: path.clone(),
148 source,
149 })?;
150
151 state.lock_callbacks().insert(path, Arc::new(on_change));
152
153 Ok(())
154 }
155
156 pub fn unwatch(&mut self, path: &Path) -> SessionResult<()> {
162 let Some(state) = &mut self.state else {
163 return Ok(());
164 };
165
166 if state.lock_callbacks().remove(path).is_some() {
167 state
168 .watcher
169 .unwatch(path)
170 .map_err(|source| SessionError::UnwatchIndex {
171 path: path.to_path_buf(),
172 source,
173 })?;
174 }
175
176 Ok(())
177 }
178
179 #[cfg(test)]
185 #[must_use]
186 pub(crate) fn watched_paths(&self) -> Vec<PathBuf> {
187 self.state
188 .as_ref()
189 .map(|state| state.lock_callbacks().keys().cloned().collect())
190 .unwrap_or_default()
191 }
192
193 #[cfg(test)]
197 pub(crate) fn trigger_for_test(&self, path: &Path) -> bool {
198 let callback = self
199 .state
200 .as_ref()
201 .and_then(|state| state.lock_callbacks().get(path).cloned());
202
203 if let Some(callback) = callback {
204 callback();
205 true
206 } else {
207 false
208 }
209 }
210
211 pub fn process_events(&mut self) -> SessionResult<()> {
217 let Some(state) = &mut self.state else {
218 return Ok(());
219 };
220
221 loop {
222 match state.rx.try_recv() {
223 Ok(Ok(event)) => Self::handle_event(state, &event),
224 Ok(Err(err)) => {
225 log::warn!("file watcher error: {err}");
226 }
227 Err(TryRecvError::Empty | TryRecvError::Disconnected) => break,
228 }
229 }
230
231 Ok(())
232 }
233
234 pub fn wait_and_process(&mut self, duration: std::time::Duration) -> SessionResult<()> {
246 let Some(state) = &mut self.state else {
247 return Ok(());
248 };
249
250 let deadline = std::time::Instant::now() + duration;
251
252 while std::time::Instant::now() < deadline {
253 let remaining = deadline.saturating_duration_since(std::time::Instant::now());
254 let poll_interval = std::time::Duration::from_millis(10).min(remaining);
255
256 match state.rx.recv_timeout(poll_interval) {
257 Ok(Ok(event)) => Self::handle_event(state, &event),
258 Ok(Err(err)) => {
259 log::warn!("file watcher error: {err}");
260 }
261 Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
262 }
264 Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
265 break;
266 }
267 }
268 }
269
270 Ok(())
271 }
272
273 fn handle_event(state: &WatcherState, event: &Event) {
274 use EventKind::{Any, Create, Modify, Remove};
275
276 let relevant = matches!(event.kind, Modify(_) | Create(_) | Remove(_) | Any);
277
278 if !relevant {
279 return;
280 }
281
282 let mut callbacks_to_run: Vec<Callback> = Vec::new();
284
285 {
286 let callbacks = state.lock_callbacks();
287 for path in &event.paths {
288 if path
293 .file_name()
294 .is_some_and(|name| name == OsStr::new(MANIFEST_FILE_NAME))
295 && let Some(graph_dir) = path.parent()
296 && graph_dir
297 .file_name()
298 .is_some_and(|name| name == OsStr::new(GRAPH_DIR_SEGMENT))
299 && let Some(sqry_dir) = graph_dir.parent()
300 && sqry_dir
301 .file_name()
302 .is_some_and(|name| name == OsStr::new(SQRY_DIR_SEGMENT))
303 && let Some(callback) = callbacks.get(path)
304 {
305 callbacks_to_run.push(Arc::clone(callback));
306 }
307 }
308 }
309
310 for callback in callbacks_to_run {
311 callback();
312 }
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use std::sync::atomic::{AtomicBool, Ordering};
320 use std::time::Duration;
321 use tempfile::tempdir;
322
323 fn event_timeout() -> Duration {
324 let base = if cfg!(target_os = "macos") {
326 Duration::from_secs(3)
327 } else {
328 Duration::from_secs(1) };
330
331 if std::env::var("CI").is_ok() {
333 base * 2
334 } else {
335 base
336 }
337 }
338
339 #[test]
340 #[cfg_attr(target_os = "macos", ignore = "FSEvents timing flaky in CI")]
341 fn detects_changes_to_index_file() {
342 let temp = tempdir().unwrap();
343 let workspace = temp.path();
344 let manifest = manifest_path(workspace);
345 std::fs::create_dir_all(manifest.parent().unwrap()).unwrap();
346 std::fs::write(&manifest, b"initial").unwrap();
347
348 let mut watcher = FileWatcher::new().unwrap();
349
350 let triggered = Arc::new(AtomicBool::new(false));
351 let flag = Arc::clone(&triggered);
352 watcher
353 .watch(manifest.clone(), move || {
354 flag.store(true, Ordering::SeqCst);
355 })
356 .unwrap();
357
358 std::fs::write(&manifest, b"modified").unwrap();
359
360 watcher.wait_and_process(event_timeout()).unwrap();
361
362 assert!(triggered.load(Ordering::SeqCst));
363 }
364
365 #[test]
366 fn disabled_watcher_is_noop() {
367 let temp = tempdir().unwrap();
368 let workspace = temp.path();
369 let manifest = manifest_path(workspace);
370 std::fs::create_dir_all(manifest.parent().unwrap()).unwrap();
371 std::fs::write(&manifest, b"data").unwrap();
372
373 let mut watcher = FileWatcher::disabled();
374 watcher
375 .watch(manifest, || {
376 panic!("disabled watcher should not invoke callback");
377 })
378 .unwrap();
379 watcher.process_events().unwrap();
381 }
382}