1use std::{any::Any, marker::PhantomData, ops::Deref};
2
3use futures_util::TryStreamExt;
4use sqlx_core::{
5 Either, Error,
6 any::{AnyConnection, AnyPool},
7 arguments::Arguments,
8 database::Database,
9 describe::Describe,
10 encode::Encode,
11 executor::{Execute, Executor},
12 pool::PoolConnection,
13 try_stream,
14 types::Type,
15};
16pub trait DatabaseDialect {
21 fn backend_name(&self) -> &str;
22 fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)>;
31 fn write_count_sql(&self, sql: &mut String);
36 fn write_page_sql<'c, 'q, DB>(
47 &self,
48 sql: &mut String,
49 page_size: i64,
50 page_no: i64,
51 arg: &mut DB::Arguments<'q>,
52 ) -> Result<(), Error>
53 where
54 DB: Database,
55 i64: Encode<'q, DB> + Type<DB>;
56}
57
58#[derive(Debug, PartialEq)]
60pub enum DBType {
61 PostgreSQL,
63 MySQL,
65 SQLite,
67}
68impl DBType {
69 pub fn new(db_name: &str) -> Result<Self, Error> {
82 match db_name {
83 "PostgreSQL" => Ok(Self::PostgreSQL),
84 "MySQL" => Ok(Self::MySQL),
85 "SQLite" => Ok(Self::SQLite),
86 _ => Err(Error::Protocol(format!("unsupport db `{}`", db_name))),
87 }
88 }
89}
90
91impl DatabaseDialect for DBType {
92 fn backend_name(&self) -> &str {
93 match self {
94 Self::PostgreSQL => "PostgreSQL",
95 Self::MySQL => "MySQL",
96 Self::SQLite => "SQLite",
97 }
98 }
99 fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)> {
108 match self {
109 Self::PostgreSQL => Some(|i: usize, s: &mut String| s.push_str(&format!("${}", i))),
110 Self::MySQL | Self::SQLite => Some(|_: usize, s: &mut String| s.push('?')),
111 }
112 }
113 fn write_count_sql(&self, sql: &mut String) {
118 match self {
119 Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
120 pg_mysql_sqlite_count_sql(sql);
121 }
122 }
123 }
124 fn write_page_sql<'c, 'q, DB>(
135 &self,
136 sql: &mut String,
137 page_size: i64,
138 page_no: i64,
139
140 arg: &mut DB::Arguments<'q>,
141 ) -> Result<(), Error>
142 where
143 DB: Database,
144 i64: Encode<'q, DB> + Type<DB>,
145 {
146 let f = self.get_encode_placeholder_fn();
147 match self {
148 Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
149 pg_mysql_sqlite_page_sql(sql, page_size, page_no, f, arg)?;
150 Ok(())
151 }
152 }
153 }
154}
155fn pg_mysql_sqlite_count_sql(sql: &mut String) {
156 *sql = format!("select count(1) from ({}) t", sql)
157}
158fn pg_mysql_sqlite_page_sql<'c, 'q, DB>(
159 sql: &mut String,
160 mut page_size: i64,
161 mut page_no: i64,
162 f: Option<fn(usize, &mut String)>,
163 arg: &mut DB::Arguments<'q>,
164) -> Result<(), Error>
165where
166 DB: Database,
167 i64: Encode<'q, DB> + Type<DB>,
168{
169 if page_size < 1 {
170 page_size = 1
171 }
172 if page_no < 1 {
173 page_no = 1
174 }
175 let offset = (page_no - 1) * page_size;
176 if let Some(f) = f {
177 sql.push_str(" limit ");
178 arg.add(page_size).map_err(Error::Encode)?;
179 f(arg.len(), sql);
180 sql.push_str(" offset ");
181 arg.add(offset).map_err(Error::Encode)?;
182 f(arg.len(), sql);
183 } else {
184 sql.push_str(" limit ");
185 arg.add(page_size).map_err(Error::Encode)?;
186 arg.format_placeholder(sql)
187 .map_err(|e| Error::Encode(Box::new(e)))?;
188
189 sql.push_str(" offset ");
190 arg.add(offset).map_err(Error::Encode)?;
191 arg.format_placeholder(sql)
192 .map_err(|e| Error::Encode(Box::new(e)))?;
193 }
194
195 Ok(())
196}
197
198pub trait BackendDB<'c, DB>
212where
213 DB: Database,
214{
215 fn backend_db(
216 self,
217 ) -> impl std::future::Future<
218 Output = Result<(impl DatabaseDialect, impl Executor<'c, Database = DB> + 'c), Error>,
219 > + Send;
220}
221impl<'c, DB, C, C1> BackendDB<'c, DB> for C
222where
223 DB: Database,
224 C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
225 C1: Any,
226 for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
227{
228 async fn backend_db(
229 self,
230 ) -> Result<(impl DatabaseDialect, impl Executor<'c, Database = DB> + 'c), Error> {
231 backend_db(self).await
232 }
233}
234#[derive(Debug)]
235pub struct AdapterExecutor<'c, DB: Database, C: Executor<'c, Database = DB>> {
236 executor: Either<C, PoolConnection<DB>>,
237 _m: PhantomData<&'c ()>,
238}
239impl<'c, DB, C> AdapterExecutor<'c, DB, C>
240where
241 DB: Database,
242 C: Executor<'c, Database = DB>,
243{
244 fn new(executor: Either<C, PoolConnection<DB>>) -> Self {
245 Self {
246 executor,
247 _m: PhantomData,
248 }
249 }
250}
251
252impl<'c, DB, C> Executor<'c> for AdapterExecutor<'c, DB, C>
253where
254 DB: Database,
255 C: Executor<'c, Database = DB>,
256 for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
257{
258 type Database = DB;
259
260 fn fetch_many<'e, 'q: 'e, E>(
261 self,
262 query: E,
263 ) -> futures_core::stream::BoxStream<
264 'e,
265 Result<
266 Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
267 Error,
268 >,
269 >
270 where
271 'c: 'e,
272 E: 'q + Execute<'q, Self::Database>,
273 {
274 match self.executor {
275 Either::Left(executor) => executor.fetch_many(query),
276 Either::Right(mut conn) => Box::pin(try_stream! {
277
278
279 let mut s = conn.fetch_many(query);
280
281 while let Some(v) = s.try_next().await? {
282 r#yield!(v);
283 }
284
285 Ok(())
286 }),
287 }
288 }
289
290 fn fetch_optional<'e, 'q: 'e, E>(
291 self,
292 query: E,
293 ) -> futures_core::future::BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
294 where
295 'c: 'e,
296 E: 'q + Execute<'q, Self::Database>,
297 {
298 match self.executor {
299 Either::Left(executor) => executor.fetch_optional(query),
300 Either::Right(mut conn) => Box::pin(async move { conn.fetch_optional(query).await }),
301 }
302 }
303
304 fn prepare_with<'e, 'q: 'e>(
305 self,
306 sql: &'q str,
307 parameters: &'e [<Self::Database as Database>::TypeInfo],
308 ) -> futures_core::future::BoxFuture<
309 'e,
310 Result<<Self::Database as Database>::Statement<'q>, Error>,
311 >
312 where
313 'c: 'e,
314 {
315 match self.executor {
316 Either::Left(executor) => executor.prepare_with(sql, parameters),
317 Either::Right(mut conn) => {
318 Box::pin(async move { conn.prepare_with(sql, parameters).await })
319 }
320 }
321 }
322
323 fn describe<'e, 'q: 'e>(
324 self,
325 sql: &'q str,
326 ) -> futures_core::future::BoxFuture<'e, Result<Describe<Self::Database>, Error>>
327 where
328 'c: 'e,
329 {
330 match self.executor {
331 Either::Left(executor) => executor.describe(sql),
332 Either::Right(mut conn) => Box::pin(async move { conn.describe(sql).await }),
333 }
334 }
335}
336pub async fn backend_db<'c, DB, C, C1>(c: C) -> Result<(DBType, AdapterExecutor<'c, DB, C>), Error>
337where
338 DB: Database,
339 C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
340 C1: Any + 'static,
341{
342 if DB::NAME != sqlx_core::any::Any::NAME {
343 return Ok((
344 DBType::new(DB::NAME)?,
345 AdapterExecutor::new(Either::Left(c)),
346 ));
347 }
348
349 let any_ref = c.deref() as &dyn Any;
350 if let Some(conn) = any_ref.downcast_ref::<AnyConnection>() {
352 return Ok((
353 DBType::new(conn.backend_name())?,
354 AdapterExecutor::new(Either::Left(c)),
355 ));
356 }
357
358 if let Some(pool) = any_ref.downcast_ref::<AnyPool>() {
360 let conn = pool.acquire().await?;
361
362 let db_type = DBType::new(conn.backend_name())?;
363 let db_con: Box<dyn Any> = Box::new(conn);
364 let return_con = db_con
365 .downcast::<PoolConnection<DB>>()
366 .map_err(|_| Error::Protocol(format!("unsupport db `{}`", DB::NAME)))?;
367
368 return Ok((db_type, AdapterExecutor::new(Either::Right(*return_con))));
369 }
370 Err(Error::Protocol(format!("unsupport db `{}`", DB::NAME)))
371}