1use std::{any::Any, marker::PhantomData, ops::Deref};
2
3use futures_util::TryStreamExt;
4use sqlx_core::{
5 Either, Error,
6 any::{AnyConnection, AnyPool},
7 arguments::Arguments,
8 database::Database,
9 describe::Describe,
10 encode::Encode,
11 executor::{Execute, Executor},
12 pool::PoolConnection,
13 try_stream,
14 types::Type,
15};
16pub trait DatabaseDialect: Send {
21 fn backend_name(&self) -> &str;
23 fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)>;
32 fn write_count_sql(&self, sql: &mut String);
37 fn write_page_sql<'c, 'q, DB>(
48 &self,
49 sql: &mut String,
50 page_size: i64,
51 page_no: i64,
52 arg: &mut DB::Arguments<'q>,
53 ) -> Result<(), Error>
54 where
55 DB: Database,
56 i64: Encode<'q, DB> + Type<DB>;
57}
58
59#[derive(Debug, PartialEq)]
61pub enum DBType {
62 PostgreSQL,
64 MySQL,
66 SQLite,
68}
69impl DBType {
70 pub fn new(db_name: &str) -> Result<Self, Error> {
83 match db_name {
84 "PostgreSQL" => Ok(Self::PostgreSQL),
85 "MySQL" => Ok(Self::MySQL),
86 "SQLite" => Ok(Self::SQLite),
87 _ => Err(Error::Protocol(format!("unsupport db `{db_name}`"))),
88 }
89 }
90}
91
92impl DatabaseDialect for DBType {
93 fn backend_name(&self) -> &str {
94 match self {
95 Self::PostgreSQL => "PostgreSQL",
96 Self::MySQL => "MySQL",
97 Self::SQLite => "SQLite",
98 }
99 }
100 fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)> {
109 match self {
110 Self::PostgreSQL => Some(|i: usize, s: &mut String| s.push_str(&format!("${i}"))),
111 Self::MySQL | Self::SQLite => Some(|_: usize, s: &mut String| s.push('?')),
112 }
113 }
114 fn write_count_sql(&self, sql: &mut String) {
119 match self {
120 Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
121 pg_mysql_sqlite_count_sql(sql);
122 }
123 }
124 }
125 fn write_page_sql<'c, 'q, DB>(
136 &self,
137 sql: &mut String,
138 page_size: i64,
139 page_no: i64,
140
141 arg: &mut DB::Arguments<'q>,
142 ) -> Result<(), Error>
143 where
144 DB: Database,
145 i64: Encode<'q, DB> + Type<DB>,
146 {
147 let f = self.get_encode_placeholder_fn();
148 match self {
149 Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
150 pg_mysql_sqlite_page_sql(sql, page_size, page_no, f, arg)?;
151 Ok(())
152 }
153 }
154 }
155}
156fn pg_mysql_sqlite_count_sql(sql: &mut String) {
157 *sql = format!("select count(1) from ({sql}) t")
158}
159fn pg_mysql_sqlite_page_sql<'c, 'q, DB>(
160 sql: &mut String,
161 mut page_size: i64,
162 mut page_no: i64,
163 f: Option<fn(usize, &mut String)>,
164 arg: &mut DB::Arguments<'q>,
165) -> Result<(), Error>
166where
167 DB: Database,
168 i64: Encode<'q, DB> + Type<DB>,
169{
170 if page_size < 1 {
171 page_size = 1
172 }
173 if page_no < 1 {
174 page_no = 1
175 }
176 let offset = (page_no - 1) * page_size;
177 if let Some(f) = f {
178 sql.push_str(" limit ");
179 arg.add(page_size).map_err(Error::Encode)?;
180 f(arg.len(), sql);
181 sql.push_str(" offset ");
182 arg.add(offset).map_err(Error::Encode)?;
183 f(arg.len(), sql);
184 } else {
185 sql.push_str(" limit ");
186 arg.add(page_size).map_err(Error::Encode)?;
187 arg.format_placeholder(sql)
188 .map_err(|e| Error::Encode(Box::new(e)))?;
189
190 sql.push_str(" offset ");
191 arg.add(offset).map_err(Error::Encode)?;
192 arg.format_placeholder(sql)
193 .map_err(|e| Error::Encode(Box::new(e)))?;
194 }
195
196 Ok(())
197}
198
199pub trait BackendDB<'c, DB>: Send
201where
202 DB: Database,
203{
204 type Executor: Executor<'c, Database = DB> + 'c;
205 type DatabaseDialect: DatabaseDialect;
206 fn backend_db(
207 self,
208 ) -> impl std::future::Future<Output = Result<(Self::DatabaseDialect, Self::Executor), Error>> + Send;
209}
210impl<'c, DB, C, C1> BackendDB<'c, DB> for C
211where
212 DB: Database,
213 C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
214 C1: Any,
215 for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
216{
217 type DatabaseDialect = DBType;
218 type Executor = AdapterExecutor<'c, DB, C>;
219 async fn backend_db(self) -> Result<(Self::DatabaseDialect, Self::Executor), Error> {
220 backend_db(self).await
221 }
222}
223#[derive(Debug)]
224pub struct AdapterExecutor<'c, DB: Database, C: Executor<'c, Database = DB>> {
225 executor: Either<C, PoolConnection<DB>>,
226 _m: PhantomData<&'c ()>,
227}
228impl<'c, DB, C> AdapterExecutor<'c, DB, C>
229where
230 DB: Database,
231 C: Executor<'c, Database = DB>,
232{
233 fn new(executor: Either<C, PoolConnection<DB>>) -> Self {
234 Self {
235 executor,
236 _m: PhantomData,
237 }
238 }
239}
240
241impl<'c, DB, C> Executor<'c> for AdapterExecutor<'c, DB, C>
242where
243 DB: Database,
244 C: Executor<'c, Database = DB>,
245 for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB> + 'c1,
246{
247 type Database = DB;
248
249 fn fetch_many<'e, 'q: 'e, E>(
250 self,
251 query: E,
252 ) -> futures_core::stream::BoxStream<
253 'e,
254 Result<
255 Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
256 Error,
257 >,
258 >
259 where
260 'c: 'e,
261 E: 'q + Execute<'q, Self::Database>,
262 {
263 match self.executor {
264 Either::Left(executor) => executor.fetch_many(query),
265 Either::Right(mut conn) => Box::pin(try_stream! {
266
267
268 let mut s = conn.fetch_many(query);
269
270 while let Some(v) = s.try_next().await? {
271 r#yield!(v);
272 }
273
274 Ok(())
275 }),
276 }
277 }
278
279 fn fetch_optional<'e, 'q: 'e, E>(
280 self,
281 query: E,
282 ) -> futures_core::future::BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
283 where
284 'c: 'e,
285 E: 'q + Execute<'q, Self::Database>,
286 {
287 match self.executor {
288 Either::Left(executor) => executor.fetch_optional(query),
289 Either::Right(mut conn) => Box::pin(async move { conn.fetch_optional(query).await }),
290 }
291 }
292
293 fn prepare_with<'e, 'q: 'e>(
294 self,
295 sql: &'q str,
296 parameters: &'e [<Self::Database as Database>::TypeInfo],
297 ) -> futures_core::future::BoxFuture<
298 'e,
299 Result<<Self::Database as Database>::Statement<'q>, Error>,
300 >
301 where
302 'c: 'e,
303 {
304 match self.executor {
305 Either::Left(executor) => executor.prepare_with(sql, parameters),
306 Either::Right(mut conn) => {
307 Box::pin(async move { conn.prepare_with(sql, parameters).await })
308 }
309 }
310 }
311
312 fn describe<'e, 'q: 'e>(
313 self,
314 sql: &'q str,
315 ) -> futures_core::future::BoxFuture<'e, Result<Describe<Self::Database>, Error>>
316 where
317 'c: 'e,
318 {
319 match self.executor {
320 Either::Left(executor) => executor.describe(sql),
321 Either::Right(mut conn) => Box::pin(async move { conn.describe(sql).await }),
322 }
323 }
324}
325pub async fn backend_db<'c, DB, C, C1>(c: C) -> Result<(DBType, AdapterExecutor<'c, DB, C>), Error>
326where
327 DB: Database,
328 C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
329 C1: Any + 'static,
330{
331 if DB::NAME != sqlx_core::any::Any::NAME {
332 return Ok((
333 DBType::new(DB::NAME)?,
334 AdapterExecutor::new(Either::Left(c)),
335 ));
336 }
337
338 let any_ref = c.deref() as &dyn Any;
339 if let Some(conn) = any_ref.downcast_ref::<AnyConnection>() {
341 return Ok((
342 DBType::new(conn.backend_name())?,
343 AdapterExecutor::new(Either::Left(c)),
344 ));
345 }
346
347 if let Some(pool) = any_ref.downcast_ref::<AnyPool>() {
349 let conn = pool.acquire().await?;
350
351 let db_type = DBType::new(conn.backend_name())?;
352 let db_con: Box<dyn Any> = Box::new(conn);
353 let return_con = db_con
354 .downcast::<PoolConnection<DB>>()
355 .map_err(|_| Error::Protocol(format!("unsupport db `{}`", DB::NAME)))?;
356
357 return Ok((db_type, AdapterExecutor::new(Either::Right(*return_con))));
358 }
359 Err(Error::Protocol(format!("unsupport db `{}`", DB::NAME)))
360}