spacetimedb/client/
client_connection.rs

1use std::collections::VecDeque;
2use std::future::poll_fn;
3use std::ops::Deref;
4use std::sync::atomic::Ordering;
5use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
6use std::sync::Arc;
7use std::task::{Context, Poll};
8use std::time::Instant;
9
10use super::messages::{OneOffQueryResponseMessage, SerializableMessage};
11use super::{message_handlers, ClientActorId, MessageHandleError};
12use crate::error::DBError;
13use crate::host::module_host::ClientConnectedError;
14use crate::host::{ModuleHost, NoSuchModule, ReducerArgs, ReducerCallError, ReducerCallResult};
15use crate::messages::websocket::Subscribe;
16use crate::util::asyncify;
17use crate::util::prometheus_handle::IntGaugeExt;
18use crate::worker_metrics::WORKER_METRICS;
19use bytes::Bytes;
20use bytestring::ByteString;
21use derive_more::From;
22use futures::prelude::*;
23use prometheus::{Histogram, IntCounter, IntGauge};
24use spacetimedb_client_api_messages::websocket::{
25    BsatnFormat, CallReducerFlags, Compression, FormatSwitch, JsonFormat, SubscribeMulti, SubscribeSingle, Unsubscribe,
26    UnsubscribeMulti,
27};
28use spacetimedb_lib::identity::RequestId;
29use spacetimedb_lib::metrics::ExecutionMetrics;
30use spacetimedb_lib::Identity;
31use tokio::sync::{mpsc, oneshot, watch};
32use tokio::task::AbortHandle;
33
34#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
35pub enum Protocol {
36    Text,
37    Binary,
38}
39
40impl Protocol {
41    pub fn as_str(self) -> &'static str {
42        match self {
43            Protocol::Text => "text",
44            Protocol::Binary => "binary",
45        }
46    }
47
48    pub(crate) fn assert_matches_format_switch<B, J>(self, fs: &FormatSwitch<B, J>) {
49        match (self, fs) {
50            (Protocol::Text, FormatSwitch::Json(_)) | (Protocol::Binary, FormatSwitch::Bsatn(_)) => {}
51            _ => unreachable!("requested protocol does not match output format"),
52        }
53    }
54}
55
56#[derive(Clone, Copy, Debug)]
57pub struct ClientConfig {
58    /// The client's desired protocol (format) when the host replies.
59    pub protocol: Protocol,
60    /// The client's desired (conditional) compression algorithm, if any.
61    pub compression: Compression,
62    /// Whether the client prefers full [`TransactionUpdate`]s
63    /// rather than  [`TransactionUpdateLight`]s on a successful update.
64    // TODO(centril): As more knobs are added, make this into a bitfield (when there's time).
65    pub tx_update_full: bool,
66}
67
68impl ClientConfig {
69    pub fn for_test() -> ClientConfig {
70        Self {
71            protocol: Protocol::Binary,
72            compression: <_>::default(),
73            tx_update_full: true,
74        }
75    }
76}
77
78#[derive(Debug)]
79pub struct ClientConnectionSender {
80    pub id: ClientActorId,
81    pub config: ClientConfig,
82    sendtx: mpsc::Sender<SerializableMessage>,
83    abort_handle: AbortHandle,
84    cancelled: AtomicBool,
85
86    /// Handles on Prometheus metrics related to connections to this database.
87    ///
88    /// Will be `None` when constructed by [`ClientConnectionSender::dummy_with_channel`]
89    /// or [`ClientConnectionSender::dummy`], which are used in tests.
90    /// Will be `Some` whenever this `ClientConnectionSender` is wired up to an actual client connection.
91    metrics: Option<ClientConnectionMetrics>,
92}
93
94#[derive(Debug)]
95pub struct ClientConnectionMetrics {
96    pub websocket_request_msg_size: Histogram,
97    pub websocket_requests: IntCounter,
98
99    /// The `total_outgoing_queue_length` metric labeled with this database's `Identity`,
100    /// which we'll increment whenever sending a message.
101    ///
102    /// This metric will be decremented, and cleaned up,
103    /// by `ws_client_actor_inner` in client-api/src/routes/subscribe.rs.
104    /// Care must be taken not to increment it after the client has disconnected
105    /// and performed its clean-up.
106    pub sendtx_queue_size: IntGauge,
107}
108
109impl ClientConnectionMetrics {
110    fn new(database_identity: Identity, protocol: Protocol) -> Self {
111        let message_kind = protocol.as_str();
112        let websocket_request_msg_size = WORKER_METRICS
113            .websocket_request_msg_size
114            .with_label_values(&database_identity, message_kind);
115        let websocket_requests = WORKER_METRICS
116            .websocket_requests
117            .with_label_values(&database_identity, message_kind);
118        let sendtx_queue_size = WORKER_METRICS
119            .total_outgoing_queue_length
120            .with_label_values(&database_identity);
121
122        Self {
123            websocket_request_msg_size,
124            websocket_requests,
125            sendtx_queue_size,
126        }
127    }
128}
129
130#[derive(Debug, thiserror::Error)]
131pub enum ClientSendError {
132    #[error("client disconnected")]
133    Disconnected,
134    #[error("client was not responding and has been disconnected")]
135    Cancelled,
136}
137
138impl ClientConnectionSender {
139    pub fn dummy_with_channel(id: ClientActorId, config: ClientConfig) -> (Self, MeteredReceiver<SerializableMessage>) {
140        let (sendtx, rx) = mpsc::channel(1);
141        // just make something up, it doesn't need to be attached to a real task
142        let abort_handle = match tokio::runtime::Handle::try_current() {
143            Ok(h) => h.spawn(async {}).abort_handle(),
144            Err(_) => tokio::runtime::Runtime::new().unwrap().spawn(async {}).abort_handle(),
145        };
146
147        let rx = MeteredReceiver::new(rx);
148        let cancelled = AtomicBool::new(false);
149        let sender = Self {
150            id,
151            config,
152            sendtx,
153            abort_handle,
154            cancelled,
155            metrics: None,
156        };
157        (sender, rx)
158    }
159
160    pub fn dummy(id: ClientActorId, config: ClientConfig) -> Self {
161        Self::dummy_with_channel(id, config).0
162    }
163
164    pub fn is_cancelled(&self) -> bool {
165        self.cancelled.load(Ordering::Relaxed)
166    }
167
168    /// Send a message to the client. For data-related messages, you should probably use
169    /// `BroadcastQueue::send` to ensure that the client sees data messages in a consistent order.
170    pub fn send_message(&self, message: impl Into<SerializableMessage>) -> Result<(), ClientSendError> {
171        self.send(message.into())
172    }
173
174    fn send(&self, message: SerializableMessage) -> Result<(), ClientSendError> {
175        if self.cancelled.load(Relaxed) {
176            return Err(ClientSendError::Cancelled);
177        }
178
179        match self.sendtx.try_send(message) {
180            Err(mpsc::error::TrySendError::Full(_)) => {
181                // we've hit CLIENT_CHANNEL_CAPACITY messages backed up in
182                // the channel, so forcibly kick the client
183                tracing::warn!(identity = %self.id.identity, connection_id = %self.id.connection_id, "client channel capacity exceeded");
184                self.abort_handle.abort();
185                self.cancelled.store(true, Ordering::Relaxed);
186                return Err(ClientSendError::Cancelled);
187            }
188            Err(mpsc::error::TrySendError::Closed(_)) => return Err(ClientSendError::Disconnected),
189            Ok(()) => {
190                // If we successfully pushed a message into the queue, increment the queue size metric.
191                // Don't do this before pushing because, if the client has disconnected,
192                // it will already have performed its clean-up,
193                // and so would never perform the corresponding `dec` to this `inc`.
194                if let Some(metrics) = &self.metrics {
195                    metrics.sendtx_queue_size.inc();
196                }
197            }
198        }
199
200        Ok(())
201    }
202
203    pub(crate) fn observe_websocket_request_message(&self, message: &DataMessage) {
204        if let Some(metrics) = &self.metrics {
205            metrics.websocket_request_msg_size.observe(message.len() as f64);
206            metrics.websocket_requests.inc();
207        }
208    }
209}
210
211#[derive(Clone)]
212#[non_exhaustive]
213pub struct ClientConnection {
214    sender: Arc<ClientConnectionSender>,
215    pub replica_id: u64,
216    pub module: ModuleHost,
217    module_rx: watch::Receiver<ModuleHost>,
218}
219
220impl Deref for ClientConnection {
221    type Target = ClientConnectionSender;
222    fn deref(&self) -> &Self::Target {
223        &self.sender
224    }
225}
226
227#[derive(Debug, From)]
228pub enum DataMessage {
229    Text(ByteString),
230    Binary(Bytes),
231}
232
233impl From<String> for DataMessage {
234    fn from(value: String) -> Self {
235        ByteString::from(value).into()
236    }
237}
238
239impl From<Vec<u8>> for DataMessage {
240    fn from(value: Vec<u8>) -> Self {
241        Bytes::from(value).into()
242    }
243}
244
245impl DataMessage {
246    /// Returns the number of bytes this message consists of.
247    pub fn len(&self) -> usize {
248        match self {
249            Self::Text(s) => s.len(),
250            Self::Binary(b) => b.len(),
251        }
252    }
253
254    /// Is the message empty?
255    #[must_use]
256    pub fn is_empty(&self) -> bool {
257        self.len() == 0
258    }
259
260    /// Returns a handle to the underlying allocation of the message without consuming it.
261    pub fn allocation(&self) -> Bytes {
262        match self {
263            DataMessage::Text(alloc) => alloc.as_bytes().clone(),
264            DataMessage::Binary(alloc) => alloc.clone(),
265        }
266    }
267}
268
269/// Wraps a [VecDeque] with a gauge for tracking its size.
270/// We subtract its size from the gauge on drop to avoid leaking the metric.
271pub struct MeteredDeque<T> {
272    inner: VecDeque<T>,
273    gauge: IntGauge,
274}
275
276impl<T> MeteredDeque<T> {
277    pub fn new(gauge: IntGauge) -> Self {
278        Self {
279            inner: VecDeque::new(),
280            gauge,
281        }
282    }
283
284    pub fn pop_front(&mut self) -> Option<T> {
285        self.inner.pop_front().inspect(|_| {
286            self.gauge.dec();
287        })
288    }
289
290    pub fn pop_back(&mut self) -> Option<T> {
291        self.inner.pop_back().inspect(|_| {
292            self.gauge.dec();
293        })
294    }
295
296    pub fn push_front(&mut self, value: T) {
297        self.gauge.inc();
298        self.inner.push_front(value);
299    }
300
301    pub fn push_back(&mut self, value: T) {
302        self.gauge.inc();
303        self.inner.push_back(value);
304    }
305
306    pub fn len(&self) -> usize {
307        self.inner.len()
308    }
309
310    pub fn is_empty(&self) -> bool {
311        self.inner.is_empty()
312    }
313}
314
315impl<T> Drop for MeteredDeque<T> {
316    fn drop(&mut self) {
317        // Record the number of elements still in the deque on drop
318        self.gauge.sub(self.inner.len() as _);
319    }
320}
321
322/// Wraps the receiving end of a channel with a gauge for tracking the size of the channel.
323/// We subtract the size of the channel from the gauge on drop to avoid leaking the metric.
324pub struct MeteredReceiver<T> {
325    inner: mpsc::Receiver<T>,
326    gauge: Option<IntGauge>,
327}
328
329impl<T> MeteredReceiver<T> {
330    pub fn new(inner: mpsc::Receiver<T>) -> Self {
331        Self { inner, gauge: None }
332    }
333
334    pub fn with_gauge(inner: mpsc::Receiver<T>, gauge: IntGauge) -> Self {
335        Self {
336            inner,
337            gauge: Some(gauge),
338        }
339    }
340
341    pub async fn recv(&mut self) -> Option<T> {
342        poll_fn(|cx| self.poll_recv(cx)).await
343    }
344
345    pub async fn recv_many(&mut self, buf: &mut Vec<T>, max: usize) -> usize {
346        poll_fn(|cx| self.poll_recv_many(cx, buf, max)).await
347    }
348
349    pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
350        self.inner.poll_recv(cx).map(|maybe_item| {
351            maybe_item.inspect(|_| {
352                if let Some(gauge) = &self.gauge {
353                    gauge.dec()
354                }
355            })
356        })
357    }
358
359    pub fn poll_recv_many(&mut self, cx: &mut Context<'_>, buf: &mut Vec<T>, max: usize) -> Poll<usize> {
360        self.inner.poll_recv_many(cx, buf, max).map(|n| {
361            if let Some(gauge) = &self.gauge {
362                gauge.sub(n as _);
363            }
364            n
365        })
366    }
367
368    pub fn len(&self) -> usize {
369        self.inner.len()
370    }
371
372    pub fn is_empty(&self) -> bool {
373        self.inner.is_empty()
374    }
375
376    pub fn close(&mut self) {
377        self.inner.close();
378    }
379}
380
381impl<T> Drop for MeteredReceiver<T> {
382    fn drop(&mut self) {
383        // Record the number of elements still in the channel on drop
384        if let Some(gauge) = &self.gauge {
385            gauge.sub(self.inner.len() as _);
386        }
387    }
388}
389
390// if a client racks up this many messages in the queue without ACK'ing
391// anything, we boot 'em.
392const CLIENT_CHANNEL_CAPACITY: usize = 16 * KB;
393const KB: usize = 1024;
394
395/// Value returned by [`ClientConnection::call_client_connected_maybe_reject`]
396/// and consumed by [`ClientConnection::spawn`] which acts as a proof that the client is authorized.
397///
398/// Because this struct does not capture the module or database info or the client connection info,
399/// a malicious caller could [`ClientConnected::call_client_connected_maybe_reject`] for one client
400/// and then use the resulting `Connected` token to [`ClientConnection::spawn`] for a different client.
401/// We're not particularly worried about that.
402/// This token exists as a sanity check that non-malicious callers don't accidentally [`ClientConnection::spawn`]
403/// for an unauthorized client.
404#[non_exhaustive]
405pub struct Connected {
406    _private: (),
407}
408
409impl ClientConnection {
410    /// Call the database at `module_rx`'s `client_connection` reducer, if any,
411    /// and return `Err` if it signals rejecting this client's connection.
412    ///
413    /// Call this method before [`Self::spawn`]
414    /// and pass the returned [`Connected`] to [`Self::spawn`] as proof that the client is authorized.
415    pub async fn call_client_connected_maybe_reject(
416        module_rx: &mut watch::Receiver<ModuleHost>,
417        id: ClientActorId,
418    ) -> Result<Connected, ClientConnectedError> {
419        let module = module_rx.borrow_and_update().clone();
420        module.call_identity_connected(id.identity, id.connection_id).await?;
421        Ok(Connected { _private: () })
422    }
423
424    /// Spawn a new [`ClientConnection`] for a WebSocket subscriber.
425    ///
426    /// Callers should first call [`Self::call_client_connected_maybe_reject`]
427    /// to verify that the database at `module_rx` approves of this connection,
428    /// and should not invoke this method if that call returns an error,
429    /// and pass the returned [`Connected`] as `_proof_of_client_connected_call`.
430    pub async fn spawn<Fut>(
431        id: ClientActorId,
432        config: ClientConfig,
433        replica_id: u64,
434        mut module_rx: watch::Receiver<ModuleHost>,
435        actor: impl FnOnce(ClientConnection, MeteredReceiver<SerializableMessage>) -> Fut,
436        _proof_of_client_connected_call: Connected,
437    ) -> ClientConnection
438    where
439        Fut: Future<Output = ()> + Send + 'static,
440    {
441        // Add this client as a subscriber
442        // TODO: Right now this is connecting clients directly to a replica, but their requests should be
443        // logically subscribed to the database, not any particular replica. We should handle failover for
444        // them and stuff. Not right now though.
445        let module = module_rx.borrow_and_update().clone();
446
447        let (sendtx, sendrx) = mpsc::channel::<SerializableMessage>(CLIENT_CHANNEL_CAPACITY);
448
449        let (fut_tx, fut_rx) = oneshot::channel::<Fut>();
450        // weird dance so that we can get an abort_handle into ClientConnection
451        let module_info = module.info.clone();
452        let database_identity = module_info.database_identity;
453        let abort_handle = tokio::spawn(async move {
454            let Ok(fut) = fut_rx.await else { return };
455
456            let _gauge_guard = module_info.metrics.connected_clients.inc_scope();
457            module_info.metrics.ws_clients_spawned.inc();
458            scopeguard::defer! {
459                let database_identity = module_info.database_identity;
460                let client_identity = id.identity;
461                log::warn!("websocket connection aborted for client identity `{client_identity}` and database identity `{database_identity}`");
462                module_info.metrics.ws_clients_aborted.inc();
463            };
464
465            fut.await
466        })
467        .abort_handle();
468
469        let metrics = ClientConnectionMetrics::new(database_identity, config.protocol);
470        let sendrx = MeteredReceiver::with_gauge(sendrx, metrics.sendtx_queue_size.clone());
471
472        let sender = Arc::new(ClientConnectionSender {
473            id,
474            config,
475            sendtx,
476            abort_handle,
477            cancelled: AtomicBool::new(false),
478            metrics: Some(metrics),
479        });
480        let this = Self {
481            sender,
482            replica_id,
483            module,
484            module_rx,
485        };
486
487        let actor_fut = actor(this.clone(), sendrx);
488        // if this fails, the actor() function called .abort(), which like... okay, I guess?
489        let _ = fut_tx.send(actor_fut);
490
491        this
492    }
493
494    pub fn dummy(
495        id: ClientActorId,
496        config: ClientConfig,
497        replica_id: u64,
498        mut module_rx: watch::Receiver<ModuleHost>,
499    ) -> Self {
500        let module = module_rx.borrow_and_update().clone();
501        Self {
502            sender: Arc::new(ClientConnectionSender::dummy(id, config)),
503            replica_id,
504            module,
505            module_rx,
506        }
507    }
508
509    pub fn sender(&self) -> Arc<ClientConnectionSender> {
510        self.sender.clone()
511    }
512
513    #[inline]
514    pub fn handle_message(
515        &self,
516        message: impl Into<DataMessage>,
517        timer: Instant,
518    ) -> impl Future<Output = Result<(), MessageHandleError>> + '_ {
519        message_handlers::handle(self, message.into(), timer)
520    }
521
522    pub async fn watch_module_host(&mut self) -> Result<(), NoSuchModule> {
523        match self.module_rx.changed().await {
524            Ok(()) => {
525                self.module = self.module_rx.borrow_and_update().clone();
526                Ok(())
527            }
528            Err(_) => Err(NoSuchModule),
529        }
530    }
531
532    pub async fn call_reducer(
533        &self,
534        reducer: &str,
535        args: ReducerArgs,
536        request_id: RequestId,
537        timer: Instant,
538        flags: CallReducerFlags,
539    ) -> Result<ReducerCallResult, ReducerCallError> {
540        let caller = match flags {
541            CallReducerFlags::FullUpdate => Some(self.sender()),
542            // Setting `sender = None` causes `eval_updates` to skip sending to the caller
543            // as it has no access to the caller other than by id/connection id.
544            CallReducerFlags::NoSuccessNotify => None,
545        };
546
547        self.module
548            .call_reducer(
549                self.id.identity,
550                Some(self.id.connection_id),
551                caller,
552                Some(request_id),
553                Some(timer),
554                reducer,
555                args,
556            )
557            .await
558    }
559
560    pub async fn subscribe_single(
561        &self,
562        subscription: SubscribeSingle,
563        timer: Instant,
564    ) -> Result<Option<ExecutionMetrics>, DBError> {
565        let me = self.clone();
566        self.module
567            .on_module_thread("subscribe_single", move || {
568                me.module
569                    .subscriptions()
570                    .add_single_subscription(me.sender, subscription, timer, None)
571            })
572            .await?
573    }
574
575    pub async fn unsubscribe(&self, request: Unsubscribe, timer: Instant) -> Result<Option<ExecutionMetrics>, DBError> {
576        let me = self.clone();
577        asyncify(move || {
578            me.module
579                .subscriptions()
580                .remove_single_subscription(me.sender, request, timer)
581        })
582        .await
583    }
584
585    pub async fn subscribe_multi(
586        &self,
587        request: SubscribeMulti,
588        timer: Instant,
589    ) -> Result<Option<ExecutionMetrics>, DBError> {
590        let me = self.clone();
591        self.module
592            .on_module_thread("subscribe_multi", move || {
593                me.module
594                    .subscriptions()
595                    .add_multi_subscription(me.sender, request, timer, None)
596            })
597            .await?
598    }
599
600    pub async fn unsubscribe_multi(
601        &self,
602        request: UnsubscribeMulti,
603        timer: Instant,
604    ) -> Result<Option<ExecutionMetrics>, DBError> {
605        let me = self.clone();
606        self.module
607            .on_module_thread("unsubscribe_multi", move || {
608                me.module
609                    .subscriptions()
610                    .remove_multi_subscription(me.sender, request, timer)
611            })
612            .await?
613    }
614
615    pub async fn subscribe(&self, subscription: Subscribe, timer: Instant) -> Result<ExecutionMetrics, DBError> {
616        let me = self.clone();
617        asyncify(move || {
618            me.module
619                .subscriptions()
620                .add_legacy_subscriber(me.sender, subscription, timer, None)
621        })
622        .await
623    }
624
625    pub async fn one_off_query_json(
626        &self,
627        query: &str,
628        message_id: &[u8],
629        timer: Instant,
630    ) -> Result<(), anyhow::Error> {
631        self.module
632            .one_off_query::<JsonFormat>(
633                self.id.identity,
634                query.to_owned(),
635                self.sender.clone(),
636                message_id.to_owned(),
637                timer,
638                |msg: OneOffQueryResponseMessage<JsonFormat>| msg.into(),
639            )
640            .await
641    }
642
643    pub async fn one_off_query_bsatn(
644        &self,
645        query: &str,
646        message_id: &[u8],
647        timer: Instant,
648    ) -> Result<(), anyhow::Error> {
649        self.module
650            .one_off_query::<BsatnFormat>(
651                self.id.identity,
652                query.to_owned(),
653                self.sender.clone(),
654                message_id.to_owned(),
655                timer,
656                |msg: OneOffQueryResponseMessage<BsatnFormat>| msg.into(),
657            )
658            .await
659    }
660
661    pub async fn disconnect(self) {
662        self.module.disconnect_client(self.id).await
663    }
664}