1use futures_util::lock::Mutex;
2use log::LevelFilter;
3use sea_query::Values;
4use std::{future::Future, pin::Pin, sync::Arc};
5
6use sqlx::{
7 pool::PoolConnection,
8 sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow},
9 Connection, Executor, Sqlite, SqlitePool,
10};
11
12use sea_query_binder::SqlxValues;
13use tracing::{instrument, warn};
14
15use crate::{
16 debug_print, error::*, executor::*, sqlx_error_to_exec_err, AccessMode, ConnectOptions,
17 DatabaseConnection, DatabaseTransaction, IsolationLevel, QueryStream, Statement,
18 TransactionError,
19};
20
21use super::sqlx_common::*;
22
23#[derive(Debug)]
25pub struct SqlxSqliteConnector;
26
27#[derive(Clone)]
29pub struct SqlxSqlitePoolConnection {
30 pub(crate) pool: SqlitePool,
31 metric_callback: Option<crate::metric::Callback>,
32}
33
34impl std::fmt::Debug for SqlxSqlitePoolConnection {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "SqlxSqlitePoolConnection {{ pool: {:?} }}", self.pool)
37 }
38}
39
40impl From<SqlitePool> for SqlxSqlitePoolConnection {
41 fn from(pool: SqlitePool) -> Self {
42 SqlxSqlitePoolConnection {
43 pool,
44 metric_callback: None,
45 }
46 }
47}
48
49impl From<SqlitePool> for DatabaseConnection {
50 fn from(pool: SqlitePool) -> Self {
51 DatabaseConnection::SqlxSqlitePoolConnection(pool.into())
52 }
53}
54
55impl SqlxSqliteConnector {
56 pub fn accepts(string: &str) -> bool {
58 string.starts_with("sqlite:") && string.parse::<SqliteConnectOptions>().is_ok()
59 }
60
61 #[instrument(level = "trace")]
63 pub async fn connect(options: ConnectOptions) -> Result<DatabaseConnection, DbErr> {
64 let mut options = options;
65 let mut opt = options
66 .url
67 .parse::<SqliteConnectOptions>()
68 .map_err(sqlx_error_to_conn_err)?;
69 if let Some(sqlcipher_key) = &options.sqlcipher_key {
70 opt = opt.pragma("key", sqlcipher_key.clone());
71 }
72 use sqlx::ConnectOptions;
73 if !options.sqlx_logging {
74 opt = opt.disable_statement_logging();
75 } else {
76 opt = opt.log_statements(options.sqlx_logging_level);
77 if options.sqlx_slow_statements_logging_level != LevelFilter::Off {
78 opt = opt.log_slow_statements(
79 options.sqlx_slow_statements_logging_level,
80 options.sqlx_slow_statements_logging_threshold,
81 );
82 }
83 }
84
85 if options.get_max_connections().is_none() {
86 options.max_connections(1);
87 }
88
89 let pool = if options.connect_lazy {
90 options.sqlx_pool_options().connect_lazy_with(opt)
91 } else {
92 options
93 .sqlx_pool_options()
94 .connect_with(opt)
95 .await
96 .map_err(sqlx_error_to_conn_err)?
97 };
98
99 let pool = SqlxSqlitePoolConnection {
100 pool,
101 metric_callback: None,
102 };
103
104 #[cfg(feature = "sqlite-use-returning-for-3_35")]
105 {
106 let version = get_version(&pool).await?;
107 ensure_returning_version(&version)?;
108 }
109
110 Ok(DatabaseConnection::SqlxSqlitePoolConnection(pool))
111 }
112}
113
114impl SqlxSqliteConnector {
115 pub fn from_sqlx_sqlite_pool(pool: SqlitePool) -> DatabaseConnection {
117 DatabaseConnection::SqlxSqlitePoolConnection(SqlxSqlitePoolConnection {
118 pool,
119 metric_callback: None,
120 })
121 }
122}
123
124impl SqlxSqlitePoolConnection {
125 #[instrument(level = "trace")]
127 pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
128 debug_print!("{}", stmt);
129
130 let query = sqlx_query(&stmt);
131 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
132 crate::metric::metric!(self.metric_callback, &stmt, {
133 match query.execute(&mut *conn).await {
134 Ok(res) => Ok(res.into()),
135 Err(err) => Err(sqlx_error_to_exec_err(err)),
136 }
137 })
138 }
139
140 #[instrument(level = "trace")]
142 pub async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
143 debug_print!("{}", sql);
144
145 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
146 match conn.execute(sql).await {
147 Ok(res) => Ok(res.into()),
148 Err(err) => Err(sqlx_error_to_exec_err(err)),
149 }
150 }
151
152 #[instrument(level = "trace")]
154 pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
155 debug_print!("{}", stmt);
156
157 let query = sqlx_query(&stmt);
158 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
159 crate::metric::metric!(self.metric_callback, &stmt, {
160 match query.fetch_one(&mut *conn).await {
161 Ok(row) => Ok(Some(row.into())),
162 Err(err) => match err {
163 sqlx::Error::RowNotFound => Ok(None),
164 _ => Err(sqlx_error_to_query_err(err)),
165 },
166 }
167 })
168 }
169
170 #[instrument(level = "trace")]
172 pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
173 debug_print!("{}", stmt);
174
175 let query = sqlx_query(&stmt);
176 let mut conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
177 crate::metric::metric!(self.metric_callback, &stmt, {
178 match query.fetch_all(&mut *conn).await {
179 Ok(rows) => Ok(rows.into_iter().map(|r| r.into()).collect()),
180 Err(err) => Err(sqlx_error_to_query_err(err)),
181 }
182 })
183 }
184
185 #[instrument(level = "trace")]
187 pub async fn stream(&self, stmt: Statement) -> Result<QueryStream, DbErr> {
188 debug_print!("{}", stmt);
189
190 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
191 Ok(QueryStream::from((
192 conn,
193 stmt,
194 self.metric_callback.clone(),
195 )))
196 }
197
198 #[instrument(level = "trace")]
200 pub async fn begin(
201 &self,
202 isolation_level: Option<IsolationLevel>,
203 access_mode: Option<AccessMode>,
204 ) -> Result<DatabaseTransaction, DbErr> {
205 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
206 DatabaseTransaction::new_sqlite(
207 conn,
208 self.metric_callback.clone(),
209 isolation_level,
210 access_mode,
211 )
212 .await
213 }
214
215 #[instrument(level = "trace", skip(callback))]
217 pub async fn transaction<F, T, E>(
218 &self,
219 callback: F,
220 isolation_level: Option<IsolationLevel>,
221 access_mode: Option<AccessMode>,
222 ) -> Result<T, TransactionError<E>>
223 where
224 F: for<'b> FnOnce(
225 &'b DatabaseTransaction,
226 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'b>>
227 + Send,
228 T: Send,
229 E: std::fmt::Display + std::fmt::Debug + Send,
230 {
231 let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
232 let transaction = DatabaseTransaction::new_sqlite(
233 conn,
234 self.metric_callback.clone(),
235 isolation_level,
236 access_mode,
237 )
238 .await
239 .map_err(|e| TransactionError::Connection(e))?;
240 transaction.run(callback).await
241 }
242
243 pub(crate) fn set_metric_callback<F>(&mut self, callback: F)
244 where
245 F: Fn(&crate::metric::Info<'_>) + Send + Sync + 'static,
246 {
247 self.metric_callback = Some(Arc::new(callback));
248 }
249
250 pub async fn ping(&self) -> Result<(), DbErr> {
252 let conn = &mut self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?;
253 match conn.ping().await {
254 Ok(_) => Ok(()),
255 Err(err) => Err(sqlx_error_to_conn_err(err)),
256 }
257 }
258
259 pub async fn close(self) -> Result<(), DbErr> {
262 self.close_by_ref().await
263 }
264
265 pub async fn close_by_ref(&self) -> Result<(), DbErr> {
267 self.pool.close().await;
268 Ok(())
269 }
270}
271
272impl From<SqliteRow> for QueryResult {
273 fn from(row: SqliteRow) -> QueryResult {
274 QueryResult {
275 row: QueryResultRow::SqlxSqlite(row),
276 }
277 }
278}
279
280impl From<SqliteQueryResult> for ExecResult {
281 fn from(result: SqliteQueryResult) -> ExecResult {
282 ExecResult {
283 result: ExecResultHolder::SqlxSqlite(result),
284 }
285 }
286}
287
288pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Sqlite, SqlxValues> {
289 let values = stmt
290 .values
291 .as_ref()
292 .map_or(Values(Vec::new()), |values| values.clone());
293 sqlx::query_with(&stmt.sql, SqlxValues(values))
294}
295
296pub(crate) async fn set_transaction_config(
297 _conn: &mut PoolConnection<Sqlite>,
298 isolation_level: Option<IsolationLevel>,
299 access_mode: Option<AccessMode>,
300) -> Result<(), DbErr> {
301 if isolation_level.is_some() {
302 warn!("Setting isolation level in a SQLite transaction isn't supported");
303 }
304 if access_mode.is_some() {
305 warn!("Setting access mode in a SQLite transaction isn't supported");
306 }
307 Ok(())
308}
309
310#[cfg(feature = "sqlite-use-returning-for-3_35")]
311async fn get_version(conn: &SqlxSqlitePoolConnection) -> Result<String, DbErr> {
312 let stmt = Statement {
313 sql: "SELECT sqlite_version()".to_string(),
314 values: None,
315 db_backend: crate::DbBackend::Sqlite,
316 };
317 conn.query_one(stmt)
318 .await?
319 .ok_or_else(|| {
320 DbErr::Conn(RuntimeErr::Internal(
321 "Error reading SQLite version".to_string(),
322 ))
323 })?
324 .try_get_by(0)
325}
326
327#[cfg(feature = "sqlite-use-returning-for-3_35")]
328fn ensure_returning_version(version: &str) -> Result<(), DbErr> {
329 let mut parts = version.trim().split('.').map(|part| {
330 part.parse::<u32>().map_err(|_| {
331 DbErr::Conn(RuntimeErr::Internal(
332 "Error parsing SQLite version".to_string(),
333 ))
334 })
335 });
336
337 let mut extract_next = || {
338 parts.next().transpose().and_then(|part| {
339 part.ok_or_else(|| {
340 DbErr::Conn(RuntimeErr::Internal("SQLite version too short".to_string()))
341 })
342 })
343 };
344
345 let major = extract_next()?;
346 let minor = extract_next()?;
347
348 if major > 3 || (major == 3 && minor >= 35) {
349 Ok(())
350 } else {
351 Err(DbErr::Conn(RuntimeErr::Internal(
352 "SQLite version does not support returning".to_string(),
353 )))
354 }
355}
356
357impl
358 From<(
359 PoolConnection<sqlx::Sqlite>,
360 Statement,
361 Option<crate::metric::Callback>,
362 )> for crate::QueryStream
363{
364 fn from(
365 (conn, stmt, metric_callback): (
366 PoolConnection<sqlx::Sqlite>,
367 Statement,
368 Option<crate::metric::Callback>,
369 ),
370 ) -> Self {
371 crate::QueryStream::build(stmt, crate::InnerConnection::Sqlite(conn), metric_callback)
372 }
373}
374
375impl crate::DatabaseTransaction {
376 pub(crate) async fn new_sqlite(
377 inner: PoolConnection<sqlx::Sqlite>,
378 metric_callback: Option<crate::metric::Callback>,
379 isolation_level: Option<IsolationLevel>,
380 access_mode: Option<AccessMode>,
381 ) -> Result<crate::DatabaseTransaction, DbErr> {
382 Self::begin(
383 Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))),
384 crate::DbBackend::Sqlite,
385 metric_callback,
386 isolation_level,
387 access_mode,
388 )
389 .await
390 }
391}
392
393#[cfg(feature = "proxy")]
394pub(crate) fn from_sqlx_sqlite_row_to_proxy_row(row: &sqlx::sqlite::SqliteRow) -> crate::ProxyRow {
395 use sea_query::Value;
398 use sqlx::{Column, Row, TypeInfo};
399 crate::ProxyRow {
400 values: row
401 .columns()
402 .iter()
403 .map(|c| {
404 (
405 c.name().to_string(),
406 match c.type_info().name() {
407 "BOOLEAN" => Value::Bool(Some(
408 row.try_get(c.ordinal()).expect("Failed to get boolean"),
409 )),
410
411 "INTEGER" => Value::Int(Some(
412 row.try_get(c.ordinal()).expect("Failed to get integer"),
413 )),
414
415 "BIGINT" | "INT8" => Value::BigInt(Some(
416 row.try_get(c.ordinal()).expect("Failed to get big integer"),
417 )),
418
419 "REAL" => Value::Double(Some(
420 row.try_get(c.ordinal()).expect("Failed to get double"),
421 )),
422
423 "TEXT" => Value::String(Some(Box::new(
424 row.try_get(c.ordinal()).expect("Failed to get string"),
425 ))),
426
427 "BLOB" => Value::Bytes(Some(Box::new(
428 row.try_get(c.ordinal()).expect("Failed to get bytes"),
429 ))),
430
431 #[cfg(feature = "with-chrono")]
432 "DATETIME" => Value::ChronoDateTimeUtc(Some(Box::new(
433 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
434 ))),
435 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
436 "DATETIME" => Value::TimeDateTimeWithTimeZone(Some(Box::new(
437 row.try_get(c.ordinal()).expect("Failed to get timestamp"),
438 ))),
439
440 #[cfg(feature = "with-chrono")]
441 "DATE" => Value::ChronoDate(Some(Box::new(
442 row.try_get(c.ordinal()).expect("Failed to get date"),
443 ))),
444 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
445 "DATE" => Value::TimeDate(Some(Box::new(
446 row.try_get(c.ordinal()).expect("Failed to get date"),
447 ))),
448
449 #[cfg(feature = "with-chrono")]
450 "TIME" => Value::ChronoTime(Some(Box::new(
451 row.try_get(c.ordinal()).expect("Failed to get time"),
452 ))),
453 #[cfg(all(feature = "with-time", not(feature = "with-chrono")))]
454 "TIME" => Value::TimeTime(Some(Box::new(
455 row.try_get(c.ordinal()).expect("Failed to get time"),
456 ))),
457
458 _ => unreachable!("Unknown column type: {}", c.type_info().name()),
459 },
460 )
461 })
462 .collect(),
463 }
464}
465
466#[cfg(all(test, feature = "sqlite-use-returning-for-3_35"))]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_ensure_returning_version() {
472 assert!(ensure_returning_version("").is_err());
473 assert!(ensure_returning_version(".").is_err());
474 assert!(ensure_returning_version(".a").is_err());
475 assert!(ensure_returning_version(".4.9").is_err());
476 assert!(ensure_returning_version("a").is_err());
477 assert!(ensure_returning_version("1.").is_err());
478 assert!(ensure_returning_version("1.a").is_err());
479
480 assert!(ensure_returning_version("1.1").is_err());
481 assert!(ensure_returning_version("1.0.").is_err());
482 assert!(ensure_returning_version("1.0.0").is_err());
483 assert!(ensure_returning_version("2.0.0").is_err());
484 assert!(ensure_returning_version("3.34.0").is_err());
485 assert!(ensure_returning_version("3.34.999").is_err());
486
487 assert!(ensure_returning_version("3.35.0").is_ok());
489 assert!(ensure_returning_version("3.35.1").is_ok());
490 assert!(ensure_returning_version("3.36.0").is_ok());
491 assert!(ensure_returning_version("4.0.0").is_ok());
492 assert!(ensure_returning_version("99.0.0").is_ok());
493 }
494}