sqlx_askama_template/v3/
db_adapter.rs

1use std::{any::Any, marker::PhantomData, ops::Deref};
2
3use futures_util::TryStreamExt;
4use sqlx_core::{
5    Either, Error,
6    any::{AnyConnection, AnyPool},
7    arguments::Arguments,
8    database::Database,
9    describe::Describe,
10    encode::Encode,
11    executor::{Execute, Executor},
12    pool::PoolConnection,
13    try_stream,
14    types::Type,
15};
16/// Abstracts SQL dialect differences across database systems
17///
18/// Provides a unified interface for handling database-specific SQL syntax variations,
19/// particularly for parameter binding, count queries, and pagination.
20pub trait DatabaseDialect {
21    fn backend_name(&self) -> &str;
22    /// Gets placeholder generation function for parameter binding
23    ///
24    /// Database-specific placeholder formats:
25    /// - PostgreSQL: $1, $2...
26    /// - MySQL/SQLite: ?
27    ///
28    /// # Returns
29    /// Option<fn(usize, &mut String)> placeholder generation function
30    fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)>;
31    /// Wraps SQL in count query
32    ///
33    /// # Arguments
34    /// * `sql` - Original SQL to modify
35    fn write_count_sql(&self, sql: &mut String);
36    /// Generates pagination SQL clause
37    ///
38    /// # Arguments
39    /// * `sql` - Original SQL statement to modify
40    /// * `page_size` - Items per page
41    /// * `page_no` - Page number (auto-corrected to >=1)
42    /// * `arg` - SQL arguments container
43    ///
44    /// # Note
45    /// Automatically handles invalid page numbers
46    fn write_page_sql<'c, 'q, DB>(
47        &self,
48        sql: &mut String,
49        page_size: i64,
50        page_no: i64,
51        arg: &mut DB::Arguments<'q>,
52    ) -> Result<(), Error>
53    where
54        DB: Database,
55        i64: Encode<'q, DB> + Type<DB>;
56}
57
58/// Database type enumeration supporting major database systems
59#[derive(Debug, PartialEq)]
60pub enum DBType {
61    /// PostgreSQL database
62    PostgreSQL,
63    /// MySQL database
64    MySQL,
65    /// SQLite database
66    SQLite,
67}
68impl DBType {
69    /// Creates a DBType instance from database name
70    ///
71    /// # Arguments
72    /// * `db_name` - Database identifier ("PostgreSQL"|"MySQL"|"SQLite")
73    ///
74    /// # Errors
75    /// Returns Error::Protocol for unsupported database types
76    ///
77    /// # Example
78    /// ```
79    /// let db_type = DBType::new("PostgreSQL")?;
80    /// ```
81    pub fn new(db_name: &str) -> Result<Self, Error> {
82        match db_name {
83            "PostgreSQL" => Ok(Self::PostgreSQL),
84            "MySQL" => Ok(Self::MySQL),
85            "SQLite" => Ok(Self::SQLite),
86            _ => Err(Error::Protocol(format!("unsupport db `{}`", db_name))),
87        }
88    }
89}
90
91impl DatabaseDialect for DBType {
92    fn backend_name(&self) -> &str {
93        match self {
94            Self::PostgreSQL => "PostgreSQL",
95            Self::MySQL => "MySQL",
96            Self::SQLite => "SQLite",
97        }
98    }
99    /// Gets placeholder generation function for parameter binding
100    ///
101    /// Database-specific placeholder formats:
102    /// - PostgreSQL: $1, $2...
103    /// - MySQL/SQLite: ?
104    ///
105    /// # Returns
106    /// Option<fn(usize, &mut String)> placeholder generation function
107    fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)> {
108        match self {
109            Self::PostgreSQL => Some(|i: usize, s: &mut String| s.push_str(&format!("${}", i))),
110            Self::MySQL | Self::SQLite => Some(|_: usize, s: &mut String| s.push('?')),
111        }
112    }
113    /// Wraps SQL in count query
114    ///
115    /// # Arguments
116    /// * `sql` - Original SQL to modify
117    fn write_count_sql(&self, sql: &mut String) {
118        match self {
119            Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
120                pg_mysql_sqlite_count_sql(sql);
121            }
122        }
123    }
124    /// Generates pagination SQL clause
125    ///
126    /// # Arguments
127    /// * `sql` - Original SQL statement to modify
128    /// * `page_size` - Items per page
129    /// * `page_no` - Page number (auto-corrected to >=1)
130    /// * `arg` - SQL arguments container
131    ///
132    /// # Note
133    /// Automatically handles invalid page numbers
134    fn write_page_sql<'c, 'q, DB>(
135        &self,
136        sql: &mut String,
137        page_size: i64,
138        page_no: i64,
139
140        arg: &mut DB::Arguments<'q>,
141    ) -> Result<(), Error>
142    where
143        DB: Database,
144        i64: Encode<'q, DB> + Type<DB>,
145    {
146        let f = self.get_encode_placeholder_fn();
147        match self {
148            Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
149                pg_mysql_sqlite_page_sql(sql, page_size, page_no, f, arg)?;
150                Ok(())
151            }
152        }
153    }
154}
155fn pg_mysql_sqlite_count_sql(sql: &mut String) {
156    *sql = format!("select count(1) from ({}) t", sql)
157}
158fn pg_mysql_sqlite_page_sql<'c, 'q, DB>(
159    sql: &mut String,
160    mut page_size: i64,
161    mut page_no: i64,
162    f: Option<fn(usize, &mut String)>,
163    arg: &mut DB::Arguments<'q>,
164) -> Result<(), Error>
165where
166    DB: Database,
167    i64: Encode<'q, DB> + Type<DB>,
168{
169    if page_size < 1 {
170        page_size = 1
171    }
172    if page_no < 1 {
173        page_no = 1
174    }
175    let offset = (page_no - 1) * page_size;
176    if let Some(f) = f {
177        sql.push_str(" limit ");
178        arg.add(page_size).map_err(Error::Encode)?;
179        f(arg.len(), sql);
180        sql.push_str(" offset ");
181        arg.add(offset).map_err(Error::Encode)?;
182        f(arg.len(), sql);
183    } else {
184        sql.push_str(" limit ");
185        arg.add(page_size).map_err(Error::Encode)?;
186        arg.format_placeholder(sql)
187            .map_err(|e| Error::Encode(Box::new(e)))?;
188
189        sql.push_str(" offset ");
190        arg.add(offset).map_err(Error::Encode)?;
191        arg.format_placeholder(sql)
192            .map_err(|e| Error::Encode(Box::new(e)))?;
193    }
194
195    Ok(())
196}
197
198/// Trait for database connections/pools that can detect their backend type
199///
200/// # Type Parameters
201/// - `'c`: Connection lifetime
202/// - `DB`: Database type implementing [`sqlx::Database`]
203///
204/// # Required Implementations
205/// Automatically implemented for types that:
206/// - Implement [`Executor`] for database operations
207/// - Implement [`Deref`] to an [`Any`] type
208///
209/// # Provided Methods
210/// [`backend_db`]: Default implementation using the module-level function
211pub trait BackendDB<'c, DB>
212where
213    DB: Database,
214{
215    fn backend_db(
216        self,
217    ) -> impl std::future::Future<
218        Output = Result<(impl DatabaseDialect, impl Executor<'c, Database = DB> + 'c), Error>,
219    > + Send;
220}
221impl<'c, DB, C, C1> BackendDB<'c, DB> for C
222where
223    DB: Database,
224    C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
225    C1: Any,
226    for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
227{
228    async fn backend_db(
229        self,
230    ) -> Result<(impl DatabaseDialect, impl Executor<'c, Database = DB> + 'c), Error> {
231        backend_db(self).await
232    }
233}
234#[derive(Debug)]
235pub struct AdapterExecutor<'c, DB: Database, C: Executor<'c, Database = DB>> {
236    executor: Either<C, PoolConnection<DB>>,
237    _m: PhantomData<&'c ()>,
238}
239impl<'c, DB, C> AdapterExecutor<'c, DB, C>
240where
241    DB: Database,
242    C: Executor<'c, Database = DB>,
243{
244    fn new(executor: Either<C, PoolConnection<DB>>) -> Self {
245        Self {
246            executor,
247            _m: PhantomData,
248        }
249    }
250}
251
252impl<'c, DB, C> Executor<'c> for AdapterExecutor<'c, DB, C>
253where
254    DB: Database,
255    C: Executor<'c, Database = DB>,
256    for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
257{
258    type Database = DB;
259
260    fn fetch_many<'e, 'q: 'e, E>(
261        self,
262        query: E,
263    ) -> futures_core::stream::BoxStream<
264        'e,
265        Result<
266            Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
267            Error,
268        >,
269    >
270    where
271        'c: 'e,
272        E: 'q + Execute<'q, Self::Database>,
273    {
274        match self.executor {
275            Either::Left(executor) => executor.fetch_many(query),
276            Either::Right(mut conn) => Box::pin(try_stream! {
277
278
279                let mut s = conn.fetch_many(query);
280
281                while let Some(v) = s.try_next().await? {
282                    r#yield!(v);
283                }
284
285                Ok(())
286            }),
287        }
288    }
289
290    fn fetch_optional<'e, 'q: 'e, E>(
291        self,
292        query: E,
293    ) -> futures_core::future::BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
294    where
295        'c: 'e,
296        E: 'q + Execute<'q, Self::Database>,
297    {
298        match self.executor {
299            Either::Left(executor) => executor.fetch_optional(query),
300            Either::Right(mut conn) => Box::pin(async move { conn.fetch_optional(query).await }),
301        }
302    }
303
304    fn prepare_with<'e, 'q: 'e>(
305        self,
306        sql: &'q str,
307        parameters: &'e [<Self::Database as Database>::TypeInfo],
308    ) -> futures_core::future::BoxFuture<
309        'e,
310        Result<<Self::Database as Database>::Statement<'q>, Error>,
311    >
312    where
313        'c: 'e,
314    {
315        match self.executor {
316            Either::Left(executor) => executor.prepare_with(sql, parameters),
317            Either::Right(mut conn) => {
318                Box::pin(async move { conn.prepare_with(sql, parameters).await })
319            }
320        }
321    }
322
323    fn describe<'e, 'q: 'e>(
324        self,
325        sql: &'q str,
326    ) -> futures_core::future::BoxFuture<'e, Result<Describe<Self::Database>, Error>>
327    where
328        'c: 'e,
329    {
330        match self.executor {
331            Either::Left(executor) => executor.describe(sql),
332            Either::Right(mut conn) => Box::pin(async move { conn.describe(sql).await }),
333        }
334    }
335}
336pub async fn backend_db<'c, DB, C, C1>(c: C) -> Result<(DBType, AdapterExecutor<'c, DB, C>), Error>
337where
338    DB: Database,
339    C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
340    C1: Any + 'static,
341{
342    if DB::NAME != sqlx_core::any::Any::NAME {
343        return Ok((
344            DBType::new(DB::NAME)?,
345            AdapterExecutor::new(Either::Left(c)),
346        ));
347    }
348
349    let any_ref = c.deref() as &dyn Any;
350    // 处理 AnyConnection
351    if let Some(conn) = any_ref.downcast_ref::<AnyConnection>() {
352        return Ok((
353            DBType::new(conn.backend_name())?,
354            AdapterExecutor::new(Either::Left(c)),
355        ));
356    }
357
358    // 处理 AnyPool
359    if let Some(pool) = any_ref.downcast_ref::<AnyPool>() {
360        let conn = pool.acquire().await?;
361
362        let db_type = DBType::new(conn.backend_name())?;
363        let db_con: Box<dyn Any> = Box::new(conn);
364        let return_con = db_con
365            .downcast::<PoolConnection<DB>>()
366            .map_err(|_| Error::Protocol(format!("unsupport db `{}`", DB::NAME)))?;
367
368        return Ok((db_type, AdapterExecutor::new(Either::Right(*return_con))));
369    }
370    Err(Error::Protocol(format!("unsupport db `{}`", DB::NAME)))
371}