rusqlite/
functions.rs

1//! Create or redefine SQL functions.
2//!
3//! # Example
4//!
5//! Adding a `regexp` function to a connection in which compiled regular
6//! expressions are cached in a `HashMap`. For an alternative implementation
7//! that uses SQLite's [Function Auxiliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface
8//! to avoid recompiling regular expressions, see the unit tests for this
9//! module.
10//!
11//! ```rust
12//! use regex::Regex;
13//! use rusqlite::functions::FunctionFlags;
14//! use rusqlite::{Connection, Error, Result};
15//! use std::sync::Arc;
16//! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
17//!
18//! fn add_regexp_function(db: &Connection) -> Result<()> {
19//!     db.create_scalar_function(
20//!         "regexp",
21//!         2,
22//!         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
23//!         move |ctx| {
24//!             assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
25//!             let regexp: Arc<Regex> = ctx.get_or_create_aux(0, |vr| -> Result<_, BoxError> {
26//!                 Ok(Regex::new(vr.as_str()?)?)
27//!             })?;
28//!             let is_match = {
29//!                 let text = ctx
30//!                     .get_raw(1)
31//!                     .as_str()
32//!                     .map_err(|e| Error::UserFunctionError(e.into()))?;
33//!
34//!                 regexp.is_match(text)
35//!             };
36//!
37//!             Ok(is_match)
38//!         },
39//!     )
40//! }
41//!
42//! fn main() -> Result<()> {
43//!     let db = Connection::open_in_memory()?;
44//!     add_regexp_function(&db)?;
45//!
46//!     let is_match: bool =
47//!         db.query_row("SELECT regexp('[aeiou]*', 'aaaaeeeiii')", [], |row| {
48//!             row.get(0)
49//!         })?;
50//!
51//!     assert!(is_match);
52//!     Ok(())
53//! }
54//! ```
55use std::any::Any;
56use std::marker::PhantomData;
57use std::ops::Deref;
58use std::os::raw::{c_int, c_void};
59use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
60use std::ptr;
61use std::slice;
62use std::sync::Arc;
63
64use crate::ffi;
65use crate::ffi::sqlite3_context;
66use crate::ffi::sqlite3_value;
67
68use crate::context::set_result;
69use crate::types::{FromSql, FromSqlError, ToSql, ValueRef};
70
71use crate::{str_to_cstring, Connection, Error, InnerConnection, Result};
72
73unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
74    // Extended constraint error codes were added in SQLite 3.7.16. We don't have
75    // an explicit feature check for that, and this doesn't really warrant one.
76    // We'll use the extended code if we're on the bundled version (since it's
77    // at least 3.17.0) and the normal constraint error code if not.
78    #[cfg(feature = "modern_sqlite")]
79    fn constraint_error_code() -> i32 {
80        ffi::SQLITE_CONSTRAINT_FUNCTION
81    }
82    #[cfg(not(feature = "modern_sqlite"))]
83    fn constraint_error_code() -> i32 {
84        ffi::SQLITE_CONSTRAINT
85    }
86
87    if let Error::SqliteFailure(ref err, ref s) = *err {
88        ffi::sqlite3_result_error_code(ctx, err.extended_code);
89        if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
90            ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
91        }
92    } else {
93        ffi::sqlite3_result_error_code(ctx, constraint_error_code());
94        if let Ok(cstr) = str_to_cstring(&err.to_string()) {
95            ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
96        }
97    }
98}
99
100unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
101    drop(Box::from_raw(p.cast::<T>()));
102}
103
104/// Context is a wrapper for the SQLite function
105/// evaluation context.
106pub struct Context<'a> {
107    ctx: *mut sqlite3_context,
108    args: &'a [*mut sqlite3_value],
109}
110
111impl Context<'_> {
112    /// Returns the number of arguments to the function.
113    #[inline]
114    #[must_use]
115    pub fn len(&self) -> usize {
116        self.args.len()
117    }
118
119    /// Returns `true` when there is no argument.
120    #[inline]
121    #[must_use]
122    pub fn is_empty(&self) -> bool {
123        self.args.is_empty()
124    }
125
126    /// Returns the `idx`th argument as a `T`.
127    ///
128    /// # Failure
129    ///
130    /// Will panic if `idx` is greater than or equal to
131    /// [`self.len()`](Context::len).
132    ///
133    /// Will return Err if the underlying SQLite type cannot be converted to a
134    /// `T`.
135    pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> {
136        let arg = self.args[idx];
137        let value = unsafe { ValueRef::from_value(arg) };
138        FromSql::column_result(value).map_err(|err| match err {
139            FromSqlError::InvalidType => {
140                Error::InvalidFunctionParameterType(idx, value.data_type())
141            }
142            FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
143            FromSqlError::Other(err) => {
144                Error::FromSqlConversionFailure(idx, value.data_type(), err)
145            }
146            FromSqlError::InvalidBlobSize { .. } => {
147                Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
148            }
149        })
150    }
151
152    /// Returns the `idx`th argument as a `ValueRef`.
153    ///
154    /// # Failure
155    ///
156    /// Will panic if `idx` is greater than or equal to
157    /// [`self.len()`](Context::len).
158    #[inline]
159    #[must_use]
160    pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
161        let arg = self.args[idx];
162        unsafe { ValueRef::from_value(arg) }
163    }
164
165    /// Returns the subtype of `idx`th argument.
166    ///
167    /// # Failure
168    ///
169    /// Will panic if `idx` is greater than or equal to
170    /// [`self.len()`](Context::len).
171    #[cfg(feature = "modern_sqlite")] // 3.9.0
172    #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))]
173    pub fn get_subtype(&self, idx: usize) -> std::os::raw::c_uint {
174        let arg = self.args[idx];
175        unsafe { ffi::sqlite3_value_subtype(arg) }
176    }
177
178    /// Fetch or insert the auxiliary data associated with a particular
179    /// parameter. This is intended to be an easier-to-use way of fetching it
180    /// compared to calling [`get_aux`](Context::get_aux) and
181    /// [`set_aux`](Context::set_aux) separately.
182    ///
183    /// See `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of
184    /// this feature, or the unit tests of this module for an example.
185    pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
186    where
187        T: Send + Sync + 'static,
188        E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
189        F: FnOnce(ValueRef<'_>) -> Result<T, E>,
190    {
191        if let Some(v) = self.get_aux(arg)? {
192            Ok(v)
193        } else {
194            let vr = self.get_raw(arg as usize);
195            self.set_aux(
196                arg,
197                func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
198            )
199        }
200    }
201
202    /// Sets the auxiliary data associated with a particular parameter. See
203    /// `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of
204    /// this feature, or the unit tests of this module for an example.
205    pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
206        let orig: Arc<T> = Arc::new(value);
207        let inner: AuxInner = orig.clone();
208        let outer = Box::new(inner);
209        let raw: *mut AuxInner = Box::into_raw(outer);
210        unsafe {
211            ffi::sqlite3_set_auxdata(
212                self.ctx,
213                arg,
214                raw.cast(),
215                Some(free_boxed_value::<AuxInner>),
216            );
217        };
218        Ok(orig)
219    }
220
221    /// Gets the auxiliary data that was associated with a given parameter via
222    /// [`set_aux`](Context::set_aux). Returns `Ok(None)` if no data has been
223    /// associated, and Ok(Some(v)) if it has. Returns an error if the
224    /// requested type does not match.
225    pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
226        let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
227        if p.is_null() {
228            Ok(None)
229        } else {
230            let v: AuxInner = AuxInner::clone(unsafe { &*p });
231            v.downcast::<T>()
232                .map(Some)
233                .map_err(|_| Error::GetAuxWrongType)
234        }
235    }
236
237    /// Get the db connection handle via [sqlite3_context_db_handle](https://www.sqlite.org/c3ref/context_db_handle.html)
238    ///
239    /// # Safety
240    ///
241    /// This function is marked unsafe because there is a potential for other
242    /// references to the connection to be sent across threads, [see this comment](https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213).
243    pub unsafe fn get_connection(&self) -> Result<ConnectionRef<'_>> {
244        let handle = ffi::sqlite3_context_db_handle(self.ctx);
245        Ok(ConnectionRef {
246            conn: Connection::from_handle(handle)?,
247            phantom: PhantomData,
248        })
249    }
250
251    /// Set the Subtype of an SQL function
252    #[cfg(feature = "modern_sqlite")] // 3.9.0
253    #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))]
254    pub fn set_result_subtype(&self, sub_type: std::os::raw::c_uint) {
255        unsafe { ffi::sqlite3_result_subtype(self.ctx, sub_type) };
256    }
257}
258
259/// A reference to a connection handle with a lifetime bound to something.
260pub struct ConnectionRef<'ctx> {
261    // comes from Connection::from_handle(sqlite3_context_db_handle(...))
262    // and is non-owning
263    conn: Connection,
264    phantom: PhantomData<&'ctx Context<'ctx>>,
265}
266
267impl Deref for ConnectionRef<'_> {
268    type Target = Connection;
269
270    #[inline]
271    fn deref(&self) -> &Connection {
272        &self.conn
273    }
274}
275
276type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
277
278/// Aggregate is the callback interface for user-defined
279/// aggregate function.
280///
281/// `A` is the type of the aggregation context and `T` is the type of the final
282/// result. Implementations should be stateless.
283pub trait Aggregate<A, T>
284where
285    A: RefUnwindSafe + UnwindSafe,
286    T: ToSql,
287{
288    /// Initializes the aggregation context. Will be called prior to the first
289    /// call to [`step()`](Aggregate::step) to set up the context for an
290    /// invocation of the function. (Note: `init()` will not be called if
291    /// there are no rows.)
292    fn init(&self, _: &mut Context<'_>) -> Result<A>;
293
294    /// "step" function called once for each row in an aggregate group. May be
295    /// called 0 times if there are no rows.
296    fn step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
297
298    /// Computes and returns the final result. Will be called exactly once for
299    /// each invocation of the function. If [`step()`](Aggregate::step) was
300    /// called at least once, will be given `Some(A)` (the same `A` as was
301    /// created by [`init`](Aggregate::init) and given to
302    /// [`step`](Aggregate::step)); if [`step()`](Aggregate::step) was not
303    /// called (because the function is running against 0 rows), will be
304    /// given `None`.
305    ///
306    /// The passed context will have no arguments.
307    fn finalize(&self, _: &mut Context<'_>, _: Option<A>) -> Result<T>;
308}
309
310/// `WindowAggregate` is the callback interface for
311/// user-defined aggregate window function.
312#[cfg(feature = "window")]
313#[cfg_attr(docsrs, doc(cfg(feature = "window")))]
314pub trait WindowAggregate<A, T>: Aggregate<A, T>
315where
316    A: RefUnwindSafe + UnwindSafe,
317    T: ToSql,
318{
319    /// Returns the current value of the aggregate. Unlike xFinal, the
320    /// implementation should not delete any context.
321    fn value(&self, _: Option<&A>) -> Result<T>;
322
323    /// Removes a row from the current window.
324    fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
325}
326
327bitflags::bitflags! {
328    /// Function Flags.
329    /// See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html)
330    /// and [Function Flags](https://sqlite.org/c3ref/c_deterministic.html) for details.
331    #[repr(C)]
332    pub struct FunctionFlags: ::std::os::raw::c_int {
333        /// Specifies UTF-8 as the text encoding this SQL function prefers for its parameters.
334        const SQLITE_UTF8     = ffi::SQLITE_UTF8;
335        /// Specifies UTF-16 using little-endian byte order as the text encoding this SQL function prefers for its parameters.
336        const SQLITE_UTF16LE  = ffi::SQLITE_UTF16LE;
337        /// Specifies UTF-16 using big-endian byte order as the text encoding this SQL function prefers for its parameters.
338        const SQLITE_UTF16BE  = ffi::SQLITE_UTF16BE;
339        /// Specifies UTF-16 using native byte order as the text encoding this SQL function prefers for its parameters.
340        const SQLITE_UTF16    = ffi::SQLITE_UTF16;
341        /// Means that the function always gives the same output when the input parameters are the same.
342        const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC; // 3.8.3
343        /// Means that the function may only be invoked from top-level SQL.
344        const SQLITE_DIRECTONLY    = 0x0000_0008_0000; // 3.30.0
345        /// Indicates to SQLite that a function may call `sqlite3_value_subtype()` to inspect the sub-types of its arguments.
346        const SQLITE_SUBTYPE       = 0x0000_0010_0000; // 3.30.0
347        /// Means that the function is unlikely to cause problems even if misused.
348        const SQLITE_INNOCUOUS     = 0x0000_0020_0000; // 3.31.0
349    }
350}
351
352impl Default for FunctionFlags {
353    #[inline]
354    fn default() -> FunctionFlags {
355        FunctionFlags::SQLITE_UTF8
356    }
357}
358
359impl Connection {
360    /// Attach a user-defined scalar function to
361    /// this database connection.
362    ///
363    /// `fn_name` is the name the function will be accessible from SQL.
364    /// `n_arg` is the number of arguments to the function. Use `-1` for a
365    /// variable number. If the function always returns the same value
366    /// given the same input, `deterministic` should be `true`.
367    ///
368    /// The function will remain available until the connection is closed or
369    /// until it is explicitly removed via
370    /// [`remove_function`](Connection::remove_function).
371    ///
372    /// # Example
373    ///
374    /// ```rust
375    /// # use rusqlite::{Connection, Result};
376    /// # use rusqlite::functions::FunctionFlags;
377    /// fn scalar_function_example(db: Connection) -> Result<()> {
378    ///     db.create_scalar_function(
379    ///         "halve",
380    ///         1,
381    ///         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
382    ///         |ctx| {
383    ///             let value = ctx.get::<f64>(0)?;
384    ///             Ok(value / 2f64)
385    ///         },
386    ///     )?;
387    ///
388    ///     let six_halved: f64 = db.query_row("SELECT halve(6)", [], |r| r.get(0))?;
389    ///     assert_eq!(six_halved, 3f64);
390    ///     Ok(())
391    /// }
392    /// ```
393    ///
394    /// # Failure
395    ///
396    /// Will return Err if the function could not be attached to the connection.
397    #[inline]
398    pub fn create_scalar_function<F, T>(
399        &self,
400        fn_name: &str,
401        n_arg: c_int,
402        flags: FunctionFlags,
403        x_func: F,
404    ) -> Result<()>
405    where
406        F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
407        T: ToSql,
408    {
409        self.db
410            .borrow_mut()
411            .create_scalar_function(fn_name, n_arg, flags, x_func)
412    }
413
414    /// Attach a user-defined aggregate function to this
415    /// database connection.
416    ///
417    /// # Failure
418    ///
419    /// Will return Err if the function could not be attached to the connection.
420    #[inline]
421    pub fn create_aggregate_function<A, D, T>(
422        &self,
423        fn_name: &str,
424        n_arg: c_int,
425        flags: FunctionFlags,
426        aggr: D,
427    ) -> Result<()>
428    where
429        A: RefUnwindSafe + UnwindSafe,
430        D: Aggregate<A, T> + 'static,
431        T: ToSql,
432    {
433        self.db
434            .borrow_mut()
435            .create_aggregate_function(fn_name, n_arg, flags, aggr)
436    }
437
438    /// Attach a user-defined aggregate window function to
439    /// this database connection.
440    ///
441    /// See `https://sqlite.org/windowfunctions.html#udfwinfunc` for more
442    /// information.
443    #[cfg(feature = "window")]
444    #[cfg_attr(docsrs, doc(cfg(feature = "window")))]
445    #[inline]
446    pub fn create_window_function<A, W, T>(
447        &self,
448        fn_name: &str,
449        n_arg: c_int,
450        flags: FunctionFlags,
451        aggr: W,
452    ) -> Result<()>
453    where
454        A: RefUnwindSafe + UnwindSafe,
455        W: WindowAggregate<A, T> + 'static,
456        T: ToSql,
457    {
458        self.db
459            .borrow_mut()
460            .create_window_function(fn_name, n_arg, flags, aggr)
461    }
462
463    /// Removes a user-defined function from this
464    /// database connection.
465    ///
466    /// `fn_name` and `n_arg` should match the name and number of arguments
467    /// given to [`create_scalar_function`](Connection::create_scalar_function)
468    /// or [`create_aggregate_function`](Connection::create_aggregate_function).
469    ///
470    /// # Failure
471    ///
472    /// Will return Err if the function could not be removed.
473    #[inline]
474    pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> {
475        self.db.borrow_mut().remove_function(fn_name, n_arg)
476    }
477}
478
479impl InnerConnection {
480    fn create_scalar_function<F, T>(
481        &mut self,
482        fn_name: &str,
483        n_arg: c_int,
484        flags: FunctionFlags,
485        x_func: F,
486    ) -> Result<()>
487    where
488        F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
489        T: ToSql,
490    {
491        unsafe extern "C" fn call_boxed_closure<F, T>(
492            ctx: *mut sqlite3_context,
493            argc: c_int,
494            argv: *mut *mut sqlite3_value,
495        ) where
496            F: FnMut(&Context<'_>) -> Result<T>,
497            T: ToSql,
498        {
499            let r = catch_unwind(|| {
500                let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::<F>();
501                assert!(!boxed_f.is_null(), "Internal error - null function pointer");
502                let ctx = Context {
503                    ctx,
504                    args: slice::from_raw_parts(argv, argc as usize),
505                };
506                (*boxed_f)(&ctx)
507            });
508            let t = match r {
509                Err(_) => {
510                    report_error(ctx, &Error::UnwindingPanic);
511                    return;
512                }
513                Ok(r) => r,
514            };
515            let t = t.as_ref().map(|t| ToSql::to_sql(t));
516
517            match t {
518                Ok(Ok(ref value)) => set_result(ctx, value),
519                Ok(Err(err)) => report_error(ctx, &err),
520                Err(err) => report_error(ctx, err),
521            }
522        }
523
524        let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
525        let c_name = str_to_cstring(fn_name)?;
526        let r = unsafe {
527            ffi::sqlite3_create_function_v2(
528                self.db(),
529                c_name.as_ptr(),
530                n_arg,
531                flags.bits(),
532                boxed_f.cast::<c_void>(),
533                Some(call_boxed_closure::<F, T>),
534                None,
535                None,
536                Some(free_boxed_value::<F>),
537            )
538        };
539        self.decode_result(r)
540    }
541
542    fn create_aggregate_function<A, D, T>(
543        &mut self,
544        fn_name: &str,
545        n_arg: c_int,
546        flags: FunctionFlags,
547        aggr: D,
548    ) -> Result<()>
549    where
550        A: RefUnwindSafe + UnwindSafe,
551        D: Aggregate<A, T> + 'static,
552        T: ToSql,
553    {
554        let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
555        let c_name = str_to_cstring(fn_name)?;
556        let r = unsafe {
557            ffi::sqlite3_create_function_v2(
558                self.db(),
559                c_name.as_ptr(),
560                n_arg,
561                flags.bits(),
562                boxed_aggr.cast::<c_void>(),
563                None,
564                Some(call_boxed_step::<A, D, T>),
565                Some(call_boxed_final::<A, D, T>),
566                Some(free_boxed_value::<D>),
567            )
568        };
569        self.decode_result(r)
570    }
571
572    #[cfg(feature = "window")]
573    fn create_window_function<A, W, T>(
574        &mut self,
575        fn_name: &str,
576        n_arg: c_int,
577        flags: FunctionFlags,
578        aggr: W,
579    ) -> Result<()>
580    where
581        A: RefUnwindSafe + UnwindSafe,
582        W: WindowAggregate<A, T> + 'static,
583        T: ToSql,
584    {
585        let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
586        let c_name = str_to_cstring(fn_name)?;
587        let r = unsafe {
588            ffi::sqlite3_create_window_function(
589                self.db(),
590                c_name.as_ptr(),
591                n_arg,
592                flags.bits(),
593                boxed_aggr.cast::<c_void>(),
594                Some(call_boxed_step::<A, W, T>),
595                Some(call_boxed_final::<A, W, T>),
596                Some(call_boxed_value::<A, W, T>),
597                Some(call_boxed_inverse::<A, W, T>),
598                Some(free_boxed_value::<W>),
599            )
600        };
601        self.decode_result(r)
602    }
603
604    fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
605        let c_name = str_to_cstring(fn_name)?;
606        let r = unsafe {
607            ffi::sqlite3_create_function_v2(
608                self.db(),
609                c_name.as_ptr(),
610                n_arg,
611                ffi::SQLITE_UTF8,
612                ptr::null_mut(),
613                None,
614                None,
615                None,
616                None,
617            )
618        };
619        self.decode_result(r)
620    }
621}
622
623unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
624    let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
625    if pac.is_null() {
626        return None;
627    }
628    Some(pac)
629}
630
631unsafe extern "C" fn call_boxed_step<A, D, T>(
632    ctx: *mut sqlite3_context,
633    argc: c_int,
634    argv: *mut *mut sqlite3_value,
635) where
636    A: RefUnwindSafe + UnwindSafe,
637    D: Aggregate<A, T>,
638    T: ToSql,
639{
640    let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) {
641        pac
642    } else {
643        ffi::sqlite3_result_error_nomem(ctx);
644        return;
645    };
646
647    let r = catch_unwind(|| {
648        let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
649        assert!(
650            !boxed_aggr.is_null(),
651            "Internal error - null aggregate pointer"
652        );
653        let mut ctx = Context {
654            ctx,
655            args: slice::from_raw_parts(argv, argc as usize),
656        };
657
658        if (*pac as *mut A).is_null() {
659            *pac = Box::into_raw(Box::new((*boxed_aggr).init(&mut ctx)?));
660        }
661
662        (*boxed_aggr).step(&mut ctx, &mut **pac)
663    });
664    let r = match r {
665        Err(_) => {
666            report_error(ctx, &Error::UnwindingPanic);
667            return;
668        }
669        Ok(r) => r,
670    };
671    match r {
672        Ok(_) => {}
673        Err(err) => report_error(ctx, &err),
674    };
675}
676
677#[cfg(feature = "window")]
678unsafe extern "C" fn call_boxed_inverse<A, W, T>(
679    ctx: *mut sqlite3_context,
680    argc: c_int,
681    argv: *mut *mut sqlite3_value,
682) where
683    A: RefUnwindSafe + UnwindSafe,
684    W: WindowAggregate<A, T>,
685    T: ToSql,
686{
687    let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) {
688        pac
689    } else {
690        ffi::sqlite3_result_error_nomem(ctx);
691        return;
692    };
693
694    let r = catch_unwind(|| {
695        let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
696        assert!(
697            !boxed_aggr.is_null(),
698            "Internal error - null aggregate pointer"
699        );
700        let mut ctx = Context {
701            ctx,
702            args: slice::from_raw_parts(argv, argc as usize),
703        };
704        (*boxed_aggr).inverse(&mut ctx, &mut **pac)
705    });
706    let r = match r {
707        Err(_) => {
708            report_error(ctx, &Error::UnwindingPanic);
709            return;
710        }
711        Ok(r) => r,
712    };
713    match r {
714        Ok(_) => {}
715        Err(err) => report_error(ctx, &err),
716    };
717}
718
719unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
720where
721    A: RefUnwindSafe + UnwindSafe,
722    D: Aggregate<A, T>,
723    T: ToSql,
724{
725    // Within the xFinal callback, it is customary to set N=0 in calls to
726    // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
727    let a: Option<A> = match aggregate_context(ctx, 0) {
728        Some(pac) => {
729            if (*pac as *mut A).is_null() {
730                None
731            } else {
732                let a = Box::from_raw(*pac);
733                Some(*a)
734            }
735        }
736        None => None,
737    };
738
739    let r = catch_unwind(|| {
740        let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
741        assert!(
742            !boxed_aggr.is_null(),
743            "Internal error - null aggregate pointer"
744        );
745        let mut ctx = Context { ctx, args: &mut [] };
746        (*boxed_aggr).finalize(&mut ctx, a)
747    });
748    let t = match r {
749        Err(_) => {
750            report_error(ctx, &Error::UnwindingPanic);
751            return;
752        }
753        Ok(r) => r,
754    };
755    let t = t.as_ref().map(|t| ToSql::to_sql(t));
756    match t {
757        Ok(Ok(ref value)) => set_result(ctx, value),
758        Ok(Err(err)) => report_error(ctx, &err),
759        Err(err) => report_error(ctx, err),
760    }
761}
762
763#[cfg(feature = "window")]
764unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
765where
766    A: RefUnwindSafe + UnwindSafe,
767    W: WindowAggregate<A, T>,
768    T: ToSql,
769{
770    // Within the xValue callback, it is customary to set N=0 in calls to
771    // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
772    let a: Option<&A> = match aggregate_context(ctx, 0) {
773        Some(pac) => {
774            if (*pac as *mut A).is_null() {
775                None
776            } else {
777                let a = &**pac;
778                Some(a)
779            }
780        }
781        None => None,
782    };
783
784    let r = catch_unwind(|| {
785        let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
786        assert!(
787            !boxed_aggr.is_null(),
788            "Internal error - null aggregate pointer"
789        );
790        (*boxed_aggr).value(a)
791    });
792    let t = match r {
793        Err(_) => {
794            report_error(ctx, &Error::UnwindingPanic);
795            return;
796        }
797        Ok(r) => r,
798    };
799    let t = t.as_ref().map(|t| ToSql::to_sql(t));
800    match t {
801        Ok(Ok(ref value)) => set_result(ctx, value),
802        Ok(Err(err)) => report_error(ctx, &err),
803        Err(err) => report_error(ctx, err),
804    }
805}
806
807#[cfg(test)]
808mod test {
809    use regex::Regex;
810    use std::os::raw::c_double;
811
812    #[cfg(feature = "window")]
813    use crate::functions::WindowAggregate;
814    use crate::functions::{Aggregate, Context, FunctionFlags};
815    use crate::{Connection, Error, Result};
816
817    fn half(ctx: &Context<'_>) -> Result<c_double> {
818        assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
819        let value = ctx.get::<c_double>(0)?;
820        Ok(value / 2f64)
821    }
822
823    #[test]
824    fn test_function_half() -> Result<()> {
825        let db = Connection::open_in_memory()?;
826        db.create_scalar_function(
827            "half",
828            1,
829            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
830            half,
831        )?;
832        let result: Result<f64> = db.query_row("SELECT half(6)", [], |r| r.get(0));
833
834        assert!((3f64 - result?).abs() < f64::EPSILON);
835        Ok(())
836    }
837
838    #[test]
839    fn test_remove_function() -> Result<()> {
840        let db = Connection::open_in_memory()?;
841        db.create_scalar_function(
842            "half",
843            1,
844            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
845            half,
846        )?;
847        let result: Result<f64> = db.query_row("SELECT half(6)", [], |r| r.get(0));
848        assert!((3f64 - result?).abs() < f64::EPSILON);
849
850        db.remove_function("half", 1)?;
851        let result: Result<f64> = db.query_row("SELECT half(6)", [], |r| r.get(0));
852        assert!(result.is_err());
853        Ok(())
854    }
855
856    // This implementation of a regexp scalar function uses SQLite's auxiliary data
857    // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular
858    // expression multiple times within one query.
859    fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> {
860        assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
861        type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
862        let regexp: std::sync::Arc<Regex> = ctx
863            .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
864                Ok(Regex::new(vr.as_str()?)?)
865            })?;
866
867        let is_match = {
868            let text = ctx
869                .get_raw(1)
870                .as_str()
871                .map_err(|e| Error::UserFunctionError(e.into()))?;
872
873            regexp.is_match(text)
874        };
875
876        Ok(is_match)
877    }
878
879    #[test]
880    fn test_function_regexp_with_auxilliary() -> Result<()> {
881        let db = Connection::open_in_memory()?;
882        db.execute_batch(
883            "BEGIN;
884             CREATE TABLE foo (x string);
885             INSERT INTO foo VALUES ('lisa');
886             INSERT INTO foo VALUES ('lXsi');
887             INSERT INTO foo VALUES ('lisX');
888             END;",
889        )?;
890        db.create_scalar_function(
891            "regexp",
892            2,
893            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
894            regexp_with_auxilliary,
895        )?;
896
897        let result: Result<bool> =
898            db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", [], |r| r.get(0));
899
900        assert!(result?);
901
902        let result: Result<i64> = db.query_row(
903            "SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1",
904            [],
905            |r| r.get(0),
906        );
907
908        assert_eq!(2, result?);
909        Ok(())
910    }
911
912    #[test]
913    fn test_varargs_function() -> Result<()> {
914        let db = Connection::open_in_memory()?;
915        db.create_scalar_function(
916            "my_concat",
917            -1,
918            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
919            |ctx| {
920                let mut ret = String::new();
921
922                for idx in 0..ctx.len() {
923                    let s = ctx.get::<String>(idx)?;
924                    ret.push_str(&s);
925                }
926
927                Ok(ret)
928            },
929        )?;
930
931        for &(expected, query) in &[
932            ("", "SELECT my_concat()"),
933            ("onetwo", "SELECT my_concat('one', 'two')"),
934            ("abc", "SELECT my_concat('a', 'b', 'c')"),
935        ] {
936            let result: String = db.query_row(query, [], |r| r.get(0))?;
937            assert_eq!(expected, result);
938        }
939        Ok(())
940    }
941
942    #[test]
943    fn test_get_aux_type_checking() -> Result<()> {
944        let db = Connection::open_in_memory()?;
945        db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
946            if !ctx.get::<bool>(1)? {
947                ctx.set_aux::<i64>(0, 100)?;
948            } else {
949                assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
950                assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100);
951            }
952            Ok(true)
953        })?;
954
955        let res: bool = db.query_row(
956            "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)",
957            [],
958            |r| r.get(0),
959        )?;
960        // Doesn't actually matter, we'll assert in the function if there's a problem.
961        assert!(res);
962        Ok(())
963    }
964
965    struct Sum;
966    struct Count;
967
968    impl Aggregate<i64, Option<i64>> for Sum {
969        fn init(&self, _: &mut Context<'_>) -> Result<i64> {
970            Ok(0)
971        }
972
973        fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
974            *sum += ctx.get::<i64>(0)?;
975            Ok(())
976        }
977
978        fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> {
979            Ok(sum)
980        }
981    }
982
983    impl Aggregate<i64, i64> for Count {
984        fn init(&self, _: &mut Context<'_>) -> Result<i64> {
985            Ok(0)
986        }
987
988        fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
989            *sum += 1;
990            Ok(())
991        }
992
993        fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> {
994            Ok(sum.unwrap_or(0))
995        }
996    }
997
998    #[test]
999    fn test_sum() -> Result<()> {
1000        let db = Connection::open_in_memory()?;
1001        db.create_aggregate_function(
1002            "my_sum",
1003            1,
1004            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1005            Sum,
1006        )?;
1007
1008        // sum should return NULL when given no columns (contrast with count below)
1009        let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1010        let result: Option<i64> = db.query_row(no_result, [], |r| r.get(0))?;
1011        assert!(result.is_none());
1012
1013        let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1014        let result: i64 = db.query_row(single_sum, [], |r| r.get(0))?;
1015        assert_eq!(4, result);
1016
1017        let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
1018                        2, 1)";
1019        let result: (i64, i64) = db.query_row(dual_sum, [], |r| Ok((r.get(0)?, r.get(1)?)))?;
1020        assert_eq!((4, 2), result);
1021        Ok(())
1022    }
1023
1024    #[test]
1025    fn test_count() -> Result<()> {
1026        let db = Connection::open_in_memory()?;
1027        db.create_aggregate_function(
1028            "my_count",
1029            -1,
1030            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1031            Count,
1032        )?;
1033
1034        // count should return 0 when given no columns (contrast with sum above)
1035        let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1036        let result: i64 = db.query_row(no_result, [], |r| r.get(0))?;
1037        assert_eq!(result, 0);
1038
1039        let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1040        let result: i64 = db.query_row(single_sum, [], |r| r.get(0))?;
1041        assert_eq!(2, result);
1042        Ok(())
1043    }
1044
1045    #[cfg(feature = "window")]
1046    impl WindowAggregate<i64, Option<i64>> for Sum {
1047        fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1048            *sum -= ctx.get::<i64>(0)?;
1049            Ok(())
1050        }
1051
1052        fn value(&self, sum: Option<&i64>) -> Result<Option<i64>> {
1053            Ok(sum.copied())
1054        }
1055    }
1056
1057    #[test]
1058    #[cfg(feature = "window")]
1059    fn test_window() -> Result<()> {
1060        use fallible_iterator::FallibleIterator;
1061
1062        let db = Connection::open_in_memory()?;
1063        db.create_window_function(
1064            "sumint",
1065            1,
1066            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1067            Sum,
1068        )?;
1069        db.execute_batch(
1070            "CREATE TABLE t3(x, y);
1071             INSERT INTO t3 VALUES('a', 4),
1072                     ('b', 5),
1073                     ('c', 3),
1074                     ('d', 8),
1075                     ('e', 1);",
1076        )?;
1077
1078        let mut stmt = db.prepare(
1079            "SELECT x, sumint(y) OVER (
1080                   ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1081                 ) AS sum_y
1082                 FROM t3 ORDER BY x;",
1083        )?;
1084
1085        let results: Vec<(String, i64)> = stmt
1086            .query([])?
1087            .map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
1088            .collect()?;
1089        let expected = vec![
1090            ("a".to_owned(), 9),
1091            ("b".to_owned(), 12),
1092            ("c".to_owned(), 16),
1093            ("d".to_owned(), 12),
1094            ("e".to_owned(), 9),
1095        ];
1096        assert_eq!(expected, results);
1097        Ok(())
1098    }
1099}