1use core::ffi::{c_char, c_void};
2use core::ptr::NonNull;
3
4use crate::error::{Error, Result};
5use crate::provider::{
6 ColumnMetadata, FeatureSet, OpenOptions, OwnedBytes, RawBytes, Sqlite3Api, Sqlite3Backup,
7 Sqlite3BlobIo, Sqlite3Hooks, Sqlite3Keying, Sqlite3Metadata, Sqlite3Serialize, Sqlite3Wal,
8};
9use crate::statement::Statement;
10
11pub struct Connection<'p, P: Sqlite3Api> {
13 pub(crate) api: &'p P,
14 pub(crate) db: NonNull<P::Db>,
15}
16
17impl<'p, P: Sqlite3Api> Connection<'p, P> {
18 pub fn open(api: &'p P, filename: &str, options: OpenOptions<'_>) -> Result<Self> {
20 let db = unsafe { api.open(filename, options)? };
21 Ok(Self { api, db })
22 }
23
24 pub fn prepare(&self, sql: &str) -> Result<Statement<'_, 'p, P>> {
26 let stmt = unsafe {
27 if self.api.feature_set().contains(FeatureSet::PREPARE_V3) {
28 self.api.prepare_v3(self.db, sql, 0)?
29 } else {
30 self.api.prepare_v2(self.db, sql)?
31 }
32 };
33 Ok(Statement::new(self, stmt))
34 }
35
36 pub fn prepare_with_flags(&self, sql: &str, flags: u32) -> Result<Statement<'_, 'p, P>> {
38 if !self.api.feature_set().contains(FeatureSet::PREPARE_V3) {
39 return Err(Error::feature_unavailable("prepare_v3 unsupported"));
40 }
41 let stmt = unsafe { self.api.prepare_v3(self.db, sql, flags)? };
42 Ok(Statement::new(self, stmt))
43 }
44
45 pub fn raw_handle(&self) -> NonNull<P::Db> {
47 self.db
48 }
49}
50
51impl<'p, P: Sqlite3Keying> Connection<'p, P> {
52 pub fn open_with_key(
54 api: &'p P,
55 filename: &str,
56 options: OpenOptions<'_>,
57 key: &[u8],
58 ) -> Result<Self> {
59 let db = unsafe { api.open(filename, options)? };
60 if let Err(err) = unsafe { api.key(db, key) } {
61 let _ = unsafe { api.close(db) };
62 return Err(err);
63 }
64 Ok(Self { api, db })
65 }
66
67 pub fn rekey(&self, key: &[u8]) -> Result<()> {
69 unsafe { self.api.rekey(self.db, key) }
70 }
71}
72
73impl<'p, P: Sqlite3Api> Drop for Connection<'p, P> {
74 fn drop(&mut self) {
75 let _ = unsafe { self.api.close(self.db) };
76 }
77}
78
79#[derive(Clone, Copy, Debug, PartialEq, Eq)]
81pub struct TraceMask {
82 bits: u32,
83}
84
85impl TraceMask {
86 pub const STMT: TraceMask = TraceMask { bits: 0x01 };
87 pub const PROFILE: TraceMask = TraceMask { bits: 0x02 };
88 pub const ROW: TraceMask = TraceMask { bits: 0x04 };
89 pub const CLOSE: TraceMask = TraceMask { bits: 0x08 };
90
91 pub const fn empty() -> Self {
92 Self { bits: 0 }
93 }
94
95 pub const fn bits(self) -> u32 {
96 self.bits
97 }
98
99 pub const fn contains(self, other: TraceMask) -> bool {
100 (self.bits & other.bits) == other.bits
101 }
102}
103
104impl core::ops::BitOr for TraceMask {
105 type Output = TraceMask;
106
107 fn bitor(self, rhs: TraceMask) -> TraceMask {
108 TraceMask { bits: self.bits | rhs.bits }
109 }
110}
111
112impl core::ops::BitOrAssign for TraceMask {
113 fn bitor_assign(&mut self, rhs: TraceMask) {
114 self.bits |= rhs.bits;
115 }
116}
117
118pub enum TraceEvent<'a, P: Sqlite3Api> {
120 Stmt { stmt: NonNull<P::Stmt>, sql: Option<&'a str> },
121 Profile { stmt: NonNull<P::Stmt>, nsec: i64 },
122 Row { stmt: NonNull<P::Stmt> },
123 Close { db: NonNull<P::Db> },
124 Raw { mask: u32, p1: *mut c_void, p2: *mut c_void },
125}
126
127type TraceCallback<P> = dyn for<'a> FnMut(TraceEvent<'a, P>) + Send;
128
129struct TraceState<P: Sqlite3Api> {
130 cb: Box<TraceCallback<P>>,
131}
132
133extern "C" fn trace_trampoline<P: Sqlite3Api>(
134 mask: u32,
135 ctx: *mut c_void,
136 p1: *mut c_void,
137 p2: *mut c_void,
138) {
139 let state = unsafe { &mut *(ctx as *mut TraceState<P>) };
140 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
141 let event = decode_trace::<P>(mask, p1, p2);
142 (state.cb)(event);
143 }));
144}
145
146fn decode_trace<'a, P: Sqlite3Api>(
147 mask: u32,
148 p1: *mut c_void,
149 p2: *mut c_void,
150) -> TraceEvent<'a, P> {
151 if (mask & TraceMask::STMT.bits()) != 0 {
152 let stmt = match NonNull::new(p1 as *mut P::Stmt) {
153 Some(stmt) => stmt,
154 None => return TraceEvent::Raw { mask, p1, p2 },
155 };
156 let sql = unsafe { cstr_to_opt(p2 as *const c_char) };
157 return TraceEvent::Stmt { stmt, sql };
158 }
159 if (mask & TraceMask::PROFILE.bits()) != 0 {
160 let stmt = match NonNull::new(p1 as *mut P::Stmt) {
161 Some(stmt) => stmt,
162 None => return TraceEvent::Raw { mask, p1, p2 },
163 };
164 if p2.is_null() {
165 return TraceEvent::Raw { mask, p1, p2 };
166 }
167 let nsec = unsafe { *(p2 as *const i64) };
168 return TraceEvent::Profile { stmt, nsec };
169 }
170 if (mask & TraceMask::ROW.bits()) != 0 {
171 let stmt = match NonNull::new(p1 as *mut P::Stmt) {
172 Some(stmt) => stmt,
173 None => return TraceEvent::Raw { mask, p1, p2 },
174 };
175 return TraceEvent::Row { stmt };
176 }
177 if (mask & TraceMask::CLOSE.bits()) != 0 {
178 let db = match NonNull::new(p1 as *mut P::Db) {
179 Some(db) => db,
180 None => return TraceEvent::Raw { mask, p1, p2 },
181 };
182 return TraceEvent::Close { db };
183 }
184 TraceEvent::Raw { mask, p1, p2 }
185}
186
187pub mod authorizer {
189 pub const CREATE_INDEX: i32 = 1;
190 pub const CREATE_TABLE: i32 = 2;
191 pub const CREATE_TEMP_INDEX: i32 = 3;
192 pub const CREATE_TEMP_TABLE: i32 = 4;
193 pub const CREATE_TEMP_TRIGGER: i32 = 5;
194 pub const CREATE_TEMP_VIEW: i32 = 6;
195 pub const CREATE_TRIGGER: i32 = 7;
196 pub const CREATE_VIEW: i32 = 8;
197 pub const DELETE: i32 = 9;
198 pub const DROP_INDEX: i32 = 10;
199 pub const DROP_TABLE: i32 = 11;
200 pub const DROP_TEMP_INDEX: i32 = 12;
201 pub const DROP_TEMP_TABLE: i32 = 13;
202 pub const DROP_TEMP_TRIGGER: i32 = 14;
203 pub const DROP_TEMP_VIEW: i32 = 15;
204 pub const DROP_TRIGGER: i32 = 16;
205 pub const DROP_VIEW: i32 = 17;
206 pub const INSERT: i32 = 18;
207 pub const PRAGMA: i32 = 19;
208 pub const READ: i32 = 20;
209 pub const SELECT: i32 = 21;
210 pub const TRANSACTION: i32 = 22;
211 pub const UPDATE: i32 = 23;
212 pub const ATTACH: i32 = 24;
213 pub const DETACH: i32 = 25;
214 pub const ALTER_TABLE: i32 = 26;
215 pub const REINDEX: i32 = 27;
216 pub const ANALYZE: i32 = 28;
217 pub const CREATE_VTABLE: i32 = 29;
218 pub const DROP_VTABLE: i32 = 30;
219 pub const FUNCTION: i32 = 31;
220 pub const SAVEPOINT: i32 = 32;
221 pub const RECURSIVE: i32 = 33;
222}
223
224#[derive(Clone, Copy, Debug, PartialEq, Eq)]
226pub enum AuthorizerAction {
227 CreateIndex,
228 CreateTable,
229 CreateTempIndex,
230 CreateTempTable,
231 CreateTempTrigger,
232 CreateTempView,
233 CreateTrigger,
234 CreateView,
235 Delete,
236 DropIndex,
237 DropTable,
238 DropTempIndex,
239 DropTempTable,
240 DropTempTrigger,
241 DropTempView,
242 DropTrigger,
243 DropView,
244 Insert,
245 Pragma,
246 Read,
247 Select,
248 Transaction,
249 Update,
250 Attach,
251 Detach,
252 AlterTable,
253 Reindex,
254 Analyze,
255 CreateVTable,
256 DropVTable,
257 Function,
258 Savepoint,
259 Recursive,
260 Unknown(i32),
261}
262
263impl AuthorizerAction {
264 pub fn from_code(code: i32) -> Self {
265 match code {
266 authorizer::CREATE_INDEX => Self::CreateIndex,
267 authorizer::CREATE_TABLE => Self::CreateTable,
268 authorizer::CREATE_TEMP_INDEX => Self::CreateTempIndex,
269 authorizer::CREATE_TEMP_TABLE => Self::CreateTempTable,
270 authorizer::CREATE_TEMP_TRIGGER => Self::CreateTempTrigger,
271 authorizer::CREATE_TEMP_VIEW => Self::CreateTempView,
272 authorizer::CREATE_TRIGGER => Self::CreateTrigger,
273 authorizer::CREATE_VIEW => Self::CreateView,
274 authorizer::DELETE => Self::Delete,
275 authorizer::DROP_INDEX => Self::DropIndex,
276 authorizer::DROP_TABLE => Self::DropTable,
277 authorizer::DROP_TEMP_INDEX => Self::DropTempIndex,
278 authorizer::DROP_TEMP_TABLE => Self::DropTempTable,
279 authorizer::DROP_TEMP_TRIGGER => Self::DropTempTrigger,
280 authorizer::DROP_TEMP_VIEW => Self::DropTempView,
281 authorizer::DROP_TRIGGER => Self::DropTrigger,
282 authorizer::DROP_VIEW => Self::DropView,
283 authorizer::INSERT => Self::Insert,
284 authorizer::PRAGMA => Self::Pragma,
285 authorizer::READ => Self::Read,
286 authorizer::SELECT => Self::Select,
287 authorizer::TRANSACTION => Self::Transaction,
288 authorizer::UPDATE => Self::Update,
289 authorizer::ATTACH => Self::Attach,
290 authorizer::DETACH => Self::Detach,
291 authorizer::ALTER_TABLE => Self::AlterTable,
292 authorizer::REINDEX => Self::Reindex,
293 authorizer::ANALYZE => Self::Analyze,
294 authorizer::CREATE_VTABLE => Self::CreateVTable,
295 authorizer::DROP_VTABLE => Self::DropVTable,
296 authorizer::FUNCTION => Self::Function,
297 authorizer::SAVEPOINT => Self::Savepoint,
298 authorizer::RECURSIVE => Self::Recursive,
299 other => Self::Unknown(other),
300 }
301 }
302}
303
304#[derive(Clone, Copy, Debug, PartialEq, Eq)]
306pub enum AuthorizerResult {
307 Ok,
308 Ignore,
309 Deny,
310}
311
312impl AuthorizerResult {
313 pub fn into_code(self) -> i32 {
314 match self {
315 Self::Ok => 0,
316 Self::Ignore => 1,
317 Self::Deny => 2,
318 }
319 }
320}
321
322pub struct AuthorizerEvent<'a> {
324 pub action: AuthorizerAction,
325 pub code: i32,
326 pub arg1: Option<&'a str>,
327 pub arg2: Option<&'a str>,
328 pub db_name: Option<&'a str>,
329 pub trigger_or_view: Option<&'a str>,
330}
331
332struct AuthorizerState {
333 cb: Box<dyn for<'a> FnMut(AuthorizerEvent<'a>) -> AuthorizerResult + Send>,
334}
335
336extern "C" fn authorizer_trampoline(
337 ctx: *mut c_void,
338 action: i32,
339 arg1: *const c_char,
340 arg2: *const c_char,
341 db_name: *const c_char,
342 trigger_or_view: *const c_char,
343) -> i32 {
344 let state = unsafe { &mut *(ctx as *mut AuthorizerState) };
345 let mut out = 0;
346 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
347 let event = AuthorizerEvent {
348 action: AuthorizerAction::from_code(action),
349 code: action,
350 arg1: unsafe { cstr_to_opt(arg1) },
351 arg2: unsafe { cstr_to_opt(arg2) },
352 db_name: unsafe { cstr_to_opt(db_name) },
353 trigger_or_view: unsafe { cstr_to_opt(trigger_or_view) },
354 };
355 out = (state.cb)(event).into_code();
356 }));
357 out
358}
359
360unsafe fn cstr_to_opt<'a>(ptr: *const c_char) -> Option<&'a str> {
361 if ptr.is_null() {
362 return None;
363 }
364 unsafe { core::ffi::CStr::from_ptr(ptr) }.to_str().ok()
365}
366
367pub struct CallbackHandle<'p, P: Sqlite3Hooks> {
369 api: &'p P,
370 db: NonNull<P::Db>,
371 kind: CallbackKind,
372 ctx: *mut c_void,
373}
374
375enum CallbackKind {
376 Trace,
377 Authorizer,
378}
379
380impl<'p, P: Sqlite3Hooks> CallbackHandle<'p, P> {
381 fn new_trace(api: &'p P, db: NonNull<P::Db>, ctx: *mut c_void) -> Self {
382 Self { api, db, kind: CallbackKind::Trace, ctx }
383 }
384
385 fn new_authorizer(api: &'p P, db: NonNull<P::Db>, ctx: *mut c_void) -> Self {
386 Self { api, db, kind: CallbackKind::Authorizer, ctx }
387 }
388}
389
390impl<'p, P: Sqlite3Hooks> Drop for CallbackHandle<'p, P> {
391 fn drop(&mut self) {
392 unsafe {
393 match self.kind {
394 CallbackKind::Trace => {
395 let _ = self.api.trace_v2(self.db, 0, None, core::ptr::null_mut());
396 drop(Box::from_raw(self.ctx as *mut TraceState<P>));
397 }
398 CallbackKind::Authorizer => {
399 let _ = self.api.set_authorizer(self.db, None, core::ptr::null_mut());
400 drop(Box::from_raw(self.ctx as *mut AuthorizerState));
401 }
402 }
403 }
404 }
405}
406
407impl<'p, P: Sqlite3Hooks> Connection<'p, P> {
408 pub fn busy_timeout(&self, ms: i32) -> Result<()> {
410 unsafe { self.api.busy_timeout(self.db, ms) }
411 }
412
413 pub unsafe fn progress_handler(
419 &self,
420 n: i32,
421 cb: Option<extern "C" fn() -> i32>,
422 ctx: *mut c_void,
423 ) -> Result<()> {
424 unsafe { self.api.progress_handler(self.db, n, cb, ctx) }
425 }
426
427 pub fn register_trace<F>(&self, mask: TraceMask, f: F) -> Result<CallbackHandle<'p, P>>
429 where
430 F: for<'a> FnMut(TraceEvent<'a, P>) + Send + 'static,
431 {
432 let state = Box::new(TraceState::<P> { cb: Box::new(f) });
433 let ctx = Box::into_raw(state) as *mut c_void;
434 unsafe { self.api.trace_v2(self.db, mask.bits(), Some(trace_trampoline::<P>), ctx)? };
435 Ok(CallbackHandle::new_trace(self.api, self.db, ctx))
436 }
437
438 pub fn register_authorizer<F>(&self, f: F) -> Result<CallbackHandle<'p, P>>
440 where
441 F: FnMut(AuthorizerEvent<'_>) -> AuthorizerResult + Send + 'static,
442 {
443 let state = Box::new(AuthorizerState { cb: Box::new(f) });
444 let ctx = Box::into_raw(state) as *mut c_void;
445 unsafe { self.api.set_authorizer(self.db, Some(authorizer_trampoline), ctx)? };
446 Ok(CallbackHandle::new_authorizer(self.api, self.db, ctx))
447 }
448}
449
450pub struct Backup<'p, P: Sqlite3Backup> {
452 api: &'p P,
453 handle: NonNull<P::Backup>,
454}
455
456impl<'p, P: Sqlite3Backup> Backup<'p, P> {
457 pub fn step(&self, pages: i32) -> Result<()> {
459 unsafe { self.api.backup_step(self.handle, pages) }
460 }
461
462 pub fn remaining(&self) -> i32 {
464 unsafe { self.api.backup_remaining(self.handle) }
465 }
466
467 pub fn pagecount(&self) -> i32 {
469 unsafe { self.api.backup_pagecount(self.handle) }
470 }
471}
472
473impl<'p, P: Sqlite3Backup> Drop for Backup<'p, P> {
474 fn drop(&mut self) {
475 let _ = unsafe { self.api.backup_finish(self.handle) };
476 }
477}
478
479impl<'p, P: Sqlite3Backup> Connection<'p, P> {
480 pub fn backup_to(&self, dest: &Connection<'p, P>, name: &str) -> Result<Backup<'p, P>> {
482 let handle = unsafe { self.api.backup_init(dest.db, name, self.db, "main")? };
483 Ok(Backup { api: self.api, handle })
484 }
485}
486
487pub struct Blob<'p, P: Sqlite3BlobIo> {
489 api: &'p P,
490 handle: NonNull<P::Blob>,
491}
492
493impl<'p, P: Sqlite3BlobIo> Blob<'p, P> {
494 pub fn read(&self, buf: &mut [u8], offset: i32) -> Result<()> {
496 unsafe { self.api.blob_read(self.handle, buf, offset) }
497 }
498
499 pub fn write(&self, buf: &[u8], offset: i32) -> Result<()> {
501 unsafe { self.api.blob_write(self.handle, buf, offset) }
502 }
503
504 pub fn len(&self) -> i32 {
506 unsafe { self.api.blob_bytes(self.handle) }
507 }
508
509 pub fn is_empty(&self) -> bool {
511 self.len() == 0
512 }
513}
514
515impl<'p, P: Sqlite3BlobIo> Drop for Blob<'p, P> {
516 fn drop(&mut self) {
517 let _ = unsafe { self.api.blob_close(self.handle) };
518 }
519}
520
521impl<'p, P: Sqlite3BlobIo> Connection<'p, P> {
522 pub fn open_blob(
524 &self,
525 db_name: &str,
526 table: &str,
527 column: &str,
528 rowid: i64,
529 flags: u32,
530 ) -> Result<Blob<'p, P>> {
531 let handle = unsafe { self.api.blob_open(self.db, db_name, table, column, rowid, flags)? };
532 Ok(Blob { api: self.api, handle })
533 }
534}
535
536pub struct SerializedDb<'p, P: Sqlite3Serialize> {
538 api: &'p P,
539 bytes: OwnedBytes,
540}
541
542impl<'p, P: Sqlite3Serialize> SerializedDb<'p, P> {
543 pub fn as_slice(&self) -> &[u8] {
545 unsafe { core::slice::from_raw_parts(self.bytes.ptr.as_ptr(), self.bytes.len) }
546 }
547
548 pub fn into_vec(self) -> Vec<u8> {
550 let me = core::mem::ManuallyDrop::new(self);
551 let vec = me.as_slice().to_vec();
552 unsafe { me.api.free(me.bytes) };
553 vec
554 }
555}
556
557impl<'p, P: Sqlite3Serialize> Drop for SerializedDb<'p, P> {
558 fn drop(&mut self) {
559 unsafe { self.api.free(self.bytes) };
560 }
561}
562
563impl<'p, P: Sqlite3Serialize> Connection<'p, P> {
564 pub fn serialize(&self, schema: Option<&str>, flags: u32) -> Result<SerializedDb<'p, P>> {
566 let bytes = unsafe { self.api.serialize(self.db, schema, flags)? };
567 Ok(SerializedDb { api: self.api, bytes })
568 }
569
570 pub fn deserialize(&self, schema: Option<&str>, data: &[u8], flags: u32) -> Result<()> {
572 unsafe { self.api.deserialize(self.db, schema, data, flags) }
573 }
574}
575
576impl<'p, P: Sqlite3Wal> Connection<'p, P> {
577 pub fn wal_checkpoint(&self, db_name: Option<&str>) -> Result<()> {
579 unsafe { self.api.wal_checkpoint(self.db, db_name) }
580 }
581
582 pub fn wal_checkpoint_v2(&self, db_name: Option<&str>, mode: i32) -> Result<(i32, i32)> {
584 unsafe { self.api.wal_checkpoint_v2(self.db, db_name, mode) }
585 }
586
587 pub fn wal_frame_count(&self) -> Result<Option<u32>> {
589 unsafe { self.api.wal_frame_count(self.db) }
590 }
591}
592
593impl<'p, P: Sqlite3Metadata> Connection<'p, P> {
594 pub fn table_column_metadata(
596 &self,
597 db_name: Option<&str>,
598 table: &str,
599 column: &str,
600 ) -> Result<ColumnMetadata> {
601 unsafe { self.api.table_column_metadata(self.db, db_name, table, column) }
602 }
603}
604
605impl<'p, P: Sqlite3Metadata> Statement<'_, 'p, P> {
606 pub fn column_decltype_raw(&self, col: i32) -> Option<RawBytes> {
608 unsafe { self.conn.api.column_decltype(self.stmt, col) }
609 }
610
611 pub fn column_name_raw(&self, col: i32) -> Option<RawBytes> {
613 unsafe { self.conn.api.column_name(self.stmt, col) }
614 }
615
616 pub fn column_table_name_raw(&self, col: i32) -> Option<RawBytes> {
618 unsafe { self.conn.api.column_table_name(self.stmt, col) }
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::authorizer;
625 use super::{AuthorizerAction, AuthorizerResult, TraceMask};
626
627 #[test]
628 fn trace_mask_bits() {
629 let mask = TraceMask::STMT | TraceMask::PROFILE;
630 assert!(mask.contains(TraceMask::STMT));
631 assert!(mask.contains(TraceMask::PROFILE));
632 }
633
634 #[test]
635 fn authorizer_action_from_code() {
636 assert_eq!(AuthorizerAction::from_code(authorizer::READ), AuthorizerAction::Read);
637 assert_eq!(AuthorizerAction::from_code(999), AuthorizerAction::Unknown(999));
638 }
639
640 #[test]
641 fn authorizer_result_codes() {
642 assert_eq!(AuthorizerResult::Ok.into_code(), 0);
643 assert_eq!(AuthorizerResult::Ignore.into_code(), 1);
644 assert_eq!(AuthorizerResult::Deny.into_code(), 2);
645 }
646}