sqlite_loadable/
scalar.rs

1//! Define scalar functions on sqlite3 database connections.
2
3#![allow(clippy::not_unsafe_ptr_arg_deref)]
4use std::{
5    ffi::CString,
6    os::raw::{c_int, c_void},
7    slice,
8};
9
10use crate::{
11    api,
12    constants::{SQLITE_INTERNAL, SQLITE_OKAY},
13    errors::{Error, ErrorKind, Result},
14    ext::sqlite3ext_create_function_v2,
15};
16use sqlite3ext_sys::{sqlite3, sqlite3_context, sqlite3_user_data, sqlite3_value};
17
18use bitflags::bitflags;
19
20use sqlite3ext_sys::{
21    SQLITE_DETERMINISTIC, SQLITE_DIRECTONLY, SQLITE_INNOCUOUS, SQLITE_SUBTYPE, SQLITE_UTF16,
22    SQLITE_UTF16BE, SQLITE_UTF16LE, SQLITE_UTF8,
23};
24
25bitflags! {
26    /// Represents the possible flag values that can be passed into sqlite3_create_function_v2
27    /// or sqlite3_create_window_function, as the 4th "eTextRep" parameter.
28    /// Includes both the encoding options (utf8, utf16, etc.) and function-level parameters
29    /// (deterministion, innocuous, etc.).
30    pub struct FunctionFlags: i32 {
31        const UTF8 = SQLITE_UTF8 as i32;
32        const UTF16LE = SQLITE_UTF16LE as i32;
33        const UTF16BE = SQLITE_UTF16BE as i32;
34        const UTF16 = SQLITE_UTF16 as i32;
35
36        /// "... to signal that the function will always return the same result given the same
37        /// inputs within a single SQL statement."
38        /// <https://www.sqlite.org/c3ref/create_function.html#:~:text=ORed%20with%20SQLITE_DETERMINISTIC>
39        const DETERMINISTIC = SQLITE_DETERMINISTIC as i32;
40        const DIRECTONLY = SQLITE_DIRECTONLY as i32;
41        const SUBTYPE = SQLITE_SUBTYPE as i32;
42        const INNOCUOUS = SQLITE_INNOCUOUS as i32;
43    }
44}
45
46/// Defines a new scalar function on the given database connection.
47///
48/// # Example
49/// ```rs
50/// fn xyz_version(context: *mut sqlite3_context, _values: &[*mut sqlite3_value]) -> Result<()> {
51///   context_result_text(context, &format!("v{}", env!("CARGO_PKG_VERSION")))?;
52///   Ok(())
53/// }
54///
55/// define_scalar_function(db, "xyz_version", 0, xyz_version)?;
56/// ```
57pub fn define_scalar_function<F>(
58    db: *mut sqlite3,
59    name: &str,
60    num_args: c_int,
61    x_func: F,
62    func_flags: FunctionFlags,
63) -> Result<()>
64where
65    // TODO - can we wrap the context arg with a safe/ergonomic struct?
66    // calling `context_result_text(context, "foo")` is long, but maybe
67    // `context.result_text("foo")` with a special wrapper struct can be
68    // as fast
69    F: Fn(*mut sqlite3_context, &[*mut sqlite3_value]) -> Result<()>,
70{
71    let function_pointer: *mut F = Box::into_raw(Box::new(x_func));
72
73    unsafe extern "C" fn x_func_wrapper<F>(
74        context: *mut sqlite3_context,
75        argc: c_int,
76        argv: *mut *mut sqlite3_value,
77    ) where
78        F: Fn(*mut sqlite3_context, &[*mut sqlite3_value]) -> Result<()>,
79    {
80        let boxed_function: *mut F = sqlite3_user_data(context).cast::<F>();
81        // .collect slows things waaaay down, so stick with slice for now
82        let args = slice::from_raw_parts(argv, argc as usize);
83        match (*boxed_function)(context, args) {
84            Ok(()) => (),
85            Err(e) => {
86                if api::result_error(context, &e.result_error_message()).is_err() {
87                    api::result_error_code(context, SQLITE_INTERNAL);
88                }
89            }
90        }
91    }
92    let cname = CString::new(name)?;
93    let result = unsafe {
94        sqlite3ext_create_function_v2(
95            db,
96            cname.as_ptr(),
97            num_args,
98            func_flags.bits,
99            function_pointer.cast::<c_void>(),
100            Some(x_func_wrapper::<F>),
101            None,
102            None,
103            None,
104        )
105    };
106
107    if result != SQLITE_OKAY {
108        Err(Error::new(ErrorKind::DefineScalarFunction(result)))
109    } else {
110        Ok(())
111    }
112}
113
114pub fn delete_scalar_function(
115    db: *mut sqlite3,
116    name: &str,
117    num_args: c_int,
118    func_flags: FunctionFlags,
119) -> Result<()> {
120    let cname = CString::new(name)?;
121    let result = unsafe {
122        sqlite3ext_create_function_v2(
123            db,
124            cname.as_ptr(),
125            num_args,
126            func_flags.bits,
127            std::ptr::null_mut(),
128            None,
129            None,
130            None,
131            None,
132        )
133    };
134
135    if result != SQLITE_OKAY {
136        println!("failed with {result}");
137        Err(Error::new(ErrorKind::DefineScalarFunction(result)))
138    } else {
139        Ok(())
140    }
141}
142
143/// Defines a new scalar function, but with the added ability to pass in an arbritary
144/// application "pointer" as any rust type. Can be accessed in the callback
145/// function as the 3rd argument, as a reference.
146/// <https://www.sqlite.org/c3ref/create_function.html#:~:text=The%20fifth%20parameter%20is%20an%20arbitrary%20pointer.>
147pub fn define_scalar_function_with_aux<F, T>(
148    db: *mut sqlite3,
149    name: &str,
150    num_args: c_int,
151    x_func: F,
152    func_flags: FunctionFlags,
153    aux: T,
154) -> Result<()>
155where
156    F: Fn(*mut sqlite3_context, &[*mut sqlite3_value], &T) -> Result<()>,
157{
158    let function_pointer: *mut F = Box::into_raw(Box::new(x_func));
159    let aux_pointer: *mut T = Box::into_raw(Box::new(aux));
160    let app_pointer = Box::into_raw(Box::new((function_pointer, aux_pointer)));
161
162    unsafe extern "C" fn x_func_wrapper<F, T>(
163        context: *mut sqlite3_context,
164        argc: c_int,
165        argv: *mut *mut sqlite3_value,
166    ) where
167        F: Fn(*mut sqlite3_context, &[*mut sqlite3_value], &T) -> Result<()>,
168    {
169        let x = sqlite3_user_data(context).cast::<(*mut F, *mut T)>();
170        let boxed_function = (*x).0;
171        let aux = (*x).1;
172        // .collect slows things waaaay down, so stick with slice for now
173        let args = slice::from_raw_parts(argv, argc as usize);
174        let b = Box::from_raw(aux);
175        match (*boxed_function)(context, args, &*b) {
176            Ok(()) => (),
177            Err(e) => {
178                if api::result_error(context, &e.result_error_message()).is_err() {
179                    api::result_error_code(context, SQLITE_INTERNAL);
180                }
181            }
182        }
183        Box::into_raw(b);
184    }
185    let cname = CString::new(name)?;
186
187    let result = unsafe {
188        sqlite3ext_create_function_v2(
189            db,
190            cname.as_ptr(),
191            num_args,
192            func_flags.bits,
193            app_pointer.cast::<c_void>(),
194            Some(x_func_wrapper::<F, T>),
195            None,
196            None,
197            None,
198        )
199    };
200
201    if result != SQLITE_OKAY {
202        Err(Error::new(ErrorKind::DefineScalarFunction(result)))
203    } else {
204        Ok(())
205    }
206}