Skip to main content

sqlite_watcher/
connection.rs

1use crate::statement::{
2    BatchQuery, Sealed, SqlExecuteStatement, SqlTransactionStatement, Statement, StatementWithInput,
3};
4use crate::watcher::{ObservedTableOp, Watcher};
5use fixedbitset::FixedBitSet;
6use std::error::Error;
7use std::future::Future;
8use std::ops::{Deref, DerefMut};
9use std::sync::Arc;
10use tracing::{debug, trace, warn};
11
12#[cfg(feature = "rusqlite")]
13pub mod rusqlite;
14
15#[cfg(feature = "sqlx")]
16pub mod sqlx;
17
18#[cfg(feature = "diesel")]
19pub mod diesel;
20
21/// Defines an implementation capable of executing SQL statement on a sqlite connection.
22///
23/// This is required so we can set up the temporary triggers and tables required to
24/// track changes.
25pub trait SqlExecutor {
26    type Error: Error;
27    /// This method will execute a query which returns 0 or N rows with one column of type `u32`.
28    ///
29    /// # Errors
30    ///
31    /// Should return error if the query failed.
32    fn sql_query_values(&self, query: &str) -> Result<Vec<u32>, Self::Error>;
33
34    /// Execute an sql statement which does not return any rows.
35    ///
36    /// # Errors
37    ///
38    /// Should return error if the query failed.
39    fn sql_execute(&self, query: &str) -> Result<(), Self::Error>;
40}
41
42/// Similar to [`SqlExecutor`], but for implementations that require mutable access to
43/// the connection to work.
44pub trait SqlExecutorMut {
45    type Error: Error;
46    /// This method will execute a query which returns 0 or N rows with one column of type `u32`.
47    ///
48    /// # Errors
49    ///
50    /// Should return error if the query failed.
51    fn sql_query_values(&mut self, query: &str) -> Result<Vec<u32>, Self::Error>;
52
53    /// Execute an sql statement which does not return any rows.
54    ///
55    /// # Errors
56    ///
57    /// Should return error if the query failed.
58    fn sql_execute(&mut self, query: &str) -> Result<(), Self::Error>;
59}
60
61// Automatically derive SqlExecutorMut for any implementation of SqlExecutor.
62impl<T: SqlExecutor> SqlExecutorMut for T {
63    type Error = T::Error;
64
65    fn sql_query_values(&mut self, query: &str) -> Result<Vec<u32>, Self::Error> {
66        SqlExecutor::sql_query_values(self, query)
67    }
68
69    fn sql_execute(&mut self, query: &str) -> Result<(), Self::Error> {
70        SqlExecutor::sql_execute(self, query)
71    }
72}
73
74/// Defines an implementation capable of executing SQL statement on a sqlite connection.
75///
76/// This is required so we can set up the temporary triggers and tables required to
77/// track changes.
78pub trait SqlExecutorAsync: Send {
79    type Error: Error + Send;
80    /// This method will execute a query which returns 0 or N rows with one column of type `u32`.
81    ///
82    /// # Errors
83    ///
84    /// Should return error if the query failed.
85    fn sql_query_values(
86        &mut self,
87        query: &str,
88    ) -> impl Future<Output = Result<Vec<u32>, Self::Error>> + Send;
89
90    /// Execute an sql statement which does not return any rows.
91    ///
92    /// # Errors
93    ///
94    /// Should return error if the query failed.
95    fn sql_execute(&mut self, query: &str) -> impl Future<Output = Result<(), Self::Error>> + Send;
96}
97
98/// Building block to provide tracking capabilities to any type of sqlite connection which
99/// implements the [`SqlExecutor`] trait.
100///
101/// # Initialization
102///
103/// It's recommended to call [`State::set_pragmas()`] to enable in memory temporary tables and recursive
104/// triggers. If your connection already has this set up, this can be skipped.
105///
106/// Next you need to create the infrastructure to track changes. This can be accomplished with
107/// [`State::start_tracking()`].
108///
109/// # Tracking changes
110///
111/// To make sure we only track required tables always call [`State::sync_tables()`] before a query/statement
112/// or a transaction.
113///
114/// When the query/statement or transaction are completed, call [`State::publish_changes()`] to check
115/// which tables have been modified and send this information to the watcher.
116///
117/// # Disable Tracking
118///
119/// If you wish to remove all the tracking infrastructure from a connection on which
120/// [`State::start_tracking()`] was called, then call [`State::stop_tracking()`].
121///
122/// # See Also
123///
124/// The [`Connection`] type provided by this crate provides an example integration implementation.
125#[derive(Debug, Default)]
126pub struct State {
127    tracked_tables: FixedBitSet,
128    last_sync_version: u64,
129}
130
131impl State {
132    /// Enable required pragmas for execution.
133    #[must_use]
134    pub fn set_pragmas() -> impl Statement {
135        SqlExecuteStatement::new("PRAGMA temp_store = MEMORY")
136            .then(SqlExecuteStatement::new("PRAGMA recursive_triggers='ON'"))
137    }
138
139    /// Prepare the `connection` for tracking.
140    ///
141    /// This will create the temporary table used to track change.
142    #[must_use]
143    #[tracing::instrument(level = tracing::Level::DEBUG)]
144    pub fn start_tracking() -> impl Statement {
145        // create tracking table and cleanup previous data if re-used from a connection pool.
146        SqlTransactionStatement::temporary(
147            SqlExecuteStatement::new(create_tracking_table_query())
148                .then(SqlExecuteStatement::new(empty_tracking_table_query())),
149        )
150        .spanned_in_current()
151    }
152
153    /// Remove all triggers and the tracking table from `connection`.
154    //
155    /// # Errors
156    ///
157    /// Returns error if the initialization failed.
158    #[tracing::instrument(level = tracing::Level::DEBUG, skip_all)]
159    pub fn stop_tracking(&self, watcher: &Watcher) -> impl Statement {
160        let tables = watcher.observed_tables();
161        SqlTransactionStatement::temporary(
162            BatchQuery::new(
163                tables
164                    .into_iter()
165                    .enumerate()
166                    .flat_map(|(id, table_name)| drop_triggers(&table_name, id)),
167            )
168            .then(SqlExecuteStatement::new(drop_tracking_table_query())),
169        )
170        .spanned_in_current()
171    }
172
173    /// Create a new instance without initializing any connection.
174    #[must_use]
175    pub fn new() -> Self {
176        Self {
177            tracked_tables: FixedBitSet::new(),
178            last_sync_version: 0,
179        }
180    }
181
182    /// Synchronize the table list from the watcher.
183    ///
184    /// This method will create new triggers for tables that are not being watched over this
185    /// connection and remove triggers for tables that are no longer observed by the watcher.
186    ///
187    /// # Errors
188    ///
189    /// Returns error if creation or removal of triggers failed.
190    #[tracing::instrument(level=tracing::Level::DEBUG, skip(self, watcher))]
191    pub fn sync_tables(&mut self, watcher: &Watcher) -> Option<impl Statement + '_> {
192        let new_version = self.should_sync(watcher)?;
193
194        debug!("Syncing tables from observer");
195        let Some((new_tracker_state, tracker_changes)) = self.calculate_sync_changes(watcher)
196        else {
197            debug!("No changes");
198            return None;
199        };
200
201        let mut queries = BatchQuery::new([]);
202
203        if self.tracked_tables.is_empty() {
204            // It is possible on certain circumstances that if a connection can have leftover
205            // tracking data that is not cleared. To make sure this is reset, we force empty
206            // the table if we detect that we are not watching any tables at the moment.
207            queries.push(SqlExecuteStatement::new(empty_tracking_table_query()));
208        }
209        for change in tracker_changes {
210            match change {
211                ObservedTableOp::Add(table_name, id) => {
212                    debug!("Add watcher for table {table_name} id={id}");
213                    queries.extend(create_triggers(&table_name, id));
214                }
215                ObservedTableOp::Remove(table_name, id) => {
216                    debug!("Remove watcher for table {table_name}");
217                    queries.extend(drop_triggers(&table_name, id));
218                }
219            }
220        }
221
222        let tx = SqlTransactionStatement::temporary(queries);
223        Some(
224            tx.then(ConcludeStateChangeStatement {
225                state: self,
226                tracked_tables: new_tracker_state,
227                new_version,
228            })
229            .spanned_in_current(),
230        )
231    }
232
233    /// Check the tracking table and report finding to the [Watcher].
234    ///
235    /// The table where the changes are tracked is read and reset. Any
236    /// table that has been modified will be communicated to the [Watcher], which in turn
237    /// will notify the respective [TableObserver].
238    ///
239    /// # Errors
240    ///
241    /// Returns error if we failed to read from the temporary tables.
242    ///
243    /// [Watcher]: `crate::watcher::Watcher`
244    /// [TableObserver]: `crate::watcher::TableObserver`
245    #[tracing::instrument(level=tracing::Level::DEBUG, skip(self, watcher))]
246    pub fn publish_changes(&self, watcher: &Watcher) -> impl Statement {
247        SqlReadTableIdsStatement
248            .pipe(CalculateWatcherUpdatesStatement { state: self })
249            .pipe(MaybeResetResultsQuery)
250            .pipe(PublishWatcherChangesStatement(watcher))
251            .spanned_in_current()
252    }
253
254    fn prepare_watcher_changes(&self, modified_table_ids: Vec<u32>) -> FixedBitSet {
255        trace!("Preparing watcher changes");
256        let mut result = FixedBitSet::with_capacity(self.tracked_tables.len());
257        result.grow(self.tracked_tables.len());
258        for id in modified_table_ids {
259            let id = id as usize;
260            debug!("Table {} has been modified", id);
261            if id >= result.len() {
262                warn!(
263                    "Received update for table {id}, but only tracking {} tables",
264                    self.tracked_tables.len(),
265                );
266                // We need to grow on the index + 1.
267                result.grow(id + 1);
268            }
269            result.set(id, true);
270        }
271
272        result
273    }
274
275    fn should_sync(&self, watcher: &Watcher) -> Option<u64> {
276        let service_version = watcher.tables_version();
277        if service_version == self.last_sync_version {
278            None
279        } else {
280            Some(service_version)
281        }
282    }
283
284    /// Determine which tables should start and/or stop being watched.
285    fn calculate_sync_changes(
286        &self,
287        watcher: &Watcher,
288    ) -> Option<(FixedBitSet, Vec<ObservedTableOp>)> {
289        trace!("Calculating sync changes");
290        let (new_tracker_state, tracker_changes) =
291            watcher.calculate_sync_changes(&self.tracked_tables);
292
293        if tracker_changes.is_empty() {
294            return None;
295        }
296
297        Some((new_tracker_state, tracker_changes))
298    }
299
300    /// Once we are satisfied with the changes, apply the new state.
301    fn apply_sync_changes(&mut self, new_tracker_state: FixedBitSet, new_version: u64) {
302        // Update local tracker bitset
303        trace!("Applying sync changes");
304        self.tracked_tables = new_tracker_state;
305        self.last_sync_version = new_version;
306    }
307}
308
309/// Connection abstraction that provides on possible implementation which uses the building
310/// blocks ([`State`]) provided by this crate.
311///
312/// For simplicity, it takes ownership of an existing type which implements [`SqlExecutor`] and
313/// initializes all the tracking infrastructure. The original type can still be accessed as
314/// [`Connection`] implements both [`Deref`] and [`DerefMut`].
315///
316/// # Remarks
317///
318/// To make sure all changes are capture, it's recommended to always call
319/// [`Connection::sync_watcher_tables()`]
320/// before any query/statement or transaction.
321///
322/// # Example
323///
324/// ## Single Query/Statement
325///
326/// ```rust
327/// use sqlite_watcher::connection::Connection;
328/// use sqlite_watcher::connection::SqlExecutor;
329/// use sqlite_watcher::watcher::Watcher;
330///
331/// pub fn track_changes<C:SqlExecutor>(connection: C) {
332///     let watcher = Watcher::new().unwrap();
333///     let mut connection = Connection::new(connection, watcher).unwrap();
334///
335///     // Sync tables so we are up to date.
336///     connection.sync_watcher_tables().unwrap();
337///
338///     connection.sql_execute("sql query here").unwrap();
339///
340///     // Publish changes to the watcher
341///     connection.publish_watcher_changes().unwrap();
342/// }
343/// ```
344///
345/// ## Transaction
346///
347/// ```rust
348/// use sqlite_watcher::connection::Connection;
349/// use sqlite_watcher::connection::{SqlExecutor};
350/// use sqlite_watcher::watcher::Watcher;
351///
352/// pub fn track_changes<C:SqlExecutor>(connection: C) {
353///     let watcher = Watcher::new().unwrap();
354///     let mut connection = Connection::new(connection, watcher).unwrap();
355///
356///     // Sync tables so we are up to date.
357///     connection.sync_watcher_tables().unwrap();
358///
359///     // Start a transaction
360///     connection.sql_execute("sql query here").unwrap();
361///     connection.sql_execute("sql query here").unwrap();
362///     // Commit transaction
363///
364///     // Publish changes to the watcher
365///     connection.publish_watcher_changes().unwrap();
366/// }
367/// ```
368pub struct Connection<C: SqlExecutorMut> {
369    state: State,
370    watcher: Arc<Watcher>,
371    #[allow(clippy::struct_field_names)]
372    connection: C,
373}
374impl<C: SqlExecutorMut> Connection<C> {
375    /// Create a new connection with `connection` and `watcher`.
376    ///
377    /// See [`State::start_tracking()`] for more information about initialization.
378    ///
379    /// # Errors
380    ///
381    /// Returns error if the initialization failed.
382    pub fn new(mut connection: C, watcher: Arc<Watcher>) -> Result<Self, C::Error> {
383        let state = State::new();
384        State::set_pragmas().execute_mut(&mut connection)?;
385        State::start_tracking().execute_mut(&mut connection)?;
386        Ok(Self {
387            state,
388            watcher,
389            connection,
390        })
391    }
392
393    /// Sync tables from the [`Watcher`] and update tracking infrastructure.
394    ///
395    /// See [`State::sync_tables()`] for more information.
396    ///
397    /// # Errors
398    ///
399    /// Returns error if we failed to sync the changes to the database.
400    pub fn sync_watcher_tables(&mut self) -> Result<(), C::Error> {
401        self.state
402            .sync_tables(&self.watcher)
403            .execute_mut(&mut self.connection)?;
404        Ok(())
405    }
406
407    /// Check if any tables have changed and notify the [`Watcher`]
408    ///
409    /// See [`State::publish_changes()`] for more information.
410    ///
411    /// It is recommended to call this method
412    ///
413    /// # Errors
414    ///
415    /// Returns error if we failed to check for changes.
416    pub fn publish_watcher_changes(&mut self) -> Result<(), C::Error> {
417        self.state
418            .publish_changes(&self.watcher)
419            .execute_mut(&mut self.connection)?;
420        Ok(())
421    }
422
423    /// Disable all tracking on this connection.
424    ///
425    /// See [`State::stop_tracking`] for more details.
426    ///
427    /// # Errors
428    ///
429    /// Returns error if the queries failed.
430    pub fn stop_tracking(&mut self) -> Result<(), C::Error> {
431        self.state
432            .stop_tracking(&self.watcher)
433            .execute_mut(&mut self.connection)?;
434        Ok(())
435    }
436
437    /// Consume the current connection and take ownership of the real sql connection.
438    ///
439    /// # Remarks
440    ///
441    /// This does not stop the tracking infrastructure enabled on the connection.
442    /// Use [`Self::stop_tracking()`] to disable it first.
443    pub fn take(self) -> C {
444        self.connection
445    }
446}
447
448/// Same as [`Connection`] but with an async executor.
449#[allow(clippy::module_name_repetitions)]
450pub struct ConnectionAsync<C: SqlExecutorAsync> {
451    state: State,
452    watcher: Arc<Watcher>,
453    connection: C,
454}
455impl<C: SqlExecutorAsync> ConnectionAsync<C> {
456    /// Create a new connection with `connection` and `watcher`.
457    ///
458    /// See [`State::start_tracking()`] for more information about initialization.
459    ///
460    /// # Errors
461    ///
462    /// Returns error if the initialization failed.
463    pub async fn new(mut connection: C, watcher: Arc<Watcher>) -> Result<Self, C::Error> {
464        let state = State::new();
465        State::set_pragmas().execute_async(&mut connection).await?;
466        State::start_tracking()
467            .execute_async(&mut connection)
468            .await?;
469        Ok(Self {
470            state,
471            watcher,
472            connection,
473        })
474    }
475
476    /// See [`Connection::sync_watcher_tables`] for more details.
477    ///
478    /// # Errors
479    ///
480    /// Returns error if we failed to sync the changes to the database.
481    pub async fn sync_watcher_tables(&mut self) -> Result<(), C::Error> {
482        self.state
483            .sync_tables(&self.watcher)
484            .execute_async(&mut self.connection)
485            .await?;
486        Ok(())
487    }
488
489    /// See [`Connection::publish_watcher_changes`] for more details.
490    ///
491    /// # Errors
492    ///
493    /// Returns error if we failed to check for changes.
494    pub async fn publish_watcher_changes(&mut self) -> Result<(), C::Error> {
495        self.state
496            .publish_changes(&self.watcher)
497            .execute_async(&mut self.connection)
498            .await?;
499        Ok(())
500    }
501
502    /// See [`Connection::stop_tracking`] for more details.
503    ///
504    /// # Errors
505    ///
506    /// Returns error if the queries failed.
507    pub async fn stop_tracking(&mut self) -> Result<(), C::Error> {
508        self.state
509            .stop_tracking(&self.watcher)
510            .execute_async(&mut self.connection)
511            .await?;
512        Ok(())
513    }
514
515    /// Consume the current connection and take ownership of the real sql connection.
516    ///
517    /// # Remarks
518    ///
519    /// This does not stop the tracking infrastructure enabled on the connection.
520    /// Use [`Self::stop_tracking()`] to disable it first.
521    pub fn take(self) -> C {
522        self.connection
523    }
524}
525
526impl<C: SqlExecutorAsync> Deref for ConnectionAsync<C> {
527    type Target = C;
528
529    fn deref(&self) -> &Self::Target {
530        &self.connection
531    }
532}
533
534impl<C: SqlExecutorAsync> DerefMut for ConnectionAsync<C> {
535    fn deref_mut(&mut self) -> &mut Self::Target {
536        &mut self.connection
537    }
538}
539
540impl<C: SqlExecutorAsync> AsRef<C> for ConnectionAsync<C> {
541    fn as_ref(&self) -> &C {
542        &self.connection
543    }
544}
545
546impl<C: SqlExecutorAsync> AsMut<C> for ConnectionAsync<C> {
547    fn as_mut(&mut self) -> &mut C {
548        &mut self.connection
549    }
550}
551
552impl<C: SqlExecutor> Deref for Connection<C> {
553    type Target = C;
554
555    fn deref(&self) -> &Self::Target {
556        &self.connection
557    }
558}
559
560impl<C: SqlExecutor> DerefMut for Connection<C> {
561    fn deref_mut(&mut self) -> &mut Self::Target {
562        &mut self.connection
563    }
564}
565
566impl<C: SqlExecutor> AsRef<C> for Connection<C> {
567    fn as_ref(&self) -> &C {
568        &self.connection
569    }
570}
571
572impl<C: SqlExecutor> AsMut<C> for Connection<C> {
573    fn as_mut(&mut self) -> &mut C {
574        &mut self.connection
575    }
576}
577
578const TRACKER_TABLE_NAME: &str = "rsqlite_watcher_version_tracker";
579
580const TRIGGER_LIST: [(&str, &str); 3] = [
581    ("INSERT", "insert"),
582    ("UPDATE", "update"),
583    ("DELETE", "delete"),
584];
585
586#[inline]
587fn create_tracking_table_query() -> String {
588    format!(
589        "CREATE TEMP TABLE IF NOT EXISTS `{TRACKER_TABLE_NAME}` (table_id INTEGER PRIMARY KEY, updated INTEGER)"
590    )
591}
592#[inline]
593fn empty_tracking_table_query() -> String {
594    format!("DELETE FROM `{TRACKER_TABLE_NAME}`")
595}
596#[inline]
597fn drop_tracking_table_query() -> String {
598    format!("DROP TABLE IF EXISTS `{TRACKER_TABLE_NAME}`")
599}
600
601#[inline]
602fn create_trigger_query(
603    table_name: &str,
604    trigger: &str,
605    trigger_name: &str,
606    table_id: usize,
607) -> String {
608    format!(
609        r"
610CREATE TEMP TRIGGER IF NOT EXISTS `{TRACKER_TABLE_NAME}_trigger_{table_name}_{trigger_name}` AFTER {trigger} ON `{table_name}`
611BEGIN
612    UPDATE  `{TRACKER_TABLE_NAME}` SET updated=1 WHERE table_id={table_id};
613END
614            "
615    )
616}
617
618#[inline]
619fn insert_table_id_into_tracking_table_query(id: usize) -> String {
620    format!("INSERT INTO `{TRACKER_TABLE_NAME}` VALUES ({id},0)")
621}
622
623#[inline]
624fn drop_trigger_query(table_name: &str, trigger_name: &str) -> String {
625    format!("DROP TRIGGER IF EXISTS `{TRACKER_TABLE_NAME}_trigger_{table_name}_{trigger_name}`")
626}
627
628#[inline]
629fn remove_table_id_from_tracking_table_query(table_id: usize) -> String {
630    format!("DELETE FROM `{TRACKER_TABLE_NAME}` WHERE table_id={table_id}")
631}
632
633#[inline]
634fn select_updated_tables_query() -> String {
635    format!("SELECT table_id  FROM `{TRACKER_TABLE_NAME}` WHERE updated=1")
636}
637
638#[inline]
639fn reset_updated_tables_query() -> String {
640    format!("UPDATE `{TRACKER_TABLE_NAME}` SET updated=0 WHERE updated=1")
641}
642
643/// Create tracking triggers for `table` with `id`.
644fn create_triggers(table: &str, id: usize) -> Vec<SqlExecuteStatement<String>> {
645    TRIGGER_LIST
646        .iter()
647        .map(|(trigger, trigger_name)| {
648            let query = create_trigger_query(table, trigger, trigger_name, id);
649            SqlExecuteStatement::new(query)
650        })
651        .chain(std::iter::once_with(|| {
652            let query = insert_table_id_into_tracking_table_query(id);
653            SqlExecuteStatement::new(query)
654        }))
655        .collect()
656}
657
658/// Remove tracking triggers for `table` with `id`.
659fn drop_triggers(table: &str, id: usize) -> Vec<SqlExecuteStatement<String>> {
660    TRIGGER_LIST
661        .iter()
662        .map(|(_, trigger_name)| {
663            let query = drop_trigger_query(table, trigger_name);
664            SqlExecuteStatement::new(query)
665        })
666        .chain(std::iter::once_with(|| {
667            let query = remove_table_id_from_tracking_table_query(id);
668            SqlExecuteStatement::new(query)
669        }))
670        .collect()
671}
672
673/// Apply the new tracked table state to a [`State`].
674struct ConcludeStateChangeStatement<'s> {
675    state: &'s mut State,
676    tracked_tables: FixedBitSet,
677    new_version: u64,
678}
679
680impl Sealed for ConcludeStateChangeStatement<'_> {}
681impl Statement for ConcludeStateChangeStatement<'_> {
682    type Output = ();
683    fn execute<S: SqlExecutor>(self, _: &S) -> Result<Self::Output, S::Error> {
684        self.state
685            .apply_sync_changes(self.tracked_tables, self.new_version);
686        Ok(())
687    }
688
689    fn execute_mut<S: SqlExecutorMut>(self, _: &mut S) -> Result<Self::Output, S::Error> {
690        self.state
691            .apply_sync_changes(self.tracked_tables, self.new_version);
692        Ok(())
693    }
694
695    async fn execute_async<S: SqlExecutorAsync>(self, _: &mut S) -> Result<Self::Output, S::Error> {
696        self.state
697            .apply_sync_changes(self.tracked_tables, self.new_version);
698        Ok(())
699    }
700}
701
702/// Calculate what the changes to be sent to the watcher.
703struct CalculateWatcherUpdatesStatement<'s> {
704    state: &'s State,
705}
706
707impl StatementWithInput for CalculateWatcherUpdatesStatement<'_> {
708    type Input = Vec<u32>;
709    type Output = FixedBitSet;
710
711    fn execute<S: SqlExecutor>(self, _: &S, input: Self::Input) -> Result<Self::Output, S::Error> {
712        Ok(self.state.prepare_watcher_changes(input))
713    }
714    fn execute_mut<S: SqlExecutorMut>(
715        self,
716        _: &mut S,
717        input: Self::Input,
718    ) -> Result<Self::Output, S::Error> {
719        Ok(self.state.prepare_watcher_changes(input))
720    }
721    async fn execute_async<S: SqlExecutorAsync>(
722        self,
723        _: &mut S,
724        input: Self::Input,
725    ) -> Result<Self::Output, S::Error> {
726        Ok(self.state.prepare_watcher_changes(input))
727    }
728}
729
730/// Publish the changes to the watcher.
731struct PublishWatcherChangesStatement<'w>(&'w Watcher);
732
733impl Sealed for PublishWatcherChangesStatement<'_> {}
734
735impl StatementWithInput for PublishWatcherChangesStatement<'_> {
736    type Input = FixedBitSet;
737    type Output = ();
738
739    fn execute<S: SqlExecutor>(self, _: &S, input: Self::Input) -> Result<Self::Output, S::Error> {
740        self.0.publish_changes(input);
741        Ok(())
742    }
743    fn execute_mut<S: SqlExecutorMut>(
744        self,
745        _: &mut S,
746        input: Self::Input,
747    ) -> Result<Self::Output, S::Error> {
748        self.0.publish_changes(input);
749        Ok(())
750    }
751    async fn execute_async<S: SqlExecutorAsync>(
752        self,
753        _: &mut S,
754        input: Self::Input,
755    ) -> Result<Self::Output, S::Error> {
756        self.0.publish_changes_async(input).await;
757        Ok(())
758    }
759}
760
761impl Sealed for SqlReadTableIdsStatement {}
762struct SqlReadTableIdsStatement;
763impl Statement for SqlReadTableIdsStatement {
764    type Output = Vec<u32>;
765    fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
766        connection.sql_query_values(&select_updated_tables_query())
767    }
768    fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
769        connection.sql_query_values(&select_updated_tables_query())
770    }
771    async fn execute_async<S: SqlExecutorAsync>(
772        self,
773        connection: &mut S,
774    ) -> Result<Self::Output, S::Error> {
775        connection
776            .sql_query_values(&select_updated_tables_query())
777            .await
778    }
779}
780
781/// It is possible on certain circumstances that if a connection can have leftover
782/// tracking data that is not cleared. To make sure this is reset, we force empty
783/// the table if we detect that we are not watching any tables at the moment.
784struct MaybeResetResultsQuery;
785impl StatementWithInput for MaybeResetResultsQuery {
786    type Input = FixedBitSet;
787    type Output = FixedBitSet;
788
789    fn execute<S: SqlExecutor>(
790        self,
791        connection: &S,
792        input: Self::Input,
793    ) -> Result<Self::Output, S::Error> {
794        if !input.is_clear() {
795            // Reset updated values.
796            connection.sql_execute(&reset_updated_tables_query())?;
797        }
798        Ok(input)
799    }
800    fn execute_mut<S: SqlExecutorMut>(
801        self,
802        connection: &mut S,
803        input: Self::Input,
804    ) -> Result<Self::Output, S::Error> {
805        if !input.is_clear() {
806            // Reset updated values.
807            connection.sql_execute(&reset_updated_tables_query())?;
808        }
809        Ok(input)
810    }
811    async fn execute_async<S: SqlExecutorAsync>(
812        self,
813        connection: &mut S,
814        input: Self::Input,
815    ) -> Result<Self::Output, S::Error> {
816        if !input.is_clear() {
817            // Reset updated values.
818            connection
819                .sql_execute(&reset_updated_tables_query())
820                .await?;
821        }
822        Ok(input)
823    }
824}
825
826#[cfg(test)]
827mod test {
828    use crate::connection::State;
829    use crate::watcher::tests::new_test_observer;
830    use crate::watcher::{ObservedTableOp, TableObserver, Watcher};
831    use std::collections::BTreeSet;
832    use std::ops::Index;
833    use std::sync::Mutex;
834    use std::sync::mpsc::{Receiver, SyncSender};
835
836    pub struct TestObserver {
837        expected: Mutex<Vec<BTreeSet<String>>>,
838        tables: Vec<String>,
839        // Channel is here to make sure we don't trigger a merge of multiple pending updates.
840        checked_channel: SyncSender<()>,
841    }
842
843    impl TestObserver {
844        pub fn new(
845            tables: Vec<String>,
846            expected: impl IntoIterator<Item = BTreeSet<String>>,
847        ) -> (Self, Receiver<()>) {
848            let (sender, receiver) = std::sync::mpsc::sync_channel::<()>(0);
849            let mut expected = expected.into_iter().collect::<Vec<_>>();
850            expected.reverse();
851            (
852                Self {
853                    expected: Mutex::new(expected),
854                    tables,
855                    checked_channel: sender,
856                },
857                receiver,
858            )
859        }
860    }
861
862    impl TableObserver for TestObserver {
863        fn tables(&self) -> Vec<String> {
864            self.tables.clone()
865        }
866
867        fn on_tables_changed(&self, tables: &BTreeSet<String>) {
868            let expected = self.expected.lock().unwrap().pop().unwrap();
869            assert_eq!(*tables, expected);
870            self.checked_channel.send(()).unwrap();
871        }
872    }
873
874    #[test]
875    fn connection_state() {
876        let service = Watcher::new().unwrap();
877
878        let observer_1 = new_test_observer(["foo", "bar"]);
879        let observer_2 = new_test_observer(["bar"]);
880        let observer_3 = new_test_observer(["bar", "omega"]);
881
882        let mut local_state = State::new();
883
884        assert!(local_state.should_sync(&service).is_none());
885        let observer_id_1 = service.add_observer(observer_1).unwrap();
886        let foo_table_id = service.get_table_id("foo").unwrap();
887        let bar_table_id = service.get_table_id("bar").unwrap();
888        {
889            let new_version = local_state
890                .should_sync(&service)
891                .expect("Should have new version");
892            let (tracker, ops) = local_state
893                .calculate_sync_changes(&service)
894                .expect("must have changes");
895            assert!(tracker[bar_table_id]);
896            assert!(tracker[foo_table_id]);
897            assert_eq!(ops.len(), 2);
898            assert_eq!(
899                ops[0],
900                ObservedTableOp::Add("bar".to_string(), bar_table_id)
901            );
902            assert_eq!(
903                ops[1],
904                ObservedTableOp::Add("foo".to_string(), foo_table_id)
905            );
906
907            local_state.apply_sync_changes(tracker, new_version);
908        }
909
910        let observer_id_2 = service.add_observer(observer_2).unwrap();
911        assert!(local_state.should_sync(&service).is_none());
912
913        let observer_id_3 = service.add_observer(observer_3).unwrap();
914        let omega_table_id = service.get_table_id("omega").unwrap();
915        {
916            let new_version = local_state
917                .should_sync(&service)
918                .expect("Should have new version");
919            let (tracker, ops) = local_state
920                .calculate_sync_changes(&service)
921                .expect("must have changes");
922            assert!(tracker[foo_table_id]);
923            assert!(tracker[bar_table_id]);
924            assert!(tracker[omega_table_id]);
925            assert_eq!(ops.len(), 1);
926            assert_eq!(
927                ops[0],
928                ObservedTableOp::Add("omega".to_string(), omega_table_id)
929            );
930
931            local_state.apply_sync_changes(tracker, new_version);
932        }
933
934        service.remove_observer(observer_id_2).unwrap();
935        assert!(local_state.should_sync(&service).is_none());
936
937        service.remove_observer(observer_id_3).unwrap();
938        {
939            let new_version = local_state
940                .should_sync(&service)
941                .expect("Should have new version");
942            let (tracker, ops) = local_state
943                .calculate_sync_changes(&service)
944                .expect("must have changes");
945            assert!(tracker[foo_table_id]);
946            assert!(tracker[bar_table_id]);
947            assert!(!tracker[omega_table_id]);
948            assert_eq!(ops.len(), 1);
949            assert_eq!(
950                ops[0],
951                ObservedTableOp::Remove("omega".to_string(), omega_table_id)
952            );
953
954            local_state.apply_sync_changes(tracker, new_version);
955        }
956
957        service.remove_observer(observer_id_1).unwrap();
958        {
959            let new_version = local_state
960                .should_sync(&service)
961                .expect("Should have new version");
962            let (tracker, ops) = local_state
963                .calculate_sync_changes(&service)
964                .expect("must have changes");
965            assert!(!tracker[foo_table_id]);
966            assert!(!tracker[bar_table_id]);
967            assert!(!tracker[omega_table_id]);
968            assert_eq!(ops.len(), 2);
969            assert_eq!(
970                ops[1],
971                ObservedTableOp::Remove("foo".to_string(), foo_table_id)
972            );
973            assert_eq!(
974                ops[0],
975                ObservedTableOp::Remove("bar".to_string(), bar_table_id)
976            );
977
978            local_state.apply_sync_changes(tracker, new_version);
979        }
980    }
981
982    #[test]
983    fn prepare_watcher_changes_out_of_bounds_table_id() {
984        // Empty state but we receive update for positions that we do not know about.
985        let state = State::new();
986        let result = state.prepare_watcher_changes(vec![4, 3]);
987        assert_eq!(result.len(), 5);
988        assert!(result.index(4));
989        assert!(result.index(3));
990        // State with N tables, but we receive an update for table N + 1
991        let mut state = State::new();
992        state.tracked_tables.grow(4);
993        let result = state.prepare_watcher_changes(vec![4, 8]);
994        assert_eq!(result.len(), 9);
995        assert!(result.index(4));
996        assert!(result.index(8));
997    }
998}