sea_orm/driver/
sqlx_sqlite.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    Connection, Executor, Sqlite, SqlitePool,
8    pool::PoolConnection,
9    sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16    AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17    IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*,
18    sqlx_error_to_exec_err,
19};
20
21use super::sqlx_common::*;
22
23/// Defines the [sqlx::sqlite] connector
24#[derive(Debug)]
25pub struct SqlxSqliteConnector;
26
27/// Defines a sqlx SQLite pool
28#[derive(Clone)]
29pub struct SqlxSqlitePoolConnection {
30    pub(crate) pool: SqlitePool,
31    metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxSqlitePoolConnection {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
37    }
38}
39
40impl From<SqlitePool> for SqlxSqlitePoolConnection {
41    fn from(pool: SqlitePool) -> Self {
42        SqlxSqlitePoolConnection {
43            pool,
44            metric_callback: None,
45        }
46    }
47}
48
49impl From<SqlitePool> for DatabaseConnection {
50    fn from(pool: SqlitePool) -> Self {
51        DatabaseConnectionType::SqlxSqlitePoolConnection(pool.into()).into()
52    }
53}
54
55impl SqlxSqliteConnector {
56    /// Check if the URI provided corresponds to `sqlite:` for a SQLite database
57    pub fn accepts(string: &str) -> bool {
58        string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
59    }
60
61    /// Add configuration options for the SQLite database
62    #[instrument(level = "trace")]
63    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64        let mut options = options;
65        let mut sqlx_opts = options
66            .url
67            .parse::<SqliteConnectOptions>()
68            .map_err(sqlx_error_to_conn_err)?;
69        if let Some(sqlcipher_key) = &options.sqlcipher_key {
70            sqlx_opts = sqlx_opts.pragma("key", sqlcipher_key.clone());
71        }
72        use sqlx::ConnectOptions;
73        if !options.sqlx_logging {
74            sqlx_opts = sqlx_opts.disable_statement_logging();
75        } else {
76            sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
77            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
78                sqlx_opts = sqlx_opts.log_slow_statements(
79                    options.sqlx_slow_statements_logging_level,
80                    options.sqlx_slow_statements_logging_threshold,
81                );
82            }
83        }
84
85        if options.get_max_connections().is_none() {
86            options.max_connections(1);
87        }
88
89        if let Some(f) = &options.sqlite_opts_fn {
90            sqlx_opts = f(sqlx_opts);
91        }
92
93        let pool = if options.connect_lazy {
94            options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
95        } else {
96            options
97                .sqlx_pool_options()
98                .connect_with(sqlx_opts)
99                .await
100                .map_err(sqlx_error_to_conn_err)?
101        };
102
103        let pool = SqlxSqlitePoolConnection {
104            pool,
105            metric_callback: None,
106        };
107
108        #[cfg(feature = "sqlite-use-returning-for-3_35")]
109        {
110            let version = get_version(&pool).await?;
111            ensure_returning_version(&version)?;
112        }
113
114        Ok(DatabaseConnectionType::SqlxSqlitePoolConnection(pool).into())
115    }
116}
117
118impl SqlxSqliteConnector {
119    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
120    pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
121        DatabaseConnectionType::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
122            pool,
123            metric_callback: None,
124        })
125        .into()
126    }
127}
128
129impl SqlxSqlitePoolConnection {
130    /// Execute a [Statement] on a SQLite backend
131    #[instrument(level = "trace")]
132    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
133        debug_print!("{}", stmt);
134
135        let query = sqlx_query(&stmt);
136        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
137        crate::metric::metric!(self.metric_callback, &stmt, {
138            match query.execute(&mut *conn).await {
139                Ok(res) => Ok(res.into()),
140                Err(err) => Err(sqlx_error_to_exec_err(err)),
141            }
142        })
143    }
144
145    /// Execute an unprepared SQL statement on a SQLite backend
146    #[instrument(level = "trace")]
147    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
148        debug_print!("{}", sql);
149
150        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
151        match conn.execute(sql).await {
152            Ok(res) => Ok(res.into()),
153            Err(err) => Err(sqlx_error_to_exec_err(err)),
154        }
155    }
156
157    /// Get one result from a SQL query. Returns [Option::None] if no match was found
158    #[instrument(level = "trace")]
159    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
160        debug_print!("{}", stmt);
161
162        let query = sqlx_query(&stmt);
163        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
164        crate::metric::metric!(self.metric_callback, &stmt, {
165            match query.fetch_one(&mut *conn).await {
166                Ok(row) => Ok(Some(row.into())),
167                Err(err) => match err {
168                    sqlx::Error::RowNotFound => Ok(None),
169                    _ => Err(sqlx_error_to_query_err(err)),
170                },
171            }
172        })
173    }
174
175    /// Get the results of a query returning them as a Vec<[QueryResult]>
176    #[instrument(level = "trace")]
177    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
178        debug_print!("{}", stmt);
179
180        let query = sqlx_query(&stmt);
181        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
182        crate::metric::metric!(self.metric_callback, &stmt, {
183            match query.fetch_all(&mut *conn).await {
184                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
185                Err(err) => Err(sqlx_error_to_query_err(err)),
186            }
187        })
188    }
189
190    /// Stream the results of executing a SQL query
191    #[instrument(level = "trace")]
192    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
193        debug_print!("{}", stmt);
194
195        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
196        Ok(QueryStream::from((
197            conn,
198            stmt,
199            self.metric_callback.clone(),
200        )))
201    }
202
203    /// Bundle a set of SQL statements that execute together.
204    #[instrument(level = "trace")]
205    pub async fn begin(
206        &self,
207        isolation_level: Option<IsolationLevel>,
208        access_mode: Option<AccessMode>,
209    ) -> Result<DatabaseTransaction, DbErr> {
210        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
211        DatabaseTransaction::new_sqlite(
212            conn,
213            self.metric_callback.clone(),
214            isolation_level,
215            access_mode,
216        )
217        .await
218    }
219
220    /// Create a MySQL transaction
221    #[instrument(level = "trace", skip(callback))]
222    pub async fn transaction<F, T, E>(
223        &self,
224        callback: F,
225        isolation_level: Option<IsolationLevel>,
226        access_mode: Option<AccessMode>,
227    ) -> Result<T, TransactionError<E>>
228    where
229        F: for<'b> FnOnce(
230                &'b DatabaseTransaction,
231            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
232            + Send,
233        T: Send,
234        E: std::fmt::Display + std::fmt::Debug + Send,
235    {
236        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
237        let transaction = DatabaseTransaction::new_sqlite(
238            conn,
239            self.metric_callback.clone(),
240            isolation_level,
241            access_mode,
242        )
243        .await
244        .map_err(|e| TransactionError::Connection(e))?;
245        transaction.run(callback).await
246    }
247
248    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
249    where
250        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
251    {
252        self.metric_callback = Some(Arc::new(callback));
253    }
254
255    /// Checks if a connection to the database is still valid.
256    pub async fn ping(&self) -> Result<(), DbErr> {
257        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
258        match conn.ping().await {
259            Ok(_) => Ok(()),
260            Err(err) => Err(sqlx_error_to_conn_err(err)),
261        }
262    }
263
264    /// Explicitly close the SQLite connection.
265    /// See [`Self::close_by_ref`] for usage with references.
266    pub async fn close(self) -> Result<(), DbErr> {
267        self.close_by_ref().await
268    }
269
270    /// Explicitly close the SQLite connection
271    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
272        self.pool.close().await;
273        Ok(())
274    }
275}
276
277impl From<SqliteRow> for QueryResult {
278    fn from(row: SqliteRow) -> QueryResult {
279        QueryResult {
280            row: QueryResultRow::SqlxSqlite(row),
281        }
282    }
283}
284
285impl From<SqliteQueryResult> for ExecResult {
286    fn from(result: SqliteQueryResult) -> ExecResult {
287        ExecResult {
288            result: ExecResultHolder::SqlxSqlite(result),
289        }
290    }
291}
292
293pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
294    let values = stmt
295        .values
296        .as_ref()
297        .map_or(Values(Vec::new()), |values| values.clone());
298    sqlx::query_with(&stmt.sql, SqlxValues(values))
299}
300
301pub(crate) async fn set_transaction_config(
302    _conn: &mut PoolConnection<Sqlite>,
303    isolation_level: Option<IsolationLevel>,
304    access_mode: Option<AccessMode>,
305) -> Result<(), DbErr> {
306    if isolation_level.is_some() {
307        warn!("Setting isolation level in a SQLite transaction isn't supported");
308    }
309    if access_mode.is_some() {
310        warn!("Setting access mode in a SQLite transaction isn't supported");
311    }
312    Ok(())
313}
314
315#[cfg(feature = "sqlite-use-returning-for-3_35")]
316async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
317    let stmt = Statement {
318        sql: "SELECT sqlite_version()".to_string(),
319        values: None,
320        db_backend: crate::DbBackend::Sqlite,
321    };
322    conn.query_one(stmt)
323        .await?
324        .ok_or_else(|| {
325            DbErr::Conn(RuntimeErr::Internal(
326                "Error reading SQLite version".to_string(),
327            ))
328        })?
329        .try_get_by(0)
330}
331
332#[cfg(feature = "sqlite-use-returning-for-3_35")]
333fn ensure_returning_version(version: &str) -> Result<(), DbErr> {
334    let mut parts = version.trim().split('.').map(|part| {
335        part.parse::<u32>().map_err(|_| {
336            DbErr::Conn(RuntimeErr::Internal(
337                "Error parsing SQLite version".to_string(),
338            ))
339        })
340    });
341
342    let mut extract_next = || {
343        parts.next().transpose().and_then(|part| {
344            part.ok_or_else(|| {
345                DbErr::Conn(RuntimeErr::Internal("SQLite version too short".to_string()))
346            })
347        })
348    };
349
350    let major = extract_next()?;
351    let minor = extract_next()?;
352
353    if major > 3 || (major == 3 && minor >= 35) {
354        Ok(())
355    } else {
356        Err(DbErr::Conn(RuntimeErr::Internal(
357            "SQLite version does not support returning".to_string(),
358        )))
359    }
360}
361
362impl
363    From<(
364        PoolConnection<sqlx::Sqlite>,
365        Statement,
366        Option<crate::metric::Callback>,
367    )> for crate::QueryStream
368{
369    fn from(
370        (conn, stmt, metric_callback): (
371            PoolConnection<sqlx::Sqlite>,
372            Statement,
373            Option<crate::metric::Callback>,
374        ),
375    ) -> Self {
376        crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
377    }
378}
379
380impl crate::DatabaseTransaction {
381    pub(crate) async fn new_sqlite(
382        inner: PoolConnection<sqlx::Sqlite>,
383        metric_callback: Option<crate::metric::Callback>,
384        isolation_level: Option<IsolationLevel>,
385        access_mode: Option<AccessMode>,
386    ) -> Result<crate::DatabaseTransaction, DbErr> {
387        Self::begin(
388            Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
389            crate::DbBackend::Sqlite,
390            metric_callback,
391            isolation_level,
392            access_mode,
393        )
394        .await
395    }
396}
397
398#[cfg(feature = "proxy")]
399pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
400    // https://docs.rs/sqlx-sqlite/0.7.2/src/sqlx_sqlite/type_info.rs.html
401    // https://docs.rs/sqlx-sqlite/0.7.2/sqlx_sqlite/types/index.html
402    use sea_query::Value;
403    use sqlx::{Column, Row, TypeInfo};
404    crate::ProxyRow {
405        values: row
406            .columns()
407            .iter()
408            .map(|c| {
409                (
410                    c.name().to_string(),
411                    match c.type_info().name() {
412                        "BOOLEAN" => Value::Bool(Some(
413                            row.try_get(c.ordinal()).expect("Failed to get boolean"),
414                        )),
415
416                        "INTEGER" => Value::Int(Some(
417                            row.try_get(c.ordinal()).expect("Failed to get integer"),
418                        )),
419
420                        "BIGINT" | "INT8" => Value::BigInt(Some(
421                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
422                        )),
423
424                        "REAL" => Value::Double(Some(
425                            row.try_get(c.ordinal()).expect("Failed to get double"),
426                        )),
427
428                        "TEXT" => Value::String(Some(Box::new(
429                            row.try_get(c.ordinal()).expect("Failed to get string"),
430                        ))),
431
432                        "BLOB" => Value::Bytes(Some(Box::new(
433                            row.try_get(c.ordinal()).expect("Failed to get bytes"),
434                        ))),
435
436                        #[cfg(feature = "with-chrono")]
437                        "DATETIME" => Value::ChronoDateTimeUtc(Some(Box::new(
438                            row.try_get(c.ordinal()).expect("Failed to get timestamp"),
439                        ))),
440                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
441                        "DATETIME" => Value::TimeDateTimeWithTimeZone(Some(Box::new(
442                            row.try_get(c.ordinal()).expect("Failed to get timestamp"),
443                        ))),
444
445                        #[cfg(feature = "with-chrono")]
446                        "DATE" => Value::ChronoDate(Some(Box::new(
447                            row.try_get(c.ordinal()).expect("Failed to get date"),
448                        ))),
449                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
450                        "DATE" => Value::TimeDate(Some(Box::new(
451                            row.try_get(c.ordinal()).expect("Failed to get date"),
452                        ))),
453
454                        #[cfg(feature = "with-chrono")]
455                        "TIME" => Value::ChronoTime(Some(Box::new(
456                            row.try_get(c.ordinal()).expect("Failed to get time"),
457                        ))),
458                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
459                        "TIME" => Value::TimeTime(Some(Box::new(
460                            row.try_get(c.ordinal()).expect("Failed to get time"),
461                        ))),
462
463                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
464                    },
465                )
466            })
467            .collect(),
468    }
469}
470
471#[cfg(all(test, feature = "sqlite-use-returning-for-3_35"))]
472mod tests {
473    use super::*;
474
475    #[test]
476    fn test_ensure_returning_version() {
477        assert!(ensure_returning_version("").is_err());
478        assert!(ensure_returning_version(".").is_err());
479        assert!(ensure_returning_version(".a").is_err());
480        assert!(ensure_returning_version(".4.9").is_err());
481        assert!(ensure_returning_version("a").is_err());
482        assert!(ensure_returning_version("1.").is_err());
483        assert!(ensure_returning_version("1.a").is_err());
484
485        assert!(ensure_returning_version("1.1").is_err());
486        assert!(ensure_returning_version("1.0.").is_err());
487        assert!(ensure_returning_version("1.0.0").is_err());
488        assert!(ensure_returning_version("2.0.0").is_err());
489        assert!(ensure_returning_version("3.34.0").is_err());
490        assert!(ensure_returning_version("3.34.999").is_err());
491
492        // valid version
493        assert!(ensure_returning_version("3.35.0").is_ok());
494        assert!(ensure_returning_version("3.35.1").is_ok());
495        assert!(ensure_returning_version("3.36.0").is_ok());
496        assert!(ensure_returning_version("4.0.0").is_ok());
497        assert!(ensure_returning_version("99.0.0").is_ok());
498    }
499}