Skip to main content

stakpak_shared/
file_watcher.rs

1use std::{
2    collections::HashMap,
3    hash::{DefaultHasher, Hash, Hasher},
4    path::{Path, PathBuf},
5    sync::Arc,
6};
7
8use notify::{Config, Event, RecommendedWatcher, RecursiveMode, Watcher};
9use serde::{Deserialize, Serialize};
10use tokio::sync::mpsc;
11use walkdir::WalkDir;
12
13/// Represents a file's content and metadata for tracking changes
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct FileBuffer {
16    pub content: String,
17    pub uri: String,
18    pub hash: u64,
19    pub path: PathBuf,
20}
21
22/// Events that can occur during file watching
23#[derive(Debug, Clone)]
24pub enum FileWatchEvent {
25    /// File was modified
26    Modified {
27        file: FileBuffer,
28        old_content: String,
29    },
30    /// File was deleted
31    Deleted { file: FileBuffer },
32    /// File was created
33    Created { file: FileBuffer },
34    /// Raw filesystem event for custom handling
35    Raw { event: Event },
36}
37
38/// Trait for filtering which files should be watched
39pub trait FileFilter: Send + Sync {
40    /// Returns true if the file should be watched
41    fn should_watch(&self, path: &Path) -> bool;
42}
43
44/// Simple closure-based file filter
45pub struct ClosureFilter<F>
46where
47    F: Fn(&Path) -> bool + Send + Sync,
48{
49    filter_fn: F,
50}
51
52impl<F> ClosureFilter<F>
53where
54    F: Fn(&Path) -> bool + Send + Sync,
55{
56    pub fn new(filter_fn: F) -> Self {
57        Self { filter_fn }
58    }
59}
60
61impl<F> FileFilter for ClosureFilter<F>
62where
63    F: Fn(&Path) -> bool + Send + Sync,
64{
65    fn should_watch(&self, path: &Path) -> bool {
66        (self.filter_fn)(path)
67    }
68}
69
70/// Main file watcher that can watch directories for changes
71pub struct FileWatcher {
72    watch_dir: PathBuf,
73    watched_files: HashMap<String, FileBuffer>,
74    filter: Arc<dyn FileFilter>,
75    watcher: Option<RecommendedWatcher>,
76}
77
78impl FileWatcher {
79    /// Create a new file watcher
80    pub fn new<F>(watch_dir: PathBuf, filter: F) -> Self
81    where
82        F: FileFilter + 'static,
83    {
84        Self {
85            watch_dir,
86            watched_files: HashMap::new(),
87            filter: Arc::new(filter),
88            watcher: None,
89        }
90    }
91
92    /// Initialize the watcher and scan for existing files
93    pub fn initialize(&mut self) -> Result<(), String> {
94        self.watched_files = self.scan_directory()?;
95        Ok(())
96    }
97
98    /// Start watching the directory and return a receiver for processed events
99    pub async fn start_watching(&mut self) -> Result<mpsc::Receiver<FileWatchEvent>, String> {
100        let (processed_tx, processed_rx) = mpsc::channel(100);
101        let (raw_tx, mut raw_rx) = mpsc::unbounded_channel();
102
103        let watch_dir = self.watch_dir.clone();
104        let filter = Arc::clone(&self.filter);
105        let raw_tx_clone = raw_tx.clone();
106
107        // Create the filesystem watcher
108        let watcher: Result<RecommendedWatcher, notify::Error> = RecommendedWatcher::new(
109            move |result: Result<Event, notify::Error>| {
110                if let Ok(event) = result {
111                    // Filter events based on paths
112                    let should_process = event
113                        .paths
114                        .iter()
115                        .any(|path| path.is_file() && filter.should_watch(path));
116
117                    if should_process {
118                        let _ = raw_tx_clone.send(event);
119                    }
120                }
121            },
122            Config::default(),
123        );
124        let mut watcher = watcher.map_err(|e| format!("Failed to create watcher: {}", e))?;
125
126        // Start watching
127        watcher
128            .watch(&watch_dir, RecursiveMode::Recursive)
129            .map_err(|e| format!("Failed to watch directory: {}", e))?;
130
131        self.watcher = Some(watcher);
132
133        // Spawn background task to process raw events
134        let watch_dir_clone = self.watch_dir.clone();
135        let filter_clone = Arc::clone(&self.filter);
136        let watched_files = self.watched_files.clone();
137
138        tokio::spawn(async move {
139            let mut internal_watcher = InternalEventProcessor {
140                watch_dir: watch_dir_clone,
141                watched_files,
142                filter: filter_clone,
143                processed_tx,
144            };
145
146            while let Some(raw_event) = raw_rx.recv().await {
147                if let Err(e) = internal_watcher.process_event(raw_event).await {
148                    eprintln!("Error processing file watch event: {}", e);
149                }
150            }
151        });
152
153        Ok(processed_rx)
154    }
155
156    /// Get current watched files (snapshot at initialization)
157    pub fn get_watched_files(&self) -> &HashMap<String, FileBuffer> {
158        &self.watched_files
159    }
160
161    /// Get the directory being watched
162    pub fn watch_dir(&self) -> &Path {
163        &self.watch_dir
164    }
165
166    /// Scan directory for existing files
167    fn scan_directory(&self) -> Result<HashMap<String, FileBuffer>, String> {
168        let mut files = HashMap::new();
169
170        for entry in WalkDir::new(&self.watch_dir)
171            .into_iter()
172            .filter_map(Result::ok)
173            .filter(|entry| entry.path().is_file() && self.filter.should_watch(entry.path()))
174        {
175            let path = entry.path();
176            if let Ok(buffer) = self.create_file_buffer(path) {
177                files.insert(buffer.uri.clone(), buffer);
178            }
179        }
180
181        Ok(files)
182    }
183
184    /// Create a file buffer from a path
185    fn create_file_buffer(&self, path: &Path) -> Result<FileBuffer, String> {
186        let content = std::fs::read_to_string(path)
187            .map_err(|e| format!("Failed to read file {}: {}", path.display(), e))?;
188
189        let hash = self.hash_content(&content);
190        let uri = self.path_to_uri(path);
191
192        // Use canonical path for consistency
193        let canonical_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
194
195        Ok(FileBuffer {
196            content,
197            uri,
198            hash,
199            path: canonical_path,
200        })
201    }
202
203    /// Convert path to URI
204    fn path_to_uri(&self, path: &Path) -> String {
205        // Use canonical path to ensure consistency across all platforms
206        let canonical_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
207
208        // Create absolute URI instead of relative
209        format!(
210            "file://{}",
211            canonical_path.to_string_lossy().replace('\\', "/")
212        )
213    }
214
215    /// Hash file content
216    fn hash_content(&self, content: &str) -> u64 {
217        let mut hasher = DefaultHasher::new();
218        content.hash(&mut hasher);
219        hasher.finish()
220    }
221}
222
223/// Internal event processor that handles raw events and produces processed events
224struct InternalEventProcessor {
225    #[allow(dead_code)]
226    watch_dir: PathBuf,
227    watched_files: HashMap<String, FileBuffer>,
228    filter: Arc<dyn FileFilter>,
229    processed_tx: mpsc::Sender<FileWatchEvent>,
230}
231
232impl InternalEventProcessor {
233    /// Process a raw filesystem event and send processed events
234    async fn process_event(&mut self, event: Event) -> Result<(), String> {
235        let mut events_to_send = Vec::new();
236
237        // Handle deletions first
238        self.process_deletions(&mut events_to_send);
239
240        // Handle modifications and creations
241        self.process_modifications(&event, &mut events_to_send)?;
242
243        // Send all processed events
244        for event in events_to_send {
245            if self.processed_tx.send(event).await.is_err() {
246                // Channel was closed, stop processing
247                return Err("Event channel closed".to_string());
248            }
249        }
250
251        Ok(())
252    }
253
254    /// Process file deletions
255    fn process_deletions(&mut self, events: &mut Vec<FileWatchEvent>) {
256        let mut to_remove = Vec::new();
257
258        for (uri, buffer) in &self.watched_files {
259            if !buffer.path.exists() {
260                events.push(FileWatchEvent::Deleted {
261                    file: buffer.clone(),
262                });
263                to_remove.push(uri.clone());
264            }
265        }
266
267        for uri in to_remove {
268            self.watched_files.remove(&uri);
269        }
270    }
271
272    /// Process file modifications and creations
273    fn process_modifications(
274        &mut self,
275        event: &Event,
276        events: &mut Vec<FileWatchEvent>,
277    ) -> Result<(), String> {
278        for path in &event.paths {
279            if !path.is_file() || !self.filter.should_watch(path) {
280                continue;
281            }
282
283            let uri = self.path_to_uri(path);
284
285            match self.create_file_buffer(path) {
286                Ok(new_buffer) => {
287                    if let Some(old_buffer) = self.watched_files.get(&uri) {
288                        // File exists and was modified
289                        if old_buffer.hash != new_buffer.hash {
290                            events.push(FileWatchEvent::Modified {
291                                file: new_buffer.clone(),
292                                old_content: old_buffer.content.clone(),
293                            });
294                            self.watched_files.insert(uri, new_buffer);
295                        }
296                    } else {
297                        // New file created
298                        events.push(FileWatchEvent::Created {
299                            file: new_buffer.clone(),
300                        });
301                        self.watched_files.insert(uri, new_buffer);
302                    }
303                }
304                Err(_) => {
305                    // File might have been deleted
306                    if let Some(old_buffer) = self.watched_files.remove(&uri) {
307                        events.push(FileWatchEvent::Deleted { file: old_buffer });
308                    }
309                }
310            }
311        }
312
313        Ok(())
314    }
315
316    /// Create a file buffer from a path
317    fn create_file_buffer(&self, path: &Path) -> Result<FileBuffer, String> {
318        let content = std::fs::read_to_string(path)
319            .map_err(|e| format!("Failed to read file {}: {}", path.display(), e))?;
320
321        let hash = self.hash_content(&content);
322        let uri = self.path_to_uri(path);
323
324        // Use canonical path for consistency
325        let canonical_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
326
327        Ok(FileBuffer {
328            content,
329            uri,
330            hash,
331            path: canonical_path,
332        })
333    }
334
335    /// Convert path to URI
336    fn path_to_uri(&self, path: &Path) -> String {
337        // Use canonical path to ensure consistency across all platforms
338        let canonical_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
339
340        // Create absolute URI instead of relative
341        format!(
342            "file://{}",
343            canonical_path.to_string_lossy().replace('\\', "/")
344        )
345    }
346
347    /// Hash file content
348    fn hash_content(&self, content: &str) -> u64 {
349        let mut hasher = DefaultHasher::new();
350        content.hash(&mut hasher);
351        hasher.finish()
352    }
353}
354
355/// Convenience function to create a file watcher with closure-based filter
356pub fn create_file_watcher<F>(watch_dir: PathBuf, filter: F) -> Result<FileWatcher, String>
357where
358    F: Fn(&Path) -> bool + Send + Sync + 'static,
359{
360    let filter = ClosureFilter::new(filter);
361    let watcher = FileWatcher::new(watch_dir, filter);
362    Ok(watcher)
363}
364
365/// Convenience function to create and start a file watcher, returning the event receiver
366pub async fn create_and_start_watcher<F>(
367    watch_dir: PathBuf,
368    filter: F,
369) -> Result<(FileWatcher, mpsc::Receiver<FileWatchEvent>), String>
370where
371    F: Fn(&Path) -> bool + Send + Sync + 'static,
372{
373    let mut watcher = create_file_watcher(watch_dir, filter)?;
374    watcher.initialize()?;
375    let receiver = watcher.start_watching().await?;
376    Ok((watcher, receiver))
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use std::fs;
383    use std::path::Path;
384    use tempfile::TempDir;
385    use tokio::time::Duration;
386
387    // Helper function to create a test directory with some files
388    fn create_test_directory() -> TempDir {
389        let temp_dir = TempDir::new().expect("Failed to create temp directory");
390        let temp_path = temp_dir.path();
391
392        // Create some test files
393        fs::write(temp_path.join("test1.txt"), "content1").expect("Failed to write test1.txt");
394        fs::write(temp_path.join("test2.rs"), "fn main() {}").expect("Failed to write test2.rs");
395        fs::write(temp_path.join("ignore.log"), "log content").expect("Failed to write ignore.log");
396
397        // Create subdirectory with files
398        let sub_dir = temp_path.join("subdir");
399        fs::create_dir(&sub_dir).expect("Failed to create subdirectory");
400        fs::write(sub_dir.join("nested.txt"), "nested content")
401            .expect("Failed to write nested.txt");
402
403        temp_dir
404    }
405
406    // Simple test filter that only watches .txt and .rs files
407    fn test_filter(path: &Path) -> bool {
408        if let Some(ext) = path.extension() {
409            matches!(ext.to_str(), Some("txt") | Some("rs"))
410        } else {
411            false
412        }
413    }
414
415    #[test]
416    fn test_file_filter_trait() {
417        let filter = ClosureFilter::new(test_filter);
418
419        assert!(filter.should_watch(Path::new("test.txt")));
420        assert!(filter.should_watch(Path::new("test.rs")));
421        assert!(!filter.should_watch(Path::new("test.log")));
422        assert!(!filter.should_watch(Path::new("test")));
423    }
424
425    #[test]
426    fn test_file_watcher_creation() {
427        let temp_dir = create_test_directory();
428        let filter = ClosureFilter::new(test_filter);
429
430        let watcher = FileWatcher::new(temp_dir.path().to_path_buf(), filter);
431
432        assert_eq!(watcher.watch_dir(), temp_dir.path());
433        assert_eq!(watcher.get_watched_files().len(), 0); // Not initialized yet
434    }
435
436    #[test]
437    fn test_file_watcher_initialization() {
438        let temp_dir = create_test_directory();
439        let filter = ClosureFilter::new(test_filter);
440
441        let mut watcher = FileWatcher::new(temp_dir.path().to_path_buf(), filter);
442        watcher.initialize().expect("Failed to initialize watcher");
443
444        let watched_files = watcher.get_watched_files();
445
446        // Should have 3 files: test1.txt, test2.rs, and nested.txt (filtered by extension)
447        assert_eq!(watched_files.len(), 3);
448
449        // Check that files are properly tracked
450        let file_names: Vec<_> = watched_files
451            .values()
452            .map(|f| f.path.file_name().unwrap().to_str().unwrap())
453            .collect();
454
455        assert!(file_names.contains(&"test1.txt"));
456        assert!(file_names.contains(&"test2.rs"));
457        assert!(file_names.contains(&"nested.txt"));
458    }
459
460    #[tokio::test]
461    async fn test_create_and_start_watcher() {
462        let temp_dir = create_test_directory();
463
464        let (watcher, _rx) = create_and_start_watcher(temp_dir.path().to_path_buf(), test_filter)
465            .await
466            .expect("Failed to create and start watcher");
467
468        // Should have the same files as the basic test
469        assert_eq!(watcher.get_watched_files().len(), 3);
470        assert_eq!(watcher.watch_dir(), temp_dir.path());
471    }
472
473    #[tokio::test]
474    async fn test_real_file_creation_detection() {
475        let temp_dir = TempDir::new().expect("Failed to create temp directory");
476
477        let (_watcher, mut event_rx) =
478            create_and_start_watcher(temp_dir.path().to_path_buf(), test_filter)
479                .await
480                .expect("Failed to create and start watcher");
481
482        // Give the watcher a moment to start
483        tokio::time::sleep(Duration::from_millis(200)).await;
484
485        // Create a new file
486        let new_file = temp_dir.path().join("new_test.txt");
487        fs::write(&new_file, "new file content").expect("Failed to create new file");
488        let new_file_canonical = new_file
489            .canonicalize()
490            .expect("Failed to canonicalize path");
491
492        // Wait for processed events
493        let mut creation_detected = false;
494        let timeout = tokio::time::Instant::now() + Duration::from_secs(2);
495
496        while tokio::time::Instant::now() < timeout && !creation_detected {
497            tokio::select! {
498                Some(event) = event_rx.recv() => {
499                    if let FileWatchEvent::Created { file } = event
500                        && file.path == new_file_canonical {
501                            assert_eq!(file.content, "new file content");
502                            creation_detected = true;
503                            break;
504                        }
505                }
506                _ = tokio::time::sleep(Duration::from_millis(50)) => {
507                    // Continue waiting
508                }
509            }
510        }
511
512        assert!(creation_detected, "File creation was not detected");
513    }
514
515    #[tokio::test]
516    async fn test_real_file_modification_detection() {
517        let temp_dir = TempDir::new().expect("Failed to create temp directory");
518
519        // Create initial file
520        let test_file = temp_dir.path().join("modify_test.txt");
521        fs::write(&test_file, "initial content").expect("Failed to create initial file");
522        let test_file_canonical = test_file
523            .canonicalize()
524            .expect("Failed to canonicalize path");
525
526        let (_watcher, mut event_rx) =
527            create_and_start_watcher(temp_dir.path().to_path_buf(), test_filter)
528                .await
529                .expect("Failed to create and start watcher");
530
531        // Give the watcher a moment to start
532        tokio::time::sleep(Duration::from_millis(200)).await;
533
534        // Modify the file
535        fs::write(&test_file, "modified content").expect("Failed to modify file");
536
537        // Wait for processed events
538        let mut modification_detected = false;
539        let timeout = tokio::time::Instant::now() + Duration::from_secs(2);
540
541        while tokio::time::Instant::now() < timeout && !modification_detected {
542            tokio::select! {
543                Some(event) = event_rx.recv() => {
544                    if let FileWatchEvent::Modified { file, old_content } = event
545                        && file.path == test_file_canonical {
546                            assert_eq!(file.content, "modified content");
547                            assert_eq!(old_content, "initial content");
548                            modification_detected = true;
549                            break;
550                        }
551                }
552                _ = tokio::time::sleep(Duration::from_millis(50)) => {
553                    // Continue waiting
554                }
555            }
556        }
557
558        assert!(modification_detected, "File modification was not detected");
559    }
560
561    #[tokio::test]
562    async fn test_file_filter_in_real_watching() {
563        let temp_dir = TempDir::new().expect("Failed to create temp directory");
564
565        let (_watcher, mut event_rx) = create_and_start_watcher(
566            temp_dir.path().to_path_buf(),
567            test_filter, // Only watches .txt and .rs files
568        )
569        .await
570        .expect("Failed to create and start watcher");
571
572        // Give the watcher a moment to start
573        tokio::time::sleep(Duration::from_millis(200)).await;
574
575        // Create files with different extensions
576        let txt_file = temp_dir.path().join("watched.txt");
577        let log_file = temp_dir.path().join("ignored.log");
578
579        fs::write(&txt_file, "should be watched").expect("Failed to create txt file");
580        fs::write(&log_file, "should be ignored").expect("Failed to create log file");
581
582        let txt_file_canonical = txt_file
583            .canonicalize()
584            .expect("Failed to canonicalize txt file path");
585        let log_file_canonical = log_file
586            .canonicalize()
587            .expect("Failed to canonicalize log file path");
588
589        // Wait for processed events
590        let mut txt_detected = false;
591        let mut log_detected = false;
592        let timeout = tokio::time::Instant::now() + Duration::from_secs(2);
593
594        while tokio::time::Instant::now() < timeout && !txt_detected {
595            tokio::select! {
596                Some(event) = event_rx.recv() => {
597                    if let FileWatchEvent::Created { file } = event {
598                        if file.path == txt_file_canonical {
599                            txt_detected = true;
600                        } else if file.path == log_file_canonical {
601                            log_detected = true;
602                        }
603                    }
604                }
605                _ = tokio::time::sleep(Duration::from_millis(50)) => {
606                    // Continue waiting
607                }
608            }
609        }
610
611        assert!(txt_detected, "TXT file creation should be detected");
612        assert!(!log_detected, "LOG file creation should be filtered out");
613    }
614}