1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 Connection, Executor, Sqlite, SqlitePool,
8 pool::PoolConnection,
9 sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
10};
11
12use sea_query_sqlx::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16 AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction,
17 IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*,
18 sqlx_error_to_exec_err,
19};
20
21use super::sqlx_common::*;
22
23#[derive(Debug)]
25pub struct SqlxSqliteConnector;
26
27#[derive(Clone)]
29pub struct SqlxSqlitePoolConnection {
30 pub(crate) pool: SqlitePool,
31 metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxSqlitePoolConnection {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
37 }
38}
39
40impl From<SqlitePool> for SqlxSqlitePoolConnection {
41 fn from(pool: SqlitePool) -> Self {
42 SqlxSqlitePoolConnection {
43 pool,
44 metric_callback: None,
45 }
46 }
47}
48
49impl From<SqlitePool> for DatabaseConnection {
50 fn from(pool: SqlitePool) -> Self {
51 DatabaseConnectionType::SqlxSqlitePoolConnection(pool.into()).into()
52 }
53}
54
55impl SqlxSqliteConnector {
56 pub fn accepts(string: &str) -> bool {
58 string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
59 }
60
61 #[instrument(level = "trace")]
63 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64 let mut options = options;
65 let mut opt = options
66 .url
67 .parse::<SqliteConnectOptions>()
68 .map_err(sqlx_error_to_conn_err)?;
69 if let Some(sqlcipher_key) = &options.sqlcipher_key {
70 opt = opt.pragma("key", sqlcipher_key.clone());
71 }
72 use sqlx::ConnectOptions;
73 if !options.sqlx_logging {
74 opt = opt.disable_statement_logging();
75 } else {
76 opt = opt.log_statements(options.sqlx_logging_level);
77 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
78 opt = opt.log_slow_statements(
79 options.sqlx_slow_statements_logging_level,
80 options.sqlx_slow_statements_logging_threshold,
81 );
82 }
83 }
84
85 if options.get_max_connections().is_none() {
86 options.max_connections(1);
87 }
88
89 let pool = if options.connect_lazy {
90 options.sqlx_pool_options().connect_lazy_with(opt)
91 } else {
92 options
93 .sqlx_pool_options()
94 .connect_with(opt)
95 .await
96 .map_err(sqlx_error_to_conn_err)?
97 };
98
99 let pool = SqlxSqlitePoolConnection {
100 pool,
101 metric_callback: None,
102 };
103
104 #[cfg(feature = "sqlite-use-returning-for-3_35")]
105 {
106 let version = get_version(&pool).await?;
107 ensure_returning_version(&version)?;
108 }
109
110 Ok(DatabaseConnectionType::SqlxSqlitePoolConnection(pool).into())
111 }
112}
113
114impl SqlxSqliteConnector {
115 pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
117 DatabaseConnectionType::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
118 pool,
119 metric_callback: None,
120 })
121 .into()
122 }
123}
124
125impl SqlxSqlitePoolConnection {
126 #[instrument(level = "trace")]
128 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
129 debug_print!("{}", stmt);
130
131 let query = sqlx_query(&stmt);
132 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
133 crate::metric::metric!(self.metric_callback, &stmt, {
134 match query.execute(&mut *conn).await {
135 Ok(res) => Ok(res.into()),
136 Err(err) => Err(sqlx_error_to_exec_err(err)),
137 }
138 })
139 }
140
141 #[instrument(level = "trace")]
143 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
144 debug_print!("{}", sql);
145
146 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
147 match conn.execute(sql).await {
148 Ok(res) => Ok(res.into()),
149 Err(err) => Err(sqlx_error_to_exec_err(err)),
150 }
151 }
152
153 #[instrument(level = "trace")]
155 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
156 debug_print!("{}", stmt);
157
158 let query = sqlx_query(&stmt);
159 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
160 crate::metric::metric!(self.metric_callback, &stmt, {
161 match query.fetch_one(&mut *conn).await {
162 Ok(row) => Ok(Some(row.into())),
163 Err(err) => match err {
164 sqlx::Error::RowNotFound => Ok(None),
165 _ => Err(sqlx_error_to_query_err(err)),
166 },
167 }
168 })
169 }
170
171 #[instrument(level = "trace")]
173 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
174 debug_print!("{}", stmt);
175
176 let query = sqlx_query(&stmt);
177 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
178 crate::metric::metric!(self.metric_callback, &stmt, {
179 match query.fetch_all(&mut *conn).await {
180 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
181 Err(err) => Err(sqlx_error_to_query_err(err)),
182 }
183 })
184 }
185
186 #[instrument(level = "trace")]
188 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
189 debug_print!("{}", stmt);
190
191 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
192 Ok(QueryStream::from((
193 conn,
194 stmt,
195 self.metric_callback.clone(),
196 )))
197 }
198
199 #[instrument(level = "trace")]
201 pub async fn begin(
202 &self,
203 isolation_level: Option<IsolationLevel>,
204 access_mode: Option<AccessMode>,
205 ) -> Result<DatabaseTransaction, DbErr> {
206 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
207 DatabaseTransaction::new_sqlite(
208 conn,
209 self.metric_callback.clone(),
210 isolation_level,
211 access_mode,
212 )
213 .await
214 }
215
216 #[instrument(level = "trace", skip(callback))]
218 pub async fn transaction<F, T, E>(
219 &self,
220 callback: F,
221 isolation_level: Option<IsolationLevel>,
222 access_mode: Option<AccessMode>,
223 ) -> Result<T, TransactionError<E>>
224 where
225 F: for<'b> FnOnce(
226 &'b DatabaseTransaction,
227 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
228 + Send,
229 T: Send,
230 E: std::fmt::Display + std::fmt::Debug + Send,
231 {
232 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
233 let transaction = DatabaseTransaction::new_sqlite(
234 conn,
235 self.metric_callback.clone(),
236 isolation_level,
237 access_mode,
238 )
239 .await
240 .map_err(|e| TransactionError::Connection(e))?;
241 transaction.run(callback).await
242 }
243
244 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
245 where
246 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
247 {
248 self.metric_callback = Some(Arc::new(callback));
249 }
250
251 pub async fn ping(&self) -> Result<(), DbErr> {
253 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
254 match conn.ping().await {
255 Ok(_) => Ok(()),
256 Err(err) => Err(sqlx_error_to_conn_err(err)),
257 }
258 }
259
260 pub async fn close(self) -> Result<(), DbErr> {
263 self.close_by_ref().await
264 }
265
266 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
268 self.pool.close().await;
269 Ok(())
270 }
271}
272
273impl From<SqliteRow> for QueryResult {
274 fn from(row: SqliteRow) -> QueryResult {
275 QueryResult {
276 row: QueryResultRow::SqlxSqlite(row),
277 }
278 }
279}
280
281impl From<SqliteQueryResult> for ExecResult {
282 fn from(result: SqliteQueryResult) -> ExecResult {
283 ExecResult {
284 result: ExecResultHolder::SqlxSqlite(result),
285 }
286 }
287}
288
289pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
290 let values = stmt
291 .values
292 .as_ref()
293 .map_or(Values(Vec::new()), |values| values.clone());
294 sqlx::query_with(&stmt.sql, SqlxValues(values))
295}
296
297pub(crate) async fn set_transaction_config(
298 _conn: &mut PoolConnection<Sqlite>,
299 isolation_level: Option<IsolationLevel>,
300 access_mode: Option<AccessMode>,
301) -> Result<(), DbErr> {
302 if isolation_level.is_some() {
303 warn!("Setting isolation level in a SQLite transaction isn't supported");
304 }
305 if access_mode.is_some() {
306 warn!("Setting access mode in a SQLite transaction isn't supported");
307 }
308 Ok(())
309}
310
311#[cfg(feature = "sqlite-use-returning-for-3_35")]
312async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
313 let stmt = Statement {
314 sql: "SELECT sqlite_version()".to_string(),
315 values: None,
316 db_backend: crate::DbBackend::Sqlite,
317 };
318 conn.query_one(stmt)
319 .await?
320 .ok_or_else(|| {
321 DbErr::Conn(RuntimeErr::Internal(
322 "Error reading SQLite version".to_string(),
323 ))
324 })?
325 .try_get_by(0)
326}
327
328#[cfg(feature = "sqlite-use-returning-for-3_35")]
329fn ensure_returning_version(version: &str) -> Result<(), DbErr> {
330 let mut parts = version.trim().split('.').map(|part| {
331 part.parse::<u32>().map_err(|_| {
332 DbErr::Conn(RuntimeErr::Internal(
333 "Error parsing SQLite version".to_string(),
334 ))
335 })
336 });
337
338 let mut extract_next = || {
339 parts.next().transpose().and_then(|part| {
340 part.ok_or_else(|| {
341 DbErr::Conn(RuntimeErr::Internal("SQLite version too short".to_string()))
342 })
343 })
344 };
345
346 let major = extract_next()?;
347 let minor = extract_next()?;
348
349 if major > 3 || (major == 3 && minor >= 35) {
350 Ok(())
351 } else {
352 Err(DbErr::Conn(RuntimeErr::Internal(
353 "SQLite version does not support returning".to_string(),
354 )))
355 }
356}
357
358impl
359 From<(
360 PoolConnection<sqlx::Sqlite>,
361 Statement,
362 Option<crate::metric::Callback>,
363 )> for crate::QueryStream
364{
365 fn from(
366 (conn, stmt, metric_callback): (
367 PoolConnection<sqlx::Sqlite>,
368 Statement,
369 Option<crate::metric::Callback>,
370 ),
371 ) -> Self {
372 crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
373 }
374}
375
376impl crate::DatabaseTransaction {
377 pub(crate) async fn new_sqlite(
378 inner: PoolConnection<sqlx::Sqlite>,
379 metric_callback: Option<crate::metric::Callback>,
380 isolation_level: Option<IsolationLevel>,
381 access_mode: Option<AccessMode>,
382 ) -> Result<crate::DatabaseTransaction, DbErr> {
383 Self::begin(
384 Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
385 crate::DbBackend::Sqlite,
386 metric_callback,
387 isolation_level,
388 access_mode,
389 )
390 .await
391 }
392}
393
394#[cfg(feature = "proxy")]
395pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
396 use sea_query::Value;
399 use sqlx::{Column, Row, TypeInfo};
400 crate::ProxyRow {
401 values: row
402 .columns()
403 .iter()
404 .map(|c| {
405 (
406 c.name().to_string(),
407 match c.type_info().name() {
408 "BOOLEAN" => Value::Bool(Some(
409 row.try_get(c.ordinal()).expect("Failed to get boolean"),
410 )),
411
412 "INTEGER" => Value::Int(Some(
413 row.try_get(c.ordinal()).expect("Failed to get integer"),
414 )),
415
416 "BIGINT" | "INT8" => Value::BigInt(Some(
417 row.try_get(c.ordinal()).expect("Failed to get big integer"),
418 )),
419
420 "REAL" => Value::Double(Some(
421 row.try_get(c.ordinal()).expect("Failed to get double"),
422 )),
423
424 "TEXT" => Value::String(Some(Box::new(
425 row.try_get(c.ordinal()).expect("Failed to get string"),
426 ))),
427
428 "BLOB" => Value::Bytes(Some(Box::new(
429 row.try_get(c.ordinal()).expect("Failed to get bytes"),
430 ))),
431
432 #[cfg(feature = "with-chrono")]
433 "DATETIME" => Value::ChronoDateTimeUtc(Some(Box::new(
434 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
435 ))),
436 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
437 "DATETIME" => Value::TimeDateTimeWithTimeZone(Some(Box::new(
438 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
439 ))),
440
441 #[cfg(feature = "with-chrono")]
442 "DATE" => Value::ChronoDate(Some(Box::new(
443 row.try_get(c.ordinal()).expect("Failed to get date"),
444 ))),
445 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
446 "DATE" => Value::TimeDate(Some(Box::new(
447 row.try_get(c.ordinal()).expect("Failed to get date"),
448 ))),
449
450 #[cfg(feature = "with-chrono")]
451 "TIME" => Value::ChronoTime(Some(Box::new(
452 row.try_get(c.ordinal()).expect("Failed to get time"),
453 ))),
454 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
455 "TIME" => Value::TimeTime(Some(Box::new(
456 row.try_get(c.ordinal()).expect("Failed to get time"),
457 ))),
458
459 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
460 },
461 )
462 })
463 .collect(),
464 }
465}
466
467#[cfg(all(test, feature = "sqlite-use-returning-for-3_35"))]
468mod tests {
469 use super::*;
470
471 #[test]
472 fn test_ensure_returning_version() {
473 assert!(ensure_returning_version("").is_err());
474 assert!(ensure_returning_version(".").is_err());
475 assert!(ensure_returning_version(".a").is_err());
476 assert!(ensure_returning_version(".4.9").is_err());
477 assert!(ensure_returning_version("a").is_err());
478 assert!(ensure_returning_version("1.").is_err());
479 assert!(ensure_returning_version("1.a").is_err());
480
481 assert!(ensure_returning_version("1.1").is_err());
482 assert!(ensure_returning_version("1.0.").is_err());
483 assert!(ensure_returning_version("1.0.0").is_err());
484 assert!(ensure_returning_version("2.0.0").is_err());
485 assert!(ensure_returning_version("3.34.0").is_err());
486 assert!(ensure_returning_version("3.34.999").is_err());
487
488 assert!(ensure_returning_version("3.35.0").is_ok());
490 assert!(ensure_returning_version("3.35.1").is_ok());
491 assert!(ensure_returning_version("3.36.0").is_ok());
492 assert!(ensure_returning_version("4.0.0").is_ok());
493 assert!(ensure_returning_version("99.0.0").is_ok());
494 }
495}