Skip to main content

sqlite_provider/
function.rs

1use core::ffi::c_void;
2use core::mem::MaybeUninit;
3use core::ptr::NonNull;
4
5use crate::connection::Connection;
6use crate::error::{Error, Result};
7use crate::provider::{FeatureSet, FunctionFlags, Sqlite3Api, ValueType};
8use crate::value::{Value, ValueRef};
9
10/// Context wrapper passed to user-defined functions.
11pub struct Context<'p, P: Sqlite3Api> {
12    api: &'p P,
13    ctx: NonNull<P::Context>,
14}
15
16impl<'p, P: Sqlite3Api> Context<'p, P> {
17    pub(crate) fn new(api: &'p P, ctx: NonNull<P::Context>) -> Self {
18        Self { api, ctx }
19    }
20
21    /// Set NULL result.
22    pub fn result_null(&self) {
23        unsafe { self.api.result_null(self.ctx) }
24    }
25
26    /// Set integer result.
27    pub fn result_int64(&self, v: i64) {
28        unsafe { self.api.result_int64(self.ctx, v) }
29    }
30
31    /// Set floating result.
32    pub fn result_double(&self, v: f64) {
33        unsafe { self.api.result_double(self.ctx, v) }
34    }
35
36    /// Set text result (provider must copy or retain the bytes as needed).
37    pub fn result_text(&self, v: &str) {
38        unsafe { self.api.result_text(self.ctx, v) }
39    }
40
41    /// Set blob result (provider must copy or retain the bytes as needed).
42    pub fn result_blob(&self, v: &[u8]) {
43        unsafe { self.api.result_blob(self.ctx, v) }
44    }
45
46    /// Set error result.
47    pub fn result_error(&self, msg: &str) {
48        unsafe { self.api.result_error(self.ctx, msg) }
49    }
50
51    /// Set result from an owned `Value`.
52    pub fn result_value(&self, value: Value) {
53        match value {
54            Value::Null => self.result_null(),
55            Value::Integer(v) => self.result_int64(v),
56            Value::Float(v) => self.result_double(v),
57            Value::Text(v) => self.result_text(&v),
58            Value::Blob(v) => self.result_blob(&v),
59        }
60    }
61}
62
63const INLINE_ARGS: usize = 8;
64
65struct ArgBuffer<'a> {
66    inline: [MaybeUninit<ValueRef<'a>>; INLINE_ARGS],
67    len: usize,
68    heap: Option<Vec<ValueRef<'a>>>,
69}
70
71impl<'a> ArgBuffer<'a> {
72    fn new(argc: usize) -> Self {
73        let inline = unsafe { MaybeUninit::<[MaybeUninit<ValueRef<'a>>; INLINE_ARGS]>::uninit().assume_init() };
74        let heap = if argc > INLINE_ARGS { Some(Vec::with_capacity(argc)) } else { None };
75        Self { inline, len: 0, heap }
76    }
77
78    fn push(&mut self, value: ValueRef<'a>) {
79        if let Some(heap) = &mut self.heap {
80            heap.push(value);
81            return;
82        }
83        let slot = &mut self.inline[self.len];
84        slot.write(value);
85        self.len += 1;
86    }
87
88    fn as_slice(&self) -> &[ValueRef<'a>] {
89        if let Some(heap) = &self.heap {
90            return heap.as_slice();
91        }
92        unsafe { core::slice::from_raw_parts(self.inline.as_ptr() as *const ValueRef<'a>, self.len) }
93    }
94}
95
96unsafe fn value_ref_from_raw<'a, P: Sqlite3Api>(api: &P, value: NonNull<P::Value>) -> ValueRef<'a> {
97    match unsafe { api.value_type(value) } {
98        ValueType::Null => ValueRef::Null,
99        ValueType::Integer => ValueRef::Integer(unsafe { api.value_int64(value) }),
100        ValueType::Float => ValueRef::Float(unsafe { api.value_double(value) }),
101        ValueType::Text => unsafe { ValueRef::from_raw_text(api.value_text(value)) },
102        ValueType::Blob => unsafe { ValueRef::from_raw_blob(api.value_blob(value)) },
103    }
104}
105
106fn args_from_raw<'a, P: Sqlite3Api>(api: &P, argc: i32, argv: *mut *mut P::Value) -> ArgBuffer<'a> {
107    let argc = if argc < 0 { 0 } else { argc as usize };
108    let mut args = ArgBuffer::new(argc);
109    if argc == 0 || argv.is_null() {
110        return args;
111    }
112    let values = unsafe { core::slice::from_raw_parts(argv, argc) };
113    for value in values {
114        if let Some(ptr) = NonNull::new(*value) {
115            let arg = unsafe { value_ref_from_raw(api, ptr) };
116            args.push(arg);
117        } else {
118            args.push(ValueRef::Null);
119        }
120    }
121    args
122}
123
124fn set_error<P: Sqlite3Api>(ctx: &Context<'_, P>, err: &Error) {
125    let msg = err.message.as_deref().unwrap_or("sqlite function error");
126    ctx.result_error(msg);
127}
128
129struct ScalarState<P: Sqlite3Api, F> {
130    api: *const P,
131    func: F,
132}
133
134extern "C" fn scalar_trampoline<P, F>(
135    ctx: *mut P::Context,
136    argc: i32,
137    argv: *mut *mut P::Value,
138) where
139    P: Sqlite3Api,
140    F: for<'a> FnMut(&Context<'a, P>, &[ValueRef<'a>]) -> Result<Value> + Send + 'static,
141{
142    let ctx = match NonNull::new(ctx) {
143        Some(ctx) => ctx,
144        None => return,
145    };
146    let user_data = unsafe { P::user_data(ctx) };
147    if user_data.is_null() {
148        return;
149    }
150    let state = unsafe { &mut *(user_data as *mut ScalarState<P, F>) };
151    let api = unsafe { &*state.api };
152    let context = Context { api, ctx };
153    let args = args_from_raw(api, argc, argv);
154    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
155        (state.func)(&context, args.as_slice())
156    }));
157    match out {
158        Ok(Ok(value)) => context.result_value(value),
159        Ok(Err(err)) => set_error(&context, &err),
160        Err(_) => context.result_error("panic in sqlite function"),
161    }
162}
163
164struct AggregateState<P: Sqlite3Api, T, Init, Step, Final> {
165    api: *const P,
166    init: Init,
167    step: Step,
168    final_fn: Final,
169    _marker: core::marker::PhantomData<T>,
170}
171
172struct AggCell<T> {
173    initialized: bool,
174    value: MaybeUninit<T>,
175}
176
177unsafe fn get_agg_cell<P: Sqlite3Api, T>(
178    api: &P,
179    ctx: NonNull<P::Context>,
180    allocate: bool,
181) -> *mut AggCell<T> {
182    let bytes = if allocate { core::mem::size_of::<AggCell<T>>() } else { 0 };
183    unsafe { api.aggregate_context(ctx, bytes) as *mut AggCell<T> }
184}
185
186extern "C" fn aggregate_step_trampoline<P, T, Init, Step, Final>(
187    ctx: *mut P::Context,
188    argc: i32,
189    argv: *mut *mut P::Value,
190) where
191    P: Sqlite3Api,
192    T: Send + 'static,
193    Init: Fn() -> T + Send + 'static,
194    Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
195    Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
196{
197    let ctx = match NonNull::new(ctx) {
198        Some(ctx) => ctx,
199        None => return,
200    };
201    let user_data = unsafe { P::user_data(ctx) };
202    if user_data.is_null() {
203        return;
204    }
205    let state = unsafe { &mut *(user_data as *mut AggregateState<P, T, Init, Step, Final>) };
206    let api = unsafe { &*state.api };
207    let context = Context { api, ctx };
208    let cell = unsafe { get_agg_cell::<P, T>(api, ctx, true) };
209    if cell.is_null() {
210        context.result_error("sqlite aggregate no memory");
211        return;
212    }
213    let cell = unsafe { &mut *cell };
214    if !cell.initialized {
215        cell.value.write((state.init)());
216        cell.initialized = true;
217    }
218    let args = args_from_raw(api, argc, argv);
219    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
220        let value = unsafe { &mut *cell.value.as_mut_ptr() };
221        (state.step)(&context, value, args.as_slice())
222    }));
223    match out {
224        Ok(Ok(())) => {}
225        Ok(Err(err)) => set_error(&context, &err),
226        Err(_) => context.result_error("panic in sqlite aggregate"),
227    }
228}
229
230extern "C" fn aggregate_final_trampoline<P, T, Init, Step, Final>(ctx: *mut P::Context)
231where
232    P: Sqlite3Api,
233    T: Send + 'static,
234    Init: Fn() -> T + Send + 'static,
235    Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
236    Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
237{
238    let ctx = match NonNull::new(ctx) {
239        Some(ctx) => ctx,
240        None => return,
241    };
242    let user_data = unsafe { P::user_data(ctx) };
243    if user_data.is_null() {
244        return;
245    }
246    let state = unsafe { &mut *(user_data as *mut AggregateState<P, T, Init, Step, Final>) };
247    let api = unsafe { &*state.api };
248    let context = Context { api, ctx };
249    let cell = unsafe { get_agg_cell::<P, T>(api, ctx, false) };
250    if cell.is_null() {
251        context.result_null();
252        return;
253    }
254    let cell = unsafe { &mut *cell };
255    if !cell.initialized {
256        context.result_null();
257        return;
258    }
259    let value = unsafe { cell.value.assume_init_read() };
260    cell.initialized = false;
261    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
262        (state.final_fn)(&context, value)
263    }));
264    match out {
265        Ok(Ok(result)) => context.result_value(result),
266        Ok(Err(err)) => set_error(&context, &err),
267        Err(_) => context.result_error("panic in sqlite aggregate final"),
268    }
269}
270
271struct WindowState<P: Sqlite3Api, T, Init, Step, Inverse, ValueFn, Final> {
272    api: *const P,
273    init: Init,
274    step: Step,
275    inverse: Inverse,
276    value_fn: ValueFn,
277    final_fn: Final,
278    _marker: core::marker::PhantomData<T>,
279}
280
281extern "C" fn window_step_trampoline<P, T, Init, Step, Inverse, ValueFn, Final>(
282    ctx: *mut P::Context,
283    argc: i32,
284    argv: *mut *mut P::Value,
285) where
286    P: Sqlite3Api,
287    T: Send + 'static,
288    Init: Fn() -> T + Send + 'static,
289    Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
290    Inverse: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
291    ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
292    Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
293{
294    let ctx = match NonNull::new(ctx) {
295        Some(ctx) => ctx,
296        None => return,
297    };
298    let user_data = unsafe { P::user_data(ctx) };
299    if user_data.is_null() {
300        return;
301    }
302    let state = unsafe { &mut *(user_data as *mut WindowState<P, T, Init, Step, Inverse, ValueFn, Final>) };
303    let api = unsafe { &*state.api };
304    let context = Context { api, ctx };
305    let cell = unsafe { get_agg_cell::<P, T>(api, ctx, true) };
306    if cell.is_null() {
307        context.result_error("sqlite window no memory");
308        return;
309    }
310    let cell = unsafe { &mut *cell };
311    if !cell.initialized {
312        cell.value.write((state.init)());
313        cell.initialized = true;
314    }
315    let args = args_from_raw(api, argc, argv);
316    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
317        let value = unsafe { &mut *cell.value.as_mut_ptr() };
318        (state.step)(&context, value, args.as_slice())
319    }));
320    match out {
321        Ok(Ok(())) => {}
322        Ok(Err(err)) => set_error(&context, &err),
323        Err(_) => context.result_error("panic in sqlite window step"),
324    }
325}
326
327extern "C" fn window_inverse_trampoline<P, T, Init, Step, Inverse, ValueFn, Final>(
328    ctx: *mut P::Context,
329    argc: i32,
330    argv: *mut *mut P::Value,
331) where
332    P: Sqlite3Api,
333    T: Send + 'static,
334    Init: Fn() -> T + Send + 'static,
335    Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
336    Inverse: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
337    ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
338    Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
339{
340    let ctx = match NonNull::new(ctx) {
341        Some(ctx) => ctx,
342        None => return,
343    };
344    let user_data = unsafe { P::user_data(ctx) };
345    if user_data.is_null() {
346        return;
347    }
348    let state = unsafe { &mut *(user_data as *mut WindowState<P, T, Init, Step, Inverse, ValueFn, Final>) };
349    let api = unsafe { &*state.api };
350    let context = Context { api, ctx };
351    let cell = unsafe { get_agg_cell::<P, T>(api, ctx, true) };
352    if cell.is_null() {
353        context.result_error("sqlite window no memory");
354        return;
355    }
356    let cell = unsafe { &mut *cell };
357    if !cell.initialized {
358        cell.value.write((state.init)());
359        cell.initialized = true;
360    }
361    let args = args_from_raw(api, argc, argv);
362    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
363        let value = unsafe { &mut *cell.value.as_mut_ptr() };
364        (state.inverse)(&context, value, args.as_slice())
365    }));
366    match out {
367        Ok(Ok(())) => {}
368        Ok(Err(err)) => set_error(&context, &err),
369        Err(_) => context.result_error("panic in sqlite window inverse"),
370    }
371}
372
373extern "C" fn window_value_trampoline<P, T, Init, Step, Inverse, ValueFn, Final>(
374    ctx: *mut P::Context,
375) where
376    P: Sqlite3Api,
377    T: Send + 'static,
378    Init: Fn() -> T + Send + 'static,
379    Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
380    Inverse: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
381    ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
382    Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
383{
384    let ctx = match NonNull::new(ctx) {
385        Some(ctx) => ctx,
386        None => return,
387    };
388    let user_data = unsafe { P::user_data(ctx) };
389    if user_data.is_null() {
390        return;
391    }
392    let state = unsafe { &mut *(user_data as *mut WindowState<P, T, Init, Step, Inverse, ValueFn, Final>) };
393    let api = unsafe { &*state.api };
394    let context = Context { api, ctx };
395    let cell = unsafe { get_agg_cell::<P, T>(api, ctx, false) };
396    if cell.is_null() {
397        context.result_null();
398        return;
399    }
400    let cell = unsafe { &mut *cell };
401    if !cell.initialized {
402        context.result_null();
403        return;
404    }
405    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
406        let value = unsafe { &mut *cell.value.as_mut_ptr() };
407        (state.value_fn)(&context, value)
408    }));
409    match out {
410        Ok(Ok(result)) => context.result_value(result),
411        Ok(Err(err)) => set_error(&context, &err),
412        Err(_) => context.result_error("panic in sqlite window value"),
413    }
414}
415
416extern "C" fn window_final_trampoline<P, T, Init, Step, Inverse, ValueFn, Final>(
417    ctx: *mut P::Context,
418) where
419    P: Sqlite3Api,
420    T: Send + 'static,
421    Init: Fn() -> T + Send + 'static,
422    Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
423    Inverse: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
424    ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
425    Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
426{
427    let ctx = match NonNull::new(ctx) {
428        Some(ctx) => ctx,
429        None => return,
430    };
431    let user_data = unsafe { P::user_data(ctx) };
432    if user_data.is_null() {
433        return;
434    }
435    let state = unsafe { &mut *(user_data as *mut WindowState<P, T, Init, Step, Inverse, ValueFn, Final>) };
436    let api = unsafe { &*state.api };
437    let context = Context { api, ctx };
438    let cell = unsafe { get_agg_cell::<P, T>(api, ctx, false) };
439    if cell.is_null() {
440        context.result_null();
441        return;
442    }
443    let cell = unsafe { &mut *cell };
444    if !cell.initialized {
445        context.result_null();
446        return;
447    }
448    let value = unsafe { cell.value.assume_init_read() };
449    cell.initialized = false;
450    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
451        (state.final_fn)(&context, value)
452    }));
453    match out {
454        Ok(Ok(result)) => context.result_value(result),
455        Ok(Err(err)) => set_error(&context, &err),
456        Err(_) => context.result_error("panic in sqlite window final"),
457    }
458}
459
460extern "C" fn drop_boxed<T>(ptr: *mut c_void) {
461    if !ptr.is_null() {
462        unsafe { drop(Box::from_raw(ptr as *mut T)) };
463    }
464}
465
466impl<'p, P: Sqlite3Api> Connection<'p, P> {
467    /// Register a scalar function (xFunc).
468    pub fn create_scalar_function<F>(
469        &self,
470        name: &str,
471        n_args: i32,
472        func: F,
473    ) -> Result<()>
474    where
475        F: for<'a> FnMut(&Context<'a, P>, &[ValueRef<'a>]) -> Result<Value> + Send + 'static,
476    {
477        if !self.api.feature_set().contains(FeatureSet::CREATE_FUNCTION_V2) {
478            return Err(Error::feature_unavailable("create_function_v2 unsupported"));
479        }
480        let state = Box::new(ScalarState { api: self.api as *const P, func });
481        let user_data = Box::into_raw(state) as *mut c_void;
482        unsafe {
483            self.api.create_function_v2(
484                self.db,
485                name,
486                n_args,
487                FunctionFlags::empty(),
488                Some(scalar_trampoline::<P, F>),
489                None,
490                None,
491                user_data,
492                Some(drop_boxed::<ScalarState<P, F>>),
493            )
494        }
495    }
496
497    /// Register an aggregate function (xStep/xFinal).
498    pub fn create_aggregate_function<T, Init, Step, Final>(
499        &self,
500        name: &str,
501        n_args: i32,
502        init: Init,
503        step: Step,
504        final_fn: Final,
505    ) -> Result<()>
506    where
507        T: Send + 'static,
508        Init: Fn() -> T + Send + 'static,
509        Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
510        Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
511    {
512        if !self.api.feature_set().contains(FeatureSet::CREATE_FUNCTION_V2) {
513            return Err(Error::feature_unavailable("create_function_v2 unsupported"));
514        }
515        let state = Box::new(AggregateState::<P, T, Init, Step, Final> {
516            api: self.api as *const P,
517            init,
518            step,
519            final_fn,
520            _marker: core::marker::PhantomData,
521        });
522        let user_data = Box::into_raw(state) as *mut c_void;
523        unsafe {
524            self.api.create_function_v2(
525                self.db,
526                name,
527                n_args,
528                FunctionFlags::empty(),
529                None,
530                Some(aggregate_step_trampoline::<P, T, Init, Step, Final>),
531                Some(aggregate_final_trampoline::<P, T, Init, Step, Final>),
532                user_data,
533                Some(drop_boxed::<AggregateState<P, T, Init, Step, Final>>),
534            )
535        }
536    }
537
538    /// Register a window function (xStep/xInverse/xValue/xFinal).
539    #[allow(clippy::too_many_arguments)]
540    pub fn create_window_function<T, Init, Step, Inverse, ValueFn, Final>(
541        &self,
542        name: &str,
543        n_args: i32,
544        init: Init,
545        step: Step,
546        inverse: Inverse,
547        value_fn: ValueFn,
548        final_fn: Final,
549    ) -> Result<()>
550    where
551        T: Send + 'static,
552        Init: Fn() -> T + Send + 'static,
553        Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
554        Inverse: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
555        ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
556        Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
557    {
558        if !self.api.feature_set().contains(FeatureSet::WINDOW_FUNCTIONS) {
559            return Err(Error::feature_unavailable("window functions unsupported"));
560        }
561        let state = Box::new(WindowState::<P, T, Init, Step, Inverse, ValueFn, Final> {
562            api: self.api as *const P,
563            init,
564            step,
565            inverse,
566            value_fn,
567            final_fn,
568            _marker: core::marker::PhantomData,
569        });
570        let user_data = Box::into_raw(state) as *mut c_void;
571        unsafe {
572            self.api.create_window_function(
573                self.db,
574                name,
575                n_args,
576                FunctionFlags::empty(),
577                Some(window_step_trampoline::<P, T, Init, Step, Inverse, ValueFn, Final>),
578                Some(window_final_trampoline::<P, T, Init, Step, Inverse, ValueFn, Final>),
579                Some(window_value_trampoline::<P, T, Init, Step, Inverse, ValueFn, Final>),
580                Some(window_inverse_trampoline::<P, T, Init, Step, Inverse, ValueFn, Final>),
581                user_data,
582                Some(drop_boxed::<WindowState<P, T, Init, Step, Inverse, ValueFn, Final>>),
583            )
584        }
585    }
586}
587
588#[cfg(test)]
589mod tests {
590    use super::ArgBuffer;
591    use crate::value::ValueRef;
592
593    #[test]
594    fn arg_buffer_inline() {
595        let mut buf = ArgBuffer::new(2);
596        buf.push(ValueRef::Integer(1));
597        buf.push(ValueRef::Integer(2));
598        assert_eq!(buf.as_slice(), &[ValueRef::Integer(1), ValueRef::Integer(2)]);
599    }
600}