1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 Connection, Executor, Sqlite, SqlitePool,
8 pool::PoolConnection,
9 sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16 AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17 IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*,
18 sqlx_error_to_exec_err,
19};
20
21use super::sqlx_common::*;
22
23#[derive(Debug)]
25pub struct SqlxSqliteConnector;
26
27#[derive(Clone)]
29pub struct SqlxSqlitePoolConnection {
30 pub(crate) pool: SqlitePool,
31 metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxSqlitePoolConnection {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
37 }
38}
39
40impl From<SqlitePool> for SqlxSqlitePoolConnection {
41 fn from(pool: SqlitePool) -> Self {
42 SqlxSqlitePoolConnection {
43 pool,
44 metric_callback: None,
45 }
46 }
47}
48
49impl From<SqlitePool> for DatabaseConnection {
50 fn from(pool: SqlitePool) -> Self {
51 DatabaseConnectionType::SqlxSqlitePoolConnection(pool.into()).into()
52 }
53}
54
55impl SqlxSqliteConnector {
56 pub fn accepts(string: &str) -> bool {
58 string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
59 }
60
61 #[instrument(level = "trace")]
63 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64 let mut options = options;
65 let mut sqlx_opts = options
66 .url
67 .parse::<SqliteConnectOptions>()
68 .map_err(sqlx_error_to_conn_err)?;
69 if let Some(sqlcipher_key) = &options.sqlcipher_key {
70 sqlx_opts = sqlx_opts.pragma("key", sqlcipher_key.clone());
71 }
72 use sqlx::ConnectOptions;
73 if !options.sqlx_logging {
74 sqlx_opts = sqlx_opts.disable_statement_logging();
75 } else {
76 sqlx_opts = sqlx_opts.log_statements(options.sqlx_logging_level);
77 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
78 sqlx_opts = sqlx_opts.log_slow_statements(
79 options.sqlx_slow_statements_logging_level,
80 options.sqlx_slow_statements_logging_threshold,
81 );
82 }
83 }
84
85 if options.get_max_connections().is_none() {
86 options.max_connections(1);
87 }
88
89 if let Some(f) = &options.sqlite_opts_fn {
90 sqlx_opts = f(sqlx_opts);
91 }
92
93 let pool = if options.connect_lazy {
94 options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
95 } else {
96 options
97 .sqlx_pool_options()
98 .connect_with(sqlx_opts)
99 .await
100 .map_err(sqlx_error_to_conn_err)?
101 };
102
103 let pool = SqlxSqlitePoolConnection {
104 pool,
105 metric_callback: None,
106 };
107
108 #[cfg(feature = "sqlite-use-returning-for-3_35")]
109 {
110 let version = get_version(&pool).await?;
111 ensure_returning_version(&version)?;
112 }
113
114 Ok(DatabaseConnectionType::SqlxSqlitePoolConnection(pool).into())
115 }
116}
117
118impl SqlxSqliteConnector {
119 pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
121 DatabaseConnectionType::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
122 pool,
123 metric_callback: None,
124 })
125 .into()
126 }
127}
128
129impl SqlxSqlitePoolConnection {
130 #[instrument(level = "trace")]
132 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
133 debug_print!("{}", stmt);
134
135 let query = sqlx_query(&stmt);
136 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
137 crate::metric::metric!(self.metric_callback, &stmt, {
138 match query.execute(&mut *conn).await {
139 Ok(res) => Ok(res.into()),
140 Err(err) => Err(sqlx_error_to_exec_err(err)),
141 }
142 })
143 }
144
145 #[instrument(level = "trace")]
147 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
148 debug_print!("{}", sql);
149
150 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
151 match conn.execute(sql).await {
152 Ok(res) => Ok(res.into()),
153 Err(err) => Err(sqlx_error_to_exec_err(err)),
154 }
155 }
156
157 #[instrument(level = "trace")]
159 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
160 debug_print!("{}", stmt);
161
162 let query = sqlx_query(&stmt);
163 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
164 crate::metric::metric!(self.metric_callback, &stmt, {
165 match query.fetch_one(&mut *conn).await {
166 Ok(row) => Ok(Some(row.into())),
167 Err(err) => match err {
168 sqlx::Error::RowNotFound => Ok(None),
169 _ => Err(sqlx_error_to_query_err(err)),
170 },
171 }
172 })
173 }
174
175 #[instrument(level = "trace")]
177 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
178 debug_print!("{}", stmt);
179
180 let query = sqlx_query(&stmt);
181 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
182 crate::metric::metric!(self.metric_callback, &stmt, {
183 match query.fetch_all(&mut *conn).await {
184 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
185 Err(err) => Err(sqlx_error_to_query_err(err)),
186 }
187 })
188 }
189
190 #[instrument(level = "trace")]
192 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
193 debug_print!("{}", stmt);
194
195 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
196 Ok(QueryStream::from((
197 conn,
198 stmt,
199 self.metric_callback.clone(),
200 )))
201 }
202
203 #[instrument(level = "trace")]
205 pub async fn begin(
206 &self,
207 isolation_level: Option<IsolationLevel>,
208 access_mode: Option<AccessMode>,
209 ) -> Result<DatabaseTransaction, DbErr> {
210 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
211 DatabaseTransaction::new_sqlite(
212 conn,
213 self.metric_callback.clone(),
214 isolation_level,
215 access_mode,
216 )
217 .await
218 }
219
220 #[instrument(level = "trace", skip(callback))]
222 pub async fn transaction<F, T, E>(
223 &self,
224 callback: F,
225 isolation_level: Option<IsolationLevel>,
226 access_mode: Option<AccessMode>,
227 ) -> Result<T, TransactionError<E>>
228 where
229 F: for<'b> FnOnce(
230 &'b DatabaseTransaction,
231 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
232 + Send,
233 T: Send,
234 E: std::fmt::Display + std::fmt::Debug + Send,
235 {
236 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
237 let transaction = DatabaseTransaction::new_sqlite(
238 conn,
239 self.metric_callback.clone(),
240 isolation_level,
241 access_mode,
242 )
243 .await
244 .map_err(|e| TransactionError::Connection(e))?;
245 transaction.run(callback).await
246 }
247
248 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
249 where
250 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
251 {
252 self.metric_callback = Some(Arc::new(callback));
253 }
254
255 pub async fn ping(&self) -> Result<(), DbErr> {
257 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
258 match conn.ping().await {
259 Ok(_) => Ok(()),
260 Err(err) => Err(sqlx_error_to_conn_err(err)),
261 }
262 }
263
264 pub async fn close(self) -> Result<(), DbErr> {
267 self.close_by_ref().await
268 }
269
270 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
272 self.pool.close().await;
273 Ok(())
274 }
275}
276
277impl From<SqliteRow> for QueryResult {
278 fn from(row: SqliteRow) -> QueryResult {
279 QueryResult {
280 row: QueryResultRow::SqlxSqlite(row),
281 }
282 }
283}
284
285impl From<SqliteQueryResult> for ExecResult {
286 fn from(result: SqliteQueryResult) -> ExecResult {
287 ExecResult {
288 result: ExecResultHolder::SqlxSqlite(result),
289 }
290 }
291}
292
293pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
294 let values = stmt
295 .values
296 .as_ref()
297 .map_or(Values(Vec::new()), |values| values.clone());
298 sqlx::query_with(&stmt.sql, SqlxValues(values))
299}
300
301pub(crate) async fn set_transaction_config(
302 _conn: &mut PoolConnection<Sqlite>,
303 isolation_level: Option<IsolationLevel>,
304 access_mode: Option<AccessMode>,
305) -> Result<(), DbErr> {
306 if isolation_level.is_some() {
307 warn!("Setting isolation level in a SQLite transaction isn't supported");
308 }
309 if access_mode.is_some() {
310 warn!("Setting access mode in a SQLite transaction isn't supported");
311 }
312 Ok(())
313}
314
315#[cfg(feature = "sqlite-use-returning-for-3_35")]
316async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
317 let stmt = Statement {
318 sql: "SELECT sqlite_version()".to_string(),
319 values: None,
320 db_backend: crate::DbBackend::Sqlite,
321 };
322 conn.query_one(stmt)
323 .await?
324 .ok_or_else(|| {
325 DbErr::Conn(RuntimeErr::Internal(
326 "Error reading SQLite version".to_string(),
327 ))
328 })?
329 .try_get_by(0)
330}
331
332#[cfg(feature = "sqlite-use-returning-for-3_35")]
333fn ensure_returning_version(version: &str) -> Result<(), DbErr> {
334 let mut parts = version.trim().split('.').map(|part| {
335 part.parse::<u32>().map_err(|_| {
336 DbErr::Conn(RuntimeErr::Internal(
337 "Error parsing SQLite version".to_string(),
338 ))
339 })
340 });
341
342 let mut extract_next = || {
343 parts.next().transpose().and_then(|part| {
344 part.ok_or_else(|| {
345 DbErr::Conn(RuntimeErr::Internal("SQLite version too short".to_string()))
346 })
347 })
348 };
349
350 let major = extract_next()?;
351 let minor = extract_next()?;
352
353 if major > 3 || (major == 3 && minor >= 35) {
354 Ok(())
355 } else {
356 Err(DbErr::Conn(RuntimeErr::Internal(
357 "SQLite version does not support returning".to_string(),
358 )))
359 }
360}
361
362impl
363 From<(
364 PoolConnection<sqlx::Sqlite>,
365 Statement,
366 Option<crate::metric::Callback>,
367 )> for crate::QueryStream
368{
369 fn from(
370 (conn, stmt, metric_callback): (
371 PoolConnection<sqlx::Sqlite>,
372 Statement,
373 Option<crate::metric::Callback>,
374 ),
375 ) -> Self {
376 crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
377 }
378}
379
380impl crate::DatabaseTransaction {
381 pub(crate) async fn new_sqlite(
382 inner: PoolConnection<sqlx::Sqlite>,
383 metric_callback: Option<crate::metric::Callback>,
384 isolation_level: Option<IsolationLevel>,
385 access_mode: Option<AccessMode>,
386 ) -> Result<crate::DatabaseTransaction, DbErr> {
387 Self::begin(
388 Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
389 crate::DbBackend::Sqlite,
390 metric_callback,
391 isolation_level,
392 access_mode,
393 )
394 .await
395 }
396}
397
398#[cfg(feature = "proxy")]
399pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
400 use sea_query::Value;
403 use sqlx::{Column, Row, TypeInfo};
404 crate::ProxyRow {
405 values: row
406 .columns()
407 .iter()
408 .map(|c| {
409 (
410 c.name().to_string(),
411 match c.type_info().name() {
412 "BOOLEAN" => Value::Bool(Some(
413 row.try_get(c.ordinal()).expect("Failed to get boolean"),
414 )),
415
416 "INTEGER" => Value::Int(Some(
417 row.try_get(c.ordinal()).expect("Failed to get integer"),
418 )),
419
420 "BIGINT" | "INT8" => Value::BigInt(Some(
421 row.try_get(c.ordinal()).expect("Failed to get big integer"),
422 )),
423
424 "REAL" => Value::Double(Some(
425 row.try_get(c.ordinal()).expect("Failed to get double"),
426 )),
427
428 "TEXT" => Value::String(Some(Box::new(
429 row.try_get(c.ordinal()).expect("Failed to get string"),
430 ))),
431
432 "BLOB" => Value::Bytes(Some(Box::new(
433 row.try_get(c.ordinal()).expect("Failed to get bytes"),
434 ))),
435
436 #[cfg(feature = "with-chrono")]
437 "DATETIME" => Value::ChronoDateTimeUtc(Some(Box::new(
438 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
439 ))),
440 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
441 "DATETIME" => Value::TimeDateTimeWithTimeZone(Some(Box::new(
442 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
443 ))),
444
445 #[cfg(feature = "with-chrono")]
446 "DATE" => Value::ChronoDate(Some(Box::new(
447 row.try_get(c.ordinal()).expect("Failed to get date"),
448 ))),
449 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
450 "DATE" => Value::TimeDate(Some(Box::new(
451 row.try_get(c.ordinal()).expect("Failed to get date"),
452 ))),
453
454 #[cfg(feature = "with-chrono")]
455 "TIME" => Value::ChronoTime(Some(Box::new(
456 row.try_get(c.ordinal()).expect("Failed to get time"),
457 ))),
458 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
459 "TIME" => Value::TimeTime(Some(Box::new(
460 row.try_get(c.ordinal()).expect("Failed to get time"),
461 ))),
462
463 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
464 },
465 )
466 })
467 .collect(),
468 }
469}
470
471#[cfg(all(test, feature = "sqlite-use-returning-for-3_35"))]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn test_ensure_returning_version() {
477 assert!(ensure_returning_version("").is_err());
478 assert!(ensure_returning_version(".").is_err());
479 assert!(ensure_returning_version(".a").is_err());
480 assert!(ensure_returning_version(".4.9").is_err());
481 assert!(ensure_returning_version("a").is_err());
482 assert!(ensure_returning_version("1.").is_err());
483 assert!(ensure_returning_version("1.a").is_err());
484
485 assert!(ensure_returning_version("1.1").is_err());
486 assert!(ensure_returning_version("1.0.").is_err());
487 assert!(ensure_returning_version("1.0.0").is_err());
488 assert!(ensure_returning_version("2.0.0").is_err());
489 assert!(ensure_returning_version("3.34.0").is_err());
490 assert!(ensure_returning_version("3.34.999").is_err());
491
492 assert!(ensure_returning_version("3.35.0").is_ok());
494 assert!(ensure_returning_version("3.35.1").is_ok());
495 assert!(ensure_returning_version("3.36.0").is_ok());
496 assert!(ensure_returning_version("4.0.0").is_ok());
497 assert!(ensure_returning_version("99.0.0").is_ok());
498 }
499}