sqlite_watcher/
connection.rs

1use crate::statement::{
2    BatchQuery, Sealed, SqlExecuteStatement, SqlTransactionStatement, Statement, StatementWithInput,
3};
4use crate::watcher::{ObservedTableOp, Watcher};
5use fixedbitset::FixedBitSet;
6use std::error::Error;
7use std::future::Future;
8use std::ops::{Deref, DerefMut};
9use std::sync::Arc;
10use tracing::{debug, trace, warn};
11
12#[cfg(feature = "rusqlite")]
13pub mod rusqlite;
14
15#[cfg(feature = "sqlx")]
16pub mod sqlx;
17
18#[cfg(feature = "diesel")]
19pub mod diesel;
20
21/// Defines an implementation capable of executing SQL statement on a sqlite connection.
22///
23/// This is required so we can set up the temporary triggers and tables required to
24/// track changes.
25pub trait SqlExecutor {
26    type Error: Error;
27    /// This method will execute a query which returns 0 or N rows with one column of type `u32`.
28    ///
29    /// # Errors
30    ///
31    /// Should return error if the query failed.
32    fn sql_query_values(&self, query: &str) -> Result<Vec<u32>, Self::Error>;
33
34    /// Execute an sql statement which does not return any rows.
35    ///
36    /// # Errors
37    ///
38    /// Should return error if the query failed.
39    fn sql_execute(&self, query: &str) -> Result<(), Self::Error>;
40}
41
42/// Similar to [`SqlExecutor`], but for implementations that require mutable access to
43/// the connection to work.
44pub trait SqlExecutorMut {
45    type Error: Error;
46    /// This method will execute a query which returns 0 or N rows with one column of type `u32`.
47    ///
48    /// # Errors
49    ///
50    /// Should return error if the query failed.
51    fn sql_query_values(&mut self, query: &str) -> Result<Vec<u32>, Self::Error>;
52
53    /// Execute an sql statement which does not return any rows.
54    ///
55    /// # Errors
56    ///
57    /// Should return error if the query failed.
58    fn sql_execute(&mut self, query: &str) -> Result<(), Self::Error>;
59}
60
61// Automatically derive SqlExecutorMut for any implementation of SqlExecutor.
62impl<T: SqlExecutor> SqlExecutorMut for T {
63    type Error = T::Error;
64
65    fn sql_query_values(&mut self, query: &str) -> Result<Vec<u32>, Self::Error> {
66        SqlExecutor::sql_query_values(self, query)
67    }
68
69    fn sql_execute(&mut self, query: &str) -> Result<(), Self::Error> {
70        SqlExecutor::sql_execute(self, query)
71    }
72}
73
74/// Defines an implementation capable of executing SQL statement on a sqlite connection.
75///
76/// This is required so we can set up the temporary triggers and tables required to
77/// track changes.
78pub trait SqlExecutorAsync: Send {
79    type Error: Error + Send;
80    /// This method will execute a query which returns 0 or N rows with one column of type `u32`.
81    ///
82    /// # Errors
83    ///
84    /// Should return error if the query failed.
85    fn sql_query_values(
86        &mut self,
87        query: &str,
88    ) -> impl Future<Output = Result<Vec<u32>, Self::Error>> + Send;
89
90    /// Execute an sql statement which does not return any rows.
91    ///
92    /// # Errors
93    ///
94    /// Should return error if the query failed.
95    fn sql_execute(&mut self, query: &str) -> impl Future<Output = Result<(), Self::Error>> + Send;
96}
97
98/// Building block to provide tracking capabilities to any type of sqlite connection which
99/// implements the [`SqlExecutor`] trait.
100///
101/// # Initialization
102///
103/// It's recommended to call [`State::set_pragmas()`] to enable in memory temporary tables and recursive
104/// triggers. If your connection already has this set up, this can be skipped.
105///
106/// Next you need to create the infrastructure to track changes. This can be accomplished with
107/// [`State::start_tracking()`].
108///
109/// # Tracking changes
110///
111/// To make sure we only track required tables always call [`State::sync_tables()`] before a query/statement
112/// or a transaction.
113///
114/// When the query/statement or transaction are completed, call [`State::publish_changes()`] to check
115/// which tables have been modified and send this information to the watcher.
116///
117/// # Disable Tracking
118///
119/// If you wish to remove all the tracking infrastructure from a connection on which
120/// [`State::start_tracking()`] was called, then call [`State::stop_tracking()`].
121///
122/// # See Also
123///
124/// The [`Connection`] type provided by this crate provides an example integration implementation.
125#[derive(Debug, Default)]
126pub struct State {
127    tracked_tables: FixedBitSet,
128    last_sync_version: u64,
129}
130
131impl State {
132    /// Enable required pragmas for execution.
133    #[must_use]
134    pub fn set_pragmas() -> impl Statement {
135        SqlExecuteStatement::new("PRAGMA temp_store = MEMORY")
136            .then(SqlExecuteStatement::new("PRAGMA recursive_triggers='ON'"))
137    }
138
139    /// Prepare the `connection` for tracking.
140    ///
141    /// This will create the temporary table used to track change.
142    #[must_use]
143    #[tracing::instrument(level = tracing::Level::DEBUG)]
144    pub fn start_tracking() -> impl Statement {
145        // create tracking table and cleanup previous data if re-used from a connection pool.
146        SqlTransactionStatement::temporary(
147            SqlExecuteStatement::new(create_tracking_table_query())
148                .then(SqlExecuteStatement::new(empty_tracking_table_query())),
149        )
150        .spanned_in_current()
151    }
152
153    /// Remove all triggers and the tracking table from `connection`.
154    //
155    /// # Errors
156    ///
157    /// Returns error if the initialization failed.
158    #[tracing::instrument(level = tracing::Level::DEBUG, skip_all)]
159    pub fn stop_tracking(&self, watcher: &Watcher) -> impl Statement {
160        let tables = watcher.observed_tables();
161        SqlTransactionStatement::temporary(
162            BatchQuery::new(
163                tables
164                    .into_iter()
165                    .enumerate()
166                    .flat_map(|(id, table_name)| drop_triggers(&table_name, id)),
167            )
168            .then(SqlExecuteStatement::new(drop_tracking_table_query())),
169        )
170        .spanned_in_current()
171    }
172
173    /// Create a new instance without initializing any connection.
174    #[must_use]
175    pub fn new() -> Self {
176        Self {
177            tracked_tables: FixedBitSet::new(),
178            last_sync_version: 0,
179        }
180    }
181
182    /// Synchronize the table list from the watcher.
183    ///
184    /// This method will create new triggers for tables that are not being watched over this
185    /// connection and remove triggers for tables that are no longer observed by the watcher.
186    ///
187    /// # Errors
188    ///
189    /// Returns error if creation or removal of triggers failed.
190    #[tracing::instrument(level=tracing::Level::DEBUG, skip(self, watcher))]
191    pub fn sync_tables(&mut self, watcher: &Watcher) -> Option<impl Statement + '_> {
192        let new_version = self.should_sync(watcher)?;
193
194        debug!("Syncing tables from observer");
195        let Some((new_tracker_state, tracker_changes)) = self.calculate_sync_changes(watcher)
196        else {
197            debug!("No changes");
198            return None;
199        };
200
201        let mut queries = BatchQuery::new([]);
202
203        if self.tracked_tables.is_empty() {
204            // It is possible on certain circumstances that if a connection can have leftover
205            // tracking data that is not cleared. To make sure this is reset, we force empty
206            // the table if we detect that we are not watching any tables at the moment.
207            queries.push(SqlExecuteStatement::new(empty_tracking_table_query()));
208        }
209        for change in tracker_changes {
210            match change {
211                ObservedTableOp::Add(table_name, id) => {
212                    debug!("Add watcher for table {table_name} id={id}");
213                    queries.extend(create_triggers(&table_name, id));
214                }
215                ObservedTableOp::Remove(table_name, id) => {
216                    debug!("Remove watcher for table {table_name}");
217                    queries.extend(drop_triggers(&table_name, id));
218                }
219            }
220        }
221
222        let tx = SqlTransactionStatement::temporary(queries);
223        Some(
224            tx.then(ConcludeStateChangeStatement {
225                state: self,
226                tracked_tables: new_tracker_state,
227                new_version,
228            })
229            .spanned_in_current(),
230        )
231    }
232
233    /// Check the tracking table and report finding to the [Watcher].
234    ///
235    /// The table where the changes are tracked is read and reset. Any
236    /// table that has been modified will be communicated to the [Watcher], which in turn
237    /// will notify the respective [TableObserver].
238    ///
239    /// # Errors
240    ///
241    /// Returns error if we failed to read from the temporary tables.
242    ///
243    /// [Watcher]: `crate::watcher::Watcher`
244    /// [TableObserver]: `crate::watcher::TableObserver`
245    #[tracing::instrument(level=tracing::Level::DEBUG, skip(self, watcher))]
246    pub fn publish_changes(&self, watcher: &Watcher) -> impl Statement {
247        SqlReadTableIdsStatement
248            .pipe(CalculateWatcherUpdatesStatement { state: self })
249            .pipe(MaybeResetResultsQuery)
250            .pipe(PublishWatcherChangesStatement(watcher))
251            .spanned_in_current()
252    }
253
254    fn prepare_watcher_changes(&self, modified_table_ids: Vec<u32>) -> FixedBitSet {
255        trace!("Preparing watcher changes");
256        let mut result = FixedBitSet::with_capacity(self.tracked_tables.len());
257        result.grow(self.tracked_tables.len());
258        for id in modified_table_ids {
259            let id = id as usize;
260            debug!("Table {} has been modified", id);
261            if id >= result.len() {
262                warn!(
263                    "Received update for table {id}, but only tracking {} tables",
264                    self.tracked_tables.len(),
265                );
266                // We need to grow on the index + 1.
267                result.grow(id + 1);
268            }
269            result.set(id, true);
270        }
271
272        result
273    }
274
275    fn should_sync(&self, watcher: &Watcher) -> Option<u64> {
276        let service_version = watcher.tables_version();
277        if service_version == self.last_sync_version {
278            None
279        } else {
280            Some(service_version)
281        }
282    }
283
284    /// Determine which tables should start and/or stop being watched.
285    fn calculate_sync_changes(
286        &self,
287        watcher: &Watcher,
288    ) -> Option<(FixedBitSet, Vec<ObservedTableOp>)> {
289        trace!("Calculating sync changes");
290        let (new_tracker_state, tracker_changes) =
291            watcher.calculate_sync_changes(&self.tracked_tables);
292
293        if tracker_changes.is_empty() {
294            return None;
295        }
296
297        Some((new_tracker_state, tracker_changes))
298    }
299
300    /// Once we are satisfied with the changes, apply the new state.
301    fn apply_sync_changes(&mut self, new_tracker_state: FixedBitSet, new_version: u64) {
302        // Update local tracker bitset
303        trace!("Applying sync changes");
304        self.tracked_tables = new_tracker_state;
305        self.last_sync_version = new_version;
306    }
307}
308
309/// Connection abstraction that provides on possible implementation which uses the building
310/// blocks ([`State`]) provided by this crate.
311///
312/// For simplicity, it takes ownership of an existing type which implements [`SqlExecutor`] and
313/// initializes all the tracking infrastructure. The original type can still be accessed as
314/// [`Connection`] implements both [`Deref`] and [`DerefMut`].
315///
316/// # Remarks
317///
318/// To make sure all changes are capture, it's recommended to always call
319/// [`Connection::sync_watcher_tables()`]
320/// before any query/statement or transaction.
321///
322/// # Example
323///
324/// ## Single Query/Statement
325///
326/// ```rust
327/// use sqlite_watcher::connection::Connection;
328/// use sqlite_watcher::connection::SqlExecutor;
329/// use sqlite_watcher::watcher::Watcher;
330///
331/// pub fn track_changes<C:SqlExecutor>(connection: C) {
332///     let watcher = Watcher::new().unwrap();
333///     let mut connection = Connection::new(connection, watcher).unwrap();
334///
335///     // Sync tables so we are up to date.
336///     connection.sync_watcher_tables().unwrap();
337///
338///     connection.sql_execute("sql query here").unwrap();
339///
340///     // Publish changes to the watcher
341///     connection.publish_watcher_changes().unwrap();
342/// }
343/// ```
344///
345/// ## Transaction
346///
347/// ```rust
348/// use sqlite_watcher::connection::Connection;
349/// use sqlite_watcher::connection::{SqlExecutor};
350/// use sqlite_watcher::watcher::Watcher;
351///
352/// pub fn track_changes<C:SqlExecutor>(connection: C) {
353///     let watcher = Watcher::new().unwrap();
354///     let mut connection = Connection::new(connection, watcher).unwrap();
355///
356///     // Sync tables so we are up to date.
357///     connection.sync_watcher_tables().unwrap();
358///
359///     // Start a transaction
360///     connection.sql_execute("sql query here").unwrap();
361///     connection.sql_execute("sql query here").unwrap();
362///     // Commit transaction
363///
364///     // Publish changes to the watcher
365///     connection.publish_watcher_changes().unwrap();
366/// }
367/// ```
368pub struct Connection<C: SqlExecutorMut> {
369    state: State,
370    watcher: Arc<Watcher>,
371    connection: C,
372}
373impl<C: SqlExecutorMut> Connection<C> {
374    /// Create a new connection with `connection` and `watcher`.
375    ///
376    /// See [`State::start_tracking()`] for more information about initialization.
377    ///
378    /// # Errors
379    ///
380    /// Returns error if the initialization failed.
381    pub fn new(mut connection: C, watcher: Arc<Watcher>) -> Result<Self, C::Error> {
382        let state = State::new();
383        State::set_pragmas().execute_mut(&mut connection)?;
384        State::start_tracking().execute_mut(&mut connection)?;
385        Ok(Self {
386            state,
387            watcher,
388            connection,
389        })
390    }
391
392    /// Sync tables from the [`Watcher`] and update tracking infrastructure.
393    ///
394    /// See [`State::sync_tables()`] for more information.
395    ///
396    /// # Errors
397    ///
398    /// Returns error if we failed to sync the changes to the database.
399    pub fn sync_watcher_tables(&mut self) -> Result<(), C::Error> {
400        self.state
401            .sync_tables(&self.watcher)
402            .execute_mut(&mut self.connection)?;
403        Ok(())
404    }
405
406    /// Check if any tables have changed and notify the [`Watcher`]
407    ///
408    /// See [`State::publish_changes()`] for more information.
409    ///
410    /// It is recommended to call this method
411    ///
412    /// # Errors
413    ///
414    /// Returns error if we failed to check for changes.
415    pub fn publish_watcher_changes(&mut self) -> Result<(), C::Error> {
416        self.state
417            .publish_changes(&self.watcher)
418            .execute_mut(&mut self.connection)?;
419        Ok(())
420    }
421
422    /// Disable all tracking on this connection.
423    ///
424    /// See [`State::stop_tracking`] for more details.
425    ///
426    /// # Errors
427    ///
428    /// Returns error if the queries failed.
429    pub fn stop_tracking(&mut self) -> Result<(), C::Error> {
430        self.state
431            .stop_tracking(&self.watcher)
432            .execute_mut(&mut self.connection)?;
433        Ok(())
434    }
435
436    /// Consume the current connection and take ownership of the real sql connection.
437    ///
438    /// # Remarks
439    ///
440    /// This does not stop the tracking infrastructure enabled on the connection.
441    /// Use [`Self::stop_tracking()`] to disable it first.
442    pub fn take(self) -> C {
443        self.connection
444    }
445}
446
447/// Same as [`Connection`] but with an async executor.
448#[allow(clippy::module_name_repetitions)]
449pub struct ConnectionAsync<C: SqlExecutorAsync> {
450    state: State,
451    watcher: Arc<Watcher>,
452    connection: C,
453}
454impl<C: SqlExecutorAsync> ConnectionAsync<C> {
455    /// Create a new connection with `connection` and `watcher`.
456    ///
457    /// See [`State::start_tracking()`] for more information about initialization.
458    ///
459    /// # Errors
460    ///
461    /// Returns error if the initialization failed.
462    pub async fn new(mut connection: C, watcher: Arc<Watcher>) -> Result<Self, C::Error> {
463        let state = State::new();
464        State::set_pragmas().execute_async(&mut connection).await?;
465        State::start_tracking()
466            .execute_async(&mut connection)
467            .await?;
468        Ok(Self {
469            state,
470            watcher,
471            connection,
472        })
473    }
474
475    /// See [`Connection::sync_watcher_tables`] for more details.
476    ///
477    /// # Errors
478    ///
479    /// Returns error if we failed to sync the changes to the database.
480    pub async fn sync_watcher_tables(&mut self) -> Result<(), C::Error> {
481        self.state
482            .sync_tables(&self.watcher)
483            .execute_async(&mut self.connection)
484            .await?;
485        Ok(())
486    }
487
488    /// See [`Connection::publish_watcher_changes`] for more details.
489    ///
490    /// # Errors
491    ///
492    /// Returns error if we failed to check for changes.
493    pub async fn publish_watcher_changes(&mut self) -> Result<(), C::Error> {
494        self.state
495            .publish_changes(&self.watcher)
496            .execute_async(&mut self.connection)
497            .await?;
498        Ok(())
499    }
500
501    /// See [`Connection::stop_tracking`] for more details.
502    ///
503    /// # Errors
504    ///
505    /// Returns error if the queries failed.
506    pub async fn stop_tracking(&mut self) -> Result<(), C::Error> {
507        self.state
508            .stop_tracking(&self.watcher)
509            .execute_async(&mut self.connection)
510            .await?;
511        Ok(())
512    }
513
514    /// Consume the current connection and take ownership of the real sql connection.
515    ///
516    /// # Remarks
517    ///
518    /// This does not stop the tracking infrastructure enabled on the connection.
519    /// Use [`Self::stop_tracking()`] to disable it first.
520    pub fn take(self) -> C {
521        self.connection
522    }
523}
524
525impl<C: SqlExecutorAsync> Deref for ConnectionAsync<C> {
526    type Target = C;
527
528    fn deref(&self) -> &Self::Target {
529        &self.connection
530    }
531}
532
533impl<C: SqlExecutorAsync> DerefMut for ConnectionAsync<C> {
534    fn deref_mut(&mut self) -> &mut Self::Target {
535        &mut self.connection
536    }
537}
538
539impl<C: SqlExecutorAsync> AsRef<C> for ConnectionAsync<C> {
540    fn as_ref(&self) -> &C {
541        &self.connection
542    }
543}
544
545impl<C: SqlExecutorAsync> AsMut<C> for ConnectionAsync<C> {
546    fn as_mut(&mut self) -> &mut C {
547        &mut self.connection
548    }
549}
550
551impl<C: SqlExecutor> Deref for Connection<C> {
552    type Target = C;
553
554    fn deref(&self) -> &Self::Target {
555        &self.connection
556    }
557}
558
559impl<C: SqlExecutor> DerefMut for Connection<C> {
560    fn deref_mut(&mut self) -> &mut Self::Target {
561        &mut self.connection
562    }
563}
564
565impl<C: SqlExecutor> AsRef<C> for Connection<C> {
566    fn as_ref(&self) -> &C {
567        &self.connection
568    }
569}
570
571impl<C: SqlExecutor> AsMut<C> for Connection<C> {
572    fn as_mut(&mut self) -> &mut C {
573        &mut self.connection
574    }
575}
576
577const TRACKER_TABLE_NAME: &str = "rsqlite_watcher_version_tracker";
578
579const TRIGGER_LIST: [(&str, &str); 3] = [
580    ("INSERT", "insert"),
581    ("UPDATE", "update"),
582    ("DELETE", "delete"),
583];
584
585#[inline]
586fn create_tracking_table_query() -> String {
587    format!(
588        "CREATE TEMP TABLE IF NOT EXISTS `{TRACKER_TABLE_NAME}` (table_id INTEGER PRIMARY KEY, updated INTEGER)"
589    )
590}
591#[inline]
592fn empty_tracking_table_query() -> String {
593    format!("DELETE FROM `{TRACKER_TABLE_NAME}`")
594}
595#[inline]
596fn drop_tracking_table_query() -> String {
597    format!("DROP TABLE IF EXISTS `{TRACKER_TABLE_NAME}`")
598}
599
600#[inline]
601fn create_trigger_query(
602    table_name: &str,
603    trigger: &str,
604    trigger_name: &str,
605    table_id: usize,
606) -> String {
607    format!(
608        r"
609CREATE TEMP TRIGGER IF NOT EXISTS `{TRACKER_TABLE_NAME}_trigger_{table_name}_{trigger_name}` AFTER {trigger} ON `{table_name}`
610BEGIN
611    UPDATE  `{TRACKER_TABLE_NAME}` SET updated=1 WHERE table_id={table_id};
612END
613            "
614    )
615}
616
617#[inline]
618fn insert_table_id_into_tracking_table_query(id: usize) -> String {
619    format!("INSERT INTO `{TRACKER_TABLE_NAME}` VALUES ({id},0)")
620}
621
622#[inline]
623fn drop_trigger_query(table_name: &str, trigger_name: &str) -> String {
624    format!("DROP TRIGGER IF EXISTS `{TRACKER_TABLE_NAME}_trigger_{table_name}_{trigger_name}`")
625}
626
627#[inline]
628fn remove_table_id_from_tracking_table_query(table_id: usize) -> String {
629    format!("DELETE FROM `{TRACKER_TABLE_NAME}` WHERE table_id={table_id}")
630}
631
632#[inline]
633fn select_updated_tables_query() -> String {
634    format!("SELECT table_id  FROM `{TRACKER_TABLE_NAME}` WHERE updated=1")
635}
636
637#[inline]
638fn reset_updated_tables_query() -> String {
639    format!("UPDATE `{TRACKER_TABLE_NAME}` SET updated=0 WHERE updated=1")
640}
641
642/// Create tracking triggers for `table` with `id`.
643fn create_triggers(table: &str, id: usize) -> Vec<SqlExecuteStatement<String>> {
644    TRIGGER_LIST
645        .iter()
646        .map(|(trigger, trigger_name)| {
647            let query = create_trigger_query(table, trigger, trigger_name, id);
648            SqlExecuteStatement::new(query)
649        })
650        .chain(std::iter::once_with(|| {
651            let query = insert_table_id_into_tracking_table_query(id);
652            SqlExecuteStatement::new(query)
653        }))
654        .collect()
655}
656
657/// Remove tracking triggers for `table` with `id`.
658fn drop_triggers(table: &str, id: usize) -> Vec<SqlExecuteStatement<String>> {
659    TRIGGER_LIST
660        .iter()
661        .map(|(_, trigger_name)| {
662            let query = drop_trigger_query(table, trigger_name);
663            SqlExecuteStatement::new(query)
664        })
665        .chain(std::iter::once_with(|| {
666            let query = remove_table_id_from_tracking_table_query(id);
667            SqlExecuteStatement::new(query)
668        }))
669        .collect()
670}
671
672/// Apply the new tracked table state to a [`State`].
673struct ConcludeStateChangeStatement<'s> {
674    state: &'s mut State,
675    tracked_tables: FixedBitSet,
676    new_version: u64,
677}
678
679impl Sealed for ConcludeStateChangeStatement<'_> {}
680impl Statement for ConcludeStateChangeStatement<'_> {
681    type Output = ();
682    fn execute<S: SqlExecutor>(self, _: &S) -> Result<Self::Output, S::Error> {
683        self.state
684            .apply_sync_changes(self.tracked_tables, self.new_version);
685        Ok(())
686    }
687
688    fn execute_mut<S: SqlExecutorMut>(self, _: &mut S) -> Result<Self::Output, S::Error> {
689        self.state
690            .apply_sync_changes(self.tracked_tables, self.new_version);
691        Ok(())
692    }
693
694    async fn execute_async<S: SqlExecutorAsync>(self, _: &mut S) -> Result<Self::Output, S::Error> {
695        self.state
696            .apply_sync_changes(self.tracked_tables, self.new_version);
697        Ok(())
698    }
699}
700
701/// Calculate what the changes to be sent to the watcher.
702struct CalculateWatcherUpdatesStatement<'s> {
703    state: &'s State,
704}
705
706impl StatementWithInput for CalculateWatcherUpdatesStatement<'_> {
707    type Input = Vec<u32>;
708    type Output = FixedBitSet;
709
710    fn execute<S: SqlExecutor>(self, _: &S, input: Self::Input) -> Result<Self::Output, S::Error> {
711        Ok(self.state.prepare_watcher_changes(input))
712    }
713    fn execute_mut<S: SqlExecutorMut>(
714        self,
715        _: &mut S,
716        input: Self::Input,
717    ) -> Result<Self::Output, S::Error> {
718        Ok(self.state.prepare_watcher_changes(input))
719    }
720    async fn execute_async<S: SqlExecutorAsync>(
721        self,
722        _: &mut S,
723        input: Self::Input,
724    ) -> Result<Self::Output, S::Error> {
725        Ok(self.state.prepare_watcher_changes(input))
726    }
727}
728
729/// Publish the changes to the watcher.
730struct PublishWatcherChangesStatement<'w>(&'w Watcher);
731
732impl Sealed for PublishWatcherChangesStatement<'_> {}
733
734impl StatementWithInput for PublishWatcherChangesStatement<'_> {
735    type Input = FixedBitSet;
736    type Output = ();
737
738    fn execute<S: SqlExecutor>(self, _: &S, input: Self::Input) -> Result<Self::Output, S::Error> {
739        self.0.publish_changes(input);
740        Ok(())
741    }
742    fn execute_mut<S: SqlExecutorMut>(
743        self,
744        _: &mut S,
745        input: Self::Input,
746    ) -> Result<Self::Output, S::Error> {
747        self.0.publish_changes(input);
748        Ok(())
749    }
750    async fn execute_async<S: SqlExecutorAsync>(
751        self,
752        _: &mut S,
753        input: Self::Input,
754    ) -> Result<Self::Output, S::Error> {
755        self.0.publish_changes_async(input).await;
756        Ok(())
757    }
758}
759
760impl Sealed for SqlReadTableIdsStatement {}
761struct SqlReadTableIdsStatement;
762impl Statement for SqlReadTableIdsStatement {
763    type Output = Vec<u32>;
764    fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
765        connection.sql_query_values(&select_updated_tables_query())
766    }
767    fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
768        connection.sql_query_values(&select_updated_tables_query())
769    }
770    async fn execute_async<S: SqlExecutorAsync>(
771        self,
772        connection: &mut S,
773    ) -> Result<Self::Output, S::Error> {
774        connection
775            .sql_query_values(&select_updated_tables_query())
776            .await
777    }
778}
779
780/// It is possible on certain circumstances that if a connection can have leftover
781/// tracking data that is not cleared. To make sure this is reset, we force empty
782/// the table if we detect that we are not watching any tables at the moment.
783struct MaybeResetResultsQuery;
784impl StatementWithInput for MaybeResetResultsQuery {
785    type Input = FixedBitSet;
786    type Output = FixedBitSet;
787
788    fn execute<S: SqlExecutor>(
789        self,
790        connection: &S,
791        input: Self::Input,
792    ) -> Result<Self::Output, S::Error> {
793        if !input.is_clear() {
794            // Reset updated values.
795            connection.sql_execute(&reset_updated_tables_query())?;
796        }
797        Ok(input)
798    }
799    fn execute_mut<S: SqlExecutorMut>(
800        self,
801        connection: &mut S,
802        input: Self::Input,
803    ) -> Result<Self::Output, S::Error> {
804        if !input.is_clear() {
805            // Reset updated values.
806            connection.sql_execute(&reset_updated_tables_query())?;
807        }
808        Ok(input)
809    }
810    async fn execute_async<S: SqlExecutorAsync>(
811        self,
812        connection: &mut S,
813        input: Self::Input,
814    ) -> Result<Self::Output, S::Error> {
815        if !input.is_clear() {
816            // Reset updated values.
817            connection
818                .sql_execute(&reset_updated_tables_query())
819                .await?;
820        }
821        Ok(input)
822    }
823}
824
825#[cfg(test)]
826mod test {
827    use crate::connection::State;
828    use crate::watcher::tests::new_test_observer;
829    use crate::watcher::{ObservedTableOp, TableObserver, Watcher};
830    use std::collections::BTreeSet;
831    use std::ops::Index;
832    use std::sync::Mutex;
833    use std::sync::mpsc::{Receiver, SyncSender};
834
835    pub struct TestObserver {
836        expected: Mutex<Vec<BTreeSet<String>>>,
837        tables: Vec<String>,
838        // Channel is here to make sure we don't trigger a merge of multiple pending updates.
839        checked_channel: SyncSender<()>,
840    }
841
842    impl TestObserver {
843        pub fn new(
844            tables: Vec<String>,
845            expected: impl IntoIterator<Item = BTreeSet<String>>,
846        ) -> (Self, Receiver<()>) {
847            let (sender, receiver) = std::sync::mpsc::sync_channel::<()>(0);
848            let mut expected = expected.into_iter().collect::<Vec<_>>();
849            expected.reverse();
850            (
851                Self {
852                    expected: Mutex::new(expected),
853                    tables,
854                    checked_channel: sender,
855                },
856                receiver,
857            )
858        }
859    }
860
861    impl TableObserver for TestObserver {
862        fn tables(&self) -> Vec<String> {
863            self.tables.clone()
864        }
865
866        fn on_tables_changed(&self, tables: &BTreeSet<String>) {
867            let expected = self.expected.lock().unwrap().pop().unwrap();
868            assert_eq!(*tables, expected);
869            self.checked_channel.send(()).unwrap();
870        }
871    }
872
873    #[test]
874    fn connection_state() {
875        let service = Watcher::new().unwrap();
876
877        let observer_1 = new_test_observer(["foo", "bar"]);
878        let observer_2 = new_test_observer(["bar"]);
879        let observer_3 = new_test_observer(["bar", "omega"]);
880
881        let mut local_state = State::new();
882
883        assert!(local_state.should_sync(&service).is_none());
884        let observer_id_1 = service.add_observer(observer_1).unwrap();
885        let foo_table_id = service.get_table_id("foo").unwrap();
886        let bar_table_id = service.get_table_id("bar").unwrap();
887        {
888            let new_version = local_state
889                .should_sync(&service)
890                .expect("Should have new version");
891            let (tracker, ops) = local_state
892                .calculate_sync_changes(&service)
893                .expect("must have changes");
894            assert!(tracker[bar_table_id]);
895            assert!(tracker[foo_table_id]);
896            assert_eq!(ops.len(), 2);
897            assert_eq!(
898                ops[0],
899                ObservedTableOp::Add("bar".to_string(), bar_table_id)
900            );
901            assert_eq!(
902                ops[1],
903                ObservedTableOp::Add("foo".to_string(), foo_table_id)
904            );
905
906            local_state.apply_sync_changes(tracker, new_version);
907        }
908
909        let observer_id_2 = service.add_observer(observer_2).unwrap();
910        assert!(local_state.should_sync(&service).is_none());
911
912        let observer_id_3 = service.add_observer(observer_3).unwrap();
913        let omega_table_id = service.get_table_id("omega").unwrap();
914        {
915            let new_version = local_state
916                .should_sync(&service)
917                .expect("Should have new version");
918            let (tracker, ops) = local_state
919                .calculate_sync_changes(&service)
920                .expect("must have changes");
921            assert!(tracker[foo_table_id]);
922            assert!(tracker[bar_table_id]);
923            assert!(tracker[omega_table_id]);
924            assert_eq!(ops.len(), 1);
925            assert_eq!(
926                ops[0],
927                ObservedTableOp::Add("omega".to_string(), omega_table_id)
928            );
929
930            local_state.apply_sync_changes(tracker, new_version);
931        }
932
933        service.remove_observer(observer_id_2).unwrap();
934        assert!(local_state.should_sync(&service).is_none());
935
936        service.remove_observer(observer_id_3).unwrap();
937        {
938            let new_version = local_state
939                .should_sync(&service)
940                .expect("Should have new version");
941            let (tracker, ops) = local_state
942                .calculate_sync_changes(&service)
943                .expect("must have changes");
944            assert!(tracker[foo_table_id]);
945            assert!(tracker[bar_table_id]);
946            assert!(!tracker[omega_table_id]);
947            assert_eq!(ops.len(), 1);
948            assert_eq!(
949                ops[0],
950                ObservedTableOp::Remove("omega".to_string(), omega_table_id)
951            );
952
953            local_state.apply_sync_changes(tracker, new_version);
954        }
955
956        service.remove_observer(observer_id_1).unwrap();
957        {
958            let new_version = local_state
959                .should_sync(&service)
960                .expect("Should have new version");
961            let (tracker, ops) = local_state
962                .calculate_sync_changes(&service)
963                .expect("must have changes");
964            assert!(!tracker[foo_table_id]);
965            assert!(!tracker[bar_table_id]);
966            assert!(!tracker[omega_table_id]);
967            assert_eq!(ops.len(), 2);
968            assert_eq!(
969                ops[1],
970                ObservedTableOp::Remove("foo".to_string(), foo_table_id)
971            );
972            assert_eq!(
973                ops[0],
974                ObservedTableOp::Remove("bar".to_string(), bar_table_id)
975            );
976
977            local_state.apply_sync_changes(tracker, new_version);
978        }
979    }
980
981    #[test]
982    fn prepare_watcher_changes_out_of_bounds_table_id() {
983        // Empty state but we receive update for positions that we do not know about.
984        let state = State::new();
985        let result = state.prepare_watcher_changes(vec![4, 3]);
986        assert_eq!(result.len(), 5);
987        assert!(result.index(4));
988        assert!(result.index(3));
989        // State with N tables, but we receive an update for table N + 1
990        let mut state = State::new();
991        state.tracked_tables.grow(4);
992        let result = state.prepare_watcher_changes(vec![4, 8]);
993        assert_eq!(result.len(), 9);
994        assert!(result.index(4));
995        assert!(result.index(8));
996    }
997}