1#![forbid(unsafe_code, clippy::unwrap_used)]
7#![allow(clippy::needless_return)]
8#![warn(clippy::await_holding_lock, clippy::inefficient_to_string)]
9
10use eventsource_stream::Eventsource;
11use futures_lite::StreamExt;
12use parking_lot::RwLock;
13use reqwest::header::{self, HeaderMap, HeaderValue};
14use reqwest::{Method, StatusCode};
15use serde::de::DeserializeOwned;
16use serde::{Deserialize, Serialize};
17use std::borrow::Cow;
18use std::sync::Arc;
19use thiserror::Error;
20use tracing::*;
21
22pub use futures_lite::Stream;
23
24#[derive(Debug, Error)]
25#[non_exhaustive]
26pub enum Error {
27 #[error("HTTP status: {0}")]
28 HttpStatus(StatusCode),
29
30 #[error("RecordSerialization: {0}")]
31 RecordSerialization(serde_json::Error),
32
33 #[error("InvalidToken: {0}")]
34 InvalidToken(jsonwebtoken::errors::Error),
35
36 #[error("InvalidUrl: {0}")]
37 InvalidUrl(url::ParseError),
38
39 #[error("Reqwest: {0}")]
41 OtherReqwest(reqwest::Error),
42
43 #[cfg(feature = "ws")]
44 #[error("WebSocket: {0}")]
45 WebSocket(#[from] reqwest_websocket::Error),
46}
47
48impl From<reqwest::Error> for Error {
49 fn from(err: reqwest::Error) -> Self {
50 match err.status() {
51 Some(code) => Self::HttpStatus(code),
52 _ => Self::OtherReqwest(err),
53 }
54 }
55}
56
57#[derive(Clone, Debug)]
59pub struct User {
60 pub sub: String,
61 pub email: String,
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
68pub struct Tokens {
69 pub auth_token: String,
70 pub refresh_token: Option<String>,
71 pub csrf_token: Option<String>,
72}
73
74#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
75pub struct MultiFactorAuthToken {
76 mfa_token: String,
77}
78
79#[derive(Clone, Debug, Default, PartialEq)]
80pub struct Pagination {
81 cursor: Option<String>,
82 limit: Option<usize>,
83 offset: Option<usize>,
84}
85
86impl Pagination {
87 pub fn new() -> Self {
88 return Self::default();
89 }
90
91 pub fn with_limit(mut self, limit: impl Into<Option<usize>>) -> Pagination {
92 self.limit = limit.into();
93 return self;
94 }
95
96 pub fn with_cursor(mut self, cursor: impl Into<Option<String>>) -> Pagination {
97 self.cursor = cursor.into();
98 return self;
99 }
100
101 pub fn with_offset(mut self, offset: impl Into<Option<usize>>) -> Pagination {
102 self.offset = offset.into();
103 return self;
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
108pub enum DbEvent {
109 Update(Option<serde_json::Value>),
110 Insert(Option<serde_json::Value>),
111 Delete(Option<serde_json::Value>),
112 Error(String),
113}
114
115#[derive(Clone, Debug, Deserialize)]
116pub struct ListResponse<T> {
117 pub cursor: Option<String>,
118 pub total_count: Option<usize>,
119 pub records: Vec<T>,
120}
121
122pub trait RecordId<'a> {
123 fn serialized_id(self) -> Cow<'a, str>;
124}
125
126impl RecordId<'_> for String {
127 fn serialized_id(self) -> Cow<'static, str> {
128 return Cow::Owned(self);
129 }
130}
131
132impl<'a> RecordId<'a> for &'a String {
133 fn serialized_id(self) -> Cow<'a, str> {
134 return Cow::Borrowed(self);
135 }
136}
137
138impl<'a> RecordId<'a> for &'a str {
139 fn serialized_id(self) -> Cow<'a, str> {
140 return Cow::Borrowed(self);
141 }
142}
143
144impl RecordId<'_> for i64 {
145 fn serialized_id(self) -> Cow<'static, str> {
146 return Cow::Owned(self.to_string());
147 }
148}
149
150pub trait ReadArgumentsTrait<'a> {
151 fn serialized_id(self) -> Cow<'a, str>;
152 fn expand(&self) -> Option<&Vec<&'a str>>;
153}
154
155impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for T {
156 fn serialized_id(self) -> Cow<'a, str> {
157 return self.serialized_id();
158 }
159
160 fn expand(&self) -> Option<&Vec<&'a str>> {
161 return None;
162 }
163}
164
165#[derive(Clone, Debug, PartialEq)]
166pub struct ReadArguments<'a, T: RecordId<'a>> {
167 id: T,
168 expand: Option<Vec<&'a str>>,
169}
170
171impl<'a, T: RecordId<'a>> ReadArguments<'a, T> {
172 pub fn new(id: T) -> Self {
173 return Self { id, expand: None };
174 }
175
176 pub fn with_expand(mut self, expand: impl AsRef<[&'a str]>) -> Self {
177 self.expand = Some(expand.as_ref().to_vec());
178 return self;
179 }
180}
181
182impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for ReadArguments<'a, T> {
183 fn serialized_id(self) -> Cow<'a, str> {
184 return self.id.serialized_id();
185 }
186
187 fn expand(&self) -> Option<&Vec<&'a str>> {
188 return self.expand.as_ref();
189 }
190}
191
192#[async_trait::async_trait]
193pub trait Transport {
194 async fn fetch(
195 &self,
196 path: &str,
197 headers: HeaderMap,
198 method: Method,
199 body: Option<Vec<u8>>,
200 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
201 ) -> Result<http::Response<reqwest::Body>, Error>;
202
203 #[cfg(feature = "ws")]
204 async fn upgrade_ws(
205 &self,
206 path: &str,
207 headers: HeaderMap,
208 method: Method,
209 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
210 ) -> Result<reqwest_websocket::UpgradeResponse, Error>;
211}
212
213pub struct DefaultTransport {
214 client: reqwest::Client,
215 url: url::Url,
216}
217
218impl DefaultTransport {
219 pub fn new(url: url::Url) -> Self {
220 return Self {
221 client: reqwest::Client::new(),
222 url,
223 };
224 }
225}
226
227#[async_trait::async_trait]
228impl Transport for DefaultTransport {
229 async fn fetch(
230 &self,
231 path: &str,
232 headers: HeaderMap,
233 method: Method,
234 body: Option<Vec<u8>>,
235 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
236 ) -> Result<http::Response<reqwest::Body>, Error> {
237 assert!(path.starts_with("/"));
238
239 let mut url = self.url.clone();
240 url.set_path(path);
241
242 if let Some(query_params) = query_params {
243 let mut params = url.query_pairs_mut();
244 for (key, value) in query_params {
245 params.append_pair(key, value);
246 }
247 }
248
249 let request = {
250 let mut builder = self.client.request(method, url).headers(headers);
251 if let Some(body) = body {
252 builder = builder.body(body);
255 }
256 builder.build()?
257 };
258
259 return Ok(self.client.execute(request).await?.into());
260 }
261
262 #[cfg(feature = "ws")]
263 async fn upgrade_ws(
264 &self,
265 path: &str,
266 headers: HeaderMap,
267 method: Method,
268 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
269 ) -> Result<reqwest_websocket::UpgradeResponse, Error> {
270 use reqwest_websocket::Upgrade;
271
272 assert!(path.starts_with("/"));
273
274 let mut url = self.url.clone();
275 url.set_path(path);
276
277 if let Some(query_params) = query_params {
278 let mut params = url.query_pairs_mut();
279 for (key, value) in query_params {
280 params.append_pair(key, value);
281 }
282 }
283
284 return Ok(
285 self
286 .client
287 .request(method, url)
288 .headers(headers)
289 .upgrade()
290 .send()
291 .await?,
292 );
293 }
294}
295
296#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
297struct JwtTokenClaims {
298 sub: String,
299 iat: i64,
300 exp: i64,
301 email: String,
302 csrf_token: String,
303}
304
305fn decode_auth_token<T: DeserializeOwned + Clone>(token: &str) -> Result<T, Error> {
306 return jsonwebtoken::dangerous::insecure_decode::<T>(token)
307 .map(|data| data.claims)
308 .map_err(Error::InvalidToken);
309}
310
311#[derive(Clone)]
312pub struct RecordApi {
313 client: Arc<ClientState>,
314 name: String,
315}
316
317#[derive(Clone, Debug, Default, PartialEq)]
318pub struct ListArguments<'a> {
319 pagination: Pagination,
320 order: Option<Vec<&'a str>>,
321 filters: Option<ValueOrFilterGroup>,
322 expand: Option<Vec<&'a str>>,
323 count: bool,
324}
325
326#[derive(Clone, Copy, Debug, PartialEq)]
327pub enum CompareOp {
328 Equal,
329 NotEqual,
330 GreaterThanEqual,
331 GreaterThan,
332 LessThanEqual,
333 LessThan,
334 Like,
335 Regexp,
336 StWithin,
337 StIntersects,
338 StContains,
339}
340
341impl CompareOp {
342 fn format(&self) -> &'static str {
343 return match self {
344 Self::Equal => "$eq",
345 Self::NotEqual => "$ne",
346 Self::GreaterThanEqual => "$gte",
347 Self::GreaterThan => "$gt",
348 Self::LessThanEqual => "$lte",
349 Self::LessThan => "$lt",
350 Self::Like => "$like",
351 Self::Regexp => "$re",
352 Self::StWithin => "@within",
353 Self::StIntersects => "@intersects",
354 Self::StContains => "@contains",
355 };
356 }
357}
358
359#[derive(Clone, Default, Debug, PartialEq)]
360pub struct Filter {
361 pub column: String,
362 pub op: Option<CompareOp>,
363 pub value: String,
364}
365
366impl Filter {
367 pub fn new(column: impl Into<String>, op: CompareOp, value: impl Into<String>) -> Self {
368 return Self {
369 column: column.into(),
370 op: Some(op),
371 value: value.into(),
372 };
373 }
374}
375
376impl From<Filter> for ValueOrFilterGroup {
377 fn from(value: Filter) -> Self {
378 return ValueOrFilterGroup::Filter(value);
379 }
380}
381
382#[derive(Clone, Debug, PartialEq)]
383pub enum ValueOrFilterGroup {
384 Filter(Filter),
385 And(Vec<ValueOrFilterGroup>),
386 Or(Vec<ValueOrFilterGroup>),
387}
388
389impl<F> From<F> for ValueOrFilterGroup
390where
391 F: Into<Vec<Filter>>,
392{
393 fn from(filters: F) -> Self {
394 return ValueOrFilterGroup::And(
395 filters
396 .into()
397 .into_iter()
398 .map(ValueOrFilterGroup::Filter)
399 .collect(),
400 );
401 }
402}
403
404impl<'a> ListArguments<'a> {
405 pub fn new() -> Self {
406 return ListArguments::default();
407 }
408
409 pub fn with_pagination(mut self, pagination: Pagination) -> Self {
410 self.pagination = pagination;
411 return self;
412 }
413
414 pub fn with_order(mut self, order: impl AsRef<[&'a str]>) -> Self {
415 self.order = Some(order.as_ref().to_vec());
416 return self;
417 }
418
419 pub fn with_filters(mut self, filters: impl Into<ValueOrFilterGroup>) -> Self {
420 self.filters = Some(filters.into());
421 return self;
422 }
423
424 pub fn with_expand(mut self, expand: impl AsRef<[&'a str]>) -> Self {
425 self.expand = Some(expand.as_ref().to_vec());
426 return self;
427 }
428
429 pub fn with_count(mut self, count: bool) -> Self {
430 self.count = count;
431 return self;
432 }
433}
434
435impl RecordApi {
436 pub async fn list<T: DeserializeOwned>(
437 &self,
438 args: ListArguments<'_>,
439 ) -> Result<ListResponse<T>, Error> {
440 type Param = (Cow<'static, str>, Cow<'static, str>);
441 let mut params: Vec<Param> = vec![];
442 if let Some(cursor) = args.pagination.cursor {
443 params.push((Cow::Borrowed("cursor"), Cow::Owned(cursor)));
444 }
445
446 if let Some(limit) = args.pagination.limit {
447 params.push((Cow::Borrowed("limit"), Cow::Owned(limit.to_string())));
448 }
449
450 #[inline]
451 fn to_list(slice: &[&str]) -> String {
452 return slice.join(",");
453 }
454
455 if let Some(order) = args.order
456 && !order.is_empty()
457 {
458 params.push((Cow::Borrowed("order"), Cow::Owned(to_list(&order))));
459 }
460
461 if let Some(expand) = args.expand
462 && !expand.is_empty()
463 {
464 params.push((Cow::Borrowed("expand"), Cow::Owned(to_list(&expand))));
465 }
466
467 if args.count {
468 params.push((Cow::Borrowed("count"), Cow::Borrowed("true")));
469 }
470
471 fn traverse_filters(params: &mut Vec<Param>, path: String, filter: ValueOrFilterGroup) {
472 match filter {
473 ValueOrFilterGroup::Filter(filter) => {
474 if let Some(op) = filter.op {
475 params.push((
476 Cow::Owned(format!(
477 "{path}[{col}][{op}]",
478 col = filter.column,
479 op = op.format()
480 )),
481 Cow::Owned(filter.value),
482 ));
483 } else {
484 params.push((
485 Cow::Owned(format!("{path}[{col}]", col = filter.column)),
486 Cow::Owned(filter.value),
487 ));
488 }
489 }
490 ValueOrFilterGroup::And(vec) => {
491 for (i, f) in vec.into_iter().enumerate() {
492 traverse_filters(params, format!("{path}[$and][{i}]"), f);
493 }
494 }
495 ValueOrFilterGroup::Or(vec) => {
496 for (i, f) in vec.into_iter().enumerate() {
497 traverse_filters(params, format!("{path}[$or][{i}]"), f);
498 }
499 }
500 }
501 }
502
503 if let Some(filters) = args.filters {
504 traverse_filters(&mut params, "filter".to_string(), filters);
505 }
506
507 let response = self
508 .client
509 .fetch(
510 &format!("/{RECORD_API}/{}", self.name),
511 Method::GET,
512 None,
513 Some(¶ms),
514 true,
515 )
516 .await?;
517
518 return json(response).await;
519 }
520
521 pub async fn read<'a, T: DeserializeOwned>(
522 &self,
523 args: impl ReadArgumentsTrait<'a>,
524 ) -> Result<T, Error> {
525 let expand = args
526 .expand()
527 .map(|e| vec![(Cow::Borrowed("expand"), Cow::Owned(e.join(",")))]);
528
529 let response = self
530 .client
531 .fetch(
532 &format!(
533 "/{RECORD_API}/{name}/{id}",
534 name = self.name,
535 id = args.serialized_id()
536 ),
537 Method::GET,
538 None,
539 expand.as_deref(),
540 true,
541 )
542 .await?;
543
544 return json(response).await;
545 }
546
547 pub async fn create<T: Serialize>(&self, record: T) -> Result<String, Error> {
548 return Ok(self.create_impl(record).await?.swap_remove(0));
549 }
550
551 pub async fn create_bulk<T: Serialize>(&self, record: &[T]) -> Result<Vec<String>, Error> {
552 return self.create_impl(record).await;
553 }
554
555 async fn create_impl<T: Serialize>(&self, record: T) -> Result<Vec<String>, Error> {
556 let response = self
557 .client
558 .fetch(
559 &format!("/{RECORD_API}/{name}", name = self.name),
560 Method::POST,
561 Some(serde_json::to_vec(&record).map_err(Error::RecordSerialization)?),
562 None,
563 true,
564 )
565 .await?;
566
567 #[derive(Deserialize)]
568 pub struct RecordIdResponse {
569 pub ids: Vec<String>,
570 }
571
572 return Ok(json::<RecordIdResponse>(response).await?.ids);
573 }
574
575 pub async fn update<'a, T: Serialize>(
576 &self,
577 id: impl RecordId<'a>,
578 record: T,
579 ) -> Result<(), Error> {
580 self
581 .client
582 .fetch(
583 &format!(
584 "/{RECORD_API}/{name}/{id}",
585 name = self.name,
586 id = id.serialized_id()
587 ),
588 Method::PATCH,
589 Some(serde_json::to_vec(&record).map_err(Error::RecordSerialization)?),
590 None,
591 true,
592 )
593 .await?;
594
595 return Ok(());
596 }
597
598 pub async fn delete<'a>(&self, id: impl RecordId<'a>) -> Result<(), Error> {
599 self
600 .client
601 .fetch(
602 &format!(
603 "/{RECORD_API}/{name}/{id}",
604 name = self.name,
605 id = id.serialized_id()
606 ),
607 Method::DELETE,
608 None,
609 None,
610 true,
611 )
612 .await?;
613
614 return Ok(());
615 }
616
617 pub async fn subscribe<'a, T: RecordId<'a>>(
618 &self,
619 id: T,
620 ) -> Result<impl Stream<Item = DbEvent> + use<T>, Error> {
621 let response = self
623 .client
624 .fetch(
625 &format!(
626 "/{RECORD_API}/{name}/subscribe/{id}",
627 name = self.name,
628 id = id.serialized_id()
629 ),
630 Method::GET,
631 None,
632 None,
633 true,
634 )
635 .await?;
636
637 return Ok(
638 http_body_util::BodyDataStream::new(response.into_body())
639 .eventsource()
640 .filter_map(|event_or| {
641 if let Ok(event) = event_or {
644 return serde_json::from_str::<DbEvent>(&event.data).ok();
645 }
646 return None;
647 }),
648 );
649 }
650
651 #[cfg(feature = "ws")]
652 pub async fn subscribe_ws<'a, T: RecordId<'a>>(
653 &self,
654 id: T,
655 ) -> Result<impl Stream<Item = DbEvent> + use<T>, Error> {
656 let response = self
657 .client
658 .upgrade_ws(
659 &format!(
660 "/{RECORD_API}/{name}/subscribe/{id}",
661 name = self.name,
662 id = id.serialized_id()
663 ),
664 Method::GET,
665 Some(&[("ws".into(), "true".into())]),
666 )
667 .await?;
668
669 let websocket = response.into_websocket().await?;
670
671 return Ok(websocket.filter_map(|message| {
672 use reqwest_websocket::Message;
673
674 return match message {
675 Ok(Message::Text(msg)) => serde_json::from_str::<DbEvent>(&msg)
676 .map_err(|err| {
677 warn!("json error: {err}");
678 return err;
679 })
680 .ok(),
681 msg => {
682 warn!("unexpected msg: {msg:?}");
683 None
684 }
685 };
686 }));
687 }
688}
689
690#[derive(Clone, Debug)]
691struct TokenState {
692 state: Option<(Tokens, JwtTokenClaims)>,
693 headers: HeaderMap,
694}
695
696impl TokenState {
697 fn build(tokens: Option<&Tokens>) -> TokenState {
698 let headers = build_headers(tokens);
699 return TokenState {
700 state: tokens.and_then(|tokens| {
701 return match decode_auth_token::<JwtTokenClaims>(&tokens.auth_token) {
702 Ok(jwt_token) => Some((tokens.clone(), jwt_token)),
703 Err(err) => {
704 error!("Failed to decode auth token: {err}");
705 None
706 }
707 };
708 }),
709 headers,
710 };
711 }
712}
713
714struct ClientState {
715 transport: Box<dyn Transport + Send + Sync>,
716 base_url: url::Url,
717 tokens: RwLock<TokenState>,
718}
719
720impl ClientState {
721 #[inline]
722 async fn fetch(
723 &self,
724 path: &str,
725 method: Method,
726 body: Option<Vec<u8>>,
727 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
728 error_for_status: bool,
729 ) -> Result<http::Response<reqwest::Body>, Error> {
730 let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
731 if let Some(refresh_token) = refresh_token {
732 let new_tokens =
733 ClientState::refresh_tokens_impl(&*self.transport, headers, refresh_token).await?;
734
735 headers = new_tokens.headers.clone();
736 *self.tokens.write() = new_tokens;
737 }
738
739 let response = self
740 .transport
741 .fetch(path, headers, method, body, query_params)
742 .await?;
743
744 if error_for_status {
745 return error_for_status_unpack(response);
746 }
747 return Ok(response);
748 }
749
750 #[cfg(feature = "ws")]
751 #[inline]
752 async fn upgrade_ws(
753 &self,
754 path: &str,
755 method: Method,
756 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
757 ) -> Result<reqwest_websocket::UpgradeResponse, Error> {
758 let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
759 if let Some(refresh_token) = refresh_token {
760 let new_tokens =
761 ClientState::refresh_tokens_impl(&*self.transport, headers, refresh_token).await?;
762
763 headers = new_tokens.headers.clone();
764 *self.tokens.write() = new_tokens;
765 }
766
767 return self
768 .client
769 .upgrade_ws(path, headers, method, query_params)
770 .await;
771 }
772
773 #[inline]
774 fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option<String>) {
775 #[inline]
776 fn should_refresh(jwt: &JwtTokenClaims) -> bool {
777 return jwt.exp - 60 < now() as i64;
778 }
779
780 let tokens = self.tokens.read();
781 let headers = tokens.headers.clone();
782 return match tokens.state {
783 Some(ref state) if should_refresh(&state.1) => (headers, state.0.refresh_token.clone()),
784 _ => (headers, None),
785 };
786 }
787
788 fn extract_headers_refresh_token(&self) -> Option<(HeaderMap, String)> {
789 let tokens = self.tokens.read();
790 let state = tokens.state.as_ref()?;
791
792 if let Some(ref refresh_token) = state.0.refresh_token {
793 return Some((tokens.headers.clone(), refresh_token.clone()));
794 }
795 return None;
796 }
797
798 async fn refresh_tokens_impl(
799 transport: &(dyn Transport + Send + Sync),
800 headers: HeaderMap,
801 refresh_token: String,
802 ) -> Result<TokenState, Error> {
803 #[derive(Serialize)]
804 struct RefreshRequest<'a> {
805 refresh_token: &'a str,
806 }
807
808 let response = transport
810 .fetch(
811 &format!("/{AUTH_API}/refresh"),
812 headers,
813 Method::POST,
814 Some(
815 serde_json::to_vec(&RefreshRequest {
816 refresh_token: &refresh_token,
817 })
818 .map_err(Error::RecordSerialization)?,
819 ),
820 None,
821 )
822 .await?;
823
824 #[derive(Deserialize)]
825 struct RefreshResponse {
826 auth_token: String,
827 csrf_token: Option<String>,
828 }
829
830 let refresh_response: RefreshResponse = json(response).await?;
831 return Ok(TokenState::build(Some(&Tokens {
832 auth_token: refresh_response.auth_token,
833 refresh_token: Some(refresh_token),
834 csrf_token: refresh_response.csrf_token,
835 })));
836 }
837}
838
839#[derive(Clone)]
840pub struct Client {
841 state: Arc<ClientState>,
842}
843
844#[derive(Default)]
845pub struct ClientOptions {
846 pub tokens: Option<Tokens>,
847 pub transport: Option<Box<dyn Transport + Send + Sync>>,
848}
849
850impl Client {
851 pub fn new(
852 base_url: impl TryInto<url::Url, Error = url::ParseError>,
853 opts: Option<ClientOptions>,
854 ) -> Result<Client, Error> {
855 let opts = opts.unwrap_or_default();
856 let base_url = base_url.try_into().map_err(Error::InvalidUrl)?;
857 return Ok(Client {
858 state: Arc::new(ClientState {
859 transport: opts.transport.unwrap_or_else(|| {
860 return Box::new(DefaultTransport::new(base_url.clone()));
861 }),
862 base_url,
863 tokens: RwLock::new(TokenState::build(opts.tokens.as_ref())),
864 }),
865 });
866 }
867
868 pub fn base_url(&self) -> &url::Url {
869 return &self.state.base_url;
870 }
871
872 pub fn tokens(&self) -> Option<Tokens> {
873 return self.state.tokens.read().state.as_ref().map(|x| x.0.clone());
874 }
875
876 pub fn user(&self) -> Option<User> {
877 if let Some(state) = &self.state.tokens.read().state {
878 return Some(User {
879 sub: state.1.sub.clone(),
880 email: state.1.email.clone(),
881 });
882 }
883 return None;
884 }
885
886 pub fn records(&self, api_name: &str) -> RecordApi {
887 return RecordApi {
888 client: self.state.clone(),
889 name: api_name.to_string(),
890 };
891 }
892
893 pub async fn refresh(&self) -> Result<(), Error> {
894 let Some((headers, refresh_token)) = self.state.extract_headers_refresh_token() else {
895 return Ok(());
897 };
898
899 let new_tokens =
900 ClientState::refresh_tokens_impl(&*self.state.transport, headers, refresh_token).await?;
901
902 *self.state.tokens.write() = new_tokens;
903 return Ok(());
904 }
905
906 pub async fn login(
907 &self,
908 email: &str,
909 password: &str,
910 ) -> Result<Option<MultiFactorAuthToken>, Error> {
911 #[derive(Serialize)]
912 struct Credentials<'a> {
913 email: &'a str,
914 password: &'a str,
915 }
916
917 let response = self
918 .state
919 .fetch(
920 &format!("/{AUTH_API}/login"),
921 Method::POST,
922 Some(
923 serde_json::to_vec(&Credentials { email, password })
924 .map_err(Error::RecordSerialization)?,
925 ),
926 None,
927 false,
928 )
929 .await?;
930
931 if response.status() == StatusCode::FORBIDDEN {
932 let mfa_token: MultiFactorAuthToken = json(response).await?;
933 return Ok(Some(mfa_token));
934 }
935
936 let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
937 self.update_tokens(Some(&tokens));
938
939 return Ok(None);
940 }
941
942 pub async fn login_second(
943 &self,
944 mfa_token: &MultiFactorAuthToken,
945 totp_code: &str,
946 ) -> Result<(), Error> {
947 #[derive(Serialize)]
948 struct Credentials<'a> {
949 mfa_token: &'a str,
950 totp: &'a str,
951 }
952
953 let response = self
954 .state
955 .fetch(
956 &format!("/{AUTH_API}/login_mfa"),
957 Method::POST,
958 Some(
959 serde_json::to_vec(&Credentials {
960 mfa_token: &mfa_token.mfa_token,
961 totp: totp_code,
962 })
963 .map_err(Error::RecordSerialization)?,
964 ),
965 None,
966 true,
967 )
968 .await?;
969
970 let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
971 self.update_tokens(Some(&tokens));
972
973 return Ok(());
974 }
975
976 pub async fn request_otp(&self, email: &str, redirect_uri: Option<&str>) -> Result<(), Error> {
977 #[derive(Serialize)]
978 struct Credentials<'a> {
979 email: &'a str,
980 redirect_uri: Option<&'a str>,
981 }
982
983 let _response = self
984 .state
985 .fetch(
986 &format!("/{AUTH_API}/otp/request"),
987 Method::POST,
988 Some(
989 serde_json::to_vec(&Credentials {
990 email,
991 redirect_uri,
992 })
993 .map_err(Error::RecordSerialization)?,
994 ),
995 None,
996 true,
997 )
998 .await?;
999
1000 return Ok(());
1001 }
1002
1003 pub async fn login_otp(&self, email: &str, code: &str) -> Result<(), Error> {
1004 #[derive(Serialize)]
1005 struct Credentials<'a> {
1006 email: &'a str,
1007 code: &'a str,
1008 }
1009
1010 let response = self
1011 .state
1012 .fetch(
1013 &format!("/{AUTH_API}/otp/login"),
1014 Method::POST,
1015 Some(serde_json::to_vec(&Credentials { email, code }).map_err(Error::RecordSerialization)?),
1016 None,
1017 true,
1018 )
1019 .await?;
1020
1021 let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
1022 self.update_tokens(Some(&tokens));
1023
1024 return Ok(());
1025 }
1026
1027 pub async fn logout(&self) -> Result<(), Error> {
1028 #[derive(Serialize)]
1029 struct LogoutRequest {
1030 refresh_token: String,
1031 }
1032
1033 let response_or = match self.state.extract_headers_refresh_token() {
1034 Some((_headers, refresh_token)) => {
1035 self
1036 .state
1037 .fetch(
1038 &format!("/{AUTH_API}/logout"),
1039 Method::POST,
1040 Some(
1041 serde_json::to_vec(&LogoutRequest { refresh_token })
1042 .map_err(Error::RecordSerialization)?,
1043 ),
1044 None,
1045 true,
1046 )
1047 .await
1048 }
1049 _ => {
1050 self
1051 .state
1052 .fetch(
1053 &format!("/{AUTH_API}/logout"),
1054 Method::GET,
1055 None,
1056 None,
1057 true,
1058 )
1059 .await
1060 }
1061 };
1062
1063 self.update_tokens(None);
1064
1065 return response_or.map(|_| ());
1066 }
1067
1068 fn update_tokens(&self, tokens: Option<&Tokens>) -> TokenState {
1069 let state = TokenState::build(tokens);
1070
1071 *self.state.tokens.write() = state.clone();
1072 if let Some(ref s) = state.state {
1075 let now = now();
1076 if s.1.exp < now as i64 {
1077 warn!("Token expired");
1078 }
1079 }
1080
1081 return state;
1082 }
1083}
1084
1085fn build_headers(tokens: Option<&Tokens>) -> HeaderMap {
1086 let mut base = HeaderMap::with_capacity(5);
1087 base.insert(
1088 header::CONTENT_TYPE,
1089 HeaderValue::from_static("application/json"),
1090 );
1091
1092 if let Some(tokens) = tokens {
1093 if let Ok(value) = HeaderValue::from_str(&format!("Bearer {}", tokens.auth_token)) {
1094 base.insert(header::AUTHORIZATION, value);
1095 } else {
1096 error!("Failed to build bearer token.");
1097 }
1098
1099 if let Some(ref refresh) = tokens.refresh_token {
1100 if let Ok(value) = HeaderValue::from_str(refresh) {
1101 base.insert("Refresh-Token", value);
1102 } else {
1103 error!("Failed to build refresh token header.");
1104 }
1105 }
1106
1107 if let Some(ref csrf) = tokens.csrf_token {
1108 if let Ok(value) = HeaderValue::from_str(csrf) {
1109 base.insert("CSRF-Token", value);
1110 } else {
1111 error!("Failed to build refresh token header.");
1112 }
1113 }
1114 }
1115
1116 return base;
1117}
1118
1119fn now() -> u64 {
1120 return std::time::SystemTime::now()
1121 .duration_since(std::time::UNIX_EPOCH)
1122 .expect("Duration since epoch")
1123 .as_secs();
1124}
1125
1126#[inline]
1127async fn json<T: DeserializeOwned>(resp: http::Response<reqwest::Body>) -> Result<T, Error> {
1128 let full = into_bytes(resp).await?;
1129 return serde_json::from_slice(&full).map_err(Error::RecordSerialization);
1130}
1131
1132#[inline]
1133async fn into_bytes(resp: http::Response<reqwest::Body>) -> Result<bytes::Bytes, Error> {
1134 return Ok(
1135 http_body_util::BodyExt::collect(resp.into_body())
1136 .await
1137 .map(|buf| buf.to_bytes())?,
1138 );
1139}
1140
1141fn error_for_status_unpack(
1142 resp: http::Response<reqwest::Body>,
1143) -> Result<http::Response<reqwest::Body>, Error> {
1144 let status = resp.status();
1145 if status.is_client_error() || status.is_server_error() {
1146 return Err(Error::HttpStatus(status));
1147 }
1148 return Ok(resp);
1149}
1150
1151const AUTH_API: &str = "api/auth/v1";
1152const RECORD_API: &str = "api/records/v1";
1153
1154#[cfg(test)]
1155mod tests {
1156 use super::*;
1157
1158 #[tokio::test]
1159 async fn is_send_test() {
1160 let client = Client::new("http://127.0.0.1:4000", None).unwrap();
1161
1162 let api = client.records("simple_strict_table");
1163
1164 for _ in 0..2 {
1165 let api = api.clone();
1166 tokio::spawn(async move {
1167 let response = api.read::<serde_json::Value>(0).await;
1169 assert!(response.is_err());
1170 })
1171 .await
1172 .unwrap();
1173 }
1174 }
1175}