1use fixedbitset::FixedBitSet;
2use flume::{Receiver, Sender, TryRecvError};
3use parking_lot::RwLock;
4use slotmap::{SlotMap, new_key_type};
5use std::collections::btree_map::Entry;
6use std::collections::{BTreeMap, BTreeSet};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::{Arc, Weak};
9use tracing::{debug, error};
10
11new_key_type! {
12 pub struct TableObserverHandle;
14}
15
16#[derive(Debug, Clone)]
20pub struct DropRemoveTableObserverHandle {
21 watcher: Weak<Watcher>,
22 handle: TableObserverHandle,
23}
24
25impl DropRemoveTableObserverHandle {
26 fn new(handle: TableObserverHandle, watcher: &Arc<Watcher>) -> Self {
27 Self {
28 watcher: Arc::downgrade(watcher),
29 handle,
30 }
31 }
32
33 #[must_use]
35 pub fn handle(&self) -> TableObserverHandle {
36 self.handle
37 }
38
39 pub fn unsubscribe(&self) -> Result<(), Error> {
49 if let Some(watcher) = self.watcher.upgrade() {
50 watcher.remove_observer(self.handle)
51 } else {
52 Err(Error::Command)
53 }
54 }
55}
56
57impl Drop for DropRemoveTableObserverHandle {
58 fn drop(&mut self) {
59 if let Some(watcher) = self.watcher.upgrade() {
60 if watcher.remove_observer_deferred(self.handle).is_err() {
61 error!("Failed to remove watcher from observer on drop");
62 }
63 }
64 }
65}
66
67pub trait TableObserver: Send + Sync {
69 fn tables(&self) -> Vec<String>;
71
72 fn on_tables_changed(&self, tables: &BTreeSet<String>);
80}
81
82pub struct Watcher {
127 tables: RwLock<ObservedTables>,
128 tables_version: AtomicU64,
129 sender: Sender<Command>,
130}
131
132const WATCHER_CHANNEL_CAPACITY: usize = 24;
133
134impl Watcher {
135 pub fn new() -> Result<Arc<Self>, Error> {
140 let (sender, receiver) = flume::bounded(WATCHER_CHANNEL_CAPACITY);
141 let watcher = Arc::new(Self {
142 tables: RwLock::new(ObservedTables::new()),
143 tables_version: AtomicU64::new(0),
144 sender,
145 });
146
147 let watcher_cloned = Arc::clone(&watcher);
148 std::thread::Builder::new()
149 .name("sqlite_watcher".into())
150 .spawn(move || {
151 Watcher::background_loop(receiver, &watcher_cloned);
152 })
153 .map_err(Error::Thread)?;
154
155 Ok(watcher)
156 }
157
158 pub fn add_observer(
168 &self,
169 observer: Box<dyn TableObserver>,
170 ) -> Result<TableObserverHandle, Error> {
171 let (sender, receiver) = oneshot::channel();
172 if self
173 .sender
174 .send(Command::AddObserver(observer, sender))
175 .is_err()
176 {
177 error!("Failed to send add observer command");
178 return Err(Error::Command);
179 }
180
181 let Ok(handle) = receiver.recv() else {
182 error!("Failed to receive handle for new observer");
183 return Err(Error::Command);
184 };
185
186 Ok(handle)
187 }
188
189 pub fn add_observer_with_drop_remove(
197 self: &Arc<Self>,
198 observer: Box<dyn TableObserver>,
199 ) -> Result<DropRemoveTableObserverHandle, Error> {
200 let handle = self.add_observer(observer)?;
201
202 Ok(DropRemoveTableObserverHandle::new(handle, self))
203 }
204
205 pub fn remove_observer_deferred(&self, handle: TableObserverHandle) -> Result<(), Error> {
217 self.sender
218 .send(Command::RemoveObserverDeferred(handle))
219 .map_err(|_| Error::Command)
220 }
221
222 pub fn remove_observer(&self, handle: TableObserverHandle) -> Result<(), Error> {
232 let (sender, receiver) = oneshot::channel();
233 self.sender
234 .send(Command::RemoveObserver(handle, sender))
235 .map_err(|_| Error::Command)?;
236
237 receiver.recv().map_err(|_| {
238 error!("Failed to receive reply for remove observer command");
239 Error::Command
240 })
241 }
242
243 pub(crate) fn publish_changes(&self, table_ids: FixedBitSet) {
244 if self
245 .sender
246 .send(Command::PublishChanges(table_ids))
247 .is_err()
248 {
249 error!("Watcher could not communicate with background thread");
250 }
251 }
252
253 pub(crate) async fn publish_changes_async(&self, table_ids: FixedBitSet) {
254 if self
255 .sender
256 .send_async(Command::PublishChanges(table_ids))
257 .await
258 .is_err()
259 {
260 error!("Watcher could not communicate with background thread");
261 }
262 }
263
264 #[cfg(test)]
265 pub(crate) fn get_table_id(&self, table: &str) -> Option<usize> {
266 self.with_tables(|tables| tables.table_ids.get(table).copied())
267 }
268
269 fn with_tables_mut(&self, f: impl FnOnce(&mut ObservedTables)) {
270 let mut accessor = self.tables.write();
271 let prev_counter = accessor.counter;
273
274 (f)(&mut accessor);
275
276 let cur_counter = accessor.counter;
278 if prev_counter != cur_counter {
279 self.tables_version.fetch_add(1, Ordering::Release);
280 }
281 }
282
283 fn with_tables<R>(&self, f: impl (FnOnce(&ObservedTables) -> R)) -> R {
284 let accessor = self.tables.read();
285 (f)(&accessor)
286 }
287
288 pub(crate) fn tables_version(&self) -> u64 {
290 self.tables_version.load(Ordering::Acquire)
291 }
292
293 pub fn observed_tables(&self) -> Vec<String> {
295 self.with_tables(|t| t.tables.clone())
296 }
297
298 pub(crate) fn calculate_sync_changes(
299 &self,
300 connection_state: &FixedBitSet,
301 ) -> (FixedBitSet, Vec<ObservedTableOp>) {
302 self.with_tables(|t| t.calculate_changes(connection_state))
303 }
304
305 #[allow(clippy::needless_pass_by_value)]
306 #[tracing::instrument(level= tracing::Level::TRACE, skip(receiver, watcher))]
307 fn background_loop(receiver: Receiver<Command>, watcher: &Watcher) {
308 let mut worker = WatcherWorker::new();
309
310 loop {
311 debug_assert!(worker.add_observers.is_empty());
312 debug_assert!(worker.remove_observers.is_empty());
313 debug_assert!(worker.publish_changes.is_empty());
314
315 let Ok(command) = receiver.recv() else {
316 return;
317 };
318
319 worker.unpack_command(command);
320
321 loop {
323 match receiver.try_recv() {
324 Ok(command) => {
325 worker.unpack_command(command);
326 }
327 Err(e) => match e {
328 TryRecvError::Empty => {
329 break;
330 }
331 TryRecvError::Disconnected => {
332 return;
333 }
334 },
335 }
336 }
337
338 worker.tick(watcher);
339 }
340 }
341}
342
343struct WatcherWorker {
345 observers: SlotMap<TableObserverHandle, ActiveObserver>,
346 updated_tables: BTreeSet<String>,
347 remove_observers: Vec<(TableObserverHandle, Option<oneshot::Sender<()>>)>,
348 add_observers: Vec<(Box<dyn TableObserver>, oneshot::Sender<TableObserverHandle>)>,
349 publish_changes: Vec<FixedBitSet>,
350}
351
352impl WatcherWorker {
353 fn new() -> Self {
354 Self {
355 observers: SlotMap::with_capacity_and_key(4),
356 updated_tables: BTreeSet::default(),
357 remove_observers: vec![],
358 add_observers: vec![],
359 publish_changes: vec![],
360 }
361 }
362 fn unpack_command(&mut self, command: Command) {
363 match command {
364 Command::AddObserver(o, r) => self.add_observers.push((o, r)),
365 Command::RemoveObserver(h, r) => self.remove_observers.push((h, Some(r))),
366 Command::RemoveObserverDeferred(h) => {
367 self.remove_observers.push((h, None));
368 }
369 Command::PublishChanges(fixedbitset) => {
370 self.publish_changes.push(fixedbitset);
371 }
372 }
373 }
374
375 fn tick(&mut self, watcher: &Watcher) {
376 for (handle, reply) in self.remove_observers.drain(..) {
378 if let Some(observer) = self.observers.remove(handle) {
379 watcher.with_tables_mut(|tables| {
380 tables.untrack_tables(observer.tables.iter());
381 });
382 }
383
384 if let Some(reply) = reply {
385 if reply.send(()).is_err() {
386 error!("Failed to send reply for observer removal");
387 }
388 }
389 }
390
391 for (observer, reply) in self.add_observers.drain(..) {
393 let active_observer = ActiveObserver::new(observer);
394 watcher.with_tables_mut(|tables| {
395 tables.track_tables(active_observer.tables.iter().cloned());
396 });
397 let handle = self.observers.insert(active_observer);
398 if reply.send(handle).is_err() {
399 error!("Failed to send reply back to caller, new observer will not be added");
400 self.observers.remove(handle);
401 }
402 }
403
404 self.updated_tables.clear();
406
407 for table_ids in self.publish_changes.drain(..) {
408 if table_ids.is_clear() {
409 continue;
410 }
411
412 watcher.with_tables(|observer_tables| {
414 for idx in table_ids.ones() {
415 if let Some(name) = observer_tables.tables.get(idx).cloned() {
417 self.updated_tables.insert(name);
418 }
419 }
420 });
421 }
422
423 if !self.updated_tables.is_empty() {
424 debug!("Changes detected on tables: {:?}", self.updated_tables);
425 {
427 for (_, active_observer) in &self.observers {
428 if self
429 .updated_tables
430 .intersection(&active_observer.tables)
431 .next()
432 .is_some()
433 {
434 active_observer
435 .observer
436 .on_tables_changed(&self.updated_tables);
437 }
438 }
439 }
440 }
441 }
442}
443
444struct ActiveObserver {
445 observer: Box<dyn TableObserver>,
446 tables: BTreeSet<String>,
447}
448
449impl ActiveObserver {
450 fn new(observer: Box<dyn TableObserver>) -> ActiveObserver {
451 let tables = BTreeSet::from_iter(observer.tables());
452 Self { observer, tables }
453 }
454}
455
456enum Command {
458 AddObserver(Box<dyn TableObserver>, oneshot::Sender<TableObserverHandle>),
460 RemoveObserverDeferred(TableObserverHandle),
462 RemoveObserver(TableObserverHandle, oneshot::Sender<()>),
464 PublishChanges(FixedBitSet),
466}
467
468#[derive(Debug, thiserror::Error)]
469pub enum Error {
470 #[error("Failed to send or receive command to/from background thread")]
471 Command,
472 #[error("Failed to create background thread: {0}")]
473 Thread(std::io::Error),
474}
475
476#[derive(Debug, Clone, Eq, PartialEq)]
477pub(crate) enum ObservedTableOp {
478 Add(String, usize),
479 Remove(String, usize),
480}
481
482struct ObservedTables {
487 table_ids: BTreeMap<String, usize>,
489 tables: Vec<String>,
491 num_observers: Vec<usize>,
493 counter: u64,
495}
496
497impl ObservedTables {
498 fn new() -> Self {
499 Self {
500 table_ids: BTreeMap::new(),
501 tables: Vec::with_capacity(8),
502 num_observers: Vec::with_capacity(8),
503 counter: 0,
504 }
505 }
506
507 fn track_tables(&mut self, tables: impl Iterator<Item = String>) {
509 let mut requires_version_bump = false;
510 for table in tables {
511 match self.table_ids.entry(table.clone()) {
512 Entry::Vacant(v) => {
513 let id = self.num_observers.len();
514 self.tables.push(table.clone());
515 self.num_observers.push(1);
516 v.insert(id);
517 requires_version_bump = true;
518 }
519 Entry::Occupied(o) => {
520 let id = *o.get();
521 let current = self.num_observers[id];
522 if current == 0 {
523 requires_version_bump = true;
526 }
527 self.num_observers[*o.get()] = current + 1;
528 }
529 }
530 }
531
532 if requires_version_bump {
533 self.counter = self.counter.saturating_add(1);
534 }
535 }
536
537 fn untrack_tables<'i>(&mut self, tables: impl Iterator<Item = &'i String>) {
539 let mut requires_version_bump = false;
540 for table in tables {
541 if let Some(id) = self.table_ids.get(table) {
542 self.num_observers[*id] -= 1;
545 if self.num_observers[*id] == 0 {
546 requires_version_bump = true;
547 }
548 }
549 }
550
551 if requires_version_bump {
552 self.counter = self.counter.saturating_add(1);
553 }
554 }
555
556 fn calculate_changes(
562 &self,
563 connection_state: &FixedBitSet,
564 ) -> (FixedBitSet, Vec<ObservedTableOp>) {
565 let mut result = connection_state.clone();
566 result.grow(self.tables.len());
567 let mut changes = Vec::with_capacity(self.tables.len());
568 let min_index = connection_state.len().min(self.tables.len());
569 for i in 0..min_index {
570 let is_tracking = connection_state[i];
571 let num_observers = self.num_observers[i];
572
573 if is_tracking && num_observers == 0 {
574 changes.push(ObservedTableOp::Remove(self.tables[i].clone(), i));
575 result.set(i, false);
576 } else if !is_tracking && num_observers != 0 {
577 changes.push(ObservedTableOp::Add(self.tables[i].clone(), i));
578 result.set(i, true);
579 }
580 }
581
582 for i in min_index..self.num_observers.len() {
584 if self.num_observers[i] != 0 {
585 changes.push(ObservedTableOp::Add(self.tables[i].clone(), i));
586 result.set(i, true);
587 }
588 }
589
590 (result, changes)
591 }
592}
593
594#[cfg(test)]
595pub(crate) mod tests {
596 use crate::watcher::{ObservedTables, TableObserver, Watcher};
597 use std::collections::BTreeSet;
598 use std::sync::atomic::Ordering;
599
600 pub struct TestObserver {
601 tables: Vec<String>,
602 }
603
604 impl TableObserver for TestObserver {
605 fn tables(&self) -> Vec<String> {
606 self.tables.clone()
607 }
608 fn on_tables_changed(&self, _: &BTreeSet<String>) {}
609 }
610
611 pub(crate) fn new_test_observer(
612 tables: impl IntoIterator<Item = &'static str>,
613 ) -> Box<dyn TableObserver + Send + 'static> {
614 Box::new(TestObserver {
615 tables: tables.into_iter().map(ToString::to_string).collect(),
616 })
617 }
618
619 fn check_table_counter(tables: &ObservedTables, name: &str, expected: usize) {
620 let idx = *tables
621 .table_ids
622 .get(name)
623 .expect("could not find table by name");
624 assert_eq!(tables.num_observers[idx], expected);
625 }
626
627 #[test]
628 fn test_observer_tables_version_counter() {
629 let service = Watcher::new().unwrap();
630
631 let mut version = service.tables_version.load(Ordering::Relaxed);
632 let observer_1 = new_test_observer(["foo", "bar"]);
633 let observer_2 = new_test_observer(["bar"]);
634 let observer_3 = new_test_observer(["bar", "omega"]);
635
636 let observer_1_id = service.add_observer(observer_1).unwrap();
638 service.with_tables(|tables| {
639 assert_eq!(tables.num_observers.len(), 2);
640 check_table_counter(tables, "foo", 1);
641 check_table_counter(tables, "bar", 1);
642 });
643 version += 1;
644 assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
645
646 let observer_2_id = service.add_observer(observer_2).unwrap();
648 service.with_tables(|tables| {
649 assert_eq!(tables.num_observers.len(), 2);
650 check_table_counter(tables, "foo", 1);
651 check_table_counter(tables, "bar", 2);
652 });
653 assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
654
655 let observer_3_id = service.add_observer(observer_3).unwrap();
657 service.with_tables(|tables| {
658 assert_eq!(tables.num_observers.len(), 3);
659 check_table_counter(tables, "foo", 1);
660 check_table_counter(tables, "omega", 1);
661 check_table_counter(tables, "bar", 3);
662 });
663 version += 1;
664 assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
665
666 service.remove_observer(observer_2_id).unwrap();
668 service.with_tables(|tables| {
669 assert_eq!(tables.num_observers.len(), 3);
670 check_table_counter(tables, "foo", 1);
671 check_table_counter(tables, "bar", 2);
672 check_table_counter(tables, "omega", 1);
673 });
674 assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
675
676 service.remove_observer(observer_3_id).unwrap();
678 service.with_tables(|tables| {
679 assert_eq!(tables.num_observers.len(), 3);
680 check_table_counter(tables, "foo", 1);
681 check_table_counter(tables, "bar", 1);
682 check_table_counter(tables, "omega", 0);
683 });
684 version += 1;
685 assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
686
687 service.remove_observer(observer_1_id).unwrap();
689 service.with_tables(|tables| {
690 assert_eq!(tables.num_observers.len(), 3);
691 check_table_counter(tables, "foo", 0);
692 check_table_counter(tables, "bar", 0);
693 check_table_counter(tables, "omega", 0);
694 });
695 version += 1;
696 assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
697 }
698}