1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 Connection, Executor, Sqlite, SqlitePool,
8 pool::PoolConnection,
9 sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16 AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17 IsolationLevel, QueryStream, SqliteTransactionMode, Statement, TransactionError, debug_print,
18 error::*, executor::*, sqlx_error_to_exec_err,
19};
20
21use super::sqlx_common::*;
22
23#[derive(Debug)]
25pub struct SqlxSqliteConnector;
26
27#[derive(Clone)]
29pub struct SqlxSqlitePoolConnection {
30 pub(crate) pool: SqlitePool,
31 metric_callback: Option<crate::metric::Callback>,
32 pub(crate) record_stmt_in_spans: bool,
33}
34
35impl std::fmt::Debug for SqlxSqlitePoolConnection {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
38 }
39}
40
41impl From<SqlitePool> for SqlxSqlitePoolConnection {
42 fn from(pool: SqlitePool) -> Self {
43 SqlxSqlitePoolConnection {
44 pool,
45 metric_callback: None,
46 record_stmt_in_spans: true,
47 }
48 }
49}
50
51impl From<SqlitePool> for DatabaseConnection {
52 fn from(pool: SqlitePool) -> Self {
53 DatabaseConnectionType::SqlxSqlitePoolConnection(pool.into()).into()
54 }
55}
56
57impl SqlxSqliteConnector {
58 pub fn accepts(string: &str) -> bool {
60 string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
61 }
62
63 #[instrument(level = "trace")]
65 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
66 let mut options = options;
67 let record_stmt_in_spans = options.get_record_stmt_in_spans();
68 let mut sqlx_opts = options
69 .url
70 .parse::<SqliteConnectOptions>()
71 .map_err(sqlx_error_to_conn_err)?;
72 if let Some(sqlcipher_key) = &options.sqlcipher_key {
73 sqlx_opts = sqlx_opts.pragma("key", sqlcipher_key.clone());
74 }
75 use sqlx::ConnectOptions;
76 if !options.sqlx_logging {
77 sqlx_opts = sqlx_opts.disable_statement_logging();
78 } else {
79 sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
80 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
81 sqlx_opts = sqlx_opts.log_slow_statements(
82 options.sqlx_slow_statements_logging_level,
83 options.sqlx_slow_statements_logging_threshold,
84 );
85 }
86 }
87
88 if options.get_max_connections().is_none() {
89 options.max_connections(1);
90 }
91
92 if let Some(f) = &options.sqlite_opts_fn {
93 sqlx_opts = f(sqlx_opts);
94 }
95
96 let after_conn = options.after_connect.clone();
97 let connect_lazy = options.connect_lazy;
98 let sqlite_pool_opts_fn = options.sqlite_pool_opts_fn.clone();
99 let mut pool_options = options.sqlx_pool_options();
100
101 if let Some(f) = &sqlite_pool_opts_fn {
102 pool_options = f(pool_options);
103 }
104
105 let pool = if connect_lazy {
106 pool_options.connect_lazy_with(sqlx_opts)
107 } else {
108 pool_options
109 .connect_with(sqlx_opts)
110 .await
111 .map_err(sqlx_error_to_conn_err)?
112 };
113
114 let pool = SqlxSqlitePoolConnection {
115 pool,
116 metric_callback: None,
117 record_stmt_in_spans,
118 };
119
120 #[cfg(feature = "sqlite-use-returning-for-3_35")]
121 {
122 let version = get_version(&pool).await?;
123 super::sqlite::ensure_returning_version(&version)?;
124 }
125
126 let conn: DatabaseConnection =
127 DatabaseConnectionType::SqlxSqlitePoolConnection(pool).into();
128
129 if let Some(cb) = after_conn {
130 cb(conn.clone()).await?;
131 }
132
133 Ok(conn)
134 }
135}
136
137impl SqlxSqliteConnector {
138 pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
140 DatabaseConnectionType::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
141 pool,
142 metric_callback: None,
143 record_stmt_in_spans: true,
144 })
145 .into()
146 }
147}
148
149impl SqlxSqlitePoolConnection {
150 #[instrument(level = "trace")]
152 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
153 debug_print!("{}", stmt);
154
155 let query = sqlx_query(&stmt);
156 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
157 crate::metric::metric!(self.metric_callback, &stmt, {
158 match query.execute(&mut *conn).await {
159 Ok(res) => Ok(res.into()),
160 Err(err) => Err(sqlx_error_to_exec_err(err)),
161 }
162 })
163 }
164
165 #[instrument(level = "trace")]
167 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
168 debug_print!("{}", sql);
169
170 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
171 match conn.execute(sqlx::AssertSqlSafe(sql.to_owned())).await {
172 Ok(res) => Ok(res.into()),
173 Err(err) => Err(sqlx_error_to_exec_err(err)),
174 }
175 }
176
177 #[instrument(level = "trace")]
179 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
180 debug_print!("{}", stmt);
181
182 let query = sqlx_query(&stmt);
183 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
184 crate::metric::metric!(self.metric_callback, &stmt, {
185 match query.fetch_one(&mut *conn).await {
186 Ok(row) => Ok(Some(row.into())),
187 Err(err) => match err {
188 sqlx::Error::RowNotFound => Ok(None),
189 _ => Err(sqlx_error_to_query_err(err)),
190 },
191 }
192 })
193 }
194
195 #[instrument(level = "trace")]
197 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
198 debug_print!("{}", stmt);
199
200 let query = sqlx_query(&stmt);
201 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
202 crate::metric::metric!(self.metric_callback, &stmt, {
203 match query.fetch_all(&mut *conn).await {
204 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
205 Err(err) => Err(sqlx_error_to_query_err(err)),
206 }
207 })
208 }
209
210 #[instrument(level = "trace")]
212 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
213 debug_print!("{}", stmt);
214
215 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
216 Ok(QueryStream::from((
217 conn,
218 stmt,
219 self.metric_callback.clone(),
220 )))
221 }
222
223 #[instrument(level = "trace")]
225 pub async fn begin(
226 &self,
227 isolation_level: Option<IsolationLevel>,
228 access_mode: Option<AccessMode>,
229 sqlite_transaction_mode: Option<SqliteTransactionMode>,
230 ) -> Result<DatabaseTransaction, DbErr> {
231 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
232 DatabaseTransaction::new_sqlite(
233 conn,
234 self.metric_callback.clone(),
235 self.record_stmt_in_spans,
236 isolation_level,
237 access_mode,
238 sqlite_transaction_mode,
239 )
240 .await
241 }
242
243 #[instrument(level = "trace", skip(callback))]
245 pub async fn transaction<F, T, E>(
246 &self,
247 callback: F,
248 isolation_level: Option<IsolationLevel>,
249 access_mode: Option<AccessMode>,
250 ) -> Result<T, TransactionError<E>>
251 where
252 F: for<'b> FnOnce(
253 &'b DatabaseTransaction,
254 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
255 + Send,
256 T: Send,
257 E: std::fmt::Display + std::fmt::Debug + Send,
258 {
259 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
260 let transaction = DatabaseTransaction::new_sqlite(
261 conn,
262 self.metric_callback.clone(),
263 self.record_stmt_in_spans,
264 isolation_level,
265 access_mode,
266 None,
267 )
268 .await
269 .map_err(|e| TransactionError::Connection(e))?;
270 transaction.run(callback).await
271 }
272
273 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
274 where
275 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
276 {
277 self.metric_callback = Some(Arc::new(callback));
278 }
279
280 pub async fn ping(&self) -> Result<(), DbErr> {
282 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
283 match conn.ping().await {
284 Ok(_) => Ok(()),
285 Err(err) => Err(sqlx_error_to_conn_err(err)),
286 }
287 }
288
289 pub async fn close(self) -> Result<(), DbErr> {
292 self.close_by_ref().await
293 }
294
295 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
297 self.pool.close().await;
298 Ok(())
299 }
300}
301
302impl From<SqliteRow> for QueryResult {
303 fn from(row: SqliteRow) -> QueryResult {
304 QueryResult {
305 row: QueryResultRow::SqlxSqlite(row),
306 }
307 }
308}
309
310impl From<SqliteQueryResult> for ExecResult {
311 fn from(result: SqliteQueryResult) -> ExecResult {
312 ExecResult {
313 result: ExecResultHolder::SqlxSqlite(result),
314 }
315 }
316}
317
318pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
319 let values = stmt
320 .values
321 .as_ref()
322 .map_or(Values(Vec::new()), |values| values.clone());
323 sqlx::query_with(sqlx::AssertSqlSafe(stmt.sql.as_str()), SqlxValues(values))
324}
325
326pub(crate) async fn set_transaction_config(
327 _conn: &mut PoolConnection<Sqlite>,
328 isolation_level: Option<IsolationLevel>,
329 access_mode: Option<AccessMode>,
330) -> Result<(), DbErr> {
331 if isolation_level.is_some() {
332 warn!("Setting isolation level in a SQLite transaction isn't supported");
333 }
334 if access_mode.is_some() {
335 warn!("Setting access mode in a SQLite transaction isn't supported");
336 }
337 Ok(())
338}
339
340#[cfg(feature = "sqlite-use-returning-for-3_35")]
341async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
342 let stmt = Statement {
343 sql: "SELECT sqlite_version()".to_string(),
344 values: None,
345 db_backend: crate::DbBackend::Sqlite,
346 };
347 conn.query_one(stmt)
348 .await?
349 .ok_or_else(|| {
350 DbErr::Conn(RuntimeErr::Internal(
351 "Error reading SQLite version".to_string(),
352 ))
353 })?
354 .try_get_by(0)
355}
356
357impl
358 From<(
359 PoolConnection<sqlx::Sqlite>,
360 Statement,
361 Option<crate::metric::Callback>,
362 )> for crate::QueryStream
363{
364 fn from(
365 (conn, stmt, metric_callback): (
366 PoolConnection<sqlx::Sqlite>,
367 Statement,
368 Option<crate::metric::Callback>,
369 ),
370 ) -> Self {
371 crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
372 }
373}
374
375impl crate::DatabaseTransaction {
376 pub(crate) async fn new_sqlite(
377 inner: PoolConnection<sqlx::Sqlite>,
378 metric_callback: Option<crate::metric::Callback>,
379 record_stmt_in_spans: bool,
380 isolation_level: Option<IsolationLevel>,
381 access_mode: Option<AccessMode>,
382 sqlite_transaction_mode: Option<SqliteTransactionMode>,
383 ) -> Result<crate::DatabaseTransaction, DbErr> {
384 Self::begin(
385 Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
386 crate::DbBackend::Sqlite,
387 metric_callback,
388 record_stmt_in_spans,
389 isolation_level,
390 access_mode,
391 sqlite_transaction_mode,
392 )
393 .await
394 }
395}
396
397#[cfg(feature = "proxy")]
398pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
399 use sea_query::Value;
402 use sqlx::{Column, Row, TypeInfo};
403 crate::ProxyRow {
404 values: row
405 .columns()
406 .iter()
407 .map(|c| {
408 (
409 c.name().to_string(),
410 match c.type_info().name() {
411 "BOOLEAN" => {
412 Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
413 }
414
415 "INTEGER" => {
416 Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
417 }
418
419 "BIGINT" | "INT8" => Value::BigInt(
420 row.try_get(c.ordinal()).expect("Failed to get big integer"),
421 ),
422
423 "REAL" => {
424 Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
425 }
426
427 "TEXT" => Value::String(
428 row.try_get::<Option<String>, _>(c.ordinal())
429 .expect("Failed to get string")
430 .map(Box::new),
431 ),
432
433 "BLOB" => Value::Bytes(
434 row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
435 .expect("Failed to get bytes")
436 .map(Box::new),
437 ),
438
439 #[cfg(feature = "with-chrono")]
440 "DATETIME" => {
441 use chrono::{DateTime, Utc};
442
443 Value::ChronoDateTimeUtc(
444 row.try_get::<Option<DateTime<Utc>>, _>(c.ordinal())
445 .expect("Failed to get timestamp")
446 .map(Box::new),
447 )
448 }
449 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
450 "DATETIME" => {
451 use time::OffsetDateTime;
452 Value::TimeDateTimeWithTimeZone(
453 row.try_get::<Option<OffsetDateTime>, _>(c.ordinal())
454 .expect("Failed to get timestamp")
455 .map(Box::new),
456 )
457 }
458 #[cfg(feature = "with-chrono")]
459 "DATE" => {
460 use chrono::NaiveDate;
461 Value::ChronoDate(
462 row.try_get::<Option<NaiveDate>, _>(c.ordinal())
463 .expect("Failed to get date")
464 .map(Box::new),
465 )
466 }
467 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
468 "DATE" => {
469 use time::Date;
470 Value::TimeDate(
471 row.try_get::<Option<Date>, _>(c.ordinal())
472 .expect("Failed to get date")
473 .map(Box::new),
474 )
475 }
476
477 #[cfg(feature = "with-chrono")]
478 "TIME" => {
479 use chrono::NaiveTime;
480 Value::ChronoTime(
481 row.try_get::<Option<NaiveTime>, _>(c.ordinal())
482 .expect("Failed to get time")
483 .map(Box::new),
484 )
485 }
486 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
487 "TIME" => {
488 use time::Time;
489 Value::TimeTime(
490 row.try_get::<Option<Time>, _>(c.ordinal())
491 .expect("Failed to get time")
492 .map(Box::new),
493 )
494 }
495
496 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
497 },
498 )
499 })
500 .collect(),
501 }
502}