Skip to main content

rustbasic_core/sql/
any.rs

1use serde_json::Value;
2pub use super::error::Error;
3use super::row::{AnyColumn, AnyTypeInfo, DbValue};
4pub use super::row::AnyRow;
5use sqlx::{AnyPool as SqlxAnyPool, Executor as SqlxExecutor, Row, Column, TypeInfo, ValueRef};
6use sqlx::any::{AnyConnectOptions, AnyRow as SqlxAnyRow};
7use std::str::FromStr;
8
9// Re-export sqlx::Any so it's accessible as sql::Any
10pub type Any = sqlx::Any;
11pub type SqlxAnyConnection = sqlx::AnyConnection;
12
13pub trait Database {}
14impl Database for Any {}
15
16#[cfg(feature = "mysql")]
17pub type MySqlPool = AnyPool;
18
19#[cfg(feature = "sqlite")]
20pub type SqlitePool = AnyPool;
21
22pub struct AnyArguments<'q> {
23    pub _marker: std::marker::PhantomData<&'q ()>,
24}
25
26pub fn install_default_drivers() {
27    sqlx::any::install_default_drivers();
28}
29
30#[derive(Clone)]
31pub enum AnyPool {
32    Sqlx(SqlxAnyPool),
33}
34
35impl std::fmt::Debug for AnyPool {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        write!(f, "AnyPool")
38    }
39}
40
41pub struct AnyConnection {
42    pub conn: SqlxAnyConnection,
43}
44
45impl std::fmt::Debug for AnyConnection {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        write!(f, "AnyConnection")
48    }
49}
50
51pub struct AnyQueryResult {
52    pub rows_affected: u64,
53    pub last_insert_id: Option<i64>,
54}
55
56impl AnyQueryResult {
57    pub fn rows_affected(&self) -> u64 {
58        self.rows_affected
59    }
60
61    pub fn last_insert_id(&self) -> Option<i64> {
62        self.last_insert_id
63    }
64}
65
66impl AnyPool {
67    pub async fn connect(url: &str) -> Result<Self, Error> {
68        install_default_drivers();
69        
70        let url = if url.starts_with("sqlite:") {
71            let path = url.trim_start_matches("sqlite:")
72                .split('?')
73                .next()
74                .unwrap_or(url);
75            if let Some(parent) = std::path::Path::new(path).parent() {
76                let _ = std::fs::create_dir_all(parent);
77            }
78            url.to_string()
79        } else {
80            url.to_string()
81        };
82
83        let options = AnyConnectOptions::from_str(&url)
84            .map_err(|e| Error::Database(e.to_string()))?;
85        
86        let pool = SqlxAnyPool::connect_with(options)
87            .await
88            .map_err(|e| Error::Database(e.to_string()))?;
89            
90        Ok(AnyPool::Sqlx(pool))
91    }
92
93    pub async fn acquire(&self) -> Result<PoolConnection, Error> {
94        match self {
95            AnyPool::Sqlx(pool) => {
96                let conn = pool.acquire()
97                    .await
98                    .map_err(|e| Error::Database(e.to_string()))?;
99                Ok(PoolConnection { conn: AnyConnection { conn: conn.detach() } })
100            }
101        }
102    }
103
104    pub fn backend_name(&self) -> &str {
105        "SQLx"
106    }
107}
108
109impl AnyConnection {
110    pub fn backend_name(&self) -> &str {
111        "SQLx"
112    }
113}
114
115pub struct PoolConnection {
116    pub conn: AnyConnection,
117}
118
119impl std::ops::Deref for PoolConnection {
120    type Target = AnyConnection;
121    fn deref(&self) -> &Self::Target {
122        &self.conn
123    }
124}
125
126impl std::ops::DerefMut for PoolConnection {
127    fn deref_mut(&mut self) -> &mut Self::Target {
128        &mut self.conn
129    }
130}
131
132#[allow(async_fn_in_trait)]
133pub trait Executor {
134    type Database: Database;
135
136    async fn execute(self, sql: &str, arguments: &[Value]) -> Result<AnyQueryResult, Error>;
137    async fn fetch_all(self, sql: &str, arguments: &[Value]) -> Result<Vec<AnyRow>, Error>;
138    async fn fetch_optional(self, sql: &str, arguments: &[Value]) -> Result<Option<AnyRow>, Error>;
139    async fn fetch_one(self, sql: &str, arguments: &[Value]) -> Result<AnyRow, Error>;
140}
141
142impl Executor for &AnyPool {
143    type Database = Any;
144
145    async fn execute(self, sql: &str, arguments: &[Value]) -> Result<AnyQueryResult, Error> {
146        let AnyPool::Sqlx(pool) = self;
147        execute_sqlx(pool, sql, arguments).await
148    }
149
150    async fn fetch_all(self, sql: &str, arguments: &[Value]) -> Result<Vec<AnyRow>, Error> {
151        let AnyPool::Sqlx(pool) = self;
152        fetch_all_sqlx(pool, sql, arguments).await
153    }
154
155    async fn fetch_optional(self, sql: &str, arguments: &[Value]) -> Result<Option<AnyRow>, Error> {
156        let AnyPool::Sqlx(pool) = self;
157        fetch_optional_sqlx(pool, sql, arguments).await
158    }
159
160    async fn fetch_one(self, sql: &str, arguments: &[Value]) -> Result<AnyRow, Error> {
161        let AnyPool::Sqlx(pool) = self;
162        fetch_one_sqlx(pool, sql, arguments).await
163    }
164}
165
166impl Executor for &mut AnyConnection {
167    type Database = Any;
168
169    async fn execute(self, sql: &str, arguments: &[Value]) -> Result<AnyQueryResult, Error> {
170        execute_sqlx(&mut self.conn, sql, arguments).await
171    }
172
173    async fn fetch_all(self, sql: &str, arguments: &[Value]) -> Result<Vec<AnyRow>, Error> {
174        fetch_all_sqlx(&mut self.conn, sql, arguments).await
175    }
176
177    async fn fetch_optional(self, sql: &str, arguments: &[Value]) -> Result<Option<AnyRow>, Error> {
178        fetch_optional_sqlx(&mut self.conn, sql, arguments).await
179    }
180
181    async fn fetch_one(self, sql: &str, arguments: &[Value]) -> Result<AnyRow, Error> {
182        fetch_one_sqlx(&mut self.conn, sql, arguments).await
183    }
184}
185
186async fn execute_sqlx<'e, E>(executor: E, sql: &str, arguments: &[Value]) -> Result<AnyQueryResult, Error> 
187where E: SqlxExecutor<'e, Database = Any>
188{
189    let mut query = sqlx::query(sql);
190    for arg in arguments {
191        query = bind_json_value(query, arg);
192    }
193    let res = query.execute(executor).await.map_err(|e| Error::Database(e.to_string()))?;
194    Ok(AnyQueryResult {
195        rows_affected: res.rows_affected(),
196        last_insert_id: res.last_insert_id(),
197    })
198}
199
200async fn fetch_all_sqlx<'e, E>(executor: E, sql: &str, arguments: &[Value]) -> Result<Vec<AnyRow>, Error>
201where E: SqlxExecutor<'e, Database = Any>
202{
203    let mut query = sqlx::query(sql);
204    for arg in arguments {
205        query = bind_json_value(query, arg);
206    }
207    let rows = query.fetch_all(executor).await.map_err(|e| Error::Database(e.to_string()))?;
208    Ok(rows.into_iter().map(sqlx_row_to_any_row).collect())
209}
210
211async fn fetch_optional_sqlx<'e, E>(executor: E, sql: &str, arguments: &[Value]) -> Result<Option<AnyRow>, Error>
212where E: SqlxExecutor<'e, Database = Any>
213{
214    let mut query = sqlx::query(sql);
215    for arg in arguments {
216        query = bind_json_value(query, arg);
217    }
218    let row = query.fetch_optional(executor).await.map_err(|e| Error::Database(e.to_string()))?;
219    Ok(row.map(sqlx_row_to_any_row))
220}
221
222async fn fetch_one_sqlx<'e, E>(executor: E, sql: &str, arguments: &[Value]) -> Result<AnyRow, Error>
223where E: SqlxExecutor<'e, Database = Any>
224{
225    let mut query = sqlx::query(sql);
226    for arg in arguments {
227        query = bind_json_value(query, arg);
228    }
229    let row = query.fetch_one(executor).await.map_err(|e| Error::Database(e.to_string()))?;
230    Ok(sqlx_row_to_any_row(row))
231}
232
233fn bind_json_value<'q>(query: sqlx::query::Query<'q, Any, sqlx::any::AnyArguments<'q>>, val: &'q Value) -> sqlx::query::Query<'q, Any, sqlx::any::AnyArguments<'q>> {
234    match val {
235        Value::Null => query.bind(None::<String>),
236        Value::Bool(b) => query.bind(*b),
237        Value::Number(n) => {
238            if let Some(i) = n.as_i64() {
239                query.bind(i)
240            } else if let Some(f) = n.as_f64() {
241                query.bind(f)
242            } else {
243                query.bind(0.0)
244            }
245        }
246        Value::String(s) => query.bind(s.as_str()),
247        _ => query.bind(val.to_string()),
248    }
249}
250
251fn sqlx_row_to_any_row(row: SqlxAnyRow) -> AnyRow {
252    let mut columns = Vec::new();
253    let mut values = Vec::new();
254
255    for col in row.columns() {
256        columns.push(AnyColumn {
257            name: col.name().to_string(),
258            type_info: AnyTypeInfo {
259                name: col.type_info().name().to_string(),
260            },
261        });
262        
263        let val: DbValue = match row.try_get_raw(col.ordinal()) {
264            Ok(raw_val) => {
265                if raw_val.is_null() {
266                    DbValue::Null
267                } else {
268                    if let Ok(v) = row.try_get::<i64, _>(col.ordinal()) {
269                        DbValue::Integer(v)
270                    } else if let Ok(v) = row.try_get::<f64, _>(col.ordinal()) {
271                        DbValue::Real(v)
272                    } else if let Ok(v) = row.try_get::<bool, _>(col.ordinal()) {
273                        DbValue::Bool(v)
274                    } else if let Ok(v) = row.try_get::<String, _>(col.ordinal()) {
275                        DbValue::Text(v)
276                    } else if let Ok(v) = row.try_get::<Vec<u8>, _>(col.ordinal()) {
277                        DbValue::Blob(v)
278                    } else {
279                        DbValue::Text(format!("{:?}", raw_val))
280                    }
281                }
282            }
283            Err(_) => DbValue::Null,
284        };
285        values.push(val);
286    }
287
288    AnyRow { columns, values }
289}