sqry_core/session/
watcher.rs1use std::collections::HashMap;
21use std::ffi::OsStr;
22use std::path::{Path, PathBuf};
23use std::sync::mpsc::{Receiver, TryRecvError};
24use std::sync::{Arc, Mutex, MutexGuard};
25
26use notify::{
27 Config as NotifyConfig, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher,
28};
29
30use super::error::{SessionError, SessionResult};
31use crate::config::buffers::watch_event_queue_capacity;
32
33const MANIFEST_FILE_NAME: &str = "manifest.json";
39
40const GRAPH_DIR_SEGMENT: &str = "graph";
42
43const SQRY_DIR_SEGMENT: &str = ".sqry";
46
47#[cfg(test)]
54fn manifest_path(workspace: &Path) -> PathBuf {
55 workspace
56 .join(SQRY_DIR_SEGMENT)
57 .join(GRAPH_DIR_SEGMENT)
58 .join(MANIFEST_FILE_NAME)
59}
60
61type Callback = Arc<dyn Fn() + Send + Sync + 'static>;
62
63struct WatcherState {
64 watcher: RecommendedWatcher,
65 rx: Receiver<notify::Result<Event>>,
66 callbacks: Arc<Mutex<HashMap<PathBuf, Callback>>>,
67}
68
69impl WatcherState {
70 fn lock_callbacks(&self) -> MutexGuard<'_, HashMap<PathBuf, Callback>> {
71 self.callbacks
72 .lock()
73 .unwrap_or_else(std::sync::PoisonError::into_inner)
74 }
75}
76
77pub struct FileWatcher {
80 state: Option<WatcherState>,
81}
82
83impl FileWatcher {
84 pub fn new() -> SessionResult<Self> {
90 let capacity = watch_event_queue_capacity();
93 let (tx, rx) = std::sync::mpsc::sync_channel(capacity);
94
95 let watcher = RecommendedWatcher::new(
96 move |event| {
97 let _ = tx.send(event);
98 },
99 NotifyConfig::default(),
100 )
101 .map_err(SessionError::WatcherInit)?;
102
103 Ok(Self {
104 state: Some(WatcherState {
105 watcher,
106 rx,
107 callbacks: Arc::new(Mutex::new(HashMap::new())),
108 }),
109 })
110 }
111
112 #[must_use]
114 pub fn disabled() -> Self {
115 Self { state: None }
116 }
117
118 pub fn watch<F>(&mut self, path: PathBuf, on_change: F) -> SessionResult<()>
130 where
131 F: Fn() + Send + Sync + 'static,
132 {
133 let Some(state) = &mut self.state else {
134 return Ok(());
136 };
137
138 if state.lock_callbacks().contains_key(&path) {
140 return Ok(());
141 }
142
143 state
144 .watcher
145 .watch(&path, RecursiveMode::NonRecursive)
146 .map_err(|source| SessionError::WatchIndex {
147 path: path.clone(),
148 source,
149 })?;
150
151 state.lock_callbacks().insert(path, Arc::new(on_change));
152
153 Ok(())
154 }
155
156 pub fn unwatch(&mut self, path: &Path) -> SessionResult<()> {
162 let Some(state) = &mut self.state else {
163 return Ok(());
164 };
165
166 if state.lock_callbacks().remove(path).is_some() {
167 state
168 .watcher
169 .unwatch(path)
170 .map_err(|source| SessionError::UnwatchIndex {
171 path: path.to_path_buf(),
172 source,
173 })?;
174 }
175
176 Ok(())
177 }
178
179 #[cfg(test)]
185 #[must_use]
186 pub(crate) fn watched_paths(&self) -> Vec<PathBuf> {
187 self.state
188 .as_ref()
189 .map(|state| state.lock_callbacks().keys().cloned().collect())
190 .unwrap_or_default()
191 }
192
193 pub fn process_events(&mut self) -> SessionResult<()> {
199 let Some(state) = &mut self.state else {
200 return Ok(());
201 };
202
203 loop {
204 match state.rx.try_recv() {
205 Ok(Ok(event)) => Self::handle_event(state, &event),
206 Ok(Err(err)) => {
207 log::warn!("file watcher error: {err}");
208 }
209 Err(TryRecvError::Empty | TryRecvError::Disconnected) => break,
210 }
211 }
212
213 Ok(())
214 }
215
216 pub fn wait_and_process(&mut self, duration: std::time::Duration) -> SessionResult<()> {
228 let Some(state) = &mut self.state else {
229 return Ok(());
230 };
231
232 let deadline = std::time::Instant::now() + duration;
233
234 while std::time::Instant::now() < deadline {
235 let remaining = deadline.saturating_duration_since(std::time::Instant::now());
236 let poll_interval = std::time::Duration::from_millis(10).min(remaining);
237
238 match state.rx.recv_timeout(poll_interval) {
239 Ok(Ok(event)) => Self::handle_event(state, &event),
240 Ok(Err(err)) => {
241 log::warn!("file watcher error: {err}");
242 }
243 Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
244 }
246 Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
247 break;
248 }
249 }
250 }
251
252 Ok(())
253 }
254
255 fn handle_event(state: &WatcherState, event: &Event) {
256 use EventKind::{Any, Create, Modify, Remove};
257
258 let relevant = matches!(event.kind, Modify(_) | Create(_) | Remove(_) | Any);
259
260 if !relevant {
261 return;
262 }
263
264 let mut callbacks_to_run: Vec<Callback> = Vec::new();
266
267 {
268 let callbacks = state.lock_callbacks();
269 for path in &event.paths {
270 if path
275 .file_name()
276 .is_some_and(|name| name == OsStr::new(MANIFEST_FILE_NAME))
277 && let Some(graph_dir) = path.parent()
278 && graph_dir
279 .file_name()
280 .is_some_and(|name| name == OsStr::new(GRAPH_DIR_SEGMENT))
281 && let Some(sqry_dir) = graph_dir.parent()
282 && sqry_dir
283 .file_name()
284 .is_some_and(|name| name == OsStr::new(SQRY_DIR_SEGMENT))
285 && let Some(callback) = callbacks.get(path)
286 {
287 callbacks_to_run.push(Arc::clone(callback));
288 }
289 }
290 }
291
292 for callback in callbacks_to_run {
293 callback();
294 }
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use std::sync::atomic::{AtomicBool, Ordering};
302 use std::time::Duration;
303 use tempfile::tempdir;
304
305 fn event_timeout() -> Duration {
306 let base = if cfg!(target_os = "macos") {
308 Duration::from_secs(3)
309 } else {
310 Duration::from_secs(1) };
312
313 if std::env::var("CI").is_ok() {
315 base * 2
316 } else {
317 base
318 }
319 }
320
321 #[test]
322 #[cfg_attr(target_os = "macos", ignore = "FSEvents timing flaky in CI")]
323 fn detects_changes_to_index_file() {
324 let temp = tempdir().unwrap();
325 let workspace = temp.path();
326 let manifest = manifest_path(workspace);
327 std::fs::create_dir_all(manifest.parent().unwrap()).unwrap();
328 std::fs::write(&manifest, b"initial").unwrap();
329
330 let mut watcher = FileWatcher::new().unwrap();
331
332 let triggered = Arc::new(AtomicBool::new(false));
333 let flag = Arc::clone(&triggered);
334 watcher
335 .watch(manifest.clone(), move || {
336 flag.store(true, Ordering::SeqCst);
337 })
338 .unwrap();
339
340 std::fs::write(&manifest, b"modified").unwrap();
341
342 watcher.wait_and_process(event_timeout()).unwrap();
343
344 assert!(triggered.load(Ordering::SeqCst));
345 }
346
347 #[test]
348 fn disabled_watcher_is_noop() {
349 let temp = tempdir().unwrap();
350 let workspace = temp.path();
351 let manifest = manifest_path(workspace);
352 std::fs::create_dir_all(manifest.parent().unwrap()).unwrap();
353 std::fs::write(&manifest, b"data").unwrap();
354
355 let mut watcher = FileWatcher::disabled();
356 watcher
357 .watch(manifest, || {
358 panic!("disabled watcher should not invoke callback");
359 })
360 .unwrap();
361 watcher.process_events().unwrap();
363 }
364}