1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 pool::PoolConnection,
8 sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
9 Connection, Executor, Sqlite, SqlitePool,
10};
11
12use sea_query_binder::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16 debug_print, error::*, executor::*, sqlx_error_to_exec_err, AccessMode, ConnectOptions,
17 DatabaseConnection, DatabaseTransaction, IsolationLevel, QueryStream, Statement,
18 TransactionError,
19};
20
21use super::sqlx_common::*;
22
23#[derive(Debug)]
25pub struct SqlxSqliteConnector;
26
27#[derive(Clone)]
29pub struct SqlxSqlitePoolConnection {
30 pub(crate) pool: SqlitePool,
31 metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxSqlitePoolConnection {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
37 }
38}
39
40impl From<SqlitePool> for SqlxSqlitePoolConnection {
41 fn from(pool: SqlitePool) -> Self {
42 SqlxSqlitePoolConnection {
43 pool,
44 metric_callback: None,
45 }
46 }
47}
48
49impl From<SqlitePool> for DatabaseConnection {
50 fn from(pool: SqlitePool) -> Self {
51 DatabaseConnection::SqlxSqlitePoolConnection(pool.into())
52 }
53}
54
55impl SqlxSqliteConnector {
56 pub fn accepts(string: &str) -> bool {
58 string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
59 }
60
61 #[instrument(level = "trace")]
63 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64 let mut options = options;
65 let mut sqlx_opts = options
66 .url
67 .parse::<SqliteConnectOptions>()
68 .map_err(sqlx_error_to_conn_err)?;
69 if let Some(sqlcipher_key) = &options.sqlcipher_key {
70 sqlx_opts = sqlx_opts.pragma("key", sqlcipher_key.clone());
71 }
72 use sqlx::ConnectOptions;
73 if !options.sqlx_logging {
74 sqlx_opts = sqlx_opts.disable_statement_logging();
75 } else {
76 sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
77 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
78 sqlx_opts = sqlx_opts.log_slow_statements(
79 options.sqlx_slow_statements_logging_level,
80 options.sqlx_slow_statements_logging_threshold,
81 );
82 }
83 }
84
85 if options.get_max_connections().is_none() {
86 options.max_connections(1);
87 }
88
89 if let Some(f) = &options.sqlite_opts_fn {
90 sqlx_opts = f(sqlx_opts);
91 }
92
93 let pool = if options.connect_lazy {
94 options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
95 } else {
96 options
97 .sqlx_pool_options()
98 .connect_with(sqlx_opts)
99 .await
100 .map_err(sqlx_error_to_conn_err)?
101 };
102
103 let pool = SqlxSqlitePoolConnection {
104 pool,
105 metric_callback: None,
106 };
107
108 #[cfg(feature = "sqlite-use-returning-for-3_35")]
109 {
110 let version = get_version(&pool).await?;
111 ensure_returning_version(&version)?;
112 }
113
114 Ok(DatabaseConnection::SqlxSqlitePoolConnection(pool))
115 }
116}
117
118impl SqlxSqliteConnector {
119 pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
121 DatabaseConnection::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
122 pool,
123 metric_callback: None,
124 })
125 }
126}
127
128impl SqlxSqlitePoolConnection {
129 #[instrument(level = "trace")]
131 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
132 debug_print!("{}", stmt);
133
134 let query = sqlx_query(&stmt);
135 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
136 crate::metric::metric!(self.metric_callback, &stmt, {
137 match query.execute(&mut *conn).await {
138 Ok(res) => Ok(res.into()),
139 Err(err) => Err(sqlx_error_to_exec_err(err)),
140 }
141 })
142 }
143
144 #[instrument(level = "trace")]
146 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
147 debug_print!("{}", sql);
148
149 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
150 match conn.execute(sql).await {
151 Ok(res) => Ok(res.into()),
152 Err(err) => Err(sqlx_error_to_exec_err(err)),
153 }
154 }
155
156 #[instrument(level = "trace")]
158 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
159 debug_print!("{}", stmt);
160
161 let query = sqlx_query(&stmt);
162 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
163 crate::metric::metric!(self.metric_callback, &stmt, {
164 match query.fetch_one(&mut *conn).await {
165 Ok(row) => Ok(Some(row.into())),
166 Err(err) => match err {
167 sqlx::Error::RowNotFound => Ok(None),
168 _ => Err(sqlx_error_to_query_err(err)),
169 },
170 }
171 })
172 }
173
174 #[instrument(level = "trace")]
176 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
177 debug_print!("{}", stmt);
178
179 let query = sqlx_query(&stmt);
180 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
181 crate::metric::metric!(self.metric_callback, &stmt, {
182 match query.fetch_all(&mut *conn).await {
183 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
184 Err(err) => Err(sqlx_error_to_query_err(err)),
185 }
186 })
187 }
188
189 #[instrument(level = "trace")]
191 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
192 debug_print!("{}", stmt);
193
194 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
195 Ok(QueryStream::from((
196 conn,
197 stmt,
198 self.metric_callback.clone(),
199 )))
200 }
201
202 #[instrument(level = "trace")]
204 pub async fn begin(
205 &self,
206 isolation_level: Option<IsolationLevel>,
207 access_mode: Option<AccessMode>,
208 ) -> Result<DatabaseTransaction, DbErr> {
209 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
210 DatabaseTransaction::new_sqlite(
211 conn,
212 self.metric_callback.clone(),
213 isolation_level,
214 access_mode,
215 )
216 .await
217 }
218
219 #[instrument(level = "trace", skip(callback))]
221 pub async fn transaction<F, T, E>(
222 &self,
223 callback: F,
224 isolation_level: Option<IsolationLevel>,
225 access_mode: Option<AccessMode>,
226 ) -> Result<T, TransactionError<E>>
227 where
228 F: for<'b> FnOnce(
229 &'b DatabaseTransaction,
230 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
231 + Send,
232 T: Send,
233 E: std::fmt::Display + std::fmt::Debug + Send,
234 {
235 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
236 let transaction = DatabaseTransaction::new_sqlite(
237 conn,
238 self.metric_callback.clone(),
239 isolation_level,
240 access_mode,
241 )
242 .await
243 .map_err(|e| TransactionError::Connection(e))?;
244 transaction.run(callback).await
245 }
246
247 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
248 where
249 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
250 {
251 self.metric_callback = Some(Arc::new(callback));
252 }
253
254 pub async fn ping(&self) -> Result<(), DbErr> {
256 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
257 match conn.ping().await {
258 Ok(_) => Ok(()),
259 Err(err) => Err(sqlx_error_to_conn_err(err)),
260 }
261 }
262
263 pub async fn close(self) -> Result<(), DbErr> {
266 self.close_by_ref().await
267 }
268
269 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
271 self.pool.close().await;
272 Ok(())
273 }
274}
275
276impl From<SqliteRow> for QueryResult {
277 fn from(row: SqliteRow) -> QueryResult {
278 QueryResult {
279 row: QueryResultRow::SqlxSqlite(row),
280 }
281 }
282}
283
284impl From<SqliteQueryResult> for ExecResult {
285 fn from(result: SqliteQueryResult) -> ExecResult {
286 ExecResult {
287 result: ExecResultHolder::SqlxSqlite(result),
288 }
289 }
290}
291
292pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
293 let values = stmt
294 .values
295 .as_ref()
296 .map_or(Values(Vec::new()), |values| values.clone());
297 sqlx::query_with(&stmt.sql, SqlxValues(values))
298}
299
300pub(crate) async fn set_transaction_config(
301 _conn: &mut PoolConnection<Sqlite>,
302 isolation_level: Option<IsolationLevel>,
303 access_mode: Option<AccessMode>,
304) -> Result<(), DbErr> {
305 if isolation_level.is_some() {
306 warn!("Setting isolation level in a SQLite transaction isn't supported");
307 }
308 if access_mode.is_some() {
309 warn!("Setting access mode in a SQLite transaction isn't supported");
310 }
311 Ok(())
312}
313
314#[cfg(feature = "sqlite-use-returning-for-3_35")]
315async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
316 let stmt = Statement {
317 sql: "SELECT sqlite_version()".to_string(),
318 values: None,
319 db_backend: crate::DbBackend::Sqlite,
320 };
321 conn.query_one(stmt)
322 .await?
323 .ok_or_else(|| {
324 DbErr::Conn(RuntimeErr::Internal(
325 "Error reading SQLite version".to_string(),
326 ))
327 })?
328 .try_get_by(0)
329}
330
331#[cfg(feature = "sqlite-use-returning-for-3_35")]
332fn ensure_returning_version(version: &str) -> Result<(), DbErr> {
333 let mut parts = version.trim().split('.').map(|part| {
334 part.parse::<u32>().map_err(|_| {
335 DbErr::Conn(RuntimeErr::Internal(
336 "Error parsing SQLite version".to_string(),
337 ))
338 })
339 });
340
341 let mut extract_next = || {
342 parts.next().transpose().and_then(|part| {
343 part.ok_or_else(|| {
344 DbErr::Conn(RuntimeErr::Internal("SQLite version too short".to_string()))
345 })
346 })
347 };
348
349 let major = extract_next()?;
350 let minor = extract_next()?;
351
352 if major > 3 || (major == 3 && minor >= 35) {
353 Ok(())
354 } else {
355 Err(DbErr::Conn(RuntimeErr::Internal(
356 "SQLite version does not support returning".to_string(),
357 )))
358 }
359}
360
361impl
362 From<(
363 PoolConnection<sqlx::Sqlite>,
364 Statement,
365 Option<crate::metric::Callback>,
366 )> for crate::QueryStream
367{
368 fn from(
369 (conn, stmt, metric_callback): (
370 PoolConnection<sqlx::Sqlite>,
371 Statement,
372 Option<crate::metric::Callback>,
373 ),
374 ) -> Self {
375 crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
376 }
377}
378
379impl crate::DatabaseTransaction {
380 pub(crate) async fn new_sqlite(
381 inner: PoolConnection<sqlx::Sqlite>,
382 metric_callback: Option<crate::metric::Callback>,
383 isolation_level: Option<IsolationLevel>,
384 access_mode: Option<AccessMode>,
385 ) -> Result<crate::DatabaseTransaction, DbErr> {
386 Self::begin(
387 Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
388 crate::DbBackend::Sqlite,
389 metric_callback,
390 isolation_level,
391 access_mode,
392 )
393 .await
394 }
395}
396
397#[cfg(feature = "proxy")]
398pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
399 use sea_query::Value;
402 use sqlx::{Column, Row, TypeInfo};
403 crate::ProxyRow {
404 values: row
405 .columns()
406 .iter()
407 .map(|c| {
408 (
409 c.name().to_string(),
410 match c.type_info().name() {
411 "BOOLEAN" => {
412 Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
413 }
414
415 "INTEGER" => {
416 Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
417 }
418
419 "BIGINT" | "INT8" => Value::BigInt(
420 row.try_get(c.ordinal()).expect("Failed to get big integer"),
421 ),
422
423 "REAL" => {
424 Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
425 }
426
427 "TEXT" => Value::String(
428 row.try_get::<Option<String>, _>(c.ordinal())
429 .expect("Failed to get string")
430 .map(Box::new),
431 ),
432
433 "BLOB" => Value::Bytes(
434 row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
435 .expect("Failed to get bytes")
436 .map(Box::new),
437 ),
438
439 #[cfg(feature = "with-chrono")]
440 "DATETIME" => {
441 use chrono::{DateTime, Utc};
442
443 Value::ChronoDateTimeUtc(
444 row.try_get::<Option<DateTime<Utc>>, _>(c.ordinal())
445 .expect("Failed to get timestamp")
446 .map(Box::new),
447 )
448 }
449 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
450 "DATETIME" => {
451 use time::OffsetDateTime;
452 Value::TimeDateTimeWithTimeZone(
453 row.try_get::<Option<OffsetDateTime>, _>(c.ordinal())
454 .expect("Failed to get timestamp")
455 .map(Box::new),
456 )
457 }
458 #[cfg(feature = "with-chrono")]
459 "DATE" => {
460 use chrono::NaiveDate;
461 Value::ChronoDate(
462 row.try_get::<Option<NaiveDate>, _>(c.ordinal())
463 .expect("Failed to get date")
464 .map(Box::new),
465 )
466 }
467 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
468 "DATE" => {
469 use time::Date;
470 Value::TimeDate(
471 row.try_get::<Option<Date>, _>(c.ordinal())
472 .expect("Failed to get date")
473 .map(Box::new),
474 )
475 }
476
477 #[cfg(feature = "with-chrono")]
478 "TIME" => {
479 use chrono::NaiveTime;
480 Value::ChronoTime(
481 row.try_get::<Option<NaiveTime>, _>(c.ordinal())
482 .expect("Failed to get time")
483 .map(Box::new),
484 )
485 }
486 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
487 "TIME" => {
488 use time::Time;
489 Value::TimeTime(
490 row.try_get::<Option<Time>, _>(c.ordinal())
491 .expect("Failed to get time")
492 .map(Box::new),
493 )
494 }
495
496 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
497 },
498 )
499 })
500 .collect(),
501 }
502}
503
504#[cfg(all(test, feature = "sqlite-use-returning-for-3_35"))]
505mod tests {
506 use super::*;
507
508 #[test]
509 fn test_ensure_returning_version() {
510 assert!(ensure_returning_version("").is_err());
511 assert!(ensure_returning_version(".").is_err());
512 assert!(ensure_returning_version(".a").is_err());
513 assert!(ensure_returning_version(".4.9").is_err());
514 assert!(ensure_returning_version("a").is_err());
515 assert!(ensure_returning_version("1.").is_err());
516 assert!(ensure_returning_version("1.a").is_err());
517
518 assert!(ensure_returning_version("1.1").is_err());
519 assert!(ensure_returning_version("1.0.").is_err());
520 assert!(ensure_returning_version("1.0.0").is_err());
521 assert!(ensure_returning_version("2.0.0").is_err());
522 assert!(ensure_returning_version("3.34.0").is_err());
523 assert!(ensure_returning_version("3.34.999").is_err());
524
525 assert!(ensure_returning_version("3.35.0").is_ok());
527 assert!(ensure_returning_version("3.35.1").is_ok());
528 assert!(ensure_returning_version("3.36.0").is_ok());
529 assert!(ensure_returning_version("4.0.0").is_ok());
530 assert!(ensure_returning_version("99.0.0").is_ok());
531 }
532}