sea_orm/driver/
sqlx_sqlite.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    pool::PoolConnection,
8    sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
9    Connection, Executor, Sqlite, SqlitePool,
10};
11
12use sea_query_binder::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16    debug_print, error::*, executor::*, sqlx_error_to_exec_err, AccessMode, ConnectOptions,
17    DatabaseConnection, DatabaseTransaction, IsolationLevel, QueryStream, Statement,
18    TransactionError,
19};
20
21use super::sqlx_common::*;
22
23/// Defines the [sqlx::sqlite] connector
24#[derive(Debug)]
25pub struct SqlxSqliteConnector;
26
27/// Defines a sqlx SQLite pool
28#[derive(Clone)]
29pub struct SqlxSqlitePoolConnection {
30    pub(crate) pool: SqlitePool,
31    metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxSqlitePoolConnection {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
37    }
38}
39
40impl From<SqlitePool> for SqlxSqlitePoolConnection {
41    fn from(pool: SqlitePool) -> Self {
42        SqlxSqlitePoolConnection {
43            pool,
44            metric_callback: None,
45        }
46    }
47}
48
49impl From<SqlitePool> for DatabaseConnection {
50    fn from(pool: SqlitePool) -> Self {
51        DatabaseConnection::SqlxSqlitePoolConnection(pool.into())
52    }
53}
54
55impl SqlxSqliteConnector {
56    /// Check if the URI provided corresponds to `sqlite:` for a SQLite database
57    pub fn accepts(string: &str) -> bool {
58        string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
59    }
60
61    /// Add configuration options for the SQLite database
62    #[instrument(level = "trace")]
63    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64        let mut options = options;
65        let mut opt = options
66            .url
67            .parse::<SqliteConnectOptions>()
68            .map_err(sqlx_error_to_conn_err)?;
69        if let Some(sqlcipher_key) = &options.sqlcipher_key {
70            opt = opt.pragma("key", sqlcipher_key.clone());
71        }
72        use sqlx::ConnectOptions;
73        if !options.sqlx_logging {
74            opt = opt.disable_statement_logging();
75        } else {
76            opt = opt.log_statements(options.sqlx_logging_level);
77            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
78                opt = opt.log_slow_statements(
79                    options.sqlx_slow_statements_logging_level,
80                    options.sqlx_slow_statements_logging_threshold,
81                );
82            }
83        }
84
85        if options.get_max_connections().is_none() {
86            options.max_connections(1);
87        }
88
89        let pool = if options.connect_lazy {
90            options.sqlx_pool_options().connect_lazy_with(opt)
91        } else {
92            options
93                .sqlx_pool_options()
94                .connect_with(opt)
95                .await
96                .map_err(sqlx_error_to_conn_err)?
97        };
98
99        let pool = SqlxSqlitePoolConnection {
100            pool,
101            metric_callback: None,
102        };
103
104        #[cfg(feature = "sqlite-use-returning-for-3_35")]
105        {
106            let version = get_version(&pool).await?;
107            ensure_returning_version(&version)?;
108        }
109
110        Ok(DatabaseConnection::SqlxSqlitePoolConnection(pool))
111    }
112}
113
114impl SqlxSqliteConnector {
115    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
116    pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
117        DatabaseConnection::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
118            pool,
119            metric_callback: None,
120        })
121    }
122}
123
124impl SqlxSqlitePoolConnection {
125    /// Execute a [Statement] on a SQLite backend
126    #[instrument(level = "trace")]
127    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
128        debug_print!("{}", stmt);
129
130        let query = sqlx_query(&stmt);
131        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
132        crate::metric::metric!(self.metric_callback, &stmt, {
133            match query.execute(&mut *conn).await {
134                Ok(res) => Ok(res.into()),
135                Err(err) => Err(sqlx_error_to_exec_err(err)),
136            }
137        })
138    }
139
140    /// Execute an unprepared SQL statement on a SQLite backend
141    #[instrument(level = "trace")]
142    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
143        debug_print!("{}", sql);
144
145        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
146        match conn.execute(sql).await {
147            Ok(res) => Ok(res.into()),
148            Err(err) => Err(sqlx_error_to_exec_err(err)),
149        }
150    }
151
152    /// Get one result from a SQL query. Returns [Option::None] if no match was found
153    #[instrument(level = "trace")]
154    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
155        debug_print!("{}", stmt);
156
157        let query = sqlx_query(&stmt);
158        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
159        crate::metric::metric!(self.metric_callback, &stmt, {
160            match query.fetch_one(&mut *conn).await {
161                Ok(row) => Ok(Some(row.into())),
162                Err(err) => match err {
163                    sqlx::Error::RowNotFound => Ok(None),
164                    _ => Err(sqlx_error_to_query_err(err)),
165                },
166            }
167        })
168    }
169
170    /// Get the results of a query returning them as a Vec<[QueryResult]>
171    #[instrument(level = "trace")]
172    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
173        debug_print!("{}", stmt);
174
175        let query = sqlx_query(&stmt);
176        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
177        crate::metric::metric!(self.metric_callback, &stmt, {
178            match query.fetch_all(&mut *conn).await {
179                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
180                Err(err) => Err(sqlx_error_to_query_err(err)),
181            }
182        })
183    }
184
185    /// Stream the results of executing a SQL query
186    #[instrument(level = "trace")]
187    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
188        debug_print!("{}", stmt);
189
190        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
191        Ok(QueryStream::from((
192            conn,
193            stmt,
194            self.metric_callback.clone(),
195        )))
196    }
197
198    /// Bundle a set of SQL statements that execute together.
199    #[instrument(level = "trace")]
200    pub async fn begin(
201        &self,
202        isolation_level: Option<IsolationLevel>,
203        access_mode: Option<AccessMode>,
204    ) -> Result<DatabaseTransaction, DbErr> {
205        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
206        DatabaseTransaction::new_sqlite(
207            conn,
208            self.metric_callback.clone(),
209            isolation_level,
210            access_mode,
211        )
212        .await
213    }
214
215    /// Create a MySQL transaction
216    #[instrument(level = "trace", skip(callback))]
217    pub async fn transaction<F, T, E>(
218        &self,
219        callback: F,
220        isolation_level: Option<IsolationLevel>,
221        access_mode: Option<AccessMode>,
222    ) -> Result<T, TransactionError<E>>
223    where
224        F: for<'b> FnOnce(
225                &'b DatabaseTransaction,
226            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
227            + Send,
228        T: Send,
229        E: std::error::Error + Send,
230    {
231        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
232        let transaction = DatabaseTransaction::new_sqlite(
233            conn,
234            self.metric_callback.clone(),
235            isolation_level,
236            access_mode,
237        )
238        .await
239        .map_err(|e| TransactionError::Connection(e))?;
240        transaction.run(callback).await
241    }
242
243    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
244    where
245        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
246    {
247        self.metric_callback = Some(Arc::new(callback));
248    }
249
250    /// Checks if a connection to the database is still valid.
251    pub async fn ping(&self) -> Result<(), DbErr> {
252        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
253        match conn.ping().await {
254            Ok(_) => Ok(()),
255            Err(err) => Err(sqlx_error_to_conn_err(err)),
256        }
257    }
258
259    /// Explicitly close the SQLite connection
260    pub async fn close(self) -> Result<(), DbErr> {
261        self.pool.close().await;
262        Ok(())
263    }
264}
265
266impl From<SqliteRow> for QueryResult {
267    fn from(row: SqliteRow) -> QueryResult {
268        QueryResult {
269            row: QueryResultRow::SqlxSqlite(row),
270        }
271    }
272}
273
274impl From<SqliteQueryResult> for ExecResult {
275    fn from(result: SqliteQueryResult) -> ExecResult {
276        ExecResult {
277            result: ExecResultHolder::SqlxSqlite(result),
278        }
279    }
280}
281
282pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
283    let values = stmt
284        .values
285        .as_ref()
286        .map_or(Values(Vec::new()), |values| values.clone());
287    sqlx::query_with(&stmt.sql, SqlxValues(values))
288}
289
290pub(crate) async fn set_transaction_config(
291    _conn: &mut PoolConnection<Sqlite>,
292    isolation_level: Option<IsolationLevel>,
293    access_mode: Option<AccessMode>,
294) -> Result<(), DbErr> {
295    if isolation_level.is_some() {
296        warn!("Setting isolation level in a SQLite transaction isn't supported");
297    }
298    if access_mode.is_some() {
299        warn!("Setting access mode in a SQLite transaction isn't supported");
300    }
301    Ok(())
302}
303
304#[cfg(feature = "sqlite-use-returning-for-3_35")]
305async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
306    let stmt = Statement {
307        sql: "SELECT sqlite_version()".to_string(),
308        values: None,
309        db_backend: crate::DbBackend::Sqlite,
310    };
311    conn.query_one(stmt)
312        .await?
313        .ok_or_else(|| {
314            DbErr::Conn(RuntimeErr::Internal(
315                "Error reading SQLite version".to_string(),
316            ))
317        })?
318        .try_get_by(0)
319}
320
321#[cfg(feature = "sqlite-use-returning-for-3_35")]
322fn ensure_returning_version(version: &str) -> Result<(), DbErr> {
323    let mut parts = version.trim().split('.').map(|part| {
324        part.parse::<u32>().map_err(|_| {
325            DbErr::Conn(RuntimeErr::Internal(
326                "Error parsing SQLite version".to_string(),
327            ))
328        })
329    });
330
331    let mut extract_next = || {
332        parts.next().transpose().and_then(|part| {
333            part.ok_or_else(|| {
334                DbErr::Conn(RuntimeErr::Internal("SQLite version too short".to_string()))
335            })
336        })
337    };
338
339    let major = extract_next()?;
340    let minor = extract_next()?;
341
342    if major > 3 || (major == 3 && minor >= 35) {
343        Ok(())
344    } else {
345        Err(DbErr::Conn(RuntimeErr::Internal(
346            "SQLite version does not support returning".to_string(),
347        )))
348    }
349}
350
351impl
352    From<(
353        PoolConnection<sqlx::Sqlite>,
354        Statement,
355        Option<crate::metric::Callback>,
356    )> for crate::QueryStream
357{
358    fn from(
359        (conn, stmt, metric_callback): (
360            PoolConnection<sqlx::Sqlite>,
361            Statement,
362            Option<crate::metric::Callback>,
363        ),
364    ) -> Self {
365        crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
366    }
367}
368
369impl crate::DatabaseTransaction {
370    pub(crate) async fn new_sqlite(
371        inner: PoolConnection<sqlx::Sqlite>,
372        metric_callback: Option<crate::metric::Callback>,
373        isolation_level: Option<IsolationLevel>,
374        access_mode: Option<AccessMode>,
375    ) -> Result<crate::DatabaseTransaction, DbErr> {
376        Self::begin(
377            Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
378            crate::DbBackend::Sqlite,
379            metric_callback,
380            isolation_level,
381            access_mode,
382        )
383        .await
384    }
385}
386
387#[cfg(feature = "proxy")]
388pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
389    // https://docs.rs/sqlx-sqlite/0.7.2/src/sqlx_sqlite/type_info.rs.html
390    // https://docs.rs/sqlx-sqlite/0.7.2/sqlx_sqlite/types/index.html
391    use sea_query::Value;
392    use sqlx::{Column, Row, TypeInfo};
393    crate::ProxyRow {
394        values: row
395            .columns()
396            .iter()
397            .map(|c| {
398                (
399                    c.name().to_string(),
400                    match c.type_info().name() {
401                        "BOOLEAN" => Value::Bool(Some(
402                            row.try_get(c.ordinal()).expect("Failed to get boolean"),
403                        )),
404
405                        "INTEGER" => Value::Int(Some(
406                            row.try_get(c.ordinal()).expect("Failed to get integer"),
407                        )),
408
409                        "BIGINT" | "INT8" => Value::BigInt(Some(
410                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
411                        )),
412
413                        "REAL" => Value::Double(Some(
414                            row.try_get(c.ordinal()).expect("Failed to get double"),
415                        )),
416
417                        "TEXT" => Value::String(Some(Box::new(
418                            row.try_get(c.ordinal()).expect("Failed to get string"),
419                        ))),
420
421                        "BLOB" => Value::Bytes(Some(Box::new(
422                            row.try_get(c.ordinal()).expect("Failed to get bytes"),
423                        ))),
424
425                        #[cfg(feature = "with-chrono")]
426                        "DATETIME" => Value::ChronoDateTimeUtc(Some(Box::new(
427                            row.try_get(c.ordinal()).expect("Failed to get timestamp"),
428                        ))),
429                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
430                        "DATETIME" => Value::TimeDateTimeWithTimeZone(Some(Box::new(
431                            row.try_get(c.ordinal()).expect("Failed to get timestamp"),
432                        ))),
433
434                        #[cfg(feature = "with-chrono")]
435                        "DATE" => Value::ChronoDate(Some(Box::new(
436                            row.try_get(c.ordinal()).expect("Failed to get date"),
437                        ))),
438                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
439                        "DATE" => Value::TimeDate(Some(Box::new(
440                            row.try_get(c.ordinal()).expect("Failed to get date"),
441                        ))),
442
443                        #[cfg(feature = "with-chrono")]
444                        "TIME" => Value::ChronoTime(Some(Box::new(
445                            row.try_get(c.ordinal()).expect("Failed to get time"),
446                        ))),
447                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
448                        "TIME" => Value::TimeTime(Some(Box::new(
449                            row.try_get(c.ordinal()).expect("Failed to get time"),
450                        ))),
451
452                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
453                    },
454                )
455            })
456            .collect(),
457    }
458}
459
460#[cfg(all(test, feature = "sqlite-use-returning-for-3_35"))]
461mod tests {
462    use super::*;
463
464    #[test]
465    fn test_ensure_returning_version() {
466        assert!(ensure_returning_version("").is_err());
467        assert!(ensure_returning_version(".").is_err());
468        assert!(ensure_returning_version(".a").is_err());
469        assert!(ensure_returning_version(".4.9").is_err());
470        assert!(ensure_returning_version("a").is_err());
471        assert!(ensure_returning_version("1.").is_err());
472        assert!(ensure_returning_version("1.a").is_err());
473
474        assert!(ensure_returning_version("1.1").is_err());
475        assert!(ensure_returning_version("1.0.").is_err());
476        assert!(ensure_returning_version("1.0.0").is_err());
477        assert!(ensure_returning_version("2.0.0").is_err());
478        assert!(ensure_returning_version("3.34.0").is_err());
479        assert!(ensure_returning_version("3.34.999").is_err());
480
481        // valid version
482        assert!(ensure_returning_version("3.35.0").is_ok());
483        assert!(ensure_returning_version("3.35.1").is_ok());
484        assert!(ensure_returning_version("3.36.0").is_ok());
485        assert!(ensure_returning_version("4.0.0").is_ok());
486        assert!(ensure_returning_version("99.0.0").is_ok());
487    }
488}