Skip to main content

trailbase_client/
lib.rs

1//! A client library to connect to a TrailBase server via HTTP.
2//!
3//! TrailBase is a sub-millisecond, open-source application server with type-safe APIs, built-in
4//! WASM runtime, realtime, auth, and admin UI built on Rust, SQLite & Wasmtime.
5
6#![forbid(unsafe_code, clippy::unwrap_used)]
7#![allow(clippy::needless_return)]
8#![warn(clippy::await_holding_lock, clippy::inefficient_to_string)]
9
10use eventsource_stream::Eventsource;
11use futures_lite::StreamExt;
12use parking_lot::RwLock;
13use reqwest::header::{self, HeaderMap, HeaderValue};
14use reqwest::{Method, StatusCode};
15use serde::de::DeserializeOwned;
16use serde::{Deserialize, Serialize};
17use serde_repr::{Deserialize_repr, Serialize_repr};
18use std::borrow::Cow;
19use std::sync::Arc;
20use thiserror::Error;
21use tracing::*;
22
23pub use futures_lite::Stream;
24
25#[derive(Debug, Error)]
26#[non_exhaustive]
27pub enum Error {
28  #[error("HTTP status: {0}")]
29  HttpStatus(StatusCode),
30
31  #[error("RecordSerialization: {0}")]
32  RecordSerialization(serde_json::Error),
33
34  #[error("InvalidToken: {0}")]
35  InvalidToken(jsonwebtoken::errors::Error),
36
37  #[error("InvalidUrl: {0}")]
38  InvalidUrl(url::ParseError),
39
40  // NOTE: This error is leaky but comprehensively unpacking reqwest is unsustainable.
41  #[error("Reqwest: {0}")]
42  OtherReqwest(reqwest::Error),
43
44  #[cfg(feature = "ws")]
45  #[error("WebSocket: {0}")]
46  WebSocket(#[from] reqwest_websocket::Error),
47}
48
49impl From<reqwest::Error> for Error {
50  fn from(err: reqwest::Error) -> Self {
51    match err.status() {
52      Some(code) => Self::HttpStatus(code),
53      _ => Self::OtherReqwest(err),
54    }
55  }
56}
57
58/// Represents the currently logged-in user.
59#[derive(Clone, Debug)]
60pub struct User {
61  pub sub: String,
62  pub email: String,
63}
64
65/// Holds the tokens minted by the server on login.
66///
67/// It is also the exact JSON serialization format.
68#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
69pub struct Tokens {
70  pub auth_token: String,
71  pub refresh_token: Option<String>,
72  pub csrf_token: Option<String>,
73}
74
75#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
76pub struct MultiFactorAuthToken {
77  mfa_token: String,
78}
79
80#[derive(Clone, Debug, Default, PartialEq)]
81pub struct Pagination {
82  cursor: Option<String>,
83  limit: Option<usize>,
84  offset: Option<usize>,
85}
86
87impl Pagination {
88  pub fn new() -> Self {
89    return Self::default();
90  }
91
92  pub fn with_limit(mut self, limit: impl Into<Option<usize>>) -> Pagination {
93    self.limit = limit.into();
94    return self;
95  }
96
97  pub fn with_cursor(mut self, cursor: impl Into<Option<String>>) -> Pagination {
98    self.cursor = cursor.into();
99    return self;
100  }
101
102  pub fn with_offset(mut self, offset: impl Into<Option<usize>>) -> Pagination {
103    self.offset = offset.into();
104    return self;
105  }
106}
107
108type JsonObject = serde_json::value::Map<String, serde_json::Value>;
109
110#[derive(Debug, Clone, Copy, Deserialize_repr, Serialize_repr, PartialEq)]
111#[repr(i64)]
112pub enum EventErrorStatus {
113  /// Unknown or unspecified error.
114  Unknown = 0,
115  /// Access forbidden.
116  Forbidden = 1,
117  /// Server-side event-loss, e.g. a buffer ran out of capacity. This does not account for
118  /// additional losses that may happen between the TrailBase server and the client. This
119  /// needs to be determined client-side based on event `seq` numbers.
120  Loss = 2,
121}
122
123#[derive(Debug, Clone, Deserialize, PartialEq)]
124pub enum EventPayload {
125  Update(JsonObject),
126  Insert(JsonObject),
127  Delete(JsonObject),
128  Error {
129    status: EventErrorStatus,
130    message: Option<String>,
131  },
132}
133
134#[derive(Debug, Clone, Deserialize, PartialEq)]
135pub struct ChangeEvent {
136  #[serde(flatten)]
137  pub event: Arc<EventPayload>,
138  pub seq: Option<i64>,
139}
140
141impl ChangeEvent {
142  fn from_str(msg: &str) -> Result<ChangeEvent, serde_json::Error> {
143    return serde_json::from_str::<ChangeEvent>(msg);
144  }
145}
146
147#[derive(Clone, Debug, Deserialize)]
148pub struct ListResponse<T> {
149  pub cursor: Option<String>,
150  pub total_count: Option<usize>,
151  pub records: Vec<T>,
152}
153
154pub trait RecordId<'a> {
155  fn serialized_id(self) -> Cow<'a, str>;
156}
157
158impl RecordId<'_> for String {
159  fn serialized_id(self) -> Cow<'static, str> {
160    return Cow::Owned(self);
161  }
162}
163
164impl<'a> RecordId<'a> for &'a String {
165  fn serialized_id(self) -> Cow<'a, str> {
166    return Cow::Borrowed(self);
167  }
168}
169
170impl<'a> RecordId<'a> for &'a str {
171  fn serialized_id(self) -> Cow<'a, str> {
172    return Cow::Borrowed(self);
173  }
174}
175
176impl RecordId<'_> for i64 {
177  fn serialized_id(self) -> Cow<'static, str> {
178    return Cow::Owned(self.to_string());
179  }
180}
181
182pub trait ReadArgumentsTrait<'a> {
183  fn serialized_id(self) -> Cow<'a, str>;
184  fn expand(&self) -> Option<&Vec<&'a str>>;
185}
186
187impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for T {
188  fn serialized_id(self) -> Cow<'a, str> {
189    return self.serialized_id();
190  }
191
192  fn expand(&self) -> Option<&Vec<&'a str>> {
193    return None;
194  }
195}
196
197#[derive(Clone, Debug, PartialEq)]
198pub struct ReadArguments<'a, T: RecordId<'a>> {
199  id: T,
200  expand: Option<Vec<&'a str>>,
201}
202
203impl<'a, T: RecordId<'a>> ReadArguments<'a, T> {
204  pub fn new(id: T) -> Self {
205    return Self { id, expand: None };
206  }
207
208  pub fn with_expand(mut self, expand: impl AsRef<[&'a str]>) -> Self {
209    self.expand = Some(expand.as_ref().to_vec());
210    return self;
211  }
212}
213
214impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for ReadArguments<'a, T> {
215  fn serialized_id(self) -> Cow<'a, str> {
216    return self.id.serialized_id();
217  }
218
219  fn expand(&self) -> Option<&Vec<&'a str>> {
220    return self.expand.as_ref();
221  }
222}
223
224#[async_trait::async_trait]
225pub trait Transport {
226  async fn fetch(
227    &self,
228    path: &str,
229    headers: HeaderMap,
230    method: Method,
231    body: Option<Vec<u8>>,
232    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
233  ) -> Result<http::Response<reqwest::Body>, Error>;
234
235  #[cfg(feature = "ws")]
236  async fn upgrade_ws(
237    &self,
238    path: &str,
239    headers: HeaderMap,
240    method: Method,
241    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
242  ) -> Result<reqwest_websocket::UpgradeResponse, Error>;
243}
244
245pub struct DefaultTransport {
246  client: reqwest::Client,
247  url: url::Url,
248}
249
250impl DefaultTransport {
251  pub fn new(url: url::Url) -> Self {
252    return Self {
253      client: reqwest::Client::new(),
254      url,
255    };
256  }
257}
258
259#[async_trait::async_trait]
260impl Transport for DefaultTransport {
261  async fn fetch(
262    &self,
263    path: &str,
264    headers: HeaderMap,
265    method: Method,
266    body: Option<Vec<u8>>,
267    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
268  ) -> Result<http::Response<reqwest::Body>, Error> {
269    assert!(path.starts_with("/"));
270
271    let mut url = self.url.clone();
272    url.set_path(path);
273
274    if let Some(query_params) = query_params {
275      let mut params = url.query_pairs_mut();
276      for (key, value) in query_params {
277        params.append_pair(key, value);
278      }
279    }
280
281    let request = {
282      let mut builder = self.client.request(method, url).headers(headers);
283      if let Some(body) = body {
284        // let json = serde_json::to_string(body).map_err(Error::RecordSerialization)?;
285        // builder = builder.body(json);
286        builder = builder.body(body);
287      }
288      builder.build()?
289    };
290
291    return Ok(self.client.execute(request).await?.into());
292  }
293
294  #[cfg(feature = "ws")]
295  async fn upgrade_ws(
296    &self,
297    path: &str,
298    headers: HeaderMap,
299    method: Method,
300    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
301  ) -> Result<reqwest_websocket::UpgradeResponse, Error> {
302    use reqwest_websocket::Upgrade;
303
304    assert!(path.starts_with("/"));
305
306    let mut url = self.url.clone();
307    url.set_path(path);
308
309    if let Some(query_params) = query_params {
310      let mut params = url.query_pairs_mut();
311      for (key, value) in query_params {
312        params.append_pair(key, value);
313      }
314    }
315
316    return Ok(
317      self
318        .client
319        .request(method, url)
320        .headers(headers)
321        .upgrade()
322        .send()
323        .await?,
324    );
325  }
326}
327
328#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
329struct JwtTokenClaims {
330  sub: String,
331  iat: i64,
332  exp: i64,
333  email: String,
334  csrf_token: String,
335}
336
337fn decode_auth_token<T: DeserializeOwned + Clone>(token: &str) -> Result<T, Error> {
338  return jsonwebtoken::dangerous::insecure_decode::<T>(token)
339    .map(|data| data.claims)
340    .map_err(Error::InvalidToken);
341}
342
343#[derive(Clone)]
344pub struct RecordApi {
345  client: Arc<ClientState>,
346  name: String,
347}
348
349#[derive(Clone, Debug, Default, PartialEq)]
350pub struct ListArguments<'a> {
351  pagination: Pagination,
352  order: Option<Vec<&'a str>>,
353  filters: Option<ValueOrFilterGroup>,
354  expand: Option<Vec<&'a str>>,
355  count: bool,
356}
357
358#[derive(Clone, Copy, Debug, PartialEq)]
359pub enum CompareOp {
360  Equal,
361  NotEqual,
362  GreaterThanEqual,
363  GreaterThan,
364  LessThanEqual,
365  LessThan,
366  Like,
367  Regexp,
368  StWithin,
369  StIntersects,
370  StContains,
371}
372
373impl CompareOp {
374  fn format(&self) -> &'static str {
375    return match self {
376      Self::Equal => "$eq",
377      Self::NotEqual => "$ne",
378      Self::GreaterThanEqual => "$gte",
379      Self::GreaterThan => "$gt",
380      Self::LessThanEqual => "$lte",
381      Self::LessThan => "$lt",
382      Self::Like => "$like",
383      Self::Regexp => "$re",
384      Self::StWithin => "@within",
385      Self::StIntersects => "@intersects",
386      Self::StContains => "@contains",
387    };
388  }
389}
390
391#[derive(Clone, Default, Debug, PartialEq)]
392pub struct Filter {
393  pub column: String,
394  pub op: Option<CompareOp>,
395  pub value: String,
396}
397
398impl Filter {
399  pub fn new(column: impl Into<String>, op: CompareOp, value: impl Into<String>) -> Self {
400    return Self {
401      column: column.into(),
402      op: Some(op),
403      value: value.into(),
404    };
405  }
406}
407
408impl From<Filter> for ValueOrFilterGroup {
409  fn from(value: Filter) -> Self {
410    return ValueOrFilterGroup::Filter(value);
411  }
412}
413
414#[derive(Clone, Debug, PartialEq)]
415pub enum ValueOrFilterGroup {
416  Filter(Filter),
417  And(Vec<ValueOrFilterGroup>),
418  Or(Vec<ValueOrFilterGroup>),
419}
420
421impl<F> From<F> for ValueOrFilterGroup
422where
423  F: Into<Vec<Filter>>,
424{
425  fn from(filters: F) -> Self {
426    return ValueOrFilterGroup::And(
427      filters
428        .into()
429        .into_iter()
430        .map(ValueOrFilterGroup::Filter)
431        .collect(),
432    );
433  }
434}
435
436impl<'a> ListArguments<'a> {
437  pub fn new() -> Self {
438    return ListArguments::default();
439  }
440
441  pub fn with_pagination(mut self, pagination: Pagination) -> Self {
442    self.pagination = pagination;
443    return self;
444  }
445
446  pub fn with_order(mut self, order: impl AsRef<[&'a str]>) -> Self {
447    self.order = Some(order.as_ref().to_vec());
448    return self;
449  }
450
451  pub fn with_filters(mut self, filters: impl Into<ValueOrFilterGroup>) -> Self {
452    self.filters = Some(filters.into());
453    return self;
454  }
455
456  pub fn with_expand(mut self, expand: impl AsRef<[&'a str]>) -> Self {
457    self.expand = Some(expand.as_ref().to_vec());
458    return self;
459  }
460
461  pub fn with_count(mut self, count: bool) -> Self {
462    self.count = count;
463    return self;
464  }
465}
466
467impl RecordApi {
468  pub async fn list<T: DeserializeOwned>(
469    &self,
470    args: ListArguments<'_>,
471  ) -> Result<ListResponse<T>, Error> {
472    type Param = (Cow<'static, str>, Cow<'static, str>);
473    let mut params: Vec<Param> = vec![];
474    if let Some(cursor) = args.pagination.cursor {
475      params.push((Cow::Borrowed("cursor"), Cow::Owned(cursor)));
476    }
477
478    if let Some(limit) = args.pagination.limit {
479      params.push((Cow::Borrowed("limit"), Cow::Owned(limit.to_string())));
480    }
481
482    #[inline]
483    fn to_list(slice: &[&str]) -> String {
484      return slice.join(",");
485    }
486
487    if let Some(order) = args.order
488      && !order.is_empty()
489    {
490      params.push((Cow::Borrowed("order"), Cow::Owned(to_list(&order))));
491    }
492
493    if let Some(expand) = args.expand
494      && !expand.is_empty()
495    {
496      params.push((Cow::Borrowed("expand"), Cow::Owned(to_list(&expand))));
497    }
498
499    if args.count {
500      params.push((Cow::Borrowed("count"), Cow::Borrowed("true")));
501    }
502
503    fn traverse_filters(params: &mut Vec<Param>, path: String, filter: ValueOrFilterGroup) {
504      match filter {
505        ValueOrFilterGroup::Filter(filter) => {
506          if let Some(op) = filter.op {
507            params.push((
508              Cow::Owned(format!(
509                "{path}[{col}][{op}]",
510                col = filter.column,
511                op = op.format()
512              )),
513              Cow::Owned(filter.value),
514            ));
515          } else {
516            params.push((
517              Cow::Owned(format!("{path}[{col}]", col = filter.column)),
518              Cow::Owned(filter.value),
519            ));
520          }
521        }
522        ValueOrFilterGroup::And(vec) => {
523          for (i, f) in vec.into_iter().enumerate() {
524            traverse_filters(params, format!("{path}[$and][{i}]"), f);
525          }
526        }
527        ValueOrFilterGroup::Or(vec) => {
528          for (i, f) in vec.into_iter().enumerate() {
529            traverse_filters(params, format!("{path}[$or][{i}]"), f);
530          }
531        }
532      }
533    }
534
535    if let Some(filters) = args.filters {
536      traverse_filters(&mut params, "filter".to_string(), filters);
537    }
538
539    let response = self
540      .client
541      .fetch(
542        &format!("/{RECORD_API}/{}", self.name),
543        Method::GET,
544        None,
545        Some(&params),
546        /* error_for_status= */ true,
547      )
548      .await?;
549
550    return json(response).await;
551  }
552
553  pub async fn read<'a, T: DeserializeOwned>(
554    &self,
555    args: impl ReadArgumentsTrait<'a>,
556  ) -> Result<T, Error> {
557    let expand = args
558      .expand()
559      .map(|e| vec![(Cow::Borrowed("expand"), Cow::Owned(e.join(",")))]);
560
561    let response = self
562      .client
563      .fetch(
564        &format!(
565          "/{RECORD_API}/{name}/{id}",
566          name = self.name,
567          id = args.serialized_id()
568        ),
569        Method::GET,
570        None,
571        expand.as_deref(),
572        /* error_for_status= */ true,
573      )
574      .await?;
575
576    return json(response).await;
577  }
578
579  pub async fn create<T: Serialize>(&self, record: T) -> Result<String, Error> {
580    return Ok(self.create_impl(record).await?.swap_remove(0));
581  }
582
583  pub async fn create_bulk<T: Serialize>(&self, record: &[T]) -> Result<Vec<String>, Error> {
584    return self.create_impl(record).await;
585  }
586
587  async fn create_impl<T: Serialize>(&self, record: T) -> Result<Vec<String>, Error> {
588    let response = self
589      .client
590      .fetch(
591        &format!("/{RECORD_API}/{name}", name = self.name),
592        Method::POST,
593        Some(serde_json::to_vec(&record).map_err(Error::RecordSerialization)?),
594        None,
595        /* error_for_status= */ true,
596      )
597      .await?;
598
599    #[derive(Deserialize)]
600    pub struct RecordIdResponse {
601      pub ids: Vec<String>,
602    }
603
604    return Ok(json::<RecordIdResponse>(response).await?.ids);
605  }
606
607  pub async fn update<'a, T: Serialize>(
608    &self,
609    id: impl RecordId<'a>,
610    record: T,
611  ) -> Result<(), Error> {
612    self
613      .client
614      .fetch(
615        &format!(
616          "/{RECORD_API}/{name}/{id}",
617          name = self.name,
618          id = id.serialized_id()
619        ),
620        Method::PATCH,
621        Some(serde_json::to_vec(&record).map_err(Error::RecordSerialization)?),
622        None,
623        /* error_for_status= */ true,
624      )
625      .await?;
626
627    return Ok(());
628  }
629
630  pub async fn delete<'a>(&self, id: impl RecordId<'a>) -> Result<(), Error> {
631    self
632      .client
633      .fetch(
634        &format!(
635          "/{RECORD_API}/{name}/{id}",
636          name = self.name,
637          id = id.serialized_id()
638        ),
639        Method::DELETE,
640        None,
641        None,
642        /* error_for_status= */ true,
643      )
644      .await?;
645
646    return Ok(());
647  }
648
649  pub async fn subscribe<'a, T: RecordId<'a>>(
650    &self,
651    id: T,
652  ) -> Result<impl Stream<Item = ChangeEvent> + use<T>, Error> {
653    // TODO: Might have to add HeaderValue::from_static("text/event-stream").
654    let response = self
655      .client
656      .fetch(
657        &format!(
658          "/{RECORD_API}/{name}/subscribe/{id}",
659          name = self.name,
660          id = id.serialized_id()
661        ),
662        Method::GET,
663        None,
664        None,
665        /* error_for_status= */ true,
666      )
667      .await?;
668
669    return Ok(
670      http_body_util::BodyDataStream::new(response.into_body())
671        .eventsource()
672        .filter_map(|event_or| {
673          // QUESTION: Should we instead return a `Stream<Item = Result<ChangeEvent, _>>` to allow
674          // for better error handling here.
675          if let Ok(event) = event_or {
676            return ChangeEvent::from_str(&event.data)
677              .map_err(|err| {
678                warn!("Failed to parse change event: {}", event.data);
679                return err;
680              })
681              .ok();
682          }
683          return None;
684        }),
685    );
686  }
687
688  #[cfg(feature = "ws")]
689  pub async fn subscribe_ws<'a, T: RecordId<'a>>(
690    &self,
691    id: T,
692  ) -> Result<impl Stream<Item = ChangeEvent> + use<T>, Error> {
693    let response = self
694      .client
695      .upgrade_ws(
696        &format!(
697          "/{RECORD_API}/{name}/subscribe/{id}",
698          name = self.name,
699          id = id.serialized_id()
700        ),
701        Method::GET,
702        Some(&[("ws".into(), "true".into())]),
703      )
704      .await?;
705
706    let websocket = response.into_websocket().await?;
707
708    return Ok(websocket.filter_map(|message| {
709      use reqwest_websocket::Message;
710
711      return match message {
712        Ok(Message::Text(msg)) => serde_json::from_str::<ChangeEvent>(&msg)
713          .map_err(|err| {
714            warn!("json error: {err}");
715            return err;
716          })
717          .ok(),
718        msg => {
719          warn!("unexpected msg: {msg:?}");
720          None
721        }
722      };
723    }));
724  }
725}
726
727#[derive(Clone, Debug)]
728struct TokenState {
729  state: Option<(Tokens, JwtTokenClaims)>,
730  headers: HeaderMap,
731}
732
733impl TokenState {
734  fn build(tokens: Option<&Tokens>) -> TokenState {
735    let headers = build_headers(tokens);
736    return TokenState {
737      state: tokens.and_then(|tokens| {
738        return match decode_auth_token::<JwtTokenClaims>(&tokens.auth_token) {
739          Ok(jwt_token) => Some((tokens.clone(), jwt_token)),
740          Err(err) => {
741            error!("Failed to decode auth token: {err}");
742            None
743          }
744        };
745      }),
746      headers,
747    };
748  }
749}
750
751struct ClientState {
752  transport: Box<dyn Transport + Send + Sync>,
753  base_url: url::Url,
754  tokens: RwLock<TokenState>,
755}
756
757impl ClientState {
758  #[inline]
759  async fn fetch(
760    &self,
761    path: &str,
762    method: Method,
763    body: Option<Vec<u8>>,
764    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
765    error_for_status: bool,
766  ) -> Result<http::Response<reqwest::Body>, Error> {
767    let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
768    if let Some(refresh_token) = refresh_token {
769      let new_tokens =
770        ClientState::refresh_tokens_impl(&*self.transport, headers, refresh_token).await?;
771
772      headers = new_tokens.headers.clone();
773      *self.tokens.write() = new_tokens;
774    }
775
776    let response = self
777      .transport
778      .fetch(path, headers, method, body, query_params)
779      .await?;
780
781    if error_for_status {
782      return error_for_status_unpack(response);
783    }
784    return Ok(response);
785  }
786
787  #[cfg(feature = "ws")]
788  #[inline]
789  async fn upgrade_ws(
790    &self,
791    path: &str,
792    method: Method,
793    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
794  ) -> Result<reqwest_websocket::UpgradeResponse, Error> {
795    let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
796    if let Some(refresh_token) = refresh_token {
797      let new_tokens =
798        ClientState::refresh_tokens_impl(&*self.transport, headers, refresh_token).await?;
799
800      headers = new_tokens.headers.clone();
801      *self.tokens.write() = new_tokens;
802    }
803
804    return self
805      .transport
806      .upgrade_ws(path, headers, method, query_params)
807      .await;
808  }
809
810  #[inline]
811  fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option<String>) {
812    #[inline]
813    fn should_refresh(jwt: &JwtTokenClaims) -> bool {
814      return jwt.exp - 60 < now() as i64;
815    }
816
817    let tokens = self.tokens.read();
818    let headers = tokens.headers.clone();
819    return match tokens.state {
820      Some(ref state) if should_refresh(&state.1) => (headers, state.0.refresh_token.clone()),
821      _ => (headers, None),
822    };
823  }
824
825  fn extract_headers_refresh_token(&self) -> Option<(HeaderMap, String)> {
826    let tokens = self.tokens.read();
827    let state = tokens.state.as_ref()?;
828
829    if let Some(ref refresh_token) = state.0.refresh_token {
830      return Some((tokens.headers.clone(), refresh_token.clone()));
831    }
832    return None;
833  }
834
835  async fn refresh_tokens_impl(
836    transport: &(dyn Transport + Send + Sync),
837    headers: HeaderMap,
838    refresh_token: String,
839  ) -> Result<TokenState, Error> {
840    #[derive(Serialize)]
841    struct RefreshRequest<'a> {
842      refresh_token: &'a str,
843    }
844
845    // NOTE: Do not use `ClientState::fetch`, which may do token refreshing to avoid loops.
846    let response = transport
847      .fetch(
848        &format!("/{AUTH_API}/refresh"),
849        headers,
850        Method::POST,
851        Some(
852          serde_json::to_vec(&RefreshRequest {
853            refresh_token: &refresh_token,
854          })
855          .map_err(Error::RecordSerialization)?,
856        ),
857        None,
858      )
859      .await?;
860
861    #[derive(Deserialize)]
862    struct RefreshResponse {
863      auth_token: String,
864      csrf_token: Option<String>,
865    }
866
867    let refresh_response: RefreshResponse = json(response).await?;
868    return Ok(TokenState::build(Some(&Tokens {
869      auth_token: refresh_response.auth_token,
870      refresh_token: Some(refresh_token),
871      csrf_token: refresh_response.csrf_token,
872    })));
873  }
874}
875
876#[derive(Clone)]
877pub struct Client {
878  state: Arc<ClientState>,
879}
880
881#[derive(Default)]
882pub struct ClientOptions {
883  pub tokens: Option<Tokens>,
884  pub transport: Option<Box<dyn Transport + Send + Sync>>,
885}
886
887impl Client {
888  pub fn new(
889    base_url: impl TryInto<url::Url, Error = url::ParseError>,
890    opts: Option<ClientOptions>,
891  ) -> Result<Client, Error> {
892    let opts = opts.unwrap_or_default();
893    let base_url = base_url.try_into().map_err(Error::InvalidUrl)?;
894    return Ok(Client {
895      state: Arc::new(ClientState {
896        transport: opts.transport.unwrap_or_else(|| {
897          return Box::new(DefaultTransport::new(base_url.clone()));
898        }),
899        base_url,
900        tokens: RwLock::new(TokenState::build(opts.tokens.as_ref())),
901      }),
902    });
903  }
904
905  pub fn base_url(&self) -> &url::Url {
906    return &self.state.base_url;
907  }
908
909  pub fn tokens(&self) -> Option<Tokens> {
910    return self.state.tokens.read().state.as_ref().map(|x| x.0.clone());
911  }
912
913  pub fn user(&self) -> Option<User> {
914    if let Some(state) = &self.state.tokens.read().state {
915      return Some(User {
916        sub: state.1.sub.clone(),
917        email: state.1.email.clone(),
918      });
919    }
920    return None;
921  }
922
923  pub fn records(&self, api_name: &str) -> RecordApi {
924    return RecordApi {
925      client: self.state.clone(),
926      name: api_name.to_string(),
927    };
928  }
929
930  pub async fn refresh(&self) -> Result<(), Error> {
931    let Some((headers, refresh_token)) = self.state.extract_headers_refresh_token() else {
932      // Not logged in - nothing to do.
933      return Ok(());
934    };
935
936    let new_tokens =
937      ClientState::refresh_tokens_impl(&*self.state.transport, headers, refresh_token).await?;
938
939    *self.state.tokens.write() = new_tokens;
940    return Ok(());
941  }
942
943  pub async fn login(
944    &self,
945    email: &str,
946    password: &str,
947  ) -> Result<Option<MultiFactorAuthToken>, Error> {
948    #[derive(Serialize)]
949    struct Credentials<'a> {
950      email: &'a str,
951      password: &'a str,
952    }
953
954    let response = self
955      .state
956      .fetch(
957        &format!("/{AUTH_API}/login"),
958        Method::POST,
959        Some(
960          serde_json::to_vec(&Credentials { email, password })
961            .map_err(Error::RecordSerialization)?,
962        ),
963        None,
964        /* error_for_status= */ false,
965      )
966      .await?;
967
968    if response.status() == StatusCode::FORBIDDEN {
969      let mfa_token: MultiFactorAuthToken = json(response).await?;
970      return Ok(Some(mfa_token));
971    }
972
973    let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
974    self.update_tokens(Some(&tokens));
975
976    return Ok(None);
977  }
978
979  pub async fn login_second(
980    &self,
981    mfa_token: &MultiFactorAuthToken,
982    totp_code: &str,
983  ) -> Result<(), Error> {
984    #[derive(Serialize)]
985    struct Credentials<'a> {
986      mfa_token: &'a str,
987      totp: &'a str,
988    }
989
990    let response = self
991      .state
992      .fetch(
993        &format!("/{AUTH_API}/login_mfa"),
994        Method::POST,
995        Some(
996          serde_json::to_vec(&Credentials {
997            mfa_token: &mfa_token.mfa_token,
998            totp: totp_code,
999          })
1000          .map_err(Error::RecordSerialization)?,
1001        ),
1002        None,
1003        /* error_for_status= */ true,
1004      )
1005      .await?;
1006
1007    let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
1008    self.update_tokens(Some(&tokens));
1009
1010    return Ok(());
1011  }
1012
1013  pub async fn request_otp(&self, email: &str, redirect_uri: Option<&str>) -> Result<(), Error> {
1014    #[derive(Serialize)]
1015    struct Credentials<'a> {
1016      email: &'a str,
1017      redirect_uri: Option<&'a str>,
1018    }
1019
1020    let _response = self
1021      .state
1022      .fetch(
1023        &format!("/{AUTH_API}/otp/request"),
1024        Method::POST,
1025        Some(
1026          serde_json::to_vec(&Credentials {
1027            email,
1028            redirect_uri,
1029          })
1030          .map_err(Error::RecordSerialization)?,
1031        ),
1032        None,
1033        /* error_for_status= */ true,
1034      )
1035      .await?;
1036
1037    return Ok(());
1038  }
1039
1040  pub async fn login_otp(&self, email: &str, code: &str) -> Result<(), Error> {
1041    #[derive(Serialize)]
1042    struct Credentials<'a> {
1043      email: &'a str,
1044      code: &'a str,
1045    }
1046
1047    let response = self
1048      .state
1049      .fetch(
1050        &format!("/{AUTH_API}/otp/login"),
1051        Method::POST,
1052        Some(serde_json::to_vec(&Credentials { email, code }).map_err(Error::RecordSerialization)?),
1053        None,
1054        /* error_for_status= */ true,
1055      )
1056      .await?;
1057
1058    let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
1059    self.update_tokens(Some(&tokens));
1060
1061    return Ok(());
1062  }
1063
1064  pub async fn logout(&self) -> Result<(), Error> {
1065    #[derive(Serialize)]
1066    struct LogoutRequest {
1067      refresh_token: String,
1068    }
1069
1070    let response_or = match self.state.extract_headers_refresh_token() {
1071      Some((_headers, refresh_token)) => {
1072        self
1073          .state
1074          .fetch(
1075            &format!("/{AUTH_API}/logout"),
1076            Method::POST,
1077            Some(
1078              serde_json::to_vec(&LogoutRequest { refresh_token })
1079                .map_err(Error::RecordSerialization)?,
1080            ),
1081            None,
1082            /* error_for_status= */ true,
1083          )
1084          .await
1085      }
1086      _ => {
1087        self
1088          .state
1089          .fetch(
1090            &format!("/{AUTH_API}/logout"),
1091            Method::GET,
1092            None,
1093            None,
1094            /* error_for_status= */ true,
1095          )
1096          .await
1097      }
1098    };
1099
1100    self.update_tokens(None);
1101
1102    return response_or.map(|_| ());
1103  }
1104
1105  fn update_tokens(&self, tokens: Option<&Tokens>) -> TokenState {
1106    let state = TokenState::build(tokens);
1107
1108    *self.state.tokens.write() = state.clone();
1109    // _authChange?.call(this, state.state?.$1);
1110
1111    if let Some(ref s) = state.state {
1112      let now = now();
1113      if s.1.exp < now as i64 {
1114        warn!("Token expired");
1115      }
1116    }
1117
1118    return state;
1119  }
1120}
1121
1122fn build_headers(tokens: Option<&Tokens>) -> HeaderMap {
1123  let mut base = HeaderMap::with_capacity(5);
1124  base.insert(
1125    header::CONTENT_TYPE,
1126    HeaderValue::from_static("application/json"),
1127  );
1128
1129  if let Some(tokens) = tokens {
1130    if let Ok(value) = HeaderValue::from_str(&format!("Bearer {}", tokens.auth_token)) {
1131      base.insert(header::AUTHORIZATION, value);
1132    } else {
1133      error!("Failed to build bearer token.");
1134    }
1135
1136    if let Some(ref refresh) = tokens.refresh_token {
1137      if let Ok(value) = HeaderValue::from_str(refresh) {
1138        base.insert("Refresh-Token", value);
1139      } else {
1140        error!("Failed to build refresh token header.");
1141      }
1142    }
1143
1144    if let Some(ref csrf) = tokens.csrf_token {
1145      if let Ok(value) = HeaderValue::from_str(csrf) {
1146        base.insert("CSRF-Token", value);
1147      } else {
1148        error!("Failed to build refresh token header.");
1149      }
1150    }
1151  }
1152
1153  return base;
1154}
1155
1156fn now() -> u64 {
1157  return std::time::SystemTime::now()
1158    .duration_since(std::time::UNIX_EPOCH)
1159    .expect("Duration since epoch")
1160    .as_secs();
1161}
1162
1163#[inline]
1164async fn json<T: DeserializeOwned>(resp: http::Response<reqwest::Body>) -> Result<T, Error> {
1165  let full = into_bytes(resp).await?;
1166  return serde_json::from_slice(&full).map_err(Error::RecordSerialization);
1167}
1168
1169#[inline]
1170async fn into_bytes(resp: http::Response<reqwest::Body>) -> Result<bytes::Bytes, Error> {
1171  return Ok(
1172    http_body_util::BodyExt::collect(resp.into_body())
1173      .await
1174      .map(|buf| buf.to_bytes())?,
1175  );
1176}
1177
1178fn error_for_status_unpack(
1179  resp: http::Response<reqwest::Body>,
1180) -> Result<http::Response<reqwest::Body>, Error> {
1181  let status = resp.status();
1182  if status.is_client_error() || status.is_server_error() {
1183    return Err(Error::HttpStatus(status));
1184  }
1185  return Ok(resp);
1186}
1187
1188const AUTH_API: &str = "api/auth/v1";
1189const RECORD_API: &str = "api/records/v1";
1190
1191#[cfg(test)]
1192mod tests {
1193  use super::*;
1194
1195  #[tokio::test]
1196  async fn is_send_test() {
1197    let client = Client::new("http://127.0.0.1:4000", None).unwrap();
1198
1199    let api = client.records("simple_strict_table");
1200
1201    for _ in 0..2 {
1202      let api = api.clone();
1203      tokio::spawn(async move {
1204        // This would not compile if locks would be held across async function calls.
1205        let response = api.read::<serde_json::Value>(0).await;
1206        assert!(response.is_err());
1207      })
1208      .await
1209      .unwrap();
1210    }
1211  }
1212
1213  #[test]
1214  fn parse_change_event_test() {
1215    let ev0 = ChangeEvent::from_str(
1216      r#"
1217        {
1218          "Error": {
1219            "status": 1,
1220            "message": "test"
1221          },
1222          "seq": 3
1223        }"#,
1224    )
1225    .unwrap();
1226
1227    assert_eq!(ev0.seq, Some(3));
1228    let EventPayload::Error { status, message } = &*ev0.event else {
1229      panic!("expected error payload, got {:?}", ev0.event);
1230    };
1231
1232    assert_eq!(*status, EventErrorStatus::Forbidden);
1233    assert_eq!(message.as_deref().unwrap(), "test");
1234
1235    let ev1 = ChangeEvent::from_str(
1236      r#"
1237        {
1238          "Update": {
1239            "col0": "val0",
1240            "col1": 4
1241          }
1242        }"#,
1243    )
1244    .unwrap();
1245
1246    assert_eq!(ev1.seq, None);
1247    let EventPayload::Update(obj) = &*ev1.event else {
1248      panic!("expected update payload, got {:?}", ev1.event);
1249    };
1250
1251    assert_eq!(
1252      serde_json::Value::Object(obj.clone()),
1253      serde_json::json!({
1254          "col0": "val0",
1255          "col1": 4,
1256      })
1257    )
1258  }
1259}