1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
pub use async_trait::async_trait;
pub use sqlx;
pub use sqlx_models_derive::model;
use std::ops::DerefMut;

pub trait SqlxModel: Send + Sync + Sized {
    type State: Send + Sync;
    type Id: Send + Sync;
    type ModelHub: SqlxModelHub<Self>;
    type SelectModelHub: SqlxSelectModelHub<Self>;
    type SelectModel: std::default::Default + Send;
    type ModelOrderBy: std::fmt::Debug + Send;
}

#[async_trait]
pub trait SqlxModelHub<Model: SqlxModel>: Send + Sync + Sized {
    fn from_state(state: Model::State) -> Self;
    fn select(&self) -> Model::SelectModelHub;
    async fn find(&self, id: &Model::Id) -> sqlx::Result<Model>;
    async fn find_optional(&self, id: &Model::Id) -> sqlx::Result<Option<Model>>;
}

#[async_trait]
pub trait SqlxSelectModelHub<Model: SqlxModel>: Send + Sync + Sized {
    fn from_state(state: Model::State) -> Self;
    fn order_by(self, val: Model::ModelOrderBy) -> Self;
    fn maybe_order_by(self, val: Option<Model::ModelOrderBy>) -> Self;
    fn desc(self, val: bool) -> Self;
    fn limit(self, val: i64) -> Self;
    fn offset(self, val: i64) -> Self;
    fn use_struct(self, value: Model::SelectModel) -> Self;
    async fn all(&self) -> sqlx::Result<Vec<Model>>;
    async fn count(&self) -> sqlx::Result<i64>;
    async fn one(&self) -> sqlx::Result<Model>;
    async fn optional(&self) -> sqlx::Result<Option<Model>>;
}

pub use sqlx::{postgres::*, query::*, Error, Postgres, Transaction};

pub type PgTx =
    Option<std::sync::Arc<futures_util::lock::Mutex<Option<Transaction<'static, Postgres>>>>>;
pub type PgQuery<'q> = Query<'q, Postgres, PgArguments>;
pub type PgMap<'q, O> = Map<'q, Postgres, O, PgArguments>;
pub type PgQueryScalar<'q, O> = QueryScalar<'q, Postgres, O, PgArguments>;

#[derive(Clone, Debug)]
pub struct Db {
    pub pool: PgPool,
    pub transaction: PgTx,
}

macro_rules! choose_executor {
    ($self:ident, $query:ident, $method:ident) => {{
        if let Some(a) = $self.transaction.as_ref() {
            let mut mutex = a.lock().await;
            if let Some(tx) = &mut *mutex {
                return $query.$method(tx.deref_mut()).await;
            }
        }
        $query.$method(&$self.pool).await
    }};
}

macro_rules! define_query_method {
    ($method:ident, $return:ty) => {
        pub async fn $method<'a, T, F>(&self, query: PgMap<'a, F>) -> sqlx::Result<$return>
        where
            F: FnMut(sqlx::postgres::PgRow) -> Result<T, Error> + Send,
            T: Unpin + Send,
        {
            choose_executor!(self, query, $method)
        }
    };
}

macro_rules! define_query_scalar_method {
    ($method:ident, $inner_method:ident, $return:ty) => {
        pub async fn $method<'a, T>(&self, query: PgQueryScalar<'a, T>) -> sqlx::Result<$return>
        where
            (T,): for<'r> sqlx::FromRow<'r, PgRow>,
            T: Unpin + Send,
        {
            choose_executor!(self, query, $inner_method)
        }
    };
}

impl Db {
    pub async fn connect(connection_string: &str) -> sqlx::Result<Self> {
        let pool = PgPoolOptions::new().connect(connection_string).await?;
        Ok(Self {
            pool,
            transaction: None,
        })
    }

    pub async fn transaction(&self) -> sqlx::Result<Self> {
        let tx = self.pool.begin().await?;
        Ok(Self {
            pool: self.pool.clone(),
            transaction: Some(std::sync::Arc::new(futures_util::lock::Mutex::new(Some(
                tx,
            )))),
        })
    }

    pub async fn execute<'a>(&self, query: PgQuery<'a>) -> sqlx::Result<PgQueryResult> {
        choose_executor!(self, query, execute)
    }

    define_query_method! {fetch_one, T}
    define_query_method! {fetch_all, Vec<T>}
    define_query_method! {fetch_optional, Option<T>}

    define_query_scalar_method! {fetch_one_scalar, fetch_one, T}
    define_query_scalar_method! {fetch_all_scalar, fetch_all, Vec<T>}
    define_query_scalar_method! {fetch_optional_scalar, fetch_optional, Option<T>}

    pub async fn commit(&self) -> sqlx::Result<()> {
        if let Some(arc) = self.transaction.as_ref() {
            let mut mutex = arc.lock().await;
            let maybe_tx = (*mutex).take();
            if let Some(tx) = maybe_tx {
                tx.commit().await?;
            }
        }
        Ok(())
    }
}