1pub use async_trait::async_trait;
2pub use sqlx;
3pub use sqlx_models_derive::model;
4use std::ops::DerefMut;
5
6pub trait SqlxModel: Send + Sync + Sized {
7 type State: Send + Sync;
8 type Id: Send + Sync;
9 type ModelHub: SqlxModelHub<Self>;
10 type SelectModelHub: SqlxSelectModelHub<Self>;
11 type SelectModel: std::default::Default + Send;
12 type ModelOrderBy: std::fmt::Debug + Send;
13}
14
15#[async_trait]
16pub trait SqlxModelHub<Model: SqlxModel>: Send + Sync + Sized {
17 fn from_state(state: Model::State) -> Self;
18 fn select(&self) -> Model::SelectModelHub;
19 async fn find(&self, id: &Model::Id) -> sqlx::Result<Model>;
20 async fn find_optional(&self, id: &Model::Id) -> sqlx::Result<Option<Model>>;
21}
22
23#[async_trait]
24pub trait SqlxSelectModelHub<Model: SqlxModel>: Send + Sync + Sized {
25 fn from_state(state: Model::State) -> Self;
26 fn order_by(self, val: Model::ModelOrderBy) -> Self;
27 fn maybe_order_by(self, val: Option<Model::ModelOrderBy>) -> Self;
28 fn desc(self, val: bool) -> Self;
29 fn limit(self, val: i64) -> Self;
30 fn offset(self, val: i64) -> Self;
31 fn use_struct(self, value: Model::SelectModel) -> Self;
32 async fn all(&self) -> sqlx::Result<Vec<Model>>;
33 async fn count(&self) -> sqlx::Result<i64>;
34 async fn one(&self) -> sqlx::Result<Model>;
35 async fn optional(&self) -> sqlx::Result<Option<Model>>;
36}
37
38pub use sqlx::{postgres::*, query::*, Error, Postgres, Transaction};
39
40pub type PgTx =
41 Option<std::sync::Arc<futures_util::lock::Mutex<Option<Transaction<'static, Postgres>>>>>;
42pub type PgQuery<'q> = Query<'q, Postgres, PgArguments>;
43pub type PgMap<'q, O> = Map<'q, Postgres, O, PgArguments>;
44pub type PgQueryScalar<'q, O> = QueryScalar<'q, Postgres, O, PgArguments>;
45
46#[derive(Clone, Debug)]
47pub struct Db {
48 pub pool: PgPool,
49 pub transaction: PgTx,
50}
51
52macro_rules! choose_executor {
53 ($self:ident, $query:ident, $method:ident) => {{
54 if let Some(a) = $self.transaction.as_ref() {
55 let mut mutex = a.lock().await;
56 if let Some(tx) = &mut *mutex {
57 return $query.$method(tx.deref_mut()).await;
58 }
59 }
60 $query.$method(&$self.pool).await
61 }};
62}
63
64macro_rules! define_query_method {
65 ($method:ident, $return:ty) => {
66 pub async fn $method<'a, T, F>(&self, query: PgMap<'a, F>) -> sqlx::Result<$return>
67 where
68 F: FnMut(sqlx::postgres::PgRow) -> Result<T, Error> + Send,
69 T: Unpin + Send,
70 {
71 choose_executor!(self, query, $method)
72 }
73 };
74}
75
76macro_rules! define_query_scalar_method {
77 ($method:ident, $inner_method:ident, $return:ty) => {
78 pub async fn $method<'a, T>(&self, query: PgQueryScalar<'a, T>) -> sqlx::Result<$return>
79 where
80 (T,): for<'r> sqlx::FromRow<'r, PgRow>,
81 T: Unpin + Send,
82 {
83 choose_executor!(self, query, $inner_method)
84 }
85 };
86}
87
88impl Db {
89 pub async fn connect(connection_string: &str) -> sqlx::Result<Self> {
90 let pool = PgPoolOptions::new().connect(connection_string).await?;
91 Ok(Self {
92 pool,
93 transaction: None,
94 })
95 }
96
97 pub async fn transaction(&self) -> sqlx::Result<Self> {
98 let tx = self.pool.begin().await?;
99 Ok(Self {
100 pool: self.pool.clone(),
101 transaction: Some(std::sync::Arc::new(futures_util::lock::Mutex::new(Some(
102 tx,
103 )))),
104 })
105 }
106
107 pub async fn execute<'a>(&self, query: PgQuery<'a>) -> sqlx::Result<PgQueryResult> {
108 choose_executor!(self, query, execute)
109 }
110
111 define_query_method! {fetch_one, T}
112 define_query_method! {fetch_all, Vec<T>}
113 define_query_method! {fetch_optional, Option<T>}
114
115 define_query_scalar_method! {fetch_one_scalar, fetch_one, T}
116 define_query_scalar_method! {fetch_all_scalar, fetch_all, Vec<T>}
117 define_query_scalar_method! {fetch_optional_scalar, fetch_optional, Option<T>}
118
119 pub async fn commit(&self) -> sqlx::Result<()> {
120 if let Some(arc) = self.transaction.as_ref() {
121 let mut mutex = arc.lock().await;
122 let maybe_tx = (*mutex).take();
123 if let Some(tx) = maybe_tx {
124 tx.commit().await?;
125 }
126 }
127 Ok(())
128 }
129}