1#![forbid(unsafe_code, clippy::unwrap_used)]
7#![allow(clippy::needless_return)]
8#![warn(clippy::await_holding_lock, clippy::inefficient_to_string)]
9
10use eventsource_stream::Eventsource;
11use futures_lite::StreamExt;
12use parking_lot::RwLock;
13use reqwest::header::{self, HeaderMap, HeaderValue};
14use reqwest::{Method, StatusCode};
15use serde::de::DeserializeOwned;
16use serde::{Deserialize, Serialize};
17use std::borrow::Cow;
18use std::sync::Arc;
19use thiserror::Error;
20use tracing::*;
21
22pub use futures_lite::Stream;
23
24#[derive(Debug, Error)]
25#[non_exhaustive]
26pub enum Error {
27 #[error("HTTP status: {0}")]
28 HttpStatus(StatusCode),
29
30 #[error("MissingRefreshToken")]
31 MissingRefreshToken,
32
33 #[error("RecordSerialization: {0}")]
34 RecordSerialization(serde_json::Error),
35
36 #[error("InvalidToken: {0}")]
37 InvalidToken(jsonwebtoken::errors::Error),
38
39 #[error("InvalidUrl: {0}")]
40 InvalidUrl(url::ParseError),
41
42 #[error("Reqwest: {0}")]
44 OtherReqwest(reqwest::Error),
45
46 #[cfg(feature = "ws")]
47 #[error("WebSocket: {0}")]
48 WebSocket(#[from] reqwest_websocket::Error),
49}
50
51impl From<reqwest::Error> for Error {
52 fn from(err: reqwest::Error) -> Self {
53 match err.status() {
54 Some(code) => Self::HttpStatus(code),
55 _ => Self::OtherReqwest(err),
56 }
57 }
58}
59
60#[derive(Clone, Debug)]
62pub struct User {
63 pub sub: String,
64 pub email: String,
65}
66
67#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
71pub struct Tokens {
72 pub auth_token: String,
73 pub refresh_token: Option<String>,
74 pub csrf_token: Option<String>,
75}
76
77#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
78pub struct MultiFactorAuthToken {
79 mfa_token: String,
80}
81
82#[derive(Clone, Debug, Default, PartialEq)]
83pub struct Pagination {
84 cursor: Option<String>,
85 limit: Option<usize>,
86 offset: Option<usize>,
87}
88
89impl Pagination {
90 pub fn new() -> Self {
91 return Self::default();
92 }
93
94 pub fn with_limit(mut self, limit: impl Into<Option<usize>>) -> Pagination {
95 self.limit = limit.into();
96 return self;
97 }
98
99 pub fn with_cursor(mut self, cursor: impl Into<Option<String>>) -> Pagination {
100 self.cursor = cursor.into();
101 return self;
102 }
103
104 pub fn with_offset(mut self, offset: impl Into<Option<usize>>) -> Pagination {
105 self.offset = offset.into();
106 return self;
107 }
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
111pub enum DbEvent {
112 Update(Option<serde_json::Value>),
113 Insert(Option<serde_json::Value>),
114 Delete(Option<serde_json::Value>),
115 Error(String),
116}
117
118#[derive(Clone, Debug, Deserialize)]
119pub struct ListResponse<T> {
120 pub cursor: Option<String>,
121 pub total_count: Option<usize>,
122 pub records: Vec<T>,
123}
124
125pub trait RecordId<'a> {
126 fn serialized_id(self) -> Cow<'a, str>;
127}
128
129impl RecordId<'_> for String {
130 fn serialized_id(self) -> Cow<'static, str> {
131 return Cow::Owned(self);
132 }
133}
134
135impl<'a> RecordId<'a> for &'a String {
136 fn serialized_id(self) -> Cow<'a, str> {
137 return Cow::Borrowed(self);
138 }
139}
140
141impl<'a> RecordId<'a> for &'a str {
142 fn serialized_id(self) -> Cow<'a, str> {
143 return Cow::Borrowed(self);
144 }
145}
146
147impl RecordId<'_> for i64 {
148 fn serialized_id(self) -> Cow<'static, str> {
149 return Cow::Owned(self.to_string());
150 }
151}
152
153pub trait ReadArgumentsTrait<'a> {
154 fn serialized_id(self) -> Cow<'a, str>;
155 fn expand(&self) -> Option<&Vec<&'a str>>;
156}
157
158impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for T {
159 fn serialized_id(self) -> Cow<'a, str> {
160 return self.serialized_id();
161 }
162
163 fn expand(&self) -> Option<&Vec<&'a str>> {
164 return None;
165 }
166}
167
168#[derive(Clone, Debug, PartialEq)]
169pub struct ReadArguments<'a, T: RecordId<'a>> {
170 id: T,
171 expand: Option<Vec<&'a str>>,
172}
173
174impl<'a, T: RecordId<'a>> ReadArguments<'a, T> {
175 pub fn new(id: T) -> Self {
176 return Self { id, expand: None };
177 }
178
179 pub fn with_expand(mut self, expand: impl AsRef<[&'a str]>) -> Self {
180 self.expand = Some(expand.as_ref().to_vec());
181 return self;
182 }
183}
184
185impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for ReadArguments<'a, T> {
186 fn serialized_id(self) -> Cow<'a, str> {
187 return self.id.serialized_id();
188 }
189
190 fn expand(&self) -> Option<&Vec<&'a str>> {
191 return self.expand.as_ref();
192 }
193}
194
195struct ThinClient {
196 client: reqwest::Client,
197 url: url::Url,
198}
199
200impl ThinClient {
201 async fn fetch_impl<T: Serialize>(
202 &self,
203 path: &str,
204 headers: HeaderMap,
205 method: Method,
206 body: Option<&T>,
207 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
208 ) -> Result<reqwest::Response, Error> {
209 assert!(path.starts_with("/"));
210
211 let mut url = self.url.clone();
212 url.set_path(path);
213
214 if let Some(query_params) = query_params {
215 let mut params = url.query_pairs_mut();
216 for (key, value) in query_params {
217 params.append_pair(key, value);
218 }
219 }
220
221 let request = {
222 let mut builder = self.client.request(method, url).headers(headers);
223 if let Some(ref body) = body {
224 let json = serde_json::to_string(body).map_err(Error::RecordSerialization)?;
225 builder = builder.body(json);
226 }
227 builder.build()?
228 };
229
230 return Ok(self.client.execute(request).await?);
231 }
232
233 #[cfg(feature = "ws")]
234 async fn upgrade_ws_impl<T: Serialize>(
235 &self,
236 path: &str,
237 headers: HeaderMap,
238 method: Method,
239 body: Option<&T>,
240 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
241 ) -> Result<reqwest_websocket::UpgradeResponse, Error> {
242 use reqwest_websocket::Upgrade;
243
244 assert!(path.starts_with("/"));
245
246 let mut url = self.url.clone();
247 url.set_path(path);
248
249 if let Some(query_params) = query_params {
250 let mut params = url.query_pairs_mut();
251 for (key, value) in query_params {
252 params.append_pair(key, value);
253 }
254 }
255
256 let request = {
257 let mut builder = self.client.request(method, url).headers(headers);
258 if let Some(ref body) = body {
259 let json = serde_json::to_string(body).map_err(Error::RecordSerialization)?;
260 builder = builder.body(json);
261 }
262 builder.upgrade()
263 };
264
265 return Ok(request.send().await?);
266 }
267}
268
269#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
270struct JwtTokenClaims {
271 sub: String,
272 iat: i64,
273 exp: i64,
274 email: String,
275 csrf_token: String,
276}
277
278fn decode_auth_token<T: DeserializeOwned + Clone>(token: &str) -> Result<T, Error> {
279 return jsonwebtoken::dangerous::insecure_decode::<T>(token)
280 .map(|data| data.claims)
281 .map_err(Error::InvalidToken);
282}
283
284#[derive(Clone)]
285pub struct RecordApi {
286 client: Arc<ClientState>,
287 name: String,
288}
289
290#[derive(Clone, Debug, Default, PartialEq)]
291pub struct ListArguments<'a> {
292 pagination: Pagination,
293 order: Option<Vec<&'a str>>,
294 filters: Option<ValueOrFilterGroup>,
295 expand: Option<Vec<&'a str>>,
296 count: bool,
297}
298
299#[derive(Clone, Copy, Debug, PartialEq)]
300pub enum CompareOp {
301 Equal,
302 NotEqual,
303 GreaterThanEqual,
304 GreaterThan,
305 LessThanEqual,
306 LessThan,
307 Like,
308 Regexp,
309 StWithin,
310 StIntersects,
311 StContains,
312}
313
314impl CompareOp {
315 fn format(&self) -> &'static str {
316 return match self {
317 Self::Equal => "$eq",
318 Self::NotEqual => "$ne",
319 Self::GreaterThanEqual => "$gte",
320 Self::GreaterThan => "$gt",
321 Self::LessThanEqual => "$lte",
322 Self::LessThan => "$lt",
323 Self::Like => "$like",
324 Self::Regexp => "$re",
325 Self::StWithin => "@within",
326 Self::StIntersects => "@intersects",
327 Self::StContains => "@contains",
328 };
329 }
330}
331
332#[derive(Clone, Default, Debug, PartialEq)]
333pub struct Filter {
334 pub column: String,
335 pub op: Option<CompareOp>,
336 pub value: String,
337}
338
339impl Filter {
340 pub fn new(column: impl Into<String>, op: CompareOp, value: impl Into<String>) -> Self {
341 return Self {
342 column: column.into(),
343 op: Some(op),
344 value: value.into(),
345 };
346 }
347}
348
349impl From<Filter> for ValueOrFilterGroup {
350 fn from(value: Filter) -> Self {
351 return ValueOrFilterGroup::Filter(value);
352 }
353}
354
355#[derive(Clone, Debug, PartialEq)]
356pub enum ValueOrFilterGroup {
357 Filter(Filter),
358 And(Vec<ValueOrFilterGroup>),
359 Or(Vec<ValueOrFilterGroup>),
360}
361
362impl<F> From<F> for ValueOrFilterGroup
363where
364 F: Into<Vec<Filter>>,
365{
366 fn from(filters: F) -> Self {
367 return ValueOrFilterGroup::And(
368 filters
369 .into()
370 .into_iter()
371 .map(ValueOrFilterGroup::Filter)
372 .collect(),
373 );
374 }
375}
376
377impl<'a> ListArguments<'a> {
378 pub fn new() -> Self {
379 return ListArguments::default();
380 }
381
382 pub fn with_pagination(mut self, pagination: Pagination) -> Self {
383 self.pagination = pagination;
384 return self;
385 }
386
387 pub fn with_order(mut self, order: impl AsRef<[&'a str]>) -> Self {
388 self.order = Some(order.as_ref().to_vec());
389 return self;
390 }
391
392 pub fn with_filters(mut self, filters: impl Into<ValueOrFilterGroup>) -> Self {
393 self.filters = Some(filters.into());
394 return self;
395 }
396
397 pub fn with_expand(mut self, expand: impl AsRef<[&'a str]>) -> Self {
398 self.expand = Some(expand.as_ref().to_vec());
399 return self;
400 }
401
402 pub fn with_count(mut self, count: bool) -> Self {
403 self.count = count;
404 return self;
405 }
406}
407
408impl RecordApi {
409 pub async fn list<T: DeserializeOwned>(
410 &self,
411 args: ListArguments<'_>,
412 ) -> Result<ListResponse<T>, Error> {
413 type Param = (Cow<'static, str>, Cow<'static, str>);
414 let mut params: Vec<Param> = vec![];
415 if let Some(cursor) = args.pagination.cursor {
416 params.push((Cow::Borrowed("cursor"), Cow::Owned(cursor)));
417 }
418
419 if let Some(limit) = args.pagination.limit {
420 params.push((Cow::Borrowed("limit"), Cow::Owned(limit.to_string())));
421 }
422
423 #[inline]
424 fn to_list(slice: &[&str]) -> String {
425 return slice.join(",");
426 }
427
428 if let Some(order) = args.order
429 && !order.is_empty()
430 {
431 params.push((Cow::Borrowed("order"), Cow::Owned(to_list(&order))));
432 }
433
434 if let Some(expand) = args.expand
435 && !expand.is_empty()
436 {
437 params.push((Cow::Borrowed("expand"), Cow::Owned(to_list(&expand))));
438 }
439
440 if args.count {
441 params.push((Cow::Borrowed("count"), Cow::Borrowed("true")));
442 }
443
444 fn traverse_filters(params: &mut Vec<Param>, path: String, filter: ValueOrFilterGroup) {
445 match filter {
446 ValueOrFilterGroup::Filter(filter) => {
447 if let Some(op) = filter.op {
448 params.push((
449 Cow::Owned(format!(
450 "{path}[{col}][{op}]",
451 col = filter.column,
452 op = op.format()
453 )),
454 Cow::Owned(filter.value),
455 ));
456 } else {
457 params.push((
458 Cow::Owned(format!("{path}[{col}]", col = filter.column)),
459 Cow::Owned(filter.value),
460 ));
461 }
462 }
463 ValueOrFilterGroup::And(vec) => {
464 for (i, f) in vec.into_iter().enumerate() {
465 traverse_filters(params, format!("{path}[$and][{i}]"), f);
466 }
467 }
468 ValueOrFilterGroup::Or(vec) => {
469 for (i, f) in vec.into_iter().enumerate() {
470 traverse_filters(params, format!("{path}[$or][{i}]"), f);
471 }
472 }
473 }
474 }
475
476 if let Some(filters) = args.filters {
477 traverse_filters(&mut params, "filter".to_string(), filters);
478 }
479
480 let response = self
481 .client
482 .fetch(
483 &format!("/{RECORD_API}/{}", self.name),
484 Method::GET,
485 None::<&()>,
486 Some(¶ms),
487 true,
488 )
489 .await?;
490
491 return json(response).await;
492 }
493
494 pub async fn read<'a, T: DeserializeOwned>(
495 &self,
496 args: impl ReadArgumentsTrait<'a>,
497 ) -> Result<T, Error> {
498 let expand = args
499 .expand()
500 .map(|e| vec![(Cow::Borrowed("expand"), Cow::Owned(e.join(",")))]);
501
502 let response = self
503 .client
504 .fetch(
505 &format!(
506 "/{RECORD_API}/{name}/{id}",
507 name = self.name,
508 id = args.serialized_id()
509 ),
510 Method::GET,
511 None::<&()>,
512 expand.as_deref(),
513 true,
514 )
515 .await?;
516
517 return json(response).await;
518 }
519
520 pub async fn create<T: Serialize>(&self, record: T) -> Result<String, Error> {
521 return Ok(self.create_impl(record).await?.swap_remove(0));
522 }
523
524 pub async fn create_bulk<T: Serialize>(&self, record: &[T]) -> Result<Vec<String>, Error> {
525 return self.create_impl(record).await;
526 }
527
528 async fn create_impl<T: Serialize>(&self, record: T) -> Result<Vec<String>, Error> {
529 let response = self
530 .client
531 .fetch(
532 &format!("/{RECORD_API}/{name}", name = self.name),
533 Method::POST,
534 Some(&record),
535 None,
536 true,
537 )
538 .await?;
539
540 #[derive(Deserialize)]
541 pub struct RecordIdResponse {
542 pub ids: Vec<String>,
543 }
544
545 return Ok(json::<RecordIdResponse>(response).await?.ids);
546 }
547
548 pub async fn update<'a, T: Serialize>(
549 &self,
550 id: impl RecordId<'a>,
551 record: T,
552 ) -> Result<(), Error> {
553 self
554 .client
555 .fetch(
556 &format!(
557 "/{RECORD_API}/{name}/{id}",
558 name = self.name,
559 id = id.serialized_id()
560 ),
561 Method::PATCH,
562 Some(&record),
563 None,
564 true,
565 )
566 .await?;
567
568 return Ok(());
569 }
570
571 pub async fn delete<'a>(&self, id: impl RecordId<'a>) -> Result<(), Error> {
572 self
573 .client
574 .fetch(
575 &format!(
576 "/{RECORD_API}/{name}/{id}",
577 name = self.name,
578 id = id.serialized_id()
579 ),
580 Method::DELETE,
581 None::<&()>,
582 None,
583 true,
584 )
585 .await?;
586
587 return Ok(());
588 }
589
590 pub async fn subscribe<'a, T: RecordId<'a>>(
591 &self,
592 id: T,
593 ) -> Result<impl Stream<Item = DbEvent> + use<T>, Error> {
594 let response = self
596 .client
597 .fetch(
598 &format!(
599 "/{RECORD_API}/{name}/subscribe/{id}",
600 name = self.name,
601 id = id.serialized_id()
602 ),
603 Method::GET,
604 None::<&()>,
605 None,
606 true,
607 )
608 .await?;
609
610 return Ok(
611 response
612 .bytes_stream()
613 .eventsource()
614 .filter_map(|event_or| {
615 if let Ok(event) = event_or {
618 return serde_json::from_str::<DbEvent>(&event.data).ok();
619 }
620 return None;
621 }),
622 );
623 }
624
625 #[cfg(feature = "ws")]
626 pub async fn subscribe_ws<'a, T: RecordId<'a>>(
627 &self,
628 id: T,
629 ) -> Result<impl Stream<Item = DbEvent> + use<T>, Error> {
630 let response = self
631 .client
632 .upgrade_ws(
633 &format!(
634 "/{RECORD_API}/{name}/subscribe/{id}",
635 name = self.name,
636 id = id.serialized_id()
637 ),
638 Method::GET,
639 None::<&()>,
640 Some(&[("ws".into(), "true".into())]),
641 )
642 .await?;
643
644 let websocket = response.into_websocket().await?;
645
646 return Ok(websocket.filter_map(|message| {
647 use reqwest_websocket::Message;
648
649 return match message {
650 Ok(Message::Text(msg)) => serde_json::from_str::<DbEvent>(&msg)
651 .map_err(|err| {
652 warn!("json error: {err}");
653 return err;
654 })
655 .ok(),
656 msg => {
657 warn!("unexpected msg: {msg:?}");
658 None
659 }
660 };
661 }));
662 }
663}
664
665#[derive(Clone, Debug)]
666struct TokenState {
667 state: Option<(Tokens, JwtTokenClaims)>,
668 headers: HeaderMap,
669}
670
671impl TokenState {
672 fn build(tokens: Option<&Tokens>) -> TokenState {
673 let headers = build_headers(tokens);
674 return TokenState {
675 state: tokens.and_then(|tokens| {
676 return match decode_auth_token::<JwtTokenClaims>(&tokens.auth_token) {
677 Ok(jwt_token) => Some((tokens.clone(), jwt_token)),
678 Err(err) => {
679 error!("Failed to decode auth token: {err}");
680 None
681 }
682 };
683 }),
684 headers,
685 };
686 }
687}
688
689struct ClientState {
690 client: ThinClient,
691 site: String,
692 tokens: RwLock<TokenState>,
693}
694
695impl ClientState {
696 #[inline]
697 async fn fetch<T: Serialize>(
698 &self,
699 path: &str,
700 method: Method,
701 body: Option<&T>,
702 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
703 error_for_status: bool,
704 ) -> Result<reqwest::Response, Error> {
705 let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
706 if let Some(refresh_token) = refresh_token {
707 let new_tokens = ClientState::refresh_tokens(&self.client, headers, refresh_token).await?;
708
709 headers = new_tokens.headers.clone();
710 *self.tokens.write() = new_tokens;
711 }
712
713 let response = self
714 .client
715 .fetch_impl(path, headers, method, body, query_params)
716 .await?;
717
718 if error_for_status {
719 return Ok(response.error_for_status()?);
720 }
721 return Ok(response);
722 }
723
724 #[cfg(feature = "ws")]
725 #[inline]
726 async fn upgrade_ws<T: Serialize>(
727 &self,
728 path: &str,
729 method: Method,
730 body: Option<&T>,
731 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
732 ) -> Result<reqwest_websocket::UpgradeResponse, Error> {
733 let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
734 if let Some(refresh_token) = refresh_token {
735 let new_tokens = ClientState::refresh_tokens(&self.client, headers, refresh_token).await?;
736
737 headers = new_tokens.headers.clone();
738 *self.tokens.write() = new_tokens;
739 }
740
741 return self
742 .client
743 .upgrade_ws_impl(path, headers, method, body, query_params)
744 .await;
745 }
746
747 #[inline]
748 fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option<String>) {
749 #[inline]
750 fn should_refresh(jwt: &JwtTokenClaims) -> bool {
751 return jwt.exp - 60 < now() as i64;
752 }
753
754 let tokens = self.tokens.read();
755 let headers = tokens.headers.clone();
756 return match tokens.state {
757 Some(ref state) if should_refresh(&state.1) => (headers, state.0.refresh_token.clone()),
758 _ => (headers, None),
759 };
760 }
761
762 fn extract_headers_refresh_token(&self) -> Option<(HeaderMap, String)> {
763 let tokens = self.tokens.read();
764 let state = tokens.state.as_ref()?;
765
766 if let Some(ref refresh_token) = state.0.refresh_token {
767 return Some((tokens.headers.clone(), refresh_token.clone()));
768 }
769 return None;
770 }
771
772 async fn refresh_tokens(
773 client: &ThinClient,
774 headers: HeaderMap,
775 refresh_token: String,
776 ) -> Result<TokenState, Error> {
777 #[derive(Serialize)]
778 struct RefreshRequest<'a> {
779 refresh_token: &'a str,
780 }
781
782 let response = client
783 .fetch_impl(
784 &format!("/{AUTH_API}/refresh"),
785 headers,
786 Method::POST,
787 Some(&RefreshRequest {
788 refresh_token: &refresh_token,
789 }),
790 None,
791 )
792 .await?;
793
794 #[derive(Deserialize)]
795 struct RefreshResponse {
796 auth_token: String,
797 csrf_token: Option<String>,
798 }
799
800 let refresh_response: RefreshResponse = json(response).await?;
801 return Ok(TokenState::build(Some(&Tokens {
802 auth_token: refresh_response.auth_token,
803 refresh_token: Some(refresh_token),
804 csrf_token: refresh_response.csrf_token,
805 })));
806 }
807}
808
809#[derive(Clone)]
810pub struct Client {
811 state: Arc<ClientState>,
812}
813
814impl Client {
815 pub fn new(site: &str, tokens: Option<Tokens>) -> Result<Client, Error> {
816 return Ok(Client {
817 state: Arc::new(ClientState {
818 client: ThinClient {
819 client: reqwest::Client::new(),
820 url: url::Url::parse(site).map_err(Error::InvalidUrl)?,
821 },
822 site: site.to_string(),
823 tokens: RwLock::new(TokenState::build(tokens.as_ref())),
824 }),
825 });
826 }
827
828 pub fn site(&self) -> String {
829 return self.state.site.clone();
830 }
831
832 pub fn tokens(&self) -> Option<Tokens> {
833 return self.state.tokens.read().state.as_ref().map(|x| x.0.clone());
834 }
835
836 pub fn user(&self) -> Option<User> {
837 if let Some(state) = &self.state.tokens.read().state {
838 return Some(User {
839 sub: state.1.sub.clone(),
840 email: state.1.email.clone(),
841 });
842 }
843 return None;
844 }
845
846 pub fn records(&self, api_name: &str) -> RecordApi {
847 return RecordApi {
848 client: self.state.clone(),
849 name: api_name.to_string(),
850 };
851 }
852
853 pub async fn refresh(&self) -> Result<(), Error> {
854 let Some((headers, refresh_token)) = self.state.extract_headers_refresh_token() else {
855 return Err(Error::MissingRefreshToken);
856 };
857
858 let new_tokens =
859 ClientState::refresh_tokens(&self.state.client, headers, refresh_token).await?;
860
861 *self.state.tokens.write() = new_tokens;
862 return Ok(());
863 }
864
865 pub async fn login(
866 &self,
867 email: &str,
868 password: &str,
869 ) -> Result<Option<MultiFactorAuthToken>, Error> {
870 #[derive(Serialize)]
871 struct Credentials<'a> {
872 email: &'a str,
873 password: &'a str,
874 }
875
876 let response = self
877 .state
878 .fetch(
879 &format!("/{AUTH_API}/login"),
880 Method::POST,
881 Some(&Credentials { email, password }),
882 None,
883 false,
884 )
885 .await?;
886
887 if response.status() == StatusCode::FORBIDDEN {
888 let mfa_token: MultiFactorAuthToken = json(response).await?;
889 return Ok(Some(mfa_token));
890 }
891
892 let tokens: Tokens = json(response.error_for_status()?).await?;
893 self.update_tokens(Some(&tokens));
894
895 return Ok(None);
896 }
897
898 pub async fn login_second(
899 &self,
900 mfa_token: &MultiFactorAuthToken,
901 totp_code: &str,
902 ) -> Result<(), Error> {
903 #[derive(Serialize)]
904 struct Credentials<'a> {
905 mfa_token: &'a str,
906 totp: &'a str,
907 }
908
909 let response = self
910 .state
911 .fetch(
912 &format!("/{AUTH_API}/login_mfa"),
913 Method::POST,
914 Some(&Credentials {
915 mfa_token: &mfa_token.mfa_token,
916 totp: totp_code,
917 }),
918 None,
919 true,
920 )
921 .await?;
922
923 let tokens: Tokens = json(response.error_for_status()?).await?;
924 self.update_tokens(Some(&tokens));
925
926 return Ok(());
927 }
928
929 pub async fn request_otp(&self, email: &str, redirect_uri: Option<&str>) -> Result<(), Error> {
930 #[derive(Serialize)]
931 struct Credentials<'a> {
932 email: &'a str,
933 redirect_uri: Option<&'a str>,
934 }
935
936 let _response = self
937 .state
938 .fetch(
939 &format!("/{AUTH_API}/otp/request"),
940 Method::POST,
941 Some(&Credentials {
942 email,
943 redirect_uri,
944 }),
945 None,
946 true,
947 )
948 .await?;
949
950 return Ok(());
951 }
952
953 pub async fn login_otp(&self, email: &str, code: &str) -> Result<(), Error> {
954 #[derive(Serialize)]
955 struct Credentials<'a> {
956 email: &'a str,
957 code: &'a str,
958 }
959
960 let response = self
961 .state
962 .fetch(
963 &format!("/{AUTH_API}/otp/login"),
964 Method::POST,
965 Some(&Credentials { email, code }),
966 None,
967 true,
968 )
969 .await?;
970
971 let tokens: Tokens = json(response.error_for_status()?).await?;
972 self.update_tokens(Some(&tokens));
973
974 return Ok(());
975 }
976
977 pub async fn logout(&self) -> Result<(), Error> {
978 #[derive(Serialize)]
979 struct LogoutRequest {
980 refresh_token: String,
981 }
982
983 let response_or = match self.state.extract_headers_refresh_token() {
984 Some((_headers, refresh_token)) => {
985 self
986 .state
987 .fetch(
988 &format!("/{AUTH_API}/logout"),
989 Method::POST,
990 Some(&LogoutRequest { refresh_token }),
991 None,
992 true,
993 )
994 .await
995 }
996 _ => {
997 self
998 .state
999 .fetch(
1000 &format!("/{AUTH_API}/logout"),
1001 Method::GET,
1002 None::<&()>,
1003 None,
1004 true,
1005 )
1006 .await
1007 }
1008 };
1009
1010 self.update_tokens(None);
1011
1012 return response_or.map(|_| ());
1013 }
1014
1015 fn update_tokens(&self, tokens: Option<&Tokens>) -> TokenState {
1016 let state = TokenState::build(tokens);
1017
1018 *self.state.tokens.write() = state.clone();
1019 if let Some(ref s) = state.state {
1022 let now = now();
1023 if s.1.exp < now as i64 {
1024 warn!("Token expired");
1025 }
1026 }
1027
1028 return state;
1029 }
1030}
1031
1032fn build_headers(tokens: Option<&Tokens>) -> HeaderMap {
1033 let mut base = HeaderMap::with_capacity(5);
1034 base.insert(
1035 header::CONTENT_TYPE,
1036 HeaderValue::from_static("application/json"),
1037 );
1038
1039 if let Some(tokens) = tokens {
1040 if let Ok(value) = HeaderValue::from_str(&format!("Bearer {}", tokens.auth_token)) {
1041 base.insert(header::AUTHORIZATION, value);
1042 } else {
1043 error!("Failed to build bearer token.");
1044 }
1045
1046 if let Some(ref refresh) = tokens.refresh_token {
1047 if let Ok(value) = HeaderValue::from_str(refresh) {
1048 base.insert("Refresh-Token", value);
1049 } else {
1050 error!("Failed to build refresh token header.");
1051 }
1052 }
1053
1054 if let Some(ref csrf) = tokens.csrf_token {
1055 if let Ok(value) = HeaderValue::from_str(csrf) {
1056 base.insert("CSRF-Token", value);
1057 } else {
1058 error!("Failed to build refresh token header.");
1059 }
1060 }
1061 }
1062
1063 return base;
1064}
1065
1066fn now() -> u64 {
1067 return std::time::SystemTime::now()
1068 .duration_since(std::time::UNIX_EPOCH)
1069 .expect("Duration since epoch")
1070 .as_secs();
1071}
1072
1073#[inline]
1074async fn json<T: DeserializeOwned>(resp: reqwest::Response) -> Result<T, Error> {
1075 let full = resp.bytes().await?;
1076 return serde_json::from_slice(&full).map_err(Error::RecordSerialization);
1077}
1078
1079const AUTH_API: &str = "api/auth/v1";
1080const RECORD_API: &str = "api/records/v1";
1081
1082#[cfg(test)]
1083mod tests {
1084 use super::*;
1085
1086 #[tokio::test]
1087 async fn is_send_test() {
1088 let client = Client::new("http://127.0.0.1:4000", None).unwrap();
1089
1090 let api = client.records("simple_strict_table");
1091
1092 for _ in 0..2 {
1093 let api = api.clone();
1094 tokio::spawn(async move {
1095 let response = api.read::<serde_json::Value>(0).await;
1097 assert!(response.is_err());
1098 })
1099 .await
1100 .unwrap();
1101 }
1102 }
1103}