Skip to main content

sea_orm/driver/
sqlx_sqlite.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    Connection, Executor, Sqlite, SqlitePool,
8    pool::PoolConnection,
9    sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16    AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17    IsolationLevel, QueryStream, SqliteTransactionMode, Statement, TransactionError, debug_print,
18    error::*, executor::*, sqlx_error_to_exec_err,
19};
20
21use super::sqlx_common::*;
22
23/// Defines the [sqlx::sqlite] connector
24#[derive(Debug)]
25pub struct SqlxSqliteConnector;
26
27/// Defines a sqlx SQLite pool
28#[derive(Clone)]
29pub struct SqlxSqlitePoolConnection {
30    pub(crate) pool: SqlitePool,
31    metric_callback: Option<crate::metric::Callback>,
32    pub(crate) record_stmt_in_spans: bool,
33}
34
35impl std::fmt::Debug for SqlxSqlitePoolConnection {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
38    }
39}
40
41impl From<SqlitePool> for SqlxSqlitePoolConnection {
42    fn from(pool: SqlitePool) -> Self {
43        SqlxSqlitePoolConnection {
44            pool,
45            metric_callback: None,
46            record_stmt_in_spans: true,
47        }
48    }
49}
50
51impl From<SqlitePool> for DatabaseConnection {
52    fn from(pool: SqlitePool) -> Self {
53        DatabaseConnectionType::SqlxSqlitePoolConnection(pool.into()).into()
54    }
55}
56
57impl SqlxSqliteConnector {
58    /// Check if the URI provided corresponds to `sqlite:` for a SQLite database
59    pub fn accepts(string: &str) -> bool {
60        string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
61    }
62
63    /// Add configuration options for the SQLite database
64    #[instrument(level = "trace")]
65    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
66        let mut options = options;
67        let record_stmt_in_spans = options.get_record_stmt_in_spans();
68        let mut sqlx_opts = options
69            .url
70            .parse::<SqliteConnectOptions>()
71            .map_err(sqlx_error_to_conn_err)?;
72        if let Some(sqlcipher_key) = &options.sqlcipher_key {
73            sqlx_opts = sqlx_opts.pragma("key", sqlcipher_key.clone());
74        }
75        use sqlx::ConnectOptions;
76        if !options.sqlx_logging {
77            sqlx_opts = sqlx_opts.disable_statement_logging();
78        } else {
79            sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
80            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
81                sqlx_opts = sqlx_opts.log_slow_statements(
82                    options.sqlx_slow_statements_logging_level,
83                    options.sqlx_slow_statements_logging_threshold,
84                );
85            }
86        }
87
88        if options.get_max_connections().is_none() {
89            options.max_connections(1);
90        }
91
92        if let Some(f) = &options.sqlite_opts_fn {
93            sqlx_opts = f(sqlx_opts);
94        }
95
96        let after_conn = options.after_connect.clone();
97        let connect_lazy = options.connect_lazy;
98        let sqlite_pool_opts_fn = options.sqlite_pool_opts_fn.clone();
99        let mut pool_options = options.sqlx_pool_options();
100
101        if let Some(f) = &sqlite_pool_opts_fn {
102            pool_options = f(pool_options);
103        }
104
105        let pool = if connect_lazy {
106            pool_options.connect_lazy_with(sqlx_opts)
107        } else {
108            pool_options
109                .connect_with(sqlx_opts)
110                .await
111                .map_err(sqlx_error_to_conn_err)?
112        };
113
114        let pool = SqlxSqlitePoolConnection {
115            pool,
116            metric_callback: None,
117            record_stmt_in_spans,
118        };
119
120        #[cfg(feature = "sqlite-use-returning-for-3_35")]
121        {
122            let version = get_version(&pool).await?;
123            super::sqlite::ensure_returning_version(&version)?;
124        }
125
126        let conn: DatabaseConnection =
127            DatabaseConnectionType::SqlxSqlitePoolConnection(pool).into();
128
129        if let Some(cb) = after_conn {
130            cb(conn.clone()).await?;
131        }
132
133        Ok(conn)
134    }
135}
136
137impl SqlxSqliteConnector {
138    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
139    pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
140        DatabaseConnectionType::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
141            pool,
142            metric_callback: None,
143            record_stmt_in_spans: true,
144        })
145        .into()
146    }
147}
148
149impl SqlxSqlitePoolConnection {
150    /// Execute a [Statement] on a SQLite backend
151    #[instrument(level = "trace")]
152    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
153        debug_print!("{}", stmt);
154
155        let query = sqlx_query(&stmt);
156        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
157        crate::metric::metric!(self.metric_callback, &stmt, {
158            match query.execute(&mut *conn).await {
159                Ok(res) => Ok(res.into()),
160                Err(err) => Err(sqlx_error_to_exec_err(err)),
161            }
162        })
163    }
164
165    /// Execute an unprepared SQL statement on a SQLite backend
166    #[instrument(level = "trace")]
167    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
168        debug_print!("{}", sql);
169
170        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
171        match conn.execute(sqlx::AssertSqlSafe(sql.to_owned())).await {
172            Ok(res) => Ok(res.into()),
173            Err(err) => Err(sqlx_error_to_exec_err(err)),
174        }
175    }
176
177    /// Get one result from a SQL query. Returns [Option::None] if no match was found
178    #[instrument(level = "trace")]
179    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
180        debug_print!("{}", stmt);
181
182        let query = sqlx_query(&stmt);
183        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
184        crate::metric::metric!(self.metric_callback, &stmt, {
185            match query.fetch_one(&mut *conn).await {
186                Ok(row) => Ok(Some(row.into())),
187                Err(err) => match err {
188                    sqlx::Error::RowNotFound => Ok(None),
189                    _ => Err(sqlx_error_to_query_err(err)),
190                },
191            }
192        })
193    }
194
195    /// Get the results of a query returning them as a Vec<[QueryResult]>
196    #[instrument(level = "trace")]
197    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
198        debug_print!("{}", stmt);
199
200        let query = sqlx_query(&stmt);
201        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
202        crate::metric::metric!(self.metric_callback, &stmt, {
203            match query.fetch_all(&mut *conn).await {
204                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
205                Err(err) => Err(sqlx_error_to_query_err(err)),
206            }
207        })
208    }
209
210    /// Stream the results of executing a SQL query
211    #[instrument(level = "trace")]
212    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
213        debug_print!("{}", stmt);
214
215        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
216        Ok(QueryStream::from((
217            conn,
218            stmt,
219            self.metric_callback.clone(),
220        )))
221    }
222
223    /// Bundle a set of SQL statements that execute together.
224    #[instrument(level = "trace")]
225    pub async fn begin(
226        &self,
227        isolation_level: Option<IsolationLevel>,
228        access_mode: Option<AccessMode>,
229        sqlite_transaction_mode: Option<SqliteTransactionMode>,
230    ) -> Result<DatabaseTransaction, DbErr> {
231        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
232        DatabaseTransaction::new_sqlite(
233            conn,
234            self.metric_callback.clone(),
235            self.record_stmt_in_spans,
236            isolation_level,
237            access_mode,
238            sqlite_transaction_mode,
239        )
240        .await
241    }
242
243    /// Create a SQLite transaction
244    #[instrument(level = "trace", skip(callback))]
245    pub async fn transaction<F, T, E>(
246        &self,
247        callback: F,
248        isolation_level: Option<IsolationLevel>,
249        access_mode: Option<AccessMode>,
250    ) -> Result<T, TransactionError<E>>
251    where
252        F: for<'b> FnOnce(
253                &'b DatabaseTransaction,
254            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
255            + Send,
256        T: Send,
257        E: std::fmt::Display + std::fmt::Debug + Send,
258    {
259        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
260        let transaction = DatabaseTransaction::new_sqlite(
261            conn,
262            self.metric_callback.clone(),
263            self.record_stmt_in_spans,
264            isolation_level,
265            access_mode,
266            None,
267        )
268        .await
269        .map_err(|e| TransactionError::Connection(e))?;
270        transaction.run(callback).await
271    }
272
273    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
274    where
275        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
276    {
277        self.metric_callback = Some(Arc::new(callback));
278    }
279
280    /// Checks if a connection to the database is still valid.
281    pub async fn ping(&self) -> Result<(), DbErr> {
282        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
283        match conn.ping().await {
284            Ok(_) => Ok(()),
285            Err(err) => Err(sqlx_error_to_conn_err(err)),
286        }
287    }
288
289    /// Explicitly close the SQLite connection.
290    /// See [`Self::close_by_ref`] for usage with references.
291    pub async fn close(self) -> Result<(), DbErr> {
292        self.close_by_ref().await
293    }
294
295    /// Explicitly close the SQLite connection
296    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
297        self.pool.close().await;
298        Ok(())
299    }
300}
301
302impl From<SqliteRow> for QueryResult {
303    fn from(row: SqliteRow) -> QueryResult {
304        QueryResult {
305            row: QueryResultRow::SqlxSqlite(row),
306        }
307    }
308}
309
310impl From<SqliteQueryResult> for ExecResult {
311    fn from(result: SqliteQueryResult) -> ExecResult {
312        ExecResult {
313            result: ExecResultHolder::SqlxSqlite(result),
314        }
315    }
316}
317
318pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
319    let values = stmt
320        .values
321        .as_ref()
322        .map_or(Values(Vec::new()), |values| values.clone());
323    sqlx::query_with(sqlx::AssertSqlSafe(stmt.sql.as_str()), SqlxValues(values))
324}
325
326pub(crate) async fn set_transaction_config(
327    _conn: &mut PoolConnection<Sqlite>,
328    isolation_level: Option<IsolationLevel>,
329    access_mode: Option<AccessMode>,
330) -> Result<(), DbErr> {
331    if isolation_level.is_some() {
332        warn!("Setting isolation level in a SQLite transaction isn't supported");
333    }
334    if access_mode.is_some() {
335        warn!("Setting access mode in a SQLite transaction isn't supported");
336    }
337    Ok(())
338}
339
340#[cfg(feature = "sqlite-use-returning-for-3_35")]
341async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
342    let stmt = Statement {
343        sql: "SELECT sqlite_version()".to_string(),
344        values: None,
345        db_backend: crate::DbBackend::Sqlite,
346    };
347    conn.query_one(stmt)
348        .await?
349        .ok_or_else(|| {
350            DbErr::Conn(RuntimeErr::Internal(
351                "Error reading SQLite version".to_string(),
352            ))
353        })?
354        .try_get_by(0)
355}
356
357impl
358    From<(
359        PoolConnection<sqlx::Sqlite>,
360        Statement,
361        Option<crate::metric::Callback>,
362    )> for crate::QueryStream
363{
364    fn from(
365        (conn, stmt, metric_callback): (
366            PoolConnection<sqlx::Sqlite>,
367            Statement,
368            Option<crate::metric::Callback>,
369        ),
370    ) -> Self {
371        crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
372    }
373}
374
375impl crate::DatabaseTransaction {
376    pub(crate) async fn new_sqlite(
377        inner: PoolConnection<sqlx::Sqlite>,
378        metric_callback: Option<crate::metric::Callback>,
379        record_stmt_in_spans: bool,
380        isolation_level: Option<IsolationLevel>,
381        access_mode: Option<AccessMode>,
382        sqlite_transaction_mode: Option<SqliteTransactionMode>,
383    ) -> Result<crate::DatabaseTransaction, DbErr> {
384        Self::begin(
385            Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
386            crate::DbBackend::Sqlite,
387            metric_callback,
388            record_stmt_in_spans,
389            isolation_level,
390            access_mode,
391            sqlite_transaction_mode,
392        )
393        .await
394    }
395}
396
397#[cfg(feature = "proxy")]
398pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
399    // https://docs.rs/sqlx-sqlite/0.7.2/src/sqlx_sqlite/type_info.rs.html
400    // https://docs.rs/sqlx-sqlite/0.7.2/sqlx_sqlite/types/index.html
401    use sea_query::Value;
402    use sqlx::{Column, Row, TypeInfo};
403    crate::ProxyRow {
404        values: row
405            .columns()
406            .iter()
407            .map(|c| {
408                (
409                    c.name().to_string(),
410                    match c.type_info().name() {
411                        "BOOLEAN" => {
412                            Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
413                        }
414
415                        "INTEGER" => {
416                            Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
417                        }
418
419                        "BIGINT" | "INT8" => Value::BigInt(
420                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
421                        ),
422
423                        "REAL" => {
424                            Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
425                        }
426
427                        "TEXT" => Value::String(
428                            row.try_get::<Option<String>, _>(c.ordinal())
429                                .expect("Failed to get string")
430                                .map(Box::new),
431                        ),
432
433                        "BLOB" => Value::Bytes(
434                            row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
435                                .expect("Failed to get bytes")
436                                .map(Box::new),
437                        ),
438
439                        #[cfg(feature = "with-chrono")]
440                        "DATETIME" => {
441                            use chrono::{DateTime, Utc};
442
443                            Value::ChronoDateTimeUtc(
444                                row.try_get::<Option<DateTime<Utc>>, _>(c.ordinal())
445                                    .expect("Failed to get timestamp")
446                                    .map(Box::new),
447                            )
448                        }
449                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
450                        "DATETIME" => {
451                            use time::OffsetDateTime;
452                            Value::TimeDateTimeWithTimeZone(
453                                row.try_get::<Option<OffsetDateTime>, _>(c.ordinal())
454                                    .expect("Failed to get timestamp")
455                                    .map(Box::new),
456                            )
457                        }
458                        #[cfg(feature = "with-chrono")]
459                        "DATE" => {
460                            use chrono::NaiveDate;
461                            Value::ChronoDate(
462                                row.try_get::<Option<NaiveDate>, _>(c.ordinal())
463                                    .expect("Failed to get date")
464                                    .map(Box::new),
465                            )
466                        }
467                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
468                        "DATE" => {
469                            use time::Date;
470                            Value::TimeDate(
471                                row.try_get::<Option<Date>, _>(c.ordinal())
472                                    .expect("Failed to get date")
473                                    .map(Box::new),
474                            )
475                        }
476
477                        #[cfg(feature = "with-chrono")]
478                        "TIME" => {
479                            use chrono::NaiveTime;
480                            Value::ChronoTime(
481                                row.try_get::<Option<NaiveTime>, _>(c.ordinal())
482                                    .expect("Failed to get time")
483                                    .map(Box::new),
484                            )
485                        }
486                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
487                        "TIME" => {
488                            use time::Time;
489                            Value::TimeTime(
490                                row.try_get::<Option<Time>, _>(c.ordinal())
491                                    .expect("Failed to get time")
492                                    .map(Box::new),
493                            )
494                        }
495
496                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
497                    },
498                )
499            })
500            .collect(),
501    }
502}