spacetimedb/subscription/
module_subscription_manager.rs

1use super::execution_unit::QueryHash;
2use super::tx::DeltaTx;
3use crate::client::messages::{
4    SerializableMessage, SubscriptionError, SubscriptionMessage, SubscriptionResult, SubscriptionUpdateMessage,
5    TransactionUpdateMessage,
6};
7use crate::client::{ClientConnectionSender, Protocol};
8use crate::error::DBError;
9use crate::host::module_host::{DatabaseTableUpdate, ModuleEvent, UpdatesRelValue};
10use crate::messages::websocket::{self as ws, TableUpdate};
11use crate::subscription::delta::eval_delta;
12use crate::worker_metrics::WORKER_METRICS;
13use core::mem;
14use hashbrown::hash_map::OccupiedError;
15use hashbrown::{HashMap, HashSet};
16use parking_lot::RwLock;
17use prometheus::IntGauge;
18use spacetimedb_client_api_messages::websocket::{
19    BsatnFormat, CompressableQueryUpdate, FormatSwitch, JsonFormat, QueryId, QueryUpdate, SingleQueryUpdate,
20    WebsocketFormat,
21};
22use spacetimedb_data_structures::map::{Entry, IntMap};
23use spacetimedb_datastore::locking_tx_datastore::state_view::StateView;
24use spacetimedb_lib::metrics::ExecutionMetrics;
25use spacetimedb_lib::{AlgebraicValue, ConnectionId, Identity, ProductValue};
26use spacetimedb_primitives::{ColId, IndexId, TableId};
27use spacetimedb_subscription::{JoinEdge, SubscriptionPlan, TableName};
28use std::collections::BTreeMap;
29use std::fmt::Debug;
30use std::sync::atomic::{AtomicBool, Ordering};
31use std::sync::Arc;
32use tokio::sync::mpsc;
33
34/// Clients are uniquely identified by their Identity and ConnectionId.
35/// Identity is insufficient because different ConnectionIds can use the same Identity.
36/// TODO: Determine if ConnectionId is sufficient for uniquely identifying a client.
37type ClientId = (Identity, ConnectionId);
38type Query = Arc<Plan>;
39type Client = Arc<ClientConnectionSender>;
40type SwitchedTableUpdate = FormatSwitch<TableUpdate<BsatnFormat>, TableUpdate<JsonFormat>>;
41type SwitchedDbUpdate = FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>;
42
43/// ClientQueryId is an identifier for a query set by the client.
44type ClientQueryId = QueryId;
45/// SubscriptionId is a globally unique identifier for a subscription.
46type SubscriptionId = (ClientId, ClientQueryId);
47
48#[derive(Debug)]
49pub struct Plan {
50    hash: QueryHash,
51    sql: String,
52    plans: Vec<SubscriptionPlan>,
53}
54
55impl Plan {
56    /// Create a new subscription plan to be cached
57    pub fn new(plans: Vec<SubscriptionPlan>, hash: QueryHash, text: String) -> Self {
58        Self { plans, hash, sql: text }
59    }
60
61    /// Returns the query hash for this subscription
62    pub fn hash(&self) -> QueryHash {
63        self.hash
64    }
65
66    /// A subscription query return rows from a single table.
67    /// This method returns the id of that table.
68    pub fn subscribed_table_id(&self) -> TableId {
69        self.plans[0].subscribed_table_id()
70    }
71
72    /// A subscription query return rows from a single table.
73    /// This method returns the name of that table.
74    pub fn subscribed_table_name(&self) -> &str {
75        self.plans[0].subscribed_table_name()
76    }
77
78    /// Returns the index ids from which this subscription reads
79    pub fn index_ids(&self) -> impl Iterator<Item = (TableId, IndexId)> {
80        self.plans
81            .iter()
82            .flat_map(|plan| plan.index_ids())
83            .collect::<HashSet<_>>()
84            .into_iter()
85    }
86
87    /// Returns the table ids from which this subscription reads
88    pub fn table_ids(&self) -> impl Iterator<Item = TableId> + '_ {
89        self.plans
90            .iter()
91            .flat_map(|plan| plan.table_ids())
92            .collect::<HashSet<_>>()
93            .into_iter()
94    }
95
96    /// Return the search arguments for this query
97    fn search_args(&self) -> impl Iterator<Item = (TableId, ColId, AlgebraicValue)> {
98        let mut args = HashSet::new();
99        for arg in self
100            .plans
101            .iter()
102            .flat_map(|subscription| subscription.optimized_physical_plan().search_args())
103        {
104            args.insert(arg);
105        }
106        args.into_iter()
107    }
108
109    /// Returns the plan fragments that comprise this subscription.
110    /// Will only return one element unless there is a table with multiple RLS rules.
111    pub fn plans_fragments(&self) -> impl Iterator<Item = &SubscriptionPlan> + '_ {
112        self.plans.iter()
113    }
114
115    /// Returns the join edges for this plan, if any.
116    pub fn join_edges(&self) -> impl Iterator<Item = (JoinEdge, AlgebraicValue)> + '_ {
117        self.plans.iter().filter_map(|plan| plan.join_edge())
118    }
119
120    /// The `SQL` text of this subscription.
121    pub fn sql(&self) -> &str {
122        &self.sql
123    }
124}
125
126/// For each client, we hold a handle for sending messages, and we track the queries they are subscribed to.
127#[derive(Debug)]
128struct ClientInfo {
129    outbound_ref: Client,
130    subscriptions: HashMap<SubscriptionId, HashSet<QueryHash>>,
131    subscription_ref_count: HashMap<QueryHash, usize>,
132    // This should be removed when we migrate to SubscribeSingle.
133    legacy_subscriptions: HashSet<QueryHash>,
134    /// This flag is set if an error occurs during a tx update.
135    /// It will be cleaned up async or on resubscribe.
136    ///
137    /// [`Arc`]ed so that this can be updated by the [`SendWorker`]
138    /// and observed by [`SubscriptionManager::remove_dropped_clients`].
139    dropped: Arc<AtomicBool>,
140}
141
142impl ClientInfo {
143    fn new(outbound_ref: Client) -> Self {
144        Self {
145            outbound_ref,
146            subscriptions: HashMap::default(),
147            subscription_ref_count: HashMap::default(),
148            legacy_subscriptions: HashSet::default(),
149            dropped: Arc::new(AtomicBool::new(false)),
150        }
151    }
152
153    /// Check that the subscription ref count matches the actual number of subscriptions.
154    #[cfg(test)]
155    fn assert_ref_count_consistency(&self) {
156        let mut expected_ref_count = HashMap::new();
157        for query_hashes in self.subscriptions.values() {
158            for query_hash in query_hashes {
159                assert!(
160                    self.subscription_ref_count.contains_key(query_hash),
161                    "Query hash not found: {query_hash:?}"
162                );
163                expected_ref_count
164                    .entry(*query_hash)
165                    .and_modify(|count| *count += 1)
166                    .or_insert(1);
167            }
168        }
169        assert_eq!(
170            self.subscription_ref_count, expected_ref_count,
171            "Checking the reference totals failed"
172        );
173    }
174}
175
176/// For each query that has subscribers, we track a set of legacy subscribers and individual subscriptions.
177#[derive(Debug)]
178struct QueryState {
179    query: Query,
180    // For legacy clients that subscribe to a set of queries, we track them here.
181    legacy_subscribers: HashSet<ClientId>,
182    // For clients that subscribe to a single query, we track them here.
183    subscriptions: HashSet<ClientId>,
184}
185
186impl QueryState {
187    fn new(query: Query) -> Self {
188        Self {
189            query,
190            legacy_subscribers: HashSet::default(),
191            subscriptions: HashSet::default(),
192        }
193    }
194    fn has_subscribers(&self) -> bool {
195        !self.subscriptions.is_empty() || !self.legacy_subscribers.is_empty()
196    }
197
198    // This returns all of the clients listening to a query. If a client has multiple subscriptions for this query, it will appear twice.
199    fn all_clients(&self) -> impl Iterator<Item = &ClientId> {
200        itertools::chain(&self.legacy_subscribers, &self.subscriptions)
201    }
202
203    /// Return the [`Query`] for this [`QueryState`]
204    pub fn query(&self) -> &Query {
205        &self.query
206    }
207
208    /// Return the search arguments for this query
209    fn search_args(&self) -> impl Iterator<Item = (TableId, ColId, AlgebraicValue)> {
210        self.query.search_args()
211    }
212}
213
214/// In this container, we keep track of parameterized subscription queries.
215/// This is used to prune unnecessary queries during subscription evaluation.
216///
217/// TODO: This container is populated on initial subscription.
218/// Ideally this information would be stored in the datastore,
219/// but because subscriptions are evaluated using a read only tx,
220/// we have to manage this memory separately.
221///
222/// If we stored this information in the datastore,
223/// we could encode pruning logic in the execution plan itself.
224#[derive(Debug, Default)]
225pub struct SearchArguments {
226    /// We parameterize subscriptions if they have an equality selection.
227    /// In this case a parameter is a [TableId], [ColId] pair.
228    ///
229    /// Ex.
230    ///
231    /// ```sql
232    /// SELECT * FROM t WHERE id = <value>
233    /// ```
234    ///
235    /// This query is parameterized by `t.id`.
236    ///
237    /// Ex.
238    ///
239    /// ```sql
240    /// SELECT t.* FROM t JOIN s ON t.id = s.id WHERE s.x = <value>
241    /// ```
242    ///
243    /// This query is parameterized by `s.x`.
244    cols: HashMap<TableId, HashSet<ColId>>,
245    /// For each parameter we keep track of its possible values or arguments.
246    /// These arguments are the different values that clients subscribe with.
247    ///
248    /// Ex.
249    ///
250    /// ```sql
251    /// SELECT * FROM t WHERE id = 3
252    /// SELECT * FROM t WHERE id = 5
253    /// ```
254    ///
255    /// These queries will get parameterized by `t.id`,
256    /// and we will record the args `3` and `5` in this map.
257    args: BTreeMap<(TableId, ColId, AlgebraicValue), HashSet<QueryHash>>,
258}
259
260impl SearchArguments {
261    /// Return the column ids by which a table is parameterized
262    fn search_params_for_table(&self, table_id: TableId) -> impl Iterator<Item = &ColId> + '_ {
263        self.cols.get(&table_id).into_iter().flatten()
264    }
265
266    /// Are there queries parameterized by this table and column?
267    /// If so, do we have a subscriber for this `search_arg`?
268    fn queries_for_search_arg(
269        &self,
270        table_id: TableId,
271        col_id: ColId,
272        search_arg: AlgebraicValue,
273    ) -> impl Iterator<Item = &QueryHash> {
274        self.args.get(&(table_id, col_id, search_arg)).into_iter().flatten()
275    }
276
277    /// Find the queries that need to be evaluated for this row.
278    fn queries_for_row<'a>(&'a self, table_id: TableId, row: &'a ProductValue) -> impl Iterator<Item = &'a QueryHash> {
279        self.search_params_for_table(table_id)
280            .filter_map(|col_id| row.get_field(col_id.idx(), None).ok().map(|arg| (col_id, arg.clone())))
281            .flat_map(move |(col_id, arg)| self.queries_for_search_arg(table_id, *col_id, arg))
282    }
283
284    /// Remove a query hash and its associated data from this container.
285    /// Note, a query hash may be associated with multiple column ids.
286    fn remove_query(&mut self, query: &Query) {
287        // Collect the column parameters for this query
288        let mut params = query.search_args().collect::<Vec<_>>();
289
290        // Remove the search argument entries for this query
291        for key in &params {
292            if let Some(hashes) = self.args.get_mut(key) {
293                hashes.remove(&query.hash);
294                if hashes.is_empty() {
295                    self.args.remove(key);
296                }
297            }
298        }
299
300        // Retain columns that no longer map to any search arguments
301        params.retain(|(table_id, col_id, _)| {
302            self.args
303                .range((*table_id, *col_id, AlgebraicValue::Min)..=(*table_id, *col_id, AlgebraicValue::Max))
304                .next()
305                .is_none()
306        });
307
308        // Remove columns that no longer map to any search arguments
309        for (table_id, col_id, _) in params {
310            if let Some(col_ids) = self.cols.get_mut(&table_id) {
311                col_ids.remove(&col_id);
312                if col_ids.is_empty() {
313                    self.cols.remove(&table_id);
314                }
315            }
316        }
317    }
318
319    /// Add a new mapping from search argument to query hash
320    fn insert_query(&mut self, table_id: TableId, col_id: ColId, arg: AlgebraicValue, query: QueryHash) {
321        self.args.entry((table_id, col_id, arg)).or_default().insert(query);
322        self.cols.entry(table_id).or_default().insert(col_id);
323    }
324}
325
326/// Keeps track of the indexes that are used in subscriptions.
327#[derive(Debug, Default)]
328pub struct QueriedTableIndexIds {
329    ids: HashMap<TableId, HashMap<IndexId, usize>>,
330}
331
332impl FromIterator<(TableId, IndexId)> for QueriedTableIndexIds {
333    fn from_iter<T: IntoIterator<Item = (TableId, IndexId)>>(iter: T) -> Self {
334        let mut index_ids = Self::default();
335        for (table_id, index_id) in iter {
336            index_ids.insert_index_id(table_id, index_id);
337        }
338        index_ids
339    }
340}
341
342impl QueriedTableIndexIds {
343    /// Returns the index ids that are used in subscriptions for this table.
344    /// Note, it does not return all of the index ids that are defined on this table.
345    /// Only those that are used by at least one subscription query.
346    pub fn index_ids_for_table(&self, table_id: TableId) -> impl Iterator<Item = IndexId> + '_ {
347        self.ids
348            .get(&table_id)
349            .into_iter()
350            .flat_map(|index_ids| index_ids.keys())
351            .copied()
352    }
353
354    /// Insert a new `table_id` `index_id` pair into this container.
355    /// Note, different queries may read from the same index.
356    /// Hence we may already be tracking this index, in which case we bump its ref count.
357    pub fn insert_index_id(&mut self, table_id: TableId, index_id: IndexId) {
358        *self.ids.entry(table_id).or_default().entry(index_id).or_default() += 1;
359    }
360
361    /// Remove a `table_id` `index_id` pair from this container.
362    /// Note, different queries may read from the same index.
363    /// Hence we only remove this key from the map if its ref count goes to zero.
364    pub fn delete_index_id(&mut self, table_id: TableId, index_id: IndexId) {
365        if let Some(ids) = self.ids.get_mut(&table_id) {
366            if let Some(n) = ids.get_mut(&index_id) {
367                *n -= 1;
368
369                if *n == 0 {
370                    ids.remove(&index_id);
371
372                    if ids.is_empty() {
373                        self.ids.remove(&table_id);
374                    }
375                }
376            }
377        }
378    }
379
380    /// Insert the index ids from which a query reads into this mapping.
381    /// Note, an index may already be tracked if another query is already using it.
382    /// In this case we just bump its ref count.
383    pub fn insert_index_ids_for_query(&mut self, query: &Query) {
384        for (table_id, index_id) in query.index_ids() {
385            self.insert_index_id(table_id, index_id);
386        }
387    }
388
389    /// Delete the index ids from which a query reads from this mapping
390    /// Note, we will not remove an index id from this mapping if another query is using it.
391    /// Instead we decrement its ref count.
392    pub fn delete_index_ids_for_query(&mut self, query: &Query) {
393        for (table_id, index_id) in query.index_ids() {
394            self.delete_index_id(table_id, index_id);
395        }
396    }
397}
398
399/// A sorted set of join edges used for pruning queries.
400/// See [`JoinEdge`] for more details.
401#[derive(Debug, Default)]
402pub struct JoinEdges {
403    edges: BTreeMap<JoinEdge, HashMap<AlgebraicValue, HashSet<QueryHash>>>,
404}
405
406impl JoinEdges {
407    /// If this query has any join edges, add them to the map.
408    fn add_query(&mut self, qs: &QueryState) -> bool {
409        let mut inserted = false;
410        for (edge, rhs_val) in qs.query.join_edges() {
411            inserted = true;
412            self.edges
413                .entry(edge)
414                .or_default()
415                .entry(rhs_val)
416                .or_default()
417                .insert(qs.query.hash);
418        }
419        inserted
420    }
421
422    /// If this query has any join edges, remove them from the map.
423    fn remove_query(&mut self, query: &Query) {
424        for (edge, rhs_val) in query.join_edges() {
425            if let Some(values) = self.edges.get_mut(&edge) {
426                if let Some(hashes) = values.get_mut(&rhs_val) {
427                    hashes.remove(&query.hash);
428                    if hashes.is_empty() {
429                        values.remove(&rhs_val);
430                        if values.is_empty() {
431                            self.edges.remove(&edge);
432                        }
433                    }
434                }
435            }
436        }
437    }
438
439    /// Searches for queries that must be evaluated for this row,
440    /// and effectively prunes queries that do not.
441    fn queries_for_row<'a>(
442        &'a self,
443        table_id: TableId,
444        row: &'a ProductValue,
445        find_rhs_val: impl Fn(&JoinEdge, &ProductValue) -> Option<AlgebraicValue>,
446    ) -> impl Iterator<Item = &'a QueryHash> {
447        self.edges
448            .range(JoinEdge::range_for_table(table_id))
449            .filter_map(move |(edge, hashes)| find_rhs_val(edge, row).as_ref().and_then(|rhs_val| hashes.get(rhs_val)))
450            .flatten()
451    }
452}
453
454/// Responsible for the efficient evaluation of subscriptions.
455/// It performs basic multi-query optimization,
456/// in that if a query has N subscribers,
457/// it is only executed once,
458/// with the results copied to the N receivers.
459#[derive(Debug)]
460pub struct SubscriptionManager {
461    /// State for each client.
462    clients: HashMap<ClientId, ClientInfo>,
463
464    /// Queries for which there is at least one subscriber.
465    queries: HashMap<QueryHash, QueryState>,
466
467    /// If a query reads from a table,
468    /// but does not have a simple equality filter on that table,
469    /// we map the table to the query in this inverted index.
470    tables: IntMap<TableId, HashSet<QueryHash>>,
471
472    /// Tracks the indices used across all subscriptions
473    /// to enable building the appropriate indexes for row updates.
474    indexes: QueriedTableIndexIds,
475
476    /// If a query reads from a table,
477    /// and has a simple equality filter on that table,
478    /// we map the filter values to the query in this lookup table.
479    search_args: SearchArguments,
480
481    /// A sorted set of join edges used for pruning queries.
482    /// See [`JoinEdge`] for more details.
483    join_edges: JoinEdges,
484
485    /// Transmit side of a channel to the manager's [`SendWorker`] task.
486    ///
487    /// The send worker runs in parallel and pops [`ComputedQueries`]es out in order,
488    /// aggregates each client's full set of updates,
489    /// then passes them to the clients' websocket workers.
490    /// This allows transaction processing to proceed on the main thread
491    /// ahead of post-processing and broadcasting updates
492    /// while still ensuring that those updates are sent in the correct serial order.
493    /// Additionally, it avoids starving the next reducer request of Tokio workers,
494    /// as it imposes a delay between unlocking the datastore
495    /// and waking the many per-client sender Tokio tasks.
496    send_worker_queue: BroadcastQueue,
497}
498
499/// A single update for one client and one query.
500#[derive(Debug)]
501struct ClientUpdate {
502    id: ClientId,
503    table_id: TableId,
504    table_name: TableName,
505    update: FormatSwitch<SingleQueryUpdate<BsatnFormat>, SingleQueryUpdate<JsonFormat>>,
506}
507
508/// The computed incremental update queries with sufficient information
509/// to not depend on the transaction lock so that further work can be
510/// done in a separate worker: [`SubscriptionManager::send_worker`].
511/// The queries in this structure have not been aggregated yet
512/// but will be in the worker.
513#[derive(Debug)]
514struct ComputedQueries {
515    updates: Vec<ClientUpdate>,
516    errs: Vec<(ClientId, Box<str>)>,
517    event: Arc<ModuleEvent>,
518    caller: Option<Arc<ClientConnectionSender>>,
519}
520
521// Wraps a sender so that it will increment a gauge.
522#[derive(Debug)]
523struct SenderWithGauge<T> {
524    tx: mpsc::UnboundedSender<T>,
525    metric: Option<IntGauge>,
526}
527impl<T> Clone for SenderWithGauge<T> {
528    fn clone(&self) -> Self {
529        SenderWithGauge {
530            tx: self.tx.clone(),
531            metric: self.metric.clone(),
532        }
533    }
534}
535
536impl<T> SenderWithGauge<T> {
537    fn new(tx: mpsc::UnboundedSender<T>, metric: Option<IntGauge>) -> Self {
538        Self { tx, metric }
539    }
540
541    /// Send a message to the worker and update the queue length metric.
542    pub fn send(&self, msg: T) -> Result<(), mpsc::error::SendError<T>> {
543        if let Some(metric) = &self.metric {
544            metric.inc();
545        }
546        // Note, this could number would be permanently off if the send call panics.
547        self.tx.send(msg)
548    }
549}
550
551/// Message sent by the [`SubscriptionManager`] to the [`SendWorker`].
552#[derive(Debug)]
553enum SendWorkerMessage {
554    /// A transaction has completed and the [`SubscriptionManager`] has evaluated the incremental queries,
555    /// so the [`SendWorker`] should broadcast them to clients.
556    Broadcast(ComputedQueries),
557
558    /// A new client has been registered in the [`SubscriptionManager`],
559    /// so the [`SendWorker`] should also record its existence.
560    AddClient {
561        client_id: ClientId,
562        /// Shared handle on the `dropped` flag in the [`Subscriptionmanager`]'s [`ClientInfo`].
563        ///
564        /// Will be updated by [`SendWorker::run`] and read by [`SubscriptionManager::remove_dropped_clients`].
565        dropped: Arc<AtomicBool>,
566        outbound_ref: Client,
567    },
568
569    // Send a message to a client.
570    SendMessage {
571        recipient: Arc<ClientConnectionSender>,
572        message: SerializableMessage,
573    },
574
575    /// A client previously added by a [`Self::AddClient`] message has been removed from the [`SubscriptionManager`],
576    /// so the [`SendWorker`] should also forget it.
577    RemoveClient(ClientId),
578}
579
580// Tracks some gauges related to subscriptions.
581pub struct SubscriptionGaugeStats {
582    // The number of unique queries with at least one subscriber.
583    pub num_queries: usize,
584    // The number of unique connections with at least one subscription.
585    pub num_connections: usize,
586    // The number of subscription sets across all clients.
587    pub num_subscription_sets: usize,
588    // The total number of subscriptions across all clients and queries.
589    pub num_query_subscriptions: usize,
590    // The total number of subscriptions across all clients and queries.
591    pub num_legacy_subscriptions: usize,
592}
593
594impl SubscriptionManager {
595    pub fn for_test_without_metrics_arc_rwlock() -> Arc<RwLock<Self>> {
596        Arc::new(RwLock::new(Self::for_test_without_metrics()))
597    }
598
599    pub fn for_test_without_metrics() -> Self {
600        Self::new(SendWorker::spawn_new(None))
601    }
602
603    pub fn new(send_worker_queue: BroadcastQueue) -> Self {
604        Self {
605            clients: Default::default(),
606            queries: Default::default(),
607            indexes: Default::default(),
608            tables: Default::default(),
609            search_args: Default::default(),
610            join_edges: Default::default(),
611            send_worker_queue,
612        }
613    }
614
615    pub fn query(&self, hash: &QueryHash) -> Option<Query> {
616        self.queries.get(hash).map(|state| state.query.clone())
617    }
618
619    pub fn calculate_gauge_stats(&self) -> SubscriptionGaugeStats {
620        let num_queries = self.queries.len();
621        let num_connections = self.clients.len();
622        let num_query_subscriptions = self.queries.values().map(|state| state.subscriptions.len()).sum();
623        let num_subscription_sets = self.clients.values().map(|ci| ci.subscriptions.len()).sum();
624        let num_legacy_subscriptions = self
625            .clients
626            .values()
627            .filter(|ci| !ci.legacy_subscriptions.is_empty())
628            .count();
629
630        SubscriptionGaugeStats {
631            num_queries,
632            num_connections,
633            num_query_subscriptions,
634            num_subscription_sets,
635            num_legacy_subscriptions,
636        }
637    }
638
639    /// Add a new [`ClientInfo`] to the `clients` map, and broadcast a message along `send_worker_tx`
640    /// that the [`SendWorker`] should also add this client.
641    ///
642    /// Horrible signature to enable split borrows on [`Self`].
643    fn get_or_make_client_info_and_inform_send_worker<'clients>(
644        clients: &'clients mut HashMap<ClientId, ClientInfo>,
645        send_worker_tx: &BroadcastQueue,
646        client_id: ClientId,
647        outbound_ref: Client,
648    ) -> &'clients mut ClientInfo {
649        clients.entry(client_id).or_insert_with(|| {
650            let info = ClientInfo::new(outbound_ref.clone());
651            send_worker_tx
652                .send(SendWorkerMessage::AddClient {
653                    client_id,
654                    dropped: info.dropped.clone(),
655                    outbound_ref,
656                })
657                .expect("send worker has panicked, or otherwise dropped its recv queue!");
658            info
659        })
660    }
661
662    /// Remove a [`ClientInfo`] from the `clients` map,
663    /// and broadcast a message along `send_worker_tx` that the [`SendWorker`] should also remove it.
664    fn remove_client_and_inform_send_worker(&mut self, client_id: ClientId) -> Option<ClientInfo> {
665        self.clients.remove(&client_id).inspect(|_| {
666            self.send_worker_queue
667                .send(SendWorkerMessage::RemoveClient(client_id))
668                .expect("send worker has panicked, or otherwise dropped its recv queue!");
669        })
670    }
671
672    pub fn num_unique_queries(&self) -> usize {
673        self.queries.len()
674    }
675
676    #[cfg(test)]
677    fn contains_query(&self, hash: &QueryHash) -> bool {
678        self.queries.contains_key(hash)
679    }
680
681    #[cfg(test)]
682    fn contains_client(&self, subscriber: &ClientId) -> bool {
683        self.clients.contains_key(subscriber)
684    }
685
686    #[cfg(test)]
687    fn contains_legacy_subscription(&self, subscriber: &ClientId, query: &QueryHash) -> bool {
688        self.queries
689            .get(query)
690            .is_some_and(|state| state.legacy_subscribers.contains(subscriber))
691    }
692
693    #[cfg(test)]
694    fn query_reads_from_table(&self, query: &QueryHash, table: &TableId) -> bool {
695        self.tables.get(table).is_some_and(|queries| queries.contains(query))
696    }
697
698    #[cfg(test)]
699    fn query_has_search_arg(&self, query: QueryHash, table_id: TableId, col_id: ColId, arg: AlgebraicValue) -> bool {
700        self.search_args
701            .queries_for_search_arg(table_id, col_id, arg)
702            .any(|hash| *hash == query)
703    }
704
705    #[cfg(test)]
706    fn table_has_search_param(&self, table_id: TableId, col_id: ColId) -> bool {
707        self.search_args
708            .search_params_for_table(table_id)
709            .any(|id| *id == col_id)
710    }
711
712    fn remove_legacy_subscriptions(&mut self, client: &ClientId) {
713        if let Some(ci) = self.clients.get_mut(client) {
714            let mut queries_to_remove = Vec::new();
715            for query_hash in ci.legacy_subscriptions.iter() {
716                let Some(query_state) = self.queries.get_mut(query_hash) else {
717                    tracing::warn!("Query state not found for query hash: {:?}", query_hash);
718                    continue;
719                };
720
721                query_state.legacy_subscribers.remove(client);
722                if !query_state.has_subscribers() {
723                    SubscriptionManager::remove_query_from_tables(
724                        &mut self.tables,
725                        &mut self.join_edges,
726                        &mut self.indexes,
727                        &mut self.search_args,
728                        &query_state.query,
729                    );
730                    queries_to_remove.push(*query_hash);
731                }
732            }
733            ci.legacy_subscriptions.clear();
734            for query_hash in queries_to_remove {
735                self.queries.remove(&query_hash);
736            }
737        }
738    }
739
740    /// Remove any clients that have been marked for removal
741    pub fn remove_dropped_clients(&mut self) {
742        for id in self.clients.keys().copied().collect::<Vec<_>>() {
743            if let Some(client) = self.clients.get(&id) {
744                if client.dropped.load(Ordering::Relaxed) {
745                    self.remove_all_subscriptions(&id);
746                }
747            }
748        }
749    }
750
751    /// Remove a single subscription for a client.
752    /// This will return an error if the client does not have a subscription with the given query id.
753    pub fn remove_subscription(&mut self, client_id: ClientId, query_id: ClientQueryId) -> Result<Vec<Query>, DBError> {
754        let subscription_id = (client_id, query_id);
755        let Some(ci) = self
756            .clients
757            .get_mut(&client_id)
758            .filter(|ci| !ci.dropped.load(Ordering::Acquire))
759        else {
760            return Err(anyhow::anyhow!("Client not found: {:?}", client_id).into());
761        };
762
763        #[cfg(test)]
764        ci.assert_ref_count_consistency();
765
766        let Some(query_hashes) = ci.subscriptions.remove(&subscription_id) else {
767            return Err(anyhow::anyhow!("Subscription not found: {:?}", subscription_id).into());
768        };
769        let mut queries_to_return = Vec::new();
770        for hash in query_hashes {
771            let remaining_refs = {
772                let Some(count) = ci.subscription_ref_count.get_mut(&hash) else {
773                    return Err(anyhow::anyhow!("Query count not found for query hash: {:?}", hash).into());
774                };
775                *count -= 1;
776                *count
777            };
778            if remaining_refs > 0 {
779                // The client is still subscribed to this query, so we are done for now.
780                continue;
781            }
782            // The client is no longer subscribed to this query.
783            ci.subscription_ref_count.remove(&hash);
784            let Some(query_state) = self.queries.get_mut(&hash) else {
785                return Err(anyhow::anyhow!("Query state not found for query hash: {:?}", hash).into());
786            };
787            queries_to_return.push(query_state.query.clone());
788            query_state.subscriptions.remove(&client_id);
789            if !query_state.has_subscribers() {
790                SubscriptionManager::remove_query_from_tables(
791                    &mut self.tables,
792                    &mut self.join_edges,
793                    &mut self.indexes,
794                    &mut self.search_args,
795                    &query_state.query,
796                );
797                self.queries.remove(&hash);
798            }
799        }
800
801        #[cfg(test)]
802        ci.assert_ref_count_consistency();
803
804        Ok(queries_to_return)
805    }
806
807    /// Adds a single subscription for a client.
808    pub fn add_subscription(&mut self, client: Client, query: Query, query_id: ClientQueryId) -> Result<(), DBError> {
809        self.add_subscription_multi(client, vec![query], query_id).map(|_| ())
810    }
811
812    pub fn add_subscription_multi(
813        &mut self,
814        client: Client,
815        queries: Vec<Query>,
816        query_id: ClientQueryId,
817    ) -> Result<Vec<Query>, DBError> {
818        let client_id = (client.id.identity, client.id.connection_id);
819
820        // Clean up any dropped subscriptions
821        if self
822            .clients
823            .get(&client_id)
824            .is_some_and(|ci| ci.dropped.load(Ordering::Acquire))
825        {
826            self.remove_all_subscriptions(&client_id);
827        }
828
829        let ci = Self::get_or_make_client_info_and_inform_send_worker(
830            &mut self.clients,
831            &self.send_worker_queue,
832            client_id,
833            client,
834        );
835
836        #[cfg(test)]
837        ci.assert_ref_count_consistency();
838        let subscription_id = (client_id, query_id);
839        let hash_set = match ci.subscriptions.try_insert(subscription_id, HashSet::new()) {
840            Err(OccupiedError { .. }) => {
841                return Err(anyhow::anyhow!(
842                    "Subscription with id {:?} already exists for client: {:?}",
843                    query_id,
844                    client_id
845                )
846                .into());
847            }
848            Ok(hash_set) => hash_set,
849        };
850        // We track the queries that are being added for this client.
851        let mut new_queries = Vec::new();
852
853        for query in &queries {
854            let hash = query.hash();
855            // Deduping queries within this single call.
856            if !hash_set.insert(hash) {
857                continue;
858            }
859            let query_state = self
860                .queries
861                .entry(hash)
862                .or_insert_with(|| QueryState::new(query.clone()));
863
864            Self::insert_query(
865                &mut self.tables,
866                &mut self.join_edges,
867                &mut self.indexes,
868                &mut self.search_args,
869                query_state,
870            );
871
872            let entry = ci.subscription_ref_count.entry(hash).or_insert(0);
873            *entry += 1;
874            let is_new_entry = *entry == 1;
875
876            let inserted = query_state.subscriptions.insert(client_id);
877            // This should arguably crash the server, as it indicates a bug.
878            if inserted != is_new_entry {
879                return Err(anyhow::anyhow!("Internal error, ref count and query_state mismatch").into());
880            }
881            if inserted {
882                new_queries.push(query.clone());
883            }
884        }
885
886        #[cfg(test)]
887        {
888            ci.assert_ref_count_consistency();
889        }
890
891        Ok(new_queries)
892    }
893
894    /// Adds a client and its queries to the subscription manager.
895    /// Sets up the set of subscriptions for the client, replacing any existing legacy subscriptions.
896    ///
897    /// If a query is not already indexed,
898    /// its table ids added to the inverted index.
899    // #[tracing::instrument(level = "trace", skip_all)]
900    pub fn set_legacy_subscription(&mut self, client: Client, queries: impl IntoIterator<Item = Query>) {
901        let client_id = (client.id.identity, client.id.connection_id);
902        // First, remove any existing legacy subscriptions.
903        self.remove_legacy_subscriptions(&client_id);
904
905        // Now, add the new subscriptions.
906        let ci = Self::get_or_make_client_info_and_inform_send_worker(
907            &mut self.clients,
908            &self.send_worker_queue,
909            client_id,
910            client,
911        );
912
913        for unit in queries {
914            let hash = unit.hash();
915            ci.legacy_subscriptions.insert(hash);
916            let query_state = self
917                .queries
918                .entry(hash)
919                .or_insert_with(|| QueryState::new(unit.clone()));
920            Self::insert_query(
921                &mut self.tables,
922                &mut self.join_edges,
923                &mut self.indexes,
924                &mut self.search_args,
925                query_state,
926            );
927            query_state.legacy_subscribers.insert(client_id);
928        }
929    }
930
931    // Update the mapping from table id to related queries by removing the given query.
932    // If this removes all queries for a table, the map entry for that table is removed altogether.
933    // This takes a ref to the table map instead of `self` to avoid borrowing issues.
934    fn remove_query_from_tables(
935        tables: &mut IntMap<TableId, HashSet<QueryHash>>,
936        join_edges: &mut JoinEdges,
937        index_ids: &mut QueriedTableIndexIds,
938        search_args: &mut SearchArguments,
939        query: &Query,
940    ) {
941        let hash = query.hash();
942        join_edges.remove_query(query);
943        search_args.remove_query(query);
944        index_ids.delete_index_ids_for_query(query);
945        for table_id in query.table_ids() {
946            if let Entry::Occupied(mut entry) = tables.entry(table_id) {
947                let hashes = entry.get_mut();
948                if hashes.remove(&hash) && hashes.is_empty() {
949                    entry.remove();
950                }
951            }
952        }
953    }
954
955    // Update the mapping from table id to related queries by inserting the given query.
956    // Also add any search arguments the query may have.
957    // This takes a ref to the table map instead of `self` to avoid borrowing issues.
958    fn insert_query(
959        tables: &mut IntMap<TableId, HashSet<QueryHash>>,
960        join_edges: &mut JoinEdges,
961        index_ids: &mut QueriedTableIndexIds,
962        search_args: &mut SearchArguments,
963        query_state: &QueryState,
964    ) {
965        // If this is new, we need to update the table to query mapping.
966        if !query_state.has_subscribers() {
967            let hash = query_state.query.hash;
968            let query = query_state.query();
969            let return_table = query.subscribed_table_id();
970            let mut table_ids = query.table_ids().collect::<HashSet<_>>();
971
972            // Update the index id mapping
973            index_ids.insert_index_ids_for_query(query);
974
975            // Update the search arguments
976            for (table_id, col_id, arg) in query_state.search_args() {
977                table_ids.remove(&table_id);
978                search_args.insert_query(table_id, col_id, arg, hash);
979            }
980
981            // Update the join edges if the return table didn't have any search arguments
982            if table_ids.contains(&return_table) && join_edges.add_query(query_state) {
983                table_ids.remove(&return_table);
984            }
985
986            // Finally update the `tables` map if the query didn't have a search argument or a join edge for a table
987            for table_id in table_ids {
988                tables.entry(table_id).or_default().insert(hash);
989            }
990        }
991    }
992
993    /// Removes a client from the subscriber mapping.
994    /// If a query no longer has any subscribers,
995    /// it is removed from the index along with its table ids.
996    #[tracing::instrument(level = "trace", skip_all)]
997    pub fn remove_all_subscriptions(&mut self, client: &ClientId) {
998        self.remove_legacy_subscriptions(client);
999        let Some(client_info) = self.remove_client_and_inform_send_worker(*client) else {
1000            return;
1001        };
1002
1003        debug_assert!(client_info.legacy_subscriptions.is_empty());
1004        let mut queries_to_remove = Vec::new();
1005        for query_hash in client_info.subscription_ref_count.keys() {
1006            let Some(query_state) = self.queries.get_mut(query_hash) else {
1007                tracing::warn!("Query state not found for query hash: {:?}", query_hash);
1008                return;
1009            };
1010            query_state.subscriptions.remove(client);
1011            // This could happen twice for the same hash if a client has a duplicate, but that's fine. It is idepotent.
1012            if !query_state.has_subscribers() {
1013                queries_to_remove.push(*query_hash);
1014                SubscriptionManager::remove_query_from_tables(
1015                    &mut self.tables,
1016                    &mut self.join_edges,
1017                    &mut self.indexes,
1018                    &mut self.search_args,
1019                    &query_state.query,
1020                );
1021            }
1022        }
1023        for query_hash in queries_to_remove {
1024            self.queries.remove(&query_hash);
1025        }
1026    }
1027
1028    /// Find the queries that need to be evaluated for this table update.
1029    ///
1030    /// Note, this tries to prune irrelevant queries from the subscription.
1031    ///
1032    /// When is this beneficial?
1033    ///
1034    /// If many different clients subscribe to the same parameterized query,
1035    /// but they all subscribe with different parameter values,
1036    /// and if these rows contain only a few unique values for this parameter,
1037    /// most clients will not receive an update,
1038    /// and so we can avoid evaluating queries for them entirely.
1039    ///
1040    /// Ex.
1041    ///
1042    /// 1000 clients subscribe to `SELECT * FROM t WHERE id = ?`,
1043    /// each one with a different value for `?`.
1044    /// If there are transactions that only ever update one row of `t` at a time,
1045    /// we only pay the cost of evaluating one query.
1046    ///
1047    /// When is this not beneficial?
1048    ///
1049    /// If the table update contains a lot of unique values for a parameter,
1050    /// we won't be able to prune very many queries from the subscription,
1051    /// so this could add some overhead linear in the size of the table update.
1052    ///
1053    /// TODO: This logic should be expressed in the execution plan itself,
1054    /// so that we don't have to preprocess the table update before execution.
1055    fn queries_for_table_update<'a>(
1056        &'a self,
1057        table_update: &'a DatabaseTableUpdate,
1058        find_rhs_val: &impl Fn(&JoinEdge, &ProductValue) -> Option<AlgebraicValue>,
1059    ) -> impl Iterator<Item = &'a QueryHash> {
1060        let mut queries = HashSet::new();
1061        for hash in table_update
1062            .inserts
1063            .iter()
1064            .chain(table_update.deletes.iter())
1065            .flat_map(|row| self.queries_for_row(table_update.table_id, row, find_rhs_val))
1066        {
1067            queries.insert(hash);
1068        }
1069        for hash in self.tables.get(&table_update.table_id).into_iter().flatten() {
1070            queries.insert(hash);
1071        }
1072        queries.into_iter()
1073    }
1074
1075    /// Find the queries that need to be evaluated for this row.
1076    fn queries_for_row<'a>(
1077        &'a self,
1078        table_id: TableId,
1079        row: &'a ProductValue,
1080        find_rhs_val: impl Fn(&JoinEdge, &ProductValue) -> Option<AlgebraicValue>,
1081    ) -> impl Iterator<Item = &'a QueryHash> {
1082        self.search_args
1083            .queries_for_row(table_id, row)
1084            .chain(self.join_edges.queries_for_row(table_id, row, find_rhs_val))
1085    }
1086
1087    /// Returns the index ids that are used in subscription queries
1088    pub fn index_ids_for_subscriptions(&self) -> &QueriedTableIndexIds {
1089        &self.indexes
1090    }
1091
1092    /// This method takes a set of delta tables,
1093    /// evaluates only the necessary queries for those delta tables,
1094    /// and then sends the results to each client.
1095    ///
1096    /// This previously used rayon to parallelize subscription evaluation.
1097    /// However, in order to optimize for the common case of small updates,
1098    /// we removed rayon and switched to a single-threaded execution,
1099    /// which removed significant overhead associated with thread switching.
1100    #[tracing::instrument(level = "trace", skip_all)]
1101    pub fn eval_updates_sequential(
1102        &self,
1103        tx: &DeltaTx,
1104        event: Arc<ModuleEvent>,
1105        caller: Option<Arc<ClientConnectionSender>>,
1106    ) -> ExecutionMetrics {
1107        use FormatSwitch::{Bsatn, Json};
1108
1109        let tables = &event.status.database_update().unwrap().tables;
1110
1111        let span = tracing::info_span!("eval_incr").entered();
1112
1113        #[derive(Default)]
1114        struct FoldState {
1115            updates: Vec<ClientUpdate>,
1116            errs: Vec<(ClientId, Box<str>)>,
1117            metrics: ExecutionMetrics,
1118        }
1119
1120        /// Returns the value pointed to by this join edge
1121        fn find_rhs_val(edge: &JoinEdge, row: &ProductValue, tx: &DeltaTx) -> Option<AlgebraicValue> {
1122            // What if the joining row was deleted in this tx?
1123            // Will we prune a query that we shouldn't have?
1124            //
1125            // Ultimately no we will not.
1126            // We may prune it for this row specifically,
1127            // but we will eventually include it for the joining row.
1128            tx.iter_by_col_eq(
1129                edge.rhs_table,
1130                edge.rhs_join_col,
1131                &row.elements[edge.lhs_join_col.idx()],
1132            )
1133            .expect("This read should always succeed, and it's a bug if it doesn't")
1134            .next()
1135            .map(|row| {
1136                row.read_col(edge.rhs_col)
1137                    .expect("This read should always succeed, and it's a bug if it doesn't")
1138            })
1139        }
1140
1141        let FoldState { updates, errs, metrics } = tables
1142            .iter()
1143            .filter(|table| !table.inserts.is_empty() || !table.deletes.is_empty())
1144            .flat_map(|table_update| {
1145                self.queries_for_table_update(table_update, &|edge, row| find_rhs_val(edge, row, tx))
1146            })
1147            // deduplicate queries by their hash
1148            .filter({
1149                let mut seen = HashSet::new();
1150                // (HashSet::insert returns true for novel elements)
1151                move |&hash| seen.insert(hash)
1152            })
1153            .flat_map(|hash| {
1154                let qstate = &self.queries[hash];
1155                qstate
1156                    .query
1157                    .plans_fragments()
1158                    .map(move |plan_fragment| (qstate, plan_fragment))
1159            })
1160            // If N clients are subscribed to a query,
1161            // we copy the DatabaseTableUpdate N times,
1162            // which involves cloning BSATN (binary) or product values (json).
1163            .fold(FoldState::default(), |mut acc, (qstate, plan)| {
1164                let table_id = plan.subscribed_table_id();
1165                let table_name = plan.subscribed_table_name().clone();
1166                // Store at most one copy for both the serialization to BSATN and JSON.
1167                // Each subscriber gets to pick which of these they want,
1168                // but we only fill `ops_bin_uncompressed` and `ops_json` at most once.
1169                // The former will be `Some(_)` if some subscriber uses `Protocol::Binary`
1170                // and the latter `Some(_)` if some subscriber uses `Protocol::Text`.
1171                //
1172                // Previously we were compressing each `QueryUpdate` within a `TransactionUpdate`.
1173                // The reason was simple - many clients can subscribe to the same query.
1174                // If we compress `TransactionUpdate`s independently for each client,
1175                // we could be doing a lot of redundant compression.
1176                //
1177                // However the risks associated with this approach include:
1178                //   1. We have to hold the tx lock when compressing
1179                //   2. A potentially worse compression ratio
1180                //   3. Extra decompression overhead on the client
1181                //
1182                // Because transaction processing is currently single-threaded,
1183                // the risks of holding the tx lock for longer than necessary,
1184                // as well as additional the message processing overhead on the client,
1185                // outweighed the benefit of reduced cpu with the former approach.
1186                let mut ops_bin_uncompressed: Option<(CompressableQueryUpdate<BsatnFormat>, _, _)> = None;
1187                let mut ops_json: Option<(QueryUpdate<JsonFormat>, _, _)> = None;
1188
1189                fn memo_encode<F: WebsocketFormat>(
1190                    updates: &UpdatesRelValue<'_>,
1191                    memory: &mut Option<(F::QueryUpdate, u64, usize)>,
1192                    metrics: &mut ExecutionMetrics,
1193                ) -> SingleQueryUpdate<F> {
1194                    let (update, num_rows, num_bytes) = memory
1195                        .get_or_insert_with(|| {
1196                            let encoded = updates.encode::<F>();
1197                            // The first time we insert into this map, we call encode.
1198                            // This is when we serialize the rows to BSATN/JSON.
1199                            // Hence this is where we increment `bytes_scanned`.
1200                            metrics.bytes_scanned += encoded.2;
1201                            encoded
1202                        })
1203                        .clone();
1204                    // We call this function for each query,
1205                    // and for each client subscribed to it.
1206                    // Therefore every time we call this function,
1207                    // we update the `bytes_sent_to_clients` metric.
1208                    metrics.bytes_sent_to_clients += num_bytes;
1209                    SingleQueryUpdate { update, num_rows }
1210                }
1211
1212                let clients_for_query = qstate.all_clients();
1213
1214                match eval_delta(tx, &mut acc.metrics, plan) {
1215                    Err(err) => {
1216                        tracing::error!(
1217                            message = "Query errored during tx update",
1218                            sql = qstate.query.sql,
1219                            reason = ?err,
1220                        );
1221                        let err = DBError::WithSql {
1222                            sql: qstate.query.sql.as_str().into(),
1223                            error: Box::new(err.into()),
1224                        }
1225                        .to_string()
1226                        .into_boxed_str();
1227
1228                        acc.errs.extend(clients_for_query.map(|id| (*id, err.clone())))
1229                    }
1230                    // The query didn't return any rows to update
1231                    Ok(None) => {}
1232                    // The query did return updates - process them and add them to the accumulator
1233                    Ok(Some(delta_updates)) => {
1234                        let row_iter = clients_for_query.map(|id| {
1235                            let client = &self.clients[id].outbound_ref;
1236                            let update = match client.config.protocol {
1237                                Protocol::Binary => Bsatn(memo_encode::<BsatnFormat>(
1238                                    &delta_updates,
1239                                    &mut ops_bin_uncompressed,
1240                                    &mut acc.metrics,
1241                                )),
1242                                Protocol::Text => Json(memo_encode::<JsonFormat>(
1243                                    &delta_updates,
1244                                    &mut ops_json,
1245                                    &mut acc.metrics,
1246                                )),
1247                            };
1248                            ClientUpdate {
1249                                id: *id,
1250                                table_id,
1251                                table_name: table_name.clone(),
1252                                update,
1253                            }
1254                        });
1255                        acc.updates.extend(row_iter);
1256                    }
1257                }
1258
1259                acc
1260            });
1261
1262        // We've now finished all of the work which needs to read from the datastore,
1263        // so get this work off the main thread and over to the `send_worker`,
1264        // then return ASAP in order to unlock the datastore and start running the next transaction.
1265        // See comment on the `send_worker_tx` field in [`SubscriptionManager`] for more motivation.
1266        self.send_worker_queue
1267            .send(SendWorkerMessage::Broadcast(ComputedQueries {
1268                updates,
1269                errs,
1270                event,
1271                caller,
1272            }))
1273            .expect("send worker has panicked, or otherwise dropped its recv queue!");
1274
1275        drop(span);
1276
1277        metrics
1278    }
1279}
1280
1281struct SendWorkerClient {
1282    /// This flag is set if an error occurs during a tx update.
1283    /// It will be cleaned up async or on resubscribe.
1284    ///
1285    /// [`Arc`]ed so that this can be updated by [`Self::run`]
1286    /// and observed by [`SubscriptionManager::remove_dropped_clients`].
1287    dropped: Arc<AtomicBool>,
1288    outbound_ref: Client,
1289}
1290
1291impl SendWorkerClient {
1292    fn is_dropped(&self) -> bool {
1293        self.dropped.load(Ordering::Relaxed)
1294    }
1295
1296    fn is_cancelled(&self) -> bool {
1297        self.outbound_ref.is_cancelled()
1298    }
1299}
1300
1301/// Asynchronous background worker which aggregates each of the clients' updates from a [`ComputedQueries`]
1302/// into `DbUpdate`s and then sends them to the clients' WebSocket workers.
1303///
1304/// See comment on the `send_worker_tx` field in [`SubscriptionManager`] for motivation.
1305struct SendWorker {
1306    /// Receiver end of the [`SubscriptionManager`]'s `send_worker_tx` channel.
1307    rx: mpsc::UnboundedReceiver<SendWorkerMessage>,
1308
1309    /// `subscription_send_queue_length` metric labeled for this database's `Identity`.
1310    ///
1311    /// If `Some`, this metric will be decremented each time we pop a [`ComputedQueries`] from `rx`.
1312    ///
1313    /// Will be `None` in contexts where there is no database `Identity` to use as label,
1314    /// i.e. in tests.
1315    queue_length_metric: Option<IntGauge>,
1316
1317    /// Mirror of the [`SubscriptionManager`]'s `clients` map local to this actor.
1318    ///
1319    /// Updated by [`SendWorkerMessage::AddClient`] and [`SendWorkerMessage::RemoveClient`] messages
1320    /// sent along `self.rx`.
1321    clients: HashMap<ClientId, SendWorkerClient>,
1322
1323    /// The `Identity` which labels the `queue_length_metric`.
1324    ///
1325    /// If `Some`, this type's `drop` method will do `remove_label_values` to clean up the metric on exit.
1326    database_identity_to_clean_up_metric: Option<Identity>,
1327
1328    /// A map (re)used by [`SendWorker::send_one_computed_queries`]
1329    /// to avoid creating new allocations.
1330    table_updates_client_id_table_id: HashMap<(ClientId, TableId), SwitchedTableUpdate>,
1331
1332    /// A map (re)used by [`SendWorker::send_one_computed_queries`]
1333    /// to avoid creating new allocations.
1334    table_updates_client_id: HashMap<ClientId, SwitchedDbUpdate>,
1335}
1336
1337impl Drop for SendWorker {
1338    fn drop(&mut self) {
1339        if let Some(identity) = self.database_identity_to_clean_up_metric {
1340            let _ = WORKER_METRICS
1341                .subscription_send_queue_length
1342                .remove_label_values(&identity);
1343        }
1344    }
1345}
1346
1347impl SendWorker {
1348    fn is_client_dropped_or_cancelled(&self, client_id: &ClientId) -> bool {
1349        self.clients
1350            .get(client_id)
1351            .is_some_and(|client| client.is_cancelled() || client.is_dropped())
1352    }
1353}
1354
1355#[derive(Debug, Clone)]
1356pub struct BroadcastQueue(SenderWithGauge<SendWorkerMessage>);
1357
1358#[derive(thiserror::Error, Debug)]
1359#[error(transparent)]
1360pub struct BroadcastError(#[from] mpsc::error::SendError<SendWorkerMessage>);
1361
1362impl BroadcastQueue {
1363    fn send(&self, message: SendWorkerMessage) -> Result<(), BroadcastError> {
1364        self.0.send(message)?;
1365        Ok(())
1366    }
1367
1368    pub fn send_client_message(
1369        &self,
1370        recipient: Arc<ClientConnectionSender>,
1371        message: impl Into<SerializableMessage>,
1372    ) -> Result<(), BroadcastError> {
1373        self.0.send(SendWorkerMessage::SendMessage {
1374            recipient,
1375            message: message.into(),
1376        })?;
1377        Ok(())
1378    }
1379}
1380pub fn spawn_send_worker(metric_database_identity: Option<Identity>) -> BroadcastQueue {
1381    SendWorker::spawn_new(metric_database_identity)
1382}
1383impl SendWorker {
1384    fn new(
1385        rx: mpsc::UnboundedReceiver<SendWorkerMessage>,
1386        queue_length_metric: Option<IntGauge>,
1387        database_identity_to_clean_up_metric: Option<Identity>,
1388    ) -> Self {
1389        Self {
1390            rx,
1391            queue_length_metric,
1392            clients: Default::default(),
1393            database_identity_to_clean_up_metric,
1394            table_updates_client_id_table_id: <_>::default(),
1395            table_updates_client_id: <_>::default(),
1396        }
1397    }
1398
1399    // Spawn a new send worker.
1400    // If a `metric_database_identity` is provided, we will decrement the corresponding
1401    // `subscription_send_queue_length` metric, and clean it up on drop.
1402    fn spawn_new(metric_database_identity: Option<Identity>) -> BroadcastQueue {
1403        let metric = metric_database_identity.map(|identity| {
1404            WORKER_METRICS
1405                .subscription_send_queue_length
1406                .with_label_values(&identity)
1407        });
1408        let (send_worker_tx, rx) = mpsc::unbounded_channel();
1409        tokio::spawn(Self::new(rx, metric.clone(), metric_database_identity).run());
1410        BroadcastQueue(SenderWithGauge::new(send_worker_tx, metric))
1411    }
1412
1413    async fn run(mut self) {
1414        while let Some(message) = self.rx.recv().await {
1415            if let Some(metric) = &self.queue_length_metric {
1416                metric.dec();
1417            }
1418
1419            match message {
1420                SendWorkerMessage::AddClient {
1421                    client_id,
1422                    dropped,
1423                    outbound_ref,
1424                } => {
1425                    self.clients
1426                        .insert(client_id, SendWorkerClient { dropped, outbound_ref });
1427                }
1428                SendWorkerMessage::SendMessage { recipient, message } => {
1429                    let _ = recipient.send_message(message);
1430                }
1431                SendWorkerMessage::RemoveClient(client_id) => {
1432                    self.clients.remove(&client_id);
1433                }
1434                SendWorkerMessage::Broadcast(queries) => {
1435                    self.send_one_computed_queries(queries);
1436                }
1437            }
1438        }
1439    }
1440
1441    fn send_one_computed_queries(
1442        &mut self,
1443        ComputedQueries {
1444            updates,
1445            errs,
1446            event,
1447            caller,
1448        }: ComputedQueries,
1449    ) {
1450        use FormatSwitch::{Bsatn, Json};
1451
1452        let clients_with_errors = errs.iter().map(|(id, _)| id).collect::<HashSet<_>>();
1453
1454        let span = tracing::info_span!("eval_incr_group_messages_by_client");
1455
1456        // Reuse the aggregation maps from the worker.
1457        let client_table_id_updates = mem::take(&mut self.table_updates_client_id_table_id);
1458        let client_id_updates = mem::take(&mut self.table_updates_client_id);
1459
1460        // For each subscriber, aggregate all the updates for the same table.
1461        // That is, we build a map `(subscriber_id, table_id) -> updates`.
1462        // A particular subscriber uses only one format,
1463        // so their `TableUpdate` will contain either JSON (`Protocol::Text`)
1464        // or BSATN (`Protocol::Binary`).
1465        let mut client_table_id_updates = updates
1466            .into_iter()
1467            // Filter out dropped or cancelled clients
1468            .filter(|upd| !self.is_client_dropped_or_cancelled(&upd.id))
1469            // Filter out clients whose subscriptions failed
1470            .filter(|upd| !clients_with_errors.contains(&upd.id))
1471            // Do the aggregation.
1472            .fold(client_table_id_updates, |mut tables, upd| {
1473                match tables.entry((upd.id, upd.table_id)) {
1474                    Entry::Occupied(mut entry) => match entry.get_mut().zip_mut(upd.update) {
1475                        Bsatn((tbl_upd, update)) => tbl_upd.push(update),
1476                        Json((tbl_upd, update)) => tbl_upd.push(update),
1477                    },
1478                    Entry::Vacant(entry) => drop(entry.insert(match upd.update {
1479                        Bsatn(update) => Bsatn(TableUpdate::new(upd.table_id, (&*upd.table_name).into(), update)),
1480                        Json(update) => Json(TableUpdate::new(upd.table_id, (&*upd.table_name).into(), update)),
1481                    })),
1482                }
1483                tables
1484            });
1485
1486        // Each client receives a single list of updates per transaction.
1487        // So before sending the updates to each client,
1488        // we must stitch together the `TableUpdate*`s into an aggregated list.
1489        let mut client_id_updates = client_table_id_updates
1490            .drain()
1491            // Do the aggregation.
1492            .fold(client_id_updates, |mut updates, ((id, _), update)| {
1493                let entry = updates.entry(id);
1494                let entry = entry.or_insert_with(|| match &update {
1495                    Bsatn(_) => Bsatn(<_>::default()),
1496                    Json(_) => Json(<_>::default()),
1497                });
1498                match entry.zip_mut(update) {
1499                    Bsatn((list, elem)) => list.tables.push(elem),
1500                    Json((list, elem)) => list.tables.push(elem),
1501                }
1502                updates
1503            });
1504
1505        drop(clients_with_errors);
1506        drop(span);
1507
1508        let _span = tracing::info_span!("eval_send").entered();
1509
1510        // We might have a known caller that hasn't been hidden from here..
1511        // This caller may have subscribed to some query.
1512        // If they haven't, we'll send them an empty update.
1513        // Regardless, the update that we send to the caller, if we send any,
1514        // is a full tx update, rather than a light one.
1515        // That is, in the case of the caller, we don't respect the light setting.
1516        if let Some(caller) = caller {
1517            let caller_id = (caller.id.identity, caller.id.connection_id);
1518            let database_update = client_id_updates
1519                .remove(&caller_id)
1520                .map(|update| SubscriptionUpdateMessage::from_event_and_update(&event, update))
1521                .unwrap_or_else(|| {
1522                    SubscriptionUpdateMessage::default_for_protocol(caller.config.protocol, event.request_id)
1523                });
1524            let message = TransactionUpdateMessage {
1525                event: Some(event.clone()),
1526                database_update,
1527            };
1528            send_to_client(&caller, message);
1529        }
1530
1531        // Send all the other updates.
1532        for (id, update) in client_id_updates.drain() {
1533            let database_update = SubscriptionUpdateMessage::from_event_and_update(&event, update);
1534            let client = self.clients[&id].outbound_ref.clone();
1535            // Conditionally send out a full update or a light one otherwise.
1536            let event = client.config.tx_update_full.then(|| event.clone());
1537            let message = TransactionUpdateMessage { event, database_update };
1538            send_to_client(&client, message);
1539        }
1540
1541        // Put back the aggregation maps into the worker.
1542        self.table_updates_client_id_table_id = client_table_id_updates;
1543        self.table_updates_client_id = client_id_updates;
1544
1545        // Send error messages and mark clients for removal
1546        for (id, message) in errs {
1547            if let Some(client) = self.clients.get(&id) {
1548                client.dropped.store(true, Ordering::Release);
1549                send_to_client(
1550                    &client.outbound_ref,
1551                    SubscriptionMessage {
1552                        request_id: None,
1553                        query_id: None,
1554                        timer: None,
1555                        result: SubscriptionResult::Error(SubscriptionError {
1556                            table_id: None,
1557                            message,
1558                        }),
1559                    },
1560                );
1561            }
1562        }
1563    }
1564}
1565
1566fn send_to_client(client: &ClientConnectionSender, message: impl Into<SerializableMessage>) {
1567    if let Err(e) = client.send_message(message) {
1568        tracing::warn!(%client.id, "failed to send update message to client: {e}")
1569    }
1570}
1571
1572#[cfg(test)]
1573mod tests {
1574    use std::{sync::Arc, time::Duration};
1575
1576    use spacetimedb_client_api_messages::websocket::QueryId;
1577    use spacetimedb_lib::AlgebraicValue;
1578    use spacetimedb_lib::{error::ResultTest, identity::AuthCtx, AlgebraicType, ConnectionId, Identity, Timestamp};
1579    use spacetimedb_primitives::{ColId, TableId};
1580    use spacetimedb_sats::product;
1581    use spacetimedb_subscription::SubscriptionPlan;
1582
1583    use super::{Plan, SubscriptionManager};
1584    use crate::db::relational_db::tests_utils::with_read_only;
1585    use crate::host::module_host::DatabaseTableUpdate;
1586    use crate::sql::ast::SchemaViewer;
1587    use crate::subscription::module_subscription_manager::ClientQueryId;
1588    use crate::{
1589        client::{ClientActorId, ClientConfig, ClientConnectionSender, ClientName},
1590        db::relational_db::{tests_utils::TestDB, RelationalDB},
1591        energy::EnergyQuanta,
1592        host::{
1593            module_host::{DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall},
1594            ArgsTuple,
1595        },
1596        subscription::execution_unit::QueryHash,
1597    };
1598    use spacetimedb_datastore::execution_context::Workload;
1599
1600    fn create_table(db: &RelationalDB, name: &str) -> ResultTest<TableId> {
1601        Ok(db.create_table_for_test(name, &[("a", AlgebraicType::U8)], &[])?)
1602    }
1603
1604    fn compile_plan(db: &RelationalDB, sql: &str) -> ResultTest<Arc<Plan>> {
1605        with_read_only(db, |tx| {
1606            let auth = AuthCtx::for_testing();
1607            let tx = SchemaViewer::new(&*tx, &auth);
1608            let (plans, has_param) = SubscriptionPlan::compile(sql, &tx, &auth).unwrap();
1609            let hash = QueryHash::from_string(sql, auth.caller, has_param);
1610            Ok(Arc::new(Plan::new(plans, hash, sql.into())))
1611        })
1612    }
1613
1614    fn id(connection_id: u128) -> (Identity, ConnectionId) {
1615        (Identity::ZERO, ConnectionId::from_u128(connection_id))
1616    }
1617
1618    fn client(connection_id: u128) -> ClientConnectionSender {
1619        let (identity, connection_id) = id(connection_id);
1620        ClientConnectionSender::dummy(
1621            ClientActorId {
1622                identity,
1623                connection_id,
1624                name: ClientName(0),
1625            },
1626            ClientConfig::for_test(),
1627        )
1628    }
1629
1630    #[test]
1631    fn test_subscribe_legacy() -> ResultTest<()> {
1632        let db = TestDB::durable()?;
1633
1634        let table_id = create_table(&db, "T")?;
1635        let sql = "select * from T";
1636        let plan = compile_plan(&db, sql)?;
1637        let hash = plan.hash();
1638
1639        let id = id(0);
1640        let client = Arc::new(client(0));
1641
1642        let runtime = tokio::runtime::Runtime::new().unwrap();
1643        let _rt = runtime.enter();
1644
1645        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1646        subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]);
1647
1648        assert!(subscriptions.contains_query(&hash));
1649        assert!(subscriptions.contains_legacy_subscription(&id, &hash));
1650        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1651
1652        Ok(())
1653    }
1654
1655    #[test]
1656    fn test_subscribe_single_adds_table_mapping() -> ResultTest<()> {
1657        let db = TestDB::durable()?;
1658
1659        let table_id = create_table(&db, "T")?;
1660        let sql = "select * from T";
1661        let plan = compile_plan(&db, sql)?;
1662        let hash = plan.hash();
1663
1664        let client = Arc::new(client(0));
1665
1666        let query_id: ClientQueryId = QueryId::new(1);
1667
1668        let runtime = tokio::runtime::Runtime::new().unwrap();
1669        let _rt = runtime.enter();
1670
1671        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1672        subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1673        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1674
1675        Ok(())
1676    }
1677
1678    #[test]
1679    fn test_unsubscribe_from_the_only_subscription() -> ResultTest<()> {
1680        let db = TestDB::durable()?;
1681
1682        let table_id = create_table(&db, "T")?;
1683        let sql = "select * from T";
1684        let plan = compile_plan(&db, sql)?;
1685        let hash = plan.hash();
1686
1687        let client = Arc::new(client(0));
1688
1689        let query_id: ClientQueryId = QueryId::new(1);
1690
1691        let runtime = tokio::runtime::Runtime::new().unwrap();
1692        let _rt = runtime.enter();
1693
1694        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1695        subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1696        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1697
1698        let client_id = (client.id.identity, client.id.connection_id);
1699        subscriptions.remove_subscription(client_id, query_id)?;
1700        assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1701
1702        Ok(())
1703    }
1704
1705    #[test]
1706    fn test_unsubscribe_with_unknown_query_id_fails() -> ResultTest<()> {
1707        let db = TestDB::durable()?;
1708
1709        create_table(&db, "T")?;
1710        let sql = "select * from T";
1711        let plan = compile_plan(&db, sql)?;
1712
1713        let client = Arc::new(client(0));
1714
1715        let query_id: ClientQueryId = QueryId::new(1);
1716
1717        let runtime = tokio::runtime::Runtime::new().unwrap();
1718        let _rt = runtime.enter();
1719
1720        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1721        subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1722
1723        let client_id = (client.id.identity, client.id.connection_id);
1724        assert!(subscriptions.remove_subscription(client_id, QueryId::new(2)).is_err());
1725
1726        Ok(())
1727    }
1728
1729    #[test]
1730    fn test_subscribe_and_unsubscribe_with_duplicate_queries() -> ResultTest<()> {
1731        let db = TestDB::durable()?;
1732
1733        let table_id = create_table(&db, "T")?;
1734        let sql = "select * from T";
1735        let plan = compile_plan(&db, sql)?;
1736        let hash = plan.hash();
1737
1738        let client = Arc::new(client(0));
1739
1740        let query_id: ClientQueryId = QueryId::new(1);
1741
1742        let runtime = tokio::runtime::Runtime::new().unwrap();
1743        let _rt = runtime.enter();
1744
1745        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1746        subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1747        subscriptions.add_subscription(client.clone(), plan.clone(), QueryId::new(2))?;
1748
1749        let client_id = (client.id.identity, client.id.connection_id);
1750        subscriptions.remove_subscription(client_id, query_id)?;
1751
1752        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1753
1754        Ok(())
1755    }
1756
1757    /// A very simple test case of a duplicate query.
1758    #[test]
1759    fn test_subscribe_and_unsubscribe_with_duplicate_queries_multi() -> ResultTest<()> {
1760        let db = TestDB::durable()?;
1761
1762        let table_id = create_table(&db, "T")?;
1763        let sql = "select * from T";
1764        let plan = compile_plan(&db, sql)?;
1765        let hash = plan.hash();
1766
1767        let client = Arc::new(client(0));
1768
1769        let query_id: ClientQueryId = QueryId::new(1);
1770
1771        let runtime = tokio::runtime::Runtime::new().unwrap();
1772        let _rt = runtime.enter();
1773
1774        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1775        let added_query = subscriptions.add_subscription_multi(client.clone(), vec![plan.clone()], query_id)?;
1776        assert!(added_query.len() == 1);
1777        assert_eq!(added_query[0].hash, hash);
1778        let second_one = subscriptions.add_subscription_multi(client.clone(), vec![plan.clone()], QueryId::new(2))?;
1779        assert!(second_one.is_empty());
1780
1781        let client_id = (client.id.identity, client.id.connection_id);
1782        let removed_queries = subscriptions.remove_subscription(client_id, query_id)?;
1783        assert!(removed_queries.is_empty());
1784
1785        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1786        let removed_queries = subscriptions.remove_subscription(client_id, QueryId::new(2))?;
1787        assert!(removed_queries.len() == 1);
1788        assert_eq!(removed_queries[0].hash, hash);
1789
1790        assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1791
1792        Ok(())
1793    }
1794
1795    #[test]
1796    fn test_unsubscribe_doesnt_remove_other_clients() -> ResultTest<()> {
1797        let db = TestDB::durable()?;
1798
1799        let table_id = create_table(&db, "T")?;
1800        let sql = "select * from T";
1801        let plan = compile_plan(&db, sql)?;
1802        let hash = plan.hash();
1803
1804        let clients = (0..3).map(|i| Arc::new(client(i))).collect::<Vec<_>>();
1805
1806        // All of the clients are using the same query id.
1807        let query_id: ClientQueryId = QueryId::new(1);
1808
1809        let runtime = tokio::runtime::Runtime::new().unwrap();
1810        let _rt = runtime.enter();
1811
1812        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1813        subscriptions.add_subscription(clients[0].clone(), plan.clone(), query_id)?;
1814        subscriptions.add_subscription(clients[1].clone(), plan.clone(), query_id)?;
1815        subscriptions.add_subscription(clients[2].clone(), plan.clone(), query_id)?;
1816
1817        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1818
1819        let client_ids = clients
1820            .iter()
1821            .map(|client| (client.id.identity, client.id.connection_id))
1822            .collect::<Vec<_>>();
1823        subscriptions.remove_subscription(client_ids[0], query_id)?;
1824        // There are still two left.
1825        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1826        subscriptions.remove_subscription(client_ids[1], query_id)?;
1827        // There is still one left.
1828        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1829        subscriptions.remove_subscription(client_ids[2], query_id)?;
1830        // Now there are no subscribers.
1831        assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1832
1833        Ok(())
1834    }
1835
1836    #[test]
1837    fn test_unsubscribe_all_doesnt_remove_other_clients() -> ResultTest<()> {
1838        let db = TestDB::durable()?;
1839
1840        let table_id = create_table(&db, "T")?;
1841        let sql = "select * from T";
1842        let plan = compile_plan(&db, sql)?;
1843        let hash = plan.hash();
1844
1845        let clients = (0..3).map(|i| Arc::new(client(i))).collect::<Vec<_>>();
1846
1847        // All of the clients are using the same query id.
1848        let query_id: ClientQueryId = QueryId::new(1);
1849
1850        let runtime = tokio::runtime::Runtime::new().unwrap();
1851        let _rt = runtime.enter();
1852
1853        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1854        subscriptions.add_subscription(clients[0].clone(), plan.clone(), query_id)?;
1855        subscriptions.add_subscription(clients[1].clone(), plan.clone(), query_id)?;
1856        subscriptions.add_subscription(clients[2].clone(), plan.clone(), query_id)?;
1857
1858        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1859
1860        let client_ids = clients
1861            .iter()
1862            .map(|client| (client.id.identity, client.id.connection_id))
1863            .collect::<Vec<_>>();
1864        subscriptions.remove_all_subscriptions(&client_ids[0]);
1865        assert!(!subscriptions.contains_client(&client_ids[0]));
1866        // There are still two left.
1867        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1868        subscriptions.remove_all_subscriptions(&client_ids[1]);
1869        // There is still one left.
1870        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1871        assert!(!subscriptions.contains_client(&client_ids[1]));
1872        subscriptions.remove_all_subscriptions(&client_ids[2]);
1873        // Now there are no subscribers.
1874        assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1875        assert!(!subscriptions.contains_client(&client_ids[2]));
1876
1877        Ok(())
1878    }
1879
1880    // This test has a single client with 3 queries of different tables, and tests removing them.
1881    #[test]
1882    fn test_multiple_queries() -> ResultTest<()> {
1883        let db = TestDB::durable()?;
1884
1885        let table_names = ["T", "S", "U"];
1886        let table_ids = table_names
1887            .iter()
1888            .map(|name| create_table(&db, name))
1889            .collect::<ResultTest<Vec<_>>>()?;
1890        let queries = table_names
1891            .iter()
1892            .map(|name| format!("select * from {name}"))
1893            .map(|sql| compile_plan(&db, &sql))
1894            .collect::<ResultTest<Vec<_>>>()?;
1895
1896        let client = Arc::new(client(0));
1897
1898        let runtime = tokio::runtime::Runtime::new().unwrap();
1899        let _rt = runtime.enter();
1900
1901        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1902        subscriptions.add_subscription(client.clone(), queries[0].clone(), QueryId::new(1))?;
1903        subscriptions.add_subscription(client.clone(), queries[1].clone(), QueryId::new(2))?;
1904        subscriptions.add_subscription(client.clone(), queries[2].clone(), QueryId::new(3))?;
1905        for i in 0..3 {
1906            assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1907        }
1908
1909        let client_id = (client.id.identity, client.id.connection_id);
1910        subscriptions.remove_subscription(client_id, QueryId::new(1))?;
1911        assert!(!subscriptions.query_reads_from_table(&queries[0].hash(), &table_ids[0]));
1912        // Assert that the rest are there.
1913        for i in 1..3 {
1914            assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1915        }
1916
1917        // Now remove the final two at once.
1918        subscriptions.remove_all_subscriptions(&client_id);
1919        for i in 0..3 {
1920            assert!(!subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1921        }
1922
1923        Ok(())
1924    }
1925
1926    #[test]
1927    fn test_multiple_query_sets() -> ResultTest<()> {
1928        let db = TestDB::durable()?;
1929
1930        let table_names = ["T", "S", "U"];
1931        let table_ids = table_names
1932            .iter()
1933            .map(|name| create_table(&db, name))
1934            .collect::<ResultTest<Vec<_>>>()?;
1935        let queries = table_names
1936            .iter()
1937            .map(|name| format!("select * from {name}"))
1938            .map(|sql| compile_plan(&db, &sql))
1939            .collect::<ResultTest<Vec<_>>>()?;
1940
1941        let client = Arc::new(client(0));
1942
1943        let runtime = tokio::runtime::Runtime::new().unwrap();
1944        let _rt = runtime.enter();
1945
1946        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1947        let added = subscriptions.add_subscription_multi(client.clone(), vec![queries[0].clone()], QueryId::new(1))?;
1948        assert_eq!(added.len(), 1);
1949        assert_eq!(added[0].hash, queries[0].hash());
1950        let added = subscriptions.add_subscription_multi(client.clone(), vec![queries[1].clone()], QueryId::new(2))?;
1951        assert_eq!(added.len(), 1);
1952        assert_eq!(added[0].hash, queries[1].hash());
1953        let added = subscriptions.add_subscription_multi(client.clone(), vec![queries[2].clone()], QueryId::new(3))?;
1954        assert_eq!(added.len(), 1);
1955        assert_eq!(added[0].hash, queries[2].hash());
1956        for i in 0..3 {
1957            assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1958        }
1959
1960        let client_id = (client.id.identity, client.id.connection_id);
1961        let removed = subscriptions.remove_subscription(client_id, QueryId::new(1))?;
1962        assert_eq!(removed.len(), 1);
1963        assert_eq!(removed[0].hash, queries[0].hash());
1964        assert!(!subscriptions.query_reads_from_table(&queries[0].hash(), &table_ids[0]));
1965        // Assert that the rest are there.
1966        for i in 1..3 {
1967            assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1968        }
1969
1970        // Now remove the final two at once.
1971        subscriptions.remove_all_subscriptions(&client_id);
1972        for i in 0..3 {
1973            assert!(!subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1974        }
1975
1976        Ok(())
1977    }
1978
1979    #[test]
1980    fn test_internals_for_search_args() -> ResultTest<()> {
1981        let db = TestDB::durable()?;
1982
1983        let table_id = create_table(&db, "t")?;
1984
1985        let client = Arc::new(client(0));
1986
1987        let runtime = tokio::runtime::Runtime::new().unwrap();
1988        let _rt = runtime.enter();
1989
1990        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1991
1992        // Subscribe to queries that have search arguments
1993        let queries = (0u8..5)
1994            .map(|name| format!("select * from t where a = {name}"))
1995            .map(|sql| compile_plan(&db, &sql))
1996            .collect::<ResultTest<Vec<_>>>()?;
1997
1998        for (i, query) in queries.iter().enumerate().take(5) {
1999            let added =
2000                subscriptions.add_subscription_multi(client.clone(), vec![query.clone()], QueryId::new(i as u32))?;
2001            assert_eq!(added.len(), 1);
2002            assert_eq!(added[0].hash, queries[i].hash);
2003        }
2004
2005        // Assert this table has a search parameter
2006        assert!(subscriptions.table_has_search_param(table_id, ColId(0)));
2007
2008        for (i, query) in queries.iter().enumerate().take(5) {
2009            assert!(subscriptions.query_has_search_arg(query.hash, table_id, ColId(0), AlgebraicValue::U8(i as u8)));
2010
2011            // Only one of `query_reads_from_table` and `query_has_search_arg` can be true at any given time
2012            assert!(!subscriptions.query_reads_from_table(&queries[i].hash, &table_id));
2013        }
2014
2015        // Remove one of the subscriptions
2016        let query_id = QueryId::new(2);
2017        let client_id = (client.id.identity, client.id.connection_id);
2018        let removed = subscriptions.remove_subscription(client_id, query_id)?;
2019        assert_eq!(removed.len(), 1);
2020
2021        // We haven't removed the other subscriptions,
2022        // so this table should still have a search parameter.
2023        assert!(subscriptions.table_has_search_param(table_id, ColId(0)));
2024
2025        // We should have removed the search argument for this query
2026        assert!(!subscriptions.query_reads_from_table(&queries[2].hash, &table_id));
2027        assert!(!subscriptions.query_has_search_arg(queries[2].hash, table_id, ColId(0), AlgebraicValue::U8(2)));
2028
2029        for (i, query) in queries.iter().enumerate().take(5) {
2030            if i != 2 {
2031                assert!(subscriptions.query_has_search_arg(
2032                    query.hash,
2033                    table_id,
2034                    ColId(0),
2035                    AlgebraicValue::U8(i as u8)
2036                ));
2037            }
2038        }
2039
2040        // Remove all of the subscriptions
2041        subscriptions.remove_all_subscriptions(&client_id);
2042
2043        // We should no longer record a search parameter for this table
2044        assert!(!subscriptions.table_has_search_param(table_id, ColId(0)));
2045        for (i, query) in queries.iter().enumerate().take(5) {
2046            assert!(!subscriptions.query_has_search_arg(query.hash, table_id, ColId(0), AlgebraicValue::U8(i as u8)));
2047        }
2048
2049        Ok(())
2050    }
2051
2052    #[test]
2053    fn test_search_args_for_selects() -> ResultTest<()> {
2054        let db = TestDB::durable()?;
2055
2056        let table_id = create_table(&db, "t")?;
2057
2058        let client = Arc::new(client(0));
2059
2060        let runtime = tokio::runtime::Runtime::new().unwrap();
2061        let _rt = runtime.enter();
2062
2063        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2064
2065        let queries = (0u8..5)
2066            .map(|name| format!("select * from t where a = {name}"))
2067            .chain(std::iter::once(String::from("select * from t")))
2068            .map(|sql| compile_plan(&db, &sql))
2069            .collect::<ResultTest<Vec<_>>>()?;
2070
2071        for (i, query) in queries.iter().enumerate() {
2072            subscriptions.add_subscription_multi(client.clone(), vec![query.clone()], QueryId::new(i as u32))?;
2073        }
2074
2075        let hash_for_2 = queries[2].hash;
2076        let hash_for_3 = queries[3].hash;
2077        let hash_for_5 = queries[5].hash;
2078
2079        // Which queries are relevant for this table update? Only:
2080        //
2081        // select * from t where a = 2
2082        // select * from t where a = 3
2083        // select * from t
2084        let table_update = DatabaseTableUpdate {
2085            table_id,
2086            table_name: "t".into(),
2087            inserts: [product![2u8]].into(),
2088            deletes: [product![3u8]].into(),
2089        };
2090
2091        let hashes = subscriptions
2092            .queries_for_table_update(&table_update, &|_, _| None)
2093            .collect::<Vec<_>>();
2094
2095        assert!(hashes.len() == 3);
2096        assert!(hashes.contains(&&hash_for_2));
2097        assert!(hashes.contains(&&hash_for_3));
2098        assert!(hashes.contains(&&hash_for_5));
2099
2100        // Which queries are relevant for this table update?
2101        // Only: select * from t
2102        let table_update = DatabaseTableUpdate {
2103            table_id,
2104            table_name: "t".into(),
2105            inserts: [product![8u8]].into(),
2106            deletes: [product![9u8]].into(),
2107        };
2108
2109        let hashes = subscriptions
2110            .queries_for_table_update(&table_update, &|_, _| None)
2111            .collect::<Vec<_>>();
2112
2113        assert!(hashes.len() == 1);
2114        assert!(hashes.contains(&&hash_for_5));
2115
2116        Ok(())
2117    }
2118
2119    #[test]
2120    fn test_search_args_for_join() -> ResultTest<()> {
2121        let db = TestDB::durable()?;
2122
2123        let schema = [("id", AlgebraicType::U8), ("a", AlgebraicType::U8)];
2124
2125        let t_id = db.create_table_for_test("t", &schema, &[0.into()])?;
2126        let s_id = db.create_table_for_test("s", &schema, &[0.into()])?;
2127
2128        let client = Arc::new(client(0));
2129
2130        let runtime = tokio::runtime::Runtime::new().unwrap();
2131        let _rt = runtime.enter();
2132
2133        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2134
2135        let plan = compile_plan(&db, "select t.* from t join s on t.id = s.id where s.a = 1")?;
2136        let hash = plan.hash;
2137
2138        subscriptions.add_subscription_multi(client.clone(), vec![plan], QueryId::new(0))?;
2139
2140        // Do we need to evaluate the above join query for this table update?
2141        // Yes, because the above query does not filter on `t`.
2142        // Therefore we must evaluate it for any update on `t`.
2143        let table_update = DatabaseTableUpdate {
2144            table_id: t_id,
2145            table_name: "t".into(),
2146            inserts: [product![0u8, 0u8]].into(),
2147            deletes: [].into(),
2148        };
2149
2150        let hashes = subscriptions
2151            .queries_for_table_update(&table_update, &|_, _| None)
2152            .cloned()
2153            .collect::<Vec<_>>();
2154
2155        assert_eq!(hashes, vec![hash]);
2156
2157        // Do we need to evaluate the above join query for this table update?
2158        // Yes, because `s.a = 1`.
2159        let table_update = DatabaseTableUpdate {
2160            table_id: s_id,
2161            table_name: "s".into(),
2162            inserts: [product![0u8, 1u8]].into(),
2163            deletes: [].into(),
2164        };
2165
2166        let hashes = subscriptions
2167            .queries_for_table_update(&table_update, &|_, _| None)
2168            .cloned()
2169            .collect::<Vec<_>>();
2170
2171        assert_eq!(hashes, vec![hash]);
2172
2173        // Do we need to evaluate the above join query for this table update?
2174        // No, because `s.a != 1`.
2175        let table_update = DatabaseTableUpdate {
2176            table_id: s_id,
2177            table_name: "s".into(),
2178            inserts: [product![0u8, 2u8]].into(),
2179            deletes: [].into(),
2180        };
2181
2182        let hashes = subscriptions
2183            .queries_for_table_update(&table_update, &|_, _| None)
2184            .cloned()
2185            .collect::<Vec<_>>();
2186
2187        assert!(hashes.is_empty());
2188
2189        Ok(())
2190    }
2191
2192    #[test]
2193    fn test_subscribe_fails_with_duplicate_request_id() -> ResultTest<()> {
2194        let db = TestDB::durable()?;
2195
2196        create_table(&db, "T")?;
2197        let sql = "select * from T";
2198        let plan = compile_plan(&db, sql)?;
2199
2200        let client = Arc::new(client(0));
2201
2202        let query_id: ClientQueryId = QueryId::new(1);
2203
2204        let runtime = tokio::runtime::Runtime::new().unwrap();
2205        let _rt = runtime.enter();
2206
2207        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2208        subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
2209
2210        assert!(subscriptions
2211            .add_subscription(client.clone(), plan.clone(), query_id)
2212            .is_err());
2213
2214        Ok(())
2215    }
2216
2217    #[test]
2218    fn test_subscribe_multi_fails_with_duplicate_request_id() -> ResultTest<()> {
2219        let db = TestDB::durable()?;
2220
2221        create_table(&db, "T")?;
2222        let sql = "select * from T";
2223        let plan = compile_plan(&db, sql)?;
2224
2225        let client = Arc::new(client(0));
2226
2227        let query_id: ClientQueryId = QueryId::new(1);
2228
2229        let runtime = tokio::runtime::Runtime::new().unwrap();
2230        let _rt = runtime.enter();
2231
2232        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2233        let result = subscriptions.add_subscription_multi(client.clone(), vec![plan.clone()], query_id)?;
2234        assert_eq!(result[0].hash, plan.hash);
2235
2236        assert!(subscriptions
2237            .add_subscription_multi(client.clone(), vec![plan.clone()], query_id)
2238            .is_err());
2239
2240        Ok(())
2241    }
2242
2243    #[test]
2244    fn test_unsubscribe() -> ResultTest<()> {
2245        let db = TestDB::durable()?;
2246
2247        let table_id = create_table(&db, "T")?;
2248        let sql = "select * from T";
2249        let plan = compile_plan(&db, sql)?;
2250        let hash = plan.hash();
2251
2252        let id = id(0);
2253        let client = Arc::new(client(0));
2254
2255        let runtime = tokio::runtime::Runtime::new().unwrap();
2256        let _rt = runtime.enter();
2257
2258        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2259        subscriptions.set_legacy_subscription(client, [plan]);
2260        subscriptions.remove_all_subscriptions(&id);
2261
2262        assert!(!subscriptions.contains_query(&hash));
2263        assert!(!subscriptions.contains_legacy_subscription(&id, &hash));
2264        assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
2265
2266        Ok(())
2267    }
2268
2269    #[test]
2270    fn test_subscribe_idempotent() -> ResultTest<()> {
2271        let db = TestDB::durable()?;
2272
2273        let table_id = create_table(&db, "T")?;
2274        let sql = "select * from T";
2275        let plan = compile_plan(&db, sql)?;
2276        let hash = plan.hash();
2277
2278        let id = id(0);
2279        let client = Arc::new(client(0));
2280
2281        let runtime = tokio::runtime::Runtime::new().unwrap();
2282        let _rt = runtime.enter();
2283
2284        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2285        subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]);
2286        subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]);
2287
2288        assert!(subscriptions.contains_query(&hash));
2289        assert!(subscriptions.contains_legacy_subscription(&id, &hash));
2290        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
2291
2292        subscriptions.remove_all_subscriptions(&id);
2293
2294        assert!(!subscriptions.contains_query(&hash));
2295        assert!(!subscriptions.contains_legacy_subscription(&id, &hash));
2296        assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
2297
2298        Ok(())
2299    }
2300
2301    #[test]
2302    fn test_share_queries_full() -> ResultTest<()> {
2303        let db = TestDB::durable()?;
2304
2305        let table_id = create_table(&db, "T")?;
2306        let sql = "select * from T";
2307        let plan = compile_plan(&db, sql)?;
2308        let hash = plan.hash();
2309
2310        let id0 = id(0);
2311        let client0 = Arc::new(client(0));
2312
2313        let id1 = id(1);
2314        let client1 = Arc::new(client(1));
2315
2316        let runtime = tokio::runtime::Runtime::new().unwrap();
2317        let _rt = runtime.enter();
2318
2319        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2320        subscriptions.set_legacy_subscription(client0, [plan.clone()]);
2321        subscriptions.set_legacy_subscription(client1, [plan.clone()]);
2322
2323        assert!(subscriptions.contains_query(&hash));
2324        assert!(subscriptions.contains_legacy_subscription(&id0, &hash));
2325        assert!(subscriptions.contains_legacy_subscription(&id1, &hash));
2326        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
2327
2328        subscriptions.remove_all_subscriptions(&id0);
2329
2330        assert!(subscriptions.contains_query(&hash));
2331        assert!(subscriptions.contains_legacy_subscription(&id1, &hash));
2332        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
2333
2334        assert!(!subscriptions.contains_legacy_subscription(&id0, &hash));
2335
2336        Ok(())
2337    }
2338
2339    #[test]
2340    fn test_share_queries_partial() -> ResultTest<()> {
2341        let db = TestDB::durable()?;
2342
2343        let t = create_table(&db, "T")?;
2344        let s = create_table(&db, "S")?;
2345
2346        let scan = "select * from T";
2347        let select0 = "select * from T where a = 0";
2348        let select1 = "select * from S where a = 1";
2349
2350        let plan_scan = compile_plan(&db, scan)?;
2351        let plan_select0 = compile_plan(&db, select0)?;
2352        let plan_select1 = compile_plan(&db, select1)?;
2353
2354        let hash_scan = plan_scan.hash();
2355        let hash_select0 = plan_select0.hash();
2356        let hash_select1 = plan_select1.hash();
2357
2358        let id0 = id(0);
2359        let client0 = Arc::new(client(0));
2360
2361        let id1 = id(1);
2362        let client1 = Arc::new(client(1));
2363
2364        let runtime = tokio::runtime::Runtime::new().unwrap();
2365        let _rt = runtime.enter();
2366        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2367        subscriptions.set_legacy_subscription(client0, [plan_scan.clone(), plan_select0.clone()]);
2368        subscriptions.set_legacy_subscription(client1, [plan_scan.clone(), plan_select1.clone()]);
2369
2370        assert!(subscriptions.contains_query(&hash_scan));
2371        assert!(subscriptions.contains_query(&hash_select0));
2372        assert!(subscriptions.contains_query(&hash_select1));
2373
2374        assert!(subscriptions.contains_legacy_subscription(&id0, &hash_scan));
2375        assert!(subscriptions.contains_legacy_subscription(&id0, &hash_select0));
2376
2377        assert!(subscriptions.contains_legacy_subscription(&id1, &hash_scan));
2378        assert!(subscriptions.contains_legacy_subscription(&id1, &hash_select1));
2379
2380        assert!(subscriptions.query_reads_from_table(&hash_scan, &t));
2381        assert!(subscriptions.query_has_search_arg(hash_select0, t, ColId(0), AlgebraicValue::U8(0)));
2382        assert!(subscriptions.query_has_search_arg(hash_select1, s, ColId(0), AlgebraicValue::U8(1)));
2383
2384        assert!(!subscriptions.query_reads_from_table(&hash_scan, &s));
2385        assert!(!subscriptions.query_reads_from_table(&hash_select0, &t));
2386        assert!(!subscriptions.query_reads_from_table(&hash_select1, &s));
2387        assert!(!subscriptions.query_reads_from_table(&hash_select0, &s));
2388        assert!(!subscriptions.query_reads_from_table(&hash_select1, &t));
2389
2390        subscriptions.remove_all_subscriptions(&id0);
2391
2392        assert!(subscriptions.contains_query(&hash_scan));
2393        assert!(subscriptions.contains_query(&hash_select1));
2394        assert!(!subscriptions.contains_query(&hash_select0));
2395
2396        assert!(subscriptions.contains_legacy_subscription(&id1, &hash_scan));
2397        assert!(subscriptions.contains_legacy_subscription(&id1, &hash_select1));
2398
2399        assert!(!subscriptions.contains_legacy_subscription(&id0, &hash_scan));
2400        assert!(!subscriptions.contains_legacy_subscription(&id0, &hash_select0));
2401
2402        assert!(subscriptions.query_reads_from_table(&hash_scan, &t));
2403        assert!(subscriptions.query_has_search_arg(hash_select1, s, ColId(0), AlgebraicValue::U8(1)));
2404
2405        assert!(!subscriptions.query_reads_from_table(&hash_select1, &s));
2406        assert!(!subscriptions.query_reads_from_table(&hash_scan, &s));
2407        assert!(!subscriptions.query_reads_from_table(&hash_select1, &t));
2408
2409        Ok(())
2410    }
2411
2412    #[test]
2413    fn test_caller_transaction_update_without_subscription() -> ResultTest<()> {
2414        // test if a transaction update is sent to the reducer caller even if
2415        // the caller haven't subscribed to any updates
2416        let db = TestDB::durable()?;
2417
2418        let id0 = Identity::ZERO;
2419        let client0 = ClientActorId::for_test(id0);
2420        let config = ClientConfig::for_test();
2421        let (client0, mut rx) = ClientConnectionSender::dummy_with_channel(client0, config);
2422
2423        let runtime = tokio::runtime::Runtime::new().unwrap();
2424        let _rt = runtime.enter();
2425        let subscriptions = SubscriptionManager::for_test_without_metrics();
2426
2427        let event = Arc::new(ModuleEvent {
2428            timestamp: Timestamp::now(),
2429            caller_identity: id0,
2430            caller_connection_id: Some(client0.id.connection_id),
2431            function_call: ModuleFunctionCall {
2432                reducer: "DummyReducer".into(),
2433                reducer_id: u32::MAX.into(),
2434                args: ArgsTuple::nullary(),
2435            },
2436            status: EventStatus::Committed(DatabaseUpdate::default()),
2437            energy_quanta_used: EnergyQuanta::ZERO,
2438            host_execution_duration: Duration::default(),
2439            request_id: None,
2440            timer: None,
2441        });
2442
2443        db.with_read_only(Workload::Update, |tx| {
2444            subscriptions.eval_updates_sequential(&(&*tx).into(), event, Some(Arc::new(client0)))
2445        });
2446
2447        runtime.block_on(async move {
2448            tokio::time::timeout(Duration::from_millis(20), async move {
2449                rx.recv().await.expect("Expected at least one message");
2450            })
2451            .await
2452            .expect("Timed out waiting for a message to the client");
2453        });
2454
2455        Ok(())
2456    }
2457}