1use std::{any::Any, marker::PhantomData, ops::Deref};
2
3use futures_core::future::BoxFuture;
4use futures_util::{FutureExt, TryStreamExt};
5use sqlx_core::{
6 Either, Error,
7 any::{AnyConnection, AnyPool},
8 arguments::Arguments,
9 database::Database,
10 describe::Describe,
11 encode::Encode,
12 executor::{Execute, Executor},
13 pool::PoolConnection,
14 try_stream,
15 types::Type,
16};
17pub trait DatabaseDialect: Send {
22 fn backend_name(&self) -> &str;
24 fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)>;
33 fn write_count_sql(&self, sql: &mut String);
38 fn write_page_sql<'c, 'q, DB>(
49 &self,
50 sql: &mut String,
51 page_size: i64,
52 page_no: i64,
53 arg: &mut DB::Arguments<'q>,
54 ) -> Result<(), Error>
55 where
56 DB: Database,
57 i64: Encode<'q, DB> + Type<DB>;
58}
59
60#[derive(Debug, PartialEq)]
62pub enum DBType {
63 PostgreSQL,
65 MySQL,
67 SQLite,
69}
70impl DBType {
71 pub fn new(db_name: &str) -> Result<Self, Error> {
84 match db_name {
85 "PostgreSQL" => Ok(Self::PostgreSQL),
86 "MySQL" => Ok(Self::MySQL),
87 "SQLite" => Ok(Self::SQLite),
88 _ => Err(Error::Protocol(format!("unsupport db `{db_name}`"))),
89 }
90 }
91}
92
93impl DatabaseDialect for DBType {
94 fn backend_name(&self) -> &str {
95 match self {
96 Self::PostgreSQL => "PostgreSQL",
97 Self::MySQL => "MySQL",
98 Self::SQLite => "SQLite",
99 }
100 }
101 fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)> {
110 match self {
111 Self::PostgreSQL => Some(|i: usize, s: &mut String| s.push_str(&format!("${i}"))),
112 Self::MySQL | Self::SQLite => Some(|_: usize, s: &mut String| s.push('?')),
113 }
114 }
115 fn write_count_sql(&self, sql: &mut String) {
120 match self {
121 Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
122 pg_mysql_sqlite_count_sql(sql);
123 }
124 }
125 }
126 fn write_page_sql<'c, 'q, DB>(
137 &self,
138 sql: &mut String,
139 page_size: i64,
140 page_no: i64,
141
142 arg: &mut DB::Arguments<'q>,
143 ) -> Result<(), Error>
144 where
145 DB: Database,
146 i64: Encode<'q, DB> + Type<DB>,
147 {
148 let f = self.get_encode_placeholder_fn();
149 match self {
150 Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
151 pg_mysql_sqlite_page_sql(sql, page_size, page_no, f, arg)?;
152 Ok(())
153 }
154 }
155 }
156}
157fn pg_mysql_sqlite_count_sql(sql: &mut String) {
158 *sql = format!("select count(1) from ({sql}) t")
159}
160fn pg_mysql_sqlite_page_sql<'c, 'q, DB>(
161 sql: &mut String,
162 mut page_size: i64,
163 mut page_no: i64,
164 f: Option<fn(usize, &mut String)>,
165 arg: &mut DB::Arguments<'q>,
166) -> Result<(), Error>
167where
168 DB: Database,
169 i64: Encode<'q, DB> + Type<DB>,
170{
171 if page_size < 1 {
172 page_size = 1
173 }
174 if page_no < 1 {
175 page_no = 1
176 }
177 let offset = (page_no - 1) * page_size;
178 if let Some(f) = f {
179 sql.push_str(" limit ");
180 arg.add(page_size).map_err(Error::Encode)?;
181 f(arg.len(), sql);
182 sql.push_str(" offset ");
183 arg.add(offset).map_err(Error::Encode)?;
184 f(arg.len(), sql);
185 } else {
186 sql.push_str(" limit ");
187 arg.add(page_size).map_err(Error::Encode)?;
188 arg.format_placeholder(sql)
189 .map_err(|e| Error::Encode(Box::new(e)))?;
190
191 sql.push_str(" offset ");
192 arg.add(offset).map_err(Error::Encode)?;
193 arg.format_placeholder(sql)
194 .map_err(|e| Error::Encode(Box::new(e)))?;
195 }
196
197 Ok(())
198}
199
200#[allow(clippy::type_complexity)]
202pub trait BackendDB<'c, DB: Database, C: Executor<'c, Database = DB> + 'c>: Send {
203 fn backend_db(self) -> BoxFuture<'c, Result<(DBType, AdapterExecutor<'c, DB, C>), Error>>;
204}
205impl<'c, DB, C, C1> BackendDB<'c, DB, C> for C
206where
207 DB: Database,
208 C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
209 C1: Any,
210 for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
211{
212 fn backend_db(self) -> BoxFuture<'c, Result<(DBType, AdapterExecutor<'c, DB, C>), Error>> {
213 backend_db(self).boxed()
214 }
215}
216#[derive(Debug)]
217pub struct AdapterExecutor<'c, DB: Database, C: Executor<'c, Database = DB>> {
218 executor: Either<C, PoolConnection<DB>>,
219 _m: PhantomData<&'c ()>,
220}
221impl<'c, DB, C> AdapterExecutor<'c, DB, C>
222where
223 DB: Database,
224 C: Executor<'c, Database = DB>,
225{
226 fn new(executor: Either<C, PoolConnection<DB>>) -> Self {
227 Self {
228 executor,
229 _m: PhantomData,
230 }
231 }
232}
233
234impl<'c, DB, C> Executor<'c> for AdapterExecutor<'c, DB, C>
235where
236 DB: Database,
237 C: Executor<'c, Database = DB>,
238 for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB> + 'c1,
239{
240 type Database = DB;
241
242 fn fetch_many<'e, 'q: 'e, E>(
243 self,
244 query: E,
245 ) -> futures_core::stream::BoxStream<
246 'e,
247 Result<
248 Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
249 Error,
250 >,
251 >
252 where
253 'c: 'e,
254 E: 'q + Execute<'q, Self::Database>,
255 {
256 match self.executor {
257 Either::Left(executor) => executor.fetch_many(query),
258 Either::Right(mut conn) => Box::pin(try_stream! {
259
260
261 let mut s = conn.fetch_many(query);
262
263 while let Some(v) = s.try_next().await? {
264 r#yield!(v);
265 }
266
267 Ok(())
268 }),
269 }
270 }
271
272 fn fetch_optional<'e, 'q: 'e, E>(
273 self,
274 query: E,
275 ) -> futures_core::future::BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
276 where
277 'c: 'e,
278 E: 'q + Execute<'q, Self::Database>,
279 {
280 match self.executor {
281 Either::Left(executor) => executor.fetch_optional(query),
282 Either::Right(mut conn) => Box::pin(async move { conn.fetch_optional(query).await }),
283 }
284 }
285
286 fn prepare_with<'e, 'q: 'e>(
287 self,
288 sql: &'q str,
289 parameters: &'e [<Self::Database as Database>::TypeInfo],
290 ) -> futures_core::future::BoxFuture<
291 'e,
292 Result<<Self::Database as Database>::Statement<'q>, Error>,
293 >
294 where
295 'c: 'e,
296 {
297 match self.executor {
298 Either::Left(executor) => executor.prepare_with(sql, parameters),
299 Either::Right(mut conn) => {
300 Box::pin(async move { conn.prepare_with(sql, parameters).await })
301 }
302 }
303 }
304
305 fn describe<'e, 'q: 'e>(
306 self,
307 sql: &'q str,
308 ) -> futures_core::future::BoxFuture<'e, Result<Describe<Self::Database>, Error>>
309 where
310 'c: 'e,
311 {
312 match self.executor {
313 Either::Left(executor) => executor.describe(sql),
314 Either::Right(mut conn) => Box::pin(async move { conn.describe(sql).await }),
315 }
316 }
317}
318pub async fn backend_db<'c, DB, C, C1>(c: C) -> Result<(DBType, AdapterExecutor<'c, DB, C>), Error>
319where
320 DB: Database,
321 C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
322 C1: Any + 'static,
323{
324 if DB::NAME != sqlx_core::any::Any::NAME {
325 return Ok((
326 DBType::new(DB::NAME)?,
327 AdapterExecutor::new(Either::Left(c)),
328 ));
329 }
330
331 let any_ref = c.deref() as &dyn Any;
332 if let Some(conn) = any_ref.downcast_ref::<AnyConnection>() {
334 return Ok((
335 DBType::new(conn.backend_name())?,
336 AdapterExecutor::new(Either::Left(c)),
337 ));
338 }
339
340 if let Some(pool) = any_ref.downcast_ref::<AnyPool>() {
342 let conn = pool.acquire().await?;
343
344 let db_type = DBType::new(conn.backend_name())?;
345 let db_con: Box<dyn Any> = Box::new(conn);
346 let return_con = db_con
347 .downcast::<PoolConnection<DB>>()
348 .map_err(|_| Error::Protocol(format!("unsupport db `{}`", DB::NAME)))?;
349
350 return Ok((db_type, AdapterExecutor::new(Either::Right(*return_con))));
351 }
352 Err(Error::Protocol(format!("unsupport db `{}`", DB::NAME)))
353}