spacetimedb_cli/subcommands/
subscribe.rs1use anyhow::Context;
2use clap::{value_parser, Arg, ArgAction, ArgMatches};
3use futures::{Sink, SinkExt, TryStream, TryStreamExt};
4use http::header;
5use http::uri::Scheme;
6use serde_json::Value;
7use spacetimedb_client_api_messages::websocket::{self as ws, JsonFormat};
8use spacetimedb_data_structures::map::HashMap;
9use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9;
10use spacetimedb_lib::de::serde::{DeserializeWrapper, SeedWrapper};
11use spacetimedb_lib::ser::serde::SerializeWrapper;
12use std::io;
13use std::time::Duration;
14use thiserror::Error;
15use tokio::io::AsyncWriteExt;
16use tokio_tungstenite::tungstenite::client::IntoClientRequest;
17use tokio_tungstenite::tungstenite::{Error as WsError, Message as WsMessage};
18
19use crate::api::ClientApi;
20use crate::common_args;
21use crate::sql::parse_req;
22use crate::util::UNSTABLE_WARNING;
23use crate::Config;
24
25pub fn cli() -> clap::Command {
26 clap::Command::new("subscribe")
27 .about(format!("Subscribe to SQL queries on the database. {UNSTABLE_WARNING}"))
28 .arg(
29 Arg::new("database")
30 .required(true)
31 .help("The name or identity of the database you would like to query"),
32 )
33 .arg(
34 Arg::new("query")
35 .required(true)
36 .num_args(1..)
37 .help("The SQL query to execute"),
38 )
39 .arg(
40 Arg::new("num-updates")
41 .required(false)
42 .long("num-updates")
43 .short('n')
44 .action(ArgAction::Set)
45 .value_parser(value_parser!(u32))
46 .help("The number of subscription updates to receive before exiting"),
47 )
48 .arg(
49 Arg::new("timeout")
50 .required(false)
51 .short('t')
52 .long("timeout")
53 .action(ArgAction::Set)
54 .value_parser(value_parser!(u32))
55 .help(
56 "The timeout, in seconds, after which to disconnect and stop receiving \
57 subscription messages. If `-n` is specified, it will stop after whichever
58 one comes first.",
59 ),
60 )
61 .arg(
62 Arg::new("print_initial_update")
63 .required(false)
64 .long("print-initial-update")
65 .action(ArgAction::SetTrue)
66 .help("Print the initial update for the queries."),
67 )
68 .arg(common_args::anonymous())
69 .arg(common_args::yes())
70 .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database"))
71}
72
73fn parse_msg_json(msg: &WsMessage) -> Option<ws::ServerMessage<JsonFormat>> {
74 let WsMessage::Text(msg) = msg else { return None };
75 serde_json::from_str::<DeserializeWrapper<ws::ServerMessage<JsonFormat>>>(msg)
76 .inspect_err(|e| eprintln!("couldn't parse message from server: {e}"))
77 .map(|wrapper| wrapper.0)
78 .ok()
79}
80
81fn reformat_update<'a>(
82 msg: &'a ws::DatabaseUpdate<JsonFormat>,
83 schema: &RawModuleDefV9,
84) -> anyhow::Result<HashMap<&'a str, SubscriptionTable>> {
85 msg.tables
86 .iter()
87 .map(|upd| {
88 let table_schema = schema
89 .tables
90 .iter()
91 .find(|tbl| tbl.name == upd.table_name)
92 .context("table not found in schema")?;
93 let table_ty = schema.typespace.resolve(table_schema.product_type_ref);
94
95 let reformat_row = |row: &str| -> anyhow::Result<Value> {
96 let row = serde_json::from_str::<Value>(row)?;
98 let row = serde::de::DeserializeSeed::deserialize(SeedWrapper(table_ty), row)?;
99 let row = table_ty.with_value(&row);
100 let row = serde_json::to_value(SerializeWrapper::from_ref(&row))?;
101 Ok(row)
102 };
103
104 let mut deletes = Vec::new();
105 let mut inserts = Vec::new();
106 for upd in &upd.updates {
107 for s in &upd.deletes {
108 deletes.push(reformat_row(s)?);
109 }
110 for s in &upd.inserts {
111 inserts.push(reformat_row(s)?);
112 }
113 }
114
115 Ok((&*upd.table_name, SubscriptionTable { deletes, inserts }))
116 })
117 .collect()
118}
119
120#[derive(serde::Serialize, Debug)]
121struct SubscriptionTable {
122 deletes: Vec<serde_json::Value>,
123 inserts: Vec<serde_json::Value>,
124}
125
126pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> {
127 eprintln!("{UNSTABLE_WARNING}\n");
128
129 let queries = args.get_many::<String>("query").unwrap();
130 let num = args.get_one::<u32>("num-updates").copied();
131 let timeout = args.get_one::<u32>("timeout").copied();
132 let print_initial_update = args.get_flag("print_initial_update");
133
134 let conn = parse_req(config, args).await?;
135 let api = ClientApi::new(conn);
136 let module_def = api.module_def().await?;
137
138 let mut uri = http::Uri::try_from(api.con.db_uri("subscribe"))?.into_parts();
140 uri.scheme = uri.scheme.map(|s| {
141 if s == Scheme::HTTP {
142 "ws".parse().unwrap()
143 } else if s == Scheme::HTTPS {
144 "wss".parse().unwrap()
145 } else {
146 s
147 }
148 });
149
150 let mut req = http::Uri::from_parts(uri)?.into_client_request()?;
152 req.headers_mut().insert(
153 header::SEC_WEBSOCKET_PROTOCOL,
154 http::HeaderValue::from_static(ws::TEXT_PROTOCOL),
155 );
156 if let Some(auth_header) = api.con.auth_header.to_header() {
158 req.headers_mut().insert(header::AUTHORIZATION, auth_header);
159 }
160 let mut ws = tokio_tungstenite::connect_async(req).await.map(|(ws, _)| ws)?;
161
162 let task = async {
163 subscribe(&mut ws, queries.cloned().map(Into::into).collect()).await?;
164 await_initial_update(&mut ws, print_initial_update.then_some(&module_def)).await?;
165 consume_transaction_updates(&mut ws, num, &module_def).await
166 };
167
168 let res = if let Some(timeout) = timeout {
169 let timeout = Duration::from_secs(timeout.into());
170 match tokio::time::timeout(timeout, task).await {
171 Ok(res) => res,
172 Err(_elapsed) => {
173 eprintln!("timed out after {}s", timeout.as_secs());
174 Ok(())
175 }
176 }
177 } else {
178 task.await
179 };
180
181 let _ = ws.close(None).await;
188 res.or_else(|e| {
191 if e.is_server_closed_connection() {
192 Ok(())
193 } else {
194 Err(e)
195 }
196 })
197 .map_err(anyhow::Error::from)
198}
199
200#[derive(Debug, Error)]
201enum Error {
202 #[error("error sending subscription queries")]
203 Subscribe {
204 #[source]
205 source: WsError,
206 },
207 #[error("protocol error: {details}")]
208 Protocol { details: &'static str },
209 #[error("websocket error: {source}")]
210 Websocket {
211 #[source]
212 source: WsError,
213 },
214 #[error("encountered failed transaction: {reason}")]
215 TransactionFailure { reason: Box<str> },
216 #[error("error formatting response: {source:#}")]
217 Reformat {
218 #[source]
219 source: anyhow::Error,
220 },
221 #[error(transparent)]
222 Serde(#[from] serde_json::Error),
223 #[error(transparent)]
224 Io(#[from] io::Error),
225}
226
227impl Error {
228 fn is_server_closed_connection(&self) -> bool {
229 matches!(
230 self,
231 Self::Websocket {
232 source: WsError::ConnectionClosed
233 }
234 )
235 }
236}
237
238async fn subscribe<S>(ws: &mut S, query_strings: Box<[Box<str>]>) -> Result<(), Error>
240where
241 S: Sink<WsMessage, Error = WsError> + Unpin,
242{
243 let msg = serde_json::to_string(&SerializeWrapper::new(ws::ClientMessage::<()>::Subscribe(
244 ws::Subscribe {
245 query_strings,
246 request_id: 0,
247 },
248 )))
249 .unwrap();
250 ws.send(msg.into()).await.map_err(|source| Error::Subscribe { source })
251}
252
253async fn await_initial_update<S>(ws: &mut S, module_def: Option<&RawModuleDefV9>) -> Result<(), Error>
256where
257 S: TryStream<Ok = WsMessage, Error = WsError> + Unpin,
258{
259 const RECV_TX_UPDATE: &str = "protocol error: received transaction update before initial subscription update";
260
261 while let Some(msg) = ws.try_next().await.map_err(|source| Error::Websocket { source })? {
262 let Some(msg) = parse_msg_json(&msg) else { continue };
263 match msg {
264 ws::ServerMessage::InitialSubscription(sub) => {
265 if let Some(module_def) = module_def {
266 let output = format_output_json(&sub.database_update, module_def)?;
267 tokio::io::stdout().write_all(output.as_bytes()).await?
268 }
269 break;
270 }
271 ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate { status, .. }) => {
272 return Err(match status {
273 ws::UpdateStatus::Failed(msg) => Error::TransactionFailure { reason: msg },
274 _ => Error::Protocol {
275 details: RECV_TX_UPDATE,
276 },
277 })
278 }
279 ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { .. }) => {
280 return Err(Error::Protocol {
281 details: RECV_TX_UPDATE,
282 })
283 }
284 _ => continue,
285 }
286 }
287
288 Ok(())
289}
290
291async fn consume_transaction_updates<S>(ws: &mut S, num: Option<u32>, module_def: &RawModuleDefV9) -> Result<(), Error>
294where
295 S: TryStream<Ok = WsMessage, Error = WsError> + Unpin,
296{
297 let mut stdout = tokio::io::stdout();
298 let mut num_received = 0;
299 loop {
300 if num.is_some_and(|n| num_received >= n) {
301 return Ok(());
302 }
303 let Some(msg) = ws.try_next().await.map_err(|source| Error::Websocket { source })? else {
304 eprintln!("disconnected by server");
305 return Err(Error::Websocket {
306 source: WsError::ConnectionClosed,
307 });
308 };
309
310 let Some(msg) = parse_msg_json(&msg) else { continue };
311 match msg {
312 ws::ServerMessage::InitialSubscription(_) => {
313 return Err(Error::Protocol {
314 details: "received a second initial subscription update",
315 })
316 }
317 ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { update, .. })
318 | ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate {
319 status: ws::UpdateStatus::Committed(update),
320 ..
321 }) => {
322 let output = format_output_json(&update, module_def)?;
323 stdout.write_all(output.as_bytes()).await?;
324 num_received += 1;
325 }
326 _ => continue,
327 }
328 }
329}
330
331fn format_output_json(msg: &ws::DatabaseUpdate<JsonFormat>, schema: &RawModuleDefV9) -> Result<String, Error> {
332 let formatted = reformat_update(msg, schema).map_err(|source| Error::Reformat { source })?;
333 let output = serde_json::to_string(&formatted)? + "\n";
334
335 Ok(output)
336}