sqlite3_ext/function/
mod.rs

1//! Create application-defined functions.
2//!
3//! The functionality in this module is primarily exposed through
4//! [Connection::create_scalar_function] and [Connection::create_aggregate_function].
5use super::{ffi, sqlite3_match_version, types::*, value::*, Connection, RiskLevel};
6pub use context::*;
7use std::{cmp::Ordering, ffi::CString, ptr::null_mut};
8
9mod context;
10mod stubs;
11mod test;
12
13/// Constructor for aggregate functions.
14///
15/// Aggregate functions are instantiated using user data provided when the function is
16/// registered. There is a blanket implementation for types implementing [Default] for cases
17/// where user data is not required.
18pub trait FromUserData<T> {
19    /// Construct a new instance based on the provided user data.
20    fn from_user_data(data: &T) -> Self;
21}
22
23/// Implement an application-defined aggregate function which cannot be used as a window
24/// function.
25///
26/// In general, there is no reason to implement this trait instead of [AggregateFunction],
27/// because the latter provides a blanket implementation of the former.
28pub trait LegacyAggregateFunction<UserData>: FromUserData<UserData> {
29    /// Assign the default value of the aggregate function to the context using
30    /// [Context::set_result].
31    ///
32    /// This method is called when the aggregate function is invoked over an empty set of
33    /// rows. The default implementation is equivalent to
34    /// `Self::from_user_data(user_data).value(context)`.
35    fn default_value(user_data: &UserData, context: &Context) -> Result<()>
36    where
37        Self: Sized,
38    {
39        Self::from_user_data(user_data).value(context)
40    }
41
42    /// Add a new row to the aggregate.
43    fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
44
45    /// Assign the current value of the aggregate function to the context using
46    /// [Context::set_result]. If no result is set, SQL NULL is returned. If the function returns
47    /// an Err value, the SQL statement will fail, even if a result had been set before the
48    /// failure.
49    fn value(&self, context: &Context) -> Result<()>;
50}
51
52/// Implement an application-defined aggregate window function.
53///
54/// The function can be registered with a database connection using
55/// [Connection::create_aggregate_function].
56pub trait AggregateFunction<UserData>: FromUserData<UserData> {
57    /// Assign the default value of the aggregate function to the context using
58    /// [Context::set_result].
59    ///
60    /// This method is called when the aggregate function is invoked over an empty set of
61    /// rows. The default implementation is equivalent to
62    /// `Self::from_user_data(user_data).value(context)`.
63    fn default_value(user_data: &UserData, context: &Context) -> Result<()>
64    where
65        Self: Sized,
66    {
67        Self::from_user_data(user_data).value(context)
68    }
69
70    /// Add a new row to the aggregate.
71    fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
72
73    /// Assign the current value of the aggregate function to the context using
74    /// [Context::set_result]. If no result is set, SQL NULL is returned. If the function returns
75    /// an Err value, the SQL statement will fail, even if a result had been set before the
76    /// failure.
77    fn value(&self, context: &Context) -> Result<()>;
78
79    /// Remove the oldest presently aggregated row.
80    ///
81    /// The args are the same that were passed to [AggregateFunction::step] when this row
82    /// was added.
83    fn inverse(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
84}
85
86impl<U, F: Default> FromUserData<U> for F {
87    fn from_user_data(_: &U) -> F {
88        F::default()
89    }
90}
91
92impl<U, T: AggregateFunction<U>> LegacyAggregateFunction<U> for T {
93    fn default_value(user_data: &U, context: &Context) -> Result<()> {
94        <T as AggregateFunction<U>>::default_value(user_data, context)
95    }
96
97    fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()> {
98        <T as AggregateFunction<U>>::step(self, context, args)
99    }
100
101    fn value(&self, context: &Context) -> Result<()> {
102        <T as AggregateFunction<U>>::value(self, context)
103    }
104}
105
106#[derive(Debug, Clone)]
107pub struct FunctionOptions {
108    n_args: i32,
109    flags: i32,
110}
111
112impl Default for FunctionOptions {
113    fn default() -> Self {
114        FunctionOptions::default()
115    }
116}
117
118impl FunctionOptions {
119    pub const fn default() -> Self {
120        FunctionOptions {
121            n_args: -1,
122            flags: 0,
123        }
124    }
125
126    /// Set the number of parameters accepted by this function. Multiple functions may be
127    /// provided under the same name with different n_args values; the implementation will
128    /// be chosen by SQLite based on the number of parameters at the call site. The value
129    /// may also be -1, which means that the function accepts any number of parameters.
130    /// Functions which take a specific number of parameters take precedence over functions
131    /// which take any number.
132    ///
133    /// # Panics
134    ///
135    /// This function panics if n_args is outside the range -1..128. This limitation is
136    /// imposed by SQLite.
137    pub const fn set_n_args(mut self, n_args: i32) -> Self {
138        assert!(n_args >= -1 && n_args < 128, "n_args invalid");
139        self.n_args = n_args;
140        self
141    }
142
143    /// Enable or disable the deterministic flag. This flag indicates that the function is
144    /// pure. It must have no side effects and the value must be determined solely its the
145    /// parameters.
146    ///
147    /// The SQLite query planner is able to perform additional optimizations on
148    /// deterministic functions, so use of this flag is recommended where possible.
149    pub const fn set_deterministic(mut self, val: bool) -> Self {
150        if val {
151            self.flags |= ffi::SQLITE_DETERMINISTIC;
152        } else {
153            self.flags &= !ffi::SQLITE_DETERMINISTIC;
154        }
155        self
156    }
157
158    /// Set the level of risk for this function. See the [RiskLevel] enum for details about
159    /// what the individual options mean.
160    ///
161    /// Requires SQLite 3.31.0. On earlier versions of SQLite, this function is a harmless no-op.
162    pub const fn set_risk_level(
163        #[cfg_attr(not(modern_sqlite), allow(unused_mut))] mut self,
164        level: RiskLevel,
165    ) -> Self {
166        let _ = level;
167        #[cfg(modern_sqlite)]
168        {
169            self.flags |= match level {
170                RiskLevel::Innocuous => ffi::SQLITE_INNOCUOUS,
171                RiskLevel::DirectOnly => ffi::SQLITE_DIRECTONLY,
172            };
173            self.flags &= match level {
174                RiskLevel::Innocuous => !ffi::SQLITE_DIRECTONLY,
175                RiskLevel::DirectOnly => !ffi::SQLITE_INNOCUOUS,
176            };
177        }
178        self
179    }
180}
181
182impl Connection {
183    /// Create a stub function that always fails.
184    ///
185    /// This API makes sure a global version of a function with a particular name and
186    /// number of parameters exists. If no such function exists before this API is called,
187    /// a new function is created. The implementation of the new function always causes an
188    /// exception to be thrown. So the new function is not good for anything by itself. Its
189    /// only purpose is to be a placeholder function that can be overloaded by a virtual
190    /// table.
191    ///
192    /// For more information, see [vtab::FindFunctionVTab](super::vtab::FindFunctionVTab).
193    pub fn create_overloaded_function(&self, name: &str, opts: &FunctionOptions) -> Result<()> {
194        let guard = self.lock();
195        let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
196        unsafe {
197            Error::from_sqlite_desc(
198                ffi::sqlite3_overload_function(self.as_mut_ptr(), name.as_ptr() as _, opts.n_args),
199                guard,
200            )
201        }
202    }
203
204    /// Create a new scalar function. The function will be invoked with a [Context] and an array of
205    /// [ValueRef] objects. The function is required to set its output using [Context::set_result].
206    /// If no result is set, SQL NULL is returned. If the function returns an Err value, the SQL
207    /// statement will fail, even if a result had been set before the failure.
208    ///
209    /// # Compatibility
210    ///
211    /// On versions of SQLite earlier than 3.7.3, this function will leak the function and
212    /// all bound variables. This is because these versions of SQLite did not provide the
213    /// ability to specify a destructor function.
214    pub fn create_scalar_function<F>(
215        &self,
216        name: &str,
217        opts: &FunctionOptions,
218        func: F,
219    ) -> Result<()>
220    where
221        F: Fn(&Context, &mut [&mut ValueRef]) -> Result<()> + 'static,
222    {
223        let guard = self.lock();
224        let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
225        let func = Box::new(func);
226        unsafe {
227            Error::from_sqlite_desc(
228                sqlite3_match_version! {
229                    3_007_003 => ffi::sqlite3_create_function_v2(
230                        self.as_mut_ptr(),
231                        name.as_ptr() as _,
232                        opts.n_args,
233                        opts.flags,
234                        Box::into_raw(func) as _,
235                        Some(stubs::call_scalar::<F>),
236                        None,
237                        None,
238                        Some(ffi::drop_boxed::<F>),
239                    ),
240                    _ => ffi::sqlite3_create_function(
241                        self.as_mut_ptr(),
242                        name.as_ptr() as _,
243                        opts.n_args,
244                        opts.flags,
245                        Box::into_raw(func) as _,
246                        Some(stubs::call_scalar::<F>),
247                        None,
248                        None,
249                    ),
250                },
251                guard,
252            )
253        }
254    }
255
256    /// Create a new aggregate function which cannot be used as a window function.
257    ///
258    /// In general, you should use
259    /// [create_aggregate_function](Connection::create_aggregate_function) instead, which
260    /// provides all of the same features as legacy aggregate functions but also support
261    /// WINDOW.
262    ///
263    /// # Compatibility
264    ///
265    /// On versions of SQLite earlier than 3.7.3, this function will leak the user data.
266    /// This is because these versions of SQLite did not provide the ability to specify a
267    /// destructor function.
268    pub fn create_legacy_aggregate_function<U, F: LegacyAggregateFunction<U>>(
269        &self,
270        name: &str,
271        opts: &FunctionOptions,
272        user_data: U,
273    ) -> Result<()> {
274        let guard = self.lock();
275        let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
276        let user_data = Box::new(user_data);
277        unsafe {
278            Error::from_sqlite_desc(
279                sqlite3_match_version! {
280                    3_007_003 => ffi::sqlite3_create_function_v2(
281                        self.as_mut_ptr(),
282                        name.as_ptr() as _,
283                        opts.n_args,
284                        opts.flags,
285                        Box::into_raw(user_data) as _,
286                        None,
287                        Some(stubs::aggregate_step::<U, F>),
288                        Some(stubs::aggregate_final::<U, F>),
289                        Some(ffi::drop_boxed::<U>),
290                    ),
291                    _ => ffi::sqlite3_create_function(
292                        self.as_mut_ptr(),
293                        name.as_ptr() as _,
294                        opts.n_args,
295                        opts.flags,
296                        Box::into_raw(user_data) as _,
297                        None,
298                        Some(stubs::aggregate_step::<U, F>),
299                        Some(stubs::aggregate_final::<U, F>),
300                    ),
301                },
302                guard,
303            )
304        }
305    }
306
307    /// Create a new aggregate function.
308    ///
309    /// # Compatibility
310    ///
311    /// Window functions require SQLite 3.25.0. On earlier versions of SQLite, this
312    /// function will automatically fall back to
313    /// [create_legacy_aggregate_function](Connection::create_legacy_aggregate_function).
314    pub fn create_aggregate_function<U, F: AggregateFunction<U>>(
315        &self,
316        name: &str,
317        opts: &FunctionOptions,
318        user_data: U,
319    ) -> Result<()> {
320        sqlite3_match_version! {
321            3_025_000 => {
322                let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
323                let user_data = Box::new(user_data);
324                let guard = self.lock();
325                unsafe {
326                    Error::from_sqlite_desc(ffi::sqlite3_create_window_function(
327                        self.as_mut_ptr(),
328                        name.as_ptr() as _,
329                        opts.n_args,
330                        opts.flags,
331                        Box::into_raw(user_data) as _,
332                        Some(stubs::aggregate_step::<U, F>),
333                        Some(stubs::aggregate_final::<U, F>),
334                        Some(stubs::aggregate_value::<U, F>),
335                        Some(stubs::aggregate_inverse::<U, F>),
336                        Some(ffi::drop_boxed::<U>),
337                    ), guard)
338                }
339            },
340            _ => self.create_legacy_aggregate_function::<U, F>(name, opts, user_data),
341        }
342    }
343
344    /// Remove an application-defined scalar or aggregate function. The name and n_args
345    /// parameters must match the values used when the function was created.
346    pub fn remove_function(&self, name: &str, n_args: i32) -> Result<()> {
347        let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
348        let guard = self.lock();
349        unsafe {
350            Error::from_sqlite_desc(
351                ffi::sqlite3_create_function(
352                    self.as_mut_ptr(),
353                    name.as_ptr() as _,
354                    n_args,
355                    0,
356                    null_mut(),
357                    None,
358                    None,
359                    None,
360                ),
361                guard,
362            )
363        }
364    }
365
366    /// Register a new collating sequence.
367    pub fn create_collation<F: Fn(&str, &str) -> Ordering>(
368        &self,
369        name: &str,
370        func: F,
371    ) -> Result<()> {
372        let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
373        let func = Box::into_raw(Box::new(func));
374        let guard = self.lock();
375        unsafe {
376            let rc = ffi::sqlite3_create_collation_v2(
377                self.as_mut_ptr(),
378                name.as_ptr() as _,
379                ffi::SQLITE_UTF8,
380                func as _,
381                Some(stubs::compare::<F>),
382                Some(ffi::drop_boxed::<F>),
383            );
384            if rc != ffi::SQLITE_OK {
385                // The xDestroy callback is not called if the
386                // sqlite3_create_collation_v2() function fails.
387                drop(Box::from_raw(func));
388            }
389            Error::from_sqlite_desc(rc, guard)
390        }
391    }
392
393    /// Register a callback for when SQLite needs a collation sequence. The function will
394    /// be invoked when a collation sequence is needed, and
395    /// [create_collation](Connection::create_collation) can be used to provide the needed
396    /// sequence.
397    ///
398    /// Note: the provided function and any captured variables will be leaked. SQLite does
399    /// not provide any facilities for cleaning up this data.
400    pub fn set_collation_needed_func<F: Fn(&str)>(&self, func: F) -> Result<()> {
401        let func = Box::new(func);
402        let guard = self.lock();
403        unsafe {
404            Error::from_sqlite_desc(
405                ffi::sqlite3_collation_needed(
406                    self.as_mut_ptr(),
407                    Box::into_raw(func) as _,
408                    Some(stubs::collation_needed::<F>),
409                ),
410                guard,
411            )
412        }
413    }
414}