1use std::{any::Any, marker::PhantomData, ops::Deref};
2
3use futures_util::TryStreamExt;
4use sqlx_core::{
5 Either, Error,
6 any::{AnyConnection, AnyPool},
7 arguments::Arguments,
8 database::Database,
9 describe::Describe,
10 encode::Encode,
11 executor::{Execute, Executor},
12 pool::PoolConnection,
13 try_stream,
14 types::Type,
15};
16pub trait DatabaseDialect {
21 fn backend_name(&self) -> &str;
23 fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)>;
32 fn write_count_sql(&self, sql: &mut String);
37 fn write_page_sql<'c, 'q, DB>(
48 &self,
49 sql: &mut String,
50 page_size: i64,
51 page_no: i64,
52 arg: &mut DB::Arguments<'q>,
53 ) -> Result<(), Error>
54 where
55 DB: Database,
56 i64: Encode<'q, DB> + Type<DB>;
57}
58
59#[derive(Debug, PartialEq)]
61pub enum DBType {
62 PostgreSQL,
64 MySQL,
66 SQLite,
68}
69impl DBType {
70 pub fn new(db_name: &str) -> Result<Self, Error> {
83 match db_name {
84 "PostgreSQL" => Ok(Self::PostgreSQL),
85 "MySQL" => Ok(Self::MySQL),
86 "SQLite" => Ok(Self::SQLite),
87 _ => Err(Error::Protocol(format!("unsupport db `{db_name}`"))),
88 }
89 }
90}
91
92impl DatabaseDialect for DBType {
93 fn backend_name(&self) -> &str {
94 match self {
95 Self::PostgreSQL => "PostgreSQL",
96 Self::MySQL => "MySQL",
97 Self::SQLite => "SQLite",
98 }
99 }
100 fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)> {
109 match self {
110 Self::PostgreSQL => Some(|i: usize, s: &mut String| s.push_str(&format!("${i}"))),
111 Self::MySQL | Self::SQLite => Some(|_: usize, s: &mut String| s.push('?')),
112 }
113 }
114 fn write_count_sql(&self, sql: &mut String) {
119 match self {
120 Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
121 pg_mysql_sqlite_count_sql(sql);
122 }
123 }
124 }
125 fn write_page_sql<'c, 'q, DB>(
136 &self,
137 sql: &mut String,
138 page_size: i64,
139 page_no: i64,
140
141 arg: &mut DB::Arguments<'q>,
142 ) -> Result<(), Error>
143 where
144 DB: Database,
145 i64: Encode<'q, DB> + Type<DB>,
146 {
147 let f = self.get_encode_placeholder_fn();
148 match self {
149 Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
150 pg_mysql_sqlite_page_sql(sql, page_size, page_no, f, arg)?;
151 Ok(())
152 }
153 }
154 }
155}
156fn pg_mysql_sqlite_count_sql(sql: &mut String) {
157 *sql = format!("select count(1) from ({sql}) t")
158}
159fn pg_mysql_sqlite_page_sql<'c, 'q, DB>(
160 sql: &mut String,
161 mut page_size: i64,
162 mut page_no: i64,
163 f: Option<fn(usize, &mut String)>,
164 arg: &mut DB::Arguments<'q>,
165) -> Result<(), Error>
166where
167 DB: Database,
168 i64: Encode<'q, DB> + Type<DB>,
169{
170 if page_size < 1 {
171 page_size = 1
172 }
173 if page_no < 1 {
174 page_no = 1
175 }
176 let offset = (page_no - 1) * page_size;
177 if let Some(f) = f {
178 sql.push_str(" limit ");
179 arg.add(page_size).map_err(Error::Encode)?;
180 f(arg.len(), sql);
181 sql.push_str(" offset ");
182 arg.add(offset).map_err(Error::Encode)?;
183 f(arg.len(), sql);
184 } else {
185 sql.push_str(" limit ");
186 arg.add(page_size).map_err(Error::Encode)?;
187 arg.format_placeholder(sql)
188 .map_err(|e| Error::Encode(Box::new(e)))?;
189
190 sql.push_str(" offset ");
191 arg.add(offset).map_err(Error::Encode)?;
192 arg.format_placeholder(sql)
193 .map_err(|e| Error::Encode(Box::new(e)))?;
194 }
195
196 Ok(())
197}
198
199pub trait BackendDB<'c, DB>
213where
214 DB: Database,
215{
216 type Executor: Executor<'c, Database = DB> + 'c;
217 type DatabaseDialect: DatabaseDialect;
218 fn backend_db(
219 self,
220 ) -> impl std::future::Future<Output = Result<(Self::DatabaseDialect, Self::Executor), Error>> + Send;
221}
222impl<'c, DB, C, C1> BackendDB<'c, DB> for C
223where
224 DB: Database,
225 C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
226 C1: Any,
227 for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
228{
229 type DatabaseDialect = DBType;
230 type Executor = AdapterExecutor<'c, DB, C>;
231 async fn backend_db(self) -> Result<(Self::DatabaseDialect, Self::Executor), Error> {
232 backend_db(self).await
233 }
234}
235#[derive(Debug)]
236pub struct AdapterExecutor<'c, DB: Database, C: Executor<'c, Database = DB>> {
237 executor: Either<C, PoolConnection<DB>>,
238 _m: PhantomData<&'c ()>,
239}
240impl<'c, DB, C> AdapterExecutor<'c, DB, C>
241where
242 DB: Database,
243 C: Executor<'c, Database = DB>,
244{
245 fn new(executor: Either<C, PoolConnection<DB>>) -> Self {
246 Self {
247 executor,
248 _m: PhantomData,
249 }
250 }
251}
252
253impl<'c, DB, C> Executor<'c> for AdapterExecutor<'c, DB, C>
254where
255 DB: Database,
256 C: Executor<'c, Database = DB>,
257 for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
258{
259 type Database = DB;
260
261 fn fetch_many<'e, 'q: 'e, E>(
262 self,
263 query: E,
264 ) -> futures_core::stream::BoxStream<
265 'e,
266 Result<
267 Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
268 Error,
269 >,
270 >
271 where
272 'c: 'e,
273 E: 'q + Execute<'q, Self::Database>,
274 {
275 match self.executor {
276 Either::Left(executor) => executor.fetch_many(query),
277 Either::Right(mut conn) => Box::pin(try_stream! {
278
279
280 let mut s = conn.fetch_many(query);
281
282 while let Some(v) = s.try_next().await? {
283 r#yield!(v);
284 }
285
286 Ok(())
287 }),
288 }
289 }
290
291 fn fetch_optional<'e, 'q: 'e, E>(
292 self,
293 query: E,
294 ) -> futures_core::future::BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
295 where
296 'c: 'e,
297 E: 'q + Execute<'q, Self::Database>,
298 {
299 match self.executor {
300 Either::Left(executor) => executor.fetch_optional(query),
301 Either::Right(mut conn) => Box::pin(async move { conn.fetch_optional(query).await }),
302 }
303 }
304
305 fn prepare_with<'e, 'q: 'e>(
306 self,
307 sql: &'q str,
308 parameters: &'e [<Self::Database as Database>::TypeInfo],
309 ) -> futures_core::future::BoxFuture<
310 'e,
311 Result<<Self::Database as Database>::Statement<'q>, Error>,
312 >
313 where
314 'c: 'e,
315 {
316 match self.executor {
317 Either::Left(executor) => executor.prepare_with(sql, parameters),
318 Either::Right(mut conn) => {
319 Box::pin(async move { conn.prepare_with(sql, parameters).await })
320 }
321 }
322 }
323
324 fn describe<'e, 'q: 'e>(
325 self,
326 sql: &'q str,
327 ) -> futures_core::future::BoxFuture<'e, Result<Describe<Self::Database>, Error>>
328 where
329 'c: 'e,
330 {
331 match self.executor {
332 Either::Left(executor) => executor.describe(sql),
333 Either::Right(mut conn) => Box::pin(async move { conn.describe(sql).await }),
334 }
335 }
336}
337pub async fn backend_db<'c, DB, C, C1>(c: C) -> Result<(DBType, AdapterExecutor<'c, DB, C>), Error>
338where
339 DB: Database,
340 C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
341 C1: Any + 'static,
342{
343 if DB::NAME != sqlx_core::any::Any::NAME {
344 return Ok((
345 DBType::new(DB::NAME)?,
346 AdapterExecutor::new(Either::Left(c)),
347 ));
348 }
349
350 let any_ref = c.deref() as &dyn Any;
351 if let Some(conn) = any_ref.downcast_ref::<AnyConnection>() {
353 return Ok((
354 DBType::new(conn.backend_name())?,
355 AdapterExecutor::new(Either::Left(c)),
356 ));
357 }
358
359 if let Some(pool) = any_ref.downcast_ref::<AnyPool>() {
361 let conn = pool.acquire().await?;
362
363 let db_type = DBType::new(conn.backend_name())?;
364 let db_con: Box<dyn Any> = Box::new(conn);
365 let return_con = db_con
366 .downcast::<PoolConnection<DB>>()
367 .map_err(|_| Error::Protocol(format!("unsupport db `{}`", DB::NAME)))?;
368
369 return Ok((db_type, AdapterExecutor::new(Either::Right(*return_con))));
370 }
371 Err(Error::Protocol(format!("unsupport db `{}`", DB::NAME)))
372}