1use std::borrow::Cow;
2use std::cmp::Ordering;
3use std::ffi::CStr;
4use std::fmt::Write;
5use std::fmt::{self, Debug, Formatter};
6use std::os::raw::{c_char, c_int, c_void};
7use std::panic::catch_unwind;
8use std::ptr;
9use std::ptr::NonNull;
10
11use futures_core::future::BoxFuture;
12use futures_intrusive::sync::MutexGuard;
13use futures_util::future;
14use libsqlite3_sys::{
15 sqlite3, sqlite3_commit_hook, sqlite3_get_autocommit, sqlite3_progress_handler,
16 sqlite3_rollback_hook, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
17};
18#[cfg(feature = "preupdate-hook")]
19pub use preupdate_hook::*;
20
21pub(crate) use handle::ConnectionHandle;
22use sqlx_core::common::StatementCache;
23pub(crate) use sqlx_core::connection::*;
24use sqlx_core::error::Error;
25use sqlx_core::executor::Executor;
26use sqlx_core::transaction::Transaction;
27
28use crate::connection::establish::EstablishParams;
29use crate::connection::worker::ConnectionWorker;
30use crate::options::OptimizeOnClose;
31use crate::statement::VirtualStatement;
32use crate::{Sqlite, SqliteConnectOptions, SqliteError};
33
34pub(crate) mod collation;
35pub(crate) mod describe;
36pub(crate) mod establish;
37pub(crate) mod execute;
38mod executor;
39mod explain;
40mod handle;
41pub(crate) mod intmap;
42#[cfg(feature = "preupdate-hook")]
43mod preupdate_hook;
44pub(crate) mod serialize;
45
46mod worker;
47
48pub struct SqliteConnection {
59 optimize_on_close: OptimizeOnClose,
60 pub(crate) worker: ConnectionWorker,
61 pub(crate) row_channel_size: usize,
62}
63
64pub struct LockedSqliteHandle<'a> {
65 pub(crate) guard: MutexGuard<'a, ConnectionState>,
66}
67
68pub(crate) struct Handler(NonNull<dyn FnMut() -> bool + Send + 'static>);
70unsafe impl Send for Handler {}
71
72#[derive(Debug, PartialEq, Eq, Clone)]
73pub enum SqliteOperation {
74 Insert,
75 Update,
76 Delete,
77 Unknown(i32),
78}
79
80impl From<i32> for SqliteOperation {
81 fn from(value: i32) -> Self {
82 match value {
83 SQLITE_INSERT => SqliteOperation::Insert,
84 SQLITE_UPDATE => SqliteOperation::Update,
85 SQLITE_DELETE => SqliteOperation::Delete,
86 code => SqliteOperation::Unknown(code),
87 }
88 }
89}
90
91pub struct UpdateHookResult<'a> {
92 pub operation: SqliteOperation,
93 pub database: &'a str,
94 pub table: &'a str,
95 pub rowid: i64,
96}
97
98pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
99unsafe impl Send for UpdateHookHandler {}
100
101pub(crate) struct CommitHookHandler(NonNull<dyn FnMut() -> bool + Send + 'static>);
102unsafe impl Send for CommitHookHandler {}
103
104pub(crate) struct RollbackHookHandler(NonNull<dyn FnMut() + Send + 'static>);
105unsafe impl Send for RollbackHookHandler {}
106
107pub(crate) struct ConnectionState {
108 pub(crate) handle: ConnectionHandle,
109
110 pub(crate) statements: Statements,
111
112 log_settings: LogSettings,
113
114 progress_handler_callback: Option<Handler>,
117
118 update_hook_callback: Option<UpdateHookHandler>,
119 #[cfg(feature = "preupdate-hook")]
120 preupdate_hook_callback: Option<preupdate_hook::PreupdateHookHandler>,
121
122 commit_hook_callback: Option<CommitHookHandler>,
123
124 rollback_hook_callback: Option<RollbackHookHandler>,
125}
126
127impl ConnectionState {
128 pub(crate) fn remove_progress_handler(&mut self) {
130 if let Some(mut handler) = self.progress_handler_callback.take() {
131 unsafe {
132 sqlite3_progress_handler(self.handle.as_ptr(), 0, None, ptr::null_mut());
133 let _ = { Box::from_raw(handler.0.as_mut()) };
134 }
135 }
136 }
137
138 pub(crate) fn remove_update_hook(&mut self) {
139 if let Some(mut handler) = self.update_hook_callback.take() {
140 unsafe {
141 sqlite3_update_hook(self.handle.as_ptr(), None, ptr::null_mut());
142 let _ = { Box::from_raw(handler.0.as_mut()) };
143 }
144 }
145 }
146
147 #[cfg(feature = "preupdate-hook")]
148 pub(crate) fn remove_preupdate_hook(&mut self) {
149 if let Some(mut handler) = self.preupdate_hook_callback.take() {
150 unsafe {
151 libsqlite3_sys::sqlite3_preupdate_hook(self.handle.as_ptr(), None, ptr::null_mut());
152 let _ = { Box::from_raw(handler.0.as_mut()) };
153 }
154 }
155 }
156
157 pub(crate) fn remove_commit_hook(&mut self) {
158 if let Some(mut handler) = self.commit_hook_callback.take() {
159 unsafe {
160 sqlite3_commit_hook(self.handle.as_ptr(), None, ptr::null_mut());
161 let _ = { Box::from_raw(handler.0.as_mut()) };
162 }
163 }
164 }
165
166 pub(crate) fn remove_rollback_hook(&mut self) {
167 if let Some(mut handler) = self.rollback_hook_callback.take() {
168 unsafe {
169 sqlite3_rollback_hook(self.handle.as_ptr(), None, ptr::null_mut());
170 let _ = { Box::from_raw(handler.0.as_mut()) };
171 }
172 }
173 }
174}
175
176pub(crate) struct Statements {
177 cached: StatementCache<VirtualStatement>,
179 temp: Option<VirtualStatement>,
181}
182
183impl SqliteConnection {
184 pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result<Self, Error> {
185 let params = EstablishParams::from_options(options)?;
186 let worker = ConnectionWorker::establish(params).await?;
187 Ok(Self {
188 optimize_on_close: options.optimize_on_close.clone(),
189 worker,
190 row_channel_size: options.row_channel_size,
191 })
192 }
193
194 pub async fn lock_handle(&mut self) -> Result<LockedSqliteHandle<'_>, Error> {
199 let guard = self.worker.unlock_db().await?;
200
201 Ok(LockedSqliteHandle { guard })
202 }
203}
204
205impl Debug for SqliteConnection {
206 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
207 f.debug_struct("SqliteConnection")
208 .field("row_channel_size", &self.row_channel_size)
209 .field("cached_statements_size", &self.cached_statements_size())
210 .finish()
211 }
212}
213
214impl Connection for SqliteConnection {
215 type Database = Sqlite;
216
217 type Options = SqliteConnectOptions;
218
219 fn close(mut self) -> BoxFuture<'static, Result<(), Error>> {
220 Box::pin(async move {
221 if let OptimizeOnClose::Enabled { analysis_limit } = self.optimize_on_close {
222 let mut pragma_string = String::new();
223 if let Some(limit) = analysis_limit {
224 write!(pragma_string, "PRAGMA analysis_limit = {limit}; ").ok();
225 }
226 pragma_string.push_str("PRAGMA optimize;");
227 self.execute(&*pragma_string).await?;
228 }
229 let shutdown = self.worker.shutdown();
230 drop(self);
233 shutdown.await
235 })
236 }
237
238 fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
239 Box::pin(async move {
240 drop(self);
241 Ok(())
242 })
243 }
244
245 fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
247 Box::pin(self.worker.ping())
248 }
249
250 fn begin(&mut self) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
251 where
252 Self: Sized,
253 {
254 Transaction::begin(self, None)
255 }
256
257 fn begin_with(
258 &mut self,
259 statement: impl Into<Cow<'static, str>>,
260 ) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
261 where
262 Self: Sized,
263 {
264 Transaction::begin(self, Some(statement.into()))
265 }
266
267 fn cached_statements_size(&self) -> usize {
268 self.worker.shared.get_cached_statements_size()
269 }
270
271 fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> {
272 Box::pin(async move {
273 self.worker.clear_cache().await?;
274 Ok(())
275 })
276 }
277
278 #[inline]
279 fn shrink_buffers(&mut self) {
280 }
282
283 #[doc(hidden)]
284 fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
285 Box::pin(future::ok(()))
290 }
291
292 #[doc(hidden)]
293 fn should_flush(&self) -> bool {
294 false
295 }
296}
297
298extern "C" fn progress_callback<F>(callback: *mut c_void) -> c_int
301where
302 F: FnMut() -> bool,
303{
304 unsafe {
305 let r = catch_unwind(|| {
306 let callback: *mut F = callback.cast::<F>();
307 (*callback)()
308 });
309 c_int::from(!r.unwrap_or_default())
310 }
311}
312
313extern "C" fn update_hook<F>(
314 callback: *mut c_void,
315 op_code: c_int,
316 database: *const c_char,
317 table: *const c_char,
318 rowid: i64,
319) where
320 F: FnMut(UpdateHookResult),
321{
322 unsafe {
323 let _ = catch_unwind(|| {
324 let callback: *mut F = callback.cast::<F>();
325 let operation: SqliteOperation = op_code.into();
326 let database = CStr::from_ptr(database).to_str().unwrap_or_default();
327 let table = CStr::from_ptr(table).to_str().unwrap_or_default();
328 (*callback)(UpdateHookResult {
329 operation,
330 database,
331 table,
332 rowid,
333 })
334 });
335 }
336}
337
338extern "C" fn commit_hook<F>(callback: *mut c_void) -> c_int
339where
340 F: FnMut() -> bool,
341{
342 unsafe {
343 let r = catch_unwind(|| {
344 let callback: *mut F = callback.cast::<F>();
345 (*callback)()
346 });
347 c_int::from(!r.unwrap_or_default())
348 }
349}
350
351extern "C" fn rollback_hook<F>(callback: *mut c_void)
352where
353 F: FnMut(),
354{
355 unsafe {
356 let _ = catch_unwind(|| {
357 let callback: *mut F = callback.cast::<F>();
358 (*callback)()
359 });
360 }
361}
362
363impl LockedSqliteHandle<'_> {
364 pub fn as_raw_handle(&mut self) -> NonNull<sqlite3> {
379 self.guard.handle.as_non_null_ptr()
380 }
381
382 pub fn create_collation(
386 &mut self,
387 name: &str,
388 compare: impl Fn(&str, &str) -> Ordering + Send + Sync + 'static,
389 ) -> Result<(), Error> {
390 collation::create_collation(&mut self.guard.handle, name, compare)
391 }
392
393 pub fn set_progress_handler<F>(&mut self, num_ops: i32, callback: F)
407 where
408 F: FnMut() -> bool + Send + 'static,
409 {
410 unsafe {
411 let callback_boxed = Box::new(callback);
412 let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
414 let handler = callback.as_ptr() as *mut _;
415 self.guard.remove_progress_handler();
416 self.guard.progress_handler_callback = Some(Handler(callback));
417
418 sqlite3_progress_handler(
419 self.as_raw_handle().as_mut(),
420 num_ops,
421 Some(progress_callback::<F>),
422 handler,
423 );
424 }
425 }
426
427 pub fn set_update_hook<F>(&mut self, callback: F)
428 where
429 F: FnMut(UpdateHookResult) + Send + 'static,
430 {
431 unsafe {
432 let callback_boxed = Box::new(callback);
433 let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
435 let handler = callback.as_ptr() as *mut _;
436 self.guard.remove_update_hook();
437 self.guard.update_hook_callback = Some(UpdateHookHandler(callback));
438
439 sqlite3_update_hook(
440 self.as_raw_handle().as_mut(),
441 Some(update_hook::<F>),
442 handler,
443 );
444 }
445 }
446
447 #[cfg(feature = "preupdate-hook")]
455 pub fn set_preupdate_hook<F>(&mut self, callback: F)
456 where
457 F: FnMut(PreupdateHookResult) + Send + 'static,
458 {
459 unsafe {
460 let callback_boxed = Box::new(callback);
461 let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
463 let handler = callback.as_ptr() as *mut _;
464 self.guard.remove_preupdate_hook();
465 self.guard.preupdate_hook_callback = Some(PreupdateHookHandler(callback));
466
467 libsqlite3_sys::sqlite3_preupdate_hook(
468 self.as_raw_handle().as_mut(),
469 Some(preupdate_hook::<F>),
470 handler,
471 );
472 }
473 }
474
475 pub fn set_commit_hook<F>(&mut self, callback: F)
487 where
488 F: FnMut() -> bool + Send + 'static,
489 {
490 unsafe {
491 let callback_boxed = Box::new(callback);
492 let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
494 let handler = callback.as_ptr() as *mut _;
495 self.guard.remove_commit_hook();
496 self.guard.commit_hook_callback = Some(CommitHookHandler(callback));
497
498 sqlite3_commit_hook(
499 self.as_raw_handle().as_mut(),
500 Some(commit_hook::<F>),
501 handler,
502 );
503 }
504 }
505
506 pub fn set_rollback_hook<F>(&mut self, callback: F)
511 where
512 F: FnMut() + Send + 'static,
513 {
514 unsafe {
515 let callback_boxed = Box::new(callback);
516 let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
518 let handler = callback.as_ptr() as *mut _;
519 self.guard.remove_rollback_hook();
520 self.guard.rollback_hook_callback = Some(RollbackHookHandler(callback));
521
522 sqlite3_rollback_hook(
523 self.as_raw_handle().as_mut(),
524 Some(rollback_hook::<F>),
525 handler,
526 );
527 }
528 }
529
530 pub fn remove_progress_handler(&mut self) {
532 self.guard.remove_progress_handler();
533 }
534
535 pub fn remove_update_hook(&mut self) {
536 self.guard.remove_update_hook();
537 }
538
539 #[cfg(feature = "preupdate-hook")]
540 pub fn remove_preupdate_hook(&mut self) {
541 self.guard.remove_preupdate_hook();
542 }
543
544 pub fn remove_commit_hook(&mut self) {
545 self.guard.remove_commit_hook();
546 }
547
548 pub fn remove_rollback_hook(&mut self) {
549 self.guard.remove_rollback_hook();
550 }
551
552 pub fn last_error(&mut self) -> Option<SqliteError> {
553 self.guard.handle.last_error()
554 }
555
556 pub(crate) fn in_transaction(&mut self) -> bool {
557 let ret = unsafe { sqlite3_get_autocommit(self.as_raw_handle().as_ptr()) };
558 ret == 0
559 }
560}
561
562impl Drop for ConnectionState {
563 fn drop(&mut self) {
564 self.statements.clear();
566 self.remove_progress_handler();
567 self.remove_update_hook();
568 self.remove_commit_hook();
569 self.remove_rollback_hook();
570 }
571}
572
573impl Statements {
574 fn new(capacity: usize) -> Self {
575 Statements {
576 cached: StatementCache::new(capacity),
577 temp: None,
578 }
579 }
580
581 fn get(&mut self, query: &str, persistent: bool) -> Result<&mut VirtualStatement, Error> {
582 if !persistent || !self.cached.is_enabled() {
583 return Ok(self.temp.insert(VirtualStatement::new(query, false)?));
584 }
585
586 let exists = self.cached.contains_key(query);
587
588 if !exists {
589 let statement = VirtualStatement::new(query, true)?;
590 self.cached.insert(query, statement);
591 }
592
593 let statement = self.cached.get_mut(query).unwrap();
594
595 if exists {
596 statement.reset()?;
598 }
599
600 Ok(statement)
601 }
602
603 fn len(&self) -> usize {
604 self.cached.len()
605 }
606
607 fn clear(&mut self) {
608 self.cached.clear();
609 self.temp = None;
610 }
611}