Skip to main content

trailbase_client/
lib.rs

1//! A client library to connect to a TrailBase server via HTTP.
2//!
3//! TrailBase is a sub-millisecond, open-source application server with type-safe APIs, built-in
4//! WASM runtime, realtime, auth, and admin UI built on Rust, SQLite & Wasmtime.
5
6#![forbid(unsafe_code, clippy::unwrap_used)]
7#![allow(clippy::needless_return)]
8#![warn(clippy::await_holding_lock, clippy::inefficient_to_string)]
9
10use eventsource_stream::Eventsource;
11use futures_lite::StreamExt;
12use parking_lot::RwLock;
13use reqwest::header::{self, HeaderMap, HeaderValue};
14use reqwest::{Method, StatusCode};
15use serde::de::DeserializeOwned;
16use serde::{Deserialize, Serialize};
17use std::borrow::Cow;
18use std::sync::Arc;
19use thiserror::Error;
20use tracing::*;
21
22pub use futures_lite::Stream;
23
24#[derive(Debug, Error)]
25#[non_exhaustive]
26pub enum Error {
27  #[error("HTTP status: {0}")]
28  HttpStatus(StatusCode),
29
30  #[error("RecordSerialization: {0}")]
31  RecordSerialization(serde_json::Error),
32
33  #[error("InvalidToken: {0}")]
34  InvalidToken(jsonwebtoken::errors::Error),
35
36  #[error("InvalidUrl: {0}")]
37  InvalidUrl(url::ParseError),
38
39  // NOTE: This error is leaky but comprehensively unpacking reqwest is unsustainable.
40  #[error("Reqwest: {0}")]
41  OtherReqwest(reqwest::Error),
42
43  #[cfg(feature = "ws")]
44  #[error("WebSocket: {0}")]
45  WebSocket(#[from] reqwest_websocket::Error),
46}
47
48impl From<reqwest::Error> for Error {
49  fn from(err: reqwest::Error) -> Self {
50    match err.status() {
51      Some(code) => Self::HttpStatus(code),
52      _ => Self::OtherReqwest(err),
53    }
54  }
55}
56
57/// Represents the currently logged-in user.
58#[derive(Clone, Debug)]
59pub struct User {
60  pub sub: String,
61  pub email: String,
62}
63
64/// Holds the tokens minted by the server on login.
65///
66/// It is also the exact JSON serialization format.
67#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
68pub struct Tokens {
69  pub auth_token: String,
70  pub refresh_token: Option<String>,
71  pub csrf_token: Option<String>,
72}
73
74#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
75pub struct MultiFactorAuthToken {
76  mfa_token: String,
77}
78
79#[derive(Clone, Debug, Default, PartialEq)]
80pub struct Pagination {
81  cursor: Option<String>,
82  limit: Option<usize>,
83  offset: Option<usize>,
84}
85
86impl Pagination {
87  pub fn new() -> Self {
88    return Self::default();
89  }
90
91  pub fn with_limit(mut self, limit: impl Into<Option<usize>>) -> Pagination {
92    self.limit = limit.into();
93    return self;
94  }
95
96  pub fn with_cursor(mut self, cursor: impl Into<Option<String>>) -> Pagination {
97    self.cursor = cursor.into();
98    return self;
99  }
100
101  pub fn with_offset(mut self, offset: impl Into<Option<usize>>) -> Pagination {
102    self.offset = offset.into();
103    return self;
104  }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
108pub enum DbEvent {
109  Update(Option<serde_json::Value>),
110  Insert(Option<serde_json::Value>),
111  Delete(Option<serde_json::Value>),
112  Error(String),
113}
114
115#[derive(Clone, Debug, Deserialize)]
116pub struct ListResponse<T> {
117  pub cursor: Option<String>,
118  pub total_count: Option<usize>,
119  pub records: Vec<T>,
120}
121
122pub trait RecordId<'a> {
123  fn serialized_id(self) -> Cow<'a, str>;
124}
125
126impl RecordId<'_> for String {
127  fn serialized_id(self) -> Cow<'static, str> {
128    return Cow::Owned(self);
129  }
130}
131
132impl<'a> RecordId<'a> for &'a String {
133  fn serialized_id(self) -> Cow<'a, str> {
134    return Cow::Borrowed(self);
135  }
136}
137
138impl<'a> RecordId<'a> for &'a str {
139  fn serialized_id(self) -> Cow<'a, str> {
140    return Cow::Borrowed(self);
141  }
142}
143
144impl RecordId<'_> for i64 {
145  fn serialized_id(self) -> Cow<'static, str> {
146    return Cow::Owned(self.to_string());
147  }
148}
149
150pub trait ReadArgumentsTrait<'a> {
151  fn serialized_id(self) -> Cow<'a, str>;
152  fn expand(&self) -> Option<&Vec<&'a str>>;
153}
154
155impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for T {
156  fn serialized_id(self) -> Cow<'a, str> {
157    return self.serialized_id();
158  }
159
160  fn expand(&self) -> Option<&Vec<&'a str>> {
161    return None;
162  }
163}
164
165#[derive(Clone, Debug, PartialEq)]
166pub struct ReadArguments<'a, T: RecordId<'a>> {
167  id: T,
168  expand: Option<Vec<&'a str>>,
169}
170
171impl<'a, T: RecordId<'a>> ReadArguments<'a, T> {
172  pub fn new(id: T) -> Self {
173    return Self { id, expand: None };
174  }
175
176  pub fn with_expand(mut self, expand: impl AsRef<[&'a str]>) -> Self {
177    self.expand = Some(expand.as_ref().to_vec());
178    return self;
179  }
180}
181
182impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for ReadArguments<'a, T> {
183  fn serialized_id(self) -> Cow<'a, str> {
184    return self.id.serialized_id();
185  }
186
187  fn expand(&self) -> Option<&Vec<&'a str>> {
188    return self.expand.as_ref();
189  }
190}
191
192#[async_trait::async_trait]
193pub trait Transport {
194  async fn fetch(
195    &self,
196    path: &str,
197    headers: HeaderMap,
198    method: Method,
199    body: Option<Vec<u8>>,
200    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
201  ) -> Result<http::Response<reqwest::Body>, Error>;
202
203  #[cfg(feature = "ws")]
204  async fn upgrade_ws(
205    &self,
206    path: &str,
207    headers: HeaderMap,
208    method: Method,
209    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
210  ) -> Result<reqwest_websocket::UpgradeResponse, Error>;
211}
212
213pub struct DefaultTransport {
214  client: reqwest::Client,
215  url: url::Url,
216}
217
218impl DefaultTransport {
219  pub fn new(url: url::Url) -> Self {
220    return Self {
221      client: reqwest::Client::new(),
222      url,
223    };
224  }
225}
226
227#[async_trait::async_trait]
228impl Transport for DefaultTransport {
229  async fn fetch(
230    &self,
231    path: &str,
232    headers: HeaderMap,
233    method: Method,
234    body: Option<Vec<u8>>,
235    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
236  ) -> Result<http::Response<reqwest::Body>, Error> {
237    assert!(path.starts_with("/"));
238
239    let mut url = self.url.clone();
240    url.set_path(path);
241
242    if let Some(query_params) = query_params {
243      let mut params = url.query_pairs_mut();
244      for (key, value) in query_params {
245        params.append_pair(key, value);
246      }
247    }
248
249    let request = {
250      let mut builder = self.client.request(method, url).headers(headers);
251      if let Some(body) = body {
252        // let json = serde_json::to_string(body).map_err(Error::RecordSerialization)?;
253        // builder = builder.body(json);
254        builder = builder.body(body);
255      }
256      builder.build()?
257    };
258
259    return Ok(self.client.execute(request).await?.into());
260  }
261
262  #[cfg(feature = "ws")]
263  async fn upgrade_ws(
264    &self,
265    path: &str,
266    headers: HeaderMap,
267    method: Method,
268    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
269  ) -> Result<reqwest_websocket::UpgradeResponse, Error> {
270    use reqwest_websocket::Upgrade;
271
272    assert!(path.starts_with("/"));
273
274    let mut url = self.url.clone();
275    url.set_path(path);
276
277    if let Some(query_params) = query_params {
278      let mut params = url.query_pairs_mut();
279      for (key, value) in query_params {
280        params.append_pair(key, value);
281      }
282    }
283
284    return Ok(
285      self
286        .client
287        .request(method, url)
288        .headers(headers)
289        .upgrade()
290        .send()
291        .await?,
292    );
293  }
294}
295
296#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
297struct JwtTokenClaims {
298  sub: String,
299  iat: i64,
300  exp: i64,
301  email: String,
302  csrf_token: String,
303}
304
305fn decode_auth_token<T: DeserializeOwned + Clone>(token: &str) -> Result<T, Error> {
306  return jsonwebtoken::dangerous::insecure_decode::<T>(token)
307    .map(|data| data.claims)
308    .map_err(Error::InvalidToken);
309}
310
311#[derive(Clone)]
312pub struct RecordApi {
313  client: Arc<ClientState>,
314  name: String,
315}
316
317#[derive(Clone, Debug, Default, PartialEq)]
318pub struct ListArguments<'a> {
319  pagination: Pagination,
320  order: Option<Vec<&'a str>>,
321  filters: Option<ValueOrFilterGroup>,
322  expand: Option<Vec<&'a str>>,
323  count: bool,
324}
325
326#[derive(Clone, Copy, Debug, PartialEq)]
327pub enum CompareOp {
328  Equal,
329  NotEqual,
330  GreaterThanEqual,
331  GreaterThan,
332  LessThanEqual,
333  LessThan,
334  Like,
335  Regexp,
336  StWithin,
337  StIntersects,
338  StContains,
339}
340
341impl CompareOp {
342  fn format(&self) -> &'static str {
343    return match self {
344      Self::Equal => "$eq",
345      Self::NotEqual => "$ne",
346      Self::GreaterThanEqual => "$gte",
347      Self::GreaterThan => "$gt",
348      Self::LessThanEqual => "$lte",
349      Self::LessThan => "$lt",
350      Self::Like => "$like",
351      Self::Regexp => "$re",
352      Self::StWithin => "@within",
353      Self::StIntersects => "@intersects",
354      Self::StContains => "@contains",
355    };
356  }
357}
358
359#[derive(Clone, Default, Debug, PartialEq)]
360pub struct Filter {
361  pub column: String,
362  pub op: Option<CompareOp>,
363  pub value: String,
364}
365
366impl Filter {
367  pub fn new(column: impl Into<String>, op: CompareOp, value: impl Into<String>) -> Self {
368    return Self {
369      column: column.into(),
370      op: Some(op),
371      value: value.into(),
372    };
373  }
374}
375
376impl From<Filter> for ValueOrFilterGroup {
377  fn from(value: Filter) -> Self {
378    return ValueOrFilterGroup::Filter(value);
379  }
380}
381
382#[derive(Clone, Debug, PartialEq)]
383pub enum ValueOrFilterGroup {
384  Filter(Filter),
385  And(Vec<ValueOrFilterGroup>),
386  Or(Vec<ValueOrFilterGroup>),
387}
388
389impl<F> From<F> for ValueOrFilterGroup
390where
391  F: Into<Vec<Filter>>,
392{
393  fn from(filters: F) -> Self {
394    return ValueOrFilterGroup::And(
395      filters
396        .into()
397        .into_iter()
398        .map(ValueOrFilterGroup::Filter)
399        .collect(),
400    );
401  }
402}
403
404impl<'a> ListArguments<'a> {
405  pub fn new() -> Self {
406    return ListArguments::default();
407  }
408
409  pub fn with_pagination(mut self, pagination: Pagination) -> Self {
410    self.pagination = pagination;
411    return self;
412  }
413
414  pub fn with_order(mut self, order: impl AsRef<[&'a str]>) -> Self {
415    self.order = Some(order.as_ref().to_vec());
416    return self;
417  }
418
419  pub fn with_filters(mut self, filters: impl Into<ValueOrFilterGroup>) -> Self {
420    self.filters = Some(filters.into());
421    return self;
422  }
423
424  pub fn with_expand(mut self, expand: impl AsRef<[&'a str]>) -> Self {
425    self.expand = Some(expand.as_ref().to_vec());
426    return self;
427  }
428
429  pub fn with_count(mut self, count: bool) -> Self {
430    self.count = count;
431    return self;
432  }
433}
434
435impl RecordApi {
436  pub async fn list<T: DeserializeOwned>(
437    &self,
438    args: ListArguments<'_>,
439  ) -> Result<ListResponse<T>, Error> {
440    type Param = (Cow<'static, str>, Cow<'static, str>);
441    let mut params: Vec<Param> = vec![];
442    if let Some(cursor) = args.pagination.cursor {
443      params.push((Cow::Borrowed("cursor"), Cow::Owned(cursor)));
444    }
445
446    if let Some(limit) = args.pagination.limit {
447      params.push((Cow::Borrowed("limit"), Cow::Owned(limit.to_string())));
448    }
449
450    #[inline]
451    fn to_list(slice: &[&str]) -> String {
452      return slice.join(",");
453    }
454
455    if let Some(order) = args.order
456      && !order.is_empty()
457    {
458      params.push((Cow::Borrowed("order"), Cow::Owned(to_list(&order))));
459    }
460
461    if let Some(expand) = args.expand
462      && !expand.is_empty()
463    {
464      params.push((Cow::Borrowed("expand"), Cow::Owned(to_list(&expand))));
465    }
466
467    if args.count {
468      params.push((Cow::Borrowed("count"), Cow::Borrowed("true")));
469    }
470
471    fn traverse_filters(params: &mut Vec<Param>, path: String, filter: ValueOrFilterGroup) {
472      match filter {
473        ValueOrFilterGroup::Filter(filter) => {
474          if let Some(op) = filter.op {
475            params.push((
476              Cow::Owned(format!(
477                "{path}[{col}][{op}]",
478                col = filter.column,
479                op = op.format()
480              )),
481              Cow::Owned(filter.value),
482            ));
483          } else {
484            params.push((
485              Cow::Owned(format!("{path}[{col}]", col = filter.column)),
486              Cow::Owned(filter.value),
487            ));
488          }
489        }
490        ValueOrFilterGroup::And(vec) => {
491          for (i, f) in vec.into_iter().enumerate() {
492            traverse_filters(params, format!("{path}[$and][{i}]"), f);
493          }
494        }
495        ValueOrFilterGroup::Or(vec) => {
496          for (i, f) in vec.into_iter().enumerate() {
497            traverse_filters(params, format!("{path}[$or][{i}]"), f);
498          }
499        }
500      }
501    }
502
503    if let Some(filters) = args.filters {
504      traverse_filters(&mut params, "filter".to_string(), filters);
505    }
506
507    let response = self
508      .client
509      .fetch(
510        &format!("/{RECORD_API}/{}", self.name),
511        Method::GET,
512        None,
513        Some(&params),
514        /* error_for_status= */ true,
515      )
516      .await?;
517
518    return json(response).await;
519  }
520
521  pub async fn read<'a, T: DeserializeOwned>(
522    &self,
523    args: impl ReadArgumentsTrait<'a>,
524  ) -> Result<T, Error> {
525    let expand = args
526      .expand()
527      .map(|e| vec![(Cow::Borrowed("expand"), Cow::Owned(e.join(",")))]);
528
529    let response = self
530      .client
531      .fetch(
532        &format!(
533          "/{RECORD_API}/{name}/{id}",
534          name = self.name,
535          id = args.serialized_id()
536        ),
537        Method::GET,
538        None,
539        expand.as_deref(),
540        /* error_for_status= */ true,
541      )
542      .await?;
543
544    return json(response).await;
545  }
546
547  pub async fn create<T: Serialize>(&self, record: T) -> Result<String, Error> {
548    return Ok(self.create_impl(record).await?.swap_remove(0));
549  }
550
551  pub async fn create_bulk<T: Serialize>(&self, record: &[T]) -> Result<Vec<String>, Error> {
552    return self.create_impl(record).await;
553  }
554
555  async fn create_impl<T: Serialize>(&self, record: T) -> Result<Vec<String>, Error> {
556    let response = self
557      .client
558      .fetch(
559        &format!("/{RECORD_API}/{name}", name = self.name),
560        Method::POST,
561        Some(serde_json::to_vec(&record).map_err(Error::RecordSerialization)?),
562        None,
563        /* error_for_status= */ true,
564      )
565      .await?;
566
567    #[derive(Deserialize)]
568    pub struct RecordIdResponse {
569      pub ids: Vec<String>,
570    }
571
572    return Ok(json::<RecordIdResponse>(response).await?.ids);
573  }
574
575  pub async fn update<'a, T: Serialize>(
576    &self,
577    id: impl RecordId<'a>,
578    record: T,
579  ) -> Result<(), Error> {
580    self
581      .client
582      .fetch(
583        &format!(
584          "/{RECORD_API}/{name}/{id}",
585          name = self.name,
586          id = id.serialized_id()
587        ),
588        Method::PATCH,
589        Some(serde_json::to_vec(&record).map_err(Error::RecordSerialization)?),
590        None,
591        /* error_for_status= */ true,
592      )
593      .await?;
594
595    return Ok(());
596  }
597
598  pub async fn delete<'a>(&self, id: impl RecordId<'a>) -> Result<(), Error> {
599    self
600      .client
601      .fetch(
602        &format!(
603          "/{RECORD_API}/{name}/{id}",
604          name = self.name,
605          id = id.serialized_id()
606        ),
607        Method::DELETE,
608        None,
609        None,
610        /* error_for_status= */ true,
611      )
612      .await?;
613
614    return Ok(());
615  }
616
617  pub async fn subscribe<'a, T: RecordId<'a>>(
618    &self,
619    id: T,
620  ) -> Result<impl Stream<Item = DbEvent> + use<T>, Error> {
621    // TODO: Might have to add HeaderValue::from_static("text/event-stream").
622    let response = self
623      .client
624      .fetch(
625        &format!(
626          "/{RECORD_API}/{name}/subscribe/{id}",
627          name = self.name,
628          id = id.serialized_id()
629        ),
630        Method::GET,
631        None,
632        None,
633        /* error_for_status= */ true,
634      )
635      .await?;
636
637    return Ok(
638      http_body_util::BodyDataStream::new(response.into_body())
639        .eventsource()
640        .filter_map(|event_or| {
641          // QUESTION: Should we instead return a `Stream<Item = Result<DbEvent, _>>` to allow for
642          // better error handling here.
643          if let Ok(event) = event_or {
644            return serde_json::from_str::<DbEvent>(&event.data).ok();
645          }
646          return None;
647        }),
648    );
649  }
650
651  #[cfg(feature = "ws")]
652  pub async fn subscribe_ws<'a, T: RecordId<'a>>(
653    &self,
654    id: T,
655  ) -> Result<impl Stream<Item = DbEvent> + use<T>, Error> {
656    let response = self
657      .client
658      .upgrade_ws(
659        &format!(
660          "/{RECORD_API}/{name}/subscribe/{id}",
661          name = self.name,
662          id = id.serialized_id()
663        ),
664        Method::GET,
665        Some(&[("ws".into(), "true".into())]),
666      )
667      .await?;
668
669    let websocket = response.into_websocket().await?;
670
671    return Ok(websocket.filter_map(|message| {
672      use reqwest_websocket::Message;
673
674      return match message {
675        Ok(Message::Text(msg)) => serde_json::from_str::<DbEvent>(&msg)
676          .map_err(|err| {
677            warn!("json error: {err}");
678            return err;
679          })
680          .ok(),
681        msg => {
682          warn!("unexpected msg: {msg:?}");
683          None
684        }
685      };
686    }));
687  }
688}
689
690#[derive(Clone, Debug)]
691struct TokenState {
692  state: Option<(Tokens, JwtTokenClaims)>,
693  headers: HeaderMap,
694}
695
696impl TokenState {
697  fn build(tokens: Option<&Tokens>) -> TokenState {
698    let headers = build_headers(tokens);
699    return TokenState {
700      state: tokens.and_then(|tokens| {
701        return match decode_auth_token::<JwtTokenClaims>(&tokens.auth_token) {
702          Ok(jwt_token) => Some((tokens.clone(), jwt_token)),
703          Err(err) => {
704            error!("Failed to decode auth token: {err}");
705            None
706          }
707        };
708      }),
709      headers,
710    };
711  }
712}
713
714struct ClientState {
715  transport: Box<dyn Transport + Send + Sync>,
716  base_url: url::Url,
717  tokens: RwLock<TokenState>,
718}
719
720impl ClientState {
721  #[inline]
722  async fn fetch(
723    &self,
724    path: &str,
725    method: Method,
726    body: Option<Vec<u8>>,
727    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
728    error_for_status: bool,
729  ) -> Result<http::Response<reqwest::Body>, Error> {
730    let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
731    if let Some(refresh_token) = refresh_token {
732      let new_tokens =
733        ClientState::refresh_tokens_impl(&*self.transport, headers, refresh_token).await?;
734
735      headers = new_tokens.headers.clone();
736      *self.tokens.write() = new_tokens;
737    }
738
739    let response = self
740      .transport
741      .fetch(path, headers, method, body, query_params)
742      .await?;
743
744    if error_for_status {
745      return error_for_status_unpack(response);
746    }
747    return Ok(response);
748  }
749
750  #[cfg(feature = "ws")]
751  #[inline]
752  async fn upgrade_ws(
753    &self,
754    path: &str,
755    method: Method,
756    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
757  ) -> Result<reqwest_websocket::UpgradeResponse, Error> {
758    let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
759    if let Some(refresh_token) = refresh_token {
760      let new_tokens =
761        ClientState::refresh_tokens_impl(&*self.transport, headers, refresh_token).await?;
762
763      headers = new_tokens.headers.clone();
764      *self.tokens.write() = new_tokens;
765    }
766
767    return self
768      .client
769      .upgrade_ws(path, headers, method, query_params)
770      .await;
771  }
772
773  #[inline]
774  fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option<String>) {
775    #[inline]
776    fn should_refresh(jwt: &JwtTokenClaims) -> bool {
777      return jwt.exp - 60 < now() as i64;
778    }
779
780    let tokens = self.tokens.read();
781    let headers = tokens.headers.clone();
782    return match tokens.state {
783      Some(ref state) if should_refresh(&state.1) => (headers, state.0.refresh_token.clone()),
784      _ => (headers, None),
785    };
786  }
787
788  fn extract_headers_refresh_token(&self) -> Option<(HeaderMap, String)> {
789    let tokens = self.tokens.read();
790    let state = tokens.state.as_ref()?;
791
792    if let Some(ref refresh_token) = state.0.refresh_token {
793      return Some((tokens.headers.clone(), refresh_token.clone()));
794    }
795    return None;
796  }
797
798  async fn refresh_tokens_impl(
799    transport: &(dyn Transport + Send + Sync),
800    headers: HeaderMap,
801    refresh_token: String,
802  ) -> Result<TokenState, Error> {
803    #[derive(Serialize)]
804    struct RefreshRequest<'a> {
805      refresh_token: &'a str,
806    }
807
808    // NOTE: Do not use `ClientState::fetch`, which may do token refreshing to avoid loops.
809    let response = transport
810      .fetch(
811        &format!("/{AUTH_API}/refresh"),
812        headers,
813        Method::POST,
814        Some(
815          serde_json::to_vec(&RefreshRequest {
816            refresh_token: &refresh_token,
817          })
818          .map_err(Error::RecordSerialization)?,
819        ),
820        None,
821      )
822      .await?;
823
824    #[derive(Deserialize)]
825    struct RefreshResponse {
826      auth_token: String,
827      csrf_token: Option<String>,
828    }
829
830    let refresh_response: RefreshResponse = json(response).await?;
831    return Ok(TokenState::build(Some(&Tokens {
832      auth_token: refresh_response.auth_token,
833      refresh_token: Some(refresh_token),
834      csrf_token: refresh_response.csrf_token,
835    })));
836  }
837}
838
839#[derive(Clone)]
840pub struct Client {
841  state: Arc<ClientState>,
842}
843
844#[derive(Default)]
845pub struct ClientOptions {
846  pub tokens: Option<Tokens>,
847  pub transport: Option<Box<dyn Transport + Send + Sync>>,
848}
849
850impl Client {
851  pub fn new(
852    base_url: impl TryInto<url::Url, Error = url::ParseError>,
853    opts: Option<ClientOptions>,
854  ) -> Result<Client, Error> {
855    let opts = opts.unwrap_or_default();
856    let base_url = base_url.try_into().map_err(Error::InvalidUrl)?;
857    return Ok(Client {
858      state: Arc::new(ClientState {
859        transport: opts.transport.unwrap_or_else(|| {
860          return Box::new(DefaultTransport::new(base_url.clone()));
861        }),
862        base_url,
863        tokens: RwLock::new(TokenState::build(opts.tokens.as_ref())),
864      }),
865    });
866  }
867
868  pub fn base_url(&self) -> &url::Url {
869    return &self.state.base_url;
870  }
871
872  pub fn tokens(&self) -> Option<Tokens> {
873    return self.state.tokens.read().state.as_ref().map(|x| x.0.clone());
874  }
875
876  pub fn user(&self) -> Option<User> {
877    if let Some(state) = &self.state.tokens.read().state {
878      return Some(User {
879        sub: state.1.sub.clone(),
880        email: state.1.email.clone(),
881      });
882    }
883    return None;
884  }
885
886  pub fn records(&self, api_name: &str) -> RecordApi {
887    return RecordApi {
888      client: self.state.clone(),
889      name: api_name.to_string(),
890    };
891  }
892
893  pub async fn refresh(&self) -> Result<(), Error> {
894    let Some((headers, refresh_token)) = self.state.extract_headers_refresh_token() else {
895      // Not logged in - nothing to do.
896      return Ok(());
897    };
898
899    let new_tokens =
900      ClientState::refresh_tokens_impl(&*self.state.transport, headers, refresh_token).await?;
901
902    *self.state.tokens.write() = new_tokens;
903    return Ok(());
904  }
905
906  pub async fn login(
907    &self,
908    email: &str,
909    password: &str,
910  ) -> Result<Option<MultiFactorAuthToken>, Error> {
911    #[derive(Serialize)]
912    struct Credentials<'a> {
913      email: &'a str,
914      password: &'a str,
915    }
916
917    let response = self
918      .state
919      .fetch(
920        &format!("/{AUTH_API}/login"),
921        Method::POST,
922        Some(
923          serde_json::to_vec(&Credentials { email, password })
924            .map_err(Error::RecordSerialization)?,
925        ),
926        None,
927        /* error_for_status= */ false,
928      )
929      .await?;
930
931    if response.status() == StatusCode::FORBIDDEN {
932      let mfa_token: MultiFactorAuthToken = json(response).await?;
933      return Ok(Some(mfa_token));
934    }
935
936    let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
937    self.update_tokens(Some(&tokens));
938
939    return Ok(None);
940  }
941
942  pub async fn login_second(
943    &self,
944    mfa_token: &MultiFactorAuthToken,
945    totp_code: &str,
946  ) -> Result<(), Error> {
947    #[derive(Serialize)]
948    struct Credentials<'a> {
949      mfa_token: &'a str,
950      totp: &'a str,
951    }
952
953    let response = self
954      .state
955      .fetch(
956        &format!("/{AUTH_API}/login_mfa"),
957        Method::POST,
958        Some(
959          serde_json::to_vec(&Credentials {
960            mfa_token: &mfa_token.mfa_token,
961            totp: totp_code,
962          })
963          .map_err(Error::RecordSerialization)?,
964        ),
965        None,
966        /* error_for_status= */ true,
967      )
968      .await?;
969
970    let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
971    self.update_tokens(Some(&tokens));
972
973    return Ok(());
974  }
975
976  pub async fn request_otp(&self, email: &str, redirect_uri: Option<&str>) -> Result<(), Error> {
977    #[derive(Serialize)]
978    struct Credentials<'a> {
979      email: &'a str,
980      redirect_uri: Option<&'a str>,
981    }
982
983    let _response = self
984      .state
985      .fetch(
986        &format!("/{AUTH_API}/otp/request"),
987        Method::POST,
988        Some(
989          serde_json::to_vec(&Credentials {
990            email,
991            redirect_uri,
992          })
993          .map_err(Error::RecordSerialization)?,
994        ),
995        None,
996        /* error_for_status= */ true,
997      )
998      .await?;
999
1000    return Ok(());
1001  }
1002
1003  pub async fn login_otp(&self, email: &str, code: &str) -> Result<(), Error> {
1004    #[derive(Serialize)]
1005    struct Credentials<'a> {
1006      email: &'a str,
1007      code: &'a str,
1008    }
1009
1010    let response = self
1011      .state
1012      .fetch(
1013        &format!("/{AUTH_API}/otp/login"),
1014        Method::POST,
1015        Some(serde_json::to_vec(&Credentials { email, code }).map_err(Error::RecordSerialization)?),
1016        None,
1017        /* error_for_status= */ true,
1018      )
1019      .await?;
1020
1021    let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
1022    self.update_tokens(Some(&tokens));
1023
1024    return Ok(());
1025  }
1026
1027  pub async fn logout(&self) -> Result<(), Error> {
1028    #[derive(Serialize)]
1029    struct LogoutRequest {
1030      refresh_token: String,
1031    }
1032
1033    let response_or = match self.state.extract_headers_refresh_token() {
1034      Some((_headers, refresh_token)) => {
1035        self
1036          .state
1037          .fetch(
1038            &format!("/{AUTH_API}/logout"),
1039            Method::POST,
1040            Some(
1041              serde_json::to_vec(&LogoutRequest { refresh_token })
1042                .map_err(Error::RecordSerialization)?,
1043            ),
1044            None,
1045            /* error_for_status= */ true,
1046          )
1047          .await
1048      }
1049      _ => {
1050        self
1051          .state
1052          .fetch(
1053            &format!("/{AUTH_API}/logout"),
1054            Method::GET,
1055            None,
1056            None,
1057            /* error_for_status= */ true,
1058          )
1059          .await
1060      }
1061    };
1062
1063    self.update_tokens(None);
1064
1065    return response_or.map(|_| ());
1066  }
1067
1068  fn update_tokens(&self, tokens: Option<&Tokens>) -> TokenState {
1069    let state = TokenState::build(tokens);
1070
1071    *self.state.tokens.write() = state.clone();
1072    // _authChange?.call(this, state.state?.$1);
1073
1074    if let Some(ref s) = state.state {
1075      let now = now();
1076      if s.1.exp < now as i64 {
1077        warn!("Token expired");
1078      }
1079    }
1080
1081    return state;
1082  }
1083}
1084
1085fn build_headers(tokens: Option<&Tokens>) -> HeaderMap {
1086  let mut base = HeaderMap::with_capacity(5);
1087  base.insert(
1088    header::CONTENT_TYPE,
1089    HeaderValue::from_static("application/json"),
1090  );
1091
1092  if let Some(tokens) = tokens {
1093    if let Ok(value) = HeaderValue::from_str(&format!("Bearer {}", tokens.auth_token)) {
1094      base.insert(header::AUTHORIZATION, value);
1095    } else {
1096      error!("Failed to build bearer token.");
1097    }
1098
1099    if let Some(ref refresh) = tokens.refresh_token {
1100      if let Ok(value) = HeaderValue::from_str(refresh) {
1101        base.insert("Refresh-Token", value);
1102      } else {
1103        error!("Failed to build refresh token header.");
1104      }
1105    }
1106
1107    if let Some(ref csrf) = tokens.csrf_token {
1108      if let Ok(value) = HeaderValue::from_str(csrf) {
1109        base.insert("CSRF-Token", value);
1110      } else {
1111        error!("Failed to build refresh token header.");
1112      }
1113    }
1114  }
1115
1116  return base;
1117}
1118
1119fn now() -> u64 {
1120  return std::time::SystemTime::now()
1121    .duration_since(std::time::UNIX_EPOCH)
1122    .expect("Duration since epoch")
1123    .as_secs();
1124}
1125
1126#[inline]
1127async fn json<T: DeserializeOwned>(resp: http::Response<reqwest::Body>) -> Result<T, Error> {
1128  let full = into_bytes(resp).await?;
1129  return serde_json::from_slice(&full).map_err(Error::RecordSerialization);
1130}
1131
1132#[inline]
1133async fn into_bytes(resp: http::Response<reqwest::Body>) -> Result<bytes::Bytes, Error> {
1134  return Ok(
1135    http_body_util::BodyExt::collect(resp.into_body())
1136      .await
1137      .map(|buf| buf.to_bytes())?,
1138  );
1139}
1140
1141fn error_for_status_unpack(
1142  resp: http::Response<reqwest::Body>,
1143) -> Result<http::Response<reqwest::Body>, Error> {
1144  let status = resp.status();
1145  if status.is_client_error() || status.is_server_error() {
1146    return Err(Error::HttpStatus(status));
1147  }
1148  return Ok(resp);
1149}
1150
1151const AUTH_API: &str = "api/auth/v1";
1152const RECORD_API: &str = "api/records/v1";
1153
1154#[cfg(test)]
1155mod tests {
1156  use super::*;
1157
1158  #[tokio::test]
1159  async fn is_send_test() {
1160    let client = Client::new("http://127.0.0.1:4000", None).unwrap();
1161
1162    let api = client.records("simple_strict_table");
1163
1164    for _ in 0..2 {
1165      let api = api.clone();
1166      tokio::spawn(async move {
1167        // This would not compile if locks would be held across async function calls.
1168        let response = api.read::<serde_json::Value>(0).await;
1169        assert!(response.is_err());
1170      })
1171      .await
1172      .unwrap();
1173    }
1174  }
1175}