spacetimedb_cli/subcommands/
subscribe.rs1use anyhow::Context;
2use clap::{value_parser, Arg, ArgAction, ArgMatches};
3use futures::{Sink, SinkExt, TryStream, TryStreamExt};
4use http::header;
5use http::uri::Scheme;
6use serde_json::Value;
7use spacetimedb_client_api_messages::websocket::{self as ws, JsonFormat};
8use spacetimedb_data_structures::map::HashMap;
9use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9;
10use spacetimedb_lib::de::serde::{DeserializeWrapper, SeedWrapper};
11use spacetimedb_lib::ser::serde::SerializeWrapper;
12use std::time::Duration;
13use tokio::io::AsyncWriteExt;
14use tokio_tungstenite::tungstenite::client::IntoClientRequest;
15use tokio_tungstenite::tungstenite::Message as WsMessage;
16
17use crate::api::ClientApi;
18use crate::common_args;
19use crate::sql::parse_req;
20use crate::util::UNSTABLE_WARNING;
21use crate::Config;
22
23pub fn cli() -> clap::Command {
24 clap::Command::new("subscribe")
25 .about(format!(
26 "Subscribe to SQL queries on the database. {}",
27 UNSTABLE_WARNING
28 ))
29 .arg(
30 Arg::new("database")
31 .required(true)
32 .help("The name or identity of the database you would like to query"),
33 )
34 .arg(
35 Arg::new("query")
36 .required(true)
37 .num_args(1..)
38 .help("The SQL query to execute"),
39 )
40 .arg(
41 Arg::new("num-updates")
42 .required(false)
43 .long("num-updates")
44 .short('n')
45 .action(ArgAction::Set)
46 .value_parser(value_parser!(u32))
47 .help("The number of subscription updates to receive before exiting"),
48 )
49 .arg(
50 Arg::new("timeout")
51 .required(false)
52 .short('t')
53 .long("timeout")
54 .action(ArgAction::Set)
55 .value_parser(value_parser!(u32))
56 .help(
57 "The timeout, in seconds, after which to disconnect and stop receiving \
58 subscription messages. If `-n` is specified, it will stop after whichever
59 one comes first.",
60 ),
61 )
62 .arg(
63 Arg::new("print_initial_update")
64 .required(false)
65 .long("print-initial-update")
66 .action(ArgAction::SetTrue)
67 .help("Print the initial update for the queries."),
68 )
69 .arg(common_args::anonymous())
70 .arg(common_args::yes())
71 .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database"))
72}
73
74fn parse_msg_json(msg: &WsMessage) -> Option<ws::ServerMessage<JsonFormat>> {
75 let WsMessage::Text(msg) = msg else { return None };
76 serde_json::from_str::<DeserializeWrapper<ws::ServerMessage<JsonFormat>>>(msg)
77 .inspect_err(|e| eprintln!("couldn't parse message from server: {e}"))
78 .map(|wrapper| wrapper.0)
79 .ok()
80}
81
82fn reformat_update<'a>(
83 msg: &'a ws::DatabaseUpdate<JsonFormat>,
84 schema: &RawModuleDefV9,
85) -> anyhow::Result<HashMap<&'a str, SubscriptionTable>> {
86 msg.tables
87 .iter()
88 .map(|upd| {
89 let table_schema = schema
90 .tables
91 .iter()
92 .find(|tbl| tbl.name == upd.table_name)
93 .context("table not found in schema")?;
94 let table_ty = schema.typespace.resolve(table_schema.product_type_ref);
95
96 let reformat_row = |row: &str| -> anyhow::Result<Value> {
97 let row = serde_json::from_str::<Value>(row)?;
99 let row = serde::de::DeserializeSeed::deserialize(SeedWrapper(table_ty), row)?;
100 let row = table_ty.with_value(&row);
101 let row = serde_json::to_value(SerializeWrapper::from_ref(&row))?;
102 Ok(row)
103 };
104
105 let mut deletes = Vec::new();
106 let mut inserts = Vec::new();
107 for upd in &upd.updates {
108 for s in &upd.deletes {
109 deletes.push(reformat_row(s)?);
110 }
111 for s in &upd.inserts {
112 inserts.push(reformat_row(s)?);
113 }
114 }
115
116 Ok((&*upd.table_name, SubscriptionTable { deletes, inserts }))
117 })
118 .collect()
119}
120
121#[derive(serde::Serialize, Debug)]
122struct SubscriptionTable {
123 deletes: Vec<serde_json::Value>,
124 inserts: Vec<serde_json::Value>,
125}
126
127pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> {
128 eprintln!("{}\n", UNSTABLE_WARNING);
129
130 let queries = args.get_many::<String>("query").unwrap();
131 let num = args.get_one::<u32>("num-updates").copied();
132 let timeout = args.get_one::<u32>("timeout").copied();
133 let print_initial_update = args.get_flag("print_initial_update");
134
135 let conn = parse_req(config, args).await?;
136 let api = ClientApi::new(conn);
137 let module_def = api.module_def().await?;
138
139 let mut uri = http::Uri::try_from(api.con.db_uri("subscribe"))?.into_parts();
141 uri.scheme = uri.scheme.map(|s| {
142 if s == Scheme::HTTP {
143 "ws".parse().unwrap()
144 } else if s == Scheme::HTTPS {
145 "wss".parse().unwrap()
146 } else {
147 s
148 }
149 });
150
151 let mut req = http::Uri::from_parts(uri)?.into_client_request()?;
153 req.headers_mut().insert(
154 header::SEC_WEBSOCKET_PROTOCOL,
155 http::HeaderValue::from_static(ws::TEXT_PROTOCOL),
156 );
157 if let Some(auth_header) = api.con.auth_header.to_header() {
159 req.headers_mut().insert(header::AUTHORIZATION, auth_header);
160 }
161 let (mut ws, _) = tokio_tungstenite::connect_async(req).await?;
162
163 let task = async {
164 subscribe(&mut ws, queries.cloned().map(Into::into).collect()).await?;
165 await_initial_update(&mut ws, print_initial_update.then_some(&module_def)).await?;
166 consume_transaction_updates(&mut ws, num, &module_def).await
167 };
168
169 let needs_shutdown = if let Some(timeout) = timeout {
170 let timeout = Duration::from_secs(timeout.into());
171 match tokio::time::timeout(timeout, task).await {
172 Ok(res) => res?,
173 Err(_elapsed) => true,
174 }
175 } else {
176 task.await?
177 };
178
179 if needs_shutdown {
180 ws.close(None).await?;
181 }
182
183 Ok(())
184}
185
186async fn subscribe<S>(ws: &mut S, query_strings: Box<[Box<str>]>) -> Result<(), S::Error>
188where
189 S: Sink<WsMessage> + Unpin,
190{
191 let msg = serde_json::to_string(&SerializeWrapper::new(ws::ClientMessage::<()>::Subscribe(
192 ws::Subscribe {
193 query_strings,
194 request_id: 0,
195 },
196 )))
197 .unwrap();
198 ws.send(msg.into()).await
199}
200
201async fn await_initial_update<S>(ws: &mut S, module_def: Option<&RawModuleDefV9>) -> anyhow::Result<()>
204where
205 S: TryStream<Ok = WsMessage> + Unpin,
206 S::Error: std::error::Error + Send + Sync + 'static,
207{
208 const RECV_TX_UPDATE: &str = "protocol error: received transaction update before initial subscription update";
209
210 while let Some(msg) = ws.try_next().await? {
211 let Some(msg) = parse_msg_json(&msg) else { continue };
212 match msg {
213 ws::ServerMessage::InitialSubscription(sub) => {
214 if let Some(module_def) = module_def {
215 let formatted = reformat_update(&sub.database_update, module_def)?;
216 let output = serde_json::to_string(&formatted)? + "\n";
217 tokio::io::stdout().write_all(output.as_bytes()).await?
218 }
219 break;
220 }
221 ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate { status, .. }) => anyhow::bail!(match status {
222 ws::UpdateStatus::Failed(msg) => msg,
223 _ => RECV_TX_UPDATE.into(),
224 }),
225 ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { .. }) => {
226 anyhow::bail!(RECV_TX_UPDATE)
227 }
228 _ => continue,
229 }
230 }
231
232 Ok(())
233}
234
235async fn consume_transaction_updates<S>(
238 ws: &mut S,
239 num: Option<u32>,
240 module_def: &RawModuleDefV9,
241) -> anyhow::Result<bool>
242where
243 S: TryStream<Ok = WsMessage> + Unpin,
244 S::Error: std::error::Error + Send + Sync + 'static,
245{
246 let mut stdout = tokio::io::stdout();
247 let mut num_received = 0;
248 loop {
249 if num.is_some_and(|n| num_received >= n) {
250 break Ok(true);
251 }
252 let Some(msg) = ws.try_next().await? else {
253 eprintln!("disconnected by server");
254 break Ok(false);
255 };
256
257 let Some(msg) = parse_msg_json(&msg) else { continue };
258 match msg {
259 ws::ServerMessage::InitialSubscription(_) => {
260 anyhow::bail!("protocol error: received a second initial subscription update")
261 }
262 ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { update, .. })
263 | ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate {
264 status: ws::UpdateStatus::Committed(update),
265 ..
266 }) => {
267 let output = serde_json::to_string(&reformat_update(&update, module_def)?)? + "\n";
268 stdout.write_all(output.as_bytes()).await?;
269 num_received += 1;
270 }
271 _ => continue,
272 }
273 }
274}