1use super::{ClientConfig, DataMessage, Protocol};
2use crate::host::module_host::{EventStatus, ModuleEvent};
3use crate::host::ArgsTuple;
4use crate::messages::websocket as ws;
5use bytes::{BufMut, Bytes, BytesMut};
6use bytestring::ByteString;
7use derive_more::From;
8use spacetimedb_client_api_messages::websocket::{
9 BsatnFormat, Compression, FormatSwitch, JsonFormat, OneOffTable, RowListLen, WebsocketFormat,
10 SERVER_MSG_COMPRESSION_TAG_BROTLI, SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE,
11};
12use spacetimedb_datastore::execution_context::WorkloadType;
13use spacetimedb_lib::identity::RequestId;
14use spacetimedb_lib::ser::serde::SerializeWrapper;
15use spacetimedb_lib::{ConnectionId, TimeDuration};
16use spacetimedb_primitives::TableId;
17use spacetimedb_sats::bsatn;
18use std::sync::Arc;
19use std::time::Instant;
20
21pub trait ToProtocol {
24 type Encoded;
25 fn to_protocol(self, protocol: Protocol) -> Self::Encoded;
27}
28
29pub type SwitchedServerMessage = FormatSwitch<ws::ServerMessage<BsatnFormat>, ws::ServerMessage<JsonFormat>>;
30pub(super) type SwitchedDbUpdate = FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>;
31
32const SERIALIZE_BUFFER_INIT_CAP: usize = 4096;
36
37pub struct SerializeBuffer {
39 uncompressed: BytesMut,
40 compressed: BytesMut,
41}
42
43impl SerializeBuffer {
44 pub fn new(config: ClientConfig) -> Self {
45 let uncompressed_capacity = SERIALIZE_BUFFER_INIT_CAP;
46 let compressed_capacity = if config.compression == Compression::None || config.protocol == Protocol::Text {
47 0
48 } else {
49 SERIALIZE_BUFFER_INIT_CAP
50 };
51 Self {
52 uncompressed: BytesMut::with_capacity(uncompressed_capacity),
53 compressed: BytesMut::with_capacity(compressed_capacity),
54 }
55 }
56
57 fn uncompressed(self) -> (InUseSerializeBuffer, Bytes) {
59 let uncompressed = self.uncompressed.freeze();
60 let in_use = InUseSerializeBuffer::Uncompressed {
61 uncompressed: uncompressed.clone(),
62 compressed: self.compressed,
63 };
64 (in_use, uncompressed)
65 }
66
67 fn write_with_tag<F>(&mut self, tag: u8, write: F) -> &[u8]
69 where
70 F: FnOnce(bytes::buf::Writer<&mut BytesMut>),
71 {
72 self.uncompressed.put_u8(tag);
73 write((&mut self.uncompressed).writer());
74 &self.uncompressed[1..]
75 }
76
77 fn compress_with_tag(
79 self,
80 tag: u8,
81 write: impl FnOnce(&[u8], &mut bytes::buf::Writer<BytesMut>),
82 ) -> (InUseSerializeBuffer, Bytes) {
83 let mut writer = self.compressed.writer();
84 writer.get_mut().put_u8(tag);
85 write(&self.uncompressed[1..], &mut writer);
86 let compressed = writer.into_inner().freeze();
87 let in_use = InUseSerializeBuffer::Compressed {
88 uncompressed: self.uncompressed,
89 compressed: compressed.clone(),
90 };
91 (in_use, compressed)
92 }
93}
94
95type BytesMutWriter<'a> = bytes::buf::Writer<&'a mut BytesMut>;
96
97pub enum InUseSerializeBuffer {
98 Uncompressed { uncompressed: Bytes, compressed: BytesMut },
99 Compressed { uncompressed: BytesMut, compressed: Bytes },
100}
101
102impl InUseSerializeBuffer {
103 pub fn try_reclaim(self) -> Option<SerializeBuffer> {
104 let (mut uncompressed, mut compressed) = match self {
105 Self::Uncompressed {
106 uncompressed,
107 compressed,
108 } => (uncompressed.try_into_mut().ok()?, compressed),
109 Self::Compressed {
110 uncompressed,
111 compressed,
112 } => (uncompressed, compressed.try_into_mut().ok()?),
113 };
114 uncompressed.clear();
115 compressed.clear();
116 Some(SerializeBuffer {
117 uncompressed,
118 compressed,
119 })
120 }
121}
122
123pub fn serialize(
128 mut buffer: SerializeBuffer,
129 msg: impl ToProtocol<Encoded = SwitchedServerMessage>,
130 config: ClientConfig,
131) -> (InUseSerializeBuffer, DataMessage) {
132 match msg.to_protocol(config.protocol) {
133 FormatSwitch::Json(msg) => {
134 let out: BytesMutWriter<'_> = (&mut buffer.uncompressed).writer();
135 serde_json::to_writer(out, &SerializeWrapper::new(msg))
136 .expect("should be able to json encode a `ServerMessage`");
137
138 let (in_use, out) = buffer.uncompressed();
139 let msg_json = unsafe { ByteString::from_bytes_unchecked(out) };
142 (in_use, msg_json.into())
143 }
144 FormatSwitch::Bsatn(msg) => {
145 let srv_msg = buffer.write_with_tag(SERVER_MSG_COMPRESSION_TAG_NONE, |w| {
147 bsatn::to_writer(w.into_inner(), &msg).unwrap()
148 });
149
150 let (in_use, msg_bytes) = match ws::decide_compression(srv_msg.len(), config.compression) {
152 Compression::None => buffer.uncompressed(),
153 Compression::Brotli => buffer.compress_with_tag(SERVER_MSG_COMPRESSION_TAG_BROTLI, ws::brotli_compress),
154 Compression::Gzip => buffer.compress_with_tag(SERVER_MSG_COMPRESSION_TAG_GZIP, ws::gzip_compress),
155 };
156 (in_use, msg_bytes.into())
157 }
158 }
159}
160
161#[derive(Debug, From)]
162pub enum SerializableMessage {
163 QueryBinary(OneOffQueryResponseMessage<BsatnFormat>),
164 QueryText(OneOffQueryResponseMessage<JsonFormat>),
165 Identity(IdentityTokenMessage),
166 Subscribe(SubscriptionUpdateMessage),
167 Subscription(SubscriptionMessage),
168 TxUpdate(TransactionUpdateMessage),
169}
170
171impl SerializableMessage {
172 pub fn num_rows(&self) -> Option<usize> {
173 match self {
174 Self::QueryBinary(msg) => Some(msg.num_rows()),
175 Self::QueryText(msg) => Some(msg.num_rows()),
176 Self::Subscribe(msg) => Some(msg.num_rows()),
177 Self::Subscription(msg) => Some(msg.num_rows()),
178 Self::TxUpdate(msg) => Some(msg.num_rows()),
179 Self::Identity(_) => None,
180 }
181 }
182
183 pub fn workload(&self) -> Option<WorkloadType> {
184 match self {
185 Self::QueryBinary(_) | Self::QueryText(_) => Some(WorkloadType::Sql),
186 Self::Subscribe(_) => Some(WorkloadType::Subscribe),
187 Self::Subscription(msg) => match &msg.result {
188 SubscriptionResult::Subscribe(_) => Some(WorkloadType::Subscribe),
189 SubscriptionResult::Unsubscribe(_) => Some(WorkloadType::Unsubscribe),
190 SubscriptionResult::Error(_) => None,
191 SubscriptionResult::SubscribeMulti(_) => Some(WorkloadType::Subscribe),
192 SubscriptionResult::UnsubscribeMulti(_) => Some(WorkloadType::Unsubscribe),
193 },
194 Self::TxUpdate(_) => Some(WorkloadType::Update),
195 Self::Identity(_) => None,
196 }
197 }
198}
199
200impl ToProtocol for SerializableMessage {
201 type Encoded = SwitchedServerMessage;
202 fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
203 match self {
204 SerializableMessage::QueryBinary(msg) => msg.to_protocol(protocol),
205 SerializableMessage::QueryText(msg) => msg.to_protocol(protocol),
206 SerializableMessage::Identity(msg) => msg.to_protocol(protocol),
207 SerializableMessage::Subscribe(msg) => msg.to_protocol(protocol),
208 SerializableMessage::TxUpdate(msg) => msg.to_protocol(protocol),
209 SerializableMessage::Subscription(msg) => msg.to_protocol(protocol),
210 }
211 }
212}
213
214pub type IdentityTokenMessage = ws::IdentityToken;
215
216impl ToProtocol for IdentityTokenMessage {
217 type Encoded = SwitchedServerMessage;
218 fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
219 match protocol {
220 Protocol::Text => FormatSwitch::Json(ws::ServerMessage::IdentityToken(self)),
221 Protocol::Binary => FormatSwitch::Bsatn(ws::ServerMessage::IdentityToken(self)),
222 }
223 }
224}
225
226#[derive(Debug)]
227pub struct TransactionUpdateMessage {
228 pub event: Option<Arc<ModuleEvent>>,
231 pub database_update: SubscriptionUpdateMessage,
232}
233
234impl TransactionUpdateMessage {
235 fn num_rows(&self) -> usize {
236 self.database_update.num_rows()
237 }
238}
239
240impl ToProtocol for TransactionUpdateMessage {
241 type Encoded = SwitchedServerMessage;
242 fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
243 fn convert<F: WebsocketFormat>(
244 event: Option<Arc<ModuleEvent>>,
245 request_id: u32,
246 update: ws::DatabaseUpdate<F>,
247 conv_args: impl FnOnce(&ArgsTuple) -> F::Single,
248 ) -> ws::ServerMessage<F> {
249 let Some(event) = event else {
250 return ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { request_id, update });
251 };
252
253 let status = match &event.status {
254 EventStatus::Committed(_) => ws::UpdateStatus::Committed(update),
255 EventStatus::Failed(errmsg) => ws::UpdateStatus::Failed(errmsg.clone().into()),
256 EventStatus::OutOfEnergy => ws::UpdateStatus::OutOfEnergy,
257 };
258
259 let args = conv_args(&event.function_call.args);
260
261 let tx_update = ws::TransactionUpdate {
262 timestamp: event.timestamp,
263 status,
264 caller_identity: event.caller_identity,
265 reducer_call: ws::ReducerCallInfo {
266 reducer_name: event.function_call.reducer.to_owned().into(),
267 reducer_id: event.function_call.reducer_id.into(),
268 args,
269 request_id,
270 },
271 energy_quanta_used: event.energy_quanta_used,
272 total_host_execution_duration: event.host_execution_duration.into(),
273 caller_connection_id: event.caller_connection_id.unwrap_or(ConnectionId::ZERO),
274 };
275
276 ws::ServerMessage::TransactionUpdate(tx_update)
277 }
278
279 let TransactionUpdateMessage { event, database_update } = self;
280 let update = database_update.database_update;
281 protocol.assert_matches_format_switch(&update);
282 let request_id = database_update.request_id.unwrap_or(0);
283 match update {
284 FormatSwitch::Bsatn(update) => FormatSwitch::Bsatn(convert(event, request_id, update, |args| {
285 Vec::from(args.get_bsatn().clone()).into()
286 })),
287 FormatSwitch::Json(update) => {
288 FormatSwitch::Json(convert(event, request_id, update, |args| args.get_json().clone()))
289 }
290 }
291 }
292}
293
294#[derive(Debug, Clone)]
295pub struct SubscriptionUpdateMessage {
296 pub database_update: SwitchedDbUpdate,
297 pub request_id: Option<RequestId>,
298 pub timer: Option<Instant>,
299}
300
301impl SubscriptionUpdateMessage {
302 pub(crate) fn default_for_protocol(protocol: Protocol, request_id: Option<RequestId>) -> Self {
303 Self {
304 database_update: match protocol {
305 Protocol::Text => FormatSwitch::Json(<_>::default()),
306 Protocol::Binary => FormatSwitch::Bsatn(<_>::default()),
307 },
308 request_id,
309 timer: None,
310 }
311 }
312
313 pub(crate) fn from_event_and_update(event: &ModuleEvent, update: SwitchedDbUpdate) -> Self {
314 Self {
315 database_update: update,
316 request_id: event.request_id,
317 timer: event.timer,
318 }
319 }
320
321 fn num_rows(&self) -> usize {
322 match &self.database_update {
323 FormatSwitch::Bsatn(x) => x.num_rows(),
324 FormatSwitch::Json(x) => x.num_rows(),
325 }
326 }
327}
328
329impl ToProtocol for SubscriptionUpdateMessage {
330 type Encoded = SwitchedServerMessage;
331 fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
332 let request_id = self.request_id.unwrap_or(0);
333 let total_host_execution_duration = self.timer.map_or(TimeDuration::ZERO, |t| t.elapsed().into());
334
335 protocol.assert_matches_format_switch(&self.database_update);
336 match self.database_update {
337 FormatSwitch::Bsatn(database_update) => {
338 FormatSwitch::Bsatn(ws::ServerMessage::InitialSubscription(ws::InitialSubscription {
339 database_update,
340 request_id,
341 total_host_execution_duration,
342 }))
343 }
344 FormatSwitch::Json(database_update) => {
345 FormatSwitch::Json(ws::ServerMessage::InitialSubscription(ws::InitialSubscription {
346 database_update,
347 request_id,
348 total_host_execution_duration,
349 }))
350 }
351 }
352 }
353}
354
355#[derive(Debug, Clone)]
356pub struct SubscriptionData {
357 pub data: FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>,
358}
359
360#[derive(Debug, Clone)]
361pub struct SubscriptionRows {
362 pub table_id: TableId,
363 pub table_name: Box<str>,
364 pub table_rows: FormatSwitch<ws::TableUpdate<BsatnFormat>, ws::TableUpdate<JsonFormat>>,
365}
366
367#[derive(Debug, Clone)]
368pub struct SubscriptionError {
369 pub table_id: Option<TableId>,
370 pub message: Box<str>,
371}
372
373#[derive(Debug, Clone)]
374pub enum SubscriptionResult {
375 Subscribe(SubscriptionRows),
376 Unsubscribe(SubscriptionRows),
377 Error(SubscriptionError),
378 SubscribeMulti(SubscriptionData),
379 UnsubscribeMulti(SubscriptionData),
380}
381
382#[derive(Debug, Clone)]
383pub struct SubscriptionMessage {
384 pub timer: Option<Instant>,
385 pub request_id: Option<RequestId>,
386 pub query_id: Option<ws::QueryId>,
387 pub result: SubscriptionResult,
388}
389
390fn num_rows_in(rows: &SubscriptionRows) -> usize {
391 match &rows.table_rows {
392 FormatSwitch::Bsatn(x) => x.num_rows(),
393 FormatSwitch::Json(x) => x.num_rows(),
394 }
395}
396
397fn subscription_data_rows(rows: &SubscriptionData) -> usize {
398 match &rows.data {
399 FormatSwitch::Bsatn(x) => x.num_rows(),
400 FormatSwitch::Json(x) => x.num_rows(),
401 }
402}
403
404impl SubscriptionMessage {
405 fn num_rows(&self) -> usize {
406 match &self.result {
407 SubscriptionResult::Subscribe(x) => num_rows_in(x),
408 SubscriptionResult::SubscribeMulti(x) => subscription_data_rows(x),
409 SubscriptionResult::UnsubscribeMulti(x) => subscription_data_rows(x),
410 SubscriptionResult::Unsubscribe(x) => num_rows_in(x),
411 _ => 0,
412 }
413 }
414}
415
416impl ToProtocol for SubscriptionMessage {
417 type Encoded = SwitchedServerMessage;
418 fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
419 let request_id = self.request_id.unwrap_or(0);
420 let query_id = self.query_id.unwrap_or(ws::QueryId::new(0));
421 let total_host_execution_duration_micros = self.timer.map_or(0, |t| t.elapsed().as_micros() as u64);
422
423 match self.result {
424 SubscriptionResult::Subscribe(result) => {
425 protocol.assert_matches_format_switch(&result.table_rows);
426 match result.table_rows {
427 FormatSwitch::Bsatn(table_rows) => FormatSwitch::Bsatn(
428 ws::SubscribeApplied {
429 total_host_execution_duration_micros,
430 request_id,
431 query_id,
432 rows: ws::SubscribeRows {
433 table_id: result.table_id,
434 table_name: result.table_name,
435 table_rows,
436 },
437 }
438 .into(),
439 ),
440 FormatSwitch::Json(table_rows) => FormatSwitch::Json(
441 ws::SubscribeApplied {
442 total_host_execution_duration_micros,
443 request_id,
444 query_id,
445 rows: ws::SubscribeRows {
446 table_id: result.table_id,
447 table_name: result.table_name,
448 table_rows,
449 },
450 }
451 .into(),
452 ),
453 }
454 }
455 SubscriptionResult::Unsubscribe(result) => {
456 protocol.assert_matches_format_switch(&result.table_rows);
457 match result.table_rows {
458 FormatSwitch::Bsatn(table_rows) => FormatSwitch::Bsatn(
459 ws::UnsubscribeApplied {
460 total_host_execution_duration_micros,
461 request_id,
462 query_id,
463 rows: ws::SubscribeRows {
464 table_id: result.table_id,
465 table_name: result.table_name,
466 table_rows,
467 },
468 }
469 .into(),
470 ),
471 FormatSwitch::Json(table_rows) => FormatSwitch::Json(
472 ws::UnsubscribeApplied {
473 total_host_execution_duration_micros,
474 request_id,
475 query_id,
476 rows: ws::SubscribeRows {
477 table_id: result.table_id,
478 table_name: result.table_name,
479 table_rows,
480 },
481 }
482 .into(),
483 ),
484 }
485 }
486 SubscriptionResult::Error(error) => {
487 let msg = ws::SubscriptionError {
488 total_host_execution_duration_micros,
489 request_id: self.request_id, query_id: self.query_id.map(|x| x.id), table_id: error.table_id,
492 error: error.message,
493 };
494 match protocol {
495 Protocol::Binary => FormatSwitch::Bsatn(msg.into()),
496 Protocol::Text => FormatSwitch::Json(msg.into()),
497 }
498 }
499 SubscriptionResult::SubscribeMulti(result) => {
500 protocol.assert_matches_format_switch(&result.data);
501 match result.data {
502 FormatSwitch::Bsatn(data) => FormatSwitch::Bsatn(
503 ws::SubscribeMultiApplied {
504 total_host_execution_duration_micros,
505 request_id,
506 query_id,
507 update: data,
508 }
509 .into(),
510 ),
511 FormatSwitch::Json(data) => FormatSwitch::Json(
512 ws::SubscribeMultiApplied {
513 total_host_execution_duration_micros,
514 request_id,
515 query_id,
516 update: data,
517 }
518 .into(),
519 ),
520 }
521 }
522 SubscriptionResult::UnsubscribeMulti(result) => {
523 protocol.assert_matches_format_switch(&result.data);
524 match result.data {
525 FormatSwitch::Bsatn(data) => FormatSwitch::Bsatn(
526 ws::UnsubscribeMultiApplied {
527 total_host_execution_duration_micros,
528 request_id,
529 query_id,
530 update: data,
531 }
532 .into(),
533 ),
534 FormatSwitch::Json(data) => FormatSwitch::Json(
535 ws::UnsubscribeMultiApplied {
536 total_host_execution_duration_micros,
537 request_id,
538 query_id,
539 update: data,
540 }
541 .into(),
542 ),
543 }
544 }
545 }
546 }
547}
548
549#[derive(Debug)]
550pub struct OneOffQueryResponseMessage<F: WebsocketFormat> {
551 pub message_id: Vec<u8>,
552 pub error: Option<String>,
553 pub results: Vec<OneOffTable<F>>,
554 pub total_host_execution_duration: TimeDuration,
555}
556
557impl<F: WebsocketFormat> OneOffQueryResponseMessage<F> {
558 fn num_rows(&self) -> usize {
559 self.results.iter().map(|table| table.rows.len()).sum()
560 }
561}
562
563impl ToProtocol for OneOffQueryResponseMessage<BsatnFormat> {
564 type Encoded = SwitchedServerMessage;
565
566 fn to_protocol(self, _: Protocol) -> Self::Encoded {
567 FormatSwitch::Bsatn(convert(self))
568 }
569}
570
571impl ToProtocol for OneOffQueryResponseMessage<JsonFormat> {
572 type Encoded = SwitchedServerMessage;
573 fn to_protocol(self, _: Protocol) -> Self::Encoded {
574 FormatSwitch::Json(convert(self))
575 }
576}
577
578fn convert<F: WebsocketFormat>(msg: OneOffQueryResponseMessage<F>) -> ws::ServerMessage<F> {
579 ws::ServerMessage::OneOffQueryResponse(ws::OneOffQueryResponse {
580 message_id: msg.message_id.into(),
581 error: msg.error.map(Into::into),
582 tables: msg.results.into_boxed_slice(),
583 total_host_execution_duration: msg.total_host_execution_duration,
584 })
585}