1use super::messages::{SubscriptionUpdateMessage, SwitchedServerMessage, ToProtocol, TransactionUpdateMessage};
2use super::{ClientConnection, DataMessage, Protocol};
3use crate::energy::EnergyQuanta;
4use crate::host::module_host::{EventStatus, ModuleEvent, ModuleFunctionCall};
5use crate::host::{ReducerArgs, ReducerId};
6use crate::identity::Identity;
7use crate::messages::websocket::{CallReducer, ClientMessage, OneOffQuery};
8use crate::worker_metrics::WORKER_METRICS;
9use spacetimedb_datastore::execution_context::WorkloadType;
10use spacetimedb_lib::de::serde::DeserializeWrapper;
11use spacetimedb_lib::identity::RequestId;
12use spacetimedb_lib::{bsatn, ConnectionId, Timestamp};
13use std::borrow::Cow;
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17#[derive(thiserror::Error, Debug)]
18pub enum MessageHandleError {
19 #[error(transparent)]
20 BinaryDecode(#[from] bsatn::DecodeError),
21 #[error(transparent)]
22 TextDecode(#[from] serde_json::Error),
23 #[error(transparent)]
24 Base64Decode(#[from] base64::DecodeError),
25
26 #[error(transparent)]
27 Execution(#[from] MessageExecutionError),
28}
29
30pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Instant) -> Result<(), MessageHandleError> {
31 client.observe_websocket_request_message(&message);
32
33 let message = match message {
34 DataMessage::Text(text) => {
35 let DeserializeWrapper(message) =
37 serde_json::from_str::<DeserializeWrapper<ClientMessage<Cow<str>>>>(&text)?;
38 message.map_args(|s| {
39 ReducerArgs::Json(match s {
40 Cow::Borrowed(s) => text.slice_ref(s),
41 Cow::Owned(string) => string.into(),
42 })
43 })
44 }
45 DataMessage::Binary(message_buf) => bsatn::from_slice::<ClientMessage<&[u8]>>(&message_buf)?
46 .map_args(|b| ReducerArgs::Bsatn(message_buf.slice_ref(b))),
47 };
48
49 let mod_info = client.module.info();
50 let mod_metrics = &mod_info.metrics;
51 let database_identity = mod_info.database_identity;
52 let db = &client.module.replica_ctx().relational_db;
53 let record_metrics = |wl| {
54 move |metrics| {
55 if let Some(metrics) = metrics {
56 db.exec_counters_for(wl).record(&metrics);
57 }
58 }
59 };
60 let sub_metrics = record_metrics(WorkloadType::Subscribe);
61 let unsub_metrics = record_metrics(WorkloadType::Unsubscribe);
62
63 let res = match message {
64 ClientMessage::CallReducer(CallReducer {
65 ref reducer,
66 args,
67 request_id,
68 flags,
69 }) => {
70 let res = client.call_reducer(reducer, args, request_id, timer, flags).await;
71 WORKER_METRICS
72 .request_round_trip
73 .with_label_values(&WorkloadType::Reducer, &database_identity, reducer)
74 .observe(timer.elapsed().as_secs_f64());
75 res.map(drop).map_err(|e| {
76 (
77 Some(reducer),
78 mod_info.module_def.reducer_full(&**reducer).map(|(id, _)| id),
79 e.into(),
80 )
81 })
82 }
83 ClientMessage::SubscribeMulti(subscription) => {
84 let res = client.subscribe_multi(subscription, timer).await.map(sub_metrics);
85 mod_metrics
86 .request_round_trip_subscribe
87 .observe(timer.elapsed().as_secs_f64());
88 res.map_err(|e| (None, None, e.into()))
89 }
90 ClientMessage::UnsubscribeMulti(request) => {
91 let res = client.unsubscribe_multi(request, timer).await.map(unsub_metrics);
92 mod_metrics
93 .request_round_trip_unsubscribe
94 .observe(timer.elapsed().as_secs_f64());
95 res.map_err(|e| (None, None, e.into()))
96 }
97 ClientMessage::SubscribeSingle(subscription) => {
98 let res = client.subscribe_single(subscription, timer).await.map(sub_metrics);
99 mod_metrics
100 .request_round_trip_subscribe
101 .observe(timer.elapsed().as_secs_f64());
102 res.map_err(|e| (None, None, e.into()))
103 }
104 ClientMessage::Unsubscribe(request) => {
105 let res = client.unsubscribe(request, timer).await.map(unsub_metrics);
106 mod_metrics
107 .request_round_trip_unsubscribe
108 .observe(timer.elapsed().as_secs_f64());
109 res.map_err(|e| (None, None, e.into()))
110 }
111 ClientMessage::Subscribe(subscription) => {
112 let res = client.subscribe(subscription, timer).await.map(Some).map(sub_metrics);
113 mod_metrics
114 .request_round_trip_subscribe
115 .observe(timer.elapsed().as_secs_f64());
116 res.map_err(|e| (None, None, e.into()))
117 }
118 ClientMessage::OneOffQuery(OneOffQuery {
119 query_string: query,
120 message_id,
121 }) => {
122 let res = match client.config.protocol {
123 Protocol::Binary => client.one_off_query_bsatn(&query, &message_id, timer).await,
124 Protocol::Text => client.one_off_query_json(&query, &message_id, timer).await,
125 };
126 mod_metrics
127 .request_round_trip_sql
128 .observe(timer.elapsed().as_secs_f64());
129 res.map_err(|err| (None, None, err))
130 }
131 };
132 res.map_err(|(reducer, reducer_id, err)| MessageExecutionError {
133 reducer: reducer.cloned(),
134 reducer_id,
135 caller_identity: client.id.identity,
136 caller_connection_id: Some(client.id.connection_id),
137 err,
138 })?;
139
140 Ok(())
141}
142
143#[derive(thiserror::Error, Debug)]
144#[error("error executing message (reducer: {reducer:?}) (err: {err:#})")]
145pub struct MessageExecutionError {
146 pub reducer: Option<Box<str>>,
147 pub reducer_id: Option<ReducerId>,
148 pub caller_identity: Identity,
149 pub caller_connection_id: Option<ConnectionId>,
150 #[source]
151 pub err: anyhow::Error,
152}
153
154impl MessageExecutionError {
155 fn into_event(self) -> ModuleEvent {
156 ModuleEvent {
157 timestamp: Timestamp::now(),
158 caller_identity: self.caller_identity,
159 caller_connection_id: self.caller_connection_id,
160 function_call: ModuleFunctionCall {
161 reducer: self.reducer.unwrap_or_else(|| "<none>".into()).into(),
162 reducer_id: self.reducer_id.unwrap_or(u32::MAX.into()),
163 args: Default::default(),
164 },
165 status: EventStatus::Failed(format!("{:#}", self.err)),
166 energy_quanta_used: EnergyQuanta::ZERO,
167 host_execution_duration: Duration::ZERO,
168 request_id: Some(RequestId::default()),
169 timer: None,
170 }
171 }
172}
173
174impl ToProtocol for MessageExecutionError {
175 type Encoded = SwitchedServerMessage;
176 fn to_protocol(self, protocol: super::Protocol) -> Self::Encoded {
177 TransactionUpdateMessage {
178 event: Some(Arc::new(self.into_event())),
179 database_update: SubscriptionUpdateMessage::default_for_protocol(protocol, None),
180 }
181 .to_protocol(protocol)
182 }
183}