Skip to main content

sea_orm/driver/
sqlx_sqlite.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    Connection, Executor, Sqlite, SqlitePool,
8    pool::PoolConnection,
9    sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16    AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17    IsolationLevel, SqliteTransactionMode, Statement, TransactionError, debug_print, error::*,
18    executor::*, sqlx_error_to_exec_err,
19};
20
21use super::sqlx_common::*;
22
23#[cfg(feature = "stream")]
24use crate::QueryStream;
25
26/// Defines the [sqlx::sqlite] connector
27#[derive(Debug)]
28pub struct SqlxSqliteConnector;
29
30/// Defines a sqlx SQLite pool
31#[derive(Clone)]
32pub struct SqlxSqlitePoolConnection {
33    pub(crate) pool: SqlitePool,
34    metric_callback: Option<crate::metric::Callback>,
35    pub(crate) record_stmt_in_spans: bool,
36}
37
38impl std::fmt::Debug for SqlxSqlitePoolConnection {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
41    }
42}
43
44impl From<SqlitePool> for SqlxSqlitePoolConnection {
45    fn from(pool: SqlitePool) -> Self {
46        SqlxSqlitePoolConnection {
47            pool,
48            metric_callback: None,
49            record_stmt_in_spans: true,
50        }
51    }
52}
53
54impl From<SqlitePool> for DatabaseConnection {
55    fn from(pool: SqlitePool) -> Self {
56        DatabaseConnectionType::SqlxSqlitePoolConnection(pool.into()).into()
57    }
58}
59
60impl SqlxSqliteConnector {
61    /// Check if the URI provided corresponds to `sqlite:` for a SQLite database
62    pub fn accepts(string: &str) -> bool {
63        string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
64    }
65
66    /// Add configuration options for the SQLite database
67    #[instrument(level = "trace")]
68    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
69        let mut options = options;
70        let record_stmt_in_spans = options.get_record_stmt_in_spans();
71        let mut sqlx_opts = options
72            .url
73            .parse::<SqliteConnectOptions>()
74            .map_err(sqlx_error_to_conn_err)?;
75        if let Some(sqlcipher_key) = &options.sqlcipher_key {
76            sqlx_opts = sqlx_opts.pragma("key", sqlcipher_key.clone());
77        }
78        use sqlx::ConnectOptions;
79        if !options.sqlx_logging {
80            sqlx_opts = sqlx_opts.disable_statement_logging();
81        } else {
82            sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
83            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
84                sqlx_opts = sqlx_opts.log_slow_statements(
85                    options.sqlx_slow_statements_logging_level,
86                    options.sqlx_slow_statements_logging_threshold,
87                );
88            }
89        }
90
91        if options.get_max_connections().is_none() {
92            options.max_connections(1);
93        }
94
95        if let Some(f) = &options.sqlite_opts_fn {
96            sqlx_opts = f(sqlx_opts);
97        }
98
99        let after_conn = options.after_connect.clone();
100        let connect_lazy = options.connect_lazy;
101        let sqlite_pool_opts_fn = options.sqlite_pool_opts_fn.clone();
102        let mut pool_options = options.sqlx_pool_options();
103
104        if let Some(f) = &sqlite_pool_opts_fn {
105            pool_options = f(pool_options);
106        }
107
108        let pool = if connect_lazy {
109            pool_options.connect_lazy_with(sqlx_opts)
110        } else {
111            pool_options
112                .connect_with(sqlx_opts)
113                .await
114                .map_err(sqlx_error_to_conn_err)?
115        };
116
117        let pool = SqlxSqlitePoolConnection {
118            pool,
119            metric_callback: None,
120            record_stmt_in_spans,
121        };
122
123        #[cfg(feature = "sqlite-use-returning-for-3_35")]
124        {
125            let version = get_version(&pool).await?;
126            super::sqlite::ensure_returning_version(&version)?;
127        }
128
129        let conn: DatabaseConnection =
130            DatabaseConnectionType::SqlxSqlitePoolConnection(pool).into();
131
132        if let Some(cb) = after_conn {
133            cb(conn.clone()).await?;
134        }
135
136        Ok(conn)
137    }
138}
139
140impl SqlxSqliteConnector {
141    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
142    pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
143        DatabaseConnectionType::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
144            pool,
145            metric_callback: None,
146            record_stmt_in_spans: true,
147        })
148        .into()
149    }
150}
151
152impl SqlxSqlitePoolConnection {
153    /// Execute a [Statement] on a SQLite backend
154    #[instrument(level = "trace", skip(stmt))]
155    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
156        debug_print!("{}", stmt);
157
158        let query = sqlx_query(&stmt);
159        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
160        crate::metric::metric!(self.metric_callback, &stmt, {
161            match query.execute(&mut *conn).await {
162                Ok(res) => Ok(res.into()),
163                Err(err) => Err(sqlx_error_to_exec_err(err)),
164            }
165        })
166    }
167
168    /// Execute an unprepared SQL statement on a SQLite backend
169    #[instrument(level = "trace", skip(sql))]
170    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
171        debug_print!("{}", sql);
172
173        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
174        match conn.execute(sqlx::AssertSqlSafe(sql.to_owned())).await {
175            Ok(res) => Ok(res.into()),
176            Err(err) => Err(sqlx_error_to_exec_err(err)),
177        }
178    }
179
180    /// Get one result from a SQL query. Returns [Option::None] if no match was found
181    #[instrument(level = "trace", skip(stmt))]
182    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
183        debug_print!("{}", stmt);
184
185        let query = sqlx_query(&stmt);
186        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
187        crate::metric::metric!(self.metric_callback, &stmt, {
188            match query.fetch_one(&mut *conn).await {
189                Ok(row) => Ok(Some(row.into())),
190                Err(err) => match err {
191                    sqlx::Error::RowNotFound => Ok(None),
192                    _ => Err(sqlx_error_to_query_err(err)),
193                },
194            }
195        })
196    }
197
198    /// Get the results of a query returning them as a Vec<[QueryResult]>
199    #[instrument(level = "trace", skip(stmt))]
200    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
201        debug_print!("{}", stmt);
202
203        let query = sqlx_query(&stmt);
204        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
205        crate::metric::metric!(self.metric_callback, &stmt, {
206            match query.fetch_all(&mut *conn).await {
207                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
208                Err(err) => Err(sqlx_error_to_query_err(err)),
209            }
210        })
211    }
212
213    /// Stream the results of executing a SQL query
214    #[instrument(level = "trace", skip(stmt))]
215    #[cfg(feature = "stream")]
216    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
217        debug_print!("{}", stmt);
218
219        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
220        Ok(QueryStream::from((
221            conn,
222            stmt,
223            self.metric_callback.clone(),
224        )))
225    }
226
227    /// Bundle a set of SQL statements that execute together.
228    #[instrument(level = "trace")]
229    pub async fn begin(
230        &self,
231        isolation_level: Option<IsolationLevel>,
232        access_mode: Option<AccessMode>,
233        sqlite_transaction_mode: Option<SqliteTransactionMode>,
234    ) -> Result<DatabaseTransaction, DbErr> {
235        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
236        DatabaseTransaction::new_sqlite(
237            conn,
238            self.metric_callback.clone(),
239            self.record_stmt_in_spans,
240            isolation_level,
241            access_mode,
242            sqlite_transaction_mode,
243        )
244        .await
245    }
246
247    /// Create a SQLite transaction
248    #[instrument(level = "trace", skip(callback))]
249    pub async fn transaction<F, T, E>(
250        &self,
251        callback: F,
252        isolation_level: Option<IsolationLevel>,
253        access_mode: Option<AccessMode>,
254    ) -> Result<T, TransactionError<E>>
255    where
256        F: for<'b> FnOnce(
257                &'b DatabaseTransaction,
258            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
259            + Send,
260        T: Send,
261        E: std::fmt::Display + std::fmt::Debug + Send,
262    {
263        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
264        let transaction = DatabaseTransaction::new_sqlite(
265            conn,
266            self.metric_callback.clone(),
267            self.record_stmt_in_spans,
268            isolation_level,
269            access_mode,
270            None,
271        )
272        .await
273        .map_err(|e| TransactionError::Connection(e))?;
274        transaction.run(callback).await
275    }
276
277    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
278    where
279        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
280    {
281        self.metric_callback = Some(Arc::new(callback));
282    }
283
284    /// Checks if a connection to the database is still valid.
285    pub async fn ping(&self) -> Result<(), DbErr> {
286        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
287        match conn.ping().await {
288            Ok(_) => Ok(()),
289            Err(err) => Err(sqlx_error_to_conn_err(err)),
290        }
291    }
292
293    /// Explicitly close the SQLite connection.
294    /// See [`Self::close_by_ref`] for usage with references.
295    pub async fn close(self) -> Result<(), DbErr> {
296        self.close_by_ref().await
297    }
298
299    /// Explicitly close the SQLite connection
300    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
301        self.pool.close().await;
302        Ok(())
303    }
304}
305
306impl From<SqliteRow> for QueryResult {
307    fn from(row: SqliteRow) -> QueryResult {
308        QueryResult {
309            row: QueryResultRow::SqlxSqlite(row),
310        }
311    }
312}
313
314impl From<SqliteQueryResult> for ExecResult {
315    fn from(result: SqliteQueryResult) -> ExecResult {
316        ExecResult {
317            result: ExecResultHolder::SqlxSqlite(result),
318        }
319    }
320}
321
322pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
323    let values = stmt
324        .values
325        .as_ref()
326        .map_or(Values(Vec::new()), |values| values.clone());
327    sqlx::query_with(sqlx::AssertSqlSafe(stmt.sql.as_str()), SqlxValues(values))
328}
329
330pub(crate) async fn set_transaction_config(
331    _conn: &mut PoolConnection<Sqlite>,
332    isolation_level: Option<IsolationLevel>,
333    access_mode: Option<AccessMode>,
334) -> Result<(), DbErr> {
335    if isolation_level.is_some() {
336        warn!("Setting isolation level in a SQLite transaction isn't supported");
337    }
338    if access_mode.is_some() {
339        warn!("Setting access mode in a SQLite transaction isn't supported");
340    }
341    Ok(())
342}
343
344#[cfg(feature = "sqlite-use-returning-for-3_35")]
345async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
346    let stmt = Statement {
347        sql: "SELECT sqlite_version()".to_string(),
348        values: None,
349        db_backend: crate::DbBackend::Sqlite,
350    };
351    conn.query_one(stmt)
352        .await?
353        .ok_or_else(|| {
354            DbErr::Conn(RuntimeErr::Internal(
355                "Error reading SQLite version".to_string(),
356            ))
357        })?
358        .try_get_by(0)
359}
360
361#[cfg(feature = "stream")]
362impl
363    From<(
364        PoolConnection<sqlx::Sqlite>,
365        Statement,
366        Option<crate::metric::Callback>,
367    )> for crate::QueryStream
368{
369    fn from(
370        (conn, stmt, metric_callback): (
371            PoolConnection<sqlx::Sqlite>,
372            Statement,
373            Option<crate::metric::Callback>,
374        ),
375    ) -> Self {
376        crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
377    }
378}
379
380impl crate::DatabaseTransaction {
381    pub(crate) async fn new_sqlite(
382        inner: PoolConnection<sqlx::Sqlite>,
383        metric_callback: Option<crate::metric::Callback>,
384        record_stmt_in_spans: bool,
385        isolation_level: Option<IsolationLevel>,
386        access_mode: Option<AccessMode>,
387        sqlite_transaction_mode: Option<SqliteTransactionMode>,
388    ) -> Result<crate::DatabaseTransaction, DbErr> {
389        Self::begin(
390            Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
391            crate::DbBackend::Sqlite,
392            metric_callback,
393            record_stmt_in_spans,
394            isolation_level,
395            access_mode,
396            sqlite_transaction_mode,
397        )
398        .await
399    }
400}
401
402#[cfg(feature = "proxy")]
403pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
404    // https://docs.rs/sqlx-sqlite/0.7.2/src/sqlx_sqlite/type_info.rs.html
405    // https://docs.rs/sqlx-sqlite/0.7.2/sqlx_sqlite/types/index.html
406    use sea_query::Value;
407    use sqlx::{Column, Row, TypeInfo};
408    crate::ProxyRow {
409        values: row
410            .columns()
411            .iter()
412            .map(|c| {
413                (
414                    c.name().to_string(),
415                    match c.type_info().name() {
416                        "BOOLEAN" => {
417                            Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
418                        }
419
420                        "INTEGER" => {
421                            Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
422                        }
423
424                        "BIGINT" | "INT8" => Value::BigInt(
425                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
426                        ),
427
428                        "REAL" => {
429                            Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
430                        }
431
432                        "TEXT" => Value::String(
433                            row.try_get::<Option<String>, _>(c.ordinal())
434                                .expect("Failed to get string")
435                                .map(Box::new),
436                        ),
437
438                        "BLOB" => Value::Bytes(
439                            row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
440                                .expect("Failed to get bytes")
441                                .map(Box::new),
442                        ),
443
444                        #[cfg(feature = "with-chrono")]
445                        "DATETIME" => {
446                            use chrono::{DateTime, Utc};
447
448                            Value::ChronoDateTimeUtc(
449                                row.try_get::<Option<DateTime<Utc>>, _>(c.ordinal())
450                                    .expect("Failed to get timestamp")
451                                    .map(Box::new),
452                            )
453                        }
454                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
455                        "DATETIME" => {
456                            use time::OffsetDateTime;
457                            Value::TimeDateTimeWithTimeZone(
458                                row.try_get::<Option<OffsetDateTime>, _>(c.ordinal())
459                                    .expect("Failed to get timestamp")
460                                    .map(Box::new),
461                            )
462                        }
463                        #[cfg(feature = "with-chrono")]
464                        "DATE" => {
465                            use chrono::NaiveDate;
466                            Value::ChronoDate(
467                                row.try_get::<Option<NaiveDate>, _>(c.ordinal())
468                                    .expect("Failed to get date")
469                                    .map(Box::new),
470                            )
471                        }
472                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
473                        "DATE" => {
474                            use time::Date;
475                            Value::TimeDate(
476                                row.try_get::<Option<Date>, _>(c.ordinal())
477                                    .expect("Failed to get date")
478                                    .map(Box::new),
479                            )
480                        }
481
482                        #[cfg(feature = "with-chrono")]
483                        "TIME" => {
484                            use chrono::NaiveTime;
485                            Value::ChronoTime(
486                                row.try_get::<Option<NaiveTime>, _>(c.ordinal())
487                                    .expect("Failed to get time")
488                                    .map(Box::new),
489                            )
490                        }
491                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
492                        "TIME" => {
493                            use time::Time;
494                            Value::TimeTime(
495                                row.try_get::<Option<Time>, _>(c.ordinal())
496                                    .expect("Failed to get time")
497                                    .map(Box::new),
498                            )
499                        }
500
501                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
502                    },
503                )
504            })
505            .collect(),
506    }
507}