sqlx_core_oldapi/sqlite/connection/
function.rs

1use std::ffi::{c_char, CString};
2use std::os::raw::{c_int, c_void};
3use std::sync::Arc;
4
5use libsqlite3_sys::{
6    sqlite3_context, sqlite3_create_function_v2, sqlite3_result_blob, sqlite3_result_double,
7    sqlite3_result_error, sqlite3_result_int, sqlite3_result_int64, sqlite3_result_null,
8    sqlite3_result_text, sqlite3_user_data, sqlite3_value, sqlite3_value_type,
9    SQLITE_DETERMINISTIC, SQLITE_DIRECTONLY, SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8,
10};
11
12use crate::decode::Decode;
13use crate::encode::{Encode, IsNull};
14use crate::error::{BoxDynError, Error};
15use crate::sqlite::type_info::DataType;
16use crate::sqlite::Sqlite;
17use crate::sqlite::SqliteArgumentValue;
18use crate::sqlite::SqliteTypeInfo;
19use crate::sqlite::SqliteValue;
20use crate::sqlite::{connection::handle::ConnectionHandle, SqliteError};
21use crate::value::Value;
22
23pub trait SqliteCallable: Send + Sync {
24    unsafe fn call_boxed_closure(
25        &self,
26        ctx: *mut sqlite3_context,
27        argc: c_int,
28        argv: *mut *mut sqlite3_value,
29    );
30    // number of arguments
31    fn arg_count(&self) -> i32;
32}
33
34pub struct SqliteFunctionCtx {
35    ctx: *mut sqlite3_context,
36    argument_values: Vec<SqliteValue>,
37}
38
39impl SqliteFunctionCtx {
40    /// Creates a new `SqliteFunctionCtx` from the given raw SQLite function context.
41    /// The context is used to access the arguments passed to the function.
42    /// Safety: the context must be valid and argc must be the number of arguments passed to the function.
43    unsafe fn new(ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value) -> Self {
44        let count = usize::try_from(argc).expect("invalid argument count");
45        let argument_values = (0..count)
46            .map(|i| {
47                let raw = *argv.add(i);
48                let data_type_code = sqlite3_value_type(raw);
49                let value_type_info = SqliteTypeInfo(DataType::from_code(data_type_code));
50                SqliteValue::new(raw, value_type_info)
51            })
52            .collect::<Vec<_>>();
53        Self {
54            ctx,
55            argument_values,
56        }
57    }
58
59    /// Returns the argument at the given index, or panics if the argument number is out of bounds or
60    /// the argument cannot be decoded as the requested type.
61    pub fn get_arg<'q, T: Decode<'q, Sqlite>>(&'q self, index: usize) -> T {
62        self.try_get_arg::<T>(index)
63            .expect("invalid argument index")
64    }
65
66    /// Returns the argument at the given index, or `None` if the argument number is out of bounds or
67    /// the argument cannot be decoded as the requested type.
68    pub fn try_get_arg<'q, T: Decode<'q, Sqlite>>(
69        &'q self,
70        index: usize,
71    ) -> Result<T, BoxDynError> {
72        if let Some(value) = self.argument_values.get(index) {
73            let value_ref = value.as_ref();
74            T::decode(value_ref)
75        } else {
76            Err("invalid argument index".into())
77        }
78    }
79
80    pub fn set_result<'q, R: Encode<'q, Sqlite>>(&self, result: R) {
81        unsafe {
82            let mut arg_buffer: Vec<SqliteArgumentValue<'q>> = Vec::with_capacity(1);
83            if let IsNull::Yes = result.encode(&mut arg_buffer) {
84                sqlite3_result_null(self.ctx);
85            } else {
86                let arg = arg_buffer.pop().unwrap();
87                match arg {
88                    SqliteArgumentValue::Null => {
89                        sqlite3_result_null(self.ctx);
90                    }
91                    SqliteArgumentValue::Text(text) => {
92                        sqlite3_result_text(
93                            self.ctx,
94                            text.as_ptr() as *const c_char,
95                            text.len() as c_int,
96                            SQLITE_TRANSIENT(),
97                        );
98                    }
99                    SqliteArgumentValue::Blob(blob) => {
100                        sqlite3_result_blob(
101                            self.ctx,
102                            blob.as_ptr() as *const c_void,
103                            blob.len() as c_int,
104                            SQLITE_TRANSIENT(),
105                        );
106                    }
107                    SqliteArgumentValue::Double(double) => {
108                        sqlite3_result_double(self.ctx, double);
109                    }
110                    SqliteArgumentValue::Int(int) => {
111                        sqlite3_result_int(self.ctx, int);
112                    }
113                    SqliteArgumentValue::Int64(int64) => {
114                        sqlite3_result_int64(self.ctx, int64);
115                    }
116                }
117            }
118        }
119    }
120
121    pub fn set_error(&self, error_str: &str) {
122        let error_str = CString::new(error_str).expect("invalid error string");
123        unsafe {
124            sqlite3_result_error(
125                self.ctx,
126                error_str.as_ptr(),
127                error_str.as_bytes().len() as c_int,
128            );
129        }
130    }
131}
132
133impl<F: Fn(&SqliteFunctionCtx) + Send + Sync> SqliteCallable for F {
134    unsafe fn call_boxed_closure(
135        &self,
136        ctx: *mut sqlite3_context,
137        argc: c_int,
138        argv: *mut *mut sqlite3_value,
139    ) {
140        let ctx = SqliteFunctionCtx::new(ctx, argc, argv);
141        (*self)(&ctx);
142    }
143    fn arg_count(&self) -> i32 {
144        -1
145    }
146}
147
148#[derive(Clone)]
149pub struct Function {
150    name: CString,
151    func: Arc<dyn SqliteCallable>,
152    /// the function always returns the same result given the same inputs
153    pub deterministic: bool,
154    /// the function may only be invoked from top-level SQL, and cannot be used in VIEWs or TRIGGERs nor in schema structures such as CHECK constraints, DEFAULT clauses, expression indexes, partial indexes, or generated columns.
155    pub direct_only: bool,
156    call:
157        unsafe extern "C" fn(ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value),
158}
159
160impl std::fmt::Debug for Function {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        f.debug_struct("Function")
163            .field("name", &self.name)
164            .field("deterministic", &self.deterministic)
165            .finish_non_exhaustive()
166    }
167}
168
169impl Function {
170    pub fn new<N, F>(name: N, func: F) -> Self
171    where
172        N: Into<Vec<u8>>,
173        F: SqliteCallable + Send + Sync + 'static,
174    {
175        Function {
176            name: CString::new(name).expect("invalid function name"),
177            func: Arc::new(func),
178            deterministic: false,
179            direct_only: false,
180            call: call_boxed_closure::<F>,
181        }
182    }
183
184    pub(crate) fn create(&self, handle: &mut ConnectionHandle) -> Result<(), Error> {
185        let raw_f = Arc::into_raw(Arc::clone(&self.func));
186        let r = unsafe {
187            sqlite3_create_function_v2(
188                handle.as_ptr(),
189                self.name.as_ptr(),
190                self.func.arg_count(), // number of arguments
191                self.sqlite_flags(),
192                raw_f as *mut c_void,
193                Some(self.call),
194                None, // no step function for scalar functions
195                None, // no final function for scalar functions
196                None, // no need to free the function
197            )
198        };
199
200        if r == SQLITE_OK {
201            Ok(())
202        } else {
203            Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr()))))
204        }
205    }
206
207    fn sqlite_flags(&self) -> c_int {
208        let mut flags = SQLITE_UTF8;
209        if self.deterministic {
210            flags |= SQLITE_DETERMINISTIC;
211        }
212        if self.direct_only {
213            flags |= SQLITE_DIRECTONLY;
214        }
215        flags
216    }
217
218    pub fn deterministic(mut self) -> Self {
219        self.deterministic = true;
220        self
221    }
222
223    pub fn direct_only(mut self) -> Self {
224        self.direct_only = true;
225        self
226    }
227}
228
229unsafe extern "C" fn call_boxed_closure<F: SqliteCallable>(
230    ctx: *mut sqlite3_context,
231    argc: c_int,
232    argv: *mut *mut sqlite3_value,
233) {
234    let data = sqlite3_user_data(ctx);
235    let boxed_f: *mut F = data as *mut F;
236    debug_assert!(!boxed_f.is_null());
237    let expected_argc = (*boxed_f).arg_count();
238    debug_assert!(expected_argc == -1 || argc == expected_argc);
239    (*boxed_f).call_boxed_closure(ctx, argc, argv);
240}