Skip to main content

sqlx_askama_template/
db_adapter.rs

1use std::{any::Any, marker::PhantomData, ops::Deref};
2
3use futures_util::TryStreamExt;
4use sqlx_core::{
5    Either, Error,
6    any::{AnyConnection, AnyPool},
7    arguments::Arguments,
8    database::Database,
9    describe::Describe,
10    encode::Encode,
11    executor::{Execute, Executor},
12    pool::PoolConnection,
13    sql_str::SqlStr,
14    try_stream,
15    types::Type,
16};
17/// Abstracts SQL dialect differences across database systems
18///
19/// Provides a unified interface for handling database-specific SQL syntax variations,
20/// particularly for parameter binding, count queries, and pagination.
21pub trait DatabaseDialect: Send + Sync {
22    /// Returns the name of the database backend in use (e.g. PostgreSQL, MySQL, SQLite, etc.)
23    fn backend_name(&self) -> &str;
24    /// Gets placeholder generation function for parameter binding
25    ///
26    /// Database-specific placeholder formats:
27    /// - PostgreSQL: $1, $2...
28    /// - MySQL/SQLite: ?
29    ///
30    /// # Returns
31    /// Option<fn(usize, &mut String)> placeholder generation function
32    fn placeholder_fn(&self) -> Option<fn(usize, &mut String)>;
33    /// Wraps SQL in count query
34    ///
35    /// # Arguments
36    /// * `sql` - Original SQL to modify
37    fn write_count_sql(&self, sql: &mut String);
38    /// Generates pagination SQL clause
39    ///
40    /// # Arguments
41    /// * `sql` - Original SQL statement to modify
42    /// * `pagination_size` - Items per pagination
43    /// * `pagination_no` - Pagination number (auto-corrected to >=1)
44    /// * `arg` - SQL arguments container
45    ///
46    /// # Note
47    /// Automatically handles invalid pagination numbers
48    fn write_pagination_sql<'q, DB>(
49        &self,
50        sql: &mut String,
51        pagination_size: i64,
52        pagination_no: i64,
53        arg: &mut DB::Arguments,
54    ) -> Result<(), Error>
55    where
56        DB: Database,
57        i64: Encode<'q, DB> + Type<DB>;
58}
59
60/// Database type enumeration supporting major database systems
61#[derive(Debug, PartialEq)]
62pub enum DBType {
63    /// PostgreSQL database
64    PostgreSQL,
65    /// MySQL database
66    MySQL,
67    /// SQLite database
68    SQLite,
69}
70impl DBType {
71    /// Creates a DBType instance from database name
72    ///
73    /// # Arguments
74    /// * `db_name` - Database identifier ("PostgreSQL"|"MySQL"|"SQLite")
75    ///
76    /// # Errors
77    /// Returns Error::Protocol for unsupported database types
78    ///
79    /// # Example
80    /// ```
81    /// let db_type = DBType::new("PostgreSQL")?;
82    /// ```
83    pub fn new(db_name: &str) -> Result<Self, Error> {
84        match db_name {
85            "PostgreSQL" => Ok(Self::PostgreSQL),
86            "MySQL" => Ok(Self::MySQL),
87            "SQLite" => Ok(Self::SQLite),
88            _ => Err(Error::Protocol(format!("unsupported db `{db_name}`"))),
89        }
90    }
91}
92
93impl DatabaseDialect for DBType {
94    fn backend_name(&self) -> &str {
95        match self {
96            Self::PostgreSQL => "PostgreSQL",
97            Self::MySQL => "MySQL",
98            Self::SQLite => "SQLite",
99        }
100    }
101    /// Gets placeholder generation function for parameter binding
102    ///
103    /// Database-specific placeholder formats:
104    /// - PostgreSQL: $1, $2...
105    /// - MySQL/SQLite: ?
106    ///
107    /// # Returns
108    /// Option<fn(usize, &mut String)> placeholder generation function
109    fn placeholder_fn(&self) -> Option<fn(usize, &mut String)> {
110        match self {
111            Self::PostgreSQL => Some(|i: usize, s: &mut String| s.push_str(&format!("${i}"))),
112            Self::MySQL | Self::SQLite => Some(|_: usize, s: &mut String| s.push('?')),
113        }
114    }
115    /// Wraps SQL in count query
116    ///
117    /// # Arguments
118    /// * `sql` - Original SQL to modify
119    fn write_count_sql(&self, sql: &mut String) {
120        match self {
121            Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
122                pg_mysql_sqlite_count_sql(sql);
123            }
124        }
125    }
126    /// Generates pagination SQL clause
127    ///
128    /// # Arguments
129    /// * `sql` - Original SQL statement to modify
130    /// * `pagination_size` - Items per pagination
131    /// * `pagination_no` - Pagination number (auto-corrected to >=1)
132    /// * `arg` - SQL arguments container
133    ///
134    /// # Note
135    /// Automatically handles invalid pagination numbers
136    fn write_pagination_sql<'q, DB>(
137        &self,
138        sql: &mut String,
139        pagination_size: i64,
140        pagination_no: i64,
141
142        arg: &mut DB::Arguments,
143    ) -> Result<(), Error>
144    where
145        DB: Database,
146        i64: Encode<'q, DB> + Type<DB>,
147    {
148        let f = self.placeholder_fn();
149        match self {
150            Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
151                pg_mysql_sqlite_pagination_sql(sql, pagination_size, pagination_no, f, arg)?;
152                Ok(())
153            }
154        }
155    }
156}
157
158/// Generates count SQL query wrapping the original SQL for PostgreSQL/MySQL/SQLite databases
159fn pg_mysql_sqlite_count_sql(sql: &mut String) {
160    *sql = format!("select count(1) from ({sql}) t")
161}
162/// Generates pagination SQL clause for PostgreSQL/MySQL/SQLite databases
163fn pg_mysql_sqlite_pagination_sql<'q, DB>(
164    sql: &mut String,
165    mut pagination_size: i64,
166    mut pagination_no: i64,
167    f: Option<fn(usize, &mut String)>,
168    arg: &mut DB::Arguments,
169) -> Result<(), Error>
170where
171    DB: Database,
172    i64: Encode<'q, DB> + Type<DB>,
173{
174    if pagination_size < 1 {
175        pagination_size = 1
176    }
177    if pagination_no < 1 {
178        pagination_no = 1
179    }
180    let offset = (pagination_no - 1) * pagination_size;
181    if let Some(f) = f {
182        sql.push_str(" limit ");
183        arg.add(pagination_size).map_err(Error::Encode)?;
184        f(arg.len(), sql);
185        sql.push_str(" offset ");
186        arg.add(offset).map_err(Error::Encode)?;
187        f(arg.len(), sql);
188    } else {
189        sql.push_str(" limit ");
190        arg.add(pagination_size).map_err(Error::Encode)?;
191        arg.format_placeholder(sql)
192            .map_err(|e| Error::Encode(Box::new(e)))?;
193
194        sql.push_str(" offset ");
195        arg.add(offset).map_err(Error::Encode)?;
196        arg.format_placeholder(sql)
197            .map_err(|e| Error::Encode(Box::new(e)))?;
198    }
199
200    Ok(())
201}
202
203/// Trait for database connections/pools that can detect their backend type
204///
205/// # Type Parameters
206/// - `'c`: Connection lifetime
207/// - `DB`: Database type implementing [`sqlx::Database`]
208///
209/// # Required Implementations
210/// Automatically implemented for types that:
211/// - Implement [`Executor`] for database operations
212/// - Implement [`Deref`] to an [`Any`] type
213///
214/// # Provided Methods
215/// [`backend_db`]: Default implementation using the module-level function
216pub trait BackendDB<'c, DB>: Send
217where
218    DB: Database,
219{
220    type Executor: Executor<'c, Database = DB> + 'c;
221    type DatabaseDialect: DatabaseDialect;
222    fn backend_db(
223        self,
224    ) -> impl std::future::Future<Output = Result<(Self::DatabaseDialect, Self::Executor), Error>> + Send;
225}
226impl<'c, DB, C, C1> BackendDB<'c, DB> for C
227where
228    DB: Database,
229    C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
230    C1: Any,
231    for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
232{
233    type DatabaseDialect = DBType;
234    type Executor = AdapterExecutor<'c, DB, C>;
235    async fn backend_db(self) -> Result<(Self::DatabaseDialect, Self::Executor), Error> {
236        detect_backend_db(self).await
237    }
238}
239
240#[derive(Debug)]
241pub struct AdapterExecutor<'c, DB: Database, C: Executor<'c, Database = DB>> {
242    executor: Either<C, PoolConnection<DB>>,
243    _m: PhantomData<&'c ()>,
244}
245impl<'c, DB, C> AdapterExecutor<'c, DB, C>
246where
247    DB: Database,
248    C: Executor<'c, Database = DB>,
249{
250    fn new(executor: Either<C, PoolConnection<DB>>) -> Self {
251        Self {
252            executor,
253            _m: PhantomData,
254        }
255    }
256}
257
258impl<'c, DB, C> Executor<'c> for AdapterExecutor<'c, DB, C>
259where
260    DB: Database,
261    C: Executor<'c, Database = DB>,
262    for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
263{
264    type Database = DB;
265
266    fn fetch_many<'e, 'q: 'e, E>(
267        self,
268        query: E,
269    ) -> futures_core::stream::BoxStream<
270        'e,
271        Result<
272            Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
273            Error,
274        >,
275    >
276    where
277        'c: 'e,
278        E: 'q + Execute<'q, Self::Database>,
279    {
280        match self.executor {
281            Either::Left(executor) => executor.fetch_many(query),
282            Either::Right(mut conn) => Box::pin(try_stream! {
283
284
285                let mut s = conn.fetch_many(query);
286
287                while let Some(v) = s.try_next().await? {
288                    r#yield!(v);
289                }
290
291                Ok(())
292            }),
293        }
294    }
295
296    fn fetch_optional<'e, 'q: 'e, E>(
297        self,
298        query: E,
299    ) -> futures_core::future::BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
300    where
301        'c: 'e,
302        E: 'q + Execute<'q, Self::Database>,
303    {
304        match self.executor {
305            Either::Left(executor) => executor.fetch_optional(query),
306            Either::Right(mut conn) => Box::pin(async move { conn.fetch_optional(query).await }),
307        }
308    }
309
310    fn prepare_with<'e>(
311        self,
312        sql: SqlStr,
313        parameters: &'e [<Self::Database as Database>::TypeInfo],
314    ) -> futures_core::future::BoxFuture<'e, Result<<Self::Database as Database>::Statement, Error>>
315    where
316        'c: 'e,
317    {
318        match self.executor {
319            Either::Left(executor) => executor.prepare_with(sql, parameters),
320            Either::Right(mut conn) => {
321                Box::pin(async move { conn.prepare_with(sql, parameters).await })
322            }
323        }
324    }
325
326    fn describe<'e>(
327        self,
328        sql: SqlStr,
329    ) -> futures_core::future::BoxFuture<'e, Result<Describe<Self::Database>, Error>>
330    where
331        'c: 'e,
332    {
333        match self.executor {
334            Either::Left(executor) => executor.describe(sql),
335            Either::Right(mut conn) => Box::pin(async move { conn.describe(sql).await }),
336        }
337    }
338}
339/// Detect the real database type from the executor.
340/// params
341///  - c: The executor.
342///
343/// returns
344///  - (DBType, AdapterExecutor<'c, DB, C>): The database type and the adapter executor.
345pub async fn detect_backend_db<'c, DB, C, C1>(
346    c: C,
347) -> Result<(DBType, AdapterExecutor<'c, DB, C>), Error>
348where
349    DB: Database,
350    C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
351    C1: Any,
352{
353    if DB::NAME != sqlx_core::any::Any::NAME {
354        return Ok((
355            DBType::new(DB::NAME)?,
356            AdapterExecutor::new(Either::Left(c)),
357        ));
358    }
359
360    let any_ref = c.deref() as &dyn Any;
361    //处理 AnyConnection
362    if let Some(conn) = any_ref.downcast_ref::<AnyConnection>() {
363        return Ok((
364            DBType::new(conn.backend_name())?,
365            AdapterExecutor::new(Either::Left(c)),
366        ));
367    }
368
369    //处理 AnyPool
370    if let Some(pool) = any_ref.downcast_ref::<AnyPool>() {
371        let conn = pool.acquire().await?;
372
373        let db_type = DBType::new(conn.backend_name())?;
374        let db_con: Box<dyn Any> = Box::new(conn);
375        let return_con = db_con
376            .downcast::<PoolConnection<DB>>()
377            .map_err(|_| Error::Protocol(format!("unsupported db `{}`", DB::NAME)))?;
378
379        return Ok((db_type, AdapterExecutor::new(Either::Right(*return_con))));
380    }
381    Err(Error::Protocol(format!("unsupported db `{}`", DB::NAME)))
382}