spacetimedb_cli/
util.rs

1use anyhow::Context;
2use base64::{engine::general_purpose::STANDARD_NO_PAD as BASE_64_STD_NO_PAD, Engine as _};
3use reqwest::{RequestBuilder, Url};
4use spacetimedb_auth::identity::{IncomingClaims, SpacetimeIdentityClaims};
5use spacetimedb_client_api_messages::name::GetNamesResponse;
6use spacetimedb_lib::Identity;
7use std::io::Write;
8use std::path::{Path, PathBuf};
9
10use crate::config::Config;
11use crate::login::{spacetimedb_login_force, DEFAULT_AUTH_HOST};
12
13pub const UNSTABLE_WARNING: &str = "WARNING: This command is UNSTABLE and subject to breaking changes.";
14
15/// Determine the identity of the `database`.
16pub async fn database_identity(
17    config: &Config,
18    name_or_identity: &str,
19    server: Option<&str>,
20) -> Result<Identity, anyhow::Error> {
21    if let Ok(identity) = Identity::from_hex(name_or_identity) {
22        return Ok(identity);
23    }
24    spacetime_dns(config, name_or_identity, server)
25        .await?
26        .with_context(|| format!("the dns resolution of `{name_or_identity}` failed."))
27}
28
29pub(crate) trait ResponseExt: Sized {
30    /// Ensure that this response has the given content-type, especially if it's
31    /// a success response.
32    ///
33    /// This checks the response status for you, so you shouldn't call
34    /// `error_for_status()` beforehand.
35    ///
36    /// If the response does not have the given content type, assume it's an error message
37    /// and return it as such. Success responses with the wrong content type are treated
38    /// as a bug in the API implementation, since that makes it harder to tell what's
39    /// meant to be a structured response and what's a plain-text error message.
40    async fn ensure_content_type(self, content_type: &str) -> anyhow::Result<Self>;
41
42    /// Like [`reqwest::Response::json()`], but handles non-JSON error messages gracefully.
43    async fn json_or_error<T: serde::de::DeserializeOwned>(self) -> anyhow::Result<T>;
44
45    /// Transforms a status of `NOT_FOUND` into `None`.
46    fn found(self) -> Option<Self>;
47}
48
49fn err_status_desc(status: http::StatusCode) -> Option<&'static str> {
50    if status.is_success() {
51        None
52    } else if status.is_client_error() {
53        Some("HTTP status client error")
54    } else if status.is_server_error() {
55        Some("HTTP status server error")
56    } else {
57        Some("unexpected HTTP status code")
58    }
59}
60
61impl ResponseExt for reqwest::Response {
62    async fn ensure_content_type(self, content_type: &str) -> anyhow::Result<Self> {
63        let status = self.status();
64        if self
65            .headers()
66            .get(http::header::CONTENT_TYPE)
67            .is_some_and(|ty| ty == content_type)
68        {
69            return Ok(self);
70        }
71        let url = self.url();
72        let Some(status_desc) = err_status_desc(status) else {
73            anyhow::bail!("HTTP response from url ({url}) was success but did not have content-type: {content_type}");
74        };
75        let url = url.to_string();
76        let status_err = match self.error_for_status_ref() {
77            Err(e) => anyhow::Error::from(e),
78            Ok(_) => anyhow::anyhow!("{status_desc} ({status}) from url ({url})"),
79        };
80        let err = match self.text().await {
81            Ok(text) => status_err.context(text),
82            Err(err) => anyhow::Error::from(err)
83                .context(format!("{status_desc} ({status})"))
84                .context("failed to get response text"),
85        };
86        Err(err)
87    }
88
89    async fn json_or_error<T: serde::de::DeserializeOwned>(self) -> anyhow::Result<T> {
90        let status = self.status();
91        self.ensure_content_type("application/json")
92            .await?
93            .json()
94            .await
95            .map_err(|err| {
96                let mut err = anyhow::Error::from(err);
97                if let Some(desc) = err_status_desc(status) {
98                    err = err.context(format!("malformed json payload for {desc} ({status})"))
99                }
100                err
101            })
102    }
103
104    fn found(self) -> Option<Self> {
105        (self.status() != http::StatusCode::NOT_FOUND).then_some(self)
106    }
107}
108
109/// Converts a name to a database identity.
110pub async fn spacetime_dns(
111    config: &Config,
112    domain: &str,
113    server: Option<&str>,
114) -> Result<Option<Identity>, anyhow::Error> {
115    let client = reqwest::Client::new();
116    let url = format!("{}/v1/database/{}/identity", config.get_host_url(server)?, domain);
117    let Some(res) = client.get(url).send().await?.found() else {
118        return Ok(None);
119    };
120    let text = res.error_for_status()?.text().await?;
121    text.parse()
122        .map(Some)
123        .context("identity endpoint did not return an identity")
124}
125
126pub async fn spacetime_server_fingerprint(url: &str) -> anyhow::Result<String> {
127    let builder = reqwest::Client::new().get(format!("{url}/v1/identity/public-key").as_str());
128    let res = builder.send().await?.error_for_status()?;
129    let fingerprint = res.text().await?;
130    Ok(fingerprint)
131}
132
133/// Returns all known names for the given identity.
134pub async fn spacetime_reverse_dns(
135    config: &Config,
136    identity: &str,
137    server: Option<&str>,
138) -> Result<GetNamesResponse, anyhow::Error> {
139    let client = reqwest::Client::new();
140    let url = format!("{}/v1/database/{}/names", config.get_host_url(server)?, identity);
141    client.get(url).send().await?.json_or_error().await
142}
143
144/// Add an authorization header, if provided, to the request `builder`.
145pub fn add_auth_header_opt(mut builder: RequestBuilder, auth_header: &AuthHeader) -> RequestBuilder {
146    if let Some(token) = &auth_header.token {
147        builder = builder.bearer_auth(token);
148    }
149    builder
150}
151
152/// Gets the `auth_header` for a request to the server depending on how you want
153/// to identify yourself.  If you specify `anon_identity = true` then no
154/// `auth_header` is returned. If you specify an identity this function will try
155/// to find the identity in the config file. If no identity can be found, the
156/// program will `exit(1)`. If you do not specify an identity this function will
157/// either get the default identity if one exists or create and save a new
158/// default identity returning the one that was just created.
159///
160/// # Arguments
161///  * `config` - The config file reference
162///  * `anon_identity` - Whether or not to just use an anonymous identity (no identity)
163///  * `identity_or_name` - The identity to try to lookup, which is typically provided from the command line
164pub async fn get_auth_header(
165    config: &mut Config,
166    anon_identity: bool,
167    target_server: Option<&str>,
168    interactive: bool,
169) -> anyhow::Result<AuthHeader> {
170    let token = if anon_identity {
171        None
172    } else {
173        Some(get_login_token_or_log_in(config, target_server, interactive).await?)
174    };
175    Ok(AuthHeader { token })
176}
177
178#[derive(Debug, Clone)]
179pub struct AuthHeader {
180    token: Option<String>,
181}
182impl AuthHeader {
183    pub fn to_header(&self) -> Option<http::HeaderValue> {
184        self.token.as_ref().map(|token| {
185            let mut val = http::HeaderValue::try_from(["Bearer ", token].concat()).unwrap();
186            val.set_sensitive(true);
187            val
188        })
189    }
190}
191
192pub const VALID_PROTOCOLS: [&str; 2] = ["http", "https"];
193
194#[derive(Clone, Copy, PartialEq, Debug)]
195pub enum ModuleLanguage {
196    Csharp,
197    Rust,
198}
199impl clap::ValueEnum for ModuleLanguage {
200    fn value_variants<'a>() -> &'a [Self] {
201        &[Self::Csharp, Self::Rust]
202    }
203    fn to_possible_value(&self) -> Option<clap::builder::PossibleValue> {
204        match self {
205            Self::Csharp => Some(clap::builder::PossibleValue::new("csharp").aliases(["c#", "cs", "C#", "CSharp"])),
206            Self::Rust => Some(clap::builder::PossibleValue::new("rust").aliases(["rs", "Rust"])),
207        }
208    }
209}
210
211pub fn detect_module_language(path_to_project: &Path) -> anyhow::Result<ModuleLanguage> {
212    // TODO: Possible add a config file durlng spacetime init with the language
213    // check for Cargo.toml
214    if path_to_project.join("Cargo.toml").exists() {
215        Ok(ModuleLanguage::Rust)
216    } else if path_to_project
217        .read_dir()
218        .unwrap()
219        .any(|entry| entry.unwrap().path().extension() == Some("csproj".as_ref()))
220    {
221        Ok(ModuleLanguage::Csharp)
222    } else {
223        anyhow::bail!("Could not detect the language of the module. Are you in a SpacetimeDB project directory?")
224    }
225}
226
227pub fn url_to_host_and_protocol(url: &str) -> anyhow::Result<(&str, &str)> {
228    if contains_protocol(url) {
229        let protocol = url.split("://").next().unwrap();
230        let host = url.split("://").last().unwrap();
231
232        if !VALID_PROTOCOLS.contains(&protocol) {
233            Err(anyhow::anyhow!("Invalid protocol: {}", protocol))
234        } else {
235            Ok((host, protocol))
236        }
237    } else {
238        Err(anyhow::anyhow!("Invalid url: {}", url))
239    }
240}
241
242pub fn contains_protocol(name_or_url: &str) -> bool {
243    name_or_url.contains("://")
244}
245
246pub fn host_or_url_to_host_and_protocol(host_or_url: &str) -> (&str, Option<&str>) {
247    if contains_protocol(host_or_url) {
248        let (host, protocol) = url_to_host_and_protocol(host_or_url).unwrap();
249        (host, Some(protocol))
250    } else {
251        (host_or_url, None)
252    }
253}
254
255/// Prompt the user for `y` or `n` from stdin.
256///
257/// Return `false` unless the input is `y`.
258pub fn y_or_n(force: bool, prompt: &str) -> anyhow::Result<bool> {
259    if force {
260        println!("Skipping confirmation due to --yes");
261        return Ok(true);
262    }
263    let mut input = String::new();
264    print!("{prompt} [y/N]");
265    std::io::stdout().flush()?;
266    std::io::stdin().read_line(&mut input)?;
267    let input = input.trim().to_lowercase();
268    Ok(input == "y" || input == "yes")
269}
270
271pub fn unauth_error_context<T>(res: anyhow::Result<T>, identity: &str, server: &str) -> anyhow::Result<T> {
272    res.with_context(|| {
273        format!(
274            "Identity {identity} is not valid for server {server}.
275Please log back in with `spacetime logout` and then `spacetime login`."
276        )
277    })
278}
279
280pub fn decode_identity(token: &String) -> anyhow::Result<String> {
281    // Here, we manually extract and decode the claims from the json web token.
282    // We do this without using the `jsonwebtoken` crate because it doesn't seem to have a way to skip signature verification.
283    // But signature verification would require getting the public key from a server, and we don't necessarily want to do that.
284    let token_parts: Vec<_> = token.split('.').collect();
285    if token_parts.len() != 3 {
286        return Err(anyhow::anyhow!("Token does not look like a JSON web token: {}", token));
287    }
288    let decoded_bytes = BASE_64_STD_NO_PAD.decode(token_parts[1])?;
289    let decoded_string = String::from_utf8(decoded_bytes)?;
290
291    let claims_data: IncomingClaims = serde_json::from_str(decoded_string.as_str())?;
292    let claims_data: SpacetimeIdentityClaims = claims_data.try_into()?;
293
294    Ok(claims_data.identity.to_string())
295}
296
297pub async fn get_login_token_or_log_in(
298    config: &mut Config,
299    target_server: Option<&str>,
300    interactive: bool,
301) -> anyhow::Result<String> {
302    if let Some(token) = config.spacetimedb_token() {
303        return Ok(token.clone());
304    }
305
306    // Note: We pass `force: false` to `y_or_n` because if we're running non-interactively we want to default to "no", not yes!
307    let full_login = interactive
308        && y_or_n(
309            false,
310            // It would be "ideal" if we could print the `spacetimedb.com` by deriving it from the `default_auth_host` constant,
311            // but this will change _so_ infrequently that it's not even worth the time to write that code and test it.
312            "You are not logged in. Would you like to log in with spacetimedb.com?",
313        )?;
314
315    if full_login {
316        let host = Url::parse(DEFAULT_AUTH_HOST)?;
317        spacetimedb_login_force(config, &host, false).await
318    } else {
319        let host = Url::parse(&config.get_host_url(target_server)?)?;
320        spacetimedb_login_force(config, &host, true).await
321    }
322}
323
324pub fn resolve_sibling_binary(bin_name: &str) -> anyhow::Result<PathBuf> {
325    let resolved_exe = std::env::current_exe().context("could not retrieve current exe")?;
326    let bin_path = resolved_exe
327        .parent()
328        .unwrap()
329        .join(bin_name)
330        .with_extension(std::env::consts::EXE_EXTENSION);
331    Ok(bin_path)
332}