Skip to main content

sea_orm/driver/
sqlx_sqlite.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    Connection, Executor, Sqlite, SqlitePool,
8    pool::PoolConnection,
9    sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16    AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17    IsolationLevel, QueryStream, SqliteTransactionMode, Statement, TransactionError, debug_print,
18    error::*, executor::*, sqlx_error_to_exec_err,
19};
20
21use super::sqlx_common::*;
22
23/// Defines the [sqlx::sqlite] connector
24#[derive(Debug)]
25pub struct SqlxSqliteConnector;
26
27/// Defines a sqlx SQLite pool
28#[derive(Clone)]
29pub struct SqlxSqlitePoolConnection {
30    pub(crate) pool: SqlitePool,
31    metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxSqlitePoolConnection {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
37    }
38}
39
40impl From<SqlitePool> for SqlxSqlitePoolConnection {
41    fn from(pool: SqlitePool) -> Self {
42        SqlxSqlitePoolConnection {
43            pool,
44            metric_callback: None,
45        }
46    }
47}
48
49impl From<SqlitePool> for DatabaseConnection {
50    fn from(pool: SqlitePool) -> Self {
51        DatabaseConnectionType::SqlxSqlitePoolConnection(pool.into()).into()
52    }
53}
54
55impl SqlxSqliteConnector {
56    /// Check if the URI provided corresponds to `sqlite:` for a SQLite database
57    pub fn accepts(string: &str) -> bool {
58        string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
59    }
60
61    /// Add configuration options for the SQLite database
62    #[instrument(level = "trace")]
63    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64        let mut options = options;
65        let mut sqlx_opts = options
66            .url
67            .parse::<SqliteConnectOptions>()
68            .map_err(sqlx_error_to_conn_err)?;
69        if let Some(sqlcipher_key) = &options.sqlcipher_key {
70            sqlx_opts = sqlx_opts.pragma("key", sqlcipher_key.clone());
71        }
72        use sqlx::ConnectOptions;
73        if !options.sqlx_logging {
74            sqlx_opts = sqlx_opts.disable_statement_logging();
75        } else {
76            sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
77            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
78                sqlx_opts = sqlx_opts.log_slow_statements(
79                    options.sqlx_slow_statements_logging_level,
80                    options.sqlx_slow_statements_logging_threshold,
81                );
82            }
83        }
84
85        if options.get_max_connections().is_none() {
86            options.max_connections(1);
87        }
88
89        if let Some(f) = &options.sqlite_opts_fn {
90            sqlx_opts = f(sqlx_opts);
91        }
92
93        let after_conn = options.after_connect.clone();
94        let connect_lazy = options.connect_lazy;
95        let sqlite_pool_opts_fn = options.sqlite_pool_opts_fn.clone();
96        let mut pool_options = options.sqlx_pool_options();
97
98        if let Some(f) = &sqlite_pool_opts_fn {
99            pool_options = f(pool_options);
100        }
101
102        let pool = if connect_lazy {
103            pool_options.connect_lazy_with(sqlx_opts)
104        } else {
105            pool_options
106                .connect_with(sqlx_opts)
107                .await
108                .map_err(sqlx_error_to_conn_err)?
109        };
110
111        let pool = SqlxSqlitePoolConnection {
112            pool,
113            metric_callback: None,
114        };
115
116        #[cfg(feature = "sqlite-use-returning-for-3_35")]
117        {
118            let version = get_version(&pool).await?;
119            super::sqlite::ensure_returning_version(&version)?;
120        }
121
122        let conn: DatabaseConnection =
123            DatabaseConnectionType::SqlxSqlitePoolConnection(pool).into();
124
125        if let Some(cb) = after_conn {
126            cb(conn.clone()).await?;
127        }
128
129        Ok(conn)
130    }
131}
132
133impl SqlxSqliteConnector {
134    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
135    pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
136        DatabaseConnectionType::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
137            pool,
138            metric_callback: None,
139        })
140        .into()
141    }
142}
143
144impl SqlxSqlitePoolConnection {
145    /// Execute a [Statement] on a SQLite backend
146    #[instrument(level = "trace")]
147    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
148        debug_print!("{}", stmt);
149
150        let query = sqlx_query(&stmt);
151        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
152        crate::metric::metric!(self.metric_callback, &stmt, {
153            match query.execute(&mut *conn).await {
154                Ok(res) => Ok(res.into()),
155                Err(err) => Err(sqlx_error_to_exec_err(err)),
156            }
157        })
158    }
159
160    /// Execute an unprepared SQL statement on a SQLite backend
161    #[instrument(level = "trace")]
162    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
163        debug_print!("{}", sql);
164
165        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
166        match conn.execute(sql).await {
167            Ok(res) => Ok(res.into()),
168            Err(err) => Err(sqlx_error_to_exec_err(err)),
169        }
170    }
171
172    /// Get one result from a SQL query. Returns [Option::None] if no match was found
173    #[instrument(level = "trace")]
174    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
175        debug_print!("{}", stmt);
176
177        let query = sqlx_query(&stmt);
178        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
179        crate::metric::metric!(self.metric_callback, &stmt, {
180            match query.fetch_one(&mut *conn).await {
181                Ok(row) => Ok(Some(row.into())),
182                Err(err) => match err {
183                    sqlx::Error::RowNotFound => Ok(None),
184                    _ => Err(sqlx_error_to_query_err(err)),
185                },
186            }
187        })
188    }
189
190    /// Get the results of a query returning them as a Vec<[QueryResult]>
191    #[instrument(level = "trace")]
192    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
193        debug_print!("{}", stmt);
194
195        let query = sqlx_query(&stmt);
196        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
197        crate::metric::metric!(self.metric_callback, &stmt, {
198            match query.fetch_all(&mut *conn).await {
199                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
200                Err(err) => Err(sqlx_error_to_query_err(err)),
201            }
202        })
203    }
204
205    /// Stream the results of executing a SQL query
206    #[instrument(level = "trace")]
207    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
208        debug_print!("{}", stmt);
209
210        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
211        Ok(QueryStream::from((
212            conn,
213            stmt,
214            self.metric_callback.clone(),
215        )))
216    }
217
218    /// Bundle a set of SQL statements that execute together.
219    #[instrument(level = "trace")]
220    pub async fn begin(
221        &self,
222        isolation_level: Option<IsolationLevel>,
223        access_mode: Option<AccessMode>,
224        sqlite_transaction_mode: Option<SqliteTransactionMode>,
225    ) -> Result<DatabaseTransaction, DbErr> {
226        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
227        DatabaseTransaction::new_sqlite(
228            conn,
229            self.metric_callback.clone(),
230            isolation_level,
231            access_mode,
232            sqlite_transaction_mode,
233        )
234        .await
235    }
236
237    /// Create a SQLite transaction
238    #[instrument(level = "trace", skip(callback))]
239    pub async fn transaction<F, T, E>(
240        &self,
241        callback: F,
242        isolation_level: Option<IsolationLevel>,
243        access_mode: Option<AccessMode>,
244    ) -> Result<T, TransactionError<E>>
245    where
246        F: for<'b> FnOnce(
247                &'b DatabaseTransaction,
248            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
249            + Send,
250        T: Send,
251        E: std::fmt::Display + std::fmt::Debug + Send,
252    {
253        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
254        let transaction = DatabaseTransaction::new_sqlite(
255            conn,
256            self.metric_callback.clone(),
257            isolation_level,
258            access_mode,
259            None,
260        )
261        .await
262        .map_err(|e| TransactionError::Connection(e))?;
263        transaction.run(callback).await
264    }
265
266    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
267    where
268        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
269    {
270        self.metric_callback = Some(Arc::new(callback));
271    }
272
273    /// Checks if a connection to the database is still valid.
274    pub async fn ping(&self) -> Result<(), DbErr> {
275        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
276        match conn.ping().await {
277            Ok(_) => Ok(()),
278            Err(err) => Err(sqlx_error_to_conn_err(err)),
279        }
280    }
281
282    /// Explicitly close the SQLite connection.
283    /// See [`Self::close_by_ref`] for usage with references.
284    pub async fn close(self) -> Result<(), DbErr> {
285        self.close_by_ref().await
286    }
287
288    /// Explicitly close the SQLite connection
289    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
290        self.pool.close().await;
291        Ok(())
292    }
293}
294
295impl From<SqliteRow> for QueryResult {
296    fn from(row: SqliteRow) -> QueryResult {
297        QueryResult {
298            row: QueryResultRow::SqlxSqlite(row),
299        }
300    }
301}
302
303impl From<SqliteQueryResult> for ExecResult {
304    fn from(result: SqliteQueryResult) -> ExecResult {
305        ExecResult {
306            result: ExecResultHolder::SqlxSqlite(result),
307        }
308    }
309}
310
311pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
312    let values = stmt
313        .values
314        .as_ref()
315        .map_or(Values(Vec::new()), |values| values.clone());
316    sqlx::query_with(&stmt.sql, SqlxValues(values))
317}
318
319pub(crate) async fn set_transaction_config(
320    _conn: &mut PoolConnection<Sqlite>,
321    isolation_level: Option<IsolationLevel>,
322    access_mode: Option<AccessMode>,
323) -> Result<(), DbErr> {
324    if isolation_level.is_some() {
325        warn!("Setting isolation level in a SQLite transaction isn't supported");
326    }
327    if access_mode.is_some() {
328        warn!("Setting access mode in a SQLite transaction isn't supported");
329    }
330    Ok(())
331}
332
333#[cfg(feature = "sqlite-use-returning-for-3_35")]
334async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
335    let stmt = Statement {
336        sql: "SELECT sqlite_version()".to_string(),
337        values: None,
338        db_backend: crate::DbBackend::Sqlite,
339    };
340    conn.query_one(stmt)
341        .await?
342        .ok_or_else(|| {
343            DbErr::Conn(RuntimeErr::Internal(
344                "Error reading SQLite version".to_string(),
345            ))
346        })?
347        .try_get_by(0)
348}
349
350impl
351    From<(
352        PoolConnection<sqlx::Sqlite>,
353        Statement,
354        Option<crate::metric::Callback>,
355    )> for crate::QueryStream
356{
357    fn from(
358        (conn, stmt, metric_callback): (
359            PoolConnection<sqlx::Sqlite>,
360            Statement,
361            Option<crate::metric::Callback>,
362        ),
363    ) -> Self {
364        crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
365    }
366}
367
368impl crate::DatabaseTransaction {
369    pub(crate) async fn new_sqlite(
370        inner: PoolConnection<sqlx::Sqlite>,
371        metric_callback: Option<crate::metric::Callback>,
372        isolation_level: Option<IsolationLevel>,
373        access_mode: Option<AccessMode>,
374        sqlite_transaction_mode: Option<SqliteTransactionMode>,
375    ) -> Result<crate::DatabaseTransaction, DbErr> {
376        Self::begin(
377            Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
378            crate::DbBackend::Sqlite,
379            metric_callback,
380            isolation_level,
381            access_mode,
382            sqlite_transaction_mode,
383        )
384        .await
385    }
386}
387
388#[cfg(feature = "proxy")]
389pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
390    // https://docs.rs/sqlx-sqlite/0.7.2/src/sqlx_sqlite/type_info.rs.html
391    // https://docs.rs/sqlx-sqlite/0.7.2/sqlx_sqlite/types/index.html
392    use sea_query::Value;
393    use sqlx::{Column, Row, TypeInfo};
394    crate::ProxyRow {
395        values: row
396            .columns()
397            .iter()
398            .map(|c| {
399                (
400                    c.name().to_string(),
401                    match c.type_info().name() {
402                        "BOOLEAN" => {
403                            Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
404                        }
405
406                        "INTEGER" => {
407                            Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
408                        }
409
410                        "BIGINT" | "INT8" => Value::BigInt(
411                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
412                        ),
413
414                        "REAL" => {
415                            Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
416                        }
417
418                        "TEXT" => Value::String(
419                            row.try_get::<Option<String>, _>(c.ordinal())
420                                .expect("Failed to get string")
421                                .map(Box::new),
422                        ),
423
424                        "BLOB" => Value::Bytes(
425                            row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
426                                .expect("Failed to get bytes")
427                                .map(Box::new),
428                        ),
429
430                        #[cfg(feature = "with-chrono")]
431                        "DATETIME" => {
432                            use chrono::{DateTime, Utc};
433
434                            Value::ChronoDateTimeUtc(
435                                row.try_get::<Option<DateTime<Utc>>, _>(c.ordinal())
436                                    .expect("Failed to get timestamp")
437                                    .map(Box::new),
438                            )
439                        }
440                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
441                        "DATETIME" => {
442                            use time::OffsetDateTime;
443                            Value::TimeDateTimeWithTimeZone(
444                                row.try_get::<Option<OffsetDateTime>, _>(c.ordinal())
445                                    .expect("Failed to get timestamp")
446                                    .map(Box::new),
447                            )
448                        }
449                        #[cfg(feature = "with-chrono")]
450                        "DATE" => {
451                            use chrono::NaiveDate;
452                            Value::ChronoDate(
453                                row.try_get::<Option<NaiveDate>, _>(c.ordinal())
454                                    .expect("Failed to get date")
455                                    .map(Box::new),
456                            )
457                        }
458                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
459                        "DATE" => {
460                            use time::Date;
461                            Value::TimeDate(
462                                row.try_get::<Option<Date>, _>(c.ordinal())
463                                    .expect("Failed to get date")
464                                    .map(Box::new),
465                            )
466                        }
467
468                        #[cfg(feature = "with-chrono")]
469                        "TIME" => {
470                            use chrono::NaiveTime;
471                            Value::ChronoTime(
472                                row.try_get::<Option<NaiveTime>, _>(c.ordinal())
473                                    .expect("Failed to get time")
474                                    .map(Box::new),
475                            )
476                        }
477                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
478                        "TIME" => {
479                            use time::Time;
480                            Value::TimeTime(
481                                row.try_get::<Option<Time>, _>(c.ordinal())
482                                    .expect("Failed to get time")
483                                    .map(Box::new),
484                            )
485                        }
486
487                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
488                    },
489                )
490            })
491            .collect(),
492    }
493}