1use anyhow::Context;
2use base64::{engine::general_purpose::STANDARD_NO_PAD as BASE_64_STD_NO_PAD, Engine as _};
3use reqwest::{RequestBuilder, Url};
4use spacetimedb_auth::identity::{IncomingClaims, SpacetimeIdentityClaims};
5use spacetimedb_client_api_messages::name::GetNamesResponse;
6use spacetimedb_lib::Identity;
7use std::io::Write;
8use std::path::{Path, PathBuf};
9
10use crate::config::Config;
11use crate::login::{spacetimedb_login_force, DEFAULT_AUTH_HOST};
12
13pub const UNSTABLE_WARNING: &str = "WARNING: This command is UNSTABLE and subject to breaking changes.";
14
15pub async fn database_identity(
17 config: &Config,
18 name_or_identity: &str,
19 server: Option<&str>,
20) -> Result<Identity, anyhow::Error> {
21 if let Ok(identity) = Identity::from_hex(name_or_identity) {
22 return Ok(identity);
23 }
24 spacetime_dns(config, name_or_identity, server)
25 .await?
26 .with_context(|| format!("the dns resolution of `{name_or_identity}` failed."))
27}
28
29pub(crate) trait ResponseExt: Sized {
30 async fn ensure_content_type(self, content_type: &str) -> anyhow::Result<Self>;
41
42 async fn json_or_error<T: serde::de::DeserializeOwned>(self) -> anyhow::Result<T>;
44
45 fn found(self) -> Option<Self>;
47}
48
49fn err_status_desc(status: http::StatusCode) -> Option<&'static str> {
50 if status.is_success() {
51 None
52 } else if status.is_client_error() {
53 Some("HTTP status client error")
54 } else if status.is_server_error() {
55 Some("HTTP status server error")
56 } else {
57 Some("unexpected HTTP status code")
58 }
59}
60
61impl ResponseExt for reqwest::Response {
62 async fn ensure_content_type(self, content_type: &str) -> anyhow::Result<Self> {
63 let status = self.status();
64 if self
65 .headers()
66 .get(http::header::CONTENT_TYPE)
67 .is_some_and(|ty| ty == content_type)
68 {
69 return Ok(self);
70 }
71 let url = self.url();
72 let Some(status_desc) = err_status_desc(status) else {
73 anyhow::bail!("HTTP response from url ({url}) was success but did not have content-type: {content_type}");
74 };
75 let url = url.to_string();
76 let status_err = match self.error_for_status_ref() {
77 Err(e) => anyhow::Error::from(e),
78 Ok(_) => anyhow::anyhow!("{status_desc} ({status}) from url ({url})"),
79 };
80 let err = match self.text().await {
81 Ok(text) => status_err.context(text),
82 Err(err) => anyhow::Error::from(err)
83 .context(format!("{status_desc} ({status})"))
84 .context("failed to get response text"),
85 };
86 Err(err)
87 }
88
89 async fn json_or_error<T: serde::de::DeserializeOwned>(self) -> anyhow::Result<T> {
90 let status = self.status();
91 self.ensure_content_type("application/json")
92 .await?
93 .json()
94 .await
95 .map_err(|err| {
96 let mut err = anyhow::Error::from(err);
97 if let Some(desc) = err_status_desc(status) {
98 err = err.context(format!("malformed json payload for {desc} ({status})"))
99 }
100 err
101 })
102 }
103
104 fn found(self) -> Option<Self> {
105 (self.status() != http::StatusCode::NOT_FOUND).then_some(self)
106 }
107}
108
109pub async fn spacetime_dns(
111 config: &Config,
112 domain: &str,
113 server: Option<&str>,
114) -> Result<Option<Identity>, anyhow::Error> {
115 let client = reqwest::Client::new();
116 let url = format!("{}/v1/database/{}/identity", config.get_host_url(server)?, domain);
117 let Some(res) = client.get(url).send().await?.found() else {
118 return Ok(None);
119 };
120 let text = res.error_for_status()?.text().await?;
121 text.parse()
122 .map(Some)
123 .context("identity endpoint did not return an identity")
124}
125
126pub async fn spacetime_server_fingerprint(url: &str) -> anyhow::Result<String> {
127 let builder = reqwest::Client::new().get(format!("{url}/v1/identity/public-key").as_str());
128 let res = builder.send().await?.error_for_status()?;
129 let fingerprint = res.text().await?;
130 Ok(fingerprint)
131}
132
133pub async fn spacetime_reverse_dns(
135 config: &Config,
136 identity: &str,
137 server: Option<&str>,
138) -> Result<GetNamesResponse, anyhow::Error> {
139 let client = reqwest::Client::new();
140 let url = format!("{}/v1/database/{}/names", config.get_host_url(server)?, identity);
141 client.get(url).send().await?.json_or_error().await
142}
143
144pub fn add_auth_header_opt(mut builder: RequestBuilder, auth_header: &AuthHeader) -> RequestBuilder {
146 if let Some(token) = &auth_header.token {
147 builder = builder.bearer_auth(token);
148 }
149 builder
150}
151
152pub async fn get_auth_header(
165 config: &mut Config,
166 anon_identity: bool,
167 target_server: Option<&str>,
168 interactive: bool,
169) -> anyhow::Result<AuthHeader> {
170 let token = if anon_identity {
171 None
172 } else {
173 Some(get_login_token_or_log_in(config, target_server, interactive).await?)
174 };
175 Ok(AuthHeader { token })
176}
177
178#[derive(Debug, Clone)]
179pub struct AuthHeader {
180 token: Option<String>,
181}
182impl AuthHeader {
183 pub fn to_header(&self) -> Option<http::HeaderValue> {
184 self.token.as_ref().map(|token| {
185 let mut val = http::HeaderValue::try_from(["Bearer ", token].concat()).unwrap();
186 val.set_sensitive(true);
187 val
188 })
189 }
190}
191
192pub const VALID_PROTOCOLS: [&str; 2] = ["http", "https"];
193
194#[derive(Clone, Copy, PartialEq, Debug)]
195pub enum ModuleLanguage {
196 Csharp,
197 Rust,
198}
199impl clap::ValueEnum for ModuleLanguage {
200 fn value_variants<'a>() -> &'a [Self] {
201 &[Self::Csharp, Self::Rust]
202 }
203 fn to_possible_value(&self) -> Option<clap::builder::PossibleValue> {
204 match self {
205 Self::Csharp => Some(clap::builder::PossibleValue::new("csharp").aliases(["c#", "cs", "C#", "CSharp"])),
206 Self::Rust => Some(clap::builder::PossibleValue::new("rust").aliases(["rs", "Rust"])),
207 }
208 }
209}
210
211pub fn detect_module_language(path_to_project: &Path) -> anyhow::Result<ModuleLanguage> {
212 if path_to_project.join("Cargo.toml").exists() {
215 Ok(ModuleLanguage::Rust)
216 } else if path_to_project
217 .read_dir()
218 .unwrap()
219 .any(|entry| entry.unwrap().path().extension() == Some("csproj".as_ref()))
220 {
221 Ok(ModuleLanguage::Csharp)
222 } else {
223 anyhow::bail!("Could not detect the language of the module. Are you in a SpacetimeDB project directory?")
224 }
225}
226
227pub fn url_to_host_and_protocol(url: &str) -> anyhow::Result<(&str, &str)> {
228 if contains_protocol(url) {
229 let protocol = url.split("://").next().unwrap();
230 let host = url.split("://").last().unwrap();
231
232 if !VALID_PROTOCOLS.contains(&protocol) {
233 Err(anyhow::anyhow!("Invalid protocol: {}", protocol))
234 } else {
235 Ok((host, protocol))
236 }
237 } else {
238 Err(anyhow::anyhow!("Invalid url: {}", url))
239 }
240}
241
242pub fn contains_protocol(name_or_url: &str) -> bool {
243 name_or_url.contains("://")
244}
245
246pub fn host_or_url_to_host_and_protocol(host_or_url: &str) -> (&str, Option<&str>) {
247 if contains_protocol(host_or_url) {
248 let (host, protocol) = url_to_host_and_protocol(host_or_url).unwrap();
249 (host, Some(protocol))
250 } else {
251 (host_or_url, None)
252 }
253}
254
255pub fn y_or_n(force: bool, prompt: &str) -> anyhow::Result<bool> {
259 if force {
260 println!("Skipping confirmation due to --yes");
261 return Ok(true);
262 }
263 let mut input = String::new();
264 print!("{prompt} [y/N]");
265 std::io::stdout().flush()?;
266 std::io::stdin().read_line(&mut input)?;
267 let input = input.trim().to_lowercase();
268 Ok(input == "y" || input == "yes")
269}
270
271pub fn unauth_error_context<T>(res: anyhow::Result<T>, identity: &str, server: &str) -> anyhow::Result<T> {
272 res.with_context(|| {
273 format!(
274 "Identity {identity} is not valid for server {server}.
275Please log back in with `spacetime logout` and then `spacetime login`."
276 )
277 })
278}
279
280pub fn decode_identity(token: &String) -> anyhow::Result<String> {
281 let token_parts: Vec<_> = token.split('.').collect();
285 if token_parts.len() != 3 {
286 return Err(anyhow::anyhow!("Token does not look like a JSON web token: {}", token));
287 }
288 let decoded_bytes = BASE_64_STD_NO_PAD.decode(token_parts[1])?;
289 let decoded_string = String::from_utf8(decoded_bytes)?;
290
291 let claims_data: IncomingClaims = serde_json::from_str(decoded_string.as_str())?;
292 let claims_data: SpacetimeIdentityClaims = claims_data.try_into()?;
293
294 Ok(claims_data.identity.to_string())
295}
296
297pub async fn get_login_token_or_log_in(
298 config: &mut Config,
299 target_server: Option<&str>,
300 interactive: bool,
301) -> anyhow::Result<String> {
302 if let Some(token) = config.spacetimedb_token() {
303 return Ok(token.clone());
304 }
305
306 let full_login = interactive
308 && y_or_n(
309 false,
310 "You are not logged in. Would you like to log in with spacetimedb.com?",
313 )?;
314
315 if full_login {
316 let host = Url::parse(DEFAULT_AUTH_HOST)?;
317 spacetimedb_login_force(config, &host, false).await
318 } else {
319 let host = Url::parse(&config.get_host_url(target_server)?)?;
320 spacetimedb_login_force(config, &host, true).await
321 }
322}
323
324pub fn resolve_sibling_binary(bin_name: &str) -> anyhow::Result<PathBuf> {
325 let resolved_exe = std::env::current_exe().context("could not retrieve current exe")?;
326 let bin_path = resolved_exe
327 .parent()
328 .unwrap()
329 .join(bin_name)
330 .with_extension(std::env::consts::EXE_EXTENSION);
331 Ok(bin_path)
332}