Skip to main content

sqlx_askama_template/v3/
db_adapter.rs

1use std::{any::Any, marker::PhantomData, ops::Deref};
2
3use futures_core::future::BoxFuture;
4use futures_util::{FutureExt, TryStreamExt};
5use sqlx_core::{
6    Either, Error,
7    any::{AnyConnection, AnyPool},
8    arguments::Arguments,
9    database::Database,
10    describe::Describe,
11    encode::Encode,
12    executor::{Execute, Executor},
13    pool::PoolConnection,
14    try_stream,
15    types::Type,
16};
17/// Abstracts SQL dialect differences across database systems
18///
19/// Provides a unified interface for handling database-specific SQL syntax variations,
20/// particularly for parameter binding, count queries, and pagination.
21pub trait DatabaseDialect: Send {
22    /// Returns the name of the database backend in use (e.g. PostgreSQL, MySQL, SQLite, etc.)
23    fn backend_name(&self) -> &str;
24    /// Gets placeholder generation function for parameter binding
25    ///
26    /// Database-specific placeholder formats:
27    /// - PostgreSQL: $1, $2...
28    /// - MySQL/SQLite: ?
29    ///
30    /// # Returns
31    /// Option<fn(usize, &mut String)> placeholder generation function
32    fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)>;
33    /// Wraps SQL in count query
34    ///
35    /// # Arguments
36    /// * `sql` - Original SQL to modify
37    fn write_count_sql(&self, sql: &mut String);
38    /// Generates pagination SQL clause
39    ///
40    /// # Arguments
41    /// * `sql` - Original SQL statement to modify
42    /// * `page_size` - Items per page
43    /// * `page_no` - Page number (auto-corrected to >=1)
44    /// * `arg` - SQL arguments container
45    ///
46    /// # Note
47    /// Automatically handles invalid page numbers
48    fn write_page_sql<'c, 'q, DB>(
49        &self,
50        sql: &mut String,
51        page_size: i64,
52        page_no: i64,
53        arg: &mut DB::Arguments<'q>,
54    ) -> Result<(), Error>
55    where
56        DB: Database,
57        i64: Encode<'q, DB> + Type<DB>;
58}
59
60/// Database type enumeration supporting major database systems
61#[derive(Debug, PartialEq)]
62pub enum DBType {
63    /// PostgreSQL database
64    PostgreSQL,
65    /// MySQL database
66    MySQL,
67    /// SQLite database
68    SQLite,
69}
70impl DBType {
71    /// Creates a DBType instance from database name
72    ///
73    /// # Arguments
74    /// * `db_name` - Database identifier ("PostgreSQL"|"MySQL"|"SQLite")
75    ///
76    /// # Errors
77    /// Returns Error::Protocol for unsupported database types
78    ///
79    /// # Example
80    /// ```
81    /// let db_type = DBType::new("PostgreSQL")?;
82    /// ```
83    pub fn new(db_name: &str) -> Result<Self, Error> {
84        match db_name {
85            "PostgreSQL" => Ok(Self::PostgreSQL),
86            "MySQL" => Ok(Self::MySQL),
87            "SQLite" => Ok(Self::SQLite),
88            _ => Err(Error::Protocol(format!("unsupport db `{db_name}`"))),
89        }
90    }
91}
92
93impl DatabaseDialect for DBType {
94    fn backend_name(&self) -> &str {
95        match self {
96            Self::PostgreSQL => "PostgreSQL",
97            Self::MySQL => "MySQL",
98            Self::SQLite => "SQLite",
99        }
100    }
101    /// Gets placeholder generation function for parameter binding
102    ///
103    /// Database-specific placeholder formats:
104    /// - PostgreSQL: $1, $2...
105    /// - MySQL/SQLite: ?
106    ///
107    /// # Returns
108    /// Option<fn(usize, &mut String)> placeholder generation function
109    fn get_encode_placeholder_fn(&self) -> Option<fn(usize, &mut String)> {
110        match self {
111            Self::PostgreSQL => Some(|i: usize, s: &mut String| s.push_str(&format!("${i}"))),
112            Self::MySQL | Self::SQLite => Some(|_: usize, s: &mut String| s.push('?')),
113        }
114    }
115    /// Wraps SQL in count query
116    ///
117    /// # Arguments
118    /// * `sql` - Original SQL to modify
119    fn write_count_sql(&self, sql: &mut String) {
120        match self {
121            Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
122                pg_mysql_sqlite_count_sql(sql);
123            }
124        }
125    }
126    /// Generates pagination SQL clause
127    ///
128    /// # Arguments
129    /// * `sql` - Original SQL statement to modify
130    /// * `page_size` - Items per page
131    /// * `page_no` - Page number (auto-corrected to >=1)
132    /// * `arg` - SQL arguments container
133    ///
134    /// # Note
135    /// Automatically handles invalid page numbers
136    fn write_page_sql<'c, 'q, DB>(
137        &self,
138        sql: &mut String,
139        page_size: i64,
140        page_no: i64,
141
142        arg: &mut DB::Arguments<'q>,
143    ) -> Result<(), Error>
144    where
145        DB: Database,
146        i64: Encode<'q, DB> + Type<DB>,
147    {
148        let f = self.get_encode_placeholder_fn();
149        match self {
150            Self::PostgreSQL | DBType::MySQL | DBType::SQLite => {
151                pg_mysql_sqlite_page_sql(sql, page_size, page_no, f, arg)?;
152                Ok(())
153            }
154        }
155    }
156}
157fn pg_mysql_sqlite_count_sql(sql: &mut String) {
158    *sql = format!("select count(1) from ({sql}) t")
159}
160fn pg_mysql_sqlite_page_sql<'c, 'q, DB>(
161    sql: &mut String,
162    mut page_size: i64,
163    mut page_no: i64,
164    f: Option<fn(usize, &mut String)>,
165    arg: &mut DB::Arguments<'q>,
166) -> Result<(), Error>
167where
168    DB: Database,
169    i64: Encode<'q, DB> + Type<DB>,
170{
171    if page_size < 1 {
172        page_size = 1
173    }
174    if page_no < 1 {
175        page_no = 1
176    }
177    let offset = (page_no - 1) * page_size;
178    if let Some(f) = f {
179        sql.push_str(" limit ");
180        arg.add(page_size).map_err(Error::Encode)?;
181        f(arg.len(), sql);
182        sql.push_str(" offset ");
183        arg.add(offset).map_err(Error::Encode)?;
184        f(arg.len(), sql);
185    } else {
186        sql.push_str(" limit ");
187        arg.add(page_size).map_err(Error::Encode)?;
188        arg.format_placeholder(sql)
189            .map_err(|e| Error::Encode(Box::new(e)))?;
190
191        sql.push_str(" offset ");
192        arg.add(offset).map_err(Error::Encode)?;
193        arg.format_placeholder(sql)
194            .map_err(|e| Error::Encode(Box::new(e)))?;
195    }
196
197    Ok(())
198}
199
200/// Trait for database connections/pools that can detect their backend type
201#[allow(clippy::type_complexity)]
202pub trait BackendDB<'c, DB: Database, C: Executor<'c, Database = DB> + 'c>: Send {
203    fn backend_db(self) -> BoxFuture<'c, Result<(DBType, AdapterExecutor<'c, DB, C>), Error>>;
204}
205impl<'c, DB, C, C1> BackendDB<'c, DB, C> for C
206where
207    DB: Database,
208    C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
209    C1: Any,
210    for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB>,
211{
212    fn backend_db(self) -> BoxFuture<'c, Result<(DBType, AdapterExecutor<'c, DB, C>), Error>> {
213        backend_db(self).boxed()
214    }
215}
216#[derive(Debug)]
217pub struct AdapterExecutor<'c, DB: Database, C: Executor<'c, Database = DB>> {
218    executor: Either<C, PoolConnection<DB>>,
219    _m: PhantomData<&'c ()>,
220}
221impl<'c, DB, C> AdapterExecutor<'c, DB, C>
222where
223    DB: Database,
224    C: Executor<'c, Database = DB>,
225{
226    fn new(executor: Either<C, PoolConnection<DB>>) -> Self {
227        Self {
228            executor,
229            _m: PhantomData,
230        }
231    }
232}
233
234impl<'c, DB, C> Executor<'c> for AdapterExecutor<'c, DB, C>
235where
236    DB: Database,
237    C: Executor<'c, Database = DB>,
238    for<'c1> &'c1 mut DB::Connection: Executor<'c1, Database = DB> + 'c1,
239{
240    type Database = DB;
241
242    fn fetch_many<'e, 'q: 'e, E>(
243        self,
244        query: E,
245    ) -> futures_core::stream::BoxStream<
246        'e,
247        Result<
248            Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
249            Error,
250        >,
251    >
252    where
253        'c: 'e,
254        E: 'q + Execute<'q, Self::Database>,
255    {
256        match self.executor {
257            Either::Left(executor) => executor.fetch_many(query),
258            Either::Right(mut conn) => Box::pin(try_stream! {
259
260
261                let mut s = conn.fetch_many(query);
262
263                while let Some(v) = s.try_next().await? {
264                    r#yield!(v);
265                }
266
267                Ok(())
268            }),
269        }
270    }
271
272    fn fetch_optional<'e, 'q: 'e, E>(
273        self,
274        query: E,
275    ) -> futures_core::future::BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
276    where
277        'c: 'e,
278        E: 'q + Execute<'q, Self::Database>,
279    {
280        match self.executor {
281            Either::Left(executor) => executor.fetch_optional(query),
282            Either::Right(mut conn) => Box::pin(async move { conn.fetch_optional(query).await }),
283        }
284    }
285
286    fn prepare_with<'e, 'q: 'e>(
287        self,
288        sql: &'q str,
289        parameters: &'e [<Self::Database as Database>::TypeInfo],
290    ) -> futures_core::future::BoxFuture<
291        'e,
292        Result<<Self::Database as Database>::Statement<'q>, Error>,
293    >
294    where
295        'c: 'e,
296    {
297        match self.executor {
298            Either::Left(executor) => executor.prepare_with(sql, parameters),
299            Either::Right(mut conn) => {
300                Box::pin(async move { conn.prepare_with(sql, parameters).await })
301            }
302        }
303    }
304
305    fn describe<'e, 'q: 'e>(
306        self,
307        sql: &'q str,
308    ) -> futures_core::future::BoxFuture<'e, Result<Describe<Self::Database>, Error>>
309    where
310        'c: 'e,
311    {
312        match self.executor {
313            Either::Left(executor) => executor.describe(sql),
314            Either::Right(mut conn) => Box::pin(async move { conn.describe(sql).await }),
315        }
316    }
317}
318pub async fn backend_db<'c, DB, C, C1>(c: C) -> Result<(DBType, AdapterExecutor<'c, DB, C>), Error>
319where
320    DB: Database,
321    C: Executor<'c, Database = DB> + 'c + Deref<Target = C1>,
322    C1: Any + 'static,
323{
324    if DB::NAME != sqlx_core::any::Any::NAME {
325        return Ok((
326            DBType::new(DB::NAME)?,
327            AdapterExecutor::new(Either::Left(c)),
328        ));
329    }
330
331    let any_ref = c.deref() as &dyn Any;
332    // 处理 AnyConnection
333    if let Some(conn) = any_ref.downcast_ref::<AnyConnection>() {
334        return Ok((
335            DBType::new(conn.backend_name())?,
336            AdapterExecutor::new(Either::Left(c)),
337        ));
338    }
339
340    // 处理 AnyPool
341    if let Some(pool) = any_ref.downcast_ref::<AnyPool>() {
342        let conn = pool.acquire().await?;
343
344        let db_type = DBType::new(conn.backend_name())?;
345        let db_con: Box<dyn Any> = Box::new(conn);
346        let return_con = db_con
347            .downcast::<PoolConnection<DB>>()
348            .map_err(|_| Error::Protocol(format!("unsupport db `{}`", DB::NAME)))?;
349
350        return Ok((db_type, AdapterExecutor::new(Either::Right(*return_con))));
351    }
352    Err(Error::Protocol(format!("unsupport db `{}`", DB::NAME)))
353}