Skip to main content

sqlite_provider_sqlite3/
core_impl.rs

1use super::*;
2
3fn drop_registration_user_data(
4    user_data: *mut c_void,
5    drop_user_data: Option<extern "C" fn(*mut c_void)>,
6) {
7    if let Some(drop_user_data) = drop_user_data {
8        drop_user_data(user_data);
9    }
10}
11
12fn registration_name_or_drop(
13    name: &str,
14    user_data: *mut c_void,
15    drop_user_data: Option<extern "C" fn(*mut c_void)>,
16    error_message: &'static str,
17) -> Result<CString> {
18    match CString::new(name) {
19        Ok(name) => Ok(name),
20        Err(_) => {
21            drop_registration_user_data(user_data, drop_user_data);
22            Err(Error::with_message(ErrorCode::Misuse, error_message))
23        }
24    }
25}
26
27#[allow(unsafe_op_in_unsafe_fn)]
28unsafe impl Sqlite3Api for LibSqlite3 {
29    type Db = sqlite3;
30    type Stmt = sqlite3_stmt;
31    type Value = sqlite3_value;
32    type Context = sqlite3_context;
33    type VTab = c_void;
34    type VTabCursor = c_void;
35
36    fn api_version(&self) -> ApiVersion {
37        self.api_version
38    }
39
40    fn feature_set(&self) -> FeatureSet {
41        self.features
42    }
43
44    fn backend_name(&self) -> &'static str {
45        "libsqlite3"
46    }
47
48    fn backend_version(&self) -> Option<ApiVersion> {
49        Some(self.api_version)
50    }
51
52    unsafe fn malloc(&self, size: usize) -> *mut c_void {
53        if size > i32::MAX as usize {
54            return null_mut();
55        }
56        (self.fns.malloc)(size as c_int)
57    }
58
59    unsafe fn free(&self, ptr: *mut c_void) {
60        (self.fns.free)(ptr);
61    }
62
63    fn threadsafe(&self) -> i32 {
64        self.fns.threadsafe.map(|f| unsafe { f() }).unwrap_or(0)
65    }
66
67    unsafe fn open(&self, filename: &str, options: OpenOptions<'_>) -> Result<NonNull<Self::Db>> {
68        let filename = CString::new(filename)
69            .map_err(|_| Error::with_message(ErrorCode::Misuse, "filename contains NUL"))?;
70        let vfs = match options.vfs {
71            Some(vfs) => Some(
72                CString::new(vfs)
73                    .map_err(|_| Error::with_message(ErrorCode::Misuse, "vfs contains NUL"))?,
74            ),
75            None => None,
76        };
77        let mut db = null_mut();
78        let flags = map_open_flags(options.flags);
79        let vfs_ptr = vfs.as_ref().map(|s| s.as_ptr()).unwrap_or(null());
80        let rc = (self.fns.open_v2)(filename.as_ptr(), &mut db, flags, vfs_ptr);
81        if rc != SQLITE_OK {
82            let err = self.error_from_rc(rc, NonNull::new(db));
83            if !db.is_null() {
84                let _ = (self.fns.close)(db);
85            }
86            return Err(err);
87        }
88        Ok(NonNull::new_unchecked(db))
89    }
90
91    unsafe fn close(&self, db: NonNull<Self::Db>) -> Result<()> {
92        let rc = (self.fns.close)(db.as_ptr());
93        if rc == SQLITE_OK {
94            Ok(())
95        } else {
96            Err(self.error_from_rc(rc, Some(db)))
97        }
98    }
99
100    unsafe fn prepare_v2(&self, db: NonNull<Self::Db>, sql: &str) -> Result<NonNull<Self::Stmt>> {
101        let mut stmt = null_mut();
102        let sql_ptr = sql.as_ptr() as *const c_char;
103        let sql_len = clamp_len(sql.len());
104        let rc = (self.fns.prepare_v2)(db.as_ptr(), sql_ptr, sql_len, &mut stmt, null_mut());
105        if rc != SQLITE_OK {
106            return Err(self.error_from_rc(rc, Some(db)));
107        }
108        Ok(NonNull::new_unchecked(stmt))
109    }
110
111    unsafe fn prepare_v3(
112        &self,
113        db: NonNull<Self::Db>,
114        sql: &str,
115        flags: u32,
116    ) -> Result<NonNull<Self::Stmt>> {
117        let prepare = match self.fns.prepare_v3 {
118            Some(prepare) => prepare,
119            None => return Err(Error::feature_unavailable("prepare_v3 not available")),
120        };
121        let mut stmt = null_mut();
122        let sql_ptr = sql.as_ptr() as *const c_char;
123        let sql_len = clamp_len(sql.len());
124        let rc = prepare(db.as_ptr(), sql_ptr, sql_len, flags, &mut stmt, null_mut());
125        if rc != SQLITE_OK {
126            return Err(self.error_from_rc(rc, Some(db)));
127        }
128        Ok(NonNull::new_unchecked(stmt))
129    }
130
131    unsafe fn step(&self, stmt: NonNull<Self::Stmt>) -> Result<StepResult> {
132        match (self.fns.step)(stmt.as_ptr()) {
133            SQLITE_ROW => Ok(StepResult::Row),
134            SQLITE_DONE => Ok(StepResult::Done),
135            rc => Err(self.error_from_rc(rc, None)),
136        }
137    }
138
139    unsafe fn reset(&self, stmt: NonNull<Self::Stmt>) -> Result<()> {
140        let rc = (self.fns.reset)(stmt.as_ptr());
141        if rc == SQLITE_OK {
142            Ok(())
143        } else {
144            Err(self.error_from_rc(rc, None))
145        }
146    }
147
148    unsafe fn finalize(&self, stmt: NonNull<Self::Stmt>) -> Result<()> {
149        let rc = (self.fns.finalize)(stmt.as_ptr());
150        if rc == SQLITE_OK {
151            Ok(())
152        } else {
153            Err(self.error_from_rc(rc, None))
154        }
155    }
156
157    unsafe fn bind_null(&self, stmt: NonNull<Self::Stmt>, idx: i32) -> Result<()> {
158        let rc = (self.fns.bind_null)(stmt.as_ptr(), idx);
159        if rc == SQLITE_OK {
160            Ok(())
161        } else {
162            Err(self.error_from_rc(rc, None))
163        }
164    }
165
166    unsafe fn bind_int64(&self, stmt: NonNull<Self::Stmt>, idx: i32, v: i64) -> Result<()> {
167        let rc = (self.fns.bind_int64)(stmt.as_ptr(), idx, v);
168        if rc == SQLITE_OK {
169            Ok(())
170        } else {
171            Err(self.error_from_rc(rc, None))
172        }
173    }
174
175    unsafe fn bind_double(&self, stmt: NonNull<Self::Stmt>, idx: i32, v: f64) -> Result<()> {
176        let rc = (self.fns.bind_double)(stmt.as_ptr(), idx, v);
177        if rc == SQLITE_OK {
178            Ok(())
179        } else {
180            Err(self.error_from_rc(rc, None))
181        }
182    }
183
184    unsafe fn bind_text(&self, stmt: NonNull<Self::Stmt>, idx: i32, v: &str) -> Result<()> {
185        unsafe { self.bind_text_bytes(stmt, idx, v.as_bytes()) }
186    }
187
188    unsafe fn bind_text_bytes(&self, stmt: NonNull<Self::Stmt>, idx: i32, v: &[u8]) -> Result<()> {
189        let (ptr, dtor) = self.alloc_copy(v)?;
190        let rc = (self.fns.bind_text)(
191            stmt.as_ptr(),
192            idx,
193            ptr as *const c_char,
194            clamp_len(v.len()),
195            dtor,
196        );
197        if rc != SQLITE_OK {
198            return Err(self.error_from_rc(rc, None));
199        }
200        Ok(())
201    }
202
203    unsafe fn bind_blob(&self, stmt: NonNull<Self::Stmt>, idx: i32, v: &[u8]) -> Result<()> {
204        let (ptr, dtor) = self.alloc_copy(v)?;
205        let rc = (self.fns.bind_blob)(stmt.as_ptr(), idx, ptr, clamp_len(v.len()), dtor);
206        if rc != SQLITE_OK {
207            return Err(self.error_from_rc(rc, None));
208        }
209        Ok(())
210    }
211
212    unsafe fn column_count(&self, stmt: NonNull<Self::Stmt>) -> i32 {
213        (self.fns.column_count)(stmt.as_ptr())
214    }
215
216    unsafe fn column_type(&self, stmt: NonNull<Self::Stmt>, col: i32) -> ValueType {
217        ValueType::from_code((self.fns.column_type)(stmt.as_ptr(), col))
218    }
219
220    unsafe fn column_int64(&self, stmt: NonNull<Self::Stmt>, col: i32) -> i64 {
221        (self.fns.column_int64)(stmt.as_ptr(), col)
222    }
223
224    unsafe fn column_double(&self, stmt: NonNull<Self::Stmt>, col: i32) -> f64 {
225        (self.fns.column_double)(stmt.as_ptr(), col)
226    }
227
228    unsafe fn column_text(&self, stmt: NonNull<Self::Stmt>, col: i32) -> RawBytes {
229        let ptr = (self.fns.column_text)(stmt.as_ptr(), col);
230        if ptr.is_null() {
231            return RawBytes::empty();
232        }
233        let len = (self.fns.column_bytes)(stmt.as_ptr(), col);
234        RawBytes {
235            ptr,
236            len: len as usize,
237        }
238    }
239
240    unsafe fn column_blob(&self, stmt: NonNull<Self::Stmt>, col: i32) -> RawBytes {
241        let ptr = (self.fns.column_blob)(stmt.as_ptr(), col) as *const u8;
242        if ptr.is_null() {
243            return RawBytes::empty();
244        }
245        let len = (self.fns.column_bytes)(stmt.as_ptr(), col);
246        RawBytes {
247            ptr,
248            len: len as usize,
249        }
250    }
251
252    unsafe fn errcode(&self, db: NonNull<Self::Db>) -> i32 {
253        (self.fns.errcode)(db.as_ptr())
254    }
255
256    unsafe fn errmsg(&self, db: NonNull<Self::Db>) -> *const c_char {
257        (self.fns.errmsg)(db.as_ptr())
258    }
259
260    unsafe fn extended_errcode(&self, db: NonNull<Self::Db>) -> Option<i32> {
261        self.fns.extended_errcode.map(|f| f(db.as_ptr()))
262    }
263
264    unsafe fn create_function_v2(
265        &self,
266        db: NonNull<Self::Db>,
267        name: &str,
268        n_args: i32,
269        flags: FunctionFlags,
270        x_func: Option<extern "C" fn(*mut Self::Context, i32, *mut *mut Self::Value)>,
271        x_step: Option<extern "C" fn(*mut Self::Context, i32, *mut *mut Self::Value)>,
272        x_final: Option<extern "C" fn(*mut Self::Context)>,
273        user_data: *mut c_void,
274        drop_user_data: Option<extern "C" fn(*mut c_void)>,
275    ) -> Result<()> {
276        let name = registration_name_or_drop(
277            name,
278            user_data,
279            drop_user_data,
280            "function name contains NUL",
281        )?;
282        let flags = map_function_flags(flags);
283        let rc = (self.fns.create_function_v2)(
284            db.as_ptr(),
285            name.as_ptr(),
286            n_args,
287            flags,
288            user_data,
289            x_func,
290            x_step,
291            x_final,
292            drop_user_data,
293        );
294        if rc == SQLITE_OK {
295            Ok(())
296        } else {
297            Err(self.error_from_rc(rc, Some(db)))
298        }
299    }
300
301    unsafe fn create_window_function(
302        &self,
303        db: NonNull<Self::Db>,
304        name: &str,
305        n_args: i32,
306        flags: FunctionFlags,
307        x_step: Option<extern "C" fn(*mut Self::Context, i32, *mut *mut Self::Value)>,
308        x_final: Option<extern "C" fn(*mut Self::Context)>,
309        x_value: Option<extern "C" fn(*mut Self::Context)>,
310        x_inverse: Option<extern "C" fn(*mut Self::Context, i32, *mut *mut Self::Value)>,
311        user_data: *mut c_void,
312        drop_user_data: Option<extern "C" fn(*mut c_void)>,
313    ) -> Result<()> {
314        let create = match self.fns.create_window_function {
315            Some(create) => create,
316            None => {
317                drop_registration_user_data(user_data, drop_user_data);
318                return Err(Error::feature_unavailable(
319                    "create_window_function not available",
320                ));
321            }
322        };
323        let name = registration_name_or_drop(
324            name,
325            user_data,
326            drop_user_data,
327            "function name contains NUL",
328        )?;
329        let flags = map_function_flags(flags);
330        let rc = create(
331            db.as_ptr(),
332            name.as_ptr(),
333            n_args,
334            flags,
335            user_data,
336            x_step,
337            x_final,
338            x_value,
339            x_inverse,
340            drop_user_data,
341        );
342        if rc == SQLITE_OK {
343            Ok(())
344        } else {
345            Err(self.error_from_rc(rc, Some(db)))
346        }
347    }
348
349    unsafe fn create_collation_v2(
350        &self,
351        db: NonNull<Self::Db>,
352        name: &str,
353        enc: i32,
354        context: *mut c_void,
355        cmp: Option<extern "C" fn(*mut c_void, i32, *const c_void, i32, *const c_void) -> i32>,
356        destroy: Option<extern "C" fn(*mut c_void)>,
357    ) -> Result<()> {
358        let create = match self.fns.create_collation_v2 {
359            Some(create) => create,
360            None => {
361                return Err(Error::feature_unavailable(
362                    "create_collation_v2 not available",
363                ));
364            }
365        };
366        let name = CString::new(name)
367            .map_err(|_| Error::with_message(ErrorCode::Misuse, "collation name contains NUL"))?;
368        let rc = create(db.as_ptr(), name.as_ptr(), enc, context, cmp, destroy);
369        if rc == SQLITE_OK {
370            Ok(())
371        } else {
372            Err(self.error_from_rc(rc, Some(db)))
373        }
374    }
375
376    unsafe fn aggregate_context(&self, ctx: NonNull<Self::Context>, bytes: usize) -> *mut c_void {
377        (self.fns.aggregate_context)(ctx.as_ptr(), clamp_len(bytes) as c_int)
378    }
379
380    unsafe fn result_null(&self, ctx: NonNull<Self::Context>) {
381        (self.fns.result_null)(ctx.as_ptr());
382    }
383
384    unsafe fn result_int64(&self, ctx: NonNull<Self::Context>, v: i64) {
385        (self.fns.result_int64)(ctx.as_ptr(), v);
386    }
387
388    unsafe fn result_double(&self, ctx: NonNull<Self::Context>, v: f64) {
389        (self.fns.result_double)(ctx.as_ptr(), v);
390    }
391
392    unsafe fn result_text(&self, ctx: NonNull<Self::Context>, v: &str) {
393        unsafe { self.result_text_bytes(ctx, v.as_bytes()) }
394    }
395
396    unsafe fn result_text_bytes(&self, ctx: NonNull<Self::Context>, v: &[u8]) {
397        match self.alloc_copy(v) {
398            Ok((ptr, dtor)) => {
399                (self.fns.result_text)(
400                    ctx.as_ptr(),
401                    ptr as *const c_char,
402                    clamp_len(v.len()),
403                    dtor,
404                );
405            }
406            Err(_) => {
407                const OOM: &str = "out of memory";
408                (self.fns.result_error)(
409                    ctx.as_ptr(),
410                    OOM.as_ptr() as *const c_char,
411                    clamp_len(OOM.len()),
412                );
413            }
414        }
415    }
416
417    unsafe fn result_blob(&self, ctx: NonNull<Self::Context>, v: &[u8]) {
418        match self.alloc_copy(v) {
419            Ok((ptr, dtor)) => {
420                (self.fns.result_blob)(ctx.as_ptr(), ptr, clamp_len(v.len()), dtor);
421            }
422            Err(_) => {
423                const OOM: &str = "out of memory";
424                (self.fns.result_error)(
425                    ctx.as_ptr(),
426                    OOM.as_ptr() as *const c_char,
427                    clamp_len(OOM.len()),
428                );
429            }
430        }
431    }
432
433    unsafe fn result_error(&self, ctx: NonNull<Self::Context>, msg: &str) {
434        (self.fns.result_error)(
435            ctx.as_ptr(),
436            msg.as_ptr() as *const c_char,
437            clamp_len(msg.len()),
438        );
439    }
440
441    unsafe fn user_data(ctx: NonNull<Self::Context>) -> *mut c_void {
442        match USER_DATA_FN.get() {
443            Some(f) => f(ctx.as_ptr()),
444            None => null_mut(),
445        }
446    }
447
448    unsafe fn value_type(&self, v: NonNull<Self::Value>) -> ValueType {
449        ValueType::from_code((self.fns.value_type)(v.as_ptr()))
450    }
451
452    unsafe fn value_int64(&self, v: NonNull<Self::Value>) -> i64 {
453        (self.fns.value_int64)(v.as_ptr())
454    }
455
456    unsafe fn value_double(&self, v: NonNull<Self::Value>) -> f64 {
457        (self.fns.value_double)(v.as_ptr())
458    }
459
460    unsafe fn value_text(&self, v: NonNull<Self::Value>) -> RawBytes {
461        let ptr = (self.fns.value_text)(v.as_ptr());
462        if ptr.is_null() {
463            return RawBytes::empty();
464        }
465        let len = (self.fns.value_bytes)(v.as_ptr());
466        RawBytes {
467            ptr,
468            len: len as usize,
469        }
470    }
471
472    unsafe fn value_blob(&self, v: NonNull<Self::Value>) -> RawBytes {
473        let ptr = (self.fns.value_blob)(v.as_ptr()) as *const u8;
474        if ptr.is_null() {
475            return RawBytes::empty();
476        }
477        let len = (self.fns.value_bytes)(v.as_ptr());
478        RawBytes {
479            ptr,
480            len: len as usize,
481        }
482    }
483
484    unsafe fn declare_vtab(&self, db: NonNull<Self::Db>, schema: &str) -> Result<()> {
485        let schema = CString::new(schema)
486            .map_err(|_| Error::with_message(ErrorCode::Misuse, "schema contains NUL"))?;
487        let rc = (self.fns.declare_vtab)(db.as_ptr(), schema.as_ptr());
488        if rc == SQLITE_OK {
489            Ok(())
490        } else {
491            Err(self.error_from_rc(rc, Some(db)))
492        }
493    }
494
495    unsafe fn create_module_v2(
496        &self,
497        db: NonNull<Self::Db>,
498        name: &str,
499        module: &'static sqlite_provider::sqlite3_module<Self>,
500        user_data: *mut c_void,
501        drop_user_data: Option<extern "C" fn(*mut c_void)>,
502    ) -> Result<()> {
503        let create = match self.fns.create_module_v2 {
504            Some(create) => create,
505            None => return Err(Error::feature_unavailable("create_module_v2 not available")),
506        };
507        let name = CString::new(name)
508            .map_err(|_| Error::with_message(ErrorCode::Misuse, "module name contains NUL"))?;
509        let rc = create(
510            db.as_ptr(),
511            name.as_ptr(),
512            module as *const _ as *const c_void,
513            user_data,
514            drop_user_data,
515        );
516        if rc == SQLITE_OK {
517            Ok(())
518        } else {
519            Err(self.error_from_rc(rc, Some(db)))
520        }
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::{drop_registration_user_data, registration_name_or_drop};
527    use sqlite_provider::ErrorCode;
528    use std::ffi::c_void;
529    use std::sync::Mutex;
530    use std::sync::atomic::{AtomicUsize, Ordering};
531
532    static DROP_CALLS: AtomicUsize = AtomicUsize::new(0);
533    static DROP_TEST_LOCK: Mutex<()> = Mutex::new(());
534
535    extern "C" fn drop_counting_box(ptr: *mut c_void) {
536        DROP_CALLS.fetch_add(1, Ordering::SeqCst);
537        if !ptr.is_null() {
538            unsafe { drop(Box::from_raw(ptr as *mut usize)) };
539        }
540    }
541
542    #[test]
543    fn registration_name_or_drop_invokes_drop_on_interior_nul() {
544        let _guard = DROP_TEST_LOCK
545            .lock()
546            .unwrap_or_else(|poison| poison.into_inner());
547        DROP_CALLS.store(0, Ordering::SeqCst);
548        let user_data = Box::into_raw(Box::new(7usize)) as *mut c_void;
549        let result = registration_name_or_drop(
550            "bad\0name",
551            user_data,
552            Some(drop_counting_box),
553            "function name contains NUL",
554        );
555        let err = result.expect_err("interior NUL should fail");
556        assert_eq!(err.code, ErrorCode::Misuse);
557        assert_eq!(DROP_CALLS.load(Ordering::SeqCst), 1);
558    }
559
560    #[test]
561    fn registration_name_or_drop_keeps_user_data_on_success() {
562        let _guard = DROP_TEST_LOCK
563            .lock()
564            .unwrap_or_else(|poison| poison.into_inner());
565        DROP_CALLS.store(0, Ordering::SeqCst);
566        let user_data = Box::into_raw(Box::new(9usize)) as *mut c_void;
567        let name = registration_name_or_drop(
568            "ok_name",
569            user_data,
570            Some(drop_counting_box),
571            "function name contains NUL",
572        )
573        .expect("valid name should pass");
574        assert_eq!(name.to_str().unwrap(), "ok_name");
575        assert_eq!(DROP_CALLS.load(Ordering::SeqCst), 0);
576        drop_registration_user_data(user_data, Some(drop_counting_box));
577        assert_eq!(DROP_CALLS.load(Ordering::SeqCst), 1);
578    }
579
580    #[test]
581    fn drop_registration_user_data_is_noop_without_callback() {
582        let _guard = DROP_TEST_LOCK
583            .lock()
584            .unwrap_or_else(|poison| poison.into_inner());
585        drop_registration_user_data(std::ptr::null_mut(), None);
586    }
587}