1use std::collections::VecDeque;
2use std::future::poll_fn;
3use std::ops::Deref;
4use std::sync::atomic::Ordering;
5use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
6use std::sync::Arc;
7use std::task::{Context, Poll};
8use std::time::Instant;
9
10use super::messages::{OneOffQueryResponseMessage, SerializableMessage};
11use super::{message_handlers, ClientActorId, MessageHandleError};
12use crate::error::DBError;
13use crate::host::module_host::ClientConnectedError;
14use crate::host::{ModuleHost, NoSuchModule, ReducerArgs, ReducerCallError, ReducerCallResult};
15use crate::messages::websocket::Subscribe;
16use crate::util::asyncify;
17use crate::util::prometheus_handle::IntGaugeExt;
18use crate::worker_metrics::WORKER_METRICS;
19use bytes::Bytes;
20use bytestring::ByteString;
21use derive_more::From;
22use futures::prelude::*;
23use prometheus::{Histogram, IntCounter, IntGauge};
24use spacetimedb_client_api_messages::websocket::{
25 BsatnFormat, CallReducerFlags, Compression, FormatSwitch, JsonFormat, SubscribeMulti, SubscribeSingle, Unsubscribe,
26 UnsubscribeMulti,
27};
28use spacetimedb_lib::identity::RequestId;
29use spacetimedb_lib::metrics::ExecutionMetrics;
30use spacetimedb_lib::Identity;
31use tokio::sync::{mpsc, oneshot, watch};
32use tokio::task::AbortHandle;
33
34#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
35pub enum Protocol {
36 Text,
37 Binary,
38}
39
40impl Protocol {
41 pub fn as_str(self) -> &'static str {
42 match self {
43 Protocol::Text => "text",
44 Protocol::Binary => "binary",
45 }
46 }
47
48 pub(crate) fn assert_matches_format_switch<B, J>(self, fs: &FormatSwitch<B, J>) {
49 match (self, fs) {
50 (Protocol::Text, FormatSwitch::Json(_)) | (Protocol::Binary, FormatSwitch::Bsatn(_)) => {}
51 _ => unreachable!("requested protocol does not match output format"),
52 }
53 }
54}
55
56#[derive(Clone, Copy, Debug)]
57pub struct ClientConfig {
58 pub protocol: Protocol,
60 pub compression: Compression,
62 pub tx_update_full: bool,
66}
67
68impl ClientConfig {
69 pub fn for_test() -> ClientConfig {
70 Self {
71 protocol: Protocol::Binary,
72 compression: <_>::default(),
73 tx_update_full: true,
74 }
75 }
76}
77
78#[derive(Debug)]
79pub struct ClientConnectionSender {
80 pub id: ClientActorId,
81 pub config: ClientConfig,
82 sendtx: mpsc::Sender<SerializableMessage>,
83 abort_handle: AbortHandle,
84 cancelled: AtomicBool,
85
86 metrics: Option<ClientConnectionMetrics>,
92}
93
94#[derive(Debug)]
95pub struct ClientConnectionMetrics {
96 pub websocket_request_msg_size: Histogram,
97 pub websocket_requests: IntCounter,
98
99 pub sendtx_queue_size: IntGauge,
107}
108
109impl ClientConnectionMetrics {
110 fn new(database_identity: Identity, protocol: Protocol) -> Self {
111 let message_kind = protocol.as_str();
112 let websocket_request_msg_size = WORKER_METRICS
113 .websocket_request_msg_size
114 .with_label_values(&database_identity, message_kind);
115 let websocket_requests = WORKER_METRICS
116 .websocket_requests
117 .with_label_values(&database_identity, message_kind);
118 let sendtx_queue_size = WORKER_METRICS
119 .total_outgoing_queue_length
120 .with_label_values(&database_identity);
121
122 Self {
123 websocket_request_msg_size,
124 websocket_requests,
125 sendtx_queue_size,
126 }
127 }
128}
129
130#[derive(Debug, thiserror::Error)]
131pub enum ClientSendError {
132 #[error("client disconnected")]
133 Disconnected,
134 #[error("client was not responding and has been disconnected")]
135 Cancelled,
136}
137
138impl ClientConnectionSender {
139 pub fn dummy_with_channel(id: ClientActorId, config: ClientConfig) -> (Self, MeteredReceiver<SerializableMessage>) {
140 let (sendtx, rx) = mpsc::channel(1);
141 let abort_handle = match tokio::runtime::Handle::try_current() {
143 Ok(h) => h.spawn(async {}).abort_handle(),
144 Err(_) => tokio::runtime::Runtime::new().unwrap().spawn(async {}).abort_handle(),
145 };
146
147 let rx = MeteredReceiver::new(rx);
148 let cancelled = AtomicBool::new(false);
149 let sender = Self {
150 id,
151 config,
152 sendtx,
153 abort_handle,
154 cancelled,
155 metrics: None,
156 };
157 (sender, rx)
158 }
159
160 pub fn dummy(id: ClientActorId, config: ClientConfig) -> Self {
161 Self::dummy_with_channel(id, config).0
162 }
163
164 pub fn is_cancelled(&self) -> bool {
165 self.cancelled.load(Ordering::Relaxed)
166 }
167
168 pub fn send_message(&self, message: impl Into<SerializableMessage>) -> Result<(), ClientSendError> {
171 self.send(message.into())
172 }
173
174 fn send(&self, message: SerializableMessage) -> Result<(), ClientSendError> {
175 if self.cancelled.load(Relaxed) {
176 return Err(ClientSendError::Cancelled);
177 }
178
179 match self.sendtx.try_send(message) {
180 Err(mpsc::error::TrySendError::Full(_)) => {
181 tracing::warn!(identity = %self.id.identity, connection_id = %self.id.connection_id, "client channel capacity exceeded");
184 self.abort_handle.abort();
185 self.cancelled.store(true, Ordering::Relaxed);
186 return Err(ClientSendError::Cancelled);
187 }
188 Err(mpsc::error::TrySendError::Closed(_)) => return Err(ClientSendError::Disconnected),
189 Ok(()) => {
190 if let Some(metrics) = &self.metrics {
195 metrics.sendtx_queue_size.inc();
196 }
197 }
198 }
199
200 Ok(())
201 }
202
203 pub(crate) fn observe_websocket_request_message(&self, message: &DataMessage) {
204 if let Some(metrics) = &self.metrics {
205 metrics.websocket_request_msg_size.observe(message.len() as f64);
206 metrics.websocket_requests.inc();
207 }
208 }
209}
210
211#[derive(Clone)]
212#[non_exhaustive]
213pub struct ClientConnection {
214 sender: Arc<ClientConnectionSender>,
215 pub replica_id: u64,
216 pub module: ModuleHost,
217 module_rx: watch::Receiver<ModuleHost>,
218}
219
220impl Deref for ClientConnection {
221 type Target = ClientConnectionSender;
222 fn deref(&self) -> &Self::Target {
223 &self.sender
224 }
225}
226
227#[derive(Debug, From)]
228pub enum DataMessage {
229 Text(ByteString),
230 Binary(Bytes),
231}
232
233impl From<String> for DataMessage {
234 fn from(value: String) -> Self {
235 ByteString::from(value).into()
236 }
237}
238
239impl From<Vec<u8>> for DataMessage {
240 fn from(value: Vec<u8>) -> Self {
241 Bytes::from(value).into()
242 }
243}
244
245impl DataMessage {
246 pub fn len(&self) -> usize {
248 match self {
249 Self::Text(s) => s.len(),
250 Self::Binary(b) => b.len(),
251 }
252 }
253
254 #[must_use]
256 pub fn is_empty(&self) -> bool {
257 self.len() == 0
258 }
259
260 pub fn allocation(&self) -> Bytes {
262 match self {
263 DataMessage::Text(alloc) => alloc.as_bytes().clone(),
264 DataMessage::Binary(alloc) => alloc.clone(),
265 }
266 }
267}
268
269pub struct MeteredDeque<T> {
272 inner: VecDeque<T>,
273 gauge: IntGauge,
274}
275
276impl<T> MeteredDeque<T> {
277 pub fn new(gauge: IntGauge) -> Self {
278 Self {
279 inner: VecDeque::new(),
280 gauge,
281 }
282 }
283
284 pub fn pop_front(&mut self) -> Option<T> {
285 self.inner.pop_front().inspect(|_| {
286 self.gauge.dec();
287 })
288 }
289
290 pub fn pop_back(&mut self) -> Option<T> {
291 self.inner.pop_back().inspect(|_| {
292 self.gauge.dec();
293 })
294 }
295
296 pub fn push_front(&mut self, value: T) {
297 self.gauge.inc();
298 self.inner.push_front(value);
299 }
300
301 pub fn push_back(&mut self, value: T) {
302 self.gauge.inc();
303 self.inner.push_back(value);
304 }
305
306 pub fn len(&self) -> usize {
307 self.inner.len()
308 }
309
310 pub fn is_empty(&self) -> bool {
311 self.inner.is_empty()
312 }
313}
314
315impl<T> Drop for MeteredDeque<T> {
316 fn drop(&mut self) {
317 self.gauge.sub(self.inner.len() as _);
319 }
320}
321
322pub struct MeteredReceiver<T> {
325 inner: mpsc::Receiver<T>,
326 gauge: Option<IntGauge>,
327}
328
329impl<T> MeteredReceiver<T> {
330 pub fn new(inner: mpsc::Receiver<T>) -> Self {
331 Self { inner, gauge: None }
332 }
333
334 pub fn with_gauge(inner: mpsc::Receiver<T>, gauge: IntGauge) -> Self {
335 Self {
336 inner,
337 gauge: Some(gauge),
338 }
339 }
340
341 pub async fn recv(&mut self) -> Option<T> {
342 poll_fn(|cx| self.poll_recv(cx)).await
343 }
344
345 pub async fn recv_many(&mut self, buf: &mut Vec<T>, max: usize) -> usize {
346 poll_fn(|cx| self.poll_recv_many(cx, buf, max)).await
347 }
348
349 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
350 self.inner.poll_recv(cx).map(|maybe_item| {
351 maybe_item.inspect(|_| {
352 if let Some(gauge) = &self.gauge {
353 gauge.dec()
354 }
355 })
356 })
357 }
358
359 pub fn poll_recv_many(&mut self, cx: &mut Context<'_>, buf: &mut Vec<T>, max: usize) -> Poll<usize> {
360 self.inner.poll_recv_many(cx, buf, max).map(|n| {
361 if let Some(gauge) = &self.gauge {
362 gauge.sub(n as _);
363 }
364 n
365 })
366 }
367
368 pub fn len(&self) -> usize {
369 self.inner.len()
370 }
371
372 pub fn is_empty(&self) -> bool {
373 self.inner.is_empty()
374 }
375
376 pub fn close(&mut self) {
377 self.inner.close();
378 }
379}
380
381impl<T> Drop for MeteredReceiver<T> {
382 fn drop(&mut self) {
383 if let Some(gauge) = &self.gauge {
385 gauge.sub(self.inner.len() as _);
386 }
387 }
388}
389
390const CLIENT_CHANNEL_CAPACITY: usize = 16 * KB;
393const KB: usize = 1024;
394
395#[non_exhaustive]
405pub struct Connected {
406 _private: (),
407}
408
409impl ClientConnection {
410 pub async fn call_client_connected_maybe_reject(
416 module_rx: &mut watch::Receiver<ModuleHost>,
417 id: ClientActorId,
418 ) -> Result<Connected, ClientConnectedError> {
419 let module = module_rx.borrow_and_update().clone();
420 module.call_identity_connected(id.identity, id.connection_id).await?;
421 Ok(Connected { _private: () })
422 }
423
424 pub async fn spawn<Fut>(
431 id: ClientActorId,
432 config: ClientConfig,
433 replica_id: u64,
434 mut module_rx: watch::Receiver<ModuleHost>,
435 actor: impl FnOnce(ClientConnection, MeteredReceiver<SerializableMessage>) -> Fut,
436 _proof_of_client_connected_call: Connected,
437 ) -> ClientConnection
438 where
439 Fut: Future<Output = ()> + Send + 'static,
440 {
441 let module = module_rx.borrow_and_update().clone();
446
447 let (sendtx, sendrx) = mpsc::channel::<SerializableMessage>(CLIENT_CHANNEL_CAPACITY);
448
449 let (fut_tx, fut_rx) = oneshot::channel::<Fut>();
450 let module_info = module.info.clone();
452 let database_identity = module_info.database_identity;
453 let abort_handle = tokio::spawn(async move {
454 let Ok(fut) = fut_rx.await else { return };
455
456 let _gauge_guard = module_info.metrics.connected_clients.inc_scope();
457 module_info.metrics.ws_clients_spawned.inc();
458 scopeguard::defer! {
459 let database_identity = module_info.database_identity;
460 let client_identity = id.identity;
461 log::warn!("websocket connection aborted for client identity `{client_identity}` and database identity `{database_identity}`");
462 module_info.metrics.ws_clients_aborted.inc();
463 };
464
465 fut.await
466 })
467 .abort_handle();
468
469 let metrics = ClientConnectionMetrics::new(database_identity, config.protocol);
470 let sendrx = MeteredReceiver::with_gauge(sendrx, metrics.sendtx_queue_size.clone());
471
472 let sender = Arc::new(ClientConnectionSender {
473 id,
474 config,
475 sendtx,
476 abort_handle,
477 cancelled: AtomicBool::new(false),
478 metrics: Some(metrics),
479 });
480 let this = Self {
481 sender,
482 replica_id,
483 module,
484 module_rx,
485 };
486
487 let actor_fut = actor(this.clone(), sendrx);
488 let _ = fut_tx.send(actor_fut);
490
491 this
492 }
493
494 pub fn dummy(
495 id: ClientActorId,
496 config: ClientConfig,
497 replica_id: u64,
498 mut module_rx: watch::Receiver<ModuleHost>,
499 ) -> Self {
500 let module = module_rx.borrow_and_update().clone();
501 Self {
502 sender: Arc::new(ClientConnectionSender::dummy(id, config)),
503 replica_id,
504 module,
505 module_rx,
506 }
507 }
508
509 pub fn sender(&self) -> Arc<ClientConnectionSender> {
510 self.sender.clone()
511 }
512
513 #[inline]
514 pub fn handle_message(
515 &self,
516 message: impl Into<DataMessage>,
517 timer: Instant,
518 ) -> impl Future<Output = Result<(), MessageHandleError>> + '_ {
519 message_handlers::handle(self, message.into(), timer)
520 }
521
522 pub async fn watch_module_host(&mut self) -> Result<(), NoSuchModule> {
523 match self.module_rx.changed().await {
524 Ok(()) => {
525 self.module = self.module_rx.borrow_and_update().clone();
526 Ok(())
527 }
528 Err(_) => Err(NoSuchModule),
529 }
530 }
531
532 pub async fn call_reducer(
533 &self,
534 reducer: &str,
535 args: ReducerArgs,
536 request_id: RequestId,
537 timer: Instant,
538 flags: CallReducerFlags,
539 ) -> Result<ReducerCallResult, ReducerCallError> {
540 let caller = match flags {
541 CallReducerFlags::FullUpdate => Some(self.sender()),
542 CallReducerFlags::NoSuccessNotify => None,
545 };
546
547 self.module
548 .call_reducer(
549 self.id.identity,
550 Some(self.id.connection_id),
551 caller,
552 Some(request_id),
553 Some(timer),
554 reducer,
555 args,
556 )
557 .await
558 }
559
560 pub async fn subscribe_single(
561 &self,
562 subscription: SubscribeSingle,
563 timer: Instant,
564 ) -> Result<Option<ExecutionMetrics>, DBError> {
565 let me = self.clone();
566 self.module
567 .on_module_thread("subscribe_single", move || {
568 me.module
569 .subscriptions()
570 .add_single_subscription(me.sender, subscription, timer, None)
571 })
572 .await?
573 }
574
575 pub async fn unsubscribe(&self, request: Unsubscribe, timer: Instant) -> Result<Option<ExecutionMetrics>, DBError> {
576 let me = self.clone();
577 asyncify(move || {
578 me.module
579 .subscriptions()
580 .remove_single_subscription(me.sender, request, timer)
581 })
582 .await
583 }
584
585 pub async fn subscribe_multi(
586 &self,
587 request: SubscribeMulti,
588 timer: Instant,
589 ) -> Result<Option<ExecutionMetrics>, DBError> {
590 let me = self.clone();
591 self.module
592 .on_module_thread("subscribe_multi", move || {
593 me.module
594 .subscriptions()
595 .add_multi_subscription(me.sender, request, timer, None)
596 })
597 .await?
598 }
599
600 pub async fn unsubscribe_multi(
601 &self,
602 request: UnsubscribeMulti,
603 timer: Instant,
604 ) -> Result<Option<ExecutionMetrics>, DBError> {
605 let me = self.clone();
606 self.module
607 .on_module_thread("unsubscribe_multi", move || {
608 me.module
609 .subscriptions()
610 .remove_multi_subscription(me.sender, request, timer)
611 })
612 .await?
613 }
614
615 pub async fn subscribe(&self, subscription: Subscribe, timer: Instant) -> Result<ExecutionMetrics, DBError> {
616 let me = self.clone();
617 asyncify(move || {
618 me.module
619 .subscriptions()
620 .add_legacy_subscriber(me.sender, subscription, timer, None)
621 })
622 .await
623 }
624
625 pub async fn one_off_query_json(
626 &self,
627 query: &str,
628 message_id: &[u8],
629 timer: Instant,
630 ) -> Result<(), anyhow::Error> {
631 self.module
632 .one_off_query::<JsonFormat>(
633 self.id.identity,
634 query.to_owned(),
635 self.sender.clone(),
636 message_id.to_owned(),
637 timer,
638 |msg: OneOffQueryResponseMessage<JsonFormat>| msg.into(),
639 )
640 .await
641 }
642
643 pub async fn one_off_query_bsatn(
644 &self,
645 query: &str,
646 message_id: &[u8],
647 timer: Instant,
648 ) -> Result<(), anyhow::Error> {
649 self.module
650 .one_off_query::<BsatnFormat>(
651 self.id.identity,
652 query.to_owned(),
653 self.sender.clone(),
654 message_id.to_owned(),
655 timer,
656 |msg: OneOffQueryResponseMessage<BsatnFormat>| msg.into(),
657 )
658 .await
659 }
660
661 pub async fn disconnect(self) {
662 self.module.disconnect_client(self.id).await
663 }
664}