spacetimedb_cli/subcommands/
subscribe.rs

1use anyhow::Context;
2use clap::{value_parser, Arg, ArgAction, ArgMatches};
3use futures::{Sink, SinkExt, TryStream, TryStreamExt};
4use http::header;
5use http::uri::Scheme;
6use serde_json::Value;
7use spacetimedb_client_api_messages::websocket::{self as ws, JsonFormat};
8use spacetimedb_data_structures::map::HashMap;
9use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9;
10use spacetimedb_lib::de::serde::{DeserializeWrapper, SeedWrapper};
11use spacetimedb_lib::ser::serde::SerializeWrapper;
12use std::time::Duration;
13use tokio::io::AsyncWriteExt;
14use tokio_tungstenite::tungstenite::client::IntoClientRequest;
15use tokio_tungstenite::tungstenite::Message as WsMessage;
16
17use crate::api::ClientApi;
18use crate::common_args;
19use crate::sql::parse_req;
20use crate::util::UNSTABLE_WARNING;
21use crate::Config;
22
23pub fn cli() -> clap::Command {
24    clap::Command::new("subscribe")
25        .about(format!(
26            "Subscribe to SQL queries on the database. {}",
27            UNSTABLE_WARNING
28        ))
29        .arg(
30            Arg::new("database")
31                .required(true)
32                .help("The name or identity of the database you would like to query"),
33        )
34        .arg(
35            Arg::new("query")
36                .required(true)
37                .num_args(1..)
38                .help("The SQL query to execute"),
39        )
40        .arg(
41            Arg::new("num-updates")
42                .required(false)
43                .long("num-updates")
44                .short('n')
45                .action(ArgAction::Set)
46                .value_parser(value_parser!(u32))
47                .help("The number of subscription updates to receive before exiting"),
48        )
49        .arg(
50            Arg::new("timeout")
51                .required(false)
52                .short('t')
53                .long("timeout")
54                .action(ArgAction::Set)
55                .value_parser(value_parser!(u32))
56                .help(
57                    "The timeout, in seconds, after which to disconnect and stop receiving \
58                     subscription messages. If `-n` is specified, it will stop after whichever
59                     one comes first.",
60                ),
61        )
62        .arg(
63            Arg::new("print_initial_update")
64                .required(false)
65                .long("print-initial-update")
66                .action(ArgAction::SetTrue)
67                .help("Print the initial update for the queries."),
68        )
69        .arg(common_args::anonymous())
70        .arg(common_args::yes())
71        .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database"))
72}
73
74fn parse_msg_json(msg: &WsMessage) -> Option<ws::ServerMessage<JsonFormat>> {
75    let WsMessage::Text(msg) = msg else { return None };
76    serde_json::from_str::<DeserializeWrapper<ws::ServerMessage<JsonFormat>>>(msg)
77        .inspect_err(|e| eprintln!("couldn't parse message from server: {e}"))
78        .map(|wrapper| wrapper.0)
79        .ok()
80}
81
82fn reformat_update<'a>(
83    msg: &'a ws::DatabaseUpdate<JsonFormat>,
84    schema: &RawModuleDefV9,
85) -> anyhow::Result<HashMap<&'a str, SubscriptionTable>> {
86    msg.tables
87        .iter()
88        .map(|upd| {
89            let table_schema = schema
90                .tables
91                .iter()
92                .find(|tbl| tbl.name == upd.table_name)
93                .context("table not found in schema")?;
94            let table_ty = schema.typespace.resolve(table_schema.product_type_ref);
95
96            let reformat_row = |row: &str| -> anyhow::Result<Value> {
97                // TODO: can the following two calls be merged into a single call to reduce allocations?
98                let row = serde_json::from_str::<Value>(row)?;
99                let row = serde::de::DeserializeSeed::deserialize(SeedWrapper(table_ty), row)?;
100                let row = table_ty.with_value(&row);
101                let row = serde_json::to_value(SerializeWrapper::from_ref(&row))?;
102                Ok(row)
103            };
104
105            let mut deletes = Vec::new();
106            let mut inserts = Vec::new();
107            for upd in &upd.updates {
108                for s in &upd.deletes {
109                    deletes.push(reformat_row(s)?);
110                }
111                for s in &upd.inserts {
112                    inserts.push(reformat_row(s)?);
113                }
114            }
115
116            Ok((&*upd.table_name, SubscriptionTable { deletes, inserts }))
117        })
118        .collect()
119}
120
121#[derive(serde::Serialize, Debug)]
122struct SubscriptionTable {
123    deletes: Vec<serde_json::Value>,
124    inserts: Vec<serde_json::Value>,
125}
126
127pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> {
128    eprintln!("{}\n", UNSTABLE_WARNING);
129
130    let queries = args.get_many::<String>("query").unwrap();
131    let num = args.get_one::<u32>("num-updates").copied();
132    let timeout = args.get_one::<u32>("timeout").copied();
133    let print_initial_update = args.get_flag("print_initial_update");
134
135    let conn = parse_req(config, args).await?;
136    let api = ClientApi::new(conn);
137    let module_def = api.module_def().await?;
138
139    // Change the URI scheme from `http(s)` to `ws(s)`.
140    let mut uri = http::Uri::try_from(api.con.db_uri("subscribe"))?.into_parts();
141    uri.scheme = uri.scheme.map(|s| {
142        if s == Scheme::HTTP {
143            "ws".parse().unwrap()
144        } else if s == Scheme::HTTPS {
145            "wss".parse().unwrap()
146        } else {
147            s
148        }
149    });
150
151    // Create the websocket request.
152    let mut req = http::Uri::from_parts(uri)?.into_client_request()?;
153    req.headers_mut().insert(
154        header::SEC_WEBSOCKET_PROTOCOL,
155        http::HeaderValue::from_static(ws::TEXT_PROTOCOL),
156    );
157    //  Add the authorization header, if any.
158    if let Some(auth_header) = api.con.auth_header.to_header() {
159        req.headers_mut().insert(header::AUTHORIZATION, auth_header);
160    }
161    let (mut ws, _) = tokio_tungstenite::connect_async(req).await?;
162
163    let task = async {
164        subscribe(&mut ws, queries.cloned().map(Into::into).collect()).await?;
165        await_initial_update(&mut ws, print_initial_update.then_some(&module_def)).await?;
166        consume_transaction_updates(&mut ws, num, &module_def).await
167    };
168
169    let needs_shutdown = if let Some(timeout) = timeout {
170        let timeout = Duration::from_secs(timeout.into());
171        match tokio::time::timeout(timeout, task).await {
172            Ok(res) => res?,
173            Err(_elapsed) => true,
174        }
175    } else {
176        task.await?
177    };
178
179    if needs_shutdown {
180        ws.close(None).await?;
181    }
182
183    Ok(())
184}
185
186/// Send the subscribe message.
187async fn subscribe<S>(ws: &mut S, query_strings: Box<[Box<str>]>) -> Result<(), S::Error>
188where
189    S: Sink<WsMessage> + Unpin,
190{
191    let msg = serde_json::to_string(&SerializeWrapper::new(ws::ClientMessage::<()>::Subscribe(
192        ws::Subscribe {
193            query_strings,
194            request_id: 0,
195        },
196    )))
197    .unwrap();
198    ws.send(msg.into()).await
199}
200
201/// Await the initial [`ServerMessage::SubscriptionUpdate`].
202/// If `module_def` is `Some`, print a JSON representation to stdout.
203async fn await_initial_update<S>(ws: &mut S, module_def: Option<&RawModuleDefV9>) -> anyhow::Result<()>
204where
205    S: TryStream<Ok = WsMessage> + Unpin,
206    S::Error: std::error::Error + Send + Sync + 'static,
207{
208    const RECV_TX_UPDATE: &str = "protocol error: received transaction update before initial subscription update";
209
210    while let Some(msg) = ws.try_next().await? {
211        let Some(msg) = parse_msg_json(&msg) else { continue };
212        match msg {
213            ws::ServerMessage::InitialSubscription(sub) => {
214                if let Some(module_def) = module_def {
215                    let formatted = reformat_update(&sub.database_update, module_def)?;
216                    let output = serde_json::to_string(&formatted)? + "\n";
217                    tokio::io::stdout().write_all(output.as_bytes()).await?
218                }
219                break;
220            }
221            ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate { status, .. }) => anyhow::bail!(match status {
222                ws::UpdateStatus::Failed(msg) => msg,
223                _ => RECV_TX_UPDATE.into(),
224            }),
225            ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { .. }) => {
226                anyhow::bail!(RECV_TX_UPDATE)
227            }
228            _ => continue,
229        }
230    }
231
232    Ok(())
233}
234
235/// Print `num` [`ServerMessage::TransactionUpdate`] messages as JSON.
236/// If `num` is `None`, keep going indefinitely.
237async fn consume_transaction_updates<S>(
238    ws: &mut S,
239    num: Option<u32>,
240    module_def: &RawModuleDefV9,
241) -> anyhow::Result<bool>
242where
243    S: TryStream<Ok = WsMessage> + Unpin,
244    S::Error: std::error::Error + Send + Sync + 'static,
245{
246    let mut stdout = tokio::io::stdout();
247    let mut num_received = 0;
248    loop {
249        if num.is_some_and(|n| num_received >= n) {
250            break Ok(true);
251        }
252        let Some(msg) = ws.try_next().await? else {
253            eprintln!("disconnected by server");
254            break Ok(false);
255        };
256
257        let Some(msg) = parse_msg_json(&msg) else { continue };
258        match msg {
259            ws::ServerMessage::InitialSubscription(_) => {
260                anyhow::bail!("protocol error: received a second initial subscription update")
261            }
262            ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { update, .. })
263            | ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate {
264                status: ws::UpdateStatus::Committed(update),
265                ..
266            }) => {
267                let output = serde_json::to_string(&reformat_update(&update, module_def)?)? + "\n";
268                stdout.write_all(output.as_bytes()).await?;
269                num_received += 1;
270            }
271            _ => continue,
272        }
273    }
274}