Skip to main content

zccache_watcher/
polling_watcher.rs

1use std::collections::{HashMap, HashSet};
2use std::path::Path;
3use std::sync::{mpsc, Arc, Mutex};
4use std::thread::{self, JoinHandle};
5use std::time::{Duration, Instant};
6
7use globset::{Glob, GlobSet, GlobSetBuilder};
8use zccache_core::NormalizedPath;
9
10#[derive(Clone, Debug, PartialEq, Eq)]
11struct FileState {
12    mtime_ns: u128,
13    size: u64,
14}
15
16#[derive(Clone)]
17struct ScanConfig {
18    root: NormalizedPath,
19    include_folders: Vec<NormalizedPath>,
20    include_globs: GlobSet,
21    exclude_globs: GlobSet,
22    excluded_names: HashSet<String>,
23}
24
25#[derive(Clone, Debug, PartialEq, Eq)]
26pub struct PollWatchBatch {
27    pub changed: Vec<NormalizedPath>,
28    pub removed: Vec<NormalizedPath>,
29    pub overflow: bool,
30}
31
32impl PollWatchBatch {
33    #[must_use]
34    pub fn is_empty(&self) -> bool {
35        self.changed.is_empty() && self.removed.is_empty() && !self.overflow
36    }
37}
38
39pub trait PollWatchObserver: Send + Sync {
40    fn on_batch(&self, batch: &PollWatchBatch);
41}
42
43struct FnObserver<F> {
44    callback: F,
45}
46
47impl<F> PollWatchObserver for FnObserver<F>
48where
49    F: Fn(&PollWatchBatch) + Send + Sync + 'static,
50{
51    fn on_batch(&self, batch: &PollWatchBatch) {
52        (self.callback)(batch);
53    }
54}
55
56#[derive(Clone, Debug)]
57pub struct PollingWatcherConfig {
58    pub root: NormalizedPath,
59    pub include_folders: Vec<NormalizedPath>,
60    pub include_globs: Vec<String>,
61    pub excluded_patterns: Vec<String>,
62    pub poll_interval: Duration,
63    pub debounce: Duration,
64}
65
66impl PollingWatcherConfig {
67    #[must_use]
68    pub fn new(root: impl Into<NormalizedPath>) -> Self {
69        Self {
70            root: root.into(),
71            include_folders: Vec::new(),
72            include_globs: Vec::new(),
73            excluded_patterns: Vec::new(),
74            poll_interval: Duration::from_millis(100),
75            debounce: Duration::from_millis(200),
76        }
77    }
78}
79
80pub struct PollingWatcher {
81    config: ScanConfig,
82    poll_interval: Duration,
83    debounce: Duration,
84    observers: Arc<Mutex<Vec<Arc<dyn PollWatchObserver>>>>,
85    poll_rx: Mutex<Option<mpsc::Receiver<PollWatchBatch>>>,
86    worker_shutdown_tx: Mutex<Option<mpsc::Sender<()>>>,
87    worker_handle: Mutex<Option<JoinHandle<()>>>,
88    dispatch_shutdown_tx: Mutex<Option<mpsc::Sender<()>>>,
89    dispatch_handle: Mutex<Option<JoinHandle<()>>>,
90}
91
92impl PollingWatcher {
93    pub fn new(config: PollingWatcherConfig) -> std::io::Result<Self> {
94        let root = config.root;
95        if !root.is_dir() {
96            return Err(std::io::Error::new(
97                std::io::ErrorKind::NotFound,
98                format!(
99                    "watch root does not exist or is not a directory: {}",
100                    root.display()
101                ),
102            ));
103        }
104
105        let scan_config = build_config(
106            &root,
107            &config.include_folders,
108            &config.include_globs,
109            &config.excluded_patterns,
110        )?;
111
112        Ok(Self {
113            config: scan_config,
114            poll_interval: config.poll_interval.max(Duration::from_millis(1)),
115            debounce: config.debounce,
116            observers: Arc::new(Mutex::new(Vec::new())),
117            poll_rx: Mutex::new(None),
118            worker_shutdown_tx: Mutex::new(None),
119            worker_handle: Mutex::new(None),
120            dispatch_shutdown_tx: Mutex::new(None),
121            dispatch_handle: Mutex::new(None),
122        })
123    }
124
125    pub fn start(&self) -> std::io::Result<()> {
126        if self.is_running() {
127            return Ok(());
128        }
129
130        let (worker_batch_tx, worker_batch_rx) = mpsc::channel();
131        let (poll_tx, poll_rx) = mpsc::channel();
132        let (worker_shutdown_tx, worker_shutdown_rx) = mpsc::channel();
133        let (dispatch_shutdown_tx, dispatch_shutdown_rx) = mpsc::channel();
134        let (ready_tx, ready_rx) = mpsc::channel();
135        let config = self.config.clone();
136        let poll_interval = self.poll_interval;
137        let debounce = self.debounce;
138        let observers = Arc::clone(&self.observers);
139
140        let worker_handle = thread::Builder::new()
141            .name("zccache-polling-watcher".to_string())
142            .spawn(move || {
143                run_poll_loop(
144                    config,
145                    poll_interval,
146                    debounce,
147                    worker_batch_tx,
148                    worker_shutdown_rx,
149                    ready_tx,
150                )
151            })?;
152
153        match ready_rx.recv() {
154            Ok(()) => {}
155            Err(_) => {
156                let _ = worker_handle.join();
157                return Err(std::io::Error::other(
158                    "watcher worker exited before initialization completed",
159                ));
160            }
161        }
162
163        let dispatch_handle = thread::Builder::new()
164            .name("zccache-polling-watcher-dispatch".to_string())
165            .spawn(move || {
166                run_dispatch_loop(worker_batch_rx, poll_tx, dispatch_shutdown_rx, observers)
167            })?;
168
169        *self
170            .poll_rx
171            .lock()
172            .map_err(|_| std::io::Error::other("watcher receiver lock poisoned"))? = Some(poll_rx);
173        *self
174            .worker_shutdown_tx
175            .lock()
176            .map_err(|_| std::io::Error::other("watcher shutdown lock poisoned"))? =
177            Some(worker_shutdown_tx);
178        *self
179            .worker_handle
180            .lock()
181            .map_err(|_| std::io::Error::other("watcher worker lock poisoned"))? =
182            Some(worker_handle);
183        *self
184            .dispatch_shutdown_tx
185            .lock()
186            .map_err(|_| std::io::Error::other("watcher dispatch shutdown lock poisoned"))? =
187            Some(dispatch_shutdown_tx);
188        *self
189            .dispatch_handle
190            .lock()
191            .map_err(|_| std::io::Error::other("watcher dispatch lock poisoned"))? =
192            Some(dispatch_handle);
193
194        Ok(())
195    }
196
197    pub fn resume(&self) -> std::io::Result<()> {
198        self.start()
199    }
200
201    pub fn stop(&self) -> std::io::Result<()> {
202        let worker_shutdown = self
203            .worker_shutdown_tx
204            .lock()
205            .map_err(|_| std::io::Error::other("watcher shutdown lock poisoned"))?
206            .take();
207        if let Some(tx) = worker_shutdown {
208            let _ = tx.send(());
209        }
210
211        let dispatch_shutdown = self
212            .dispatch_shutdown_tx
213            .lock()
214            .map_err(|_| std::io::Error::other("watcher dispatch shutdown lock poisoned"))?
215            .take();
216        if let Some(tx) = dispatch_shutdown {
217            let _ = tx.send(());
218        }
219
220        let worker = self
221            .worker_handle
222            .lock()
223            .map_err(|_| std::io::Error::other("watcher worker lock poisoned"))?
224            .take();
225        if let Some(handle) = worker {
226            handle
227                .join()
228                .map_err(|_| std::io::Error::other("watcher worker thread panicked"))?;
229        }
230
231        let dispatch = self
232            .dispatch_handle
233            .lock()
234            .map_err(|_| std::io::Error::other("watcher dispatch lock poisoned"))?
235            .take();
236        if let Some(handle) = dispatch {
237            handle
238                .join()
239                .map_err(|_| std::io::Error::other("watcher dispatch thread panicked"))?;
240        }
241
242        *self
243            .poll_rx
244            .lock()
245            .map_err(|_| std::io::Error::other("watcher receiver lock poisoned"))? = None;
246
247        Ok(())
248    }
249
250    #[must_use]
251    pub fn is_running(&self) -> bool {
252        self.worker_handle
253            .lock()
254            .ok()
255            .and_then(|guard| guard.as_ref().map(JoinHandle::is_finished))
256            .is_some_and(|finished| !finished)
257    }
258
259    pub fn poll(&self) -> std::io::Result<Option<PollWatchBatch>> {
260        self.poll_timeout(Duration::ZERO)
261    }
262
263    pub fn poll_timeout(&self, timeout: Duration) -> std::io::Result<Option<PollWatchBatch>> {
264        let receiver_guard = self
265            .poll_rx
266            .lock()
267            .map_err(|_| std::io::Error::other("watcher receiver lock poisoned"))?;
268        let Some(receiver) = receiver_guard.as_ref() else {
269            return Ok(None);
270        };
271
272        if timeout.is_zero() {
273            match receiver.try_recv() {
274                Ok(batch) => Ok(Some(batch)),
275                Err(mpsc::TryRecvError::Empty | mpsc::TryRecvError::Disconnected) => Ok(None),
276            }
277        } else {
278            match receiver.recv_timeout(timeout) {
279                Ok(batch) => Ok(Some(batch)),
280                Err(mpsc::RecvTimeoutError::Timeout | mpsc::RecvTimeoutError::Disconnected) => {
281                    Ok(None)
282                }
283            }
284        }
285    }
286
287    pub fn add_observer(&self, observer: Arc<dyn PollWatchObserver>) -> std::io::Result<()> {
288        self.observers
289            .lock()
290            .map_err(|_| std::io::Error::other("watcher observers lock poisoned"))?
291            .push(observer);
292        Ok(())
293    }
294
295    pub fn add_callback<F>(&self, callback: F) -> std::io::Result<()>
296    where
297        F: Fn(&PollWatchBatch) + Send + Sync + 'static,
298    {
299        self.add_observer(Arc::new(FnObserver { callback }))
300    }
301}
302
303impl Drop for PollingWatcher {
304    fn drop(&mut self) {
305        let _ = self.stop();
306    }
307}
308
309fn run_dispatch_loop(
310    worker_batch_rx: mpsc::Receiver<PollWatchBatch>,
311    poll_tx: mpsc::Sender<PollWatchBatch>,
312    dispatch_shutdown_rx: mpsc::Receiver<()>,
313    observers: Arc<Mutex<Vec<Arc<dyn PollWatchObserver>>>>,
314) {
315    loop {
316        if dispatch_shutdown_rx.try_recv().is_ok() {
317            break;
318        }
319
320        let batch = match worker_batch_rx.recv_timeout(Duration::from_millis(25)) {
321            Ok(batch) => batch,
322            Err(mpsc::RecvTimeoutError::Timeout) => continue,
323            Err(mpsc::RecvTimeoutError::Disconnected) => break,
324        };
325
326        if poll_tx.send(batch.clone()).is_err() {
327            break;
328        }
329
330        let snapshot = match observers.lock() {
331            Ok(guard) => guard.clone(),
332            Err(_) => break,
333        };
334        for observer in snapshot {
335            observer.on_batch(&batch);
336        }
337    }
338}
339
340fn run_poll_loop(
341    config: ScanConfig,
342    poll_interval: Duration,
343    debounce: Duration,
344    batch_tx: mpsc::Sender<PollWatchBatch>,
345    shutdown_rx: mpsc::Receiver<()>,
346    ready_tx: mpsc::Sender<()>,
347) {
348    let mut snapshot = scan_snapshot(&config);
349    let _ = ready_tx.send(());
350    let mut pending_changed: HashSet<NormalizedPath> = HashSet::new();
351    let mut pending_removed: HashSet<NormalizedPath> = HashSet::new();
352    let mut last_change: Option<Instant> = None;
353
354    loop {
355        if shutdown_rx.try_recv().is_ok() {
356            break;
357        }
358
359        let current = scan_snapshot(&config);
360        let (changed, removed) = diff_snapshots(&snapshot, &current);
361
362        if !changed.is_empty() || !removed.is_empty() {
363            for path in changed {
364                pending_removed.remove(&path);
365                pending_changed.insert(path);
366            }
367            for path in removed {
368                pending_changed.remove(&path);
369                pending_removed.insert(path);
370            }
371            last_change = Some(Instant::now());
372        } else if let Some(last) = last_change {
373            if last.elapsed() >= debounce
374                && (!pending_changed.is_empty() || !pending_removed.is_empty())
375            {
376                let mut changed: Vec<NormalizedPath> = pending_changed.drain().collect();
377                let mut removed: Vec<NormalizedPath> = pending_removed.drain().collect();
378                changed.sort();
379                removed.sort();
380                if batch_tx
381                    .send(PollWatchBatch {
382                        changed,
383                        removed,
384                        overflow: false,
385                    })
386                    .is_err()
387                {
388                    break;
389                }
390                last_change = None;
391            }
392        }
393
394        snapshot = current;
395
396        if shutdown_rx.recv_timeout(poll_interval).is_ok() {
397            break;
398        }
399    }
400}
401
402fn build_config(
403    root: &Path,
404    include_folders: &[NormalizedPath],
405    include_globs: &[String],
406    excluded_patterns: &[String],
407) -> std::io::Result<ScanConfig> {
408    let root = NormalizedPath::new(root.canonicalize()?);
409
410    let include_folders = if include_folders.is_empty() {
411        vec![root.clone()]
412    } else {
413        include_folders
414            .iter()
415            .map(|folder| {
416                let absolute = if folder.is_absolute() {
417                    folder.clone().into_path_buf()
418                } else {
419                    root.join(folder).into_path_buf()
420                };
421                Ok(NormalizedPath::new(
422                    absolute.canonicalize().unwrap_or(absolute),
423                ))
424            })
425            .collect::<std::io::Result<Vec<_>>>()?
426    };
427
428    let include_patterns = if include_globs.is_empty() {
429        vec!["**".to_string()]
430    } else {
431        include_globs.to_vec()
432    };
433    let include_globs = build_globset(&expand_patterns(&include_patterns))?;
434
435    let excluded_names = excluded_patterns
436        .iter()
437        .filter(|pattern| !has_glob_meta(pattern) && !pattern.contains('/'))
438        .cloned()
439        .collect::<HashSet<_>>();
440    let exclude_globs = build_globset(&expand_patterns(excluded_patterns))?;
441
442    Ok(ScanConfig {
443        root,
444        include_folders,
445        include_globs,
446        exclude_globs,
447        excluded_names,
448    })
449}
450
451fn build_globset(patterns: &[String]) -> std::io::Result<GlobSet> {
452    let mut builder = GlobSetBuilder::new();
453    for pattern in patterns {
454        builder.add(
455            Glob::new(pattern).map_err(|e| std::io::Error::other(format!("invalid glob: {e}")))?,
456        );
457    }
458    builder
459        .build()
460        .map_err(|e| std::io::Error::other(format!("failed to compile glob set: {e}")))
461}
462
463fn expand_patterns(patterns: &[String]) -> Vec<String> {
464    let mut expanded = Vec::new();
465    for pattern in patterns {
466        let mut seen = HashSet::new();
467        let mut pending = vec![pattern.replace('\\', "/")];
468        while let Some(current) = pending.pop() {
469            if !seen.insert(current.clone()) {
470                continue;
471            }
472            if current.contains("**/") {
473                pending.push(current.replace("**/", ""));
474            }
475            if current.contains("/**") {
476                pending.push(current.replace("/**", ""));
477            }
478            expanded.push(current);
479        }
480    }
481    expanded
482}
483
484fn has_glob_meta(pattern: &str) -> bool {
485    pattern.contains('*') || pattern.contains('?') || pattern.contains('[')
486}
487
488fn scan_snapshot(config: &ScanConfig) -> HashMap<NormalizedPath, FileState> {
489    let mut result = HashMap::new();
490
491    for base in &config.include_folders {
492        if !base.exists() {
493            continue;
494        }
495
496        let root = config.root.clone();
497        let exclude_names = config.excluded_names.clone();
498        let exclude_globs = config.exclude_globs.clone();
499
500        let walker = jwalk::WalkDir::new(base)
501            .follow_links(false)
502            .skip_hidden(false)
503            .process_read_dir(move |_depth, _path, _state, children| {
504                children.retain(|entry| {
505                    let Ok(entry) = entry else {
506                        return true;
507                    };
508                    if !entry.file_type.is_dir() {
509                        return true;
510                    }
511                    let path = entry.path();
512                    if let Some(name) = path.file_name().and_then(|name| name.to_str()) {
513                        if exclude_names.contains(name) {
514                            return false;
515                        }
516                    }
517                    let rel = rel_string(&root, &path);
518                    !exclude_globs.is_match(&rel)
519                });
520            });
521
522        for entry in walker.into_iter().flatten() {
523            if !entry.file_type.is_file() {
524                continue;
525            }
526            let path = entry.path();
527            let rel = rel_string(&config.root, &path);
528            if config.exclude_globs.is_match(&rel) || !config.include_globs.is_match(&rel) {
529                continue;
530            }
531            if let Ok(metadata) = path.metadata() {
532                result.insert(
533                    NormalizedPath::new(path),
534                    FileState {
535                        mtime_ns: metadata
536                            .modified()
537                            .ok()
538                            .and_then(|time| time.duration_since(std::time::UNIX_EPOCH).ok())
539                            .map_or(0, |duration| duration.as_nanos()),
540                        size: metadata.len(),
541                    },
542                );
543            }
544        }
545    }
546
547    result
548}
549
550fn diff_snapshots(
551    previous: &HashMap<NormalizedPath, FileState>,
552    current: &HashMap<NormalizedPath, FileState>,
553) -> (HashSet<NormalizedPath>, HashSet<NormalizedPath>) {
554    let mut changed = HashSet::new();
555    let mut removed = HashSet::new();
556
557    for (path, state) in current {
558        if previous.get(path) != Some(state) {
559            changed.insert(path.clone());
560        }
561    }
562
563    for path in previous.keys() {
564        if !current.contains_key(path) {
565            removed.insert(path.clone());
566        }
567    }
568
569    (changed, removed)
570}
571
572fn rel_string(root: &Path, path: &Path) -> String {
573    path.strip_prefix(root)
574        .unwrap_or(path)
575        .components()
576        .map(|component| component.as_os_str().to_string_lossy())
577        .collect::<Vec<_>>()
578        .join("/")
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584    use std::fs;
585    use std::sync::atomic::{AtomicUsize, Ordering};
586    use tempfile::tempdir;
587
588    fn wait_for_batch(watcher: &PollingWatcher) -> PollWatchBatch {
589        let deadline = Instant::now() + Duration::from_secs(3);
590        loop {
591            if let Some(batch) = watcher
592                .poll_timeout(Duration::from_millis(100))
593                .expect("poll should succeed")
594            {
595                return batch;
596            }
597            assert!(
598                Instant::now() < deadline,
599                "timed out waiting for watcher batch"
600            );
601        }
602    }
603
604    #[test]
605    fn polling_watcher_reports_filtered_changes() {
606        let dir = tempdir().unwrap();
607        let root = dir.path();
608        fs::create_dir_all(root.join("src")).unwrap();
609        fs::create_dir_all(root.join("build")).unwrap();
610        fs::write(root.join("src/watch.cpp"), "a\n").unwrap();
611        fs::write(root.join("build/ignore.cpp"), "a\n").unwrap();
612
613        let mut config = PollingWatcherConfig::new(root);
614        config.include_folders = vec![NormalizedPath::from("src"), NormalizedPath::from("build")];
615        config.include_globs = vec!["**/*.cpp".to_string()];
616        config.excluded_patterns = vec!["build".to_string()];
617        config.poll_interval = Duration::from_millis(20);
618        config.debounce = Duration::from_millis(20);
619
620        let watcher = PollingWatcher::new(config).unwrap();
621        watcher.start().unwrap();
622        fs::write(root.join("src/watch.cpp"), "b\n").unwrap();
623        fs::write(root.join("build/ignore.cpp"), "b\n").unwrap();
624
625        let batch = wait_for_batch(&watcher);
626        watcher.stop().unwrap();
627
628        assert_eq!(
629            batch.changed,
630            vec![NormalizedPath::new(
631                root.join("src/watch.cpp").canonicalize().unwrap(),
632            )]
633        );
634        assert!(batch.removed.is_empty());
635    }
636
637    #[test]
638    fn polling_watcher_resume_resets_baseline() {
639        let dir = tempdir().unwrap();
640        let root = dir.path();
641        fs::write(root.join("watch.cpp"), "a\n").unwrap();
642
643        let mut config = PollingWatcherConfig::new(root);
644        config.include_globs = vec!["**/*.cpp".to_string()];
645        config.poll_interval = Duration::from_millis(20);
646        config.debounce = Duration::from_millis(20);
647
648        let watcher = PollingWatcher::new(config).unwrap();
649        watcher.start().unwrap();
650        watcher.stop().unwrap();
651        fs::write(root.join("watch.cpp"), "b\n").unwrap();
652        watcher.resume().unwrap();
653        assert!(watcher
654            .poll_timeout(Duration::from_millis(200))
655            .unwrap()
656            .is_none());
657        fs::write(root.join("watch.cpp"), "c\n").unwrap();
658        let batch = wait_for_batch(&watcher);
659        watcher.stop().unwrap();
660
661        assert_eq!(
662            batch.changed,
663            vec![NormalizedPath::new(
664                root.join("watch.cpp").canonicalize().unwrap()
665            )]
666        );
667    }
668
669    #[test]
670    fn polling_watcher_callbacks_and_polling_share_events() {
671        let dir = tempdir().unwrap();
672        let root = dir.path();
673        fs::write(root.join("watch.cpp"), "a\n").unwrap();
674
675        let mut config = PollingWatcherConfig::new(root);
676        config.include_globs = vec!["**/*.cpp".to_string()];
677        config.poll_interval = Duration::from_millis(20);
678        config.debounce = Duration::from_millis(20);
679
680        let watcher = PollingWatcher::new(config).unwrap();
681        let callback_count = Arc::new(AtomicUsize::new(0));
682        let callback_count_clone = Arc::clone(&callback_count);
683        watcher
684            .add_callback(move |_batch| {
685                callback_count_clone.fetch_add(1, Ordering::SeqCst);
686            })
687            .unwrap();
688        watcher.start().unwrap();
689
690        fs::write(root.join("watch.cpp"), "b\n").unwrap();
691        let batch = wait_for_batch(&watcher);
692        watcher.stop().unwrap();
693
694        assert_eq!(callback_count.load(Ordering::SeqCst), 1);
695        assert_eq!(
696            batch.changed,
697            vec![NormalizedPath::new(
698                root.join("watch.cpp").canonicalize().unwrap()
699            )]
700        );
701    }
702}