spacetimedb_client_api/routes/
subscribe.rs

1use std::fmt::Display;
2use std::future::{poll_fn, Future};
3use std::num::NonZeroUsize;
4use std::panic;
5use std::pin::{pin, Pin};
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Arc;
8use std::task::{Context, Poll};
9use std::time::Duration;
10
11use async_stream::stream;
12use axum::extract::{Path, Query, State};
13use axum::response::IntoResponse;
14use axum::Extension;
15use axum_extra::TypedHeader;
16use bytes::Bytes;
17use bytestring::ByteString;
18use derive_more::From;
19use futures::{pin_mut, Sink, SinkExt, Stream, StreamExt};
20use http::{HeaderValue, StatusCode};
21use prometheus::IntGauge;
22use scopeguard::{defer, ScopeGuard};
23use serde::Deserialize;
24use spacetimedb::client::messages::{
25    serialize, IdentityTokenMessage, SerializableMessage, SerializeBuffer, SwitchedServerMessage, ToProtocol,
26};
27use spacetimedb::client::{
28    ClientActorId, ClientConfig, ClientConnection, DataMessage, MessageExecutionError, MessageHandleError,
29    MeteredReceiver, Protocol,
30};
31use spacetimedb::host::module_host::ClientConnectedError;
32use spacetimedb::host::NoSuchModule;
33use spacetimedb::util::spawn_rayon;
34use spacetimedb::worker_metrics::WORKER_METRICS;
35use spacetimedb::Identity;
36use spacetimedb_client_api_messages::websocket::{self as ws_api, Compression};
37use spacetimedb_datastore::execution_context::WorkloadType;
38use spacetimedb_lib::connection_id::{ConnectionId, ConnectionIdForUrl};
39use std::time::Instant;
40use tokio::sync::{mpsc, watch};
41use tokio::task::JoinHandle;
42use tokio::time::error::Elapsed;
43use tokio::time::{sleep_until, timeout};
44use tokio_tungstenite::tungstenite::Utf8Bytes;
45
46use crate::auth::SpacetimeAuth;
47use crate::util::serde::humantime_duration;
48use crate::util::websocket::{
49    CloseCode, CloseFrame, Message as WsMessage, WebSocketConfig, WebSocketStream, WebSocketUpgrade, WsError,
50};
51use crate::util::{NameOrIdentity, XForwardedFor};
52use crate::{log_and_500, ControlStateDelegate, NodeDelegate};
53
54#[allow(clippy::declare_interior_mutable_const)]
55pub const TEXT_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_api::TEXT_PROTOCOL);
56#[allow(clippy::declare_interior_mutable_const)]
57pub const BIN_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_api::BIN_PROTOCOL);
58
59pub trait HasWebSocketOptions {
60    fn websocket_options(&self) -> WebSocketOptions;
61}
62
63impl<T: HasWebSocketOptions> HasWebSocketOptions for Arc<T> {
64    fn websocket_options(&self) -> WebSocketOptions {
65        (**self).websocket_options()
66    }
67}
68
69#[derive(Deserialize)]
70pub struct SubscribeParams {
71    pub name_or_identity: NameOrIdentity,
72}
73
74#[derive(Deserialize)]
75pub struct SubscribeQueryParams {
76    pub connection_id: Option<ConnectionIdForUrl>,
77    #[serde(default)]
78    pub compression: Compression,
79    /// Whether we want "light" responses, tailored to network bandwidth constrained clients.
80    /// This knob works by setting other, more specific, knobs to the value.
81    #[serde(default)]
82    pub light: bool,
83}
84
85pub fn generate_random_connection_id() -> ConnectionId {
86    ConnectionId::from_le_byte_array(rand::random())
87}
88
89pub async fn handle_websocket<S>(
90    State(ctx): State<S>,
91    Path(SubscribeParams { name_or_identity }): Path<SubscribeParams>,
92    Query(SubscribeQueryParams {
93        connection_id,
94        compression,
95        light,
96    }): Query<SubscribeQueryParams>,
97    forwarded_for: Option<TypedHeader<XForwardedFor>>,
98    Extension(auth): Extension<SpacetimeAuth>,
99    ws: WebSocketUpgrade,
100) -> axum::response::Result<impl IntoResponse>
101where
102    S: NodeDelegate + ControlStateDelegate + HasWebSocketOptions,
103{
104    if connection_id.is_some() {
105        // TODO: Bump this up to `log::warn!` after removing the client SDKs' uses of that parameter.
106        log::debug!("The connection_id query parameter to the subscribe HTTP endpoint is internal and will be removed in a future version of SpacetimeDB.");
107    }
108
109    let connection_id = connection_id
110        .map(ConnectionId::from)
111        .unwrap_or_else(generate_random_connection_id);
112
113    if connection_id == ConnectionId::ZERO {
114        Err((
115            StatusCode::BAD_REQUEST,
116            "Invalid connection ID: the all-zeros ConnectionId is reserved.",
117        ))?;
118    }
119
120    let db_identity = name_or_identity.resolve(&ctx).await?;
121
122    let (res, ws_upgrade, protocol) =
123        ws.select_protocol([(BIN_PROTOCOL, Protocol::Binary), (TEXT_PROTOCOL, Protocol::Text)]);
124
125    let protocol = protocol.ok_or((StatusCode::BAD_REQUEST, "no valid protocol selected"))?;
126    let client_config = ClientConfig {
127        protocol,
128        compression,
129        tx_update_full: !light,
130    };
131
132    // TODO: Should also maybe refactor the code and the protocol to allow a single websocket
133    // to connect to multiple modules
134
135    let database = ctx
136        .get_database_by_identity(&db_identity)
137        .unwrap()
138        .ok_or(StatusCode::NOT_FOUND)?;
139
140    let leader = ctx
141        .leader(database.id)
142        .await
143        .map_err(log_and_500)?
144        .ok_or(StatusCode::NOT_FOUND)?;
145
146    let identity_token = auth.creds.token().into();
147
148    let mut module_rx = leader.module_watcher().await.map_err(log_and_500)?;
149
150    let client_id = ClientActorId {
151        identity: auth.identity,
152        connection_id,
153        name: ctx.client_actor_index().next_client_name(),
154    };
155
156    let ws_config = WebSocketConfig::default()
157        .max_message_size(Some(0x2000000))
158        .max_frame_size(None)
159        .accept_unmasked_frames(false);
160    let ws_opts = ctx.websocket_options();
161
162    tokio::spawn(async move {
163        let ws = match ws_upgrade.upgrade(ws_config).await {
164            Ok(ws) => ws,
165            Err(err) => {
166                log::error!("websocket: WebSocket init error: {err}");
167                return;
168            }
169        };
170
171        let identity = client_id.identity;
172        let client_log_string = match forwarded_for {
173            Some(TypedHeader(XForwardedFor(ip))) => {
174                format!("ip {ip} with Identity {identity} and ConnectionId {connection_id}")
175            }
176            None => format!("unknown ip with Identity {identity} and ConnectionId {connection_id}"),
177        };
178
179        log::debug!("websocket: New client connected from {client_log_string}");
180
181        let connected = match ClientConnection::call_client_connected_maybe_reject(&mut module_rx, client_id).await {
182            Ok(connected) => {
183                log::debug!("websocket: client_connected returned Ok for {client_log_string}");
184                connected
185            }
186            Err(e @ (ClientConnectedError::Rejected(_) | ClientConnectedError::OutOfEnergy)) => {
187                log::info!(
188                    "websocket: Rejecting connection for {client_log_string} due to error from client_connected reducer: {e}"
189                );
190                return;
191            }
192            Err(e @ (ClientConnectedError::DBError(_) | ClientConnectedError::ReducerCall(_))) => {
193                log::warn!("websocket: ModuleHost died while {client_log_string} was connecting: {e:#}");
194                return;
195            }
196        };
197
198        log::debug!(
199            "websocket: Database accepted connection from {client_log_string}; spawning ws_client_actor and ClientConnection"
200        );
201
202        let actor = |client, sendrx| ws_client_actor(ws_opts, client, ws, sendrx);
203        let client =
204            ClientConnection::spawn(client_id, client_config, leader.replica_id, module_rx, actor, connected).await;
205
206        // Send the client their identity token message as the first message
207        // NOTE: We're adding this to the protocol because some client libraries are
208        // unable to access the http response headers.
209        // Clients that receive the token from the response headers should ignore this
210        // message.
211        let message = IdentityTokenMessage {
212            identity: auth.identity,
213            token: identity_token,
214            connection_id,
215        };
216        if let Err(e) = client.send_message(message) {
217            log::warn!("websocket: Error sending IdentityToken message to {client_log_string}: {e}");
218        }
219    });
220
221    Ok(res)
222}
223
224struct ActorState {
225    pub client_id: ClientActorId,
226    pub database: Identity,
227    config: WebSocketOptions,
228    closed: AtomicBool,
229    got_pong: AtomicBool,
230}
231
232impl ActorState {
233    pub fn new(database: Identity, client_id: ClientActorId, config: WebSocketOptions) -> Self {
234        Self {
235            database,
236            client_id,
237            config,
238            closed: AtomicBool::new(false),
239            got_pong: AtomicBool::new(true),
240        }
241    }
242
243    pub fn closed(&self) -> bool {
244        self.closed.load(Ordering::Relaxed)
245    }
246
247    pub fn close(&self) -> bool {
248        self.closed.swap(true, Ordering::Relaxed)
249    }
250
251    pub fn set_ponged(&self) {
252        self.got_pong.store(true, Ordering::Relaxed);
253    }
254
255    pub fn reset_ponged(&self) -> bool {
256        self.got_pong.swap(false, Ordering::Relaxed)
257    }
258
259    pub fn next_idle_deadline(&self) -> Instant {
260        Instant::now() + self.config.idle_timeout
261    }
262}
263
264/// Configuration for WebSocket connections.
265#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
266#[serde(rename_all = "kebab-case")]
267pub struct WebSocketOptions {
268    /// Interval at which to send `Ping` frames.
269    ///
270    /// We use pings for connection keep-alive.
271    /// Value must be smaller than `idle_timeout`.
272    ///
273    /// Default: 15s
274    #[serde(with = "humantime_duration")]
275    #[serde(default = "WebSocketOptions::default_ping_interval")]
276    pub ping_interval: Duration,
277    /// Amount of time after which an idle connection is closed.
278    ///
279    /// A connection is considered idle if no data is received nor sent.
280    /// This includes `Ping`/`Pong` frames used for keep-alive.
281    ///
282    /// Value must be greater than `ping_interval`.
283    ///
284    /// Default: 30s
285    #[serde(with = "humantime_duration")]
286    #[serde(default = "WebSocketOptions::default_idle_timeout")]
287    pub idle_timeout: Duration,
288    /// For how long to keep draining the incoming messages until a client close
289    /// is received.
290    ///
291    /// Default: 250ms
292    #[serde(with = "humantime_duration")]
293    #[serde(default = "WebSocketOptions::default_close_handshake_timeout")]
294    pub close_handshake_timeout: Duration,
295    /// Maximum number of messages to queue for processing.
296    ///
297    /// If this number is exceeded, the client is disconnected.
298    ///
299    /// Default: 2048
300    #[serde(default = "WebSocketOptions::default_incoming_queue_length")]
301    pub incoming_queue_length: NonZeroUsize,
302}
303
304impl Default for WebSocketOptions {
305    fn default() -> Self {
306        Self::DEFAULT
307    }
308}
309
310impl WebSocketOptions {
311    const DEFAULT_PING_INTERVAL: Duration = Duration::from_secs(15);
312    const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
313    const DEFAULT_CLOSE_HANDSHAKE_TIMEOUT: Duration = Duration::from_millis(250);
314    const DEFAULT_INCOMING_QUEUE_LENGTH: NonZeroUsize = NonZeroUsize::new(2048).expect("2048 > 0, qed");
315
316    const DEFAULT: Self = Self {
317        ping_interval: Self::DEFAULT_PING_INTERVAL,
318        idle_timeout: Self::DEFAULT_IDLE_TIMEOUT,
319        close_handshake_timeout: Self::DEFAULT_CLOSE_HANDSHAKE_TIMEOUT,
320        incoming_queue_length: Self::DEFAULT_INCOMING_QUEUE_LENGTH,
321    };
322
323    const fn default_ping_interval() -> Duration {
324        Self::DEFAULT_PING_INTERVAL
325    }
326
327    const fn default_idle_timeout() -> Duration {
328        Self::DEFAULT_IDLE_TIMEOUT
329    }
330
331    const fn default_close_handshake_timeout() -> Duration {
332        Self::DEFAULT_CLOSE_HANDSHAKE_TIMEOUT
333    }
334
335    const fn default_incoming_queue_length() -> NonZeroUsize {
336        Self::DEFAULT_INCOMING_QUEUE_LENGTH
337    }
338}
339
340async fn ws_client_actor(
341    options: WebSocketOptions,
342    client: ClientConnection,
343    ws: WebSocketStream,
344    sendrx: MeteredReceiver<SerializableMessage>,
345) {
346    // ensure that even if this task gets cancelled, we always cleanup the connection
347    let mut client = scopeguard::guard(client, |client| {
348        tokio::spawn(client.disconnect());
349    });
350
351    ws_client_actor_inner(&mut client, options, ws, sendrx).await;
352
353    ScopeGuard::into_inner(client).disconnect().await;
354}
355
356async fn ws_client_actor_inner(
357    client: &mut ClientConnection,
358    config: WebSocketOptions,
359    ws: WebSocketStream,
360    sendrx: MeteredReceiver<SerializableMessage>,
361) {
362    let database = client.module.info().database_identity;
363    let client_id = client.id;
364    let client_closed_metric = WORKER_METRICS.ws_clients_closed_connection.with_label_values(&database);
365    let state = Arc::new(ActorState::new(database, client_id, config));
366
367    // Channel for [`UnorderedWsMessage`]s.
368    let (unordered_tx, unordered_rx) = mpsc::unbounded_channel();
369
370    // Split websocket into send and receive halves.
371    let (ws_send, ws_recv) = ws.split();
372
373    // Set up the idle timer.
374    let (idle_tx, idle_rx) = watch::channel(state.next_idle_deadline());
375    let idle_timer = ws_idle_timer(idle_rx);
376
377    // Spawn a task to send outgoing messages
378    // obtained from `sendrx` and `unordered_rx`.
379    let send_task = tokio::spawn(ws_send_loop(
380        state.clone(),
381        client.config,
382        ws_send,
383        sendrx,
384        unordered_rx,
385    ));
386    // Spawn a task to handle incoming messages.
387    let recv_task = tokio::spawn(ws_recv_task(
388        state.clone(),
389        idle_tx,
390        client_closed_metric,
391        {
392            let client = client.clone();
393            move |data, timer| {
394                let client = client.clone();
395                async move { client.handle_message(data, timer).await }
396            }
397        },
398        unordered_tx.clone(),
399        ws_recv,
400    ));
401    let hotswap = {
402        let client = client.clone();
403        move || {
404            let mut client = client.clone();
405            async move { client.watch_module_host().await }
406        }
407    };
408
409    ws_main_loop(state, hotswap, idle_timer, send_task, recv_task, move |msg| {
410        let _ = unordered_tx.send(msg);
411    })
412    .await;
413    log::info!("Client connection ended: {client_id}");
414}
415
416/// The main `select!` loop of the websocket client actor.
417///
418/// > This function is defined standalone with generic parameters so that its
419/// > behavior can be tested in isolation, not requiring I/O and allowing to
420/// > mock effects easily.
421///
422/// The loop's responsibilities are:
423///
424/// - Drive the tasks handling the send and receive ends of the websockets to
425///   completion, terminating when either of them completes.
426///
427/// - Terminating if the connection is idle for longer than [`ActorConfig::idle_timeout`].
428///   The connection becomes idle if nothing is received from the socket.
429///
430/// - Periodically sending `Ping` frames to prevent the connection from becoming
431///   idle (the client is supposed to respond with `Pong`, which resets the
432///   idle timer). See [`ActorConfig::ping_interval`].
433///
434/// - Watch for changes to the [`ClientConnection`]'s module reference.
435///   If it changes, the [`ClientConnection`] "hotswaps" the module, if it
436///   is exited, the loop schedules a `Close` frame to be sent, initiating a
437///   connection shutdown.
438///
439/// A peculiarity of handling termination is the websocket [close handshake]:
440/// whichever side wants to close the connection sends a `Close` frame and needs
441/// to wait for the other end to respond with a `Close` for the connection to
442/// end cleanly.
443///
444/// `tungstenite` handles the protocol details of the close handshake for us,
445/// but for it to work properly, we must keep polling the socket until the
446/// handshake is complete.
447///
448/// This is straightforward when the client initiates the close, as the receive
449/// stream will just become exhausted, and we'll exit the loop.
450///
451/// In the case of a server-initiated close, it's a bit more tricky, as we're
452/// not supposed to send any more data after a `Close` frame (and `tungstenite`
453/// prevents it). Yet, we need to keep polling the receive end until either
454/// the `Close` response (which could be queued behind a large number of
455/// outstanding messages) arrives, or a timeout elapses (in case the client
456/// never responds).
457///
458/// The implementations [`ws_recv_loop`] and [`ws_send_loop`] thus share the
459/// [`ActorState`], which tracks whether the connection is in the closing phase
460/// ([`ActorState::closed()`]). If closed, both the send and receive loops keep
461/// running, but drop any incoming or outgoing messages respectively until
462/// either the `Close` response arrives or [`ActorConfig::close_handshake_timeout`]
463/// elapses.
464///
465///
466/// Parameters:
467///
468/// * **state**:
469///   The shared [`ActorState`], updated here when a `Pong` message is received.
470///
471/// * **hotswap**:
472///   An abstraction for [`ClientConnection::watch_module_host`], which updates
473///   the connection's internal reference to the module if it was updated,
474///   allowing database updates without disconnecting clients.
475///
476///   It is polled here for its error return value: if the output of the future
477///   is `Err(NoSuchModule)`, the database was shut down and existing clients
478///   must be disconnected.
479///
480/// * **idle_timer**:
481///   Abstraction for [`ws_idle_timer`]: if and when the future completes, the
482///   connection is considered unresponsive, and the connection is closed.
483///
484///   The idle timer should be reset whenever data is received from the websocket.
485///
486/// * **send_task**:
487///   Task handling outgoing messages. Holds the receive end of `unordered_tx`.
488///
489///   If the task returns, the connection is considered bad, and the main loop
490///   exits. If the task panicked, the panic is resumed on the current thread.
491///
492///   Note that the send task must not terminate after it has sent a `Close`
493///   frame (via `unordered_tx`) -- the websocket protocol mandates that the
494///   initiator of the close handshake wait for the other end to respond with
495///   a `Close` frame. Thus, the loop must continue to poll `recv_task` and not
496///   exit due to `send_task` being complete.
497///
498///   See [`ws_send_loop`].
499///
500/// * **recv_task**:
501///   Task handling incoming messages.
502///
503///   If the task returns, the connection is considered closed, and the main
504///   loop exits. If the task panicked, the panic is resumed on the current
505///   thread.
506///
507///   See [`ws_recv_task`].
508///
509/// * **unordered_tx**:
510///   Channel connected to `send_task` that allows the loop to send `Ping` and
511///   `Close` frames.
512///
513///   Note that messages sent while the receiving `send_task` is already
514///   terminated are silently ignored. This is safe because the loop will exit
515///   anyway when the `send_task` is complete.
516///
517///
518/// [close handshake]: https://datatracker.ietf.org/doc/html/rfc6455#section-7
519async fn ws_main_loop<HotswapWatcher>(
520    state: Arc<ActorState>,
521    hotswap: impl Fn() -> HotswapWatcher,
522    idle_timer: impl Future<Output = ()>,
523    mut send_task: JoinHandle<()>,
524    mut recv_task: JoinHandle<()>,
525    unordered_tx: impl Fn(UnorderedWsMessage),
526) where
527    HotswapWatcher: Future<Output = Result<(), NoSuchModule>>,
528{
529    // Ensure we terminate both tasks if either exits.
530    let abort_send = send_task.abort_handle();
531    let abort_recv = recv_task.abort_handle();
532    defer! {
533        abort_send.abort();
534        abort_recv.abort();
535    };
536    // Set up the ping interval.
537    let mut ping_interval = tokio::time::interval(state.config.ping_interval);
538    // Arm the first hotswap watcher.
539    let watch_hotswap = hotswap();
540
541    pin_mut!(watch_hotswap);
542    pin_mut!(idle_timer);
543
544    loop {
545        let closed = state.closed();
546
547        tokio::select! {
548            // Drive send and receive tasks to completion,
549            // propagating panics.
550            //
551            // If either task completes,
552            // the connection is considered closed and we break the loop.
553            //
554            // NOTE: We don't abort the tasks until this function returns,
555            // so the `Err` can't contain an `is_cancelled()` value.
556            //
557            // Even if the tasks were cancelled (e.g. if the caller retains
558            // [`tokio::task::AbortHandle`]s), the reasonable thing to do is to
559            // exit the loop as if the tasks completed normally.
560            res = &mut send_task => {
561                if let Err(e) = res {
562                    if e.is_panic() {
563                        panic::resume_unwind(e.into_panic())
564                    }
565                }
566                break;
567            },
568            res = &mut recv_task => {
569                if let Err(e) = res {
570                    if e.is_panic() {
571                        panic::resume_unwind(e.into_panic())
572                    }
573                }
574                break;
575            },
576
577            // Exit if we haven't heard from the client for too long.
578            _ = &mut idle_timer => {
579                log::warn!("Client {} timed out", state.client_id);
580                break;
581            },
582
583            // Update the client's module host if it was hotswapped,
584            // or close the session if the module exited.
585            //
586            // Branch is disabled if we already sent a close frame.
587            res = &mut watch_hotswap, if !closed => {
588                if let Err(NoSuchModule) = res {
589                    let close = CloseFrame {
590                        code: CloseCode::Away,
591                        reason: "module exited".into()
592                    };
593                    unordered_tx(close.into());
594                }
595                watch_hotswap.set(hotswap());
596            },
597
598            // Send ping.
599            //
600            // If we didn't receive a response to the last ping,
601            // we don't bother sending a fresh one.
602            //
603            // Either the connection is idle (in which case the timer will kick
604            // in), or there is a massive backlog to process until the pong
605            // appears on the ordered stream. In either case, adding more pings
606            // is of no value.
607            //
608            // Branch is disabled if we already sent a close frame.
609            _ = ping_interval.tick(), if !closed => {
610                let was_ponged = state.reset_ponged();
611                if was_ponged {
612                    unordered_tx(UnorderedWsMessage::Ping(Bytes::new()));
613                }
614            }
615        }
616    }
617}
618
619/// A sleep that can be extended by sending it new deadlines.
620///
621/// Sleeps until the deadline appearing on the `activity` channel,
622/// i.e. if a new deadline appears before the sleep finishes,
623/// the sleep is reset to the new deadline.
624///
625/// The `activity` should be updated whenever a new message is received.
626async fn ws_idle_timer(mut activity: watch::Receiver<Instant>) {
627    let mut deadline = *activity.borrow();
628    let sleep = sleep_until(deadline.into());
629    pin_mut!(sleep);
630
631    loop {
632        tokio::select! {
633            biased;
634
635            Ok(()) = activity.changed() => {
636                let new_deadline = *activity.borrow_and_update();
637                if new_deadline != deadline {
638                    deadline = new_deadline;
639                    sleep.as_mut().reset(deadline.into());
640                }
641            },
642
643            () = &mut sleep => {
644                break;
645            },
646        }
647    }
648}
649
650/// Consumes `ws` by composing [`ws_recv_queue`], [`ws_recv_loop`],
651/// [`ws_client_message_handler`] and `message_handler`.
652///
653/// `idle_tx` is the sending end of a [`ws_idle_timer`]. The [`ws_recv_loop`]
654/// sends a new, extended deadline whenever it receives a message.
655///
656/// `unordered_tx` is used to send message execution errors
657/// or to initiate a close handshake.
658///
659/// Initiates a close handshake if the `message_handler` returns any variant
660/// of [`MessageHandleError`] that is **not** [`MessageHandleError::Execution`].
661///
662/// Terminates if:
663///
664/// - the `ws` stream is exhausted
665/// - or, `unordered_tx` is already closed
666///
667/// In the latter case, we assume that the connection is in an errored state,
668/// such that we wouldn't be able to receive any more messages anyway.
669async fn ws_recv_task<MessageHandler>(
670    state: Arc<ActorState>,
671    idle_tx: watch::Sender<Instant>,
672    client_closed_metric: IntGauge,
673    message_handler: impl Fn(DataMessage, Instant) -> MessageHandler,
674    unordered_tx: mpsc::UnboundedSender<UnorderedWsMessage>,
675    ws: impl Stream<Item = Result<WsMessage, WsError>> + Unpin + Send + 'static,
676) where
677    MessageHandler: Future<Output = Result<(), MessageHandleError>>,
678{
679    let recv_queue = ws_recv_queue(state.clone(), unordered_tx.clone(), ws);
680    let recv_loop = pin!(ws_recv_loop(state.clone(), idle_tx, recv_queue));
681    let recv_handler = ws_client_message_handler(state.clone(), client_closed_metric, recv_loop);
682    pin_mut!(recv_handler);
683
684    while let Some((data, timer)) = recv_handler.next().await {
685        let result = message_handler(data, timer).await;
686        if let Err(e) = result {
687            if let MessageHandleError::Execution(err) = e {
688                log::error!("{err:#}");
689                // If the send task has exited, also exit this recv task.
690                if unordered_tx.send(err.into()).is_err() {
691                    break;
692                }
693                continue;
694            }
695            log::debug!("Client caused error: {e}");
696            let close = CloseFrame {
697                code: CloseCode::Error,
698                reason: format!("{e:#}").into(),
699            };
700            // If the send task has exited, also exit this recv task.
701            // No need to send the close handshake in that case; the client is already gone.
702            if unordered_tx.send(close.into()).is_err() {
703                break;
704            };
705        }
706    }
707}
708
709/// Stream that consumes a stream of [`WsMessage`]s and yields [`ClientMessage`]s.
710///
711/// Terminates if:
712///
713/// - the input stream is exhausted
714/// - the input stream yields an error
715///
716/// If `state.closed`, continues to poll the input stream in order for the
717/// websocket close handshake to complete. Any messages received while in this
718/// state are dropped.
719fn ws_recv_loop(
720    state: Arc<ActorState>,
721    idle_tx: watch::Sender<Instant>,
722    mut ws: impl Stream<Item = Result<WsMessage, WsError>> + Unpin,
723) -> impl Stream<Item = ClientMessage> {
724    // Get the next message from `ws`, or `None` if the stream is exhausted.
725    //
726    // If `state.closed`, `ws` is drained until it either yields an `Err`, is
727    // exhausted, or a timeout of 250ms has elapsed.
728    async fn next_message(
729        state: &ActorState,
730        ws: &mut (impl Stream<Item = Result<WsMessage, WsError>> + Unpin),
731    ) -> Option<Result<WsMessage, WsError>> {
732        if state.closed() {
733            log::trace!("drain websocket waiting for client close");
734            let res: Result<Option<Result<WsMessage, WsError>>, Elapsed> =
735                timeout(state.config.close_handshake_timeout, async {
736                    while let Some(item) = ws.next().await {
737                        match item {
738                            Ok(message) => drop(message),
739                            Err(e) => return Some(Err(e)),
740                        }
741                    }
742                    None
743                })
744                .await;
745            match res {
746                Err(_elapsed) => {
747                    log::warn!("timeout waiting for client close");
748                    None
749                }
750                Ok(item) => item, // either error or `None`
751            }
752        } else {
753            log::trace!("await next client message without timeout");
754            ws.next().await
755        }
756    }
757
758    stream! {
759        loop {
760            let Some(res) = next_message(&state, &mut ws).await else {
761                log::trace!("recv stream exhausted");
762                break;
763            };
764            match res {
765                Ok(m) => {
766                    idle_tx.send(state.next_idle_deadline()).ok();
767
768                    if !state.closed() {
769                        yield ClientMessage::from_message(m);
770                    }
771                    // If closed, keep polling until either:
772                    //
773                    // - the client sends a close frame (`ws` returns `None)
774                    // - or `ws` yields an error
775                    log::trace!("message received while already closed");
776                }
777                // None of the error cases can be meaningfully recovered from
778                // (and some can't even occur on the `ws` stream).
779                // Exit here but spell out an exhaustive match
780                // in order to bring any future library changes to our attention.
781                Err(e) => match e {
782                    e @ (WsError::ConnectionClosed
783                    | WsError::AlreadyClosed
784                    | WsError::Io(_)
785                    | WsError::Tls(_)
786                    | WsError::Capacity(_)
787                    | WsError::Protocol(_)
788                    | WsError::WriteBufferFull(_)
789                    | WsError::Utf8(_)
790                    | WsError::AttackAttempt
791                    | WsError::Url(_)
792                    | WsError::Http(_)
793                    | WsError::HttpFormat(_)) => {
794                        log::warn!("Websocket receive error: {e}");
795                        break;
796                    }
797                },
798            }
799        }
800    }
801}
802
803/// Consumes `ws` and queues its items in a channel.
804///
805/// The channel is initialized with [`ActorConfig::incoming_queue_length`].
806/// If it is at capacity, a connection shutdown is initiated by sending
807/// [`UnorderedWsMessage::Close`] via `unordered_tx`.
808///
809/// Returns the channel receiver.
810///
811/// NOTE: This function is provided for backwards-compatibility, in particular
812/// SDK clients not handling backpressure gracefully, and for observability of
813/// transaction backlogging. It will probably go away in the future, see [#1851].
814///
815/// [#1851]: https://github.com/clockworklabs/SpacetimeDBPrivate/issues/1851
816fn ws_recv_queue(
817    state: Arc<ActorState>,
818    unordered_tx: mpsc::UnboundedSender<UnorderedWsMessage>,
819    mut ws: impl Stream<Item = Result<WsMessage, WsError>> + Unpin + Send + 'static,
820) -> impl Stream<Item = Result<WsMessage, WsError>> {
821    const CLOSE: UnorderedWsMessage = UnorderedWsMessage::Close(CloseFrame {
822        code: CloseCode::Again,
823        reason: Utf8Bytes::from_static("too many requests"),
824    });
825    let on_message_after_close = move |client_id| {
826        log::warn!("client {client_id} sent message after close or error");
827    };
828
829    let (tx, rx) = mpsc::channel(state.config.incoming_queue_length.get());
830    let rx = MeteredReceiverStream {
831        inner: MeteredReceiver::with_gauge(
832            rx,
833            WORKER_METRICS
834                .total_incoming_queue_length
835                .with_label_values(&state.database),
836        ),
837    };
838
839    tokio::spawn(async move {
840        while let Some(item) = ws.next().await {
841            if let Err(e) = tx.try_send(item) {
842                match e {
843                    // If the queue is full, disconnect the client.
844                    mpsc::error::TrySendError::Full(item) => {
845                        // If we can't send close (send task already terminated):
846                        //
847                        // - Let downstream handlers know that we're closing,
848                        //   so that remaining items in the queue are dropped.
849                        //
850                        // - Then exit the loop, as we won't be processing any
851                        //   more messages, and we don't expect a close response
852                        //   to arrive from the client.
853                        if unordered_tx.send(CLOSE).is_err() {
854                            state.close();
855                            break;
856                        }
857                        // If we successfully enqueued `CLOSE`, enqueue `item`
858                        // as well, as soon as there is space in the channel.
859                        //
860                        // This is to allow the client to complete the close
861                        // handshake, for which the downstream handler needs to
862                        // drain the queue.
863                        //
864                        // If `tx.send` fails, the pipeline is broken, so exit.
865                        // See commentary on the `TrySendError::Closed` match
866                        // arm below.
867                        if tx.send(item).await.is_err() {
868                            on_message_after_close(state.client_id);
869                            break;
870                        }
871                    }
872                    // If the downstream consumer went away,
873                    // it has consumed a `Close` frame or `Err` value
874                    // from the queue and thus has determined that it's done.
875                    //
876                    // Well-behaved clients shouldn't send anything after
877                    // closing, so issue a warning.
878                    //
879                    // We're done either way, so break.
880                    mpsc::error::TrySendError::Closed(_item) => {
881                        on_message_after_close(state.client_id);
882                        break;
883                    }
884                }
885            }
886        }
887    });
888
889    rx
890}
891
892/// Turns a [`MeteredReceiver`] into a [`Stream`],
893/// like [`tokio_stream::wrappers::ReceiverStream`] does for [`mpsc::Receiver`].
894struct MeteredReceiverStream<T> {
895    inner: MeteredReceiver<T>,
896}
897
898impl<T> Stream for MeteredReceiverStream<T> {
899    type Item = T;
900
901    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
902        self.inner.poll_recv(cx)
903    }
904}
905
906/// Stream that consumes [`ClientMessage`]s and yields [`DataMessage`]s for
907/// evaluation.
908///
909/// Calls `state.set_ponged()` if and when the input yields a pong message.
910/// Calls `state.close()` if and when the input yields a close frame,
911/// i.e. the client initiated a close handshake, which we track using the
912/// `client_closed_metric`.
913///
914/// Terminates if and when the input stream terminates.
915fn ws_client_message_handler(
916    state: Arc<ActorState>,
917    client_closed_metric: IntGauge,
918    mut messages: impl Stream<Item = ClientMessage> + Unpin,
919) -> impl Stream<Item = (DataMessage, Instant)> {
920    stream! {
921        while let Some(message) = messages.next().await {
922            match message {
923                ClientMessage::Message(message) => {
924                    log::trace!("Received client message");
925                    yield (message, Instant::now());
926                },
927                ClientMessage::Ping(_bytes) => {
928                    log::trace!("Received ping from client {}", state.client_id);
929                    // `tungstenite` will respond with `Pong` for us,
930                    // no need to send it ourselves.
931                },
932                ClientMessage::Pong(_bytes) => {
933                    log::trace!("Received pong from client {}", state.client_id);
934                    state.set_ponged();
935                },
936                ClientMessage::Close(close_frame) => {
937                    log::trace!("Received Close frame from client {}: {:?}", state.client_id, close_frame);
938                    let was_closed = state.close();
939                    // This is the client telling us they want to close.
940                    if !was_closed {
941                        client_closed_metric.inc();
942                    }
943                }
944            }
945        }
946        log::trace!("client message handler done");
947    }
948}
949
950/// Outgoing messages that don't need to be ordered wrt subscription updates.
951#[derive(Debug, From)]
952enum UnorderedWsMessage {
953    /// Server-initiated close.
954    Close(CloseFrame),
955    /// Server-initiated ping.
956    Ping(Bytes),
957    /// Error calling a reducer.
958    ///
959    /// The error indicates that the reducer was **not** called,
960    /// and can thus be unordered wrt subscription updates.
961    Error(MessageExecutionError),
962}
963
964/// Sink that sends outgoing messages to the `ws` sink.
965///
966/// Consumes `messages`, which yields subscription updates and reducer call
967/// results. Note that [`SerializableMessage`]s require serialization and
968/// potentially compression, which can be costly.
969/// Also consumes `unordered`, which yields [`UnorderedWsMessage`]s.
970///
971/// Terminates if:
972///
973/// - `unordered` is closed
974/// - an error occurs sending to the `ws` sink
975///
976/// If an [`UnorderedWsMessage::Close`] is encountered, a close frame is sent
977/// to the `ws` sink, and `state.close()` is called. When this happens,
978/// `messages` will no longer be polled (no data can be sent after a close
979/// frame anyways), so `messages.close()` will be called.
980///
981/// Keeps polling `unordered` if `state.closed()`, but discards all data.
982/// This is so `ws_client_actor_inner` keeps polling the receive end of the
983/// socket until the close handshake completes -- it would otherwise exit early
984/// when sending to `unordered` fails.
985async fn ws_send_loop(
986    state: Arc<ActorState>,
987    config: ClientConfig,
988    mut ws: impl Sink<WsMessage, Error: Display> + Unpin,
989    mut messages: MeteredReceiver<SerializableMessage>,
990    mut unordered: mpsc::UnboundedReceiver<UnorderedWsMessage>,
991) {
992    let mut messages_buf = Vec::with_capacity(32);
993    let mut serialize_buf = SerializeBuffer::new(config);
994
995    loop {
996        let closed = state.closed();
997
998        tokio::select! {
999            // `biased` towards the unordered queue,
1000            // which may initiate a connection shutdown.
1001            biased;
1002
1003            maybe_msg = unordered.recv() => {
1004                let Some(msg) = maybe_msg else {
1005                    break;
1006                };
1007                // We shall not sent more data after a close frame,
1008                // but keep polling `unordered` so that `ws_client_actor` keeps
1009                // waiting for an acknowledgement from the client,
1010                // even if it spuriously initiates another close itself.
1011                if closed {
1012                    continue;
1013                }
1014                match msg {
1015                    UnorderedWsMessage::Close(close_frame) => {
1016                        log::trace!("sending close frame");
1017                        if let Err(e) = ws.send(WsMessage::Close(Some(close_frame))).await {
1018                            log::warn!("error sending close frame: {e:#}");
1019                            break;
1020                        }
1021                        // NOTE: It's ok to not update the state if we fail to
1022                        // send the close frame, because we assume that the main
1023                        // loop will exit when this future terminates.
1024                        state.close();
1025                        // We won't be polling `messages` anymore,
1026                        // so let senders know.
1027                        messages.close();
1028                    },
1029                    UnorderedWsMessage::Ping(bytes) => {
1030                        log::trace!("sending ping");
1031                        if let Err(e) = ws.feed(WsMessage::Ping(bytes)).await {
1032                            log::warn!("error sending ping: {e:#}");
1033                            break;
1034                        }
1035                    },
1036                    UnorderedWsMessage::Error(err) => {
1037                        log::trace!("sending error result");
1038                        let (msg_alloc, res) = send_message(
1039                            &state.database,
1040                            config,
1041                            serialize_buf,
1042                            None,
1043                            &mut ws,
1044                            err
1045                        ).await;
1046                        serialize_buf = msg_alloc;
1047
1048                        if let Err(e) = res {
1049                            log::warn!("websocket send error: {e}");
1050                            break;
1051                        }
1052                    },
1053                }
1054            },
1055
1056            n = messages.recv_many(&mut messages_buf, 32), if !closed => {
1057                if n == 0 {
1058                    continue;
1059                }
1060                log::trace!("sending {n} outgoing messages");
1061                for msg in messages_buf.drain(..n) {
1062                    let (msg_alloc, res) = send_message(
1063                        &state.database,
1064                        config,
1065                        serialize_buf,
1066                        msg.workload().zip(msg.num_rows()),
1067                        &mut ws,
1068                        msg
1069                    ).await;
1070                    serialize_buf = msg_alloc;
1071
1072                    if let Err(e) = res {
1073                        log::warn!("websocket send error: {e}");
1074                        return;
1075                    }
1076                }
1077            },
1078        }
1079
1080        if let Err(e) = ws.flush().await {
1081            log::warn!("error flushing websocket: {e}");
1082            break;
1083        }
1084    }
1085}
1086
1087/// Serialize and potentially compress `message`, and feed it to the `ws` sink.
1088async fn send_message<S: Sink<WsMessage> + Unpin>(
1089    database_identity: &Identity,
1090    config: ClientConfig,
1091    serialize_buf: SerializeBuffer,
1092    metrics_metadata: Option<(WorkloadType, usize)>,
1093    ws: &mut S,
1094    message: impl ToProtocol<Encoded = SwitchedServerMessage> + Send + 'static,
1095) -> (SerializeBuffer, Result<(), S::Error>) {
1096    let (workload, num_rows) = metrics_metadata.unzip();
1097    // Move large messages to a rayon thread,
1098    // as serialization and compression can take a long time.
1099    // The threshold of 1024 rows is arbitrary, and may need to be refined.
1100    let serialize_and_compress = |serialize_buf, message, config| {
1101        let start = Instant::now();
1102        let (msg_alloc, msg_data) = serialize(serialize_buf, message, config);
1103        (start.elapsed(), msg_alloc, msg_data)
1104    };
1105    let (timing, msg_alloc, msg_data) = if num_rows.is_some_and(|n| n > 1024) {
1106        spawn_rayon(move || serialize_and_compress(serialize_buf, message, config)).await
1107    } else {
1108        serialize_and_compress(serialize_buf, message, config)
1109    };
1110    report_ws_sent_metrics(database_identity, workload, num_rows, timing, &msg_data);
1111
1112    let res = async {
1113        ws.feed(datamsg_to_wsmsg(msg_data)).await?;
1114        // To reclaim the `msg_alloc` memory, we need `SplitSink` to push down
1115        // its item slot to the inner sink, which will copy the `Bytes` and
1116        // drop the reference.
1117        // We don't want to flush the inner sink just yet, as we might be
1118        // writing many messages.
1119        // `SplitSink::poll_ready` does what we want.
1120        poll_fn(|cx| ws.poll_ready_unpin(cx)).await
1121    }
1122    .await;
1123    // Reclaim can fail if we didn't succeed pushing down the data to the
1124    // websocket. We must return a buffer, though, so create a fresh one.
1125    let buf = msg_alloc.try_reclaim().unwrap_or_else(|| SerializeBuffer::new(config));
1126
1127    (buf, res)
1128}
1129
1130#[derive(Debug)]
1131enum ClientMessage {
1132    Message(DataMessage),
1133    Ping(Bytes),
1134    Pong(Bytes),
1135    Close(Option<CloseFrame>),
1136}
1137
1138impl ClientMessage {
1139    fn from_message(msg: WsMessage) -> Self {
1140        match msg {
1141            WsMessage::Text(s) => Self::Message(DataMessage::Text(utf8bytes_to_bytestring(s))),
1142            WsMessage::Binary(b) => Self::Message(DataMessage::Binary(b)),
1143            WsMessage::Ping(b) => Self::Ping(b),
1144            WsMessage::Pong(b) => Self::Pong(b),
1145            WsMessage::Close(frame) => Self::Close(frame),
1146            // WebSocket::read_message() never returns a raw Message::Frame
1147            WsMessage::Frame(_) => unreachable!(),
1148        }
1149    }
1150}
1151
1152/// Report metrics on sent rows and message sizes to a websocket client.
1153fn report_ws_sent_metrics(
1154    addr: &Identity,
1155    workload: Option<WorkloadType>,
1156    num_rows: Option<usize>,
1157    serialize_duration: Duration,
1158    msg_ws: &DataMessage,
1159) {
1160    // These metrics should be updated together,
1161    // or not at all.
1162    if let (Some(workload), Some(num_rows)) = (workload, num_rows) {
1163        WORKER_METRICS
1164            .websocket_sent_num_rows
1165            .with_label_values(addr, &workload)
1166            .observe(num_rows as f64);
1167        WORKER_METRICS
1168            .websocket_sent_msg_size
1169            .with_label_values(addr, &workload)
1170            .observe(msg_ws.len() as f64);
1171    }
1172
1173    WORKER_METRICS
1174        .websocket_serialize_secs
1175        .with_label_values(addr)
1176        .observe(serialize_duration.as_secs_f64());
1177}
1178
1179fn datamsg_to_wsmsg(msg: DataMessage) -> WsMessage {
1180    match msg {
1181        DataMessage::Text(text) => WsMessage::Text(bytestring_to_utf8bytes(text)),
1182        DataMessage::Binary(bin) => WsMessage::Binary(bin),
1183    }
1184}
1185
1186fn utf8bytes_to_bytestring(s: Utf8Bytes) -> ByteString {
1187    // SAFETY: `Utf8Bytes` and `ByteString` have the same invariant of UTF-8 validity
1188    unsafe { ByteString::from_bytes_unchecked(Bytes::from(s)) }
1189}
1190fn bytestring_to_utf8bytes(s: ByteString) -> Utf8Bytes {
1191    // SAFETY: `Utf8Bytes` and `ByteString` have the same invariant of UTF-8 validity
1192    unsafe { Utf8Bytes::from_bytes_unchecked(s.into_bytes()) }
1193}
1194
1195#[cfg(test)]
1196mod tests {
1197    use std::{
1198        future::Future,
1199        pin::Pin,
1200        sync::atomic::AtomicUsize,
1201        task::{Context, Poll},
1202    };
1203
1204    use anyhow::anyhow;
1205    use futures::{
1206        future::{self, Either, FutureExt as _},
1207        sink, stream,
1208    };
1209    use pretty_assertions::assert_matches;
1210    use spacetimedb::client::ClientName;
1211    use tokio::time::sleep;
1212
1213    use super::*;
1214
1215    fn dummy_client_id() -> ClientActorId {
1216        ClientActorId {
1217            identity: Identity::ZERO,
1218            connection_id: ConnectionId::ZERO,
1219            name: ClientName(0),
1220        }
1221    }
1222
1223    fn dummy_actor_state() -> ActorState {
1224        dummy_actor_state_with_config(<_>::default())
1225    }
1226
1227    fn dummy_actor_state_with_config(config: WebSocketOptions) -> ActorState {
1228        ActorState::new(Identity::ZERO, dummy_client_id(), config)
1229    }
1230
1231    #[tokio::test]
1232    async fn idle_timer_extends_sleep() {
1233        let timeout = Duration::from_millis(10);
1234
1235        let start = Instant::now();
1236        let (tx, rx) = watch::channel(start + timeout);
1237        tokio::join!(ws_idle_timer(rx), async {
1238            for _ in 0..5 {
1239                sleep(Duration::from_millis(1)).await;
1240                tx.send(Instant::now() + timeout).unwrap();
1241            }
1242        });
1243        let elapsed = start.elapsed();
1244        let expected = timeout + Duration::from_millis(5);
1245        assert!(
1246            elapsed >= expected,
1247            "{}ms elapsed, expected >= {}ms",
1248            elapsed.as_millis(),
1249            expected.as_millis(),
1250        );
1251    }
1252
1253    #[tokio::test]
1254    async fn recv_loop_terminates_when_input_exhausted() {
1255        let state = Arc::new(dummy_actor_state());
1256        let (idle_tx, _idle_rx) = watch::channel(Instant::now() + state.config.idle_timeout);
1257
1258        let input = stream::iter(vec![Ok(WsMessage::Ping(Bytes::new()))]);
1259        pin_mut!(input);
1260
1261        let recv_loop = ws_recv_loop(state, idle_tx, input);
1262        pin_mut!(recv_loop);
1263
1264        assert_matches!(recv_loop.next().await, Some(ClientMessage::Ping(_)));
1265        assert_matches!(recv_loop.next().await, None);
1266    }
1267
1268    #[tokio::test]
1269    async fn recv_loop_terminates_when_input_yields_err() {
1270        let state = Arc::new(dummy_actor_state());
1271        let (idle_tx, _idle_rx) = watch::channel(Instant::now() + state.config.idle_timeout);
1272
1273        let input = stream::iter(vec![
1274            Ok(WsMessage::Ping(Bytes::new())),
1275            Err(WsError::ConnectionClosed),
1276            Ok(WsMessage::Pong(Bytes::new())),
1277        ]);
1278        pin_mut!(input);
1279
1280        let recv_loop = ws_recv_loop(state, idle_tx, input);
1281        pin_mut!(recv_loop);
1282
1283        assert_matches!(recv_loop.next().await, Some(ClientMessage::Ping(_)));
1284        assert_matches!(recv_loop.next().await, None);
1285    }
1286
1287    #[tokio::test]
1288    async fn recv_loop_drains_remaining_messages_when_closed() {
1289        let state = Arc::new(dummy_actor_state());
1290        let (idle_tx, _idle_rx) = watch::channel(Instant::now() + state.config.idle_timeout);
1291
1292        let input = stream::iter(vec![
1293            Ok(WsMessage::Ping(Bytes::new())),
1294            Ok(WsMessage::Pong(Bytes::new())),
1295        ]);
1296        pin_mut!(input);
1297        {
1298            let recv_loop = ws_recv_loop(state.clone(), idle_tx, &mut input);
1299            pin_mut!(recv_loop);
1300
1301            state.close();
1302            assert_matches!(recv_loop.next().await, None);
1303        }
1304        assert_matches!(input.next().await, None);
1305    }
1306
1307    #[tokio::test]
1308    async fn recv_loop_stops_at_error_while_draining() {
1309        let state = Arc::new(dummy_actor_state());
1310        let (idle_tx, _idle_rx) = watch::channel(Instant::now() + state.config.idle_timeout);
1311
1312        let input = stream::iter(vec![
1313            Ok(WsMessage::Ping(Bytes::new())),
1314            Err(WsError::ConnectionClosed),
1315            Ok(WsMessage::Pong(Bytes::new())),
1316        ]);
1317        pin_mut!(input);
1318        {
1319            let recv_loop = ws_recv_loop(state.clone(), idle_tx, &mut input);
1320            pin_mut!(recv_loop);
1321
1322            state.close();
1323            assert_matches!(recv_loop.next().await, None);
1324        }
1325        assert_matches!(input.next().await, Some(Ok(WsMessage::Pong(_))));
1326    }
1327
1328    #[tokio::test]
1329    async fn recv_loop_updates_idle_channel() {
1330        let state = Arc::new(dummy_actor_state());
1331        let idle_deadline = Instant::now() + state.config.idle_timeout;
1332        let (idle_tx, mut idle_rx) = watch::channel(idle_deadline);
1333
1334        let input = stream::iter(vec![
1335            Ok(WsMessage::Ping(Bytes::new())),
1336            Ok(WsMessage::Pong(Bytes::new())),
1337        ]);
1338        let recv_loop = ws_recv_loop(state, idle_tx, input);
1339        pin_mut!(recv_loop);
1340
1341        let mut new_idle_deadline = *idle_rx.borrow();
1342        while let Some(message) = recv_loop.next().await {
1343            drop(message);
1344            assert!(idle_rx.has_changed().unwrap());
1345            new_idle_deadline = *idle_rx.borrow_and_update();
1346        }
1347        assert!(new_idle_deadline > idle_deadline);
1348    }
1349
1350    #[tokio::test]
1351    async fn client_message_handler_terminates_when_input_exhausted() {
1352        let state = Arc::new(dummy_actor_state());
1353        let metric = IntGauge::new("bleep", "unhelpful").unwrap();
1354
1355        let input = stream::iter(vec![
1356            ClientMessage::Ping(Bytes::new()),
1357            ClientMessage::Message(DataMessage::from("hello".to_owned())),
1358        ]);
1359        let handler = ws_client_message_handler(state, metric, input);
1360        pin_mut!(handler);
1361
1362        assert_matches!(
1363            handler.next().await,
1364            Some((DataMessage::Text(data), _instant)) if data == "hello"
1365        );
1366        assert_matches!(handler.next().await, None);
1367    }
1368
1369    #[tokio::test]
1370    async fn client_message_handler_updates_pong_and_closed_states_and_metric() {
1371        let state = Arc::new(dummy_actor_state());
1372        state.reset_ponged();
1373        let metric = IntGauge::new("bleep", "unhelpful").unwrap();
1374
1375        let input = stream::iter(vec![ClientMessage::Pong(Bytes::new()), ClientMessage::Close(None)]);
1376        let handler = ws_client_message_handler(state.clone(), metric.clone(), input);
1377        handler.map(drop).for_each(future::ready).await;
1378
1379        assert!(state.closed());
1380        assert!(state.reset_ponged());
1381        assert_eq!(metric.get(), 1);
1382    }
1383
1384    #[tokio::test]
1385    async fn send_loop_terminates_when_unordered_closed() {
1386        let state = Arc::new(dummy_actor_state());
1387        let (messages_tx, messages_rx) = mpsc::channel(64);
1388        let messages = MeteredReceiver::new(messages_rx);
1389        let (unordered_tx, unordered_rx) = mpsc::unbounded_channel();
1390
1391        let send_loop = ws_send_loop(state, ClientConfig::for_test(), sink::drain(), messages, unordered_rx);
1392        pin_mut!(send_loop);
1393
1394        assert!(is_pending(&mut send_loop).await);
1395        drop(messages_tx);
1396        assert!(is_pending(&mut send_loop).await);
1397
1398        drop(unordered_tx);
1399        send_loop.await;
1400    }
1401
1402    #[tokio::test]
1403    async fn send_loop_close_message_closes_state_and_messages() {
1404        let state = Arc::new(dummy_actor_state());
1405        let (messages_tx, messages_rx) = mpsc::channel(64);
1406        let messages = MeteredReceiver::new(messages_rx);
1407        let (unordered_tx, unordered_rx) = mpsc::unbounded_channel();
1408
1409        let send_loop = ws_send_loop(
1410            state.clone(),
1411            ClientConfig::for_test(),
1412            sink::drain(),
1413            messages,
1414            unordered_rx,
1415        );
1416        pin_mut!(send_loop);
1417
1418        unordered_tx
1419            .send(UnorderedWsMessage::Close(CloseFrame {
1420                code: CloseCode::Away,
1421                reason: "done".into(),
1422            }))
1423            .unwrap();
1424
1425        assert!(is_pending(&mut send_loop).await);
1426        assert!(state.closed());
1427        assert!(messages_tx.is_closed());
1428    }
1429
1430    #[tokio::test]
1431    async fn send_loop_terminates_if_sink_cant_be_fed() {
1432        let input = [
1433            Either::Left(UnorderedWsMessage::Close(CloseFrame {
1434                code: CloseCode::Away,
1435                reason: "bah!".into(),
1436            })),
1437            Either::Left(UnorderedWsMessage::Ping(Bytes::new())),
1438            Either::Left(UnorderedWsMessage::Error(MessageExecutionError {
1439                reducer: None,
1440                reducer_id: None,
1441                caller_identity: Identity::ZERO,
1442                caller_connection_id: None,
1443                err: anyhow!("it did not work"),
1444            })),
1445            // TODO: This is the easiest to construct,
1446            // but maybe we want other variants, too.
1447            Either::Right(SerializableMessage::Identity(IdentityTokenMessage {
1448                identity: Identity::ZERO,
1449                token: "macaron".into(),
1450                connection_id: ConnectionId::ZERO,
1451            })),
1452        ];
1453
1454        for msg in input {
1455            let state = Arc::new(dummy_actor_state());
1456            let (messages_tx, messages_rx) = mpsc::channel(64);
1457            let messages = MeteredReceiver::new(messages_rx);
1458            let (unordered_tx, unordered_rx) = mpsc::unbounded_channel();
1459
1460            let send_loop = ws_send_loop(
1461                state.clone(),
1462                ClientConfig::for_test(),
1463                UnfeedableSink,
1464                messages,
1465                unordered_rx,
1466            );
1467            pin_mut!(send_loop);
1468
1469            match msg {
1470                Either::Left(unordered) => unordered_tx.send(unordered).unwrap(),
1471                Either::Right(msg) => messages_tx.send(msg).await.unwrap(),
1472            }
1473            send_loop.await;
1474        }
1475    }
1476
1477    #[tokio::test]
1478    async fn send_loop_terminates_if_sink_cant_be_flushed() {
1479        let input = [
1480            Either::Left(UnorderedWsMessage::Close(CloseFrame {
1481                code: CloseCode::Away,
1482                reason: "bah!".into(),
1483            })),
1484            Either::Left(UnorderedWsMessage::Ping(Bytes::new())),
1485            Either::Left(UnorderedWsMessage::Error(MessageExecutionError {
1486                reducer: None,
1487                reducer_id: None,
1488                caller_identity: Identity::ZERO,
1489                caller_connection_id: None,
1490                err: anyhow!("it did not work"),
1491            })),
1492            // TODO: This is the easiest to construct,
1493            // but maybe we want other variants, too.
1494            Either::Right(SerializableMessage::Identity(IdentityTokenMessage {
1495                identity: Identity::ZERO,
1496                token: "macaron".into(),
1497                connection_id: ConnectionId::ZERO,
1498            })),
1499        ];
1500
1501        for msg in input {
1502            let state = Arc::new(dummy_actor_state());
1503            let (messages_tx, messages_rx) = mpsc::channel(64);
1504            let messages = MeteredReceiver::new(messages_rx);
1505            let (unordered_tx, unordered_rx) = mpsc::unbounded_channel();
1506
1507            let send_loop = ws_send_loop(
1508                state.clone(),
1509                ClientConfig::for_test(),
1510                UnflushableSink,
1511                messages,
1512                unordered_rx,
1513            );
1514            pin_mut!(send_loop);
1515
1516            match msg {
1517                Either::Left(unordered) => unordered_tx.send(unordered).unwrap(),
1518                Either::Right(msg) => messages_tx.send(msg).await.unwrap(),
1519            }
1520            send_loop.await;
1521        }
1522    }
1523
1524    #[tokio::test]
1525    async fn main_loop_terminates_if_either_send_or_recv_terminates() {
1526        let state = Arc::new(dummy_actor_state());
1527        ws_main_loop(
1528            state.clone(),
1529            future::pending,
1530            future::pending(),
1531            tokio::spawn(sleep(Duration::from_millis(10))),
1532            tokio::spawn(future::pending()),
1533            drop,
1534        )
1535        .await;
1536        ws_main_loop(
1537            state,
1538            future::pending,
1539            future::pending(),
1540            tokio::spawn(future::pending()),
1541            tokio::spawn(sleep(Duration::from_millis(10))),
1542            drop,
1543        )
1544        .await;
1545    }
1546
1547    #[tokio::test]
1548    async fn main_loop_terminates_on_idle_timeout() {
1549        let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
1550            idle_timeout: Duration::from_millis(10),
1551            ..<_>::default()
1552        }));
1553        let (idle_tx, idle_rx) = watch::channel(state.next_idle_deadline());
1554
1555        let start = Instant::now();
1556        let mut t = tokio::spawn({
1557            let state = state.clone();
1558            async move {
1559                ws_main_loop(
1560                    state,
1561                    future::pending,
1562                    ws_idle_timer(idle_rx),
1563                    tokio::spawn(future::pending()),
1564                    tokio::spawn(future::pending()),
1565                    drop,
1566                )
1567                .await
1568            }
1569        });
1570
1571        let loop_start = Instant::now();
1572        for _ in 0..5 {
1573            sleep(Duration::from_millis(5)).await;
1574            idle_tx.send(state.next_idle_deadline()).unwrap();
1575            assert!(is_pending(&mut t).await);
1576        }
1577        let timeout = loop_start.elapsed() + Duration::from_millis(10);
1578
1579        t.await.unwrap();
1580        let elapsed = start.elapsed();
1581        assert!(elapsed >= timeout);
1582        assert!(elapsed < timeout + Duration::from_millis(10));
1583    }
1584
1585    #[tokio::test]
1586    async fn main_loop_keepalive_keeps_alive() {
1587        let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
1588            ping_interval: Duration::from_millis(5),
1589            idle_timeout: Duration::from_millis(10),
1590            ..<_>::default()
1591        }));
1592        let (idle_tx, idle_rx) = watch::channel(state.next_idle_deadline());
1593        // Pretend we received a pong immediately after sending a ping,
1594        // but only five times.
1595        let unordered_tx = {
1596            let state = state.clone();
1597            let pings = AtomicUsize::new(0);
1598            move |m| {
1599                if let UnorderedWsMessage::Ping(_) = m {
1600                    let n = pings.fetch_add(1, Ordering::Relaxed);
1601                    if n < 5 {
1602                        state.set_ponged();
1603                        idle_tx.send(state.next_idle_deadline()).ok();
1604                    }
1605                }
1606            }
1607        };
1608
1609        let start = Instant::now();
1610        let t = tokio::spawn({
1611            let state = state.clone();
1612            async move {
1613                ws_main_loop(
1614                    state,
1615                    future::pending,
1616                    ws_idle_timer(idle_rx),
1617                    tokio::spawn(future::pending()),
1618                    tokio::spawn(future::pending()),
1619                    unordered_tx,
1620                )
1621                .await
1622            }
1623        });
1624
1625        let expected_timeout = (5 * state.config.ping_interval) + state.config.idle_timeout;
1626        let res = timeout(expected_timeout, t).await;
1627        let elapsed = start.elapsed();
1628
1629        // It didn't time out.
1630        assert_matches!(res, Ok(Ok(())));
1631        // It didn't exit early. Allow it to miss a ping.
1632        assert!(elapsed >= expected_timeout - state.config.ping_interval);
1633    }
1634
1635    #[tokio::test]
1636    async fn main_loop_terminates_when_module_exits() {
1637        let state = Arc::new(dummy_actor_state());
1638
1639        let (_idle_tx, idle_rx) = watch::channel(state.next_idle_deadline());
1640        let unordered_tx = {
1641            let state = state.clone();
1642            move |m| {
1643                if let UnorderedWsMessage::Close(_) = m {
1644                    state.close();
1645                }
1646            }
1647        };
1648
1649        let start = Instant::now();
1650        tokio::spawn(async move {
1651            let hotswap = || async {
1652                sleep(Duration::from_millis(5)).await;
1653                Err(NoSuchModule)
1654            };
1655
1656            ws_main_loop(
1657                state.clone(),
1658                hotswap,
1659                ws_idle_timer(idle_rx),
1660                // Pretend we received a close immediately after sending one.
1661                tokio::spawn(async move {
1662                    loop {
1663                        if state.closed() {
1664                            break;
1665                        }
1666                        sleep(Duration::from_millis(1)).await
1667                    }
1668                }),
1669                tokio::spawn(future::pending()),
1670                unordered_tx,
1671            )
1672            .await
1673        })
1674        .await
1675        .unwrap();
1676        let elapsed = start.elapsed();
1677        assert!(elapsed >= Duration::from_millis(5));
1678        assert!(elapsed < Duration::from_millis(10));
1679    }
1680
1681    #[tokio::test]
1682    async fn recv_queue_sends_close_when_at_capacity() {
1683        let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
1684            incoming_queue_length: 10.try_into().unwrap(),
1685            ..<_>::default()
1686        }));
1687
1688        let (unordered_tx, mut unordered_rx) = mpsc::unbounded_channel();
1689        let input = stream::iter((0..20).map(|i| Ok(WsMessage::text(format!("message {i}")))));
1690
1691        let received = ws_recv_queue(state, unordered_tx, input).collect::<Vec<_>>().await;
1692        assert_matches!(unordered_rx.recv().await, Some(UnorderedWsMessage::Close(_)));
1693        // Should have received all of the input.
1694        assert_eq!(received.len(), 20);
1695    }
1696
1697    #[tokio::test]
1698    async fn recv_queue_closes_state_if_sender_gone() {
1699        let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
1700            incoming_queue_length: 10.try_into().unwrap(),
1701            ..<_>::default()
1702        }));
1703
1704        let (unordered_tx, _) = mpsc::unbounded_channel();
1705        let input = stream::iter((0..20).map(|i| Ok(WsMessage::text(format!("message {i}")))));
1706
1707        let received = ws_recv_queue(state.clone(), unordered_tx, input)
1708            .collect::<Vec<_>>()
1709            .await;
1710        assert!(state.closed());
1711        // Should have received up to capacity.
1712        assert_eq!(received.len(), 10);
1713    }
1714
1715    async fn is_pending(fut: &mut (impl Future + Unpin)) -> bool {
1716        poll_fn(|cx| Poll::Ready(fut.poll_unpin(cx).is_pending())).await
1717    }
1718
1719    struct UnfeedableSink;
1720
1721    impl<T> Sink<T> for UnfeedableSink {
1722        type Error = &'static str;
1723
1724        fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1725            Poll::Ready(Ok(()))
1726        }
1727
1728        fn start_send(self: Pin<&mut Self>, _: T) -> Result<(), Self::Error> {
1729            Err("don't feed the sink")
1730        }
1731
1732        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1733            Poll::Ready(Ok(()))
1734        }
1735
1736        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1737            Poll::Ready(Ok(()))
1738        }
1739    }
1740
1741    struct UnflushableSink;
1742
1743    impl<T> Sink<T> for UnflushableSink {
1744        type Error = &'static str;
1745
1746        fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1747            Poll::Ready(Ok(()))
1748        }
1749
1750        fn start_send(self: Pin<&mut Self>, _: T) -> Result<(), Self::Error> {
1751            Ok(())
1752        }
1753
1754        fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1755            Poll::Ready(Err("don't flush the sink"))
1756        }
1757
1758        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1759            Poll::Ready(Ok(()))
1760        }
1761    }
1762
1763    #[test]
1764    fn options_toml_roundtrip() {
1765        let options = WebSocketOptions::default();
1766        let toml = toml::to_string(&options).unwrap();
1767        assert_eq!(options, toml::from_str::<WebSocketOptions>(&toml).unwrap());
1768    }
1769
1770    #[test]
1771    fn options_from_partial_toml() {
1772        let toml = r#"
1773            ping-interval = "53s"
1774            idle-timeout = "1m 3s"
1775"#;
1776
1777        let expected = WebSocketOptions {
1778            ping_interval: Duration::from_secs(53),
1779            idle_timeout: Duration::from_secs(63),
1780            ..<_>::default()
1781        };
1782
1783        assert_eq!(expected, toml::from_str(toml).unwrap());
1784    }
1785}