Skip to main content

sqlite3_ext/function/
mod.rs

1//! Create application-defined functions.
2//!
3//! The functionality in this module is primarily exposed through
4//! [Connection::create_scalar_function] and [Connection::create_aggregate_function].
5use super::{ffi, sqlite3_match_version, types::*, value::*, Connection, RiskLevel};
6pub use context::*;
7use std::{cmp::Ordering, ffi::CString, ptr::null_mut};
8
9mod context;
10mod stubs;
11mod test;
12
13/// Constructor for aggregate functions.
14///
15/// Aggregate functions are instantiated using user data provided when the function is
16/// registered. There is a blanket implementation for types implementing [Default] for cases
17/// where user data is not required.
18pub trait FromUserData<T> {
19    /// Construct a new instance based on the provided user data.
20    fn from_user_data(data: &T) -> Self;
21}
22
23/// Trait for scalar functions. This trait is used with
24/// [Connection::create_scalar_function_object] to implement scalar functions that have a
25/// lifetime smaller than `'static`. It is also possible to use closures and avoid implementing
26/// this trait, see [Connection::create_scalar_function] for details.
27pub trait ScalarFunction<'db> {
28    /// Perform a single invocation. The function will be invoked with a [Context] and an
29    /// array of [ValueRef] objects. The function is required to set its output using
30    /// [Context::set_result]. If no result is set, SQL NULL is returned. If the function
31    /// returns an Err value, the SQL statement will fail, even if a result had been set.
32    fn call(&self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
33}
34
35struct ScalarClosure<F>(F)
36where
37    F: Fn(&Context, &mut [&mut ValueRef]) -> Result<()> + 'static;
38
39impl<F> ScalarFunction<'_> for ScalarClosure<F>
40where
41    F: Fn(&Context, &mut [&mut ValueRef]) -> Result<()> + 'static,
42{
43    fn call(&self, ctx: &Context, args: &mut [&mut ValueRef]) -> Result<()> {
44        self.0(ctx, args)
45    }
46}
47
48/// Implement an application-defined aggregate function which cannot be used as a window
49/// function.
50///
51/// In general, there is no reason to implement this trait instead of [AggregateFunction],
52/// because the latter provides a blanket implementation of the former.
53pub trait LegacyAggregateFunction<UserData>: FromUserData<UserData> {
54    /// Assign the default value of the aggregate function to the context using
55    /// [Context::set_result].
56    ///
57    /// This method is called when the aggregate function is invoked over an empty set of
58    /// rows. The default implementation is equivalent to
59    /// `Self::from_user_data(user_data).value(context)`.
60    fn default_value(user_data: &UserData, context: &Context) -> Result<()>
61    where
62        Self: Sized,
63    {
64        Self::from_user_data(user_data).value(context)
65    }
66
67    /// Add a new row to the aggregate.
68    fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
69
70    /// Assign the current value of the aggregate function to the context using
71    /// [Context::set_result]. If no result is set, SQL NULL is returned. If the function returns
72    /// an Err value, the SQL statement will fail, even if a result had been set before the
73    /// failure.
74    fn value(&self, context: &Context) -> Result<()>;
75}
76
77/// Implement an application-defined aggregate window function.
78///
79/// The function can be registered with a database connection using
80/// [Connection::create_aggregate_function].
81pub trait AggregateFunction<UserData>: FromUserData<UserData> {
82    /// Assign the default value of the aggregate function to the context using
83    /// [Context::set_result].
84    ///
85    /// This method is called when the aggregate function is invoked over an empty set of
86    /// rows. The default implementation is equivalent to
87    /// `Self::from_user_data(user_data).value(context)`.
88    fn default_value(user_data: &UserData, context: &Context) -> Result<()>
89    where
90        Self: Sized,
91    {
92        Self::from_user_data(user_data).value(context)
93    }
94
95    /// Add a new row to the aggregate.
96    fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
97
98    /// Assign the current value of the aggregate function to the context using
99    /// [Context::set_result]. If no result is set, SQL NULL is returned. If the function returns
100    /// an Err value, the SQL statement will fail, even if a result had been set before the
101    /// failure.
102    fn value(&self, context: &Context) -> Result<()>;
103
104    /// Remove the oldest presently aggregated row.
105    ///
106    /// The args are the same that were passed to [AggregateFunction::step] when this row
107    /// was added.
108    fn inverse(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
109}
110
111impl<U, F: Default> FromUserData<U> for F {
112    fn from_user_data(_: &U) -> F {
113        F::default()
114    }
115}
116
117impl<U, T: AggregateFunction<U>> LegacyAggregateFunction<U> for T {
118    fn default_value(user_data: &U, context: &Context) -> Result<()> {
119        <T as AggregateFunction<U>>::default_value(user_data, context)
120    }
121
122    fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()> {
123        <T as AggregateFunction<U>>::step(self, context, args)
124    }
125
126    fn value(&self, context: &Context) -> Result<()> {
127        <T as AggregateFunction<U>>::value(self, context)
128    }
129}
130
131#[derive(Debug, Clone)]
132pub struct FunctionOptions {
133    n_args: i32,
134    flags: i32,
135}
136
137impl Default for FunctionOptions {
138    fn default() -> Self {
139        FunctionOptions::default()
140    }
141}
142
143impl FunctionOptions {
144    pub const fn default() -> Self {
145        FunctionOptions {
146            n_args: -1,
147            flags: 0,
148        }
149    }
150
151    /// Set the number of parameters accepted by this function. Multiple functions may be
152    /// provided under the same name with different n_args values; the implementation will
153    /// be chosen by SQLite based on the number of parameters at the call site. The value
154    /// may also be -1, which means that the function accepts any number of parameters.
155    /// Functions which take a specific number of parameters take precedence over functions
156    /// which take any number.
157    ///
158    /// # Panics
159    ///
160    /// This function panics if n_args is outside the range -1..128. This limitation is
161    /// imposed by SQLite.
162    pub const fn set_n_args(mut self, n_args: i32) -> Self {
163        assert!(n_args >= -1 && n_args < 128, "n_args invalid");
164        self.n_args = n_args;
165        self
166    }
167
168    /// Enable or disable the deterministic flag. This flag indicates that the function is
169    /// pure. It must have no side effects and the value must be determined solely its the
170    /// parameters.
171    ///
172    /// The SQLite query planner is able to perform additional optimizations on
173    /// deterministic functions, so use of this flag is recommended where possible.
174    pub const fn set_deterministic(mut self, val: bool) -> Self {
175        if val {
176            self.flags |= ffi::SQLITE_DETERMINISTIC;
177        } else {
178            self.flags &= !ffi::SQLITE_DETERMINISTIC;
179        }
180        self
181    }
182
183    /// Set the level of risk for this function. See the [RiskLevel] enum for details about
184    /// what the individual options mean.
185    ///
186    /// Requires SQLite 3.31.0. On earlier versions of SQLite, this function is a harmless no-op.
187    pub const fn set_risk_level(
188        #[cfg_attr(not(modern_sqlite), allow(unused_mut))] mut self,
189        level: RiskLevel,
190    ) -> Self {
191        let _ = level;
192        #[cfg(modern_sqlite)]
193        {
194            self.flags |= match level {
195                RiskLevel::Innocuous => ffi::SQLITE_INNOCUOUS,
196                RiskLevel::DirectOnly => ffi::SQLITE_DIRECTONLY,
197            };
198            self.flags &= match level {
199                RiskLevel::Innocuous => !ffi::SQLITE_DIRECTONLY,
200                RiskLevel::DirectOnly => !ffi::SQLITE_INNOCUOUS,
201            };
202        }
203        self
204    }
205}
206
207impl Connection {
208    /// Create a stub function that always fails.
209    ///
210    /// This API makes sure a global version of a function with a particular name and
211    /// number of parameters exists. If no such function exists before this API is called,
212    /// a new function is created. The implementation of the new function always causes an
213    /// exception to be thrown. So the new function is not good for anything by itself. Its
214    /// only purpose is to be a placeholder function that can be overloaded by a virtual
215    /// table.
216    ///
217    /// For more information, see [vtab::FindFunctionVTab](super::vtab::FindFunctionVTab).
218    pub fn create_overloaded_function(&self, name: &str, opts: &FunctionOptions) -> Result<()> {
219        let guard = self.lock();
220        let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
221        unsafe {
222            Error::from_sqlite_desc(
223                ffi::sqlite3_overload_function(self.as_mut_ptr(), name.as_ptr() as _, opts.n_args),
224                guard,
225            )
226        }
227    }
228
229    /// Create a new scalar function. The function will be invoked with a [Context] and an array of
230    /// [ValueRef] objects. The function is required to set its output using [Context::set_result].
231    /// If no result is set, SQL NULL is returned. If the function returns an Err value, the SQL
232    /// statement will fail, even if a result had been set.
233    ///
234    /// The passed function can be a closure, however the lifetime of the closure must be
235    /// `'static` due to limitations in the Rust borrow checker. The
236    /// [Self::create_scalar_function_object] function is an alternative that allows using an
237    /// alternative lifetime.
238    ///
239    /// # Compatibility
240    ///
241    /// On versions of SQLite earlier than 3.7.3, this function will leak the function and
242    /// all bound variables. This is because these versions of SQLite did not provide the
243    /// ability to specify a destructor function.
244    pub fn create_scalar_function<F>(
245        &self,
246        name: &str,
247        opts: &FunctionOptions,
248        func: F,
249    ) -> Result<()>
250    where
251        F: Fn(&Context, &mut [&mut ValueRef]) -> Result<()> + 'static,
252    {
253        self.create_scalar_function_object(name, &opts, ScalarClosure(func))
254    }
255
256    /// Create a new scalar function using a struct. This function is identical to
257    /// [Self::create_scalar_function], but uses a trait object instead of a closure. This enables
258    /// creating scalar functions that maintain references with a lifetime smaller than `'static`.
259    pub fn create_scalar_function_object<'db, F>(
260        &'db self,
261        name: &str,
262        opts: &FunctionOptions,
263        func: F,
264    ) -> Result<()>
265    where
266        F: ScalarFunction<'db>,
267    {
268        let guard = self.lock();
269        let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
270        let func = Box::new(func);
271        unsafe {
272            Error::from_sqlite_desc(
273                sqlite3_match_version! {
274                    3_007_003 => ffi::sqlite3_create_function_v2(
275                        self.as_mut_ptr(),
276                        name.as_ptr() as _,
277                        opts.n_args,
278                        opts.flags,
279                        Box::into_raw(func) as _,
280                        Some(stubs::call_scalar::<F>),
281                        None,
282                        None,
283                        Some(ffi::drop_boxed::<F>),
284                    ),
285                    _ => ffi::sqlite3_create_function(
286                        self.as_mut_ptr(),
287                        name.as_ptr() as _,
288                        opts.n_args,
289                        opts.flags,
290                        Box::into_raw(func) as _,
291                        Some(stubs::call_scalar::<F>),
292                        None,
293                        None,
294                    ),
295                },
296                guard,
297            )
298        }
299    }
300
301    /// Create a new aggregate function which cannot be used as a window function.
302    ///
303    /// In general, you should use
304    /// [create_aggregate_function](Connection::create_aggregate_function) instead, which
305    /// provides all of the same features as legacy aggregate functions but also support
306    /// WINDOW.
307    ///
308    /// # Compatibility
309    ///
310    /// On versions of SQLite earlier than 3.7.3, this function will leak the user data.
311    /// This is because these versions of SQLite did not provide the ability to specify a
312    /// destructor function.
313    pub fn create_legacy_aggregate_function<U, F: LegacyAggregateFunction<U>>(
314        &self,
315        name: &str,
316        opts: &FunctionOptions,
317        user_data: U,
318    ) -> Result<()> {
319        let guard = self.lock();
320        let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
321        let user_data = Box::new(user_data);
322        unsafe {
323            Error::from_sqlite_desc(
324                sqlite3_match_version! {
325                    3_007_003 => ffi::sqlite3_create_function_v2(
326                        self.as_mut_ptr(),
327                        name.as_ptr() as _,
328                        opts.n_args,
329                        opts.flags,
330                        Box::into_raw(user_data) as _,
331                        None,
332                        Some(stubs::aggregate_step::<U, F>),
333                        Some(stubs::aggregate_final::<U, F>),
334                        Some(ffi::drop_boxed::<U>),
335                    ),
336                    _ => ffi::sqlite3_create_function(
337                        self.as_mut_ptr(),
338                        name.as_ptr() as _,
339                        opts.n_args,
340                        opts.flags,
341                        Box::into_raw(user_data) as _,
342                        None,
343                        Some(stubs::aggregate_step::<U, F>),
344                        Some(stubs::aggregate_final::<U, F>),
345                    ),
346                },
347                guard,
348            )
349        }
350    }
351
352    /// Create a new aggregate function.
353    ///
354    /// # Compatibility
355    ///
356    /// Window functions require SQLite 3.25.0. On earlier versions of SQLite, this
357    /// function will automatically fall back to
358    /// [create_legacy_aggregate_function](Connection::create_legacy_aggregate_function).
359    pub fn create_aggregate_function<U, F: AggregateFunction<U>>(
360        &self,
361        name: &str,
362        opts: &FunctionOptions,
363        user_data: U,
364    ) -> Result<()> {
365        sqlite3_match_version! {
366            3_025_000 => {
367                let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
368                let user_data = Box::new(user_data);
369                let guard = self.lock();
370                unsafe {
371                    Error::from_sqlite_desc(ffi::sqlite3_create_window_function(
372                        self.as_mut_ptr(),
373                        name.as_ptr() as _,
374                        opts.n_args,
375                        opts.flags,
376                        Box::into_raw(user_data) as _,
377                        Some(stubs::aggregate_step::<U, F>),
378                        Some(stubs::aggregate_final::<U, F>),
379                        Some(stubs::aggregate_value::<U, F>),
380                        Some(stubs::aggregate_inverse::<U, F>),
381                        Some(ffi::drop_boxed::<U>),
382                    ), guard)
383                }
384            },
385            _ => self.create_legacy_aggregate_function::<U, F>(name, opts, user_data),
386        }
387    }
388
389    /// Remove an application-defined scalar or aggregate function. The name and n_args
390    /// parameters must match the values used when the function was created.
391    pub fn remove_function(&self, name: &str, n_args: i32) -> Result<()> {
392        let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
393        let guard = self.lock();
394        unsafe {
395            Error::from_sqlite_desc(
396                ffi::sqlite3_create_function(
397                    self.as_mut_ptr(),
398                    name.as_ptr() as _,
399                    n_args,
400                    0,
401                    null_mut(),
402                    None,
403                    None,
404                    None,
405                ),
406                guard,
407            )
408        }
409    }
410
411    /// Register a new collating sequence.
412    pub fn create_collation<F: Fn(&str, &str) -> Ordering>(
413        &self,
414        name: &str,
415        func: F,
416    ) -> Result<()> {
417        let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
418        let func = Box::into_raw(Box::new(func));
419        let guard = self.lock();
420        unsafe {
421            let rc = ffi::sqlite3_create_collation_v2(
422                self.as_mut_ptr(),
423                name.as_ptr() as _,
424                ffi::SQLITE_UTF8,
425                func as _,
426                Some(stubs::compare::<F>),
427                Some(ffi::drop_boxed::<F>),
428            );
429            if rc != ffi::SQLITE_OK {
430                // The xDestroy callback is not called if the
431                // sqlite3_create_collation_v2() function fails.
432                drop(Box::from_raw(func));
433            }
434            Error::from_sqlite_desc(rc, guard)
435        }
436    }
437
438    /// Register a callback for when SQLite needs a collation sequence. The function will
439    /// be invoked when a collation sequence is needed, and
440    /// [create_collation](Connection::create_collation) can be used to provide the needed
441    /// sequence.
442    ///
443    /// Note: the provided function and any captured variables will be leaked. SQLite does
444    /// not provide any facilities for cleaning up this data.
445    pub fn set_collation_needed_func<F: Fn(&str)>(&self, func: F) -> Result<()> {
446        let func = Box::new(func);
447        let guard = self.lock();
448        unsafe {
449            Error::from_sqlite_desc(
450                ffi::sqlite3_collation_needed(
451                    self.as_mut_ptr(),
452                    Box::into_raw(func) as _,
453                    Some(stubs::collation_needed::<F>),
454                ),
455                guard,
456            )
457        }
458    }
459}