1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
pub use sqlx_models_derive::model;
pub use async_trait::async_trait;
pub use sqlx;

pub trait SqlxModel: Send + Sync + Sized {
  type State: Send + Sync;
  type Id: Send + Sync;
  type ModelHub: SqlxModelHub<Self>;
  type SelectModelHub: SqlxSelectModelHub<Self>;
  type SelectModel: std::default::Default + Send;
  type ModelOrderBy: std::fmt::Debug + Send;
}

#[async_trait]
pub trait SqlxModelHub<Model: SqlxModel>: Send + Sync + Sized {
  fn from_state(state: Model::State) -> Self;
  fn select(&self) -> Model::SelectModelHub;
  async fn find(&self, id: &Model::Id) -> sqlx::Result<Model>;
  async fn find_optional(&self, id: &Model::Id) -> sqlx::Result<Option<Model>>;
}

#[async_trait]
pub trait SqlxSelectModelHub<Model: SqlxModel>: Send + Sync + Sized {
  fn from_state(state: Model::State) -> Self;
  fn order_by(self, val: Model::ModelOrderBy) -> Self;
  fn maybe_order_by(self, val: Option<Model::ModelOrderBy>) -> Self;
  fn desc(self, val: bool) -> Self;
  fn limit(self, val: i64) -> Self;
  fn offset(self, val: i64) -> Self;
  fn use_struct(self, value: Model::SelectModel) -> Self;
  async fn all(&self) -> sqlx::Result<Vec<Model>>;
  async fn count(&self) -> sqlx::Result<i64>;
  async fn one(&self) -> sqlx::Result<Model>;
  async fn optional(&self) -> sqlx::Result<Option<Model>>;
}

pub use sqlx::{Postgres, Transaction, query::*, postgres::*, Error};

pub type PgTx = Option<std::sync::Arc<futures_util::lock::Mutex<Option<Transaction<'static, Postgres>>>>>;
pub type PgQuery<'q> = Query<'q, Postgres, PgArguments>;
pub type PgMap<'q, O> = Map<'q, Postgres, O, PgArguments>;
pub type PgQueryScalar<'q, O> = QueryScalar<'q, Postgres, O, PgArguments>;

#[derive(Clone, Debug)]
pub struct Db {
  pub pool: PgPool,
  pub transaction: PgTx,
}

macro_rules! choose_executor {
  ($self:ident, $query:ident, $method:ident) => ({
    if let Some(a) = $self.transaction.as_ref() {
      let mut mutex = a.lock().await;
      if let Some(tx) = &mut *mutex {
        return $query.$method(&mut *tx).await;
      }
    }
    $query.$method(&$self.pool).await
  })
}

macro_rules! define_query_method {
  ($method:ident, $return:ty) => (
    pub async fn $method<'a, T, F>(&self, query: PgMap<'a, F>) -> sqlx::Result<$return>
      where
        F: FnMut(sqlx::postgres::PgRow) -> Result<T, Error> + Send,
        T: Unpin + Send,
    {
      choose_executor!(self, query, $method)
    }
  )
}

macro_rules! define_query_scalar_method {
  ($method:ident, $inner_method:ident, $return:ty) => (
    pub async fn $method<'a, T>(&self, query: PgQueryScalar<'a, T>) -> sqlx::Result<$return>
      where
        (T,): for<'r> sqlx::FromRow<'r, PgRow>,
        T: Unpin + Send,
    {
      choose_executor!(self, query, $inner_method)
    }
  )
}

impl Db {
  pub async fn connect(connection_string: &str) -> sqlx::Result<Self> {
    let pool = PgPoolOptions::new().connect(connection_string).await?;
    Ok(Self{ pool, transaction: None })
  } 

  pub async fn transaction(&self) -> sqlx::Result<Self> {
    let tx = self.pool.begin().await?;
    Ok(Self{ pool: self.pool.clone(), transaction: Some(std::sync::Arc::new(futures_util::lock::Mutex::new(Some(tx))))})
  }

  pub async fn execute<'a>(&self, query: PgQuery<'a>) -> sqlx::Result<PgQueryResult> {
    choose_executor!(self, query, execute)
  }

  define_query_method!{fetch_one, T}
  define_query_method!{fetch_all, Vec<T>}
  define_query_method!{fetch_optional, Option<T>}

  define_query_scalar_method!{fetch_one_scalar, fetch_one, T}
  define_query_scalar_method!{fetch_all_scalar, fetch_all, Vec<T>}
  define_query_scalar_method!{fetch_optional_scalar, fetch_optional, Option<T>}

  pub async fn commit(&self) -> sqlx::Result<()> {
    if let Some(arc) = self.transaction.as_ref() {
      let mut mutex = arc.lock().await;
      let maybe_tx = (&mut *mutex).take();
      if let Some(tx) = maybe_tx {
        tx.commit().await?;
      }
    }
    Ok(())
  }
}