sqlx_sqlite/connection/
collation.rs

1use std::cmp::Ordering;
2use std::ffi::CString;
3use std::fmt::{self, Debug, Formatter};
4use std::os::raw::{c_int, c_void};
5use std::slice;
6use std::str::from_utf8_unchecked;
7use std::sync::Arc;
8
9use libsqlite3_sys::{sqlite3_create_collation_v2, SQLITE_OK, SQLITE_UTF8};
10
11use crate::connection::handle::ConnectionHandle;
12use crate::error::Error;
13
14#[derive(Clone)]
15pub struct Collation {
16    name: Arc<str>,
17    #[allow(clippy::type_complexity)]
18    collate: Arc<dyn Fn(&str, &str) -> Ordering + Send + Sync + 'static>,
19    // SAFETY: these must match the concrete type of `collate`
20    call: unsafe extern "C" fn(
21        arg1: *mut c_void,
22        arg2: c_int,
23        arg3: *const c_void,
24        arg4: c_int,
25        arg5: *const c_void,
26    ) -> c_int,
27    free: unsafe extern "C" fn(*mut c_void),
28}
29
30impl Collation {
31    pub fn new<N, F>(name: N, collate: F) -> Self
32    where
33        N: Into<Arc<str>>,
34        F: Fn(&str, &str) -> Ordering + Send + Sync + 'static,
35    {
36        unsafe extern "C" fn drop_arc_value<T>(p: *mut c_void) {
37            drop(Arc::from_raw(p as *mut T));
38        }
39
40        Collation {
41            name: name.into(),
42            collate: Arc::new(collate),
43            call: call_boxed_closure::<F>,
44            free: drop_arc_value::<F>,
45        }
46    }
47
48    pub(crate) fn create(&self, handle: &mut ConnectionHandle) -> Result<(), Error> {
49        let raw_f = Arc::into_raw(Arc::clone(&self.collate));
50        let c_name = CString::new(&*self.name)
51            .map_err(|_| err_protocol!("invalid collation name: {:?}", self.name))?;
52        let flags = SQLITE_UTF8;
53        let r = unsafe {
54            sqlite3_create_collation_v2(
55                handle.as_ptr(),
56                c_name.as_ptr(),
57                flags,
58                raw_f as *mut c_void,
59                Some(self.call),
60                Some(self.free),
61            )
62        };
63
64        if r == SQLITE_OK {
65            Ok(())
66        } else {
67            // The xDestroy callback is not called if the sqlite3_create_collation_v2() function fails.
68            drop(unsafe { Arc::from_raw(raw_f) });
69            Err(handle.expect_error().into())
70        }
71    }
72}
73
74impl Debug for Collation {
75    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
76        f.debug_struct("Collation")
77            .field("name", &self.name)
78            .finish_non_exhaustive()
79    }
80}
81
82pub(crate) fn create_collation<F>(
83    handle: &mut ConnectionHandle,
84    name: &str,
85    compare: F,
86) -> Result<(), Error>
87where
88    F: Fn(&str, &str) -> Ordering + Send + Sync + 'static,
89{
90    unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
91        drop(Box::from_raw(p as *mut T));
92    }
93
94    let boxed_f: *mut F = Box::into_raw(Box::new(compare));
95    let c_name =
96        CString::new(name).map_err(|_| err_protocol!("invalid collation name: {}", name))?;
97    let flags = SQLITE_UTF8;
98    let r = unsafe {
99        sqlite3_create_collation_v2(
100            handle.as_ptr(),
101            c_name.as_ptr(),
102            flags,
103            boxed_f as *mut c_void,
104            Some(call_boxed_closure::<F>),
105            Some(free_boxed_value::<F>),
106        )
107    };
108
109    if r == SQLITE_OK {
110        Ok(())
111    } else {
112        // The xDestroy callback is not called if the sqlite3_create_collation_v2() function fails.
113        drop(unsafe { Box::from_raw(boxed_f) });
114        Err(handle.expect_error().into())
115    }
116}
117
118unsafe extern "C" fn call_boxed_closure<C>(
119    data: *mut c_void,
120    left_len: c_int,
121    left_ptr: *const c_void,
122    right_len: c_int,
123    right_ptr: *const c_void,
124) -> c_int
125where
126    C: Fn(&str, &str) -> Ordering,
127{
128    let boxed_f: *mut C = data as *mut C;
129
130    // Note: unwinding is now caught at the FFI boundary:
131    // https://doc.rust-lang.org/nomicon/ffi.html#ffi-and-unwinding
132    assert!(!boxed_f.is_null());
133
134    let left_len =
135        usize::try_from(left_len).unwrap_or_else(|_| panic!("left_len out of range: {left_len}"));
136
137    let right_len = usize::try_from(right_len)
138        .unwrap_or_else(|_| panic!("right_len out of range: {right_len}"));
139
140    let s1 = {
141        let c_slice = slice::from_raw_parts(left_ptr as *const u8, left_len);
142        from_utf8_unchecked(c_slice)
143    };
144    let s2 = {
145        let c_slice = slice::from_raw_parts(right_ptr as *const u8, right_len);
146        from_utf8_unchecked(c_slice)
147    };
148    let t = (*boxed_f)(s1, s2);
149
150    match t {
151        Ordering::Less => -1,
152        Ordering::Equal => 0,
153        Ordering::Greater => 1,
154    }
155}