spacetimedb/subscription/
module_subscription_actor.rs

1use super::execution_unit::QueryHash;
2use super::module_subscription_manager::{
3    spawn_send_worker, BroadcastError, BroadcastQueue, Plan, SubscriptionGaugeStats, SubscriptionManager,
4};
5use super::query::compile_query_with_hashes;
6use super::tx::DeltaTx;
7use super::{collect_table_update, TableUpdateType};
8use crate::client::messages::{
9    SerializableMessage, SubscriptionData, SubscriptionError, SubscriptionMessage, SubscriptionResult,
10    SubscriptionRows, SubscriptionUpdateMessage, TransactionUpdateMessage,
11};
12use crate::client::{ClientActorId, ClientConnectionSender, Protocol};
13use crate::db::relational_db::{MutTx, RelationalDB, Tx};
14use crate::error::DBError;
15use crate::estimation::estimate_rows_scanned;
16use crate::host::module_host::{DatabaseUpdate, EventStatus, ModuleEvent};
17use crate::messages::websocket::Subscribe;
18use crate::subscription::execute_plans;
19use crate::subscription::query::is_subscribe_to_all_tables;
20use crate::util::prometheus_handle::IntGaugeExt;
21use crate::vm::check_row_limit;
22use crate::worker_metrics::WORKER_METRICS;
23use parking_lot::RwLock;
24use prometheus::{Histogram, HistogramTimer, IntCounter, IntGauge};
25use spacetimedb_client_api_messages::websocket::{
26    self as ws, BsatnFormat, FormatSwitch, JsonFormat, SubscribeMulti, SubscribeSingle, TableUpdate, Unsubscribe,
27    UnsubscribeMulti,
28};
29use spacetimedb_datastore::db_metrics::DB_METRICS;
30use spacetimedb_datastore::execution_context::{Workload, WorkloadType};
31use spacetimedb_datastore::locking_tx_datastore::TxId;
32use spacetimedb_execution::pipelined::PipelinedProject;
33use spacetimedb_lib::identity::AuthCtx;
34use spacetimedb_lib::metrics::ExecutionMetrics;
35use spacetimedb_lib::Identity;
36use std::{sync::Arc, time::Instant};
37
38type Subscriptions = Arc<RwLock<SubscriptionManager>>;
39
40#[derive(Clone)]
41pub struct ModuleSubscriptions {
42    relational_db: Arc<RelationalDB>,
43    /// If taking a lock (tx) on the db at the same time, ALWAYS lock the db first.
44    /// You will deadlock otherwise.
45    subscriptions: Subscriptions,
46    broadcast_queue: BroadcastQueue,
47    owner_identity: Identity,
48    stats: Arc<SubscriptionGauges>,
49}
50
51#[derive(Debug, Clone)]
52pub struct SubscriptionGauges {
53    db_identity: Identity,
54    num_queries: IntGauge,
55    num_connections: IntGauge,
56    num_subscription_sets: IntGauge,
57    num_query_subscriptions: IntGauge,
58    num_legacy_subscriptions: IntGauge,
59}
60
61impl SubscriptionGauges {
62    fn new(db_identity: &Identity) -> Self {
63        let num_queries = WORKER_METRICS.subscription_queries.with_label_values(db_identity);
64        let num_connections = DB_METRICS.subscription_connections.with_label_values(db_identity);
65        let num_subscription_sets = DB_METRICS.subscription_sets.with_label_values(db_identity);
66        let num_query_subscriptions = DB_METRICS.total_query_subscriptions.with_label_values(db_identity);
67        let num_legacy_subscriptions = DB_METRICS.num_legacy_subscriptions.with_label_values(db_identity);
68        Self {
69            db_identity: *db_identity,
70            num_queries,
71            num_connections,
72            num_subscription_sets,
73            num_query_subscriptions,
74            num_legacy_subscriptions,
75        }
76    }
77
78    // Clear the subscription gauges for this database.
79    fn unregister(&self) {
80        let _ = WORKER_METRICS
81            .subscription_queries
82            .remove_label_values(&self.db_identity);
83        let _ = DB_METRICS
84            .subscription_connections
85            .remove_label_values(&self.db_identity);
86        let _ = DB_METRICS.subscription_sets.remove_label_values(&self.db_identity);
87        let _ = DB_METRICS
88            .total_query_subscriptions
89            .remove_label_values(&self.db_identity);
90        let _ = DB_METRICS
91            .num_legacy_subscriptions
92            .remove_label_values(&self.db_identity);
93    }
94
95    fn report(&self, stats: &SubscriptionGaugeStats) {
96        self.num_queries.set(stats.num_queries as i64);
97        self.num_connections.set(stats.num_connections as i64);
98        self.num_subscription_sets.set(stats.num_subscription_sets as i64);
99        self.num_query_subscriptions.set(stats.num_query_subscriptions as i64);
100        self.num_legacy_subscriptions.set(stats.num_legacy_subscriptions as i64);
101    }
102}
103
104pub struct SubscriptionMetrics {
105    pub lock_waiters: IntGauge,
106    pub lock_wait_time: Histogram,
107    pub compilation_time: Histogram,
108    pub num_queries_subscribed: IntCounter,
109    pub num_new_queries_subscribed: IntCounter,
110    pub num_queries_evaluated: IntCounter,
111}
112
113impl SubscriptionMetrics {
114    pub fn new(db: &Identity, workload: &WorkloadType) -> Self {
115        Self {
116            lock_waiters: DB_METRICS.subscription_lock_waiters.with_label_values(db, workload),
117            lock_wait_time: DB_METRICS.subscription_lock_wait_time.with_label_values(db, workload),
118            compilation_time: DB_METRICS.subscription_compile_time.with_label_values(db, workload),
119            num_queries_subscribed: DB_METRICS.num_queries_subscribed.with_label_values(db),
120            num_new_queries_subscribed: DB_METRICS.num_new_queries_subscribed.with_label_values(db),
121            num_queries_evaluated: DB_METRICS.num_queries_evaluated.with_label_values(db, workload),
122        }
123    }
124}
125
126type AssertTxFn = Arc<dyn Fn(&Tx)>;
127type SubscriptionUpdate = FormatSwitch<TableUpdate<BsatnFormat>, TableUpdate<JsonFormat>>;
128type FullSubscriptionUpdate = FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>;
129
130/// A utility for sending an error message to a client and returning early
131macro_rules! return_on_err {
132    ($expr:expr, $handler:expr, $metrics:expr) => {
133        match $expr {
134            Ok(val) => val,
135            Err(e) => {
136                // TODO: Handle errors sending messages.
137                let _ = $handler(e.to_string().into());
138                return Ok($metrics);
139            }
140        }
141    };
142}
143
144/// A utility for sending an error message to a client and returning early
145macro_rules! return_on_err_with_sql {
146    ($expr:expr, $sql:expr, $handler:expr) => {
147        match $expr.map_err(|err| DBError::WithSql {
148            sql: $sql.into(),
149            error: Box::new(DBError::Other(err.into())),
150        }) {
151            Ok(val) => val,
152            Err(e) => {
153                // TODO: Handle errors sending messages.
154                let _ = $handler(e.to_string().into());
155                return Ok(None);
156            }
157        }
158    };
159}
160
161impl ModuleSubscriptions {
162    pub fn new(
163        relational_db: Arc<RelationalDB>,
164        subscriptions: Subscriptions,
165        broadcast_queue: BroadcastQueue,
166        owner_identity: Identity,
167    ) -> Self {
168        let db = &relational_db.database_identity();
169        let stats = Arc::new(SubscriptionGauges::new(db));
170
171        Self {
172            relational_db,
173            subscriptions,
174            broadcast_queue,
175            owner_identity,
176            stats,
177        }
178    }
179
180    /// Construct a new [`ModuleSubscriptions`] for use in testing,
181    /// creating a new [`tokio::runtime::Runtime`] to run its send worker.
182    pub fn for_test_new_runtime(db: Arc<RelationalDB>) -> (ModuleSubscriptions, tokio::runtime::Runtime) {
183        let runtime = tokio::runtime::Runtime::new().unwrap();
184        let _rt = runtime.enter();
185        (Self::for_test_enclosing_runtime(db), runtime)
186    }
187
188    /// Construct a new [`ModuleSubscriptions`] for use in testing,
189    /// running its send worker on the dynamically enclosing [`tokio::runtime::Runtime`]
190    pub fn for_test_enclosing_runtime(db: Arc<RelationalDB>) -> ModuleSubscriptions {
191        let send_worker_queue = spawn_send_worker(None);
192        ModuleSubscriptions::new(
193            db,
194            SubscriptionManager::for_test_without_metrics_arc_rwlock(),
195            send_worker_queue,
196            Identity::ZERO,
197        )
198    }
199
200    // Recompute gauges to update metrics.
201    pub fn update_gauges(&self) {
202        let num_queries = self.subscriptions.read().calculate_gauge_stats();
203        self.stats.report(&num_queries);
204    }
205
206    // Remove the subscription gauges for this database.
207    // TODO: This should be called when the database is shut down.
208    pub fn remove_gauges(&self) {
209        self.stats.unregister();
210    }
211
212    /// Run auth and row limit checks for a new subscriber, then compute the initial query results.
213    fn evaluate_initial_subscription(
214        &self,
215        sender: Arc<ClientConnectionSender>,
216        query: Arc<Plan>,
217        tx: &TxId,
218        auth: &AuthCtx,
219        update_type: TableUpdateType,
220    ) -> Result<(SubscriptionUpdate, ExecutionMetrics), DBError> {
221        check_row_limit(
222            &[&query],
223            &self.relational_db,
224            tx,
225            |plan, tx| {
226                plan.plans_fragments()
227                    .map(|plan_fragment| estimate_rows_scanned(tx, plan_fragment.optimized_physical_plan()))
228                    .fold(0, |acc, rows_scanned| acc.saturating_add(rows_scanned))
229            },
230            auth,
231        )?;
232
233        let table_id = query.subscribed_table_id();
234        let table_name = query.subscribed_table_name();
235
236        let plans = query
237            .plans_fragments()
238            .map(|fragment| fragment.optimized_physical_plan())
239            .cloned()
240            .map(|plan| plan.optimize())
241            .collect::<Result<Vec<_>, _>>()?
242            .into_iter()
243            .map(PipelinedProject::from)
244            .collect::<Vec<_>>();
245
246        let tx = DeltaTx::from(tx);
247
248        Ok(match sender.config.protocol {
249            Protocol::Binary => collect_table_update(&plans, table_id, table_name.into(), &tx, update_type)
250                .map(|(table_update, metrics)| (FormatSwitch::Bsatn(table_update), metrics)),
251            Protocol::Text => collect_table_update(&plans, table_id, table_name.into(), &tx, update_type)
252                .map(|(table_update, metrics)| (FormatSwitch::Json(table_update), metrics)),
253        }?)
254    }
255
256    fn evaluate_queries(
257        &self,
258        sender: Arc<ClientConnectionSender>,
259        queries: &[Arc<Plan>],
260        tx: &TxId,
261        auth: &AuthCtx,
262        update_type: TableUpdateType,
263    ) -> Result<(FullSubscriptionUpdate, ExecutionMetrics), DBError> {
264        check_row_limit(
265            queries,
266            &self.relational_db,
267            tx,
268            |plan, tx| {
269                plan.plans_fragments()
270                    .map(|plan_fragment| estimate_rows_scanned(tx, plan_fragment.optimized_physical_plan()))
271                    .fold(0, |acc, rows_scanned| acc.saturating_add(rows_scanned))
272            },
273            auth,
274        )?;
275
276        let tx = DeltaTx::from(tx);
277        match sender.config.protocol {
278            Protocol::Binary => {
279                let (update, metrics) = execute_plans(queries, &tx, update_type)?;
280                Ok((FormatSwitch::Bsatn(update), metrics))
281            }
282            Protocol::Text => {
283                let (update, metrics) = execute_plans(queries, &tx, update_type)?;
284                Ok((FormatSwitch::Json(update), metrics))
285            }
286        }
287    }
288
289    /// Add a subscription to a single query.
290    #[tracing::instrument(level = "trace", skip_all)]
291    pub fn add_single_subscription(
292        &self,
293        sender: Arc<ClientConnectionSender>,
294        request: SubscribeSingle,
295        timer: Instant,
296        _assert: Option<AssertTxFn>,
297    ) -> Result<Option<ExecutionMetrics>, DBError> {
298        // Send an error message to the client
299        let send_err_msg = |message| {
300            self.broadcast_queue.send_client_message(
301                sender.clone(),
302                SubscriptionMessage {
303                    request_id: Some(request.request_id),
304                    query_id: Some(request.query_id),
305                    timer: Some(timer),
306                    result: SubscriptionResult::Error(SubscriptionError {
307                        table_id: None,
308                        message,
309                    }),
310                },
311            )
312        };
313
314        let sql = request.query;
315        let auth = AuthCtx::new(self.owner_identity, sender.id.identity);
316        let hash = QueryHash::from_string(&sql, auth.caller, false);
317        let hash_with_param = QueryHash::from_string(&sql, auth.caller, true);
318
319        let tx = scopeguard::guard(self.relational_db.begin_tx(Workload::Subscribe), |tx| {
320            let (tx_metrics, reducer) = self.relational_db.release_tx(tx);
321            self.relational_db.report_read_tx_metrics(reducer, tx_metrics);
322        });
323
324        let existing_query = {
325            let guard = self.subscriptions.read();
326            guard.query(&hash)
327        };
328
329        let query = return_on_err_with_sql!(
330            existing_query.map(Ok).unwrap_or_else(|| compile_query_with_hashes(
331                &auth,
332                &tx,
333                &sql,
334                hash,
335                hash_with_param
336            )
337            .map(Arc::new)),
338            sql,
339            send_err_msg
340        );
341
342        let (table_rows, metrics) = return_on_err_with_sql!(
343            self.evaluate_initial_subscription(sender.clone(), query.clone(), &tx, &auth, TableUpdateType::Subscribe),
344            query.sql(),
345            send_err_msg
346        );
347
348        // It acquires the subscription lock after `eval`, allowing `add_subscription` to run concurrently.
349        // This also makes it possible for `broadcast_event` to get scheduled before the subsequent part here
350        // but that should not pose an issue.
351        let mut subscriptions = self.subscriptions.write();
352        subscriptions.add_subscription(sender.clone(), query.clone(), request.query_id)?;
353
354        #[cfg(test)]
355        if let Some(assert) = _assert {
356            assert(&tx);
357        }
358
359        // Note: to make sure transaction updates are consistent, we need to put this in the broadcast
360        // queue while we are still holding a read-lock on the database.
361
362        // That will avoid race conditions because reducers first grab a write lock on the db, then
363        // grab a read lock on the subscriptions.
364
365        // Holding a write lock on `self.subscriptions` would also be sufficient.
366        let _ = self.broadcast_queue.send_client_message(
367            sender.clone(),
368            SubscriptionMessage {
369                request_id: Some(request.request_id),
370                query_id: Some(request.query_id),
371                timer: Some(timer),
372                result: SubscriptionResult::Subscribe(SubscriptionRows {
373                    table_id: query.subscribed_table_id(),
374                    table_name: query.subscribed_table_name().into(),
375                    table_rows,
376                }),
377            },
378        );
379        Ok(Some(metrics))
380    }
381
382    /// Remove a subscription for a single query.
383    pub fn remove_single_subscription(
384        &self,
385        sender: Arc<ClientConnectionSender>,
386        request: Unsubscribe,
387        timer: Instant,
388    ) -> Result<Option<ExecutionMetrics>, DBError> {
389        // Send an error message to the client
390        let send_err_msg = |message| {
391            self.broadcast_queue.send_client_message(
392                sender.clone(),
393                SubscriptionMessage {
394                    request_id: Some(request.request_id),
395                    query_id: Some(request.query_id),
396                    timer: Some(timer),
397                    result: SubscriptionResult::Error(SubscriptionError {
398                        table_id: None,
399                        message,
400                    }),
401                },
402            )
403        };
404
405        let mut subscriptions = self.subscriptions.write();
406
407        let queries = return_on_err!(
408            subscriptions.remove_subscription((sender.id.identity, sender.id.connection_id), request.query_id),
409            // Apparently we ignore errors sending messages.
410            send_err_msg,
411            None
412        );
413        // This is technically a bug, since this could be empty if the client has another duplicate subscription.
414        // This whole function should be removed soon, so I don't think we need to fix it.
415        let [query] = &*queries else {
416            // Apparently we ignore errors sending messages.
417            let _ = send_err_msg("Internal error".into());
418            return Ok(None);
419        };
420
421        let tx = scopeguard::guard(self.relational_db.begin_tx(Workload::Unsubscribe), |tx| {
422            let (tx_metrics, reducer) = self.relational_db.release_tx(tx);
423            self.relational_db.report_read_tx_metrics(reducer, tx_metrics);
424        });
425        let auth = AuthCtx::new(self.owner_identity, sender.id.identity);
426        let (table_rows, metrics) = return_on_err_with_sql!(
427            self.evaluate_initial_subscription(sender.clone(), query.clone(), &tx, &auth, TableUpdateType::Unsubscribe),
428            query.sql(),
429            send_err_msg
430        );
431
432        // Note: to make sure transaction updates are consistent, we need to put this in the broadcast
433        // queue while we are still holding a read-lock on the database.
434
435        // That will avoid race conditions because reducers first grab a write lock on the db, then
436        // grab a read lock on the subscriptions.
437
438        // Holding a write lock on `self.subscriptions` would also be sufficient.
439        let _ = self.broadcast_queue.send_client_message(
440            sender.clone(),
441            SubscriptionMessage {
442                request_id: Some(request.request_id),
443                query_id: Some(request.query_id),
444                timer: Some(timer),
445                result: SubscriptionResult::Unsubscribe(SubscriptionRows {
446                    table_id: query.subscribed_table_id(),
447                    table_name: query.subscribed_table_name().into(),
448                    table_rows,
449                }),
450            },
451        );
452        Ok(Some(metrics))
453    }
454
455    /// Remove a client's subscription for a set of queries.
456    #[tracing::instrument(level = "trace", skip_all)]
457    pub fn remove_multi_subscription(
458        &self,
459        sender: Arc<ClientConnectionSender>,
460        request: UnsubscribeMulti,
461        timer: Instant,
462    ) -> Result<Option<ExecutionMetrics>, DBError> {
463        // Send an error message to the client
464        let send_err_msg = |message| {
465            self.broadcast_queue.send_client_message(
466                sender.clone(),
467                SubscriptionMessage {
468                    request_id: Some(request.request_id),
469                    query_id: Some(request.query_id),
470                    timer: Some(timer),
471                    result: SubscriptionResult::Error(SubscriptionError {
472                        table_id: None,
473                        message,
474                    }),
475                },
476            )
477        };
478
479        let database_identity = self.relational_db.database_identity();
480        let subscription_metrics = SubscriptionMetrics::new(&database_identity, &WorkloadType::Unsubscribe);
481
482        // Always lock the db before the subscription lock to avoid deadlocks.
483        let tx = scopeguard::guard(self.relational_db.begin_tx(Workload::Unsubscribe), |tx| {
484            let (tx_metrics, reducer) = self.relational_db.release_tx(tx);
485            self.relational_db.report_read_tx_metrics(reducer, tx_metrics);
486        });
487
488        let removed_queries = {
489            let _compile_timer = subscription_metrics.compilation_time.start_timer();
490            let mut subscriptions = {
491                // How contended is the lock?
492                let _wait_guard = subscription_metrics.lock_waiters.inc_scope();
493                let _wait_timer = subscription_metrics.lock_wait_time.start_timer();
494                self.subscriptions.write()
495            };
496
497            return_on_err!(
498                subscriptions.remove_subscription((sender.id.identity, sender.id.connection_id), request.query_id),
499                send_err_msg,
500                None
501            )
502        };
503
504        let (update, metrics) = return_on_err!(
505            self.evaluate_queries(
506                sender.clone(),
507                &removed_queries,
508                &tx,
509                &AuthCtx::new(self.owner_identity, sender.id.identity),
510                TableUpdateType::Unsubscribe,
511            ),
512            send_err_msg,
513            None
514        );
515
516        // How many queries did we evaluate?
517        subscription_metrics
518            .num_queries_evaluated
519            .inc_by(removed_queries.len() as _);
520
521        // Note: to make sure transaction updates are consistent, we need to put this in the broadcast
522        // queue while we are still holding a read-lock on the database.
523
524        // That will avoid race conditions because reducers first grab a write lock on the db, then
525        // grab a read lock on the subscriptions.
526
527        // Holding a write lock on `self.subscriptions` would also be sufficient.
528        let _ = self.broadcast_queue.send_client_message(
529            sender,
530            SubscriptionMessage {
531                request_id: Some(request.request_id),
532                query_id: Some(request.query_id),
533                timer: Some(timer),
534                result: SubscriptionResult::UnsubscribeMulti(SubscriptionData { data: update }),
535            },
536        );
537
538        Ok(Some(metrics))
539    }
540
541    /// Compiles the queries in a [Subscribe] or [SubscribeMulti] message.
542    ///
543    /// Note, we hash queries to avoid recompilation,
544    /// but we need to know if a query is parameterized in order to hash it correctly.
545    /// This requires that we type check which in turn requires that we start a tx.
546    ///
547    /// Unfortunately parsing with sqlparser is quite expensive,
548    /// so we'd like to avoid that cost while holding the tx lock,
549    /// especially since all we're trying to do is generate a hash.
550    ///
551    /// Instead we generate two hashes and outside of the tx lock.
552    /// If either one is currently tracked, we can avoid recompilation.
553    #[allow(clippy::type_complexity)]
554    fn compile_queries(
555        &self,
556        sender: Identity,
557        queries: &[Box<str>],
558        num_queries: usize,
559        metrics: &SubscriptionMetrics,
560    ) -> Result<(Vec<Arc<Plan>>, AuthCtx, TxId, HistogramTimer), DBError> {
561        let mut subscribe_to_all_tables = false;
562        let mut plans = Vec::with_capacity(num_queries);
563        let mut query_hashes = Vec::with_capacity(num_queries);
564
565        for sql in queries {
566            let sql = sql.trim();
567            if is_subscribe_to_all_tables(sql) {
568                subscribe_to_all_tables = true;
569                continue;
570            }
571            let hash = QueryHash::from_string(sql, sender, false);
572            let hash_with_param = QueryHash::from_string(sql, sender, true);
573            query_hashes.push((sql, hash, hash_with_param));
574        }
575
576        let auth = AuthCtx::new(self.owner_identity, sender);
577
578        // We always get the db lock before the subscription lock to avoid deadlocks.
579        let tx = scopeguard::guard(self.relational_db.begin_tx(Workload::Subscribe), |tx| {
580            let (tx_metrics, reducer) = self.relational_db.release_tx(tx);
581            self.relational_db.report_read_tx_metrics(reducer, tx_metrics);
582        });
583
584        let compile_timer = metrics.compilation_time.start_timer();
585
586        let guard = {
587            // How contended is the lock?
588            let _wait_guard = metrics.lock_waiters.inc_scope();
589            let _wait_timer = metrics.lock_wait_time.start_timer();
590            self.subscriptions.read()
591        };
592
593        if subscribe_to_all_tables {
594            plans.extend(
595                super::subscription::get_all(&self.relational_db, &tx, &auth)?
596                    .into_iter()
597                    .map(Arc::new),
598            );
599        }
600
601        let mut new_queries = 0;
602
603        for (sql, hash, hash_with_param) in query_hashes {
604            if let Some(unit) = guard.query(&hash) {
605                plans.push(unit);
606            } else if let Some(unit) = guard.query(&hash_with_param) {
607                plans.push(unit);
608            } else {
609                plans.push(Arc::new(
610                    compile_query_with_hashes(&auth, &tx, sql, hash, hash_with_param).map_err(|err| {
611                        DBError::WithSql {
612                            error: Box::new(DBError::Other(err.into())),
613                            sql: sql.into(),
614                        }
615                    })?,
616                ));
617                new_queries += 1;
618            }
619        }
620
621        // How many queries in this subscription are not cached?
622        metrics.num_new_queries_subscribed.inc_by(new_queries);
623
624        Ok((plans, auth, scopeguard::ScopeGuard::into_inner(tx), compile_timer))
625    }
626
627    /// Send a message to a client connection.
628    /// This will eventually be sent by the send-worker.
629    /// This takes a `TxId`, because this should be called while still holding a lock on the database.
630    pub fn send_client_message(
631        &self,
632        recipient: Arc<ClientConnectionSender>,
633        message: impl Into<SerializableMessage>,
634        _tx_id: &TxId,
635    ) -> Result<(), BroadcastError> {
636        self.broadcast_queue.send_client_message(recipient, message)
637    }
638
639    #[tracing::instrument(level = "trace", skip_all)]
640    pub fn add_multi_subscription(
641        &self,
642        sender: Arc<ClientConnectionSender>,
643        request: SubscribeMulti,
644        timer: Instant,
645        _assert: Option<AssertTxFn>,
646    ) -> Result<Option<ExecutionMetrics>, DBError> {
647        // Send an error message to the client
648        let send_err_msg = |message| {
649            let _ = self.broadcast_queue.send_client_message(
650                sender.clone(),
651                SubscriptionMessage {
652                    request_id: Some(request.request_id),
653                    query_id: Some(request.query_id),
654                    timer: Some(timer),
655                    result: SubscriptionResult::Error(SubscriptionError {
656                        table_id: None,
657                        message,
658                    }),
659                },
660            );
661        };
662
663        let num_queries = request.query_strings.len();
664
665        let database_identity = self.relational_db.database_identity();
666        let subscription_metrics = SubscriptionMetrics::new(&database_identity, &WorkloadType::Subscribe);
667
668        // How many queries make up this subscription?
669        subscription_metrics.num_queries_subscribed.inc_by(num_queries as _);
670
671        let (queries, auth, tx, compile_timer) = return_on_err!(
672            self.compile_queries(
673                sender.id.identity,
674                &request.query_strings,
675                num_queries,
676                &subscription_metrics
677            ),
678            send_err_msg,
679            None
680        );
681        let tx = scopeguard::guard(tx, |tx| {
682            let (tx_metrics, reducer) = self.relational_db.release_tx(tx);
683            self.relational_db.report_read_tx_metrics(reducer, tx_metrics);
684        });
685
686        // We minimize locking so that other clients can add subscriptions concurrently.
687        // We are protected from race conditions with broadcasts, because we have the db lock,
688        // an `commit_and_broadcast_event` grabs a read lock on `subscriptions` while it still has a
689        // write lock on the db.
690        let queries = {
691            let mut subscriptions = {
692                // How contended is the lock?
693                let _wait_guard = subscription_metrics.lock_waiters.inc_scope();
694                let _wait_timer = subscription_metrics.lock_wait_time.start_timer();
695                self.subscriptions.write()
696            };
697
698            subscriptions.add_subscription_multi(sender.clone(), queries, request.query_id)?
699        };
700
701        // Record how long it took to compile the subscription
702        drop(compile_timer);
703
704        let Ok((update, metrics)) =
705            self.evaluate_queries(sender.clone(), &queries, &tx, &auth, TableUpdateType::Subscribe)
706        else {
707            // If we fail the query, we need to remove the subscription.
708            let mut subscriptions = {
709                // How contended is the lock?
710                let _wait_guard = subscription_metrics.lock_waiters.inc_scope();
711                let _wait_timer = subscription_metrics.lock_wait_time.start_timer();
712                self.subscriptions.write()
713            };
714            {
715                let _compile_timer = subscription_metrics.compilation_time.start_timer();
716                subscriptions.remove_subscription((sender.id.identity, sender.id.connection_id), request.query_id)?;
717            }
718
719            send_err_msg("Internal error evaluating queries".into());
720            return Ok(None);
721        };
722
723        // How many queries did we actually evaluate?
724        subscription_metrics.num_queries_evaluated.inc_by(queries.len() as _);
725
726        #[cfg(test)]
727        if let Some(assert) = _assert {
728            assert(&tx);
729        }
730
731        // Note: to make sure transaction updates are consistent, we need to put this in the broadcast
732        // queue while we are still holding a read-lock on the database.
733
734        // That will avoid race conditions because reducers first grab a write lock on the db, then
735        // grab a read lock on the subscriptions.
736
737        // Holding a write lock on `self.subscriptions` would also be sufficient.
738
739        let _ = self.broadcast_queue.send_client_message(
740            sender.clone(),
741            SubscriptionMessage {
742                request_id: Some(request.request_id),
743                query_id: Some(request.query_id),
744                timer: Some(timer),
745                result: SubscriptionResult::SubscribeMulti(SubscriptionData { data: update }),
746            },
747        );
748
749        Ok(Some(metrics))
750    }
751
752    /// Add a subscriber to the module. NOTE: this function is blocking.
753    /// This is used for the legacy subscription API which uses a set of queries.
754    #[tracing::instrument(level = "trace", skip_all)]
755    pub fn add_legacy_subscriber(
756        &self,
757        sender: Arc<ClientConnectionSender>,
758        subscription: Subscribe,
759        timer: Instant,
760        _assert: Option<AssertTxFn>,
761    ) -> Result<ExecutionMetrics, DBError> {
762        let num_queries = subscription.query_strings.len();
763        let database_identity = self.relational_db.database_identity();
764        let subscription_metrics = SubscriptionMetrics::new(&database_identity, &WorkloadType::Subscribe);
765
766        // How many queries make up this subscription?
767        subscription_metrics.num_queries_subscribed.inc_by(num_queries as _);
768
769        let (queries, auth, tx, compile_timer) = self.compile_queries(
770            sender.id.identity,
771            &subscription.query_strings,
772            num_queries,
773            &subscription_metrics,
774        )?;
775        let tx = scopeguard::guard(tx, |tx| {
776            let (tx_metrics, reducer) = self.relational_db.release_tx(tx);
777            self.relational_db.report_read_tx_metrics(reducer, tx_metrics);
778        });
779
780        check_row_limit(
781            &queries,
782            &self.relational_db,
783            &tx,
784            |plan, tx| {
785                plan.plans_fragments()
786                    .map(|plan_fragment| estimate_rows_scanned(tx, plan_fragment.optimized_physical_plan()))
787                    .fold(0, |acc, rows_scanned| acc.saturating_add(rows_scanned))
788            },
789            &auth,
790        )?;
791
792        // Record how long it took to compile the subscription
793        drop(compile_timer);
794
795        let tx = DeltaTx::from(&*tx);
796        let (database_update, metrics) = match sender.config.protocol {
797            Protocol::Binary => execute_plans(&queries, &tx, TableUpdateType::Subscribe)
798                .map(|(table_update, metrics)| (FormatSwitch::Bsatn(table_update), metrics))?,
799            Protocol::Text => execute_plans(&queries, &tx, TableUpdateType::Subscribe)
800                .map(|(table_update, metrics)| (FormatSwitch::Json(table_update), metrics))?,
801        };
802
803        // It acquires the subscription lock after `eval`, allowing `add_subscription` to run concurrently.
804        // This also makes it possible for `broadcast_event` to get scheduled before the subsequent part here
805        // but that should not pose an issue.
806        {
807            let _compile_timer = subscription_metrics.compilation_time.start_timer();
808
809            let mut subscriptions = {
810                // How contended is the lock?
811                let _wait_guard = subscription_metrics.lock_waiters.inc_scope();
812                let _wait_timer = subscription_metrics.lock_wait_time.start_timer();
813                self.subscriptions.write()
814            };
815
816            subscriptions.set_legacy_subscription(sender.clone(), queries.into_iter());
817        }
818
819        #[cfg(test)]
820        if let Some(assert) = _assert {
821            assert(&tx);
822        }
823
824        // Note: to make sure transaction updates are consistent, we need to put this in the broadcast
825        // queue while we are still holding a read-lock on the database.
826
827        // That will avoid race conditions because reducers first grab a write lock on the db, then
828        // grab a read lock on the subscriptions.
829
830        // Holding a write lock on `self.subscriptions` would also be sufficient.
831        let _ = self.broadcast_queue.send_client_message(
832            sender,
833            SubscriptionUpdateMessage {
834                database_update,
835                request_id: Some(subscription.request_id),
836                timer: Some(timer),
837            },
838        );
839
840        Ok(metrics)
841    }
842
843    pub fn remove_subscriber(&self, client_id: ClientActorId) {
844        let mut subscriptions = self.subscriptions.write();
845        subscriptions.remove_all_subscriptions(&(client_id.identity, client_id.connection_id));
846    }
847
848    /// Commit a transaction and broadcast its ModuleEvent to all interested subscribers.
849    ///
850    /// The returned [`ExecutionMetrics`] are reported in this method via `report_tx_metrics`.
851    /// They are returned for testing purposes but should not be reported separately.
852    pub fn commit_and_broadcast_event(
853        &self,
854        caller: Option<Arc<ClientConnectionSender>>,
855        mut event: ModuleEvent,
856        tx: MutTx,
857    ) -> Result<Result<(Arc<ModuleEvent>, ExecutionMetrics), WriteConflict>, DBError> {
858        let database_identity = self.relational_db.database_identity();
859        let subscription_metrics = SubscriptionMetrics::new(&database_identity, &WorkloadType::Update);
860
861        // Take a read lock on `subscriptions` before committing tx
862        // else it can result in subscriber receiving duplicate updates.
863        let subscriptions = {
864            // How contended is the lock?
865            let _wait_guard = subscription_metrics.lock_waiters.inc_scope();
866            let _wait_timer = subscription_metrics.lock_wait_time.start_timer();
867            self.subscriptions.read()
868        };
869
870        let stdb = &self.relational_db;
871        // Downgrade mutable tx.
872        // We'll later ensure tx is released/cleaned up once out of scope.
873        let (read_tx, tx_data, tx_metrics_mut) = match &mut event.status {
874            EventStatus::Committed(db_update) => {
875                let Some((tx_data, tx_metrics, read_tx)) = stdb.commit_tx_downgrade(tx, Workload::Update)? else {
876                    return Ok(Err(WriteConflict));
877                };
878                *db_update = DatabaseUpdate::from_writes(&tx_data);
879                (read_tx, Some(tx_data), tx_metrics)
880            }
881            EventStatus::Failed(_) | EventStatus::OutOfEnergy => {
882                let (tx_metrics, tx) = stdb.rollback_mut_tx_downgrade(tx, Workload::Update);
883                (tx, None, tx_metrics)
884            }
885        };
886
887        let tx_data = tx_data.map(Arc::new);
888
889        // When we're done with this method, release the tx and report metrics.
890        let mut read_tx = scopeguard::guard(read_tx, |tx| {
891            let (tx_metrics_read, reducer) = self.relational_db.release_tx(tx);
892            self.relational_db
893                .report_tx_metrics(reducer, tx_data.clone(), Some(tx_metrics_mut), Some(tx_metrics_read));
894        });
895        // Create the delta transaction we'll use to eval updates against.
896        let delta_read_tx = tx_data
897            .as_ref()
898            .as_ref()
899            .map(|tx_data| DeltaTx::new(&read_tx, tx_data, subscriptions.index_ids_for_subscriptions()))
900            .unwrap_or_else(|| DeltaTx::from(&*read_tx));
901
902        let event = Arc::new(event);
903        let mut update_metrics: ExecutionMetrics = ExecutionMetrics::default();
904
905        match &event.status {
906            EventStatus::Committed(_) => {
907                update_metrics = subscriptions.eval_updates_sequential(&delta_read_tx, event.clone(), caller);
908            }
909            EventStatus::Failed(_) => {
910                if let Some(client) = caller {
911                    let message = TransactionUpdateMessage {
912                        event: Some(event.clone()),
913                        database_update: SubscriptionUpdateMessage::default_for_protocol(client.config.protocol, None),
914                    };
915
916                    let _ = self.broadcast_queue.send_client_message(client, message);
917                } else {
918                    log::trace!("Reducer failed but there is no client to send the failure to!")
919                }
920            }
921            EventStatus::OutOfEnergy => {} // ?
922        }
923
924        // Merge in the subscription evaluation metrics.
925        read_tx.metrics.merge(update_metrics);
926
927        Ok(Ok((event, update_metrics)))
928    }
929}
930
931pub struct WriteConflict;
932
933#[cfg(test)]
934mod tests {
935    use super::{AssertTxFn, ModuleSubscriptions};
936    use crate::client::messages::{
937        SerializableMessage, SubscriptionData, SubscriptionError, SubscriptionMessage, SubscriptionResult,
938        SubscriptionUpdateMessage, TransactionUpdateMessage,
939    };
940    use crate::client::{ClientActorId, ClientConfig, ClientConnectionSender, ClientName, MeteredReceiver, Protocol};
941    use crate::db::relational_db::tests_utils::{
942        begin_mut_tx, begin_tx, insert, with_auto_commit, with_read_only, TestDB,
943    };
944    use crate::db::relational_db::RelationalDB;
945    use crate::error::DBError;
946    use crate::host::module_host::{DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall};
947    use crate::messages::websocket as ws;
948    use crate::sql::execute::run;
949    use crate::subscription::module_subscription_manager::{spawn_send_worker, SubscriptionManager};
950    use crate::subscription::query::compile_read_only_query;
951    use crate::subscription::TableUpdateType;
952    use hashbrown::HashMap;
953    use itertools::Itertools;
954    use pretty_assertions::assert_matches;
955    use spacetimedb_client_api_messages::energy::EnergyQuanta;
956    use spacetimedb_client_api_messages::websocket::{
957        CompressableQueryUpdate, Compression, FormatSwitch, QueryId, Subscribe, SubscribeMulti, SubscribeSingle,
958        TableUpdate, Unsubscribe, UnsubscribeMulti,
959    };
960    use spacetimedb_datastore::system_tables::{StRowLevelSecurityRow, ST_ROW_LEVEL_SECURITY_ID};
961    use spacetimedb_execution::dml::MutDatastore;
962    use spacetimedb_lib::bsatn::ToBsatn;
963    use spacetimedb_lib::db::auth::StAccess;
964    use spacetimedb_lib::identity::AuthCtx;
965    use spacetimedb_lib::metrics::ExecutionMetrics;
966    use spacetimedb_lib::{bsatn, ConnectionId, ProductType, ProductValue, Timestamp};
967    use spacetimedb_lib::{error::ResultTest, AlgebraicType, Identity};
968    use spacetimedb_primitives::TableId;
969    use spacetimedb_sats::product;
970    use std::time::Instant;
971    use std::{sync::Arc, time::Duration};
972    use tokio::sync::mpsc::{self};
973
974    fn add_subscriber(db: Arc<RelationalDB>, sql: &str, assert: Option<AssertTxFn>) -> Result<(), DBError> {
975        // Create and enter a Tokio runtime to run the `ModuleSubscriptions`' background workers in parallel.
976        let runtime = tokio::runtime::Runtime::new().unwrap();
977        let _rt = runtime.enter();
978        let owner = Identity::from_byte_array([1; 32]);
979        let client = ClientActorId::for_test(Identity::ZERO);
980        let config = ClientConfig::for_test();
981        let sender = Arc::new(ClientConnectionSender::dummy(client, config));
982        let send_worker_queue = spawn_send_worker(None);
983        let module_subscriptions = ModuleSubscriptions::new(
984            db.clone(),
985            SubscriptionManager::for_test_without_metrics_arc_rwlock(),
986            send_worker_queue,
987            owner,
988        );
989
990        let subscribe = Subscribe {
991            query_strings: [sql.into()].into(),
992            request_id: 0,
993        };
994        module_subscriptions.add_legacy_subscriber(sender, subscribe, Instant::now(), assert)?;
995        Ok(())
996    }
997
998    /// An in-memory `RelationalDB` for testing
999    fn relational_db() -> anyhow::Result<Arc<RelationalDB>> {
1000        let TestDB { db, .. } = TestDB::in_memory()?;
1001        Ok(Arc::new(db))
1002    }
1003
1004    /// A [SubscribeSingle] message for testing
1005    fn single_subscribe(sql: &str, query_id: u32) -> SubscribeSingle {
1006        SubscribeSingle {
1007            query: sql.into(),
1008            request_id: 0,
1009            query_id: QueryId::new(query_id),
1010        }
1011    }
1012
1013    /// A [SubscribeMulti] message for testing
1014    fn multi_subscribe(query_strings: &[&'static str], query_id: u32) -> SubscribeMulti {
1015        SubscribeMulti {
1016            query_strings: query_strings
1017                .iter()
1018                .map(|sql| String::from(*sql).into_boxed_str())
1019                .collect(),
1020            request_id: 0,
1021            query_id: QueryId::new(query_id),
1022        }
1023    }
1024
1025    /// A [SubscribeMulti] message for testing
1026    fn multi_unsubscribe(query_id: u32) -> UnsubscribeMulti {
1027        UnsubscribeMulti {
1028            request_id: 0,
1029            query_id: QueryId::new(query_id),
1030        }
1031    }
1032
1033    /// An [Unsubscribe] message for testing
1034    fn single_unsubscribe(query_id: u32) -> Unsubscribe {
1035        Unsubscribe {
1036            request_id: 0,
1037            query_id: QueryId::new(query_id),
1038        }
1039    }
1040
1041    /// A dummy [ModuleEvent] for testing
1042    fn module_event() -> ModuleEvent {
1043        ModuleEvent {
1044            timestamp: Timestamp::now(),
1045            caller_identity: Identity::ZERO,
1046            caller_connection_id: None,
1047            function_call: ModuleFunctionCall::default(),
1048            status: EventStatus::Committed(DatabaseUpdate::default()),
1049            energy_quanta_used: EnergyQuanta { quanta: 0 },
1050            host_execution_duration: Duration::from_millis(0),
1051            request_id: None,
1052            timer: None,
1053        }
1054    }
1055
1056    /// Create an [Identity] from a [u8]
1057    fn identity_from_u8(v: u8) -> Identity {
1058        Identity::from_byte_array([v; 32])
1059    }
1060
1061    /// Create an [ConnectionId] from a [u8]
1062    fn connection_id_from_u8(v: u8) -> ConnectionId {
1063        ConnectionId::from_be_byte_array([v; 16])
1064    }
1065
1066    /// Create an [ClientActorId] from a [u8].
1067    /// Calls [identity_from_u8] internally with the passed value.
1068    fn client_id_from_u8(v: u8) -> ClientActorId {
1069        ClientActorId {
1070            identity: identity_from_u8(v),
1071            connection_id: connection_id_from_u8(v),
1072            name: ClientName(v as u64),
1073        }
1074    }
1075
1076    /// Instantiate a client connection with compression
1077    fn client_connection_with_compression(
1078        client_id: ClientActorId,
1079        compression: Compression,
1080    ) -> (Arc<ClientConnectionSender>, MeteredReceiver<SerializableMessage>) {
1081        let (sender, rx) = ClientConnectionSender::dummy_with_channel(
1082            client_id,
1083            ClientConfig {
1084                protocol: Protocol::Binary,
1085                compression,
1086                tx_update_full: true,
1087            },
1088        );
1089        (Arc::new(sender), rx)
1090    }
1091
1092    /// Instantiate a client connection
1093    fn client_connection(
1094        client_id: ClientActorId,
1095    ) -> (Arc<ClientConnectionSender>, MeteredReceiver<SerializableMessage>) {
1096        client_connection_with_compression(client_id, Compression::None)
1097    }
1098
1099    /// Insert rules into the RLS system table
1100    fn insert_rls_rules(
1101        db: &RelationalDB,
1102        table_ids: impl IntoIterator<Item = TableId>,
1103        rules: impl IntoIterator<Item = &'static str>,
1104    ) -> anyhow::Result<()> {
1105        with_auto_commit(db, |tx| {
1106            for (table_id, sql) in table_ids.into_iter().zip(rules) {
1107                db.insert(
1108                    tx,
1109                    ST_ROW_LEVEL_SECURITY_ID,
1110                    &ProductValue::from(StRowLevelSecurityRow {
1111                        table_id,
1112                        sql: sql.into(),
1113                    })
1114                    .to_bsatn_vec()?,
1115                )?;
1116            }
1117            Ok(())
1118        })
1119    }
1120
1121    /// Subscribe to a query as a client
1122    fn subscribe_single(
1123        subs: &ModuleSubscriptions,
1124        sql: &'static str,
1125        sender: Arc<ClientConnectionSender>,
1126        counter: &mut u32,
1127    ) -> anyhow::Result<()> {
1128        *counter += 1;
1129        subs.add_single_subscription(sender, single_subscribe(sql, *counter), Instant::now(), None)?;
1130        Ok(())
1131    }
1132
1133    /// Subscribe to a set of queries as a client
1134    fn subscribe_multi(
1135        subs: &ModuleSubscriptions,
1136        queries: &[&'static str],
1137        sender: Arc<ClientConnectionSender>,
1138        counter: &mut u32,
1139    ) -> anyhow::Result<ExecutionMetrics> {
1140        *counter += 1;
1141        let metrics = subs
1142            .add_multi_subscription(sender, multi_subscribe(queries, *counter), Instant::now(), None)
1143            .map(|metrics| metrics.unwrap_or_default())?;
1144        Ok(metrics)
1145    }
1146
1147    /// Unsubscribe from a single query
1148    fn unsubscribe_single(
1149        subs: &ModuleSubscriptions,
1150        sender: Arc<ClientConnectionSender>,
1151        query_id: u32,
1152    ) -> anyhow::Result<()> {
1153        subs.remove_single_subscription(sender, single_unsubscribe(query_id), Instant::now())?;
1154        Ok(())
1155    }
1156
1157    /// Unsubscribe from a set of queries
1158    fn unsubscribe_multi(
1159        subs: &ModuleSubscriptions,
1160        sender: Arc<ClientConnectionSender>,
1161        query_id: u32,
1162    ) -> anyhow::Result<()> {
1163        subs.remove_multi_subscription(sender, multi_unsubscribe(query_id), Instant::now())?;
1164        Ok(())
1165    }
1166
1167    /// Pull a message from receiver and assert that it is a `TxUpdate` with the expected rows
1168    async fn assert_tx_update_for_table(
1169        rx: &mut MeteredReceiver<SerializableMessage>,
1170        table_id: TableId,
1171        schema: &ProductType,
1172        inserts: impl IntoIterator<Item = ProductValue>,
1173        deletes: impl IntoIterator<Item = ProductValue>,
1174    ) {
1175        match rx.recv().await {
1176            Some(SerializableMessage::TxUpdate(TransactionUpdateMessage {
1177                database_update:
1178                    SubscriptionUpdateMessage {
1179                        database_update: FormatSwitch::Bsatn(ws::DatabaseUpdate { mut tables }),
1180                        ..
1181                    },
1182                ..
1183            })) => {
1184                // Assume an update for only one table
1185                assert_eq!(tables.len(), 1);
1186
1187                let table_update = tables.pop().unwrap();
1188
1189                // We should not be sending empty updates to clients
1190                assert_ne!(table_update.num_rows, 0);
1191
1192                // It should be the table we expect
1193                assert_eq!(table_update.table_id, table_id);
1194
1195                let mut rows_received: HashMap<ProductValue, i32> = HashMap::new();
1196
1197                for uncompressed in table_update.updates {
1198                    let CompressableQueryUpdate::Uncompressed(table_update) = uncompressed else {
1199                        panic!("expected an uncompressed table update")
1200                    };
1201
1202                    for row in table_update
1203                        .inserts
1204                        .into_iter()
1205                        .map(|bytes| ProductValue::decode(schema, &mut &*bytes).unwrap())
1206                    {
1207                        *rows_received.entry(row).or_insert(0) += 1;
1208                    }
1209
1210                    for row in table_update
1211                        .deletes
1212                        .into_iter()
1213                        .map(|bytes| ProductValue::decode(schema, &mut &*bytes).unwrap())
1214                    {
1215                        *rows_received.entry(row).or_insert(0) -= 1;
1216                    }
1217                }
1218
1219                assert_eq!(
1220                    rows_received
1221                        .iter()
1222                        .filter(|(_, n)| n > &&0)
1223                        .map(|(row, _)| row)
1224                        .cloned()
1225                        .sorted()
1226                        .collect::<Vec<_>>(),
1227                    inserts.into_iter().sorted().collect::<Vec<_>>()
1228                );
1229                assert_eq!(
1230                    rows_received
1231                        .iter()
1232                        .filter(|(_, n)| n < &&0)
1233                        .map(|(row, _)| row)
1234                        .cloned()
1235                        .sorted()
1236                        .collect::<Vec<_>>(),
1237                    deletes.into_iter().sorted().collect::<Vec<_>>()
1238                );
1239            }
1240            Some(msg) => panic!("expected a TxUpdate, but got {msg:#?}"),
1241            None => panic!("The receiver closed due to an error"),
1242        }
1243    }
1244
1245    /// Commit a set of row updates and broadcast to subscribers
1246    fn commit_tx(
1247        db: &RelationalDB,
1248        subs: &ModuleSubscriptions,
1249        deletes: impl IntoIterator<Item = (TableId, ProductValue)>,
1250        inserts: impl IntoIterator<Item = (TableId, ProductValue)>,
1251    ) -> anyhow::Result<ExecutionMetrics> {
1252        let mut tx = begin_mut_tx(db);
1253        for (table_id, row) in deletes {
1254            tx.delete_product_value(table_id, &row)?;
1255        }
1256        for (table_id, row) in inserts {
1257            db.insert(&mut tx, table_id, &bsatn::to_vec(&row)?)?;
1258        }
1259
1260        let Ok(Ok((_, metrics))) = subs.commit_and_broadcast_event(None, module_event(), tx) else {
1261            panic!("Encountered an error in `commit_and_broadcast_event`");
1262        };
1263        Ok(metrics)
1264    }
1265
1266    #[test]
1267    fn test_subscribe_metrics() -> anyhow::Result<()> {
1268        let client_id = client_id_from_u8(1);
1269        let (sender, _) = client_connection(client_id);
1270
1271        let db = relational_db()?;
1272        let (subs, _runtime) = ModuleSubscriptions::for_test_new_runtime(db.clone());
1273
1274        // Create a table `t` with index on `id`
1275        let table_id = db.create_table_for_test("t", &[("id", AlgebraicType::U64)], &[0.into()])?;
1276        with_auto_commit(&db, |tx| -> anyhow::Result<_> {
1277            db.insert(tx, table_id, &bsatn::to_vec(&product![1_u64])?)?;
1278            Ok(())
1279        })?;
1280
1281        let auth = AuthCtx::for_testing();
1282        let sql = "select * from t where id = 1";
1283        let tx = begin_tx(&db);
1284        let plan = compile_read_only_query(&auth, &tx, sql)?;
1285        let plan = Arc::new(plan);
1286
1287        let (_, metrics) = subs.evaluate_queries(sender, &[plan], &tx, &auth, TableUpdateType::Subscribe)?;
1288
1289        // We only probe the index once
1290        assert_eq!(metrics.index_seeks, 1);
1291        // We scan a single u64 when serializing the result
1292        assert_eq!(metrics.bytes_scanned, 8);
1293        // Subscriptions are read-only
1294        assert_eq!(metrics.bytes_written, 0);
1295        // Bytes scanned and bytes sent will always be the same for an initial subscription,
1296        // because a subscription is initiated by a single client.
1297        assert_eq!(metrics.bytes_sent_to_clients, 8);
1298
1299        // Note, rows scanned may be greater than one.
1300        // It depends on the number of operators used to answer the query.
1301        assert!(metrics.rows_scanned > 0);
1302        Ok(())
1303    }
1304
1305    fn check_subscription_err(sql: &str, result: Option<SerializableMessage>) {
1306        if let Some(SerializableMessage::Subscription(SubscriptionMessage {
1307            result: SubscriptionResult::Error(SubscriptionError { message, .. }),
1308            ..
1309        })) = result
1310        {
1311            assert!(
1312                message.contains(sql),
1313                "Expected error message to contain the SQL query: {sql}, but got: {message}",
1314            );
1315            return;
1316        }
1317        panic!("Expected a subscription error message, but got: {result:?}");
1318    }
1319
1320    /// Test that clients receive error messages on subscribe
1321    #[tokio::test]
1322    async fn subscribe_single_error() -> anyhow::Result<()> {
1323        let client_id = client_id_from_u8(1);
1324        let (tx, mut rx) = client_connection(client_id);
1325
1326        let db = relational_db()?;
1327        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
1328
1329        db.create_table_for_test("t", &[("x", AlgebraicType::U8)], &[])?;
1330
1331        // Subscribe to an invalid query (r is not in scope)
1332        let sql = "select r.* from t";
1333        subscribe_single(&subs, sql, tx, &mut 0)?;
1334
1335        check_subscription_err(sql, rx.recv().await);
1336
1337        Ok(())
1338    }
1339
1340    /// Test that clients receive error messages on subscribe
1341    #[tokio::test]
1342    async fn subscribe_multi_error() -> anyhow::Result<()> {
1343        let client_id = client_id_from_u8(1);
1344        let (tx, mut rx) = client_connection(client_id);
1345
1346        let db = relational_db()?;
1347        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
1348
1349        db.create_table_for_test("t", &[("x", AlgebraicType::U8)], &[])?;
1350
1351        // Subscribe to an invalid query (r is not in scope)
1352        let sql = "select r.* from t";
1353        subscribe_multi(&subs, &[sql], tx, &mut 0)?;
1354
1355        check_subscription_err(sql, rx.recv().await);
1356
1357        Ok(())
1358    }
1359
1360    /// Test that clients receive error messages on unsubscribe
1361    #[tokio::test]
1362    async fn unsubscribe_single_error() -> anyhow::Result<()> {
1363        let client_id = client_id_from_u8(1);
1364        let (tx, mut rx) = client_connection(client_id);
1365
1366        let db = relational_db()?;
1367        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
1368
1369        // Create a table `t` with an index on `id`
1370        let table_id = db.create_table_for_test("t", &[("id", AlgebraicType::U8)], &[0.into()])?;
1371        let index_id = with_read_only(&db, |tx| {
1372            db.schema_for_table(&*tx, table_id).map(|schema| {
1373                schema
1374                    .indexes
1375                    .first()
1376                    .map(|index_schema| index_schema.index_id)
1377                    .unwrap()
1378            })
1379        })?;
1380
1381        let mut query_id = 0;
1382
1383        // Subscribe to `t`
1384        let sql = "select * from t where id = 1";
1385        subscribe_single(&subs, sql, tx.clone(), &mut query_id)?;
1386
1387        // The initial subscription should succeed
1388        assert!(matches!(
1389            rx.recv().await,
1390            Some(SerializableMessage::Subscription(SubscriptionMessage {
1391                result: SubscriptionResult::Subscribe(..),
1392                ..
1393            }))
1394        ));
1395
1396        // Remove the index from `id`
1397        with_auto_commit(&db, |tx| db.drop_index(tx, index_id))?;
1398
1399        // Unsubscribe from `t`
1400        unsubscribe_single(&subs, tx, query_id)?;
1401
1402        // Why does the unsubscribe fail?
1403        // This relies on some knowledge of the underlying implementation.
1404        // Specifically that we do not recompile queries on unsubscribe.
1405        // We execute the cached plan which in this case is an index scan.
1406        // The index no longer exists, and therefore it fails.
1407        check_subscription_err(sql, rx.recv().await);
1408
1409        Ok(())
1410    }
1411
1412    /// Test that clients receive error messages on unsubscribe
1413    ///
1414    /// Needs a multi-threaded tokio runtime so that the module subscription worker can run in parallel.
1415    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
1416    async fn unsubscribe_multi_error() -> anyhow::Result<()> {
1417        let client_id = client_id_from_u8(1);
1418        let (tx, mut rx) = client_connection(client_id);
1419
1420        let db = relational_db()?;
1421        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
1422
1423        // Create a table `t` with an index on `id`
1424        let table_id = db.create_table_for_test("t", &[("id", AlgebraicType::U8)], &[0.into()])?;
1425        let index_id = with_read_only(&db, |tx| {
1426            db.schema_for_table(&*tx, table_id).map(|schema| {
1427                schema
1428                    .indexes
1429                    .first()
1430                    .map(|index_schema| index_schema.index_id)
1431                    .unwrap()
1432            })
1433        })?;
1434
1435        commit_tx(&db, &subs, [], [(table_id, product![0_u8])])?;
1436
1437        let mut query_id = 0;
1438
1439        // Subscribe to `t`
1440        let sql = "select * from t where id = 1";
1441        subscribe_multi(&subs, &[sql], tx.clone(), &mut query_id)?;
1442
1443        // The initial subscription should succeed
1444        assert!(matches!(
1445            rx.recv().await,
1446            Some(SerializableMessage::Subscription(SubscriptionMessage {
1447                result: SubscriptionResult::SubscribeMulti(..),
1448                ..
1449            }))
1450        ));
1451
1452        // Remove the index from `id`
1453        with_auto_commit(&db, |tx| db.drop_index(tx, index_id))?;
1454
1455        // Unsubscribe from `t`
1456        unsubscribe_multi(&subs, tx, query_id)?;
1457
1458        // Why does the unsubscribe fail?
1459        // This relies on some knowledge of the underlying implementation.
1460        // Specifically that we do not recompile queries on unsubscribe.
1461        // We execute the cached plan which in this case is an index scan.
1462        // The index no longer exists, and therefore it fails.
1463        check_subscription_err(sql, rx.recv().await);
1464
1465        Ok(())
1466    }
1467
1468    /// Test that clients receive error messages on tx updates
1469    #[tokio::test]
1470    async fn tx_update_error() -> anyhow::Result<()> {
1471        let client_id = client_id_from_u8(1);
1472        let (tx, mut rx) = client_connection(client_id);
1473
1474        let db = relational_db()?;
1475        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
1476
1477        // Create two tables `t` and `s` with indexes on their `id` columns
1478        let t_id = db.create_table_for_test("t", &[("id", AlgebraicType::U8)], &[0.into()])?;
1479        let s_id = db.create_table_for_test("s", &[("id", AlgebraicType::U8)], &[0.into()])?;
1480        let index_id = with_read_only(&db, |tx| {
1481            db.schema_for_table(&*tx, s_id).map(|schema| {
1482                schema
1483                    .indexes
1484                    .first()
1485                    .map(|index_schema| index_schema.index_id)
1486                    .unwrap()
1487            })
1488        })?;
1489        let sql = "select t.* from t join s on t.id = s.id";
1490        subscribe_single(&subs, sql, tx, &mut 0)?;
1491
1492        // The initial subscription should succeed
1493        assert!(matches!(
1494            rx.recv().await,
1495            Some(SerializableMessage::Subscription(SubscriptionMessage {
1496                result: SubscriptionResult::Subscribe(..),
1497                ..
1498            }))
1499        ));
1500
1501        // Remove the index from `s`
1502        with_auto_commit(&db, |tx| db.drop_index(tx, index_id))?;
1503
1504        // Start a new transaction and insert a new row into `t`
1505        let mut tx = begin_mut_tx(&db);
1506        db.insert(&mut tx, t_id, &bsatn::to_vec(&product![2_u8])?)?;
1507
1508        assert!(matches!(
1509            subs.commit_and_broadcast_event(None, module_event(), tx),
1510            Ok(Ok(_))
1511        ));
1512
1513        // Why does the update fail?
1514        // This relies on some knowledge of the underlying implementation.
1515        // Specifically, plans are cached on the initial subscribe.
1516        // Hence we execute a cached plan which happens to be an index join.
1517        // We've removed the index on `s`, and therefore it fails.
1518        check_subscription_err(sql, rx.recv().await);
1519
1520        Ok(())
1521    }
1522
1523    /// Test that two clients can subscribe to a parameterized query and get the correct rows.
1524    #[tokio::test]
1525    async fn test_parameterized_subscription() -> anyhow::Result<()> {
1526        // Create identities for two different clients
1527        let id_for_a = identity_from_u8(1);
1528        let id_for_b = identity_from_u8(2);
1529
1530        let client_id_for_a = client_id_from_u8(1);
1531        let client_id_for_b = client_id_from_u8(2);
1532
1533        // Establish a connection for each client
1534        let (tx_for_a, mut rx_for_a) = client_connection(client_id_for_a);
1535        let (tx_for_b, mut rx_for_b) = client_connection(client_id_for_b);
1536
1537        let db = relational_db()?;
1538        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
1539
1540        let schema = [("identity", AlgebraicType::identity())];
1541
1542        let table_id = db.create_table_for_test("t", &schema, &[])?;
1543
1544        let mut query_ids = 0;
1545
1546        // Have each client subscribe to the same parameterized query.
1547        // Each client should receive different rows.
1548        subscribe_multi(
1549            &subs,
1550            &["select * from t where identity = :sender"],
1551            tx_for_a,
1552            &mut query_ids,
1553        )?;
1554        subscribe_multi(
1555            &subs,
1556            &["select * from t where identity = :sender"],
1557            tx_for_b,
1558            &mut query_ids,
1559        )?;
1560
1561        // Wait for both subscriptions
1562        assert!(matches!(
1563            rx_for_a.recv().await,
1564            Some(SerializableMessage::Subscription(_))
1565        ));
1566        assert!(matches!(
1567            rx_for_b.recv().await,
1568            Some(SerializableMessage::Subscription(_))
1569        ));
1570
1571        // Insert two identities - one for each caller - into the table
1572        let mut tx = begin_mut_tx(&db);
1573        db.insert(&mut tx, table_id, &bsatn::to_vec(&product![id_for_a])?)?;
1574        db.insert(&mut tx, table_id, &bsatn::to_vec(&product![id_for_b])?)?;
1575
1576        assert!(matches!(
1577            subs.commit_and_broadcast_event(None, module_event(), tx),
1578            Ok(Ok(_))
1579        ));
1580
1581        let schema = ProductType::from([AlgebraicType::identity()]);
1582
1583        // Both clients should only receive their identities and not the other's.
1584        assert_tx_update_for_table(&mut rx_for_a, table_id, &schema, [product![id_for_a]], []).await;
1585        assert_tx_update_for_table(&mut rx_for_b, table_id, &schema, [product![id_for_b]], []).await;
1586        Ok(())
1587    }
1588
1589    /// Test that two clients can subscribe to a table with RLS rules and get the correct rows
1590    #[tokio::test]
1591    async fn test_rls_subscription() -> anyhow::Result<()> {
1592        // Create identities for two different clients
1593        let id_for_a = identity_from_u8(1);
1594        let id_for_b = identity_from_u8(2);
1595
1596        let client_id_for_a = client_id_from_u8(1);
1597        let client_id_for_b = client_id_from_u8(2);
1598
1599        // Establish a connection for each client
1600        let (tx_for_a, mut rx_for_a) = client_connection(client_id_for_a);
1601        let (tx_for_b, mut rx_for_b) = client_connection(client_id_for_b);
1602
1603        let db = relational_db()?;
1604        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
1605
1606        let schema = [("id", AlgebraicType::identity())];
1607
1608        let u_id = db.create_table_for_test("u", &schema, &[0.into()])?;
1609        let v_id = db.create_table_for_test("v", &schema, &[0.into()])?;
1610        let w_id = db.create_table_for_test("w", &schema, &[0.into()])?;
1611
1612        insert_rls_rules(
1613            &db,
1614            [u_id, v_id, w_id, w_id],
1615            [
1616                "select * from u where id = :sender",
1617                "select * from v where id = :sender",
1618                "select w.* from u join w on u.id = w.id",
1619                "select w.* from v join w on v.id = w.id",
1620            ],
1621        )?;
1622
1623        let mut query_ids = 0;
1624
1625        // Have each client subscribe to `w`.
1626        // Because `w` is gated using parameterized RLS rules,
1627        // each client should receive different rows.
1628        subscribe_multi(&subs, &["select * from w"], tx_for_a, &mut query_ids)?;
1629        subscribe_multi(&subs, &["select * from w"], tx_for_b, &mut query_ids)?;
1630
1631        // Wait for both subscriptions
1632        assert!(matches!(
1633            rx_for_a.recv().await,
1634            Some(SerializableMessage::Subscription(_))
1635        ));
1636        assert!(matches!(
1637            rx_for_b.recv().await,
1638            Some(SerializableMessage::Subscription(_))
1639        ));
1640
1641        // Insert a row into `u` for client "a".
1642        // Insert a row into `v` for client "b".
1643        // Insert a row into `w` for both.
1644        let mut tx = begin_mut_tx(&db);
1645        db.insert(&mut tx, u_id, &bsatn::to_vec(&product![id_for_a])?)?;
1646        db.insert(&mut tx, v_id, &bsatn::to_vec(&product![id_for_b])?)?;
1647        db.insert(&mut tx, w_id, &bsatn::to_vec(&product![id_for_a])?)?;
1648        db.insert(&mut tx, w_id, &bsatn::to_vec(&product![id_for_b])?)?;
1649
1650        assert!(matches!(
1651            subs.commit_and_broadcast_event(None, module_event(), tx),
1652            Ok(Ok(_))
1653        ));
1654
1655        let schema = ProductType::from([AlgebraicType::identity()]);
1656
1657        // Both clients should only receive their identities and not the other's.
1658        assert_tx_update_for_table(&mut rx_for_a, w_id, &schema, [product![id_for_a]], []).await;
1659        assert_tx_update_for_table(&mut rx_for_b, w_id, &schema, [product![id_for_b]], []).await;
1660        Ok(())
1661    }
1662
1663    /// Test that a client and the database owner can subscribe to the same query
1664    #[tokio::test]
1665    async fn test_rls_for_owner() -> anyhow::Result<()> {
1666        // Establish a connection for owner and client
1667        let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(0));
1668        let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(1));
1669
1670        let db = relational_db()?;
1671        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
1672
1673        // Create table `t`
1674        let table_id = db.create_table_for_test("t", &[("id", AlgebraicType::identity())], &[0.into()])?;
1675
1676        // Restrict access to `t`
1677        insert_rls_rules(&db, [table_id], ["select * from t where id = :sender"])?;
1678
1679        let mut query_ids = 0;
1680
1681        // Have owner and client subscribe to `t`
1682        subscribe_multi(&subs, &["select * from t"], tx_for_a, &mut query_ids)?;
1683        subscribe_multi(&subs, &["select * from t"], tx_for_b, &mut query_ids)?;
1684
1685        // Wait for both subscriptions
1686        assert_matches!(
1687            rx_for_a.recv().await,
1688            Some(SerializableMessage::Subscription(SubscriptionMessage {
1689                result: SubscriptionResult::SubscribeMulti(_),
1690                ..
1691            }))
1692        );
1693        assert_matches!(
1694            rx_for_b.recv().await,
1695            Some(SerializableMessage::Subscription(SubscriptionMessage {
1696                result: SubscriptionResult::SubscribeMulti(_),
1697                ..
1698            }))
1699        );
1700
1701        let schema = ProductType::from([AlgebraicType::identity()]);
1702
1703        let id_for_b = identity_from_u8(1);
1704        let id_for_c = identity_from_u8(2);
1705
1706        commit_tx(
1707            &db,
1708            &subs,
1709            [],
1710            [
1711                // Insert an identity for client `b` plus a random identity
1712                (table_id, product![id_for_b]),
1713                (table_id, product![id_for_c]),
1714            ],
1715        )?;
1716
1717        assert_tx_update_for_table(
1718            &mut rx_for_a,
1719            table_id,
1720            &schema,
1721            // The owner should receive both identities
1722            [product![id_for_b], product![id_for_c]],
1723            [],
1724        )
1725        .await;
1726
1727        assert_tx_update_for_table(
1728            &mut rx_for_b,
1729            table_id,
1730            &schema,
1731            // Client `b` should only receive its identity
1732            [product![id_for_b]],
1733            [],
1734        )
1735        .await;
1736
1737        Ok(())
1738    }
1739
1740    /// Test that we do not send empty updates to clients
1741    #[tokio::test]
1742    async fn test_no_empty_updates() -> anyhow::Result<()> {
1743        // Establish a client connection
1744        let (tx, mut rx) = client_connection(client_id_from_u8(1));
1745
1746        let db = relational_db()?;
1747        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
1748
1749        let schema = [("x", AlgebraicType::U8)];
1750
1751        let t_id = db.create_table_for_test("t", &schema, &[])?;
1752
1753        // Subscribe to rows of `t` where `x` is 0
1754        subscribe_multi(&subs, &["select * from t where x = 0"], tx, &mut 0)?;
1755
1756        // Wait to receive the initial subscription message
1757        assert!(matches!(rx.recv().await, Some(SerializableMessage::Subscription(_))));
1758
1759        // Insert a row that does not match the query
1760        let mut tx = begin_mut_tx(&db);
1761        db.insert(&mut tx, t_id, &bsatn::to_vec(&product![1_u8])?)?;
1762
1763        assert!(matches!(
1764            subs.commit_and_broadcast_event(None, module_event(), tx),
1765            Ok(Ok(_))
1766        ));
1767
1768        // Insert a row that does match the query
1769        let mut tx = begin_mut_tx(&db);
1770        db.insert(&mut tx, t_id, &bsatn::to_vec(&product![0_u8])?)?;
1771
1772        assert!(matches!(
1773            subs.commit_and_broadcast_event(None, module_event(), tx),
1774            Ok(Ok(_))
1775        ));
1776
1777        let schema = ProductType::from([AlgebraicType::U8]);
1778
1779        // If the server sends empty updates, this assertion will fail,
1780        // because we will receive one for the first transaction.
1781        assert_tx_update_for_table(&mut rx, t_id, &schema, [product![0_u8]], []).await;
1782        Ok(())
1783    }
1784
1785    /// Test that we do not compress within a [SubscriptionMessage].
1786    /// The message itself is compressed before being sent over the wire,
1787    /// but we don't care about that for this test.
1788    ///
1789    /// Needs a multi-threaded tokio runtime so that the module subscription worker can run in parallel.
1790    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
1791    async fn test_no_compression_for_subscribe() -> anyhow::Result<()> {
1792        // Establish a client connection with compression
1793        let (tx, mut rx) = client_connection_with_compression(client_id_from_u8(1), Compression::Brotli);
1794
1795        let db = relational_db()?;
1796        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
1797
1798        let table_id = db.create_table_for_test("t", &[("x", AlgebraicType::U64)], &[])?;
1799
1800        let mut inserts = vec![];
1801
1802        for i in 0..16_000u64 {
1803            inserts.push((table_id, product![i]));
1804        }
1805
1806        // Insert a lot of rows into `t`.
1807        // We want to insert enough to cross any threshold there might be for compression.
1808        commit_tx(&db, &subs, [], inserts)?;
1809
1810        // Subscribe to the entire table
1811        subscribe_multi(&subs, &["select * from t"], tx, &mut 0)?;
1812
1813        // Assert the table updates within this message are all be uncompressed
1814        match rx.recv().await {
1815            Some(SerializableMessage::Subscription(SubscriptionMessage {
1816                result:
1817                    SubscriptionResult::SubscribeMulti(SubscriptionData {
1818                        data: FormatSwitch::Bsatn(ws::DatabaseUpdate { tables }),
1819                    }),
1820                ..
1821            })) => {
1822                assert!(tables.iter().all(|TableUpdate { updates, .. }| updates
1823                    .iter()
1824                    .all(|query_update| matches!(query_update, CompressableQueryUpdate::Uncompressed(_)))));
1825            }
1826            Some(_) => panic!("unexpected message from subscription"),
1827            None => panic!("channel unexpectedly closed"),
1828        };
1829
1830        Ok(())
1831    }
1832
1833    /// Test that we receive subscription updates for DML
1834    #[tokio::test]
1835    async fn test_updates_for_dml() -> anyhow::Result<()> {
1836        // Establish a client connection
1837        let (tx, mut rx) = client_connection(client_id_from_u8(1));
1838
1839        let db = relational_db()?;
1840        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
1841        let schema = [("x", AlgebraicType::U8), ("y", AlgebraicType::U8)];
1842        let t_id = db.create_table_for_test("t", &schema, &[])?;
1843
1844        // Subscribe to `t`
1845        subscribe_multi(&subs, &["select * from t"], tx, &mut 0)?;
1846
1847        // Wait to receive the initial subscription message
1848        assert_matches!(rx.recv().await, Some(SerializableMessage::Subscription(_)));
1849
1850        let schema = ProductType::from([AlgebraicType::U8, AlgebraicType::U8]);
1851
1852        // Only the owner can invoke DML commands
1853        let auth = AuthCtx::new(identity_from_u8(0), identity_from_u8(0));
1854
1855        run(
1856            &db,
1857            "INSERT INTO t (x, y) VALUES (0, 1)",
1858            auth,
1859            Some(&subs),
1860            &mut vec![],
1861        )?;
1862
1863        // Client should receive insert
1864        assert_tx_update_for_table(&mut rx, t_id, &schema, [product![0_u8, 1_u8]], []).await;
1865
1866        run(&db, "UPDATE t SET y=2 WHERE x=0", auth, Some(&subs), &mut vec![])?;
1867
1868        // Client should receive update
1869        assert_tx_update_for_table(&mut rx, t_id, &schema, [product![0_u8, 2_u8]], [product![0_u8, 1_u8]]).await;
1870
1871        run(&db, "DELETE FROM t WHERE x=0", auth, Some(&subs), &mut vec![])?;
1872
1873        // Client should receive delete
1874        assert_tx_update_for_table(&mut rx, t_id, &schema, [], [product![0_u8, 2_u8]]).await;
1875        Ok(())
1876    }
1877
1878    /// Test that we do not compress within a [TransactionUpdateMessage].
1879    /// The message itself is compressed before being sent over the wire,
1880    /// but we don't care about that for this test.
1881    #[tokio::test]
1882    async fn test_no_compression_for_update() -> anyhow::Result<()> {
1883        // Establish a client connection with compression
1884        let (tx, mut rx) = client_connection_with_compression(client_id_from_u8(1), Compression::Brotli);
1885
1886        let db = relational_db()?;
1887        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
1888
1889        let table_id = db.create_table_for_test("t", &[("x", AlgebraicType::U64)], &[])?;
1890
1891        let mut inserts = vec![];
1892
1893        for i in 0..16_000u64 {
1894            inserts.push((table_id, product![i]));
1895        }
1896
1897        // Subscribe to the entire table
1898        subscribe_multi(&subs, &["select * from t"], tx, &mut 0)?;
1899
1900        // Wait to receive the initial subscription message
1901        assert!(matches!(rx.recv().await, Some(SerializableMessage::Subscription(_))));
1902
1903        // Insert a lot of rows into `t`.
1904        // We want to insert enough to cross any threshold there might be for compression.
1905        commit_tx(&db, &subs, [], inserts)?;
1906
1907        // Assert the table updates within this message are all be uncompressed
1908        match rx.recv().await {
1909            Some(SerializableMessage::TxUpdate(TransactionUpdateMessage {
1910                database_update:
1911                    SubscriptionUpdateMessage {
1912                        database_update: FormatSwitch::Bsatn(ws::DatabaseUpdate { tables }),
1913                        ..
1914                    },
1915                ..
1916            })) => {
1917                assert!(tables.iter().all(|TableUpdate { updates, .. }| updates
1918                    .iter()
1919                    .all(|query_update| matches!(query_update, CompressableQueryUpdate::Uncompressed(_)))));
1920            }
1921            Some(_) => panic!("unexpected message from subscription"),
1922            None => panic!("channel unexpectedly closed"),
1923        };
1924
1925        Ok(())
1926    }
1927
1928    /// In this test we subscribe to a join query, update the lhs table,
1929    /// and assert that the server sends the correct delta to the client.
1930    #[tokio::test]
1931    async fn test_update_for_join() -> anyhow::Result<()> {
1932        async fn test_subscription_updates(queries: &[&'static str]) -> anyhow::Result<()> {
1933            // Establish a client connection
1934            let (sender, mut rx) = client_connection(client_id_from_u8(1));
1935
1936            let db = relational_db()?;
1937            let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
1938
1939            let p_schema = [("id", AlgebraicType::U64), ("signed_in", AlgebraicType::Bool)];
1940            let l_schema = [
1941                ("id", AlgebraicType::U64),
1942                ("x", AlgebraicType::U64),
1943                ("z", AlgebraicType::U64),
1944            ];
1945
1946            let p_id = db.create_table_for_test("p", &p_schema, &[0.into()])?;
1947            let l_id = db.create_table_for_test("l", &l_schema, &[0.into()])?;
1948
1949            subscribe_multi(&subs, queries, sender, &mut 0)?;
1950
1951            assert!(matches!(rx.recv().await, Some(SerializableMessage::Subscription(_))));
1952
1953            // Insert two matching player rows
1954            commit_tx(
1955                &db,
1956                &subs,
1957                [],
1958                [
1959                    (p_id, product![1_u64, true]),
1960                    (p_id, product![2_u64, true]),
1961                    (l_id, product![1_u64, 2_u64, 2_u64]),
1962                    (l_id, product![2_u64, 3_u64, 3_u64]),
1963                ],
1964            )?;
1965
1966            let schema = ProductType::from(p_schema);
1967
1968            // We should receive both matching player rows
1969            assert_tx_update_for_table(
1970                &mut rx,
1971                p_id,
1972                &schema,
1973                [product![1_u64, true], product![2_u64, true]],
1974                [],
1975            )
1976            .await;
1977
1978            // Update one of the matching player rows
1979            commit_tx(
1980                &db,
1981                &subs,
1982                [(p_id, product![2_u64, true])],
1983                [(p_id, product![2_u64, false])],
1984            )?;
1985
1986            // We should receive an update for it because it is still matching
1987            assert_tx_update_for_table(
1988                &mut rx,
1989                p_id,
1990                &schema,
1991                [product![2_u64, false]],
1992                [product![2_u64, true]],
1993            )
1994            .await;
1995
1996            // Update the the same matching player row
1997            commit_tx(
1998                &db,
1999                &subs,
2000                [(p_id, product![2_u64, false])],
2001                [(p_id, product![2_u64, true])],
2002            )?;
2003
2004            // We should receive an update for it because it is still matching
2005            assert_tx_update_for_table(
2006                &mut rx,
2007                p_id,
2008                &schema,
2009                [product![2_u64, true]],
2010                [product![2_u64, false]],
2011            )
2012            .await;
2013
2014            Ok(())
2015        }
2016
2017        test_subscription_updates(&[
2018            "select * from p where id = 1",
2019            "select p.* from p join l on p.id = l.id where l.x > 0 and l.x < 5 and l.z > 0 and l.z < 5",
2020        ])
2021        .await?;
2022        test_subscription_updates(&[
2023            "select * from p where id = 1",
2024            "select p.* from p join l on p.id = l.id where 0 < l.x and l.x < 5 and 0 < l.z and l.z < 5",
2025        ])
2026        .await?;
2027        test_subscription_updates(&[
2028            "select * from p where id = 1",
2029            "select p.* from p join l on p.id = l.id where l.x > 0 and l.x < 5 and l.x > 0 and l.z < 5 and l.id != 1",
2030        ])
2031        .await?;
2032        test_subscription_updates(&[
2033            "select * from p where id = 1",
2034            "select p.* from p join l on p.id = l.id where 0 < l.x and l.x < 5 and 0 < l.z and l.z < 5 and l.id != 1",
2035        ])
2036        .await?;
2037
2038        Ok(())
2039    }
2040
2041    /// Test that we do not evaluate queries that we know will not match table update rows
2042    ///
2043    /// Needs a multi-threaded tokio runtime so that the module subscription worker can run in parallel.
2044    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
2045    async fn test_query_pruning() -> anyhow::Result<()> {
2046        // Establish a connection for each client
2047        let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1));
2048        let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2));
2049
2050        let db = relational_db()?;
2051        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
2052
2053        let u_id = db.create_table_for_test(
2054            "u",
2055            &[
2056                ("i", AlgebraicType::U64),
2057                ("a", AlgebraicType::U64),
2058                ("b", AlgebraicType::U64),
2059            ],
2060            &[0.into()],
2061        )?;
2062        let v_id = db.create_table_for_test(
2063            "v",
2064            &[
2065                ("i", AlgebraicType::U64),
2066                ("x", AlgebraicType::U64),
2067                ("y", AlgebraicType::U64),
2068            ],
2069            &[0.into(), 1.into()],
2070        )?;
2071
2072        commit_tx(
2073            &db,
2074            &subs,
2075            [],
2076            [
2077                (u_id, product![0u64, 1u64, 1u64]),
2078                (u_id, product![1u64, 2u64, 2u64]),
2079                (u_id, product![2u64, 3u64, 3u64]),
2080                (v_id, product![0u64, 4u64, 4u64]),
2081                (v_id, product![1u64, 5u64, 5u64]),
2082            ],
2083        )?;
2084
2085        let mut query_ids = 0;
2086
2087        // Returns (i: 0, a: 1, b: 1)
2088        subscribe_multi(
2089            &subs,
2090            &[
2091                "select u.* from u join v on u.i = v.i where v.x = 4",
2092                "select u.* from u join v on u.i = v.i where v.x = 6",
2093            ],
2094            tx_for_a,
2095            &mut query_ids,
2096        )?;
2097
2098        // Returns (i: 1, a: 2, b: 2)
2099        subscribe_multi(
2100            &subs,
2101            &[
2102                "select u.* from u join v on u.i = v.i where v.x = 5",
2103                "select u.* from u join v on u.i = v.i where v.x = 7",
2104            ],
2105            tx_for_b,
2106            &mut query_ids,
2107        )?;
2108
2109        // Wait for both subscriptions
2110        assert!(matches!(
2111            rx_for_a.recv().await,
2112            Some(SerializableMessage::Subscription(SubscriptionMessage {
2113                result: SubscriptionResult::SubscribeMulti(_),
2114                ..
2115            }))
2116        ));
2117        assert!(matches!(
2118            rx_for_b.recv().await,
2119            Some(SerializableMessage::Subscription(SubscriptionMessage {
2120                result: SubscriptionResult::SubscribeMulti(_),
2121                ..
2122            }))
2123        ));
2124
2125        // Modify a single row in `v`
2126        let metrics = commit_tx(
2127            &db,
2128            &subs,
2129            [(v_id, product![1u64, 5u64, 5u64])],
2130            [(v_id, product![1u64, 5u64, 6u64])],
2131        )?;
2132
2133        // We should only have evaluated a single query
2134        assert_eq!(metrics.delta_queries_evaluated, 1);
2135        assert_eq!(metrics.delta_queries_matched, 0);
2136
2137        // Insert a new row into `v`
2138        let metrics = commit_tx(&db, &subs, [], [(v_id, product![2u64, 6u64, 6u64])])?;
2139
2140        assert_tx_update_for_table(
2141            &mut rx_for_a,
2142            u_id,
2143            &ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]),
2144            [product![2u64, 3u64, 3u64]],
2145            [],
2146        )
2147        .await;
2148
2149        // We should only have evaluated a single query
2150        assert_eq!(metrics.delta_queries_evaluated, 1);
2151        assert_eq!(metrics.delta_queries_matched, 1);
2152
2153        // Modify a matching row in `u`
2154        let metrics = commit_tx(
2155            &db,
2156            &subs,
2157            [(u_id, product![1u64, 2u64, 2u64])],
2158            [(u_id, product![1u64, 2u64, 3u64])],
2159        )?;
2160
2161        assert_tx_update_for_table(
2162            &mut rx_for_b,
2163            u_id,
2164            &ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]),
2165            [product![1u64, 2u64, 3u64]],
2166            [product![1u64, 2u64, 2u64]],
2167        )
2168        .await;
2169
2170        // We should have evaluated all of the queries
2171        assert_eq!(metrics.delta_queries_evaluated, 4);
2172        assert_eq!(metrics.delta_queries_matched, 1);
2173
2174        // Insert a non-matching row in `u`
2175        let metrics = commit_tx(&db, &subs, [], [(u_id, product![3u64, 0u64, 0u64])])?;
2176
2177        // We should have evaluated all of the queries
2178        assert_eq!(metrics.delta_queries_evaluated, 4);
2179        assert_eq!(metrics.delta_queries_matched, 0);
2180
2181        Ok(())
2182    }
2183
2184    /// Test that we do not evaluate queries that we know will not match row updates
2185    #[tokio::test]
2186    async fn test_join_pruning() -> anyhow::Result<()> {
2187        let (tx, mut rx) = client_connection(client_id_from_u8(1));
2188
2189        let db = relational_db()?;
2190        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
2191
2192        let u_id = db.create_table_for_test_with_the_works(
2193            "u",
2194            &[
2195                ("i", AlgebraicType::U64),
2196                ("a", AlgebraicType::U64),
2197                ("b", AlgebraicType::U64),
2198            ],
2199            &[0.into()],
2200            // The join column for this table does not have to be unique,
2201            // because pruning only requires us to probe the join index on `v`.
2202            &[],
2203            StAccess::Public,
2204        )?;
2205        let v_id = db.create_table_for_test_with_the_works(
2206            "v",
2207            &[
2208                ("i", AlgebraicType::U64),
2209                ("x", AlgebraicType::U64),
2210                ("y", AlgebraicType::U64),
2211            ],
2212            &[0.into(), 1.into()],
2213            &[0.into()],
2214            StAccess::Public,
2215        )?;
2216
2217        let schema = ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]);
2218
2219        commit_tx(
2220            &db,
2221            &subs,
2222            [],
2223            [
2224                (v_id, product![1u64, 1u64, 1u64]),
2225                (v_id, product![2u64, 2u64, 2u64]),
2226                (v_id, product![3u64, 3u64, 3u64]),
2227                (v_id, product![4u64, 4u64, 4u64]),
2228                (v_id, product![5u64, 5u64, 5u64]),
2229            ],
2230        )?;
2231
2232        let mut query_ids = 0;
2233
2234        subscribe_multi(
2235            &subs,
2236            &[
2237                "select u.* from u join v on u.i = v.i where v.x = 1",
2238                "select u.* from u join v on u.i = v.i where v.x = 2",
2239                "select u.* from u join v on u.i = v.i where v.x = 3",
2240                "select u.* from u join v on u.i = v.i where v.x = 4",
2241                "select u.* from u join v on u.i = v.i where v.x = 5",
2242            ],
2243            tx,
2244            &mut query_ids,
2245        )?;
2246
2247        assert_matches!(
2248            rx.recv().await,
2249            Some(SerializableMessage::Subscription(SubscriptionMessage {
2250                result: SubscriptionResult::SubscribeMulti(_),
2251                ..
2252            }))
2253        );
2254
2255        // Insert a new row into `u` that joins with `x = 1`
2256        let metrics = commit_tx(&db, &subs, [], [(u_id, product![1u64, 2u64, 3u64])])?;
2257
2258        assert_tx_update_for_table(&mut rx, u_id, &schema, [product![1u64, 2u64, 3u64]], []).await;
2259
2260        // We should only have evaluated a single query
2261        assert_eq!(metrics.delta_queries_evaluated, 1);
2262        assert_eq!(metrics.delta_queries_matched, 1);
2263
2264        // UPDATE v SET y = 2 WHERE id = 1
2265        let metrics = commit_tx(
2266            &db,
2267            &subs,
2268            [(v_id, product![1u64, 1u64, 1u64])],
2269            [(v_id, product![1u64, 1u64, 2u64])],
2270        )?;
2271
2272        // We should only have evaluated a single query
2273        assert_eq!(metrics.delta_queries_evaluated, 1);
2274        assert_eq!(metrics.delta_queries_matched, 0);
2275
2276        // UPDATE v SET x = 2 WHERE id = 1
2277        let metrics = commit_tx(
2278            &db,
2279            &subs,
2280            [(v_id, product![1u64, 1u64, 2u64])],
2281            [(v_id, product![1u64, 2u64, 2u64])],
2282        )?;
2283
2284        // Results in a no-op
2285        assert_tx_update_for_table(&mut rx, u_id, &schema, [], []).await;
2286
2287        // We should have evaluated queries for `x = 1` and `x = 2`
2288        assert_eq!(metrics.delta_queries_evaluated, 2);
2289        assert_eq!(metrics.delta_queries_matched, 2);
2290
2291        // Insert new row into `u` that joins with `x = 3`
2292        // UPDATE v SET x = 4 WHERE id = 3
2293        let metrics = commit_tx(
2294            &db,
2295            &subs,
2296            [(v_id, product![3u64, 3u64, 3u64])],
2297            [(v_id, product![3u64, 4u64, 3u64]), (u_id, product![3u64, 4u64, 5u64])],
2298        )?;
2299
2300        assert_tx_update_for_table(&mut rx, u_id, &schema, [product![3u64, 4u64, 5u64]], []).await;
2301
2302        // We should have evaluated queries for `x = 3` and `x = 4`
2303        assert_eq!(metrics.delta_queries_evaluated, 2);
2304        assert_eq!(metrics.delta_queries_matched, 1);
2305
2306        // UPDATE v SET x = 0 WHERE id = 3
2307        let metrics = commit_tx(
2308            &db,
2309            &subs,
2310            [(v_id, product![3u64, 4u64, 3u64])],
2311            [(v_id, product![3u64, 0u64, 3u64])],
2312        )?;
2313
2314        assert_tx_update_for_table(&mut rx, u_id, &schema, [], [product![3u64, 4u64, 5u64]]).await;
2315
2316        // We should only have evaluated the query for `x = 4`
2317        assert_eq!(metrics.delta_queries_evaluated, 1);
2318        assert_eq!(metrics.delta_queries_matched, 1);
2319
2320        // Insert new row into `u` that joins with `x = 5`
2321        // UPDATE v SET x = 6 WHERE id = 5
2322        // Should result in a no-op
2323        let metrics = commit_tx(
2324            &db,
2325            &subs,
2326            [(v_id, product![5u64, 5u64, 5u64])],
2327            [(v_id, product![5u64, 6u64, 6u64]), (u_id, product![5u64, 6u64, 7u64])],
2328        )?;
2329
2330        // We should only have evaluated the query for `x = 5`
2331        assert_eq!(metrics.delta_queries_evaluated, 1);
2332        assert_eq!(metrics.delta_queries_matched, 0);
2333
2334        Ok(())
2335    }
2336
2337    /// Test that one client subscribing does not affect another
2338    #[tokio::test]
2339    async fn test_subscribe_distinct_queries_same_plan() -> anyhow::Result<()> {
2340        // Establish a connection for each client
2341        let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1));
2342        let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2));
2343
2344        let db = relational_db()?;
2345        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
2346
2347        let u_id = db.create_table_for_test_with_the_works(
2348            "u",
2349            &[
2350                ("i", AlgebraicType::U64),
2351                ("a", AlgebraicType::U64),
2352                ("b", AlgebraicType::U64),
2353            ],
2354            &[0.into()],
2355            // The join column for this table does not have to be unique,
2356            // because pruning only requires us to probe the join index on `v`.
2357            &[],
2358            StAccess::Public,
2359        )?;
2360        let v_id = db.create_table_for_test_with_the_works(
2361            "v",
2362            &[
2363                ("i", AlgebraicType::U64),
2364                ("x", AlgebraicType::U64),
2365                ("y", AlgebraicType::U64),
2366            ],
2367            &[0.into(), 1.into()],
2368            &[0.into()],
2369            StAccess::Public,
2370        )?;
2371
2372        commit_tx(&db, &subs, [], [(v_id, product![1u64, 1u64, 1u64])])?;
2373
2374        let mut query_ids = 0;
2375
2376        // Both clients subscribe to the same query modulo whitespace
2377        subscribe_multi(
2378            &subs,
2379            &["select u.* from u join v on u.i = v.i where v.x = 1"],
2380            tx_for_a,
2381            &mut query_ids,
2382        )?;
2383        subscribe_multi(
2384            &subs,
2385            &["select u.* from u join v on u.i = v.i where v.x =  1"],
2386            tx_for_b.clone(),
2387            &mut query_ids,
2388        )?;
2389
2390        // Wait for both subscriptions
2391        assert_matches!(
2392            rx_for_a.recv().await,
2393            Some(SerializableMessage::Subscription(SubscriptionMessage {
2394                result: SubscriptionResult::SubscribeMulti(_),
2395                ..
2396            }))
2397        );
2398        assert_matches!(
2399            rx_for_b.recv().await,
2400            Some(SerializableMessage::Subscription(SubscriptionMessage {
2401                result: SubscriptionResult::SubscribeMulti(_),
2402                ..
2403            }))
2404        );
2405
2406        // Insert a new row into `u`
2407        commit_tx(&db, &subs, [], [(u_id, product![1u64, 0u64, 0u64])])?;
2408
2409        assert_tx_update_for_table(
2410            &mut rx_for_a,
2411            u_id,
2412            &ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]),
2413            [product![1u64, 0u64, 0u64]],
2414            [],
2415        )
2416        .await;
2417
2418        assert_tx_update_for_table(
2419            &mut rx_for_b,
2420            u_id,
2421            &ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]),
2422            [product![1u64, 0u64, 0u64]],
2423            [],
2424        )
2425        .await;
2426
2427        Ok(())
2428    }
2429
2430    /// Test that one client unsubscribing does not affect another
2431    #[tokio::test]
2432    async fn test_unsubscribe_distinct_queries_same_plan() -> anyhow::Result<()> {
2433        // Establish a connection for each client
2434        let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1));
2435        let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2));
2436
2437        let db = relational_db()?;
2438        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
2439
2440        let u_id = db.create_table_for_test_with_the_works(
2441            "u",
2442            &[
2443                ("i", AlgebraicType::U64),
2444                ("a", AlgebraicType::U64),
2445                ("b", AlgebraicType::U64),
2446            ],
2447            &[0.into()],
2448            // The join column for this table does not have to be unique,
2449            // because pruning only requires us to probe the join index on `v`.
2450            &[],
2451            StAccess::Public,
2452        )?;
2453        let v_id = db.create_table_for_test_with_the_works(
2454            "v",
2455            &[
2456                ("i", AlgebraicType::U64),
2457                ("x", AlgebraicType::U64),
2458                ("y", AlgebraicType::U64),
2459            ],
2460            &[0.into(), 1.into()],
2461            &[0.into()],
2462            StAccess::Public,
2463        )?;
2464
2465        commit_tx(&db, &subs, [], [(v_id, product![1u64, 1u64, 1u64])])?;
2466
2467        let mut query_ids = 0;
2468
2469        subscribe_multi(
2470            &subs,
2471            &["select u.* from u join v on u.i = v.i where v.x = 1"],
2472            tx_for_a,
2473            &mut query_ids,
2474        )?;
2475        subscribe_multi(
2476            &subs,
2477            &["select u.* from u join v on u.i = v.i where  v.x = 1"],
2478            tx_for_b.clone(),
2479            &mut query_ids,
2480        )?;
2481
2482        // Wait for both subscriptions
2483        assert_matches!(
2484            rx_for_a.recv().await,
2485            Some(SerializableMessage::Subscription(SubscriptionMessage {
2486                result: SubscriptionResult::SubscribeMulti(_),
2487                ..
2488            }))
2489        );
2490        assert_matches!(
2491            rx_for_b.recv().await,
2492            Some(SerializableMessage::Subscription(SubscriptionMessage {
2493                result: SubscriptionResult::SubscribeMulti(_),
2494                ..
2495            }))
2496        );
2497
2498        unsubscribe_multi(&subs, tx_for_b, query_ids)?;
2499
2500        assert_matches!(
2501            rx_for_b.recv().await,
2502            Some(SerializableMessage::Subscription(SubscriptionMessage {
2503                result: SubscriptionResult::UnsubscribeMulti(_),
2504                ..
2505            }))
2506        );
2507
2508        // Insert a new row into `u`
2509        let metrics = commit_tx(&db, &subs, [], [(u_id, product![1u64, 0u64, 0u64])])?;
2510
2511        assert_tx_update_for_table(
2512            &mut rx_for_a,
2513            u_id,
2514            &ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]),
2515            [product![1u64, 0u64, 0u64]],
2516            [],
2517        )
2518        .await;
2519
2520        // We should only have evaluated a single query
2521        assert_eq!(metrics.delta_queries_evaluated, 1);
2522        assert_eq!(metrics.delta_queries_matched, 1);
2523
2524        // Modify a matching row in `v`
2525        let metrics = commit_tx(
2526            &db,
2527            &subs,
2528            [(v_id, product![1u64, 1u64, 1u64])],
2529            [(v_id, product![1u64, 2u64, 2u64])],
2530        )?;
2531
2532        // We should only have evaluated a single query
2533        assert_eq!(metrics.delta_queries_evaluated, 1);
2534        assert_eq!(metrics.delta_queries_matched, 1);
2535
2536        Ok(())
2537    }
2538
2539    /// Test that we do not evaluate queries that return trivially empty results
2540    ///
2541    /// Needs a multi-threaded tokio runtime so that the module subscription worker can run in parallel.
2542    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
2543    async fn test_query_pruning_for_empty_tables() -> anyhow::Result<()> {
2544        // Establish a client connection
2545        let (tx, mut rx) = client_connection(client_id_from_u8(1));
2546
2547        let db = relational_db()?;
2548        let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone());
2549
2550        let schema = &[("id", AlgebraicType::U64), ("a", AlgebraicType::U64)];
2551        let indices = &[0.into()];
2552        // Create tables `t` and `s` with `(i: u64, a: u64)`.
2553        db.create_table_for_test("t", schema, indices)?;
2554        let s_id = db.create_table_for_test("s", schema, indices)?;
2555
2556        // Insert one row into `s`, but leave `t` empty.
2557        commit_tx(&db, &subs, [], [(s_id, product![0u64, 0u64])])?;
2558
2559        // Subscribe to queries that return empty results
2560        let metrics = subscribe_multi(
2561            &subs,
2562            &[
2563                "select t.* from t where a = 0",
2564                "select t.* from t join s on t.id = s.id where s.a = 0",
2565                "select s.* from t join s on t.id = s.id where t.a = 0",
2566            ],
2567            tx,
2568            &mut 0,
2569        )?;
2570
2571        assert_matches!(
2572            rx.recv().await,
2573            Some(SerializableMessage::Subscription(SubscriptionMessage {
2574                result: SubscriptionResult::SubscribeMulti(_),
2575                ..
2576            }))
2577        );
2578
2579        assert_eq!(metrics.rows_scanned, 0);
2580        assert_eq!(metrics.index_seeks, 0);
2581
2582        Ok(())
2583    }
2584
2585    /// Asserts that a subscription holds a tx handle for the entire length of its evaluation.
2586    #[test]
2587    fn test_tx_subscription_ordering() -> ResultTest<()> {
2588        let test_db = TestDB::durable()?;
2589
2590        let runtime = test_db.runtime().cloned().unwrap();
2591        let db = Arc::new(test_db.db.clone());
2592
2593        // Create table with one row
2594        let table_id = db.create_table_for_test("T", &[("a", AlgebraicType::U8)], &[])?;
2595        with_auto_commit(&db, |tx| insert(&db, tx, table_id, &product!(1_u8)).map(drop))?;
2596
2597        let (send, mut recv) = mpsc::unbounded_channel();
2598
2599        // Subscribing to T should return a single row.
2600        let db2 = db.clone();
2601        let query_handle = runtime.spawn_blocking(move || {
2602            add_subscriber(
2603                db.clone(),
2604                "select * from T",
2605                Some(Arc::new(move |tx: &_| {
2606                    // Wake up writer thread after starting the reader tx
2607                    let _ = send.send(());
2608                    // Then go to sleep
2609                    std::thread::sleep(Duration::from_secs(1));
2610                    // Assuming subscription evaluation holds a lock on the db,
2611                    // any mutations to T will necessarily occur after,
2612                    // and therefore we should only see a single row returned.
2613                    assert_eq!(1, db.iter(tx, table_id).unwrap().count());
2614                })),
2615            )
2616        });
2617
2618        // Write a second row to T concurrently with the reader thread
2619        let write_handle = runtime.spawn(async move {
2620            let _ = recv.recv().await;
2621            with_auto_commit(&db2, |tx| insert(&db2, tx, table_id, &product!(2_u8)).map(drop))
2622        });
2623
2624        runtime.block_on(write_handle)??;
2625        runtime.block_on(query_handle)??;
2626
2627        test_db.close()?;
2628
2629        Ok(())
2630    }
2631
2632    #[test]
2633    fn subs_cannot_access_private_tables() -> ResultTest<()> {
2634        let test_db = TestDB::durable()?;
2635        let db = Arc::new(test_db.db.clone());
2636
2637        // Create a public table.
2638        let indexes = &[0.into()];
2639        let cols = &[("a", AlgebraicType::U8)];
2640        let _ = db.create_table_for_test("public", cols, indexes)?;
2641
2642        // Create a private table.
2643        let _ = db.create_table_for_test_with_access("private", cols, indexes, StAccess::Private)?;
2644
2645        // We can subscribe to a public table.
2646        let subscribe = |sql| add_subscriber(db.clone(), sql, None);
2647        assert!(subscribe("SELECT * FROM public").is_ok());
2648
2649        // We cannot subscribe when a private table is mentioned,
2650        // not even when in a join where the projection doesn't mention the table,
2651        // as the mere fact of joining can leak information from the private table.
2652        for sql in [
2653            "SELECT * FROM private",
2654            // Even if the query will return no rows, we still reject it.
2655            "SELECT * FROM private WHERE false",
2656            "SELECT private.* FROM private",
2657            "SELECT public.* FROM public JOIN private ON public.a = private.a WHERE private.a = 1",
2658            "SELECT private.* FROM private JOIN public ON private.a = public.a WHERE public.a = 1",
2659        ] {
2660            assert!(subscribe(sql).is_err(),);
2661        }
2662
2663        Ok(())
2664    }
2665}