sea_orm/driver/
sqlx_sqlite.rs

1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7    Connection, Executor, Sqlite, SqlitePool,
8    pool::PoolConnection,
9    sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16    AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17    IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*,
18    sqlx_error_to_exec_err,
19};
20
21use super::sqlx_common::*;
22
23/// Defines the [sqlx::sqlite] connector
24#[derive(Debug)]
25pub struct SqlxSqliteConnector;
26
27/// Defines a sqlx SQLite pool
28#[derive(Clone)]
29pub struct SqlxSqlitePoolConnection {
30    pub(crate) pool: SqlitePool,
31    metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxSqlitePoolConnection {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
37    }
38}
39
40impl From<SqlitePool> for SqlxSqlitePoolConnection {
41    fn from(pool: SqlitePool) -> Self {
42        SqlxSqlitePoolConnection {
43            pool,
44            metric_callback: None,
45        }
46    }
47}
48
49impl From<SqlitePool> for DatabaseConnection {
50    fn from(pool: SqlitePool) -> Self {
51        DatabaseConnectionType::SqlxSqlitePoolConnection(pool.into()).into()
52    }
53}
54
55impl SqlxSqliteConnector {
56    /// Check if the URI provided corresponds to `sqlite:` for a SQLite database
57    pub fn accepts(string: &str) -> bool {
58        string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
59    }
60
61    /// Add configuration options for the SQLite database
62    #[instrument(level = "trace")]
63    pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64        let mut options = options;
65        let mut sqlx_opts = options
66            .url
67            .parse::<SqliteConnectOptions>()
68            .map_err(sqlx_error_to_conn_err)?;
69        if let Some(sqlcipher_key) = &options.sqlcipher_key {
70            sqlx_opts = sqlx_opts.pragma("key", sqlcipher_key.clone());
71        }
72        use sqlx::ConnectOptions;
73        if !options.sqlx_logging {
74            sqlx_opts = sqlx_opts.disable_statement_logging();
75        } else {
76            sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
77            if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
78                sqlx_opts = sqlx_opts.log_slow_statements(
79                    options.sqlx_slow_statements_logging_level,
80                    options.sqlx_slow_statements_logging_threshold,
81                );
82            }
83        }
84
85        if options.get_max_connections().is_none() {
86            options.max_connections(1);
87        }
88
89        if let Some(f) = &options.sqlite_opts_fn {
90            sqlx_opts = f(sqlx_opts);
91        }
92
93        let after_conn = options.after_connect.clone();
94
95        let pool = if options.connect_lazy {
96            options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
97        } else {
98            options
99                .sqlx_pool_options()
100                .connect_with(sqlx_opts)
101                .await
102                .map_err(sqlx_error_to_conn_err)?
103        };
104
105        let pool = SqlxSqlitePoolConnection {
106            pool,
107            metric_callback: None,
108        };
109
110        #[cfg(feature = "sqlite-use-returning-for-3_35")]
111        {
112            let version = get_version(&pool).await?;
113            super::sqlite::ensure_returning_version(&version)?;
114        }
115
116        let conn: DatabaseConnection =
117            DatabaseConnectionType::SqlxSqlitePoolConnection(pool).into();
118
119        if let Some(cb) = after_conn {
120            cb(conn.clone()).await?;
121        }
122
123        Ok(conn)
124    }
125}
126
127impl SqlxSqliteConnector {
128    /// Instantiate a sqlx pool connection to a [DatabaseConnection]
129    pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
130        DatabaseConnectionType::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
131            pool,
132            metric_callback: None,
133        })
134        .into()
135    }
136}
137
138impl SqlxSqlitePoolConnection {
139    /// Execute a [Statement] on a SQLite backend
140    #[instrument(level = "trace")]
141    pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
142        debug_print!("{}", stmt);
143
144        let query = sqlx_query(&stmt);
145        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
146        crate::metric::metric!(self.metric_callback, &stmt, {
147            match query.execute(&mut *conn).await {
148                Ok(res) => Ok(res.into()),
149                Err(err) => Err(sqlx_error_to_exec_err(err)),
150            }
151        })
152    }
153
154    /// Execute an unprepared SQL statement on a SQLite backend
155    #[instrument(level = "trace")]
156    pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
157        debug_print!("{}", sql);
158
159        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
160        match conn.execute(sql).await {
161            Ok(res) => Ok(res.into()),
162            Err(err) => Err(sqlx_error_to_exec_err(err)),
163        }
164    }
165
166    /// Get one result from a SQL query. Returns [Option::None] if no match was found
167    #[instrument(level = "trace")]
168    pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
169        debug_print!("{}", stmt);
170
171        let query = sqlx_query(&stmt);
172        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
173        crate::metric::metric!(self.metric_callback, &stmt, {
174            match query.fetch_one(&mut *conn).await {
175                Ok(row) => Ok(Some(row.into())),
176                Err(err) => match err {
177                    sqlx::Error::RowNotFound => Ok(None),
178                    _ => Err(sqlx_error_to_query_err(err)),
179                },
180            }
181        })
182    }
183
184    /// Get the results of a query returning them as a Vec<[QueryResult]>
185    #[instrument(level = "trace")]
186    pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
187        debug_print!("{}", stmt);
188
189        let query = sqlx_query(&stmt);
190        let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
191        crate::metric::metric!(self.metric_callback, &stmt, {
192            match query.fetch_all(&mut *conn).await {
193                Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
194                Err(err) => Err(sqlx_error_to_query_err(err)),
195            }
196        })
197    }
198
199    /// Stream the results of executing a SQL query
200    #[instrument(level = "trace")]
201    pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
202        debug_print!("{}", stmt);
203
204        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
205        Ok(QueryStream::from((
206            conn,
207            stmt,
208            self.metric_callback.clone(),
209        )))
210    }
211
212    /// Bundle a set of SQL statements that execute together.
213    #[instrument(level = "trace")]
214    pub async fn begin(
215        &self,
216        isolation_level: Option<IsolationLevel>,
217        access_mode: Option<AccessMode>,
218    ) -> Result<DatabaseTransaction, DbErr> {
219        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
220        DatabaseTransaction::new_sqlite(
221            conn,
222            self.metric_callback.clone(),
223            isolation_level,
224            access_mode,
225        )
226        .await
227    }
228
229    /// Create a SQLite transaction
230    #[instrument(level = "trace", skip(callback))]
231    pub async fn transaction<F, T, E>(
232        &self,
233        callback: F,
234        isolation_level: Option<IsolationLevel>,
235        access_mode: Option<AccessMode>,
236    ) -> Result<T, TransactionError<E>>
237    where
238        F: for<'b> FnOnce(
239                &'b DatabaseTransaction,
240            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
241            + Send,
242        T: Send,
243        E: std::fmt::Display + std::fmt::Debug + Send,
244    {
245        let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
246        let transaction = DatabaseTransaction::new_sqlite(
247            conn,
248            self.metric_callback.clone(),
249            isolation_level,
250            access_mode,
251        )
252        .await
253        .map_err(|e| TransactionError::Connection(e))?;
254        transaction.run(callback).await
255    }
256
257    pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
258    where
259        F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
260    {
261        self.metric_callback = Some(Arc::new(callback));
262    }
263
264    /// Checks if a connection to the database is still valid.
265    pub async fn ping(&self) -> Result<(), DbErr> {
266        let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
267        match conn.ping().await {
268            Ok(_) => Ok(()),
269            Err(err) => Err(sqlx_error_to_conn_err(err)),
270        }
271    }
272
273    /// Explicitly close the SQLite connection.
274    /// See [`Self::close_by_ref`] for usage with references.
275    pub async fn close(self) -> Result<(), DbErr> {
276        self.close_by_ref().await
277    }
278
279    /// Explicitly close the SQLite connection
280    pub async fn close_by_ref(&self) -> Result<(), DbErr> {
281        self.pool.close().await;
282        Ok(())
283    }
284}
285
286impl From<SqliteRow> for QueryResult {
287    fn from(row: SqliteRow) -> QueryResult {
288        QueryResult {
289            row: QueryResultRow::SqlxSqlite(row),
290        }
291    }
292}
293
294impl From<SqliteQueryResult> for ExecResult {
295    fn from(result: SqliteQueryResult) -> ExecResult {
296        ExecResult {
297            result: ExecResultHolder::SqlxSqlite(result),
298        }
299    }
300}
301
302pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
303    let values = stmt
304        .values
305        .as_ref()
306        .map_or(Values(Vec::new()), |values| values.clone());
307    sqlx::query_with(&stmt.sql, SqlxValues(values))
308}
309
310pub(crate) async fn set_transaction_config(
311    _conn: &mut PoolConnection<Sqlite>,
312    isolation_level: Option<IsolationLevel>,
313    access_mode: Option<AccessMode>,
314) -> Result<(), DbErr> {
315    if isolation_level.is_some() {
316        warn!("Setting isolation level in a SQLite transaction isn't supported");
317    }
318    if access_mode.is_some() {
319        warn!("Setting access mode in a SQLite transaction isn't supported");
320    }
321    Ok(())
322}
323
324#[cfg(feature = "sqlite-use-returning-for-3_35")]
325async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
326    let stmt = Statement {
327        sql: "SELECT sqlite_version()".to_string(),
328        values: None,
329        db_backend: crate::DbBackend::Sqlite,
330    };
331    conn.query_one(stmt)
332        .await?
333        .ok_or_else(|| {
334            DbErr::Conn(RuntimeErr::Internal(
335                "Error reading SQLite version".to_string(),
336            ))
337        })?
338        .try_get_by(0)
339}
340
341impl
342    From<(
343        PoolConnection<sqlx::Sqlite>,
344        Statement,
345        Option<crate::metric::Callback>,
346    )> for crate::QueryStream
347{
348    fn from(
349        (conn, stmt, metric_callback): (
350            PoolConnection<sqlx::Sqlite>,
351            Statement,
352            Option<crate::metric::Callback>,
353        ),
354    ) -> Self {
355        crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
356    }
357}
358
359impl crate::DatabaseTransaction {
360    pub(crate) async fn new_sqlite(
361        inner: PoolConnection<sqlx::Sqlite>,
362        metric_callback: Option<crate::metric::Callback>,
363        isolation_level: Option<IsolationLevel>,
364        access_mode: Option<AccessMode>,
365    ) -> Result<crate::DatabaseTransaction, DbErr> {
366        Self::begin(
367            Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
368            crate::DbBackend::Sqlite,
369            metric_callback,
370            isolation_level,
371            access_mode,
372        )
373        .await
374    }
375}
376
377#[cfg(feature = "proxy")]
378pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
379    // https://docs.rs/sqlx-sqlite/0.7.2/src/sqlx_sqlite/type_info.rs.html
380    // https://docs.rs/sqlx-sqlite/0.7.2/sqlx_sqlite/types/index.html
381    use sea_query::Value;
382    use sqlx::{Column, Row, TypeInfo};
383    crate::ProxyRow {
384        values: row
385            .columns()
386            .iter()
387            .map(|c| {
388                (
389                    c.name().to_string(),
390                    match c.type_info().name() {
391                        "BOOLEAN" => {
392                            Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
393                        }
394
395                        "INTEGER" => {
396                            Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
397                        }
398
399                        "BIGINT" | "INT8" => Value::BigInt(
400                            row.try_get(c.ordinal()).expect("Failed to get big integer"),
401                        ),
402
403                        "REAL" => {
404                            Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
405                        }
406
407                        "TEXT" => Value::String(
408                            row.try_get::<Option<String>, _>(c.ordinal())
409                                .expect("Failed to get string")
410                                .map(Box::new),
411                        ),
412
413                        "BLOB" => Value::Bytes(
414                            row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
415                                .expect("Failed to get bytes")
416                                .map(Box::new),
417                        ),
418
419                        #[cfg(feature = "with-chrono")]
420                        "DATETIME" => {
421                            use chrono::{DateTime, Utc};
422
423                            Value::ChronoDateTimeUtc(
424                                row.try_get::<Option<DateTime<Utc>>, _>(c.ordinal())
425                                    .expect("Failed to get timestamp")
426                                    .map(Box::new),
427                            )
428                        }
429                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
430                        "DATETIME" => {
431                            use time::OffsetDateTime;
432                            Value::TimeDateTimeWithTimeZone(
433                                row.try_get::<Option<OffsetDateTime>, _>(c.ordinal())
434                                    .expect("Failed to get timestamp")
435                                    .map(Box::new),
436                            )
437                        }
438                        #[cfg(feature = "with-chrono")]
439                        "DATE" => {
440                            use chrono::NaiveDate;
441                            Value::ChronoDate(
442                                row.try_get::<Option<NaiveDate>, _>(c.ordinal())
443                                    .expect("Failed to get date")
444                                    .map(Box::new),
445                            )
446                        }
447                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
448                        "DATE" => {
449                            use time::Date;
450                            Value::TimeDate(
451                                row.try_get::<Option<Date>, _>(c.ordinal())
452                                    .expect("Failed to get date")
453                                    .map(Box::new),
454                            )
455                        }
456
457                        #[cfg(feature = "with-chrono")]
458                        "TIME" => {
459                            use chrono::NaiveTime;
460                            Value::ChronoTime(
461                                row.try_get::<Option<NaiveTime>, _>(c.ordinal())
462                                    .expect("Failed to get time")
463                                    .map(Box::new),
464                            )
465                        }
466                        #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
467                        "TIME" => {
468                            use time::Time;
469                            Value::TimeTime(
470                                row.try_get::<Option<Time>, _>(c.ordinal())
471                                    .expect("Failed to get time")
472                                    .map(Box::new),
473                            )
474                        }
475
476                        _ => unreachable!("Unknown column type: {}", c.type_info().name()),
477                    },
478                )
479            })
480            .collect(),
481    }
482}