1use std::collections::{HashMap, HashSet};
2use std::path::Path;
3use std::sync::{mpsc, Arc, Mutex};
4use std::thread::{self, JoinHandle};
5use std::time::{Duration, Instant};
6
7use globset::{Glob, GlobSet, GlobSetBuilder};
8use zccache_core::NormalizedPath;
9
10#[derive(Clone, Debug, PartialEq, Eq)]
11struct FileState {
12 mtime_ns: u128,
13 size: u64,
14}
15
16#[derive(Clone)]
17struct ScanConfig {
18 root: NormalizedPath,
19 include_folders: Vec<NormalizedPath>,
20 include_globs: GlobSet,
21 exclude_globs: GlobSet,
22 excluded_names: HashSet<String>,
23}
24
25#[derive(Clone, Debug, PartialEq, Eq)]
26pub struct PollWatchBatch {
27 pub changed: Vec<NormalizedPath>,
28 pub removed: Vec<NormalizedPath>,
29 pub overflow: bool,
30}
31
32impl PollWatchBatch {
33 #[must_use]
34 pub fn is_empty(&self) -> bool {
35 self.changed.is_empty() && self.removed.is_empty() && !self.overflow
36 }
37}
38
39pub trait PollWatchObserver: Send + Sync {
40 fn on_batch(&self, batch: &PollWatchBatch);
41}
42
43struct FnObserver<F> {
44 callback: F,
45}
46
47impl<F> PollWatchObserver for FnObserver<F>
48where
49 F: Fn(&PollWatchBatch) + Send + Sync + 'static,
50{
51 fn on_batch(&self, batch: &PollWatchBatch) {
52 (self.callback)(batch);
53 }
54}
55
56#[derive(Clone, Debug)]
57pub struct PollingWatcherConfig {
58 pub root: NormalizedPath,
59 pub include_folders: Vec<NormalizedPath>,
60 pub include_globs: Vec<String>,
61 pub excluded_patterns: Vec<String>,
62 pub poll_interval: Duration,
63 pub debounce: Duration,
64}
65
66impl PollingWatcherConfig {
67 #[must_use]
68 pub fn new(root: impl Into<NormalizedPath>) -> Self {
69 Self {
70 root: root.into(),
71 include_folders: Vec::new(),
72 include_globs: Vec::new(),
73 excluded_patterns: Vec::new(),
74 poll_interval: Duration::from_millis(100),
75 debounce: Duration::from_millis(200),
76 }
77 }
78}
79
80pub struct PollingWatcher {
81 config: ScanConfig,
82 poll_interval: Duration,
83 debounce: Duration,
84 observers: Arc<Mutex<Vec<Arc<dyn PollWatchObserver>>>>,
85 poll_rx: Mutex<Option<mpsc::Receiver<PollWatchBatch>>>,
86 worker_shutdown_tx: Mutex<Option<mpsc::Sender<()>>>,
87 worker_handle: Mutex<Option<JoinHandle<()>>>,
88 dispatch_shutdown_tx: Mutex<Option<mpsc::Sender<()>>>,
89 dispatch_handle: Mutex<Option<JoinHandle<()>>>,
90}
91
92impl PollingWatcher {
93 pub fn new(config: PollingWatcherConfig) -> std::io::Result<Self> {
94 let root = config.root;
95 if !root.is_dir() {
96 return Err(std::io::Error::new(
97 std::io::ErrorKind::NotFound,
98 format!(
99 "watch root does not exist or is not a directory: {}",
100 root.display()
101 ),
102 ));
103 }
104
105 let scan_config = build_config(
106 &root,
107 &config.include_folders,
108 &config.include_globs,
109 &config.excluded_patterns,
110 )?;
111
112 Ok(Self {
113 config: scan_config,
114 poll_interval: config.poll_interval.max(Duration::from_millis(1)),
115 debounce: config.debounce,
116 observers: Arc::new(Mutex::new(Vec::new())),
117 poll_rx: Mutex::new(None),
118 worker_shutdown_tx: Mutex::new(None),
119 worker_handle: Mutex::new(None),
120 dispatch_shutdown_tx: Mutex::new(None),
121 dispatch_handle: Mutex::new(None),
122 })
123 }
124
125 pub fn start(&self) -> std::io::Result<()> {
126 if self.is_running() {
127 return Ok(());
128 }
129
130 let (worker_batch_tx, worker_batch_rx) = mpsc::channel();
131 let (poll_tx, poll_rx) = mpsc::channel();
132 let (worker_shutdown_tx, worker_shutdown_rx) = mpsc::channel();
133 let (dispatch_shutdown_tx, dispatch_shutdown_rx) = mpsc::channel();
134 let (ready_tx, ready_rx) = mpsc::channel();
135 let config = self.config.clone();
136 let poll_interval = self.poll_interval;
137 let debounce = self.debounce;
138 let observers = Arc::clone(&self.observers);
139
140 let worker_handle = thread::Builder::new()
141 .name("zccache-polling-watcher".to_string())
142 .spawn(move || {
143 run_poll_loop(
144 config,
145 poll_interval,
146 debounce,
147 worker_batch_tx,
148 worker_shutdown_rx,
149 ready_tx,
150 )
151 })?;
152
153 match ready_rx.recv() {
154 Ok(()) => {}
155 Err(_) => {
156 let _ = worker_handle.join();
157 return Err(std::io::Error::other(
158 "watcher worker exited before initialization completed",
159 ));
160 }
161 }
162
163 let dispatch_handle = thread::Builder::new()
164 .name("zccache-polling-watcher-dispatch".to_string())
165 .spawn(move || {
166 run_dispatch_loop(worker_batch_rx, poll_tx, dispatch_shutdown_rx, observers)
167 })?;
168
169 *self
170 .poll_rx
171 .lock()
172 .map_err(|_| std::io::Error::other("watcher receiver lock poisoned"))? = Some(poll_rx);
173 *self
174 .worker_shutdown_tx
175 .lock()
176 .map_err(|_| std::io::Error::other("watcher shutdown lock poisoned"))? =
177 Some(worker_shutdown_tx);
178 *self
179 .worker_handle
180 .lock()
181 .map_err(|_| std::io::Error::other("watcher worker lock poisoned"))? =
182 Some(worker_handle);
183 *self
184 .dispatch_shutdown_tx
185 .lock()
186 .map_err(|_| std::io::Error::other("watcher dispatch shutdown lock poisoned"))? =
187 Some(dispatch_shutdown_tx);
188 *self
189 .dispatch_handle
190 .lock()
191 .map_err(|_| std::io::Error::other("watcher dispatch lock poisoned"))? =
192 Some(dispatch_handle);
193
194 Ok(())
195 }
196
197 pub fn resume(&self) -> std::io::Result<()> {
198 self.start()
199 }
200
201 pub fn stop(&self) -> std::io::Result<()> {
202 let worker_shutdown = self
203 .worker_shutdown_tx
204 .lock()
205 .map_err(|_| std::io::Error::other("watcher shutdown lock poisoned"))?
206 .take();
207 if let Some(tx) = worker_shutdown {
208 let _ = tx.send(());
209 }
210
211 let dispatch_shutdown = self
212 .dispatch_shutdown_tx
213 .lock()
214 .map_err(|_| std::io::Error::other("watcher dispatch shutdown lock poisoned"))?
215 .take();
216 if let Some(tx) = dispatch_shutdown {
217 let _ = tx.send(());
218 }
219
220 let worker = self
221 .worker_handle
222 .lock()
223 .map_err(|_| std::io::Error::other("watcher worker lock poisoned"))?
224 .take();
225 if let Some(handle) = worker {
226 handle
227 .join()
228 .map_err(|_| std::io::Error::other("watcher worker thread panicked"))?;
229 }
230
231 let dispatch = self
232 .dispatch_handle
233 .lock()
234 .map_err(|_| std::io::Error::other("watcher dispatch lock poisoned"))?
235 .take();
236 if let Some(handle) = dispatch {
237 handle
238 .join()
239 .map_err(|_| std::io::Error::other("watcher dispatch thread panicked"))?;
240 }
241
242 *self
243 .poll_rx
244 .lock()
245 .map_err(|_| std::io::Error::other("watcher receiver lock poisoned"))? = None;
246
247 Ok(())
248 }
249
250 #[must_use]
251 pub fn is_running(&self) -> bool {
252 self.worker_handle
253 .lock()
254 .ok()
255 .and_then(|guard| guard.as_ref().map(JoinHandle::is_finished))
256 .is_some_and(|finished| !finished)
257 }
258
259 pub fn poll(&self) -> std::io::Result<Option<PollWatchBatch>> {
260 self.poll_timeout(Duration::ZERO)
261 }
262
263 pub fn poll_timeout(&self, timeout: Duration) -> std::io::Result<Option<PollWatchBatch>> {
264 let receiver_guard = self
265 .poll_rx
266 .lock()
267 .map_err(|_| std::io::Error::other("watcher receiver lock poisoned"))?;
268 let Some(receiver) = receiver_guard.as_ref() else {
269 return Ok(None);
270 };
271
272 if timeout.is_zero() {
273 match receiver.try_recv() {
274 Ok(batch) => Ok(Some(batch)),
275 Err(mpsc::TryRecvError::Empty | mpsc::TryRecvError::Disconnected) => Ok(None),
276 }
277 } else {
278 match receiver.recv_timeout(timeout) {
279 Ok(batch) => Ok(Some(batch)),
280 Err(mpsc::RecvTimeoutError::Timeout | mpsc::RecvTimeoutError::Disconnected) => {
281 Ok(None)
282 }
283 }
284 }
285 }
286
287 pub fn add_observer(&self, observer: Arc<dyn PollWatchObserver>) -> std::io::Result<()> {
288 self.observers
289 .lock()
290 .map_err(|_| std::io::Error::other("watcher observers lock poisoned"))?
291 .push(observer);
292 Ok(())
293 }
294
295 pub fn add_callback<F>(&self, callback: F) -> std::io::Result<()>
296 where
297 F: Fn(&PollWatchBatch) + Send + Sync + 'static,
298 {
299 self.add_observer(Arc::new(FnObserver { callback }))
300 }
301}
302
303impl Drop for PollingWatcher {
304 fn drop(&mut self) {
305 let _ = self.stop();
306 }
307}
308
309fn run_dispatch_loop(
310 worker_batch_rx: mpsc::Receiver<PollWatchBatch>,
311 poll_tx: mpsc::Sender<PollWatchBatch>,
312 dispatch_shutdown_rx: mpsc::Receiver<()>,
313 observers: Arc<Mutex<Vec<Arc<dyn PollWatchObserver>>>>,
314) {
315 loop {
316 if dispatch_shutdown_rx.try_recv().is_ok() {
317 break;
318 }
319
320 let batch = match worker_batch_rx.recv_timeout(Duration::from_millis(25)) {
321 Ok(batch) => batch,
322 Err(mpsc::RecvTimeoutError::Timeout) => continue,
323 Err(mpsc::RecvTimeoutError::Disconnected) => break,
324 };
325
326 if poll_tx.send(batch.clone()).is_err() {
327 break;
328 }
329
330 let snapshot = match observers.lock() {
331 Ok(guard) => guard.clone(),
332 Err(_) => break,
333 };
334 for observer in snapshot {
335 observer.on_batch(&batch);
336 }
337 }
338}
339
340fn run_poll_loop(
341 config: ScanConfig,
342 poll_interval: Duration,
343 debounce: Duration,
344 batch_tx: mpsc::Sender<PollWatchBatch>,
345 shutdown_rx: mpsc::Receiver<()>,
346 ready_tx: mpsc::Sender<()>,
347) {
348 let mut snapshot = scan_snapshot(&config);
349 let _ = ready_tx.send(());
350 let mut pending_changed: HashSet<NormalizedPath> = HashSet::new();
351 let mut pending_removed: HashSet<NormalizedPath> = HashSet::new();
352 let mut last_change: Option<Instant> = None;
353
354 loop {
355 if shutdown_rx.try_recv().is_ok() {
356 break;
357 }
358
359 let current = scan_snapshot(&config);
360 let (changed, removed) = diff_snapshots(&snapshot, ¤t);
361
362 if !changed.is_empty() || !removed.is_empty() {
363 for path in changed {
364 pending_removed.remove(&path);
365 pending_changed.insert(path);
366 }
367 for path in removed {
368 pending_changed.remove(&path);
369 pending_removed.insert(path);
370 }
371 last_change = Some(Instant::now());
372 } else if let Some(last) = last_change {
373 if last.elapsed() >= debounce
374 && (!pending_changed.is_empty() || !pending_removed.is_empty())
375 {
376 let mut changed: Vec<NormalizedPath> = pending_changed.drain().collect();
377 let mut removed: Vec<NormalizedPath> = pending_removed.drain().collect();
378 changed.sort();
379 removed.sort();
380 if batch_tx
381 .send(PollWatchBatch {
382 changed,
383 removed,
384 overflow: false,
385 })
386 .is_err()
387 {
388 break;
389 }
390 last_change = None;
391 }
392 }
393
394 snapshot = current;
395
396 if shutdown_rx.recv_timeout(poll_interval).is_ok() {
397 break;
398 }
399 }
400}
401
402fn build_config(
403 root: &Path,
404 include_folders: &[NormalizedPath],
405 include_globs: &[String],
406 excluded_patterns: &[String],
407) -> std::io::Result<ScanConfig> {
408 let root = NormalizedPath::new(root.canonicalize()?);
409
410 let include_folders = if include_folders.is_empty() {
411 vec![root.clone()]
412 } else {
413 include_folders
414 .iter()
415 .map(|folder| {
416 let absolute = if folder.is_absolute() {
417 folder.clone().into_path_buf()
418 } else {
419 root.join(folder).into_path_buf()
420 };
421 Ok(NormalizedPath::new(
422 absolute.canonicalize().unwrap_or(absolute),
423 ))
424 })
425 .collect::<std::io::Result<Vec<_>>>()?
426 };
427
428 let include_patterns = if include_globs.is_empty() {
429 vec!["**".to_string()]
430 } else {
431 include_globs.to_vec()
432 };
433 let include_globs = build_globset(&expand_patterns(&include_patterns))?;
434
435 let excluded_names = excluded_patterns
436 .iter()
437 .filter(|pattern| !has_glob_meta(pattern) && !pattern.contains('/'))
438 .cloned()
439 .collect::<HashSet<_>>();
440 let exclude_globs = build_globset(&expand_patterns(excluded_patterns))?;
441
442 Ok(ScanConfig {
443 root,
444 include_folders,
445 include_globs,
446 exclude_globs,
447 excluded_names,
448 })
449}
450
451fn build_globset(patterns: &[String]) -> std::io::Result<GlobSet> {
452 let mut builder = GlobSetBuilder::new();
453 for pattern in patterns {
454 builder.add(
455 Glob::new(pattern).map_err(|e| std::io::Error::other(format!("invalid glob: {e}")))?,
456 );
457 }
458 builder
459 .build()
460 .map_err(|e| std::io::Error::other(format!("failed to compile glob set: {e}")))
461}
462
463fn expand_patterns(patterns: &[String]) -> Vec<String> {
464 let mut expanded = Vec::new();
465 for pattern in patterns {
466 let mut seen = HashSet::new();
467 let mut pending = vec![pattern.replace('\\', "/")];
468 while let Some(current) = pending.pop() {
469 if !seen.insert(current.clone()) {
470 continue;
471 }
472 if current.contains("**/") {
473 pending.push(current.replace("**/", ""));
474 }
475 if current.contains("/**") {
476 pending.push(current.replace("/**", ""));
477 }
478 expanded.push(current);
479 }
480 }
481 expanded
482}
483
484fn has_glob_meta(pattern: &str) -> bool {
485 pattern.contains('*') || pattern.contains('?') || pattern.contains('[')
486}
487
488fn scan_snapshot(config: &ScanConfig) -> HashMap<NormalizedPath, FileState> {
489 let mut result = HashMap::new();
490
491 for base in &config.include_folders {
492 if !base.exists() {
493 continue;
494 }
495
496 let root = config.root.clone();
497 let exclude_names = config.excluded_names.clone();
498 let exclude_globs = config.exclude_globs.clone();
499
500 let walker = jwalk::WalkDir::new(base)
501 .follow_links(false)
502 .skip_hidden(false)
503 .process_read_dir(move |_depth, _path, _state, children| {
504 children.retain(|entry| {
505 let Ok(entry) = entry else {
506 return true;
507 };
508 if !entry.file_type.is_dir() {
509 return true;
510 }
511 let path = entry.path();
512 if let Some(name) = path.file_name().and_then(|name| name.to_str()) {
513 if exclude_names.contains(name) {
514 return false;
515 }
516 }
517 let rel = rel_string(&root, &path);
518 !exclude_globs.is_match(&rel)
519 });
520 });
521
522 for entry in walker.into_iter().flatten() {
523 if !entry.file_type.is_file() {
524 continue;
525 }
526 let path = entry.path();
527 let rel = rel_string(&config.root, &path);
528 if config.exclude_globs.is_match(&rel) || !config.include_globs.is_match(&rel) {
529 continue;
530 }
531 if let Ok(metadata) = path.metadata() {
532 result.insert(
533 NormalizedPath::new(path),
534 FileState {
535 mtime_ns: metadata
536 .modified()
537 .ok()
538 .and_then(|time| time.duration_since(std::time::UNIX_EPOCH).ok())
539 .map_or(0, |duration| duration.as_nanos()),
540 size: metadata.len(),
541 },
542 );
543 }
544 }
545 }
546
547 result
548}
549
550fn diff_snapshots(
551 previous: &HashMap<NormalizedPath, FileState>,
552 current: &HashMap<NormalizedPath, FileState>,
553) -> (HashSet<NormalizedPath>, HashSet<NormalizedPath>) {
554 let mut changed = HashSet::new();
555 let mut removed = HashSet::new();
556
557 for (path, state) in current {
558 if previous.get(path) != Some(state) {
559 changed.insert(path.clone());
560 }
561 }
562
563 for path in previous.keys() {
564 if !current.contains_key(path) {
565 removed.insert(path.clone());
566 }
567 }
568
569 (changed, removed)
570}
571
572fn rel_string(root: &Path, path: &Path) -> String {
573 path.strip_prefix(root)
574 .unwrap_or(path)
575 .components()
576 .map(|component| component.as_os_str().to_string_lossy())
577 .collect::<Vec<_>>()
578 .join("/")
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584 use std::fs;
585 use std::sync::atomic::{AtomicUsize, Ordering};
586 use tempfile::tempdir;
587
588 fn wait_for_batch(watcher: &PollingWatcher) -> PollWatchBatch {
589 let deadline = Instant::now() + Duration::from_secs(3);
590 loop {
591 if let Some(batch) = watcher
592 .poll_timeout(Duration::from_millis(100))
593 .expect("poll should succeed")
594 {
595 return batch;
596 }
597 assert!(
598 Instant::now() < deadline,
599 "timed out waiting for watcher batch"
600 );
601 }
602 }
603
604 #[test]
605 fn polling_watcher_reports_filtered_changes() {
606 let dir = tempdir().unwrap();
607 let root = dir.path();
608 fs::create_dir_all(root.join("src")).unwrap();
609 fs::create_dir_all(root.join("build")).unwrap();
610 fs::write(root.join("src/watch.cpp"), "a\n").unwrap();
611 fs::write(root.join("build/ignore.cpp"), "a\n").unwrap();
612
613 let mut config = PollingWatcherConfig::new(root);
614 config.include_folders = vec![NormalizedPath::from("src"), NormalizedPath::from("build")];
615 config.include_globs = vec!["**/*.cpp".to_string()];
616 config.excluded_patterns = vec!["build".to_string()];
617 config.poll_interval = Duration::from_millis(20);
618 config.debounce = Duration::from_millis(20);
619
620 let watcher = PollingWatcher::new(config).unwrap();
621 watcher.start().unwrap();
622 fs::write(root.join("src/watch.cpp"), "b\n").unwrap();
623 fs::write(root.join("build/ignore.cpp"), "b\n").unwrap();
624
625 let batch = wait_for_batch(&watcher);
626 watcher.stop().unwrap();
627
628 assert_eq!(
629 batch.changed,
630 vec![NormalizedPath::new(
631 root.join("src/watch.cpp").canonicalize().unwrap(),
632 )]
633 );
634 assert!(batch.removed.is_empty());
635 }
636
637 #[test]
638 fn polling_watcher_resume_resets_baseline() {
639 let dir = tempdir().unwrap();
640 let root = dir.path();
641 fs::write(root.join("watch.cpp"), "a\n").unwrap();
642
643 let mut config = PollingWatcherConfig::new(root);
644 config.include_globs = vec!["**/*.cpp".to_string()];
645 config.poll_interval = Duration::from_millis(20);
646 config.debounce = Duration::from_millis(20);
647
648 let watcher = PollingWatcher::new(config).unwrap();
649 watcher.start().unwrap();
650 watcher.stop().unwrap();
651 fs::write(root.join("watch.cpp"), "b\n").unwrap();
652 watcher.resume().unwrap();
653 assert!(watcher
654 .poll_timeout(Duration::from_millis(200))
655 .unwrap()
656 .is_none());
657 fs::write(root.join("watch.cpp"), "c\n").unwrap();
658 let batch = wait_for_batch(&watcher);
659 watcher.stop().unwrap();
660
661 assert_eq!(
662 batch.changed,
663 vec![NormalizedPath::new(
664 root.join("watch.cpp").canonicalize().unwrap()
665 )]
666 );
667 }
668
669 #[test]
670 fn polling_watcher_callbacks_and_polling_share_events() {
671 let dir = tempdir().unwrap();
672 let root = dir.path();
673 fs::write(root.join("watch.cpp"), "a\n").unwrap();
674
675 let mut config = PollingWatcherConfig::new(root);
676 config.include_globs = vec!["**/*.cpp".to_string()];
677 config.poll_interval = Duration::from_millis(20);
678 config.debounce = Duration::from_millis(20);
679
680 let watcher = PollingWatcher::new(config).unwrap();
681 let callback_count = Arc::new(AtomicUsize::new(0));
682 let callback_count_clone = Arc::clone(&callback_count);
683 watcher
684 .add_callback(move |_batch| {
685 callback_count_clone.fetch_add(1, Ordering::SeqCst);
686 })
687 .unwrap();
688 watcher.start().unwrap();
689
690 fs::write(root.join("watch.cpp"), "b\n").unwrap();
691 let batch = wait_for_batch(&watcher);
692 watcher.stop().unwrap();
693
694 assert_eq!(callback_count.load(Ordering::SeqCst), 1);
695 assert_eq!(
696 batch.changed,
697 vec![NormalizedPath::new(
698 root.join("watch.cpp").canonicalize().unwrap()
699 )]
700 );
701 }
702}