trailbase_client/
lib.rs

1//! A client library to connect to a TrailBase server via HTTP.
2//!
3//! TrailBase is a sub-millisecond, open-source application server with type-safe APIs, built-in
4//! JS/ES6/TS runtime, realtime, auth, and admin UI built on Rust, SQLite & V8.
5
6#![forbid(unsafe_code, clippy::unwrap_used)]
7#![allow(clippy::needless_return)]
8#![warn(clippy::await_holding_lock, clippy::inefficient_to_string)]
9
10use eventsource_stream::Eventsource;
11pub use futures::Stream;
12use futures::StreamExt;
13use parking_lot::RwLock;
14use reqwest::header::{self, HeaderMap, HeaderValue};
15use reqwest::{Method, StatusCode};
16use std::borrow::Cow;
17use std::sync::Arc;
18use thiserror::Error;
19use tracing::*;
20
21use serde::de::DeserializeOwned;
22use serde::{Deserialize, Serialize};
23
24// TODO: Don't leak internals and make this non_exhaustive.
25#[derive(Debug, Error)]
26#[non_exhaustive]
27pub enum Error {
28  #[error("HTTP status: {0}")]
29  HttpStatus(StatusCode),
30
31  #[error("MissingRefreshToken")]
32  MissingRefreshToken,
33
34  #[error("RecordSerialization: {0}")]
35  RecordSerialization(serde_json::Error),
36
37  #[error("InvalidToken: {0}")]
38  InvalidToken(jsonwebtoken::errors::Error),
39
40  #[error("InvalidUrl: {0}")]
41  InvalidUrl(url::ParseError),
42
43  // NOTE: This error is leaky but comprehensively unpacking reqwest is unsustainable.
44  #[error("Reqwest: {0}")]
45  OtherReqwest(reqwest::Error),
46}
47
48impl From<reqwest::Error> for Error {
49  fn from(err: reqwest::Error) -> Self {
50    match err.status() {
51      Some(code) => Self::HttpStatus(code),
52      _ => Self::OtherReqwest(err),
53    }
54  }
55}
56
57/// Represents the currently logged-in user.
58#[derive(Clone, Debug)]
59pub struct User {
60  pub sub: String,
61  pub email: String,
62}
63
64/// Holds the tokens minted by the server on login.
65///
66/// It is also the exact JSON serialization format.
67#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
68pub struct Tokens {
69  pub auth_token: String,
70  pub refresh_token: Option<String>,
71  pub csrf_token: Option<String>,
72}
73
74#[derive(Clone, Debug, Default)]
75pub struct Pagination {
76  pub cursor: Option<String>,
77  pub limit: Option<usize>,
78}
79
80impl Pagination {
81  pub fn with(limit: impl Into<Option<usize>>, cursor: impl Into<Option<String>>) -> Pagination {
82    return Pagination {
83      limit: limit.into(),
84      cursor: cursor.into(),
85    };
86  }
87
88  pub fn with_limit(limit: impl Into<Option<usize>>) -> Pagination {
89    return Pagination::with(limit, None);
90  }
91
92  pub fn with_cursor(cursor: impl Into<Option<String>>) -> Pagination {
93    return Pagination::with(None, cursor);
94  }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
98pub enum DbEvent {
99  Update(Option<serde_json::Value>),
100  Insert(Option<serde_json::Value>),
101  Delete(Option<serde_json::Value>),
102  Error(String),
103}
104
105#[derive(Clone, Debug, Deserialize)]
106pub struct ListResponse<T> {
107  pub cursor: Option<String>,
108  pub total_count: Option<usize>,
109  pub records: Vec<T>,
110}
111
112pub trait RecordId<'a> {
113  fn serialized_id(self) -> Cow<'a, str>;
114}
115
116impl RecordId<'_> for String {
117  fn serialized_id(self) -> Cow<'static, str> {
118    return Cow::Owned(self);
119  }
120}
121
122impl<'a> RecordId<'a> for &'a String {
123  fn serialized_id(self) -> Cow<'a, str> {
124    return Cow::Borrowed(self);
125  }
126}
127
128impl<'a> RecordId<'a> for &'a str {
129  fn serialized_id(self) -> Cow<'a, str> {
130    return Cow::Borrowed(self);
131  }
132}
133
134impl RecordId<'_> for i64 {
135  fn serialized_id(self) -> Cow<'static, str> {
136    return Cow::Owned(self.to_string());
137  }
138}
139
140pub trait ReadArgumentsTrait<'a> {
141  fn serialized_id(self) -> Cow<'a, str>;
142  fn expand(&self) -> Option<&'a [&'a str]>;
143}
144
145impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for T {
146  fn serialized_id(self) -> Cow<'a, str> {
147    return self.serialized_id();
148  }
149
150  fn expand(&self) -> Option<&'a [&'a str]> {
151    return None;
152  }
153}
154
155#[derive(Debug, Default)]
156pub struct ReadArguments<'a, T: RecordId<'a>> {
157  pub id: T,
158  pub expand: Option<&'a [&'a str]>,
159}
160
161impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for ReadArguments<'a, T> {
162  fn serialized_id(self) -> Cow<'a, str> {
163    return self.id.serialized_id();
164  }
165
166  fn expand(&self) -> Option<&'a [&'a str]> {
167    return self.expand;
168  }
169}
170
171struct ThinClient {
172  client: reqwest::Client,
173  url: url::Url,
174}
175
176impl ThinClient {
177  async fn fetch<T: Serialize>(
178    &self,
179    path: &str,
180    headers: HeaderMap,
181    method: Method,
182    body: Option<&T>,
183    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
184  ) -> Result<reqwest::Response, Error> {
185    assert!(path.starts_with("/"));
186
187    let mut url = self.url.clone();
188    url.set_path(path);
189
190    if let Some(query_params) = query_params {
191      let mut params = url.query_pairs_mut();
192      for (key, value) in query_params {
193        params.append_pair(key, value);
194      }
195    }
196
197    let request = {
198      let mut builder = self.client.request(method, url).headers(headers);
199      if let Some(ref body) = body {
200        let json = serde_json::to_string(body).map_err(Error::RecordSerialization)?;
201        builder = builder.body(json);
202      }
203      builder.build()?
204    };
205
206    return Ok(self.client.execute(request).await?);
207  }
208}
209
210#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
211struct JwtTokenClaims {
212  sub: String,
213  iat: i64,
214  exp: i64,
215  email: String,
216  csrf_token: String,
217}
218
219fn decode_auth_token<T: DeserializeOwned>(token: &str) -> Result<T, Error> {
220  let decoding_key = jsonwebtoken::DecodingKey::from_secret(&[]);
221
222  // Don't validate the token, we don't have the secret key. Just deserialize the claims/contents.
223  let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::EdDSA);
224  validation.insecure_disable_signature_validation();
225
226  return jsonwebtoken::decode::<T>(token, &decoding_key, &validation)
227    .map(|data| data.claims)
228    .map_err(Error::InvalidToken);
229}
230
231#[derive(Clone)]
232pub struct RecordApi {
233  client: Arc<ClientState>,
234  name: String,
235}
236
237#[derive(Default)]
238pub struct ListArguments<'a> {
239  pub pagination: Pagination,
240  pub order: Option<&'a [&'a str]>,
241  pub filters: Option<&'a [&'a str]>,
242  pub expand: Option<&'a [&'a str]>,
243  pub count: bool,
244}
245
246impl RecordApi {
247  pub async fn list<T: DeserializeOwned>(
248    &self,
249    args: ListArguments<'_>,
250  ) -> Result<ListResponse<T>, Error> {
251    let mut params: Vec<(Cow<'static, str>, Cow<'static, str>)> = vec![];
252    if let Some(cursor) = args.pagination.cursor {
253      params.push((Cow::Borrowed("cursor"), Cow::Owned(cursor)));
254    }
255
256    if let Some(limit) = args.pagination.limit {
257      params.push((Cow::Borrowed("limit"), Cow::Owned(limit.to_string())));
258    }
259
260    #[inline]
261    fn to_list(slice: &[&str]) -> String {
262      return slice.join(",");
263    }
264
265    if let Some(order) = args.order {
266      if !order.is_empty() {
267        params.push((Cow::Borrowed("order"), Cow::Owned(to_list(order))));
268      }
269    }
270
271    if let Some(expand) = args.expand {
272      if !expand.is_empty() {
273        params.push((Cow::Borrowed("expand"), Cow::Owned(to_list(expand))));
274      }
275    }
276
277    if args.count {
278      params.push((Cow::Borrowed("count"), Cow::Borrowed("true")));
279    }
280
281    if let Some(filters) = args.filters {
282      for filter in filters {
283        let Some((name_op, value)) = filter.split_once("=") else {
284          panic!("Filter '{filter}' does not match: 'name[op]=value'");
285        };
286
287        params.push((
288          Cow::Owned(name_op.to_string()),
289          Cow::Owned(value.to_string()),
290        ));
291      }
292    }
293
294    let response = self
295      .client
296      .fetch(
297        &format!("/{RECORD_API}/{}", self.name),
298        Method::GET,
299        None::<&()>,
300        Some(&params),
301      )
302      .await?;
303
304    return json(response).await;
305  }
306
307  pub async fn read<'a, T: DeserializeOwned>(
308    &self,
309    args: impl ReadArgumentsTrait<'a>,
310  ) -> Result<T, Error> {
311    let expand = args
312      .expand()
313      .map(|e| vec![(Cow::Borrowed("expand"), Cow::Owned(e.join(",")))]);
314
315    let response = self
316      .client
317      .fetch(
318        &format!(
319          "/{RECORD_API}/{name}/{id}",
320          name = self.name,
321          id = args.serialized_id()
322        ),
323        Method::GET,
324        None::<&()>,
325        expand.as_deref(),
326      )
327      .await?;
328
329    return json(response).await;
330  }
331
332  pub async fn create<T: Serialize>(&self, record: T) -> Result<String, Error> {
333    return Ok(self.create_impl(record).await?.swap_remove(0));
334  }
335
336  pub async fn create_bulk<T: Serialize>(&self, record: &[T]) -> Result<Vec<String>, Error> {
337    return self.create_impl(record).await;
338  }
339
340  async fn create_impl<T: Serialize>(&self, record: T) -> Result<Vec<String>, Error> {
341    let response = self
342      .client
343      .fetch(
344        &format!("/{RECORD_API}/{name}", name = self.name),
345        Method::POST,
346        Some(&record),
347        None,
348      )
349      .await?;
350
351    #[derive(Deserialize)]
352    pub struct RecordIdResponse {
353      pub ids: Vec<String>,
354    }
355
356    return Ok(json::<RecordIdResponse>(response).await?.ids);
357  }
358
359  pub async fn update<'a, T: Serialize>(
360    &self,
361    id: impl RecordId<'a>,
362    record: T,
363  ) -> Result<(), Error> {
364    self
365      .client
366      .fetch(
367        &format!(
368          "/{RECORD_API}/{name}/{id}",
369          name = self.name,
370          id = id.serialized_id()
371        ),
372        Method::PATCH,
373        Some(&record),
374        None,
375      )
376      .await?;
377
378    return Ok(());
379  }
380
381  pub async fn delete<'a>(&self, id: impl RecordId<'a>) -> Result<(), Error> {
382    self
383      .client
384      .fetch(
385        &format!(
386          "/{RECORD_API}/{name}/{id}",
387          name = self.name,
388          id = id.serialized_id()
389        ),
390        Method::DELETE,
391        None::<&()>,
392        None,
393      )
394      .await?;
395
396    return Ok(());
397  }
398
399  pub async fn subscribe<'a>(
400    &self,
401    id: impl RecordId<'a>,
402  ) -> Result<impl Stream<Item = DbEvent>, Error> {
403    // TODO: Might have to add HeaderValue::from_static("text/event-stream").
404    let response = self
405      .client
406      .fetch(
407        &format!(
408          "/{RECORD_API}/{name}/subscribe/{id}",
409          name = self.name,
410          id = id.serialized_id()
411        ),
412        Method::GET,
413        None::<&()>,
414        None,
415      )
416      .await?;
417
418    return Ok(
419      response
420        .bytes_stream()
421        .eventsource()
422        .filter_map(|event_or| async {
423          if let Ok(event) = event_or {
424            if let Ok(db_event) = serde_json::from_str::<DbEvent>(&event.data) {
425              return Some(db_event);
426            }
427          }
428          return None;
429        }),
430    );
431  }
432}
433
434#[derive(Clone, Debug)]
435struct TokenState {
436  state: Option<(Tokens, JwtTokenClaims)>,
437  headers: HeaderMap,
438}
439
440impl TokenState {
441  fn build(tokens: Option<&Tokens>) -> TokenState {
442    let headers = build_headers(tokens);
443    return TokenState {
444      state: tokens.and_then(|tokens| {
445        let Ok(jwt_token) = decode_auth_token::<JwtTokenClaims>(&tokens.auth_token) else {
446          error!("Failed to decode auth token.");
447          return None;
448        };
449        return Some((tokens.clone(), jwt_token));
450      }),
451      headers,
452    };
453  }
454}
455
456struct ClientState {
457  client: ThinClient,
458  site: String,
459  tokens: RwLock<TokenState>,
460}
461
462impl ClientState {
463  #[inline]
464  async fn fetch<T: Serialize>(
465    &self,
466    path: &str,
467    method: Method,
468    body: Option<&T>,
469    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
470  ) -> Result<reqwest::Response, Error> {
471    let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
472    if let Some(refresh_token) = refresh_token {
473      let new_tokens = ClientState::refresh_tokens(&self.client, headers, refresh_token).await?;
474
475      headers = new_tokens.headers.clone();
476      *self.tokens.write() = new_tokens;
477    }
478
479    return Ok(
480      self
481        .client
482        .fetch(path, headers, method, body, query_params)
483        .await?
484        .error_for_status()?,
485    );
486  }
487
488  #[inline]
489  fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option<String>) {
490    #[inline]
491    fn should_refresh(jwt: &JwtTokenClaims) -> bool {
492      return jwt.exp - 60 < now() as i64;
493    }
494
495    let tokens = self.tokens.read();
496    let headers = tokens.headers.clone();
497    return match tokens.state {
498      Some(ref state) if should_refresh(&state.1) => (headers, state.0.refresh_token.clone()),
499      _ => (headers, None),
500    };
501  }
502
503  fn extract_headers_refresh_token(&self) -> Option<(HeaderMap, String)> {
504    let tokens = self.tokens.read();
505    let state = tokens.state.as_ref()?;
506
507    if let Some(ref refresh_token) = state.0.refresh_token {
508      return Some((tokens.headers.clone(), refresh_token.clone()));
509    }
510    return None;
511  }
512
513  async fn refresh_tokens(
514    client: &ThinClient,
515    headers: HeaderMap,
516    refresh_token: String,
517  ) -> Result<TokenState, Error> {
518    #[derive(Serialize)]
519    struct RefreshRequest<'a> {
520      refresh_token: &'a str,
521    }
522
523    let response = client
524      .fetch(
525        &format!("/{AUTH_API}/refresh"),
526        headers,
527        Method::POST,
528        Some(&RefreshRequest {
529          refresh_token: &refresh_token,
530        }),
531        None,
532      )
533      .await?;
534
535    #[derive(Deserialize)]
536    struct RefreshResponse {
537      auth_token: String,
538      csrf_token: Option<String>,
539    }
540
541    let refresh_response: RefreshResponse = json(response).await?;
542    return Ok(TokenState::build(Some(&Tokens {
543      auth_token: refresh_response.auth_token,
544      refresh_token: Some(refresh_token),
545      csrf_token: refresh_response.csrf_token,
546    })));
547  }
548}
549
550#[derive(Clone)]
551pub struct Client {
552  state: Arc<ClientState>,
553}
554
555impl Client {
556  pub fn new(site: &str, tokens: Option<Tokens>) -> Result<Client, Error> {
557    return Ok(Client {
558      state: Arc::new(ClientState {
559        client: ThinClient {
560          client: reqwest::Client::new(),
561          url: url::Url::parse(site).map_err(Error::InvalidUrl)?,
562        },
563        site: site.to_string(),
564        tokens: RwLock::new(TokenState::build(tokens.as_ref())),
565      }),
566    });
567  }
568
569  pub fn site(&self) -> String {
570    return self.state.site.clone();
571  }
572
573  pub fn tokens(&self) -> Option<Tokens> {
574    return self.state.tokens.read().state.as_ref().map(|x| x.0.clone());
575  }
576
577  pub fn user(&self) -> Option<User> {
578    if let Some(state) = &self.state.tokens.read().state {
579      return Some(User {
580        sub: state.1.sub.clone(),
581        email: state.1.email.clone(),
582      });
583    }
584    return None;
585  }
586
587  pub fn records(&self, api_name: &str) -> RecordApi {
588    return RecordApi {
589      client: self.state.clone(),
590      name: api_name.to_string(),
591    };
592  }
593
594  pub async fn refresh(&self) -> Result<(), Error> {
595    let Some((headers, refresh_token)) = self.state.extract_headers_refresh_token() else {
596      return Err(Error::MissingRefreshToken);
597    };
598
599    let new_tokens =
600      ClientState::refresh_tokens(&self.state.client, headers, refresh_token).await?;
601
602    *self.state.tokens.write() = new_tokens;
603    return Ok(());
604  }
605
606  pub async fn login(&self, email: &str, password: &str) -> Result<Tokens, Error> {
607    #[derive(Serialize)]
608    struct Credentials<'a> {
609      email: &'a str,
610      password: &'a str,
611    }
612
613    let response = self
614      .state
615      .fetch(
616        &format!("/{AUTH_API}/login"),
617        Method::POST,
618        Some(&Credentials { email, password }),
619        None,
620      )
621      .await?;
622
623    let tokens: Tokens = json(response).await?;
624    self.update_tokens(Some(&tokens));
625    return Ok(tokens);
626  }
627
628  pub async fn logout(&self) -> Result<(), Error> {
629    #[derive(Serialize)]
630    struct LogoutRequest {
631      refresh_token: String,
632    }
633
634    let response_or = match self.state.extract_headers_refresh_token() {
635      Some((_headers, refresh_token)) => {
636        self
637          .state
638          .fetch(
639            &format!("/{AUTH_API}/logout"),
640            Method::POST,
641            Some(&LogoutRequest { refresh_token }),
642            None,
643          )
644          .await
645      }
646      _ => {
647        self
648          .state
649          .fetch(
650            &format!("/{AUTH_API}/logout"),
651            Method::GET,
652            None::<&()>,
653            None,
654          )
655          .await
656      }
657    };
658
659    self.update_tokens(None);
660
661    return response_or.map(|_| ());
662  }
663
664  fn update_tokens(&self, tokens: Option<&Tokens>) -> TokenState {
665    let state = TokenState::build(tokens);
666
667    *self.state.tokens.write() = state.clone();
668    // _authChange?.call(this, state.state?.$1);
669
670    if let Some(ref s) = state.state {
671      let now = now();
672      if s.1.exp < now as i64 {
673        warn!("Token expired");
674      }
675    }
676
677    return state;
678  }
679}
680
681fn build_headers(tokens: Option<&Tokens>) -> HeaderMap {
682  let mut base = HeaderMap::with_capacity(5);
683  base.insert(
684    header::CONTENT_TYPE,
685    HeaderValue::from_static("application/json"),
686  );
687
688  if let Some(tokens) = tokens {
689    if let Ok(value) = HeaderValue::from_str(&format!("Bearer {}", tokens.auth_token)) {
690      base.insert(header::AUTHORIZATION, value);
691    } else {
692      error!("Failed to build bearer token.");
693    }
694
695    if let Some(ref refresh) = tokens.refresh_token {
696      if let Ok(value) = HeaderValue::from_str(refresh) {
697        base.insert("Refresh-Token", value);
698      } else {
699        error!("Failed to build refresh token header.");
700      }
701    }
702
703    if let Some(ref csrf) = tokens.csrf_token {
704      if let Ok(value) = HeaderValue::from_str(csrf) {
705        base.insert("CSRF-Token", value);
706      } else {
707        error!("Failed to build refresh token header.");
708      }
709    }
710  }
711
712  return base;
713}
714
715fn now() -> u64 {
716  return std::time::SystemTime::now()
717    .duration_since(std::time::UNIX_EPOCH)
718    .expect("Duration since epoch")
719    .as_secs();
720}
721
722#[inline]
723async fn json<T: DeserializeOwned>(resp: reqwest::Response) -> Result<T, Error> {
724  let full = resp.bytes().await?;
725  return serde_json::from_slice(&full).map_err(Error::RecordSerialization);
726}
727
728const AUTH_API: &str = "api/auth/v1";
729const RECORD_API: &str = "api/records/v1";
730
731#[cfg(test)]
732mod tests {
733  use super::*;
734
735  #[tokio::test]
736  async fn is_send_test() {
737    let client = Client::new("http://127.0.0.1:4000", None).unwrap();
738
739    let api = client.records("simple_strict_table");
740
741    for _ in 0..2 {
742      let api = api.clone();
743      tokio::spawn(async move {
744        // This would not compile if locks would be held across async function calls.
745        let response = api.read::<serde_json::Value>(0).await;
746        assert!(response.is_err());
747      })
748      .await
749      .unwrap();
750    }
751  }
752}