1use std::{any::Any, marker::PhantomData, ops::Deref};
2
3use futures_util::TryStreamExt;
4use sqlx_core::{
5 Either, Error,
6 any::{AnyConnection, AnyPool},
7 arguments::Arguments,
8 database::Database,
9 describe::Describe,
10 encode::Encode,
11 executor::{Execute, Executor},
12 pool::PoolConnection,
13 sql_str::SqlStr,
14 try_stream,
15 types::Type,
16};
17pub trait DatabaseDialect: Send + Sync {
22 fn backend_name(&self) -> &str;
24 fn placeholder_fn(&self) -> Option<fn(usize, &mut String)>;
33 fn write_count_sql(&self, sql: &mut String);
38 fn write_pagination_sql<'q, DB>(
49 &self,
50 sql: &mut String,
51 pagination_size: i64,
52 pagination_no: i64,
53 arg: &mut DB::Arguments,
54 ) -> Result<(), Error>
55 where
56 DB: Database,
57 i64: Encode<'q, DB> + Type<DB>;
58}
59
60#[derive(Debug, PartialEq)]
62pub enum DBType {
63 PostgreSQL,
65 MySQL,
67 SQLite,
69}
70impl DBType {
71 pub fn new(db_name: &str) -> Result<Self, Error> {
84 match db_name {
85 "PostgreSQL" => Ok(Self::PostgreSQL),
86 "MySQL" => Ok(Self::MySQL),
87 "SQLite" => Ok(Self::SQLite),
88 _ => Err(Error::Protocol(format!("unsupported db `{db_name}`"))),
89 }
90 }
91}
92
93impl DatabaseDialect for DBType {
94 fn backend_name(&self) -> &str {
95 match self {
96 Self::PostgreSQL => "PostgreSQL",
97 Self::MySQL => "MySQL",
98 Self::SQLite => "SQLite",
99 }
100 }
101 fn placeholder_fn(&self) -> Option<fn(usize, &mut String)> {
110 match self {
111 Self::PostgreSQL => Some(|i: usize, s: &mut String| s.push_str(&format!("${i}"))),
112 Self::MySQL | Self::SQLite => Some(|_: usize, s: &mut String| s.push('?')),
113 }
114 }
115 fn write_count_sql(&self, sql: &mut String) {
120 match self {
121 Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
122 pg_mysql_sqlite_count_sql(sql);
123 }
124 }
125 }
126 fn write_pagination_sql<'q, DB>(
137 &self,
138 sql: &mut String,
139 pagination_size: i64,
140 pagination_no: i64,
141
142 arg: &mut DB::Arguments,
143 ) -> Result<(), Error>
144 where
145 DB: Database,
146 i64: Encode<'q, DB> + Type<DB>,
147 {
148 let f = self.placeholder_fn();
149 match self {
150 Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
151 pg_mysql_sqlite_pagination_sql(sql, pagination_size, pagination_no, f, arg)?;
152 Ok(())
153 }
154 }
155 }
156}
157
158fn pg_mysql_sqlite_count_sql(sql: &mut String) {
160 *sql = format!("select count(1) from ({sql}) t")
161}
162fn pg_mysql_sqlite_pagination_sql<'q, DB>(
164 sql: &mut String,
165 mut pagination_size: i64,
166 mut pagination_no: i64,
167 f: Option<fn(usize, &mut String)>,
168 arg: &mut DB::Arguments,
169) -> Result<(), Error>
170where
171 DB: Database,
172 i64: Encode<'q, DB> + Type<DB>,
173{
174 if pagination_size < 1 {
175 pagination_size = 1
176 }
177 if pagination_no < 1 {
178 pagination_no = 1
179 }
180 let offset = (pagination_no - 1) * pagination_size;
181 if let Some(f) = f {
182 sql.push_str(" limit ");
183 arg.add(pagination_size).map_err(Error::Encode)?;
184 f(arg.len(), sql);
185 sql.push_str(" offset ");
186 arg.add(offset).map_err(Error::Encode)?;
187 f(arg.len(), sql);
188 } else {
189 sql.push_str(" limit ");
190 arg.add(pagination_size).map_err(Error::Encode)?;
191 arg.format_placeholder(sql)
192 .map_err(|e| Error::Encode(Box::new(e)))?;
193
194 sql.push_str(" offset ");
195 arg.add(offset).map_err(Error::Encode)?;
196 arg.format_placeholder(sql)
197 .map_err(|e| Error::Encode(Box::new(e)))?;
198 }
199
200 Ok(())
201}
202
203pub trait BackendDB<'c, DB>: Send
217where
218 DB: Database,
219{
220 type Executor: Executor<'c, Database = DB> + 'c;
221 type DatabaseDialect: DatabaseDialect;
222 fn backend_db(
223 self,
224 ) -> impl std::future::Future<Output = Result<(Self::DatabaseDialect, Self::Executor), Error>> + Send;
225}
226impl<'c, DB, C, C1> BackendDB<'c, DB> for C
227where
228 DB: Database,
229 C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
230 C1: Any,
231 for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
232{
233 type DatabaseDialect = DBType;
234 type Executor = AdapterExecutor<'c, DB, C>;
235 async fn backend_db(self) -> Result<(Self::DatabaseDialect, Self::Executor), Error> {
236 detect_backend_db(self).await
237 }
238}
239
240#[derive(Debug)]
241pub struct AdapterExecutor<'c, DB: Database, C: Executor<'c, Database = DB>> {
242 executor: Either<C, PoolConnection<DB>>,
243 _m: PhantomData<&'c ()>,
244}
245impl<'c, DB, C> AdapterExecutor<'c, DB, C>
246where
247 DB: Database,
248 C: Executor<'c, Database = DB>,
249{
250 fn new(executor: Either<C, PoolConnection<DB>>) -> Self {
251 Self {
252 executor,
253 _m: PhantomData,
254 }
255 }
256}
257
258impl<'c, DB, C> Executor<'c> for AdapterExecutor<'c, DB, C>
259where
260 DB: Database,
261 C: Executor<'c, Database = DB>,
262 for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
263{
264 type Database = DB;
265
266 fn fetch_many<'e, 'q: 'e, E>(
267 self,
268 query: E,
269 ) -> futures_core::stream::BoxStream<
270 'e,
271 Result<
272 Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
273 Error,
274 >,
275 >
276 where
277 'c: 'e,
278 E: 'q + Execute<'q, Self::Database>,
279 {
280 match self.executor {
281 Either::Left(executor) => executor.fetch_many(query),
282 Either::Right(mut conn) => Box::pin(try_stream! {
283
284
285 let mut s = conn.fetch_many(query);
286
287 while let Some(v) = s.try_next().await? {
288 r#yield!(v);
289 }
290
291 Ok(())
292 }),
293 }
294 }
295
296 fn fetch_optional<'e, 'q: 'e, E>(
297 self,
298 query: E,
299 ) -> futures_core::future::BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
300 where
301 'c: 'e,
302 E: 'q + Execute<'q, Self::Database>,
303 {
304 match self.executor {
305 Either::Left(executor) => executor.fetch_optional(query),
306 Either::Right(mut conn) => Box::pin(async move { conn.fetch_optional(query).await }),
307 }
308 }
309
310 fn prepare_with<'e>(
311 self,
312 sql: SqlStr,
313 parameters: &'e [<Self::Database as Database>::TypeInfo],
314 ) -> futures_core::future::BoxFuture<'e, Result<<Self::Database as Database>::Statement, Error>>
315 where
316 'c: 'e,
317 {
318 match self.executor {
319 Either::Left(executor) => executor.prepare_with(sql, parameters),
320 Either::Right(mut conn) => {
321 Box::pin(async move { conn.prepare_with(sql, parameters).await })
322 }
323 }
324 }
325
326 fn describe<'e>(
327 self,
328 sql: SqlStr,
329 ) -> futures_core::future::BoxFuture<'e, Result<Describe<Self::Database>, Error>>
330 where
331 'c: 'e,
332 {
333 match self.executor {
334 Either::Left(executor) => executor.describe(sql),
335 Either::Right(mut conn) => Box::pin(async move { conn.describe(sql).await }),
336 }
337 }
338}
339pub async fn detect_backend_db<'c, DB, C, C1>(
346 c: C,
347) -> Result<(DBType, AdapterExecutor<'c, DB, C>), Error>
348where
349 DB: Database,
350 C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
351 C1: Any,
352{
353 if DB::NAME != sqlx_core::any::Any::NAME {
354 return Ok((
355 DBType::new(DB::NAME)?,
356 AdapterExecutor::new(Either::Left(c)),
357 ));
358 }
359
360 let any_ref = c.deref() as &dyn Any;
361 if let Some(conn) = any_ref.downcast_ref::<AnyConnection>() {
363 return Ok((
364 DBType::new(conn.backend_name())?,
365 AdapterExecutor::new(Either::Left(c)),
366 ));
367 }
368
369 if let Some(pool) = any_ref.downcast_ref::<AnyPool>() {
371 let conn = pool.acquire().await?;
372
373 let db_type = DBType::new(conn.backend_name())?;
374 let db_con: Box<dyn Any> = Box::new(conn);
375 let return_con = db_con
376 .downcast::<PoolConnection<DB>>()
377 .map_err(|_| Error::Protocol(format!("unsupported db `{}`", DB::NAME)))?;
378
379 return Ok((db_type, AdapterExecutor::new(Either::Right(*return_con))));
380 }
381 Err(Error::Protocol(format!("unsupported db `{}`", DB::NAME)))
382}