1use std::ops::Deref;
2use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
3use std::sync::Arc;
4use std::time::Instant;
5
6use super::messages::{OneOffQueryResponseMessage, SerializableMessage};
7use super::{message_handlers, ClientActorId, MessageHandleError};
8use crate::error::DBError;
9use crate::host::module_host::ClientConnectedError;
10use crate::host::{ModuleHost, NoSuchModule, ReducerArgs, ReducerCallError, ReducerCallResult};
11use crate::messages::websocket::Subscribe;
12use crate::util::asyncify;
13use crate::util::prometheus_handle::IntGaugeExt;
14use crate::worker_metrics::WORKER_METRICS;
15use bytes::Bytes;
16use bytestring::ByteString;
17use derive_more::From;
18use futures::prelude::*;
19use prometheus::{Histogram, IntCounter, IntGauge};
20use spacetimedb_client_api_messages::websocket::{
21 BsatnFormat, CallReducerFlags, Compression, FormatSwitch, JsonFormat, SubscribeMulti, SubscribeSingle, Unsubscribe,
22 UnsubscribeMulti,
23};
24use spacetimedb_lib::identity::RequestId;
25use spacetimedb_lib::metrics::ExecutionMetrics;
26use spacetimedb_lib::Identity;
27use tokio::sync::{mpsc, oneshot, watch};
28use tokio::task::AbortHandle;
29
30#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)]
31pub enum Protocol {
32 Text,
33 Binary,
34}
35
36impl Protocol {
37 pub fn as_str(self) -> &'static str {
38 match self {
39 Protocol::Text => "text",
40 Protocol::Binary => "binary",
41 }
42 }
43
44 pub(crate) fn assert_matches_format_switch<B, J>(self, fs: &FormatSwitch<B, J>) {
45 match (self, fs) {
46 (Protocol::Text, FormatSwitch::Json(_)) | (Protocol::Binary, FormatSwitch::Bsatn(_)) => {}
47 _ => unreachable!("requested protocol does not match output format"),
48 }
49 }
50}
51
52#[derive(Clone, Copy, Debug)]
53pub struct ClientConfig {
54 pub protocol: Protocol,
56 pub compression: Compression,
58 pub tx_update_full: bool,
62}
63
64impl ClientConfig {
65 pub fn for_test() -> ClientConfig {
66 Self {
67 protocol: Protocol::Binary,
68 compression: <_>::default(),
69 tx_update_full: true,
70 }
71 }
72}
73
74#[derive(Debug)]
75pub struct ClientConnectionSender {
76 pub id: ClientActorId,
77 pub config: ClientConfig,
78 sendtx: mpsc::Sender<SerializableMessage>,
79 abort_handle: AbortHandle,
80 cancelled: AtomicBool,
81
82 metrics: Option<ClientConnectionMetrics>,
88}
89
90#[derive(Debug)]
91pub struct ClientConnectionMetrics {
92 pub websocket_request_msg_size: Histogram,
93 pub websocket_requests: IntCounter,
94
95 pub sendtx_queue_size: IntGauge,
103}
104
105impl ClientConnectionMetrics {
106 fn new(database_identity: Identity, protocol: Protocol) -> Self {
107 let message_kind = protocol.as_str();
108 let websocket_request_msg_size = WORKER_METRICS
109 .websocket_request_msg_size
110 .with_label_values(&database_identity, message_kind);
111 let websocket_requests = WORKER_METRICS
112 .websocket_requests
113 .with_label_values(&database_identity, message_kind);
114 let sendtx_queue_size = WORKER_METRICS
115 .total_outgoing_queue_length
116 .with_label_values(&database_identity);
117
118 Self {
119 websocket_request_msg_size,
120 websocket_requests,
121 sendtx_queue_size,
122 }
123 }
124}
125
126#[derive(Debug, thiserror::Error)]
127pub enum ClientSendError {
128 #[error("client disconnected")]
129 Disconnected,
130 #[error("client was not responding and has been disconnected")]
131 Cancelled,
132}
133
134impl ClientConnectionSender {
135 pub fn dummy_with_channel(id: ClientActorId, config: ClientConfig) -> (Self, mpsc::Receiver<SerializableMessage>) {
136 let (sendtx, rx) = mpsc::channel(1);
137 let abort_handle = match tokio::runtime::Handle::try_current() {
139 Ok(h) => h.spawn(async {}).abort_handle(),
140 Err(_) => tokio::runtime::Runtime::new().unwrap().spawn(async {}).abort_handle(),
141 };
142
143 let cancelled = AtomicBool::new(false);
144 let sender = Self {
145 id,
146 config,
147 sendtx,
148 abort_handle,
149 cancelled,
150 metrics: None,
151 };
152 (sender, rx)
153 }
154
155 pub fn dummy(id: ClientActorId, config: ClientConfig) -> Self {
156 Self::dummy_with_channel(id, config).0
157 }
158
159 pub fn send_message(&self, message: impl Into<SerializableMessage>) -> Result<(), ClientSendError> {
162 self.send(message.into())
163 }
164
165 fn send(&self, message: SerializableMessage) -> Result<(), ClientSendError> {
166 if self.cancelled.load(Relaxed) {
167 return Err(ClientSendError::Cancelled);
168 }
169
170 match self.sendtx.try_send(message) {
171 Err(mpsc::error::TrySendError::Full(_)) => {
172 tracing::warn!(identity = %self.id.identity, connection_id = %self.id.connection_id, "client channel capacity exceeded");
175 self.abort_handle.abort();
176 self.cancelled.store(true, Relaxed);
177 return Err(ClientSendError::Cancelled);
178 }
179 Err(mpsc::error::TrySendError::Closed(_)) => return Err(ClientSendError::Disconnected),
180 Ok(()) => {
181 if let Some(metrics) = &self.metrics {
186 metrics.sendtx_queue_size.inc();
187 }
188 }
189 }
190
191 Ok(())
192 }
193
194 pub(crate) fn observe_websocket_request_message(&self, message: &DataMessage) {
195 if let Some(metrics) = &self.metrics {
196 metrics.websocket_request_msg_size.observe(message.len() as f64);
197 metrics.websocket_requests.inc();
198 }
199 }
200}
201
202#[derive(Clone)]
203#[non_exhaustive]
204pub struct ClientConnection {
205 sender: Arc<ClientConnectionSender>,
206 pub replica_id: u64,
207 pub module: ModuleHost,
208 module_rx: watch::Receiver<ModuleHost>,
209}
210
211impl Deref for ClientConnection {
212 type Target = ClientConnectionSender;
213 fn deref(&self) -> &Self::Target {
214 &self.sender
215 }
216}
217
218#[derive(Debug, From)]
219pub enum DataMessage {
220 Text(ByteString),
221 Binary(Bytes),
222}
223
224impl From<String> for DataMessage {
225 fn from(value: String) -> Self {
226 ByteString::from(value).into()
227 }
228}
229
230impl From<Vec<u8>> for DataMessage {
231 fn from(value: Vec<u8>) -> Self {
232 Bytes::from(value).into()
233 }
234}
235
236impl DataMessage {
237 pub fn len(&self) -> usize {
238 match self {
239 DataMessage::Text(s) => s.len(),
240 DataMessage::Binary(b) => b.len(),
241 }
242 }
243
244 #[must_use]
245 pub fn is_empty(&self) -> bool {
246 self.len() == 0
247 }
248}
249
250const CLIENT_CHANNEL_CAPACITY: usize = 16 * KB;
253const KB: usize = 1024;
254
255impl ClientConnection {
256 pub async fn spawn<Fut>(
258 id: ClientActorId,
259 config: ClientConfig,
260 replica_id: u64,
261 mut module_rx: watch::Receiver<ModuleHost>,
262 actor: impl FnOnce(ClientConnection, mpsc::Receiver<SerializableMessage>) -> Fut,
263 ) -> Result<ClientConnection, ClientConnectedError>
264 where
265 Fut: Future<Output = ()> + Send + 'static,
266 {
267 let module = module_rx.borrow_and_update().clone();
272 module.call_identity_connected(id.identity, id.connection_id).await?;
273
274 let (sendtx, sendrx) = mpsc::channel::<SerializableMessage>(CLIENT_CHANNEL_CAPACITY);
275
276 let (fut_tx, fut_rx) = oneshot::channel::<Fut>();
277 let module_info = module.info.clone();
279 let database_identity = module_info.database_identity;
280 let abort_handle = tokio::spawn(async move {
281 let Ok(fut) = fut_rx.await else { return };
282
283 let _gauge_guard = module_info.metrics.connected_clients.inc_scope();
284 module_info.metrics.ws_clients_spawned.inc();
285 scopeguard::defer!(module_info.metrics.ws_clients_aborted.inc());
286
287 fut.await
288 })
289 .abort_handle();
290
291 let metrics = ClientConnectionMetrics::new(database_identity, config.protocol);
292
293 let sender = Arc::new(ClientConnectionSender {
294 id,
295 config,
296 sendtx,
297 abort_handle,
298 cancelled: AtomicBool::new(false),
299 metrics: Some(metrics),
300 });
301 let this = Self {
302 sender,
303 replica_id,
304 module,
305 module_rx,
306 };
307
308 let actor_fut = actor(this.clone(), sendrx);
309 let _ = fut_tx.send(actor_fut);
311
312 Ok(this)
313 }
314
315 pub fn dummy(
316 id: ClientActorId,
317 config: ClientConfig,
318 replica_id: u64,
319 mut module_rx: watch::Receiver<ModuleHost>,
320 ) -> Self {
321 let module = module_rx.borrow_and_update().clone();
322 Self {
323 sender: Arc::new(ClientConnectionSender::dummy(id, config)),
324 replica_id,
325 module,
326 module_rx,
327 }
328 }
329
330 pub fn sender(&self) -> Arc<ClientConnectionSender> {
331 self.sender.clone()
332 }
333
334 #[inline]
335 pub fn handle_message(
336 &self,
337 message: impl Into<DataMessage>,
338 timer: Instant,
339 ) -> impl Future<Output = Result<(), MessageHandleError>> + '_ {
340 message_handlers::handle(self, message.into(), timer)
341 }
342
343 pub async fn watch_module_host(&mut self) -> Result<(), NoSuchModule> {
344 match self.module_rx.changed().await {
345 Ok(()) => {
346 self.module = self.module_rx.borrow_and_update().clone();
347 Ok(())
348 }
349 Err(_) => Err(NoSuchModule),
350 }
351 }
352
353 pub async fn call_reducer(
354 &self,
355 reducer: &str,
356 args: ReducerArgs,
357 request_id: RequestId,
358 timer: Instant,
359 flags: CallReducerFlags,
360 ) -> Result<ReducerCallResult, ReducerCallError> {
361 let caller = match flags {
362 CallReducerFlags::FullUpdate => Some(self.sender()),
363 CallReducerFlags::NoSuccessNotify => None,
366 };
367
368 self.module
369 .call_reducer(
370 self.id.identity,
371 Some(self.id.connection_id),
372 caller,
373 Some(request_id),
374 Some(timer),
375 reducer,
376 args,
377 )
378 .await
379 }
380
381 pub async fn subscribe_single(
382 &self,
383 subscription: SubscribeSingle,
384 timer: Instant,
385 ) -> Result<Option<ExecutionMetrics>, DBError> {
386 let me = self.clone();
387 asyncify(move || {
388 me.module
389 .subscriptions()
390 .add_single_subscription(me.sender, subscription, timer, None)
391 })
392 .await
393 }
394
395 pub async fn unsubscribe(&self, request: Unsubscribe, timer: Instant) -> Result<Option<ExecutionMetrics>, DBError> {
396 let me = self.clone();
397 asyncify(move || {
398 me.module
399 .subscriptions()
400 .remove_single_subscription(me.sender, request, timer)
401 })
402 .await
403 }
404
405 pub async fn subscribe_multi(
406 &self,
407 request: SubscribeMulti,
408 timer: Instant,
409 ) -> Result<Option<ExecutionMetrics>, DBError> {
410 let me = self.clone();
411 asyncify(move || {
412 me.module
413 .subscriptions()
414 .add_multi_subscription(me.sender, request, timer, None)
415 })
416 .await
417 }
418
419 pub async fn unsubscribe_multi(
420 &self,
421 request: UnsubscribeMulti,
422 timer: Instant,
423 ) -> Result<Option<ExecutionMetrics>, DBError> {
424 let me = self.clone();
425 asyncify(move || {
426 me.module
427 .subscriptions()
428 .remove_multi_subscription(me.sender, request, timer)
429 })
430 .await
431 }
432
433 pub async fn subscribe(&self, subscription: Subscribe, timer: Instant) -> Result<ExecutionMetrics, DBError> {
434 let me = self.clone();
435 asyncify(move || {
436 me.module
437 .subscriptions()
438 .add_legacy_subscriber(me.sender, subscription, timer, None)
439 })
440 .await
441 }
442
443 pub async fn one_off_query_json(
444 &self,
445 query: &str,
446 message_id: &[u8],
447 timer: Instant,
448 ) -> Result<(), anyhow::Error> {
449 self.module
450 .one_off_query::<JsonFormat>(
451 self.id.identity,
452 query.to_owned(),
453 self.sender.clone(),
454 message_id.to_owned(),
455 timer,
456 |msg: OneOffQueryResponseMessage<JsonFormat>| msg.into(),
457 )
458 .await
459 }
460
461 pub async fn one_off_query_bsatn(
462 &self,
463 query: &str,
464 message_id: &[u8],
465 timer: Instant,
466 ) -> Result<(), anyhow::Error> {
467 self.module
468 .one_off_query::<BsatnFormat>(
469 self.id.identity,
470 query.to_owned(),
471 self.sender.clone(),
472 message_id.to_owned(),
473 timer,
474 |msg: OneOffQueryResponseMessage<BsatnFormat>| msg.into(),
475 )
476 .await
477 }
478
479 pub async fn disconnect(self) {
480 self.module.disconnect_client(self.id).await
481 }
482}