Skip to main content

tuitbot_core/automation/watchtower/
mod.rs

1//! Watchtower content source watcher and shared ingest pipeline.
2//!
3//! Watches configured local directories for `.md` and `.txt` changes via
4//! the `notify` crate with debouncing, and polls remote content sources
5//! (e.g. Google Drive) on a configurable interval.  Both local filesystem
6//! events and remote polls funnel through `ingest_content()`, ensuring
7//! identical state transitions.
8
9pub mod loopback;
10
11#[cfg(test)]
12mod tests;
13
14use std::collections::HashMap;
15use std::path::{Path, PathBuf};
16use std::sync::Mutex;
17use std::time::{Duration, Instant};
18
19use notify_debouncer_full::{
20    new_debouncer, notify::RecursiveMode, DebounceEventResult, Debouncer, RecommendedCache,
21};
22use sha2::{Digest, Sha256};
23use tokio_util::sync::CancellationToken;
24
25use crate::config::ContentSourcesConfig;
26use crate::source::ContentSourceProvider;
27use crate::storage::watchtower as store;
28use crate::storage::DbPool;
29
30// ---------------------------------------------------------------------------
31// Error type
32// ---------------------------------------------------------------------------
33
34/// Errors specific to Watchtower operations.
35#[derive(Debug, thiserror::Error)]
36pub enum WatchtowerError {
37    #[error("IO error: {0}")]
38    Io(#[from] std::io::Error),
39
40    #[error("storage error: {0}")]
41    Storage(#[from] crate::error::StorageError),
42
43    #[error("notify error: {0}")]
44    Notify(#[from] notify::Error),
45
46    #[error("config error: {0}")]
47    Config(String),
48}
49
50// ---------------------------------------------------------------------------
51// Ingest result types
52// ---------------------------------------------------------------------------
53
54/// Summary of a batch ingest operation.
55#[derive(Debug, Default)]
56pub struct IngestSummary {
57    pub ingested: u32,
58    pub skipped: u32,
59    pub errors: Vec<String>,
60}
61
62/// Parsed front-matter from a markdown file.
63#[derive(Debug, Default)]
64pub struct ParsedFrontMatter {
65    pub title: Option<String>,
66    pub tags: Option<String>,
67    pub raw_yaml: Option<String>,
68}
69
70// ---------------------------------------------------------------------------
71// Front-matter parsing
72// ---------------------------------------------------------------------------
73
74/// Parse YAML front-matter from file content.
75///
76/// Returns extracted metadata and the body text (content after front-matter).
77pub fn parse_front_matter(content: &str) -> (ParsedFrontMatter, &str) {
78    let (yaml_str, body) = loopback::split_front_matter(content);
79
80    let yaml_str = match yaml_str {
81        Some(y) => y,
82        None => return (ParsedFrontMatter::default(), content),
83    };
84
85    let parsed: Result<serde_yaml::Value, _> = serde_yaml::from_str(yaml_str);
86    match parsed {
87        Ok(serde_yaml::Value::Mapping(map)) => {
88            let title = map
89                .get(serde_yaml::Value::String("title".to_string()))
90                .and_then(|v| v.as_str())
91                .map(|s| s.to_string());
92
93            let tags = map
94                .get(serde_yaml::Value::String("tags".to_string()))
95                .map(|v| match v {
96                    serde_yaml::Value::Sequence(seq) => seq
97                        .iter()
98                        .filter_map(|item| item.as_str())
99                        .collect::<Vec<_>>()
100                        .join(","),
101                    serde_yaml::Value::String(s) => s.clone(),
102                    _ => String::new(),
103                })
104                .filter(|s| !s.is_empty());
105
106            let fm = ParsedFrontMatter {
107                title,
108                tags,
109                raw_yaml: Some(yaml_str.to_string()),
110            };
111            (fm, body)
112        }
113        _ => (
114            ParsedFrontMatter {
115                raw_yaml: Some(yaml_str.to_string()),
116                ..Default::default()
117            },
118            body,
119        ),
120    }
121}
122
123// ---------------------------------------------------------------------------
124// Pattern matching
125// ---------------------------------------------------------------------------
126
127/// Check whether a file path matches any of the given glob patterns.
128///
129/// Matches against the file name only (not the full path), so `*.md`
130/// matches `sub/dir/note.md`.
131pub fn matches_patterns(path: &Path, patterns: &[String]) -> bool {
132    let file_name = match path.file_name().and_then(|n| n.to_str()) {
133        Some(n) => n,
134        None => return false,
135    };
136
137    for pattern in patterns {
138        if let Ok(p) = glob::Pattern::new(pattern) {
139            if p.matches(file_name) {
140                return true;
141            }
142        }
143    }
144    false
145}
146
147/// Convert a relative path into a stable slash-delimited string across platforms.
148fn relative_path_string(path: &Path) -> String {
149    path.iter()
150        .map(|part| part.to_string_lossy().into_owned())
151        .collect::<Vec<_>>()
152        .join("/")
153}
154
155// ---------------------------------------------------------------------------
156// Shared ingest pipeline
157// ---------------------------------------------------------------------------
158
159/// Ingest raw text content into the Watchtower pipeline.
160///
161/// This is the provider-agnostic code path that both local file reads and
162/// remote content fetches funnel through. It parses front-matter, computes
163/// a content hash, and upserts the content node in the database.
164pub async fn ingest_content(
165    pool: &DbPool,
166    source_id: i64,
167    provider_id: &str,
168    content: &str,
169    force: bool,
170) -> Result<store::UpsertResult, WatchtowerError> {
171    let (fm, body) = parse_front_matter(content);
172
173    let hash = if force {
174        let mut hasher = Sha256::new();
175        hasher.update(content.as_bytes());
176        hasher.update(
177            std::time::SystemTime::now()
178                .duration_since(std::time::UNIX_EPOCH)
179                .unwrap_or_default()
180                .as_nanos()
181                .to_le_bytes(),
182        );
183        format!("{:x}", hasher.finalize())
184    } else {
185        let mut hasher = Sha256::new();
186        hasher.update(content.as_bytes());
187        format!("{:x}", hasher.finalize())
188    };
189
190    let result = store::upsert_content_node(
191        pool,
192        source_id,
193        provider_id,
194        &hash,
195        fm.title.as_deref(),
196        body,
197        fm.raw_yaml.as_deref(),
198        fm.tags.as_deref(),
199    )
200    .await?;
201
202    Ok(result)
203}
204
205/// Ingest a single file from the local filesystem into the Watchtower pipeline.
206///
207/// Convenience wrapper that reads the file then delegates to `ingest_content`.
208pub async fn ingest_file(
209    pool: &DbPool,
210    source_id: i64,
211    base_path: &Path,
212    relative_path: &str,
213    force: bool,
214) -> Result<store::UpsertResult, WatchtowerError> {
215    let full_path = base_path.join(relative_path);
216    let content = tokio::fs::read_to_string(&full_path).await?;
217    ingest_content(pool, source_id, relative_path, &content, force).await
218}
219
220/// Ingest multiple files, collecting results into a summary.
221pub async fn ingest_files(
222    pool: &DbPool,
223    source_id: i64,
224    base_path: &Path,
225    paths: &[String],
226    force: bool,
227) -> IngestSummary {
228    let mut summary = IngestSummary::default();
229
230    for rel_path in paths {
231        match ingest_file(pool, source_id, base_path, rel_path, force).await {
232            Ok(store::UpsertResult::Inserted | store::UpsertResult::Updated) => {
233                summary.ingested += 1;
234            }
235            Ok(store::UpsertResult::Skipped) => {
236                summary.skipped += 1;
237            }
238            Err(e) => {
239                summary.errors.push(format!("{rel_path}: {e}"));
240            }
241        }
242    }
243
244    summary
245}
246
247// ---------------------------------------------------------------------------
248// Cooldown set
249// ---------------------------------------------------------------------------
250
251/// Tracks recently-written paths to prevent re-ingestion of our own writes.
252struct CooldownSet {
253    entries: HashMap<PathBuf, Instant>,
254    ttl: Duration,
255}
256
257impl CooldownSet {
258    fn new(ttl: Duration) -> Self {
259        Self {
260            entries: HashMap::new(),
261            ttl,
262        }
263    }
264
265    /// Mark a path as recently written (used by loop-back writes and tests).
266    #[allow(dead_code)]
267    fn mark(&mut self, path: PathBuf) {
268        self.entries.insert(path, Instant::now());
269    }
270
271    /// Check if a path is in cooldown (recently written by us).
272    fn is_cooling(&self, path: &Path) -> bool {
273        if let Some(ts) = self.entries.get(path) {
274            ts.elapsed() < self.ttl
275        } else {
276            false
277        }
278    }
279
280    /// Remove expired entries to prevent unbounded growth.
281    fn cleanup(&mut self) {
282        self.entries.retain(|_, ts| ts.elapsed() < self.ttl);
283    }
284}
285
286// ---------------------------------------------------------------------------
287// WatchtowerLoop
288// ---------------------------------------------------------------------------
289
290/// A registered remote source: (db_source_id, provider, file_patterns, poll_interval).
291type RemoteSource = (i64, Box<dyn ContentSourceProvider>, Vec<String>, Duration);
292
293/// The Watchtower content source watcher service.
294///
295/// Watches configured source directories for file changes, debounces events,
296/// and ingests changed files into the database via the shared pipeline.
297pub struct WatchtowerLoop {
298    pool: DbPool,
299    config: ContentSourcesConfig,
300    debounce_duration: Duration,
301    fallback_scan_interval: Duration,
302    cooldown_ttl: Duration,
303}
304
305impl WatchtowerLoop {
306    /// Create a new WatchtowerLoop.
307    pub fn new(pool: DbPool, config: ContentSourcesConfig) -> Self {
308        Self {
309            pool,
310            config,
311            debounce_duration: Duration::from_secs(2),
312            fallback_scan_interval: Duration::from_secs(300), // 5 minutes
313            cooldown_ttl: Duration::from_secs(5),
314        }
315    }
316
317    /// Run the watchtower loop until the cancellation token is triggered.
318    ///
319    /// Registers both local filesystem and remote sources, then runs:
320    /// - `notify` watcher + fallback polling for local sources
321    /// - interval-based polling for remote sources (e.g. Google Drive)
322    pub async fn run(&self, cancel: CancellationToken) {
323        // Split config into local (watchable) and remote (pollable) sources.
324        let local_sources: Vec<_> = self
325            .config
326            .sources
327            .iter()
328            .filter(|s| s.source_type == "local_fs" && s.watch && s.path.is_some())
329            .collect();
330
331        let remote_sources: Vec<_> = self
332            .config
333            .sources
334            .iter()
335            .filter(|s| s.source_type == "google_drive" && s.folder_id.is_some())
336            .collect();
337
338        if local_sources.is_empty() && remote_sources.is_empty() {
339            tracing::info!("Watchtower: no watch sources configured, exiting");
340            return;
341        }
342
343        // Register local source contexts in DB.
344        let mut source_map: Vec<(i64, PathBuf, Vec<String>)> = Vec::new();
345        for src in &local_sources {
346            let path_str = src.path.as_deref().unwrap();
347            let expanded = PathBuf::from(crate::storage::expand_tilde(path_str));
348
349            let config_json = serde_json::json!({
350                "path": path_str,
351                "file_patterns": src.file_patterns,
352                "loop_back_enabled": src.loop_back_enabled,
353            })
354            .to_string();
355
356            match store::ensure_local_fs_source(&self.pool, path_str, &config_json).await {
357                Ok(source_id) => {
358                    source_map.push((source_id, expanded, src.file_patterns.clone()));
359                }
360                Err(e) => {
361                    tracing::error!(path = path_str, error = %e, "Failed to register source context");
362                }
363            }
364        }
365
366        // Register remote source contexts and build provider instances.
367        let mut remote_map: Vec<RemoteSource> = Vec::new();
368        for src in &remote_sources {
369            let folder_id = src.folder_id.as_deref().unwrap();
370            let config_json = serde_json::json!({
371                "folder_id": folder_id,
372                "file_patterns": src.file_patterns,
373                "service_account_key": src.service_account_key,
374            })
375            .to_string();
376
377            match store::ensure_google_drive_source(&self.pool, folder_id, &config_json).await {
378                Ok(source_id) => {
379                    let key_path = src.service_account_key.clone().unwrap_or_default();
380                    let provider = crate::source::google_drive::GoogleDriveProvider::new(
381                        folder_id.to_string(),
382                        key_path,
383                    );
384                    let interval = Duration::from_secs(src.poll_interval_seconds.unwrap_or(300));
385                    remote_map.push((
386                        source_id,
387                        Box::new(provider),
388                        src.file_patterns.clone(),
389                        interval,
390                    ));
391                }
392                Err(e) => {
393                    tracing::error!(
394                        folder_id = folder_id,
395                        error = %e,
396                        "Failed to register Google Drive source"
397                    );
398                }
399            }
400        }
401
402        if source_map.is_empty() && remote_map.is_empty() {
403            tracing::warn!("Watchtower: no sources registered, exiting");
404            return;
405        }
406
407        // Initial scan of all local directories.
408        for (source_id, base_path, patterns) in &source_map {
409            if let Err(e) = self.scan_directory(*source_id, base_path, patterns).await {
410                tracing::error!(
411                    path = %base_path.display(),
412                    error = %e,
413                    "Initial scan failed"
414                );
415            }
416        }
417
418        // Initial poll of remote sources.
419        if !remote_map.is_empty() {
420            self.poll_remote_sources(&remote_map).await;
421        }
422
423        // If there are no local sources, only run remote polling.
424        if source_map.is_empty() {
425            self.remote_only_loop(&remote_map, cancel).await;
426            return;
427        }
428
429        // Bridge notify's sync callback to an async-friendly tokio channel.
430        let (async_tx, mut async_rx) = tokio::sync::mpsc::channel::<DebounceEventResult>(256);
431
432        let handler = move |result: DebounceEventResult| {
433            let _ = async_tx.blocking_send(result);
434        };
435
436        let debouncer_result = new_debouncer(self.debounce_duration, None, handler);
437        let mut debouncer: Debouncer<notify::RecommendedWatcher, RecommendedCache> =
438            match debouncer_result {
439                Ok(d) => d,
440                Err(e) => {
441                    tracing::error!(error = %e, "Failed to create filesystem watcher, falling back to polling");
442                    self.polling_loop(&source_map, cancel).await;
443                    return;
444                }
445            };
446
447        // Register directories with the watcher.
448        for (_, base_path, _) in &source_map {
449            if let Err(e) = debouncer.watch(base_path, RecursiveMode::Recursive) {
450                tracing::error!(
451                    path = %base_path.display(),
452                    error = %e,
453                    "Failed to watch directory"
454                );
455            }
456        }
457
458        tracing::info!(
459            local_sources = source_map.len(),
460            remote_sources = remote_map.len(),
461            "Watchtower watching for changes"
462        );
463
464        let cooldown = Mutex::new(CooldownSet::new(self.cooldown_ttl));
465
466        // Main event loop.
467        let mut fallback_timer = tokio::time::interval(self.fallback_scan_interval);
468        fallback_timer.tick().await; // Consume the immediate first tick.
469
470        // Remote poll interval (use minimum configured or fallback default).
471        let remote_interval = remote_map
472            .iter()
473            .map(|(_, _, _, d)| *d)
474            .min()
475            .unwrap_or(self.fallback_scan_interval);
476        let mut remote_timer = tokio::time::interval(remote_interval);
477        remote_timer.tick().await; // Consume the immediate first tick.
478
479        loop {
480            tokio::select! {
481                () = cancel.cancelled() => {
482                    tracing::info!("Watchtower: cancellation received, shutting down");
483                    break;
484                }
485                _ = fallback_timer.tick() => {
486                    // Periodic fallback scan to catch any missed events.
487                    for (source_id, base_path, patterns) in &source_map {
488                        if let Err(e) = self.scan_directory(*source_id, base_path, patterns).await {
489                            tracing::warn!(
490                                path = %base_path.display(),
491                                error = %e,
492                                "Fallback scan failed"
493                            );
494                        }
495                    }
496                    if let Ok(mut cd) = cooldown.lock() {
497                        cd.cleanup();
498                    }
499                }
500                _ = remote_timer.tick(), if !remote_map.is_empty() => {
501                    self.poll_remote_sources(&remote_map).await;
502                }
503                result = async_rx.recv() => {
504                    match result {
505                        Some(Ok(events)) => {
506                            for event in events {
507                                for path in &event.paths {
508                                    self.handle_event(path, &source_map, &cooldown).await;
509                                }
510                            }
511                        }
512                        Some(Err(errs)) => {
513                            for e in errs {
514                                tracing::warn!(error = %e, "Watcher error");
515                            }
516                        }
517                        None => {
518                            tracing::warn!("Watcher event channel closed");
519                            break;
520                        }
521                    }
522                }
523            }
524        }
525
526        // Drop the debouncer to stop watching.
527        drop(debouncer);
528        tracing::info!("Watchtower shut down");
529    }
530
531    /// Handle a single filesystem event for a changed path.
532    async fn handle_event(
533        &self,
534        path: &Path,
535        source_map: &[(i64, PathBuf, Vec<String>)],
536        cooldown: &Mutex<CooldownSet>,
537    ) {
538        // Check cooldown.
539        if let Ok(cd) = cooldown.lock() {
540            if cd.is_cooling(path) {
541                tracing::debug!(path = %path.display(), "Skipping cooldown path");
542                return;
543            }
544        }
545
546        // Find matching source.
547        for (source_id, base_path, patterns) in source_map {
548            if path.starts_with(base_path) {
549                // Check pattern match.
550                if !matches_patterns(path, patterns) {
551                    return;
552                }
553
554                // Compute relative path.
555                let rel = match path.strip_prefix(base_path) {
556                    Ok(r) => relative_path_string(r),
557                    Err(_) => return,
558                };
559
560                match ingest_file(&self.pool, *source_id, base_path, &rel, false).await {
561                    Ok(result) => {
562                        tracing::debug!(
563                            path = %rel,
564                            result = ?result,
565                            "Watchtower ingested file"
566                        );
567                    }
568                    Err(e) => {
569                        tracing::warn!(
570                            path = %rel,
571                            error = %e,
572                            "Watchtower ingest failed"
573                        );
574                    }
575                }
576                return;
577            }
578        }
579    }
580
581    /// Scan a directory for all matching files and ingest them.
582    async fn scan_directory(
583        &self,
584        source_id: i64,
585        base_path: &Path,
586        patterns: &[String],
587    ) -> Result<IngestSummary, WatchtowerError> {
588        let mut rel_paths = Vec::new();
589        Self::walk_directory(base_path, base_path, patterns, &mut rel_paths)?;
590
591        let summary = ingest_files(&self.pool, source_id, base_path, &rel_paths, false).await;
592
593        tracing::debug!(
594            path = %base_path.display(),
595            ingested = summary.ingested,
596            skipped = summary.skipped,
597            errors = summary.errors.len(),
598            "Directory scan complete"
599        );
600
601        // Update sync cursor.
602        let cursor = chrono::Utc::now().to_rfc3339();
603        if let Err(e) = store::update_sync_cursor(&self.pool, source_id, &cursor).await {
604            tracing::warn!(error = %e, "Failed to update sync cursor");
605        }
606
607        Ok(summary)
608    }
609
610    /// Recursively walk a directory, collecting relative paths of matching files.
611    fn walk_directory(
612        base: &Path,
613        current: &Path,
614        patterns: &[String],
615        out: &mut Vec<String>,
616    ) -> Result<(), WatchtowerError> {
617        let entries = std::fs::read_dir(current)?;
618        for entry in entries {
619            let entry = entry?;
620            let file_type = entry.file_type()?;
621            let path = entry.path();
622
623            if file_type.is_dir() {
624                // Skip hidden directories.
625                if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
626                    if name.starts_with('.') {
627                        continue;
628                    }
629                }
630                Self::walk_directory(base, &path, patterns, out)?;
631            } else if file_type.is_file() && matches_patterns(&path, patterns) {
632                if let Ok(rel) = path.strip_prefix(base) {
633                    out.push(relative_path_string(rel));
634                }
635            }
636        }
637        Ok(())
638    }
639
640    /// Poll all remote sources for changes, ingest new/updated content.
641    async fn poll_remote_sources(&self, remote_sources: &[RemoteSource]) {
642        for (source_id, provider, patterns, _interval) in remote_sources {
643            let cursor = match store::get_source_context(&self.pool, *source_id).await {
644                Ok(Some(ctx)) => ctx.sync_cursor,
645                Ok(None) => None,
646                Err(e) => {
647                    tracing::warn!(source_id, error = %e, "Failed to get source context");
648                    continue;
649                }
650            };
651
652            match provider.scan_for_changes(cursor.as_deref(), patterns).await {
653                Ok(files) => {
654                    let mut ingested = 0u32;
655                    let mut skipped = 0u32;
656                    for file in &files {
657                        match provider.read_content(&file.provider_id).await {
658                            Ok(content) => {
659                                match ingest_content(
660                                    &self.pool,
661                                    *source_id,
662                                    &file.provider_id,
663                                    &content,
664                                    false,
665                                )
666                                .await
667                                {
668                                    Ok(
669                                        store::UpsertResult::Inserted
670                                        | store::UpsertResult::Updated,
671                                    ) => {
672                                        ingested += 1;
673                                    }
674                                    Ok(store::UpsertResult::Skipped) => {
675                                        skipped += 1;
676                                    }
677                                    Err(e) => {
678                                        tracing::warn!(
679                                            provider_id = %file.provider_id,
680                                            error = %e,
681                                            "Remote ingest failed"
682                                        );
683                                    }
684                                }
685                            }
686                            Err(e) => {
687                                tracing::warn!(
688                                    provider_id = %file.provider_id,
689                                    error = %e,
690                                    "Failed to read remote content"
691                                );
692                            }
693                        }
694                    }
695
696                    tracing::debug!(
697                        source_type = provider.source_type(),
698                        ingested,
699                        skipped,
700                        total = files.len(),
701                        "Remote poll complete"
702                    );
703
704                    // Update sync cursor.
705                    let new_cursor = chrono::Utc::now().to_rfc3339();
706                    if let Err(e) =
707                        store::update_sync_cursor(&self.pool, *source_id, &new_cursor).await
708                    {
709                        tracing::warn!(error = %e, "Failed to update remote sync cursor");
710                    }
711                }
712                Err(e) => {
713                    tracing::warn!(
714                        source_type = provider.source_type(),
715                        error = %e,
716                        "Remote scan failed"
717                    );
718                    let _ = store::update_source_status(
719                        &self.pool,
720                        *source_id,
721                        "error",
722                        Some(&e.to_string()),
723                    )
724                    .await;
725                }
726            }
727        }
728    }
729
730    /// Loop for when only remote sources are configured (no local watchers).
731    async fn remote_only_loop(&self, remote_map: &[RemoteSource], cancel: CancellationToken) {
732        let interval_dur = remote_map
733            .iter()
734            .map(|(_, _, _, d)| *d)
735            .min()
736            .unwrap_or(self.fallback_scan_interval);
737        let mut interval = tokio::time::interval(interval_dur);
738        interval.tick().await;
739
740        loop {
741            tokio::select! {
742                () = cancel.cancelled() => {
743                    tracing::info!("Watchtower remote-only loop cancelled");
744                    break;
745                }
746                _ = interval.tick() => {
747                    self.poll_remote_sources(remote_map).await;
748                }
749            }
750        }
751    }
752
753    /// Polling-only fallback loop when the notify watcher fails to initialize.
754    async fn polling_loop(
755        &self,
756        source_map: &[(i64, PathBuf, Vec<String>)],
757        cancel: CancellationToken,
758    ) {
759        let mut interval = tokio::time::interval(self.fallback_scan_interval);
760        interval.tick().await; // Consume immediate tick.
761
762        loop {
763            tokio::select! {
764                () = cancel.cancelled() => {
765                    tracing::info!("Watchtower polling loop cancelled");
766                    break;
767                }
768                _ = interval.tick() => {
769                    for (source_id, base_path, patterns) in source_map {
770                        if let Err(e) = self.scan_directory(*source_id, base_path, patterns).await {
771                            tracing::warn!(
772                                path = %base_path.display(),
773                                error = %e,
774                                "Polling scan failed"
775                            );
776                        }
777                    }
778                }
779            }
780        }
781    }
782}