spacetimedb_cli/subcommands/
subscribe.rs

1use anyhow::Context;
2use clap::{value_parser, Arg, ArgAction, ArgMatches};
3use futures::{Sink, SinkExt, TryStream, TryStreamExt};
4use http::header;
5use http::uri::Scheme;
6use serde_json::Value;
7use spacetimedb_client_api_messages::websocket::{self as ws, JsonFormat};
8use spacetimedb_data_structures::map::HashMap;
9use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9;
10use spacetimedb_lib::de::serde::{DeserializeWrapper, SeedWrapper};
11use spacetimedb_lib::ser::serde::SerializeWrapper;
12use std::io;
13use std::time::Duration;
14use thiserror::Error;
15use tokio::io::AsyncWriteExt;
16use tokio_tungstenite::tungstenite::client::IntoClientRequest;
17use tokio_tungstenite::tungstenite::{Error as WsError, Message as WsMessage};
18
19use crate::api::ClientApi;
20use crate::common_args;
21use crate::sql::parse_req;
22use crate::util::UNSTABLE_WARNING;
23use crate::Config;
24
25pub fn cli() -> clap::Command {
26    clap::Command::new("subscribe")
27        .about(format!("Subscribe to SQL queries on the database. {UNSTABLE_WARNING}"))
28        .arg(
29            Arg::new("database")
30                .required(true)
31                .help("The name or identity of the database you would like to query"),
32        )
33        .arg(
34            Arg::new("query")
35                .required(true)
36                .num_args(1..)
37                .help("The SQL query to execute"),
38        )
39        .arg(
40            Arg::new("num-updates")
41                .required(false)
42                .long("num-updates")
43                .short('n')
44                .action(ArgAction::Set)
45                .value_parser(value_parser!(u32))
46                .help("The number of subscription updates to receive before exiting"),
47        )
48        .arg(
49            Arg::new("timeout")
50                .required(false)
51                .short('t')
52                .long("timeout")
53                .action(ArgAction::Set)
54                .value_parser(value_parser!(u32))
55                .help(
56                    "The timeout, in seconds, after which to disconnect and stop receiving \
57                     subscription messages. If `-n` is specified, it will stop after whichever
58                     one comes first.",
59                ),
60        )
61        .arg(
62            Arg::new("print_initial_update")
63                .required(false)
64                .long("print-initial-update")
65                .action(ArgAction::SetTrue)
66                .help("Print the initial update for the queries."),
67        )
68        .arg(common_args::anonymous())
69        .arg(common_args::yes())
70        .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database"))
71}
72
73fn parse_msg_json(msg: &WsMessage) -> Option<ws::ServerMessage<JsonFormat>> {
74    let WsMessage::Text(msg) = msg else { return None };
75    serde_json::from_str::<DeserializeWrapper<ws::ServerMessage<JsonFormat>>>(msg)
76        .inspect_err(|e| eprintln!("couldn't parse message from server: {e}"))
77        .map(|wrapper| wrapper.0)
78        .ok()
79}
80
81fn reformat_update<'a>(
82    msg: &'a ws::DatabaseUpdate<JsonFormat>,
83    schema: &RawModuleDefV9,
84) -> anyhow::Result<HashMap<&'a str, SubscriptionTable>> {
85    msg.tables
86        .iter()
87        .map(|upd| {
88            let table_schema = schema
89                .tables
90                .iter()
91                .find(|tbl| tbl.name == upd.table_name)
92                .context("table not found in schema")?;
93            let table_ty = schema.typespace.resolve(table_schema.product_type_ref);
94
95            let reformat_row = |row: &str| -> anyhow::Result<Value> {
96                // TODO: can the following two calls be merged into a single call to reduce allocations?
97                let row = serde_json::from_str::<Value>(row)?;
98                let row = serde::de::DeserializeSeed::deserialize(SeedWrapper(table_ty), row)?;
99                let row = table_ty.with_value(&row);
100                let row = serde_json::to_value(SerializeWrapper::from_ref(&row))?;
101                Ok(row)
102            };
103
104            let mut deletes = Vec::new();
105            let mut inserts = Vec::new();
106            for upd in &upd.updates {
107                for s in &upd.deletes {
108                    deletes.push(reformat_row(s)?);
109                }
110                for s in &upd.inserts {
111                    inserts.push(reformat_row(s)?);
112                }
113            }
114
115            Ok((&*upd.table_name, SubscriptionTable { deletes, inserts }))
116        })
117        .collect()
118}
119
120#[derive(serde::Serialize, Debug)]
121struct SubscriptionTable {
122    deletes: Vec<serde_json::Value>,
123    inserts: Vec<serde_json::Value>,
124}
125
126pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> {
127    eprintln!("{UNSTABLE_WARNING}\n");
128
129    let queries = args.get_many::<String>("query").unwrap();
130    let num = args.get_one::<u32>("num-updates").copied();
131    let timeout = args.get_one::<u32>("timeout").copied();
132    let print_initial_update = args.get_flag("print_initial_update");
133
134    let conn = parse_req(config, args).await?;
135    let api = ClientApi::new(conn);
136    let module_def = api.module_def().await?;
137
138    // Change the URI scheme from `http(s)` to `ws(s)`.
139    let mut uri = http::Uri::try_from(api.con.db_uri("subscribe"))?.into_parts();
140    uri.scheme = uri.scheme.map(|s| {
141        if s == Scheme::HTTP {
142            "ws".parse().unwrap()
143        } else if s == Scheme::HTTPS {
144            "wss".parse().unwrap()
145        } else {
146            s
147        }
148    });
149
150    // Create the websocket request.
151    let mut req = http::Uri::from_parts(uri)?.into_client_request()?;
152    req.headers_mut().insert(
153        header::SEC_WEBSOCKET_PROTOCOL,
154        http::HeaderValue::from_static(ws::TEXT_PROTOCOL),
155    );
156    //  Add the authorization header, if any.
157    if let Some(auth_header) = api.con.auth_header.to_header() {
158        req.headers_mut().insert(header::AUTHORIZATION, auth_header);
159    }
160    let mut ws = tokio_tungstenite::connect_async(req).await.map(|(ws, _)| ws)?;
161
162    let task = async {
163        subscribe(&mut ws, queries.cloned().map(Into::into).collect()).await?;
164        await_initial_update(&mut ws, print_initial_update.then_some(&module_def)).await?;
165        consume_transaction_updates(&mut ws, num, &module_def).await
166    };
167
168    let res = if let Some(timeout) = timeout {
169        let timeout = Duration::from_secs(timeout.into());
170        match tokio::time::timeout(timeout, task).await {
171            Ok(res) => res,
172            Err(_elapsed) => {
173                eprintln!("timed out after {}s", timeout.as_secs());
174                Ok(())
175            }
176        }
177    } else {
178        task.await
179    };
180
181    // Close the connection gracefully.
182    // This will return an error if the server already closed,
183    // or the connection is in a bad state.
184    // The error (if any) relevant to the user is already stored in `res`,
185    // so we can ignore errors here -- graceful close is basically a
186    // courtesy to the server.
187    let _ = ws.close(None).await;
188    // The server closing the connection is not considered an error,
189    // but any other error is.
190    res.or_else(|e| {
191        if e.is_server_closed_connection() {
192            Ok(())
193        } else {
194            Err(e)
195        }
196    })
197    .map_err(anyhow::Error::from)
198}
199
200#[derive(Debug, Error)]
201enum Error {
202    #[error("error sending subscription queries")]
203    Subscribe {
204        #[source]
205        source: WsError,
206    },
207    #[error("protocol error: {details}")]
208    Protocol { details: &'static str },
209    #[error("websocket error: {source}")]
210    Websocket {
211        #[source]
212        source: WsError,
213    },
214    #[error("encountered failed transaction: {reason}")]
215    TransactionFailure { reason: Box<str> },
216    #[error("error formatting response: {source:#}")]
217    Reformat {
218        #[source]
219        source: anyhow::Error,
220    },
221    #[error(transparent)]
222    Serde(#[from] serde_json::Error),
223    #[error(transparent)]
224    Io(#[from] io::Error),
225}
226
227impl Error {
228    fn is_server_closed_connection(&self) -> bool {
229        matches!(
230            self,
231            Self::Websocket {
232                source: WsError::ConnectionClosed
233            }
234        )
235    }
236}
237
238/// Send the subscribe message.
239async fn subscribe<S>(ws: &mut S, query_strings: Box<[Box<str>]>) -> Result<(), Error>
240where
241    S: Sink<WsMessage, Error = WsError> + Unpin,
242{
243    let msg = serde_json::to_string(&SerializeWrapper::new(ws::ClientMessage::<()>::Subscribe(
244        ws::Subscribe {
245            query_strings,
246            request_id: 0,
247        },
248    )))
249    .unwrap();
250    ws.send(msg.into()).await.map_err(|source| Error::Subscribe { source })
251}
252
253/// Await the initial [`ServerMessage::SubscriptionUpdate`].
254/// If `module_def` is `Some`, print a JSON representation to stdout.
255async fn await_initial_update<S>(ws: &mut S, module_def: Option<&RawModuleDefV9>) -> Result<(), Error>
256where
257    S: TryStream<Ok = WsMessage, Error = WsError> + Unpin,
258{
259    const RECV_TX_UPDATE: &str = "protocol error: received transaction update before initial subscription update";
260
261    while let Some(msg) = ws.try_next().await.map_err(|source| Error::Websocket { source })? {
262        let Some(msg) = parse_msg_json(&msg) else { continue };
263        match msg {
264            ws::ServerMessage::InitialSubscription(sub) => {
265                if let Some(module_def) = module_def {
266                    let output = format_output_json(&sub.database_update, module_def)?;
267                    tokio::io::stdout().write_all(output.as_bytes()).await?
268                }
269                break;
270            }
271            ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate { status, .. }) => {
272                return Err(match status {
273                    ws::UpdateStatus::Failed(msg) => Error::TransactionFailure { reason: msg },
274                    _ => Error::Protocol {
275                        details: RECV_TX_UPDATE,
276                    },
277                })
278            }
279            ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { .. }) => {
280                return Err(Error::Protocol {
281                    details: RECV_TX_UPDATE,
282                })
283            }
284            _ => continue,
285        }
286    }
287
288    Ok(())
289}
290
291/// Print `num` [`ServerMessage::TransactionUpdate`] messages as JSON.
292/// If `num` is `None`, keep going indefinitely.
293async fn consume_transaction_updates<S>(ws: &mut S, num: Option<u32>, module_def: &RawModuleDefV9) -> Result<(), Error>
294where
295    S: TryStream<Ok = WsMessage, Error = WsError> + Unpin,
296{
297    let mut stdout = tokio::io::stdout();
298    let mut num_received = 0;
299    loop {
300        if num.is_some_and(|n| num_received >= n) {
301            return Ok(());
302        }
303        let Some(msg) = ws.try_next().await.map_err(|source| Error::Websocket { source })? else {
304            eprintln!("disconnected by server");
305            return Err(Error::Websocket {
306                source: WsError::ConnectionClosed,
307            });
308        };
309
310        let Some(msg) = parse_msg_json(&msg) else { continue };
311        match msg {
312            ws::ServerMessage::InitialSubscription(_) => {
313                return Err(Error::Protocol {
314                    details: "received a second initial subscription update",
315                })
316            }
317            ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { update, .. })
318            | ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate {
319                status: ws::UpdateStatus::Committed(update),
320                ..
321            }) => {
322                let output = format_output_json(&update, module_def)?;
323                stdout.write_all(output.as_bytes()).await?;
324                num_received += 1;
325            }
326            _ => continue,
327        }
328    }
329}
330
331fn format_output_json(msg: &ws::DatabaseUpdate<JsonFormat>, schema: &RawModuleDefV9) -> Result<String, Error> {
332    let formatted = reformat_update(msg, schema).map_err(|source| Error::Reformat { source })?;
333    let output = serde_json::to_string(&formatted)? + "\n";
334
335    Ok(output)
336}