1use std::{any::Any, marker::PhantomData, ops::Deref};
2
3use futures_util::TryStreamExt;
4use sqlx_core::{
5 Either, Error,
6 any::{AnyConnection, AnyPool},
7 arguments::Arguments,
8 database::Database,
9 describe::Describe,
10 encode::Encode,
11 executor::{Execute, Executor},
12 pool::PoolConnection,
13 sql_str::SqlStr,
14 try_stream,
15 types::Type,
16};
17pub trait DatabaseDialect {
22 fn backend_name(&self) -> &str;
24 fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)>;
33 fn write_count_sql(&self, sql: &mut String);
38 fn write_page_sql<'c, 'q, DB>(
49 &self,
50 sql: &mut String,
51 page_size: i64,
52 page_no: i64,
53 arg: &mut DB::Arguments,
54 ) -> Result<(), Error>
55 where
56 DB: Database,
57 i64: Encode<'q, DB> + Type<DB>;
58}
59
60#[derive(Debug, PartialEq)]
62pub enum DBType {
63 PostgreSQL,
65 MySQL,
67 SQLite,
69}
70impl DBType {
71 pub fn new(db_name: &str) -> Result<Self, Error> {
84 match db_name {
85 "PostgreSQL" => Ok(Self::PostgreSQL),
86 "MySQL" => Ok(Self::MySQL),
87 "SQLite" => Ok(Self::SQLite),
88 _ => Err(Error::Protocol(format!("unsupport db `{db_name}`"))),
89 }
90 }
91}
92
93impl DatabaseDialect for DBType {
94 fn backend_name(&self) -> &str {
95 match self {
96 Self::PostgreSQL => "PostgreSQL",
97 Self::MySQL => "MySQL",
98 Self::SQLite => "SQLite",
99 }
100 }
101 fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)> {
110 match self {
111 Self::PostgreSQL => Some(|i: usize, s: &mut String| s.push_str(&format!("${i}"))),
112 Self::MySQL | Self::SQLite => Some(|_: usize, s: &mut String| s.push('?')),
113 }
114 }
115 fn write_count_sql(&self, sql: &mut String) {
120 match self {
121 Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
122 pg_mysql_sqlite_count_sql(sql);
123 }
124 }
125 }
126 fn write_page_sql<'c, 'q, DB>(
137 &self,
138 sql: &mut String,
139 page_size: i64,
140 page_no: i64,
141
142 arg: &mut DB::Arguments,
143 ) -> Result<(), Error>
144 where
145 DB: Database,
146 i64: Encode<'q, DB> + Type<DB>,
147 {
148 let f = self.get_encode_placeholder_fn();
149 match self {
150 Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
151 pg_mysql_sqlite_page_sql(sql, page_size, page_no, f, arg)?;
152 Ok(())
153 }
154 }
155 }
156}
157fn pg_mysql_sqlite_count_sql(sql: &mut String) {
158 *sql = format!("select count(1) from ({sql}) t")
159}
160fn pg_mysql_sqlite_page_sql<'c, 'q, DB>(
161 sql: &mut String,
162 mut page_size: i64,
163 mut page_no: i64,
164 f: Option<fn(usize, &mut String)>,
165 arg: &mut DB::Arguments,
166) -> Result<(), Error>
167where
168 DB: Database,
169 i64: Encode<'q, DB> + Type<DB>,
170{
171 if page_size < 1 {
172 page_size = 1
173 }
174 if page_no < 1 {
175 page_no = 1
176 }
177 let offset = (page_no - 1) * page_size;
178 if let Some(f) = f {
179 sql.push_str(" limit ");
180 arg.add(page_size).map_err(Error::Encode)?;
181 f(arg.len(), sql);
182 sql.push_str(" offset ");
183 arg.add(offset).map_err(Error::Encode)?;
184 f(arg.len(), sql);
185 } else {
186 sql.push_str(" limit ");
187 arg.add(page_size).map_err(Error::Encode)?;
188 arg.format_placeholder(sql)
189 .map_err(|e| Error::Encode(Box::new(e)))?;
190
191 sql.push_str(" offset ");
192 arg.add(offset).map_err(Error::Encode)?;
193 arg.format_placeholder(sql)
194 .map_err(|e| Error::Encode(Box::new(e)))?;
195 }
196
197 Ok(())
198}
199
200pub trait BackendDB<'c, DB>
214where
215 DB: Database,
216{
217 type Executor: Executor<'c, Database = DB> + 'c;
218 type DatabaseDialect: DatabaseDialect;
219 fn backend_db(
220 self,
221 ) -> impl std::future::Future<Output = Result<(Self::DatabaseDialect, Self::Executor), Error>> + Send;
222}
223impl<'c, DB, C, C1> BackendDB<'c, DB> for C
224where
225 DB: Database,
226 C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
227 C1: Any,
228 for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
229{
230 type DatabaseDialect = DBType;
231 type Executor = AdapterExecutor<'c, DB, C>;
232 async fn backend_db(self) -> Result<(Self::DatabaseDialect, Self::Executor), Error> {
233 backend_db(self).await
234 }
235}
236#[derive(Debug)]
237pub struct AdapterExecutor<'c, DB: Database, C: Executor<'c, Database = DB>> {
238 executor: Either<C, PoolConnection<DB>>,
239 _m: PhantomData<&'c ()>,
240}
241impl<'c, DB, C> AdapterExecutor<'c, DB, C>
242where
243 DB: Database,
244 C: Executor<'c, Database = DB>,
245{
246 fn new(executor: Either<C, PoolConnection<DB>>) -> Self {
247 Self {
248 executor,
249 _m: PhantomData,
250 }
251 }
252}
253
254impl<'c, DB, C> Executor<'c> for AdapterExecutor<'c, DB, C>
255where
256 DB: Database,
257 C: Executor<'c, Database = DB>,
258 for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
259{
260 type Database = DB;
261
262 fn fetch_many<'e, 'q: 'e, E>(
263 self,
264 query: E,
265 ) -> futures_core::stream::BoxStream<
266 'e,
267 Result<
268 Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
269 Error,
270 >,
271 >
272 where
273 'c: 'e,
274 E: 'q + Execute<'q, Self::Database>,
275 {
276 match self.executor {
277 Either::Left(executor) => executor.fetch_many(query),
278 Either::Right(mut conn) => Box::pin(try_stream! {
279
280
281 let mut s = conn.fetch_many(query);
282
283 while let Some(v) = s.try_next().await? {
284 r#yield!(v);
285 }
286
287 Ok(())
288 }),
289 }
290 }
291
292 fn fetch_optional<'e, 'q: 'e, E>(
293 self,
294 query: E,
295 ) -> futures_core::future::BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
296 where
297 'c: 'e,
298 E: 'q + Execute<'q, Self::Database>,
299 {
300 match self.executor {
301 Either::Left(executor) => executor.fetch_optional(query),
302 Either::Right(mut conn) => Box::pin(async move { conn.fetch_optional(query).await }),
303 }
304 }
305
306 fn prepare_with<'e>(
307 self,
308 sql: SqlStr,
309 parameters: &'e [<Self::Database as Database>::TypeInfo],
310 ) -> futures_core::future::BoxFuture<'e, Result<<Self::Database as Database>::Statement, Error>>
311 where
312 'c: 'e,
313 {
314 match self.executor {
315 Either::Left(executor) => executor.prepare_with(sql, parameters),
316 Either::Right(mut conn) => {
317 Box::pin(async move { conn.prepare_with(sql, parameters).await })
318 }
319 }
320 }
321
322 fn describe<'e>(
323 self,
324 sql: SqlStr,
325 ) -> futures_core::future::BoxFuture<'e, Result<Describe<Self::Database>, Error>>
326 where
327 'c: 'e,
328 {
329 match self.executor {
330 Either::Left(executor) => executor.describe(sql),
331 Either::Right(mut conn) => Box::pin(async move { conn.describe(sql).await }),
332 }
333 }
334}
335pub async fn backend_db<'c, DB, C, C1>(c: C) -> Result<(DBType, AdapterExecutor<'c, DB, C>), Error>
336where
337 DB: Database,
338 C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
339 C1: Any + 'static,
340{
341 if DB::NAME != sqlx_core::any::Any::NAME {
342 return Ok((
343 DBType::new(DB::NAME)?,
344 AdapterExecutor::new(Either::Left(c)),
345 ));
346 }
347
348 let any_ref = c.deref() as &dyn Any;
349 if let Some(conn) = any_ref.downcast_ref::<AnyConnection>() {
351 return Ok((
352 DBType::new(conn.backend_name())?,
353 AdapterExecutor::new(Either::Left(c)),
354 ));
355 }
356
357 if let Some(pool) = any_ref.downcast_ref::<AnyPool>() {
359 let conn = pool.acquire().await?;
360
361 let db_type = DBType::new(conn.backend_name())?;
362 let db_con: Box<dyn Any> = Box::new(conn);
363 let return_con = db_con
364 .downcast::<PoolConnection<DB>>()
365 .map_err(|_| Error::Protocol(format!("unsupport db `{}`", DB::NAME)))?;
366
367 return Ok((db_type, AdapterExecutor::new(Either::Right(*return_con))));
368 }
369 Err(Error::Protocol(format!("unsupport db `{}`", DB::NAME)))
370}