1#![forbid(unsafe_code, clippy::unwrap_used)]
7#![allow(clippy::needless_return)]
8#![warn(clippy::await_holding_lock, clippy::inefficient_to_string)]
9
10use eventsource_stream::Eventsource;
11pub use futures::Stream;
12use futures::StreamExt;
13use parking_lot::RwLock;
14use reqwest::header::{self, HeaderMap, HeaderValue};
15use reqwest::{Method, StatusCode};
16use std::borrow::Cow;
17use std::sync::Arc;
18use thiserror::Error;
19use tracing::*;
20
21use serde::de::DeserializeOwned;
22use serde::{Deserialize, Serialize};
23
24#[derive(Debug, Error)]
26#[non_exhaustive]
27pub enum Error {
28 #[error("HTTP status: {0}")]
29 HttpStatus(StatusCode),
30
31 #[error("MissingRefreshToken")]
32 MissingRefreshToken,
33
34 #[error("RecordSerialization: {0}")]
35 RecordSerialization(serde_json::Error),
36
37 #[error("InvalidToken: {0}")]
38 InvalidToken(jsonwebtoken::errors::Error),
39
40 #[error("InvalidUrl: {0}")]
41 InvalidUrl(url::ParseError),
42
43 #[error("Reqwest: {0}")]
45 OtherReqwest(reqwest::Error),
46}
47
48impl From<reqwest::Error> for Error {
49 fn from(err: reqwest::Error) -> Self {
50 match err.status() {
51 Some(code) => Self::HttpStatus(code),
52 _ => Self::OtherReqwest(err),
53 }
54 }
55}
56
57#[derive(Clone, Debug)]
59pub struct User {
60 pub sub: String,
61 pub email: String,
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
68pub struct Tokens {
69 pub auth_token: String,
70 pub refresh_token: Option<String>,
71 pub csrf_token: Option<String>,
72}
73
74#[derive(Clone, Debug, Default)]
75pub struct Pagination {
76 pub cursor: Option<String>,
77 pub limit: Option<usize>,
78}
79
80impl Pagination {
81 pub fn with(limit: impl Into<Option<usize>>, cursor: impl Into<Option<String>>) -> Pagination {
82 return Pagination {
83 limit: limit.into(),
84 cursor: cursor.into(),
85 };
86 }
87
88 pub fn with_limit(limit: impl Into<Option<usize>>) -> Pagination {
89 return Pagination::with(limit, None);
90 }
91
92 pub fn with_cursor(cursor: impl Into<Option<String>>) -> Pagination {
93 return Pagination::with(None, cursor);
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
98pub enum DbEvent {
99 Update(Option<serde_json::Value>),
100 Insert(Option<serde_json::Value>),
101 Delete(Option<serde_json::Value>),
102 Error(String),
103}
104
105#[derive(Clone, Debug, Deserialize)]
106pub struct ListResponse<T> {
107 pub cursor: Option<String>,
108 pub total_count: Option<usize>,
109 pub records: Vec<T>,
110}
111
112pub trait RecordId<'a> {
113 fn serialized_id(self) -> Cow<'a, str>;
114}
115
116impl RecordId<'_> for String {
117 fn serialized_id(self) -> Cow<'static, str> {
118 return Cow::Owned(self);
119 }
120}
121
122impl<'a> RecordId<'a> for &'a String {
123 fn serialized_id(self) -> Cow<'a, str> {
124 return Cow::Borrowed(self);
125 }
126}
127
128impl<'a> RecordId<'a> for &'a str {
129 fn serialized_id(self) -> Cow<'a, str> {
130 return Cow::Borrowed(self);
131 }
132}
133
134impl RecordId<'_> for i64 {
135 fn serialized_id(self) -> Cow<'static, str> {
136 return Cow::Owned(self.to_string());
137 }
138}
139
140pub trait ReadArgumentsTrait<'a> {
141 fn serialized_id(self) -> Cow<'a, str>;
142 fn expand(&self) -> Option<&'a [&'a str]>;
143}
144
145impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for T {
146 fn serialized_id(self) -> Cow<'a, str> {
147 return self.serialized_id();
148 }
149
150 fn expand(&self) -> Option<&'a [&'a str]> {
151 return None;
152 }
153}
154
155#[derive(Debug, Default)]
156pub struct ReadArguments<'a, T: RecordId<'a>> {
157 pub id: T,
158 pub expand: Option<&'a [&'a str]>,
159}
160
161impl<'a, T: RecordId<'a>> ReadArgumentsTrait<'a> for ReadArguments<'a, T> {
162 fn serialized_id(self) -> Cow<'a, str> {
163 return self.id.serialized_id();
164 }
165
166 fn expand(&self) -> Option<&'a [&'a str]> {
167 return self.expand;
168 }
169}
170
171struct ThinClient {
172 client: reqwest::Client,
173 url: url::Url,
174}
175
176impl ThinClient {
177 async fn fetch<T: Serialize>(
178 &self,
179 path: &str,
180 headers: HeaderMap,
181 method: Method,
182 body: Option<&T>,
183 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
184 ) -> Result<reqwest::Response, Error> {
185 assert!(path.starts_with("/"));
186
187 let mut url = self.url.clone();
188 url.set_path(path);
189
190 if let Some(query_params) = query_params {
191 let mut params = url.query_pairs_mut();
192 for (key, value) in query_params {
193 params.append_pair(key, value);
194 }
195 }
196
197 let request = {
198 let mut builder = self.client.request(method, url).headers(headers);
199 if let Some(ref body) = body {
200 let json = serde_json::to_string(body).map_err(Error::RecordSerialization)?;
201 builder = builder.body(json);
202 }
203 builder.build()?
204 };
205
206 return Ok(self.client.execute(request).await?);
207 }
208}
209
210#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
211struct JwtTokenClaims {
212 sub: String,
213 iat: i64,
214 exp: i64,
215 email: String,
216 csrf_token: String,
217}
218
219fn decode_auth_token<T: DeserializeOwned>(token: &str) -> Result<T, Error> {
220 let decoding_key = jsonwebtoken::DecodingKey::from_secret(&[]);
221
222 let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::EdDSA);
224 validation.insecure_disable_signature_validation();
225
226 return jsonwebtoken::decode::<T>(token, &decoding_key, &validation)
227 .map(|data| data.claims)
228 .map_err(Error::InvalidToken);
229}
230
231#[derive(Clone)]
232pub struct RecordApi {
233 client: Arc<ClientState>,
234 name: String,
235}
236
237#[derive(Default)]
238pub struct ListArguments<'a> {
239 pub pagination: Pagination,
240 pub order: Option<&'a [&'a str]>,
241 pub filters: Option<&'a [&'a str]>,
242 pub expand: Option<&'a [&'a str]>,
243 pub count: bool,
244}
245
246impl RecordApi {
247 pub async fn list<T: DeserializeOwned>(
248 &self,
249 args: ListArguments<'_>,
250 ) -> Result<ListResponse<T>, Error> {
251 let mut params: Vec<(Cow<'static, str>, Cow<'static, str>)> = vec![];
252 if let Some(cursor) = args.pagination.cursor {
253 params.push((Cow::Borrowed("cursor"), Cow::Owned(cursor)));
254 }
255
256 if let Some(limit) = args.pagination.limit {
257 params.push((Cow::Borrowed("limit"), Cow::Owned(limit.to_string())));
258 }
259
260 #[inline]
261 fn to_list(slice: &[&str]) -> String {
262 return slice.join(",");
263 }
264
265 if let Some(order) = args.order {
266 if !order.is_empty() {
267 params.push((Cow::Borrowed("order"), Cow::Owned(to_list(order))));
268 }
269 }
270
271 if let Some(expand) = args.expand {
272 if !expand.is_empty() {
273 params.push((Cow::Borrowed("expand"), Cow::Owned(to_list(expand))));
274 }
275 }
276
277 if args.count {
278 params.push((Cow::Borrowed("count"), Cow::Borrowed("true")));
279 }
280
281 if let Some(filters) = args.filters {
282 for filter in filters {
283 let Some((name_op, value)) = filter.split_once("=") else {
284 panic!("Filter '{filter}' does not match: 'name[op]=value'");
285 };
286
287 params.push((
288 Cow::Owned(name_op.to_string()),
289 Cow::Owned(value.to_string()),
290 ));
291 }
292 }
293
294 let response = self
295 .client
296 .fetch(
297 &format!("/{RECORD_API}/{}", self.name),
298 Method::GET,
299 None::<&()>,
300 Some(¶ms),
301 )
302 .await?;
303
304 return json(response).await;
305 }
306
307 pub async fn read<'a, T: DeserializeOwned>(
308 &self,
309 args: impl ReadArgumentsTrait<'a>,
310 ) -> Result<T, Error> {
311 let expand = args
312 .expand()
313 .map(|e| vec![(Cow::Borrowed("expand"), Cow::Owned(e.join(",")))]);
314
315 let response = self
316 .client
317 .fetch(
318 &format!(
319 "/{RECORD_API}/{name}/{id}",
320 name = self.name,
321 id = args.serialized_id()
322 ),
323 Method::GET,
324 None::<&()>,
325 expand.as_deref(),
326 )
327 .await?;
328
329 return json(response).await;
330 }
331
332 pub async fn create<T: Serialize>(&self, record: T) -> Result<String, Error> {
333 return Ok(self.create_impl(record).await?.swap_remove(0));
334 }
335
336 pub async fn create_bulk<T: Serialize>(&self, record: &[T]) -> Result<Vec<String>, Error> {
337 return self.create_impl(record).await;
338 }
339
340 async fn create_impl<T: Serialize>(&self, record: T) -> Result<Vec<String>, Error> {
341 let response = self
342 .client
343 .fetch(
344 &format!("/{RECORD_API}/{name}", name = self.name),
345 Method::POST,
346 Some(&record),
347 None,
348 )
349 .await?;
350
351 #[derive(Deserialize)]
352 pub struct RecordIdResponse {
353 pub ids: Vec<String>,
354 }
355
356 return Ok(json::<RecordIdResponse>(response).await?.ids);
357 }
358
359 pub async fn update<'a, T: Serialize>(
360 &self,
361 id: impl RecordId<'a>,
362 record: T,
363 ) -> Result<(), Error> {
364 self
365 .client
366 .fetch(
367 &format!(
368 "/{RECORD_API}/{name}/{id}",
369 name = self.name,
370 id = id.serialized_id()
371 ),
372 Method::PATCH,
373 Some(&record),
374 None,
375 )
376 .await?;
377
378 return Ok(());
379 }
380
381 pub async fn delete<'a>(&self, id: impl RecordId<'a>) -> Result<(), Error> {
382 self
383 .client
384 .fetch(
385 &format!(
386 "/{RECORD_API}/{name}/{id}",
387 name = self.name,
388 id = id.serialized_id()
389 ),
390 Method::DELETE,
391 None::<&()>,
392 None,
393 )
394 .await?;
395
396 return Ok(());
397 }
398
399 pub async fn subscribe<'a>(
400 &self,
401 id: impl RecordId<'a>,
402 ) -> Result<impl Stream<Item = DbEvent>, Error> {
403 let response = self
405 .client
406 .fetch(
407 &format!(
408 "/{RECORD_API}/{name}/subscribe/{id}",
409 name = self.name,
410 id = id.serialized_id()
411 ),
412 Method::GET,
413 None::<&()>,
414 None,
415 )
416 .await?;
417
418 return Ok(
419 response
420 .bytes_stream()
421 .eventsource()
422 .filter_map(|event_or| async {
423 if let Ok(event) = event_or {
424 if let Ok(db_event) = serde_json::from_str::<DbEvent>(&event.data) {
425 return Some(db_event);
426 }
427 }
428 return None;
429 }),
430 );
431 }
432}
433
434#[derive(Clone, Debug)]
435struct TokenState {
436 state: Option<(Tokens, JwtTokenClaims)>,
437 headers: HeaderMap,
438}
439
440impl TokenState {
441 fn build(tokens: Option<&Tokens>) -> TokenState {
442 let headers = build_headers(tokens);
443 return TokenState {
444 state: tokens.and_then(|tokens| {
445 let Ok(jwt_token) = decode_auth_token::<JwtTokenClaims>(&tokens.auth_token) else {
446 error!("Failed to decode auth token.");
447 return None;
448 };
449 return Some((tokens.clone(), jwt_token));
450 }),
451 headers,
452 };
453 }
454}
455
456struct ClientState {
457 client: ThinClient,
458 site: String,
459 tokens: RwLock<TokenState>,
460}
461
462impl ClientState {
463 #[inline]
464 async fn fetch<T: Serialize>(
465 &self,
466 path: &str,
467 method: Method,
468 body: Option<&T>,
469 query_params: Option<&[(Cow<'static, str>, Cow<'static, str>)]>,
470 ) -> Result<reqwest::Response, Error> {
471 let (mut headers, refresh_token) = self.extract_headers_and_refresh_token_if_exp();
472 if let Some(refresh_token) = refresh_token {
473 let new_tokens = ClientState::refresh_tokens(&self.client, headers, refresh_token).await?;
474
475 headers = new_tokens.headers.clone();
476 *self.tokens.write() = new_tokens;
477 }
478
479 return Ok(
480 self
481 .client
482 .fetch(path, headers, method, body, query_params)
483 .await?
484 .error_for_status()?,
485 );
486 }
487
488 #[inline]
489 fn extract_headers_and_refresh_token_if_exp(&self) -> (HeaderMap, Option<String>) {
490 #[inline]
491 fn should_refresh(jwt: &JwtTokenClaims) -> bool {
492 return jwt.exp - 60 < now() as i64;
493 }
494
495 let tokens = self.tokens.read();
496 let headers = tokens.headers.clone();
497 return match tokens.state {
498 Some(ref state) if should_refresh(&state.1) => (headers, state.0.refresh_token.clone()),
499 _ => (headers, None),
500 };
501 }
502
503 fn extract_headers_refresh_token(&self) -> Option<(HeaderMap, String)> {
504 let tokens = self.tokens.read();
505 let state = tokens.state.as_ref()?;
506
507 if let Some(ref refresh_token) = state.0.refresh_token {
508 return Some((tokens.headers.clone(), refresh_token.clone()));
509 }
510 return None;
511 }
512
513 async fn refresh_tokens(
514 client: &ThinClient,
515 headers: HeaderMap,
516 refresh_token: String,
517 ) -> Result<TokenState, Error> {
518 #[derive(Serialize)]
519 struct RefreshRequest<'a> {
520 refresh_token: &'a str,
521 }
522
523 let response = client
524 .fetch(
525 &format!("/{AUTH_API}/refresh"),
526 headers,
527 Method::POST,
528 Some(&RefreshRequest {
529 refresh_token: &refresh_token,
530 }),
531 None,
532 )
533 .await?;
534
535 #[derive(Deserialize)]
536 struct RefreshResponse {
537 auth_token: String,
538 csrf_token: Option<String>,
539 }
540
541 let refresh_response: RefreshResponse = json(response).await?;
542 return Ok(TokenState::build(Some(&Tokens {
543 auth_token: refresh_response.auth_token,
544 refresh_token: Some(refresh_token),
545 csrf_token: refresh_response.csrf_token,
546 })));
547 }
548}
549
550#[derive(Clone)]
551pub struct Client {
552 state: Arc<ClientState>,
553}
554
555impl Client {
556 pub fn new(site: &str, tokens: Option<Tokens>) -> Result<Client, Error> {
557 return Ok(Client {
558 state: Arc::new(ClientState {
559 client: ThinClient {
560 client: reqwest::Client::new(),
561 url: url::Url::parse(site).map_err(Error::InvalidUrl)?,
562 },
563 site: site.to_string(),
564 tokens: RwLock::new(TokenState::build(tokens.as_ref())),
565 }),
566 });
567 }
568
569 pub fn site(&self) -> String {
570 return self.state.site.clone();
571 }
572
573 pub fn tokens(&self) -> Option<Tokens> {
574 return self.state.tokens.read().state.as_ref().map(|x| x.0.clone());
575 }
576
577 pub fn user(&self) -> Option<User> {
578 if let Some(state) = &self.state.tokens.read().state {
579 return Some(User {
580 sub: state.1.sub.clone(),
581 email: state.1.email.clone(),
582 });
583 }
584 return None;
585 }
586
587 pub fn records(&self, api_name: &str) -> RecordApi {
588 return RecordApi {
589 client: self.state.clone(),
590 name: api_name.to_string(),
591 };
592 }
593
594 pub async fn refresh(&self) -> Result<(), Error> {
595 let Some((headers, refresh_token)) = self.state.extract_headers_refresh_token() else {
596 return Err(Error::MissingRefreshToken);
597 };
598
599 let new_tokens =
600 ClientState::refresh_tokens(&self.state.client, headers, refresh_token).await?;
601
602 *self.state.tokens.write() = new_tokens;
603 return Ok(());
604 }
605
606 pub async fn login(&self, email: &str, password: &str) -> Result<Tokens, Error> {
607 #[derive(Serialize)]
608 struct Credentials<'a> {
609 email: &'a str,
610 password: &'a str,
611 }
612
613 let response = self
614 .state
615 .fetch(
616 &format!("/{AUTH_API}/login"),
617 Method::POST,
618 Some(&Credentials { email, password }),
619 None,
620 )
621 .await?;
622
623 let tokens: Tokens = json(response).await?;
624 self.update_tokens(Some(&tokens));
625 return Ok(tokens);
626 }
627
628 pub async fn logout(&self) -> Result<(), Error> {
629 #[derive(Serialize)]
630 struct LogoutRequest {
631 refresh_token: String,
632 }
633
634 let response_or = match self.state.extract_headers_refresh_token() {
635 Some((_headers, refresh_token)) => {
636 self
637 .state
638 .fetch(
639 &format!("/{AUTH_API}/logout"),
640 Method::POST,
641 Some(&LogoutRequest { refresh_token }),
642 None,
643 )
644 .await
645 }
646 _ => {
647 self
648 .state
649 .fetch(
650 &format!("/{AUTH_API}/logout"),
651 Method::GET,
652 None::<&()>,
653 None,
654 )
655 .await
656 }
657 };
658
659 self.update_tokens(None);
660
661 return response_or.map(|_| ());
662 }
663
664 fn update_tokens(&self, tokens: Option<&Tokens>) -> TokenState {
665 let state = TokenState::build(tokens);
666
667 *self.state.tokens.write() = state.clone();
668 if let Some(ref s) = state.state {
671 let now = now();
672 if s.1.exp < now as i64 {
673 warn!("Token expired");
674 }
675 }
676
677 return state;
678 }
679}
680
681fn build_headers(tokens: Option<&Tokens>) -> HeaderMap {
682 let mut base = HeaderMap::with_capacity(5);
683 base.insert(
684 header::CONTENT_TYPE,
685 HeaderValue::from_static("application/json"),
686 );
687
688 if let Some(tokens) = tokens {
689 if let Ok(value) = HeaderValue::from_str(&format!("Bearer {}", tokens.auth_token)) {
690 base.insert(header::AUTHORIZATION, value);
691 } else {
692 error!("Failed to build bearer token.");
693 }
694
695 if let Some(ref refresh) = tokens.refresh_token {
696 if let Ok(value) = HeaderValue::from_str(refresh) {
697 base.insert("Refresh-Token", value);
698 } else {
699 error!("Failed to build refresh token header.");
700 }
701 }
702
703 if let Some(ref csrf) = tokens.csrf_token {
704 if let Ok(value) = HeaderValue::from_str(csrf) {
705 base.insert("CSRF-Token", value);
706 } else {
707 error!("Failed to build refresh token header.");
708 }
709 }
710 }
711
712 return base;
713}
714
715fn now() -> u64 {
716 return std::time::SystemTime::now()
717 .duration_since(std::time::UNIX_EPOCH)
718 .expect("Duration since epoch")
719 .as_secs();
720}
721
722#[inline]
723async fn json<T: DeserializeOwned>(resp: reqwest::Response) -> Result<T, Error> {
724 let full = resp.bytes().await?;
725 return serde_json::from_slice(&full).map_err(Error::RecordSerialization);
726}
727
728const AUTH_API: &str = "api/auth/v1";
729const RECORD_API: &str = "api/records/v1";
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734
735 #[tokio::test]
736 async fn is_send_test() {
737 let client = Client::new("http://127.0.0.1:4000", None).unwrap();
738
739 let api = client.records("simple_strict_table");
740
741 for _ in 0..2 {
742 let api = api.clone();
743 tokio::spawn(async move {
744 let response = api.read::<serde_json::Value>(0).await;
746 assert!(response.is_err());
747 })
748 .await
749 .unwrap();
750 }
751 }
752}