1use super::execution_unit::QueryHash;
2use super::tx::DeltaTx;
3use crate::client::messages::{
4 SerializableMessage, SubscriptionError, SubscriptionMessage, SubscriptionResult, SubscriptionUpdateMessage,
5 TransactionUpdateMessage,
6};
7use crate::client::{ClientConnectionSender, Protocol};
8use crate::error::DBError;
9use crate::host::module_host::{DatabaseTableUpdate, ModuleEvent, UpdatesRelValue};
10use crate::messages::websocket::{self as ws, TableUpdate};
11use crate::subscription::delta::eval_delta;
12use crate::worker_metrics::WORKER_METRICS;
13use core::mem;
14use hashbrown::hash_map::OccupiedError;
15use hashbrown::{HashMap, HashSet};
16use parking_lot::RwLock;
17use prometheus::IntGauge;
18use spacetimedb_client_api_messages::websocket::{
19 BsatnFormat, CompressableQueryUpdate, FormatSwitch, JsonFormat, QueryId, QueryUpdate, SingleQueryUpdate,
20 WebsocketFormat,
21};
22use spacetimedb_data_structures::map::{Entry, IntMap};
23use spacetimedb_datastore::locking_tx_datastore::state_view::StateView;
24use spacetimedb_lib::metrics::ExecutionMetrics;
25use spacetimedb_lib::{AlgebraicValue, ConnectionId, Identity, ProductValue};
26use spacetimedb_primitives::{ColId, IndexId, TableId};
27use spacetimedb_subscription::{JoinEdge, SubscriptionPlan, TableName};
28use std::collections::BTreeMap;
29use std::fmt::Debug;
30use std::sync::atomic::{AtomicBool, Ordering};
31use std::sync::Arc;
32use tokio::sync::mpsc;
33
34type ClientId = (Identity, ConnectionId);
38type Query = Arc<Plan>;
39type Client = Arc<ClientConnectionSender>;
40type SwitchedTableUpdate = FormatSwitch<TableUpdate<BsatnFormat>, TableUpdate<JsonFormat>>;
41type SwitchedDbUpdate = FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>;
42
43type ClientQueryId = QueryId;
45type SubscriptionId = (ClientId, ClientQueryId);
47
48#[derive(Debug)]
49pub struct Plan {
50 hash: QueryHash,
51 sql: String,
52 plans: Vec<SubscriptionPlan>,
53}
54
55impl Plan {
56 pub fn new(plans: Vec<SubscriptionPlan>, hash: QueryHash, text: String) -> Self {
58 Self { plans, hash, sql: text }
59 }
60
61 pub fn hash(&self) -> QueryHash {
63 self.hash
64 }
65
66 pub fn subscribed_table_id(&self) -> TableId {
69 self.plans[0].subscribed_table_id()
70 }
71
72 pub fn subscribed_table_name(&self) -> &str {
75 self.plans[0].subscribed_table_name()
76 }
77
78 pub fn index_ids(&self) -> impl Iterator<Item = (TableId, IndexId)> {
80 self.plans
81 .iter()
82 .flat_map(|plan| plan.index_ids())
83 .collect::<HashSet<_>>()
84 .into_iter()
85 }
86
87 pub fn table_ids(&self) -> impl Iterator<Item = TableId> + '_ {
89 self.plans
90 .iter()
91 .flat_map(|plan| plan.table_ids())
92 .collect::<HashSet<_>>()
93 .into_iter()
94 }
95
96 fn search_args(&self) -> impl Iterator<Item = (TableId, ColId, AlgebraicValue)> {
98 let mut args = HashSet::new();
99 for arg in self
100 .plans
101 .iter()
102 .flat_map(|subscription| subscription.optimized_physical_plan().search_args())
103 {
104 args.insert(arg);
105 }
106 args.into_iter()
107 }
108
109 pub fn plans_fragments(&self) -> impl Iterator<Item = &SubscriptionPlan> + '_ {
112 self.plans.iter()
113 }
114
115 pub fn join_edges(&self) -> impl Iterator<Item = (JoinEdge, AlgebraicValue)> + '_ {
117 self.plans.iter().filter_map(|plan| plan.join_edge())
118 }
119
120 pub fn sql(&self) -> &str {
122 &self.sql
123 }
124}
125
126#[derive(Debug)]
128struct ClientInfo {
129 outbound_ref: Client,
130 subscriptions: HashMap<SubscriptionId, HashSet<QueryHash>>,
131 subscription_ref_count: HashMap<QueryHash, usize>,
132 legacy_subscriptions: HashSet<QueryHash>,
134 dropped: Arc<AtomicBool>,
140}
141
142impl ClientInfo {
143 fn new(outbound_ref: Client) -> Self {
144 Self {
145 outbound_ref,
146 subscriptions: HashMap::default(),
147 subscription_ref_count: HashMap::default(),
148 legacy_subscriptions: HashSet::default(),
149 dropped: Arc::new(AtomicBool::new(false)),
150 }
151 }
152
153 #[cfg(test)]
155 fn assert_ref_count_consistency(&self) {
156 let mut expected_ref_count = HashMap::new();
157 for query_hashes in self.subscriptions.values() {
158 for query_hash in query_hashes {
159 assert!(
160 self.subscription_ref_count.contains_key(query_hash),
161 "Query hash not found: {query_hash:?}"
162 );
163 expected_ref_count
164 .entry(*query_hash)
165 .and_modify(|count| *count += 1)
166 .or_insert(1);
167 }
168 }
169 assert_eq!(
170 self.subscription_ref_count, expected_ref_count,
171 "Checking the reference totals failed"
172 );
173 }
174}
175
176#[derive(Debug)]
178struct QueryState {
179 query: Query,
180 legacy_subscribers: HashSet<ClientId>,
182 subscriptions: HashSet<ClientId>,
184}
185
186impl QueryState {
187 fn new(query: Query) -> Self {
188 Self {
189 query,
190 legacy_subscribers: HashSet::default(),
191 subscriptions: HashSet::default(),
192 }
193 }
194 fn has_subscribers(&self) -> bool {
195 !self.subscriptions.is_empty() || !self.legacy_subscribers.is_empty()
196 }
197
198 fn all_clients(&self) -> impl Iterator<Item = &ClientId> {
200 itertools::chain(&self.legacy_subscribers, &self.subscriptions)
201 }
202
203 pub fn query(&self) -> &Query {
205 &self.query
206 }
207
208 fn search_args(&self) -> impl Iterator<Item = (TableId, ColId, AlgebraicValue)> {
210 self.query.search_args()
211 }
212}
213
214#[derive(Debug, Default)]
225pub struct SearchArguments {
226 cols: HashMap<TableId, HashSet<ColId>>,
245 args: BTreeMap<(TableId, ColId, AlgebraicValue), HashSet<QueryHash>>,
258}
259
260impl SearchArguments {
261 fn search_params_for_table(&self, table_id: TableId) -> impl Iterator<Item = &ColId> + '_ {
263 self.cols.get(&table_id).into_iter().flatten()
264 }
265
266 fn queries_for_search_arg(
269 &self,
270 table_id: TableId,
271 col_id: ColId,
272 search_arg: AlgebraicValue,
273 ) -> impl Iterator<Item = &QueryHash> {
274 self.args.get(&(table_id, col_id, search_arg)).into_iter().flatten()
275 }
276
277 fn queries_for_row<'a>(&'a self, table_id: TableId, row: &'a ProductValue) -> impl Iterator<Item = &'a QueryHash> {
279 self.search_params_for_table(table_id)
280 .filter_map(|col_id| row.get_field(col_id.idx(), None).ok().map(|arg| (col_id, arg.clone())))
281 .flat_map(move |(col_id, arg)| self.queries_for_search_arg(table_id, *col_id, arg))
282 }
283
284 fn remove_query(&mut self, query: &Query) {
287 let mut params = query.search_args().collect::<Vec<_>>();
289
290 for key in ¶ms {
292 if let Some(hashes) = self.args.get_mut(key) {
293 hashes.remove(&query.hash);
294 if hashes.is_empty() {
295 self.args.remove(key);
296 }
297 }
298 }
299
300 params.retain(|(table_id, col_id, _)| {
302 self.args
303 .range((*table_id, *col_id, AlgebraicValue::Min)..=(*table_id, *col_id, AlgebraicValue::Max))
304 .next()
305 .is_none()
306 });
307
308 for (table_id, col_id, _) in params {
310 if let Some(col_ids) = self.cols.get_mut(&table_id) {
311 col_ids.remove(&col_id);
312 if col_ids.is_empty() {
313 self.cols.remove(&table_id);
314 }
315 }
316 }
317 }
318
319 fn insert_query(&mut self, table_id: TableId, col_id: ColId, arg: AlgebraicValue, query: QueryHash) {
321 self.args.entry((table_id, col_id, arg)).or_default().insert(query);
322 self.cols.entry(table_id).or_default().insert(col_id);
323 }
324}
325
326#[derive(Debug, Default)]
328pub struct QueriedTableIndexIds {
329 ids: HashMap<TableId, HashMap<IndexId, usize>>,
330}
331
332impl FromIterator<(TableId, IndexId)> for QueriedTableIndexIds {
333 fn from_iter<T: IntoIterator<Item = (TableId, IndexId)>>(iter: T) -> Self {
334 let mut index_ids = Self::default();
335 for (table_id, index_id) in iter {
336 index_ids.insert_index_id(table_id, index_id);
337 }
338 index_ids
339 }
340}
341
342impl QueriedTableIndexIds {
343 pub fn index_ids_for_table(&self, table_id: TableId) -> impl Iterator<Item = IndexId> + '_ {
347 self.ids
348 .get(&table_id)
349 .into_iter()
350 .flat_map(|index_ids| index_ids.keys())
351 .copied()
352 }
353
354 pub fn insert_index_id(&mut self, table_id: TableId, index_id: IndexId) {
358 *self.ids.entry(table_id).or_default().entry(index_id).or_default() += 1;
359 }
360
361 pub fn delete_index_id(&mut self, table_id: TableId, index_id: IndexId) {
365 if let Some(ids) = self.ids.get_mut(&table_id) {
366 if let Some(n) = ids.get_mut(&index_id) {
367 *n -= 1;
368
369 if *n == 0 {
370 ids.remove(&index_id);
371
372 if ids.is_empty() {
373 self.ids.remove(&table_id);
374 }
375 }
376 }
377 }
378 }
379
380 pub fn insert_index_ids_for_query(&mut self, query: &Query) {
384 for (table_id, index_id) in query.index_ids() {
385 self.insert_index_id(table_id, index_id);
386 }
387 }
388
389 pub fn delete_index_ids_for_query(&mut self, query: &Query) {
393 for (table_id, index_id) in query.index_ids() {
394 self.delete_index_id(table_id, index_id);
395 }
396 }
397}
398
399#[derive(Debug, Default)]
402pub struct JoinEdges {
403 edges: BTreeMap<JoinEdge, HashMap<AlgebraicValue, HashSet<QueryHash>>>,
404}
405
406impl JoinEdges {
407 fn add_query(&mut self, qs: &QueryState) -> bool {
409 let mut inserted = false;
410 for (edge, rhs_val) in qs.query.join_edges() {
411 inserted = true;
412 self.edges
413 .entry(edge)
414 .or_default()
415 .entry(rhs_val)
416 .or_default()
417 .insert(qs.query.hash);
418 }
419 inserted
420 }
421
422 fn remove_query(&mut self, query: &Query) {
424 for (edge, rhs_val) in query.join_edges() {
425 if let Some(values) = self.edges.get_mut(&edge) {
426 if let Some(hashes) = values.get_mut(&rhs_val) {
427 hashes.remove(&query.hash);
428 if hashes.is_empty() {
429 values.remove(&rhs_val);
430 if values.is_empty() {
431 self.edges.remove(&edge);
432 }
433 }
434 }
435 }
436 }
437 }
438
439 fn queries_for_row<'a>(
442 &'a self,
443 table_id: TableId,
444 row: &'a ProductValue,
445 find_rhs_val: impl Fn(&JoinEdge, &ProductValue) -> Option<AlgebraicValue>,
446 ) -> impl Iterator<Item = &'a QueryHash> {
447 self.edges
448 .range(JoinEdge::range_for_table(table_id))
449 .filter_map(move |(edge, hashes)| find_rhs_val(edge, row).as_ref().and_then(|rhs_val| hashes.get(rhs_val)))
450 .flatten()
451 }
452}
453
454#[derive(Debug)]
460pub struct SubscriptionManager {
461 clients: HashMap<ClientId, ClientInfo>,
463
464 queries: HashMap<QueryHash, QueryState>,
466
467 tables: IntMap<TableId, HashSet<QueryHash>>,
471
472 indexes: QueriedTableIndexIds,
475
476 search_args: SearchArguments,
480
481 join_edges: JoinEdges,
484
485 send_worker_queue: BroadcastQueue,
497}
498
499#[derive(Debug)]
501struct ClientUpdate {
502 id: ClientId,
503 table_id: TableId,
504 table_name: TableName,
505 update: FormatSwitch<SingleQueryUpdate<BsatnFormat>, SingleQueryUpdate<JsonFormat>>,
506}
507
508#[derive(Debug)]
514struct ComputedQueries {
515 updates: Vec<ClientUpdate>,
516 errs: Vec<(ClientId, Box<str>)>,
517 event: Arc<ModuleEvent>,
518 caller: Option<Arc<ClientConnectionSender>>,
519}
520
521#[derive(Debug)]
523struct SenderWithGauge<T> {
524 tx: mpsc::UnboundedSender<T>,
525 metric: Option<IntGauge>,
526}
527impl<T> Clone for SenderWithGauge<T> {
528 fn clone(&self) -> Self {
529 SenderWithGauge {
530 tx: self.tx.clone(),
531 metric: self.metric.clone(),
532 }
533 }
534}
535
536impl<T> SenderWithGauge<T> {
537 fn new(tx: mpsc::UnboundedSender<T>, metric: Option<IntGauge>) -> Self {
538 Self { tx, metric }
539 }
540
541 pub fn send(&self, msg: T) -> Result<(), mpsc::error::SendError<T>> {
543 if let Some(metric) = &self.metric {
544 metric.inc();
545 }
546 self.tx.send(msg)
548 }
549}
550
551#[derive(Debug)]
553enum SendWorkerMessage {
554 Broadcast(ComputedQueries),
557
558 AddClient {
561 client_id: ClientId,
562 dropped: Arc<AtomicBool>,
566 outbound_ref: Client,
567 },
568
569 SendMessage {
571 recipient: Arc<ClientConnectionSender>,
572 message: SerializableMessage,
573 },
574
575 RemoveClient(ClientId),
578}
579
580pub struct SubscriptionGaugeStats {
582 pub num_queries: usize,
584 pub num_connections: usize,
586 pub num_subscription_sets: usize,
588 pub num_query_subscriptions: usize,
590 pub num_legacy_subscriptions: usize,
592}
593
594impl SubscriptionManager {
595 pub fn for_test_without_metrics_arc_rwlock() -> Arc<RwLock<Self>> {
596 Arc::new(RwLock::new(Self::for_test_without_metrics()))
597 }
598
599 pub fn for_test_without_metrics() -> Self {
600 Self::new(SendWorker::spawn_new(None))
601 }
602
603 pub fn new(send_worker_queue: BroadcastQueue) -> Self {
604 Self {
605 clients: Default::default(),
606 queries: Default::default(),
607 indexes: Default::default(),
608 tables: Default::default(),
609 search_args: Default::default(),
610 join_edges: Default::default(),
611 send_worker_queue,
612 }
613 }
614
615 pub fn query(&self, hash: &QueryHash) -> Option<Query> {
616 self.queries.get(hash).map(|state| state.query.clone())
617 }
618
619 pub fn calculate_gauge_stats(&self) -> SubscriptionGaugeStats {
620 let num_queries = self.queries.len();
621 let num_connections = self.clients.len();
622 let num_query_subscriptions = self.queries.values().map(|state| state.subscriptions.len()).sum();
623 let num_subscription_sets = self.clients.values().map(|ci| ci.subscriptions.len()).sum();
624 let num_legacy_subscriptions = self
625 .clients
626 .values()
627 .filter(|ci| !ci.legacy_subscriptions.is_empty())
628 .count();
629
630 SubscriptionGaugeStats {
631 num_queries,
632 num_connections,
633 num_query_subscriptions,
634 num_subscription_sets,
635 num_legacy_subscriptions,
636 }
637 }
638
639 fn get_or_make_client_info_and_inform_send_worker<'clients>(
644 clients: &'clients mut HashMap<ClientId, ClientInfo>,
645 send_worker_tx: &BroadcastQueue,
646 client_id: ClientId,
647 outbound_ref: Client,
648 ) -> &'clients mut ClientInfo {
649 clients.entry(client_id).or_insert_with(|| {
650 let info = ClientInfo::new(outbound_ref.clone());
651 send_worker_tx
652 .send(SendWorkerMessage::AddClient {
653 client_id,
654 dropped: info.dropped.clone(),
655 outbound_ref,
656 })
657 .expect("send worker has panicked, or otherwise dropped its recv queue!");
658 info
659 })
660 }
661
662 fn remove_client_and_inform_send_worker(&mut self, client_id: ClientId) -> Option<ClientInfo> {
665 self.clients.remove(&client_id).inspect(|_| {
666 self.send_worker_queue
667 .send(SendWorkerMessage::RemoveClient(client_id))
668 .expect("send worker has panicked, or otherwise dropped its recv queue!");
669 })
670 }
671
672 pub fn num_unique_queries(&self) -> usize {
673 self.queries.len()
674 }
675
676 #[cfg(test)]
677 fn contains_query(&self, hash: &QueryHash) -> bool {
678 self.queries.contains_key(hash)
679 }
680
681 #[cfg(test)]
682 fn contains_client(&self, subscriber: &ClientId) -> bool {
683 self.clients.contains_key(subscriber)
684 }
685
686 #[cfg(test)]
687 fn contains_legacy_subscription(&self, subscriber: &ClientId, query: &QueryHash) -> bool {
688 self.queries
689 .get(query)
690 .is_some_and(|state| state.legacy_subscribers.contains(subscriber))
691 }
692
693 #[cfg(test)]
694 fn query_reads_from_table(&self, query: &QueryHash, table: &TableId) -> bool {
695 self.tables.get(table).is_some_and(|queries| queries.contains(query))
696 }
697
698 #[cfg(test)]
699 fn query_has_search_arg(&self, query: QueryHash, table_id: TableId, col_id: ColId, arg: AlgebraicValue) -> bool {
700 self.search_args
701 .queries_for_search_arg(table_id, col_id, arg)
702 .any(|hash| *hash == query)
703 }
704
705 #[cfg(test)]
706 fn table_has_search_param(&self, table_id: TableId, col_id: ColId) -> bool {
707 self.search_args
708 .search_params_for_table(table_id)
709 .any(|id| *id == col_id)
710 }
711
712 fn remove_legacy_subscriptions(&mut self, client: &ClientId) {
713 if let Some(ci) = self.clients.get_mut(client) {
714 let mut queries_to_remove = Vec::new();
715 for query_hash in ci.legacy_subscriptions.iter() {
716 let Some(query_state) = self.queries.get_mut(query_hash) else {
717 tracing::warn!("Query state not found for query hash: {:?}", query_hash);
718 continue;
719 };
720
721 query_state.legacy_subscribers.remove(client);
722 if !query_state.has_subscribers() {
723 SubscriptionManager::remove_query_from_tables(
724 &mut self.tables,
725 &mut self.join_edges,
726 &mut self.indexes,
727 &mut self.search_args,
728 &query_state.query,
729 );
730 queries_to_remove.push(*query_hash);
731 }
732 }
733 ci.legacy_subscriptions.clear();
734 for query_hash in queries_to_remove {
735 self.queries.remove(&query_hash);
736 }
737 }
738 }
739
740 pub fn remove_dropped_clients(&mut self) {
742 for id in self.clients.keys().copied().collect::<Vec<_>>() {
743 if let Some(client) = self.clients.get(&id) {
744 if client.dropped.load(Ordering::Relaxed) {
745 self.remove_all_subscriptions(&id);
746 }
747 }
748 }
749 }
750
751 pub fn remove_subscription(&mut self, client_id: ClientId, query_id: ClientQueryId) -> Result<Vec<Query>, DBError> {
754 let subscription_id = (client_id, query_id);
755 let Some(ci) = self
756 .clients
757 .get_mut(&client_id)
758 .filter(|ci| !ci.dropped.load(Ordering::Acquire))
759 else {
760 return Err(anyhow::anyhow!("Client not found: {:?}", client_id).into());
761 };
762
763 #[cfg(test)]
764 ci.assert_ref_count_consistency();
765
766 let Some(query_hashes) = ci.subscriptions.remove(&subscription_id) else {
767 return Err(anyhow::anyhow!("Subscription not found: {:?}", subscription_id).into());
768 };
769 let mut queries_to_return = Vec::new();
770 for hash in query_hashes {
771 let remaining_refs = {
772 let Some(count) = ci.subscription_ref_count.get_mut(&hash) else {
773 return Err(anyhow::anyhow!("Query count not found for query hash: {:?}", hash).into());
774 };
775 *count -= 1;
776 *count
777 };
778 if remaining_refs > 0 {
779 continue;
781 }
782 ci.subscription_ref_count.remove(&hash);
784 let Some(query_state) = self.queries.get_mut(&hash) else {
785 return Err(anyhow::anyhow!("Query state not found for query hash: {:?}", hash).into());
786 };
787 queries_to_return.push(query_state.query.clone());
788 query_state.subscriptions.remove(&client_id);
789 if !query_state.has_subscribers() {
790 SubscriptionManager::remove_query_from_tables(
791 &mut self.tables,
792 &mut self.join_edges,
793 &mut self.indexes,
794 &mut self.search_args,
795 &query_state.query,
796 );
797 self.queries.remove(&hash);
798 }
799 }
800
801 #[cfg(test)]
802 ci.assert_ref_count_consistency();
803
804 Ok(queries_to_return)
805 }
806
807 pub fn add_subscription(&mut self, client: Client, query: Query, query_id: ClientQueryId) -> Result<(), DBError> {
809 self.add_subscription_multi(client, vec![query], query_id).map(|_| ())
810 }
811
812 pub fn add_subscription_multi(
813 &mut self,
814 client: Client,
815 queries: Vec<Query>,
816 query_id: ClientQueryId,
817 ) -> Result<Vec<Query>, DBError> {
818 let client_id = (client.id.identity, client.id.connection_id);
819
820 if self
822 .clients
823 .get(&client_id)
824 .is_some_and(|ci| ci.dropped.load(Ordering::Acquire))
825 {
826 self.remove_all_subscriptions(&client_id);
827 }
828
829 let ci = Self::get_or_make_client_info_and_inform_send_worker(
830 &mut self.clients,
831 &self.send_worker_queue,
832 client_id,
833 client,
834 );
835
836 #[cfg(test)]
837 ci.assert_ref_count_consistency();
838 let subscription_id = (client_id, query_id);
839 let hash_set = match ci.subscriptions.try_insert(subscription_id, HashSet::new()) {
840 Err(OccupiedError { .. }) => {
841 return Err(anyhow::anyhow!(
842 "Subscription with id {:?} already exists for client: {:?}",
843 query_id,
844 client_id
845 )
846 .into());
847 }
848 Ok(hash_set) => hash_set,
849 };
850 let mut new_queries = Vec::new();
852
853 for query in &queries {
854 let hash = query.hash();
855 if !hash_set.insert(hash) {
857 continue;
858 }
859 let query_state = self
860 .queries
861 .entry(hash)
862 .or_insert_with(|| QueryState::new(query.clone()));
863
864 Self::insert_query(
865 &mut self.tables,
866 &mut self.join_edges,
867 &mut self.indexes,
868 &mut self.search_args,
869 query_state,
870 );
871
872 let entry = ci.subscription_ref_count.entry(hash).or_insert(0);
873 *entry += 1;
874 let is_new_entry = *entry == 1;
875
876 let inserted = query_state.subscriptions.insert(client_id);
877 if inserted != is_new_entry {
879 return Err(anyhow::anyhow!("Internal error, ref count and query_state mismatch").into());
880 }
881 if inserted {
882 new_queries.push(query.clone());
883 }
884 }
885
886 #[cfg(test)]
887 {
888 ci.assert_ref_count_consistency();
889 }
890
891 Ok(new_queries)
892 }
893
894 pub fn set_legacy_subscription(&mut self, client: Client, queries: impl IntoIterator<Item = Query>) {
901 let client_id = (client.id.identity, client.id.connection_id);
902 self.remove_legacy_subscriptions(&client_id);
904
905 let ci = Self::get_or_make_client_info_and_inform_send_worker(
907 &mut self.clients,
908 &self.send_worker_queue,
909 client_id,
910 client,
911 );
912
913 for unit in queries {
914 let hash = unit.hash();
915 ci.legacy_subscriptions.insert(hash);
916 let query_state = self
917 .queries
918 .entry(hash)
919 .or_insert_with(|| QueryState::new(unit.clone()));
920 Self::insert_query(
921 &mut self.tables,
922 &mut self.join_edges,
923 &mut self.indexes,
924 &mut self.search_args,
925 query_state,
926 );
927 query_state.legacy_subscribers.insert(client_id);
928 }
929 }
930
931 fn remove_query_from_tables(
935 tables: &mut IntMap<TableId, HashSet<QueryHash>>,
936 join_edges: &mut JoinEdges,
937 index_ids: &mut QueriedTableIndexIds,
938 search_args: &mut SearchArguments,
939 query: &Query,
940 ) {
941 let hash = query.hash();
942 join_edges.remove_query(query);
943 search_args.remove_query(query);
944 index_ids.delete_index_ids_for_query(query);
945 for table_id in query.table_ids() {
946 if let Entry::Occupied(mut entry) = tables.entry(table_id) {
947 let hashes = entry.get_mut();
948 if hashes.remove(&hash) && hashes.is_empty() {
949 entry.remove();
950 }
951 }
952 }
953 }
954
955 fn insert_query(
959 tables: &mut IntMap<TableId, HashSet<QueryHash>>,
960 join_edges: &mut JoinEdges,
961 index_ids: &mut QueriedTableIndexIds,
962 search_args: &mut SearchArguments,
963 query_state: &QueryState,
964 ) {
965 if !query_state.has_subscribers() {
967 let hash = query_state.query.hash;
968 let query = query_state.query();
969 let return_table = query.subscribed_table_id();
970 let mut table_ids = query.table_ids().collect::<HashSet<_>>();
971
972 index_ids.insert_index_ids_for_query(query);
974
975 for (table_id, col_id, arg) in query_state.search_args() {
977 table_ids.remove(&table_id);
978 search_args.insert_query(table_id, col_id, arg, hash);
979 }
980
981 if table_ids.contains(&return_table) && join_edges.add_query(query_state) {
983 table_ids.remove(&return_table);
984 }
985
986 for table_id in table_ids {
988 tables.entry(table_id).or_default().insert(hash);
989 }
990 }
991 }
992
993 #[tracing::instrument(level = "trace", skip_all)]
997 pub fn remove_all_subscriptions(&mut self, client: &ClientId) {
998 self.remove_legacy_subscriptions(client);
999 let Some(client_info) = self.remove_client_and_inform_send_worker(*client) else {
1000 return;
1001 };
1002
1003 debug_assert!(client_info.legacy_subscriptions.is_empty());
1004 let mut queries_to_remove = Vec::new();
1005 for query_hash in client_info.subscription_ref_count.keys() {
1006 let Some(query_state) = self.queries.get_mut(query_hash) else {
1007 tracing::warn!("Query state not found for query hash: {:?}", query_hash);
1008 return;
1009 };
1010 query_state.subscriptions.remove(client);
1011 if !query_state.has_subscribers() {
1013 queries_to_remove.push(*query_hash);
1014 SubscriptionManager::remove_query_from_tables(
1015 &mut self.tables,
1016 &mut self.join_edges,
1017 &mut self.indexes,
1018 &mut self.search_args,
1019 &query_state.query,
1020 );
1021 }
1022 }
1023 for query_hash in queries_to_remove {
1024 self.queries.remove(&query_hash);
1025 }
1026 }
1027
1028 fn queries_for_table_update<'a>(
1056 &'a self,
1057 table_update: &'a DatabaseTableUpdate,
1058 find_rhs_val: &impl Fn(&JoinEdge, &ProductValue) -> Option<AlgebraicValue>,
1059 ) -> impl Iterator<Item = &'a QueryHash> {
1060 let mut queries = HashSet::new();
1061 for hash in table_update
1062 .inserts
1063 .iter()
1064 .chain(table_update.deletes.iter())
1065 .flat_map(|row| self.queries_for_row(table_update.table_id, row, find_rhs_val))
1066 {
1067 queries.insert(hash);
1068 }
1069 for hash in self.tables.get(&table_update.table_id).into_iter().flatten() {
1070 queries.insert(hash);
1071 }
1072 queries.into_iter()
1073 }
1074
1075 fn queries_for_row<'a>(
1077 &'a self,
1078 table_id: TableId,
1079 row: &'a ProductValue,
1080 find_rhs_val: impl Fn(&JoinEdge, &ProductValue) -> Option<AlgebraicValue>,
1081 ) -> impl Iterator<Item = &'a QueryHash> {
1082 self.search_args
1083 .queries_for_row(table_id, row)
1084 .chain(self.join_edges.queries_for_row(table_id, row, find_rhs_val))
1085 }
1086
1087 pub fn index_ids_for_subscriptions(&self) -> &QueriedTableIndexIds {
1089 &self.indexes
1090 }
1091
1092 #[tracing::instrument(level = "trace", skip_all)]
1101 pub fn eval_updates_sequential(
1102 &self,
1103 tx: &DeltaTx,
1104 event: Arc<ModuleEvent>,
1105 caller: Option<Arc<ClientConnectionSender>>,
1106 ) -> ExecutionMetrics {
1107 use FormatSwitch::{Bsatn, Json};
1108
1109 let tables = &event.status.database_update().unwrap().tables;
1110
1111 let span = tracing::info_span!("eval_incr").entered();
1112
1113 #[derive(Default)]
1114 struct FoldState {
1115 updates: Vec<ClientUpdate>,
1116 errs: Vec<(ClientId, Box<str>)>,
1117 metrics: ExecutionMetrics,
1118 }
1119
1120 fn find_rhs_val(edge: &JoinEdge, row: &ProductValue, tx: &DeltaTx) -> Option<AlgebraicValue> {
1122 tx.iter_by_col_eq(
1129 edge.rhs_table,
1130 edge.rhs_join_col,
1131 &row.elements[edge.lhs_join_col.idx()],
1132 )
1133 .expect("This read should always succeed, and it's a bug if it doesn't")
1134 .next()
1135 .map(|row| {
1136 row.read_col(edge.rhs_col)
1137 .expect("This read should always succeed, and it's a bug if it doesn't")
1138 })
1139 }
1140
1141 let FoldState { updates, errs, metrics } = tables
1142 .iter()
1143 .filter(|table| !table.inserts.is_empty() || !table.deletes.is_empty())
1144 .flat_map(|table_update| {
1145 self.queries_for_table_update(table_update, &|edge, row| find_rhs_val(edge, row, tx))
1146 })
1147 .filter({
1149 let mut seen = HashSet::new();
1150 move |&hash| seen.insert(hash)
1152 })
1153 .flat_map(|hash| {
1154 let qstate = &self.queries[hash];
1155 qstate
1156 .query
1157 .plans_fragments()
1158 .map(move |plan_fragment| (qstate, plan_fragment))
1159 })
1160 .fold(FoldState::default(), |mut acc, (qstate, plan)| {
1164 let table_id = plan.subscribed_table_id();
1165 let table_name = plan.subscribed_table_name().clone();
1166 let mut ops_bin_uncompressed: Option<(CompressableQueryUpdate<BsatnFormat>, _, _)> = None;
1187 let mut ops_json: Option<(QueryUpdate<JsonFormat>, _, _)> = None;
1188
1189 fn memo_encode<F: WebsocketFormat>(
1190 updates: &UpdatesRelValue<'_>,
1191 memory: &mut Option<(F::QueryUpdate, u64, usize)>,
1192 metrics: &mut ExecutionMetrics,
1193 ) -> SingleQueryUpdate<F> {
1194 let (update, num_rows, num_bytes) = memory
1195 .get_or_insert_with(|| {
1196 let encoded = updates.encode::<F>();
1197 metrics.bytes_scanned += encoded.2;
1201 encoded
1202 })
1203 .clone();
1204 metrics.bytes_sent_to_clients += num_bytes;
1209 SingleQueryUpdate { update, num_rows }
1210 }
1211
1212 let clients_for_query = qstate.all_clients();
1213
1214 match eval_delta(tx, &mut acc.metrics, plan) {
1215 Err(err) => {
1216 tracing::error!(
1217 message = "Query errored during tx update",
1218 sql = qstate.query.sql,
1219 reason = ?err,
1220 );
1221 let err = DBError::WithSql {
1222 sql: qstate.query.sql.as_str().into(),
1223 error: Box::new(err.into()),
1224 }
1225 .to_string()
1226 .into_boxed_str();
1227
1228 acc.errs.extend(clients_for_query.map(|id| (*id, err.clone())))
1229 }
1230 Ok(None) => {}
1232 Ok(Some(delta_updates)) => {
1234 let row_iter = clients_for_query.map(|id| {
1235 let client = &self.clients[id].outbound_ref;
1236 let update = match client.config.protocol {
1237 Protocol::Binary => Bsatn(memo_encode::<BsatnFormat>(
1238 &delta_updates,
1239 &mut ops_bin_uncompressed,
1240 &mut acc.metrics,
1241 )),
1242 Protocol::Text => Json(memo_encode::<JsonFormat>(
1243 &delta_updates,
1244 &mut ops_json,
1245 &mut acc.metrics,
1246 )),
1247 };
1248 ClientUpdate {
1249 id: *id,
1250 table_id,
1251 table_name: table_name.clone(),
1252 update,
1253 }
1254 });
1255 acc.updates.extend(row_iter);
1256 }
1257 }
1258
1259 acc
1260 });
1261
1262 self.send_worker_queue
1267 .send(SendWorkerMessage::Broadcast(ComputedQueries {
1268 updates,
1269 errs,
1270 event,
1271 caller,
1272 }))
1273 .expect("send worker has panicked, or otherwise dropped its recv queue!");
1274
1275 drop(span);
1276
1277 metrics
1278 }
1279}
1280
1281struct SendWorkerClient {
1282 dropped: Arc<AtomicBool>,
1288 outbound_ref: Client,
1289}
1290
1291impl SendWorkerClient {
1292 fn is_dropped(&self) -> bool {
1293 self.dropped.load(Ordering::Relaxed)
1294 }
1295
1296 fn is_cancelled(&self) -> bool {
1297 self.outbound_ref.is_cancelled()
1298 }
1299}
1300
1301struct SendWorker {
1306 rx: mpsc::UnboundedReceiver<SendWorkerMessage>,
1308
1309 queue_length_metric: Option<IntGauge>,
1316
1317 clients: HashMap<ClientId, SendWorkerClient>,
1322
1323 database_identity_to_clean_up_metric: Option<Identity>,
1327
1328 table_updates_client_id_table_id: HashMap<(ClientId, TableId), SwitchedTableUpdate>,
1331
1332 table_updates_client_id: HashMap<ClientId, SwitchedDbUpdate>,
1335}
1336
1337impl Drop for SendWorker {
1338 fn drop(&mut self) {
1339 if let Some(identity) = self.database_identity_to_clean_up_metric {
1340 let _ = WORKER_METRICS
1341 .subscription_send_queue_length
1342 .remove_label_values(&identity);
1343 }
1344 }
1345}
1346
1347impl SendWorker {
1348 fn is_client_dropped_or_cancelled(&self, client_id: &ClientId) -> bool {
1349 self.clients
1350 .get(client_id)
1351 .is_some_and(|client| client.is_cancelled() || client.is_dropped())
1352 }
1353}
1354
1355#[derive(Debug, Clone)]
1356pub struct BroadcastQueue(SenderWithGauge<SendWorkerMessage>);
1357
1358#[derive(thiserror::Error, Debug)]
1359#[error(transparent)]
1360pub struct BroadcastError(#[from] mpsc::error::SendError<SendWorkerMessage>);
1361
1362impl BroadcastQueue {
1363 fn send(&self, message: SendWorkerMessage) -> Result<(), BroadcastError> {
1364 self.0.send(message)?;
1365 Ok(())
1366 }
1367
1368 pub fn send_client_message(
1369 &self,
1370 recipient: Arc<ClientConnectionSender>,
1371 message: impl Into<SerializableMessage>,
1372 ) -> Result<(), BroadcastError> {
1373 self.0.send(SendWorkerMessage::SendMessage {
1374 recipient,
1375 message: message.into(),
1376 })?;
1377 Ok(())
1378 }
1379}
1380pub fn spawn_send_worker(metric_database_identity: Option<Identity>) -> BroadcastQueue {
1381 SendWorker::spawn_new(metric_database_identity)
1382}
1383impl SendWorker {
1384 fn new(
1385 rx: mpsc::UnboundedReceiver<SendWorkerMessage>,
1386 queue_length_metric: Option<IntGauge>,
1387 database_identity_to_clean_up_metric: Option<Identity>,
1388 ) -> Self {
1389 Self {
1390 rx,
1391 queue_length_metric,
1392 clients: Default::default(),
1393 database_identity_to_clean_up_metric,
1394 table_updates_client_id_table_id: <_>::default(),
1395 table_updates_client_id: <_>::default(),
1396 }
1397 }
1398
1399 fn spawn_new(metric_database_identity: Option<Identity>) -> BroadcastQueue {
1403 let metric = metric_database_identity.map(|identity| {
1404 WORKER_METRICS
1405 .subscription_send_queue_length
1406 .with_label_values(&identity)
1407 });
1408 let (send_worker_tx, rx) = mpsc::unbounded_channel();
1409 tokio::spawn(Self::new(rx, metric.clone(), metric_database_identity).run());
1410 BroadcastQueue(SenderWithGauge::new(send_worker_tx, metric))
1411 }
1412
1413 async fn run(mut self) {
1414 while let Some(message) = self.rx.recv().await {
1415 if let Some(metric) = &self.queue_length_metric {
1416 metric.dec();
1417 }
1418
1419 match message {
1420 SendWorkerMessage::AddClient {
1421 client_id,
1422 dropped,
1423 outbound_ref,
1424 } => {
1425 self.clients
1426 .insert(client_id, SendWorkerClient { dropped, outbound_ref });
1427 }
1428 SendWorkerMessage::SendMessage { recipient, message } => {
1429 let _ = recipient.send_message(message);
1430 }
1431 SendWorkerMessage::RemoveClient(client_id) => {
1432 self.clients.remove(&client_id);
1433 }
1434 SendWorkerMessage::Broadcast(queries) => {
1435 self.send_one_computed_queries(queries);
1436 }
1437 }
1438 }
1439 }
1440
1441 fn send_one_computed_queries(
1442 &mut self,
1443 ComputedQueries {
1444 updates,
1445 errs,
1446 event,
1447 caller,
1448 }: ComputedQueries,
1449 ) {
1450 use FormatSwitch::{Bsatn, Json};
1451
1452 let clients_with_errors = errs.iter().map(|(id, _)| id).collect::<HashSet<_>>();
1453
1454 let span = tracing::info_span!("eval_incr_group_messages_by_client");
1455
1456 let client_table_id_updates = mem::take(&mut self.table_updates_client_id_table_id);
1458 let client_id_updates = mem::take(&mut self.table_updates_client_id);
1459
1460 let mut client_table_id_updates = updates
1466 .into_iter()
1467 .filter(|upd| !self.is_client_dropped_or_cancelled(&upd.id))
1469 .filter(|upd| !clients_with_errors.contains(&upd.id))
1471 .fold(client_table_id_updates, |mut tables, upd| {
1473 match tables.entry((upd.id, upd.table_id)) {
1474 Entry::Occupied(mut entry) => match entry.get_mut().zip_mut(upd.update) {
1475 Bsatn((tbl_upd, update)) => tbl_upd.push(update),
1476 Json((tbl_upd, update)) => tbl_upd.push(update),
1477 },
1478 Entry::Vacant(entry) => drop(entry.insert(match upd.update {
1479 Bsatn(update) => Bsatn(TableUpdate::new(upd.table_id, (&*upd.table_name).into(), update)),
1480 Json(update) => Json(TableUpdate::new(upd.table_id, (&*upd.table_name).into(), update)),
1481 })),
1482 }
1483 tables
1484 });
1485
1486 let mut client_id_updates = client_table_id_updates
1490 .drain()
1491 .fold(client_id_updates, |mut updates, ((id, _), update)| {
1493 let entry = updates.entry(id);
1494 let entry = entry.or_insert_with(|| match &update {
1495 Bsatn(_) => Bsatn(<_>::default()),
1496 Json(_) => Json(<_>::default()),
1497 });
1498 match entry.zip_mut(update) {
1499 Bsatn((list, elem)) => list.tables.push(elem),
1500 Json((list, elem)) => list.tables.push(elem),
1501 }
1502 updates
1503 });
1504
1505 drop(clients_with_errors);
1506 drop(span);
1507
1508 let _span = tracing::info_span!("eval_send").entered();
1509
1510 if let Some(caller) = caller {
1517 let caller_id = (caller.id.identity, caller.id.connection_id);
1518 let database_update = client_id_updates
1519 .remove(&caller_id)
1520 .map(|update| SubscriptionUpdateMessage::from_event_and_update(&event, update))
1521 .unwrap_or_else(|| {
1522 SubscriptionUpdateMessage::default_for_protocol(caller.config.protocol, event.request_id)
1523 });
1524 let message = TransactionUpdateMessage {
1525 event: Some(event.clone()),
1526 database_update,
1527 };
1528 send_to_client(&caller, message);
1529 }
1530
1531 for (id, update) in client_id_updates.drain() {
1533 let database_update = SubscriptionUpdateMessage::from_event_and_update(&event, update);
1534 let client = self.clients[&id].outbound_ref.clone();
1535 let event = client.config.tx_update_full.then(|| event.clone());
1537 let message = TransactionUpdateMessage { event, database_update };
1538 send_to_client(&client, message);
1539 }
1540
1541 self.table_updates_client_id_table_id = client_table_id_updates;
1543 self.table_updates_client_id = client_id_updates;
1544
1545 for (id, message) in errs {
1547 if let Some(client) = self.clients.get(&id) {
1548 client.dropped.store(true, Ordering::Release);
1549 send_to_client(
1550 &client.outbound_ref,
1551 SubscriptionMessage {
1552 request_id: None,
1553 query_id: None,
1554 timer: None,
1555 result: SubscriptionResult::Error(SubscriptionError {
1556 table_id: None,
1557 message,
1558 }),
1559 },
1560 );
1561 }
1562 }
1563 }
1564}
1565
1566fn send_to_client(client: &ClientConnectionSender, message: impl Into<SerializableMessage>) {
1567 if let Err(e) = client.send_message(message) {
1568 tracing::warn!(%client.id, "failed to send update message to client: {e}")
1569 }
1570}
1571
1572#[cfg(test)]
1573mod tests {
1574 use std::{sync::Arc, time::Duration};
1575
1576 use spacetimedb_client_api_messages::websocket::QueryId;
1577 use spacetimedb_lib::AlgebraicValue;
1578 use spacetimedb_lib::{error::ResultTest, identity::AuthCtx, AlgebraicType, ConnectionId, Identity, Timestamp};
1579 use spacetimedb_primitives::{ColId, TableId};
1580 use spacetimedb_sats::product;
1581 use spacetimedb_subscription::SubscriptionPlan;
1582
1583 use super::{Plan, SubscriptionManager};
1584 use crate::db::relational_db::tests_utils::with_read_only;
1585 use crate::host::module_host::DatabaseTableUpdate;
1586 use crate::sql::ast::SchemaViewer;
1587 use crate::subscription::module_subscription_manager::ClientQueryId;
1588 use crate::{
1589 client::{ClientActorId, ClientConfig, ClientConnectionSender, ClientName},
1590 db::relational_db::{tests_utils::TestDB, RelationalDB},
1591 energy::EnergyQuanta,
1592 host::{
1593 module_host::{DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall},
1594 ArgsTuple,
1595 },
1596 subscription::execution_unit::QueryHash,
1597 };
1598 use spacetimedb_datastore::execution_context::Workload;
1599
1600 fn create_table(db: &RelationalDB, name: &str) -> ResultTest<TableId> {
1601 Ok(db.create_table_for_test(name, &[("a", AlgebraicType::U8)], &[])?)
1602 }
1603
1604 fn compile_plan(db: &RelationalDB, sql: &str) -> ResultTest<Arc<Plan>> {
1605 with_read_only(db, |tx| {
1606 let auth = AuthCtx::for_testing();
1607 let tx = SchemaViewer::new(&*tx, &auth);
1608 let (plans, has_param) = SubscriptionPlan::compile(sql, &tx, &auth).unwrap();
1609 let hash = QueryHash::from_string(sql, auth.caller, has_param);
1610 Ok(Arc::new(Plan::new(plans, hash, sql.into())))
1611 })
1612 }
1613
1614 fn id(connection_id: u128) -> (Identity, ConnectionId) {
1615 (Identity::ZERO, ConnectionId::from_u128(connection_id))
1616 }
1617
1618 fn client(connection_id: u128) -> ClientConnectionSender {
1619 let (identity, connection_id) = id(connection_id);
1620 ClientConnectionSender::dummy(
1621 ClientActorId {
1622 identity,
1623 connection_id,
1624 name: ClientName(0),
1625 },
1626 ClientConfig::for_test(),
1627 )
1628 }
1629
1630 #[test]
1631 fn test_subscribe_legacy() -> ResultTest<()> {
1632 let db = TestDB::durable()?;
1633
1634 let table_id = create_table(&db, "T")?;
1635 let sql = "select * from T";
1636 let plan = compile_plan(&db, sql)?;
1637 let hash = plan.hash();
1638
1639 let id = id(0);
1640 let client = Arc::new(client(0));
1641
1642 let runtime = tokio::runtime::Runtime::new().unwrap();
1643 let _rt = runtime.enter();
1644
1645 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1646 subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]);
1647
1648 assert!(subscriptions.contains_query(&hash));
1649 assert!(subscriptions.contains_legacy_subscription(&id, &hash));
1650 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1651
1652 Ok(())
1653 }
1654
1655 #[test]
1656 fn test_subscribe_single_adds_table_mapping() -> ResultTest<()> {
1657 let db = TestDB::durable()?;
1658
1659 let table_id = create_table(&db, "T")?;
1660 let sql = "select * from T";
1661 let plan = compile_plan(&db, sql)?;
1662 let hash = plan.hash();
1663
1664 let client = Arc::new(client(0));
1665
1666 let query_id: ClientQueryId = QueryId::new(1);
1667
1668 let runtime = tokio::runtime::Runtime::new().unwrap();
1669 let _rt = runtime.enter();
1670
1671 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1672 subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1673 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1674
1675 Ok(())
1676 }
1677
1678 #[test]
1679 fn test_unsubscribe_from_the_only_subscription() -> ResultTest<()> {
1680 let db = TestDB::durable()?;
1681
1682 let table_id = create_table(&db, "T")?;
1683 let sql = "select * from T";
1684 let plan = compile_plan(&db, sql)?;
1685 let hash = plan.hash();
1686
1687 let client = Arc::new(client(0));
1688
1689 let query_id: ClientQueryId = QueryId::new(1);
1690
1691 let runtime = tokio::runtime::Runtime::new().unwrap();
1692 let _rt = runtime.enter();
1693
1694 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1695 subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1696 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1697
1698 let client_id = (client.id.identity, client.id.connection_id);
1699 subscriptions.remove_subscription(client_id, query_id)?;
1700 assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1701
1702 Ok(())
1703 }
1704
1705 #[test]
1706 fn test_unsubscribe_with_unknown_query_id_fails() -> ResultTest<()> {
1707 let db = TestDB::durable()?;
1708
1709 create_table(&db, "T")?;
1710 let sql = "select * from T";
1711 let plan = compile_plan(&db, sql)?;
1712
1713 let client = Arc::new(client(0));
1714
1715 let query_id: ClientQueryId = QueryId::new(1);
1716
1717 let runtime = tokio::runtime::Runtime::new().unwrap();
1718 let _rt = runtime.enter();
1719
1720 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1721 subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1722
1723 let client_id = (client.id.identity, client.id.connection_id);
1724 assert!(subscriptions.remove_subscription(client_id, QueryId::new(2)).is_err());
1725
1726 Ok(())
1727 }
1728
1729 #[test]
1730 fn test_subscribe_and_unsubscribe_with_duplicate_queries() -> ResultTest<()> {
1731 let db = TestDB::durable()?;
1732
1733 let table_id = create_table(&db, "T")?;
1734 let sql = "select * from T";
1735 let plan = compile_plan(&db, sql)?;
1736 let hash = plan.hash();
1737
1738 let client = Arc::new(client(0));
1739
1740 let query_id: ClientQueryId = QueryId::new(1);
1741
1742 let runtime = tokio::runtime::Runtime::new().unwrap();
1743 let _rt = runtime.enter();
1744
1745 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1746 subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1747 subscriptions.add_subscription(client.clone(), plan.clone(), QueryId::new(2))?;
1748
1749 let client_id = (client.id.identity, client.id.connection_id);
1750 subscriptions.remove_subscription(client_id, query_id)?;
1751
1752 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1753
1754 Ok(())
1755 }
1756
1757 #[test]
1759 fn test_subscribe_and_unsubscribe_with_duplicate_queries_multi() -> ResultTest<()> {
1760 let db = TestDB::durable()?;
1761
1762 let table_id = create_table(&db, "T")?;
1763 let sql = "select * from T";
1764 let plan = compile_plan(&db, sql)?;
1765 let hash = plan.hash();
1766
1767 let client = Arc::new(client(0));
1768
1769 let query_id: ClientQueryId = QueryId::new(1);
1770
1771 let runtime = tokio::runtime::Runtime::new().unwrap();
1772 let _rt = runtime.enter();
1773
1774 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1775 let added_query = subscriptions.add_subscription_multi(client.clone(), vec![plan.clone()], query_id)?;
1776 assert!(added_query.len() == 1);
1777 assert_eq!(added_query[0].hash, hash);
1778 let second_one = subscriptions.add_subscription_multi(client.clone(), vec![plan.clone()], QueryId::new(2))?;
1779 assert!(second_one.is_empty());
1780
1781 let client_id = (client.id.identity, client.id.connection_id);
1782 let removed_queries = subscriptions.remove_subscription(client_id, query_id)?;
1783 assert!(removed_queries.is_empty());
1784
1785 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1786 let removed_queries = subscriptions.remove_subscription(client_id, QueryId::new(2))?;
1787 assert!(removed_queries.len() == 1);
1788 assert_eq!(removed_queries[0].hash, hash);
1789
1790 assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1791
1792 Ok(())
1793 }
1794
1795 #[test]
1796 fn test_unsubscribe_doesnt_remove_other_clients() -> ResultTest<()> {
1797 let db = TestDB::durable()?;
1798
1799 let table_id = create_table(&db, "T")?;
1800 let sql = "select * from T";
1801 let plan = compile_plan(&db, sql)?;
1802 let hash = plan.hash();
1803
1804 let clients = (0..3).map(|i| Arc::new(client(i))).collect::<Vec<_>>();
1805
1806 let query_id: ClientQueryId = QueryId::new(1);
1808
1809 let runtime = tokio::runtime::Runtime::new().unwrap();
1810 let _rt = runtime.enter();
1811
1812 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1813 subscriptions.add_subscription(clients[0].clone(), plan.clone(), query_id)?;
1814 subscriptions.add_subscription(clients[1].clone(), plan.clone(), query_id)?;
1815 subscriptions.add_subscription(clients[2].clone(), plan.clone(), query_id)?;
1816
1817 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1818
1819 let client_ids = clients
1820 .iter()
1821 .map(|client| (client.id.identity, client.id.connection_id))
1822 .collect::<Vec<_>>();
1823 subscriptions.remove_subscription(client_ids[0], query_id)?;
1824 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1826 subscriptions.remove_subscription(client_ids[1], query_id)?;
1827 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1829 subscriptions.remove_subscription(client_ids[2], query_id)?;
1830 assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1832
1833 Ok(())
1834 }
1835
1836 #[test]
1837 fn test_unsubscribe_all_doesnt_remove_other_clients() -> ResultTest<()> {
1838 let db = TestDB::durable()?;
1839
1840 let table_id = create_table(&db, "T")?;
1841 let sql = "select * from T";
1842 let plan = compile_plan(&db, sql)?;
1843 let hash = plan.hash();
1844
1845 let clients = (0..3).map(|i| Arc::new(client(i))).collect::<Vec<_>>();
1846
1847 let query_id: ClientQueryId = QueryId::new(1);
1849
1850 let runtime = tokio::runtime::Runtime::new().unwrap();
1851 let _rt = runtime.enter();
1852
1853 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1854 subscriptions.add_subscription(clients[0].clone(), plan.clone(), query_id)?;
1855 subscriptions.add_subscription(clients[1].clone(), plan.clone(), query_id)?;
1856 subscriptions.add_subscription(clients[2].clone(), plan.clone(), query_id)?;
1857
1858 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1859
1860 let client_ids = clients
1861 .iter()
1862 .map(|client| (client.id.identity, client.id.connection_id))
1863 .collect::<Vec<_>>();
1864 subscriptions.remove_all_subscriptions(&client_ids[0]);
1865 assert!(!subscriptions.contains_client(&client_ids[0]));
1866 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1868 subscriptions.remove_all_subscriptions(&client_ids[1]);
1869 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1871 assert!(!subscriptions.contains_client(&client_ids[1]));
1872 subscriptions.remove_all_subscriptions(&client_ids[2]);
1873 assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1875 assert!(!subscriptions.contains_client(&client_ids[2]));
1876
1877 Ok(())
1878 }
1879
1880 #[test]
1882 fn test_multiple_queries() -> ResultTest<()> {
1883 let db = TestDB::durable()?;
1884
1885 let table_names = ["T", "S", "U"];
1886 let table_ids = table_names
1887 .iter()
1888 .map(|name| create_table(&db, name))
1889 .collect::<ResultTest<Vec<_>>>()?;
1890 let queries = table_names
1891 .iter()
1892 .map(|name| format!("select * from {name}"))
1893 .map(|sql| compile_plan(&db, &sql))
1894 .collect::<ResultTest<Vec<_>>>()?;
1895
1896 let client = Arc::new(client(0));
1897
1898 let runtime = tokio::runtime::Runtime::new().unwrap();
1899 let _rt = runtime.enter();
1900
1901 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1902 subscriptions.add_subscription(client.clone(), queries[0].clone(), QueryId::new(1))?;
1903 subscriptions.add_subscription(client.clone(), queries[1].clone(), QueryId::new(2))?;
1904 subscriptions.add_subscription(client.clone(), queries[2].clone(), QueryId::new(3))?;
1905 for i in 0..3 {
1906 assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1907 }
1908
1909 let client_id = (client.id.identity, client.id.connection_id);
1910 subscriptions.remove_subscription(client_id, QueryId::new(1))?;
1911 assert!(!subscriptions.query_reads_from_table(&queries[0].hash(), &table_ids[0]));
1912 for i in 1..3 {
1914 assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1915 }
1916
1917 subscriptions.remove_all_subscriptions(&client_id);
1919 for i in 0..3 {
1920 assert!(!subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1921 }
1922
1923 Ok(())
1924 }
1925
1926 #[test]
1927 fn test_multiple_query_sets() -> ResultTest<()> {
1928 let db = TestDB::durable()?;
1929
1930 let table_names = ["T", "S", "U"];
1931 let table_ids = table_names
1932 .iter()
1933 .map(|name| create_table(&db, name))
1934 .collect::<ResultTest<Vec<_>>>()?;
1935 let queries = table_names
1936 .iter()
1937 .map(|name| format!("select * from {name}"))
1938 .map(|sql| compile_plan(&db, &sql))
1939 .collect::<ResultTest<Vec<_>>>()?;
1940
1941 let client = Arc::new(client(0));
1942
1943 let runtime = tokio::runtime::Runtime::new().unwrap();
1944 let _rt = runtime.enter();
1945
1946 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1947 let added = subscriptions.add_subscription_multi(client.clone(), vec![queries[0].clone()], QueryId::new(1))?;
1948 assert_eq!(added.len(), 1);
1949 assert_eq!(added[0].hash, queries[0].hash());
1950 let added = subscriptions.add_subscription_multi(client.clone(), vec![queries[1].clone()], QueryId::new(2))?;
1951 assert_eq!(added.len(), 1);
1952 assert_eq!(added[0].hash, queries[1].hash());
1953 let added = subscriptions.add_subscription_multi(client.clone(), vec![queries[2].clone()], QueryId::new(3))?;
1954 assert_eq!(added.len(), 1);
1955 assert_eq!(added[0].hash, queries[2].hash());
1956 for i in 0..3 {
1957 assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1958 }
1959
1960 let client_id = (client.id.identity, client.id.connection_id);
1961 let removed = subscriptions.remove_subscription(client_id, QueryId::new(1))?;
1962 assert_eq!(removed.len(), 1);
1963 assert_eq!(removed[0].hash, queries[0].hash());
1964 assert!(!subscriptions.query_reads_from_table(&queries[0].hash(), &table_ids[0]));
1965 for i in 1..3 {
1967 assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1968 }
1969
1970 subscriptions.remove_all_subscriptions(&client_id);
1972 for i in 0..3 {
1973 assert!(!subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1974 }
1975
1976 Ok(())
1977 }
1978
1979 #[test]
1980 fn test_internals_for_search_args() -> ResultTest<()> {
1981 let db = TestDB::durable()?;
1982
1983 let table_id = create_table(&db, "t")?;
1984
1985 let client = Arc::new(client(0));
1986
1987 let runtime = tokio::runtime::Runtime::new().unwrap();
1988 let _rt = runtime.enter();
1989
1990 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1991
1992 let queries = (0u8..5)
1994 .map(|name| format!("select * from t where a = {name}"))
1995 .map(|sql| compile_plan(&db, &sql))
1996 .collect::<ResultTest<Vec<_>>>()?;
1997
1998 for (i, query) in queries.iter().enumerate().take(5) {
1999 let added =
2000 subscriptions.add_subscription_multi(client.clone(), vec![query.clone()], QueryId::new(i as u32))?;
2001 assert_eq!(added.len(), 1);
2002 assert_eq!(added[0].hash, queries[i].hash);
2003 }
2004
2005 assert!(subscriptions.table_has_search_param(table_id, ColId(0)));
2007
2008 for (i, query) in queries.iter().enumerate().take(5) {
2009 assert!(subscriptions.query_has_search_arg(query.hash, table_id, ColId(0), AlgebraicValue::U8(i as u8)));
2010
2011 assert!(!subscriptions.query_reads_from_table(&queries[i].hash, &table_id));
2013 }
2014
2015 let query_id = QueryId::new(2);
2017 let client_id = (client.id.identity, client.id.connection_id);
2018 let removed = subscriptions.remove_subscription(client_id, query_id)?;
2019 assert_eq!(removed.len(), 1);
2020
2021 assert!(subscriptions.table_has_search_param(table_id, ColId(0)));
2024
2025 assert!(!subscriptions.query_reads_from_table(&queries[2].hash, &table_id));
2027 assert!(!subscriptions.query_has_search_arg(queries[2].hash, table_id, ColId(0), AlgebraicValue::U8(2)));
2028
2029 for (i, query) in queries.iter().enumerate().take(5) {
2030 if i != 2 {
2031 assert!(subscriptions.query_has_search_arg(
2032 query.hash,
2033 table_id,
2034 ColId(0),
2035 AlgebraicValue::U8(i as u8)
2036 ));
2037 }
2038 }
2039
2040 subscriptions.remove_all_subscriptions(&client_id);
2042
2043 assert!(!subscriptions.table_has_search_param(table_id, ColId(0)));
2045 for (i, query) in queries.iter().enumerate().take(5) {
2046 assert!(!subscriptions.query_has_search_arg(query.hash, table_id, ColId(0), AlgebraicValue::U8(i as u8)));
2047 }
2048
2049 Ok(())
2050 }
2051
2052 #[test]
2053 fn test_search_args_for_selects() -> ResultTest<()> {
2054 let db = TestDB::durable()?;
2055
2056 let table_id = create_table(&db, "t")?;
2057
2058 let client = Arc::new(client(0));
2059
2060 let runtime = tokio::runtime::Runtime::new().unwrap();
2061 let _rt = runtime.enter();
2062
2063 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2064
2065 let queries = (0u8..5)
2066 .map(|name| format!("select * from t where a = {name}"))
2067 .chain(std::iter::once(String::from("select * from t")))
2068 .map(|sql| compile_plan(&db, &sql))
2069 .collect::<ResultTest<Vec<_>>>()?;
2070
2071 for (i, query) in queries.iter().enumerate() {
2072 subscriptions.add_subscription_multi(client.clone(), vec![query.clone()], QueryId::new(i as u32))?;
2073 }
2074
2075 let hash_for_2 = queries[2].hash;
2076 let hash_for_3 = queries[3].hash;
2077 let hash_for_5 = queries[5].hash;
2078
2079 let table_update = DatabaseTableUpdate {
2085 table_id,
2086 table_name: "t".into(),
2087 inserts: [product![2u8]].into(),
2088 deletes: [product![3u8]].into(),
2089 };
2090
2091 let hashes = subscriptions
2092 .queries_for_table_update(&table_update, &|_, _| None)
2093 .collect::<Vec<_>>();
2094
2095 assert!(hashes.len() == 3);
2096 assert!(hashes.contains(&&hash_for_2));
2097 assert!(hashes.contains(&&hash_for_3));
2098 assert!(hashes.contains(&&hash_for_5));
2099
2100 let table_update = DatabaseTableUpdate {
2103 table_id,
2104 table_name: "t".into(),
2105 inserts: [product![8u8]].into(),
2106 deletes: [product![9u8]].into(),
2107 };
2108
2109 let hashes = subscriptions
2110 .queries_for_table_update(&table_update, &|_, _| None)
2111 .collect::<Vec<_>>();
2112
2113 assert!(hashes.len() == 1);
2114 assert!(hashes.contains(&&hash_for_5));
2115
2116 Ok(())
2117 }
2118
2119 #[test]
2120 fn test_search_args_for_join() -> ResultTest<()> {
2121 let db = TestDB::durable()?;
2122
2123 let schema = [("id", AlgebraicType::U8), ("a", AlgebraicType::U8)];
2124
2125 let t_id = db.create_table_for_test("t", &schema, &[0.into()])?;
2126 let s_id = db.create_table_for_test("s", &schema, &[0.into()])?;
2127
2128 let client = Arc::new(client(0));
2129
2130 let runtime = tokio::runtime::Runtime::new().unwrap();
2131 let _rt = runtime.enter();
2132
2133 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2134
2135 let plan = compile_plan(&db, "select t.* from t join s on t.id = s.id where s.a = 1")?;
2136 let hash = plan.hash;
2137
2138 subscriptions.add_subscription_multi(client.clone(), vec![plan], QueryId::new(0))?;
2139
2140 let table_update = DatabaseTableUpdate {
2144 table_id: t_id,
2145 table_name: "t".into(),
2146 inserts: [product![0u8, 0u8]].into(),
2147 deletes: [].into(),
2148 };
2149
2150 let hashes = subscriptions
2151 .queries_for_table_update(&table_update, &|_, _| None)
2152 .cloned()
2153 .collect::<Vec<_>>();
2154
2155 assert_eq!(hashes, vec![hash]);
2156
2157 let table_update = DatabaseTableUpdate {
2160 table_id: s_id,
2161 table_name: "s".into(),
2162 inserts: [product![0u8, 1u8]].into(),
2163 deletes: [].into(),
2164 };
2165
2166 let hashes = subscriptions
2167 .queries_for_table_update(&table_update, &|_, _| None)
2168 .cloned()
2169 .collect::<Vec<_>>();
2170
2171 assert_eq!(hashes, vec![hash]);
2172
2173 let table_update = DatabaseTableUpdate {
2176 table_id: s_id,
2177 table_name: "s".into(),
2178 inserts: [product![0u8, 2u8]].into(),
2179 deletes: [].into(),
2180 };
2181
2182 let hashes = subscriptions
2183 .queries_for_table_update(&table_update, &|_, _| None)
2184 .cloned()
2185 .collect::<Vec<_>>();
2186
2187 assert!(hashes.is_empty());
2188
2189 Ok(())
2190 }
2191
2192 #[test]
2193 fn test_subscribe_fails_with_duplicate_request_id() -> ResultTest<()> {
2194 let db = TestDB::durable()?;
2195
2196 create_table(&db, "T")?;
2197 let sql = "select * from T";
2198 let plan = compile_plan(&db, sql)?;
2199
2200 let client = Arc::new(client(0));
2201
2202 let query_id: ClientQueryId = QueryId::new(1);
2203
2204 let runtime = tokio::runtime::Runtime::new().unwrap();
2205 let _rt = runtime.enter();
2206
2207 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2208 subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
2209
2210 assert!(subscriptions
2211 .add_subscription(client.clone(), plan.clone(), query_id)
2212 .is_err());
2213
2214 Ok(())
2215 }
2216
2217 #[test]
2218 fn test_subscribe_multi_fails_with_duplicate_request_id() -> ResultTest<()> {
2219 let db = TestDB::durable()?;
2220
2221 create_table(&db, "T")?;
2222 let sql = "select * from T";
2223 let plan = compile_plan(&db, sql)?;
2224
2225 let client = Arc::new(client(0));
2226
2227 let query_id: ClientQueryId = QueryId::new(1);
2228
2229 let runtime = tokio::runtime::Runtime::new().unwrap();
2230 let _rt = runtime.enter();
2231
2232 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2233 let result = subscriptions.add_subscription_multi(client.clone(), vec![plan.clone()], query_id)?;
2234 assert_eq!(result[0].hash, plan.hash);
2235
2236 assert!(subscriptions
2237 .add_subscription_multi(client.clone(), vec![plan.clone()], query_id)
2238 .is_err());
2239
2240 Ok(())
2241 }
2242
2243 #[test]
2244 fn test_unsubscribe() -> ResultTest<()> {
2245 let db = TestDB::durable()?;
2246
2247 let table_id = create_table(&db, "T")?;
2248 let sql = "select * from T";
2249 let plan = compile_plan(&db, sql)?;
2250 let hash = plan.hash();
2251
2252 let id = id(0);
2253 let client = Arc::new(client(0));
2254
2255 let runtime = tokio::runtime::Runtime::new().unwrap();
2256 let _rt = runtime.enter();
2257
2258 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2259 subscriptions.set_legacy_subscription(client, [plan]);
2260 subscriptions.remove_all_subscriptions(&id);
2261
2262 assert!(!subscriptions.contains_query(&hash));
2263 assert!(!subscriptions.contains_legacy_subscription(&id, &hash));
2264 assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
2265
2266 Ok(())
2267 }
2268
2269 #[test]
2270 fn test_subscribe_idempotent() -> ResultTest<()> {
2271 let db = TestDB::durable()?;
2272
2273 let table_id = create_table(&db, "T")?;
2274 let sql = "select * from T";
2275 let plan = compile_plan(&db, sql)?;
2276 let hash = plan.hash();
2277
2278 let id = id(0);
2279 let client = Arc::new(client(0));
2280
2281 let runtime = tokio::runtime::Runtime::new().unwrap();
2282 let _rt = runtime.enter();
2283
2284 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2285 subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]);
2286 subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]);
2287
2288 assert!(subscriptions.contains_query(&hash));
2289 assert!(subscriptions.contains_legacy_subscription(&id, &hash));
2290 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
2291
2292 subscriptions.remove_all_subscriptions(&id);
2293
2294 assert!(!subscriptions.contains_query(&hash));
2295 assert!(!subscriptions.contains_legacy_subscription(&id, &hash));
2296 assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
2297
2298 Ok(())
2299 }
2300
2301 #[test]
2302 fn test_share_queries_full() -> ResultTest<()> {
2303 let db = TestDB::durable()?;
2304
2305 let table_id = create_table(&db, "T")?;
2306 let sql = "select * from T";
2307 let plan = compile_plan(&db, sql)?;
2308 let hash = plan.hash();
2309
2310 let id0 = id(0);
2311 let client0 = Arc::new(client(0));
2312
2313 let id1 = id(1);
2314 let client1 = Arc::new(client(1));
2315
2316 let runtime = tokio::runtime::Runtime::new().unwrap();
2317 let _rt = runtime.enter();
2318
2319 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2320 subscriptions.set_legacy_subscription(client0, [plan.clone()]);
2321 subscriptions.set_legacy_subscription(client1, [plan.clone()]);
2322
2323 assert!(subscriptions.contains_query(&hash));
2324 assert!(subscriptions.contains_legacy_subscription(&id0, &hash));
2325 assert!(subscriptions.contains_legacy_subscription(&id1, &hash));
2326 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
2327
2328 subscriptions.remove_all_subscriptions(&id0);
2329
2330 assert!(subscriptions.contains_query(&hash));
2331 assert!(subscriptions.contains_legacy_subscription(&id1, &hash));
2332 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
2333
2334 assert!(!subscriptions.contains_legacy_subscription(&id0, &hash));
2335
2336 Ok(())
2337 }
2338
2339 #[test]
2340 fn test_share_queries_partial() -> ResultTest<()> {
2341 let db = TestDB::durable()?;
2342
2343 let t = create_table(&db, "T")?;
2344 let s = create_table(&db, "S")?;
2345
2346 let scan = "select * from T";
2347 let select0 = "select * from T where a = 0";
2348 let select1 = "select * from S where a = 1";
2349
2350 let plan_scan = compile_plan(&db, scan)?;
2351 let plan_select0 = compile_plan(&db, select0)?;
2352 let plan_select1 = compile_plan(&db, select1)?;
2353
2354 let hash_scan = plan_scan.hash();
2355 let hash_select0 = plan_select0.hash();
2356 let hash_select1 = plan_select1.hash();
2357
2358 let id0 = id(0);
2359 let client0 = Arc::new(client(0));
2360
2361 let id1 = id(1);
2362 let client1 = Arc::new(client(1));
2363
2364 let runtime = tokio::runtime::Runtime::new().unwrap();
2365 let _rt = runtime.enter();
2366 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2367 subscriptions.set_legacy_subscription(client0, [plan_scan.clone(), plan_select0.clone()]);
2368 subscriptions.set_legacy_subscription(client1, [plan_scan.clone(), plan_select1.clone()]);
2369
2370 assert!(subscriptions.contains_query(&hash_scan));
2371 assert!(subscriptions.contains_query(&hash_select0));
2372 assert!(subscriptions.contains_query(&hash_select1));
2373
2374 assert!(subscriptions.contains_legacy_subscription(&id0, &hash_scan));
2375 assert!(subscriptions.contains_legacy_subscription(&id0, &hash_select0));
2376
2377 assert!(subscriptions.contains_legacy_subscription(&id1, &hash_scan));
2378 assert!(subscriptions.contains_legacy_subscription(&id1, &hash_select1));
2379
2380 assert!(subscriptions.query_reads_from_table(&hash_scan, &t));
2381 assert!(subscriptions.query_has_search_arg(hash_select0, t, ColId(0), AlgebraicValue::U8(0)));
2382 assert!(subscriptions.query_has_search_arg(hash_select1, s, ColId(0), AlgebraicValue::U8(1)));
2383
2384 assert!(!subscriptions.query_reads_from_table(&hash_scan, &s));
2385 assert!(!subscriptions.query_reads_from_table(&hash_select0, &t));
2386 assert!(!subscriptions.query_reads_from_table(&hash_select1, &s));
2387 assert!(!subscriptions.query_reads_from_table(&hash_select0, &s));
2388 assert!(!subscriptions.query_reads_from_table(&hash_select1, &t));
2389
2390 subscriptions.remove_all_subscriptions(&id0);
2391
2392 assert!(subscriptions.contains_query(&hash_scan));
2393 assert!(subscriptions.contains_query(&hash_select1));
2394 assert!(!subscriptions.contains_query(&hash_select0));
2395
2396 assert!(subscriptions.contains_legacy_subscription(&id1, &hash_scan));
2397 assert!(subscriptions.contains_legacy_subscription(&id1, &hash_select1));
2398
2399 assert!(!subscriptions.contains_legacy_subscription(&id0, &hash_scan));
2400 assert!(!subscriptions.contains_legacy_subscription(&id0, &hash_select0));
2401
2402 assert!(subscriptions.query_reads_from_table(&hash_scan, &t));
2403 assert!(subscriptions.query_has_search_arg(hash_select1, s, ColId(0), AlgebraicValue::U8(1)));
2404
2405 assert!(!subscriptions.query_reads_from_table(&hash_select1, &s));
2406 assert!(!subscriptions.query_reads_from_table(&hash_scan, &s));
2407 assert!(!subscriptions.query_reads_from_table(&hash_select1, &t));
2408
2409 Ok(())
2410 }
2411
2412 #[test]
2413 fn test_caller_transaction_update_without_subscription() -> ResultTest<()> {
2414 let db = TestDB::durable()?;
2417
2418 let id0 = Identity::ZERO;
2419 let client0 = ClientActorId::for_test(id0);
2420 let config = ClientConfig::for_test();
2421 let (client0, mut rx) = ClientConnectionSender::dummy_with_channel(client0, config);
2422
2423 let runtime = tokio::runtime::Runtime::new().unwrap();
2424 let _rt = runtime.enter();
2425 let subscriptions = SubscriptionManager::for_test_without_metrics();
2426
2427 let event = Arc::new(ModuleEvent {
2428 timestamp: Timestamp::now(),
2429 caller_identity: id0,
2430 caller_connection_id: Some(client0.id.connection_id),
2431 function_call: ModuleFunctionCall {
2432 reducer: "DummyReducer".into(),
2433 reducer_id: u32::MAX.into(),
2434 args: ArgsTuple::nullary(),
2435 },
2436 status: EventStatus::Committed(DatabaseUpdate::default()),
2437 energy_quanta_used: EnergyQuanta::ZERO,
2438 host_execution_duration: Duration::default(),
2439 request_id: None,
2440 timer: None,
2441 });
2442
2443 db.with_read_only(Workload::Update, |tx| {
2444 subscriptions.eval_updates_sequential(&(&*tx).into(), event, Some(Arc::new(client0)))
2445 });
2446
2447 runtime.block_on(async move {
2448 tokio::time::timeout(Duration::from_millis(20), async move {
2449 rx.recv().await.expect("Expected at least one message");
2450 })
2451 .await
2452 .expect("Timed out waiting for a message to the client");
2453 });
2454
2455 Ok(())
2456 }
2457}