sqlx_askama_template/v3/
db_adapter.rs

1use std::{any::Any, marker::PhantomData, ops::Deref};
2
3use futures_util::TryStreamExt;
4use sqlx_core::{
5    Either, Error,
6    any::{AnyConnection, AnyPool},
7    arguments::Arguments,
8    database::Database,
9    describe::Describe,
10    encode::Encode,
11    executor::{Execute, Executor},
12    pool::PoolConnection,
13    try_stream,
14    types::Type,
15};
16/// Abstracts SQL dialect differences across database systems
17///
18/// Provides a unified interface for handling database-specific SQL syntax variations,
19/// particularly for parameter binding, count queries, and pagination.
20pub trait DatabaseDialect {
21    /// Returns the name of the database backend in use (e.g. PostgreSQL, MySQL, SQLite, etc.)
22    fn backend_name(&self) -> &str;
23    /// Gets placeholder generation function for parameter binding
24    ///
25    /// Database-specific placeholder formats:
26    /// - PostgreSQL: $1, $2...
27    /// - MySQL/SQLite: ?
28    ///
29    /// # Returns
30    /// Option<fn(usize, &mut String)> placeholder generation function
31    fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)>;
32    /// Wraps SQL in count query
33    ///
34    /// # Arguments
35    /// * `sql` - Original SQL to modify
36    fn write_count_sql(&self, sql: &mut String);
37    /// Generates pagination SQL clause
38    ///
39    /// # Arguments
40    /// * `sql` - Original SQL statement to modify
41    /// * `page_size` - Items per page
42    /// * `page_no` - Page number (auto-corrected to >=1)
43    /// * `arg` - SQL arguments container
44    ///
45    /// # Note
46    /// Automatically handles invalid page numbers
47    fn write_page_sql<'c, 'q, DB>(
48        &self,
49        sql: &mut String,
50        page_size: i64,
51        page_no: i64,
52        arg: &mut DB::Arguments<'q>,
53    ) -> Result<(), Error>
54    where
55        DB: Database,
56        i64: Encode<'q, DB> + Type<DB>;
57}
58
59/// Database type enumeration supporting major database systems
60#[derive(Debug, PartialEq)]
61pub enum DBType {
62    /// PostgreSQL database
63    PostgreSQL,
64    /// MySQL database
65    MySQL,
66    /// SQLite database
67    SQLite,
68}
69impl DBType {
70    /// Creates a DBType instance from database name
71    ///
72    /// # Arguments
73    /// * `db_name` - Database identifier ("PostgreSQL"|"MySQL"|"SQLite")
74    ///
75    /// # Errors
76    /// Returns Error::Protocol for unsupported database types
77    ///
78    /// # Example
79    /// ```
80    /// let db_type = DBType::new("PostgreSQL")?;
81    /// ```
82    pub fn new(db_name: &str) -> Result<Self, Error> {
83        match db_name {
84            "PostgreSQL" => Ok(Self::PostgreSQL),
85            "MySQL" => Ok(Self::MySQL),
86            "SQLite" => Ok(Self::SQLite),
87            _ => Err(Error::Protocol(format!("unsupport db `{db_name}`"))),
88        }
89    }
90}
91
92impl DatabaseDialect for DBType {
93    fn backend_name(&self) -> &str {
94        match self {
95            Self::PostgreSQL => "PostgreSQL",
96            Self::MySQL => "MySQL",
97            Self::SQLite => "SQLite",
98        }
99    }
100    /// Gets placeholder generation function for parameter binding
101    ///
102    /// Database-specific placeholder formats:
103    /// - PostgreSQL: $1, $2...
104    /// - MySQL/SQLite: ?
105    ///
106    /// # Returns
107    /// Option<fn(usize, &mut String)> placeholder generation function
108    fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)> {
109        match self {
110            Self::PostgreSQL => Some(|i: usize, s: &mut String| s.push_str(&format!("${i}"))),
111            Self::MySQL | Self::SQLite => Some(|_: usize, s: &mut String| s.push('?')),
112        }
113    }
114    /// Wraps SQL in count query
115    ///
116    /// # Arguments
117    /// * `sql` - Original SQL to modify
118    fn write_count_sql(&self, sql: &mut String) {
119        match self {
120            Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
121                pg_mysql_sqlite_count_sql(sql);
122            }
123        }
124    }
125    /// Generates pagination SQL clause
126    ///
127    /// # Arguments
128    /// * `sql` - Original SQL statement to modify
129    /// * `page_size` - Items per page
130    /// * `page_no` - Page number (auto-corrected to >=1)
131    /// * `arg` - SQL arguments container
132    ///
133    /// # Note
134    /// Automatically handles invalid page numbers
135    fn write_page_sql<'c, 'q, DB>(
136        &self,
137        sql: &mut String,
138        page_size: i64,
139        page_no: i64,
140
141        arg: &mut DB::Arguments<'q>,
142    ) -> Result<(), Error>
143    where
144        DB: Database,
145        i64: Encode<'q, DB> + Type<DB>,
146    {
147        let f = self.get_encode_placeholder_fn();
148        match self {
149            Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
150                pg_mysql_sqlite_page_sql(sql, page_size, page_no, f, arg)?;
151                Ok(())
152            }
153        }
154    }
155}
156fn pg_mysql_sqlite_count_sql(sql: &mut String) {
157    *sql = format!("select count(1) from ({sql}) t")
158}
159fn pg_mysql_sqlite_page_sql<'c, 'q, DB>(
160    sql: &mut String,
161    mut page_size: i64,
162    mut page_no: i64,
163    f: Option<fn(usize, &mut String)>,
164    arg: &mut DB::Arguments<'q>,
165) -> Result<(), Error>
166where
167    DB: Database,
168    i64: Encode<'q, DB> + Type<DB>,
169{
170    if page_size < 1 {
171        page_size = 1
172    }
173    if page_no < 1 {
174        page_no = 1
175    }
176    let offset = (page_no - 1) * page_size;
177    if let Some(f) = f {
178        sql.push_str(" limit ");
179        arg.add(page_size).map_err(Error::Encode)?;
180        f(arg.len(), sql);
181        sql.push_str(" offset ");
182        arg.add(offset).map_err(Error::Encode)?;
183        f(arg.len(), sql);
184    } else {
185        sql.push_str(" limit ");
186        arg.add(page_size).map_err(Error::Encode)?;
187        arg.format_placeholder(sql)
188            .map_err(|e| Error::Encode(Box::new(e)))?;
189
190        sql.push_str(" offset ");
191        arg.add(offset).map_err(Error::Encode)?;
192        arg.format_placeholder(sql)
193            .map_err(|e| Error::Encode(Box::new(e)))?;
194    }
195
196    Ok(())
197}
198
199/// Trait for database connections/pools that can detect their backend type
200///
201/// # Type Parameters
202/// - `'c`: Connection lifetime
203/// - `DB`: Database type implementing [`sqlx::Database`]
204///
205/// # Required Implementations
206/// Automatically implemented for types that:
207/// - Implement [`Executor`] for database operations
208/// - Implement [`Deref`] to an [`Any`] type
209///
210/// # Provided Methods
211/// [`backend_db`]: Default implementation using the module-level function
212pub trait BackendDB<'c, DB>
213where
214    DB: Database,
215{
216    type Executor: Executor<'c, Database = DB> + 'c;
217    type DatabaseDialect: DatabaseDialect;
218    fn backend_db(
219        self,
220    ) -> impl std::future::Future<Output = Result<(Self::DatabaseDialect, Self::Executor), Error>> + Send;
221}
222impl<'c, DB, C, C1> BackendDB<'c, DB> for C
223where
224    DB: Database,
225    C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
226    C1: Any,
227    for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
228{
229    type DatabaseDialect = DBType;
230    type Executor = AdapterExecutor<'c, DB, C>;
231    async fn backend_db(self) -> Result<(Self::DatabaseDialect, Self::Executor), Error> {
232        backend_db(self).await
233    }
234}
235#[derive(Debug)]
236pub struct AdapterExecutor<'c, DB: Database, C: Executor<'c, Database = DB>> {
237    executor: Either<C, PoolConnection<DB>>,
238    _m: PhantomData<&'c ()>,
239}
240impl<'c, DB, C> AdapterExecutor<'c, DB, C>
241where
242    DB: Database,
243    C: Executor<'c, Database = DB>,
244{
245    fn new(executor: Either<C, PoolConnection<DB>>) -> Self {
246        Self {
247            executor,
248            _m: PhantomData,
249        }
250    }
251}
252
253impl<'c, DB, C> Executor<'c> for AdapterExecutor<'c, DB, C>
254where
255    DB: Database,
256    C: Executor<'c, Database = DB>,
257    for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
258{
259    type Database = DB;
260
261    fn fetch_many<'e, 'q: 'e, E>(
262        self,
263        query: E,
264    ) -> futures_core::stream::BoxStream<
265        'e,
266        Result<
267            Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
268            Error,
269        >,
270    >
271    where
272        'c: 'e,
273        E: 'q + Execute<'q, Self::Database>,
274    {
275        match self.executor {
276            Either::Left(executor) => executor.fetch_many(query),
277            Either::Right(mut conn) => Box::pin(try_stream! {
278
279
280                let mut s = conn.fetch_many(query);
281
282                while let Some(v) = s.try_next().await? {
283                    r#yield!(v);
284                }
285
286                Ok(())
287            }),
288        }
289    }
290
291    fn fetch_optional<'e, 'q: 'e, E>(
292        self,
293        query: E,
294    ) -> futures_core::future::BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
295    where
296        'c: 'e,
297        E: 'q + Execute<'q, Self::Database>,
298    {
299        match self.executor {
300            Either::Left(executor) => executor.fetch_optional(query),
301            Either::Right(mut conn) => Box::pin(async move { conn.fetch_optional(query).await }),
302        }
303    }
304
305    fn prepare_with<'e, 'q: 'e>(
306        self,
307        sql: &'q str,
308        parameters: &'e [<Self::Database as Database>::TypeInfo],
309    ) -> futures_core::future::BoxFuture<
310        'e,
311        Result<<Self::Database as Database>::Statement<'q>, Error>,
312    >
313    where
314        'c: 'e,
315    {
316        match self.executor {
317            Either::Left(executor) => executor.prepare_with(sql, parameters),
318            Either::Right(mut conn) => {
319                Box::pin(async move { conn.prepare_with(sql, parameters).await })
320            }
321        }
322    }
323
324    fn describe<'e, 'q: 'e>(
325        self,
326        sql: &'q str,
327    ) -> futures_core::future::BoxFuture<'e, Result<Describe<Self::Database>, Error>>
328    where
329        'c: 'e,
330    {
331        match self.executor {
332            Either::Left(executor) => executor.describe(sql),
333            Either::Right(mut conn) => Box::pin(async move { conn.describe(sql).await }),
334        }
335    }
336}
337pub async fn backend_db<'c, DB, C, C1>(c: C) -> Result<(DBType, AdapterExecutor<'c, DB, C>), Error>
338where
339    DB: Database,
340    C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
341    C1: Any + 'static,
342{
343    if DB::NAME != sqlx_core::any::Any::NAME {
344        return Ok((
345            DBType::new(DB::NAME)?,
346            AdapterExecutor::new(Either::Left(c)),
347        ));
348    }
349
350    let any_ref = c.deref() as &dyn Any;
351    // 处理 AnyConnection
352    if let Some(conn) = any_ref.downcast_ref::<AnyConnection>() {
353        return Ok((
354            DBType::new(conn.backend_name())?,
355            AdapterExecutor::new(Either::Left(c)),
356        ));
357    }
358
359    // 处理 AnyPool
360    if let Some(pool) = any_ref.downcast_ref::<AnyPool>() {
361        let conn = pool.acquire().await?;
362
363        let db_type = DBType::new(conn.backend_name())?;
364        let db_con: Box<dyn Any> = Box::new(conn);
365        let return_con = db_con
366            .downcast::<PoolConnection<DB>>()
367            .map_err(|_| Error::Protocol(format!("unsupport db `{}`", DB::NAME)))?;
368
369        return Ok((db_type, AdapterExecutor::new(Either::Right(*return_con))));
370    }
371    Err(Error::Protocol(format!("unsupport db `{}`", DB::NAME)))
372}