rusqlite/
hooks.rs

1//! Commit, Data Change and Rollback Notification Callbacks
2#![allow(non_camel_case_types)]
3
4use std::os::raw::{c_char, c_int, c_void};
5use std::panic::{catch_unwind, RefUnwindSafe};
6use std::ptr;
7
8use crate::ffi;
9
10use crate::{Connection, InnerConnection};
11
12/// Action Codes
13#[derive(Clone, Copy, Debug, Eq, PartialEq)]
14#[repr(i32)]
15#[non_exhaustive]
16#[allow(clippy::upper_case_acronyms)]
17pub enum Action {
18    /// Unsupported / unexpected action
19    UNKNOWN = -1,
20    /// DELETE command
21    SQLITE_DELETE = ffi::SQLITE_DELETE,
22    /// INSERT command
23    SQLITE_INSERT = ffi::SQLITE_INSERT,
24    /// UPDATE command
25    SQLITE_UPDATE = ffi::SQLITE_UPDATE,
26}
27
28impl From<i32> for Action {
29    #[inline]
30    fn from(code: i32) -> Action {
31        match code {
32            ffi::SQLITE_DELETE => Action::SQLITE_DELETE,
33            ffi::SQLITE_INSERT => Action::SQLITE_INSERT,
34            ffi::SQLITE_UPDATE => Action::SQLITE_UPDATE,
35            _ => Action::UNKNOWN,
36        }
37    }
38}
39
40/// The context received by an authorizer hook.
41///
42/// See <https://sqlite.org/c3ref/set_authorizer.html> for more info.
43#[derive(Clone, Copy, Debug, Eq, PartialEq)]
44pub struct AuthContext<'c> {
45    /// The action to be authorized.
46    pub action: AuthAction<'c>,
47
48    /// The database name, if applicable.
49    pub database_name: Option<&'c str>,
50
51    /// The inner-most trigger or view responsible for the access attempt.
52    /// `None` if the access attempt was made by top-level SQL code.
53    pub accessor: Option<&'c str>,
54}
55
56/// Actions and arguments found within a statement during
57/// preparation.
58///
59/// See <https://sqlite.org/c3ref/c_alter_table.html> for more info.
60#[derive(Clone, Copy, Debug, Eq, PartialEq)]
61#[non_exhaustive]
62#[allow(missing_docs)]
63pub enum AuthAction<'c> {
64    /// This variant is not normally produced by SQLite. You may encounter it
65    // if you're using a different version than what's supported by this library.
66    Unknown {
67        /// The unknown authorization action code.
68        code: i32,
69        /// The third arg to the authorizer callback.
70        arg1: Option<&'c str>,
71        /// The fourth arg to the authorizer callback.
72        arg2: Option<&'c str>,
73    },
74    CreateIndex {
75        index_name: &'c str,
76        table_name: &'c str,
77    },
78    CreateTable {
79        table_name: &'c str,
80    },
81    CreateTempIndex {
82        index_name: &'c str,
83        table_name: &'c str,
84    },
85    CreateTempTable {
86        table_name: &'c str,
87    },
88    CreateTempTrigger {
89        trigger_name: &'c str,
90        table_name: &'c str,
91    },
92    CreateTempView {
93        view_name: &'c str,
94    },
95    CreateTrigger {
96        trigger_name: &'c str,
97        table_name: &'c str,
98    },
99    CreateView {
100        view_name: &'c str,
101    },
102    Delete {
103        table_name: &'c str,
104    },
105    DropIndex {
106        index_name: &'c str,
107        table_name: &'c str,
108    },
109    DropTable {
110        table_name: &'c str,
111    },
112    DropTempIndex {
113        index_name: &'c str,
114        table_name: &'c str,
115    },
116    DropTempTable {
117        table_name: &'c str,
118    },
119    DropTempTrigger {
120        trigger_name: &'c str,
121        table_name: &'c str,
122    },
123    DropTempView {
124        view_name: &'c str,
125    },
126    DropTrigger {
127        trigger_name: &'c str,
128        table_name: &'c str,
129    },
130    DropView {
131        view_name: &'c str,
132    },
133    Insert {
134        table_name: &'c str,
135    },
136    Pragma {
137        pragma_name: &'c str,
138        /// The pragma value, if present (e.g., `PRAGMA name = value;`).
139        pragma_value: Option<&'c str>,
140    },
141    Read {
142        table_name: &'c str,
143        column_name: &'c str,
144    },
145    Select,
146    Transaction {
147        operation: TransactionOperation,
148    },
149    Update {
150        table_name: &'c str,
151        column_name: &'c str,
152    },
153    Attach {
154        filename: &'c str,
155    },
156    Detach {
157        database_name: &'c str,
158    },
159    AlterTable {
160        database_name: &'c str,
161        table_name: &'c str,
162    },
163    Reindex {
164        index_name: &'c str,
165    },
166    Analyze {
167        table_name: &'c str,
168    },
169    CreateVtable {
170        table_name: &'c str,
171        module_name: &'c str,
172    },
173    DropVtable {
174        table_name: &'c str,
175        module_name: &'c str,
176    },
177    Function {
178        function_name: &'c str,
179    },
180    Savepoint {
181        operation: TransactionOperation,
182        savepoint_name: &'c str,
183    },
184    #[cfg(feature = "modern_sqlite")]
185    Recursive,
186}
187
188impl<'c> AuthAction<'c> {
189    fn from_raw(code: i32, arg1: Option<&'c str>, arg2: Option<&'c str>) -> Self {
190        match (code, arg1, arg2) {
191            (ffi::SQLITE_CREATE_INDEX, Some(index_name), Some(table_name)) => Self::CreateIndex {
192                index_name,
193                table_name,
194            },
195            (ffi::SQLITE_CREATE_TABLE, Some(table_name), _) => Self::CreateTable { table_name },
196            (ffi::SQLITE_CREATE_TEMP_INDEX, Some(index_name), Some(table_name)) => {
197                Self::CreateTempIndex {
198                    index_name,
199                    table_name,
200                }
201            }
202            (ffi::SQLITE_CREATE_TEMP_TABLE, Some(table_name), _) => {
203                Self::CreateTempTable { table_name }
204            }
205            (ffi::SQLITE_CREATE_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => {
206                Self::CreateTempTrigger {
207                    trigger_name,
208                    table_name,
209                }
210            }
211            (ffi::SQLITE_CREATE_TEMP_VIEW, Some(view_name), _) => {
212                Self::CreateTempView { view_name }
213            }
214            (ffi::SQLITE_CREATE_TRIGGER, Some(trigger_name), Some(table_name)) => {
215                Self::CreateTrigger {
216                    trigger_name,
217                    table_name,
218                }
219            }
220            (ffi::SQLITE_CREATE_VIEW, Some(view_name), _) => Self::CreateView { view_name },
221            (ffi::SQLITE_DELETE, Some(table_name), None) => Self::Delete { table_name },
222            (ffi::SQLITE_DROP_INDEX, Some(index_name), Some(table_name)) => Self::DropIndex {
223                index_name,
224                table_name,
225            },
226            (ffi::SQLITE_DROP_TABLE, Some(table_name), _) => Self::DropTable { table_name },
227            (ffi::SQLITE_DROP_TEMP_INDEX, Some(index_name), Some(table_name)) => {
228                Self::DropTempIndex {
229                    index_name,
230                    table_name,
231                }
232            }
233            (ffi::SQLITE_DROP_TEMP_TABLE, Some(table_name), _) => {
234                Self::DropTempTable { table_name }
235            }
236            (ffi::SQLITE_DROP_TEMP_TRIGGER, Some(trigger_name), Some(table_name)) => {
237                Self::DropTempTrigger {
238                    trigger_name,
239                    table_name,
240                }
241            }
242            (ffi::SQLITE_DROP_TEMP_VIEW, Some(view_name), _) => Self::DropTempView { view_name },
243            (ffi::SQLITE_DROP_TRIGGER, Some(trigger_name), Some(table_name)) => Self::DropTrigger {
244                trigger_name,
245                table_name,
246            },
247            (ffi::SQLITE_DROP_VIEW, Some(view_name), _) => Self::DropView { view_name },
248            (ffi::SQLITE_INSERT, Some(table_name), _) => Self::Insert { table_name },
249            (ffi::SQLITE_PRAGMA, Some(pragma_name), pragma_value) => Self::Pragma {
250                pragma_name,
251                pragma_value,
252            },
253            (ffi::SQLITE_READ, Some(table_name), Some(column_name)) => Self::Read {
254                table_name,
255                column_name,
256            },
257            (ffi::SQLITE_SELECT, ..) => Self::Select,
258            (ffi::SQLITE_TRANSACTION, Some(operation_str), _) => Self::Transaction {
259                operation: TransactionOperation::from_str(operation_str),
260            },
261            (ffi::SQLITE_UPDATE, Some(table_name), Some(column_name)) => Self::Update {
262                table_name,
263                column_name,
264            },
265            (ffi::SQLITE_ATTACH, Some(filename), _) => Self::Attach { filename },
266            (ffi::SQLITE_DETACH, Some(database_name), _) => Self::Detach { database_name },
267            (ffi::SQLITE_ALTER_TABLE, Some(database_name), Some(table_name)) => Self::AlterTable {
268                database_name,
269                table_name,
270            },
271            (ffi::SQLITE_REINDEX, Some(index_name), _) => Self::Reindex { index_name },
272            (ffi::SQLITE_ANALYZE, Some(table_name), _) => Self::Analyze { table_name },
273            (ffi::SQLITE_CREATE_VTABLE, Some(table_name), Some(module_name)) => {
274                Self::CreateVtable {
275                    table_name,
276                    module_name,
277                }
278            }
279            (ffi::SQLITE_DROP_VTABLE, Some(table_name), Some(module_name)) => Self::DropVtable {
280                table_name,
281                module_name,
282            },
283            (ffi::SQLITE_FUNCTION, _, Some(function_name)) => Self::Function { function_name },
284            (ffi::SQLITE_SAVEPOINT, Some(operation_str), Some(savepoint_name)) => Self::Savepoint {
285                operation: TransactionOperation::from_str(operation_str),
286                savepoint_name,
287            },
288            #[cfg(feature = "modern_sqlite")] // 3.8.3
289            (ffi::SQLITE_RECURSIVE, ..) => Self::Recursive,
290            (code, arg1, arg2) => Self::Unknown { code, arg1, arg2 },
291        }
292    }
293}
294
295pub(crate) type BoxedAuthorizer =
296    Box<dyn for<'c> FnMut(AuthContext<'c>) -> Authorization + Send + 'static>;
297
298/// A transaction operation.
299#[derive(Clone, Copy, Debug, Eq, PartialEq)]
300#[non_exhaustive]
301#[allow(missing_docs)]
302pub enum TransactionOperation {
303    Unknown,
304    Begin,
305    Release,
306    Rollback,
307}
308
309impl TransactionOperation {
310    fn from_str(op_str: &str) -> Self {
311        match op_str {
312            "BEGIN" => Self::Begin,
313            "RELEASE" => Self::Release,
314            "ROLLBACK" => Self::Rollback,
315            _ => Self::Unknown,
316        }
317    }
318}
319
320/// [`authorizer`](Connection::authorizer) return code
321#[derive(Clone, Copy, Debug, Eq, PartialEq)]
322#[non_exhaustive]
323pub enum Authorization {
324    /// Authorize the action.
325    Allow,
326    /// Don't allow access, but don't trigger an error either.
327    Ignore,
328    /// Trigger an error.
329    Deny,
330}
331
332impl Authorization {
333    fn into_raw(self) -> c_int {
334        match self {
335            Self::Allow => ffi::SQLITE_OK,
336            Self::Ignore => ffi::SQLITE_IGNORE,
337            Self::Deny => ffi::SQLITE_DENY,
338        }
339    }
340}
341
342impl Connection {
343    /// Register a callback function to be invoked whenever
344    /// a transaction is committed.
345    ///
346    /// The callback returns `true` to rollback.
347    #[inline]
348    pub fn commit_hook<F>(&self, hook: Option<F>)
349    where
350        F: FnMut() -> bool + Send + 'static,
351    {
352        self.db.borrow_mut().commit_hook(hook);
353    }
354
355    /// Register a callback function to be invoked whenever
356    /// a transaction is committed.
357    #[inline]
358    pub fn rollback_hook<F>(&self, hook: Option<F>)
359    where
360        F: FnMut() + Send + 'static,
361    {
362        self.db.borrow_mut().rollback_hook(hook);
363    }
364
365    /// Register a callback function to be invoked whenever
366    /// a row is updated, inserted or deleted in a rowid table.
367    ///
368    /// The callback parameters are:
369    ///
370    /// - the type of database update (`SQLITE_INSERT`, `SQLITE_UPDATE` or
371    /// `SQLITE_DELETE`),
372    /// - the name of the database ("main", "temp", ...),
373    /// - the name of the table that is updated,
374    /// - the ROWID of the row that is updated.
375    #[inline]
376    pub fn update_hook<F>(&self, hook: Option<F>)
377    where
378        F: FnMut(Action, &str, &str, i64) + Send + 'static,
379    {
380        self.db.borrow_mut().update_hook(hook);
381    }
382
383    /// Register a query progress callback.
384    ///
385    /// The parameter `num_ops` is the approximate number of virtual machine
386    /// instructions that are evaluated between successive invocations of the
387    /// `handler`. If `num_ops` is less than one then the progress handler
388    /// is disabled.
389    ///
390    /// If the progress callback returns `true`, the operation is interrupted.
391    pub fn progress_handler<F>(&self, num_ops: c_int, handler: Option<F>)
392    where
393        F: FnMut() -> bool + Send + RefUnwindSafe + 'static,
394    {
395        self.db.borrow_mut().progress_handler(num_ops, handler);
396    }
397
398    /// Register an authorizer callback that's invoked
399    /// as a statement is being prepared.
400    #[inline]
401    pub fn authorizer<'c, F>(&self, hook: Option<F>)
402    where
403        F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + RefUnwindSafe + 'static,
404    {
405        self.db.borrow_mut().authorizer(hook);
406    }
407}
408
409impl InnerConnection {
410    #[inline]
411    pub fn remove_hooks(&mut self) {
412        self.update_hook(None::<fn(Action, &str, &str, i64)>);
413        self.commit_hook(None::<fn() -> bool>);
414        self.rollback_hook(None::<fn()>);
415        self.progress_handler(0, None::<fn() -> bool>);
416        self.authorizer(None::<fn(AuthContext<'_>) -> Authorization>);
417    }
418
419    fn commit_hook<F>(&mut self, hook: Option<F>)
420    where
421        F: FnMut() -> bool + Send + 'static,
422    {
423        unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
424        where
425            F: FnMut() -> bool,
426        {
427            let r = catch_unwind(|| {
428                let boxed_hook: *mut F = p_arg.cast::<F>();
429                (*boxed_hook)()
430            });
431            if let Ok(true) = r {
432                1
433            } else {
434                0
435            }
436        }
437
438        // unlike `sqlite3_create_function_v2`, we cannot specify a `xDestroy` with
439        // `sqlite3_commit_hook`. so we keep the `xDestroy` function in
440        // `InnerConnection.free_boxed_hook`.
441        let free_commit_hook = if hook.is_some() {
442            Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
443        } else {
444            None
445        };
446
447        let previous_hook = match hook {
448            Some(hook) => {
449                let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
450                unsafe {
451                    ffi::sqlite3_commit_hook(
452                        self.db(),
453                        Some(call_boxed_closure::<F>),
454                        boxed_hook.cast(),
455                    )
456                }
457            }
458            _ => unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) },
459        };
460        if !previous_hook.is_null() {
461            if let Some(free_boxed_hook) = self.free_commit_hook {
462                unsafe { free_boxed_hook(previous_hook) };
463            }
464        }
465        self.free_commit_hook = free_commit_hook;
466    }
467
468    fn rollback_hook<F>(&mut self, hook: Option<F>)
469    where
470        F: FnMut() + Send + 'static,
471    {
472        unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void)
473        where
474            F: FnMut(),
475        {
476            drop(catch_unwind(|| {
477                let boxed_hook: *mut F = p_arg.cast::<F>();
478                (*boxed_hook)();
479            }));
480        }
481
482        let free_rollback_hook = if hook.is_some() {
483            Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
484        } else {
485            None
486        };
487
488        let previous_hook = match hook {
489            Some(hook) => {
490                let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
491                unsafe {
492                    ffi::sqlite3_rollback_hook(
493                        self.db(),
494                        Some(call_boxed_closure::<F>),
495                        boxed_hook.cast(),
496                    )
497                }
498            }
499            _ => unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) },
500        };
501        if !previous_hook.is_null() {
502            if let Some(free_boxed_hook) = self.free_rollback_hook {
503                unsafe { free_boxed_hook(previous_hook) };
504            }
505        }
506        self.free_rollback_hook = free_rollback_hook;
507    }
508
509    fn update_hook<F>(&mut self, hook: Option<F>)
510    where
511        F: FnMut(Action, &str, &str, i64) + Send + 'static,
512    {
513        unsafe extern "C" fn call_boxed_closure<F>(
514            p_arg: *mut c_void,
515            action_code: c_int,
516            p_db_name: *const c_char,
517            p_table_name: *const c_char,
518            row_id: i64,
519        ) where
520            F: FnMut(Action, &str, &str, i64),
521        {
522            let action = Action::from(action_code);
523            drop(catch_unwind(|| {
524                let boxed_hook: *mut F = p_arg.cast::<F>();
525                (*boxed_hook)(
526                    action,
527                    expect_utf8(p_db_name, "database name"),
528                    expect_utf8(p_table_name, "table name"),
529                    row_id,
530                );
531            }));
532        }
533
534        let free_update_hook = if hook.is_some() {
535            Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
536        } else {
537            None
538        };
539
540        let previous_hook = match hook {
541            Some(hook) => {
542                let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
543                unsafe {
544                    ffi::sqlite3_update_hook(
545                        self.db(),
546                        Some(call_boxed_closure::<F>),
547                        boxed_hook.cast(),
548                    )
549                }
550            }
551            _ => unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) },
552        };
553        if !previous_hook.is_null() {
554            if let Some(free_boxed_hook) = self.free_update_hook {
555                unsafe { free_boxed_hook(previous_hook) };
556            }
557        }
558        self.free_update_hook = free_update_hook;
559    }
560
561    fn progress_handler<F>(&mut self, num_ops: c_int, handler: Option<F>)
562    where
563        F: FnMut() -> bool + Send + RefUnwindSafe + 'static,
564    {
565        unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
566        where
567            F: FnMut() -> bool,
568        {
569            let r = catch_unwind(|| {
570                let boxed_handler: *mut F = p_arg.cast::<F>();
571                (*boxed_handler)()
572            });
573            if let Ok(true) = r {
574                1
575            } else {
576                0
577            }
578        }
579
580        if let Some(handler) = handler {
581            let boxed_handler = Box::new(handler);
582            unsafe {
583                ffi::sqlite3_progress_handler(
584                    self.db(),
585                    num_ops,
586                    Some(call_boxed_closure::<F>),
587                    &*boxed_handler as *const F as *mut _,
588                );
589            }
590            self.progress_handler = Some(boxed_handler);
591        } else {
592            unsafe { ffi::sqlite3_progress_handler(self.db(), num_ops, None, ptr::null_mut()) }
593            self.progress_handler = None;
594        };
595    }
596
597    fn authorizer<'c, F>(&'c mut self, authorizer: Option<F>)
598    where
599        F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + RefUnwindSafe + 'static,
600    {
601        unsafe extern "C" fn call_boxed_closure<'c, F>(
602            p_arg: *mut c_void,
603            action_code: c_int,
604            param1: *const c_char,
605            param2: *const c_char,
606            db_name: *const c_char,
607            trigger_or_view_name: *const c_char,
608        ) -> c_int
609        where
610            F: FnMut(AuthContext<'c>) -> Authorization + Send + 'static,
611        {
612            catch_unwind(|| {
613                let action = AuthAction::from_raw(
614                    action_code,
615                    expect_optional_utf8(param1, "authorizer param 1"),
616                    expect_optional_utf8(param2, "authorizer param 2"),
617                );
618                let auth_ctx = AuthContext {
619                    action,
620                    database_name: expect_optional_utf8(db_name, "database name"),
621                    accessor: expect_optional_utf8(
622                        trigger_or_view_name,
623                        "accessor (inner-most trigger or view)",
624                    ),
625                };
626                let boxed_hook: *mut F = p_arg.cast::<F>();
627                (*boxed_hook)(auth_ctx)
628            })
629            .map_or_else(|_| ffi::SQLITE_ERROR, Authorization::into_raw)
630        }
631
632        let callback_fn = authorizer
633            .as_ref()
634            .map(|_| call_boxed_closure::<'c, F> as unsafe extern "C" fn(_, _, _, _, _, _) -> _);
635        let boxed_authorizer = authorizer.map(Box::new);
636
637        match unsafe {
638            ffi::sqlite3_set_authorizer(
639                self.db(),
640                callback_fn,
641                boxed_authorizer
642                    .as_ref()
643                    .map_or_else(ptr::null_mut, |f| &**f as *const F as *mut _),
644            )
645        } {
646            ffi::SQLITE_OK => {
647                self.authorizer = boxed_authorizer.map(|ba| ba as _);
648            }
649            err_code => {
650                // The only error that `sqlite3_set_authorizer` returns is `SQLITE_MISUSE`
651                // when compiled with `ENABLE_API_ARMOR` and the db pointer is invalid.
652                // This library does not allow constructing a null db ptr, so if this branch
653                // is hit, something very bad has happened. Panicking instead of returning
654                // `Result` keeps this hook's API consistent with the others.
655                panic!("unexpectedly failed to set_authorizer: {}", unsafe {
656                    crate::error::error_from_handle(self.db(), err_code)
657                });
658            }
659        }
660    }
661}
662
663unsafe fn free_boxed_hook<F>(p: *mut c_void) {
664    drop(Box::from_raw(p.cast::<F>()));
665}
666
667unsafe fn expect_utf8<'a>(p_str: *const c_char, description: &'static str) -> &'a str {
668    expect_optional_utf8(p_str, description)
669        .unwrap_or_else(|| panic!("received empty {}", description))
670}
671
672unsafe fn expect_optional_utf8<'a>(
673    p_str: *const c_char,
674    description: &'static str,
675) -> Option<&'a str> {
676    if p_str.is_null() {
677        return None;
678    }
679    std::str::from_utf8(std::ffi::CStr::from_ptr(p_str).to_bytes())
680        .unwrap_or_else(|_| panic!("received non-utf8 string as {}", description))
681        .into()
682}
683
684#[cfg(test)]
685mod test {
686    use super::Action;
687    use crate::{Connection, Result};
688    use std::sync::atomic::{AtomicBool, Ordering};
689
690    #[test]
691    fn test_commit_hook() -> Result<()> {
692        let db = Connection::open_in_memory()?;
693
694        static CALLED: AtomicBool = AtomicBool::new(false);
695        db.commit_hook(Some(|| {
696            CALLED.store(true, Ordering::Relaxed);
697            false
698        }));
699        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?;
700        assert!(CALLED.load(Ordering::Relaxed));
701        Ok(())
702    }
703
704    #[test]
705    fn test_fn_commit_hook() -> Result<()> {
706        let db = Connection::open_in_memory()?;
707
708        fn hook() -> bool {
709            true
710        }
711
712        db.commit_hook(Some(hook));
713        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
714            .unwrap_err();
715        Ok(())
716    }
717
718    #[test]
719    fn test_rollback_hook() -> Result<()> {
720        let db = Connection::open_in_memory()?;
721
722        static CALLED: AtomicBool = AtomicBool::new(false);
723        db.rollback_hook(Some(|| {
724            CALLED.store(true, Ordering::Relaxed);
725        }));
726        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;")?;
727        assert!(CALLED.load(Ordering::Relaxed));
728        Ok(())
729    }
730
731    #[test]
732    fn test_update_hook() -> Result<()> {
733        let db = Connection::open_in_memory()?;
734
735        static CALLED: AtomicBool = AtomicBool::new(false);
736        db.update_hook(Some(|action, db: &str, tbl: &str, row_id| {
737            assert_eq!(Action::SQLITE_INSERT, action);
738            assert_eq!("main", db);
739            assert_eq!("foo", tbl);
740            assert_eq!(1, row_id);
741            CALLED.store(true, Ordering::Relaxed);
742        }));
743        db.execute_batch("CREATE TABLE foo (t TEXT)")?;
744        db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
745        assert!(CALLED.load(Ordering::Relaxed));
746        Ok(())
747    }
748
749    #[test]
750    fn test_progress_handler() -> Result<()> {
751        let db = Connection::open_in_memory()?;
752
753        static CALLED: AtomicBool = AtomicBool::new(false);
754        db.progress_handler(
755            1,
756            Some(|| {
757                CALLED.store(true, Ordering::Relaxed);
758                false
759            }),
760        );
761        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?;
762        assert!(CALLED.load(Ordering::Relaxed));
763        Ok(())
764    }
765
766    #[test]
767    fn test_progress_handler_interrupt() -> Result<()> {
768        let db = Connection::open_in_memory()?;
769
770        fn handler() -> bool {
771            true
772        }
773
774        db.progress_handler(1, Some(handler));
775        db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
776            .unwrap_err();
777        Ok(())
778    }
779
780    #[test]
781    fn test_authorizer() -> Result<()> {
782        use super::{AuthAction, AuthContext, Authorization};
783
784        let db = Connection::open_in_memory()?;
785        db.execute_batch("CREATE TABLE foo (public TEXT, private TEXT)")
786            .unwrap();
787
788        let authorizer = move |ctx: AuthContext<'_>| match ctx.action {
789            AuthAction::Read { column_name, .. } if column_name == "private" => {
790                Authorization::Ignore
791            }
792            AuthAction::DropTable { .. } => Authorization::Deny,
793            AuthAction::Pragma { .. } => panic!("shouldn't be called"),
794            _ => Authorization::Allow,
795        };
796
797        db.authorizer(Some(authorizer));
798        db.execute_batch(
799            "BEGIN TRANSACTION; INSERT INTO foo VALUES ('pub txt', 'priv txt'); COMMIT;",
800        )
801        .unwrap();
802        db.query_row_and_then("SELECT * FROM foo", [], |row| -> Result<()> {
803            assert_eq!(row.get::<_, String>("public")?, "pub txt");
804            assert!(row.get::<_, Option<String>>("private")?.is_none());
805            Ok(())
806        })
807        .unwrap();
808        db.execute_batch("DROP TABLE foo").unwrap_err();
809
810        db.authorizer(None::<fn(AuthContext<'_>) -> Authorization>);
811        db.execute_batch("PRAGMA user_version=1").unwrap(); // Disallowed by first authorizer, but it's now removed.
812
813        Ok(())
814    }
815}