spacetimedb/client/
messages.rs

1use super::{ClientConfig, DataMessage, Protocol};
2use crate::execution_context::WorkloadType;
3use crate::host::module_host::{EventStatus, ModuleEvent};
4use crate::host::ArgsTuple;
5use crate::messages::websocket as ws;
6use derive_more::From;
7use spacetimedb_client_api_messages::websocket::{
8    BsatnFormat, Compression, FormatSwitch, JsonFormat, OneOffTable, RowListLen, WebsocketFormat,
9    SERVER_MSG_COMPRESSION_TAG_BROTLI, SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE,
10};
11use spacetimedb_lib::identity::RequestId;
12use spacetimedb_lib::ser::serde::SerializeWrapper;
13use spacetimedb_lib::{ConnectionId, TimeDuration};
14use spacetimedb_primitives::TableId;
15use spacetimedb_sats::bsatn;
16use std::sync::Arc;
17use std::time::Instant;
18
19/// A server-to-client message which can be encoded according to a [`Protocol`],
20/// resulting in a [`ToProtocol::Encoded`] message.
21pub trait ToProtocol {
22    type Encoded;
23    /// Convert `self` into a [`Self::Encoded`] where rows and arguments are encoded with `protocol`.
24    fn to_protocol(self, protocol: Protocol) -> Self::Encoded;
25}
26
27pub(super) type SwitchedServerMessage = FormatSwitch<ws::ServerMessage<BsatnFormat>, ws::ServerMessage<JsonFormat>>;
28pub(super) type SwitchedDbUpdate = FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>;
29
30/// Serialize `msg` into a [`DataMessage`] containing a [`ws::ServerMessage`].
31///
32/// If `protocol` is [`Protocol::Binary`],
33/// the message will be conditionally compressed by this method according to `compression`.
34pub fn serialize(msg: impl ToProtocol<Encoded = SwitchedServerMessage>, config: ClientConfig) -> DataMessage {
35    // TODO(centril, perf): here we are allocating buffers only to throw them away eventually.
36    // Consider pooling these allocations so that we reuse them.
37    match msg.to_protocol(config.protocol) {
38        FormatSwitch::Json(msg) => serde_json::to_string(&SerializeWrapper::new(msg)).unwrap().into(),
39        FormatSwitch::Bsatn(msg) => {
40            // First write the tag so that we avoid shifting the entire message at the end.
41            let mut msg_bytes = vec![SERVER_MSG_COMPRESSION_TAG_NONE];
42            bsatn::to_writer(&mut msg_bytes, &msg).unwrap();
43
44            // Conditionally compress the message.
45            let srv_msg = &msg_bytes[1..];
46            let msg_bytes = match ws::decide_compression(srv_msg.len(), config.compression) {
47                Compression::None => msg_bytes,
48                Compression::Brotli => {
49                    let mut out = vec![SERVER_MSG_COMPRESSION_TAG_BROTLI];
50                    ws::brotli_compress(srv_msg, &mut out);
51                    out
52                }
53                Compression::Gzip => {
54                    let mut out = vec![SERVER_MSG_COMPRESSION_TAG_GZIP];
55                    ws::gzip_compress(srv_msg, &mut out);
56                    out
57                }
58            };
59            msg_bytes.into()
60        }
61    }
62}
63
64#[derive(Debug, From)]
65pub enum SerializableMessage {
66    QueryBinary(OneOffQueryResponseMessage<BsatnFormat>),
67    QueryText(OneOffQueryResponseMessage<JsonFormat>),
68    Identity(IdentityTokenMessage),
69    Subscribe(SubscriptionUpdateMessage),
70    Subscription(SubscriptionMessage),
71    TxUpdate(TransactionUpdateMessage),
72}
73
74impl SerializableMessage {
75    pub fn num_rows(&self) -> Option<usize> {
76        match self {
77            Self::QueryBinary(msg) => Some(msg.num_rows()),
78            Self::QueryText(msg) => Some(msg.num_rows()),
79            Self::Subscribe(msg) => Some(msg.num_rows()),
80            Self::Subscription(msg) => Some(msg.num_rows()),
81            Self::TxUpdate(msg) => Some(msg.num_rows()),
82            Self::Identity(_) => None,
83        }
84    }
85
86    pub fn workload(&self) -> Option<WorkloadType> {
87        match self {
88            Self::QueryBinary(_) | Self::QueryText(_) => Some(WorkloadType::Sql),
89            Self::Subscribe(_) => Some(WorkloadType::Subscribe),
90            Self::Subscription(msg) => match &msg.result {
91                SubscriptionResult::Subscribe(_) => Some(WorkloadType::Subscribe),
92                SubscriptionResult::Unsubscribe(_) => Some(WorkloadType::Unsubscribe),
93                SubscriptionResult::Error(_) => None,
94                SubscriptionResult::SubscribeMulti(_) => Some(WorkloadType::Subscribe),
95                SubscriptionResult::UnsubscribeMulti(_) => Some(WorkloadType::Unsubscribe),
96            },
97            Self::TxUpdate(_) => Some(WorkloadType::Update),
98            Self::Identity(_) => None,
99        }
100    }
101}
102
103impl ToProtocol for SerializableMessage {
104    type Encoded = SwitchedServerMessage;
105    fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
106        match self {
107            SerializableMessage::QueryBinary(msg) => msg.to_protocol(protocol),
108            SerializableMessage::QueryText(msg) => msg.to_protocol(protocol),
109            SerializableMessage::Identity(msg) => msg.to_protocol(protocol),
110            SerializableMessage::Subscribe(msg) => msg.to_protocol(protocol),
111            SerializableMessage::TxUpdate(msg) => msg.to_protocol(protocol),
112            SerializableMessage::Subscription(msg) => msg.to_protocol(protocol),
113        }
114    }
115}
116
117pub type IdentityTokenMessage = ws::IdentityToken;
118
119impl ToProtocol for IdentityTokenMessage {
120    type Encoded = SwitchedServerMessage;
121    fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
122        match protocol {
123            Protocol::Text => FormatSwitch::Json(ws::ServerMessage::IdentityToken(self)),
124            Protocol::Binary => FormatSwitch::Bsatn(ws::ServerMessage::IdentityToken(self)),
125        }
126    }
127}
128
129#[derive(Debug)]
130pub struct TransactionUpdateMessage {
131    /// The event that caused this update.
132    /// When `None`, this is a light update.
133    pub event: Option<Arc<ModuleEvent>>,
134    pub database_update: SubscriptionUpdateMessage,
135}
136
137impl TransactionUpdateMessage {
138    fn num_rows(&self) -> usize {
139        self.database_update.num_rows()
140    }
141}
142
143impl ToProtocol for TransactionUpdateMessage {
144    type Encoded = SwitchedServerMessage;
145    fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
146        fn convert<F: WebsocketFormat>(
147            event: Option<Arc<ModuleEvent>>,
148            request_id: u32,
149            update: ws::DatabaseUpdate<F>,
150            conv_args: impl FnOnce(&ArgsTuple) -> F::Single,
151        ) -> ws::ServerMessage<F> {
152            let Some(event) = event else {
153                return ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { request_id, update });
154            };
155
156            let status = match &event.status {
157                EventStatus::Committed(_) => ws::UpdateStatus::Committed(update),
158                EventStatus::Failed(errmsg) => ws::UpdateStatus::Failed(errmsg.clone().into()),
159                EventStatus::OutOfEnergy => ws::UpdateStatus::OutOfEnergy,
160            };
161
162            let args = conv_args(&event.function_call.args);
163
164            let tx_update = ws::TransactionUpdate {
165                timestamp: event.timestamp,
166                status,
167                caller_identity: event.caller_identity,
168                reducer_call: ws::ReducerCallInfo {
169                    reducer_name: event.function_call.reducer.to_owned().into(),
170                    reducer_id: event.function_call.reducer_id.into(),
171                    args,
172                    request_id,
173                },
174                energy_quanta_used: event.energy_quanta_used,
175                total_host_execution_duration: event.host_execution_duration.into(),
176                caller_connection_id: event.caller_connection_id.unwrap_or(ConnectionId::ZERO),
177            };
178
179            ws::ServerMessage::TransactionUpdate(tx_update)
180        }
181
182        let TransactionUpdateMessage { event, database_update } = self;
183        let update = database_update.database_update;
184        protocol.assert_matches_format_switch(&update);
185        let request_id = database_update.request_id.unwrap_or(0);
186        match update {
187            FormatSwitch::Bsatn(update) => FormatSwitch::Bsatn(convert(event, request_id, update, |args| {
188                Vec::from(args.get_bsatn().clone()).into()
189            })),
190            FormatSwitch::Json(update) => {
191                FormatSwitch::Json(convert(event, request_id, update, |args| args.get_json().clone()))
192            }
193        }
194    }
195}
196
197#[derive(Debug, Clone)]
198pub struct SubscriptionUpdateMessage {
199    pub database_update: SwitchedDbUpdate,
200    pub request_id: Option<RequestId>,
201    pub timer: Option<Instant>,
202}
203
204impl SubscriptionUpdateMessage {
205    pub(crate) fn default_for_protocol(protocol: Protocol, request_id: Option<RequestId>) -> Self {
206        Self {
207            database_update: match protocol {
208                Protocol::Text => FormatSwitch::Json(<_>::default()),
209                Protocol::Binary => FormatSwitch::Bsatn(<_>::default()),
210            },
211            request_id,
212            timer: None,
213        }
214    }
215
216    pub(crate) fn from_event_and_update(event: &ModuleEvent, update: SwitchedDbUpdate) -> Self {
217        Self {
218            database_update: update,
219            request_id: event.request_id,
220            timer: event.timer,
221        }
222    }
223
224    fn num_rows(&self) -> usize {
225        match &self.database_update {
226            FormatSwitch::Bsatn(x) => x.num_rows(),
227            FormatSwitch::Json(x) => x.num_rows(),
228        }
229    }
230}
231
232impl ToProtocol for SubscriptionUpdateMessage {
233    type Encoded = SwitchedServerMessage;
234    fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
235        let request_id = self.request_id.unwrap_or(0);
236        let total_host_execution_duration = self.timer.map_or(TimeDuration::ZERO, |t| t.elapsed().into());
237
238        protocol.assert_matches_format_switch(&self.database_update);
239        match self.database_update {
240            FormatSwitch::Bsatn(database_update) => {
241                FormatSwitch::Bsatn(ws::ServerMessage::InitialSubscription(ws::InitialSubscription {
242                    database_update,
243                    request_id,
244                    total_host_execution_duration,
245                }))
246            }
247            FormatSwitch::Json(database_update) => {
248                FormatSwitch::Json(ws::ServerMessage::InitialSubscription(ws::InitialSubscription {
249                    database_update,
250                    request_id,
251                    total_host_execution_duration,
252                }))
253            }
254        }
255    }
256}
257
258#[derive(Debug, Clone)]
259pub struct SubscriptionData {
260    pub data: FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>,
261}
262
263#[derive(Debug, Clone)]
264pub struct SubscriptionRows {
265    pub table_id: TableId,
266    pub table_name: Box<str>,
267    pub table_rows: FormatSwitch<ws::TableUpdate<BsatnFormat>, ws::TableUpdate<JsonFormat>>,
268}
269
270#[derive(Debug, Clone)]
271pub struct SubscriptionError {
272    pub table_id: Option<TableId>,
273    pub message: Box<str>,
274}
275
276#[derive(Debug, Clone)]
277pub enum SubscriptionResult {
278    Subscribe(SubscriptionRows),
279    Unsubscribe(SubscriptionRows),
280    Error(SubscriptionError),
281    SubscribeMulti(SubscriptionData),
282    UnsubscribeMulti(SubscriptionData),
283}
284
285#[derive(Debug, Clone)]
286pub struct SubscriptionMessage {
287    pub timer: Option<Instant>,
288    pub request_id: Option<RequestId>,
289    pub query_id: Option<ws::QueryId>,
290    pub result: SubscriptionResult,
291}
292
293fn num_rows_in(rows: &SubscriptionRows) -> usize {
294    match &rows.table_rows {
295        FormatSwitch::Bsatn(x) => x.num_rows(),
296        FormatSwitch::Json(x) => x.num_rows(),
297    }
298}
299
300fn subscription_data_rows(rows: &SubscriptionData) -> usize {
301    match &rows.data {
302        FormatSwitch::Bsatn(x) => x.num_rows(),
303        FormatSwitch::Json(x) => x.num_rows(),
304    }
305}
306
307impl SubscriptionMessage {
308    fn num_rows(&self) -> usize {
309        match &self.result {
310            SubscriptionResult::Subscribe(x) => num_rows_in(x),
311            SubscriptionResult::SubscribeMulti(x) => subscription_data_rows(x),
312            SubscriptionResult::UnsubscribeMulti(x) => subscription_data_rows(x),
313            SubscriptionResult::Unsubscribe(x) => num_rows_in(x),
314            _ => 0,
315        }
316    }
317}
318
319impl ToProtocol for SubscriptionMessage {
320    type Encoded = SwitchedServerMessage;
321    fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
322        let request_id = self.request_id.unwrap_or(0);
323        let query_id = self.query_id.unwrap_or(ws::QueryId::new(0));
324        let total_host_execution_duration_micros = self.timer.map_or(0, |t| t.elapsed().as_micros() as u64);
325
326        match self.result {
327            SubscriptionResult::Subscribe(result) => {
328                protocol.assert_matches_format_switch(&result.table_rows);
329                match result.table_rows {
330                    FormatSwitch::Bsatn(table_rows) => FormatSwitch::Bsatn(
331                        ws::SubscribeApplied {
332                            total_host_execution_duration_micros,
333                            request_id,
334                            query_id,
335                            rows: ws::SubscribeRows {
336                                table_id: result.table_id,
337                                table_name: result.table_name,
338                                table_rows,
339                            },
340                        }
341                        .into(),
342                    ),
343                    FormatSwitch::Json(table_rows) => FormatSwitch::Json(
344                        ws::SubscribeApplied {
345                            total_host_execution_duration_micros,
346                            request_id,
347                            query_id,
348                            rows: ws::SubscribeRows {
349                                table_id: result.table_id,
350                                table_name: result.table_name,
351                                table_rows,
352                            },
353                        }
354                        .into(),
355                    ),
356                }
357            }
358            SubscriptionResult::Unsubscribe(result) => {
359                protocol.assert_matches_format_switch(&result.table_rows);
360                match result.table_rows {
361                    FormatSwitch::Bsatn(table_rows) => FormatSwitch::Bsatn(
362                        ws::UnsubscribeApplied {
363                            total_host_execution_duration_micros,
364                            request_id,
365                            query_id,
366                            rows: ws::SubscribeRows {
367                                table_id: result.table_id,
368                                table_name: result.table_name,
369                                table_rows,
370                            },
371                        }
372                        .into(),
373                    ),
374                    FormatSwitch::Json(table_rows) => FormatSwitch::Json(
375                        ws::UnsubscribeApplied {
376                            total_host_execution_duration_micros,
377                            request_id,
378                            query_id,
379                            rows: ws::SubscribeRows {
380                                table_id: result.table_id,
381                                table_name: result.table_name,
382                                table_rows,
383                            },
384                        }
385                        .into(),
386                    ),
387                }
388            }
389            SubscriptionResult::Error(error) => {
390                let msg = ws::SubscriptionError {
391                    total_host_execution_duration_micros,
392                    request_id: self.request_id,           // Pass Option through
393                    query_id: self.query_id.map(|x| x.id), // Pass Option through
394                    table_id: error.table_id,
395                    error: error.message,
396                };
397                match protocol {
398                    Protocol::Binary => FormatSwitch::Bsatn(msg.into()),
399                    Protocol::Text => FormatSwitch::Json(msg.into()),
400                }
401            }
402            SubscriptionResult::SubscribeMulti(result) => {
403                protocol.assert_matches_format_switch(&result.data);
404                match result.data {
405                    FormatSwitch::Bsatn(data) => FormatSwitch::Bsatn(
406                        ws::SubscribeMultiApplied {
407                            total_host_execution_duration_micros,
408                            request_id,
409                            query_id,
410                            update: data,
411                        }
412                        .into(),
413                    ),
414                    FormatSwitch::Json(data) => FormatSwitch::Json(
415                        ws::SubscribeMultiApplied {
416                            total_host_execution_duration_micros,
417                            request_id,
418                            query_id,
419                            update: data,
420                        }
421                        .into(),
422                    ),
423                }
424            }
425            SubscriptionResult::UnsubscribeMulti(result) => {
426                protocol.assert_matches_format_switch(&result.data);
427                match result.data {
428                    FormatSwitch::Bsatn(data) => FormatSwitch::Bsatn(
429                        ws::UnsubscribeMultiApplied {
430                            total_host_execution_duration_micros,
431                            request_id,
432                            query_id,
433                            update: data,
434                        }
435                        .into(),
436                    ),
437                    FormatSwitch::Json(data) => FormatSwitch::Json(
438                        ws::UnsubscribeMultiApplied {
439                            total_host_execution_duration_micros,
440                            request_id,
441                            query_id,
442                            update: data,
443                        }
444                        .into(),
445                    ),
446                }
447            }
448        }
449    }
450}
451
452#[derive(Debug)]
453pub struct OneOffQueryResponseMessage<F: WebsocketFormat> {
454    pub message_id: Vec<u8>,
455    pub error: Option<String>,
456    pub results: Vec<OneOffTable<F>>,
457    pub total_host_execution_duration: TimeDuration,
458}
459
460impl<F: WebsocketFormat> OneOffQueryResponseMessage<F> {
461    fn num_rows(&self) -> usize {
462        self.results.iter().map(|table| table.rows.len()).sum()
463    }
464}
465
466impl ToProtocol for OneOffQueryResponseMessage<BsatnFormat> {
467    type Encoded = SwitchedServerMessage;
468
469    fn to_protocol(self, _: Protocol) -> Self::Encoded {
470        FormatSwitch::Bsatn(convert(self))
471    }
472}
473
474impl ToProtocol for OneOffQueryResponseMessage<JsonFormat> {
475    type Encoded = SwitchedServerMessage;
476    fn to_protocol(self, _: Protocol) -> Self::Encoded {
477        FormatSwitch::Json(convert(self))
478    }
479}
480
481fn convert<F: WebsocketFormat>(msg: OneOffQueryResponseMessage<F>) -> ws::ServerMessage<F> {
482    ws::ServerMessage::OneOffQueryResponse(ws::OneOffQueryResponse {
483        message_id: msg.message_id.into(),
484        error: msg.error.map(Into::into),
485        tables: msg.results.into_boxed_slice(),
486        total_host_execution_duration: msg.total_host_execution_duration,
487    })
488}