1use serde_json::Value;
2pub use super::error::Error;
3use super::row::{AnyColumn, AnyTypeInfo, DbValue};
4pub use super::row::AnyRow;
5use sqlx::{AnyPool as SqlxAnyPool, Executor as SqlxExecutor, Row, Column, TypeInfo, ValueRef};
6use sqlx::any::{AnyConnectOptions, AnyRow as SqlxAnyRow};
7use std::str::FromStr;
8
9pub type Any = sqlx::Any;
11pub type SqlxAnyConnection = sqlx::AnyConnection;
12
13pub trait Database {}
14impl Database for Any {}
15
16#[cfg(feature = "mysql")]
17pub type MySqlPool = AnyPool;
18
19#[cfg(feature = "sqlite")]
20pub type SqlitePool = AnyPool;
21
22pub struct AnyArguments<'q> {
23 pub _marker: std::marker::PhantomData<&'q ()>,
24}
25
26pub fn install_default_drivers() {
27 sqlx::any::install_default_drivers();
28}
29
30#[derive(Clone)]
31pub enum AnyPool {
32 Sqlx(SqlxAnyPool),
33}
34
35impl std::fmt::Debug for AnyPool {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 write!(f, "AnyPool")
38 }
39}
40
41pub struct AnyConnection {
42 pub conn: SqlxAnyConnection,
43}
44
45impl std::fmt::Debug for AnyConnection {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 write!(f, "AnyConnection")
48 }
49}
50
51pub struct AnyQueryResult {
52 pub rows_affected: u64,
53 pub last_insert_id: Option<i64>,
54}
55
56impl AnyQueryResult {
57 pub fn rows_affected(&self) -> u64 {
58 self.rows_affected
59 }
60
61 pub fn last_insert_id(&self) -> Option<i64> {
62 self.last_insert_id
63 }
64}
65
66impl AnyPool {
67 pub async fn connect(url: &str) -> Result<Self, Error> {
68 install_default_drivers();
69
70 let url = if url.starts_with("sqlite:") {
71 let path = url.trim_start_matches("sqlite:")
72 .split('?')
73 .next()
74 .unwrap_or(url);
75 if let Some(parent) = std::path::Path::new(path).parent() {
76 let _ = std::fs::create_dir_all(parent);
77 }
78 url.to_string()
79 } else {
80 url.to_string()
81 };
82
83 let options = AnyConnectOptions::from_str(&url)
84 .map_err(|e| Error::Database(e.to_string()))?;
85
86 let pool = SqlxAnyPool::connect_with(options)
87 .await
88 .map_err(|e| Error::Database(e.to_string()))?;
89
90 Ok(AnyPool::Sqlx(pool))
91 }
92
93 pub async fn acquire(&self) -> Result<PoolConnection, Error> {
94 match self {
95 AnyPool::Sqlx(pool) => {
96 let conn = pool.acquire()
97 .await
98 .map_err(|e| Error::Database(e.to_string()))?;
99 Ok(PoolConnection { conn: AnyConnection { conn: conn.detach() } })
100 }
101 }
102 }
103
104 pub fn backend_name(&self) -> &str {
105 "SQLx"
106 }
107}
108
109impl AnyConnection {
110 pub fn backend_name(&self) -> &str {
111 "SQLx"
112 }
113}
114
115pub struct PoolConnection {
116 pub conn: AnyConnection,
117}
118
119impl std::ops::Deref for PoolConnection {
120 type Target = AnyConnection;
121 fn deref(&self) -> &Self::Target {
122 &self.conn
123 }
124}
125
126impl std::ops::DerefMut for PoolConnection {
127 fn deref_mut(&mut self) -> &mut Self::Target {
128 &mut self.conn
129 }
130}
131
132#[allow(async_fn_in_trait)]
133pub trait Executor {
134 type Database: Database;
135
136 async fn execute(self, sql: &str, arguments: &[Value]) -> Result<AnyQueryResult, Error>;
137 async fn fetch_all(self, sql: &str, arguments: &[Value]) -> Result<Vec<AnyRow>, Error>;
138 async fn fetch_optional(self, sql: &str, arguments: &[Value]) -> Result<Option<AnyRow>, Error>;
139 async fn fetch_one(self, sql: &str, arguments: &[Value]) -> Result<AnyRow, Error>;
140}
141
142impl Executor for &AnyPool {
143 type Database = Any;
144
145 async fn execute(self, sql: &str, arguments: &[Value]) -> Result<AnyQueryResult, Error> {
146 let AnyPool::Sqlx(pool) = self;
147 execute_sqlx(pool, sql, arguments).await
148 }
149
150 async fn fetch_all(self, sql: &str, arguments: &[Value]) -> Result<Vec<AnyRow>, Error> {
151 let AnyPool::Sqlx(pool) = self;
152 fetch_all_sqlx(pool, sql, arguments).await
153 }
154
155 async fn fetch_optional(self, sql: &str, arguments: &[Value]) -> Result<Option<AnyRow>, Error> {
156 let AnyPool::Sqlx(pool) = self;
157 fetch_optional_sqlx(pool, sql, arguments).await
158 }
159
160 async fn fetch_one(self, sql: &str, arguments: &[Value]) -> Result<AnyRow, Error> {
161 let AnyPool::Sqlx(pool) = self;
162 fetch_one_sqlx(pool, sql, arguments).await
163 }
164}
165
166impl Executor for &mut AnyConnection {
167 type Database = Any;
168
169 async fn execute(self, sql: &str, arguments: &[Value]) -> Result<AnyQueryResult, Error> {
170 execute_sqlx(&mut self.conn, sql, arguments).await
171 }
172
173 async fn fetch_all(self, sql: &str, arguments: &[Value]) -> Result<Vec<AnyRow>, Error> {
174 fetch_all_sqlx(&mut self.conn, sql, arguments).await
175 }
176
177 async fn fetch_optional(self, sql: &str, arguments: &[Value]) -> Result<Option<AnyRow>, Error> {
178 fetch_optional_sqlx(&mut self.conn, sql, arguments).await
179 }
180
181 async fn fetch_one(self, sql: &str, arguments: &[Value]) -> Result<AnyRow, Error> {
182 fetch_one_sqlx(&mut self.conn, sql, arguments).await
183 }
184}
185
186async fn execute_sqlx<'e, E>(executor: E, sql: &str, arguments: &[Value]) -> Result<AnyQueryResult, Error>
187where E: SqlxExecutor<'e, Database = Any>
188{
189 let mut query = sqlx::query(sql);
190 for arg in arguments {
191 query = bind_json_value(query, arg);
192 }
193 let res = query.execute(executor).await.map_err(|e| Error::Database(e.to_string()))?;
194 Ok(AnyQueryResult {
195 rows_affected: res.rows_affected(),
196 last_insert_id: res.last_insert_id(),
197 })
198}
199
200async fn fetch_all_sqlx<'e, E>(executor: E, sql: &str, arguments: &[Value]) -> Result<Vec<AnyRow>, Error>
201where E: SqlxExecutor<'e, Database = Any>
202{
203 let mut query = sqlx::query(sql);
204 for arg in arguments {
205 query = bind_json_value(query, arg);
206 }
207 let rows = query.fetch_all(executor).await.map_err(|e| Error::Database(e.to_string()))?;
208 Ok(rows.into_iter().map(sqlx_row_to_any_row).collect())
209}
210
211async fn fetch_optional_sqlx<'e, E>(executor: E, sql: &str, arguments: &[Value]) -> Result<Option<AnyRow>, Error>
212where E: SqlxExecutor<'e, Database = Any>
213{
214 let mut query = sqlx::query(sql);
215 for arg in arguments {
216 query = bind_json_value(query, arg);
217 }
218 let row = query.fetch_optional(executor).await.map_err(|e| Error::Database(e.to_string()))?;
219 Ok(row.map(sqlx_row_to_any_row))
220}
221
222async fn fetch_one_sqlx<'e, E>(executor: E, sql: &str, arguments: &[Value]) -> Result<AnyRow, Error>
223where E: SqlxExecutor<'e, Database = Any>
224{
225 let mut query = sqlx::query(sql);
226 for arg in arguments {
227 query = bind_json_value(query, arg);
228 }
229 let row = query.fetch_one(executor).await.map_err(|e| Error::Database(e.to_string()))?;
230 Ok(sqlx_row_to_any_row(row))
231}
232
233fn bind_json_value<'q>(query: sqlx::query::Query<'q, Any, sqlx::any::AnyArguments<'q>>, val: &'q Value) -> sqlx::query::Query<'q, Any, sqlx::any::AnyArguments<'q>> {
234 match val {
235 Value::Null => query.bind(None::<String>),
236 Value::Bool(b) => query.bind(*b),
237 Value::Number(n) => {
238 if let Some(i) = n.as_i64() {
239 query.bind(i)
240 } else if let Some(f) = n.as_f64() {
241 query.bind(f)
242 } else {
243 query.bind(0.0)
244 }
245 }
246 Value::String(s) => query.bind(s.as_str()),
247 _ => query.bind(val.to_string()),
248 }
249}
250
251fn sqlx_row_to_any_row(row: SqlxAnyRow) -> AnyRow {
252 let mut columns = Vec::new();
253 let mut values = Vec::new();
254
255 for col in row.columns() {
256 columns.push(AnyColumn {
257 name: col.name().to_string(),
258 type_info: AnyTypeInfo {
259 name: col.type_info().name().to_string(),
260 },
261 });
262
263 let val: DbValue = match row.try_get_raw(col.ordinal()) {
264 Ok(raw_val) => {
265 if raw_val.is_null() {
266 DbValue::Null
267 } else {
268 if let Ok(v) = row.try_get::<i64, _>(col.ordinal()) {
269 DbValue::Integer(v)
270 } else if let Ok(v) = row.try_get::<f64, _>(col.ordinal()) {
271 DbValue::Real(v)
272 } else if let Ok(v) = row.try_get::<bool, _>(col.ordinal()) {
273 DbValue::Bool(v)
274 } else if let Ok(v) = row.try_get::<String, _>(col.ordinal()) {
275 DbValue::Text(v)
276 } else if let Ok(v) = row.try_get::<Vec<u8>, _>(col.ordinal()) {
277 DbValue::Blob(v)
278 } else {
279 DbValue::Text(format!("{:?}", raw_val))
280 }
281 }
282 }
283 Err(_) => DbValue::Null,
284 };
285 values.push(val);
286 }
287
288 AnyRow { columns, values }
289}