1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 Connection, Executor, Sqlite, SqlitePool,
8 pool::PoolConnection,
9 sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16 AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17 IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*,
18 sqlx_error_to_exec_err,
19};
20
21use super::sqlx_common::*;
22
23#[derive(Debug)]
25pub struct SqlxSqliteConnector;
26
27#[derive(Clone)]
29pub struct SqlxSqlitePoolConnection {
30 pub(crate) pool: SqlitePool,
31 metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxSqlitePoolConnection {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
37 }
38}
39
40impl From<SqlitePool> for SqlxSqlitePoolConnection {
41 fn from(pool: SqlitePool) -> Self {
42 SqlxSqlitePoolConnection {
43 pool,
44 metric_callback: None,
45 }
46 }
47}
48
49impl From<SqlitePool> for DatabaseConnection {
50 fn from(pool: SqlitePool) -> Self {
51 DatabaseConnectionType::SqlxSqlitePoolConnection(pool.into()).into()
52 }
53}
54
55impl SqlxSqliteConnector {
56 pub fn accepts(string: &str) -> bool {
58 string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
59 }
60
61 #[instrument(level = "trace")]
63 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64 let mut options = options;
65 let mut sqlx_opts = options
66 .url
67 .parse::<SqliteConnectOptions>()
68 .map_err(sqlx_error_to_conn_err)?;
69 if let Some(sqlcipher_key) = &options.sqlcipher_key {
70 sqlx_opts = sqlx_opts.pragma("key", sqlcipher_key.clone());
71 }
72 use sqlx::ConnectOptions;
73 if !options.sqlx_logging {
74 sqlx_opts = sqlx_opts.disable_statement_logging();
75 } else {
76 sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
77 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
78 sqlx_opts = sqlx_opts.log_slow_statements(
79 options.sqlx_slow_statements_logging_level,
80 options.sqlx_slow_statements_logging_threshold,
81 );
82 }
83 }
84
85 if options.get_max_connections().is_none() {
86 options.max_connections(1);
87 }
88
89 if let Some(f) = &options.sqlite_opts_fn {
90 sqlx_opts = f(sqlx_opts);
91 }
92
93 let after_conn = options.after_connect.clone();
94
95 let pool = if options.connect_lazy {
96 options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
97 } else {
98 options
99 .sqlx_pool_options()
100 .connect_with(sqlx_opts)
101 .await
102 .map_err(sqlx_error_to_conn_err)?
103 };
104
105 let pool = SqlxSqlitePoolConnection {
106 pool,
107 metric_callback: None,
108 };
109
110 #[cfg(feature = "sqlite-use-returning-for-3_35")]
111 {
112 let version = get_version(&pool).await?;
113 ensure_returning_version(&version)?;
114 }
115
116 let conn: DatabaseConnection =
117 DatabaseConnectionType::SqlxSqlitePoolConnection(pool).into();
118
119 if let Some(cb) = after_conn {
120 cb(conn.clone()).await?;
121 }
122
123 Ok(conn)
124 }
125}
126
127impl SqlxSqliteConnector {
128 pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
130 DatabaseConnectionType::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
131 pool,
132 metric_callback: None,
133 })
134 .into()
135 }
136}
137
138impl SqlxSqlitePoolConnection {
139 #[instrument(level = "trace")]
141 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
142 debug_print!("{}", stmt);
143
144 let query = sqlx_query(&stmt);
145 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
146 crate::metric::metric!(self.metric_callback, &stmt, {
147 match query.execute(&mut *conn).await {
148 Ok(res) => Ok(res.into()),
149 Err(err) => Err(sqlx_error_to_exec_err(err)),
150 }
151 })
152 }
153
154 #[instrument(level = "trace")]
156 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
157 debug_print!("{}", sql);
158
159 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
160 match conn.execute(sql).await {
161 Ok(res) => Ok(res.into()),
162 Err(err) => Err(sqlx_error_to_exec_err(err)),
163 }
164 }
165
166 #[instrument(level = "trace")]
168 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
169 debug_print!("{}", stmt);
170
171 let query = sqlx_query(&stmt);
172 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
173 crate::metric::metric!(self.metric_callback, &stmt, {
174 match query.fetch_one(&mut *conn).await {
175 Ok(row) => Ok(Some(row.into())),
176 Err(err) => match err {
177 sqlx::Error::RowNotFound => Ok(None),
178 _ => Err(sqlx_error_to_query_err(err)),
179 },
180 }
181 })
182 }
183
184 #[instrument(level = "trace")]
186 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
187 debug_print!("{}", stmt);
188
189 let query = sqlx_query(&stmt);
190 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
191 crate::metric::metric!(self.metric_callback, &stmt, {
192 match query.fetch_all(&mut *conn).await {
193 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
194 Err(err) => Err(sqlx_error_to_query_err(err)),
195 }
196 })
197 }
198
199 #[instrument(level = "trace")]
201 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
202 debug_print!("{}", stmt);
203
204 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
205 Ok(QueryStream::from((
206 conn,
207 stmt,
208 self.metric_callback.clone(),
209 )))
210 }
211
212 #[instrument(level = "trace")]
214 pub async fn begin(
215 &self,
216 isolation_level: Option<IsolationLevel>,
217 access_mode: Option<AccessMode>,
218 ) -> Result<DatabaseTransaction, DbErr> {
219 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
220 DatabaseTransaction::new_sqlite(
221 conn,
222 self.metric_callback.clone(),
223 isolation_level,
224 access_mode,
225 )
226 .await
227 }
228
229 #[instrument(level = "trace", skip(callback))]
231 pub async fn transaction<F, T, E>(
232 &self,
233 callback: F,
234 isolation_level: Option<IsolationLevel>,
235 access_mode: Option<AccessMode>,
236 ) -> Result<T, TransactionError<E>>
237 where
238 F: for<'b> FnOnce(
239 &'b DatabaseTransaction,
240 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
241 + Send,
242 T: Send,
243 E: std::fmt::Display + std::fmt::Debug + Send,
244 {
245 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
246 let transaction = DatabaseTransaction::new_sqlite(
247 conn,
248 self.metric_callback.clone(),
249 isolation_level,
250 access_mode,
251 )
252 .await
253 .map_err(|e| TransactionError::Connection(e))?;
254 transaction.run(callback).await
255 }
256
257 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
258 where
259 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
260 {
261 self.metric_callback = Some(Arc::new(callback));
262 }
263
264 pub async fn ping(&self) -> Result<(), DbErr> {
266 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
267 match conn.ping().await {
268 Ok(_) => Ok(()),
269 Err(err) => Err(sqlx_error_to_conn_err(err)),
270 }
271 }
272
273 pub async fn close(self) -> Result<(), DbErr> {
276 self.close_by_ref().await
277 }
278
279 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
281 self.pool.close().await;
282 Ok(())
283 }
284}
285
286impl From<SqliteRow> for QueryResult {
287 fn from(row: SqliteRow) -> QueryResult {
288 QueryResult {
289 row: QueryResultRow::SqlxSqlite(row),
290 }
291 }
292}
293
294impl From<SqliteQueryResult> for ExecResult {
295 fn from(result: SqliteQueryResult) -> ExecResult {
296 ExecResult {
297 result: ExecResultHolder::SqlxSqlite(result),
298 }
299 }
300}
301
302pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
303 let values = stmt
304 .values
305 .as_ref()
306 .map_or(Values(Vec::new()), |values| values.clone());
307 sqlx::query_with(&stmt.sql, SqlxValues(values))
308}
309
310pub(crate) async fn set_transaction_config(
311 _conn: &mut PoolConnection<Sqlite>,
312 isolation_level: Option<IsolationLevel>,
313 access_mode: Option<AccessMode>,
314) -> Result<(), DbErr> {
315 if isolation_level.is_some() {
316 warn!("Setting isolation level in a SQLite transaction isn't supported");
317 }
318 if access_mode.is_some() {
319 warn!("Setting access mode in a SQLite transaction isn't supported");
320 }
321 Ok(())
322}
323
324#[cfg(feature = "sqlite-use-returning-for-3_35")]
325async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
326 let stmt = Statement {
327 sql: "SELECT sqlite_version()".to_string(),
328 values: None,
329 db_backend: crate::DbBackend::Sqlite,
330 };
331 conn.query_one(stmt)
332 .await?
333 .ok_or_else(|| {
334 DbErr::Conn(RuntimeErr::Internal(
335 "Error reading SQLite version".to_string(),
336 ))
337 })?
338 .try_get_by(0)
339}
340
341#[cfg(feature = "sqlite-use-returning-for-3_35")]
342fn ensure_returning_version(version: &str) -> Result<(), DbErr> {
343 let mut parts = version.trim().split('.').map(|part| {
344 part.parse::<u32>().map_err(|_| {
345 DbErr::Conn(RuntimeErr::Internal(
346 "Error parsing SQLite version".to_string(),
347 ))
348 })
349 });
350
351 let mut extract_next = || {
352 parts.next().transpose().and_then(|part| {
353 part.ok_or_else(|| {
354 DbErr::Conn(RuntimeErr::Internal("SQLite version too short".to_string()))
355 })
356 })
357 };
358
359 let major = extract_next()?;
360 let minor = extract_next()?;
361
362 if major > 3 || (major == 3 && minor >= 35) {
363 Ok(())
364 } else {
365 Err(DbErr::Conn(RuntimeErr::Internal(
366 "SQLite version does not support returning".to_string(),
367 )))
368 }
369}
370
371impl
372 From<(
373 PoolConnection<sqlx::Sqlite>,
374 Statement,
375 Option<crate::metric::Callback>,
376 )> for crate::QueryStream
377{
378 fn from(
379 (conn, stmt, metric_callback): (
380 PoolConnection<sqlx::Sqlite>,
381 Statement,
382 Option<crate::metric::Callback>,
383 ),
384 ) -> Self {
385 crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
386 }
387}
388
389impl crate::DatabaseTransaction {
390 pub(crate) async fn new_sqlite(
391 inner: PoolConnection<sqlx::Sqlite>,
392 metric_callback: Option<crate::metric::Callback>,
393 isolation_level: Option<IsolationLevel>,
394 access_mode: Option<AccessMode>,
395 ) -> Result<crate::DatabaseTransaction, DbErr> {
396 Self::begin(
397 Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
398 crate::DbBackend::Sqlite,
399 metric_callback,
400 isolation_level,
401 access_mode,
402 )
403 .await
404 }
405}
406
407#[cfg(feature = "proxy")]
408pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
409 use sea_query::Value;
412 use sqlx::{Column, Row, TypeInfo};
413 crate::ProxyRow {
414 values: row
415 .columns()
416 .iter()
417 .map(|c| {
418 (
419 c.name().to_string(),
420 match c.type_info().name() {
421 "BOOLEAN" => {
422 Value::Bool(row.try_get(c.ordinal()).expect("Failed to get boolean"))
423 }
424
425 "INTEGER" => {
426 Value::Int(row.try_get(c.ordinal()).expect("Failed to get integer"))
427 }
428
429 "BIGINT" | "INT8" => Value::BigInt(
430 row.try_get(c.ordinal()).expect("Failed to get big integer"),
431 ),
432
433 "REAL" => {
434 Value::Double(row.try_get(c.ordinal()).expect("Failed to get double"))
435 }
436
437 "TEXT" => Value::String(
438 row.try_get::<Option<String>, _>(c.ordinal())
439 .expect("Failed to get string")
440 .map(Box::new),
441 ),
442
443 "BLOB" => Value::Bytes(
444 row.try_get::<Option<Vec<u8>>, _>(c.ordinal())
445 .expect("Failed to get bytes")
446 .map(Box::new),
447 ),
448
449 #[cfg(feature = "with-chrono")]
450 "DATETIME" => {
451 use chrono::{DateTime, Utc};
452
453 Value::ChronoDateTimeUtc(
454 row.try_get::<Option<DateTime<Utc>>, _>(c.ordinal())
455 .expect("Failed to get timestamp")
456 .map(Box::new),
457 )
458 }
459 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
460 "DATETIME" => {
461 use time::OffsetDateTime;
462 Value::TimeDateTimeWithTimeZone(
463 row.try_get::<Option<OffsetDateTime>, _>(c.ordinal())
464 .expect("Failed to get timestamp")
465 .map(Box::new),
466 )
467 }
468 #[cfg(feature = "with-chrono")]
469 "DATE" => {
470 use chrono::NaiveDate;
471 Value::ChronoDate(
472 row.try_get::<Option<NaiveDate>, _>(c.ordinal())
473 .expect("Failed to get date")
474 .map(Box::new),
475 )
476 }
477 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
478 "DATE" => {
479 use time::Date;
480 Value::TimeDate(
481 row.try_get::<Option<Date>, _>(c.ordinal())
482 .expect("Failed to get date")
483 .map(Box::new),
484 )
485 }
486
487 #[cfg(feature = "with-chrono")]
488 "TIME" => {
489 use chrono::NaiveTime;
490 Value::ChronoTime(
491 row.try_get::<Option<NaiveTime>, _>(c.ordinal())
492 .expect("Failed to get time")
493 .map(Box::new),
494 )
495 }
496 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
497 "TIME" => {
498 use time::Time;
499 Value::TimeTime(
500 row.try_get::<Option<Time>, _>(c.ordinal())
501 .expect("Failed to get time")
502 .map(Box::new),
503 )
504 }
505
506 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
507 },
508 )
509 })
510 .collect(),
511 }
512}
513
514#[cfg(all(test, feature = "sqlite-use-returning-for-3_35"))]
515mod tests {
516 use super::*;
517
518 #[test]
519 fn test_ensure_returning_version() {
520 assert!(ensure_returning_version("").is_err());
521 assert!(ensure_returning_version(".").is_err());
522 assert!(ensure_returning_version(".a").is_err());
523 assert!(ensure_returning_version(".4.9").is_err());
524 assert!(ensure_returning_version("a").is_err());
525 assert!(ensure_returning_version("1.").is_err());
526 assert!(ensure_returning_version("1.a").is_err());
527
528 assert!(ensure_returning_version("1.1").is_err());
529 assert!(ensure_returning_version("1.0.").is_err());
530 assert!(ensure_returning_version("1.0.0").is_err());
531 assert!(ensure_returning_version("2.0.0").is_err());
532 assert!(ensure_returning_version("3.34.0").is_err());
533 assert!(ensure_returning_version("3.34.999").is_err());
534
535 assert!(ensure_returning_version("3.35.0").is_ok());
537 assert!(ensure_returning_version("3.35.1").is_ok());
538 assert!(ensure_returning_version("3.36.0").is_ok());
539 assert!(ensure_returning_version("4.0.0").is_ok());
540 assert!(ensure_returning_version("99.0.0").is_ok());
541 }
542}