1use super::execution_unit::QueryHash;
2use super::tx::DeltaTx;
3use crate::client::messages::{
4 SerializableMessage, SubscriptionError, SubscriptionMessage, SubscriptionResult, SubscriptionUpdateMessage,
5 TransactionUpdateMessage,
6};
7use crate::client::{ClientConnectionSender, Protocol};
8use crate::db::datastore::locking_tx_datastore::state_view::StateView;
9use crate::error::DBError;
10use crate::host::module_host::{DatabaseTableUpdate, ModuleEvent, UpdatesRelValue};
11use crate::messages::websocket::{self as ws, TableUpdate};
12use crate::subscription::delta::eval_delta;
13use crate::worker_metrics::WORKER_METRICS;
14use hashbrown::hash_map::OccupiedError;
15use hashbrown::{HashMap, HashSet};
16use itertools::Itertools;
17use parking_lot::RwLock;
18use prometheus::IntGauge;
19use spacetimedb_client_api_messages::websocket::{
20 BsatnFormat, CompressableQueryUpdate, FormatSwitch, JsonFormat, QueryId, QueryUpdate, SingleQueryUpdate,
21 WebsocketFormat,
22};
23use spacetimedb_data_structures::map::{Entry, IntMap};
24use spacetimedb_lib::metrics::ExecutionMetrics;
25use spacetimedb_lib::{AlgebraicValue, ConnectionId, Identity, ProductValue};
26use spacetimedb_primitives::{ColId, IndexId, TableId};
27use spacetimedb_subscription::{JoinEdge, SubscriptionPlan, TableName};
28use std::collections::{BTreeMap, BTreeSet};
29use std::sync::atomic::{AtomicBool, Ordering};
30use std::sync::Arc;
31use tokio::sync::mpsc;
32
33type ClientId = (Identity, ConnectionId);
37type Query = Arc<Plan>;
38type Client = Arc<ClientConnectionSender>;
39type SwitchedTableUpdate = FormatSwitch<TableUpdate<BsatnFormat>, TableUpdate<JsonFormat>>;
40type SwitchedDbUpdate = FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>;
41
42type ClientQueryId = QueryId;
44type SubscriptionId = (ClientId, ClientQueryId);
46
47#[derive(Debug)]
48pub struct Plan {
49 hash: QueryHash,
50 sql: String,
51 plans: Vec<SubscriptionPlan>,
52}
53
54impl Plan {
55 pub fn new(plans: Vec<SubscriptionPlan>, hash: QueryHash, text: String) -> Self {
57 Self { plans, hash, sql: text }
58 }
59
60 pub fn hash(&self) -> QueryHash {
62 self.hash
63 }
64
65 pub fn subscribed_table_id(&self) -> TableId {
68 self.plans[0].subscribed_table_id()
69 }
70
71 pub fn subscribed_table_name(&self) -> &str {
74 self.plans[0].subscribed_table_name()
75 }
76
77 pub fn index_ids(&self) -> impl Iterator<Item = (TableId, IndexId)> {
79 self.plans
80 .iter()
81 .flat_map(|plan| plan.index_ids())
82 .collect::<HashSet<_>>()
83 .into_iter()
84 }
85
86 pub fn table_ids(&self) -> impl Iterator<Item = TableId> + '_ {
88 self.plans
89 .iter()
90 .flat_map(|plan| plan.table_ids())
91 .collect::<HashSet<_>>()
92 .into_iter()
93 }
94
95 pub fn plans_fragments(&self) -> impl Iterator<Item = &SubscriptionPlan> + '_ {
98 self.plans.iter()
99 }
100
101 pub fn join_edges(&self) -> impl Iterator<Item = (JoinEdge, AlgebraicValue)> + '_ {
103 self.plans.iter().filter_map(|plan| plan.join_edge())
104 }
105
106 pub fn sql(&self) -> &str {
108 &self.sql
109 }
110}
111
112#[derive(Debug)]
114struct ClientInfo {
115 outbound_ref: Client,
116 subscriptions: HashMap<SubscriptionId, HashSet<QueryHash>>,
117 subscription_ref_count: HashMap<QueryHash, usize>,
118 legacy_subscriptions: HashSet<QueryHash>,
120 dropped: Arc<AtomicBool>,
126}
127
128impl ClientInfo {
129 fn new(outbound_ref: Client) -> Self {
130 Self {
131 outbound_ref,
132 subscriptions: HashMap::default(),
133 subscription_ref_count: HashMap::default(),
134 legacy_subscriptions: HashSet::default(),
135 dropped: Arc::new(AtomicBool::new(false)),
136 }
137 }
138
139 #[cfg(test)]
141 fn assert_ref_count_consistency(&self) {
142 let mut expected_ref_count = HashMap::new();
143 for query_hashes in self.subscriptions.values() {
144 for query_hash in query_hashes {
145 assert!(
146 self.subscription_ref_count.contains_key(query_hash),
147 "Query hash not found: {:?}",
148 query_hash
149 );
150 expected_ref_count
151 .entry(*query_hash)
152 .and_modify(|count| *count += 1)
153 .or_insert(1);
154 }
155 }
156 assert_eq!(
157 self.subscription_ref_count, expected_ref_count,
158 "Checking the reference totals failed"
159 );
160 }
161}
162
163#[derive(Debug)]
165struct QueryState {
166 query: Query,
167 legacy_subscribers: HashSet<ClientId>,
169 subscriptions: HashSet<ClientId>,
171}
172
173impl QueryState {
174 fn new(query: Query) -> Self {
175 Self {
176 query,
177 legacy_subscribers: HashSet::default(),
178 subscriptions: HashSet::default(),
179 }
180 }
181 fn has_subscribers(&self) -> bool {
182 !self.subscriptions.is_empty() || !self.legacy_subscribers.is_empty()
183 }
184
185 fn all_clients(&self) -> impl Iterator<Item = &ClientId> {
187 itertools::chain(&self.legacy_subscribers, &self.subscriptions)
188 }
189
190 pub fn query(&self) -> &Query {
192 &self.query
193 }
194
195 fn search_args(&self) -> impl Iterator<Item = (TableId, ColId, AlgebraicValue)> {
197 let mut args = HashSet::new();
198 for arg in self
199 .query
200 .plans
201 .iter()
202 .flat_map(|subscription| subscription.optimized_physical_plan().search_args())
203 {
204 args.insert(arg);
205 }
206 args.into_iter()
207 }
208}
209
210#[derive(Debug, Default)]
221pub struct SearchArguments {
222 params: BTreeSet<(TableId, ColId)>,
241 args: BTreeSet<(TableId, ColId, AlgebraicValue, QueryHash)>,
254}
255
256impl SearchArguments {
257 fn search_params_for_table(&self, table_id: TableId) -> impl Iterator<Item = ColId> + '_ {
259 let lower_bound = (table_id, 0.into());
260 let upper_bound = (table_id, u16::MAX.into());
261 self.params
262 .range(lower_bound..=upper_bound)
263 .map(|(_, col_id)| col_id)
264 .cloned()
265 }
266
267 fn queries_for_search_arg(
270 &self,
271 table_id: TableId,
272 col_id: ColId,
273 search_arg: AlgebraicValue,
274 ) -> impl Iterator<Item = &QueryHash> {
275 let lower_bound = (table_id, col_id, search_arg.clone(), QueryHash::MIN);
276 let upper_bound = (table_id, col_id, search_arg, QueryHash::MAX);
277 self.args.range(lower_bound..upper_bound).map(|(_, _, _, hash)| hash)
278 }
279
280 fn queries_for_row<'a>(&'a self, table_id: TableId, row: &'a ProductValue) -> impl Iterator<Item = &'a QueryHash> {
282 self.search_params_for_table(table_id)
283 .filter_map(|col_id| row.get_field(col_id.idx(), None).ok().map(|arg| (col_id, arg.clone())))
284 .flat_map(move |(col_id, arg)| self.queries_for_search_arg(table_id, col_id, arg))
285 }
286
287 fn remove_query(&mut self, query: &QueryHash) {
290 let mut params = self
292 .args
293 .iter()
294 .filter(|(_, _, _, hash)| hash == query)
295 .map(|(table_id, col_id, _, _)| (*table_id, *col_id))
296 .dedup()
297 .collect::<HashSet<_>>();
298
299 self.args.retain(|(_, _, _, hash)| hash != query);
301
302 params.retain(|(table_id, col_id)| {
304 self.args
305 .range(
306 (*table_id, *col_id, AlgebraicValue::Min, QueryHash::MIN)
307 ..=(*table_id, *col_id, AlgebraicValue::Max, QueryHash::MAX),
308 )
309 .next()
310 .is_none()
311 });
312
313 self.params
314 .retain(|(table_id, col_id)| !params.contains(&(*table_id, *col_id)));
315 }
316
317 fn insert_query(&mut self, table_id: TableId, col_id: ColId, arg: AlgebraicValue, query: QueryHash) {
319 self.args.insert((table_id, col_id, arg, query));
320 self.params.insert((table_id, col_id));
321 }
322}
323
324#[derive(Debug, Default)]
326pub struct QueriedTableIndexIds {
327 ids: HashMap<TableId, HashMap<IndexId, usize>>,
328}
329
330impl FromIterator<(TableId, IndexId)> for QueriedTableIndexIds {
331 fn from_iter<T: IntoIterator<Item = (TableId, IndexId)>>(iter: T) -> Self {
332 let mut index_ids = Self::default();
333 for (table_id, index_id) in iter {
334 index_ids.insert_index_id(table_id, index_id);
335 }
336 index_ids
337 }
338}
339
340impl QueriedTableIndexIds {
341 pub fn index_ids_for_table(&self, table_id: TableId) -> impl Iterator<Item = IndexId> + '_ {
345 self.ids
346 .get(&table_id)
347 .into_iter()
348 .flat_map(|index_ids| index_ids.keys())
349 .copied()
350 }
351
352 pub fn insert_index_id(&mut self, table_id: TableId, index_id: IndexId) {
356 *self.ids.entry(table_id).or_default().entry(index_id).or_default() += 1;
357 }
358
359 pub fn delete_index_id(&mut self, table_id: TableId, index_id: IndexId) {
363 if let Some(ids) = self.ids.get_mut(&table_id) {
364 if let Some(n) = ids.get_mut(&index_id) {
365 *n -= 1;
366
367 if *n == 0 {
368 ids.remove(&index_id);
369
370 if ids.is_empty() {
371 self.ids.remove(&table_id);
372 }
373 }
374 }
375 }
376 }
377
378 pub fn insert_index_ids_for_query(&mut self, query: &Query) {
382 for (table_id, index_id) in query.index_ids() {
383 self.insert_index_id(table_id, index_id);
384 }
385 }
386
387 pub fn delete_index_ids_for_query(&mut self, query: &Query) {
391 for (table_id, index_id) in query.index_ids() {
392 self.delete_index_id(table_id, index_id);
393 }
394 }
395}
396
397#[derive(Debug, Default)]
400pub struct JoinEdges {
401 edges: BTreeMap<JoinEdge, HashMap<AlgebraicValue, QueryHash>>,
402}
403
404impl JoinEdges {
405 fn add_query(&mut self, qs: &QueryState) -> bool {
407 let mut inserted = false;
408 for (edge, rhs_val) in qs.query.join_edges() {
409 inserted = true;
410 self.edges.entry(edge).or_default().insert(rhs_val, qs.query.hash);
411 }
412 inserted
413 }
414
415 fn remove_query(&mut self, query: &Query) {
417 for (edge, rhs_val) in query.join_edges() {
418 if let Some(hashes) = self.edges.get_mut(&edge) {
419 hashes.remove(&rhs_val);
420 if hashes.is_empty() {
421 self.edges.remove(&edge);
422 }
423 }
424 }
425 }
426
427 fn queries_for_row<'a>(
430 &'a self,
431 table_id: TableId,
432 row: &'a ProductValue,
433 find_rhs_val: impl Fn(&JoinEdge, &ProductValue) -> Option<AlgebraicValue>,
434 ) -> impl Iterator<Item = &'a QueryHash> {
435 self.edges
436 .range(JoinEdge::range_for_table(table_id))
437 .filter_map(move |(edge, hashes)| find_rhs_val(edge, row).as_ref().and_then(|rhs_val| hashes.get(rhs_val)))
438 }
439}
440
441#[derive(Debug)]
447pub struct SubscriptionManager {
448 clients: HashMap<ClientId, ClientInfo>,
450
451 queries: HashMap<QueryHash, QueryState>,
453
454 tables: IntMap<TableId, HashSet<QueryHash>>,
458
459 indexes: QueriedTableIndexIds,
462
463 search_args: SearchArguments,
467
468 join_edges: JoinEdges,
471
472 send_worker_queue: BroadcastQueue,
484}
485
486#[derive(Debug)]
488struct ClientUpdate {
489 id: ClientId,
490 table_id: TableId,
491 table_name: TableName,
492 update: FormatSwitch<SingleQueryUpdate<BsatnFormat>, SingleQueryUpdate<JsonFormat>>,
493}
494
495#[derive(Debug)]
501struct ComputedQueries {
502 updates: Vec<ClientUpdate>,
503 errs: Vec<(ClientId, Box<str>)>,
504 event: Arc<ModuleEvent>,
505 caller: Option<Arc<ClientConnectionSender>>,
506}
507
508#[derive(Debug)]
510struct SenderWithGauge<T> {
511 tx: mpsc::UnboundedSender<T>,
512 metric: Option<IntGauge>,
513}
514impl<T> Clone for SenderWithGauge<T> {
515 fn clone(&self) -> Self {
516 SenderWithGauge {
517 tx: self.tx.clone(),
518 metric: self.metric.clone(),
519 }
520 }
521}
522
523impl<T> SenderWithGauge<T> {
524 fn new(tx: mpsc::UnboundedSender<T>, metric: Option<IntGauge>) -> Self {
525 Self { tx, metric }
526 }
527
528 pub fn send(&self, msg: T) -> Result<(), mpsc::error::SendError<T>> {
530 if let Some(metric) = &self.metric {
531 metric.inc();
532 }
533 self.tx.send(msg)
535 }
536}
537
538#[derive(Debug)]
540enum SendWorkerMessage {
541 Broadcast(ComputedQueries),
544
545 AddClient {
548 client_id: ClientId,
549 dropped: Arc<AtomicBool>,
553 outbound_ref: Client,
554 },
555
556 SendMessage {
558 recipient: Arc<ClientConnectionSender>,
559 message: SerializableMessage,
560 },
561
562 RemoveClient(ClientId),
565}
566
567pub struct SubscriptionGaugeStats {
569 pub num_queries: usize,
571 pub num_connections: usize,
573 pub num_subscription_sets: usize,
575 pub num_query_subscriptions: usize,
577 pub num_legacy_subscriptions: usize,
579}
580
581impl SubscriptionManager {
582 pub fn for_test_without_metrics_arc_rwlock() -> Arc<RwLock<Self>> {
583 Arc::new(RwLock::new(Self::for_test_without_metrics()))
584 }
585
586 pub fn for_test_without_metrics() -> Self {
587 Self::new(SendWorker::spawn_new(None))
588 }
589
590 pub fn new(send_worker_queue: BroadcastQueue) -> Self {
591 Self {
592 clients: Default::default(),
593 queries: Default::default(),
594 indexes: Default::default(),
595 tables: Default::default(),
596 search_args: Default::default(),
597 join_edges: Default::default(),
598 send_worker_queue,
599 }
600 }
601
602 pub fn query(&self, hash: &QueryHash) -> Option<Query> {
603 self.queries.get(hash).map(|state| state.query.clone())
604 }
605
606 pub fn calculate_gauge_stats(&self) -> SubscriptionGaugeStats {
607 let num_queries = self.queries.len();
608 let num_connections = self.clients.len();
609 let num_query_subscriptions = self.queries.values().map(|state| state.subscriptions.len()).sum();
610 let num_subscription_sets = self.clients.values().map(|ci| ci.subscriptions.len()).sum();
611 let num_legacy_subscriptions = self
612 .clients
613 .values()
614 .filter(|ci| !ci.legacy_subscriptions.is_empty())
615 .count();
616
617 SubscriptionGaugeStats {
618 num_queries,
619 num_connections,
620 num_query_subscriptions,
621 num_subscription_sets,
622 num_legacy_subscriptions,
623 }
624 }
625
626 fn get_or_make_client_info_and_inform_send_worker<'clients>(
631 clients: &'clients mut HashMap<ClientId, ClientInfo>,
632 send_worker_tx: &BroadcastQueue,
633 client_id: ClientId,
634 outbound_ref: Client,
635 ) -> &'clients mut ClientInfo {
636 clients.entry(client_id).or_insert_with(|| {
637 let info = ClientInfo::new(outbound_ref.clone());
638 send_worker_tx
639 .send(SendWorkerMessage::AddClient {
640 client_id,
641 dropped: info.dropped.clone(),
642 outbound_ref,
643 })
644 .expect("send worker has panicked, or otherwise dropped its recv queue!");
645 info
646 })
647 }
648
649 fn remove_client_and_inform_send_worker(&mut self, client_id: ClientId) -> Option<ClientInfo> {
652 self.clients.remove(&client_id).inspect(|_| {
653 self.send_worker_queue
654 .send(SendWorkerMessage::RemoveClient(client_id))
655 .expect("send worker has panicked, or otherwise dropped its recv queue!");
656 })
657 }
658
659 pub fn num_unique_queries(&self) -> usize {
660 self.queries.len()
661 }
662
663 #[cfg(test)]
664 fn contains_query(&self, hash: &QueryHash) -> bool {
665 self.queries.contains_key(hash)
666 }
667
668 #[cfg(test)]
669 fn contains_client(&self, subscriber: &ClientId) -> bool {
670 self.clients.contains_key(subscriber)
671 }
672
673 #[cfg(test)]
674 fn contains_legacy_subscription(&self, subscriber: &ClientId, query: &QueryHash) -> bool {
675 self.queries
676 .get(query)
677 .is_some_and(|state| state.legacy_subscribers.contains(subscriber))
678 }
679
680 #[cfg(test)]
681 fn query_reads_from_table(&self, query: &QueryHash, table: &TableId) -> bool {
682 self.tables.get(table).is_some_and(|queries| queries.contains(query))
683 }
684
685 #[cfg(test)]
686 fn query_has_search_arg(&self, query: QueryHash, table_id: TableId, col_id: ColId, arg: AlgebraicValue) -> bool {
687 self.search_args
688 .queries_for_search_arg(table_id, col_id, arg)
689 .any(|hash| *hash == query)
690 }
691
692 #[cfg(test)]
693 fn table_has_search_param(&self, table_id: TableId, col_id: ColId) -> bool {
694 self.search_args
695 .search_params_for_table(table_id)
696 .any(|id| id == col_id)
697 }
698
699 fn remove_legacy_subscriptions(&mut self, client: &ClientId) {
700 if let Some(ci) = self.clients.get_mut(client) {
701 let mut queries_to_remove = Vec::new();
702 for query_hash in ci.legacy_subscriptions.iter() {
703 let Some(query_state) = self.queries.get_mut(query_hash) else {
704 tracing::warn!("Query state not found for query hash: {:?}", query_hash);
705 continue;
706 };
707
708 query_state.legacy_subscribers.remove(client);
709 if !query_state.has_subscribers() {
710 SubscriptionManager::remove_query_from_tables(
711 &mut self.tables,
712 &mut self.join_edges,
713 &mut self.indexes,
714 &mut self.search_args,
715 &query_state.query,
716 );
717 queries_to_remove.push(*query_hash);
718 }
719 }
720 ci.legacy_subscriptions.clear();
721 for query_hash in queries_to_remove {
722 self.queries.remove(&query_hash);
723 }
724 }
725 }
726
727 pub fn remove_dropped_clients(&mut self) {
729 for id in self.clients.keys().copied().collect::<Vec<_>>() {
730 if let Some(client) = self.clients.get(&id) {
731 if client.dropped.load(Ordering::Relaxed) {
732 self.remove_all_subscriptions(&id);
733 }
734 }
735 }
736 }
737
738 pub fn remove_subscription(&mut self, client_id: ClientId, query_id: ClientQueryId) -> Result<Vec<Query>, DBError> {
741 let subscription_id = (client_id, query_id);
742 let Some(ci) = self
743 .clients
744 .get_mut(&client_id)
745 .filter(|ci| !ci.dropped.load(Ordering::Acquire))
746 else {
747 return Err(anyhow::anyhow!("Client not found: {:?}", client_id).into());
748 };
749
750 #[cfg(test)]
751 ci.assert_ref_count_consistency();
752
753 let Some(query_hashes) = ci.subscriptions.remove(&subscription_id) else {
754 return Err(anyhow::anyhow!("Subscription not found: {:?}", subscription_id).into());
755 };
756 let mut queries_to_return = Vec::new();
757 for hash in query_hashes {
758 let remaining_refs = {
759 let Some(count) = ci.subscription_ref_count.get_mut(&hash) else {
760 return Err(anyhow::anyhow!("Query count not found for query hash: {:?}", hash).into());
761 };
762 *count -= 1;
763 *count
764 };
765 if remaining_refs > 0 {
766 continue;
768 }
769 ci.subscription_ref_count.remove(&hash);
771 let Some(query_state) = self.queries.get_mut(&hash) else {
772 return Err(anyhow::anyhow!("Query state not found for query hash: {:?}", hash).into());
773 };
774 queries_to_return.push(query_state.query.clone());
775 query_state.subscriptions.remove(&client_id);
776 if !query_state.has_subscribers() {
777 SubscriptionManager::remove_query_from_tables(
778 &mut self.tables,
779 &mut self.join_edges,
780 &mut self.indexes,
781 &mut self.search_args,
782 &query_state.query,
783 );
784 self.queries.remove(&hash);
785 }
786 }
787
788 #[cfg(test)]
789 ci.assert_ref_count_consistency();
790
791 Ok(queries_to_return)
792 }
793
794 pub fn add_subscription(&mut self, client: Client, query: Query, query_id: ClientQueryId) -> Result<(), DBError> {
796 self.add_subscription_multi(client, vec![query], query_id).map(|_| ())
797 }
798
799 pub fn add_subscription_multi(
800 &mut self,
801 client: Client,
802 queries: Vec<Query>,
803 query_id: ClientQueryId,
804 ) -> Result<Vec<Query>, DBError> {
805 let client_id = (client.id.identity, client.id.connection_id);
806
807 if self
809 .clients
810 .get(&client_id)
811 .is_some_and(|ci| ci.dropped.load(Ordering::Acquire))
812 {
813 self.remove_all_subscriptions(&client_id);
814 }
815
816 let ci = Self::get_or_make_client_info_and_inform_send_worker(
817 &mut self.clients,
818 &self.send_worker_queue,
819 client_id,
820 client,
821 );
822
823 #[cfg(test)]
824 ci.assert_ref_count_consistency();
825 let subscription_id = (client_id, query_id);
826 let hash_set = match ci.subscriptions.try_insert(subscription_id, HashSet::new()) {
827 Err(OccupiedError { .. }) => {
828 return Err(anyhow::anyhow!(
829 "Subscription with id {:?} already exists for client: {:?}",
830 query_id,
831 client_id
832 )
833 .into());
834 }
835 Ok(hash_set) => hash_set,
836 };
837 let mut new_queries = Vec::new();
839
840 for query in &queries {
841 let hash = query.hash();
842 if !hash_set.insert(hash) {
844 continue;
845 }
846 let query_state = self
847 .queries
848 .entry(hash)
849 .or_insert_with(|| QueryState::new(query.clone()));
850
851 Self::insert_query(
852 &mut self.tables,
853 &mut self.join_edges,
854 &mut self.indexes,
855 &mut self.search_args,
856 query_state,
857 );
858
859 let entry = ci.subscription_ref_count.entry(hash).or_insert(0);
860 *entry += 1;
861 let is_new_entry = *entry == 1;
862
863 let inserted = query_state.subscriptions.insert(client_id);
864 if inserted != is_new_entry {
866 return Err(anyhow::anyhow!("Internal error, ref count and query_state mismatch").into());
867 }
868 if inserted {
869 new_queries.push(query.clone());
870 }
871 }
872
873 #[cfg(test)]
874 {
875 ci.assert_ref_count_consistency();
876 }
877
878 Ok(new_queries)
879 }
880
881 pub fn set_legacy_subscription(&mut self, client: Client, queries: impl IntoIterator<Item = Query>) {
888 let client_id = (client.id.identity, client.id.connection_id);
889 self.remove_legacy_subscriptions(&client_id);
891
892 let ci = Self::get_or_make_client_info_and_inform_send_worker(
894 &mut self.clients,
895 &self.send_worker_queue,
896 client_id,
897 client,
898 );
899
900 for unit in queries {
901 let hash = unit.hash();
902 ci.legacy_subscriptions.insert(hash);
903 let query_state = self
904 .queries
905 .entry(hash)
906 .or_insert_with(|| QueryState::new(unit.clone()));
907 Self::insert_query(
908 &mut self.tables,
909 &mut self.join_edges,
910 &mut self.indexes,
911 &mut self.search_args,
912 query_state,
913 );
914 query_state.legacy_subscribers.insert(client_id);
915 }
916 }
917
918 fn remove_query_from_tables(
922 tables: &mut IntMap<TableId, HashSet<QueryHash>>,
923 join_edges: &mut JoinEdges,
924 index_ids: &mut QueriedTableIndexIds,
925 search_args: &mut SearchArguments,
926 query: &Query,
927 ) {
928 let hash = query.hash();
929 join_edges.remove_query(query);
930 search_args.remove_query(&hash);
931 index_ids.delete_index_ids_for_query(query);
932 for table_id in query.table_ids() {
933 if let Entry::Occupied(mut entry) = tables.entry(table_id) {
934 let hashes = entry.get_mut();
935 if hashes.remove(&hash) && hashes.is_empty() {
936 entry.remove();
937 }
938 }
939 }
940 }
941
942 fn insert_query(
946 tables: &mut IntMap<TableId, HashSet<QueryHash>>,
947 join_edges: &mut JoinEdges,
948 index_ids: &mut QueriedTableIndexIds,
949 search_args: &mut SearchArguments,
950 query_state: &QueryState,
951 ) {
952 if !query_state.has_subscribers() {
954 let hash = query_state.query.hash;
955 let query = query_state.query();
956 let return_table = query.subscribed_table_id();
957 let mut table_ids = query.table_ids().collect::<HashSet<_>>();
958
959 index_ids.insert_index_ids_for_query(query);
961
962 for (table_id, col_id, arg) in query_state.search_args() {
964 table_ids.remove(&table_id);
965 search_args.insert_query(table_id, col_id, arg, hash);
966 }
967
968 if table_ids.contains(&return_table) && join_edges.add_query(query_state) {
970 table_ids.remove(&return_table);
971 }
972
973 for table_id in table_ids {
975 tables.entry(table_id).or_default().insert(hash);
976 }
977 }
978 }
979
980 #[tracing::instrument(level = "trace", skip_all)]
984 pub fn remove_all_subscriptions(&mut self, client: &ClientId) {
985 self.remove_legacy_subscriptions(client);
986 let Some(client_info) = self.remove_client_and_inform_send_worker(*client) else {
987 return;
988 };
989
990 debug_assert!(client_info.legacy_subscriptions.is_empty());
991 let mut queries_to_remove = Vec::new();
992 for query_hash in client_info.subscription_ref_count.keys() {
993 let Some(query_state) = self.queries.get_mut(query_hash) else {
994 tracing::warn!("Query state not found for query hash: {:?}", query_hash);
995 return;
996 };
997 query_state.subscriptions.remove(client);
998 if !query_state.has_subscribers() {
1000 queries_to_remove.push(*query_hash);
1001 SubscriptionManager::remove_query_from_tables(
1002 &mut self.tables,
1003 &mut self.join_edges,
1004 &mut self.indexes,
1005 &mut self.search_args,
1006 &query_state.query,
1007 );
1008 }
1009 }
1010 for query_hash in queries_to_remove {
1011 self.queries.remove(&query_hash);
1012 }
1013 }
1014
1015 fn queries_for_table_update<'a>(
1043 &'a self,
1044 table_update: &'a DatabaseTableUpdate,
1045 find_rhs_val: &impl Fn(&JoinEdge, &ProductValue) -> Option<AlgebraicValue>,
1046 ) -> impl Iterator<Item = &'a QueryHash> {
1047 let mut queries = HashSet::new();
1048 for hash in table_update
1049 .inserts
1050 .iter()
1051 .chain(table_update.deletes.iter())
1052 .flat_map(|row| self.queries_for_row(table_update.table_id, row, find_rhs_val))
1053 {
1054 queries.insert(hash);
1055 }
1056 for hash in self.tables.get(&table_update.table_id).into_iter().flatten() {
1057 queries.insert(hash);
1058 }
1059 queries.into_iter()
1060 }
1061
1062 fn queries_for_row<'a>(
1064 &'a self,
1065 table_id: TableId,
1066 row: &'a ProductValue,
1067 find_rhs_val: impl Fn(&JoinEdge, &ProductValue) -> Option<AlgebraicValue>,
1068 ) -> impl Iterator<Item = &'a QueryHash> {
1069 self.search_args
1070 .queries_for_row(table_id, row)
1071 .chain(self.join_edges.queries_for_row(table_id, row, find_rhs_val))
1072 }
1073
1074 pub fn index_ids_for_subscriptions(&self) -> &QueriedTableIndexIds {
1076 &self.indexes
1077 }
1078
1079 #[tracing::instrument(level = "trace", skip_all)]
1088 pub fn eval_updates_sequential(
1089 &self,
1090 tx: &DeltaTx,
1091 event: Arc<ModuleEvent>,
1092 caller: Option<Arc<ClientConnectionSender>>,
1093 ) -> ExecutionMetrics {
1094 use FormatSwitch::{Bsatn, Json};
1095
1096 let tables = &event.status.database_update().unwrap().tables;
1097
1098 let span = tracing::info_span!("eval_incr").entered();
1099
1100 #[derive(Default)]
1101 struct FoldState {
1102 updates: Vec<ClientUpdate>,
1103 errs: Vec<(ClientId, Box<str>)>,
1104 metrics: ExecutionMetrics,
1105 }
1106
1107 fn find_rhs_val(edge: &JoinEdge, row: &ProductValue, tx: &DeltaTx) -> Option<AlgebraicValue> {
1109 tx.iter_by_col_eq(
1116 edge.rhs_table,
1117 edge.rhs_join_col,
1118 &row.elements[edge.lhs_join_col.idx()],
1119 )
1120 .expect("This read should always succeed, and it's a bug if it doesn't")
1121 .next()
1122 .map(|row| {
1123 row.read_col(edge.rhs_col)
1124 .expect("This read should always succeed, and it's a bug if it doesn't")
1125 })
1126 }
1127
1128 let FoldState { updates, errs, metrics } = tables
1129 .iter()
1130 .filter(|table| !table.inserts.is_empty() || !table.deletes.is_empty())
1131 .flat_map(|table_update| {
1132 self.queries_for_table_update(table_update, &|edge, row| find_rhs_val(edge, row, tx))
1133 })
1134 .filter({
1136 let mut seen = HashSet::new();
1137 move |&hash| seen.insert(hash)
1139 })
1140 .flat_map(|hash| {
1141 let qstate = &self.queries[hash];
1142 qstate
1143 .query
1144 .plans_fragments()
1145 .map(move |plan_fragment| (qstate, plan_fragment))
1146 })
1147 .fold(FoldState::default(), |mut acc, (qstate, plan)| {
1151 let table_id = plan.subscribed_table_id();
1152 let table_name = plan.subscribed_table_name().clone();
1153 let mut ops_bin_uncompressed: Option<(CompressableQueryUpdate<BsatnFormat>, _, _)> = None;
1174 let mut ops_json: Option<(QueryUpdate<JsonFormat>, _, _)> = None;
1175
1176 fn memo_encode<F: WebsocketFormat>(
1177 updates: &UpdatesRelValue<'_>,
1178 memory: &mut Option<(F::QueryUpdate, u64, usize)>,
1179 metrics: &mut ExecutionMetrics,
1180 ) -> SingleQueryUpdate<F> {
1181 let (update, num_rows, num_bytes) = memory
1182 .get_or_insert_with(|| {
1183 let encoded = updates.encode::<F>();
1184 metrics.bytes_scanned += encoded.2;
1188 encoded
1189 })
1190 .clone();
1191 metrics.bytes_sent_to_clients += num_bytes;
1196 SingleQueryUpdate { update, num_rows }
1197 }
1198
1199 let clients_for_query = qstate.all_clients().filter(|id| {
1201 self.clients
1202 .get(*id)
1203 .is_some_and(|info| !info.dropped.load(Ordering::Acquire))
1204 });
1205
1206 match eval_delta(tx, &mut acc.metrics, plan) {
1207 Err(err) => {
1208 tracing::error!(
1209 message = "Query errored during tx update",
1210 sql = qstate.query.sql,
1211 reason = ?err,
1212 );
1213 let err = DBError::WithSql {
1214 sql: qstate.query.sql.as_str().into(),
1215 error: Box::new(err.into()),
1216 }
1217 .to_string()
1218 .into_boxed_str();
1219
1220 acc.errs.extend(clients_for_query.map(|id| (*id, err.clone())))
1221 }
1222 Ok(None) => {}
1224 Ok(Some(delta_updates)) => {
1226 let row_iter = clients_for_query.map(|id| {
1227 let client = &self.clients[id].outbound_ref;
1228 let update = match client.config.protocol {
1229 Protocol::Binary => Bsatn(memo_encode::<BsatnFormat>(
1230 &delta_updates,
1231 &mut ops_bin_uncompressed,
1232 &mut acc.metrics,
1233 )),
1234 Protocol::Text => Json(memo_encode::<JsonFormat>(
1235 &delta_updates,
1236 &mut ops_json,
1237 &mut acc.metrics,
1238 )),
1239 };
1240 ClientUpdate {
1241 id: *id,
1242 table_id,
1243 table_name: table_name.clone(),
1244 update,
1245 }
1246 });
1247 acc.updates.extend(row_iter);
1248 }
1249 }
1250
1251 acc
1252 });
1253
1254 self.send_worker_queue
1259 .send(SendWorkerMessage::Broadcast(ComputedQueries {
1260 updates,
1261 errs,
1262 event,
1263 caller,
1264 }))
1265 .expect("send worker has panicked, or otherwise dropped its recv queue!");
1266
1267 drop(span);
1268
1269 metrics
1270 }
1271}
1272
1273struct SendWorkerClient {
1274 dropped: Arc<AtomicBool>,
1280 outbound_ref: Client,
1281}
1282
1283struct SendWorker {
1288 rx: mpsc::UnboundedReceiver<SendWorkerMessage>,
1290
1291 queue_length_metric: Option<IntGauge>,
1298
1299 clients: HashMap<ClientId, SendWorkerClient>,
1304
1305 database_identity_to_clean_up_metric: Option<Identity>,
1309}
1310
1311impl Drop for SendWorker {
1312 fn drop(&mut self) {
1313 if let Some(identity) = self.database_identity_to_clean_up_metric {
1314 let _ = WORKER_METRICS
1315 .subscription_send_queue_length
1316 .remove_label_values(&identity);
1317 }
1318 }
1319}
1320
1321#[derive(Debug, Clone)]
1322pub struct BroadcastQueue(SenderWithGauge<SendWorkerMessage>);
1323
1324#[derive(thiserror::Error, Debug)]
1325#[error(transparent)]
1326pub struct BroadcastError(#[from] mpsc::error::SendError<SendWorkerMessage>);
1327
1328impl BroadcastQueue {
1329 fn send(&self, message: SendWorkerMessage) -> Result<(), BroadcastError> {
1330 self.0.send(message)?;
1331 Ok(())
1332 }
1333
1334 pub fn send_client_message(
1335 &self,
1336 recipient: Arc<ClientConnectionSender>,
1337 message: impl Into<SerializableMessage>,
1338 ) -> Result<(), BroadcastError> {
1339 self.0.send(SendWorkerMessage::SendMessage {
1340 recipient,
1341 message: message.into(),
1342 })?;
1343 Ok(())
1344 }
1345}
1346pub fn spawn_send_worker(metric_database_identity: Option<Identity>) -> BroadcastQueue {
1347 SendWorker::spawn_new(metric_database_identity)
1348}
1349impl SendWorker {
1350 fn new(
1351 rx: mpsc::UnboundedReceiver<SendWorkerMessage>,
1352 queue_length_metric: Option<IntGauge>,
1353 database_identity_to_clean_up_metric: Option<Identity>,
1354 ) -> Self {
1355 Self {
1356 rx,
1357 queue_length_metric,
1358 clients: Default::default(),
1359 database_identity_to_clean_up_metric,
1360 }
1361 }
1362
1363 fn spawn_new(metric_database_identity: Option<Identity>) -> BroadcastQueue {
1367 let metric = metric_database_identity.map(|identity| {
1368 WORKER_METRICS
1369 .subscription_send_queue_length
1370 .with_label_values(&identity)
1371 });
1372 let (send_worker_tx, rx) = mpsc::unbounded_channel();
1373 tokio::spawn(Self::new(rx, metric.clone(), metric_database_identity).run());
1374 BroadcastQueue(SenderWithGauge::new(send_worker_tx, metric))
1375 }
1376
1377 async fn run(mut self) {
1378 while let Some(message) = self.rx.recv().await {
1379 if let Some(metric) = &self.queue_length_metric {
1380 metric.dec();
1381 }
1382
1383 match message {
1384 SendWorkerMessage::AddClient {
1385 client_id,
1386 dropped,
1387 outbound_ref,
1388 } => {
1389 self.clients
1390 .insert(client_id, SendWorkerClient { dropped, outbound_ref });
1391 }
1392 SendWorkerMessage::SendMessage { recipient, message } => {
1393 let _ = recipient.send_message(message);
1394 }
1395 SendWorkerMessage::RemoveClient(client_id) => {
1396 self.clients.remove(&client_id);
1397 }
1398 SendWorkerMessage::Broadcast(queries) => {
1399 self.send_one_computed_queries(queries);
1400 }
1401 }
1402 }
1403 }
1404
1405 fn send_one_computed_queries(
1406 &self,
1407 ComputedQueries {
1408 updates,
1409 errs,
1410 event,
1411 caller,
1412 }: ComputedQueries,
1413 ) {
1414 use FormatSwitch::{Bsatn, Json};
1415
1416 let clients_with_errors = errs.iter().map(|(id, _)| id).collect::<HashSet<_>>();
1417
1418 let span = tracing::info_span!("eval_incr_group_messages_by_client");
1419
1420 let mut eval = updates
1421 .into_iter()
1422 .filter(|upd| !clients_with_errors.contains(&upd.id))
1424 .fold(
1430 HashMap::<(ClientId, TableId), SwitchedTableUpdate>::new(),
1431 |mut tables, upd| {
1432 match tables.entry((upd.id, upd.table_id)) {
1433 Entry::Occupied(mut entry) => match entry.get_mut().zip_mut(upd.update) {
1434 Bsatn((tbl_upd, update)) => tbl_upd.push(update),
1435 Json((tbl_upd, update)) => tbl_upd.push(update),
1436 },
1437 Entry::Vacant(entry) => drop(entry.insert(match upd.update {
1438 Bsatn(update) => Bsatn(TableUpdate::new(upd.table_id, (&*upd.table_name).into(), update)),
1439 Json(update) => Json(TableUpdate::new(upd.table_id, (&*upd.table_name).into(), update)),
1440 })),
1441 }
1442 tables
1443 },
1444 )
1445 .into_iter()
1446 .fold(
1450 HashMap::<ClientId, SwitchedDbUpdate>::new(),
1451 |mut updates, ((id, _), update)| {
1452 let entry = updates.entry(id);
1453 let entry = entry.or_insert_with(|| match &update {
1454 Bsatn(_) => Bsatn(<_>::default()),
1455 Json(_) => Json(<_>::default()),
1456 });
1457 match entry.zip_mut(update) {
1458 Bsatn((list, elem)) => list.tables.push(elem),
1459 Json((list, elem)) => list.tables.push(elem),
1460 }
1461 updates
1462 },
1463 );
1464
1465 drop(clients_with_errors);
1466 drop(span);
1467
1468 let _span = tracing::info_span!("eval_send").entered();
1469
1470 if let Some(caller) = caller {
1477 let caller_id = (caller.id.identity, caller.id.connection_id);
1478 let database_update = eval
1479 .remove(&caller_id)
1480 .map(|update| SubscriptionUpdateMessage::from_event_and_update(&event, update))
1481 .unwrap_or_else(|| {
1482 SubscriptionUpdateMessage::default_for_protocol(caller.config.protocol, event.request_id)
1483 });
1484 let message = TransactionUpdateMessage {
1485 event: Some(event.clone()),
1486 database_update,
1487 };
1488 send_to_client(&caller, message);
1489 }
1490
1491 for (id, update) in eval {
1493 let database_update = SubscriptionUpdateMessage::from_event_and_update(&event, update);
1494 let client = self.clients[&id].outbound_ref.clone();
1495 let event = client.config.tx_update_full.then(|| event.clone());
1497 let message = TransactionUpdateMessage { event, database_update };
1498 send_to_client(&client, message);
1499 }
1500
1501 for (id, message) in errs {
1503 if let Some(client) = self.clients.get(&id) {
1504 client.dropped.store(true, Ordering::Release);
1505 send_to_client(
1506 &client.outbound_ref,
1507 SubscriptionMessage {
1508 request_id: None,
1509 query_id: None,
1510 timer: None,
1511 result: SubscriptionResult::Error(SubscriptionError {
1512 table_id: None,
1513 message,
1514 }),
1515 },
1516 );
1517 }
1518 }
1519 }
1520}
1521
1522fn send_to_client(client: &ClientConnectionSender, message: impl Into<SerializableMessage>) {
1523 if let Err(e) = client.send_message(message) {
1524 tracing::warn!(%client.id, "failed to send update message to client: {e}")
1525 }
1526}
1527
1528#[cfg(test)]
1529mod tests {
1530 use std::{sync::Arc, time::Duration};
1531
1532 use spacetimedb_client_api_messages::websocket::QueryId;
1533 use spacetimedb_lib::AlgebraicValue;
1534 use spacetimedb_lib::{error::ResultTest, identity::AuthCtx, AlgebraicType, ConnectionId, Identity, Timestamp};
1535 use spacetimedb_primitives::{ColId, TableId};
1536 use spacetimedb_sats::product;
1537 use spacetimedb_subscription::SubscriptionPlan;
1538
1539 use super::{Plan, SubscriptionManager};
1540 use crate::db::relational_db::tests_utils::with_read_only;
1541 use crate::execution_context::Workload;
1542 use crate::host::module_host::DatabaseTableUpdate;
1543 use crate::sql::ast::SchemaViewer;
1544 use crate::subscription::module_subscription_manager::ClientQueryId;
1545 use crate::{
1546 client::{ClientActorId, ClientConfig, ClientConnectionSender, ClientName},
1547 db::relational_db::{tests_utils::TestDB, RelationalDB},
1548 energy::EnergyQuanta,
1549 host::{
1550 module_host::{DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall},
1551 ArgsTuple,
1552 },
1553 subscription::execution_unit::QueryHash,
1554 };
1555
1556 fn create_table(db: &RelationalDB, name: &str) -> ResultTest<TableId> {
1557 Ok(db.create_table_for_test(name, &[("a", AlgebraicType::U8)], &[])?)
1558 }
1559
1560 fn compile_plan(db: &RelationalDB, sql: &str) -> ResultTest<Arc<Plan>> {
1561 with_read_only(db, |tx| {
1562 let auth = AuthCtx::for_testing();
1563 let tx = SchemaViewer::new(&*tx, &auth);
1564 let (plans, has_param) = SubscriptionPlan::compile(sql, &tx, &auth).unwrap();
1565 let hash = QueryHash::from_string(sql, auth.caller, has_param);
1566 Ok(Arc::new(Plan::new(plans, hash, sql.into())))
1567 })
1568 }
1569
1570 fn id(connection_id: u128) -> (Identity, ConnectionId) {
1571 (Identity::ZERO, ConnectionId::from_u128(connection_id))
1572 }
1573
1574 fn client(connection_id: u128) -> ClientConnectionSender {
1575 let (identity, connection_id) = id(connection_id);
1576 ClientConnectionSender::dummy(
1577 ClientActorId {
1578 identity,
1579 connection_id,
1580 name: ClientName(0),
1581 },
1582 ClientConfig::for_test(),
1583 )
1584 }
1585
1586 #[test]
1587 fn test_subscribe_legacy() -> ResultTest<()> {
1588 let db = TestDB::durable()?;
1589
1590 let table_id = create_table(&db, "T")?;
1591 let sql = "select * from T";
1592 let plan = compile_plan(&db, sql)?;
1593 let hash = plan.hash();
1594
1595 let id = id(0);
1596 let client = Arc::new(client(0));
1597
1598 let runtime = tokio::runtime::Runtime::new().unwrap();
1599 let _rt = runtime.enter();
1600
1601 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1602 subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]);
1603
1604 assert!(subscriptions.contains_query(&hash));
1605 assert!(subscriptions.contains_legacy_subscription(&id, &hash));
1606 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1607
1608 Ok(())
1609 }
1610
1611 #[test]
1612 fn test_subscribe_single_adds_table_mapping() -> ResultTest<()> {
1613 let db = TestDB::durable()?;
1614
1615 let table_id = create_table(&db, "T")?;
1616 let sql = "select * from T";
1617 let plan = compile_plan(&db, sql)?;
1618 let hash = plan.hash();
1619
1620 let client = Arc::new(client(0));
1621
1622 let query_id: ClientQueryId = QueryId::new(1);
1623
1624 let runtime = tokio::runtime::Runtime::new().unwrap();
1625 let _rt = runtime.enter();
1626
1627 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1628 subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1629 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1630
1631 Ok(())
1632 }
1633
1634 #[test]
1635 fn test_unsubscribe_from_the_only_subscription() -> ResultTest<()> {
1636 let db = TestDB::durable()?;
1637
1638 let table_id = create_table(&db, "T")?;
1639 let sql = "select * from T";
1640 let plan = compile_plan(&db, sql)?;
1641 let hash = plan.hash();
1642
1643 let client = Arc::new(client(0));
1644
1645 let query_id: ClientQueryId = QueryId::new(1);
1646
1647 let runtime = tokio::runtime::Runtime::new().unwrap();
1648 let _rt = runtime.enter();
1649
1650 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1651 subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1652 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1653
1654 let client_id = (client.id.identity, client.id.connection_id);
1655 subscriptions.remove_subscription(client_id, query_id)?;
1656 assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1657
1658 Ok(())
1659 }
1660
1661 #[test]
1662 fn test_unsubscribe_with_unknown_query_id_fails() -> ResultTest<()> {
1663 let db = TestDB::durable()?;
1664
1665 create_table(&db, "T")?;
1666 let sql = "select * from T";
1667 let plan = compile_plan(&db, sql)?;
1668
1669 let client = Arc::new(client(0));
1670
1671 let query_id: ClientQueryId = QueryId::new(1);
1672
1673 let runtime = tokio::runtime::Runtime::new().unwrap();
1674 let _rt = runtime.enter();
1675
1676 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1677 subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1678
1679 let client_id = (client.id.identity, client.id.connection_id);
1680 assert!(subscriptions.remove_subscription(client_id, QueryId::new(2)).is_err());
1681
1682 Ok(())
1683 }
1684
1685 #[test]
1686 fn test_subscribe_and_unsubscribe_with_duplicate_queries() -> ResultTest<()> {
1687 let db = TestDB::durable()?;
1688
1689 let table_id = create_table(&db, "T")?;
1690 let sql = "select * from T";
1691 let plan = compile_plan(&db, sql)?;
1692 let hash = plan.hash();
1693
1694 let client = Arc::new(client(0));
1695
1696 let query_id: ClientQueryId = QueryId::new(1);
1697
1698 let runtime = tokio::runtime::Runtime::new().unwrap();
1699 let _rt = runtime.enter();
1700
1701 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1702 subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
1703 subscriptions.add_subscription(client.clone(), plan.clone(), QueryId::new(2))?;
1704
1705 let client_id = (client.id.identity, client.id.connection_id);
1706 subscriptions.remove_subscription(client_id, query_id)?;
1707
1708 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1709
1710 Ok(())
1711 }
1712
1713 #[test]
1715 fn test_subscribe_and_unsubscribe_with_duplicate_queries_multi() -> ResultTest<()> {
1716 let db = TestDB::durable()?;
1717
1718 let table_id = create_table(&db, "T")?;
1719 let sql = "select * from T";
1720 let plan = compile_plan(&db, sql)?;
1721 let hash = plan.hash();
1722
1723 let client = Arc::new(client(0));
1724
1725 let query_id: ClientQueryId = QueryId::new(1);
1726
1727 let runtime = tokio::runtime::Runtime::new().unwrap();
1728 let _rt = runtime.enter();
1729
1730 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1731 let added_query = subscriptions.add_subscription_multi(client.clone(), vec![plan.clone()], query_id)?;
1732 assert!(added_query.len() == 1);
1733 assert_eq!(added_query[0].hash, hash);
1734 let second_one = subscriptions.add_subscription_multi(client.clone(), vec![plan.clone()], QueryId::new(2))?;
1735 assert!(second_one.is_empty());
1736
1737 let client_id = (client.id.identity, client.id.connection_id);
1738 let removed_queries = subscriptions.remove_subscription(client_id, query_id)?;
1739 assert!(removed_queries.is_empty());
1740
1741 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1742 let removed_queries = subscriptions.remove_subscription(client_id, QueryId::new(2))?;
1743 assert!(removed_queries.len() == 1);
1744 assert_eq!(removed_queries[0].hash, hash);
1745
1746 assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1747
1748 Ok(())
1749 }
1750
1751 #[test]
1752 fn test_unsubscribe_doesnt_remove_other_clients() -> ResultTest<()> {
1753 let db = TestDB::durable()?;
1754
1755 let table_id = create_table(&db, "T")?;
1756 let sql = "select * from T";
1757 let plan = compile_plan(&db, sql)?;
1758 let hash = plan.hash();
1759
1760 let clients = (0..3).map(|i| Arc::new(client(i))).collect::<Vec<_>>();
1761
1762 let query_id: ClientQueryId = QueryId::new(1);
1764
1765 let runtime = tokio::runtime::Runtime::new().unwrap();
1766 let _rt = runtime.enter();
1767
1768 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1769 subscriptions.add_subscription(clients[0].clone(), plan.clone(), query_id)?;
1770 subscriptions.add_subscription(clients[1].clone(), plan.clone(), query_id)?;
1771 subscriptions.add_subscription(clients[2].clone(), plan.clone(), query_id)?;
1772
1773 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1774
1775 let client_ids = clients
1776 .iter()
1777 .map(|client| (client.id.identity, client.id.connection_id))
1778 .collect::<Vec<_>>();
1779 subscriptions.remove_subscription(client_ids[0], query_id)?;
1780 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1782 subscriptions.remove_subscription(client_ids[1], query_id)?;
1783 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1785 subscriptions.remove_subscription(client_ids[2], query_id)?;
1786 assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1788
1789 Ok(())
1790 }
1791
1792 #[test]
1793 fn test_unsubscribe_all_doesnt_remove_other_clients() -> ResultTest<()> {
1794 let db = TestDB::durable()?;
1795
1796 let table_id = create_table(&db, "T")?;
1797 let sql = "select * from T";
1798 let plan = compile_plan(&db, sql)?;
1799 let hash = plan.hash();
1800
1801 let clients = (0..3).map(|i| Arc::new(client(i))).collect::<Vec<_>>();
1802
1803 let query_id: ClientQueryId = QueryId::new(1);
1805
1806 let runtime = tokio::runtime::Runtime::new().unwrap();
1807 let _rt = runtime.enter();
1808
1809 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1810 subscriptions.add_subscription(clients[0].clone(), plan.clone(), query_id)?;
1811 subscriptions.add_subscription(clients[1].clone(), plan.clone(), query_id)?;
1812 subscriptions.add_subscription(clients[2].clone(), plan.clone(), query_id)?;
1813
1814 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1815
1816 let client_ids = clients
1817 .iter()
1818 .map(|client| (client.id.identity, client.id.connection_id))
1819 .collect::<Vec<_>>();
1820 subscriptions.remove_all_subscriptions(&client_ids[0]);
1821 assert!(!subscriptions.contains_client(&client_ids[0]));
1822 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1824 subscriptions.remove_all_subscriptions(&client_ids[1]);
1825 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
1827 assert!(!subscriptions.contains_client(&client_ids[1]));
1828 subscriptions.remove_all_subscriptions(&client_ids[2]);
1829 assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
1831 assert!(!subscriptions.contains_client(&client_ids[2]));
1832
1833 Ok(())
1834 }
1835
1836 #[test]
1838 fn test_multiple_queries() -> ResultTest<()> {
1839 let db = TestDB::durable()?;
1840
1841 let table_names = ["T", "S", "U"];
1842 let table_ids = table_names
1843 .iter()
1844 .map(|name| create_table(&db, name))
1845 .collect::<ResultTest<Vec<_>>>()?;
1846 let queries = table_names
1847 .iter()
1848 .map(|name| format!("select * from {}", name))
1849 .map(|sql| compile_plan(&db, &sql))
1850 .collect::<ResultTest<Vec<_>>>()?;
1851
1852 let client = Arc::new(client(0));
1853
1854 let runtime = tokio::runtime::Runtime::new().unwrap();
1855 let _rt = runtime.enter();
1856
1857 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1858 subscriptions.add_subscription(client.clone(), queries[0].clone(), QueryId::new(1))?;
1859 subscriptions.add_subscription(client.clone(), queries[1].clone(), QueryId::new(2))?;
1860 subscriptions.add_subscription(client.clone(), queries[2].clone(), QueryId::new(3))?;
1861 for i in 0..3 {
1862 assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1863 }
1864
1865 let client_id = (client.id.identity, client.id.connection_id);
1866 subscriptions.remove_subscription(client_id, QueryId::new(1))?;
1867 assert!(!subscriptions.query_reads_from_table(&queries[0].hash(), &table_ids[0]));
1868 for i in 1..3 {
1870 assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1871 }
1872
1873 subscriptions.remove_all_subscriptions(&client_id);
1875 for i in 0..3 {
1876 assert!(!subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1877 }
1878
1879 Ok(())
1880 }
1881
1882 #[test]
1883 fn test_multiple_query_sets() -> ResultTest<()> {
1884 let db = TestDB::durable()?;
1885
1886 let table_names = ["T", "S", "U"];
1887 let table_ids = table_names
1888 .iter()
1889 .map(|name| create_table(&db, name))
1890 .collect::<ResultTest<Vec<_>>>()?;
1891 let queries = table_names
1892 .iter()
1893 .map(|name| format!("select * from {}", name))
1894 .map(|sql| compile_plan(&db, &sql))
1895 .collect::<ResultTest<Vec<_>>>()?;
1896
1897 let client = Arc::new(client(0));
1898
1899 let runtime = tokio::runtime::Runtime::new().unwrap();
1900 let _rt = runtime.enter();
1901
1902 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1903 let added = subscriptions.add_subscription_multi(client.clone(), vec![queries[0].clone()], QueryId::new(1))?;
1904 assert_eq!(added.len(), 1);
1905 assert_eq!(added[0].hash, queries[0].hash());
1906 let added = subscriptions.add_subscription_multi(client.clone(), vec![queries[1].clone()], QueryId::new(2))?;
1907 assert_eq!(added.len(), 1);
1908 assert_eq!(added[0].hash, queries[1].hash());
1909 let added = subscriptions.add_subscription_multi(client.clone(), vec![queries[2].clone()], QueryId::new(3))?;
1910 assert_eq!(added.len(), 1);
1911 assert_eq!(added[0].hash, queries[2].hash());
1912 for i in 0..3 {
1913 assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1914 }
1915
1916 let client_id = (client.id.identity, client.id.connection_id);
1917 let removed = subscriptions.remove_subscription(client_id, QueryId::new(1))?;
1918 assert_eq!(removed.len(), 1);
1919 assert_eq!(removed[0].hash, queries[0].hash());
1920 assert!(!subscriptions.query_reads_from_table(&queries[0].hash(), &table_ids[0]));
1921 for i in 1..3 {
1923 assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1924 }
1925
1926 subscriptions.remove_all_subscriptions(&client_id);
1928 for i in 0..3 {
1929 assert!(!subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
1930 }
1931
1932 Ok(())
1933 }
1934
1935 #[test]
1936 fn test_internals_for_search_args() -> ResultTest<()> {
1937 let db = TestDB::durable()?;
1938
1939 let table_id = create_table(&db, "t")?;
1940
1941 let client = Arc::new(client(0));
1942
1943 let runtime = tokio::runtime::Runtime::new().unwrap();
1944 let _rt = runtime.enter();
1945
1946 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
1947
1948 let queries = (0u8..5)
1950 .map(|name| format!("select * from t where a = {}", name))
1951 .map(|sql| compile_plan(&db, &sql))
1952 .collect::<ResultTest<Vec<_>>>()?;
1953
1954 for (i, query) in queries.iter().enumerate().take(5) {
1955 let added =
1956 subscriptions.add_subscription_multi(client.clone(), vec![query.clone()], QueryId::new(i as u32))?;
1957 assert_eq!(added.len(), 1);
1958 assert_eq!(added[0].hash, queries[i].hash);
1959 }
1960
1961 assert!(subscriptions.table_has_search_param(table_id, ColId(0)));
1963
1964 for (i, query) in queries.iter().enumerate().take(5) {
1965 assert!(subscriptions.query_has_search_arg(query.hash, table_id, ColId(0), AlgebraicValue::U8(i as u8)));
1966
1967 assert!(!subscriptions.query_reads_from_table(&queries[i].hash, &table_id));
1969 }
1970
1971 let query_id = QueryId::new(2);
1973 let client_id = (client.id.identity, client.id.connection_id);
1974 let removed = subscriptions.remove_subscription(client_id, query_id)?;
1975 assert_eq!(removed.len(), 1);
1976
1977 assert!(subscriptions.table_has_search_param(table_id, ColId(0)));
1980
1981 assert!(!subscriptions.query_reads_from_table(&queries[2].hash, &table_id));
1983 assert!(!subscriptions.query_has_search_arg(queries[2].hash, table_id, ColId(0), AlgebraicValue::U8(2)));
1984
1985 for (i, query) in queries.iter().enumerate().take(5) {
1986 if i != 2 {
1987 assert!(subscriptions.query_has_search_arg(
1988 query.hash,
1989 table_id,
1990 ColId(0),
1991 AlgebraicValue::U8(i as u8)
1992 ));
1993 }
1994 }
1995
1996 subscriptions.remove_all_subscriptions(&client_id);
1998
1999 assert!(!subscriptions.table_has_search_param(table_id, ColId(0)));
2001 for (i, query) in queries.iter().enumerate().take(5) {
2002 assert!(!subscriptions.query_has_search_arg(query.hash, table_id, ColId(0), AlgebraicValue::U8(i as u8)));
2003 }
2004
2005 Ok(())
2006 }
2007
2008 #[test]
2009 fn test_search_args_for_selects() -> ResultTest<()> {
2010 let db = TestDB::durable()?;
2011
2012 let table_id = create_table(&db, "t")?;
2013
2014 let client = Arc::new(client(0));
2015
2016 let runtime = tokio::runtime::Runtime::new().unwrap();
2017 let _rt = runtime.enter();
2018
2019 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2020
2021 let queries = (0u8..5)
2022 .map(|name| format!("select * from t where a = {}", name))
2023 .chain(std::iter::once(String::from("select * from t")))
2024 .map(|sql| compile_plan(&db, &sql))
2025 .collect::<ResultTest<Vec<_>>>()?;
2026
2027 for (i, query) in queries.iter().enumerate() {
2028 subscriptions.add_subscription_multi(client.clone(), vec![query.clone()], QueryId::new(i as u32))?;
2029 }
2030
2031 let hash_for_2 = queries[2].hash;
2032 let hash_for_3 = queries[3].hash;
2033 let hash_for_5 = queries[5].hash;
2034
2035 let table_update = DatabaseTableUpdate {
2041 table_id,
2042 table_name: "t".into(),
2043 inserts: [product![2u8]].into(),
2044 deletes: [product![3u8]].into(),
2045 };
2046
2047 let hashes = subscriptions
2048 .queries_for_table_update(&table_update, &|_, _| None)
2049 .collect::<Vec<_>>();
2050
2051 assert!(hashes.len() == 3);
2052 assert!(hashes.contains(&&hash_for_2));
2053 assert!(hashes.contains(&&hash_for_3));
2054 assert!(hashes.contains(&&hash_for_5));
2055
2056 let table_update = DatabaseTableUpdate {
2059 table_id,
2060 table_name: "t".into(),
2061 inserts: [product![8u8]].into(),
2062 deletes: [product![9u8]].into(),
2063 };
2064
2065 let hashes = subscriptions
2066 .queries_for_table_update(&table_update, &|_, _| None)
2067 .collect::<Vec<_>>();
2068
2069 assert!(hashes.len() == 1);
2070 assert!(hashes.contains(&&hash_for_5));
2071
2072 Ok(())
2073 }
2074
2075 #[test]
2076 fn test_search_args_for_join() -> ResultTest<()> {
2077 let db = TestDB::durable()?;
2078
2079 let schema = [("id", AlgebraicType::U8), ("a", AlgebraicType::U8)];
2080
2081 let t_id = db.create_table_for_test("t", &schema, &[0.into()])?;
2082 let s_id = db.create_table_for_test("s", &schema, &[0.into()])?;
2083
2084 let client = Arc::new(client(0));
2085
2086 let runtime = tokio::runtime::Runtime::new().unwrap();
2087 let _rt = runtime.enter();
2088
2089 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2090
2091 let plan = compile_plan(&db, "select t.* from t join s on t.id = s.id where s.a = 1")?;
2092 let hash = plan.hash;
2093
2094 subscriptions.add_subscription_multi(client.clone(), vec![plan], QueryId::new(0))?;
2095
2096 let table_update = DatabaseTableUpdate {
2100 table_id: t_id,
2101 table_name: "t".into(),
2102 inserts: [product![0u8, 0u8]].into(),
2103 deletes: [].into(),
2104 };
2105
2106 let hashes = subscriptions
2107 .queries_for_table_update(&table_update, &|_, _| None)
2108 .cloned()
2109 .collect::<Vec<_>>();
2110
2111 assert_eq!(hashes, vec![hash]);
2112
2113 let table_update = DatabaseTableUpdate {
2116 table_id: s_id,
2117 table_name: "s".into(),
2118 inserts: [product![0u8, 1u8]].into(),
2119 deletes: [].into(),
2120 };
2121
2122 let hashes = subscriptions
2123 .queries_for_table_update(&table_update, &|_, _| None)
2124 .cloned()
2125 .collect::<Vec<_>>();
2126
2127 assert_eq!(hashes, vec![hash]);
2128
2129 let table_update = DatabaseTableUpdate {
2132 table_id: s_id,
2133 table_name: "s".into(),
2134 inserts: [product![0u8, 2u8]].into(),
2135 deletes: [].into(),
2136 };
2137
2138 let hashes = subscriptions
2139 .queries_for_table_update(&table_update, &|_, _| None)
2140 .cloned()
2141 .collect::<Vec<_>>();
2142
2143 assert!(hashes.is_empty());
2144
2145 Ok(())
2146 }
2147
2148 #[test]
2149 fn test_subscribe_fails_with_duplicate_request_id() -> ResultTest<()> {
2150 let db = TestDB::durable()?;
2151
2152 create_table(&db, "T")?;
2153 let sql = "select * from T";
2154 let plan = compile_plan(&db, sql)?;
2155
2156 let client = Arc::new(client(0));
2157
2158 let query_id: ClientQueryId = QueryId::new(1);
2159
2160 let runtime = tokio::runtime::Runtime::new().unwrap();
2161 let _rt = runtime.enter();
2162
2163 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2164 subscriptions.add_subscription(client.clone(), plan.clone(), query_id)?;
2165
2166 assert!(subscriptions
2167 .add_subscription(client.clone(), plan.clone(), query_id)
2168 .is_err());
2169
2170 Ok(())
2171 }
2172
2173 #[test]
2174 fn test_subscribe_multi_fails_with_duplicate_request_id() -> ResultTest<()> {
2175 let db = TestDB::durable()?;
2176
2177 create_table(&db, "T")?;
2178 let sql = "select * from T";
2179 let plan = compile_plan(&db, sql)?;
2180
2181 let client = Arc::new(client(0));
2182
2183 let query_id: ClientQueryId = QueryId::new(1);
2184
2185 let runtime = tokio::runtime::Runtime::new().unwrap();
2186 let _rt = runtime.enter();
2187
2188 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2189 let result = subscriptions.add_subscription_multi(client.clone(), vec![plan.clone()], query_id)?;
2190 assert_eq!(result[0].hash, plan.hash);
2191
2192 assert!(subscriptions
2193 .add_subscription_multi(client.clone(), vec![plan.clone()], query_id)
2194 .is_err());
2195
2196 Ok(())
2197 }
2198
2199 #[test]
2200 fn test_unsubscribe() -> ResultTest<()> {
2201 let db = TestDB::durable()?;
2202
2203 let table_id = create_table(&db, "T")?;
2204 let sql = "select * from T";
2205 let plan = compile_plan(&db, sql)?;
2206 let hash = plan.hash();
2207
2208 let id = id(0);
2209 let client = Arc::new(client(0));
2210
2211 let runtime = tokio::runtime::Runtime::new().unwrap();
2212 let _rt = runtime.enter();
2213
2214 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2215 subscriptions.set_legacy_subscription(client, [plan]);
2216 subscriptions.remove_all_subscriptions(&id);
2217
2218 assert!(!subscriptions.contains_query(&hash));
2219 assert!(!subscriptions.contains_legacy_subscription(&id, &hash));
2220 assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
2221
2222 Ok(())
2223 }
2224
2225 #[test]
2226 fn test_subscribe_idempotent() -> ResultTest<()> {
2227 let db = TestDB::durable()?;
2228
2229 let table_id = create_table(&db, "T")?;
2230 let sql = "select * from T";
2231 let plan = compile_plan(&db, sql)?;
2232 let hash = plan.hash();
2233
2234 let id = id(0);
2235 let client = Arc::new(client(0));
2236
2237 let runtime = tokio::runtime::Runtime::new().unwrap();
2238 let _rt = runtime.enter();
2239
2240 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2241 subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]);
2242 subscriptions.set_legacy_subscription(client.clone(), [plan.clone()]);
2243
2244 assert!(subscriptions.contains_query(&hash));
2245 assert!(subscriptions.contains_legacy_subscription(&id, &hash));
2246 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
2247
2248 subscriptions.remove_all_subscriptions(&id);
2249
2250 assert!(!subscriptions.contains_query(&hash));
2251 assert!(!subscriptions.contains_legacy_subscription(&id, &hash));
2252 assert!(!subscriptions.query_reads_from_table(&hash, &table_id));
2253
2254 Ok(())
2255 }
2256
2257 #[test]
2258 fn test_share_queries_full() -> ResultTest<()> {
2259 let db = TestDB::durable()?;
2260
2261 let table_id = create_table(&db, "T")?;
2262 let sql = "select * from T";
2263 let plan = compile_plan(&db, sql)?;
2264 let hash = plan.hash();
2265
2266 let id0 = id(0);
2267 let client0 = Arc::new(client(0));
2268
2269 let id1 = id(1);
2270 let client1 = Arc::new(client(1));
2271
2272 let runtime = tokio::runtime::Runtime::new().unwrap();
2273 let _rt = runtime.enter();
2274
2275 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2276 subscriptions.set_legacy_subscription(client0, [plan.clone()]);
2277 subscriptions.set_legacy_subscription(client1, [plan.clone()]);
2278
2279 assert!(subscriptions.contains_query(&hash));
2280 assert!(subscriptions.contains_legacy_subscription(&id0, &hash));
2281 assert!(subscriptions.contains_legacy_subscription(&id1, &hash));
2282 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
2283
2284 subscriptions.remove_all_subscriptions(&id0);
2285
2286 assert!(subscriptions.contains_query(&hash));
2287 assert!(subscriptions.contains_legacy_subscription(&id1, &hash));
2288 assert!(subscriptions.query_reads_from_table(&hash, &table_id));
2289
2290 assert!(!subscriptions.contains_legacy_subscription(&id0, &hash));
2291
2292 Ok(())
2293 }
2294
2295 #[test]
2296 fn test_share_queries_partial() -> ResultTest<()> {
2297 let db = TestDB::durable()?;
2298
2299 let t = create_table(&db, "T")?;
2300 let s = create_table(&db, "S")?;
2301
2302 let scan = "select * from T";
2303 let select0 = "select * from T where a = 0";
2304 let select1 = "select * from S where a = 1";
2305
2306 let plan_scan = compile_plan(&db, scan)?;
2307 let plan_select0 = compile_plan(&db, select0)?;
2308 let plan_select1 = compile_plan(&db, select1)?;
2309
2310 let hash_scan = plan_scan.hash();
2311 let hash_select0 = plan_select0.hash();
2312 let hash_select1 = plan_select1.hash();
2313
2314 let id0 = id(0);
2315 let client0 = Arc::new(client(0));
2316
2317 let id1 = id(1);
2318 let client1 = Arc::new(client(1));
2319
2320 let runtime = tokio::runtime::Runtime::new().unwrap();
2321 let _rt = runtime.enter();
2322 let mut subscriptions = SubscriptionManager::for_test_without_metrics();
2323 subscriptions.set_legacy_subscription(client0, [plan_scan.clone(), plan_select0.clone()]);
2324 subscriptions.set_legacy_subscription(client1, [plan_scan.clone(), plan_select1.clone()]);
2325
2326 assert!(subscriptions.contains_query(&hash_scan));
2327 assert!(subscriptions.contains_query(&hash_select0));
2328 assert!(subscriptions.contains_query(&hash_select1));
2329
2330 assert!(subscriptions.contains_legacy_subscription(&id0, &hash_scan));
2331 assert!(subscriptions.contains_legacy_subscription(&id0, &hash_select0));
2332
2333 assert!(subscriptions.contains_legacy_subscription(&id1, &hash_scan));
2334 assert!(subscriptions.contains_legacy_subscription(&id1, &hash_select1));
2335
2336 assert!(subscriptions.query_reads_from_table(&hash_scan, &t));
2337 assert!(subscriptions.query_has_search_arg(hash_select0, t, ColId(0), AlgebraicValue::U8(0)));
2338 assert!(subscriptions.query_has_search_arg(hash_select1, s, ColId(0), AlgebraicValue::U8(1)));
2339
2340 assert!(!subscriptions.query_reads_from_table(&hash_scan, &s));
2341 assert!(!subscriptions.query_reads_from_table(&hash_select0, &t));
2342 assert!(!subscriptions.query_reads_from_table(&hash_select1, &s));
2343 assert!(!subscriptions.query_reads_from_table(&hash_select0, &s));
2344 assert!(!subscriptions.query_reads_from_table(&hash_select1, &t));
2345
2346 subscriptions.remove_all_subscriptions(&id0);
2347
2348 assert!(subscriptions.contains_query(&hash_scan));
2349 assert!(subscriptions.contains_query(&hash_select1));
2350 assert!(!subscriptions.contains_query(&hash_select0));
2351
2352 assert!(subscriptions.contains_legacy_subscription(&id1, &hash_scan));
2353 assert!(subscriptions.contains_legacy_subscription(&id1, &hash_select1));
2354
2355 assert!(!subscriptions.contains_legacy_subscription(&id0, &hash_scan));
2356 assert!(!subscriptions.contains_legacy_subscription(&id0, &hash_select0));
2357
2358 assert!(subscriptions.query_reads_from_table(&hash_scan, &t));
2359 assert!(subscriptions.query_has_search_arg(hash_select1, s, ColId(0), AlgebraicValue::U8(1)));
2360
2361 assert!(!subscriptions.query_reads_from_table(&hash_select1, &s));
2362 assert!(!subscriptions.query_reads_from_table(&hash_scan, &s));
2363 assert!(!subscriptions.query_reads_from_table(&hash_select1, &t));
2364
2365 Ok(())
2366 }
2367
2368 #[test]
2369 fn test_caller_transaction_update_without_subscription() -> ResultTest<()> {
2370 let db = TestDB::durable()?;
2373
2374 let id0 = Identity::ZERO;
2375 let client0 = ClientActorId::for_test(id0);
2376 let config = ClientConfig::for_test();
2377 let (client0, mut rx) = ClientConnectionSender::dummy_with_channel(client0, config);
2378
2379 let runtime = tokio::runtime::Runtime::new().unwrap();
2380 let _rt = runtime.enter();
2381 let subscriptions = SubscriptionManager::for_test_without_metrics();
2382
2383 let event = Arc::new(ModuleEvent {
2384 timestamp: Timestamp::now(),
2385 caller_identity: id0,
2386 caller_connection_id: Some(client0.id.connection_id),
2387 function_call: ModuleFunctionCall {
2388 reducer: "DummyReducer".into(),
2389 reducer_id: u32::MAX.into(),
2390 args: ArgsTuple::nullary(),
2391 },
2392 status: EventStatus::Committed(DatabaseUpdate::default()),
2393 energy_quanta_used: EnergyQuanta::ZERO,
2394 host_execution_duration: Duration::default(),
2395 request_id: None,
2396 timer: None,
2397 });
2398
2399 db.with_read_only(Workload::Update, |tx| {
2400 subscriptions.eval_updates_sequential(&(&*tx).into(), event, Some(Arc::new(client0)))
2401 });
2402
2403 runtime.block_on(async move {
2404 tokio::time::timeout(Duration::from_millis(20), async move {
2405 rx.recv().await.expect("Expected at least one message");
2406 })
2407 .await
2408 .expect("Timed out waiting for a message to the client");
2409 });
2410
2411 Ok(())
2412 }
2413}