1use fixedbitset::FixedBitSet;
2use flume::{Receiver, Sender, TryRecvError};
3use parking_lot::RwLock;
4use slotmap::{SlotMap, new_key_type};
5use std::collections::btree_map::Entry;
6use std::collections::{BTreeMap, BTreeSet};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::{Arc, Weak};
9use tracing::{debug, error};
10
11new_key_type! {
12 pub struct TableObserverHandle;
14}
15
16#[derive(Debug, Clone)]
20pub struct DropRemoveTableObserverHandle {
21 watcher: Weak<Watcher>,
22 handle: TableObserverHandle,
23}
24
25impl DropRemoveTableObserverHandle {
26 fn new(handle: TableObserverHandle, watcher: &Arc<Watcher>) -> Self {
27 Self {
28 watcher: Arc::downgrade(watcher),
29 handle,
30 }
31 }
32
33 #[must_use]
35 pub fn handle(&self) -> TableObserverHandle {
36 self.handle
37 }
38
39 pub fn unsubscribe(&self) -> Result<(), Error> {
49 if let Some(watcher) = self.watcher.upgrade() {
50 watcher.remove_observer(self.handle)
51 } else {
52 Err(Error::Command)
53 }
54 }
55}
56
57impl Drop for DropRemoveTableObserverHandle {
58 fn drop(&mut self) {
59 if let Some(watcher) = self.watcher.upgrade() {
60 if watcher.remove_observer_deferred(self.handle).is_err() {
61 error!("Failed to remove watcher from observer on drop");
62 }
63 }
64 }
65}
66
67pub trait TableObserver: Send + Sync {
69 fn tables(&self) -> Vec<String>;
71
72 fn on_tables_changed(&self, tables: &BTreeSet<String>);
80}
81
82pub struct Watcher {
127 tables: RwLock<ObservedTables>,
128 tables_version: AtomicU64,
129 sender: Sender<Command>,
130}
131
132const WATCHER_CHANNEL_CAPACITY: usize = 24;
133
134impl Watcher {
135 pub fn new() -> Result<Arc<Self>, Error> {
140 let (sender, receiver) = flume::bounded(WATCHER_CHANNEL_CAPACITY);
141 let watcher = Arc::new(Self {
142 tables: RwLock::new(ObservedTables::new()),
143 tables_version: AtomicU64::new(0),
144 sender,
145 });
146
147 let watcher_cloned = Arc::clone(&watcher);
148 std::thread::Builder::new()
149 .name("sqlite_watcher".into())
150 .spawn(move || {
151 Watcher::background_loop(receiver, &watcher_cloned);
152 })
153 .map_err(Error::Thread)?;
154
155 Ok(watcher)
156 }
157
158 pub fn add_observer(
168 &self,
169 observer: Box<dyn TableObserver>,
170 ) -> Result<TableObserverHandle, Error> {
171 let (sender, receiver) = oneshot::channel();
172 if self
173 .sender
174 .send(Command::AddObserver(observer, sender))
175 .is_err()
176 {
177 error!("Failed to send add observer command");
178 return Err(Error::Command);
179 }
180
181 let Ok(handle) = receiver.recv() else {
182 error!("Failed to receive handle for new observer");
183 return Err(Error::Command);
184 };
185
186 Ok(handle)
187 }
188
189 pub fn add_observer_with_drop_remove(
197 self: &Arc<Self>,
198 observer: Box<dyn TableObserver>,
199 ) -> Result<DropRemoveTableObserverHandle, Error> {
200 let handle = self.add_observer(observer)?;
201
202 Ok(DropRemoveTableObserverHandle::new(handle, self))
203 }
204
205 pub fn remove_observer_deferred(&self, handle: TableObserverHandle) -> Result<(), Error> {
217 self.sender
218 .send(Command::RemoveObserverDeferred(handle))
219 .map_err(|_| Error::Command)
220 }
221
222 pub fn remove_observer(&self, handle: TableObserverHandle) -> Result<(), Error> {
232 let (sender, receiver) = oneshot::channel();
233 self.sender
234 .send(Command::RemoveObserver(handle, sender))
235 .map_err(|_| Error::Command)?;
236
237 receiver.recv().map_err(|_| {
238 error!("Failed to receive reply for remove observer command");
239 Error::Command
240 })
241 }
242
243 pub(crate) fn publish_changes(&self, table_ids: FixedBitSet) {
244 if self
245 .sender
246 .send(Command::PublishChanges(table_ids))
247 .is_err()
248 {
249 error!("Watcher could not communicate with background thread");
250 }
251 }
252
253 pub(crate) async fn publish_changes_async(&self, table_ids: FixedBitSet) {
254 if self
255 .sender
256 .send_async(Command::PublishChanges(table_ids))
257 .await
258 .is_err()
259 {
260 error!("Watcher could not communicate with background thread");
261 }
262 }
263
264 #[cfg(test)]
265 pub(crate) fn get_table_id(&self, table: &str) -> Option<usize> {
266 self.with_tables(|tables| tables.table_ids.get(table).copied())
267 }
268
269 fn with_tables_mut(&self, f: impl (FnOnce(&mut ObservedTables))) {
270 let mut accessor = self.tables.write();
271 let prev_counter = accessor.counter;
273
274 (f)(&mut accessor);
275
276 let cur_counter = accessor.counter;
278 if prev_counter != cur_counter {
279 self.tables_version.fetch_add(1, Ordering::Release);
280 }
281 }
282
283 fn with_tables<R>(&self, f: impl (FnOnce(&ObservedTables) -> R)) -> R {
284 let accessor = self.tables.read();
285 (f)(&accessor)
286 }
287
288 pub(crate) fn tables_version(&self) -> u64 {
290 self.tables_version.load(Ordering::Acquire)
291 }
292
293 pub fn observed_tables(&self) -> Vec<String> {
295 self.with_tables(|t| t.tables.clone())
296 }
297
298 pub(crate) fn calculate_sync_changes(
299 &self,
300 connection_state: &FixedBitSet,
301 ) -> (FixedBitSet, Vec<ObservedTableOp>) {
302 self.with_tables(|t| t.calculate_changes(connection_state))
303 }
304 #[tracing::instrument(level= tracing::Level::TRACE, skip(receiver, watcher))]
305 fn background_loop(receiver: Receiver<Command>, watcher: &Watcher) {
306 let mut worker = WatcherWorker::new();
307
308 loop {
309 debug_assert!(worker.add_observers.is_empty());
310 debug_assert!(worker.remove_observers.is_empty());
311 debug_assert!(worker.publish_changes.is_empty());
312
313 let Ok(command) = receiver.recv() else {
314 return;
315 };
316
317 worker.unpack_command(command);
318
319 loop {
321 match receiver.try_recv() {
322 Ok(command) => {
323 worker.unpack_command(command);
324 }
325 Err(e) => match e {
326 TryRecvError::Empty => {
327 break;
328 }
329 TryRecvError::Disconnected => {
330 return;
331 }
332 },
333 }
334 }
335
336 worker.tick(watcher);
337 }
338 }
339}
340
341struct WatcherWorker {
343 observers: SlotMap<TableObserverHandle, ActiveObserver>,
344 updated_tables: BTreeSet<String>,
345 remove_observers: Vec<(TableObserverHandle, Option<oneshot::Sender<()>>)>,
346 add_observers: Vec<(Box<dyn TableObserver>, oneshot::Sender<TableObserverHandle>)>,
347 publish_changes: Vec<FixedBitSet>,
348}
349
350impl WatcherWorker {
351 fn new() -> Self {
352 Self {
353 observers: SlotMap::with_capacity_and_key(4),
354 updated_tables: BTreeSet::default(),
355 remove_observers: vec![],
356 add_observers: vec![],
357 publish_changes: vec![],
358 }
359 }
360 fn unpack_command(&mut self, command: Command) {
361 match command {
362 Command::AddObserver(o, r) => self.add_observers.push((o, r)),
363 Command::RemoveObserver(h, r) => self.remove_observers.push((h, Some(r))),
364 Command::RemoveObserverDeferred(h) => {
365 self.remove_observers.push((h, None));
366 }
367 Command::PublishChanges(fixedbitset) => {
368 self.publish_changes.push(fixedbitset);
369 }
370 }
371 }
372
373 fn tick(&mut self, watcher: &Watcher) {
374 for (handle, reply) in self.remove_observers.drain(..) {
376 if let Some(observer) = self.observers.remove(handle) {
377 watcher.with_tables_mut(|tables| {
378 tables.untrack_tables(observer.tables.iter());
379 });
380 }
381
382 if let Some(reply) = reply {
383 if reply.send(()).is_err() {
384 error!("Failed to send reply for observer removal");
385 }
386 }
387 }
388
389 for (observer, reply) in self.add_observers.drain(..) {
391 let active_observer = ActiveObserver::new(observer);
392 watcher.with_tables_mut(|tables| {
393 tables.track_tables(active_observer.tables.iter().cloned());
394 });
395 let handle = self.observers.insert(active_observer);
396 if reply.send(handle).is_err() {
397 error!("Failed to send reply back to caller, new observer will not be added");
398 self.observers.remove(handle);
399 }
400 }
401
402 self.updated_tables.clear();
404
405 for table_ids in self.publish_changes.drain(..) {
406 if table_ids.is_clear() {
407 continue;
408 }
409
410 watcher.with_tables(|observer_tables| {
412 for idx in table_ids.ones() {
413 if let Some(name) = observer_tables.tables.get(idx).cloned() {
415 self.updated_tables.insert(name);
416 }
417 }
418 });
419 }
420
421 if !self.updated_tables.is_empty() {
422 debug!("Changes detected on tables: {:?}", self.updated_tables);
423 {
425 for (_, active_observer) in &self.observers {
426 if self
427 .updated_tables
428 .intersection(&active_observer.tables)
429 .next()
430 .is_some()
431 {
432 active_observer
433 .observer
434 .on_tables_changed(&self.updated_tables);
435 }
436 }
437 }
438 }
439 }
440}
441
442struct ActiveObserver {
443 observer: Box<dyn TableObserver>,
444 tables: BTreeSet<String>,
445}
446
447impl ActiveObserver {
448 fn new(observer: Box<dyn TableObserver>) -> ActiveObserver {
449 let tables = BTreeSet::from_iter(observer.tables());
450 Self { observer, tables }
451 }
452}
453
454enum Command {
456 AddObserver(Box<dyn TableObserver>, oneshot::Sender<TableObserverHandle>),
458 RemoveObserverDeferred(TableObserverHandle),
460 RemoveObserver(TableObserverHandle, oneshot::Sender<()>),
462 PublishChanges(FixedBitSet),
464}
465
466#[derive(Debug, thiserror::Error)]
467pub enum Error {
468 #[error("Failed to send or receive command to/from background thread")]
469 Command,
470 #[error("Failed to create background thread: {0}")]
471 Thread(std::io::Error),
472}
473
474#[derive(Debug, Clone, Eq, PartialEq)]
475pub(crate) enum ObservedTableOp {
476 Add(String, usize),
477 Remove(String, usize),
478}
479
480struct ObservedTables {
485 table_ids: BTreeMap<String, usize>,
487 tables: Vec<String>,
489 num_observers: Vec<usize>,
491 counter: u64,
493}
494
495impl ObservedTables {
496 fn new() -> Self {
497 Self {
498 table_ids: BTreeMap::new(),
499 tables: Vec::with_capacity(8),
500 num_observers: Vec::with_capacity(8),
501 counter: 0,
502 }
503 }
504
505 fn track_tables(&mut self, tables: impl Iterator<Item = String>) {
507 let mut requires_version_bump = false;
508 for table in tables {
509 match self.table_ids.entry(table.clone()) {
510 Entry::Vacant(v) => {
511 let id = self.num_observers.len();
512 self.tables.push(table.clone());
513 self.num_observers.push(1);
514 v.insert(id);
515 requires_version_bump = true;
516 }
517 Entry::Occupied(o) => {
518 let id = *o.get();
519 let current = self.num_observers[id];
520 if current == 0 {
521 requires_version_bump = true;
524 }
525 self.num_observers[*o.get()] = current + 1;
526 }
527 }
528 }
529
530 if requires_version_bump {
531 self.counter = self.counter.saturating_add(1);
532 }
533 }
534
535 fn untrack_tables<'i>(&mut self, tables: impl Iterator<Item = &'i String>) {
537 let mut requires_version_bump = false;
538 for table in tables {
539 if let Some(id) = self.table_ids.get(table) {
540 self.num_observers[*id] -= 1;
543 if self.num_observers[*id] == 0 {
544 requires_version_bump = true;
545 }
546 }
547 }
548
549 if requires_version_bump {
550 self.counter = self.counter.saturating_add(1);
551 }
552 }
553
554 fn calculate_changes(
560 &self,
561 connection_state: &FixedBitSet,
562 ) -> (FixedBitSet, Vec<ObservedTableOp>) {
563 let mut result = connection_state.clone();
564 result.grow(self.tables.len());
565 let mut changes = Vec::with_capacity(self.tables.len());
566 let min_index = connection_state.len().min(self.tables.len());
567 for i in 0..min_index {
568 let is_tracking = connection_state[i];
569 let num_observers = self.num_observers[i];
570
571 if is_tracking && num_observers == 0 {
572 changes.push(ObservedTableOp::Remove(self.tables[i].clone(), i));
573 result.set(i, false);
574 } else if !is_tracking && num_observers != 0 {
575 changes.push(ObservedTableOp::Add(self.tables[i].clone(), i));
576 result.set(i, true);
577 }
578 }
579
580 for i in min_index..self.num_observers.len() {
582 if self.num_observers[i] != 0 {
583 changes.push(ObservedTableOp::Add(self.tables[i].clone(), i));
584 result.set(i, true);
585 }
586 }
587
588 (result, changes)
589 }
590}
591
592#[cfg(test)]
593pub(crate) mod tests {
594 use crate::watcher::{ObservedTables, TableObserver, Watcher};
595 use std::collections::BTreeSet;
596 use std::sync::atomic::Ordering;
597
598 pub struct TestObserver {
599 tables: Vec<String>,
600 }
601
602 impl TableObserver for TestObserver {
603 fn tables(&self) -> Vec<String> {
604 self.tables.clone()
605 }
606 fn on_tables_changed(&self, _: &BTreeSet<String>) {}
607 }
608
609 pub(crate) fn new_test_observer(
610 tables: impl IntoIterator<Item = &'static str>,
611 ) -> Box<dyn TableObserver + Send + 'static> {
612 Box::new(TestObserver {
613 tables: tables.into_iter().map(ToString::to_string).collect(),
614 })
615 }
616
617 fn check_table_counter(tables: &ObservedTables, name: &str, expected: usize) {
618 let idx = *tables
619 .table_ids
620 .get(name)
621 .expect("could not find table by name");
622 assert_eq!(tables.num_observers[idx], expected);
623 }
624
625 #[test]
626 fn test_observer_tables_version_counter() {
627 let service = Watcher::new().unwrap();
628
629 let mut version = service.tables_version.load(Ordering::Relaxed);
630 let observer_1 = new_test_observer(["foo", "bar"]);
631 let observer_2 = new_test_observer(["bar"]);
632 let observer_3 = new_test_observer(["bar", "omega"]);
633
634 let observer_1_id = service.add_observer(observer_1).unwrap();
636 service.with_tables(|tables| {
637 assert_eq!(tables.num_observers.len(), 2);
638 check_table_counter(tables, "foo", 1);
639 check_table_counter(tables, "bar", 1);
640 });
641 version += 1;
642 assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
643
644 let observer_2_id = service.add_observer(observer_2).unwrap();
646 service.with_tables(|tables| {
647 assert_eq!(tables.num_observers.len(), 2);
648 check_table_counter(tables, "foo", 1);
649 check_table_counter(tables, "bar", 2);
650 });
651 assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
652
653 let observer_3_id = service.add_observer(observer_3).unwrap();
655 service.with_tables(|tables| {
656 assert_eq!(tables.num_observers.len(), 3);
657 check_table_counter(tables, "foo", 1);
658 check_table_counter(tables, "omega", 1);
659 check_table_counter(tables, "bar", 3);
660 });
661 version += 1;
662 assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
663
664 service.remove_observer(observer_2_id).unwrap();
666 service.with_tables(|tables| {
667 assert_eq!(tables.num_observers.len(), 3);
668 check_table_counter(tables, "foo", 1);
669 check_table_counter(tables, "bar", 2);
670 check_table_counter(tables, "omega", 1);
671 });
672 assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
673
674 service.remove_observer(observer_3_id).unwrap();
676 service.with_tables(|tables| {
677 assert_eq!(tables.num_observers.len(), 3);
678 check_table_counter(tables, "foo", 1);
679 check_table_counter(tables, "bar", 1);
680 check_table_counter(tables, "omega", 0);
681 });
682 version += 1;
683 assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
684
685 service.remove_observer(observer_1_id).unwrap();
687 service.with_tables(|tables| {
688 assert_eq!(tables.num_observers.len(), 3);
689 check_table_counter(tables, "foo", 0);
690 check_table_counter(tables, "bar", 0);
691 check_table_counter(tables, "omega", 0);
692 });
693 version += 1;
694 assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
695 }
696}