sqry_core/session/
watcher.rs1use std::collections::HashMap;
16use std::ffi::OsStr;
17use std::path::{Path, PathBuf};
18use std::sync::mpsc::{Receiver, TryRecvError};
19use std::sync::{Arc, Mutex, MutexGuard};
20
21use notify::{
22 Config as NotifyConfig, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher,
23};
24
25use super::error::{SessionError, SessionResult};
26use crate::config::buffers::watch_event_queue_capacity;
27
28const INDEX_FILE_NAME: &str = ".sqry-index";
30
31type Callback = Arc<dyn Fn() + Send + Sync + 'static>;
32
33struct WatcherState {
34 watcher: RecommendedWatcher,
35 rx: Receiver<notify::Result<Event>>,
36 callbacks: Arc<Mutex<HashMap<PathBuf, Callback>>>,
37}
38
39impl WatcherState {
40 fn lock_callbacks(&self) -> MutexGuard<'_, HashMap<PathBuf, Callback>> {
41 self.callbacks
42 .lock()
43 .unwrap_or_else(std::sync::PoisonError::into_inner)
44 }
45}
46
47pub struct FileWatcher {
49 state: Option<WatcherState>,
50}
51
52impl FileWatcher {
53 pub fn new() -> SessionResult<Self> {
59 let capacity = watch_event_queue_capacity();
62 let (tx, rx) = std::sync::mpsc::sync_channel(capacity);
63
64 let watcher = RecommendedWatcher::new(
65 move |event| {
66 let _ = tx.send(event);
67 },
68 NotifyConfig::default(),
69 )
70 .map_err(SessionError::WatcherInit)?;
71
72 Ok(Self {
73 state: Some(WatcherState {
74 watcher,
75 rx,
76 callbacks: Arc::new(Mutex::new(HashMap::new())),
77 }),
78 })
79 }
80
81 #[must_use]
83 pub fn disabled() -> Self {
84 Self { state: None }
85 }
86
87 pub fn watch<F>(&mut self, path: PathBuf, on_change: F) -> SessionResult<()>
95 where
96 F: Fn() + Send + Sync + 'static,
97 {
98 let Some(state) = &mut self.state else {
99 return Ok(());
101 };
102
103 if state.lock_callbacks().contains_key(&path) {
105 return Ok(());
106 }
107
108 state
109 .watcher
110 .watch(&path, RecursiveMode::NonRecursive)
111 .map_err(|source| SessionError::WatchIndex {
112 path: path.clone(),
113 source,
114 })?;
115
116 state.lock_callbacks().insert(path, Arc::new(on_change));
117
118 Ok(())
119 }
120
121 pub fn unwatch(&mut self, path: &Path) -> SessionResult<()> {
127 let Some(state) = &mut self.state else {
128 return Ok(());
129 };
130
131 if state.lock_callbacks().remove(path).is_some() {
132 state
133 .watcher
134 .unwatch(path)
135 .map_err(|source| SessionError::UnwatchIndex {
136 path: path.to_path_buf(),
137 source,
138 })?;
139 }
140
141 Ok(())
142 }
143
144 pub fn process_events(&mut self) -> SessionResult<()> {
150 let Some(state) = &mut self.state else {
151 return Ok(());
152 };
153
154 loop {
155 match state.rx.try_recv() {
156 Ok(Ok(event)) => Self::handle_event(state, &event),
157 Ok(Err(err)) => {
158 log::warn!("file watcher error: {err}");
159 }
160 Err(TryRecvError::Empty | TryRecvError::Disconnected) => break,
161 }
162 }
163
164 Ok(())
165 }
166
167 pub fn wait_and_process(&mut self, duration: std::time::Duration) -> SessionResult<()> {
179 let Some(state) = &mut self.state else {
180 return Ok(());
181 };
182
183 let deadline = std::time::Instant::now() + duration;
184
185 while std::time::Instant::now() < deadline {
186 let remaining = deadline.saturating_duration_since(std::time::Instant::now());
187 let poll_interval = std::time::Duration::from_millis(10).min(remaining);
188
189 match state.rx.recv_timeout(poll_interval) {
190 Ok(Ok(event)) => Self::handle_event(state, &event),
191 Ok(Err(err)) => {
192 log::warn!("file watcher error: {err}");
193 }
194 Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
195 }
197 Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
198 break;
199 }
200 }
201 }
202
203 Ok(())
204 }
205
206 fn handle_event(state: &WatcherState, event: &Event) {
207 use EventKind::{Any, Create, Modify, Remove};
208
209 let relevant = matches!(event.kind, Modify(_) | Create(_) | Remove(_) | Any);
210
211 if !relevant {
212 return;
213 }
214
215 let mut callbacks_to_run: Vec<Callback> = Vec::new();
217
218 {
219 let callbacks = state.lock_callbacks();
220 for path in &event.paths {
221 if path
222 .file_name()
223 .is_some_and(|name| name == OsStr::new(INDEX_FILE_NAME))
224 && let Some(parent) = path.parent()
225 && let Some(callback) = callbacks.get(parent)
226 {
227 callbacks_to_run.push(Arc::clone(callback));
228 }
229 }
230 }
231
232 for callback in callbacks_to_run {
233 callback();
234 }
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use std::sync::atomic::{AtomicBool, Ordering};
242 use std::time::Duration;
243 use tempfile::tempdir;
244
245 fn event_timeout() -> Duration {
246 let base = if cfg!(target_os = "macos") {
248 Duration::from_secs(3)
249 } else {
250 Duration::from_secs(1) };
252
253 if std::env::var("CI").is_ok() {
255 base * 2
256 } else {
257 base
258 }
259 }
260
261 #[test]
262 #[cfg_attr(target_os = "macos", ignore = "FSEvents timing flaky in CI")]
263 fn detects_changes_to_index_file() {
264 let temp = tempdir().unwrap();
265 let workspace = temp.path();
266 let index_path = workspace.join(".sqry-index");
267 std::fs::write(&index_path, b"initial").unwrap();
268
269 let mut watcher = FileWatcher::new().unwrap();
270
271 let triggered = Arc::new(AtomicBool::new(false));
272 let flag = Arc::clone(&triggered);
273 watcher
274 .watch(workspace.to_path_buf(), move || {
275 flag.store(true, Ordering::SeqCst);
276 })
277 .unwrap();
278
279 std::fs::write(&index_path, b"modified").unwrap();
280
281 watcher.wait_and_process(event_timeout()).unwrap();
282
283 assert!(triggered.load(Ordering::SeqCst));
284 }
285
286 #[test]
287 fn disabled_watcher_is_noop() {
288 let temp = tempdir().unwrap();
289 let workspace = temp.path();
290 std::fs::write(workspace.join(".sqry-index"), b"data").unwrap();
291
292 let mut watcher = FileWatcher::disabled();
293 watcher
294 .watch(workspace.to_path_buf(), || {
295 panic!("disabled watcher should not invoke callback");
296 })
297 .unwrap();
298 watcher.process_events().unwrap();
300 }
301}