sfo_sql/
db_helper.rs

1use std::marker::PhantomData;
2use std::ops::Deref;
3use sqlx::{Transaction, Connection, Executor, Database};
4use sqlx::pool::PoolConnection;
5pub use sqlx::Row as SqlRow;
6
7pub trait ErrorMap: 'static + Clone + Send + Sync {
8    type OutError;
9    type InError;
10    fn map(e: Self::InError, msg: &str) -> Self::OutError;
11}
12
13#[macro_export]
14macro_rules! sql_query {
15    ($query:expr) => ({
16        sfo_sql::query!($query)
17    });
18
19    ($query:expr, $($args:tt)*) => ({
20        sfo_sql::query!($query, $($args)*)
21    })
22}
23
24pub struct SqlPool<DB: sqlx::Database, EM: ErrorMap<InError = sqlx::Error>>
25where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, {
26    pub(crate) pool: sqlx::pool::Pool<DB>,
27    pub(crate) uri: String,
28    pub(crate) _em: PhantomData<EM>,
29}
30
31impl<DB: sqlx::Database, EM: ErrorMap<InError = sqlx::Error>> Clone for SqlPool<DB, EM>
32where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, {
33
34    fn clone(&self) -> Self {
35        Self {
36            pool: self.pool.clone(),
37            uri: self.uri.clone(),
38            _em: self._em.clone()
39        }
40    }
41}
42
43impl <DB: sqlx::Database, EM: ErrorMap<InError = sqlx::Error>> Deref for SqlPool<DB, EM>
44where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, {
45    type Target = sqlx::pool::Pool<DB>;
46
47    fn deref(&self) -> &Self::Target {
48        &self.pool
49    }
50}
51
52impl<DB: sqlx::Database, EM: 'static + ErrorMap<InError = sqlx::Error>> SqlPool<DB, EM>
53where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, {
54    pub fn from_raw_pool(pool: sqlx::pool::Pool<DB>) -> Self {
55        Self { pool, uri: "".to_string(), _em: Default::default() }
56    }
57
58    pub async fn raw_pool(&self) -> sqlx::pool::Pool<DB> {
59        self.pool.clone()
60    }
61
62    pub async fn get_conn(&self) -> Result<SqlConnection<DB, EM>, EM::OutError> {
63        let conn = self.pool.acquire().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), self.uri.as_str()).as_str()))?;
64        Ok(SqlConnection::<DB, EM>::from(conn))
65    }
66}
67
68pub fn sql_query<DB: Database>(sql: &str) -> sqlx::query::Query<'_, DB, DB::Arguments>
69where <DB as sqlx::Database>::Arguments: sqlx::IntoArguments<DB>,{
70    sqlx::query(sqlx::AssertSqlSafe(sql))
71}
72
73pub fn sql_query_with<DB: Database>(sql: &str, arguments: DB::Arguments) -> sqlx::query::Query<'_, DB, DB::Arguments>
74where <DB as sqlx::Database>::Arguments: sqlx::IntoArguments<DB>,{
75    sqlx::query_with(sqlx::AssertSqlSafe(sql), arguments)
76}
77
78pub enum SqlConnectionType<DB: Database>
79where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>,{
80    PoolConn(PoolConnection<DB>),
81    Conn(DB::Connection),
82}
83pub struct SqlConnection<DB: Database, EM: ErrorMap<InError = sqlx::Error>>
84where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, {
85    pub(crate) trans: Option<Transaction<'static, DB>>,
86    pub(crate) conn: SqlConnectionType<DB>,
87    pub(crate) _em: PhantomData<EM>,
88}
89
90impl <DB: Database, EM: 'static + ErrorMap<InError = sqlx::Error>> From<sqlx::pool::PoolConnection<DB>> for SqlConnection<DB, EM>
91where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, {
92    fn from(conn: sqlx::pool::PoolConnection<DB>) -> Self {
93        Self { conn: SqlConnectionType::PoolConn(conn), _em: Default::default(), trans: None }
94    }
95}
96
97impl<DB: Database, EM: 'static + ErrorMap<InError = sqlx::Error>> SqlConnection<DB, EM>
98where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>,
99      for<'b> <DB as sqlx::Database>::Arguments: sqlx::IntoArguments<DB>, {
100    pub async fn execute_sql<'a>(&mut self, query: sqlx::query::Query<'a, DB, <DB as Database>::Arguments>) -> Result<DB::QueryResult, EM::OutError>
101    {
102        match &mut self.conn {
103            SqlConnectionType::PoolConn(conn) => {
104                conn.execute(query).await.map_err(|e| EM::map(e, format!("[{}]", line!()).as_str()))
105            },
106            SqlConnectionType::Conn(conn) => {
107                conn.execute(query).await.map_err(|e| EM::map(e, format!("[{}]", line!()).as_str()))
108            }
109        }
110    }
111
112    pub async fn query_one<'a>(&mut self, query: sqlx::query::Query<'a, DB, DB::Arguments>) -> Result<DB::Row, EM::OutError> {
113        match &mut self.conn {
114            SqlConnectionType::PoolConn(conn) => {
115                conn.fetch_one(query).await.map_err(|e| EM::map(e, format!("[{}]", line!()).as_str()))
116            },
117            SqlConnectionType::Conn(conn) => {
118                conn.fetch_one(query).await.map_err(|e| EM::map(e, format!("[{}]", line!()).as_str()))
119            }
120        }
121    }
122
123    pub async fn query_all<'a>(&mut self, query: sqlx::query::Query<'a, DB, DB::Arguments>) -> Result<Vec<DB::Row>, EM::OutError> {
124        match &mut self.conn {
125            SqlConnectionType::PoolConn(conn) => {
126                conn.fetch_all(query).await.map_err(|e| EM::map(e, format!("[{}]", line!()).as_str()))
127            },
128            SqlConnectionType::Conn(conn) => {
129                conn.fetch_all(query).await.map_err(|e| EM::map(e, format!("[{}]", line!()).as_str()))
130            }
131        }
132    }
133
134    pub async fn begin_transaction(&mut self) -> Result<(), EM::OutError> {
135        let this: &'static mut Self = unsafe {std::mem::transmute(self)};
136        let trans = match &mut this.conn {
137            SqlConnectionType::PoolConn(conn) => {
138                conn.begin().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "begin trans").as_str()))
139            },
140            SqlConnectionType::Conn(conn) => {
141                conn.begin().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "begin trans").as_str()))
142            }
143        }?;
144        this.trans = Some(trans);
145        Ok(())
146    }
147
148    pub async fn rollback_transaction(&mut self) -> Result<(), EM::OutError> {
149        if self.trans.is_none() {
150            Ok(())
151        } else {
152            self.trans.take().unwrap().rollback().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "rollback trans").as_str()))
153        }
154    }
155
156    pub async fn commit_transaction(&mut self) -> Result<(), EM::OutError> {
157        if self.trans.is_none() {
158            return Ok(())
159        } else {
160            self.trans.take().unwrap().commit().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "commit trans").as_str()))
161        }
162    }
163
164}
165
166impl<DB: sqlx::Database,EM: ErrorMap<InError=sqlx::Error>> Drop for SqlConnection<DB, EM>
167where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, {
168    fn drop(&mut self) {
169        if self.trans.is_some() {
170            let _ = self.trans.take();
171        }
172    }
173}