Skip to main content

trailbase_client/
lib.rs

1//! A client library to connect to a TrailBase server via HTTP.
2//!
3//! TrailBase is a sub-millisecond, open-source application server with type-safe APIs, built-in
4//! WASM runtime, realtime, auth, and admin UI built on Rust, SQLite & Wasmtime.
5
6#![forbid(unsafe_code, clippy::unwrap_used)]
7#![allow(clippy::needless_return)]
8#![warn(clippy::await_holding_lock, clippy::inefficient_to_string)]
9
10use eventsource_stream::Eventsource;
11use futures_lite::StreamExt;
12use parking_lot::RwLock;
13use reqwest::header::{self, HeaderMap, HeaderValue};
14use reqwest::{Method, StatusCode};
15use serde::de::DeserializeOwned;
16use serde::{Deserialize, Serialize};
17use serde_repr::{Deserialize_repr, Serialize_repr};
18use std::borrow::Cow;
19use std::sync::Arc;
20use thiserror::Error;
21use tracing::*;
22
23pub use futures_lite::Stream;
24
25#[derive(Debug, Error)]
26#[non_exhaustive]
27pub enum Error {
28  #[error("HTTP status: {0}")]
29  HttpStatus(StatusCode),
30
31  #[error("RecordSerialization: {0}")]
32  RecordSerialization(serde_json::Error),
33
34  #[error("InvalidToken: {0}")]
35  InvalidToken(jsonwebtoken::errors::Error),
36
37  #[error("InvalidUrl: {0}")]
38  InvalidUrl(url::ParseError),
39
40  // NOTE: This error is leaky but comprehensively unpacking reqwest is unsustainable.
41  #[error("Reqwest: {0}")]
42  OtherReqwest(reqwest::Error),
43
44  #[cfg(feature = "ws")]
45  #[error("WebSocket: {0}")]
46  WebSocket(#[from] reqwest_websocket::Error),
47}
48
49impl From<reqwest::Error> for Error {
50  fn from(err: reqwest::Error) -> Self {
51    match err.status() {
52      Some(code) => Self::HttpStatus(code),
53      _ => Self::OtherReqwest(err),
54    }
55  }
56}
57
58/// Represents the currently logged-in user.
59#[derive(Clone, Debug)]
60pub struct User {
61  pub sub: String,
62  pub email: String,
63}
64
65/// Holds the tokens minted by the server on login.
66///
67/// It is also the exact JSON serialization format.
68#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
69pub struct Tokens {
70  pub auth_token: String,
71  pub refresh_token: Option<String>,
72  pub csrf_token: Option<String>,
73}
74
75#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
76pub struct MultiFactorAuthToken {
77  mfa_token: String,
78}
79
80#[derive(Clone, Debug, Default, PartialEq)]
81pub struct Pagination {
82  cursor: Option<String>,
83  limit: Option<usize>,
84  offset: Option<usize>,
85}
86
87impl Pagination {
88  pub fn new() -> Self {
89    return Self::default();
90  }
91
92  pub fn with_limit(mut self, limit: impl Into<Option<usize>>) -> Pagination {
93    self.limit = limit.into();
94    return self;
95  }
96
97  pub fn with_cursor(mut self, cursor: impl Into<Option<String>>) -> Pagination {
98    self.cursor = cursor.into();
99    return self;
100  }
101
102  pub fn with_offset(mut self, offset: impl Into<Option<usize>>) -> Pagination {
103    self.offset = offset.into();
104    return self;
105  }
106}
107
108type JsonObject = serde_json::value::Map<String, serde_json::Value>;
109
110#[derive(Debug, Clone, Copy, Deserialize_repr, Serialize_repr, PartialEq)]
111#[repr(i64)]
112pub enum EventErrorStatus {
113  /// Unknown or unspecified error.
114  Unknown = 0,
115  /// Access forbidden.
116  Forbidden = 1,
117  /// Server-side event-loss, e.g. a buffer ran out of capacity. This does not account for
118  /// additional losses that may happen between the TrailBase server and the client. This
119  /// needs to be determined client-side based on event `seq` numbers.
120  Loss = 2,
121}
122
123#[derive(Debug, Clone, Deserialize, PartialEq)]
124pub enum EventPayload {
125  Update(JsonObject),
126  Insert(JsonObject),
127  Delete(JsonObject),
128  Error {
129    status: EventErrorStatus,
130    message: Option<String>,
131  },
132}
133
134#[derive(Debug, Clone, Deserialize, PartialEq)]
135pub struct ChangeEvent {
136  #[serde(flatten)]
137  pub event: Arc<EventPayload>,
138  pub seq: Option<i64>,
139}
140
141impl ChangeEvent {
142  fn from_str(msg: &str) -> Result<ChangeEvent, serde_json::Error> {
143    return serde_json::from_str::<ChangeEvent>(msg);
144  }
145}
146
147#[derive(Clone, Debug, Deserialize)]
148pub struct ListResponse<T> {
149  pub cursor: Option<String>,
150  pub total_count: Option<usize>,
151  pub records: Vec<T>,
152}
153
154pub trait RecordId<'a> {
155  fn serialized_id(self) -> Cow<'a, str>;
156}
157
158impl RecordId<'_> for String {
159  fn serialized_id(self) -> Cow<'static, str> {
160    return Cow::Owned(self);
161  }
162}
163
164impl<'a> RecordId<'a> for &'a String {
165  fn serialized_id(self) -> Cow<'a, str> {
166    return Cow::Borrowed(self);
167  }
168}
169
170impl<'a> RecordId<'a> for &'a str {
171  fn serialized_id(self) -> Cow<'a, str> {
172    return Cow::Borrowed(self);
173  }
174}
175
176impl RecordId<'_> for i64 {
177  fn serialized_id(self) -> Cow<'static, str> {
178    return Cow::Owned(self.to_string());
179  }
180}
181
182pub trait ReadArgumentsTrait<'a> {
183  fn serialized_id(self) -> Cow<'a, str>;
184  fn expand(&self) -> Option<&Vec<&'a str>>;
185}
186
187impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for T {
188  fn serialized_id(self) -> Cow<'a, str> {
189    return self.serialized_id();
190  }
191
192  fn expand(&self) -> Option<&Vec<&'a str>> {
193    return None;
194  }
195}
196
197#[derive(Clone, Debug, PartialEq)]
198pub struct ReadArguments<'a, T: RecordId<'a>> {
199  id: T,
200  expand: Option<Vec<&'a str>>,
201}
202
203impl<'a, T: RecordId<'a>> ReadArguments<'a, T> {
204  pub fn new(id: T) -> Self {
205    return Self { id, expand: None };
206  }
207
208  pub fn with_expand(mut self, expand: impl AsRef<[&'a str]>) -> Self {
209    self.expand = Some(expand.as_ref().to_vec());
210    return self;
211  }
212}
213
214impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for ReadArguments<'a, T> {
215  fn serialized_id(self) -> Cow<'a, str> {
216    return self.id.serialized_id();
217  }
218
219  fn expand(&self) -> Option<&Vec<&'a str>> {
220    return self.expand.as_ref();
221  }
222}
223
224#[async_trait::async_trait]
225pub trait Transport {
226  async fn fetch(
227    &self,
228    path: &str,
229    headers: HeaderMap,
230    method: Method,
231    body: Option<Vec<u8>>,
232    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
233  ) -> Result<http::Response<reqwest::Body>, Error>;
234
235  #[cfg(feature = "ws")]
236  async fn upgrade_ws(
237    &self,
238    path: &str,
239    headers: HeaderMap,
240    method: Method,
241    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
242  ) -> Result<reqwest_websocket::UpgradeResponse, Error>;
243}
244
245pub struct DefaultTransport {
246  client: reqwest::Client,
247  url: url::Url,
248}
249
250impl DefaultTransport {
251  pub fn new(url: url::Url) -> Self {
252    return Self {
253      client: reqwest::Client::new(),
254      url,
255    };
256  }
257}
258
259#[async_trait::async_trait]
260impl Transport for DefaultTransport {
261  async fn fetch(
262    &self,
263    path: &str,
264    headers: HeaderMap,
265    method: Method,
266    body: Option<Vec<u8>>,
267    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
268  ) -> Result<http::Response<reqwest::Body>, Error> {
269    assert!(path.starts_with("/"));
270
271    let mut url = self.url.clone();
272    url.set_path(path);
273
274    if let Some(query_params) = query_params {
275      let mut params = url.query_pairs_mut();
276      for (key, value) in query_params {
277        params.append_pair(key, value);
278      }
279    }
280
281    let request = {
282      let mut builder = self.client.request(method, url).headers(headers);
283      if let Some(body) = body {
284        // let json = serde_json::to_string(body).map_err(Error::RecordSerialization)?;
285        // builder = builder.body(json);
286        builder = builder.body(body);
287      }
288      builder.build()?
289    };
290
291    return Ok(self.client.execute(request).await?.into());
292  }
293
294  #[cfg(feature = "ws")]
295  async fn upgrade_ws(
296    &self,
297    path: &str,
298    headers: HeaderMap,
299    method: Method,
300    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
301  ) -> Result<reqwest_websocket::UpgradeResponse, Error> {
302    use reqwest_websocket::Upgrade;
303
304    assert!(path.starts_with("/"));
305
306    let mut url = self.url.clone();
307    url.set_path(path);
308
309    if let Some(query_params) = query_params {
310      let mut params = url.query_pairs_mut();
311      for (key, value) in query_params {
312        params.append_pair(key, value);
313      }
314    }
315
316    return Ok(
317      self
318        .client
319        .request(method, url)
320        .headers(headers)
321        .upgrade()
322        .send()
323        .await?,
324    );
325  }
326}
327
328#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
329struct JwtTokenClaims {
330  sub: String,
331  iat: i64,
332  exp: i64,
333  email: String,
334  csrf_token: String,
335}
336
337fn decode_auth_token<T: DeserializeOwned + Clone>(token: &str) -> Result<T, Error> {
338  return jsonwebtoken::dangerous::insecure_decode::<T>(token)
339    .map(|data| data.claims)
340    .map_err(Error::InvalidToken);
341}
342
343#[derive(Clone)]
344pub struct RecordApi {
345  client: Arc<ClientState>,
346  name: String,
347}
348
349#[derive(Clone, Debug, Default, PartialEq)]
350pub struct ListArguments<'a> {
351  pagination: Pagination,
352  order: Option<Vec<&'a str>>,
353  filters: Option<ValueOrFilterGroup>,
354  expand: Option<Vec<&'a str>>,
355  count: bool,
356}
357
358#[derive(Clone, Copy, Debug, PartialEq)]
359pub enum CompareOp {
360  Equal,
361  NotEqual,
362  GreaterThanEqual,
363  GreaterThan,
364  LessThanEqual,
365  LessThan,
366  Like,
367  Regexp,
368  StWithin,
369  StIntersects,
370  StContains,
371}
372
373impl CompareOp {
374  fn format(&self) -> &'static str {
375    return match self {
376      Self::Equal => "$eq",
377      Self::NotEqual => "$ne",
378      Self::GreaterThanEqual => "$gte",
379      Self::GreaterThan => "$gt",
380      Self::LessThanEqual => "$lte",
381      Self::LessThan => "$lt",
382      Self::Like => "$like",
383      Self::Regexp => "$re",
384      Self::StWithin => "@within",
385      Self::StIntersects => "@intersects",
386      Self::StContains => "@contains",
387    };
388  }
389}
390
391#[derive(Clone, Default, Debug, PartialEq)]
392pub struct Filter {
393  pub column: String,
394  pub op: Option<CompareOp>,
395  pub value: String,
396}
397
398impl Filter {
399  pub fn new(column: impl Into<String>, op: CompareOp, value: impl Into<String>) -> Self {
400    return Self {
401      column: column.into(),
402      op: Some(op),
403      value: value.into(),
404    };
405  }
406}
407
408impl From<Filter> for ValueOrFilterGroup {
409  fn from(value: Filter) -> Self {
410    return ValueOrFilterGroup::Filter(value);
411  }
412}
413
414#[derive(Clone, Debug, PartialEq)]
415pub enum ValueOrFilterGroup {
416  Filter(Filter),
417  And(Vec<ValueOrFilterGroup>),
418  Or(Vec<ValueOrFilterGroup>),
419}
420
421impl<F> From<F> for ValueOrFilterGroup
422where
423  F: Into<Vec<Filter>>,
424{
425  fn from(filters: F) -> Self {
426    return ValueOrFilterGroup::And(
427      filters
428        .into()
429        .into_iter()
430        .map(ValueOrFilterGroup::Filter)
431        .collect(),
432    );
433  }
434}
435
436impl<'a> ListArguments<'a> {
437  pub fn new() -> Self {
438    return ListArguments::default();
439  }
440
441  pub fn with_pagination(mut self, pagination: Pagination) -> Self {
442    self.pagination = pagination;
443    return self;
444  }
445
446  pub fn with_order(mut self, order: impl AsRef<[&'a str]>) -> Self {
447    self.order = Some(order.as_ref().to_vec());
448    return self;
449  }
450
451  pub fn with_filters(mut self, filters: impl Into<ValueOrFilterGroup>) -> Self {
452    self.filters = Some(filters.into());
453    return self;
454  }
455
456  pub fn with_expand(mut self, expand: impl AsRef<[&'a str]>) -> Self {
457    self.expand = Some(expand.as_ref().to_vec());
458    return self;
459  }
460
461  pub fn with_count(mut self, count: bool) -> Self {
462    self.count = count;
463    return self;
464  }
465}
466
467impl RecordApi {
468  pub async fn list<T: DeserializeOwned>(
469    &self,
470    args: ListArguments<'_>,
471  ) -> Result<ListResponse<T>, Error> {
472    type Param = (Cow<'static, str>, Cow<'static, str>);
473    let mut params: Vec<Param> = vec![];
474    if let Some(cursor) = args.pagination.cursor {
475      params.push((Cow::Borrowed("cursor"), Cow::Owned(cursor)));
476    }
477
478    if let Some(limit) = args.pagination.limit {
479      params.push((Cow::Borrowed("limit"), Cow::Owned(limit.to_string())));
480    }
481
482    #[inline]
483    fn to_list(slice: &[&str]) -> String {
484      return slice.join(",");
485    }
486
487    if let Some(order) = args.order
488      && !order.is_empty()
489    {
490      params.push((Cow::Borrowed("order"), Cow::Owned(to_list(&order))));
491    }
492
493    if let Some(expand) = args.expand
494      && !expand.is_empty()
495    {
496      params.push((Cow::Borrowed("expand"), Cow::Owned(to_list(&expand))));
497    }
498
499    if args.count {
500      params.push((Cow::Borrowed("count"), Cow::Borrowed("true")));
501    }
502
503    fn traverse_filters(params: &mut Vec<Param>, path: String, filter: ValueOrFilterGroup) {
504      match filter {
505        ValueOrFilterGroup::Filter(filter) => {
506          if let Some(op) = filter.op {
507            params.push((
508              Cow::Owned(format!(
509                "{path}[{col}][{op}]",
510                col = filter.column,
511                op = op.format()
512              )),
513              Cow::Owned(filter.value),
514            ));
515          } else {
516            params.push((
517              Cow::Owned(format!("{path}[{col}]", col = filter.column)),
518              Cow::Owned(filter.value),
519            ));
520          }
521        }
522        ValueOrFilterGroup::And(vec) => {
523          for (i, f) in vec.into_iter().enumerate() {
524            traverse_filters(params, format!("{path}[$and][{i}]"), f);
525          }
526        }
527        ValueOrFilterGroup::Or(vec) => {
528          for (i, f) in vec.into_iter().enumerate() {
529            traverse_filters(params, format!("{path}[$or][{i}]"), f);
530          }
531        }
532      }
533    }
534
535    if let Some(filters) = args.filters {
536      traverse_filters(&mut params, "filter".to_string(), filters);
537    }
538
539    let response = self
540      .client
541      .fetch(
542        &format!("/{RECORD_API}/{}", self.name),
543        Method::GET,
544        None,
545        Some(&params),
546        /* error_for_status= */ true,
547      )
548      .await?;
549
550    return json(response).await;
551  }
552
553  pub async fn read<'a, T: DeserializeOwned>(
554    &self,
555    args: impl ReadArgumentsTrait<'a>,
556  ) -> Result<T, Error> {
557    let expand = args
558      .expand()
559      .map(|e| vec![(Cow::Borrowed("expand"), Cow::Owned(e.join(",")))]);
560
561    let response = self
562      .client
563      .fetch(
564        &format!(
565          "/{RECORD_API}/{name}/{id}",
566          name = self.name,
567          id = args.serialized_id()
568        ),
569        Method::GET,
570        None,
571        expand.as_deref(),
572        /* error_for_status= */ true,
573      )
574      .await?;
575
576    return json(response).await;
577  }
578
579  pub async fn create<T: Serialize>(&self, record: T) -> Result<String, Error> {
580    return Ok(self.create_impl(record).await?.swap_remove(0));
581  }
582
583  pub async fn create_bulk<T: Serialize>(&self, record: &[T]) -> Result<Vec<String>, Error> {
584    return self.create_impl(record).await;
585  }
586
587  async fn create_impl<T: Serialize>(&self, record: T) -> Result<Vec<String>, Error> {
588    let response = self
589      .client
590      .fetch(
591        &format!("/{RECORD_API}/{name}", name = self.name),
592        Method::POST,
593        Some(serde_json::to_vec(&record).map_err(Error::RecordSerialization)?),
594        None,
595        /* error_for_status= */ true,
596      )
597      .await?;
598
599    #[derive(Deserialize)]
600    pub struct RecordIdResponse {
601      pub ids: Vec<String>,
602    }
603
604    return Ok(json::<RecordIdResponse>(response).await?.ids);
605  }
606
607  pub async fn update<'a, T: Serialize>(
608    &self,
609    id: impl RecordId<'a>,
610    record: T,
611  ) -> Result<(), Error> {
612    self
613      .client
614      .fetch(
615        &format!(
616          "/{RECORD_API}/{name}/{id}",
617          name = self.name,
618          id = id.serialized_id()
619        ),
620        Method::PATCH,
621        Some(serde_json::to_vec(&record).map_err(Error::RecordSerialization)?),
622        None,
623        /* error_for_status= */ true,
624      )
625      .await?;
626
627    return Ok(());
628  }
629
630  pub async fn delete<'a>(&self, id: impl RecordId<'a>) -> Result<(), Error> {
631    self
632      .client
633      .fetch(
634        &format!(
635          "/{RECORD_API}/{name}/{id}",
636          name = self.name,
637          id = id.serialized_id()
638        ),
639        Method::DELETE,
640        None,
641        None,
642        /* error_for_status= */ true,
643      )
644      .await?;
645
646    return Ok(());
647  }
648
649  pub async fn subscribe<'a, T: RecordId<'a>>(
650    &self,
651    id: T,
652  ) -> Result<impl Stream<Item = ChangeEvent> + use<T>, Error> {
653    // TODO: Might have to add HeaderValue::from_static("text/event-stream").
654    let response = self
655      .client
656      .fetch(
657        &format!(
658          "/{RECORD_API}/{name}/subscribe/{id}",
659          name = self.name,
660          id = id.serialized_id()
661        ),
662        Method::GET,
663        None,
664        None,
665        /* error_for_status= */ true,
666      )
667      .await?;
668
669    return Ok(
670      http_body_util::BodyDataStream::new(response.into_body())
671        .eventsource()
672        .filter_map(|event_or| {
673          // QUESTION: Should we instead return a `Stream<Item = Result<ChangeEvent, _>>` to allow
674          // for better error handling here.
675          if let Ok(event) = event_or {
676            return ChangeEvent::from_str(&event.data)
677              .map_err(|err| {
678                warn!("Failed to parse change event: {}", event.data);
679                return err;
680              })
681              .ok();
682          }
683          return None;
684        }),
685    );
686  }
687
688  #[cfg(feature = "ws")]
689  pub async fn subscribe_ws<'a, T: RecordId<'a>>(
690    &self,
691    id: T,
692  ) -> Result<impl Stream<Item = ChangeEvent> + use<T>, Error> {
693    let response = self
694      .client
695      .upgrade_ws(
696        &format!(
697          "/{RECORD_API}/{name}/subscribe/{id}",
698          name = self.name,
699          id = id.serialized_id()
700        ),
701        Method::GET,
702        Some(&[("ws".into(), "true".into())]),
703      )
704      .await?;
705
706    let websocket = response.into_websocket().await?;
707
708    return Ok(websocket.filter_map(|message| {
709      use reqwest_websocket::Message;
710
711      return match message {
712        Ok(Message::Text(msg)) => serde_json::from_str::<ChangeEvent>(&msg)
713          .map_err(|err| {
714            warn!("json error: {err}");
715            return err;
716          })
717          .ok(),
718        msg => {
719          warn!("unexpected msg: {msg:?}");
720          None
721        }
722      };
723    }));
724  }
725}
726
727#[derive(Clone, Debug)]
728struct TokenState {
729  state: Option<(Tokens, JwtTokenClaims)>,
730  headers: HeaderMap,
731}
732
733impl TokenState {
734  fn build(tokens: Option<&Tokens>) -> TokenState {
735    let headers = build_headers(tokens);
736    return TokenState {
737      state: tokens.and_then(|tokens| {
738        return match decode_auth_token::<JwtTokenClaims>(&tokens.auth_token) {
739          Ok(jwt_token) => Some((tokens.clone(), jwt_token)),
740          Err(err) => {
741            error!("Failed to decode auth token: {err}");
742            None
743          }
744        };
745      }),
746      headers,
747    };
748  }
749}
750
751struct ClientState {
752  transport: Box<dyn Transport + Send + Sync>,
753  base_url: url::Url,
754  tokens: RwLock<TokenState>,
755}
756
757impl ClientState {
758  #[inline]
759  async fn fetch(
760    &self,
761    path: &str,
762    method: Method,
763    body: Option<Vec<u8>>,
764    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
765    error_for_status: bool,
766  ) -> Result<http::Response<reqwest::Body>, Error> {
767    let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
768    if let Some(refresh_token) = refresh_token {
769      let new_tokens = refresh_tokens_impl(&*self.transport, headers, refresh_token).await?;
770
771      headers = new_tokens.headers.clone();
772      *self.tokens.write() = new_tokens;
773    }
774
775    let response = self
776      .transport
777      .fetch(path, headers, method, body, query_params)
778      .await?;
779
780    if error_for_status {
781      return error_for_status_unpack(response);
782    }
783    return Ok(response);
784  }
785
786  #[cfg(feature = "ws")]
787  #[inline]
788  async fn upgrade_ws(
789    &self,
790    path: &str,
791    method: Method,
792    query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
793  ) -> Result<reqwest_websocket::UpgradeResponse, Error> {
794    let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
795    if let Some(refresh_token) = refresh_token {
796      let new_tokens = refresh_tokens_impl(&*self.transport, headers, refresh_token).await?;
797
798      headers = new_tokens.headers.clone();
799      *self.tokens.write() = new_tokens;
800    }
801
802    return self
803      .transport
804      .upgrade_ws(path, headers, method, query_params)
805      .await;
806  }
807
808  #[inline]
809  fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option<String>) {
810    #[inline]
811    fn should_refresh(jwt: &JwtTokenClaims) -> bool {
812      return jwt.exp - 60 < now() as i64;
813    }
814
815    let tokens = self.tokens.read();
816    let headers = tokens.headers.clone();
817    return match tokens.state {
818      Some(ref state) if should_refresh(&state.1) => (headers, state.0.refresh_token.clone()),
819      _ => (headers, None),
820    };
821  }
822
823  fn extract_headers_refresh_token(&self) -> Option<(HeaderMap, String)> {
824    let tokens = self.tokens.read();
825    let state = tokens.state.as_ref()?;
826
827    if let Some(ref refresh_token) = state.0.refresh_token {
828      return Some((tokens.headers.clone(), refresh_token.clone()));
829    }
830    return None;
831  }
832}
833
834#[derive(Clone)]
835pub struct Client {
836  state: Arc<ClientState>,
837}
838
839#[derive(Default)]
840pub struct ClientOptions {
841  pub tokens: Option<Tokens>,
842  pub transport: Option<Box<dyn Transport + Send + Sync>>,
843}
844
845impl Client {
846  pub fn new(
847    base_url: impl TryInto<url::Url, Error = url::ParseError>,
848    opts: Option<ClientOptions>,
849  ) -> Result<Client, Error> {
850    let opts = opts.unwrap_or_default();
851    let base_url = base_url.try_into().map_err(Error::InvalidUrl)?;
852    return Ok(Client {
853      state: Arc::new(ClientState {
854        transport: opts.transport.unwrap_or_else(|| {
855          return Box::new(DefaultTransport::new(base_url.clone()));
856        }),
857        base_url,
858        tokens: RwLock::new(TokenState::build(opts.tokens.as_ref())),
859      }),
860    });
861  }
862
863  pub fn base_url(&self) -> &url::Url {
864    return &self.state.base_url;
865  }
866
867  pub fn tokens(&self) -> Option<Tokens> {
868    return self.state.tokens.read().state.as_ref().map(|x| x.0.clone());
869  }
870
871  pub fn user(&self) -> Option<User> {
872    if let Some(state) = &self.state.tokens.read().state {
873      return Some(User {
874        sub: state.1.sub.clone(),
875        email: state.1.email.clone(),
876      });
877    }
878    return None;
879  }
880
881  pub fn records(&self, api_name: &str) -> RecordApi {
882    return RecordApi {
883      client: self.state.clone(),
884      name: api_name.to_string(),
885    };
886  }
887
888  pub async fn refresh(&self) -> Result<(), Error> {
889    let Some((headers, refresh_token)) = self.state.extract_headers_refresh_token() else {
890      // Not logged in - nothing to do.
891      return Ok(());
892    };
893
894    let new_tokens = refresh_tokens_impl(&*self.state.transport, headers, refresh_token).await?;
895
896    *self.state.tokens.write() = new_tokens;
897    return Ok(());
898  }
899
900  pub async fn login(
901    &self,
902    email: &str,
903    password: &str,
904  ) -> Result<Option<MultiFactorAuthToken>, Error> {
905    #[derive(Serialize)]
906    struct Credentials<'a> {
907      email: &'a str,
908      password: &'a str,
909    }
910
911    let response = self
912      .state
913      .fetch(
914        &format!("/{AUTH_API}/login"),
915        Method::POST,
916        Some(
917          serde_json::to_vec(&Credentials { email, password })
918            .map_err(Error::RecordSerialization)?,
919        ),
920        None,
921        /* error_for_status= */ false,
922      )
923      .await?;
924
925    if response.status() == StatusCode::FORBIDDEN {
926      let mfa_token: MultiFactorAuthToken = json(response).await?;
927      return Ok(Some(mfa_token));
928    }
929
930    let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
931    self.update_tokens(Some(&tokens));
932
933    return Ok(None);
934  }
935
936  pub async fn login_second(
937    &self,
938    mfa_token: &MultiFactorAuthToken,
939    totp_code: &str,
940  ) -> Result<(), Error> {
941    #[derive(Serialize)]
942    struct Credentials<'a> {
943      mfa_token: &'a str,
944      totp: &'a str,
945    }
946
947    let response = self
948      .state
949      .fetch(
950        &format!("/{AUTH_API}/login_mfa"),
951        Method::POST,
952        Some(
953          serde_json::to_vec(&Credentials {
954            mfa_token: &mfa_token.mfa_token,
955            totp: totp_code,
956          })
957          .map_err(Error::RecordSerialization)?,
958        ),
959        None,
960        /* error_for_status= */ true,
961      )
962      .await?;
963
964    let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
965    self.update_tokens(Some(&tokens));
966
967    return Ok(());
968  }
969
970  pub async fn request_otp(&self, email: &str, redirect_uri: Option<&str>) -> Result<(), Error> {
971    #[derive(Serialize)]
972    struct Credentials<'a> {
973      email: &'a str,
974      redirect_uri: Option<&'a str>,
975    }
976
977    let _response = self
978      .state
979      .fetch(
980        &format!("/{AUTH_API}/otp/request"),
981        Method::POST,
982        Some(
983          serde_json::to_vec(&Credentials {
984            email,
985            redirect_uri,
986          })
987          .map_err(Error::RecordSerialization)?,
988        ),
989        None,
990        /* error_for_status= */ true,
991      )
992      .await?;
993
994    return Ok(());
995  }
996
997  pub async fn login_otp(&self, email: &str, code: &str) -> Result<(), Error> {
998    #[derive(Serialize)]
999    struct Credentials<'a> {
1000      email: &'a str,
1001      code: &'a str,
1002    }
1003
1004    let response = self
1005      .state
1006      .fetch(
1007        &format!("/{AUTH_API}/otp/login"),
1008        Method::POST,
1009        Some(serde_json::to_vec(&Credentials { email, code }).map_err(Error::RecordSerialization)?),
1010        None,
1011        /* error_for_status= */ true,
1012      )
1013      .await?;
1014
1015    let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
1016    self.update_tokens(Some(&tokens));
1017
1018    return Ok(());
1019  }
1020
1021  pub async fn logout(&self) -> Result<(), Error> {
1022    #[derive(Serialize)]
1023    struct LogoutRequest {
1024      refresh_token: String,
1025    }
1026
1027    let response_or = match self.state.extract_headers_refresh_token() {
1028      Some((_headers, refresh_token)) => {
1029        self
1030          .state
1031          .fetch(
1032            &format!("/{AUTH_API}/logout"),
1033            Method::POST,
1034            Some(
1035              serde_json::to_vec(&LogoutRequest { refresh_token })
1036                .map_err(Error::RecordSerialization)?,
1037            ),
1038            None,
1039            /* error_for_status= */ true,
1040          )
1041          .await
1042      }
1043      _ => {
1044        self
1045          .state
1046          .fetch(
1047            &format!("/{AUTH_API}/logout"),
1048            Method::GET,
1049            None,
1050            None,
1051            /* error_for_status= */ true,
1052          )
1053          .await
1054      }
1055    };
1056
1057    self.update_tokens(None);
1058
1059    return response_or.map(|_| ());
1060  }
1061
1062  fn update_tokens(&self, tokens: Option<&Tokens>) -> TokenState {
1063    let state = TokenState::build(tokens);
1064
1065    *self.state.tokens.write() = state.clone();
1066    // _authChange?.call(this, state.state?.$1);
1067
1068    if let Some(ref s) = state.state {
1069      let now = now();
1070      if s.1.exp < now as i64 {
1071        warn!("Token expired");
1072      }
1073    }
1074
1075    return state;
1076  }
1077}
1078
1079fn build_headers(tokens: Option<&Tokens>) -> HeaderMap {
1080  let mut base = HeaderMap::with_capacity(5);
1081  base.insert(
1082    header::CONTENT_TYPE,
1083    HeaderValue::from_static("application/json"),
1084  );
1085
1086  if let Some(tokens) = tokens {
1087    if let Ok(value) = HeaderValue::from_str(&format!("Bearer {}", tokens.auth_token)) {
1088      base.insert(header::AUTHORIZATION, value);
1089    } else {
1090      error!("Failed to build bearer token.");
1091    }
1092
1093    if let Some(ref refresh) = tokens.refresh_token {
1094      if let Ok(value) = HeaderValue::from_str(refresh) {
1095        base.insert("Refresh-Token", value);
1096      } else {
1097        error!("Failed to build refresh token header.");
1098      }
1099    }
1100
1101    if let Some(ref csrf) = tokens.csrf_token {
1102      if let Ok(value) = HeaderValue::from_str(csrf) {
1103        base.insert("CSRF-Token", value);
1104      } else {
1105        error!("Failed to build refresh token header.");
1106      }
1107    }
1108  }
1109
1110  return base;
1111}
1112
1113async fn refresh_tokens_impl(
1114  transport: &(dyn Transport + Send + Sync),
1115  headers: HeaderMap,
1116  refresh_token: String,
1117) -> Result<TokenState, Error> {
1118  #[derive(Serialize)]
1119  struct RefreshRequest<'a> {
1120    refresh_token: &'a str,
1121  }
1122
1123  // NOTE: Do not use `ClientState::fetch`, which may do token refreshing to avoid loops.
1124  let response = transport
1125    .fetch(
1126      &format!("/{AUTH_API}/refresh"),
1127      headers,
1128      Method::POST,
1129      Some(
1130        serde_json::to_vec(&RefreshRequest {
1131          refresh_token: &refresh_token,
1132        })
1133        .map_err(Error::RecordSerialization)?,
1134      ),
1135      None,
1136    )
1137    .await?;
1138
1139  return match response.status() {
1140    StatusCode::OK => {
1141      #[derive(Deserialize)]
1142      struct RefreshResponse {
1143        auth_token: String,
1144        csrf_token: Option<String>,
1145      }
1146
1147      let refresh_response: RefreshResponse = json(response).await?;
1148
1149      Ok(TokenState::build(Some(&Tokens {
1150        auth_token: refresh_response.auth_token,
1151        refresh_token: Some(refresh_token),
1152        csrf_token: refresh_response.csrf_token,
1153      })))
1154    }
1155    StatusCode::UNAUTHORIZED => Ok(TokenState::build(None)),
1156    status => Err(Error::HttpStatus(status)),
1157  };
1158}
1159
1160fn now() -> u64 {
1161  return std::time::SystemTime::now()
1162    .duration_since(std::time::UNIX_EPOCH)
1163    .expect("Duration since epoch")
1164    .as_secs();
1165}
1166
1167#[inline]
1168async fn json<T: DeserializeOwned>(resp: http::Response<reqwest::Body>) -> Result<T, Error> {
1169  let full = into_bytes(resp).await?;
1170  return serde_json::from_slice(&full).map_err(Error::RecordSerialization);
1171}
1172
1173#[inline]
1174async fn into_bytes(resp: http::Response<reqwest::Body>) -> Result<bytes::Bytes, Error> {
1175  return Ok(
1176    http_body_util::BodyExt::collect(resp.into_body())
1177      .await
1178      .map(|buf| buf.to_bytes())?,
1179  );
1180}
1181
1182fn error_for_status_unpack(
1183  resp: http::Response<reqwest::Body>,
1184) -> Result<http::Response<reqwest::Body>, Error> {
1185  let status = resp.status();
1186  if status.is_client_error() || status.is_server_error() {
1187    return Err(Error::HttpStatus(status));
1188  }
1189  return Ok(resp);
1190}
1191
1192const AUTH_API: &str = "api/auth/v1";
1193const RECORD_API: &str = "api/records/v1";
1194
1195#[cfg(test)]
1196mod tests {
1197  use super::*;
1198
1199  #[tokio::test]
1200  async fn is_send_test() {
1201    let client = Client::new("http://127.0.0.1:4000", None).unwrap();
1202
1203    let api = client.records("simple_strict_table");
1204
1205    for _ in 0..2 {
1206      let api = api.clone();
1207      tokio::spawn(async move {
1208        // This would not compile if locks would be held across async function calls.
1209        let response = api.read::<serde_json::Value>(0).await;
1210        assert!(response.is_err());
1211      })
1212      .await
1213      .unwrap();
1214    }
1215  }
1216
1217  #[test]
1218  fn parse_change_event_test() {
1219    let ev0 = ChangeEvent::from_str(
1220      r#"
1221        {
1222          "Error": {
1223            "status": 1,
1224            "message": "test"
1225          },
1226          "seq": 3
1227        }"#,
1228    )
1229    .unwrap();
1230
1231    assert_eq!(ev0.seq, Some(3));
1232    let EventPayload::Error { status, message } = &*ev0.event else {
1233      panic!("expected error payload, got {:?}", ev0.event);
1234    };
1235
1236    assert_eq!(*status, EventErrorStatus::Forbidden);
1237    assert_eq!(message.as_deref().unwrap(), "test");
1238
1239    let ev1 = ChangeEvent::from_str(
1240      r#"
1241        {
1242          "Update": {
1243            "col0": "val0",
1244            "col1": 4
1245          }
1246        }"#,
1247    )
1248    .unwrap();
1249
1250    assert_eq!(ev1.seq, None);
1251    let EventPayload::Update(obj) = &*ev1.event else {
1252      panic!("expected update payload, got {:?}", ev1.event);
1253    };
1254
1255    assert_eq!(
1256      serde_json::Value::Object(obj.clone()),
1257      serde_json::json!({
1258          "col0": "val0",
1259          "col1": 4,
1260      })
1261    )
1262  }
1263}