sqlx_plus_core/
sqlx_base.rs1use std::sync::Once;
2
3use async_trait::async_trait;
4use futures::executor::block_on;
5use sqlx::{postgres::PgRow, query, query_as, Error, FromRow, Pool, Row, Transaction};
6
7#[cfg(feature = "mysql")]
8use sqlx::MySql;
9#[cfg(feature = "postgres")]
10use sqlx::Postgres;
11#[cfg(feature = "sqlite")]
12use sqlx::Sqlite;
13use tracing::debug;
14
15#[cfg(feature = "sqlite")]
18pub type DbPool = Pool<Sqlite>;
19#[cfg(feature = "postgres")]
20pub type DbPool = Pool<Postgres>;
21#[cfg(feature = "mysql")]
22pub type DbPool = Pool<MySql>;
23
24#[cfg(feature = "sqlite")]
25pub type DbRow = sqlx::sqlite::SqliteRow;
26#[cfg(feature = "postgres")]
27pub type DbRow = sqlx::postgres::PgRow;
28#[cfg(feature = "mysql")]
29pub type DbRow = sqlx::mysql::MySqlRow;
30
31#[cfg(feature = "sqlite")]
32pub type DbTransaction<'a> = Transaction<'a, Sqlite>;
33#[cfg(feature = "postgres")]
34pub type DbTransaction<'a> = Transaction<'a, Postgres>;
35#[cfg(feature = "mysql")]
36pub type DbTransaction<'a> = Transaction<'a, MySql>;
37
38#[derive(Debug)]
39pub enum DbType {
40 Sqlite,
41 Postgres,
42 Mysql,
43}
44
45pub fn get_db_type() -> DbType {
46 #[cfg(feature = "sqlite")]
47 {
48 return DbType::Sqlite;
49 }
50
51 #[cfg(feature = "postgres")]
52 {
53 return DbType::Postgres;
54 }
55
56 #[cfg(feature = "mysql")]
57 {
58 return DbType::Mysql;
59 }
60}
61
62static INIT: Once = Once::new();
63
64#[async_trait]
72pub trait SqlxBase {
73 fn pool(&self) -> &DbPool;
74
75 async fn create_table_if_not_exists(
76 &self,
77 table_name: &str,
78 create_sql: &str,
79 ) -> Result<(), Error> {
80 let db_type = get_db_type();
81 let table_exists_sql = match db_type {
82 DbType::Sqlite => format!(
83 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='{}'",
84 table_name
85 ),
86 DbType::Postgres => format!(
87 "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{}')",
88 table_name
89 ),
90 DbType::Mysql => format!(
91 "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = '{}'",
92 table_name
93 ),
94 };
95
96 let table_exists: (bool,) = query_as(&table_exists_sql).fetch_one(self.pool()).await?;
97 if !table_exists.0 {
98 query(create_sql).execute(self.pool()).await?;
99 }
100 Ok(())
101 }
102
103 async fn insert(
104 &self,
105 sql: &str,
106 transaction: Option<&mut DbTransaction<'_>>,
107 ) -> Result<u64, Error> {
108 let sql_lower = sql.to_lowercase();
109 if sql_lower.contains("returning") {
110 return Err(Error::Protocol("The 'insert' method does not support 'RETURNING' clause. Use 'insert_and_return_key' instead.".into()));
111 }
112 debug!("The 「Sqlx pPlus」「Insert)」 sql:{}", sql_lower);
113 let result = match transaction {
114 Some(tx) => query(sql).execute(&mut **tx).await?,
115 None => query(sql).execute(self.pool()).await?,
116 };
117 let rows_affected = result.rows_affected();
118 Ok(rows_affected)
119 }
120
121 async fn insert_and_return_key<T>(
122 &self,
123 sql: &str,
124 transaction: Option<&mut DbTransaction<'_>>,
125 ) -> Result<T, Error>
126 where
127 T: for<'r> FromRow<'r, DbRow> + Unpin + Send,
128 {
129 let sql_lower = sql.to_lowercase();
130 if !sql_lower.contains("returning") {
131 return Err(Error::Protocol(
132 "The 'insert_and_return_key' method requires a 'RETURNING' clause.".into(),
133 ));
134 }
135 let result = match transaction {
136 Some(tx) => query_as::<_, T>(&sql).fetch_one(&mut **tx).await?,
137 None => query_as::<_, T>(&sql).fetch_one(self.pool()).await?,
138 };
139 Ok(result)
140 }
141
142 async fn update(
143 &self,
144 sql: &str,
145 transaction: Option<&mut DbTransaction<'_>>,
146 ) -> Result<u64, Error> {
147 let result = match transaction {
148 Some(tx) => query(sql).execute(&mut **tx).await?,
149 None => query(sql).execute(self.pool()).await?,
150 };
151 Ok(result.rows_affected())
152 }
153
154 async fn delete(
155 &self,
156 sql: &str,
157 transaction: Option<&mut DbTransaction<'_>>,
158 ) -> Result<u64, Error> {
159 let result = match transaction {
160 Some(tx) => query(sql).execute(&mut **tx).await?,
161 None => query(sql).execute(self.pool()).await?,
162 };
163 Ok(result.rows_affected())
164 }
165
166 async fn select_one<T>(&self, sql: &str) -> Result<Option<T>, Error>
167 where
168 T: for<'r> FromRow<'r, DbRow> + Unpin + Send,
169 {
170 let result = query_as::<_, T>(sql).fetch_optional(self.pool()).await?;
171 Ok(result)
172 }
173
174 async fn query_one(&self, sql: &str) -> Result<PgRow, Error> {
175 let row = query(sql).fetch_one(self.pool()).await?;
176 Ok(row)
177 }
178
179 async fn select_all<T>(&self, sql: &str) -> Result<Vec<T>, Error>
180 where
181 T: for<'r> FromRow<'r, DbRow> + Unpin + Send,
182 {
183 let result = query_as::<_, T>(sql).fetch_all(self.pool()).await?;
184 Ok(result)
185 }
186
187 async fn select_page<T>(&self, sql: &str) -> Result<Vec<T>, Error>
188 where
189 T: for<'r> FromRow<'r, DbRow> + Unpin + Send,
190 {
191 let sql_lower = sql.to_lowercase();
192 if !sql_lower.contains("limit") || !sql_lower.contains("offset") {
193 return Err(Error::Protocol("SQL must contain LIMIT and OFFSET".into()));
194 }
195
196 let result = query_as::<_, T>(sql).fetch_all(self.pool()).await?;
197 Ok(result)
198 }
199
200 async fn count(&self, sql: &str) -> Result<i64, Error> {
201 let row = query(sql).fetch_one(self.pool()).await?;
202 let count: i64 = row.try_get(0)?;
203 Ok(count)
204 }
205
206 async fn health_check(&self) -> Result<bool, Error> {
207 let row: (i64,) = sqlx::query_as("SELECT $1")
208 .bind(150_i64)
209 .fetch_one(self.pool())
210 .await
211 .unwrap_or_else(|_| panic!("Sqlx Plus Failed to to execute query"));
212 Ok(row.0 == 150)
213 }
214
215 async fn begin(&self) -> Result<DbTransaction<'_>, Error> {
216 let transaction = self.pool().begin().await?;
217 Ok(transaction)
218 }
219
220 async fn commit(transaction: DbTransaction<'_>) -> Result<(), Error> {
221 transaction.commit().await?;
222 Ok(())
223 }
224
225 async fn rollback(transaction: DbTransaction<'_>) -> Result<(), Error> {
226 transaction.rollback().await?;
227 Ok(())
228 }
229}
230
231#[async_trait]
232impl SqlxBase for DbPool {
233 fn pool(&self) -> &DbPool {
234 self
236 }
237}