1use crate::statement::{
2 BatchQuery, Sealed, SqlExecuteStatement, SqlTransactionStatement, Statement, StatementWithInput,
3};
4use crate::watcher::{ObservedTableOp, Watcher};
5use fixedbitset::FixedBitSet;
6use std::error::Error;
7use std::future::Future;
8use std::ops::{Deref, DerefMut};
9use std::sync::Arc;
10use tracing::{debug, trace, warn};
11
12#[cfg(feature = "rusqlite")]
13pub mod rusqlite;
14
15#[cfg(feature = "sqlx")]
16pub mod sqlx;
17
18pub trait SqlExecutor {
23 type Error: Error;
24 fn sql_query_values(&self, query: &str) -> Result<Vec<u32>, Self::Error>;
30
31 fn sql_execute(&self, query: &str) -> Result<(), Self::Error>;
37}
38
39pub trait SqlExecutorMut {
42 type Error: Error;
43 fn sql_query_values(&mut self, query: &str) -> Result<Vec<u32>, Self::Error>;
49
50 fn sql_execute(&mut self, query: &str) -> Result<(), Self::Error>;
56}
57
58impl<T: SqlExecutor> SqlExecutorMut for T {
60 type Error = T::Error;
61
62 fn sql_query_values(&mut self, query: &str) -> Result<Vec<u32>, Self::Error> {
63 SqlExecutor::sql_query_values(self, query)
64 }
65
66 fn sql_execute(&mut self, query: &str) -> Result<(), Self::Error> {
67 SqlExecutor::sql_execute(self, query)
68 }
69}
70
71pub trait SqlExecutorAsync: Send {
76 type Error: Error + Send;
77 fn sql_query_values(
83 &mut self,
84 query: &str,
85 ) -> impl Future<Output = Result<Vec<u32>, Self::Error>> + Send;
86
87 fn sql_execute(&mut self, query: &str) -> impl Future<Output = Result<(), Self::Error>> + Send;
93}
94
95#[derive(Debug, Default)]
123pub struct State {
124 tracked_tables: FixedBitSet,
125 last_sync_version: u64,
126}
127
128impl State {
129 #[must_use]
131 pub fn set_pragmas() -> impl Statement {
132 SqlExecuteStatement::new("PRAGMA temp_store = MEMORY")
133 .then(SqlExecuteStatement::new("PRAGMA recursive_triggers='ON'"))
134 }
135
136 #[must_use]
140 #[tracing::instrument(level = tracing::Level::DEBUG)]
141 pub fn start_tracking() -> impl Statement {
142 SqlTransactionStatement::temporary(
144 SqlExecuteStatement::new(create_tracking_table_query())
145 .then(SqlExecuteStatement::new(empty_tracking_table_query())),
146 )
147 .spanned_in_current()
148 }
149
150 #[tracing::instrument(level = tracing::Level::DEBUG, skip_all)]
156 pub fn stop_tracking(&self, watcher: &Watcher) -> impl Statement {
157 let tables = watcher.observed_tables();
158 SqlTransactionStatement::temporary(
159 BatchQuery::new(
160 tables
161 .into_iter()
162 .enumerate()
163 .flat_map(|(id, table_name)| drop_triggers(&table_name, id)),
164 )
165 .then(SqlExecuteStatement::new(drop_tracking_table_query())),
166 )
167 .spanned_in_current()
168 }
169
170 #[must_use]
172 pub fn new() -> Self {
173 Self {
174 tracked_tables: FixedBitSet::new(),
175 last_sync_version: 0,
176 }
177 }
178
179 #[tracing::instrument(level=tracing::Level::DEBUG, skip(self, watcher))]
188 pub fn sync_tables(&mut self, watcher: &Watcher) -> Option<impl Statement + '_> {
189 let new_version = self.should_sync(watcher)?;
190
191 debug!("Syncing tables from observer");
192 let Some((new_tracker_state, tracker_changes)) = self.calculate_sync_changes(watcher)
193 else {
194 debug!("No changes");
195 return None;
196 };
197
198 let mut queries = BatchQuery::new([]);
199
200 if self.tracked_tables.is_empty() {
201 queries.push(SqlExecuteStatement::new(empty_tracking_table_query()));
205 }
206 for change in tracker_changes {
207 match change {
208 ObservedTableOp::Add(table_name, id) => {
209 debug!("Add watcher for table {table_name} id={id}");
210 queries.extend(create_triggers(&table_name, id));
211 }
212 ObservedTableOp::Remove(table_name, id) => {
213 debug!("Remove watcher for table {table_name}");
214 queries.extend(drop_triggers(&table_name, id));
215 }
216 }
217 }
218
219 let tx = SqlTransactionStatement::temporary(queries);
220 Some(
221 tx.then(ConcludeStateChangeStatement {
222 state: self,
223 tracked_tables: new_tracker_state,
224 new_version,
225 })
226 .spanned_in_current(),
227 )
228 }
229
230 #[tracing::instrument(level=tracing::Level::DEBUG, skip(self, watcher))]
243 pub fn publish_changes(&self, watcher: &Watcher) -> impl Statement {
244 SqlReadTableIdsStatement
245 .pipe(CalculateWatcherUpdatesStatement { state: self })
246 .pipe(MaybeResetResultsQuery)
247 .pipe(PublishWatcherChangesStatement(watcher))
248 .spanned_in_current()
249 }
250
251 fn prepare_watcher_changes(&self, modified_table_ids: Vec<u32>) -> FixedBitSet {
252 trace!("Preparing watcher changes");
253 let mut result = FixedBitSet::with_capacity(self.tracked_tables.len());
254 for id in modified_table_ids {
255 let id = id as usize;
256 debug!("Table {} has been modified", id);
257 if id >= result.len() {
258 warn!(
259 "Received update for table {id}, but only tracking {} tables",
260 self.tracked_tables.len(),
261 );
262 result.grow(id + 1);
264 }
265 result.set(id, true);
266 }
267
268 result
269 }
270
271 fn should_sync(&self, watcher: &Watcher) -> Option<u64> {
272 let service_version = watcher.tables_version();
273 if service_version == self.last_sync_version {
274 None
275 } else {
276 Some(service_version)
277 }
278 }
279
280 fn calculate_sync_changes(
282 &self,
283 watcher: &Watcher,
284 ) -> Option<(FixedBitSet, Vec<ObservedTableOp>)> {
285 trace!("Calculating sync changes");
286 let (new_tracker_state, tracker_changes) =
287 watcher.calculate_sync_changes(&self.tracked_tables);
288
289 if tracker_changes.is_empty() {
290 return None;
291 }
292
293 Some((new_tracker_state, tracker_changes))
294 }
295
296 fn apply_sync_changes(&mut self, new_tracker_state: FixedBitSet, new_version: u64) {
298 trace!("Applying sync changes");
300 self.tracked_tables = new_tracker_state;
301 self.last_sync_version = new_version;
302 }
303}
304
305pub struct Connection<C: SqlExecutor> {
365 state: State,
366 watcher: Arc<Watcher>,
367 connection: C,
368}
369impl<C: SqlExecutor> Connection<C> {
370 pub fn new(connection: C, watcher: Arc<Watcher>) -> Result<Self, C::Error> {
378 let state = State::new();
379 State::set_pragmas().execute(&connection)?;
380 State::start_tracking().execute(&connection)?;
381 Ok(Self {
382 state,
383 watcher,
384 connection,
385 })
386 }
387
388 pub fn sync_watcher_tables(&mut self) -> Result<(), C::Error> {
396 self.state
397 .sync_tables(&self.watcher)
398 .execute(&self.connection)?;
399 Ok(())
400 }
401
402 pub fn publish_watcher_changes(&mut self) -> Result<(), C::Error> {
412 self.state
413 .publish_changes(&self.watcher)
414 .execute(&self.connection)?;
415 Ok(())
416 }
417
418 pub fn stop_tracking(&mut self) -> Result<(), C::Error> {
426 self.state
427 .stop_tracking(&self.watcher)
428 .execute(&self.connection)?;
429 Ok(())
430 }
431
432 pub fn take(self) -> C {
439 self.connection
440 }
441}
442
443#[allow(clippy::module_name_repetitions)]
445pub struct ConnectionAsync<C: SqlExecutorAsync> {
446 state: State,
447 watcher: Arc<Watcher>,
448 connection: C,
449}
450impl<C: SqlExecutorAsync> ConnectionAsync<C> {
451 pub async fn new(mut connection: C, watcher: Arc<Watcher>) -> Result<Self, C::Error> {
459 let state = State::new();
460 State::set_pragmas().execute_async(&mut connection).await?;
461 State::start_tracking()
462 .execute_async(&mut connection)
463 .await?;
464 Ok(Self {
465 state,
466 watcher,
467 connection,
468 })
469 }
470
471 pub async fn sync_watcher_tables(&mut self) -> Result<(), C::Error> {
477 self.state
478 .sync_tables(&self.watcher)
479 .execute_async(&mut self.connection)
480 .await?;
481 Ok(())
482 }
483
484 pub async fn publish_watcher_changes(&mut self) -> Result<(), C::Error> {
490 self.state
491 .publish_changes(&self.watcher)
492 .execute_async(&mut self.connection)
493 .await?;
494 Ok(())
495 }
496
497 pub async fn stop_tracking(&mut self) -> Result<(), C::Error> {
503 self.state
504 .stop_tracking(&self.watcher)
505 .execute_async(&mut self.connection)
506 .await?;
507 Ok(())
508 }
509
510 pub fn take(self) -> C {
517 self.connection
518 }
519}
520
521impl<C: SqlExecutorAsync> Deref for ConnectionAsync<C> {
522 type Target = C;
523
524 fn deref(&self) -> &Self::Target {
525 &self.connection
526 }
527}
528
529impl<C: SqlExecutorAsync> DerefMut for ConnectionAsync<C> {
530 fn deref_mut(&mut self) -> &mut Self::Target {
531 &mut self.connection
532 }
533}
534
535impl<C: SqlExecutorAsync> AsRef<C> for ConnectionAsync<C> {
536 fn as_ref(&self) -> &C {
537 &self.connection
538 }
539}
540
541impl<C: SqlExecutorAsync> AsMut<C> for ConnectionAsync<C> {
542 fn as_mut(&mut self) -> &mut C {
543 &mut self.connection
544 }
545}
546
547impl<C: SqlExecutor> Deref for Connection<C> {
548 type Target = C;
549
550 fn deref(&self) -> &Self::Target {
551 &self.connection
552 }
553}
554
555impl<C: SqlExecutor> DerefMut for Connection<C> {
556 fn deref_mut(&mut self) -> &mut Self::Target {
557 &mut self.connection
558 }
559}
560
561impl<C: SqlExecutor> AsRef<C> for Connection<C> {
562 fn as_ref(&self) -> &C {
563 &self.connection
564 }
565}
566
567impl<C: SqlExecutor> AsMut<C> for Connection<C> {
568 fn as_mut(&mut self) -> &mut C {
569 &mut self.connection
570 }
571}
572
573const TRACKER_TABLE_NAME: &str = "rsqlite_watcher_version_tracker";
574
575const TRIGGER_LIST: [(&str, &str); 3] = [
576 ("INSERT", "insert"),
577 ("UPDATE", "update"),
578 ("DELETE", "delete"),
579];
580
581#[inline]
582fn create_tracking_table_query() -> String {
583 format!(
584 "CREATE TEMP TABLE IF NOT EXISTS `{TRACKER_TABLE_NAME}` (table_id INTEGER PRIMARY KEY, updated INTEGER)"
585 )
586}
587#[inline]
588fn empty_tracking_table_query() -> String {
589 format!("DELETE FROM `{TRACKER_TABLE_NAME}`")
590}
591#[inline]
592fn drop_tracking_table_query() -> String {
593 format!("DROP TABLE IF EXISTS `{TRACKER_TABLE_NAME}`")
594}
595
596#[inline]
597fn create_trigger_query(
598 table_name: &str,
599 trigger: &str,
600 trigger_name: &str,
601 table_id: usize,
602) -> String {
603 format!(
604 r"
605CREATE TEMP TRIGGER IF NOT EXISTS `{TRACKER_TABLE_NAME}_trigger_{table_name}_{trigger_name}` AFTER {trigger} ON `{table_name}`
606BEGIN
607 UPDATE `{TRACKER_TABLE_NAME}` SET updated=1 WHERE table_id={table_id};
608END
609 "
610 )
611}
612
613#[inline]
614fn insert_table_id_into_tracking_table_query(id: usize) -> String {
615 format!("INSERT INTO `{TRACKER_TABLE_NAME}` VALUES ({id},0)")
616}
617
618#[inline]
619fn drop_trigger_query(table_name: &str, trigger_name: &str) -> String {
620 format!("DROP TRIGGER IF EXISTS `{TRACKER_TABLE_NAME}_trigger_{table_name}_{trigger_name}`")
621}
622
623#[inline]
624fn remove_table_id_from_tracking_table_query(table_id: usize) -> String {
625 format!("DELETE FROM `{TRACKER_TABLE_NAME}` WHERE table_id={table_id}")
626}
627
628#[inline]
629fn select_updated_tables_query() -> String {
630 format!("SELECT table_id FROM `{TRACKER_TABLE_NAME}` WHERE updated=1")
631}
632
633#[inline]
634fn reset_updated_tables_query() -> String {
635 format!("UPDATE `{TRACKER_TABLE_NAME}` SET updated=0 WHERE updated=1")
636}
637
638fn create_triggers(table: &str, id: usize) -> Vec<SqlExecuteStatement<String>> {
640 TRIGGER_LIST
641 .iter()
642 .map(|(trigger, trigger_name)| {
643 let query = create_trigger_query(table, trigger, trigger_name, id);
644 SqlExecuteStatement::new(query)
645 })
646 .chain(std::iter::once_with(|| {
647 let query = insert_table_id_into_tracking_table_query(id);
648 SqlExecuteStatement::new(query)
649 }))
650 .collect()
651}
652
653fn drop_triggers(table: &str, id: usize) -> Vec<SqlExecuteStatement<String>> {
655 TRIGGER_LIST
656 .iter()
657 .map(|(_, trigger_name)| {
658 let query = drop_trigger_query(table, trigger_name);
659 SqlExecuteStatement::new(query)
660 })
661 .chain(std::iter::once_with(|| {
662 let query = remove_table_id_from_tracking_table_query(id);
663 SqlExecuteStatement::new(query)
664 }))
665 .collect()
666}
667
668struct ConcludeStateChangeStatement<'s> {
670 state: &'s mut State,
671 tracked_tables: FixedBitSet,
672 new_version: u64,
673}
674
675impl Sealed for ConcludeStateChangeStatement<'_> {}
676impl Statement for ConcludeStateChangeStatement<'_> {
677 type Output = ();
678 fn execute<S: SqlExecutor>(self, _: &S) -> Result<Self::Output, S::Error> {
679 self.state
680 .apply_sync_changes(self.tracked_tables, self.new_version);
681 Ok(())
682 }
683
684 fn execute_mut<S: SqlExecutorMut>(self, _: &mut S) -> Result<Self::Output, S::Error> {
685 self.state
686 .apply_sync_changes(self.tracked_tables, self.new_version);
687 Ok(())
688 }
689
690 async fn execute_async<S: SqlExecutorAsync>(self, _: &mut S) -> Result<Self::Output, S::Error> {
691 self.state
692 .apply_sync_changes(self.tracked_tables, self.new_version);
693 Ok(())
694 }
695}
696
697struct CalculateWatcherUpdatesStatement<'s> {
699 state: &'s State,
700}
701
702impl StatementWithInput for CalculateWatcherUpdatesStatement<'_> {
703 type Input = Vec<u32>;
704 type Output = FixedBitSet;
705
706 fn execute<S: SqlExecutor>(self, _: &S, input: Self::Input) -> Result<Self::Output, S::Error> {
707 Ok(self.state.prepare_watcher_changes(input))
708 }
709 fn execute_mut<S: SqlExecutorMut>(
710 self,
711 _: &mut S,
712 input: Self::Input,
713 ) -> Result<Self::Output, S::Error> {
714 Ok(self.state.prepare_watcher_changes(input))
715 }
716 async fn execute_async<S: SqlExecutorAsync>(
717 self,
718 _: &mut S,
719 input: Self::Input,
720 ) -> Result<Self::Output, S::Error> {
721 Ok(self.state.prepare_watcher_changes(input))
722 }
723}
724
725struct PublishWatcherChangesStatement<'w>(&'w Watcher);
727
728impl Sealed for PublishWatcherChangesStatement<'_> {}
729
730impl StatementWithInput for PublishWatcherChangesStatement<'_> {
731 type Input = FixedBitSet;
732 type Output = ();
733
734 fn execute<S: SqlExecutor>(self, _: &S, input: Self::Input) -> Result<Self::Output, S::Error> {
735 self.0.publish_changes(input);
736 Ok(())
737 }
738 fn execute_mut<S: SqlExecutorMut>(
739 self,
740 _: &mut S,
741 input: Self::Input,
742 ) -> Result<Self::Output, S::Error> {
743 self.0.publish_changes(input);
744 Ok(())
745 }
746 async fn execute_async<S: SqlExecutorAsync>(
747 self,
748 _: &mut S,
749 input: Self::Input,
750 ) -> Result<Self::Output, S::Error> {
751 self.0.publish_changes_async(input).await;
752 Ok(())
753 }
754}
755
756impl Sealed for SqlReadTableIdsStatement {}
757struct SqlReadTableIdsStatement;
758impl Statement for SqlReadTableIdsStatement {
759 type Output = Vec<u32>;
760 fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
761 connection.sql_query_values(&select_updated_tables_query())
762 }
763 fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
764 connection.sql_query_values(&select_updated_tables_query())
765 }
766 async fn execute_async<S: SqlExecutorAsync>(
767 self,
768 connection: &mut S,
769 ) -> Result<Self::Output, S::Error> {
770 connection
771 .sql_query_values(&select_updated_tables_query())
772 .await
773 }
774}
775
776struct MaybeResetResultsQuery;
780impl StatementWithInput for MaybeResetResultsQuery {
781 type Input = FixedBitSet;
782 type Output = FixedBitSet;
783
784 fn execute<S: SqlExecutor>(
785 self,
786 connection: &S,
787 input: Self::Input,
788 ) -> Result<Self::Output, S::Error> {
789 if !input.is_clear() {
790 connection.sql_execute(&reset_updated_tables_query())?;
792 }
793 Ok(input)
794 }
795 fn execute_mut<S: SqlExecutorMut>(
796 self,
797 connection: &mut S,
798 input: Self::Input,
799 ) -> Result<Self::Output, S::Error> {
800 if !input.is_clear() {
801 connection.sql_execute(&reset_updated_tables_query())?;
803 }
804 Ok(input)
805 }
806 async fn execute_async<S: SqlExecutorAsync>(
807 self,
808 connection: &mut S,
809 input: Self::Input,
810 ) -> Result<Self::Output, S::Error> {
811 if !input.is_clear() {
812 connection
814 .sql_execute(&reset_updated_tables_query())
815 .await?;
816 }
817 Ok(input)
818 }
819}
820
821#[cfg(test)]
822mod test {
823 use crate::connection::State;
824 use crate::watcher::tests::new_test_observer;
825 use crate::watcher::{ObservedTableOp, TableObserver, Watcher};
826 use std::collections::BTreeSet;
827 use std::sync::Mutex;
828 use std::sync::mpsc::{Receiver, SyncSender};
829
830 pub struct TestObserver {
831 expected: Mutex<Vec<BTreeSet<String>>>,
832 tables: Vec<String>,
833 checked_channel: SyncSender<()>,
835 }
836
837 impl TestObserver {
838 pub fn new(
839 tables: Vec<String>,
840 expected: impl IntoIterator<Item = BTreeSet<String>>,
841 ) -> (Self, Receiver<()>) {
842 let (sender, receiver) = std::sync::mpsc::sync_channel::<()>(0);
843 let mut expected = expected.into_iter().collect::<Vec<_>>();
844 expected.reverse();
845 (
846 Self {
847 expected: Mutex::new(expected),
848 tables,
849 checked_channel: sender,
850 },
851 receiver,
852 )
853 }
854 }
855
856 impl TableObserver for TestObserver {
857 fn tables(&self) -> Vec<String> {
858 self.tables.clone()
859 }
860
861 fn on_tables_changed(&self, tables: &BTreeSet<String>) {
862 let expected = self.expected.lock().unwrap().pop().unwrap();
863 assert_eq!(*tables, expected);
864 self.checked_channel.send(()).unwrap();
865 }
866 }
867
868 #[test]
869 fn connection_state() {
870 let service = Watcher::new().unwrap();
871
872 let observer_1 = new_test_observer(["foo", "bar"]);
873 let observer_2 = new_test_observer(["bar"]);
874 let observer_3 = new_test_observer(["bar", "omega"]);
875
876 let mut local_state = State::new();
877
878 assert!(local_state.should_sync(&service).is_none());
879 let observer_id_1 = service.add_observer(observer_1).unwrap();
880 let foo_table_id = service.get_table_id("foo").unwrap();
881 let bar_table_id = service.get_table_id("bar").unwrap();
882 {
883 let new_version = local_state
884 .should_sync(&service)
885 .expect("Should have new version");
886 let (tracker, ops) = local_state
887 .calculate_sync_changes(&service)
888 .expect("must have changes");
889 assert!(tracker[bar_table_id]);
890 assert!(tracker[foo_table_id]);
891 assert_eq!(ops.len(), 2);
892 assert_eq!(
893 ops[0],
894 ObservedTableOp::Add("bar".to_string(), bar_table_id)
895 );
896 assert_eq!(
897 ops[1],
898 ObservedTableOp::Add("foo".to_string(), foo_table_id)
899 );
900
901 local_state.apply_sync_changes(tracker, new_version);
902 }
903
904 let observer_id_2 = service.add_observer(observer_2).unwrap();
905 assert!(local_state.should_sync(&service).is_none());
906
907 let observer_id_3 = service.add_observer(observer_3).unwrap();
908 let omega_table_id = service.get_table_id("omega").unwrap();
909 {
910 let new_version = local_state
911 .should_sync(&service)
912 .expect("Should have new version");
913 let (tracker, ops) = local_state
914 .calculate_sync_changes(&service)
915 .expect("must have changes");
916 assert!(tracker[foo_table_id]);
917 assert!(tracker[bar_table_id]);
918 assert!(tracker[omega_table_id]);
919 assert_eq!(ops.len(), 1);
920 assert_eq!(
921 ops[0],
922 ObservedTableOp::Add("omega".to_string(), omega_table_id)
923 );
924
925 local_state.apply_sync_changes(tracker, new_version);
926 }
927
928 service.remove_observer(observer_id_2).unwrap();
929 assert!(local_state.should_sync(&service).is_none());
930
931 service.remove_observer(observer_id_3).unwrap();
932 {
933 let new_version = local_state
934 .should_sync(&service)
935 .expect("Should have new version");
936 let (tracker, ops) = local_state
937 .calculate_sync_changes(&service)
938 .expect("must have changes");
939 assert!(tracker[foo_table_id]);
940 assert!(tracker[bar_table_id]);
941 assert!(!tracker[omega_table_id]);
942 assert_eq!(ops.len(), 1);
943 assert_eq!(
944 ops[0],
945 ObservedTableOp::Remove("omega".to_string(), omega_table_id)
946 );
947
948 local_state.apply_sync_changes(tracker, new_version);
949 }
950
951 service.remove_observer(observer_id_1).unwrap();
952 {
953 let new_version = local_state
954 .should_sync(&service)
955 .expect("Should have new version");
956 let (tracker, ops) = local_state
957 .calculate_sync_changes(&service)
958 .expect("must have changes");
959 assert!(!tracker[foo_table_id]);
960 assert!(!tracker[bar_table_id]);
961 assert!(!tracker[omega_table_id]);
962 assert_eq!(ops.len(), 2);
963 assert_eq!(
964 ops[1],
965 ObservedTableOp::Remove("foo".to_string(), foo_table_id)
966 );
967 assert_eq!(
968 ops[0],
969 ObservedTableOp::Remove("bar".to_string(), bar_table_id)
970 );
971
972 local_state.apply_sync_changes(tracker, new_version);
973 }
974 }
975}