use anyhow::Context;
use clap::{value_parser, Arg, ArgAction, ArgMatches};
use futures::{Sink, SinkExt, TryStream, TryStreamExt};
use http::header;
use http::uri::Scheme;
use serde_json::Value;
use spacetimedb_client_api_messages::websocket::{self as ws, EncodedValue};
use spacetimedb_data_structures::map::HashMap;
use spacetimedb_lib::de::serde::{DeserializeWrapper, SeedWrapper};
use spacetimedb_lib::ser::serde::SerializeWrapper;
use spacetimedb_lib::RawModuleDefV8;
use spacetimedb_standalone::TEXT_PROTOCOL;
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::Message as WsMessage;
use crate::api::ClientApi;
use crate::common_args;
use crate::sql::parse_req;
use crate::Config;
pub fn cli() -> clap::Command {
clap::Command::new("subscribe")
.about("Subscribe to SQL queries on the database.")
.arg(
Arg::new("database")
.required(true)
.help("The domain or address of the database you would like to query"),
)
.arg(
Arg::new("query")
.required(true)
.num_args(1..)
.help("The SQL query to execute"),
)
.arg(
Arg::new("num-updates")
.required(false)
.short('n')
.action(ArgAction::Set)
.value_parser(value_parser!(u32))
.help("The number of subscription updates to receive before exiting"),
)
.arg(
Arg::new("timeout")
.required(false)
.short('t')
.long("timeout")
.action(ArgAction::Set)
.value_parser(value_parser!(u32))
.help(
"The timeout, in seconds, after which to disconnect and stop receiving \
subscription messages. If `-n` is specified, it will stop after whichever
one comes first.",
),
)
.arg(
Arg::new("print_initial_update")
.required(false)
.long("print-initial-update")
.action(ArgAction::SetTrue)
.help("Print the initial update for the queries."),
)
.arg(
common_args::identity()
.conflicts_with("anon_identity")
.help("The identity to use for querying the database")
.long_help(
"The identity to use for querying the database. \
If no identity is provided, the default one will be used.",
),
)
.arg(
Arg::new("anon_identity")
.long("anon-identity")
.short('a')
.conflicts_with("identity")
.action(ArgAction::SetTrue)
.help("If this flag is present, no identity will be provided when querying the database"),
)
.arg(common_args::server().help("The nickname, host name or URL of the server hosting the database"))
}
fn parse_msg_json(msg: &WsMessage) -> Option<ws::ServerMessage> {
let WsMessage::Text(msg) = msg else { return None };
serde_json::from_str::<DeserializeWrapper<ws::ServerMessage>>(msg)
.inspect_err(|e| eprintln!("couldn't parse message from server: {e}"))
.map(|wrapper| wrapper.0)
.ok()
}
fn reformat_update(
msg: ws::DatabaseUpdate,
schema: &RawModuleDefV8,
) -> anyhow::Result<HashMap<String, SubscriptionTable>> {
msg.tables
.into_iter()
.map(|upd| {
let table_schema = schema
.tables
.iter()
.find(|tbl| tbl.schema.table_name.as_ref() == upd.table_name)
.context("table not found in schema")?;
let table_ty = schema.typespace.resolve(table_schema.data);
let reformat_row = |row: EncodedValue| {
let EncodedValue::Text(row) = row else {
anyhow::bail!("Expected row in text format but found {row:?}");
};
let row = serde_json::from_str::<Value>(&row)?;
let row = serde::de::DeserializeSeed::deserialize(SeedWrapper(table_ty), row)?;
let row = table_ty.with_value(&row);
let row = serde_json::to_value(SerializeWrapper::from_ref(&row))?;
Ok(row)
};
let deletes = upd
.deletes
.into_iter()
.map(reformat_row)
.collect::<anyhow::Result<Vec<_>>>()?;
let inserts = upd
.inserts
.into_iter()
.map(reformat_row)
.collect::<anyhow::Result<Vec<_>>>()?;
Ok((upd.table_name, SubscriptionTable { deletes, inserts }))
})
.collect()
}
#[derive(serde::Serialize, Debug)]
struct SubscriptionTable {
deletes: Vec<serde_json::Value>,
inserts: Vec<serde_json::Value>,
}
pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> {
let queries = args.get_many::<String>("query").unwrap();
let num = args.get_one::<u32>("num-updates").copied();
let timeout = args.get_one::<u32>("timeout").copied();
let print_initial_update = args.get_flag("print_initial_update");
let conn = parse_req(config, args).await?;
let api = ClientApi::new(conn);
let module_def = api.module_def().await?;
let mut uri = http::Uri::try_from(api.con.db_uri("subscribe"))?.into_parts();
uri.scheme = uri.scheme.map(|s| {
if s == Scheme::HTTP {
"ws".parse().unwrap()
} else if s == Scheme::HTTPS {
"wss".parse().unwrap()
} else {
s
}
});
let mut req = http::Uri::from_parts(uri)?.into_client_request()?;
req.headers_mut().insert(header::SEC_WEBSOCKET_PROTOCOL, TEXT_PROTOCOL);
if let Some(auth_header) = &api.con.auth_header {
req.headers_mut().insert(header::AUTHORIZATION, auth_header.try_into()?);
}
let (mut ws, _) = tokio_tungstenite::connect_async(req).await?;
let task = async {
subscribe(&mut ws, queries.cloned().collect()).await?;
await_initial_update(&mut ws, print_initial_update.then_some(&module_def)).await?;
consume_transaction_updates(&mut ws, num, &module_def).await
};
let needs_shutdown = if let Some(timeout) = timeout {
let timeout = Duration::from_secs(timeout.into());
match tokio::time::timeout(timeout, task).await {
Ok(res) => res?,
Err(_elapsed) => true,
}
} else {
task.await?
};
if needs_shutdown {
ws.close(None).await?;
}
Ok(())
}
async fn subscribe<S>(ws: &mut S, query_strings: Vec<String>) -> Result<(), S::Error>
where
S: Sink<WsMessage> + Unpin,
{
let msg = serde_json::to_string(&SerializeWrapper::new(ws::ClientMessage::<()>::Subscribe(
ws::Subscribe {
query_strings,
request_id: 0,
},
)))
.unwrap();
ws.send(msg.into()).await
}
async fn await_initial_update<S>(ws: &mut S, module_def: Option<&RawModuleDefV8>) -> anyhow::Result<()>
where
S: TryStream<Ok = WsMessage> + Unpin,
S::Error: std::error::Error + Send + Sync + 'static,
{
while let Some(msg) = ws.try_next().await? {
let Some(msg) = parse_msg_json(&msg) else { continue };
match msg {
ws::ServerMessage::InitialSubscription(sub) => {
if let Some(module_def) = module_def {
let formatted = reformat_update(sub.database_update, module_def)?;
let output = serde_json::to_string(&formatted)? + "\n";
tokio::io::stdout().write_all(output.as_bytes()).await?
}
break;
}
ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate { status, .. }) => {
let message = match status {
ws::UpdateStatus::Failed(msg) => msg,
_ => "protocol error: received transaction update before initial subscription update".to_string(),
};
anyhow::bail!(message)
}
_ => continue,
}
}
Ok(())
}
async fn consume_transaction_updates<S>(
ws: &mut S,
num: Option<u32>,
module_def: &RawModuleDefV8,
) -> anyhow::Result<bool>
where
S: TryStream<Ok = WsMessage> + Unpin,
S::Error: std::error::Error + Send + Sync + 'static,
{
let mut stdout = tokio::io::stdout();
let mut num_received = 0;
loop {
if num.is_some_and(|n| num_received >= n) {
break Ok(true);
}
let Some(msg) = ws.try_next().await? else {
eprintln!("disconnected by server");
break Ok(false);
};
let Some(msg) = parse_msg_json(&msg) else { continue };
match msg {
ws::ServerMessage::InitialSubscription(_) => {
anyhow::bail!("protocol error: received a second initial subscription update")
}
ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate {
status: ws::UpdateStatus::Committed(update),
..
}) => {
let output = serde_json::to_string(&reformat_update(update, module_def)?)? + "\n";
stdout.write_all(output.as_bytes()).await?;
num_received += 1;
}
_ => continue,
}
}
}