1use super::*;
2
3fn drop_registration_user_data(
4 user_data: *mut c_void,
5 drop_user_data: Option<extern "C" fn(*mut c_void)>,
6) {
7 if let Some(drop_user_data) = drop_user_data {
8 drop_user_data(user_data);
9 }
10}
11
12fn registration_name_or_drop(
13 name: &str,
14 user_data: *mut c_void,
15 drop_user_data: Option<extern "C" fn(*mut c_void)>,
16 error_message: &'static str,
17) -> Result<CString> {
18 match CString::new(name) {
19 Ok(name) => Ok(name),
20 Err(_) => {
21 drop_registration_user_data(user_data, drop_user_data);
22 Err(Error::with_message(ErrorCode::Misuse, error_message))
23 }
24 }
25}
26
27#[allow(unsafe_op_in_unsafe_fn)]
28unsafe impl Sqlite3Api for LibSqlite3 {
29 type Db = sqlite3;
30 type Stmt = sqlite3_stmt;
31 type Value = sqlite3_value;
32 type Context = sqlite3_context;
33 type VTab = c_void;
34 type VTabCursor = c_void;
35
36 fn api_version(&self) -> ApiVersion {
37 self.api_version
38 }
39
40 fn feature_set(&self) -> FeatureSet {
41 self.features
42 }
43
44 fn backend_name(&self) -> &'static str {
45 "libsqlite3"
46 }
47
48 fn backend_version(&self) -> Option<ApiVersion> {
49 Some(self.api_version)
50 }
51
52 unsafe fn malloc(&self, size: usize) -> *mut c_void {
53 if size > i32::MAX as usize {
54 return null_mut();
55 }
56 (self.fns.malloc)(size as c_int)
57 }
58
59 unsafe fn free(&self, ptr: *mut c_void) {
60 (self.fns.free)(ptr);
61 }
62
63 fn threadsafe(&self) -> i32 {
64 self.fns.threadsafe.map(|f| unsafe { f() }).unwrap_or(0)
65 }
66
67 unsafe fn open(&self, filename: &str, options: OpenOptions<'_>) -> Result<NonNull<Self::Db>> {
68 let filename = CString::new(filename)
69 .map_err(|_| Error::with_message(ErrorCode::Misuse, "filename contains NUL"))?;
70 let vfs = match options.vfs {
71 Some(vfs) => Some(
72 CString::new(vfs)
73 .map_err(|_| Error::with_message(ErrorCode::Misuse, "vfs contains NUL"))?,
74 ),
75 None => None,
76 };
77 let mut db = null_mut();
78 let flags = map_open_flags(options.flags);
79 let vfs_ptr = vfs.as_ref().map(|s| s.as_ptr()).unwrap_or(null());
80 let rc = (self.fns.open_v2)(filename.as_ptr(), &mut db, flags, vfs_ptr);
81 if rc != SQLITE_OK {
82 let err = self.error_from_rc(rc, NonNull::new(db));
83 if !db.is_null() {
84 let _ = (self.fns.close)(db);
85 }
86 return Err(err);
87 }
88 Ok(NonNull::new_unchecked(db))
89 }
90
91 unsafe fn close(&self, db: NonNull<Self::Db>) -> Result<()> {
92 let rc = (self.fns.close)(db.as_ptr());
93 if rc == SQLITE_OK {
94 Ok(())
95 } else {
96 Err(self.error_from_rc(rc, Some(db)))
97 }
98 }
99
100 unsafe fn prepare_v2(&self, db: NonNull<Self::Db>, sql: &str) -> Result<NonNull<Self::Stmt>> {
101 let mut stmt = null_mut();
102 let sql_ptr = sql.as_ptr() as *const c_char;
103 let sql_len = clamp_len(sql.len());
104 let rc = (self.fns.prepare_v2)(db.as_ptr(), sql_ptr, sql_len, &mut stmt, null_mut());
105 if rc != SQLITE_OK {
106 return Err(self.error_from_rc(rc, Some(db)));
107 }
108 Ok(NonNull::new_unchecked(stmt))
109 }
110
111 unsafe fn prepare_v3(
112 &self,
113 db: NonNull<Self::Db>,
114 sql: &str,
115 flags: u32,
116 ) -> Result<NonNull<Self::Stmt>> {
117 let prepare = match self.fns.prepare_v3 {
118 Some(prepare) => prepare,
119 None => return Err(Error::feature_unavailable("prepare_v3 not available")),
120 };
121 let mut stmt = null_mut();
122 let sql_ptr = sql.as_ptr() as *const c_char;
123 let sql_len = clamp_len(sql.len());
124 let rc = prepare(db.as_ptr(), sql_ptr, sql_len, flags, &mut stmt, null_mut());
125 if rc != SQLITE_OK {
126 return Err(self.error_from_rc(rc, Some(db)));
127 }
128 Ok(NonNull::new_unchecked(stmt))
129 }
130
131 unsafe fn step(&self, stmt: NonNull<Self::Stmt>) -> Result<StepResult> {
132 match (self.fns.step)(stmt.as_ptr()) {
133 SQLITE_ROW => Ok(StepResult::Row),
134 SQLITE_DONE => Ok(StepResult::Done),
135 rc => Err(self.error_from_rc(rc, None)),
136 }
137 }
138
139 unsafe fn reset(&self, stmt: NonNull<Self::Stmt>) -> Result<()> {
140 let rc = (self.fns.reset)(stmt.as_ptr());
141 if rc == SQLITE_OK {
142 Ok(())
143 } else {
144 Err(self.error_from_rc(rc, None))
145 }
146 }
147
148 unsafe fn finalize(&self, stmt: NonNull<Self::Stmt>) -> Result<()> {
149 let rc = (self.fns.finalize)(stmt.as_ptr());
150 if rc == SQLITE_OK {
151 Ok(())
152 } else {
153 Err(self.error_from_rc(rc, None))
154 }
155 }
156
157 unsafe fn bind_null(&self, stmt: NonNull<Self::Stmt>, idx: i32) -> Result<()> {
158 let rc = (self.fns.bind_null)(stmt.as_ptr(), idx);
159 if rc == SQLITE_OK {
160 Ok(())
161 } else {
162 Err(self.error_from_rc(rc, None))
163 }
164 }
165
166 unsafe fn bind_int64(&self, stmt: NonNull<Self::Stmt>, idx: i32, v: i64) -> Result<()> {
167 let rc = (self.fns.bind_int64)(stmt.as_ptr(), idx, v);
168 if rc == SQLITE_OK {
169 Ok(())
170 } else {
171 Err(self.error_from_rc(rc, None))
172 }
173 }
174
175 unsafe fn bind_double(&self, stmt: NonNull<Self::Stmt>, idx: i32, v: f64) -> Result<()> {
176 let rc = (self.fns.bind_double)(stmt.as_ptr(), idx, v);
177 if rc == SQLITE_OK {
178 Ok(())
179 } else {
180 Err(self.error_from_rc(rc, None))
181 }
182 }
183
184 unsafe fn bind_text(&self, stmt: NonNull<Self::Stmt>, idx: i32, v: &str) -> Result<()> {
185 unsafe { self.bind_text_bytes(stmt, idx, v.as_bytes()) }
186 }
187
188 unsafe fn bind_text_bytes(&self, stmt: NonNull<Self::Stmt>, idx: i32, v: &[u8]) -> Result<()> {
189 let (ptr, dtor) = self.alloc_copy(v)?;
190 let rc = (self.fns.bind_text)(
191 stmt.as_ptr(),
192 idx,
193 ptr as *const c_char,
194 clamp_len(v.len()),
195 dtor,
196 );
197 if rc != SQLITE_OK {
198 return Err(self.error_from_rc(rc, None));
199 }
200 Ok(())
201 }
202
203 unsafe fn bind_blob(&self, stmt: NonNull<Self::Stmt>, idx: i32, v: &[u8]) -> Result<()> {
204 let (ptr, dtor) = self.alloc_copy(v)?;
205 let rc = (self.fns.bind_blob)(stmt.as_ptr(), idx, ptr, clamp_len(v.len()), dtor);
206 if rc != SQLITE_OK {
207 return Err(self.error_from_rc(rc, None));
208 }
209 Ok(())
210 }
211
212 unsafe fn column_count(&self, stmt: NonNull<Self::Stmt>) -> i32 {
213 (self.fns.column_count)(stmt.as_ptr())
214 }
215
216 unsafe fn column_type(&self, stmt: NonNull<Self::Stmt>, col: i32) -> ValueType {
217 ValueType::from_code((self.fns.column_type)(stmt.as_ptr(), col))
218 }
219
220 unsafe fn column_int64(&self, stmt: NonNull<Self::Stmt>, col: i32) -> i64 {
221 (self.fns.column_int64)(stmt.as_ptr(), col)
222 }
223
224 unsafe fn column_double(&self, stmt: NonNull<Self::Stmt>, col: i32) -> f64 {
225 (self.fns.column_double)(stmt.as_ptr(), col)
226 }
227
228 unsafe fn column_text(&self, stmt: NonNull<Self::Stmt>, col: i32) -> RawBytes {
229 let ptr = (self.fns.column_text)(stmt.as_ptr(), col);
230 if ptr.is_null() {
231 return RawBytes::empty();
232 }
233 let len = (self.fns.column_bytes)(stmt.as_ptr(), col);
234 RawBytes {
235 ptr,
236 len: len as usize,
237 }
238 }
239
240 unsafe fn column_blob(&self, stmt: NonNull<Self::Stmt>, col: i32) -> RawBytes {
241 let ptr = (self.fns.column_blob)(stmt.as_ptr(), col) as *const u8;
242 if ptr.is_null() {
243 return RawBytes::empty();
244 }
245 let len = (self.fns.column_bytes)(stmt.as_ptr(), col);
246 RawBytes {
247 ptr,
248 len: len as usize,
249 }
250 }
251
252 unsafe fn errcode(&self, db: NonNull<Self::Db>) -> i32 {
253 (self.fns.errcode)(db.as_ptr())
254 }
255
256 unsafe fn errmsg(&self, db: NonNull<Self::Db>) -> *const c_char {
257 (self.fns.errmsg)(db.as_ptr())
258 }
259
260 unsafe fn extended_errcode(&self, db: NonNull<Self::Db>) -> Option<i32> {
261 self.fns.extended_errcode.map(|f| f(db.as_ptr()))
262 }
263
264 unsafe fn create_function_v2(
265 &self,
266 db: NonNull<Self::Db>,
267 name: &str,
268 n_args: i32,
269 flags: FunctionFlags,
270 x_func: Option<extern "C" fn(*mut Self::Context, i32, *mut *mut Self::Value)>,
271 x_step: Option<extern "C" fn(*mut Self::Context, i32, *mut *mut Self::Value)>,
272 x_final: Option<extern "C" fn(*mut Self::Context)>,
273 user_data: *mut c_void,
274 drop_user_data: Option<extern "C" fn(*mut c_void)>,
275 ) -> Result<()> {
276 let name = registration_name_or_drop(
277 name,
278 user_data,
279 drop_user_data,
280 "function name contains NUL",
281 )?;
282 let flags = map_function_flags(flags);
283 let rc = (self.fns.create_function_v2)(
284 db.as_ptr(),
285 name.as_ptr(),
286 n_args,
287 flags,
288 user_data,
289 x_func,
290 x_step,
291 x_final,
292 drop_user_data,
293 );
294 if rc == SQLITE_OK {
295 Ok(())
296 } else {
297 Err(self.error_from_rc(rc, Some(db)))
298 }
299 }
300
301 unsafe fn create_window_function(
302 &self,
303 db: NonNull<Self::Db>,
304 name: &str,
305 n_args: i32,
306 flags: FunctionFlags,
307 x_step: Option<extern "C" fn(*mut Self::Context, i32, *mut *mut Self::Value)>,
308 x_final: Option<extern "C" fn(*mut Self::Context)>,
309 x_value: Option<extern "C" fn(*mut Self::Context)>,
310 x_inverse: Option<extern "C" fn(*mut Self::Context, i32, *mut *mut Self::Value)>,
311 user_data: *mut c_void,
312 drop_user_data: Option<extern "C" fn(*mut c_void)>,
313 ) -> Result<()> {
314 let create = match self.fns.create_window_function {
315 Some(create) => create,
316 None => {
317 drop_registration_user_data(user_data, drop_user_data);
318 return Err(Error::feature_unavailable(
319 "create_window_function not available",
320 ));
321 }
322 };
323 let name = registration_name_or_drop(
324 name,
325 user_data,
326 drop_user_data,
327 "function name contains NUL",
328 )?;
329 let flags = map_function_flags(flags);
330 let rc = create(
331 db.as_ptr(),
332 name.as_ptr(),
333 n_args,
334 flags,
335 user_data,
336 x_step,
337 x_final,
338 x_value,
339 x_inverse,
340 drop_user_data,
341 );
342 if rc == SQLITE_OK {
343 Ok(())
344 } else {
345 Err(self.error_from_rc(rc, Some(db)))
346 }
347 }
348
349 unsafe fn create_collation_v2(
350 &self,
351 db: NonNull<Self::Db>,
352 name: &str,
353 enc: i32,
354 context: *mut c_void,
355 cmp: Option<extern "C" fn(*mut c_void, i32, *const c_void, i32, *const c_void) -> i32>,
356 destroy: Option<extern "C" fn(*mut c_void)>,
357 ) -> Result<()> {
358 let create = match self.fns.create_collation_v2 {
359 Some(create) => create,
360 None => {
361 return Err(Error::feature_unavailable(
362 "create_collation_v2 not available",
363 ));
364 }
365 };
366 let name = CString::new(name)
367 .map_err(|_| Error::with_message(ErrorCode::Misuse, "collation name contains NUL"))?;
368 let rc = create(db.as_ptr(), name.as_ptr(), enc, context, cmp, destroy);
369 if rc == SQLITE_OK {
370 Ok(())
371 } else {
372 Err(self.error_from_rc(rc, Some(db)))
373 }
374 }
375
376 unsafe fn aggregate_context(&self, ctx: NonNull<Self::Context>, bytes: usize) -> *mut c_void {
377 (self.fns.aggregate_context)(ctx.as_ptr(), clamp_len(bytes) as c_int)
378 }
379
380 unsafe fn result_null(&self, ctx: NonNull<Self::Context>) {
381 (self.fns.result_null)(ctx.as_ptr());
382 }
383
384 unsafe fn result_int64(&self, ctx: NonNull<Self::Context>, v: i64) {
385 (self.fns.result_int64)(ctx.as_ptr(), v);
386 }
387
388 unsafe fn result_double(&self, ctx: NonNull<Self::Context>, v: f64) {
389 (self.fns.result_double)(ctx.as_ptr(), v);
390 }
391
392 unsafe fn result_text(&self, ctx: NonNull<Self::Context>, v: &str) {
393 unsafe { self.result_text_bytes(ctx, v.as_bytes()) }
394 }
395
396 unsafe fn result_text_bytes(&self, ctx: NonNull<Self::Context>, v: &[u8]) {
397 match self.alloc_copy(v) {
398 Ok((ptr, dtor)) => {
399 (self.fns.result_text)(
400 ctx.as_ptr(),
401 ptr as *const c_char,
402 clamp_len(v.len()),
403 dtor,
404 );
405 }
406 Err(_) => {
407 const OOM: &str = "out of memory";
408 (self.fns.result_error)(
409 ctx.as_ptr(),
410 OOM.as_ptr() as *const c_char,
411 clamp_len(OOM.len()),
412 );
413 }
414 }
415 }
416
417 unsafe fn result_blob(&self, ctx: NonNull<Self::Context>, v: &[u8]) {
418 match self.alloc_copy(v) {
419 Ok((ptr, dtor)) => {
420 (self.fns.result_blob)(ctx.as_ptr(), ptr, clamp_len(v.len()), dtor);
421 }
422 Err(_) => {
423 const OOM: &str = "out of memory";
424 (self.fns.result_error)(
425 ctx.as_ptr(),
426 OOM.as_ptr() as *const c_char,
427 clamp_len(OOM.len()),
428 );
429 }
430 }
431 }
432
433 unsafe fn result_error(&self, ctx: NonNull<Self::Context>, msg: &str) {
434 (self.fns.result_error)(
435 ctx.as_ptr(),
436 msg.as_ptr() as *const c_char,
437 clamp_len(msg.len()),
438 );
439 }
440
441 unsafe fn user_data(ctx: NonNull<Self::Context>) -> *mut c_void {
442 match USER_DATA_FN.get() {
443 Some(f) => f(ctx.as_ptr()),
444 None => null_mut(),
445 }
446 }
447
448 unsafe fn value_type(&self, v: NonNull<Self::Value>) -> ValueType {
449 ValueType::from_code((self.fns.value_type)(v.as_ptr()))
450 }
451
452 unsafe fn value_int64(&self, v: NonNull<Self::Value>) -> i64 {
453 (self.fns.value_int64)(v.as_ptr())
454 }
455
456 unsafe fn value_double(&self, v: NonNull<Self::Value>) -> f64 {
457 (self.fns.value_double)(v.as_ptr())
458 }
459
460 unsafe fn value_text(&self, v: NonNull<Self::Value>) -> RawBytes {
461 let ptr = (self.fns.value_text)(v.as_ptr());
462 if ptr.is_null() {
463 return RawBytes::empty();
464 }
465 let len = (self.fns.value_bytes)(v.as_ptr());
466 RawBytes {
467 ptr,
468 len: len as usize,
469 }
470 }
471
472 unsafe fn value_blob(&self, v: NonNull<Self::Value>) -> RawBytes {
473 let ptr = (self.fns.value_blob)(v.as_ptr()) as *const u8;
474 if ptr.is_null() {
475 return RawBytes::empty();
476 }
477 let len = (self.fns.value_bytes)(v.as_ptr());
478 RawBytes {
479 ptr,
480 len: len as usize,
481 }
482 }
483
484 unsafe fn declare_vtab(&self, db: NonNull<Self::Db>, schema: &str) -> Result<()> {
485 let schema = CString::new(schema)
486 .map_err(|_| Error::with_message(ErrorCode::Misuse, "schema contains NUL"))?;
487 let rc = (self.fns.declare_vtab)(db.as_ptr(), schema.as_ptr());
488 if rc == SQLITE_OK {
489 Ok(())
490 } else {
491 Err(self.error_from_rc(rc, Some(db)))
492 }
493 }
494
495 unsafe fn create_module_v2(
496 &self,
497 db: NonNull<Self::Db>,
498 name: &str,
499 module: &'static sqlite_provider::sqlite3_module<Self>,
500 user_data: *mut c_void,
501 drop_user_data: Option<extern "C" fn(*mut c_void)>,
502 ) -> Result<()> {
503 let create = match self.fns.create_module_v2 {
504 Some(create) => create,
505 None => return Err(Error::feature_unavailable("create_module_v2 not available")),
506 };
507 let name = CString::new(name)
508 .map_err(|_| Error::with_message(ErrorCode::Misuse, "module name contains NUL"))?;
509 let rc = create(
510 db.as_ptr(),
511 name.as_ptr(),
512 module as *const _ as *const c_void,
513 user_data,
514 drop_user_data,
515 );
516 if rc == SQLITE_OK {
517 Ok(())
518 } else {
519 Err(self.error_from_rc(rc, Some(db)))
520 }
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::{drop_registration_user_data, registration_name_or_drop};
527 use sqlite_provider::ErrorCode;
528 use std::ffi::c_void;
529 use std::sync::Mutex;
530 use std::sync::atomic::{AtomicUsize, Ordering};
531
532 static DROP_CALLS: AtomicUsize = AtomicUsize::new(0);
533 static DROP_TEST_LOCK: Mutex<()> = Mutex::new(());
534
535 extern "C" fn drop_counting_box(ptr: *mut c_void) {
536 DROP_CALLS.fetch_add(1, Ordering::SeqCst);
537 if !ptr.is_null() {
538 unsafe { drop(Box::from_raw(ptr as *mut usize)) };
539 }
540 }
541
542 #[test]
543 fn registration_name_or_drop_invokes_drop_on_interior_nul() {
544 let _guard = DROP_TEST_LOCK
545 .lock()
546 .unwrap_or_else(|poison| poison.into_inner());
547 DROP_CALLS.store(0, Ordering::SeqCst);
548 let user_data = Box::into_raw(Box::new(7usize)) as *mut c_void;
549 let result = registration_name_or_drop(
550 "bad\0name",
551 user_data,
552 Some(drop_counting_box),
553 "function name contains NUL",
554 );
555 let err = result.expect_err("interior NUL should fail");
556 assert_eq!(err.code, ErrorCode::Misuse);
557 assert_eq!(DROP_CALLS.load(Ordering::SeqCst), 1);
558 }
559
560 #[test]
561 fn registration_name_or_drop_keeps_user_data_on_success() {
562 let _guard = DROP_TEST_LOCK
563 .lock()
564 .unwrap_or_else(|poison| poison.into_inner());
565 DROP_CALLS.store(0, Ordering::SeqCst);
566 let user_data = Box::into_raw(Box::new(9usize)) as *mut c_void;
567 let name = registration_name_or_drop(
568 "ok_name",
569 user_data,
570 Some(drop_counting_box),
571 "function name contains NUL",
572 )
573 .expect("valid name should pass");
574 assert_eq!(name.to_str().unwrap(), "ok_name");
575 assert_eq!(DROP_CALLS.load(Ordering::SeqCst), 0);
576 drop_registration_user_data(user_data, Some(drop_counting_box));
577 assert_eq!(DROP_CALLS.load(Ordering::SeqCst), 1);
578 }
579
580 #[test]
581 fn drop_registration_user_data_is_noop_without_callback() {
582 let _guard = DROP_TEST_LOCK
583 .lock()
584 .unwrap_or_else(|poison| poison.into_inner());
585 drop_registration_user_data(std::ptr::null_mut(), None);
586 }
587}