1use core::ffi::{c_char, c_void};
2use core::marker::PhantomData;
3use core::ptr::NonNull;
4use std::collections::HashMap;
5use std::sync::{Mutex, MutexGuard, OnceLock};
6
7use crate::error::Result;
8use crate::provider::{Sqlite3Api, Sqlite3Hooks};
9
10use super::core::Connection;
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14pub struct TraceMask {
15 bits: u32,
16}
17
18impl TraceMask {
19 pub const STMT: TraceMask = TraceMask { bits: 0x01 };
21 pub const PROFILE: TraceMask = TraceMask { bits: 0x02 };
23 pub const ROW: TraceMask = TraceMask { bits: 0x04 };
25 pub const CLOSE: TraceMask = TraceMask { bits: 0x08 };
27
28 pub const fn empty() -> Self {
30 Self { bits: 0 }
31 }
32
33 pub const fn bits(self) -> u32 {
35 self.bits
36 }
37
38 pub const fn contains(self, other: TraceMask) -> bool {
40 (self.bits & other.bits) == other.bits
41 }
42}
43
44impl core::ops::BitOr for TraceMask {
45 type Output = TraceMask;
46
47 fn bitor(self, rhs: TraceMask) -> TraceMask {
48 TraceMask {
49 bits: self.bits | rhs.bits,
50 }
51 }
52}
53
54impl core::ops::BitOrAssign for TraceMask {
55 fn bitor_assign(&mut self, rhs: TraceMask) {
56 self.bits |= rhs.bits;
57 }
58}
59
60pub enum TraceEvent<'a, P: Sqlite3Api> {
62 Stmt {
64 stmt: NonNull<P::Stmt>,
66 sql: Option<&'a str>,
68 },
69 Profile {
71 stmt: NonNull<P::Stmt>,
73 nsec: i64,
75 },
76 Row {
78 stmt: NonNull<P::Stmt>,
80 },
81 Close {
83 db: NonNull<P::Db>,
85 },
86 Raw {
88 mask: u32,
90 p1: *mut c_void,
92 p2: *mut c_void,
94 },
95}
96
97type TraceCallback<P> = dyn for<'a> FnMut(TraceEvent<'a, P>) + Send;
98
99struct TraceState<P: Sqlite3Api> {
100 cb: Box<TraceCallback<P>>,
101}
102
103type ProgressCallback = dyn FnMut() -> i32 + Send;
104
105struct ProgressState {
106 cb: Box<ProgressCallback>,
107}
108
109extern "C" fn trace_trampoline<P: Sqlite3Api>(
110 mask: u32,
111 ctx: *mut c_void,
112 p1: *mut c_void,
113 p2: *mut c_void,
114) {
115 let state = unsafe { &mut *(ctx as *mut TraceState<P>) };
116 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
117 let event = decode_trace::<P>(mask, p1, p2);
118 (state.cb)(event);
119 }));
120}
121
122fn decode_trace<'a, P: Sqlite3Api>(
123 mask: u32,
124 p1: *mut c_void,
125 p2: *mut c_void,
126) -> TraceEvent<'a, P> {
127 if (mask & TraceMask::STMT.bits()) != 0 {
128 let stmt = match NonNull::new(p1 as *mut P::Stmt) {
129 Some(stmt) => stmt,
130 None => return TraceEvent::Raw { mask, p1, p2 },
131 };
132 let sql = unsafe { cstr_to_opt(p2 as *const c_char) };
133 return TraceEvent::Stmt { stmt, sql };
134 }
135 if (mask & TraceMask::PROFILE.bits()) != 0 {
136 let stmt = match NonNull::new(p1 as *mut P::Stmt) {
137 Some(stmt) => stmt,
138 None => return TraceEvent::Raw { mask, p1, p2 },
139 };
140 if p2.is_null() {
141 return TraceEvent::Raw { mask, p1, p2 };
142 }
143 let nsec = unsafe { *(p2 as *const i64) };
144 return TraceEvent::Profile { stmt, nsec };
145 }
146 if (mask & TraceMask::ROW.bits()) != 0 {
147 let stmt = match NonNull::new(p1 as *mut P::Stmt) {
148 Some(stmt) => stmt,
149 None => return TraceEvent::Raw { mask, p1, p2 },
150 };
151 return TraceEvent::Row { stmt };
152 }
153 if (mask & TraceMask::CLOSE.bits()) != 0 {
154 let db = match NonNull::new(p1 as *mut P::Db) {
155 Some(db) => db,
156 None => return TraceEvent::Raw { mask, p1, p2 },
157 };
158 return TraceEvent::Close { db };
159 }
160 TraceEvent::Raw { mask, p1, p2 }
161}
162
163pub mod authorizer {
165 pub const CREATE_INDEX: i32 = 1;
167 pub const CREATE_TABLE: i32 = 2;
169 pub const CREATE_TEMP_INDEX: i32 = 3;
171 pub const CREATE_TEMP_TABLE: i32 = 4;
173 pub const CREATE_TEMP_TRIGGER: i32 = 5;
175 pub const CREATE_TEMP_VIEW: i32 = 6;
177 pub const CREATE_TRIGGER: i32 = 7;
179 pub const CREATE_VIEW: i32 = 8;
181 pub const DELETE: i32 = 9;
183 pub const DROP_INDEX: i32 = 10;
185 pub const DROP_TABLE: i32 = 11;
187 pub const DROP_TEMP_INDEX: i32 = 12;
189 pub const DROP_TEMP_TABLE: i32 = 13;
191 pub const DROP_TEMP_TRIGGER: i32 = 14;
193 pub const DROP_TEMP_VIEW: i32 = 15;
195 pub const DROP_TRIGGER: i32 = 16;
197 pub const DROP_VIEW: i32 = 17;
199 pub const INSERT: i32 = 18;
201 pub const PRAGMA: i32 = 19;
203 pub const READ: i32 = 20;
205 pub const SELECT: i32 = 21;
207 pub const TRANSACTION: i32 = 22;
209 pub const UPDATE: i32 = 23;
211 pub const ATTACH: i32 = 24;
213 pub const DETACH: i32 = 25;
215 pub const ALTER_TABLE: i32 = 26;
217 pub const REINDEX: i32 = 27;
219 pub const ANALYZE: i32 = 28;
221 pub const CREATE_VTABLE: i32 = 29;
223 pub const DROP_VTABLE: i32 = 30;
225 pub const FUNCTION: i32 = 31;
227 pub const SAVEPOINT: i32 = 32;
229 pub const RECURSIVE: i32 = 33;
231}
232
233#[derive(Clone, Copy, Debug, PartialEq, Eq)]
235pub enum AuthorizerAction {
236 CreateIndex,
238 CreateTable,
240 CreateTempIndex,
242 CreateTempTable,
244 CreateTempTrigger,
246 CreateTempView,
248 CreateTrigger,
250 CreateView,
252 Delete,
254 DropIndex,
256 DropTable,
258 DropTempIndex,
260 DropTempTable,
262 DropTempTrigger,
264 DropTempView,
266 DropTrigger,
268 DropView,
270 Insert,
272 Pragma,
274 Read,
276 Select,
278 Transaction,
280 Update,
282 Attach,
284 Detach,
286 AlterTable,
288 Reindex,
290 Analyze,
292 CreateVTable,
294 DropVTable,
296 Function,
298 Savepoint,
300 Recursive,
302 Unknown(i32),
304}
305
306impl AuthorizerAction {
307 pub fn from_code(code: i32) -> Self {
309 match code {
310 authorizer::CREATE_INDEX => Self::CreateIndex,
311 authorizer::CREATE_TABLE => Self::CreateTable,
312 authorizer::CREATE_TEMP_INDEX => Self::CreateTempIndex,
313 authorizer::CREATE_TEMP_TABLE => Self::CreateTempTable,
314 authorizer::CREATE_TEMP_TRIGGER => Self::CreateTempTrigger,
315 authorizer::CREATE_TEMP_VIEW => Self::CreateTempView,
316 authorizer::CREATE_TRIGGER => Self::CreateTrigger,
317 authorizer::CREATE_VIEW => Self::CreateView,
318 authorizer::DELETE => Self::Delete,
319 authorizer::DROP_INDEX => Self::DropIndex,
320 authorizer::DROP_TABLE => Self::DropTable,
321 authorizer::DROP_TEMP_INDEX => Self::DropTempIndex,
322 authorizer::DROP_TEMP_TABLE => Self::DropTempTable,
323 authorizer::DROP_TEMP_TRIGGER => Self::DropTempTrigger,
324 authorizer::DROP_TEMP_VIEW => Self::DropTempView,
325 authorizer::DROP_TRIGGER => Self::DropTrigger,
326 authorizer::DROP_VIEW => Self::DropView,
327 authorizer::INSERT => Self::Insert,
328 authorizer::PRAGMA => Self::Pragma,
329 authorizer::READ => Self::Read,
330 authorizer::SELECT => Self::Select,
331 authorizer::TRANSACTION => Self::Transaction,
332 authorizer::UPDATE => Self::Update,
333 authorizer::ATTACH => Self::Attach,
334 authorizer::DETACH => Self::Detach,
335 authorizer::ALTER_TABLE => Self::AlterTable,
336 authorizer::REINDEX => Self::Reindex,
337 authorizer::ANALYZE => Self::Analyze,
338 authorizer::CREATE_VTABLE => Self::CreateVTable,
339 authorizer::DROP_VTABLE => Self::DropVTable,
340 authorizer::FUNCTION => Self::Function,
341 authorizer::SAVEPOINT => Self::Savepoint,
342 authorizer::RECURSIVE => Self::Recursive,
343 other => Self::Unknown(other),
344 }
345 }
346}
347
348#[derive(Clone, Copy, Debug, PartialEq, Eq)]
350pub enum AuthorizerResult {
351 Ok,
353 Ignore,
355 Deny,
357}
358
359impl AuthorizerResult {
360 pub fn into_code(self) -> i32 {
362 match self {
363 Self::Ok => 0,
364 Self::Ignore => 2,
365 Self::Deny => 1,
366 }
367 }
368}
369
370pub struct AuthorizerEvent<'a> {
372 pub action: AuthorizerAction,
374 pub code: i32,
376 pub arg1: Option<&'a str>,
378 pub arg2: Option<&'a str>,
380 pub db_name: Option<&'a str>,
382 pub trigger_or_view: Option<&'a str>,
384}
385
386struct AuthorizerState {
387 cb: Box<dyn for<'a> FnMut(AuthorizerEvent<'a>) -> AuthorizerResult + Send>,
388}
389
390extern "C" fn authorizer_trampoline(
391 ctx: *mut c_void,
392 action: i32,
393 arg1: *const c_char,
394 arg2: *const c_char,
395 db_name: *const c_char,
396 trigger_or_view: *const c_char,
397) -> i32 {
398 if ctx.is_null() {
399 return AuthorizerResult::Deny.into_code();
400 }
401 let state = unsafe { &mut *(ctx as *mut AuthorizerState) };
402 let mut out = AuthorizerResult::Deny.into_code();
404 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
405 let event = AuthorizerEvent {
406 action: AuthorizerAction::from_code(action),
407 code: action,
408 arg1: unsafe { cstr_to_opt(arg1) },
409 arg2: unsafe { cstr_to_opt(arg2) },
410 db_name: unsafe { cstr_to_opt(db_name) },
411 trigger_or_view: unsafe { cstr_to_opt(trigger_or_view) },
412 };
413 out = (state.cb)(event).into_code();
414 }));
415 out
416}
417
418unsafe fn cstr_to_opt<'a>(ptr: *const c_char) -> Option<&'a str> {
419 if ptr.is_null() {
420 return None;
421 }
422 unsafe { core::ffi::CStr::from_ptr(ptr) }.to_str().ok()
423}
424
425extern "C" fn progress_trampoline(ctx: *mut c_void) -> i32 {
426 if ctx.is_null() {
427 return 0;
428 }
429 let state = unsafe { &mut *(ctx as *mut ProgressState) };
430 let mut out = 1;
432 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
433 out = (state.cb)();
434 }));
435 out
436}
437
438type CallbackRegistry = HashMap<CallbackRegistryKey, usize>;
439
440#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
441struct CallbackRegistryKey {
442 db_addr: usize,
443 kind: CallbackKind,
444}
445
446fn callback_registry() -> &'static Mutex<CallbackRegistry> {
447 static REGISTRY: OnceLock<Mutex<CallbackRegistry>> = OnceLock::new();
448 REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
449}
450
451fn lock_callback_registry() -> MutexGuard<'static, CallbackRegistry> {
452 callback_registry()
453 .lock()
454 .unwrap_or_else(|poisoned| poisoned.into_inner())
455}
456
457fn callback_registry_key<T>(db: NonNull<T>, kind: CallbackKind) -> CallbackRegistryKey {
458 CallbackRegistryKey {
459 db_addr: db.as_ptr() as usize,
460 kind,
461 }
462}
463
464#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
465enum CallbackKind {
466 Trace,
467 Authorizer,
468 Progress,
469}
470
471pub struct CallbackHandle<'c, 'p, P: Sqlite3Hooks> {
476 api: &'p P,
477 db: NonNull<P::Db>,
478 kind: CallbackKind,
479 ctx: *mut c_void,
480 _conn: PhantomData<&'c Connection<'p, P>>,
481}
482
483unsafe fn drop_callback_context<P: Sqlite3Hooks>(kind: CallbackKind, ctx: *mut c_void) {
484 if ctx.is_null() {
485 return;
486 }
487 match kind {
488 CallbackKind::Trace => unsafe { drop(Box::from_raw(ctx as *mut TraceState<P>)) },
489 CallbackKind::Authorizer => unsafe { drop(Box::from_raw(ctx as *mut AuthorizerState)) },
490 CallbackKind::Progress => unsafe { drop(Box::from_raw(ctx as *mut ProgressState)) },
491 }
492}
493
494unsafe fn unregister_callback<P: Sqlite3Hooks>(
495 api: &P,
496 db: NonNull<P::Db>,
497 kind: CallbackKind,
498) -> Result<()> {
499 match kind {
500 CallbackKind::Trace => unsafe { api.trace_v2(db, 0, None, core::ptr::null_mut()) },
501 CallbackKind::Authorizer => unsafe { api.set_authorizer(db, None, core::ptr::null_mut()) },
502 CallbackKind::Progress => unsafe {
503 api.progress_handler(db, 0, None, core::ptr::null_mut())
504 },
505 }
506}
507
508impl<'c, 'p, P: Sqlite3Hooks> CallbackHandle<'c, 'p, P> {
509 fn new_trace(conn: &'c Connection<'p, P>, ctx: *mut c_void) -> Self {
510 Self {
511 api: conn.api,
512 db: conn.db,
513 kind: CallbackKind::Trace,
514 ctx,
515 _conn: PhantomData,
516 }
517 }
518
519 fn new_authorizer(conn: &'c Connection<'p, P>, ctx: *mut c_void) -> Self {
520 Self {
521 api: conn.api,
522 db: conn.db,
523 kind: CallbackKind::Authorizer,
524 ctx,
525 _conn: PhantomData,
526 }
527 }
528
529 fn new_progress(conn: &'c Connection<'p, P>, ctx: *mut c_void) -> Self {
530 Self {
531 api: conn.api,
532 db: conn.db,
533 kind: CallbackKind::Progress,
534 ctx,
535 _conn: PhantomData,
536 }
537 }
538}
539
540impl<'c, 'p, P: Sqlite3Hooks> Drop for CallbackHandle<'c, 'p, P> {
541 fn drop(&mut self) {
542 let key = callback_registry_key(self.db, self.kind);
543 let mut registry = lock_callback_registry();
544 let is_active = registry.get(&key).copied() == Some(self.ctx as usize);
545 if !is_active {
546 drop(registry);
547 unsafe { drop_callback_context::<P>(self.kind, self.ctx) };
548 return;
549 }
550
551 if unsafe { unregister_callback(self.api, self.db, self.kind) }.is_ok() {
553 if registry.get(&key).copied() == Some(self.ctx as usize) {
554 registry.remove(&key);
555 }
556 drop(registry);
557 unsafe { drop_callback_context::<P>(self.kind, self.ctx) };
558 }
559 }
560}
561
562impl<'p, P: Sqlite3Hooks> Connection<'p, P> {
563 pub fn busy_timeout(&self, ms: i32) -> Result<()> {
565 unsafe { self.api.busy_timeout(self.db, ms) }
566 }
567
568 pub unsafe fn progress_handler_raw(
574 &self,
575 n: i32,
576 cb: Option<extern "C" fn(*mut c_void) -> i32>,
577 context: *mut c_void,
578 ) -> Result<()> {
579 let key = callback_registry_key(self.db, CallbackKind::Progress);
580 let mut registry = lock_callback_registry();
581 let out = unsafe { self.api.progress_handler(self.db, n, cb, context) };
582 if out.is_ok() {
583 registry.remove(&key);
584 }
585 out
586 }
587
588 pub fn register_progress_handler<'c, F>(
593 &'c self,
594 n: i32,
595 f: F,
596 ) -> Result<CallbackHandle<'c, 'p, P>>
597 where
598 F: FnMut() -> i32 + Send + 'static,
599 {
600 let state = Box::new(ProgressState { cb: Box::new(f) });
601 let ctx = Box::into_raw(state) as *mut c_void;
602 let key = callback_registry_key(self.db, CallbackKind::Progress);
603 let mut registry = lock_callback_registry();
604 if let Err(err) = unsafe {
605 self.api
606 .progress_handler(self.db, n, Some(progress_trampoline), ctx)
607 } {
608 unsafe { drop(Box::from_raw(ctx as *mut ProgressState)) };
609 return Err(err);
610 }
611 registry.insert(key, ctx as usize);
612 Ok(CallbackHandle::new_progress(self, ctx))
613 }
614
615 pub fn clear_progress_handler(&self) -> Result<()> {
617 let key = callback_registry_key(self.db, CallbackKind::Progress);
618 let mut registry = lock_callback_registry();
619 let out = unsafe {
620 self.api
621 .progress_handler(self.db, 0, None, core::ptr::null_mut())
622 };
623 if out.is_ok() {
624 registry.remove(&key);
625 }
626 out
627 }
628
629 pub fn register_trace<'c, F>(
631 &'c self,
632 mask: TraceMask,
633 f: F,
634 ) -> Result<CallbackHandle<'c, 'p, P>>
635 where
636 F: for<'a> FnMut(TraceEvent<'a, P>) + Send + 'static,
637 {
638 let state = Box::new(TraceState::<P> { cb: Box::new(f) });
639 let ctx = Box::into_raw(state) as *mut c_void;
640 let key = callback_registry_key(self.db, CallbackKind::Trace);
641 let mut registry = lock_callback_registry();
642 if let Err(err) = unsafe {
643 self.api
644 .trace_v2(self.db, mask.bits(), Some(trace_trampoline::<P>), ctx)
645 } {
646 unsafe { drop(Box::from_raw(ctx as *mut TraceState<P>)) };
647 return Err(err);
648 }
649 registry.insert(key, ctx as usize);
650 Ok(CallbackHandle::new_trace(self, ctx))
651 }
652
653 pub fn register_authorizer<'c, F>(&'c self, f: F) -> Result<CallbackHandle<'c, 'p, P>>
655 where
656 F: FnMut(AuthorizerEvent<'_>) -> AuthorizerResult + Send + 'static,
657 {
658 let state = Box::new(AuthorizerState { cb: Box::new(f) });
659 let ctx = Box::into_raw(state) as *mut c_void;
660 let key = callback_registry_key(self.db, CallbackKind::Authorizer);
661 let mut registry = lock_callback_registry();
662 if let Err(err) = unsafe {
663 self.api
664 .set_authorizer(self.db, Some(authorizer_trampoline), ctx)
665 } {
666 unsafe { drop(Box::from_raw(ctx as *mut AuthorizerState)) };
667 return Err(err);
668 }
669 registry.insert(key, ctx as usize);
670 Ok(CallbackHandle::new_authorizer(self, ctx))
671 }
672}