1use std::marker::PhantomData;
2use std::ops::Deref;
3use sqlx::{Transaction, Connection, Executor, Database};
4use sqlx::pool::PoolConnection;
5use sqlx::Execute;
6pub use sqlx::Row as SqlRow;
7use crate::errors::{sql_err, SqlError, SqlErrorCode};
8
9pub trait ErrorMap: 'static + Clone + Send + Sync {
10 type OutError;
11 type InError;
12 fn map(e: Self::InError, msg: &str) -> Self::OutError;
13}
14
15#[macro_export]
16macro_rules! sql_query {
17 ($query:expr) => ({
18 sfo_sql::query!($query)
19 });
20
21 ($query:expr, $($args:tt)*) => ({
22 sfo_sql::query!($query, $($args)*)
23 })
24}
25
26pub struct SqlPool<DB: sqlx::Database, EM: ErrorMap<InError = sqlx::Error>>
27where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, {
28 pub(crate) pool: sqlx::pool::Pool<DB>,
29 pub(crate) uri: String,
30 pub(crate) _em: PhantomData<EM>,
31}
32
33impl<DB: sqlx::Database, EM: ErrorMap<InError = sqlx::Error>> Clone for SqlPool<DB, EM>
34where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, {
35
36 fn clone(&self) -> Self {
37 Self {
38 pool: self.pool.clone(),
39 uri: self.uri.clone(),
40 _em: self._em.clone()
41 }
42 }
43}
44
45impl <DB: sqlx::Database, EM: ErrorMap<InError = sqlx::Error>> Deref for SqlPool<DB, EM>
46where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, {
47 type Target = sqlx::pool::Pool<DB>;
48
49 fn deref(&self) -> &Self::Target {
50 &self.pool
51 }
52}
53
54impl<DB: sqlx::Database, EM: 'static + ErrorMap<InError = sqlx::Error>> SqlPool<DB, EM>
55where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, {
56 pub fn from_raw_pool(pool: sqlx::pool::Pool<DB>) -> Self {
57 Self { pool, uri: "".to_string(), _em: Default::default() }
58 }
59
60 pub async fn raw_pool(&self) -> sqlx::pool::Pool<DB> {
61 self.pool.clone()
62 }
63
64 pub async fn get_conn(&self) -> Result<SqlConnection<DB, EM>, EM::OutError> {
65 let conn = self.pool.acquire().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), self.uri.as_str()).as_str()))?;
66 Ok(SqlConnection::<DB, EM>::from(conn))
67 }
68}
69
70pub fn sql_query<DB: Database>(sql: &str) -> sqlx::query::Query<DB, DB::Arguments<'_>>
71where for<'b> <DB as sqlx::Database>::Arguments<'b>: sqlx::IntoArguments<'b, DB>,{
72 sqlx::query(sql)
73}
74
75pub fn sql_query_with<'a, DB: Database>(sql: &'a str, arguments: DB::Arguments<'a>) -> sqlx::query::Query<'a, DB, DB::Arguments<'a>>
76where for<'b> <DB as sqlx::Database>::Arguments<'b>: sqlx::IntoArguments<'b, DB>,{
77 sqlx::query_with(sql, arguments)
78}
79
80pub enum SqlConnectionType<DB: Database>
81where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>,{
82 PoolConn(PoolConnection<DB>),
83 Conn(DB::Connection),
84}
85pub struct SqlConnection<DB: Database, EM: ErrorMap<InError = sqlx::Error>>
86where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, {
87 pub(crate) trans: Option<Transaction<'static, DB>>,
88 pub(crate) conn: SqlConnectionType<DB>,
89 pub(crate) _em: PhantomData<EM>,
90}
91
92impl <DB: Database, EM: 'static + ErrorMap<InError = sqlx::Error>> From<sqlx::pool::PoolConnection<DB>> for SqlConnection<DB, EM>
93where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, {
94 fn from(conn: sqlx::pool::PoolConnection<DB>) -> Self {
95 Self { conn: SqlConnectionType::PoolConn(conn), _em: Default::default(), trans: None }
96 }
97}
98
99impl<DB: Database, EM: 'static + ErrorMap<InError = sqlx::Error>> SqlConnection<DB, EM>
100where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>,
101 for<'b> <DB as sqlx::Database>::Arguments<'b>: sqlx::IntoArguments<'b, DB>, {
102 pub async fn execute_sql<'a>(&mut self, query: sqlx::query::Query<'a, DB, <DB as Database>::Arguments<'a>>) -> Result<DB::QueryResult, EM::OutError>
103 {
104 let sql = query.sql();
105 match &mut self.conn {
106 SqlConnectionType::PoolConn(conn) => {
107 conn.execute(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
108 },
109 SqlConnectionType::Conn(conn) => {
110 conn.execute(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
111 }
112 }
113 }
114
115 pub async fn query_one<'a>(&mut self, query: sqlx::query::Query<'a, DB, DB::Arguments<'a>>) -> Result<DB::Row, EM::OutError> {
116 let sql = query.sql();
117 match &mut self.conn {
118 SqlConnectionType::PoolConn(conn) => {
119 conn.fetch_one(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
120 },
121 SqlConnectionType::Conn(conn) => {
122 conn.fetch_one(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
123 }
124 }
125 }
126
127 pub async fn query_all<'a>(&mut self, query: sqlx::query::Query<'a, DB, DB::Arguments<'a>>) -> Result<Vec<DB::Row>, EM::OutError> {
128 let sql = query.sql();
129 match &mut self.conn {
130 SqlConnectionType::PoolConn(conn) => {
131 conn.fetch_all(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
132 },
133 SqlConnectionType::Conn(conn) => {
134 conn.fetch_all(query).await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), sql).as_str()))
135 }
136 }
137 }
138
139 pub async fn begin_transaction(&mut self) -> Result<(), EM::OutError> {
140 let this: &'static mut Self = unsafe {std::mem::transmute(self)};
141 let trans = match &mut this.conn {
142 SqlConnectionType::PoolConn(conn) => {
143 conn.begin().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "begin trans").as_str()))
144 },
145 SqlConnectionType::Conn(conn) => {
146 conn.begin().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "begin trans").as_str()))
147 }
148 }?;
149 this.trans = Some(trans);
150 Ok(())
151 }
152
153 pub async fn rollback_transaction(&mut self) -> Result<(), EM::OutError> {
154 if self.trans.is_none() {
155 Ok(())
156 } else {
157 self.trans.take().unwrap().rollback().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "rollback trans").as_str()))
158 }
159 }
160
161 pub async fn commit_transaction(&mut self) -> Result<(), EM::OutError> {
162 if self.trans.is_none() {
163 return Ok(())
164 } else {
165 self.trans.take().unwrap().commit().await.map_err(|e| EM::map(e, format!("[{} {}]", line!(), "commit trans").as_str()))
166 }
167 }
168
169}
170
171impl<DB: sqlx::Database,EM: ErrorMap<InError=sqlx::Error>> Drop for SqlConnection<DB, EM>
172where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, {
173 fn drop(&mut self) {
174 if self.trans.is_some() {
175 let _ = self.trans.take();
176 }
177 }
178}