spacetimedb/client/
client_connection.rs

1use std::ops::Deref;
2use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
3use std::sync::Arc;
4use std::time::Instant;
5
6use super::messages::{OneOffQueryResponseMessage, SerializableMessage};
7use super::{message_handlers, ClientActorId, MessageHandleError};
8use crate::error::DBError;
9use crate::host::module_host::ClientConnectedError;
10use crate::host::{ModuleHost, NoSuchModule, ReducerArgs, ReducerCallError, ReducerCallResult};
11use crate::messages::websocket::Subscribe;
12use crate::util::asyncify;
13use crate::util::prometheus_handle::IntGaugeExt;
14use crate::worker_metrics::WORKER_METRICS;
15use bytes::Bytes;
16use bytestring::ByteString;
17use derive_more::From;
18use futures::prelude::*;
19use prometheus::{Histogram, IntCounter, IntGauge};
20use spacetimedb_client_api_messages::websocket::{
21    BsatnFormat, CallReducerFlags, Compression, FormatSwitch, JsonFormat, SubscribeMulti, SubscribeSingle, Unsubscribe,
22    UnsubscribeMulti,
23};
24use spacetimedb_lib::identity::RequestId;
25use spacetimedb_lib::metrics::ExecutionMetrics;
26use spacetimedb_lib::Identity;
27use tokio::sync::{mpsc, oneshot, watch};
28use tokio::task::AbortHandle;
29
30#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
31pub enum Protocol {
32    Text,
33    Binary,
34}
35
36impl Protocol {
37    pub fn as_str(self) -> &'static str {
38        match self {
39            Protocol::Text => "text",
40            Protocol::Binary => "binary",
41        }
42    }
43
44    pub(crate) fn assert_matches_format_switch<B, J>(self, fs: &FormatSwitch<B, J>) {
45        match (self, fs) {
46            (Protocol::Text, FormatSwitch::Json(_)) | (Protocol::Binary, FormatSwitch::Bsatn(_)) => {}
47            _ => unreachable!("requested protocol does not match output format"),
48        }
49    }
50}
51
52#[derive(Clone, Copy, Debug)]
53pub struct ClientConfig {
54    /// The client's desired protocol (format) when the host replies.
55    pub protocol: Protocol,
56    /// The client's desired (conditional) compression algorithm, if any.
57    pub compression: Compression,
58    /// Whether the client prefers full [`TransactionUpdate`]s
59    /// rather than  [`TransactionUpdateLight`]s on a successful update.
60    // TODO(centril): As more knobs are added, make this into a bitfield (when there's time).
61    pub tx_update_full: bool,
62}
63
64impl ClientConfig {
65    pub fn for_test() -> ClientConfig {
66        Self {
67            protocol: Protocol::Binary,
68            compression: <_>::default(),
69            tx_update_full: true,
70        }
71    }
72}
73
74#[derive(Debug)]
75pub struct ClientConnectionSender {
76    pub id: ClientActorId,
77    pub config: ClientConfig,
78    sendtx: mpsc::Sender<SerializableMessage>,
79    abort_handle: AbortHandle,
80    cancelled: AtomicBool,
81
82    /// Handles on Prometheus metrics related to connections to this database.
83    ///
84    /// Will be `None` when constructed by [`ClientConnectionSender::dummy_with_channel`]
85    /// or [`ClientConnectionSender::dummy`], which are used in tests.
86    /// Will be `Some` whenever this `ClientConnectionSender` is wired up to an actual client connection.
87    metrics: Option<ClientConnectionMetrics>,
88}
89
90#[derive(Debug)]
91pub struct ClientConnectionMetrics {
92    pub websocket_request_msg_size: Histogram,
93    pub websocket_requests: IntCounter,
94
95    /// The `total_outgoing_queue_length` metric labeled with this database's `Identity`,
96    /// which we'll increment whenever sending a message.
97    ///
98    /// This metric will be decremented, and cleaned up,
99    /// by `ws_client_actor_inner` in client-api/src/routes/subscribe.rs.
100    /// Care must be taken not to increment it after the client has disconnected
101    /// and performed its clean-up.
102    pub sendtx_queue_size: IntGauge,
103}
104
105impl ClientConnectionMetrics {
106    fn new(database_identity: Identity, protocol: Protocol) -> Self {
107        let message_kind = protocol.as_str();
108        let websocket_request_msg_size = WORKER_METRICS
109            .websocket_request_msg_size
110            .with_label_values(&database_identity, message_kind);
111        let websocket_requests = WORKER_METRICS
112            .websocket_requests
113            .with_label_values(&database_identity, message_kind);
114        let sendtx_queue_size = WORKER_METRICS
115            .total_outgoing_queue_length
116            .with_label_values(&database_identity);
117
118        Self {
119            websocket_request_msg_size,
120            websocket_requests,
121            sendtx_queue_size,
122        }
123    }
124}
125
126#[derive(Debug, thiserror::Error)]
127pub enum ClientSendError {
128    #[error("client disconnected")]
129    Disconnected,
130    #[error("client was not responding and has been disconnected")]
131    Cancelled,
132}
133
134impl ClientConnectionSender {
135    pub fn dummy_with_channel(id: ClientActorId, config: ClientConfig) -> (Self, mpsc::Receiver<SerializableMessage>) {
136        let (sendtx, rx) = mpsc::channel(1);
137        // just make something up, it doesn't need to be attached to a real task
138        let abort_handle = match tokio::runtime::Handle::try_current() {
139            Ok(h) => h.spawn(async {}).abort_handle(),
140            Err(_) => tokio::runtime::Runtime::new().unwrap().spawn(async {}).abort_handle(),
141        };
142
143        let cancelled = AtomicBool::new(false);
144        let sender = Self {
145            id,
146            config,
147            sendtx,
148            abort_handle,
149            cancelled,
150            metrics: None,
151        };
152        (sender, rx)
153    }
154
155    pub fn dummy(id: ClientActorId, config: ClientConfig) -> Self {
156        Self::dummy_with_channel(id, config).0
157    }
158
159    /// Send a message to the client. For data-related messages, you should probably use
160    /// `BroadcastQueue::send` to ensure that the client sees data messages in a consistent order.
161    pub fn send_message(&self, message: impl Into<SerializableMessage>) -> Result<(), ClientSendError> {
162        self.send(message.into())
163    }
164
165    fn send(&self, message: SerializableMessage) -> Result<(), ClientSendError> {
166        if self.cancelled.load(Relaxed) {
167            return Err(ClientSendError::Cancelled);
168        }
169
170        match self.sendtx.try_send(message) {
171            Err(mpsc::error::TrySendError::Full(_)) => {
172                // we've hit CLIENT_CHANNEL_CAPACITY messages backed up in
173                // the channel, so forcibly kick the client
174                tracing::warn!(identity = %self.id.identity, connection_id = %self.id.connection_id, "client channel capacity exceeded");
175                self.abort_handle.abort();
176                self.cancelled.store(true, Relaxed);
177                return Err(ClientSendError::Cancelled);
178            }
179            Err(mpsc::error::TrySendError::Closed(_)) => return Err(ClientSendError::Disconnected),
180            Ok(()) => {
181                // If we successfully pushed a message into the queue, increment the queue size metric.
182                // Don't do this before pushing because, if the client has disconnected,
183                // it will already have performed its clean-up,
184                // and so would never perform the corresponding `dec` to this `inc`.
185                if let Some(metrics) = &self.metrics {
186                    metrics.sendtx_queue_size.inc();
187                }
188            }
189        }
190
191        Ok(())
192    }
193
194    pub(crate) fn observe_websocket_request_message(&self, message: &DataMessage) {
195        if let Some(metrics) = &self.metrics {
196            metrics.websocket_request_msg_size.observe(message.len() as f64);
197            metrics.websocket_requests.inc();
198        }
199    }
200}
201
202#[derive(Clone)]
203#[non_exhaustive]
204pub struct ClientConnection {
205    sender: Arc<ClientConnectionSender>,
206    pub replica_id: u64,
207    pub module: ModuleHost,
208    module_rx: watch::Receiver<ModuleHost>,
209}
210
211impl Deref for ClientConnection {
212    type Target = ClientConnectionSender;
213    fn deref(&self) -> &Self::Target {
214        &self.sender
215    }
216}
217
218#[derive(Debug, From)]
219pub enum DataMessage {
220    Text(ByteString),
221    Binary(Bytes),
222}
223
224impl From<String> for DataMessage {
225    fn from(value: String) -> Self {
226        ByteString::from(value).into()
227    }
228}
229
230impl From<Vec<u8>> for DataMessage {
231    fn from(value: Vec<u8>) -> Self {
232        Bytes::from(value).into()
233    }
234}
235
236impl DataMessage {
237    pub fn len(&self) -> usize {
238        match self {
239            DataMessage::Text(s) => s.len(),
240            DataMessage::Binary(b) => b.len(),
241        }
242    }
243
244    #[must_use]
245    pub fn is_empty(&self) -> bool {
246        self.len() == 0
247    }
248}
249
250// if a client racks up this many messages in the queue without ACK'ing
251// anything, we boot 'em.
252const CLIENT_CHANNEL_CAPACITY: usize = 16 * KB;
253const KB: usize = 1024;
254
255impl ClientConnection {
256    /// Returns an error if ModuleHost closed
257    pub async fn spawn<Fut>(
258        id: ClientActorId,
259        config: ClientConfig,
260        replica_id: u64,
261        mut module_rx: watch::Receiver<ModuleHost>,
262        actor: impl FnOnce(ClientConnection, mpsc::Receiver<SerializableMessage>) -> Fut,
263    ) -> Result<ClientConnection, ClientConnectedError>
264    where
265        Fut: Future<Output = ()> + Send + 'static,
266    {
267        // Add this client as a subscriber
268        // TODO: Right now this is connecting clients directly to a replica, but their requests should be
269        // logically subscribed to the database, not any particular replica. We should handle failover for
270        // them and stuff. Not right now though.
271        let module = module_rx.borrow_and_update().clone();
272        module.call_identity_connected(id.identity, id.connection_id).await?;
273
274        let (sendtx, sendrx) = mpsc::channel::<SerializableMessage>(CLIENT_CHANNEL_CAPACITY);
275
276        let (fut_tx, fut_rx) = oneshot::channel::<Fut>();
277        // weird dance so that we can get an abort_handle into ClientConnection
278        let module_info = module.info.clone();
279        let database_identity = module_info.database_identity;
280        let abort_handle = tokio::spawn(async move {
281            let Ok(fut) = fut_rx.await else { return };
282
283            let _gauge_guard = module_info.metrics.connected_clients.inc_scope();
284            module_info.metrics.ws_clients_spawned.inc();
285            scopeguard::defer!(module_info.metrics.ws_clients_aborted.inc());
286
287            fut.await
288        })
289        .abort_handle();
290
291        let metrics = ClientConnectionMetrics::new(database_identity, config.protocol);
292
293        let sender = Arc::new(ClientConnectionSender {
294            id,
295            config,
296            sendtx,
297            abort_handle,
298            cancelled: AtomicBool::new(false),
299            metrics: Some(metrics),
300        });
301        let this = Self {
302            sender,
303            replica_id,
304            module,
305            module_rx,
306        };
307
308        let actor_fut = actor(this.clone(), sendrx);
309        // if this fails, the actor() function called .abort(), which like... okay, I guess?
310        let _ = fut_tx.send(actor_fut);
311
312        Ok(this)
313    }
314
315    pub fn dummy(
316        id: ClientActorId,
317        config: ClientConfig,
318        replica_id: u64,
319        mut module_rx: watch::Receiver<ModuleHost>,
320    ) -> Self {
321        let module = module_rx.borrow_and_update().clone();
322        Self {
323            sender: Arc::new(ClientConnectionSender::dummy(id, config)),
324            replica_id,
325            module,
326            module_rx,
327        }
328    }
329
330    pub fn sender(&self) -> Arc<ClientConnectionSender> {
331        self.sender.clone()
332    }
333
334    #[inline]
335    pub fn handle_message(
336        &self,
337        message: impl Into<DataMessage>,
338        timer: Instant,
339    ) -> impl Future<Output = Result<(), MessageHandleError>> + '_ {
340        message_handlers::handle(self, message.into(), timer)
341    }
342
343    pub async fn watch_module_host(&mut self) -> Result<(), NoSuchModule> {
344        match self.module_rx.changed().await {
345            Ok(()) => {
346                self.module = self.module_rx.borrow_and_update().clone();
347                Ok(())
348            }
349            Err(_) => Err(NoSuchModule),
350        }
351    }
352
353    pub async fn call_reducer(
354        &self,
355        reducer: &str,
356        args: ReducerArgs,
357        request_id: RequestId,
358        timer: Instant,
359        flags: CallReducerFlags,
360    ) -> Result<ReducerCallResult, ReducerCallError> {
361        let caller = match flags {
362            CallReducerFlags::FullUpdate => Some(self.sender()),
363            // Setting `sender = None` causes `eval_updates` to skip sending to the caller
364            // as it has no access to the caller other than by id/connection id.
365            CallReducerFlags::NoSuccessNotify => None,
366        };
367
368        self.module
369            .call_reducer(
370                self.id.identity,
371                Some(self.id.connection_id),
372                caller,
373                Some(request_id),
374                Some(timer),
375                reducer,
376                args,
377            )
378            .await
379    }
380
381    pub async fn subscribe_single(
382        &self,
383        subscription: SubscribeSingle,
384        timer: Instant,
385    ) -> Result<Option<ExecutionMetrics>, DBError> {
386        let me = self.clone();
387        asyncify(move || {
388            me.module
389                .subscriptions()
390                .add_single_subscription(me.sender, subscription, timer, None)
391        })
392        .await
393    }
394
395    pub async fn unsubscribe(&self, request: Unsubscribe, timer: Instant) -> Result<Option<ExecutionMetrics>, DBError> {
396        let me = self.clone();
397        asyncify(move || {
398            me.module
399                .subscriptions()
400                .remove_single_subscription(me.sender, request, timer)
401        })
402        .await
403    }
404
405    pub async fn subscribe_multi(
406        &self,
407        request: SubscribeMulti,
408        timer: Instant,
409    ) -> Result<Option<ExecutionMetrics>, DBError> {
410        let me = self.clone();
411        asyncify(move || {
412            me.module
413                .subscriptions()
414                .add_multi_subscription(me.sender, request, timer, None)
415        })
416        .await
417    }
418
419    pub async fn unsubscribe_multi(
420        &self,
421        request: UnsubscribeMulti,
422        timer: Instant,
423    ) -> Result<Option<ExecutionMetrics>, DBError> {
424        let me = self.clone();
425        asyncify(move || {
426            me.module
427                .subscriptions()
428                .remove_multi_subscription(me.sender, request, timer)
429        })
430        .await
431    }
432
433    pub async fn subscribe(&self, subscription: Subscribe, timer: Instant) -> Result<ExecutionMetrics, DBError> {
434        let me = self.clone();
435        asyncify(move || {
436            me.module
437                .subscriptions()
438                .add_legacy_subscriber(me.sender, subscription, timer, None)
439        })
440        .await
441    }
442
443    pub async fn one_off_query_json(
444        &self,
445        query: &str,
446        message_id: &[u8],
447        timer: Instant,
448    ) -> Result<(), anyhow::Error> {
449        self.module
450            .one_off_query::<JsonFormat>(
451                self.id.identity,
452                query.to_owned(),
453                self.sender.clone(),
454                message_id.to_owned(),
455                timer,
456                |msg: OneOffQueryResponseMessage<JsonFormat>| msg.into(),
457            )
458            .await
459    }
460
461    pub async fn one_off_query_bsatn(
462        &self,
463        query: &str,
464        message_id: &[u8],
465        timer: Instant,
466    ) -> Result<(), anyhow::Error> {
467        self.module
468            .one_off_query::<BsatnFormat>(
469                self.id.identity,
470                query.to_owned(),
471                self.sender.clone(),
472                message_id.to_owned(),
473                timer,
474                |msg: OneOffQueryResponseMessage<BsatnFormat>| msg.into(),
475            )
476            .await
477    }
478
479    pub async fn disconnect(self) {
480        self.module.disconnect_client(self.id).await
481    }
482}