trailbase_client/
lib.rs

1//! A client library to connect to a TrailBase server via HTTP.
2//!
3//! TrailBase is a sub-millisecond, open-source application server with type-safe APIs, built-in
4//! JS/ES6/TS runtime, realtime, auth, and admin UI built on Rust, SQLite & V8.
5
6#![allow(clippy::needless_return)]
7
8use eventsource_stream::Eventsource;
9pub use futures::Stream;
10use futures::StreamExt;
11use parking_lot::RwLock;
12use reqwest::header::{HeaderMap, HeaderValue};
13use reqwest::Method;
14use std::borrow::Cow;
15use std::sync::Arc;
16use thiserror::Error;
17
18use serde::de::DeserializeOwned;
19use serde::{Deserialize, Serialize};
20
21#[derive(Debug, Error)]
22pub enum Error {
23  #[error("Reqwest: {0}")]
24  Reqwest(#[from] reqwest::Error),
25  #[error("JSON: {0}")]
26  Json(#[from] serde_json::Error),
27  #[error("JWT: {0}")]
28  Jwt(#[from] jsonwebtoken::errors::Error),
29  #[error("Url: {0}")]
30  Url(#[from] url::ParseError),
31  #[error("Precondition: {0}")]
32  Precondition(&'static str),
33}
34
35/// Represents the currently logged-in user.
36#[derive(Clone, Debug)]
37pub struct User {
38  pub sub: String,
39  pub email: String,
40}
41
42/// Holds the tokens minted by the server on login.
43///
44/// It is also the exact JSON serialization format.
45#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
46pub struct Tokens {
47  pub auth_token: String,
48  pub refresh_token: Option<String>,
49  pub csrf_token: Option<String>,
50}
51
52#[derive(Clone, Debug, Default)]
53pub struct Pagination {
54  pub cursor: Option<String>,
55  pub limit: Option<usize>,
56}
57
58impl Pagination {
59  pub fn with(limit: impl Into<Option<usize>>, cursor: impl Into<Option<String>>) -> Pagination {
60    return Pagination {
61      limit: limit.into(),
62      cursor: cursor.into(),
63    };
64  }
65
66  pub fn with_limit(limit: impl Into<Option<usize>>) -> Pagination {
67    return Pagination::with(limit, None);
68  }
69
70  pub fn with_cursor(cursor: impl Into<Option<String>>) -> Pagination {
71    return Pagination::with(None, cursor);
72  }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
76pub enum DbEvent {
77  Update(Option<serde_json::Value>),
78  Insert(Option<serde_json::Value>),
79  Delete(Option<serde_json::Value>),
80  Error(String),
81}
82
83#[derive(Clone, Debug, Deserialize)]
84pub struct ListResponse<T> {
85  pub cursor: Option<String>,
86  pub total_count: Option<usize>,
87  pub records: Vec<T>,
88}
89
90pub trait RecordId<'a> {
91  fn serialized_id(self) -> Cow<'a, str>;
92}
93
94impl RecordId<'_> for String {
95  fn serialized_id(self) -> Cow<'static, str> {
96    return Cow::Owned(self);
97  }
98}
99
100impl<'a> RecordId<'a> for &'a String {
101  fn serialized_id(self) -> Cow<'a, str> {
102    return Cow::Borrowed(self);
103  }
104}
105
106impl<'a> RecordId<'a> for &'a str {
107  fn serialized_id(self) -> Cow<'a, str> {
108    return Cow::Borrowed(self);
109  }
110}
111
112impl RecordId<'_> for i64 {
113  fn serialized_id(self) -> Cow<'static, str> {
114    return Cow::Owned(self.to_string());
115  }
116}
117
118pub trait ReadArgumentsTrait<'a> {
119  fn serialized_id(self) -> Cow<'a, str>;
120  fn expand(&self) -> Option<&'a [&'a str]>;
121}
122
123impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for T {
124  fn serialized_id(self) -> Cow<'a, str> {
125    return self.serialized_id();
126  }
127
128  fn expand(&self) -> Option<&'a [&'a str]> {
129    return None;
130  }
131}
132
133#[derive(Debug, Default)]
134pub struct ReadArguments<'a, T: RecordId<'a>> {
135  pub id: T,
136  pub expand: Option<&'a [&'a str]>,
137}
138
139impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for ReadArguments<'a, T> {
140  fn serialized_id(self) -> Cow<'a, str> {
141    return self.id.serialized_id();
142  }
143
144  fn expand(&self) -> Option<&'a [&'a str]> {
145    return self.expand;
146  }
147}
148
149struct ThinClient {
150  client: reqwest::Client,
151  url: url::Url,
152}
153
154impl ThinClient {
155  async fn fetch<T: Serialize>(
156    &self,
157    path: &str,
158    headers: HeaderMap,
159    method: Method,
160    body: Option<&T>,
161    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
162  ) -> Result<reqwest::Response, Error> {
163    assert!(path.starts_with("/"));
164
165    let mut url = self.url.clone();
166    url.set_path(path);
167
168    if let Some(query_params) = query_params {
169      let mut params = url.query_pairs_mut();
170      for (key, value) in query_params {
171        params.append_pair(key, value);
172      }
173    }
174
175    let request = {
176      let mut builder = self.client.request(method, url).headers(headers);
177      if let Some(ref body) = body {
178        builder = builder.body(serde_json::to_string(body)?);
179      }
180      builder.build()?
181    };
182
183    return Ok(self.client.execute(request).await?);
184  }
185}
186
187#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
188struct JwtTokenClaims {
189  sub: String,
190  iat: i64,
191  exp: i64,
192  email: String,
193  csrf_token: String,
194}
195
196fn decode_auth_token<T: DeserializeOwned>(token: &str) -> Result<T, Error> {
197  let decoding_key = jsonwebtoken::DecodingKey::from_secret(&[]);
198
199  // Don't validate the token, we don't have the secret key. Just deserialize the claims/contents.
200  let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::EdDSA);
201  validation.insecure_disable_signature_validation();
202
203  return Ok(jsonwebtoken::decode::<T>(token, &decoding_key, &validation).map(|data| data.claims)?);
204}
205
206#[derive(Clone)]
207pub struct RecordApi {
208  client: Arc<ClientState>,
209  name: String,
210}
211
212#[derive(Default)]
213pub struct ListArguments<'a> {
214  pub pagination: Pagination,
215  pub order: Option<&'a [&'a str]>,
216  pub filters: Option<&'a [&'a str]>,
217  pub expand: Option<&'a [&'a str]>,
218  pub count: bool,
219}
220
221impl RecordApi {
222  pub async fn list<T: DeserializeOwned>(
223    &self,
224    args: ListArguments<'_>,
225  ) -> Result<ListResponse<T>, Error> {
226    let mut params: Vec<(Cow<'static, str>, Cow<'static, str>)> = vec![];
227    if let Some(cursor) = args.pagination.cursor {
228      params.push((Cow::Borrowed("cursor"), Cow::Owned(cursor)));
229    }
230
231    if let Some(limit) = args.pagination.limit {
232      params.push((Cow::Borrowed("limit"), Cow::Owned(limit.to_string())));
233    }
234
235    #[inline]
236    fn to_list(slice: &[&str]) -> String {
237      return slice.join(",");
238    }
239
240    if let Some(order) = args.order {
241      if !order.is_empty() {
242        params.push((Cow::Borrowed("order"), Cow::Owned(to_list(order))));
243      }
244    }
245
246    if let Some(expand) = args.expand {
247      if !expand.is_empty() {
248        params.push((Cow::Borrowed("expand"), Cow::Owned(to_list(expand))));
249      }
250    }
251
252    if args.count {
253      params.push((Cow::Borrowed("count"), Cow::Borrowed("true")));
254    }
255
256    if let Some(filters) = args.filters {
257      for filter in filters {
258        let Some((name_op, value)) = filter.split_once("=") else {
259          panic!("Filter '{filter}' does not match: 'name[op]=value'");
260        };
261
262        params.push((
263          Cow::Owned(name_op.to_string()),
264          Cow::Owned(value.to_string()),
265        ));
266      }
267    }
268
269    let response = self
270      .client
271      .fetch(
272        &format!("/{RECORD_API}/{}", self.name),
273        Method::GET,
274        None::<&()>,
275        Some(&params),
276      )
277      .await?;
278
279    return Ok(response.json().await?);
280  }
281
282  pub async fn read<'a, T: DeserializeOwned>(
283    &self,
284    args: impl ReadArgumentsTrait<'a>,
285  ) -> Result<T, Error> {
286    let expand = args
287      .expand()
288      .map(|e| vec![(Cow::Borrowed("expand"), Cow::Owned(e.join(",")))]);
289
290    let response = self
291      .client
292      .fetch(
293        &format!(
294          "/{RECORD_API}/{name}/{id}",
295          name = self.name,
296          id = args.serialized_id()
297        ),
298        Method::GET,
299        None::<&()>,
300        expand.as_deref(),
301      )
302      .await?;
303
304    return Ok(response.json().await?);
305  }
306
307  pub async fn create<T: Serialize>(&self, record: T) -> Result<String, Error> {
308    return Ok(self.create_impl(record).await?.swap_remove(0));
309  }
310
311  pub async fn create_bulk<T: Serialize>(&self, record: &[T]) -> Result<Vec<String>, Error> {
312    return self.create_impl(record).await;
313  }
314
315  async fn create_impl<T: Serialize>(&self, record: T) -> Result<Vec<String>, Error> {
316    let response = self
317      .client
318      .fetch(
319        &format!("/{RECORD_API}/{name}", name = self.name),
320        Method::POST,
321        Some(&record),
322        None,
323      )
324      .await?;
325
326    #[derive(Deserialize)]
327    pub struct RecordIdResponse {
328      pub ids: Vec<String>,
329    }
330
331    return Ok(response.json::<RecordIdResponse>().await?.ids);
332  }
333
334  pub async fn update<'a, T: Serialize>(
335    &self,
336    id: impl RecordId<'a>,
337    record: T,
338  ) -> Result<(), Error> {
339    self
340      .client
341      .fetch(
342        &format!(
343          "/{RECORD_API}/{name}/{id}",
344          name = self.name,
345          id = id.serialized_id()
346        ),
347        Method::PATCH,
348        Some(&record),
349        None,
350      )
351      .await?;
352
353    return Ok(());
354  }
355
356  pub async fn delete<'a>(&self, id: impl RecordId<'a>) -> Result<(), Error> {
357    self
358      .client
359      .fetch(
360        &format!(
361          "/{RECORD_API}/{name}/{id}",
362          name = self.name,
363          id = id.serialized_id()
364        ),
365        Method::DELETE,
366        None::<&()>,
367        None,
368      )
369      .await?;
370
371    return Ok(());
372  }
373
374  pub async fn subscribe<'a>(
375    &self,
376    id: impl RecordId<'a>,
377  ) -> Result<impl Stream<Item = DbEvent>, Error> {
378    // TODO: Might have to add HeaderValue::from_static("text/event-stream").
379    let response = self
380      .client
381      .fetch(
382        &format!(
383          "/{RECORD_API}/{name}/subscribe/{id}",
384          name = self.name,
385          id = id.serialized_id()
386        ),
387        Method::GET,
388        None::<&()>,
389        None,
390      )
391      .await?;
392
393    return Ok(
394      response
395        .bytes_stream()
396        .eventsource()
397        .filter_map(|event_or| async {
398          if let Ok(event) = event_or {
399            if let Ok(db_event) = serde_json::from_str::<DbEvent>(&event.data) {
400              return Some(db_event);
401            }
402          }
403          return None;
404        }),
405    );
406  }
407}
408
409#[derive(Clone, Debug)]
410struct TokenState {
411  state: Option<(Tokens, JwtTokenClaims)>,
412  headers: HeaderMap,
413}
414
415impl TokenState {
416  fn build(tokens: Option<&Tokens>) -> TokenState {
417    let headers = build_headers(tokens);
418    return TokenState {
419      state: tokens.and_then(|tokens| {
420        let Ok(jwt_token) = decode_auth_token::<JwtTokenClaims>(&tokens.auth_token) else {
421          log::error!("Failed to decode auth token.");
422          return None;
423        };
424        return Some((tokens.clone(), jwt_token));
425      }),
426      headers,
427    };
428  }
429}
430
431struct ClientState {
432  client: ThinClient,
433  site: String,
434  tokens: RwLock<TokenState>,
435}
436
437impl ClientState {
438  #[inline]
439  async fn fetch<T: Serialize>(
440    &self,
441    path: &str,
442    method: Method,
443    body: Option<&T>,
444    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
445  ) -> Result<reqwest::Response, Error> {
446    let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
447    if let Some(refresh_token) = refresh_token {
448      let new_tokens = ClientState::refresh_tokens(&self.client, headers, refresh_token).await?;
449
450      headers = new_tokens.headers.clone();
451      *self.tokens.write() = new_tokens;
452    }
453
454    return Ok(
455      self
456        .client
457        .fetch(path, headers, method, body, query_params)
458        .await?
459        .error_for_status()?,
460    );
461  }
462
463  #[inline]
464  fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option<String>) {
465    #[inline]
466    fn should_refresh(jwt: &JwtTokenClaims) -> bool {
467      return jwt.exp - 60 < now() as i64;
468    }
469
470    let tokens = self.tokens.read();
471    let headers = tokens.headers.clone();
472    return match tokens.state {
473      Some(ref state) if should_refresh(&state.1) => (headers, state.0.refresh_token.clone()),
474      _ => (headers, None),
475    };
476  }
477
478  fn extract_headers_refresh_token(&self) -> Result<(HeaderMap, String), Error> {
479    let tokens = self.tokens.read();
480    let Some(ref state) = tokens.state else {
481      return Err(Error::Precondition("Not logged int?"));
482    };
483
484    let Some(ref refresh_token) = state.0.refresh_token else {
485      return Err(Error::Precondition("Missing refresh token"));
486    };
487
488    return Ok((tokens.headers.clone(), refresh_token.clone()));
489  }
490
491  async fn refresh_tokens(
492    client: &ThinClient,
493    headers: HeaderMap,
494    refresh_token: String,
495  ) -> Result<TokenState, Error> {
496    #[derive(Serialize)]
497    struct RefreshRequest<'a> {
498      refresh_token: &'a str,
499    }
500
501    let response = client
502      .fetch(
503        &format!("/{AUTH_API}/refresh"),
504        headers,
505        Method::POST,
506        Some(&RefreshRequest {
507          refresh_token: &refresh_token,
508        }),
509        None,
510      )
511      .await?;
512
513    #[derive(Deserialize)]
514    struct RefreshResponse {
515      auth_token: String,
516      csrf_token: Option<String>,
517    }
518
519    let refresh_response: RefreshResponse = response.json().await?;
520    return Ok(TokenState::build(Some(&Tokens {
521      auth_token: refresh_response.auth_token,
522      refresh_token: Some(refresh_token),
523      csrf_token: refresh_response.csrf_token,
524    })));
525  }
526}
527
528#[derive(Clone)]
529pub struct Client {
530  state: Arc<ClientState>,
531}
532
533impl Client {
534  pub fn new(site: &str, tokens: Option<Tokens>) -> Result<Client, Error> {
535    return Ok(Client {
536      state: Arc::new(ClientState {
537        client: ThinClient {
538          client: reqwest::Client::new(),
539          url: url::Url::parse(site)?,
540        },
541        site: site.to_string(),
542        tokens: RwLock::new(TokenState::build(tokens.as_ref())),
543      }),
544    });
545  }
546
547  pub fn site(&self) -> String {
548    return self.state.site.clone();
549  }
550
551  pub fn tokens(&self) -> Option<Tokens> {
552    return self.state.tokens.read().state.as_ref().map(|x| x.0.clone());
553  }
554
555  pub fn user(&self) -> Option<User> {
556    if let Some(state) = &self.state.tokens.read().state {
557      return Some(User {
558        sub: state.1.sub.clone(),
559        email: state.1.email.clone(),
560      });
561    }
562    return None;
563  }
564
565  pub fn records(&self, api_name: &str) -> RecordApi {
566    return RecordApi {
567      client: self.state.clone(),
568      name: api_name.to_string(),
569    };
570  }
571
572  pub async fn refresh(&self) -> Result<(), Error> {
573    let (headers, refresh_token) = self.state.extract_headers_refresh_token()?;
574    let new_tokens =
575      ClientState::refresh_tokens(&self.state.client, headers, refresh_token).await?;
576
577    *self.state.tokens.write() = new_tokens;
578    return Ok(());
579  }
580
581  pub async fn login(&self, email: &str, password: &str) -> Result<Tokens, Error> {
582    #[derive(Serialize)]
583    struct Credentials<'a> {
584      email: &'a str,
585      password: &'a str,
586    }
587
588    let response = self
589      .state
590      .fetch(
591        &format!("/{AUTH_API}/login"),
592        Method::POST,
593        Some(&Credentials { email, password }),
594        None,
595      )
596      .await?;
597
598    let tokens: Tokens = response.json().await?;
599    self.update_tokens(Some(&tokens));
600    return Ok(tokens);
601  }
602
603  pub async fn logout(&self) -> Result<(), Error> {
604    #[derive(Serialize)]
605    struct LogoutRequest {
606      refresh_token: String,
607    }
608
609    let response_or = match self.state.extract_headers_refresh_token() {
610      Ok((_headers, refresh_token)) => {
611        self
612          .state
613          .fetch(
614            &format!("/{AUTH_API}/logout"),
615            Method::POST,
616            Some(&LogoutRequest { refresh_token }),
617            None,
618          )
619          .await
620      }
621      _ => {
622        self
623          .state
624          .fetch(
625            &format!("/{AUTH_API}/logout"),
626            Method::GET,
627            None::<&()>,
628            None,
629          )
630          .await
631      }
632    };
633
634    self.update_tokens(None);
635
636    return response_or.map(|_| ());
637  }
638
639  fn update_tokens(&self, tokens: Option<&Tokens>) -> TokenState {
640    let state = TokenState::build(tokens);
641
642    *self.state.tokens.write() = state.clone();
643    // _authChange?.call(this, state.state?.$1);
644
645    if let Some(ref s) = state.state {
646      let now = now();
647      if s.1.exp < now as i64 {
648        log::warn!("Token expired");
649      }
650    }
651
652    return state;
653  }
654}
655
656fn build_headers(tokens: Option<&Tokens>) -> HeaderMap {
657  let mut base = HeaderMap::new();
658  base.insert("Content-Type", HeaderValue::from_static("application/json"));
659
660  if let Some(tokens) = tokens {
661    if let Ok(value) = HeaderValue::from_str(&format!("Bearer {}", tokens.auth_token)) {
662      base.insert("Authorization", value);
663    } else {
664      log::error!("Failed to build bearer token.");
665    }
666
667    if let Some(ref refresh) = tokens.refresh_token {
668      if let Ok(value) = HeaderValue::from_str(refresh) {
669        base.insert("Refresh-Token", value);
670      } else {
671        log::error!("Failed to build refresh token header.");
672      }
673    }
674
675    if let Some(ref csrf) = tokens.csrf_token {
676      if let Ok(value) = HeaderValue::from_str(csrf) {
677        base.insert("CSRF-Token", value);
678      } else {
679        log::error!("Failed to build refresh token header.");
680      }
681    }
682  }
683
684  return base;
685}
686
687fn now() -> u64 {
688  return std::time::SystemTime::now()
689    .duration_since(std::time::UNIX_EPOCH)
690    .expect("Duration since epoch")
691    .as_secs();
692}
693
694const AUTH_API: &str = "api/auth/v1";
695const RECORD_API: &str = "api/records/v1";
696
697#[cfg(test)]
698mod tests {
699  use super::*;
700
701  #[tokio::test]
702  async fn is_send_test() {
703    let client = Client::new("http://127.0.0.1:4000", None).unwrap();
704
705    let api = client.records("simple_strict_table");
706
707    for _ in 0..2 {
708      let api = api.clone();
709      tokio::spawn(async move {
710        // This would not compile if locks would be held across async function calls.
711        let response = api.read::<serde_json::Value>(0).await;
712        assert!(response.is_err());
713      })
714      .await
715      .unwrap();
716    }
717  }
718}