spacetimedb/client/
messages.rs

1use super::{ClientConfig, DataMessage, Protocol};
2use crate::host::module_host::{EventStatus, ModuleEvent};
3use crate::host::ArgsTuple;
4use crate::messages::websocket as ws;
5use bytes::{BufMut, Bytes, BytesMut};
6use bytestring::ByteString;
7use derive_more::From;
8use spacetimedb_client_api_messages::websocket::{
9    BsatnFormat, Compression, FormatSwitch, JsonFormat, OneOffTable, RowListLen, WebsocketFormat,
10    SERVER_MSG_COMPRESSION_TAG_BROTLI, SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE,
11};
12use spacetimedb_datastore::execution_context::WorkloadType;
13use spacetimedb_lib::identity::RequestId;
14use spacetimedb_lib::ser::serde::SerializeWrapper;
15use spacetimedb_lib::{ConnectionId, TimeDuration};
16use spacetimedb_primitives::TableId;
17use spacetimedb_sats::bsatn;
18use std::sync::Arc;
19use std::time::Instant;
20
21/// A server-to-client message which can be encoded according to a [`Protocol`],
22/// resulting in a [`ToProtocol::Encoded`] message.
23pub trait ToProtocol {
24    type Encoded;
25    /// Convert `self` into a [`Self::Encoded`] where rows and arguments are encoded with `protocol`.
26    fn to_protocol(self, protocol: Protocol) -> Self::Encoded;
27}
28
29pub type SwitchedServerMessage = FormatSwitch<ws::ServerMessage<BsatnFormat>, ws::ServerMessage<JsonFormat>>;
30pub(super) type SwitchedDbUpdate = FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>;
31
32/// The initial size of a `serialize` buffer.
33/// Currently 4k to align with the linux page size
34/// and this should be more than enough in the common case.
35const SERIALIZE_BUFFER_INIT_CAP: usize = 4096;
36
37/// A buffer used by [`serialize`]
38pub struct SerializeBuffer {
39    uncompressed: BytesMut,
40    compressed: BytesMut,
41}
42
43impl SerializeBuffer {
44    pub fn new(config: ClientConfig) -> Self {
45        let uncompressed_capacity = SERIALIZE_BUFFER_INIT_CAP;
46        let compressed_capacity = if config.compression == Compression::None || config.protocol == Protocol::Text {
47            0
48        } else {
49            SERIALIZE_BUFFER_INIT_CAP
50        };
51        Self {
52            uncompressed: BytesMut::with_capacity(uncompressed_capacity),
53            compressed: BytesMut::with_capacity(compressed_capacity),
54        }
55    }
56
57    /// Take the uncompressed message as the one to use.
58    fn uncompressed(self) -> (InUseSerializeBuffer, Bytes) {
59        let uncompressed = self.uncompressed.freeze();
60        let in_use = InUseSerializeBuffer::Uncompressed {
61            uncompressed: uncompressed.clone(),
62            compressed: self.compressed,
63        };
64        (in_use, uncompressed)
65    }
66
67    /// Write uncompressed data with a leading tag.
68    fn write_with_tag<F>(&mut self, tag: u8, write: F) -> &[u8]
69    where
70        F: FnOnce(bytes::buf::Writer<&mut BytesMut>),
71    {
72        self.uncompressed.put_u8(tag);
73        write((&mut self.uncompressed).writer());
74        &self.uncompressed[1..]
75    }
76
77    /// Compress the data from a `write_with_tag` call, and change the tag.
78    fn compress_with_tag(
79        self,
80        tag: u8,
81        write: impl FnOnce(&[u8], &mut bytes::buf::Writer<BytesMut>),
82    ) -> (InUseSerializeBuffer, Bytes) {
83        let mut writer = self.compressed.writer();
84        writer.get_mut().put_u8(tag);
85        write(&self.uncompressed[1..], &mut writer);
86        let compressed = writer.into_inner().freeze();
87        let in_use = InUseSerializeBuffer::Compressed {
88            uncompressed: self.uncompressed,
89            compressed: compressed.clone(),
90        };
91        (in_use, compressed)
92    }
93}
94
95type BytesMutWriter<'a> = bytes::buf::Writer<&'a mut BytesMut>;
96
97pub enum InUseSerializeBuffer {
98    Uncompressed { uncompressed: Bytes, compressed: BytesMut },
99    Compressed { uncompressed: BytesMut, compressed: Bytes },
100}
101
102impl InUseSerializeBuffer {
103    pub fn try_reclaim(self) -> Option<SerializeBuffer> {
104        let (mut uncompressed, mut compressed) = match self {
105            Self::Uncompressed {
106                uncompressed,
107                compressed,
108            } => (uncompressed.try_into_mut().ok()?, compressed),
109            Self::Compressed {
110                uncompressed,
111                compressed,
112            } => (uncompressed, compressed.try_into_mut().ok()?),
113        };
114        uncompressed.clear();
115        compressed.clear();
116        Some(SerializeBuffer {
117            uncompressed,
118            compressed,
119        })
120    }
121}
122
123/// Serialize `msg` into a [`DataMessage`] containing a [`ws::ServerMessage`].
124///
125/// If `protocol` is [`Protocol::Binary`],
126/// the message will be conditionally compressed by this method according to `compression`.
127pub fn serialize(
128    mut buffer: SerializeBuffer,
129    msg: impl ToProtocol<Encoded = SwitchedServerMessage>,
130    config: ClientConfig,
131) -> (InUseSerializeBuffer, DataMessage) {
132    match msg.to_protocol(config.protocol) {
133        FormatSwitch::Json(msg) => {
134            let out: BytesMutWriter<'_> = (&mut buffer.uncompressed).writer();
135            serde_json::to_writer(out, &SerializeWrapper::new(msg))
136                .expect("should be able to json encode a `ServerMessage`");
137
138            let (in_use, out) = buffer.uncompressed();
139            // SAFETY: `serde_json::to_writer` states that:
140            // > "Serialization guarantees it only feeds valid UTF-8 sequences to the writer."
141            let msg_json = unsafe { ByteString::from_bytes_unchecked(out) };
142            (in_use, msg_json.into())
143        }
144        FormatSwitch::Bsatn(msg) => {
145            // First write the tag so that we avoid shifting the entire message at the end.
146            let srv_msg = buffer.write_with_tag(SERVER_MSG_COMPRESSION_TAG_NONE, |w| {
147                bsatn::to_writer(w.into_inner(), &msg).unwrap()
148            });
149
150            // Conditionally compress the message.
151            let (in_use, msg_bytes) = match ws::decide_compression(srv_msg.len(), config.compression) {
152                Compression::None => buffer.uncompressed(),
153                Compression::Brotli => buffer.compress_with_tag(SERVER_MSG_COMPRESSION_TAG_BROTLI, ws::brotli_compress),
154                Compression::Gzip => buffer.compress_with_tag(SERVER_MSG_COMPRESSION_TAG_GZIP, ws::gzip_compress),
155            };
156            (in_use, msg_bytes.into())
157        }
158    }
159}
160
161#[derive(Debug, From)]
162pub enum SerializableMessage {
163    QueryBinary(OneOffQueryResponseMessage<BsatnFormat>),
164    QueryText(OneOffQueryResponseMessage<JsonFormat>),
165    Identity(IdentityTokenMessage),
166    Subscribe(SubscriptionUpdateMessage),
167    Subscription(SubscriptionMessage),
168    TxUpdate(TransactionUpdateMessage),
169}
170
171impl SerializableMessage {
172    pub fn num_rows(&self) -> Option<usize> {
173        match self {
174            Self::QueryBinary(msg) => Some(msg.num_rows()),
175            Self::QueryText(msg) => Some(msg.num_rows()),
176            Self::Subscribe(msg) => Some(msg.num_rows()),
177            Self::Subscription(msg) => Some(msg.num_rows()),
178            Self::TxUpdate(msg) => Some(msg.num_rows()),
179            Self::Identity(_) => None,
180        }
181    }
182
183    pub fn workload(&self) -> Option<WorkloadType> {
184        match self {
185            Self::QueryBinary(_) | Self::QueryText(_) => Some(WorkloadType::Sql),
186            Self::Subscribe(_) => Some(WorkloadType::Subscribe),
187            Self::Subscription(msg) => match &msg.result {
188                SubscriptionResult::Subscribe(_) => Some(WorkloadType::Subscribe),
189                SubscriptionResult::Unsubscribe(_) => Some(WorkloadType::Unsubscribe),
190                SubscriptionResult::Error(_) => None,
191                SubscriptionResult::SubscribeMulti(_) => Some(WorkloadType::Subscribe),
192                SubscriptionResult::UnsubscribeMulti(_) => Some(WorkloadType::Unsubscribe),
193            },
194            Self::TxUpdate(_) => Some(WorkloadType::Update),
195            Self::Identity(_) => None,
196        }
197    }
198}
199
200impl ToProtocol for SerializableMessage {
201    type Encoded = SwitchedServerMessage;
202    fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
203        match self {
204            SerializableMessage::QueryBinary(msg) => msg.to_protocol(protocol),
205            SerializableMessage::QueryText(msg) => msg.to_protocol(protocol),
206            SerializableMessage::Identity(msg) => msg.to_protocol(protocol),
207            SerializableMessage::Subscribe(msg) => msg.to_protocol(protocol),
208            SerializableMessage::TxUpdate(msg) => msg.to_protocol(protocol),
209            SerializableMessage::Subscription(msg) => msg.to_protocol(protocol),
210        }
211    }
212}
213
214pub type IdentityTokenMessage = ws::IdentityToken;
215
216impl ToProtocol for IdentityTokenMessage {
217    type Encoded = SwitchedServerMessage;
218    fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
219        match protocol {
220            Protocol::Text => FormatSwitch::Json(ws::ServerMessage::IdentityToken(self)),
221            Protocol::Binary => FormatSwitch::Bsatn(ws::ServerMessage::IdentityToken(self)),
222        }
223    }
224}
225
226#[derive(Debug)]
227pub struct TransactionUpdateMessage {
228    /// The event that caused this update.
229    /// When `None`, this is a light update.
230    pub event: Option<Arc<ModuleEvent>>,
231    pub database_update: SubscriptionUpdateMessage,
232}
233
234impl TransactionUpdateMessage {
235    fn num_rows(&self) -> usize {
236        self.database_update.num_rows()
237    }
238}
239
240impl ToProtocol for TransactionUpdateMessage {
241    type Encoded = SwitchedServerMessage;
242    fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
243        fn convert<F: WebsocketFormat>(
244            event: Option<Arc<ModuleEvent>>,
245            request_id: u32,
246            update: ws::DatabaseUpdate<F>,
247            conv_args: impl FnOnce(&ArgsTuple) -> F::Single,
248        ) -> ws::ServerMessage<F> {
249            let Some(event) = event else {
250                return ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { request_id, update });
251            };
252
253            let status = match &event.status {
254                EventStatus::Committed(_) => ws::UpdateStatus::Committed(update),
255                EventStatus::Failed(errmsg) => ws::UpdateStatus::Failed(errmsg.clone().into()),
256                EventStatus::OutOfEnergy => ws::UpdateStatus::OutOfEnergy,
257            };
258
259            let args = conv_args(&event.function_call.args);
260
261            let tx_update = ws::TransactionUpdate {
262                timestamp: event.timestamp,
263                status,
264                caller_identity: event.caller_identity,
265                reducer_call: ws::ReducerCallInfo {
266                    reducer_name: event.function_call.reducer.to_owned().into(),
267                    reducer_id: event.function_call.reducer_id.into(),
268                    args,
269                    request_id,
270                },
271                energy_quanta_used: event.energy_quanta_used,
272                total_host_execution_duration: event.host_execution_duration.into(),
273                caller_connection_id: event.caller_connection_id.unwrap_or(ConnectionId::ZERO),
274            };
275
276            ws::ServerMessage::TransactionUpdate(tx_update)
277        }
278
279        let TransactionUpdateMessage { event, database_update } = self;
280        let update = database_update.database_update;
281        protocol.assert_matches_format_switch(&update);
282        let request_id = database_update.request_id.unwrap_or(0);
283        match update {
284            FormatSwitch::Bsatn(update) => FormatSwitch::Bsatn(convert(event, request_id, update, |args| {
285                Vec::from(args.get_bsatn().clone()).into()
286            })),
287            FormatSwitch::Json(update) => {
288                FormatSwitch::Json(convert(event, request_id, update, |args| args.get_json().clone()))
289            }
290        }
291    }
292}
293
294#[derive(Debug, Clone)]
295pub struct SubscriptionUpdateMessage {
296    pub database_update: SwitchedDbUpdate,
297    pub request_id: Option<RequestId>,
298    pub timer: Option<Instant>,
299}
300
301impl SubscriptionUpdateMessage {
302    pub(crate) fn default_for_protocol(protocol: Protocol, request_id: Option<RequestId>) -> Self {
303        Self {
304            database_update: match protocol {
305                Protocol::Text => FormatSwitch::Json(<_>::default()),
306                Protocol::Binary => FormatSwitch::Bsatn(<_>::default()),
307            },
308            request_id,
309            timer: None,
310        }
311    }
312
313    pub(crate) fn from_event_and_update(event: &ModuleEvent, update: SwitchedDbUpdate) -> Self {
314        Self {
315            database_update: update,
316            request_id: event.request_id,
317            timer: event.timer,
318        }
319    }
320
321    fn num_rows(&self) -> usize {
322        match &self.database_update {
323            FormatSwitch::Bsatn(x) => x.num_rows(),
324            FormatSwitch::Json(x) => x.num_rows(),
325        }
326    }
327}
328
329impl ToProtocol for SubscriptionUpdateMessage {
330    type Encoded = SwitchedServerMessage;
331    fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
332        let request_id = self.request_id.unwrap_or(0);
333        let total_host_execution_duration = self.timer.map_or(TimeDuration::ZERO, |t| t.elapsed().into());
334
335        protocol.assert_matches_format_switch(&self.database_update);
336        match self.database_update {
337            FormatSwitch::Bsatn(database_update) => {
338                FormatSwitch::Bsatn(ws::ServerMessage::InitialSubscription(ws::InitialSubscription {
339                    database_update,
340                    request_id,
341                    total_host_execution_duration,
342                }))
343            }
344            FormatSwitch::Json(database_update) => {
345                FormatSwitch::Json(ws::ServerMessage::InitialSubscription(ws::InitialSubscription {
346                    database_update,
347                    request_id,
348                    total_host_execution_duration,
349                }))
350            }
351        }
352    }
353}
354
355#[derive(Debug, Clone)]
356pub struct SubscriptionData {
357    pub data: FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>,
358}
359
360#[derive(Debug, Clone)]
361pub struct SubscriptionRows {
362    pub table_id: TableId,
363    pub table_name: Box<str>,
364    pub table_rows: FormatSwitch<ws::TableUpdate<BsatnFormat>, ws::TableUpdate<JsonFormat>>,
365}
366
367#[derive(Debug, Clone)]
368pub struct SubscriptionError {
369    pub table_id: Option<TableId>,
370    pub message: Box<str>,
371}
372
373#[derive(Debug, Clone)]
374pub enum SubscriptionResult {
375    Subscribe(SubscriptionRows),
376    Unsubscribe(SubscriptionRows),
377    Error(SubscriptionError),
378    SubscribeMulti(SubscriptionData),
379    UnsubscribeMulti(SubscriptionData),
380}
381
382#[derive(Debug, Clone)]
383pub struct SubscriptionMessage {
384    pub timer: Option<Instant>,
385    pub request_id: Option<RequestId>,
386    pub query_id: Option<ws::QueryId>,
387    pub result: SubscriptionResult,
388}
389
390fn num_rows_in(rows: &SubscriptionRows) -> usize {
391    match &rows.table_rows {
392        FormatSwitch::Bsatn(x) => x.num_rows(),
393        FormatSwitch::Json(x) => x.num_rows(),
394    }
395}
396
397fn subscription_data_rows(rows: &SubscriptionData) -> usize {
398    match &rows.data {
399        FormatSwitch::Bsatn(x) => x.num_rows(),
400        FormatSwitch::Json(x) => x.num_rows(),
401    }
402}
403
404impl SubscriptionMessage {
405    fn num_rows(&self) -> usize {
406        match &self.result {
407            SubscriptionResult::Subscribe(x) => num_rows_in(x),
408            SubscriptionResult::SubscribeMulti(x) => subscription_data_rows(x),
409            SubscriptionResult::UnsubscribeMulti(x) => subscription_data_rows(x),
410            SubscriptionResult::Unsubscribe(x) => num_rows_in(x),
411            _ => 0,
412        }
413    }
414}
415
416impl ToProtocol for SubscriptionMessage {
417    type Encoded = SwitchedServerMessage;
418    fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
419        let request_id = self.request_id.unwrap_or(0);
420        let query_id = self.query_id.unwrap_or(ws::QueryId::new(0));
421        let total_host_execution_duration_micros = self.timer.map_or(0, |t| t.elapsed().as_micros() as u64);
422
423        match self.result {
424            SubscriptionResult::Subscribe(result) => {
425                protocol.assert_matches_format_switch(&result.table_rows);
426                match result.table_rows {
427                    FormatSwitch::Bsatn(table_rows) => FormatSwitch::Bsatn(
428                        ws::SubscribeApplied {
429                            total_host_execution_duration_micros,
430                            request_id,
431                            query_id,
432                            rows: ws::SubscribeRows {
433                                table_id: result.table_id,
434                                table_name: result.table_name,
435                                table_rows,
436                            },
437                        }
438                        .into(),
439                    ),
440                    FormatSwitch::Json(table_rows) => FormatSwitch::Json(
441                        ws::SubscribeApplied {
442                            total_host_execution_duration_micros,
443                            request_id,
444                            query_id,
445                            rows: ws::SubscribeRows {
446                                table_id: result.table_id,
447                                table_name: result.table_name,
448                                table_rows,
449                            },
450                        }
451                        .into(),
452                    ),
453                }
454            }
455            SubscriptionResult::Unsubscribe(result) => {
456                protocol.assert_matches_format_switch(&result.table_rows);
457                match result.table_rows {
458                    FormatSwitch::Bsatn(table_rows) => FormatSwitch::Bsatn(
459                        ws::UnsubscribeApplied {
460                            total_host_execution_duration_micros,
461                            request_id,
462                            query_id,
463                            rows: ws::SubscribeRows {
464                                table_id: result.table_id,
465                                table_name: result.table_name,
466                                table_rows,
467                            },
468                        }
469                        .into(),
470                    ),
471                    FormatSwitch::Json(table_rows) => FormatSwitch::Json(
472                        ws::UnsubscribeApplied {
473                            total_host_execution_duration_micros,
474                            request_id,
475                            query_id,
476                            rows: ws::SubscribeRows {
477                                table_id: result.table_id,
478                                table_name: result.table_name,
479                                table_rows,
480                            },
481                        }
482                        .into(),
483                    ),
484                }
485            }
486            SubscriptionResult::Error(error) => {
487                let msg = ws::SubscriptionError {
488                    total_host_execution_duration_micros,
489                    request_id: self.request_id,           // Pass Option through
490                    query_id: self.query_id.map(|x| x.id), // Pass Option through
491                    table_id: error.table_id,
492                    error: error.message,
493                };
494                match protocol {
495                    Protocol::Binary => FormatSwitch::Bsatn(msg.into()),
496                    Protocol::Text => FormatSwitch::Json(msg.into()),
497                }
498            }
499            SubscriptionResult::SubscribeMulti(result) => {
500                protocol.assert_matches_format_switch(&result.data);
501                match result.data {
502                    FormatSwitch::Bsatn(data) => FormatSwitch::Bsatn(
503                        ws::SubscribeMultiApplied {
504                            total_host_execution_duration_micros,
505                            request_id,
506                            query_id,
507                            update: data,
508                        }
509                        .into(),
510                    ),
511                    FormatSwitch::Json(data) => FormatSwitch::Json(
512                        ws::SubscribeMultiApplied {
513                            total_host_execution_duration_micros,
514                            request_id,
515                            query_id,
516                            update: data,
517                        }
518                        .into(),
519                    ),
520                }
521            }
522            SubscriptionResult::UnsubscribeMulti(result) => {
523                protocol.assert_matches_format_switch(&result.data);
524                match result.data {
525                    FormatSwitch::Bsatn(data) => FormatSwitch::Bsatn(
526                        ws::UnsubscribeMultiApplied {
527                            total_host_execution_duration_micros,
528                            request_id,
529                            query_id,
530                            update: data,
531                        }
532                        .into(),
533                    ),
534                    FormatSwitch::Json(data) => FormatSwitch::Json(
535                        ws::UnsubscribeMultiApplied {
536                            total_host_execution_duration_micros,
537                            request_id,
538                            query_id,
539                            update: data,
540                        }
541                        .into(),
542                    ),
543                }
544            }
545        }
546    }
547}
548
549#[derive(Debug)]
550pub struct OneOffQueryResponseMessage<F: WebsocketFormat> {
551    pub message_id: Vec<u8>,
552    pub error: Option<String>,
553    pub results: Vec<OneOffTable<F>>,
554    pub total_host_execution_duration: TimeDuration,
555}
556
557impl<F: WebsocketFormat> OneOffQueryResponseMessage<F> {
558    fn num_rows(&self) -> usize {
559        self.results.iter().map(|table| table.rows.len()).sum()
560    }
561}
562
563impl ToProtocol for OneOffQueryResponseMessage<BsatnFormat> {
564    type Encoded = SwitchedServerMessage;
565
566    fn to_protocol(self, _: Protocol) -> Self::Encoded {
567        FormatSwitch::Bsatn(convert(self))
568    }
569}
570
571impl ToProtocol for OneOffQueryResponseMessage<JsonFormat> {
572    type Encoded = SwitchedServerMessage;
573    fn to_protocol(self, _: Protocol) -> Self::Encoded {
574        FormatSwitch::Json(convert(self))
575    }
576}
577
578fn convert<F: WebsocketFormat>(msg: OneOffQueryResponseMessage<F>) -> ws::ServerMessage<F> {
579    ws::ServerMessage::OneOffQueryResponse(ws::OneOffQueryResponse {
580        message_id: msg.message_id.into(),
581        error: msg.error.map(Into::into),
582        tables: msg.results.into_boxed_slice(),
583        total_host_execution_duration: msg.total_host_execution_duration,
584    })
585}