spacetimedb/subscription/
module_subscription_manager.rs

1use super::execution_unit::QueryHash;
2use super::tx::DeltaTx;
3use crate::client::messages::{
4    SerializableMessage, SubscriptionError, SubscriptionMessage, SubscriptionResult, SubscriptionUpdateMessage,
5    TransactionUpdateMessage,
6};
7use crate::client::{ClientConnectionSender, Protocol};
8use crate::db::datastore::locking_tx_datastore::state_view::StateView;
9use crate::error::DBError;
10use crate::host::module_host::{DatabaseTableUpdate, ModuleEvent, UpdatesRelValue};
11use crate::messages::websocket::{self as ws, TableUpdate};
12use crate::subscription::delta::eval_delta;
13use crate::worker_metrics::WORKER_METRICS;
14use hashbrown::hash_map::OccupiedError;
15use hashbrown::{HashMap, HashSet};
16use itertools::Itertools;
17use parking_lot::RwLock;
18use prometheus::IntGauge;
19use spacetimedb_client_api_messages::websocket::{
20    BsatnFormat, CompressableQueryUpdate, FormatSwitch, JsonFormat, QueryId, QueryUpdate, SingleQueryUpdate,
21    WebsocketFormat,
22};
23use spacetimedb_data_structures::map::{Entry, IntMap};
24use spacetimedb_lib::metrics::ExecutionMetrics;
25use spacetimedb_lib::{AlgebraicValue, ConnectionId, Identity, ProductValue};
26use spacetimedb_primitives::{ColId, IndexId, TableId};
27use spacetimedb_subscription::{JoinEdge, SubscriptionPlan, TableName};
28use std::collections::{BTreeMap, BTreeSet};
29use std::sync::atomic::{AtomicBool, Ordering};
30use std::sync::Arc;
31use tokio::sync::mpsc;
32
33/// Clients are uniquely identified by their Identity and ConnectionId.
34/// Identity is insufficient because different ConnectionIds can use the same Identity.
35/// TODO: Determine if ConnectionId is sufficient for uniquely identifying a client.
36type ClientId = (Identity, ConnectionId);
37type Query = Arc<Plan>;
38type Client = Arc<ClientConnectionSender>;
39type SwitchedTableUpdate = FormatSwitch<TableUpdate<BsatnFormat>, TableUpdate<JsonFormat>>;
40type SwitchedDbUpdate = FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>;
41
42/// ClientQueryId is an identifier for a query set by the client.
43type ClientQueryId = QueryId;
44/// SubscriptionId is a globally unique identifier for a subscription.
45type SubscriptionId = (ClientId, ClientQueryId);
46
47#[derive(Debug)]
48pub struct Plan {
49    hash: QueryHash,
50    sql: String,
51    plans: Vec<SubscriptionPlan>,
52}
53
54impl Plan {
55    /// Create a new subscription plan to be cached
56    pub fn new(plans: Vec<SubscriptionPlan>, hash: QueryHash, text: String) -> Self {
57        Self { plans, hash, sql: text }
58    }
59
60    /// Returns the query hash for this subscription
61    pub fn hash(&self) -> QueryHash {
62        self.hash
63    }
64
65    /// A subscription query return rows from a single table.
66    /// This method returns the id of that table.
67    pub fn subscribed_table_id(&self) -> TableId {
68        self.plans[0].subscribed_table_id()
69    }
70
71    /// A subscription query return rows from a single table.
72    /// This method returns the name of that table.
73    pub fn subscribed_table_name(&self) -> &str {
74        self.plans[0].subscribed_table_name()
75    }
76
77    /// Returns the index ids from which this subscription reads
78    pub fn index_ids(&self) -> impl Iterator<Item = (TableId, IndexId)> {
79        self.plans
80            .iter()
81            .flat_map(|plan| plan.index_ids())
82            .collect::<HashSet<_>>()
83            .into_iter()
84    }
85
86    /// Returns the table ids from which this subscription reads
87    pub fn table_ids(&self) -> impl Iterator<Item = TableId> + '_ {
88        self.plans
89            .iter()
90            .flat_map(|plan| plan.table_ids())
91            .collect::<HashSet<_>>()
92            .into_iter()
93    }
94
95    /// Returns the plan fragments that comprise this subscription.
96    /// Will only return one element unless there is a table with multiple RLS rules.
97    pub fn plans_fragments(&self) -> impl Iterator<Item = &SubscriptionPlan> + '_ {
98        self.plans.iter()
99    }
100
101    /// Returns the join edges for this plan, if any.
102    pub fn join_edges(&self) -> impl Iterator<Item = (JoinEdge, AlgebraicValue)> + '_ {
103        self.plans.iter().filter_map(|plan| plan.join_edge())
104    }
105
106    /// The `SQL` text of this subscription.
107    pub fn sql(&self) -> &str {
108        &self.sql
109    }
110}
111
112/// For each client, we hold a handle for sending messages, and we track the queries they are subscribed to.
113#[derive(Debug)]
114struct ClientInfo {
115    outbound_ref: Client,
116    subscriptions: HashMap<SubscriptionId, HashSet<QueryHash>>,
117    subscription_ref_count: HashMap<QueryHash, usize>,
118    // This should be removed when we migrate to SubscribeSingle.
119    legacy_subscriptions: HashSet<QueryHash>,
120    /// This flag is set if an error occurs during a tx update.
121    /// It will be cleaned up async or on resubscribe.
122    ///
123    /// [`Arc`]ed so that this can be updated by the [`SendWorker`]
124    /// and observed by [`SubscriptionManager::remove_dropped_clients`].
125    dropped: Arc<AtomicBool>,
126}
127
128impl ClientInfo {
129    fn new(outbound_ref: Client) -> Self {
130        Self {
131            outbound_ref,
132            subscriptions: HashMap::default(),
133            subscription_ref_count: HashMap::default(),
134            legacy_subscriptions: HashSet::default(),
135            dropped: Arc::new(AtomicBool::new(false)),
136        }
137    }
138
139    /// Check that the subscription ref count matches the actual number of subscriptions.
140    #[cfg(test)]
141    fn assert_ref_count_consistency(&self) {
142        let mut expected_ref_count = HashMap::new();
143        for query_hashes in self.subscriptions.values() {
144            for query_hash in query_hashes {
145                assert!(
146                    self.subscription_ref_count.contains_key(query_hash),
147                    "Query hash not found: {:?}",
148                    query_hash
149                );
150                expected_ref_count
151                    .entry(*query_hash)
152                    .and_modify(|count| *count += 1)
153                    .or_insert(1);
154            }
155        }
156        assert_eq!(
157            self.subscription_ref_count, expected_ref_count,
158            "Checking the reference totals failed"
159        );
160    }
161}
162
163/// For each query that has subscribers, we track a set of legacy subscribers and individual subscriptions.
164#[derive(Debug)]
165struct QueryState {
166    query: Query,
167    // For legacy clients that subscribe to a set of queries, we track them here.
168    legacy_subscribers: HashSet<ClientId>,
169    // For clients that subscribe to a single query, we track them here.
170    subscriptions: HashSet<ClientId>,
171}
172
173impl QueryState {
174    fn new(query: Query) -> Self {
175        Self {
176            query,
177            legacy_subscribers: HashSet::default(),
178            subscriptions: HashSet::default(),
179        }
180    }
181    fn has_subscribers(&self) -> bool {
182        !self.subscriptions.is_empty() || !self.legacy_subscribers.is_empty()
183    }
184
185    // This returns all of the clients listening to a query. If a client has multiple subscriptions for this query, it will appear twice.
186    fn all_clients(&self) -> impl Iterator<Item = &ClientId> {
187        itertools::chain(&self.legacy_subscribers, &self.subscriptions)
188    }
189
190    /// Return the [`Query`] for this [`QueryState`]
191    pub fn query(&self) -> &Query {
192        &self.query
193    }
194
195    /// Return the search arguments for this query
196    fn search_args(&self) -> impl Iterator<Item = (TableId, ColId, AlgebraicValue)> {
197        let mut args = HashSet::new();
198        for arg in self
199            .query
200            .plans
201            .iter()
202            .flat_map(|subscription| subscription.optimized_physical_plan().search_args())
203        {
204            args.insert(arg);
205        }
206        args.into_iter()
207    }
208}
209
210/// In this container, we keep track of parameterized subscription queries.
211/// This is used to prune unnecessary queries during subscription evaluation.
212///
213/// TODO: This container is populated on initial subscription.
214/// Ideally this information would be stored in the datastore,
215/// but because subscriptions are evaluated using a read only tx,
216/// we have to manage this memory separately.
217///
218/// If we stored this information in the datastore,
219/// we could encode pruning logic in the execution plan itself.
220#[derive(Debug, Default)]
221pub struct SearchArguments {
222    /// We parameterize subscriptions if they have an equality selection.
223    /// In this case a parameter is a [TableId], [ColId] pair.
224    ///
225    /// Ex.
226    ///
227    /// ```sql
228    /// SELECT * FROM t WHERE id = <value>
229    /// ```
230    ///
231    /// This query is parameterized by `t.id`.
232    ///
233    /// Ex.
234    ///
235    /// ```sql
236    /// SELECT t.* FROM t JOIN s ON t.id = s.id WHERE s.x = <value>
237    /// ```
238    ///
239    /// This query is parameterized by `s.x`.
240    params: BTreeSet<(TableId, ColId)>,
241    /// For each parameter we keep track of its possible values or arguments.
242    /// These arguments are the different values that clients subscribe with.
243    ///
244    /// Ex.
245    ///
246    /// ```sql
247    /// SELECT * FROM t WHERE id = 3
248    /// SELECT * FROM t WHERE id = 5
249    /// ```
250    ///
251    /// These queries will get parameterized by `t.id`,
252    /// and we will record the args `3` and `5` in this map.
253    args: BTreeSet<(TableId, ColId, AlgebraicValue, QueryHash)>,
254}
255
256impl SearchArguments {
257    /// Return the column ids by which a table is parameterized
258    fn search_params_for_table(&self, table_id: TableId) -> impl Iterator<Item = ColId> + '_ {
259        let lower_bound = (table_id, 0.into());
260        let upper_bound = (table_id, u16::MAX.into());
261        self.params
262            .range(lower_bound..=upper_bound)
263            .map(|(_, col_id)| col_id)
264            .cloned()
265    }
266
267    /// Are there queries parameterized by this table and column?
268    /// If so, do we have a subscriber for this `search_arg`?
269    fn queries_for_search_arg(
270        &self,
271        table_id: TableId,
272        col_id: ColId,
273        search_arg: AlgebraicValue,
274    ) -> impl Iterator<Item = &QueryHash> {
275        let lower_bound = (table_id, col_id, search_arg.clone(), QueryHash::MIN);
276        let upper_bound = (table_id, col_id, search_arg, QueryHash::MAX);
277        self.args.range(lower_bound..upper_bound).map(|(_, _, _, hash)| hash)
278    }
279
280    /// Find the queries that need to be evaluated for this row.
281    fn queries_for_row<'a>(&'a self, table_id: TableId, row: &'a ProductValue) -> impl Iterator<Item = &'a QueryHash> {
282        self.search_params_for_table(table_id)
283            .filter_map(|col_id| row.get_field(col_id.idx(), None).ok().map(|arg| (col_id, arg.clone())))
284            .flat_map(move |(col_id, arg)| self.queries_for_search_arg(table_id, col_id, arg))
285    }
286
287    /// Remove a query hash and its associated data from this container.
288    /// Note, a query hash may be associated with multiple column ids.
289    fn remove_query(&mut self, query: &QueryHash) {
290        // Collect the column parameters for this query
291        let mut params = self
292            .args
293            .iter()
294            .filter(|(_, _, _, hash)| hash == query)
295            .map(|(table_id, col_id, _, _)| (*table_id, *col_id))
296            .dedup()
297            .collect::<HashSet<_>>();
298
299        // Remove the search argument entries for this query
300        self.args.retain(|(_, _, _, hash)| hash != query);
301
302        // Remove column parameters that no longer have any search arguments associated to them
303        params.retain(|(table_id, col_id)| {
304            self.args
305                .range(
306                    (*table_id, *col_id, AlgebraicValue::Min, QueryHash::MIN)
307                        ..=(*table_id, *col_id, AlgebraicValue::Max, QueryHash::MAX),
308                )
309                .next()
310                .is_none()
311        });
312
313        self.params
314            .retain(|(table_id, col_id)| !params.contains(&(*table_id, *col_id)));
315    }
316
317    /// Add a new mapping from search argument to query hash
318    fn insert_query(&mut self, table_id: TableId, col_id: ColId, arg: AlgebraicValue, query: QueryHash) {
319        self.args.insert((table_id, col_id, arg, query));
320        self.params.insert((table_id, col_id));
321    }
322}
323
324/// Keeps track of the indexes that are used in subscriptions.
325#[derive(Debug, Default)]
326pub struct QueriedTableIndexIds {
327    ids: HashMap<TableId, HashMap<IndexId, usize>>,
328}
329
330impl FromIterator<(TableId, IndexId)> for QueriedTableIndexIds {
331    fn from_iter<T: IntoIterator<Item = (TableId, IndexId)>>(iter: T) -> Self {
332        let mut index_ids = Self::default();
333        for (table_id, index_id) in iter {
334            index_ids.insert_index_id(table_id, index_id);
335        }
336        index_ids
337    }
338}
339
340impl QueriedTableIndexIds {
341    /// Returns the index ids that are used in subscriptions for this table.
342    /// Note, it does not return all of the index ids that are defined on this table.
343    /// Only those that are used by at least one subscription query.
344    pub fn index_ids_for_table(&self, table_id: TableId) -> impl Iterator<Item = IndexId> + '_ {
345        self.ids
346            .get(&table_id)
347            .into_iter()
348            .flat_map(|index_ids| index_ids.keys())
349            .copied()
350    }
351
352    /// Insert a new `table_id` `index_id` pair into this container.
353    /// Note, different queries may read from the same index.
354    /// Hence we may already be tracking this index, in which case we bump its ref count.
355    pub fn insert_index_id(&mut self, table_id: TableId, index_id: IndexId) {
356        *self.ids.entry(table_id).or_default().entry(index_id).or_default() += 1;
357    }
358
359    /// Remove a `table_id` `index_id` pair from this container.
360    /// Note, different queries may read from the same index.
361    /// Hence we only remove this key from the map if its ref count goes to zero.
362    pub fn delete_index_id(&mut self, table_id: TableId, index_id: IndexId) {
363        if let Some(ids) = self.ids.get_mut(&table_id) {
364            if let Some(n) = ids.get_mut(&index_id) {
365                *n -= 1;
366
367                if *n == 0 {
368                    ids.remove(&index_id);
369
370                    if ids.is_empty() {
371                        self.ids.remove(&table_id);
372                    }
373                }
374            }
375        }
376    }
377
378    /// Insert the index ids from which a query reads into this mapping.
379    /// Note, an index may already be tracked if another query is already using it.
380    /// In this case we just bump its ref count.
381    pub fn insert_index_ids_for_query(&mut self, query: &Query) {
382        for (table_id, index_id) in query.index_ids() {
383            self.insert_index_id(table_id, index_id);
384        }
385    }
386
387    /// Delete the index ids from which a query reads from this mapping
388    /// Note, we will not remove an index id from this mapping if another query is using it.
389    /// Instead we decrement its ref count.
390    pub fn delete_index_ids_for_query(&mut self, query: &Query) {
391        for (table_id, index_id) in query.index_ids() {
392            self.delete_index_id(table_id, index_id);
393        }
394    }
395}
396
397/// A sorted set of join edges used for pruning queries.
398/// See [`JoinEdge`] for more details.
399#[derive(Debug, Default)]
400pub struct JoinEdges {
401    edges: BTreeMap<JoinEdge, HashMap<AlgebraicValue, QueryHash>>,
402}
403
404impl JoinEdges {
405    /// If this query has any join edges, add them to the map.
406    fn add_query(&mut self, qs: &QueryState) -> bool {
407        let mut inserted = false;
408        for (edge, rhs_val) in qs.query.join_edges() {
409            inserted = true;
410            self.edges.entry(edge).or_default().insert(rhs_val, qs.query.hash);
411        }
412        inserted
413    }
414
415    /// If this query has any join edges, remove them from the map.
416    fn remove_query(&mut self, query: &Query) {
417        for (edge, rhs_val) in query.join_edges() {
418            if let Some(hashes) = self.edges.get_mut(&edge) {
419                hashes.remove(&rhs_val);
420                if hashes.is_empty() {
421                    self.edges.remove(&edge);
422                }
423            }
424        }
425    }
426
427    /// Searches for queries that must be evaluated for this row,
428    /// and effectively prunes queries that do not.
429    fn queries_for_row<'a>(
430        &'a self,
431        table_id: TableId,
432        row: &'a ProductValue,
433        find_rhs_val: impl Fn(&JoinEdge, &ProductValue) -> Option<AlgebraicValue>,
434    ) -> impl Iterator<Item = &'a QueryHash> {
435        self.edges
436            .range(JoinEdge::range_for_table(table_id))
437            .filter_map(move |(edge, hashes)| find_rhs_val(edge, row).as_ref().and_then(|rhs_val| hashes.get(rhs_val)))
438    }
439}
440
441/// Responsible for the efficient evaluation of subscriptions.
442/// It performs basic multi-query optimization,
443/// in that if a query has N subscribers,
444/// it is only executed once,
445/// with the results copied to the N receivers.
446#[derive(Debug)]
447pub struct SubscriptionManager {
448    /// State for each client.
449    clients: HashMap<ClientId, ClientInfo>,
450
451    /// Queries for which there is at least one subscriber.
452    queries: HashMap<QueryHash, QueryState>,
453
454    /// If a query reads from a table,
455    /// but does not have a simple equality filter on that table,
456    /// we map the table to the query in this inverted index.
457    tables: IntMap<TableId, HashSet<QueryHash>>,
458
459    /// Tracks the indices used across all subscriptions
460    /// to enable building the appropriate indexes for row updates.
461    indexes: QueriedTableIndexIds,
462
463    /// If a query reads from a table,
464    /// and has a simple equality filter on that table,
465    /// we map the filter values to the query in this lookup table.
466    search_args: SearchArguments,
467
468    /// A sorted set of join edges used for pruning queries.
469    /// See [`JoinEdge`] for more details.
470    join_edges: JoinEdges,
471
472    /// Transmit side of a channel to the manager's [`SendWorker`] task.
473    ///
474    /// The send worker runs in parallel and pops [`ComputedQueries`]es out in order,
475    /// aggregates each client's full set of updates,
476    /// then passes them to the clients' websocket workers.
477    /// This allows transaction processing to proceed on the main thread
478    /// ahead of post-processing and broadcasting updates
479    /// while still ensuring that those updates are sent in the correct serial order.
480    /// Additionally, it avoids starving the next reducer request of Tokio workers,
481    /// as it imposes a delay between unlocking the datastore
482    /// and waking the many per-client sender Tokio tasks.
483    send_worker_queue: BroadcastQueue,
484}
485
486/// A single update for one client and one query.
487#[derive(Debug)]
488struct ClientUpdate {
489    id: ClientId,
490    table_id: TableId,
491    table_name: TableName,
492    update: FormatSwitch<SingleQueryUpdate<BsatnFormat>, SingleQueryUpdate<JsonFormat>>,
493}
494
495/// The computed incremental update queries with sufficient information
496/// to not depend on the transaction lock so that further work can be
497/// done in a separate worker: [`SubscriptionManager::send_worker`].
498/// The queries in this structure have not been aggregated yet
499/// but will be in the worker.
500#[derive(Debug)]
501struct ComputedQueries {
502    updates: Vec<ClientUpdate>,
503    errs: Vec<(ClientId, Box<str>)>,
504    event: Arc<ModuleEvent>,
505    caller: Option<Arc<ClientConnectionSender>>,
506}
507
508// Wraps a sender so that it will increment a gauge.
509#[derive(Debug)]
510struct SenderWithGauge<T> {
511    tx: mpsc::UnboundedSender<T>,
512    metric: Option<IntGauge>,
513}
514impl<T> Clone for SenderWithGauge<T> {
515    fn clone(&self) -> Self {
516        SenderWithGauge {
517            tx: self.tx.clone(),
518            metric: self.metric.clone(),
519        }
520    }
521}
522
523impl<T> SenderWithGauge<T> {
524    fn new(tx: mpsc::UnboundedSender<T>, metric: Option<IntGauge>) -> Self {
525        Self { tx, metric }
526    }
527
528    /// Send a message to the worker and update the queue length metric.
529    pub fn send(&self, msg: T) -> Result<(), mpsc::error::SendError<T>> {
530        if let Some(metric) = &self.metric {
531            metric.inc();
532        }
533        // Note, this could number would be permanently off if the send call panics.
534        self.tx.send(msg)
535    }
536}
537
538/// Message sent by the [`SubscriptionManager`] to the [`SendWorker`].
539#[derive(Debug)]
540enum SendWorkerMessage {
541    /// A transaction has completed and the [`SubscriptionManager`] has evaluated the incremental queries,
542    /// so the [`SendWorker`] should broadcast them to clients.
543    Broadcast(ComputedQueries),
544
545    /// A new client has been registered in the [`SubscriptionManager`],
546    /// so the [`SendWorker`] should also record its existence.
547    AddClient {
548        client_id: ClientId,
549        /// Shared handle on the `dropped` flag in the [`Subscriptionmanager`]'s [`ClientInfo`].
550        ///
551        /// Will be updated by [`SendWorker::run`] and read by [`SubscriptionManager::remove_dropped_clients`].
552        dropped: Arc<AtomicBool>,
553        outbound_ref: Client,
554    },
555
556    // Send a message to a client.
557    SendMessage {
558        recipient: Arc<ClientConnectionSender>,
559        message: SerializableMessage,
560    },
561
562    /// A client previously added by a [`Self::AddClient`] message has been removed from the [`SubscriptionManager`],
563    /// so the [`SendWorker`] should also forget it.
564    RemoveClient(ClientId),
565}
566
567// Tracks some gauges related to subscriptions.
568pub struct SubscriptionGaugeStats {
569    // The number of unique queries with at least one subscriber.
570    pub num_queries: usize,
571    // The number of unique connections with at least one subscription.
572    pub num_connections: usize,
573    // The number of subscription sets across all clients.
574    pub num_subscription_sets: usize,
575    // The total number of subscriptions across all clients and queries.
576    pub num_query_subscriptions: usize,
577    // The total number of subscriptions across all clients and queries.
578    pub num_legacy_subscriptions: usize,
579}
580
581impl SubscriptionManager {
582    pub fn for_test_without_metrics_arc_rwlock() -> Arc<RwLock<Self>> {
583        Arc::new(RwLock::new(Self::for_test_without_metrics()))
584    }
585
586    pub fn for_test_without_metrics() -> Self {
587        Self::new(SendWorker::spawn_new(None))
588    }
589
590    pub fn new(send_worker_queue: BroadcastQueue) -> Self {
591        Self {
592            clients: Default::default(),
593            queries: Default::default(),
594            indexes: Default::default(),
595            tables: Default::default(),
596            search_args: Default::default(),
597            join_edges: Default::default(),
598            send_worker_queue,
599        }
600    }
601
602    pub fn query(&self, hash: &QueryHash) -> Option<Query> {
603        self.queries.get(hash).map(|state| state.query.clone())
604    }
605
606    pub fn calculate_gauge_stats(&self) -> SubscriptionGaugeStats {
607        let num_queries = self.queries.len();
608        let num_connections = self.clients.len();
609        let num_query_subscriptions = self.queries.values().map(|state| state.subscriptions.len()).sum();
610        let num_subscription_sets = self.clients.values().map(|ci| ci.subscriptions.len()).sum();
611        let num_legacy_subscriptions = self
612            .clients
613            .values()
614            .filter(|ci| !ci.legacy_subscriptions.is_empty())
615            .count();
616
617        SubscriptionGaugeStats {
618            num_queries,
619            num_connections,
620            num_query_subscriptions,
621            num_subscription_sets,
622            num_legacy_subscriptions,
623        }
624    }
625
626    /// Add a new [`ClientInfo`] to the `clients` map, and broadcast a message along `send_worker_tx`
627    /// that the [`SendWorker`] should also add this client.
628    ///
629    /// Horrible signature to enable split borrows on [`Self`].
630    fn get_or_make_client_info_and_inform_send_worker<'clients>(
631        clients: &'clients mut HashMap<ClientId, ClientInfo>,
632        send_worker_tx: &BroadcastQueue,
633        client_id: ClientId,
634        outbound_ref: Client,
635    ) -> &'clients mut ClientInfo {
636        clients.entry(client_id).or_insert_with(|| {
637            let info = ClientInfo::new(outbound_ref.clone());
638            send_worker_tx
639                .send(SendWorkerMessage::AddClient {
640                    client_id,
641                    dropped: info.dropped.clone(),
642                    outbound_ref,
643                })
644                .expect("send worker has panicked, or otherwise dropped its recv queue!");
645            info
646        })
647    }
648
649    /// Remove a [`ClientInfo`] from the `clients` map,
650    /// and broadcast a message along `send_worker_tx` that the [`SendWorker`] should also remove it.
651    fn remove_client_and_inform_send_worker(&mut self, client_id: ClientId) -> Option<ClientInfo> {
652        self.clients.remove(&client_id).inspect(|_| {
653            self.send_worker_queue
654                .send(SendWorkerMessage::RemoveClient(client_id))
655                .expect("send worker has panicked, or otherwise dropped its recv queue!");
656        })
657    }
658
659    pub fn num_unique_queries(&self) -> usize {
660        self.queries.len()
661    }
662
663    #[cfg(test)]
664    fn contains_query(&self, hash: &QueryHash) -> bool {
665        self.queries.contains_key(hash)
666    }
667
668    #[cfg(test)]
669    fn contains_client(&self, subscriber: &ClientId) -> bool {
670        self.clients.contains_key(subscriber)
671    }
672
673    #[cfg(test)]
674    fn contains_legacy_subscription(&self, subscriber: &ClientId, query: &QueryHash) -> bool {
675        self.queries
676            .get(query)
677            .is_some_and(|state| state.legacy_subscribers.contains(subscriber))
678    }
679
680    #[cfg(test)]
681    fn query_reads_from_table(&self, query: &QueryHash, table: &TableId) -> bool {
682        self.tables.get(table).is_some_and(|queries| queries.contains(query))
683    }
684
685    #[cfg(test)]
686    fn query_has_search_arg(&self, query: QueryHash, table_id: TableId, col_id: ColId, arg: AlgebraicValue) -> bool {
687        self.search_args
688            .queries_for_search_arg(table_id, col_id, arg)
689            .any(|hash| *hash == query)
690    }
691
692    #[cfg(test)]
693    fn table_has_search_param(&self, table_id: TableId, col_id: ColId) -> bool {
694        self.search_args
695            .search_params_for_table(table_id)
696            .any(|id| id == col_id)
697    }
698
699    fn remove_legacy_subscriptions(&mut self, client: &ClientId) {
700        if let Some(ci) = self.clients.get_mut(client) {
701            let mut queries_to_remove = Vec::new();
702            for query_hash in ci.legacy_subscriptions.iter() {
703                let Some(query_state) = self.queries.get_mut(query_hash) else {
704                    tracing::warn!("Query state not found for query hash: {:?}", query_hash);
705                    continue;
706                };
707
708                query_state.legacy_subscribers.remove(client);
709                if !query_state.has_subscribers() {
710                    SubscriptionManager::remove_query_from_tables(
711                        &mut self.tables,
712                        &mut self.join_edges,
713                        &mut self.indexes,
714                        &mut self.search_args,
715                        &query_state.query,
716                    );
717                    queries_to_remove.push(*query_hash);
718                }
719            }
720            ci.legacy_subscriptions.clear();
721            for query_hash in queries_to_remove {
722                self.queries.remove(&query_hash);
723            }
724        }
725    }
726
727    /// Remove any clients that have been marked for removal
728    pub fn remove_dropped_clients(&mut self) {
729        for id in self.clients.keys().copied().collect::<Vec<_>>() {
730            if let Some(client) = self.clients.get(&id) {
731                if client.dropped.load(Ordering::Relaxed) {
732                    self.remove_all_subscriptions(&id);
733                }
734            }
735        }
736    }
737
738    /// Remove a single subscription for a client.
739    /// This will return an error if the client does not have a subscription with the given query id.
740    pub fn remove_subscription(&mut self, client_id: ClientId, query_id: ClientQueryId) -> Result<Vec<Query>, DBError> {
741        let subscription_id = (client_id, query_id);
742        let Some(ci) = self
743            .clients
744            .get_mut(&client_id)
745            .filter(|ci| !ci.dropped.load(Ordering::Acquire))
746        else {
747            return Err(anyhow::anyhow!("Client not found: {:?}", client_id).into());
748        };
749
750        #[cfg(test)]
751        ci.assert_ref_count_consistency();
752
753        let Some(query_hashes) = ci.subscriptions.remove(&subscription_id) else {
754            return Err(anyhow::anyhow!("Subscription not found: {:?}", subscription_id).into());
755        };
756        let mut queries_to_return = Vec::new();
757        for hash in query_hashes {
758            let remaining_refs = {
759                let Some(count) = ci.subscription_ref_count.get_mut(&hash) else {
760                    return Err(anyhow::anyhow!("Query count not found for query hash: {:?}", hash).into());
761                };
762                *count -= 1;
763                *count
764            };
765            if remaining_refs > 0 {
766                // The client is still subscribed to this query, so we are done for now.
767                continue;
768            }
769            // The client is no longer subscribed to this query.
770            ci.subscription_ref_count.remove(&hash);
771            let Some(query_state) = self.queries.get_mut(&hash) else {
772                return Err(anyhow::anyhow!("Query state not found for query hash: {:?}", hash).into());
773            };
774            queries_to_return.push(query_state.query.clone());
775            query_state.subscriptions.remove(&client_id);
776            if !query_state.has_subscribers() {
777                SubscriptionManager::remove_query_from_tables(
778                    &mut self.tables,
779                    &mut self.join_edges,
780                    &mut self.indexes,
781                    &mut self.search_args,
782                    &query_state.query,
783                );
784                self.queries.remove(&hash);
785            }
786        }
787
788        #[cfg(test)]
789        ci.assert_ref_count_consistency();
790
791        Ok(queries_to_return)
792    }
793
794    /// Adds a single subscription for a client.
795    pub fn add_subscription(&mut self, client: Client, query: Query, query_id: ClientQueryId) -> Result<(), DBError> {
796        self.add_subscription_multi(client, vec![query], query_id).map(|_| ())
797    }
798
799    pub fn add_subscription_multi(
800        &mut self,
801        client: Client,
802        queries: Vec<Query>,
803        query_id: ClientQueryId,
804    ) -> Result<Vec<Query>, DBError> {
805        let client_id = (client.id.identity, client.id.connection_id);
806
807        // Clean up any dropped subscriptions
808        if self
809            .clients
810            .get(&client_id)
811            .is_some_and(|ci| ci.dropped.load(Ordering::Acquire))
812        {
813            self.remove_all_subscriptions(&client_id);
814        }
815
816        let ci = Self::get_or_make_client_info_and_inform_send_worker(
817            &mut self.clients,
818            &self.send_worker_queue,
819            client_id,
820            client,
821        );
822
823        #[cfg(test)]
824        ci.assert_ref_count_consistency();
825        let subscription_id = (client_id, query_id);
826        let hash_set = match ci.subscriptions.try_insert(subscription_id, HashSet::new()) {
827            Err(OccupiedError { .. }) => {
828                return Err(anyhow::anyhow!(
829                    "Subscription with id {:?} already exists for client: {:?}",
830                    query_id,
831                    client_id
832                )
833                .into());
834            }
835            Ok(hash_set) => hash_set,
836        };
837        // We track the queries that are being added for this client.
838        let mut new_queries = Vec::new();
839
840        for query in &queries {
841            let hash = query.hash();
842            // Deduping queries within this single call.
843            if !hash_set.insert(hash) {
844                continue;
845            }
846            let query_state = self
847                .queries
848                .entry(hash)
849                .or_insert_with(|| QueryState::new(query.clone()));
850
851            Self::insert_query(
852                &mut self.tables,
853                &mut self.join_edges,
854                &mut self.indexes,
855                &mut self.search_args,
856                query_state,
857            );
858
859            let entry = ci.subscription_ref_count.entry(hash).or_insert(0);
860            *entry += 1;
861            let is_new_entry = *entry == 1;
862
863            let inserted = query_state.subscriptions.insert(client_id);
864            // This should arguably crash the server, as it indicates a bug.
865            if inserted != is_new_entry {
866                return Err(anyhow::anyhow!("Internal error, ref count and query_state mismatch").into());
867            }
868            if inserted {
869                new_queries.push(query.clone());
870            }
871        }
872
873        #[cfg(test)]
874        {
875            ci.assert_ref_count_consistency();
876        }
877
878        Ok(new_queries)
879    }
880
881    /// Adds a client and its queries to the subscription manager.
882    /// Sets up the set of subscriptions for the client, replacing any existing legacy subscriptions.
883    ///
884    /// If a query is not already indexed,
885    /// its table ids added to the inverted index.
886    // #[tracing::instrument(level = "trace", skip_all)]
887    pub fn set_legacy_subscription(&mut self, client: Client, queries: impl IntoIterator<Item = Query>) {
888        let client_id = (client.id.identity, client.id.connection_id);
889        // First, remove any existing legacy subscriptions.
890        self.remove_legacy_subscriptions(&client_id);
891
892        // Now, add the new subscriptions.
893        let ci = Self::get_or_make_client_info_and_inform_send_worker(
894            &mut self.clients,
895            &self.send_worker_queue,
896            client_id,
897            client,
898        );
899
900        for unit in queries {
901            let hash = unit.hash();
902            ci.legacy_subscriptions.insert(hash);
903            let query_state = self
904                .queries
905                .entry(hash)
906                .or_insert_with(|| QueryState::new(unit.clone()));
907            Self::insert_query(
908                &mut self.tables,
909                &mut self.join_edges,
910                &mut self.indexes,
911                &mut self.search_args,
912                query_state,
913            );
914            query_state.legacy_subscribers.insert(client_id);
915        }
916    }
917
918    // Update the mapping from table id to related queries by removing the given query.
919    // If this removes all queries for a table, the map entry for that table is removed altogether.
920    // This takes a ref to the table map instead of `self` to avoid borrowing issues.
921    fn remove_query_from_tables(
922        tables: &mut IntMap<TableId, HashSet<QueryHash>>,
923        join_edges: &mut JoinEdges,
924        index_ids: &mut QueriedTableIndexIds,
925        search_args: &mut SearchArguments,
926        query: &Query,
927    ) {
928        let hash = query.hash();
929        join_edges.remove_query(query);
930        search_args.remove_query(&hash);
931        index_ids.delete_index_ids_for_query(query);
932        for table_id in query.table_ids() {
933            if let Entry::Occupied(mut entry) = tables.entry(table_id) {
934                let hashes = entry.get_mut();
935                if hashes.remove(&hash) && hashes.is_empty() {
936                    entry.remove();
937                }
938            }
939        }
940    }
941
942    // Update the mapping from table id to related queries by inserting the given query.
943    // Also add any search arguments the query may have.
944    // This takes a ref to the table map instead of `self` to avoid borrowing issues.
945    fn insert_query(
946        tables: &mut IntMap<TableId, HashSet<QueryHash>>,
947        join_edges: &mut JoinEdges,
948        index_ids: &mut QueriedTableIndexIds,
949        search_args: &mut SearchArguments,
950        query_state: &QueryState,
951    ) {
952        // If this is new, we need to update the table to query mapping.
953        if !query_state.has_subscribers() {
954            let hash = query_state.query.hash;
955            let query = query_state.query();
956            let return_table = query.subscribed_table_id();
957            let mut table_ids = query.table_ids().collect::<HashSet<_>>();
958
959            // Update the index id mapping
960            index_ids.insert_index_ids_for_query(query);
961
962            // Update the search arguments
963            for (table_id, col_id, arg) in query_state.search_args() {
964                table_ids.remove(&table_id);
965                search_args.insert_query(table_id, col_id, arg, hash);
966            }
967
968            // Update the join edges if the return table didn't have any search arguments
969            if table_ids.contains(&return_table) && join_edges.add_query(query_state) {
970                table_ids.remove(&return_table);
971            }
972
973            // Finally update the `tables` map if the query didn't have a search argument or a join edge for a table
974            for table_id in table_ids {
975                tables.entry(table_id).or_default().insert(hash);
976            }
977        }
978    }
979
980    /// Removes a client from the subscriber mapping.
981    /// If a query no longer has any subscribers,
982    /// it is removed from the index along with its table ids.
983    #[tracing::instrument(level = "trace", skip_all)]
984    pub fn remove_all_subscriptions(&mut self, client: &ClientId) {
985        self.remove_legacy_subscriptions(client);
986        let Some(client_info) = self.remove_client_and_inform_send_worker(*client) else {
987            return;
988        };
989
990        debug_assert!(client_info.legacy_subscriptions.is_empty());
991        let mut queries_to_remove = Vec::new();
992        for query_hash in client_info.subscription_ref_count.keys() {
993            let Some(query_state) = self.queries.get_mut(query_hash) else {
994                tracing::warn!("Query state not found for query hash: {:?}", query_hash);
995                return;
996            };
997            query_state.subscriptions.remove(client);
998            // This could happen twice for the same hash if a client has a duplicate, but that's fine. It is idepotent.
999            if !query_state.has_subscribers() {
1000                queries_to_remove.push(*query_hash);
1001                SubscriptionManager::remove_query_from_tables(
1002                    &mut self.tables,
1003                    &mut self.join_edges,
1004                    &mut self.indexes,
1005                    &mut self.search_args,
1006                    &query_state.query,
1007                );
1008            }
1009        }
1010        for query_hash in queries_to_remove {
1011            self.queries.remove(&query_hash);
1012        }
1013    }
1014
1015    /// Find the queries that need to be evaluated for this table update.
1016    ///
1017    /// Note, this tries to prune irrelevant queries from the subscription.
1018    ///
1019    /// When is this beneficial?
1020    ///
1021    /// If many different clients subscribe to the same parameterized query,
1022    /// but they all subscribe with different parameter values,
1023    /// and if these rows contain only a few unique values for this parameter,
1024    /// most clients will not receive an update,
1025    /// and so we can avoid evaluating queries for them entirely.
1026    ///
1027    /// Ex.
1028    ///
1029    /// 1000 clients subscribe to `SELECT * FROM t WHERE id = ?`,
1030    /// each one with a different value for `?`.
1031    /// If there are transactions that only ever update one row of `t` at a time,
1032    /// we only pay the cost of evaluating one query.
1033    ///
1034    /// When is this not beneficial?
1035    ///
1036    /// If the table update contains a lot of unique values for a parameter,
1037    /// we won't be able to prune very many queries from the subscription,
1038    /// so this could add some overhead linear in the size of the table update.
1039    ///
1040    /// TODO: This logic should be expressed in the execution plan itself,
1041    /// so that we don't have to preprocess the table update before execution.
1042    fn queries_for_table_update<'a>(
1043        &'a self,
1044        table_update: &'a DatabaseTableUpdate,
1045        find_rhs_val: &impl Fn(&JoinEdge, &ProductValue) -> Option<AlgebraicValue>,
1046    ) -> impl Iterator<Item = &'a QueryHash> {
1047        let mut queries = HashSet::new();
1048        for hash in table_update
1049            .inserts
1050            .iter()
1051            .chain(table_update.deletes.iter())
1052            .flat_map(|row| self.queries_for_row(table_update.table_id, row, find_rhs_val))
1053        {
1054            queries.insert(hash);
1055        }
1056        for hash in self.tables.get(&table_update.table_id).into_iter().flatten() {
1057            queries.insert(hash);
1058        }
1059        queries.into_iter()
1060    }
1061
1062    /// Find the queries that need to be evaluated for this row.
1063    fn queries_for_row<'a>(
1064        &'a self,
1065        table_id: TableId,
1066        row: &'a ProductValue,
1067        find_rhs_val: impl Fn(&JoinEdge, &ProductValue) -> Option<AlgebraicValue>,
1068    ) -> impl Iterator<Item = &'a QueryHash> {
1069        self.search_args
1070            .queries_for_row(table_id, row)
1071            .chain(self.join_edges.queries_for_row(table_id, row, find_rhs_val))
1072    }
1073
1074    /// Returns the index ids that are used in subscription queries
1075    pub fn index_ids_for_subscriptions(&self) -> &QueriedTableIndexIds {
1076        &self.indexes
1077    }
1078
1079    /// This method takes a set of delta tables,
1080    /// evaluates only the necessary queries for those delta tables,
1081    /// and then sends the results to each client.
1082    ///
1083    /// This previously used rayon to parallelize subscription evaluation.
1084    /// However, in order to optimize for the common case of small updates,
1085    /// we removed rayon and switched to a single-threaded execution,
1086    /// which removed significant overhead associated with thread switching.
1087    #[tracing::instrument(level = "trace", skip_all)]
1088    pub fn eval_updates_sequential(
1089        &self,
1090        tx: &DeltaTx,
1091        event: Arc<ModuleEvent>,
1092        caller: Option<Arc<ClientConnectionSender>>,
1093    ) -> ExecutionMetrics {
1094        use FormatSwitch::{Bsatn, Json};
1095
1096        let tables = &event.status.database_update().unwrap().tables;
1097
1098        let span = tracing::info_span!("eval_incr").entered();
1099
1100        #[derive(Default)]
1101        struct FoldState {
1102            updates: Vec<ClientUpdate>,
1103            errs: Vec<(ClientId, Box<str>)>,
1104            metrics: ExecutionMetrics,
1105        }
1106
1107        /// Returns the value pointed to by this join edge
1108        fn find_rhs_val(edge: &JoinEdge, row: &ProductValue, tx: &DeltaTx) -> Option<AlgebraicValue> {
1109            // What if the joining row was deleted in this tx?
1110            // Will we prune a query that we shouldn't have?
1111            //
1112            // Ultimately no we will not.
1113            // We may prune it for this row specifically,
1114            // but we will eventually include it for the joining row.
1115            tx.iter_by_col_eq(
1116                edge.rhs_table,
1117                edge.rhs_join_col,
1118                &row.elements[edge.lhs_join_col.idx()],
1119            )
1120            .expect("This read should always succeed, and it's a bug if it doesn't")
1121            .next()
1122            .map(|row| {
1123                row.read_col(edge.rhs_col)
1124                    .expect("This read should always succeed, and it's a bug if it doesn't")
1125            })
1126        }
1127
1128        let FoldState { updates, errs, metrics } = tables
1129            .iter()
1130            .filter(|table| !table.inserts.is_empty() || !table.deletes.is_empty())
1131            .flat_map(|table_update| {
1132                self.queries_for_table_update(table_update, &|edge, row| find_rhs_val(edge, row, tx))
1133            })
1134            // deduplicate queries by their hash
1135            .filter({
1136                let mut seen = HashSet::new();
1137                // (HashSet::insert returns true for novel elements)
1138                move |&hash| seen.insert(hash)
1139            })
1140            .flat_map(|hash| {
1141                let qstate = &self.queries[hash];
1142                qstate
1143                    .query
1144                    .plans_fragments()
1145                    .map(move |plan_fragment| (qstate, plan_fragment))
1146            })
1147            // If N clients are subscribed to a query,
1148            // we copy the DatabaseTableUpdate N times,
1149            // which involves cloning BSATN (binary) or product values (json).
1150            .fold(FoldState::default(), |mut acc, (qstate, plan)| {
1151                let table_id = plan.subscribed_table_id();
1152                let table_name = plan.subscribed_table_name().clone();
1153                // Store at most one copy for both the serialization to BSATN and JSON.
1154                // Each subscriber gets to pick which of these they want,
1155                // but we only fill `ops_bin_uncompressed` and `ops_json` at most once.
1156                // The former will be `Some(_)` if some subscriber uses `Protocol::Binary`
1157                // and the latter `Some(_)` if some subscriber uses `Protocol::Text`.
1158                //
1159                // Previously we were compressing each `QueryUpdate` within a `TransactionUpdate`.
1160                // The reason was simple - many clients can subscribe to the same query.
1161                // If we compress `TransactionUpdate`s independently for each client,
1162                // we could be doing a lot of redundant compression.
1163                //
1164                // However the risks associated with this approach include:
1165                //   1. We have to hold the tx lock when compressing
1166                //   2. A potentially worse compression ratio
1167                //   3. Extra decompression overhead on the client
1168                //
1169                // Because transaction processing is currently single-threaded,
1170                // the risks of holding the tx lock for longer than necessary,
1171                // as well as additional the message processing overhead on the client,
1172                // outweighed the benefit of reduced cpu with the former approach.
1173                let mut ops_bin_uncompressed: Option<(CompressableQueryUpdate<BsatnFormat>, _, _)> = None;
1174                let mut ops_json: Option<(QueryUpdate<JsonFormat>, _, _)> = None;
1175
1176                fn memo_encode<F: WebsocketFormat>(
1177                    updates: &UpdatesRelValue<'_>,
1178                    memory: &mut Option<(F::QueryUpdate, u64, usize)>,
1179                    metrics: &mut ExecutionMetrics,
1180                ) -> SingleQueryUpdate<F> {
1181                    let (update, num_rows, num_bytes) = memory
1182                        .get_or_insert_with(|| {
1183                            let encoded = updates.encode::<F>();
1184                            // The first time we insert into this map, we call encode.
1185                            // This is when we serialize the rows to BSATN/JSON.
1186                            // Hence this is where we increment `bytes_scanned`.
1187                            metrics.bytes_scanned += encoded.2;
1188                            encoded
1189                        })
1190                        .clone();
1191                    // We call this function for each query,
1192                    // and for each client subscribed to it.
1193                    // Therefore every time we call this function,
1194                    // we update the `bytes_sent_to_clients` metric.
1195                    metrics.bytes_sent_to_clients += num_bytes;
1196                    SingleQueryUpdate { update, num_rows }
1197                }
1198
1199                // filter out clients that've dropped
1200                let clients_for_query = qstate.all_clients().filter(|id| {
1201                    self.clients
1202                        .get(*id)
1203                        .is_some_and(|info| !info.dropped.load(Ordering::Acquire))
1204                });
1205
1206                match eval_delta(tx, &mut acc.metrics, plan) {
1207                    Err(err) => {
1208                        tracing::error!(
1209                            message = "Query errored during tx update",
1210                            sql = qstate.query.sql,
1211                            reason = ?err,
1212                        );
1213                        let err = DBError::WithSql {
1214                            sql: qstate.query.sql.as_str().into(),
1215                            error: Box::new(err.into()),
1216                        }
1217                        .to_string()
1218                        .into_boxed_str();
1219
1220                        acc.errs.extend(clients_for_query.map(|id| (*id, err.clone())))
1221                    }
1222                    // The query didn't return any rows to update
1223                    Ok(None) => {}
1224                    // The query did return updates - process them and add them to the accumulator
1225                    Ok(Some(delta_updates)) => {
1226                        let row_iter = clients_for_query.map(|id| {
1227                            let client = &self.clients[id].outbound_ref;
1228                            let update = match client.config.protocol {
1229                                Protocol::Binary => Bsatn(memo_encode::<BsatnFormat>(
1230                                    &delta_updates,
1231                                    &mut ops_bin_uncompressed,
1232                                    &mut acc.metrics,
1233                                )),
1234                                Protocol::Text => Json(memo_encode::<JsonFormat>(
1235                                    &delta_updates,
1236                                    &mut ops_json,
1237                                    &mut acc.metrics,
1238                                )),
1239                            };
1240                            ClientUpdate {
1241                                id: *id,
1242                                table_id,
1243                                table_name: table_name.clone(),
1244                                update,
1245                            }
1246                        });
1247                        acc.updates.extend(row_iter);
1248                    }
1249                }
1250
1251                acc
1252            });
1253
1254        // We've now finished all of the work which needs to read from the datastore,
1255        // so get this work off the main thread and over to the `send_worker`,
1256        // then return ASAP in order to unlock the datastore and start running the next transaction.
1257        // See comment on the `send_worker_tx` field in [`SubscriptionManager`] for more motivation.
1258        self.send_worker_queue
1259            .send(SendWorkerMessage::Broadcast(ComputedQueries {
1260                updates,
1261                errs,
1262                event,
1263                caller,
1264            }))
1265            .expect("send worker has panicked, or otherwise dropped its recv queue!");
1266
1267        drop(span);
1268
1269        metrics
1270    }
1271}
1272
1273struct SendWorkerClient {
1274    /// This flag is set if an error occurs during a tx update.
1275    /// It will be cleaned up async or on resubscribe.
1276    ///
1277    /// [`Arc`]ed so that this can be updated by [`Self::run`]
1278    /// and observed by [`SubscriptionManager::remove_dropped_clients`].
1279    dropped: Arc<AtomicBool>,
1280    outbound_ref: Client,
1281}
1282
1283/// Asynchronous background worker which aggregates each of the clients' updates from a [`ComputedQueries`]
1284/// into `DbUpdate`s and then sends them to the clients' WebSocket workers.
1285///
1286/// See comment on the `send_worker_tx` field in [`SubscriptionManager`] for motivation.
1287struct SendWorker {
1288    /// Receiver end of the [`SubscriptionManager`]'s `send_worker_tx` channel.
1289    rx: mpsc::UnboundedReceiver<SendWorkerMessage>,
1290
1291    /// `subscription_send_queue_length` metric labeled for this database's `Identity`.
1292    ///
1293    /// If `Some`, this metric will be decremented each time we pop a [`ComputedQueries`] from `rx`.
1294    ///
1295    /// Will be `None` in contexts where there is no database `Identity` to use as label,
1296    /// i.e. in tests.
1297    queue_length_metric: Option<IntGauge>,
1298
1299    /// Mirror of the [`SubscriptionManager`]'s `clients` map local to this actor.
1300    ///
1301    /// Updated by [`SendWorkerMessage::AddClient`] and [`SendWorkerMessage::RemoveClient`] messages
1302    /// sent along `self.rx`.
1303    clients: HashMap<ClientId, SendWorkerClient>,
1304
1305    /// The `Identity` which labels the `queue_length_metric`.
1306    ///
1307    /// If `Some`, this type's `drop` method will do `remove_label_values` to clean up the metric on exit.
1308    database_identity_to_clean_up_metric: Option<Identity>,
1309}
1310
1311impl Drop for SendWorker {
1312    fn drop(&mut self) {
1313        if let Some(identity) = self.database_identity_to_clean_up_metric {
1314            let _ = WORKER_METRICS
1315                .subscription_send_queue_length
1316                .remove_label_values(&identity);
1317        }
1318    }
1319}
1320
1321#[derive(Debug, Clone)]
1322pub struct BroadcastQueue(SenderWithGauge<SendWorkerMessage>);
1323
1324#[derive(thiserror::Error, Debug)]
1325#[error(transparent)]
1326pub struct BroadcastError(#[from] mpsc::error::SendError<SendWorkerMessage>);
1327
1328impl BroadcastQueue {
1329    fn send(&self, message: SendWorkerMessage) -> Result<(), BroadcastError> {
1330        self.0.send(message)?;
1331        Ok(())
1332    }
1333
1334    pub fn send_client_message(
1335        &self,
1336        recipient: Arc<ClientConnectionSender>,
1337        message: impl Into<SerializableMessage>,
1338    ) -> Result<(), BroadcastError> {
1339        self.0.send(SendWorkerMessage::SendMessage {
1340            recipient,
1341            message: message.into(),
1342        })?;
1343        Ok(())
1344    }
1345}
1346pub fn spawn_send_worker(metric_database_identity: Option<Identity>) -> BroadcastQueue {
1347    SendWorker::spawn_new(metric_database_identity)
1348}
1349impl SendWorker {
1350    fn new(
1351        rx: mpsc::UnboundedReceiver<SendWorkerMessage>,
1352        queue_length_metric: Option<IntGauge>,
1353        database_identity_to_clean_up_metric: Option<Identity>,
1354    ) -> Self {
1355        Self {
1356            rx,
1357            queue_length_metric,
1358            clients: Default::default(),
1359            database_identity_to_clean_up_metric,
1360        }
1361    }
1362
1363    // Spawn a new send worker.
1364    // If a `metric_database_identity` is provided, we will decrement the corresponding
1365    // `subscription_send_queue_length` metric, and clean it up on drop.
1366    fn spawn_new(metric_database_identity: Option<Identity>) -> BroadcastQueue {
1367        let metric = metric_database_identity.map(|identity| {
1368            WORKER_METRICS
1369                .subscription_send_queue_length
1370                .with_label_values(&identity)
1371        });
1372        let (send_worker_tx, rx) = mpsc::unbounded_channel();
1373        tokio::spawn(Self::new(rx, metric.clone(), metric_database_identity).run());
1374        BroadcastQueue(SenderWithGauge::new(send_worker_tx, metric))
1375    }
1376
1377    async fn run(mut self) {
1378        while let Some(message) = self.rx.recv().await {
1379            if let Some(metric) = &self.queue_length_metric {
1380                metric.dec();
1381            }
1382
1383            match message {
1384                SendWorkerMessage::AddClient {
1385                    client_id,
1386                    dropped,
1387                    outbound_ref,
1388                } => {
1389                    self.clients
1390                        .insert(client_id, SendWorkerClient { dropped, outbound_ref });
1391                }
1392                SendWorkerMessage::SendMessage { recipient, message } => {
1393                    let _ = recipient.send_message(message);
1394                }
1395                SendWorkerMessage::RemoveClient(client_id) => {
1396                    self.clients.remove(&client_id);
1397                }
1398                SendWorkerMessage::Broadcast(queries) => {
1399                    self.send_one_computed_queries(queries);
1400                }
1401            }
1402        }
1403    }
1404
1405    fn send_one_computed_queries(
1406        &self,
1407        ComputedQueries {
1408            updates,
1409            errs,
1410            event,
1411            caller,
1412        }: ComputedQueries,
1413    ) {
1414        use FormatSwitch::{Bsatn, Json};
1415
1416        let clients_with_errors = errs.iter().map(|(id, _)| id).collect::<HashSet<_>>();
1417
1418        let span = tracing::info_span!("eval_incr_group_messages_by_client");
1419
1420        let mut eval = updates
1421            .into_iter()
1422            // Filter out clients whose subscriptions failed
1423            .filter(|upd| !clients_with_errors.contains(&upd.id))
1424            // For each subscriber, aggregate all the updates for the same table.
1425            // That is, we build a map `(subscriber_id, table_id) -> updates`.
1426            // A particular subscriber uses only one format,
1427            // so their `TableUpdate` will contain either JSON (`Protocol::Text`)
1428            // or BSATN (`Protocol::Binary`).
1429            .fold(
1430                HashMap::<(ClientId, TableId), SwitchedTableUpdate>::new(),
1431                |mut tables, upd| {
1432                    match tables.entry((upd.id, upd.table_id)) {
1433                        Entry::Occupied(mut entry) => match entry.get_mut().zip_mut(upd.update) {
1434                            Bsatn((tbl_upd, update)) => tbl_upd.push(update),
1435                            Json((tbl_upd, update)) => tbl_upd.push(update),
1436                        },
1437                        Entry::Vacant(entry) => drop(entry.insert(match upd.update {
1438                            Bsatn(update) => Bsatn(TableUpdate::new(upd.table_id, (&*upd.table_name).into(), update)),
1439                            Json(update) => Json(TableUpdate::new(upd.table_id, (&*upd.table_name).into(), update)),
1440                        })),
1441                    }
1442                    tables
1443                },
1444            )
1445            .into_iter()
1446            // Each client receives a single list of updates per transaction.
1447            // So before sending the updates to each client,
1448            // we must stitch together the `TableUpdate*`s into an aggregated list.
1449            .fold(
1450                HashMap::<ClientId, SwitchedDbUpdate>::new(),
1451                |mut updates, ((id, _), update)| {
1452                    let entry = updates.entry(id);
1453                    let entry = entry.or_insert_with(|| match &update {
1454                        Bsatn(_) => Bsatn(<_>::default()),
1455                        Json(_) => Json(<_>::default()),
1456                    });
1457                    match entry.zip_mut(update) {
1458                        Bsatn((list, elem)) => list.tables.push(elem),
1459                        Json((list, elem)) => list.tables.push(elem),
1460                    }
1461                    updates
1462                },
1463            );
1464
1465        drop(clients_with_errors);
1466        drop(span);
1467
1468        let _span = tracing::info_span!("eval_send").entered();
1469
1470        // We might have a known caller that hasn't been hidden from here..
1471        // This caller may have subscribed to some query.
1472        // If they haven't, we'll send them an empty update.
1473        // Regardless, the update that we send to the caller, if we send any,
1474        // is a full tx update, rather than a light one.
1475        // That is, in the case of the caller, we don't respect the light setting.
1476        if let Some(caller) = caller {
1477            let caller_id = (caller.id.identity, caller.id.connection_id);
1478            let database_update = eval
1479                .remove(&caller_id)
1480                .map(|update| SubscriptionUpdateMessage::from_event_and_update(&event, update))
1481                .unwrap_or_else(|| {
1482                    SubscriptionUpdateMessage::default_for_protocol(caller.config.protocol, event.request_id)
1483                });
1484            let message = TransactionUpdateMessage {
1485                event: Some(event.clone()),
1486                database_update,
1487            };
1488            send_to_client(&caller, message);
1489        }
1490
1491        // Send all the other updates.
1492        for (id, update) in eval {
1493            let database_update = SubscriptionUpdateMessage::from_event_and_update(&event, update);
1494            let client = self.clients[&id].outbound_ref.clone();
1495            // Conditionally send out a full update or a light one otherwise.
1496            let event = client.config.tx_update_full.then(|| event.clone());
1497            let message = TransactionUpdateMessage { event, database_update };
1498            send_to_client(&client, message);
1499        }
1500
1501        // Send error messages and mark clients for removal
1502        for (id, message) in errs {
1503            if let Some(client) = self.clients.get(&id) {
1504                client.dropped.store(true, Ordering::Release);
1505                send_to_client(
1506                    &client.outbound_ref,
1507                    SubscriptionMessage {
1508                        request_id: None,
1509                        query_id: None,
1510                        timer: None,
1511                        result: SubscriptionResult::Error(SubscriptionError {
1512                            table_id: None,
1513                            message,
1514                        }),
1515                    },
1516                );
1517            }
1518        }
1519    }
1520}
1521
1522fn send_to_client(client: &ClientConnectionSender, message: impl Into<SerializableMessage>) {
1523    if let Err(e) = client.send_message(message) {
1524        tracing::warn!(%client.id, "failed to send update message to client: {e}")
1525    }
1526}
1527
1528#[cfg(test)]
1529mod tests {
1530    use std::{sync::Arc, time::Duration};
1531
1532    use spacetimedb_client_api_messages::websocket::QueryId;
1533    use spacetimedb_lib::AlgebraicValue;
1534    use spacetimedb_lib::{error::ResultTest, identity::AuthCtx, AlgebraicType, ConnectionId, Identity, Timestamp};
1535    use spacetimedb_primitives::{ColId, TableId};
1536    use spacetimedb_sats::product;
1537    use spacetimedb_subscription::SubscriptionPlan;
1538
1539    use super::{Plan, SubscriptionManager};
1540    use crate::db::relational_db::tests_utils::with_read_only;
1541    use crate::execution_context::Workload;
1542    use crate::host::module_host::DatabaseTableUpdate;
1543    use crate::sql::ast::SchemaViewer;
1544    use crate::subscription::module_subscription_manager::ClientQueryId;
1545    use crate::{
1546        client::{ClientActorId, ClientConfig, ClientConnectionSender, ClientName},
1547        db::relational_db::{tests_utils::TestDB, RelationalDB},
1548        energy::EnergyQuanta,
1549        host::{
1550            module_host::{DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall},
1551            ArgsTuple,
1552        },
1553        subscription::execution_unit::QueryHash,
1554    };
1555
1556    fn create_table(db: &RelationalDB, name: &str) -> ResultTest<TableId> {
1557        Ok(db.create_table_for_test(name, &[("a", AlgebraicType::U8)], &[])?)
1558    }
1559
1560    fn compile_plan(db: &RelationalDB, sql: &str) -> ResultTest<Arc<Plan>> {
1561        with_read_only(db, |tx| {
1562            let auth = AuthCtx::for_testing();
1563            let tx = SchemaViewer::new(&*tx, &auth);
1564            let (plans, has_param) = SubscriptionPlan::compile(sql, &tx, &auth).unwrap();
1565            let hash = QueryHash::from_string(sql, auth.caller, has_param);
1566            Ok(Arc::new(Plan::new(plans, hash, sql.into())))
1567        })
1568    }
1569
1570    fn id(connection_id: u128) -> (Identity, ConnectionId) {
1571        (Identity::ZERO, ConnectionId::from_u128(connection_id))
1572    }
1573
1574    fn client(connection_id: u128) -> ClientConnectionSender {
1575        let (identity, connection_id) = id(connection_id);
1576        ClientConnectionSender::dummy(
1577            ClientActorId {
1578                identity,
1579                connection_id,
1580                name: ClientName(0),
1581            },
1582            ClientConfig::for_test(),
1583        )
1584    }
1585
1586    #[test]
1587    fn test_subscribe_legacy() -> ResultTest<()> {
1588        let db = TestDB::durable()?;
1589
1590        let table_id = create_table(&db, "T")?;
1591        let sql = "select * from T";
1592        let plan = compile_plan(&db, sql)?;
1593        let hash = plan.hash();
1594
1595        let id = id(0);
1596        let client = Arc::new(client(0));
1597
1598        let runtime = tokio::runtime::Runtime::new().unwrap();
1599        let _rt = runtime.enter();
1600
1601        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1602        subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]);
1603
1604        assert!(subscriptions.contains_query(&hash));
1605        assert!(subscriptions.contains_legacy_subscription(&id, &hash));
1606        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1607
1608        Ok(())
1609    }
1610
1611    #[test]
1612    fn test_subscribe_single_adds_table_mapping() -> ResultTest<()> {
1613        let db = TestDB::durable()?;
1614
1615        let table_id = create_table(&db, "T")?;
1616        let sql = "select * from T";
1617        let plan = compile_plan(&db, sql)?;
1618        let hash = plan.hash();
1619
1620        let client = Arc::new(client(0));
1621
1622        let query_id: ClientQueryId = QueryId::new(1);
1623
1624        let runtime = tokio::runtime::Runtime::new().unwrap();
1625        let _rt = runtime.enter();
1626
1627        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1628        subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1629        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1630
1631        Ok(())
1632    }
1633
1634    #[test]
1635    fn test_unsubscribe_from_the_only_subscription() -> ResultTest<()> {
1636        let db = TestDB::durable()?;
1637
1638        let table_id = create_table(&db, "T")?;
1639        let sql = "select * from T";
1640        let plan = compile_plan(&db, sql)?;
1641        let hash = plan.hash();
1642
1643        let client = Arc::new(client(0));
1644
1645        let query_id: ClientQueryId = QueryId::new(1);
1646
1647        let runtime = tokio::runtime::Runtime::new().unwrap();
1648        let _rt = runtime.enter();
1649
1650        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1651        subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1652        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1653
1654        let client_id = (client.id.identity, client.id.connection_id);
1655        subscriptions.remove_subscription(client_id, query_id)?;
1656        assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1657
1658        Ok(())
1659    }
1660
1661    #[test]
1662    fn test_unsubscribe_with_unknown_query_id_fails() -> ResultTest<()> {
1663        let db = TestDB::durable()?;
1664
1665        create_table(&db, "T")?;
1666        let sql = "select * from T";
1667        let plan = compile_plan(&db, sql)?;
1668
1669        let client = Arc::new(client(0));
1670
1671        let query_id: ClientQueryId = QueryId::new(1);
1672
1673        let runtime = tokio::runtime::Runtime::new().unwrap();
1674        let _rt = runtime.enter();
1675
1676        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1677        subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1678
1679        let client_id = (client.id.identity, client.id.connection_id);
1680        assert!(subscriptions.remove_subscription(client_id, QueryId::new(2)).is_err());
1681
1682        Ok(())
1683    }
1684
1685    #[test]
1686    fn test_subscribe_and_unsubscribe_with_duplicate_queries() -> ResultTest<()> {
1687        let db = TestDB::durable()?;
1688
1689        let table_id = create_table(&db, "T")?;
1690        let sql = "select * from T";
1691        let plan = compile_plan(&db, sql)?;
1692        let hash = plan.hash();
1693
1694        let client = Arc::new(client(0));
1695
1696        let query_id: ClientQueryId = QueryId::new(1);
1697
1698        let runtime = tokio::runtime::Runtime::new().unwrap();
1699        let _rt = runtime.enter();
1700
1701        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1702        subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1703        subscriptions.add_subscription(client.clone(), plan.clone(), QueryId::new(2))?;
1704
1705        let client_id = (client.id.identity, client.id.connection_id);
1706        subscriptions.remove_subscription(client_id, query_id)?;
1707
1708        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1709
1710        Ok(())
1711    }
1712
1713    /// A very simple test case of a duplicate query.
1714    #[test]
1715    fn test_subscribe_and_unsubscribe_with_duplicate_queries_multi() -> ResultTest<()> {
1716        let db = TestDB::durable()?;
1717
1718        let table_id = create_table(&db, "T")?;
1719        let sql = "select * from T";
1720        let plan = compile_plan(&db, sql)?;
1721        let hash = plan.hash();
1722
1723        let client = Arc::new(client(0));
1724
1725        let query_id: ClientQueryId = QueryId::new(1);
1726
1727        let runtime = tokio::runtime::Runtime::new().unwrap();
1728        let _rt = runtime.enter();
1729
1730        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1731        let added_query = subscriptions.add_subscription_multi(client.clone(), vec![plan.clone()], query_id)?;
1732        assert!(added_query.len() == 1);
1733        assert_eq!(added_query[0].hash, hash);
1734        let second_one = subscriptions.add_subscription_multi(client.clone(), vec![plan.clone()], QueryId::new(2))?;
1735        assert!(second_one.is_empty());
1736
1737        let client_id = (client.id.identity, client.id.connection_id);
1738        let removed_queries = subscriptions.remove_subscription(client_id, query_id)?;
1739        assert!(removed_queries.is_empty());
1740
1741        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1742        let removed_queries = subscriptions.remove_subscription(client_id, QueryId::new(2))?;
1743        assert!(removed_queries.len() == 1);
1744        assert_eq!(removed_queries[0].hash, hash);
1745
1746        assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1747
1748        Ok(())
1749    }
1750
1751    #[test]
1752    fn test_unsubscribe_doesnt_remove_other_clients() -> ResultTest<()> {
1753        let db = TestDB::durable()?;
1754
1755        let table_id = create_table(&db, "T")?;
1756        let sql = "select * from T";
1757        let plan = compile_plan(&db, sql)?;
1758        let hash = plan.hash();
1759
1760        let clients = (0..3).map(|i| Arc::new(client(i))).collect::<Vec<_>>();
1761
1762        // All of the clients are using the same query id.
1763        let query_id: ClientQueryId = QueryId::new(1);
1764
1765        let runtime = tokio::runtime::Runtime::new().unwrap();
1766        let _rt = runtime.enter();
1767
1768        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1769        subscriptions.add_subscription(clients[0].clone(), plan.clone(), query_id)?;
1770        subscriptions.add_subscription(clients[1].clone(), plan.clone(), query_id)?;
1771        subscriptions.add_subscription(clients[2].clone(), plan.clone(), query_id)?;
1772
1773        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1774
1775        let client_ids = clients
1776            .iter()
1777            .map(|client| (client.id.identity, client.id.connection_id))
1778            .collect::<Vec<_>>();
1779        subscriptions.remove_subscription(client_ids[0], query_id)?;
1780        // There are still two left.
1781        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1782        subscriptions.remove_subscription(client_ids[1], query_id)?;
1783        // There is still one left.
1784        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1785        subscriptions.remove_subscription(client_ids[2], query_id)?;
1786        // Now there are no subscribers.
1787        assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1788
1789        Ok(())
1790    }
1791
1792    #[test]
1793    fn test_unsubscribe_all_doesnt_remove_other_clients() -> ResultTest<()> {
1794        let db = TestDB::durable()?;
1795
1796        let table_id = create_table(&db, "T")?;
1797        let sql = "select * from T";
1798        let plan = compile_plan(&db, sql)?;
1799        let hash = plan.hash();
1800
1801        let clients = (0..3).map(|i| Arc::new(client(i))).collect::<Vec<_>>();
1802
1803        // All of the clients are using the same query id.
1804        let query_id: ClientQueryId = QueryId::new(1);
1805
1806        let runtime = tokio::runtime::Runtime::new().unwrap();
1807        let _rt = runtime.enter();
1808
1809        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1810        subscriptions.add_subscription(clients[0].clone(), plan.clone(), query_id)?;
1811        subscriptions.add_subscription(clients[1].clone(), plan.clone(), query_id)?;
1812        subscriptions.add_subscription(clients[2].clone(), plan.clone(), query_id)?;
1813
1814        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1815
1816        let client_ids = clients
1817            .iter()
1818            .map(|client| (client.id.identity, client.id.connection_id))
1819            .collect::<Vec<_>>();
1820        subscriptions.remove_all_subscriptions(&client_ids[0]);
1821        assert!(!subscriptions.contains_client(&client_ids[0]));
1822        // There are still two left.
1823        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1824        subscriptions.remove_all_subscriptions(&client_ids[1]);
1825        // There is still one left.
1826        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1827        assert!(!subscriptions.contains_client(&client_ids[1]));
1828        subscriptions.remove_all_subscriptions(&client_ids[2]);
1829        // Now there are no subscribers.
1830        assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1831        assert!(!subscriptions.contains_client(&client_ids[2]));
1832
1833        Ok(())
1834    }
1835
1836    // This test has a single client with 3 queries of different tables, and tests removing them.
1837    #[test]
1838    fn test_multiple_queries() -> ResultTest<()> {
1839        let db = TestDB::durable()?;
1840
1841        let table_names = ["T", "S", "U"];
1842        let table_ids = table_names
1843            .iter()
1844            .map(|name| create_table(&db, name))
1845            .collect::<ResultTest<Vec<_>>>()?;
1846        let queries = table_names
1847            .iter()
1848            .map(|name| format!("select * from {}", name))
1849            .map(|sql| compile_plan(&db, &sql))
1850            .collect::<ResultTest<Vec<_>>>()?;
1851
1852        let client = Arc::new(client(0));
1853
1854        let runtime = tokio::runtime::Runtime::new().unwrap();
1855        let _rt = runtime.enter();
1856
1857        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1858        subscriptions.add_subscription(client.clone(), queries[0].clone(), QueryId::new(1))?;
1859        subscriptions.add_subscription(client.clone(), queries[1].clone(), QueryId::new(2))?;
1860        subscriptions.add_subscription(client.clone(), queries[2].clone(), QueryId::new(3))?;
1861        for i in 0..3 {
1862            assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1863        }
1864
1865        let client_id = (client.id.identity, client.id.connection_id);
1866        subscriptions.remove_subscription(client_id, QueryId::new(1))?;
1867        assert!(!subscriptions.query_reads_from_table(&queries[0].hash(), &table_ids[0]));
1868        // Assert that the rest are there.
1869        for i in 1..3 {
1870            assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1871        }
1872
1873        // Now remove the final two at once.
1874        subscriptions.remove_all_subscriptions(&client_id);
1875        for i in 0..3 {
1876            assert!(!subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1877        }
1878
1879        Ok(())
1880    }
1881
1882    #[test]
1883    fn test_multiple_query_sets() -> ResultTest<()> {
1884        let db = TestDB::durable()?;
1885
1886        let table_names = ["T", "S", "U"];
1887        let table_ids = table_names
1888            .iter()
1889            .map(|name| create_table(&db, name))
1890            .collect::<ResultTest<Vec<_>>>()?;
1891        let queries = table_names
1892            .iter()
1893            .map(|name| format!("select * from {}", name))
1894            .map(|sql| compile_plan(&db, &sql))
1895            .collect::<ResultTest<Vec<_>>>()?;
1896
1897        let client = Arc::new(client(0));
1898
1899        let runtime = tokio::runtime::Runtime::new().unwrap();
1900        let _rt = runtime.enter();
1901
1902        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1903        let added = subscriptions.add_subscription_multi(client.clone(), vec![queries[0].clone()], QueryId::new(1))?;
1904        assert_eq!(added.len(), 1);
1905        assert_eq!(added[0].hash, queries[0].hash());
1906        let added = subscriptions.add_subscription_multi(client.clone(), vec![queries[1].clone()], QueryId::new(2))?;
1907        assert_eq!(added.len(), 1);
1908        assert_eq!(added[0].hash, queries[1].hash());
1909        let added = subscriptions.add_subscription_multi(client.clone(), vec![queries[2].clone()], QueryId::new(3))?;
1910        assert_eq!(added.len(), 1);
1911        assert_eq!(added[0].hash, queries[2].hash());
1912        for i in 0..3 {
1913            assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1914        }
1915
1916        let client_id = (client.id.identity, client.id.connection_id);
1917        let removed = subscriptions.remove_subscription(client_id, QueryId::new(1))?;
1918        assert_eq!(removed.len(), 1);
1919        assert_eq!(removed[0].hash, queries[0].hash());
1920        assert!(!subscriptions.query_reads_from_table(&queries[0].hash(), &table_ids[0]));
1921        // Assert that the rest are there.
1922        for i in 1..3 {
1923            assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1924        }
1925
1926        // Now remove the final two at once.
1927        subscriptions.remove_all_subscriptions(&client_id);
1928        for i in 0..3 {
1929            assert!(!subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1930        }
1931
1932        Ok(())
1933    }
1934
1935    #[test]
1936    fn test_internals_for_search_args() -> ResultTest<()> {
1937        let db = TestDB::durable()?;
1938
1939        let table_id = create_table(&db, "t")?;
1940
1941        let client = Arc::new(client(0));
1942
1943        let runtime = tokio::runtime::Runtime::new().unwrap();
1944        let _rt = runtime.enter();
1945
1946        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1947
1948        // Subscribe to queries that have search arguments
1949        let queries = (0u8..5)
1950            .map(|name| format!("select * from t where a = {}", name))
1951            .map(|sql| compile_plan(&db, &sql))
1952            .collect::<ResultTest<Vec<_>>>()?;
1953
1954        for (i, query) in queries.iter().enumerate().take(5) {
1955            let added =
1956                subscriptions.add_subscription_multi(client.clone(), vec![query.clone()], QueryId::new(i as u32))?;
1957            assert_eq!(added.len(), 1);
1958            assert_eq!(added[0].hash, queries[i].hash);
1959        }
1960
1961        // Assert this table has a search parameter
1962        assert!(subscriptions.table_has_search_param(table_id, ColId(0)));
1963
1964        for (i, query) in queries.iter().enumerate().take(5) {
1965            assert!(subscriptions.query_has_search_arg(query.hash, table_id, ColId(0), AlgebraicValue::U8(i as u8)));
1966
1967            // Only one of `query_reads_from_table` and `query_has_search_arg` can be true at any given time
1968            assert!(!subscriptions.query_reads_from_table(&queries[i].hash, &table_id));
1969        }
1970
1971        // Remove one of the subscriptions
1972        let query_id = QueryId::new(2);
1973        let client_id = (client.id.identity, client.id.connection_id);
1974        let removed = subscriptions.remove_subscription(client_id, query_id)?;
1975        assert_eq!(removed.len(), 1);
1976
1977        // We haven't removed the other subscriptions,
1978        // so this table should still have a search parameter.
1979        assert!(subscriptions.table_has_search_param(table_id, ColId(0)));
1980
1981        // We should have removed the search argument for this query
1982        assert!(!subscriptions.query_reads_from_table(&queries[2].hash, &table_id));
1983        assert!(!subscriptions.query_has_search_arg(queries[2].hash, table_id, ColId(0), AlgebraicValue::U8(2)));
1984
1985        for (i, query) in queries.iter().enumerate().take(5) {
1986            if i != 2 {
1987                assert!(subscriptions.query_has_search_arg(
1988                    query.hash,
1989                    table_id,
1990                    ColId(0),
1991                    AlgebraicValue::U8(i as u8)
1992                ));
1993            }
1994        }
1995
1996        // Remove all of the subscriptions
1997        subscriptions.remove_all_subscriptions(&client_id);
1998
1999        // We should no longer record a search parameter for this table
2000        assert!(!subscriptions.table_has_search_param(table_id, ColId(0)));
2001        for (i, query) in queries.iter().enumerate().take(5) {
2002            assert!(!subscriptions.query_has_search_arg(query.hash, table_id, ColId(0), AlgebraicValue::U8(i as u8)));
2003        }
2004
2005        Ok(())
2006    }
2007
2008    #[test]
2009    fn test_search_args_for_selects() -> ResultTest<()> {
2010        let db = TestDB::durable()?;
2011
2012        let table_id = create_table(&db, "t")?;
2013
2014        let client = Arc::new(client(0));
2015
2016        let runtime = tokio::runtime::Runtime::new().unwrap();
2017        let _rt = runtime.enter();
2018
2019        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2020
2021        let queries = (0u8..5)
2022            .map(|name| format!("select * from t where a = {}", name))
2023            .chain(std::iter::once(String::from("select * from t")))
2024            .map(|sql| compile_plan(&db, &sql))
2025            .collect::<ResultTest<Vec<_>>>()?;
2026
2027        for (i, query) in queries.iter().enumerate() {
2028            subscriptions.add_subscription_multi(client.clone(), vec![query.clone()], QueryId::new(i as u32))?;
2029        }
2030
2031        let hash_for_2 = queries[2].hash;
2032        let hash_for_3 = queries[3].hash;
2033        let hash_for_5 = queries[5].hash;
2034
2035        // Which queries are relevant for this table update? Only:
2036        //
2037        // select * from t where a = 2
2038        // select * from t where a = 3
2039        // select * from t
2040        let table_update = DatabaseTableUpdate {
2041            table_id,
2042            table_name: "t".into(),
2043            inserts: [product![2u8]].into(),
2044            deletes: [product![3u8]].into(),
2045        };
2046
2047        let hashes = subscriptions
2048            .queries_for_table_update(&table_update, &|_, _| None)
2049            .collect::<Vec<_>>();
2050
2051        assert!(hashes.len() == 3);
2052        assert!(hashes.contains(&&hash_for_2));
2053        assert!(hashes.contains(&&hash_for_3));
2054        assert!(hashes.contains(&&hash_for_5));
2055
2056        // Which queries are relevant for this table update?
2057        // Only: select * from t
2058        let table_update = DatabaseTableUpdate {
2059            table_id,
2060            table_name: "t".into(),
2061            inserts: [product![8u8]].into(),
2062            deletes: [product![9u8]].into(),
2063        };
2064
2065        let hashes = subscriptions
2066            .queries_for_table_update(&table_update, &|_, _| None)
2067            .collect::<Vec<_>>();
2068
2069        assert!(hashes.len() == 1);
2070        assert!(hashes.contains(&&hash_for_5));
2071
2072        Ok(())
2073    }
2074
2075    #[test]
2076    fn test_search_args_for_join() -> ResultTest<()> {
2077        let db = TestDB::durable()?;
2078
2079        let schema = [("id", AlgebraicType::U8), ("a", AlgebraicType::U8)];
2080
2081        let t_id = db.create_table_for_test("t", &schema, &[0.into()])?;
2082        let s_id = db.create_table_for_test("s", &schema, &[0.into()])?;
2083
2084        let client = Arc::new(client(0));
2085
2086        let runtime = tokio::runtime::Runtime::new().unwrap();
2087        let _rt = runtime.enter();
2088
2089        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2090
2091        let plan = compile_plan(&db, "select t.* from t join s on t.id = s.id where s.a = 1")?;
2092        let hash = plan.hash;
2093
2094        subscriptions.add_subscription_multi(client.clone(), vec![plan], QueryId::new(0))?;
2095
2096        // Do we need to evaluate the above join query for this table update?
2097        // Yes, because the above query does not filter on `t`.
2098        // Therefore we must evaluate it for any update on `t`.
2099        let table_update = DatabaseTableUpdate {
2100            table_id: t_id,
2101            table_name: "t".into(),
2102            inserts: [product![0u8, 0u8]].into(),
2103            deletes: [].into(),
2104        };
2105
2106        let hashes = subscriptions
2107            .queries_for_table_update(&table_update, &|_, _| None)
2108            .cloned()
2109            .collect::<Vec<_>>();
2110
2111        assert_eq!(hashes, vec![hash]);
2112
2113        // Do we need to evaluate the above join query for this table update?
2114        // Yes, because `s.a = 1`.
2115        let table_update = DatabaseTableUpdate {
2116            table_id: s_id,
2117            table_name: "s".into(),
2118            inserts: [product![0u8, 1u8]].into(),
2119            deletes: [].into(),
2120        };
2121
2122        let hashes = subscriptions
2123            .queries_for_table_update(&table_update, &|_, _| None)
2124            .cloned()
2125            .collect::<Vec<_>>();
2126
2127        assert_eq!(hashes, vec![hash]);
2128
2129        // Do we need to evaluate the above join query for this table update?
2130        // No, because `s.a != 1`.
2131        let table_update = DatabaseTableUpdate {
2132            table_id: s_id,
2133            table_name: "s".into(),
2134            inserts: [product![0u8, 2u8]].into(),
2135            deletes: [].into(),
2136        };
2137
2138        let hashes = subscriptions
2139            .queries_for_table_update(&table_update, &|_, _| None)
2140            .cloned()
2141            .collect::<Vec<_>>();
2142
2143        assert!(hashes.is_empty());
2144
2145        Ok(())
2146    }
2147
2148    #[test]
2149    fn test_subscribe_fails_with_duplicate_request_id() -> ResultTest<()> {
2150        let db = TestDB::durable()?;
2151
2152        create_table(&db, "T")?;
2153        let sql = "select * from T";
2154        let plan = compile_plan(&db, sql)?;
2155
2156        let client = Arc::new(client(0));
2157
2158        let query_id: ClientQueryId = QueryId::new(1);
2159
2160        let runtime = tokio::runtime::Runtime::new().unwrap();
2161        let _rt = runtime.enter();
2162
2163        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2164        subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
2165
2166        assert!(subscriptions
2167            .add_subscription(client.clone(), plan.clone(), query_id)
2168            .is_err());
2169
2170        Ok(())
2171    }
2172
2173    #[test]
2174    fn test_subscribe_multi_fails_with_duplicate_request_id() -> ResultTest<()> {
2175        let db = TestDB::durable()?;
2176
2177        create_table(&db, "T")?;
2178        let sql = "select * from T";
2179        let plan = compile_plan(&db, sql)?;
2180
2181        let client = Arc::new(client(0));
2182
2183        let query_id: ClientQueryId = QueryId::new(1);
2184
2185        let runtime = tokio::runtime::Runtime::new().unwrap();
2186        let _rt = runtime.enter();
2187
2188        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2189        let result = subscriptions.add_subscription_multi(client.clone(), vec![plan.clone()], query_id)?;
2190        assert_eq!(result[0].hash, plan.hash);
2191
2192        assert!(subscriptions
2193            .add_subscription_multi(client.clone(), vec![plan.clone()], query_id)
2194            .is_err());
2195
2196        Ok(())
2197    }
2198
2199    #[test]
2200    fn test_unsubscribe() -> ResultTest<()> {
2201        let db = TestDB::durable()?;
2202
2203        let table_id = create_table(&db, "T")?;
2204        let sql = "select * from T";
2205        let plan = compile_plan(&db, sql)?;
2206        let hash = plan.hash();
2207
2208        let id = id(0);
2209        let client = Arc::new(client(0));
2210
2211        let runtime = tokio::runtime::Runtime::new().unwrap();
2212        let _rt = runtime.enter();
2213
2214        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2215        subscriptions.set_legacy_subscription(client, [plan]);
2216        subscriptions.remove_all_subscriptions(&id);
2217
2218        assert!(!subscriptions.contains_query(&hash));
2219        assert!(!subscriptions.contains_legacy_subscription(&id, &hash));
2220        assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
2221
2222        Ok(())
2223    }
2224
2225    #[test]
2226    fn test_subscribe_idempotent() -> ResultTest<()> {
2227        let db = TestDB::durable()?;
2228
2229        let table_id = create_table(&db, "T")?;
2230        let sql = "select * from T";
2231        let plan = compile_plan(&db, sql)?;
2232        let hash = plan.hash();
2233
2234        let id = id(0);
2235        let client = Arc::new(client(0));
2236
2237        let runtime = tokio::runtime::Runtime::new().unwrap();
2238        let _rt = runtime.enter();
2239
2240        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2241        subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]);
2242        subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]);
2243
2244        assert!(subscriptions.contains_query(&hash));
2245        assert!(subscriptions.contains_legacy_subscription(&id, &hash));
2246        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
2247
2248        subscriptions.remove_all_subscriptions(&id);
2249
2250        assert!(!subscriptions.contains_query(&hash));
2251        assert!(!subscriptions.contains_legacy_subscription(&id, &hash));
2252        assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
2253
2254        Ok(())
2255    }
2256
2257    #[test]
2258    fn test_share_queries_full() -> ResultTest<()> {
2259        let db = TestDB::durable()?;
2260
2261        let table_id = create_table(&db, "T")?;
2262        let sql = "select * from T";
2263        let plan = compile_plan(&db, sql)?;
2264        let hash = plan.hash();
2265
2266        let id0 = id(0);
2267        let client0 = Arc::new(client(0));
2268
2269        let id1 = id(1);
2270        let client1 = Arc::new(client(1));
2271
2272        let runtime = tokio::runtime::Runtime::new().unwrap();
2273        let _rt = runtime.enter();
2274
2275        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2276        subscriptions.set_legacy_subscription(client0, [plan.clone()]);
2277        subscriptions.set_legacy_subscription(client1, [plan.clone()]);
2278
2279        assert!(subscriptions.contains_query(&hash));
2280        assert!(subscriptions.contains_legacy_subscription(&id0, &hash));
2281        assert!(subscriptions.contains_legacy_subscription(&id1, &hash));
2282        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
2283
2284        subscriptions.remove_all_subscriptions(&id0);
2285
2286        assert!(subscriptions.contains_query(&hash));
2287        assert!(subscriptions.contains_legacy_subscription(&id1, &hash));
2288        assert!(subscriptions.query_reads_from_table(&hash, &table_id));
2289
2290        assert!(!subscriptions.contains_legacy_subscription(&id0, &hash));
2291
2292        Ok(())
2293    }
2294
2295    #[test]
2296    fn test_share_queries_partial() -> ResultTest<()> {
2297        let db = TestDB::durable()?;
2298
2299        let t = create_table(&db, "T")?;
2300        let s = create_table(&db, "S")?;
2301
2302        let scan = "select * from T";
2303        let select0 = "select * from T where a = 0";
2304        let select1 = "select * from S where a = 1";
2305
2306        let plan_scan = compile_plan(&db, scan)?;
2307        let plan_select0 = compile_plan(&db, select0)?;
2308        let plan_select1 = compile_plan(&db, select1)?;
2309
2310        let hash_scan = plan_scan.hash();
2311        let hash_select0 = plan_select0.hash();
2312        let hash_select1 = plan_select1.hash();
2313
2314        let id0 = id(0);
2315        let client0 = Arc::new(client(0));
2316
2317        let id1 = id(1);
2318        let client1 = Arc::new(client(1));
2319
2320        let runtime = tokio::runtime::Runtime::new().unwrap();
2321        let _rt = runtime.enter();
2322        let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2323        subscriptions.set_legacy_subscription(client0, [plan_scan.clone(), plan_select0.clone()]);
2324        subscriptions.set_legacy_subscription(client1, [plan_scan.clone(), plan_select1.clone()]);
2325
2326        assert!(subscriptions.contains_query(&hash_scan));
2327        assert!(subscriptions.contains_query(&hash_select0));
2328        assert!(subscriptions.contains_query(&hash_select1));
2329
2330        assert!(subscriptions.contains_legacy_subscription(&id0, &hash_scan));
2331        assert!(subscriptions.contains_legacy_subscription(&id0, &hash_select0));
2332
2333        assert!(subscriptions.contains_legacy_subscription(&id1, &hash_scan));
2334        assert!(subscriptions.contains_legacy_subscription(&id1, &hash_select1));
2335
2336        assert!(subscriptions.query_reads_from_table(&hash_scan, &t));
2337        assert!(subscriptions.query_has_search_arg(hash_select0, t, ColId(0), AlgebraicValue::U8(0)));
2338        assert!(subscriptions.query_has_search_arg(hash_select1, s, ColId(0), AlgebraicValue::U8(1)));
2339
2340        assert!(!subscriptions.query_reads_from_table(&hash_scan, &s));
2341        assert!(!subscriptions.query_reads_from_table(&hash_select0, &t));
2342        assert!(!subscriptions.query_reads_from_table(&hash_select1, &s));
2343        assert!(!subscriptions.query_reads_from_table(&hash_select0, &s));
2344        assert!(!subscriptions.query_reads_from_table(&hash_select1, &t));
2345
2346        subscriptions.remove_all_subscriptions(&id0);
2347
2348        assert!(subscriptions.contains_query(&hash_scan));
2349        assert!(subscriptions.contains_query(&hash_select1));
2350        assert!(!subscriptions.contains_query(&hash_select0));
2351
2352        assert!(subscriptions.contains_legacy_subscription(&id1, &hash_scan));
2353        assert!(subscriptions.contains_legacy_subscription(&id1, &hash_select1));
2354
2355        assert!(!subscriptions.contains_legacy_subscription(&id0, &hash_scan));
2356        assert!(!subscriptions.contains_legacy_subscription(&id0, &hash_select0));
2357
2358        assert!(subscriptions.query_reads_from_table(&hash_scan, &t));
2359        assert!(subscriptions.query_has_search_arg(hash_select1, s, ColId(0), AlgebraicValue::U8(1)));
2360
2361        assert!(!subscriptions.query_reads_from_table(&hash_select1, &s));
2362        assert!(!subscriptions.query_reads_from_table(&hash_scan, &s));
2363        assert!(!subscriptions.query_reads_from_table(&hash_select1, &t));
2364
2365        Ok(())
2366    }
2367
2368    #[test]
2369    fn test_caller_transaction_update_without_subscription() -> ResultTest<()> {
2370        // test if a transaction update is sent to the reducer caller even if
2371        // the caller haven't subscribed to any updates
2372        let db = TestDB::durable()?;
2373
2374        let id0 = Identity::ZERO;
2375        let client0 = ClientActorId::for_test(id0);
2376        let config = ClientConfig::for_test();
2377        let (client0, mut rx) = ClientConnectionSender::dummy_with_channel(client0, config);
2378
2379        let runtime = tokio::runtime::Runtime::new().unwrap();
2380        let _rt = runtime.enter();
2381        let subscriptions = SubscriptionManager::for_test_without_metrics();
2382
2383        let event = Arc::new(ModuleEvent {
2384            timestamp: Timestamp::now(),
2385            caller_identity: id0,
2386            caller_connection_id: Some(client0.id.connection_id),
2387            function_call: ModuleFunctionCall {
2388                reducer: "DummyReducer".into(),
2389                reducer_id: u32::MAX.into(),
2390                args: ArgsTuple::nullary(),
2391            },
2392            status: EventStatus::Committed(DatabaseUpdate::default()),
2393            energy_quanta_used: EnergyQuanta::ZERO,
2394            host_execution_duration: Duration::default(),
2395            request_id: None,
2396            timer: None,
2397        });
2398
2399        db.with_read_only(Workload::Update, |tx| {
2400            subscriptions.eval_updates_sequential(&(&*tx).into(), event, Some(Arc::new(client0)))
2401        });
2402
2403        runtime.block_on(async move {
2404            tokio::time::timeout(Duration::from_millis(20), async move {
2405                rx.recv().await.expect("Expected at least one message");
2406            })
2407            .await
2408            .expect("Timed out waiting for a message to the client");
2409        });
2410
2411        Ok(())
2412    }
2413}