Skip to main content

sqlx_askama_template/
db_adapter.rs

1use std::{any::Any, marker::PhantomData, ops::Deref};
2
3use futures_util::TryStreamExt;
4use sqlx_core::{
5    Either, Error,
6    any::{AnyConnection, AnyPool},
7    arguments::Arguments,
8    database::Database,
9    describe::Describe,
10    encode::Encode,
11    executor::{Execute, Executor},
12    pool::PoolConnection,
13    sql_str::SqlStr,
14    try_stream,
15    types::Type,
16};
17/// Abstracts SQL dialect differences across database systems
18///
19/// Provides a unified interface for handling database-specific SQL syntax variations,
20/// particularly for parameter binding, count queries, and pagination.
21pub trait DatabaseDialect {
22    /// Returns the name of the database backend in use (e.g. PostgreSQL, MySQL, SQLite, etc.)
23    fn backend_name(&self) -> &str;
24    /// Gets placeholder generation function for parameter binding
25    ///
26    /// Database-specific placeholder formats:
27    /// - PostgreSQL: $1, $2...
28    /// - MySQL/SQLite: ?
29    ///
30    /// # Returns
31    /// Option<fn(usize, &mut String)> placeholder generation function
32    fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)>;
33    /// Wraps SQL in count query
34    ///
35    /// # Arguments
36    /// * `sql` - Original SQL to modify
37    fn write_count_sql(&self, sql: &mut String);
38    /// Generates pagination SQL clause
39    ///
40    /// # Arguments
41    /// * `sql` - Original SQL statement to modify
42    /// * `page_size` - Items per page
43    /// * `page_no` - Page number (auto-corrected to >=1)
44    /// * `arg` - SQL arguments container
45    ///
46    /// # Note
47    /// Automatically handles invalid page numbers
48    fn write_page_sql<'c, 'q, DB>(
49        &self,
50        sql: &mut String,
51        page_size: i64,
52        page_no: i64,
53        arg: &mut DB::Arguments,
54    ) -> Result<(), Error>
55    where
56        DB: Database,
57        i64: Encode<'q, DB> + Type<DB>;
58}
59
60/// Database type enumeration supporting major database systems
61#[derive(Debug, PartialEq)]
62pub enum DBType {
63    /// PostgreSQL database
64    PostgreSQL,
65    /// MySQL database
66    MySQL,
67    /// SQLite database
68    SQLite,
69}
70impl DBType {
71    /// Creates a DBType instance from database name
72    ///
73    /// # Arguments
74    /// * `db_name` - Database identifier ("PostgreSQL"|"MySQL"|"SQLite")
75    ///
76    /// # Errors
77    /// Returns Error::Protocol for unsupported database types
78    ///
79    /// # Example
80    /// ```
81    /// let db_type = DBType::new("PostgreSQL")?;
82    /// ```
83    pub fn new(db_name: &str) -> Result<Self, Error> {
84        match db_name {
85            "PostgreSQL" => Ok(Self::PostgreSQL),
86            "MySQL" => Ok(Self::MySQL),
87            "SQLite" => Ok(Self::SQLite),
88            _ => Err(Error::Protocol(format!("unsupport db `{db_name}`"))),
89        }
90    }
91}
92
93impl DatabaseDialect for DBType {
94    fn backend_name(&self) -> &str {
95        match self {
96            Self::PostgreSQL => "PostgreSQL",
97            Self::MySQL => "MySQL",
98            Self::SQLite => "SQLite",
99        }
100    }
101    /// Gets placeholder generation function for parameter binding
102    ///
103    /// Database-specific placeholder formats:
104    /// - PostgreSQL: $1, $2...
105    /// - MySQL/SQLite: ?
106    ///
107    /// # Returns
108    /// Option<fn(usize, &mut String)> placeholder generation function
109    fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)> {
110        match self {
111            Self::PostgreSQL => Some(|i: usize, s: &mut String| s.push_str(&format!("${i}"))),
112            Self::MySQL | Self::SQLite => Some(|_: usize, s: &mut String| s.push('?')),
113        }
114    }
115    /// Wraps SQL in count query
116    ///
117    /// # Arguments
118    /// * `sql` - Original SQL to modify
119    fn write_count_sql(&self, sql: &mut String) {
120        match self {
121            Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
122                pg_mysql_sqlite_count_sql(sql);
123            }
124        }
125    }
126    /// Generates pagination SQL clause
127    ///
128    /// # Arguments
129    /// * `sql` - Original SQL statement to modify
130    /// * `page_size` - Items per page
131    /// * `page_no` - Page number (auto-corrected to >=1)
132    /// * `arg` - SQL arguments container
133    ///
134    /// # Note
135    /// Automatically handles invalid page numbers
136    fn write_page_sql<'c, 'q, DB>(
137        &self,
138        sql: &mut String,
139        page_size: i64,
140        page_no: i64,
141
142        arg: &mut DB::Arguments,
143    ) -> Result<(), Error>
144    where
145        DB: Database,
146        i64: Encode<'q, DB> + Type<DB>,
147    {
148        let f = self.get_encode_placeholder_fn();
149        match self {
150            Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
151                pg_mysql_sqlite_page_sql(sql, page_size, page_no, f, arg)?;
152                Ok(())
153            }
154        }
155    }
156}
157fn pg_mysql_sqlite_count_sql(sql: &mut String) {
158    *sql = format!("select count(1) from ({sql}) t")
159}
160fn pg_mysql_sqlite_page_sql<'c, 'q, DB>(
161    sql: &mut String,
162    mut page_size: i64,
163    mut page_no: i64,
164    f: Option<fn(usize, &mut String)>,
165    arg: &mut DB::Arguments,
166) -> Result<(), Error>
167where
168    DB: Database,
169    i64: Encode<'q, DB> + Type<DB>,
170{
171    if page_size < 1 {
172        page_size = 1
173    }
174    if page_no < 1 {
175        page_no = 1
176    }
177    let offset = (page_no - 1) * page_size;
178    if let Some(f) = f {
179        sql.push_str(" limit ");
180        arg.add(page_size).map_err(Error::Encode)?;
181        f(arg.len(), sql);
182        sql.push_str(" offset ");
183        arg.add(offset).map_err(Error::Encode)?;
184        f(arg.len(), sql);
185    } else {
186        sql.push_str(" limit ");
187        arg.add(page_size).map_err(Error::Encode)?;
188        arg.format_placeholder(sql)
189            .map_err(|e| Error::Encode(Box::new(e)))?;
190
191        sql.push_str(" offset ");
192        arg.add(offset).map_err(Error::Encode)?;
193        arg.format_placeholder(sql)
194            .map_err(|e| Error::Encode(Box::new(e)))?;
195    }
196
197    Ok(())
198}
199
200/// Trait for database connections/pools that can detect their backend type
201///
202/// # Type Parameters
203/// - `'c`: Connection lifetime
204/// - `DB`: Database type implementing [`sqlx::Database`]
205///
206/// # Required Implementations
207/// Automatically implemented for types that:
208/// - Implement [`Executor`] for database operations
209/// - Implement [`Deref`] to an [`Any`] type
210///
211/// # Provided Methods
212/// [`backend_db`]: Default implementation using the module-level function
213pub trait BackendDB<'c, DB>
214where
215    DB: Database,
216{
217    type Executor: Executor<'c, Database = DB> + 'c;
218    type DatabaseDialect: DatabaseDialect;
219    fn backend_db(
220        self,
221    ) -> impl std::future::Future<Output = Result<(Self::DatabaseDialect, Self::Executor), Error>> + Send;
222}
223impl<'c, DB, C, C1> BackendDB<'c, DB> for C
224where
225    DB: Database,
226    C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
227    C1: Any,
228    for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
229{
230    type DatabaseDialect = DBType;
231    type Executor = AdapterExecutor<'c, DB, C>;
232    async fn backend_db(self) -> Result<(Self::DatabaseDialect, Self::Executor), Error> {
233        backend_db(self).await
234    }
235}
236#[derive(Debug)]
237pub struct AdapterExecutor<'c, DB: Database, C: Executor<'c, Database = DB>> {
238    executor: Either<C, PoolConnection<DB>>,
239    _m: PhantomData<&'c ()>,
240}
241impl<'c, DB, C> AdapterExecutor<'c, DB, C>
242where
243    DB: Database,
244    C: Executor<'c, Database = DB>,
245{
246    fn new(executor: Either<C, PoolConnection<DB>>) -> Self {
247        Self {
248            executor,
249            _m: PhantomData,
250        }
251    }
252}
253
254impl<'c, DB, C> Executor<'c> for AdapterExecutor<'c, DB, C>
255where
256    DB: Database,
257    C: Executor<'c, Database = DB>,
258    for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
259{
260    type Database = DB;
261
262    fn fetch_many<'e, 'q: 'e, E>(
263        self,
264        query: E,
265    ) -> futures_core::stream::BoxStream<
266        'e,
267        Result<
268            Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
269            Error,
270        >,
271    >
272    where
273        'c: 'e,
274        E: 'q + Execute<'q, Self::Database>,
275    {
276        match self.executor {
277            Either::Left(executor) => executor.fetch_many(query),
278            Either::Right(mut conn) => Box::pin(try_stream! {
279
280
281                let mut s = conn.fetch_many(query);
282
283                while let Some(v) = s.try_next().await? {
284                    r#yield!(v);
285                }
286
287                Ok(())
288            }),
289        }
290    }
291
292    fn fetch_optional<'e, 'q: 'e, E>(
293        self,
294        query: E,
295    ) -> futures_core::future::BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
296    where
297        'c: 'e,
298        E: 'q + Execute<'q, Self::Database>,
299    {
300        match self.executor {
301            Either::Left(executor) => executor.fetch_optional(query),
302            Either::Right(mut conn) => Box::pin(async move { conn.fetch_optional(query).await }),
303        }
304    }
305
306    fn prepare_with<'e>(
307        self,
308        sql: SqlStr,
309        parameters: &'e [<Self::Database as Database>::TypeInfo],
310    ) -> futures_core::future::BoxFuture<'e, Result<<Self::Database as Database>::Statement, Error>>
311    where
312        'c: 'e,
313    {
314        match self.executor {
315            Either::Left(executor) => executor.prepare_with(sql, parameters),
316            Either::Right(mut conn) => {
317                Box::pin(async move { conn.prepare_with(sql, parameters).await })
318            }
319        }
320    }
321
322    fn describe<'e>(
323        self,
324        sql: SqlStr,
325    ) -> futures_core::future::BoxFuture<'e, Result<Describe<Self::Database>, Error>>
326    where
327        'c: 'e,
328    {
329        match self.executor {
330            Either::Left(executor) => executor.describe(sql),
331            Either::Right(mut conn) => Box::pin(async move { conn.describe(sql).await }),
332        }
333    }
334}
335pub async fn backend_db<'c, DB, C, C1>(c: C) -> Result<(DBType, AdapterExecutor<'c, DB, C>), Error>
336where
337    DB: Database,
338    C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
339    C1: Any + 'static,
340{
341    if DB::NAME != sqlx_core::any::Any::NAME {
342        return Ok((
343            DBType::new(DB::NAME)?,
344            AdapterExecutor::new(Either::Left(c)),
345        ));
346    }
347
348    let any_ref = c.deref() as &dyn Any;
349    // 处理 AnyConnection
350    if let Some(conn) = any_ref.downcast_ref::<AnyConnection>() {
351        return Ok((
352            DBType::new(conn.backend_name())?,
353            AdapterExecutor::new(Either::Left(c)),
354        ));
355    }
356
357    // 处理 AnyPool
358    if let Some(pool) = any_ref.downcast_ref::<AnyPool>() {
359        let conn = pool.acquire().await?;
360
361        let db_type = DBType::new(conn.backend_name())?;
362        let db_con: Box<dyn Any> = Box::new(conn);
363        let return_con = db_con
364            .downcast::<PoolConnection<DB>>()
365            .map_err(|_| Error::Protocol(format!("unsupport db `{}`", DB::NAME)))?;
366
367        return Ok((db_type, AdapterExecutor::new(Either::Right(*return_con))));
368    }
369    Err(Error::Protocol(format!("unsupport db `{}`", DB::NAME)))
370}