sea_orm/driver/
sqlx_sqlite.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    pool::PoolConnection,
8    sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
9    Connection, Executor, Sqlite, SqlitePool,
10};
11
12use sea_query_binder::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16    debug_print, error::*, executor::*, sqlx_error_to_exec_err, AccessMode, ConnectOptions,
17    DatabaseConnection, DatabaseTransaction, IsolationLevel, QueryStream, Statement,
18    TransactionError,
19};
20
21use super::sqlx_common::*;
22
23/// Defines the [sqlx::sqlite] connector
24#[derive(Debug)]
25pub struct SqlxSqliteConnector;
26
27/// Defines a sqlx SQLite pool
28#[derive(Clone)]
29pub struct SqlxSqlitePoolConnection {
30    pub(crate) pool: SqlitePool,
31    metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxSqlitePoolConnection {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
37    }
38}
39
40impl From<SqlitePool> for SqlxSqlitePoolConnection {
41    fn from(pool: SqlitePool) -> Self {
42        SqlxSqlitePoolConnection {
43            pool,
44            metric_callback: None,
45        }
46    }
47}
48
49impl From<SqlitePool> for DatabaseConnection {
50    fn from(pool: SqlitePool) -> Self {
51        DatabaseConnection::SqlxSqlitePoolConnection(pool.into())
52    }
53}
54
55impl SqlxSqliteConnector {
56    /// Check if the URI provided corresponds to `sqlite:` for a SQLite database
57    pub fn accepts(string: &str) -> bool {
58        string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
59    }
60
61    /// Add configuration options for the SQLite database
62    #[instrument(level = "trace")]
63    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64        let mut options = options;
65        let mut sqlx_opts = options
66            .url
67            .parse::<SqliteConnectOptions>()
68            .map_err(sqlx_error_to_conn_err)?;
69        if let Some(sqlcipher_key) = &options.sqlcipher_key {
70            sqlx_opts = sqlx_opts.pragma("key", sqlcipher_key.clone());
71        }
72        use sqlx::ConnectOptions;
73        if !options.sqlx_logging {
74            sqlx_opts = sqlx_opts.disable_statement_logging();
75        } else {
76            sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
77            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
78                sqlx_opts = sqlx_opts.log_slow_statements(
79                    options.sqlx_slow_statements_logging_level,
80                    options.sqlx_slow_statements_logging_threshold,
81                );
82            }
83        }
84
85        if options.get_max_connections().is_none() {
86            options.max_connections(1);
87        }
88
89        if let Some(f) = &options.sqlite_opts_fn {
90            sqlx_opts = f(sqlx_opts);
91        }
92
93        let pool = if options.connect_lazy {
94            options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
95        } else {
96            options
97                .sqlx_pool_options()
98                .connect_with(sqlx_opts)
99                .await
100                .map_err(sqlx_error_to_conn_err)?
101        };
102
103        let pool = SqlxSqlitePoolConnection {
104            pool,
105            metric_callback: None,
106        };
107
108        #[cfg(feature = "sqlite-use-returning-for-3_35")]
109        {
110            let version = get_version(&pool).await?;
111            ensure_returning_version(&version)?;
112        }
113
114        Ok(DatabaseConnection::SqlxSqlitePoolConnection(pool))
115    }
116}
117
118impl SqlxSqliteConnector {
119    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
120    pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
121        DatabaseConnection::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
122            pool,
123            metric_callback: None,
124        })
125    }
126}
127
128impl SqlxSqlitePoolConnection {
129    /// Execute a [Statement] on a SQLite backend
130    #[instrument(level = "trace")]
131    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
132        debug_print!("{}", stmt);
133
134        let query = sqlx_query(&stmt);
135        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
136        crate::metric::metric!(self.metric_callback, &stmt, {
137            match query.execute(&mut *conn).await {
138                Ok(res) => Ok(res.into()),
139                Err(err) => Err(sqlx_error_to_exec_err(err)),
140            }
141        })
142    }
143
144    /// Execute an unprepared SQL statement on a SQLite backend
145    #[instrument(level = "trace")]
146    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
147        debug_print!("{}", sql);
148
149        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
150        match conn.execute(sql).await {
151            Ok(res) => Ok(res.into()),
152            Err(err) => Err(sqlx_error_to_exec_err(err)),
153        }
154    }
155
156    /// Get one result from a SQL query. Returns [Option::None] if no match was found
157    #[instrument(level = "trace")]
158    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
159        debug_print!("{}", stmt);
160
161        let query = sqlx_query(&stmt);
162        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
163        crate::metric::metric!(self.metric_callback, &stmt, {
164            match query.fetch_one(&mut *conn).await {
165                Ok(row) => Ok(Some(row.into())),
166                Err(err) => match err {
167                    sqlx::Error::RowNotFound => Ok(None),
168                    _ => Err(sqlx_error_to_query_err(err)),
169                },
170            }
171        })
172    }
173
174    /// Get the results of a query returning them as a Vec<[QueryResult]>
175    #[instrument(level = "trace")]
176    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
177        debug_print!("{}", stmt);
178
179        let query = sqlx_query(&stmt);
180        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
181        crate::metric::metric!(self.metric_callback, &stmt, {
182            match query.fetch_all(&mut *conn).await {
183                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
184                Err(err) => Err(sqlx_error_to_query_err(err)),
185            }
186        })
187    }
188
189    /// Stream the results of executing a SQL query
190    #[instrument(level = "trace")]
191    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
192        debug_print!("{}", stmt);
193
194        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
195        Ok(QueryStream::from((
196            conn,
197            stmt,
198            self.metric_callback.clone(),
199        )))
200    }
201
202    /// Bundle a set of SQL statements that execute together.
203    #[instrument(level = "trace")]
204    pub async fn begin(
205        &self,
206        isolation_level: Option<IsolationLevel>,
207        access_mode: Option<AccessMode>,
208    ) -> Result<DatabaseTransaction, DbErr> {
209        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
210        DatabaseTransaction::new_sqlite(
211            conn,
212            self.metric_callback.clone(),
213            isolation_level,
214            access_mode,
215        )
216        .await
217    }
218
219    /// Create a MySQL transaction
220    #[instrument(level = "trace", skip(callback))]
221    pub async fn transaction<F, T, E>(
222        &self,
223        callback: F,
224        isolation_level: Option<IsolationLevel>,
225        access_mode: Option<AccessMode>,
226    ) -> Result<T, TransactionError<E>>
227    where
228        F: for<'b> FnOnce(
229                &'b DatabaseTransaction,
230            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
231            + Send,
232        T: Send,
233        E: std::fmt::Display + std::fmt::Debug + Send,
234    {
235        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
236        let transaction = DatabaseTransaction::new_sqlite(
237            conn,
238            self.metric_callback.clone(),
239            isolation_level,
240            access_mode,
241        )
242        .await
243        .map_err(|e| TransactionError::Connection(e))?;
244        transaction.run(callback).await
245    }
246
247    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
248    where
249        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
250    {
251        self.metric_callback = Some(Arc::new(callback));
252    }
253
254    /// Checks if a connection to the database is still valid.
255    pub async fn ping(&self) -> Result<(), DbErr> {
256        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
257        match conn.ping().await {
258            Ok(_) => Ok(()),
259            Err(err) => Err(sqlx_error_to_conn_err(err)),
260        }
261    }
262
263    /// Explicitly close the SQLite connection.
264    /// See [`Self::close_by_ref`] for usage with references.
265    pub async fn close(self) -> Result<(), DbErr> {
266        self.close_by_ref().await
267    }
268
269    /// Explicitly close the SQLite connection
270    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
271        self.pool.close().await;
272        Ok(())
273    }
274}
275
276impl From<SqliteRow> for QueryResult {
277    fn from(row: SqliteRow) -> QueryResult {
278        QueryResult {
279            row: QueryResultRow::SqlxSqlite(row),
280        }
281    }
282}
283
284impl From<SqliteQueryResult> for ExecResult {
285    fn from(result: SqliteQueryResult) -> ExecResult {
286        ExecResult {
287            result: ExecResultHolder::SqlxSqlite(result),
288        }
289    }
290}
291
292pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
293    let values = stmt
294        .values
295        .as_ref()
296        .map_or(Values(Vec::new()), |values| values.clone());
297    sqlx::query_with(&stmt.sql, SqlxValues(values))
298}
299
300pub(crate) async fn set_transaction_config(
301    _conn: &mut PoolConnection<Sqlite>,
302    isolation_level: Option<IsolationLevel>,
303    access_mode: Option<AccessMode>,
304) -> Result<(), DbErr> {
305    if isolation_level.is_some() {
306        warn!("Setting isolation level in a SQLite transaction isn't supported");
307    }
308    if access_mode.is_some() {
309        warn!("Setting access mode in a SQLite transaction isn't supported");
310    }
311    Ok(())
312}
313
314#[cfg(feature = "sqlite-use-returning-for-3_35")]
315async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
316    let stmt = Statement {
317        sql: "SELECT sqlite_version()".to_string(),
318        values: None,
319        db_backend: crate::DbBackend::Sqlite,
320    };
321    conn.query_one(stmt)
322        .await?
323        .ok_or_else(|| {
324            DbErr::Conn(RuntimeErr::Internal(
325                "Error reading SQLite version".to_string(),
326            ))
327        })?
328        .try_get_by(0)
329}
330
331#[cfg(feature = "sqlite-use-returning-for-3_35")]
332fn ensure_returning_version(version: &str) -> Result<(), DbErr> {
333    let mut parts = version.trim().split('.').map(|part| {
334        part.parse::<u32>().map_err(|_| {
335            DbErr::Conn(RuntimeErr::Internal(
336                "Error parsing SQLite version".to_string(),
337            ))
338        })
339    });
340
341    let mut extract_next = || {
342        parts.next().transpose().and_then(|part| {
343            part.ok_or_else(|| {
344                DbErr::Conn(RuntimeErr::Internal("SQLite version too short".to_string()))
345            })
346        })
347    };
348
349    let major = extract_next()?;
350    let minor = extract_next()?;
351
352    if major > 3 || (major == 3 && minor >= 35) {
353        Ok(())
354    } else {
355        Err(DbErr::Conn(RuntimeErr::Internal(
356            "SQLite version does not support returning".to_string(),
357        )))
358    }
359}
360
361impl
362    From<(
363        PoolConnection<sqlx::Sqlite>,
364        Statement,
365        Option<crate::metric::Callback>,
366    )> for crate::QueryStream
367{
368    fn from(
369        (conn, stmt, metric_callback): (
370            PoolConnection<sqlx::Sqlite>,
371            Statement,
372            Option<crate::metric::Callback>,
373        ),
374    ) -> Self {
375        crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
376    }
377}
378
379impl crate::DatabaseTransaction {
380    pub(crate) async fn new_sqlite(
381        inner: PoolConnection<sqlx::Sqlite>,
382        metric_callback: Option<crate::metric::Callback>,
383        isolation_level: Option<IsolationLevel>,
384        access_mode: Option<AccessMode>,
385    ) -> Result<crate::DatabaseTransaction, DbErr> {
386        Self::begin(
387            Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
388            crate::DbBackend::Sqlite,
389            metric_callback,
390            isolation_level,
391            access_mode,
392        )
393        .await
394    }
395}
396
397#[cfg(feature = "proxy")]
398pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
399    // https://docs.rs/sqlx-sqlite/0.7.2/src/sqlx_sqlite/type_info.rs.html
400    // https://docs.rs/sqlx-sqlite/0.7.2/sqlx_sqlite/types/index.html
401    use sea_query::Value;
402    use sqlx::{Column, Row, TypeInfo};
403    crate::ProxyRow {
404        values: row
405            .columns()
406            .iter()
407            .map(|c| {
408                (
409                    c.name().to_string(),
410                    match c.type_info().name() {
411                        "BOOLEAN" => Value::Bool(Some(
412                            row.try_get(c.ordinal()).expect("Failed to get boolean"),
413                        )),
414
415                        "INTEGER" => Value::Int(Some(
416                            row.try_get(c.ordinal()).expect("Failed to get integer"),
417                        )),
418
419                        "BIGINT" | "INT8" => Value::BigInt(Some(
420                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
421                        )),
422
423                        "REAL" => Value::Double(Some(
424                            row.try_get(c.ordinal()).expect("Failed to get double"),
425                        )),
426
427                        "TEXT" => Value::String(Some(Box::new(
428                            row.try_get(c.ordinal()).expect("Failed to get string"),
429                        ))),
430
431                        "BLOB" => Value::Bytes(Some(Box::new(
432                            row.try_get(c.ordinal()).expect("Failed to get bytes"),
433                        ))),
434
435                        #[cfg(feature = "with-chrono")]
436                        "DATETIME" => Value::ChronoDateTimeUtc(Some(Box::new(
437                            row.try_get(c.ordinal()).expect("Failed to get timestamp"),
438                        ))),
439                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
440                        "DATETIME" => Value::TimeDateTimeWithTimeZone(Some(Box::new(
441                            row.try_get(c.ordinal()).expect("Failed to get timestamp"),
442                        ))),
443
444                        #[cfg(feature = "with-chrono")]
445                        "DATE" => Value::ChronoDate(Some(Box::new(
446                            row.try_get(c.ordinal()).expect("Failed to get date"),
447                        ))),
448                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
449                        "DATE" => Value::TimeDate(Some(Box::new(
450                            row.try_get(c.ordinal()).expect("Failed to get date"),
451                        ))),
452
453                        #[cfg(feature = "with-chrono")]
454                        "TIME" => Value::ChronoTime(Some(Box::new(
455                            row.try_get(c.ordinal()).expect("Failed to get time"),
456                        ))),
457                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
458                        "TIME" => Value::TimeTime(Some(Box::new(
459                            row.try_get(c.ordinal()).expect("Failed to get time"),
460                        ))),
461
462                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
463                    },
464                )
465            })
466            .collect(),
467    }
468}
469
470#[cfg(all(test, feature = "sqlite-use-returning-for-3_35"))]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn test_ensure_returning_version() {
476        assert!(ensure_returning_version("").is_err());
477        assert!(ensure_returning_version(".").is_err());
478        assert!(ensure_returning_version(".a").is_err());
479        assert!(ensure_returning_version(".4.9").is_err());
480        assert!(ensure_returning_version("a").is_err());
481        assert!(ensure_returning_version("1.").is_err());
482        assert!(ensure_returning_version("1.a").is_err());
483
484        assert!(ensure_returning_version("1.1").is_err());
485        assert!(ensure_returning_version("1.0.").is_err());
486        assert!(ensure_returning_version("1.0.0").is_err());
487        assert!(ensure_returning_version("2.0.0").is_err());
488        assert!(ensure_returning_version("3.34.0").is_err());
489        assert!(ensure_returning_version("3.34.999").is_err());
490
491        // valid version
492        assert!(ensure_returning_version("3.35.0").is_ok());
493        assert!(ensure_returning_version("3.35.1").is_ok());
494        assert!(ensure_returning_version("3.36.0").is_ok());
495        assert!(ensure_returning_version("4.0.0").is_ok());
496        assert!(ensure_returning_version("99.0.0").is_ok());
497    }
498}