sqlx_plus_core/
sqlx_base.rs

1use std::sync::Once;
2
3use async_trait::async_trait;
4use futures::executor::block_on;
5use sqlx::{postgres::PgRow, query, query_as, Error, FromRow, Pool, Row, Transaction};
6
7#[cfg(feature = "mysql")]
8use sqlx::MySql;
9#[cfg(feature = "postgres")]
10use sqlx::Postgres;
11#[cfg(feature = "sqlite")]
12use sqlx::Sqlite;
13use tracing::debug;
14
15// use crate::logger;
16
17#[cfg(feature = "sqlite")]
18pub type DbPool = Pool<Sqlite>;
19#[cfg(feature = "postgres")]
20pub type DbPool = Pool<Postgres>;
21#[cfg(feature = "mysql")]
22pub type DbPool = Pool<MySql>;
23
24#[cfg(feature = "sqlite")]
25pub type DbRow = sqlx::sqlite::SqliteRow;
26#[cfg(feature = "postgres")]
27pub type DbRow = sqlx::postgres::PgRow;
28#[cfg(feature = "mysql")]
29pub type DbRow = sqlx::mysql::MySqlRow;
30
31#[cfg(feature = "sqlite")]
32pub type DbTransaction<'a> = Transaction<'a, Sqlite>;
33#[cfg(feature = "postgres")]
34pub type DbTransaction<'a> = Transaction<'a, Postgres>;
35#[cfg(feature = "mysql")]
36pub type DbTransaction<'a> = Transaction<'a, MySql>;
37
38#[derive(Debug)]
39pub enum DbType {
40    Sqlite,
41    Postgres,
42    Mysql,
43}
44
45pub fn get_db_type() -> DbType {
46    #[cfg(feature = "sqlite")]
47    {
48        return DbType::Sqlite;
49    }
50
51    #[cfg(feature = "postgres")]
52    {
53        return DbType::Postgres;
54    }
55
56    #[cfg(feature = "mysql")]
57    {
58        return DbType::Mysql;
59    }
60}
61
62static INIT: Once = Once::new();
63
64// fn initialize_logger() {
65//     INIT.call_once(|| {
66//         block_on(logger::run());
67//     });
68// }
69
70/// 定义 XPlusCore trait
71#[async_trait]
72pub trait SqlxBase {
73    fn pool(&self) -> &DbPool;
74
75    async fn create_table_if_not_exists(
76        &self,
77        table_name: &str,
78        create_sql: &str,
79    ) -> Result<(), Error> {
80        let db_type = get_db_type();
81        let table_exists_sql = match db_type {
82            DbType::Sqlite => format!(
83                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='{}'",
84                table_name
85            ),
86            DbType::Postgres => format!(
87                "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = '{}')",
88                table_name
89            ),
90            DbType::Mysql => format!(
91                "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = '{}'",
92                table_name
93            ),
94        };
95
96        let table_exists: (bool,) = query_as(&table_exists_sql).fetch_one(self.pool()).await?;
97        if !table_exists.0 {
98            query(create_sql).execute(self.pool()).await?;
99        }
100        Ok(())
101    }
102
103    async fn insert(
104        &self,
105        sql: &str,
106        transaction: Option<&mut DbTransaction<'_>>,
107    ) -> Result<u64, Error> {
108        let sql_lower = sql.to_lowercase();
109        if sql_lower.contains("returning") {
110            return Err(Error::Protocol("The 'insert' method does not support 'RETURNING' clause. Use 'insert_and_return_key' instead.".into()));
111        }
112        debug!("The 「Sqlx pPlus」「Insert)」 sql:{}", sql_lower);
113        let result = match transaction {
114            Some(tx) => query(sql).execute(&mut **tx).await?,
115            None => query(sql).execute(self.pool()).await?,
116        };
117        let rows_affected = result.rows_affected();
118        Ok(rows_affected)
119    }
120
121    async fn insert_and_return_key<T>(
122        &self,
123        sql: &str,
124        transaction: Option<&mut DbTransaction<'_>>,
125    ) -> Result<T, Error>
126    where
127        T: for<'r> FromRow<'r, DbRow> + Unpin + Send,
128    {
129        let sql_lower = sql.to_lowercase();
130        if !sql_lower.contains("returning") {
131            return Err(Error::Protocol(
132                "The 'insert_and_return_key' method requires a 'RETURNING' clause.".into(),
133            ));
134        }
135        let result = match transaction {
136            Some(tx) => query_as::<_, T>(&sql).fetch_one(&mut **tx).await?,
137            None => query_as::<_, T>(&sql).fetch_one(self.pool()).await?,
138        };
139        Ok(result)
140    }
141
142    async fn update(
143        &self,
144        sql: &str,
145        transaction: Option<&mut DbTransaction<'_>>,
146    ) -> Result<u64, Error> {
147        let result = match transaction {
148            Some(tx) => query(sql).execute(&mut **tx).await?,
149            None => query(sql).execute(self.pool()).await?,
150        };
151        Ok(result.rows_affected())
152    }
153
154    async fn delete(
155        &self,
156        sql: &str,
157        transaction: Option<&mut DbTransaction<'_>>,
158    ) -> Result<u64, Error> {
159        let result = match transaction {
160            Some(tx) => query(sql).execute(&mut **tx).await?,
161            None => query(sql).execute(self.pool()).await?,
162        };
163        Ok(result.rows_affected())
164    }
165
166    async fn select_one<T>(&self, sql: &str) -> Result<Option<T>, Error>
167    where
168        T: for<'r> FromRow<'r, DbRow> + Unpin + Send,
169    {
170        let result = query_as::<_, T>(sql).fetch_optional(self.pool()).await?;
171        Ok(result)
172    }
173
174    async fn query_one(&self, sql: &str) -> Result<PgRow, Error> {
175        let row = query(sql).fetch_one(self.pool()).await?;
176        Ok(row)
177    }
178
179    async fn select_all<T>(&self, sql: &str) -> Result<Vec<T>, Error>
180    where
181        T: for<'r> FromRow<'r, DbRow> + Unpin + Send,
182    {
183        let result = query_as::<_, T>(sql).fetch_all(self.pool()).await?;
184        Ok(result)
185    }
186
187    async fn select_page<T>(&self, sql: &str) -> Result<Vec<T>, Error>
188    where
189        T: for<'r> FromRow<'r, DbRow> + Unpin + Send,
190    {
191        let sql_lower = sql.to_lowercase();
192        if !sql_lower.contains("limit") || !sql_lower.contains("offset") {
193            return Err(Error::Protocol("SQL must contain LIMIT and OFFSET".into()));
194        }
195
196        let result = query_as::<_, T>(sql).fetch_all(self.pool()).await?;
197        Ok(result)
198    }
199
200    async fn count(&self, sql: &str) -> Result<i64, Error> {
201        let row = query(sql).fetch_one(self.pool()).await?;
202        let count: i64 = row.try_get(0)?;
203        Ok(count)
204    }
205
206    async fn health_check(&self) -> Result<bool, Error> {
207        let row: (i64,) = sqlx::query_as("SELECT $1")
208            .bind(150_i64)
209            .fetch_one(self.pool())
210            .await
211            .unwrap_or_else(|_| panic!("Sqlx Plus Failed to to execute query"));
212        Ok(row.0 == 150)
213    }
214
215    async fn begin(&self) -> Result<DbTransaction<'_>, Error> {
216        let transaction = self.pool().begin().await?;
217        Ok(transaction)
218    }
219
220    async fn commit(transaction: DbTransaction<'_>) -> Result<(), Error> {
221        transaction.commit().await?;
222        Ok(())
223    }
224
225    async fn rollback(transaction: DbTransaction<'_>) -> Result<(), Error> {
226        transaction.rollback().await?;
227        Ok(())
228    }
229}
230
231#[async_trait]
232impl SqlxBase for DbPool {
233    fn pool(&self) -> &DbPool {
234        // initialize_logger();
235        self
236    }
237}