1use super::{ClientConfig, DataMessage, Protocol};
2use crate::execution_context::WorkloadType;
3use crate::host::module_host::{EventStatus, ModuleEvent};
4use crate::host::ArgsTuple;
5use crate::messages::websocket as ws;
6use derive_more::From;
7use spacetimedb_client_api_messages::websocket::{
8 BsatnFormat, Compression, FormatSwitch, JsonFormat, OneOffTable, RowListLen, WebsocketFormat,
9 SERVER_MSG_COMPRESSION_TAG_BROTLI, SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE,
10};
11use spacetimedb_lib::identity::RequestId;
12use spacetimedb_lib::ser::serde::SerializeWrapper;
13use spacetimedb_lib::{ConnectionId, TimeDuration};
14use spacetimedb_primitives::TableId;
15use spacetimedb_sats::bsatn;
16use std::sync::Arc;
17use std::time::Instant;
18
19pub trait ToProtocol {
22 type Encoded;
23 fn to_protocol(self, protocol: Protocol) -> Self::Encoded;
25}
26
27pub(super) type SwitchedServerMessage = FormatSwitch<ws::ServerMessage<BsatnFormat>, ws::ServerMessage<JsonFormat>>;
28pub(super) type SwitchedDbUpdate = FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>;
29
30pub fn serialize(msg: impl ToProtocol<Encoded = SwitchedServerMessage>, config: ClientConfig) -> DataMessage {
35 match msg.to_protocol(config.protocol) {
38 FormatSwitch::Json(msg) => serde_json::to_string(&SerializeWrapper::new(msg)).unwrap().into(),
39 FormatSwitch::Bsatn(msg) => {
40 let mut msg_bytes = vec![SERVER_MSG_COMPRESSION_TAG_NONE];
42 bsatn::to_writer(&mut msg_bytes, &msg).unwrap();
43
44 let srv_msg = &msg_bytes[1..];
46 let msg_bytes = match ws::decide_compression(srv_msg.len(), config.compression) {
47 Compression::None => msg_bytes,
48 Compression::Brotli => {
49 let mut out = vec![SERVER_MSG_COMPRESSION_TAG_BROTLI];
50 ws::brotli_compress(srv_msg, &mut out);
51 out
52 }
53 Compression::Gzip => {
54 let mut out = vec![SERVER_MSG_COMPRESSION_TAG_GZIP];
55 ws::gzip_compress(srv_msg, &mut out);
56 out
57 }
58 };
59 msg_bytes.into()
60 }
61 }
62}
63
64#[derive(Debug, From)]
65pub enum SerializableMessage {
66 QueryBinary(OneOffQueryResponseMessage<BsatnFormat>),
67 QueryText(OneOffQueryResponseMessage<JsonFormat>),
68 Identity(IdentityTokenMessage),
69 Subscribe(SubscriptionUpdateMessage),
70 Subscription(SubscriptionMessage),
71 TxUpdate(TransactionUpdateMessage),
72}
73
74impl SerializableMessage {
75 pub fn num_rows(&self) -> Option<usize> {
76 match self {
77 Self::QueryBinary(msg) => Some(msg.num_rows()),
78 Self::QueryText(msg) => Some(msg.num_rows()),
79 Self::Subscribe(msg) => Some(msg.num_rows()),
80 Self::Subscription(msg) => Some(msg.num_rows()),
81 Self::TxUpdate(msg) => Some(msg.num_rows()),
82 Self::Identity(_) => None,
83 }
84 }
85
86 pub fn workload(&self) -> Option<WorkloadType> {
87 match self {
88 Self::QueryBinary(_) | Self::QueryText(_) => Some(WorkloadType::Sql),
89 Self::Subscribe(_) => Some(WorkloadType::Subscribe),
90 Self::Subscription(msg) => match &msg.result {
91 SubscriptionResult::Subscribe(_) => Some(WorkloadType::Subscribe),
92 SubscriptionResult::Unsubscribe(_) => Some(WorkloadType::Unsubscribe),
93 SubscriptionResult::Error(_) => None,
94 SubscriptionResult::SubscribeMulti(_) => Some(WorkloadType::Subscribe),
95 SubscriptionResult::UnsubscribeMulti(_) => Some(WorkloadType::Unsubscribe),
96 },
97 Self::TxUpdate(_) => Some(WorkloadType::Update),
98 Self::Identity(_) => None,
99 }
100 }
101}
102
103impl ToProtocol for SerializableMessage {
104 type Encoded = SwitchedServerMessage;
105 fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
106 match self {
107 SerializableMessage::QueryBinary(msg) => msg.to_protocol(protocol),
108 SerializableMessage::QueryText(msg) => msg.to_protocol(protocol),
109 SerializableMessage::Identity(msg) => msg.to_protocol(protocol),
110 SerializableMessage::Subscribe(msg) => msg.to_protocol(protocol),
111 SerializableMessage::TxUpdate(msg) => msg.to_protocol(protocol),
112 SerializableMessage::Subscription(msg) => msg.to_protocol(protocol),
113 }
114 }
115}
116
117pub type IdentityTokenMessage = ws::IdentityToken;
118
119impl ToProtocol for IdentityTokenMessage {
120 type Encoded = SwitchedServerMessage;
121 fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
122 match protocol {
123 Protocol::Text => FormatSwitch::Json(ws::ServerMessage::IdentityToken(self)),
124 Protocol::Binary => FormatSwitch::Bsatn(ws::ServerMessage::IdentityToken(self)),
125 }
126 }
127}
128
129#[derive(Debug)]
130pub struct TransactionUpdateMessage {
131 pub event: Option<Arc<ModuleEvent>>,
134 pub database_update: SubscriptionUpdateMessage,
135}
136
137impl TransactionUpdateMessage {
138 fn num_rows(&self) -> usize {
139 self.database_update.num_rows()
140 }
141}
142
143impl ToProtocol for TransactionUpdateMessage {
144 type Encoded = SwitchedServerMessage;
145 fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
146 fn convert<F: WebsocketFormat>(
147 event: Option<Arc<ModuleEvent>>,
148 request_id: u32,
149 update: ws::DatabaseUpdate<F>,
150 conv_args: impl FnOnce(&ArgsTuple) -> F::Single,
151 ) -> ws::ServerMessage<F> {
152 let Some(event) = event else {
153 return ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { request_id, update });
154 };
155
156 let status = match &event.status {
157 EventStatus::Committed(_) => ws::UpdateStatus::Committed(update),
158 EventStatus::Failed(errmsg) => ws::UpdateStatus::Failed(errmsg.clone().into()),
159 EventStatus::OutOfEnergy => ws::UpdateStatus::OutOfEnergy,
160 };
161
162 let args = conv_args(&event.function_call.args);
163
164 let tx_update = ws::TransactionUpdate {
165 timestamp: event.timestamp,
166 status,
167 caller_identity: event.caller_identity,
168 reducer_call: ws::ReducerCallInfo {
169 reducer_name: event.function_call.reducer.to_owned().into(),
170 reducer_id: event.function_call.reducer_id.into(),
171 args,
172 request_id,
173 },
174 energy_quanta_used: event.energy_quanta_used,
175 total_host_execution_duration: event.host_execution_duration.into(),
176 caller_connection_id: event.caller_connection_id.unwrap_or(ConnectionId::ZERO),
177 };
178
179 ws::ServerMessage::TransactionUpdate(tx_update)
180 }
181
182 let TransactionUpdateMessage { event, database_update } = self;
183 let update = database_update.database_update;
184 protocol.assert_matches_format_switch(&update);
185 let request_id = database_update.request_id.unwrap_or(0);
186 match update {
187 FormatSwitch::Bsatn(update) => FormatSwitch::Bsatn(convert(event, request_id, update, |args| {
188 Vec::from(args.get_bsatn().clone()).into()
189 })),
190 FormatSwitch::Json(update) => {
191 FormatSwitch::Json(convert(event, request_id, update, |args| args.get_json().clone()))
192 }
193 }
194 }
195}
196
197#[derive(Debug, Clone)]
198pub struct SubscriptionUpdateMessage {
199 pub database_update: SwitchedDbUpdate,
200 pub request_id: Option<RequestId>,
201 pub timer: Option<Instant>,
202}
203
204impl SubscriptionUpdateMessage {
205 pub(crate) fn default_for_protocol(protocol: Protocol, request_id: Option<RequestId>) -> Self {
206 Self {
207 database_update: match protocol {
208 Protocol::Text => FormatSwitch::Json(<_>::default()),
209 Protocol::Binary => FormatSwitch::Bsatn(<_>::default()),
210 },
211 request_id,
212 timer: None,
213 }
214 }
215
216 pub(crate) fn from_event_and_update(event: &ModuleEvent, update: SwitchedDbUpdate) -> Self {
217 Self {
218 database_update: update,
219 request_id: event.request_id,
220 timer: event.timer,
221 }
222 }
223
224 fn num_rows(&self) -> usize {
225 match &self.database_update {
226 FormatSwitch::Bsatn(x) => x.num_rows(),
227 FormatSwitch::Json(x) => x.num_rows(),
228 }
229 }
230}
231
232impl ToProtocol for SubscriptionUpdateMessage {
233 type Encoded = SwitchedServerMessage;
234 fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
235 let request_id = self.request_id.unwrap_or(0);
236 let total_host_execution_duration = self.timer.map_or(TimeDuration::ZERO, |t| t.elapsed().into());
237
238 protocol.assert_matches_format_switch(&self.database_update);
239 match self.database_update {
240 FormatSwitch::Bsatn(database_update) => {
241 FormatSwitch::Bsatn(ws::ServerMessage::InitialSubscription(ws::InitialSubscription {
242 database_update,
243 request_id,
244 total_host_execution_duration,
245 }))
246 }
247 FormatSwitch::Json(database_update) => {
248 FormatSwitch::Json(ws::ServerMessage::InitialSubscription(ws::InitialSubscription {
249 database_update,
250 request_id,
251 total_host_execution_duration,
252 }))
253 }
254 }
255 }
256}
257
258#[derive(Debug, Clone)]
259pub struct SubscriptionData {
260 pub data: FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>,
261}
262
263#[derive(Debug, Clone)]
264pub struct SubscriptionRows {
265 pub table_id: TableId,
266 pub table_name: Box<str>,
267 pub table_rows: FormatSwitch<ws::TableUpdate<BsatnFormat>, ws::TableUpdate<JsonFormat>>,
268}
269
270#[derive(Debug, Clone)]
271pub struct SubscriptionError {
272 pub table_id: Option<TableId>,
273 pub message: Box<str>,
274}
275
276#[derive(Debug, Clone)]
277pub enum SubscriptionResult {
278 Subscribe(SubscriptionRows),
279 Unsubscribe(SubscriptionRows),
280 Error(SubscriptionError),
281 SubscribeMulti(SubscriptionData),
282 UnsubscribeMulti(SubscriptionData),
283}
284
285#[derive(Debug, Clone)]
286pub struct SubscriptionMessage {
287 pub timer: Option<Instant>,
288 pub request_id: Option<RequestId>,
289 pub query_id: Option<ws::QueryId>,
290 pub result: SubscriptionResult,
291}
292
293fn num_rows_in(rows: &SubscriptionRows) -> usize {
294 match &rows.table_rows {
295 FormatSwitch::Bsatn(x) => x.num_rows(),
296 FormatSwitch::Json(x) => x.num_rows(),
297 }
298}
299
300fn subscription_data_rows(rows: &SubscriptionData) -> usize {
301 match &rows.data {
302 FormatSwitch::Bsatn(x) => x.num_rows(),
303 FormatSwitch::Json(x) => x.num_rows(),
304 }
305}
306
307impl SubscriptionMessage {
308 fn num_rows(&self) -> usize {
309 match &self.result {
310 SubscriptionResult::Subscribe(x) => num_rows_in(x),
311 SubscriptionResult::SubscribeMulti(x) => subscription_data_rows(x),
312 SubscriptionResult::UnsubscribeMulti(x) => subscription_data_rows(x),
313 SubscriptionResult::Unsubscribe(x) => num_rows_in(x),
314 _ => 0,
315 }
316 }
317}
318
319impl ToProtocol for SubscriptionMessage {
320 type Encoded = SwitchedServerMessage;
321 fn to_protocol(self, protocol: Protocol) -> Self::Encoded {
322 let request_id = self.request_id.unwrap_or(0);
323 let query_id = self.query_id.unwrap_or(ws::QueryId::new(0));
324 let total_host_execution_duration_micros = self.timer.map_or(0, |t| t.elapsed().as_micros() as u64);
325
326 match self.result {
327 SubscriptionResult::Subscribe(result) => {
328 protocol.assert_matches_format_switch(&result.table_rows);
329 match result.table_rows {
330 FormatSwitch::Bsatn(table_rows) => FormatSwitch::Bsatn(
331 ws::SubscribeApplied {
332 total_host_execution_duration_micros,
333 request_id,
334 query_id,
335 rows: ws::SubscribeRows {
336 table_id: result.table_id,
337 table_name: result.table_name,
338 table_rows,
339 },
340 }
341 .into(),
342 ),
343 FormatSwitch::Json(table_rows) => FormatSwitch::Json(
344 ws::SubscribeApplied {
345 total_host_execution_duration_micros,
346 request_id,
347 query_id,
348 rows: ws::SubscribeRows {
349 table_id: result.table_id,
350 table_name: result.table_name,
351 table_rows,
352 },
353 }
354 .into(),
355 ),
356 }
357 }
358 SubscriptionResult::Unsubscribe(result) => {
359 protocol.assert_matches_format_switch(&result.table_rows);
360 match result.table_rows {
361 FormatSwitch::Bsatn(table_rows) => FormatSwitch::Bsatn(
362 ws::UnsubscribeApplied {
363 total_host_execution_duration_micros,
364 request_id,
365 query_id,
366 rows: ws::SubscribeRows {
367 table_id: result.table_id,
368 table_name: result.table_name,
369 table_rows,
370 },
371 }
372 .into(),
373 ),
374 FormatSwitch::Json(table_rows) => FormatSwitch::Json(
375 ws::UnsubscribeApplied {
376 total_host_execution_duration_micros,
377 request_id,
378 query_id,
379 rows: ws::SubscribeRows {
380 table_id: result.table_id,
381 table_name: result.table_name,
382 table_rows,
383 },
384 }
385 .into(),
386 ),
387 }
388 }
389 SubscriptionResult::Error(error) => {
390 let msg = ws::SubscriptionError {
391 total_host_execution_duration_micros,
392 request_id: self.request_id, query_id: self.query_id.map(|x| x.id), table_id: error.table_id,
395 error: error.message,
396 };
397 match protocol {
398 Protocol::Binary => FormatSwitch::Bsatn(msg.into()),
399 Protocol::Text => FormatSwitch::Json(msg.into()),
400 }
401 }
402 SubscriptionResult::SubscribeMulti(result) => {
403 protocol.assert_matches_format_switch(&result.data);
404 match result.data {
405 FormatSwitch::Bsatn(data) => FormatSwitch::Bsatn(
406 ws::SubscribeMultiApplied {
407 total_host_execution_duration_micros,
408 request_id,
409 query_id,
410 update: data,
411 }
412 .into(),
413 ),
414 FormatSwitch::Json(data) => FormatSwitch::Json(
415 ws::SubscribeMultiApplied {
416 total_host_execution_duration_micros,
417 request_id,
418 query_id,
419 update: data,
420 }
421 .into(),
422 ),
423 }
424 }
425 SubscriptionResult::UnsubscribeMulti(result) => {
426 protocol.assert_matches_format_switch(&result.data);
427 match result.data {
428 FormatSwitch::Bsatn(data) => FormatSwitch::Bsatn(
429 ws::UnsubscribeMultiApplied {
430 total_host_execution_duration_micros,
431 request_id,
432 query_id,
433 update: data,
434 }
435 .into(),
436 ),
437 FormatSwitch::Json(data) => FormatSwitch::Json(
438 ws::UnsubscribeMultiApplied {
439 total_host_execution_duration_micros,
440 request_id,
441 query_id,
442 update: data,
443 }
444 .into(),
445 ),
446 }
447 }
448 }
449 }
450}
451
452#[derive(Debug)]
453pub struct OneOffQueryResponseMessage<F: WebsocketFormat> {
454 pub message_id: Vec<u8>,
455 pub error: Option<String>,
456 pub results: Vec<OneOffTable<F>>,
457 pub total_host_execution_duration: TimeDuration,
458}
459
460impl<F: WebsocketFormat> OneOffQueryResponseMessage<F> {
461 fn num_rows(&self) -> usize {
462 self.results.iter().map(|table| table.rows.len()).sum()
463 }
464}
465
466impl ToProtocol for OneOffQueryResponseMessage<BsatnFormat> {
467 type Encoded = SwitchedServerMessage;
468
469 fn to_protocol(self, _: Protocol) -> Self::Encoded {
470 FormatSwitch::Bsatn(convert(self))
471 }
472}
473
474impl ToProtocol for OneOffQueryResponseMessage<JsonFormat> {
475 type Encoded = SwitchedServerMessage;
476 fn to_protocol(self, _: Protocol) -> Self::Encoded {
477 FormatSwitch::Json(convert(self))
478 }
479}
480
481fn convert<F: WebsocketFormat>(msg: OneOffQueryResponseMessage<F>) -> ws::ServerMessage<F> {
482 ws::ServerMessage::OneOffQueryResponse(ws::OneOffQueryResponse {
483 message_id: msg.message_id.into(),
484 error: msg.error.map(Into::into),
485 tables: msg.results.into_boxed_slice(),
486 total_host_execution_duration: msg.total_host_execution_duration,
487 })
488}