spacetimedb_cli/subcommands/
sql.rs

1use std::time::Instant;
2
3use crate::api::{from_json_seed, ClientApi, Connection, StmtResultJson};
4use crate::common_args;
5use anyhow::Context;
6use clap::{Arg, ArgAction, ArgMatches};
7use itertools::Itertools;
8use reqwest::RequestBuilder;
9use spacetimedb_lib::de::serde::SeedWrapper;
10use spacetimedb_lib::sats::{satn, Typespace};
11use tabled::settings::Style;
12
13use crate::config::Config;
14use crate::util::{database_identity, get_auth_header, ResponseExt, UNSTABLE_WARNING};
15
16pub fn cli() -> clap::Command {
17    clap::Command::new("sql")
18        .about(format!("Runs a SQL query on the database. {}", UNSTABLE_WARNING))
19        .arg(
20            Arg::new("database")
21                .required(true)
22                .help("The name or identity of the database you would like to query"),
23        )
24        .arg(
25            Arg::new("query")
26                .action(ArgAction::Set)
27                .required(true)
28                .conflicts_with("interactive")
29                .help("The SQL query to execute"),
30        )
31        .arg(
32            Arg::new("interactive")
33                .long("interactive")
34                .action(ArgAction::SetTrue)
35                .conflicts_with("query")
36                .help("Instead of using a query, run an interactive command prompt for `SQL` expressions"),
37        )
38        .arg(common_args::anonymous())
39        .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database"))
40        .arg(common_args::yes())
41}
42
43pub(crate) async fn parse_req(mut config: Config, args: &ArgMatches) -> Result<Connection, anyhow::Error> {
44    let server = args.get_one::<String>("server").map(|s| s.as_ref());
45    let force = args.get_flag("force");
46    let database_name_or_identity = args.get_one::<String>("database").unwrap();
47    let anon_identity = args.get_flag("anon_identity");
48
49    Ok(Connection {
50        host: config.get_host_url(server)?,
51        auth_header: get_auth_header(&mut config, anon_identity, server, !force).await?,
52        database_identity: database_identity(&config, database_name_or_identity, server).await?,
53        database: database_name_or_identity.to_string(),
54    })
55}
56
57// Need to report back timings from each query from the backend instead of infer here...
58fn print_row_count(rows: usize) -> String {
59    let txt = if rows == 1 { "row" } else { "rows" };
60    format!("({rows} {txt})")
61}
62
63fn print_timings(now: Instant) {
64    println!("Time: {:.2?}", now.elapsed());
65}
66
67pub(crate) async fn run_sql(builder: RequestBuilder, sql: &str, with_stats: bool) -> Result<(), anyhow::Error> {
68    let now = Instant::now();
69
70    let json = builder
71        .body(sql.to_owned())
72        .send()
73        .await?
74        .ensure_content_type("application/json")
75        .await?
76        .text()
77        .await?;
78
79    let stmt_result_json: Vec<StmtResultJson> = serde_json::from_str(&json).context("malformed sql response")?;
80
81    // Print only `OK for empty tables as it's likely a command like `INSERT`.
82    if stmt_result_json.is_empty() {
83        if with_stats {
84            print_timings(now);
85        }
86        println!("OK");
87        return Ok(());
88    };
89
90    stmt_result_json
91        .iter()
92        .map(|stmt_result| {
93            let mut table = stmt_result_to_table(stmt_result)?;
94            if with_stats {
95                // The `tabled::count_rows` add the header as a row, so subtract it.
96                let row_count = table.count_rows().wrapping_sub(1);
97                // For some reason, `table.with(...)` crashes if the row count is 0.
98                if row_count > 0 {
99                    let row_count = print_row_count(row_count);
100                    table.with(tabled::settings::panel::Footer::new(row_count));
101                }
102            }
103            anyhow::Ok(table)
104        })
105        .process_results(|it| println!("{}", it.format("\n\n")))?;
106    if with_stats {
107        print_timings(now);
108    }
109
110    Ok(())
111}
112
113fn stmt_result_to_table(stmt_result: &StmtResultJson) -> anyhow::Result<tabled::Table> {
114    let StmtResultJson { schema, rows } = stmt_result;
115
116    let mut builder = tabled::builder::Builder::default();
117    builder.set_header(
118        schema
119            .elements
120            .iter()
121            .enumerate()
122            .map(|(i, e)| e.name.clone().unwrap_or_else(|| format!("column {i}").into())),
123    );
124
125    let ty = Typespace::EMPTY.with_type(schema);
126    for row in rows {
127        let row = from_json_seed(row.get(), SeedWrapper(ty))?;
128        builder.push_record(
129            ty.with_values(&row)
130                .map(|value| satn::PsqlWrapper { ty: ty.ty(), value }.to_string()),
131        );
132    }
133
134    let mut table = builder.build();
135    table.with(Style::psql());
136
137    Ok(table)
138}
139
140pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> {
141    eprintln!("{}\n", UNSTABLE_WARNING);
142    let interactive = args.get_one::<bool>("interactive").unwrap_or(&false);
143    if *interactive {
144        let con = parse_req(config, args).await?;
145
146        crate::repl::exec(con).await?;
147    } else {
148        let query = args.get_one::<String>("query").unwrap();
149
150        let con = parse_req(config, args).await?;
151        let api = ClientApi::new(con);
152
153        run_sql(api.sql(), query, false).await?;
154    }
155    Ok(())
156}