1#![forbid(unsafe_code, clippy::unwrap_used)]
7#![allow(clippy::needless_return)]
8#![warn(clippy::await_holding_lock, clippy::inefficient_to_string)]
9
10use eventsource_stream::Eventsource;
11use futures_lite::StreamExt;
12use parking_lot::RwLock;
13use reqwest::header::{self, HeaderMap, HeaderValue};
14use reqwest::{Method, StatusCode};
15use serde::de::DeserializeOwned;
16use serde::{Deserialize, Serialize};
17use serde_repr::{Deserialize_repr, Serialize_repr};
18use std::borrow::Cow;
19use std::sync::Arc;
20use thiserror::Error;
21use tracing::*;
22
23pub use futures_lite::Stream;
24
25#[derive(Debug, Error)]
26#[non_exhaustive]
27pub enum Error {
28 #[error("HTTP status: {0}")]
29 HttpStatus(StatusCode),
30
31 #[error("RecordSerialization: {0}")]
32 RecordSerialization(serde_json::Error),
33
34 #[error("InvalidToken: {0}")]
35 InvalidToken(jsonwebtoken::errors::Error),
36
37 #[error("InvalidUrl: {0}")]
38 InvalidUrl(url::ParseError),
39
40 #[error("Reqwest: {0}")]
42 OtherReqwest(reqwest::Error),
43
44 #[cfg(feature = "ws")]
45 #[error("WebSocket: {0}")]
46 WebSocket(#[from] reqwest_websocket::Error),
47}
48
49impl From<reqwest::Error> for Error {
50 fn from(err: reqwest::Error) -> Self {
51 match err.status() {
52 Some(code) => Self::HttpStatus(code),
53 _ => Self::OtherReqwest(err),
54 }
55 }
56}
57
58#[derive(Clone, Debug)]
60pub struct User {
61 pub sub: String,
62 pub email: String,
63}
64
65#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
69pub struct Tokens {
70 pub auth_token: String,
71 pub refresh_token: Option<String>,
72 pub csrf_token: Option<String>,
73}
74
75#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
76pub struct MultiFactorAuthToken {
77 mfa_token: String,
78}
79
80#[derive(Clone, Debug, Default, PartialEq)]
81pub struct Pagination {
82 cursor: Option<String>,
83 limit: Option<usize>,
84 offset: Option<usize>,
85}
86
87impl Pagination {
88 pub fn new() -> Self {
89 return Self::default();
90 }
91
92 pub fn with_limit(mut self, limit: impl Into<Option<usize>>) -> Pagination {
93 self.limit = limit.into();
94 return self;
95 }
96
97 pub fn with_cursor(mut self, cursor: impl Into<Option<String>>) -> Pagination {
98 self.cursor = cursor.into();
99 return self;
100 }
101
102 pub fn with_offset(mut self, offset: impl Into<Option<usize>>) -> Pagination {
103 self.offset = offset.into();
104 return self;
105 }
106}
107
108type JsonObject = serde_json::value::Map<String, serde_json::Value>;
109
110#[derive(Debug, Clone, Copy, Deserialize_repr, Serialize_repr, PartialEq)]
111#[repr(i64)]
112pub enum EventErrorStatus {
113 Unknown = 0,
115 Forbidden = 1,
117 Loss = 2,
121}
122
123#[derive(Debug, Clone, Deserialize, PartialEq)]
124pub enum EventPayload {
125 Update(JsonObject),
126 Insert(JsonObject),
127 Delete(JsonObject),
128 Error {
129 status: EventErrorStatus,
130 message: Option<String>,
131 },
132}
133
134#[derive(Debug, Clone, Deserialize, PartialEq)]
135pub struct ChangeEvent {
136 #[serde(flatten)]
137 pub event: Arc<EventPayload>,
138 pub seq: Option<i64>,
139}
140
141impl ChangeEvent {
142 fn from_str(msg: &str) -> Result<ChangeEvent, serde_json::Error> {
143 return serde_json::from_str::<ChangeEvent>(msg);
144 }
145}
146
147#[derive(Clone, Debug, Deserialize)]
148pub struct ListResponse<T> {
149 pub cursor: Option<String>,
150 pub total_count: Option<usize>,
151 pub records: Vec<T>,
152}
153
154pub trait RecordId<'a> {
155 fn serialized_id(self) -> Cow<'a, str>;
156}
157
158impl RecordId<'_> for String {
159 fn serialized_id(self) -> Cow<'static, str> {
160 return Cow::Owned(self);
161 }
162}
163
164impl<'a> RecordId<'a> for &'a String {
165 fn serialized_id(self) -> Cow<'a, str> {
166 return Cow::Borrowed(self);
167 }
168}
169
170impl<'a> RecordId<'a> for &'a str {
171 fn serialized_id(self) -> Cow<'a, str> {
172 return Cow::Borrowed(self);
173 }
174}
175
176impl RecordId<'_> for i64 {
177 fn serialized_id(self) -> Cow<'static, str> {
178 return Cow::Owned(self.to_string());
179 }
180}
181
182pub trait ReadArgumentsTrait<'a> {
183 fn serialized_id(self) -> Cow<'a, str>;
184 fn expand(&self) -> Option<&Vec<&'a str>>;
185}
186
187impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for T {
188 fn serialized_id(self) -> Cow<'a, str> {
189 return self.serialized_id();
190 }
191
192 fn expand(&self) -> Option<&Vec<&'a str>> {
193 return None;
194 }
195}
196
197#[derive(Clone, Debug, PartialEq)]
198pub struct ReadArguments<'a, T: RecordId<'a>> {
199 id: T,
200 expand: Option<Vec<&'a str>>,
201}
202
203impl<'a, T: RecordId<'a>> ReadArguments<'a, T> {
204 pub fn new(id: T) -> Self {
205 return Self { id, expand: None };
206 }
207
208 pub fn with_expand(mut self, expand: impl AsRef<[&'a str]>) -> Self {
209 self.expand = Some(expand.as_ref().to_vec());
210 return self;
211 }
212}
213
214impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for ReadArguments<'a, T> {
215 fn serialized_id(self) -> Cow<'a, str> {
216 return self.id.serialized_id();
217 }
218
219 fn expand(&self) -> Option<&Vec<&'a str>> {
220 return self.expand.as_ref();
221 }
222}
223
224#[async_trait::async_trait]
225pub trait Transport {
226 async fn fetch(
227 &self,
228 path: &str,
229 headers: HeaderMap,
230 method: Method,
231 body: Option<Vec<u8>>,
232 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
233 ) -> Result<http::Response<reqwest::Body>, Error>;
234
235 #[cfg(feature = "ws")]
236 async fn upgrade_ws(
237 &self,
238 path: &str,
239 headers: HeaderMap,
240 method: Method,
241 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
242 ) -> Result<reqwest_websocket::UpgradeResponse, Error>;
243}
244
245pub struct DefaultTransport {
246 client: reqwest::Client,
247 url: url::Url,
248}
249
250impl DefaultTransport {
251 pub fn new(url: url::Url) -> Self {
252 return Self {
253 client: reqwest::Client::new(),
254 url,
255 };
256 }
257}
258
259#[async_trait::async_trait]
260impl Transport for DefaultTransport {
261 async fn fetch(
262 &self,
263 path: &str,
264 headers: HeaderMap,
265 method: Method,
266 body: Option<Vec<u8>>,
267 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
268 ) -> Result<http::Response<reqwest::Body>, Error> {
269 assert!(path.starts_with("/"));
270
271 let mut url = self.url.clone();
272 url.set_path(path);
273
274 if let Some(query_params) = query_params {
275 let mut params = url.query_pairs_mut();
276 for (key, value) in query_params {
277 params.append_pair(key, value);
278 }
279 }
280
281 let request = {
282 let mut builder = self.client.request(method, url).headers(headers);
283 if let Some(body) = body {
284 builder = builder.body(body);
287 }
288 builder.build()?
289 };
290
291 return Ok(self.client.execute(request).await?.into());
292 }
293
294 #[cfg(feature = "ws")]
295 async fn upgrade_ws(
296 &self,
297 path: &str,
298 headers: HeaderMap,
299 method: Method,
300 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
301 ) -> Result<reqwest_websocket::UpgradeResponse, Error> {
302 use reqwest_websocket::Upgrade;
303
304 assert!(path.starts_with("/"));
305
306 let mut url = self.url.clone();
307 url.set_path(path);
308
309 if let Some(query_params) = query_params {
310 let mut params = url.query_pairs_mut();
311 for (key, value) in query_params {
312 params.append_pair(key, value);
313 }
314 }
315
316 return Ok(
317 self
318 .client
319 .request(method, url)
320 .headers(headers)
321 .upgrade()
322 .send()
323 .await?,
324 );
325 }
326}
327
328#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
329struct JwtTokenClaims {
330 sub: String,
331 iat: i64,
332 exp: i64,
333 email: String,
334 csrf_token: String,
335}
336
337fn decode_auth_token<T: DeserializeOwned + Clone>(token: &str) -> Result<T, Error> {
338 return jsonwebtoken::dangerous::insecure_decode::<T>(token)
339 .map(|data| data.claims)
340 .map_err(Error::InvalidToken);
341}
342
343#[derive(Clone)]
344pub struct RecordApi {
345 client: Arc<ClientState>,
346 name: String,
347}
348
349#[derive(Clone, Debug, Default, PartialEq)]
350pub struct ListArguments<'a> {
351 pagination: Pagination,
352 order: Option<Vec<&'a str>>,
353 filters: Option<ValueOrFilterGroup>,
354 expand: Option<Vec<&'a str>>,
355 count: bool,
356}
357
358#[derive(Clone, Copy, Debug, PartialEq)]
359pub enum CompareOp {
360 Equal,
361 NotEqual,
362 GreaterThanEqual,
363 GreaterThan,
364 LessThanEqual,
365 LessThan,
366 Like,
367 Regexp,
368 StWithin,
369 StIntersects,
370 StContains,
371}
372
373impl CompareOp {
374 fn format(&self) -> &'static str {
375 return match self {
376 Self::Equal => "$eq",
377 Self::NotEqual => "$ne",
378 Self::GreaterThanEqual => "$gte",
379 Self::GreaterThan => "$gt",
380 Self::LessThanEqual => "$lte",
381 Self::LessThan => "$lt",
382 Self::Like => "$like",
383 Self::Regexp => "$re",
384 Self::StWithin => "@within",
385 Self::StIntersects => "@intersects",
386 Self::StContains => "@contains",
387 };
388 }
389}
390
391#[derive(Clone, Default, Debug, PartialEq)]
392pub struct Filter {
393 pub column: String,
394 pub op: Option<CompareOp>,
395 pub value: String,
396}
397
398impl Filter {
399 pub fn new(column: impl Into<String>, op: CompareOp, value: impl Into<String>) -> Self {
400 return Self {
401 column: column.into(),
402 op: Some(op),
403 value: value.into(),
404 };
405 }
406}
407
408impl From<Filter> for ValueOrFilterGroup {
409 fn from(value: Filter) -> Self {
410 return ValueOrFilterGroup::Filter(value);
411 }
412}
413
414#[derive(Clone, Debug, PartialEq)]
415pub enum ValueOrFilterGroup {
416 Filter(Filter),
417 And(Vec<ValueOrFilterGroup>),
418 Or(Vec<ValueOrFilterGroup>),
419}
420
421impl<F> From<F> for ValueOrFilterGroup
422where
423 F: Into<Vec<Filter>>,
424{
425 fn from(filters: F) -> Self {
426 return ValueOrFilterGroup::And(
427 filters
428 .into()
429 .into_iter()
430 .map(ValueOrFilterGroup::Filter)
431 .collect(),
432 );
433 }
434}
435
436impl<'a> ListArguments<'a> {
437 pub fn new() -> Self {
438 return ListArguments::default();
439 }
440
441 pub fn with_pagination(mut self, pagination: Pagination) -> Self {
442 self.pagination = pagination;
443 return self;
444 }
445
446 pub fn with_order(mut self, order: impl AsRef<[&'a str]>) -> Self {
447 self.order = Some(order.as_ref().to_vec());
448 return self;
449 }
450
451 pub fn with_filters(mut self, filters: impl Into<ValueOrFilterGroup>) -> Self {
452 self.filters = Some(filters.into());
453 return self;
454 }
455
456 pub fn with_expand(mut self, expand: impl AsRef<[&'a str]>) -> Self {
457 self.expand = Some(expand.as_ref().to_vec());
458 return self;
459 }
460
461 pub fn with_count(mut self, count: bool) -> Self {
462 self.count = count;
463 return self;
464 }
465}
466
467impl RecordApi {
468 pub async fn list<T: DeserializeOwned>(
469 &self,
470 args: ListArguments<'_>,
471 ) -> Result<ListResponse<T>, Error> {
472 type Param = (Cow<'static, str>, Cow<'static, str>);
473 let mut params: Vec<Param> = vec![];
474 if let Some(cursor) = args.pagination.cursor {
475 params.push((Cow::Borrowed("cursor"), Cow::Owned(cursor)));
476 }
477
478 if let Some(limit) = args.pagination.limit {
479 params.push((Cow::Borrowed("limit"), Cow::Owned(limit.to_string())));
480 }
481
482 #[inline]
483 fn to_list(slice: &[&str]) -> String {
484 return slice.join(",");
485 }
486
487 if let Some(order) = args.order
488 && !order.is_empty()
489 {
490 params.push((Cow::Borrowed("order"), Cow::Owned(to_list(&order))));
491 }
492
493 if let Some(expand) = args.expand
494 && !expand.is_empty()
495 {
496 params.push((Cow::Borrowed("expand"), Cow::Owned(to_list(&expand))));
497 }
498
499 if args.count {
500 params.push((Cow::Borrowed("count"), Cow::Borrowed("true")));
501 }
502
503 fn traverse_filters(params: &mut Vec<Param>, path: String, filter: ValueOrFilterGroup) {
504 match filter {
505 ValueOrFilterGroup::Filter(filter) => {
506 if let Some(op) = filter.op {
507 params.push((
508 Cow::Owned(format!(
509 "{path}[{col}][{op}]",
510 col = filter.column,
511 op = op.format()
512 )),
513 Cow::Owned(filter.value),
514 ));
515 } else {
516 params.push((
517 Cow::Owned(format!("{path}[{col}]", col = filter.column)),
518 Cow::Owned(filter.value),
519 ));
520 }
521 }
522 ValueOrFilterGroup::And(vec) => {
523 for (i, f) in vec.into_iter().enumerate() {
524 traverse_filters(params, format!("{path}[$and][{i}]"), f);
525 }
526 }
527 ValueOrFilterGroup::Or(vec) => {
528 for (i, f) in vec.into_iter().enumerate() {
529 traverse_filters(params, format!("{path}[$or][{i}]"), f);
530 }
531 }
532 }
533 }
534
535 if let Some(filters) = args.filters {
536 traverse_filters(&mut params, "filter".to_string(), filters);
537 }
538
539 let response = self
540 .client
541 .fetch(
542 &format!("/{RECORD_API}/{}", self.name),
543 Method::GET,
544 None,
545 Some(¶ms),
546 true,
547 )
548 .await?;
549
550 return json(response).await;
551 }
552
553 pub async fn read<'a, T: DeserializeOwned>(
554 &self,
555 args: impl ReadArgumentsTrait<'a>,
556 ) -> Result<T, Error> {
557 let expand = args
558 .expand()
559 .map(|e| vec![(Cow::Borrowed("expand"), Cow::Owned(e.join(",")))]);
560
561 let response = self
562 .client
563 .fetch(
564 &format!(
565 "/{RECORD_API}/{name}/{id}",
566 name = self.name,
567 id = args.serialized_id()
568 ),
569 Method::GET,
570 None,
571 expand.as_deref(),
572 true,
573 )
574 .await?;
575
576 return json(response).await;
577 }
578
579 pub async fn create<T: Serialize>(&self, record: T) -> Result<String, Error> {
580 return Ok(self.create_impl(record).await?.swap_remove(0));
581 }
582
583 pub async fn create_bulk<T: Serialize>(&self, record: &[T]) -> Result<Vec<String>, Error> {
584 return self.create_impl(record).await;
585 }
586
587 async fn create_impl<T: Serialize>(&self, record: T) -> Result<Vec<String>, Error> {
588 let response = self
589 .client
590 .fetch(
591 &format!("/{RECORD_API}/{name}", name = self.name),
592 Method::POST,
593 Some(serde_json::to_vec(&record).map_err(Error::RecordSerialization)?),
594 None,
595 true,
596 )
597 .await?;
598
599 #[derive(Deserialize)]
600 pub struct RecordIdResponse {
601 pub ids: Vec<String>,
602 }
603
604 return Ok(json::<RecordIdResponse>(response).await?.ids);
605 }
606
607 pub async fn update<'a, T: Serialize>(
608 &self,
609 id: impl RecordId<'a>,
610 record: T,
611 ) -> Result<(), Error> {
612 self
613 .client
614 .fetch(
615 &format!(
616 "/{RECORD_API}/{name}/{id}",
617 name = self.name,
618 id = id.serialized_id()
619 ),
620 Method::PATCH,
621 Some(serde_json::to_vec(&record).map_err(Error::RecordSerialization)?),
622 None,
623 true,
624 )
625 .await?;
626
627 return Ok(());
628 }
629
630 pub async fn delete<'a>(&self, id: impl RecordId<'a>) -> Result<(), Error> {
631 self
632 .client
633 .fetch(
634 &format!(
635 "/{RECORD_API}/{name}/{id}",
636 name = self.name,
637 id = id.serialized_id()
638 ),
639 Method::DELETE,
640 None,
641 None,
642 true,
643 )
644 .await?;
645
646 return Ok(());
647 }
648
649 pub async fn subscribe<'a, T: RecordId<'a>>(
650 &self,
651 id: T,
652 ) -> Result<impl Stream<Item = ChangeEvent> + use<T>, Error> {
653 let response = self
655 .client
656 .fetch(
657 &format!(
658 "/{RECORD_API}/{name}/subscribe/{id}",
659 name = self.name,
660 id = id.serialized_id()
661 ),
662 Method::GET,
663 None,
664 None,
665 true,
666 )
667 .await?;
668
669 return Ok(
670 http_body_util::BodyDataStream::new(response.into_body())
671 .eventsource()
672 .filter_map(|event_or| {
673 if let Ok(event) = event_or {
676 return ChangeEvent::from_str(&event.data)
677 .map_err(|err| {
678 warn!("Failed to parse change event: {}", event.data);
679 return err;
680 })
681 .ok();
682 }
683 return None;
684 }),
685 );
686 }
687
688 #[cfg(feature = "ws")]
689 pub async fn subscribe_ws<'a, T: RecordId<'a>>(
690 &self,
691 id: T,
692 ) -> Result<impl Stream<Item = ChangeEvent> + use<T>, Error> {
693 let response = self
694 .client
695 .upgrade_ws(
696 &format!(
697 "/{RECORD_API}/{name}/subscribe/{id}",
698 name = self.name,
699 id = id.serialized_id()
700 ),
701 Method::GET,
702 Some(&[("ws".into(), "true".into())]),
703 )
704 .await?;
705
706 let websocket = response.into_websocket().await?;
707
708 return Ok(websocket.filter_map(|message| {
709 use reqwest_websocket::Message;
710
711 return match message {
712 Ok(Message::Text(msg)) => serde_json::from_str::<ChangeEvent>(&msg)
713 .map_err(|err| {
714 warn!("json error: {err}");
715 return err;
716 })
717 .ok(),
718 msg => {
719 warn!("unexpected msg: {msg:?}");
720 None
721 }
722 };
723 }));
724 }
725}
726
727#[derive(Clone, Debug)]
728struct TokenState {
729 state: Option<(Tokens, JwtTokenClaims)>,
730 headers: HeaderMap,
731}
732
733impl TokenState {
734 fn build(tokens: Option<&Tokens>) -> TokenState {
735 let headers = build_headers(tokens);
736 return TokenState {
737 state: tokens.and_then(|tokens| {
738 return match decode_auth_token::<JwtTokenClaims>(&tokens.auth_token) {
739 Ok(jwt_token) => Some((tokens.clone(), jwt_token)),
740 Err(err) => {
741 error!("Failed to decode auth token: {err}");
742 None
743 }
744 };
745 }),
746 headers,
747 };
748 }
749}
750
751struct ClientState {
752 transport: Box<dyn Transport + Send + Sync>,
753 base_url: url::Url,
754 tokens: RwLock<TokenState>,
755}
756
757impl ClientState {
758 #[inline]
759 async fn fetch(
760 &self,
761 path: &str,
762 method: Method,
763 body: Option<Vec<u8>>,
764 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
765 error_for_status: bool,
766 ) -> Result<http::Response<reqwest::Body>, Error> {
767 let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
768 if let Some(refresh_token) = refresh_token {
769 let new_tokens =
770 ClientState::refresh_tokens_impl(&*self.transport, headers, refresh_token).await?;
771
772 headers = new_tokens.headers.clone();
773 *self.tokens.write() = new_tokens;
774 }
775
776 let response = self
777 .transport
778 .fetch(path, headers, method, body, query_params)
779 .await?;
780
781 if error_for_status {
782 return error_for_status_unpack(response);
783 }
784 return Ok(response);
785 }
786
787 #[cfg(feature = "ws")]
788 #[inline]
789 async fn upgrade_ws(
790 &self,
791 path: &str,
792 method: Method,
793 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
794 ) -> Result<reqwest_websocket::UpgradeResponse, Error> {
795 let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
796 if let Some(refresh_token) = refresh_token {
797 let new_tokens =
798 ClientState::refresh_tokens_impl(&*self.transport, headers, refresh_token).await?;
799
800 headers = new_tokens.headers.clone();
801 *self.tokens.write() = new_tokens;
802 }
803
804 return self
805 .transport
806 .upgrade_ws(path, headers, method, query_params)
807 .await;
808 }
809
810 #[inline]
811 fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option<String>) {
812 #[inline]
813 fn should_refresh(jwt: &JwtTokenClaims) -> bool {
814 return jwt.exp - 60 < now() as i64;
815 }
816
817 let tokens = self.tokens.read();
818 let headers = tokens.headers.clone();
819 return match tokens.state {
820 Some(ref state) if should_refresh(&state.1) => (headers, state.0.refresh_token.clone()),
821 _ => (headers, None),
822 };
823 }
824
825 fn extract_headers_refresh_token(&self) -> Option<(HeaderMap, String)> {
826 let tokens = self.tokens.read();
827 let state = tokens.state.as_ref()?;
828
829 if let Some(ref refresh_token) = state.0.refresh_token {
830 return Some((tokens.headers.clone(), refresh_token.clone()));
831 }
832 return None;
833 }
834
835 async fn refresh_tokens_impl(
836 transport: &(dyn Transport + Send + Sync),
837 headers: HeaderMap,
838 refresh_token: String,
839 ) -> Result<TokenState, Error> {
840 #[derive(Serialize)]
841 struct RefreshRequest<'a> {
842 refresh_token: &'a str,
843 }
844
845 let response = transport
847 .fetch(
848 &format!("/{AUTH_API}/refresh"),
849 headers,
850 Method::POST,
851 Some(
852 serde_json::to_vec(&RefreshRequest {
853 refresh_token: &refresh_token,
854 })
855 .map_err(Error::RecordSerialization)?,
856 ),
857 None,
858 )
859 .await?;
860
861 #[derive(Deserialize)]
862 struct RefreshResponse {
863 auth_token: String,
864 csrf_token: Option<String>,
865 }
866
867 let refresh_response: RefreshResponse = json(response).await?;
868 return Ok(TokenState::build(Some(&Tokens {
869 auth_token: refresh_response.auth_token,
870 refresh_token: Some(refresh_token),
871 csrf_token: refresh_response.csrf_token,
872 })));
873 }
874}
875
876#[derive(Clone)]
877pub struct Client {
878 state: Arc<ClientState>,
879}
880
881#[derive(Default)]
882pub struct ClientOptions {
883 pub tokens: Option<Tokens>,
884 pub transport: Option<Box<dyn Transport + Send + Sync>>,
885}
886
887impl Client {
888 pub fn new(
889 base_url: impl TryInto<url::Url, Error = url::ParseError>,
890 opts: Option<ClientOptions>,
891 ) -> Result<Client, Error> {
892 let opts = opts.unwrap_or_default();
893 let base_url = base_url.try_into().map_err(Error::InvalidUrl)?;
894 return Ok(Client {
895 state: Arc::new(ClientState {
896 transport: opts.transport.unwrap_or_else(|| {
897 return Box::new(DefaultTransport::new(base_url.clone()));
898 }),
899 base_url,
900 tokens: RwLock::new(TokenState::build(opts.tokens.as_ref())),
901 }),
902 });
903 }
904
905 pub fn base_url(&self) -> &url::Url {
906 return &self.state.base_url;
907 }
908
909 pub fn tokens(&self) -> Option<Tokens> {
910 return self.state.tokens.read().state.as_ref().map(|x| x.0.clone());
911 }
912
913 pub fn user(&self) -> Option<User> {
914 if let Some(state) = &self.state.tokens.read().state {
915 return Some(User {
916 sub: state.1.sub.clone(),
917 email: state.1.email.clone(),
918 });
919 }
920 return None;
921 }
922
923 pub fn records(&self, api_name: &str) -> RecordApi {
924 return RecordApi {
925 client: self.state.clone(),
926 name: api_name.to_string(),
927 };
928 }
929
930 pub async fn refresh(&self) -> Result<(), Error> {
931 let Some((headers, refresh_token)) = self.state.extract_headers_refresh_token() else {
932 return Ok(());
934 };
935
936 let new_tokens =
937 ClientState::refresh_tokens_impl(&*self.state.transport, headers, refresh_token).await?;
938
939 *self.state.tokens.write() = new_tokens;
940 return Ok(());
941 }
942
943 pub async fn login(
944 &self,
945 email: &str,
946 password: &str,
947 ) -> Result<Option<MultiFactorAuthToken>, Error> {
948 #[derive(Serialize)]
949 struct Credentials<'a> {
950 email: &'a str,
951 password: &'a str,
952 }
953
954 let response = self
955 .state
956 .fetch(
957 &format!("/{AUTH_API}/login"),
958 Method::POST,
959 Some(
960 serde_json::to_vec(&Credentials { email, password })
961 .map_err(Error::RecordSerialization)?,
962 ),
963 None,
964 false,
965 )
966 .await?;
967
968 if response.status() == StatusCode::FORBIDDEN {
969 let mfa_token: MultiFactorAuthToken = json(response).await?;
970 return Ok(Some(mfa_token));
971 }
972
973 let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
974 self.update_tokens(Some(&tokens));
975
976 return Ok(None);
977 }
978
979 pub async fn login_second(
980 &self,
981 mfa_token: &MultiFactorAuthToken,
982 totp_code: &str,
983 ) -> Result<(), Error> {
984 #[derive(Serialize)]
985 struct Credentials<'a> {
986 mfa_token: &'a str,
987 totp: &'a str,
988 }
989
990 let response = self
991 .state
992 .fetch(
993 &format!("/{AUTH_API}/login_mfa"),
994 Method::POST,
995 Some(
996 serde_json::to_vec(&Credentials {
997 mfa_token: &mfa_token.mfa_token,
998 totp: totp_code,
999 })
1000 .map_err(Error::RecordSerialization)?,
1001 ),
1002 None,
1003 true,
1004 )
1005 .await?;
1006
1007 let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
1008 self.update_tokens(Some(&tokens));
1009
1010 return Ok(());
1011 }
1012
1013 pub async fn request_otp(&self, email: &str, redirect_uri: Option<&str>) -> Result<(), Error> {
1014 #[derive(Serialize)]
1015 struct Credentials<'a> {
1016 email: &'a str,
1017 redirect_uri: Option<&'a str>,
1018 }
1019
1020 let _response = self
1021 .state
1022 .fetch(
1023 &format!("/{AUTH_API}/otp/request"),
1024 Method::POST,
1025 Some(
1026 serde_json::to_vec(&Credentials {
1027 email,
1028 redirect_uri,
1029 })
1030 .map_err(Error::RecordSerialization)?,
1031 ),
1032 None,
1033 true,
1034 )
1035 .await?;
1036
1037 return Ok(());
1038 }
1039
1040 pub async fn login_otp(&self, email: &str, code: &str) -> Result<(), Error> {
1041 #[derive(Serialize)]
1042 struct Credentials<'a> {
1043 email: &'a str,
1044 code: &'a str,
1045 }
1046
1047 let response = self
1048 .state
1049 .fetch(
1050 &format!("/{AUTH_API}/otp/login"),
1051 Method::POST,
1052 Some(serde_json::to_vec(&Credentials { email, code }).map_err(Error::RecordSerialization)?),
1053 None,
1054 true,
1055 )
1056 .await?;
1057
1058 let tokens: Tokens = json(error_for_status_unpack(response)?).await?;
1059 self.update_tokens(Some(&tokens));
1060
1061 return Ok(());
1062 }
1063
1064 pub async fn logout(&self) -> Result<(), Error> {
1065 #[derive(Serialize)]
1066 struct LogoutRequest {
1067 refresh_token: String,
1068 }
1069
1070 let response_or = match self.state.extract_headers_refresh_token() {
1071 Some((_headers, refresh_token)) => {
1072 self
1073 .state
1074 .fetch(
1075 &format!("/{AUTH_API}/logout"),
1076 Method::POST,
1077 Some(
1078 serde_json::to_vec(&LogoutRequest { refresh_token })
1079 .map_err(Error::RecordSerialization)?,
1080 ),
1081 None,
1082 true,
1083 )
1084 .await
1085 }
1086 _ => {
1087 self
1088 .state
1089 .fetch(
1090 &format!("/{AUTH_API}/logout"),
1091 Method::GET,
1092 None,
1093 None,
1094 true,
1095 )
1096 .await
1097 }
1098 };
1099
1100 self.update_tokens(None);
1101
1102 return response_or.map(|_| ());
1103 }
1104
1105 fn update_tokens(&self, tokens: Option<&Tokens>) -> TokenState {
1106 let state = TokenState::build(tokens);
1107
1108 *self.state.tokens.write() = state.clone();
1109 if let Some(ref s) = state.state {
1112 let now = now();
1113 if s.1.exp < now as i64 {
1114 warn!("Token expired");
1115 }
1116 }
1117
1118 return state;
1119 }
1120}
1121
1122fn build_headers(tokens: Option<&Tokens>) -> HeaderMap {
1123 let mut base = HeaderMap::with_capacity(5);
1124 base.insert(
1125 header::CONTENT_TYPE,
1126 HeaderValue::from_static("application/json"),
1127 );
1128
1129 if let Some(tokens) = tokens {
1130 if let Ok(value) = HeaderValue::from_str(&format!("Bearer {}", tokens.auth_token)) {
1131 base.insert(header::AUTHORIZATION, value);
1132 } else {
1133 error!("Failed to build bearer token.");
1134 }
1135
1136 if let Some(ref refresh) = tokens.refresh_token {
1137 if let Ok(value) = HeaderValue::from_str(refresh) {
1138 base.insert("Refresh-Token", value);
1139 } else {
1140 error!("Failed to build refresh token header.");
1141 }
1142 }
1143
1144 if let Some(ref csrf) = tokens.csrf_token {
1145 if let Ok(value) = HeaderValue::from_str(csrf) {
1146 base.insert("CSRF-Token", value);
1147 } else {
1148 error!("Failed to build refresh token header.");
1149 }
1150 }
1151 }
1152
1153 return base;
1154}
1155
1156fn now() -> u64 {
1157 return std::time::SystemTime::now()
1158 .duration_since(std::time::UNIX_EPOCH)
1159 .expect("Duration since epoch")
1160 .as_secs();
1161}
1162
1163#[inline]
1164async fn json<T: DeserializeOwned>(resp: http::Response<reqwest::Body>) -> Result<T, Error> {
1165 let full = into_bytes(resp).await?;
1166 return serde_json::from_slice(&full).map_err(Error::RecordSerialization);
1167}
1168
1169#[inline]
1170async fn into_bytes(resp: http::Response<reqwest::Body>) -> Result<bytes::Bytes, Error> {
1171 return Ok(
1172 http_body_util::BodyExt::collect(resp.into_body())
1173 .await
1174 .map(|buf| buf.to_bytes())?,
1175 );
1176}
1177
1178fn error_for_status_unpack(
1179 resp: http::Response<reqwest::Body>,
1180) -> Result<http::Response<reqwest::Body>, Error> {
1181 let status = resp.status();
1182 if status.is_client_error() || status.is_server_error() {
1183 return Err(Error::HttpStatus(status));
1184 }
1185 return Ok(resp);
1186}
1187
1188const AUTH_API: &str = "api/auth/v1";
1189const RECORD_API: &str = "api/records/v1";
1190
1191#[cfg(test)]
1192mod tests {
1193 use super::*;
1194
1195 #[tokio::test]
1196 async fn is_send_test() {
1197 let client = Client::new("http://127.0.0.1:4000", None).unwrap();
1198
1199 let api = client.records("simple_strict_table");
1200
1201 for _ in 0..2 {
1202 let api = api.clone();
1203 tokio::spawn(async move {
1204 let response = api.read::<serde_json::Value>(0).await;
1206 assert!(response.is_err());
1207 })
1208 .await
1209 .unwrap();
1210 }
1211 }
1212
1213 #[test]
1214 fn parse_change_event_test() {
1215 let ev0 = ChangeEvent::from_str(
1216 r#"
1217 {
1218 "Error": {
1219 "status": 1,
1220 "message": "test"
1221 },
1222 "seq": 3
1223 }"#,
1224 )
1225 .unwrap();
1226
1227 assert_eq!(ev0.seq, Some(3));
1228 let EventPayload::Error { status, message } = &*ev0.event else {
1229 panic!("expected error payload, got {:?}", ev0.event);
1230 };
1231
1232 assert_eq!(*status, EventErrorStatus::Forbidden);
1233 assert_eq!(message.as_deref().unwrap(), "test");
1234
1235 let ev1 = ChangeEvent::from_str(
1236 r#"
1237 {
1238 "Update": {
1239 "col0": "val0",
1240 "col1": 4
1241 }
1242 }"#,
1243 )
1244 .unwrap();
1245
1246 assert_eq!(ev1.seq, None);
1247 let EventPayload::Update(obj) = &*ev1.event else {
1248 panic!("expected update payload, got {:?}", ev1.event);
1249 };
1250
1251 assert_eq!(
1252 serde_json::Value::Object(obj.clone()),
1253 serde_json::json!({
1254 "col0": "val0",
1255 "col1": 4,
1256 })
1257 )
1258 }
1259}