1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 Connection, Executor, Sqlite, SqlitePool,
8 pool::PoolConnection,
9 sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16 AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17 IsolationLevel, SqliteTransactionMode, Statement, TransactionError, debug_print, error::*,
18 executor::*, sqlx_error_to_exec_err,
19};
20
21use super::sqlx_common::*;
22
23#[cfg(feature = "stream")]
24use crate::QueryStream;
25
26#[derive(Debug)]
28pub struct SqlxSqliteConnector;
29
30#[derive(Clone)]
32pub struct SqlxSqlitePoolConnection {
33 pub(crate) pool: SqlitePool,
34 metric_callback: Option<crate::metric::Callback>,
35 pub(crate) record_stmt_in_spans: bool,
36}
37
38impl std::fmt::Debug for SqlxSqlitePoolConnection {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
41 }
42}
43
44impl From<SqlitePool> for SqlxSqlitePoolConnection {
45 fn from(pool: SqlitePool) -> Self {
46 SqlxSqlitePoolConnection {
47 pool,
48 metric_callback: None,
49 record_stmt_in_spans: true,
50 }
51 }
52}
53
54impl From<SqlitePool> for DatabaseConnection {
55 fn from(pool: SqlitePool) -> Self {
56 DatabaseConnectionType::SqlxSqlitePoolConnection(pool.into()).into()
57 }
58}
59
60impl SqlxSqliteConnector {
61 pub fn accepts(string: &str) -> bool {
63 string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
64 }
65
66 #[instrument(level = "trace")]
68 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
69 let mut options = options;
70 let record_stmt_in_spans = options.get_record_stmt_in_spans();
71 let mut sqlx_opts = options
72 .url
73 .parse::<SqliteConnectOptions>()
74 .map_err(sqlx_error_to_conn_err)?;
75 if let Some(sqlcipher_key) = &options.sqlcipher_key {
76 sqlx_opts = sqlx_opts.pragma("key", sqlcipher_key.clone());
77 }
78 use sqlx::ConnectOptions;
79 if !options.sqlx_logging {
80 sqlx_opts = sqlx_opts.disable_statement_logging();
81 } else {
82 sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
83 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
84 sqlx_opts = sqlx_opts.log_slow_statements(
85 options.sqlx_slow_statements_logging_level,
86 options.sqlx_slow_statements_logging_threshold,
87 );
88 }
89 }
90
91 if options.get_max_connections().is_none() {
92 options.max_connections(1);
93 }
94
95 if let Some(f) = &options.sqlite_opts_fn {
96 sqlx_opts = f(sqlx_opts);
97 }
98
99 let after_conn = options.after_connect.clone();
100 let connect_lazy = options.connect_lazy;
101 let sqlite_pool_opts_fn = options.sqlite_pool_opts_fn.clone();
102 let mut pool_options = options.sqlx_pool_options();
103
104 if let Some(f) = &sqlite_pool_opts_fn {
105 pool_options = f(pool_options);
106 }
107
108 let pool = if connect_lazy {
109 pool_options.connect_lazy_with(sqlx_opts)
110 } else {
111 pool_options
112 .connect_with(sqlx_opts)
113 .await
114 .map_err(sqlx_error_to_conn_err)?
115 };
116
117 let pool = SqlxSqlitePoolConnection {
118 pool,
119 metric_callback: None,
120 record_stmt_in_spans,
121 };
122
123 #[cfg(feature = "sqlite-use-returning-for-3_35")]
124 {
125 let version = get_version(&pool).await?;
126 super::sqlite::ensure_returning_version(&version)?;
127 }
128
129 let conn: DatabaseConnection =
130 DatabaseConnectionType::SqlxSqlitePoolConnection(pool).into();
131
132 if let Some(cb) = after_conn {
133 cb(conn.clone()).await?;
134 }
135
136 Ok(conn)
137 }
138}
139
140impl SqlxSqliteConnector {
141 pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
143 DatabaseConnectionType::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
144 pool,
145 metric_callback: None,
146 record_stmt_in_spans: true,
147 })
148 .into()
149 }
150}
151
152impl SqlxSqlitePoolConnection {
153 #[instrument(level = "trace", skip(stmt))]
155 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
156 debug_print!("{}", stmt);
157
158 let query = sqlx_query(&stmt);
159 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
160 crate::metric::metric!(self.metric_callback, &stmt, {
161 match query.execute(&mut *conn).await {
162 Ok(res) => Ok(res.into()),
163 Err(err) => Err(sqlx_error_to_exec_err(err)),
164 }
165 })
166 }
167
168 #[instrument(level = "trace", skip(sql))]
170 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
171 debug_print!("{}", sql);
172
173 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
174 match conn.execute(sqlx::AssertSqlSafe(sql.to_owned())).await {
175 Ok(res) => Ok(res.into()),
176 Err(err) => Err(sqlx_error_to_exec_err(err)),
177 }
178 }
179
180 #[instrument(level = "trace", skip(stmt))]
182 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
183 debug_print!("{}", stmt);
184
185 let query = sqlx_query(&stmt);
186 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
187 crate::metric::metric!(self.metric_callback, &stmt, {
188 match query.fetch_one(&mut *conn).await {
189 Ok(row) => Ok(Some(row.into())),
190 Err(err) => match err {
191 sqlx::Error::RowNotFound => Ok(None),
192 _ => Err(sqlx_error_to_query_err(err)),
193 },
194 }
195 })
196 }
197
198 #[instrument(level = "trace", skip(stmt))]
200 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
201 debug_print!("{}", stmt);
202
203 let query = sqlx_query(&stmt);
204 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
205 crate::metric::metric!(self.metric_callback, &stmt, {
206 match query.fetch_all(&mut *conn).await {
207 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
208 Err(err) => Err(sqlx_error_to_query_err(err)),
209 }
210 })
211 }
212
213 #[instrument(level = "trace", skip(stmt))]
215 #[cfg(feature = "stream")]
216 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
217 debug_print!("{}", stmt);
218
219 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
220 Ok(QueryStream::from((
221 conn,
222 stmt,
223 self.metric_callback.clone(),
224 )))
225 }
226
227 #[instrument(level = "trace")]
229 pub async fn begin(
230 &self,
231 isolation_level: Option<IsolationLevel>,
232 access_mode: Option<AccessMode>,
233 sqlite_transaction_mode: Option<SqliteTransactionMode>,
234 ) -> Result<DatabaseTransaction, DbErr> {
235 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
236 DatabaseTransaction::new_sqlite(
237 conn,
238 self.metric_callback.clone(),
239 self.record_stmt_in_spans,
240 isolation_level,
241 access_mode,
242 sqlite_transaction_mode,
243 )
244 .await
245 }
246
247 #[instrument(level = "trace", skip(callback))]
249 pub async fn transaction<F, T, E>(
250 &self,
251 callback: F,
252 isolation_level: Option<IsolationLevel>,
253 access_mode: Option<AccessMode>,
254 ) -> Result<T, TransactionError<E>>
255 where
256 F: for<'b> FnOnce(
257 &'b DatabaseTransaction,
258 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
259 + Send,
260 T: Send,
261 E: std::fmt::Display + std::fmt::Debug + Send,
262 {
263 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
264 let transaction = DatabaseTransaction::new_sqlite(
265 conn,
266 self.metric_callback.clone(),
267 self.record_stmt_in_spans,
268 isolation_level,
269 access_mode,
270 None,
271 )
272 .await
273 .map_err(|e| TransactionError::Connection(e))?;
274 transaction.run(callback).await
275 }
276
277 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
278 where
279 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
280 {
281 self.metric_callback = Some(Arc::new(callback));
282 }
283
284 pub async fn ping(&self) -> Result<(), DbErr> {
286 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
287 match conn.ping().await {
288 Ok(_) => Ok(()),
289 Err(err) => Err(sqlx_error_to_conn_err(err)),
290 }
291 }
292
293 pub async fn close(self) -> Result<(), DbErr> {
296 self.close_by_ref().await
297 }
298
299 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
301 self.pool.close().await;
302 Ok(())
303 }
304}
305
306impl From<SqliteRow> for QueryResult {
307 fn from(row: SqliteRow) -> QueryResult {
308 QueryResult {
309 row: QueryResultRow::SqlxSqlite(row),
310 }
311 }
312}
313
314impl From<SqliteQueryResult> for ExecResult {
315 fn from(result: SqliteQueryResult) -> ExecResult {
316 ExecResult {
317 result: ExecResultHolder::SqlxSqlite(result),
318 }
319 }
320}
321
322pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
323 let values = stmt
324 .values
325 .as_ref()
326 .map_or(Values(Vec::new()), |values| values.clone());
327 sqlx::query_with(sqlx::AssertSqlSafe(stmt.sql.as_str()), SqlxValues(values))
328}
329
330pub(crate) async fn set_transaction_config(
331 _conn: &mut PoolConnection<Sqlite>,
332 isolation_level: Option<IsolationLevel>,
333 access_mode: Option<AccessMode>,
334) -> Result<(), DbErr> {
335 if isolation_level.is_some() {
336 warn!("Setting isolation level in a SQLite transaction isn't supported");
337 }
338 if access_mode.is_some() {
339 warn!("Setting access mode in a SQLite transaction isn't supported");
340 }
341 Ok(())
342}
343
344#[cfg(feature = "sqlite-use-returning-for-3_35")]
345async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
346 let stmt = Statement {
347 sql: "SELECT sqlite_version()".to_string(),
348 values: None,
349 db_backend: crate::DbBackend::Sqlite,
350 };
351 conn.query_one(stmt)
352 .await?
353 .ok_or_else(|| {
354 DbErr::Conn(RuntimeErr::Internal(
355 "Error reading SQLite version".to_string(),
356 ))
357 })?
358 .try_get_by(0)
359}
360
361#[cfg(feature = "stream")]
362impl
363 From<(
364 PoolConnection<sqlx::Sqlite>,
365 Statement,
366 Option<crate::metric::Callback>,
367 )> for crate::QueryStream
368{
369 fn from(
370 (conn, stmt, metric_callback): (
371 PoolConnection<sqlx::Sqlite>,
372 Statement,
373 Option<crate::metric::Callback>,
374 ),
375 ) -> Self {
376 crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
377 }
378}
379
380impl crate::DatabaseTransaction {
381 pub(crate) async fn new_sqlite(
382 inner: PoolConnection<sqlx::Sqlite>,
383 metric_callback: Option<crate::metric::Callback>,
384 record_stmt_in_spans: bool,
385 isolation_level: Option<IsolationLevel>,
386 access_mode: Option<AccessMode>,
387 sqlite_transaction_mode: Option<SqliteTransactionMode>,
388 ) -> Result<crate::DatabaseTransaction, DbErr> {
389 Self::begin(
390 Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
391 crate::DbBackend::Sqlite,
392 metric_callback,
393 record_stmt_in_spans,
394 isolation_level,
395 access_mode,
396 sqlite_transaction_mode,
397 )
398 .await
399 }
400}
401
402#[cfg(feature = "proxy")]
403pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
404 use sea_query::Value;
407 use sqlx::{Column, Row, TypeInfo};
408 crate::ProxyRow {
409 values: row
410 .columns()
411 .iter()
412 .map(|c| {
413 (
414 c.name().to_string(),
415 match c.type_info().name() {
416 "BOOLEAN" => {
417 Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
418 }
419
420 "INTEGER" => {
421 Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
422 }
423
424 "BIGINT" | "INT8" => Value::BigInt(
425 row.try_get(c.ordinal()).expect("Failed to get big integer"),
426 ),
427
428 "REAL" => {
429 Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
430 }
431
432 "TEXT" => Value::String(
433 row.try_get::<Option<String>, _>(c.ordinal())
434 .expect("Failed to get string")
435 .map(Box::new),
436 ),
437
438 "BLOB" => Value::Bytes(
439 row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
440 .expect("Failed to get bytes")
441 .map(Box::new),
442 ),
443
444 #[cfg(feature = "with-chrono")]
445 "DATETIME" => {
446 use chrono::{DateTime, Utc};
447
448 Value::ChronoDateTimeUtc(
449 row.try_get::<Option<DateTime<Utc>>, _>(c.ordinal())
450 .expect("Failed to get timestamp")
451 .map(Box::new),
452 )
453 }
454 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
455 "DATETIME" => {
456 use time::OffsetDateTime;
457 Value::TimeDateTimeWithTimeZone(
458 row.try_get::<Option<OffsetDateTime>, _>(c.ordinal())
459 .expect("Failed to get timestamp")
460 .map(Box::new),
461 )
462 }
463 #[cfg(feature = "with-chrono")]
464 "DATE" => {
465 use chrono::NaiveDate;
466 Value::ChronoDate(
467 row.try_get::<Option<NaiveDate>, _>(c.ordinal())
468 .expect("Failed to get date")
469 .map(Box::new),
470 )
471 }
472 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
473 "DATE" => {
474 use time::Date;
475 Value::TimeDate(
476 row.try_get::<Option<Date>, _>(c.ordinal())
477 .expect("Failed to get date")
478 .map(Box::new),
479 )
480 }
481
482 #[cfg(feature = "with-chrono")]
483 "TIME" => {
484 use chrono::NaiveTime;
485 Value::ChronoTime(
486 row.try_get::<Option<NaiveTime>, _>(c.ordinal())
487 .expect("Failed to get time")
488 .map(Box::new),
489 )
490 }
491 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
492 "TIME" => {
493 use time::Time;
494 Value::TimeTime(
495 row.try_get::<Option<Time>, _>(c.ordinal())
496 .expect("Failed to get time")
497 .map(Box::new),
498 )
499 }
500
501 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
502 },
503 )
504 })
505 .collect(),
506 }
507}