1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 Connection, Executor, Sqlite, SqlitePool,
8 pool::PoolConnection,
9 sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16 AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17 IsolationLevel, QueryStream, SqliteTransactionMode, Statement, TransactionError, debug_print,
18 error::*, executor::*, sqlx_error_to_exec_err,
19};
20
21use super::sqlx_common::*;
22
23#[derive(Debug)]
25pub struct SqlxSqliteConnector;
26
27#[derive(Clone)]
29pub struct SqlxSqlitePoolConnection {
30 pub(crate) pool: SqlitePool,
31 metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxSqlitePoolConnection {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
37 }
38}
39
40impl From<SqlitePool> for SqlxSqlitePoolConnection {
41 fn from(pool: SqlitePool) -> Self {
42 SqlxSqlitePoolConnection {
43 pool,
44 metric_callback: None,
45 }
46 }
47}
48
49impl From<SqlitePool> for DatabaseConnection {
50 fn from(pool: SqlitePool) -> Self {
51 DatabaseConnectionType::SqlxSqlitePoolConnection(pool.into()).into()
52 }
53}
54
55impl SqlxSqliteConnector {
56 pub fn accepts(string: &str) -> bool {
58 string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
59 }
60
61 #[instrument(level = "trace")]
63 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64 let mut options = options;
65 let mut sqlx_opts = options
66 .url
67 .parse::<SqliteConnectOptions>()
68 .map_err(sqlx_error_to_conn_err)?;
69 if let Some(sqlcipher_key) = &options.sqlcipher_key {
70 sqlx_opts = sqlx_opts.pragma("key", sqlcipher_key.clone());
71 }
72 use sqlx::ConnectOptions;
73 if !options.sqlx_logging {
74 sqlx_opts = sqlx_opts.disable_statement_logging();
75 } else {
76 sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
77 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
78 sqlx_opts = sqlx_opts.log_slow_statements(
79 options.sqlx_slow_statements_logging_level,
80 options.sqlx_slow_statements_logging_threshold,
81 );
82 }
83 }
84
85 if options.get_max_connections().is_none() {
86 options.max_connections(1);
87 }
88
89 if let Some(f) = &options.sqlite_opts_fn {
90 sqlx_opts = f(sqlx_opts);
91 }
92
93 let after_conn = options.after_connect.clone();
94 let connect_lazy = options.connect_lazy;
95 let sqlite_pool_opts_fn = options.sqlite_pool_opts_fn.clone();
96 let mut pool_options = options.sqlx_pool_options();
97
98 if let Some(f) = &sqlite_pool_opts_fn {
99 pool_options = f(pool_options);
100 }
101
102 let pool = if connect_lazy {
103 pool_options.connect_lazy_with(sqlx_opts)
104 } else {
105 pool_options
106 .connect_with(sqlx_opts)
107 .await
108 .map_err(sqlx_error_to_conn_err)?
109 };
110
111 let pool = SqlxSqlitePoolConnection {
112 pool,
113 metric_callback: None,
114 };
115
116 #[cfg(feature = "sqlite-use-returning-for-3_35")]
117 {
118 let version = get_version(&pool).await?;
119 super::sqlite::ensure_returning_version(&version)?;
120 }
121
122 let conn: DatabaseConnection =
123 DatabaseConnectionType::SqlxSqlitePoolConnection(pool).into();
124
125 if let Some(cb) = after_conn {
126 cb(conn.clone()).await?;
127 }
128
129 Ok(conn)
130 }
131}
132
133impl SqlxSqliteConnector {
134 pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
136 DatabaseConnectionType::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
137 pool,
138 metric_callback: None,
139 })
140 .into()
141 }
142}
143
144impl SqlxSqlitePoolConnection {
145 #[instrument(level = "trace")]
147 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
148 debug_print!("{}", stmt);
149
150 let query = sqlx_query(&stmt);
151 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
152 crate::metric::metric!(self.metric_callback, &stmt, {
153 match query.execute(&mut *conn).await {
154 Ok(res) => Ok(res.into()),
155 Err(err) => Err(sqlx_error_to_exec_err(err)),
156 }
157 })
158 }
159
160 #[instrument(level = "trace")]
162 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
163 debug_print!("{}", sql);
164
165 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
166 match conn.execute(sql).await {
167 Ok(res) => Ok(res.into()),
168 Err(err) => Err(sqlx_error_to_exec_err(err)),
169 }
170 }
171
172 #[instrument(level = "trace")]
174 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
175 debug_print!("{}", stmt);
176
177 let query = sqlx_query(&stmt);
178 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
179 crate::metric::metric!(self.metric_callback, &stmt, {
180 match query.fetch_one(&mut *conn).await {
181 Ok(row) => Ok(Some(row.into())),
182 Err(err) => match err {
183 sqlx::Error::RowNotFound => Ok(None),
184 _ => Err(sqlx_error_to_query_err(err)),
185 },
186 }
187 })
188 }
189
190 #[instrument(level = "trace")]
192 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
193 debug_print!("{}", stmt);
194
195 let query = sqlx_query(&stmt);
196 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
197 crate::metric::metric!(self.metric_callback, &stmt, {
198 match query.fetch_all(&mut *conn).await {
199 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
200 Err(err) => Err(sqlx_error_to_query_err(err)),
201 }
202 })
203 }
204
205 #[instrument(level = "trace")]
207 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
208 debug_print!("{}", stmt);
209
210 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
211 Ok(QueryStream::from((
212 conn,
213 stmt,
214 self.metric_callback.clone(),
215 )))
216 }
217
218 #[instrument(level = "trace")]
220 pub async fn begin(
221 &self,
222 isolation_level: Option<IsolationLevel>,
223 access_mode: Option<AccessMode>,
224 sqlite_transaction_mode: Option<SqliteTransactionMode>,
225 ) -> Result<DatabaseTransaction, DbErr> {
226 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
227 DatabaseTransaction::new_sqlite(
228 conn,
229 self.metric_callback.clone(),
230 isolation_level,
231 access_mode,
232 sqlite_transaction_mode,
233 )
234 .await
235 }
236
237 #[instrument(level = "trace", skip(callback))]
239 pub async fn transaction<F, T, E>(
240 &self,
241 callback: F,
242 isolation_level: Option<IsolationLevel>,
243 access_mode: Option<AccessMode>,
244 ) -> Result<T, TransactionError<E>>
245 where
246 F: for<'b> FnOnce(
247 &'b DatabaseTransaction,
248 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
249 + Send,
250 T: Send,
251 E: std::fmt::Display + std::fmt::Debug + Send,
252 {
253 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
254 let transaction = DatabaseTransaction::new_sqlite(
255 conn,
256 self.metric_callback.clone(),
257 isolation_level,
258 access_mode,
259 None,
260 )
261 .await
262 .map_err(|e| TransactionError::Connection(e))?;
263 transaction.run(callback).await
264 }
265
266 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
267 where
268 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
269 {
270 self.metric_callback = Some(Arc::new(callback));
271 }
272
273 pub async fn ping(&self) -> Result<(), DbErr> {
275 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
276 match conn.ping().await {
277 Ok(_) => Ok(()),
278 Err(err) => Err(sqlx_error_to_conn_err(err)),
279 }
280 }
281
282 pub async fn close(self) -> Result<(), DbErr> {
285 self.close_by_ref().await
286 }
287
288 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
290 self.pool.close().await;
291 Ok(())
292 }
293}
294
295impl From<SqliteRow> for QueryResult {
296 fn from(row: SqliteRow) -> QueryResult {
297 QueryResult {
298 row: QueryResultRow::SqlxSqlite(row),
299 }
300 }
301}
302
303impl From<SqliteQueryResult> for ExecResult {
304 fn from(result: SqliteQueryResult) -> ExecResult {
305 ExecResult {
306 result: ExecResultHolder::SqlxSqlite(result),
307 }
308 }
309}
310
311pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
312 let values = stmt
313 .values
314 .as_ref()
315 .map_or(Values(Vec::new()), |values| values.clone());
316 sqlx::query_with(&stmt.sql, SqlxValues(values))
317}
318
319pub(crate) async fn set_transaction_config(
320 _conn: &mut PoolConnection<Sqlite>,
321 isolation_level: Option<IsolationLevel>,
322 access_mode: Option<AccessMode>,
323) -> Result<(), DbErr> {
324 if isolation_level.is_some() {
325 warn!("Setting isolation level in a SQLite transaction isn't supported");
326 }
327 if access_mode.is_some() {
328 warn!("Setting access mode in a SQLite transaction isn't supported");
329 }
330 Ok(())
331}
332
333#[cfg(feature = "sqlite-use-returning-for-3_35")]
334async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
335 let stmt = Statement {
336 sql: "SELECT sqlite_version()".to_string(),
337 values: None,
338 db_backend: crate::DbBackend::Sqlite,
339 };
340 conn.query_one(stmt)
341 .await?
342 .ok_or_else(|| {
343 DbErr::Conn(RuntimeErr::Internal(
344 "Error reading SQLite version".to_string(),
345 ))
346 })?
347 .try_get_by(0)
348}
349
350impl
351 From<(
352 PoolConnection<sqlx::Sqlite>,
353 Statement,
354 Option<crate::metric::Callback>,
355 )> for crate::QueryStream
356{
357 fn from(
358 (conn, stmt, metric_callback): (
359 PoolConnection<sqlx::Sqlite>,
360 Statement,
361 Option<crate::metric::Callback>,
362 ),
363 ) -> Self {
364 crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
365 }
366}
367
368impl crate::DatabaseTransaction {
369 pub(crate) async fn new_sqlite(
370 inner: PoolConnection<sqlx::Sqlite>,
371 metric_callback: Option<crate::metric::Callback>,
372 isolation_level: Option<IsolationLevel>,
373 access_mode: Option<AccessMode>,
374 sqlite_transaction_mode: Option<SqliteTransactionMode>,
375 ) -> Result<crate::DatabaseTransaction, DbErr> {
376 Self::begin(
377 Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
378 crate::DbBackend::Sqlite,
379 metric_callback,
380 isolation_level,
381 access_mode,
382 sqlite_transaction_mode,
383 )
384 .await
385 }
386}
387
388#[cfg(feature = "proxy")]
389pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
390 use sea_query::Value;
393 use sqlx::{Column, Row, TypeInfo};
394 crate::ProxyRow {
395 values: row
396 .columns()
397 .iter()
398 .map(|c| {
399 (
400 c.name().to_string(),
401 match c.type_info().name() {
402 "BOOLEAN" => {
403 Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
404 }
405
406 "INTEGER" => {
407 Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
408 }
409
410 "BIGINT" | "INT8" => Value::BigInt(
411 row.try_get(c.ordinal()).expect("Failed to get big integer"),
412 ),
413
414 "REAL" => {
415 Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
416 }
417
418 "TEXT" => Value::String(
419 row.try_get::<Option<String>, _>(c.ordinal())
420 .expect("Failed to get string")
421 .map(Box::new),
422 ),
423
424 "BLOB" => Value::Bytes(
425 row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
426 .expect("Failed to get bytes")
427 .map(Box::new),
428 ),
429
430 #[cfg(feature = "with-chrono")]
431 "DATETIME" => {
432 use chrono::{DateTime, Utc};
433
434 Value::ChronoDateTimeUtc(
435 row.try_get::<Option<DateTime<Utc>>, _>(c.ordinal())
436 .expect("Failed to get timestamp")
437 .map(Box::new),
438 )
439 }
440 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
441 "DATETIME" => {
442 use time::OffsetDateTime;
443 Value::TimeDateTimeWithTimeZone(
444 row.try_get::<Option<OffsetDateTime>, _>(c.ordinal())
445 .expect("Failed to get timestamp")
446 .map(Box::new),
447 )
448 }
449 #[cfg(feature = "with-chrono")]
450 "DATE" => {
451 use chrono::NaiveDate;
452 Value::ChronoDate(
453 row.try_get::<Option<NaiveDate>, _>(c.ordinal())
454 .expect("Failed to get date")
455 .map(Box::new),
456 )
457 }
458 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
459 "DATE" => {
460 use time::Date;
461 Value::TimeDate(
462 row.try_get::<Option<Date>, _>(c.ordinal())
463 .expect("Failed to get date")
464 .map(Box::new),
465 )
466 }
467
468 #[cfg(feature = "with-chrono")]
469 "TIME" => {
470 use chrono::NaiveTime;
471 Value::ChronoTime(
472 row.try_get::<Option<NaiveTime>, _>(c.ordinal())
473 .expect("Failed to get time")
474 .map(Box::new),
475 )
476 }
477 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
478 "TIME" => {
479 use time::Time;
480 Value::TimeTime(
481 row.try_get::<Option<Time>, _>(c.ordinal())
482 .expect("Failed to get time")
483 .map(Box::new),
484 )
485 }
486
487 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
488 },
489 )
490 })
491 .collect(),
492 }
493}