sea_orm/driver/
sqlx_sqlite.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    Connection, Executor, Sqlite, SqlitePool,
8    pool::PoolConnection,
9    sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16    AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17    IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*,
18    sqlx_error_to_exec_err,
19};
20
21use super::sqlx_common::*;
22
23/// Defines the [sqlx::sqlite] connector
24#[derive(Debug)]
25pub struct SqlxSqliteConnector;
26
27/// Defines a sqlx SQLite pool
28#[derive(Clone)]
29pub struct SqlxSqlitePoolConnection {
30    pub(crate) pool: SqlitePool,
31    metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxSqlitePoolConnection {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
37    }
38}
39
40impl From<SqlitePool> for SqlxSqlitePoolConnection {
41    fn from(pool: SqlitePool) -> Self {
42        SqlxSqlitePoolConnection {
43            pool,
44            metric_callback: None,
45        }
46    }
47}
48
49impl From<SqlitePool> for DatabaseConnection {
50    fn from(pool: SqlitePool) -> Self {
51        DatabaseConnectionType::SqlxSqlitePoolConnection(pool.into()).into()
52    }
53}
54
55impl SqlxSqliteConnector {
56    /// Check if the URI provided corresponds to `sqlite:` for a SQLite database
57    pub fn accepts(string: &str) -> bool {
58        string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
59    }
60
61    /// Add configuration options for the SQLite database
62    #[instrument(level = "trace")]
63    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64        let mut options = options;
65        let mut opt = options
66            .url
67            .parse::<SqliteConnectOptions>()
68            .map_err(sqlx_error_to_conn_err)?;
69        if let Some(sqlcipher_key) = &options.sqlcipher_key {
70            opt = opt.pragma("key", sqlcipher_key.clone());
71        }
72        use sqlx::ConnectOptions;
73        if !options.sqlx_logging {
74            opt = opt.disable_statement_logging();
75        } else {
76            opt = opt.log_statements(options.sqlx_logging_level);
77            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
78                opt = opt.log_slow_statements(
79                    options.sqlx_slow_statements_logging_level,
80                    options.sqlx_slow_statements_logging_threshold,
81                );
82            }
83        }
84
85        if options.get_max_connections().is_none() {
86            options.max_connections(1);
87        }
88
89        let pool = if options.connect_lazy {
90            options.sqlx_pool_options().connect_lazy_with(opt)
91        } else {
92            options
93                .sqlx_pool_options()
94                .connect_with(opt)
95                .await
96                .map_err(sqlx_error_to_conn_err)?
97        };
98
99        let pool = SqlxSqlitePoolConnection {
100            pool,
101            metric_callback: None,
102        };
103
104        #[cfg(feature = "sqlite-use-returning-for-3_35")]
105        {
106            let version = get_version(&pool).await?;
107            ensure_returning_version(&version)?;
108        }
109
110        Ok(DatabaseConnectionType::SqlxSqlitePoolConnection(pool).into())
111    }
112}
113
114impl SqlxSqliteConnector {
115    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
116    pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
117        DatabaseConnectionType::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
118            pool,
119            metric_callback: None,
120        })
121        .into()
122    }
123}
124
125impl SqlxSqlitePoolConnection {
126    /// Execute a [Statement] on a SQLite backend
127    #[instrument(level = "trace")]
128    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
129        debug_print!("{}", stmt);
130
131        let query = sqlx_query(&stmt);
132        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
133        crate::metric::metric!(self.metric_callback, &stmt, {
134            match query.execute(&mut *conn).await {
135                Ok(res) => Ok(res.into()),
136                Err(err) => Err(sqlx_error_to_exec_err(err)),
137            }
138        })
139    }
140
141    /// Execute an unprepared SQL statement on a SQLite backend
142    #[instrument(level = "trace")]
143    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
144        debug_print!("{}", sql);
145
146        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
147        match conn.execute(sql).await {
148            Ok(res) => Ok(res.into()),
149            Err(err) => Err(sqlx_error_to_exec_err(err)),
150        }
151    }
152
153    /// Get one result from a SQL query. Returns [Option::None] if no match was found
154    #[instrument(level = "trace")]
155    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
156        debug_print!("{}", stmt);
157
158        let query = sqlx_query(&stmt);
159        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
160        crate::metric::metric!(self.metric_callback, &stmt, {
161            match query.fetch_one(&mut *conn).await {
162                Ok(row) => Ok(Some(row.into())),
163                Err(err) => match err {
164                    sqlx::Error::RowNotFound => Ok(None),
165                    _ => Err(sqlx_error_to_query_err(err)),
166                },
167            }
168        })
169    }
170
171    /// Get the results of a query returning them as a Vec<[QueryResult]>
172    #[instrument(level = "trace")]
173    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
174        debug_print!("{}", stmt);
175
176        let query = sqlx_query(&stmt);
177        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
178        crate::metric::metric!(self.metric_callback, &stmt, {
179            match query.fetch_all(&mut *conn).await {
180                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
181                Err(err) => Err(sqlx_error_to_query_err(err)),
182            }
183        })
184    }
185
186    /// Stream the results of executing a SQL query
187    #[instrument(level = "trace")]
188    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
189        debug_print!("{}", stmt);
190
191        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
192        Ok(QueryStream::from((
193            conn,
194            stmt,
195            self.metric_callback.clone(),
196        )))
197    }
198
199    /// Bundle a set of SQL statements that execute together.
200    #[instrument(level = "trace")]
201    pub async fn begin(
202        &self,
203        isolation_level: Option<IsolationLevel>,
204        access_mode: Option<AccessMode>,
205    ) -> Result<DatabaseTransaction, DbErr> {
206        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
207        DatabaseTransaction::new_sqlite(
208            conn,
209            self.metric_callback.clone(),
210            isolation_level,
211            access_mode,
212        )
213        .await
214    }
215
216    /// Create a MySQL transaction
217    #[instrument(level = "trace", skip(callback))]
218    pub async fn transaction<F, T, E>(
219        &self,
220        callback: F,
221        isolation_level: Option<IsolationLevel>,
222        access_mode: Option<AccessMode>,
223    ) -> Result<T, TransactionError<E>>
224    where
225        F: for<'b> FnOnce(
226                &'b DatabaseTransaction,
227            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
228            + Send,
229        T: Send,
230        E: std::fmt::Display + std::fmt::Debug + Send,
231    {
232        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
233        let transaction = DatabaseTransaction::new_sqlite(
234            conn,
235            self.metric_callback.clone(),
236            isolation_level,
237            access_mode,
238        )
239        .await
240        .map_err(|e| TransactionError::Connection(e))?;
241        transaction.run(callback).await
242    }
243
244    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
245    where
246        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
247    {
248        self.metric_callback = Some(Arc::new(callback));
249    }
250
251    /// Checks if a connection to the database is still valid.
252    pub async fn ping(&self) -> Result<(), DbErr> {
253        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
254        match conn.ping().await {
255            Ok(_) => Ok(()),
256            Err(err) => Err(sqlx_error_to_conn_err(err)),
257        }
258    }
259
260    /// Explicitly close the SQLite connection.
261    /// See [`Self::close_by_ref`] for usage with references.
262    pub async fn close(self) -> Result<(), DbErr> {
263        self.close_by_ref().await
264    }
265
266    /// Explicitly close the SQLite connection
267    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
268        self.pool.close().await;
269        Ok(())
270    }
271}
272
273impl From<SqliteRow> for QueryResult {
274    fn from(row: SqliteRow) -> QueryResult {
275        QueryResult {
276            row: QueryResultRow::SqlxSqlite(row),
277        }
278    }
279}
280
281impl From<SqliteQueryResult> for ExecResult {
282    fn from(result: SqliteQueryResult) -> ExecResult {
283        ExecResult {
284            result: ExecResultHolder::SqlxSqlite(result),
285        }
286    }
287}
288
289pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
290    let values = stmt
291        .values
292        .as_ref()
293        .map_or(Values(Vec::new()), |values| values.clone());
294    sqlx::query_with(&stmt.sql, SqlxValues(values))
295}
296
297pub(crate) async fn set_transaction_config(
298    _conn: &mut PoolConnection<Sqlite>,
299    isolation_level: Option<IsolationLevel>,
300    access_mode: Option<AccessMode>,
301) -> Result<(), DbErr> {
302    if isolation_level.is_some() {
303        warn!("Setting isolation level in a SQLite transaction isn't supported");
304    }
305    if access_mode.is_some() {
306        warn!("Setting access mode in a SQLite transaction isn't supported");
307    }
308    Ok(())
309}
310
311#[cfg(feature = "sqlite-use-returning-for-3_35")]
312async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
313    let stmt = Statement {
314        sql: "SELECT sqlite_version()".to_string(),
315        values: None,
316        db_backend: crate::DbBackend::Sqlite,
317    };
318    conn.query_one(stmt)
319        .await?
320        .ok_or_else(|| {
321            DbErr::Conn(RuntimeErr::Internal(
322                "Error reading SQLite version".to_string(),
323            ))
324        })?
325        .try_get_by(0)
326}
327
328#[cfg(feature = "sqlite-use-returning-for-3_35")]
329fn ensure_returning_version(version: &str) -> Result<(), DbErr> {
330    let mut parts = version.trim().split('.').map(|part| {
331        part.parse::<u32>().map_err(|_| {
332            DbErr::Conn(RuntimeErr::Internal(
333                "Error parsing SQLite version".to_string(),
334            ))
335        })
336    });
337
338    let mut extract_next = || {
339        parts.next().transpose().and_then(|part| {
340            part.ok_or_else(|| {
341                DbErr::Conn(RuntimeErr::Internal("SQLite version too short".to_string()))
342            })
343        })
344    };
345
346    let major = extract_next()?;
347    let minor = extract_next()?;
348
349    if major > 3 || (major == 3 && minor >= 35) {
350        Ok(())
351    } else {
352        Err(DbErr::Conn(RuntimeErr::Internal(
353            "SQLite version does not support returning".to_string(),
354        )))
355    }
356}
357
358impl
359    From<(
360        PoolConnection<sqlx::Sqlite>,
361        Statement,
362        Option<crate::metric::Callback>,
363    )> for crate::QueryStream
364{
365    fn from(
366        (conn, stmt, metric_callback): (
367            PoolConnection<sqlx::Sqlite>,
368            Statement,
369            Option<crate::metric::Callback>,
370        ),
371    ) -> Self {
372        crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
373    }
374}
375
376impl crate::DatabaseTransaction {
377    pub(crate) async fn new_sqlite(
378        inner: PoolConnection<sqlx::Sqlite>,
379        metric_callback: Option<crate::metric::Callback>,
380        isolation_level: Option<IsolationLevel>,
381        access_mode: Option<AccessMode>,
382    ) -> Result<crate::DatabaseTransaction, DbErr> {
383        Self::begin(
384            Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
385            crate::DbBackend::Sqlite,
386            metric_callback,
387            isolation_level,
388            access_mode,
389        )
390        .await
391    }
392}
393
394#[cfg(feature = "proxy")]
395pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
396    // https://docs.rs/sqlx-sqlite/0.7.2/src/sqlx_sqlite/type_info.rs.html
397    // https://docs.rs/sqlx-sqlite/0.7.2/sqlx_sqlite/types/index.html
398    use sea_query::Value;
399    use sqlx::{Column, Row, TypeInfo};
400    crate::ProxyRow {
401        values: row
402            .columns()
403            .iter()
404            .map(|c| {
405                (
406                    c.name().to_string(),
407                    match c.type_info().name() {
408                        "BOOLEAN" => Value::Bool(Some(
409                            row.try_get(c.ordinal()).expect("Failed to get boolean"),
410                        )),
411
412                        "INTEGER" => Value::Int(Some(
413                            row.try_get(c.ordinal()).expect("Failed to get integer"),
414                        )),
415
416                        "BIGINT" | "INT8" => Value::BigInt(Some(
417                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
418                        )),
419
420                        "REAL" => Value::Double(Some(
421                            row.try_get(c.ordinal()).expect("Failed to get double"),
422                        )),
423
424                        "TEXT" => Value::String(Some(Box::new(
425                            row.try_get(c.ordinal()).expect("Failed to get string"),
426                        ))),
427
428                        "BLOB" => Value::Bytes(Some(Box::new(
429                            row.try_get(c.ordinal()).expect("Failed to get bytes"),
430                        ))),
431
432                        #[cfg(feature = "with-chrono")]
433                        "DATETIME" => Value::ChronoDateTimeUtc(Some(Box::new(
434                            row.try_get(c.ordinal()).expect("Failed to get timestamp"),
435                        ))),
436                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
437                        "DATETIME" => Value::TimeDateTimeWithTimeZone(Some(Box::new(
438                            row.try_get(c.ordinal()).expect("Failed to get timestamp"),
439                        ))),
440
441                        #[cfg(feature = "with-chrono")]
442                        "DATE" => Value::ChronoDate(Some(Box::new(
443                            row.try_get(c.ordinal()).expect("Failed to get date"),
444                        ))),
445                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
446                        "DATE" => Value::TimeDate(Some(Box::new(
447                            row.try_get(c.ordinal()).expect("Failed to get date"),
448                        ))),
449
450                        #[cfg(feature = "with-chrono")]
451                        "TIME" => Value::ChronoTime(Some(Box::new(
452                            row.try_get(c.ordinal()).expect("Failed to get time"),
453                        ))),
454                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
455                        "TIME" => Value::TimeTime(Some(Box::new(
456                            row.try_get(c.ordinal()).expect("Failed to get time"),
457                        ))),
458
459                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
460                    },
461                )
462            })
463            .collect(),
464    }
465}
466
467#[cfg(all(test, feature = "sqlite-use-returning-for-3_35"))]
468mod tests {
469    use super::*;
470
471    #[test]
472    fn test_ensure_returning_version() {
473        assert!(ensure_returning_version("").is_err());
474        assert!(ensure_returning_version(".").is_err());
475        assert!(ensure_returning_version(".a").is_err());
476        assert!(ensure_returning_version(".4.9").is_err());
477        assert!(ensure_returning_version("a").is_err());
478        assert!(ensure_returning_version("1.").is_err());
479        assert!(ensure_returning_version("1.a").is_err());
480
481        assert!(ensure_returning_version("1.1").is_err());
482        assert!(ensure_returning_version("1.0.").is_err());
483        assert!(ensure_returning_version("1.0.0").is_err());
484        assert!(ensure_returning_version("2.0.0").is_err());
485        assert!(ensure_returning_version("3.34.0").is_err());
486        assert!(ensure_returning_version("3.34.999").is_err());
487
488        // valid version
489        assert!(ensure_returning_version("3.35.0").is_ok());
490        assert!(ensure_returning_version("3.35.1").is_ok());
491        assert!(ensure_returning_version("3.36.0").is_ok());
492        assert!(ensure_returning_version("4.0.0").is_ok());
493        assert!(ensure_returning_version("99.0.0").is_ok());
494    }
495}