sqlite_watcher/
watcher.rs

1use fixedbitset::FixedBitSet;
2use flume::{Receiver, Sender, TryRecvError};
3use parking_lot::RwLock;
4use slotmap::{SlotMap, new_key_type};
5use std::collections::btree_map::Entry;
6use std::collections::{BTreeMap, BTreeSet};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::{Arc, Weak};
9use tracing::{debug, error};
10
11new_key_type! {
12    /// Handle for a [`TableObserver`].
13    pub struct TableObserverHandle;
14}
15
16/// Utility type that removes an observer from a [`Watcher`] when the type is dropped.
17///
18/// The [`TableObserver`] will be removed with [`Watcher::remove_observer_deferred()`].
19#[derive(Debug, Clone)]
20pub struct DropRemoveTableObserverHandle {
21    watcher: Weak<Watcher>,
22    handle: TableObserverHandle,
23}
24
25impl DropRemoveTableObserverHandle {
26    fn new(handle: TableObserverHandle, watcher: &Arc<Watcher>) -> Self {
27        Self {
28            watcher: Arc::downgrade(watcher),
29            handle,
30        }
31    }
32
33    /// Returns the handle of the table observer.
34    #[must_use]
35    pub fn handle(&self) -> TableObserverHandle {
36        self.handle
37    }
38
39    /// Unsubscribe the observer immediately.
40    ///
41    /// This can be safely called multiple times, the observer
42    /// is only unsubscribed once.
43    ///
44    /// # Errors
45    ///
46    /// Returns error if we can't communicate with Watcher or the
47    /// removal of the observer failed.
48    pub fn unsubscribe(&self) -> Result<(), Error> {
49        if let Some(watcher) = self.watcher.upgrade() {
50            watcher.remove_observer(self.handle)
51        } else {
52            Err(Error::Command)
53        }
54    }
55}
56
57impl Drop for DropRemoveTableObserverHandle {
58    fn drop(&mut self) {
59        if let Some(watcher) = self.watcher.upgrade() {
60            if watcher.remove_observer_deferred(self.handle).is_err() {
61                error!("Failed to remove watcher from observer on drop");
62            }
63        }
64    }
65}
66
67/// Defines an observer for a set of tables.
68pub trait TableObserver: Send + Sync {
69    /// Return the set of tables this observer is interested in.
70    fn tables(&self) -> Vec<String>;
71
72    /// When one or more of the tables return by [`Self::tables()`] is modified, this method
73    /// will be invoked by the [`Watcher`].
74    ///
75    /// `tables` contains the set of tables that we modified.
76    ///
77    /// It is recommended that the implementation be as short as possible to not delay/block
78    /// the execution of other observers.
79    fn on_tables_changed(&self, tables: &BTreeSet<String>);
80}
81
82/// The [`Watcher`] is the hub where updates are published regarding tables that updated when
83/// observing a connection.
84///
85/// All changes are published to a background thread which then notifies the respective
86/// [`TableObserver`]s.
87///
88/// # Observing Tables
89///
90/// To be notified of changes, register an observer with [`Watcher::add_observer`].
91///
92/// The [`Watcher`] by itself does not automatically watch all tables. The observed tables
93/// are driven by the tables defined by each [`TableObserver`].
94///
95/// A table can be observed by many [`TableObserver`]. When the last [`TableObserver`] is removed
96/// for a given table, that table stops being tracked.
97///
98/// # Update Propagation
99///
100/// Every time a [`TableObserver`] is added or removed, the list of tracked tables is updated and
101/// a counter is bumped. These changes are propagated to [State] instances when they sync their
102/// state [`State::sync_tables()`](crate::connection::State::sync_tables).
103///
104/// Due to the nature of concurrent operations, it is possible that a connection on different
105/// thread will miss the changes applied from adding/removing an observer on the current thread. On
106/// the next sync this will be rectified.
107///
108/// If both operation happen on the same thread, everything will work as expected.
109///
110/// # Notifications
111///
112/// To notify the [`Watcher`] of changed tables, an instance of either [Connection] or
113/// [State] needs to be used. Check each type for more information on how to use
114/// it correctly.
115///
116/// # Remarks
117///
118/// The [`Watcher`] currently maintains a list of observed tables that is never pruned. It will
119/// keep growing with every new table that is observed. If you have af fixed set of tables that
120/// you watch on a regular basis this is not an issue. If you have a dynamic list of tables
121/// deleted tables are currently not removed. To be addressed in the future.
122///
123///
124/// [Connection]: `crate::connection::Connection`
125/// [State]: `crate::connection::State`
126pub struct Watcher {
127    tables: RwLock<ObservedTables>,
128    tables_version: AtomicU64,
129    sender: Sender<Command>,
130}
131
132const WATCHER_CHANNEL_CAPACITY: usize = 24;
133
134impl Watcher {
135    /// Create a new instance of an in process tracker service.
136    ///
137    /// # Errors
138    /// Returns error if the worker thread fails to spawn.
139    pub fn new() -> Result<Arc<Self>, Error> {
140        let (sender, receiver) = flume::bounded(WATCHER_CHANNEL_CAPACITY);
141        let watcher = Arc::new(Self {
142            tables: RwLock::new(ObservedTables::new()),
143            tables_version: AtomicU64::new(0),
144            sender,
145        });
146
147        let watcher_cloned = Arc::clone(&watcher);
148        std::thread::Builder::new()
149            .name("sqlite_watcher".into())
150            .spawn(move || {
151                Watcher::background_loop(receiver, &watcher_cloned);
152            })
153            .map_err(Error::Thread)?;
154
155        Ok(watcher)
156    }
157
158    /// Register a new observer with a list of interested tables.
159    ///
160    /// This function returns a [`TableObserverHandle`] which can later be used to
161    /// remove the current observer.
162    ///
163    /// # Errors
164    ///
165    /// Returns error if the command which adds the observer to the background thread
166    /// could not be sent or the handle could not be retrieved.
167    pub fn add_observer(
168        &self,
169        observer: Box<dyn TableObserver>,
170    ) -> Result<TableObserverHandle, Error> {
171        let (sender, receiver) = oneshot::channel();
172        if self
173            .sender
174            .send(Command::AddObserver(observer, sender))
175            .is_err()
176        {
177            error!("Failed to send add observer command");
178            return Err(Error::Command);
179        }
180
181        let Ok(handle) = receiver.recv() else {
182            error!("Failed to receive handle for new observer");
183            return Err(Error::Command);
184        };
185
186        Ok(handle)
187    }
188
189    /// Same as [`Self::add_observer`], but returns a handle that removes the observer
190    /// from this [`Watcher`] on drop.
191    ///
192    ///
193    /// # Errors
194    ///
195    /// See [`Self::add_observer`] for more details.
196    pub fn add_observer_with_drop_remove(
197        self: &Arc<Self>,
198        observer: Box<dyn TableObserver>,
199    ) -> Result<DropRemoveTableObserverHandle, Error> {
200        let handle = self.add_observer(observer)?;
201
202        Ok(DropRemoveTableObserverHandle::new(handle, self))
203    }
204
205    /// Remove an observer via its `handle` without waiting for the operation to complete.
206    ///
207    /// The removal of observers is deferred to the background thread and will
208    /// be executed as soon as possible.
209    ///
210    /// If you wish to wait for an observer to finish being removed from the list,
211    /// you should use [`Self::remove_observer()`]
212    ///
213    /// # Errors
214    ///
215    /// Returns error if the command to remove the observer could not be sent.
216    pub fn remove_observer_deferred(&self, handle: TableObserverHandle) -> Result<(), Error> {
217        self.sender
218            .send(Command::RemoveObserverDeferred(handle))
219            .map_err(|_| Error::Command)
220    }
221
222    /// Remove an observer via its `handle` and wait for it to be removed.
223    ///
224    /// If you wish do not wish to wait for an observer to finish being removed from the list,
225    /// you should use [`Self::remove_observer_deferred()`]
226    ///
227    /// # Errors
228    ///
229    /// Returns error if the command to remove the observer could not be sent or the reply
230    /// could not be received.
231    pub fn remove_observer(&self, handle: TableObserverHandle) -> Result<(), Error> {
232        let (sender, receiver) = oneshot::channel();
233        self.sender
234            .send(Command::RemoveObserver(handle, sender))
235            .map_err(|_| Error::Command)?;
236
237        receiver.recv().map_err(|_| {
238            error!("Failed to receive reply for remove observer command");
239            Error::Command
240        })
241    }
242
243    pub(crate) fn publish_changes(&self, table_ids: FixedBitSet) {
244        if self
245            .sender
246            .send(Command::PublishChanges(table_ids))
247            .is_err()
248        {
249            error!("Watcher could not communicate with background thread");
250        }
251    }
252
253    pub(crate) async fn publish_changes_async(&self, table_ids: FixedBitSet) {
254        if self
255            .sender
256            .send_async(Command::PublishChanges(table_ids))
257            .await
258            .is_err()
259        {
260            error!("Watcher could not communicate with background thread");
261        }
262    }
263
264    #[cfg(test)]
265    pub(crate) fn get_table_id(&self, table: &str) -> Option<usize> {
266        self.with_tables(|tables| tables.table_ids.get(table).copied())
267    }
268
269    fn with_tables_mut(&self, f: impl (FnOnce(&mut ObservedTables))) {
270        let mut accessor = self.tables.write();
271        // Save counter to check for significant changes
272        let prev_counter = accessor.counter;
273
274        (f)(&mut accessor);
275
276        // Significant changes were made.
277        let cur_counter = accessor.counter;
278        if prev_counter != cur_counter {
279            self.tables_version.fetch_add(1, Ordering::Release);
280        }
281    }
282
283    fn with_tables<R>(&self, f: impl (FnOnce(&ObservedTables) -> R)) -> R {
284        let accessor = self.tables.read();
285        (f)(&accessor)
286    }
287
288    /// The current version of the tracked tables state.
289    pub(crate) fn tables_version(&self) -> u64 {
290        self.tables_version.load(Ordering::Acquire)
291    }
292
293    /// Return the list of observed tables at this point in time.
294    pub fn observed_tables(&self) -> Vec<String> {
295        self.with_tables(|t| t.tables.clone())
296    }
297
298    pub(crate) fn calculate_sync_changes(
299        &self,
300        connection_state: &FixedBitSet,
301    ) -> (FixedBitSet, Vec<ObservedTableOp>) {
302        self.with_tables(|t| t.calculate_changes(connection_state))
303    }
304    #[tracing::instrument(level= tracing::Level::TRACE, skip(receiver, watcher))]
305    fn background_loop(receiver: Receiver<Command>, watcher: &Watcher) {
306        let mut worker = WatcherWorker::new();
307
308        loop {
309            debug_assert!(worker.add_observers.is_empty());
310            debug_assert!(worker.remove_observers.is_empty());
311            debug_assert!(worker.publish_changes.is_empty());
312
313            let Ok(command) = receiver.recv() else {
314                return;
315            };
316
317            worker.unpack_command(command);
318
319            // try to read more commands if any are queued.
320            loop {
321                match receiver.try_recv() {
322                    Ok(command) => {
323                        worker.unpack_command(command);
324                    }
325                    Err(e) => match e {
326                        TryRecvError::Empty => {
327                            break;
328                        }
329                        TryRecvError::Disconnected => {
330                            return;
331                        }
332                    },
333                }
334            }
335
336            worker.tick(watcher);
337        }
338    }
339}
340
341/// Background watcher worker which responds to [`Command`]s;
342struct WatcherWorker {
343    observers: SlotMap<TableObserverHandle, ActiveObserver>,
344    updated_tables: BTreeSet<String>,
345    remove_observers: Vec<(TableObserverHandle, Option<oneshot::Sender<()>>)>,
346    add_observers: Vec<(Box<dyn TableObserver>, oneshot::Sender<TableObserverHandle>)>,
347    publish_changes: Vec<FixedBitSet>,
348}
349
350impl WatcherWorker {
351    fn new() -> Self {
352        Self {
353            observers: SlotMap::with_capacity_and_key(4),
354            updated_tables: BTreeSet::default(),
355            remove_observers: vec![],
356            add_observers: vec![],
357            publish_changes: vec![],
358        }
359    }
360    fn unpack_command(&mut self, command: Command) {
361        match command {
362            Command::AddObserver(o, r) => self.add_observers.push((o, r)),
363            Command::RemoveObserver(h, r) => self.remove_observers.push((h, Some(r))),
364            Command::RemoveObserverDeferred(h) => {
365                self.remove_observers.push((h, None));
366            }
367            Command::PublishChanges(fixedbitset) => {
368                self.publish_changes.push(fixedbitset);
369            }
370        }
371    }
372
373    fn tick(&mut self, watcher: &Watcher) {
374        // Remove old observers,
375        for (handle, reply) in self.remove_observers.drain(..) {
376            if let Some(observer) = self.observers.remove(handle) {
377                watcher.with_tables_mut(|tables| {
378                    tables.untrack_tables(observer.tables.iter());
379                });
380            }
381
382            if let Some(reply) = reply {
383                if reply.send(()).is_err() {
384                    error!("Failed to send reply for observer removal");
385                }
386            }
387        }
388
389        // Add new observers
390        for (observer, reply) in self.add_observers.drain(..) {
391            let active_observer = ActiveObserver::new(observer);
392            watcher.with_tables_mut(|tables| {
393                tables.track_tables(active_observer.tables.iter().cloned());
394            });
395            let handle = self.observers.insert(active_observer);
396            if reply.send(handle).is_err() {
397                error!("Failed to send reply back to caller, new observer will not be added");
398                self.observers.remove(handle);
399            }
400        }
401
402        // Combine and publish changes.
403        self.updated_tables.clear();
404
405        for table_ids in self.publish_changes.drain(..) {
406            if table_ids.is_clear() {
407                continue;
408            }
409
410            // resolve table names.
411            watcher.with_tables(|observer_tables| {
412                for idx in table_ids.ones() {
413                    // Safeguard against some invalid index, just in case.
414                    if let Some(name) = observer_tables.tables.get(idx).cloned() {
415                        self.updated_tables.insert(name);
416                    }
417                }
418            });
419        }
420
421        if !self.updated_tables.is_empty() {
422            debug!("Changes detected on tables: {:?}", self.updated_tables);
423            // publish changes;
424            {
425                for (_, active_observer) in &self.observers {
426                    if self
427                        .updated_tables
428                        .intersection(&active_observer.tables)
429                        .next()
430                        .is_some()
431                    {
432                        active_observer
433                            .observer
434                            .on_tables_changed(&self.updated_tables);
435                    }
436                }
437            }
438        }
439    }
440}
441
442struct ActiveObserver {
443    observer: Box<dyn TableObserver>,
444    tables: BTreeSet<String>,
445}
446
447impl ActiveObserver {
448    fn new(observer: Box<dyn TableObserver>) -> ActiveObserver {
449        let tables = BTreeSet::from_iter(observer.tables());
450        Self { observer, tables }
451    }
452}
453
454/// Commands send to the background thread.
455enum Command {
456    /// Add a new observer
457    AddObserver(Box<dyn TableObserver>, oneshot::Sender<TableObserverHandle>),
458    /// Remove an observer
459    RemoveObserverDeferred(TableObserverHandle),
460    /// Remove an observer and wait for the operation to finish.
461    RemoveObserver(TableObserverHandle, oneshot::Sender<()>),
462    /// Publish new changes
463    PublishChanges(FixedBitSet),
464}
465
466#[derive(Debug, thiserror::Error)]
467pub enum Error {
468    #[error("Failed to send or receive command to/from background thread")]
469    Command,
470    #[error("Failed to create background thread: {0}")]
471    Thread(std::io::Error),
472}
473
474#[derive(Debug, Clone, Eq, PartialEq)]
475pub(crate) enum ObservedTableOp {
476    Add(String, usize),
477    Remove(String, usize),
478}
479
480/// Keeps track of all the observed tables.
481///
482/// Each table is assigned an unique value (index) which is then propagated to all the trackers
483/// when they sync their state.
484struct ObservedTables {
485    /// Table names to index/id
486    table_ids: BTreeMap<String, usize>,
487    /// Table names by index/id
488    tables: Vec<String>,
489    /// Number of active observers for each table.
490    num_observers: Vec<usize>,
491    /// Version counter.
492    counter: u64,
493}
494
495impl ObservedTables {
496    fn new() -> Self {
497        Self {
498            table_ids: BTreeMap::new(),
499            tables: Vec::with_capacity(8),
500            num_observers: Vec::with_capacity(8),
501            counter: 0,
502        }
503    }
504
505    /// Add the `tables` to the list of tables that need to be observed.
506    fn track_tables(&mut self, tables: impl Iterator<Item = String>) {
507        let mut requires_version_bump = false;
508        for table in tables {
509            match self.table_ids.entry(table.clone()) {
510                Entry::Vacant(v) => {
511                    let id = self.num_observers.len();
512                    self.tables.push(table.clone());
513                    self.num_observers.push(1);
514                    v.insert(id);
515                    requires_version_bump = true;
516                }
517                Entry::Occupied(o) => {
518                    let id = *o.get();
519                    let current = self.num_observers[id];
520                    if current == 0 {
521                        // We should start following this table again. If it is not
522                        // 0, we are already observing it.
523                        requires_version_bump = true;
524                    }
525                    self.num_observers[*o.get()] = current + 1;
526                }
527            }
528        }
529
530        if requires_version_bump {
531            self.counter = self.counter.saturating_add(1);
532        }
533    }
534
535    /// Remove the `tables` from the list of tables that need to be observed.
536    fn untrack_tables<'i>(&mut self, tables: impl Iterator<Item = &'i String>) {
537        let mut requires_version_bump = false;
538        for table in tables {
539            if let Some(id) = self.table_ids.get(table) {
540                // We never remove the table entirely, but we need to stop tracking
541                // once all observers have been removed.
542                self.num_observers[*id] -= 1;
543                if self.num_observers[*id] == 0 {
544                    requires_version_bump = true;
545                }
546            }
547        }
548
549        if requires_version_bump {
550            self.counter = self.counter.saturating_add(1);
551        }
552    }
553
554    /// Calculate the which tables should be added or removed from a `connection_state` to
555    /// make sure it is synced up with the current list.
556    ///
557    /// This will return the new updated state as well as the list of triggers that should be
558    /// created or removed.
559    fn calculate_changes(
560        &self,
561        connection_state: &FixedBitSet,
562    ) -> (FixedBitSet, Vec<ObservedTableOp>) {
563        let mut result = connection_state.clone();
564        result.grow(self.tables.len());
565        let mut changes = Vec::with_capacity(self.tables.len());
566        let min_index = connection_state.len().min(self.tables.len());
567        for i in 0..min_index {
568            let is_tracking = connection_state[i];
569            let num_observers = self.num_observers[i];
570
571            if is_tracking && num_observers == 0 {
572                changes.push(ObservedTableOp::Remove(self.tables[i].clone(), i));
573                result.set(i, false);
574            } else if !is_tracking && num_observers != 0 {
575                changes.push(ObservedTableOp::Add(self.tables[i].clone(), i));
576                result.set(i, true);
577            }
578        }
579
580        // Process any new tables that might be missing.
581        for i in min_index..self.num_observers.len() {
582            if self.num_observers[i] != 0 {
583                changes.push(ObservedTableOp::Add(self.tables[i].clone(), i));
584                result.set(i, true);
585            }
586        }
587
588        (result, changes)
589    }
590}
591
592#[cfg(test)]
593pub(crate) mod tests {
594    use crate::watcher::{ObservedTables, TableObserver, Watcher};
595    use std::collections::BTreeSet;
596    use std::sync::atomic::Ordering;
597
598    pub struct TestObserver {
599        tables: Vec<String>,
600    }
601
602    impl TableObserver for TestObserver {
603        fn tables(&self) -> Vec<String> {
604            self.tables.clone()
605        }
606        fn on_tables_changed(&self, _: &BTreeSet<String>) {}
607    }
608
609    pub(crate) fn new_test_observer(
610        tables: impl IntoIterator<Item = &'static str>,
611    ) -> Box<dyn TableObserver + Send + 'static> {
612        Box::new(TestObserver {
613            tables: tables.into_iter().map(ToString::to_string).collect(),
614        })
615    }
616
617    fn check_table_counter(tables: &ObservedTables, name: &str, expected: usize) {
618        let idx = *tables
619            .table_ids
620            .get(name)
621            .expect("could not find table by name");
622        assert_eq!(tables.num_observers[idx], expected);
623    }
624
625    #[test]
626    fn test_observer_tables_version_counter() {
627        let service = Watcher::new().unwrap();
628
629        let mut version = service.tables_version.load(Ordering::Relaxed);
630        let observer_1 = new_test_observer(["foo", "bar"]);
631        let observer_2 = new_test_observer(["bar"]);
632        let observer_3 = new_test_observer(["bar", "omega"]);
633
634        // Adding new observer triggers change.
635        let observer_1_id = service.add_observer(observer_1).unwrap();
636        service.with_tables(|tables| {
637            assert_eq!(tables.num_observers.len(), 2);
638            check_table_counter(tables, "foo", 1);
639            check_table_counter(tables, "bar", 1);
640        });
641        version += 1;
642        assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
643
644        // Adding an observer for only bar does not change version counter.
645        let observer_2_id = service.add_observer(observer_2).unwrap();
646        service.with_tables(|tables| {
647            assert_eq!(tables.num_observers.len(), 2);
648            check_table_counter(tables, "foo", 1);
649            check_table_counter(tables, "bar", 2);
650        });
651        assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
652
653        // Adding this observer causes another change
654        let observer_3_id = service.add_observer(observer_3).unwrap();
655        service.with_tables(|tables| {
656            assert_eq!(tables.num_observers.len(), 3);
657            check_table_counter(tables, "foo", 1);
658            check_table_counter(tables, "omega", 1);
659            check_table_counter(tables, "bar", 3);
660        });
661        version += 1;
662        assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
663
664        // Remove observer 2 causes no version change.
665        service.remove_observer(observer_2_id).unwrap();
666        service.with_tables(|tables| {
667            assert_eq!(tables.num_observers.len(), 3);
668            check_table_counter(tables, "foo", 1);
669            check_table_counter(tables, "bar", 2);
670            check_table_counter(tables, "omega", 1);
671        });
672        assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
673
674        // Remove observer 3 causes version change.
675        service.remove_observer(observer_3_id).unwrap();
676        service.with_tables(|tables| {
677            assert_eq!(tables.num_observers.len(), 3);
678            check_table_counter(tables, "foo", 1);
679            check_table_counter(tables, "bar", 1);
680            check_table_counter(tables, "omega", 0);
681        });
682        version += 1;
683        assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
684
685        // Remove observer 1 causes version change.
686        service.remove_observer(observer_1_id).unwrap();
687        service.with_tables(|tables| {
688            assert_eq!(tables.num_observers.len(), 3);
689            check_table_counter(tables, "foo", 0);
690            check_table_counter(tables, "bar", 0);
691            check_table_counter(tables, "omega", 0);
692        });
693        version += 1;
694        assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
695    }
696}