1use core::ffi::c_void;
2use core::mem::MaybeUninit;
3use core::ptr::NonNull;
4
5use crate::connection::Connection;
6use crate::error::{Error, Result};
7use crate::provider::{FeatureSet, FunctionFlags, Sqlite3Api, ValueType};
8use crate::value::{Value, ValueRef};
9
10pub struct Context<'p, P: Sqlite3Api> {
12 api: &'p P,
13 ctx: NonNull<P::Context>,
14}
15
16impl<'p, P: Sqlite3Api> Context<'p, P> {
17 pub(crate) fn new(api: &'p P, ctx: NonNull<P::Context>) -> Self {
18 Self { api, ctx }
19 }
20
21 pub fn result_null(&self) {
23 unsafe { self.api.result_null(self.ctx) }
24 }
25
26 pub fn result_int64(&self, v: i64) {
28 unsafe { self.api.result_int64(self.ctx, v) }
29 }
30
31 pub fn result_double(&self, v: f64) {
33 unsafe { self.api.result_double(self.ctx, v) }
34 }
35
36 pub fn result_text(&self, v: &str) {
38 unsafe { self.api.result_text(self.ctx, v) }
39 }
40
41 pub fn result_blob(&self, v: &[u8]) {
43 unsafe { self.api.result_blob(self.ctx, v) }
44 }
45
46 pub fn result_error(&self, msg: &str) {
48 unsafe { self.api.result_error(self.ctx, msg) }
49 }
50
51 pub fn result_value(&self, value: Value) {
53 match value {
54 Value::Null => self.result_null(),
55 Value::Integer(v) => self.result_int64(v),
56 Value::Float(v) => self.result_double(v),
57 Value::Text(v) => self.result_text(&v),
58 Value::Blob(v) => self.result_blob(&v),
59 }
60 }
61}
62
63const INLINE_ARGS: usize = 8;
64
65struct ArgBuffer<'a> {
66 inline: [MaybeUninit<ValueRef<'a>>; INLINE_ARGS],
67 len: usize,
68 heap: Option<Vec<ValueRef<'a>>>,
69}
70
71impl<'a> ArgBuffer<'a> {
72 fn new(argc: usize) -> Self {
73 let inline = unsafe { MaybeUninit::<[MaybeUninit<ValueRef<'a>>; INLINE_ARGS]>::uninit().assume_init() };
74 let heap = if argc > INLINE_ARGS { Some(Vec::with_capacity(argc)) } else { None };
75 Self { inline, len: 0, heap }
76 }
77
78 fn push(&mut self, value: ValueRef<'a>) {
79 if let Some(heap) = &mut self.heap {
80 heap.push(value);
81 return;
82 }
83 let slot = &mut self.inline[self.len];
84 slot.write(value);
85 self.len += 1;
86 }
87
88 fn as_slice(&self) -> &[ValueRef<'a>] {
89 if let Some(heap) = &self.heap {
90 return heap.as_slice();
91 }
92 unsafe { core::slice::from_raw_parts(self.inline.as_ptr() as *const ValueRef<'a>, self.len) }
93 }
94}
95
96unsafe fn value_ref_from_raw<'a, P: Sqlite3Api>(api: &P, value: NonNull<P::Value>) -> ValueRef<'a> {
97 match unsafe { api.value_type(value) } {
98 ValueType::Null => ValueRef::Null,
99 ValueType::Integer => ValueRef::Integer(unsafe { api.value_int64(value) }),
100 ValueType::Float => ValueRef::Float(unsafe { api.value_double(value) }),
101 ValueType::Text => unsafe { ValueRef::from_raw_text(api.value_text(value)) },
102 ValueType::Blob => unsafe { ValueRef::from_raw_blob(api.value_blob(value)) },
103 }
104}
105
106fn args_from_raw<'a, P: Sqlite3Api>(api: &P, argc: i32, argv: *mut *mut P::Value) -> ArgBuffer<'a> {
107 let argc = if argc < 0 { 0 } else { argc as usize };
108 let mut args = ArgBuffer::new(argc);
109 if argc == 0 || argv.is_null() {
110 return args;
111 }
112 let values = unsafe { core::slice::from_raw_parts(argv, argc) };
113 for value in values {
114 if let Some(ptr) = NonNull::new(*value) {
115 let arg = unsafe { value_ref_from_raw(api, ptr) };
116 args.push(arg);
117 } else {
118 args.push(ValueRef::Null);
119 }
120 }
121 args
122}
123
124fn set_error<P: Sqlite3Api>(ctx: &Context<'_, P>, err: &Error) {
125 let msg = err.message.as_deref().unwrap_or("sqlite function error");
126 ctx.result_error(msg);
127}
128
129struct ScalarState<P: Sqlite3Api, F> {
130 api: *const P,
131 func: F,
132}
133
134extern "C" fn scalar_trampoline<P, F>(
135 ctx: *mut P::Context,
136 argc: i32,
137 argv: *mut *mut P::Value,
138) where
139 P: Sqlite3Api,
140 F: for<'a> FnMut(&Context<'a, P>, &[ValueRef<'a>]) -> Result<Value> + Send + 'static,
141{
142 let ctx = match NonNull::new(ctx) {
143 Some(ctx) => ctx,
144 None => return,
145 };
146 let user_data = unsafe { P::user_data(ctx) };
147 if user_data.is_null() {
148 return;
149 }
150 let state = unsafe { &mut *(user_data as *mut ScalarState<P, F>) };
151 let api = unsafe { &*state.api };
152 let context = Context { api, ctx };
153 let args = args_from_raw(api, argc, argv);
154 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
155 (state.func)(&context, args.as_slice())
156 }));
157 match out {
158 Ok(Ok(value)) => context.result_value(value),
159 Ok(Err(err)) => set_error(&context, &err),
160 Err(_) => context.result_error("panic in sqlite function"),
161 }
162}
163
164struct AggregateState<P: Sqlite3Api, T, Init, Step, Final> {
165 api: *const P,
166 init: Init,
167 step: Step,
168 final_fn: Final,
169 _marker: core::marker::PhantomData<T>,
170}
171
172struct AggCell<T> {
173 initialized: bool,
174 value: MaybeUninit<T>,
175}
176
177unsafe fn get_agg_cell<P: Sqlite3Api, T>(
178 api: &P,
179 ctx: NonNull<P::Context>,
180 allocate: bool,
181) -> *mut AggCell<T> {
182 let bytes = if allocate { core::mem::size_of::<AggCell<T>>() } else { 0 };
183 unsafe { api.aggregate_context(ctx, bytes) as *mut AggCell<T> }
184}
185
186extern "C" fn aggregate_step_trampoline<P, T, Init, Step, Final>(
187 ctx: *mut P::Context,
188 argc: i32,
189 argv: *mut *mut P::Value,
190) where
191 P: Sqlite3Api,
192 T: Send + 'static,
193 Init: Fn() -> T + Send + 'static,
194 Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
195 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
196{
197 let ctx = match NonNull::new(ctx) {
198 Some(ctx) => ctx,
199 None => return,
200 };
201 let user_data = unsafe { P::user_data(ctx) };
202 if user_data.is_null() {
203 return;
204 }
205 let state = unsafe { &mut *(user_data as *mut AggregateState<P, T, Init, Step, Final>) };
206 let api = unsafe { &*state.api };
207 let context = Context { api, ctx };
208 let cell = unsafe { get_agg_cell::<P, T>(api, ctx, true) };
209 if cell.is_null() {
210 context.result_error("sqlite aggregate no memory");
211 return;
212 }
213 let cell = unsafe { &mut *cell };
214 if !cell.initialized {
215 cell.value.write((state.init)());
216 cell.initialized = true;
217 }
218 let args = args_from_raw(api, argc, argv);
219 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
220 let value = unsafe { &mut *cell.value.as_mut_ptr() };
221 (state.step)(&context, value, args.as_slice())
222 }));
223 match out {
224 Ok(Ok(())) => {}
225 Ok(Err(err)) => set_error(&context, &err),
226 Err(_) => context.result_error("panic in sqlite aggregate"),
227 }
228}
229
230extern "C" fn aggregate_final_trampoline<P, T, Init, Step, Final>(ctx: *mut P::Context)
231where
232 P: Sqlite3Api,
233 T: Send + 'static,
234 Init: Fn() -> T + Send + 'static,
235 Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
236 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
237{
238 let ctx = match NonNull::new(ctx) {
239 Some(ctx) => ctx,
240 None => return,
241 };
242 let user_data = unsafe { P::user_data(ctx) };
243 if user_data.is_null() {
244 return;
245 }
246 let state = unsafe { &mut *(user_data as *mut AggregateState<P, T, Init, Step, Final>) };
247 let api = unsafe { &*state.api };
248 let context = Context { api, ctx };
249 let cell = unsafe { get_agg_cell::<P, T>(api, ctx, false) };
250 if cell.is_null() {
251 context.result_null();
252 return;
253 }
254 let cell = unsafe { &mut *cell };
255 if !cell.initialized {
256 context.result_null();
257 return;
258 }
259 let value = unsafe { cell.value.assume_init_read() };
260 cell.initialized = false;
261 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
262 (state.final_fn)(&context, value)
263 }));
264 match out {
265 Ok(Ok(result)) => context.result_value(result),
266 Ok(Err(err)) => set_error(&context, &err),
267 Err(_) => context.result_error("panic in sqlite aggregate final"),
268 }
269}
270
271struct WindowState<P: Sqlite3Api, T, Init, Step, Inverse, ValueFn, Final> {
272 api: *const P,
273 init: Init,
274 step: Step,
275 inverse: Inverse,
276 value_fn: ValueFn,
277 final_fn: Final,
278 _marker: core::marker::PhantomData<T>,
279}
280
281extern "C" fn window_step_trampoline<P, T, Init, Step, Inverse, ValueFn, Final>(
282 ctx: *mut P::Context,
283 argc: i32,
284 argv: *mut *mut P::Value,
285) where
286 P: Sqlite3Api,
287 T: Send + 'static,
288 Init: Fn() -> T + Send + 'static,
289 Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
290 Inverse: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
291 ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
292 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
293{
294 let ctx = match NonNull::new(ctx) {
295 Some(ctx) => ctx,
296 None => return,
297 };
298 let user_data = unsafe { P::user_data(ctx) };
299 if user_data.is_null() {
300 return;
301 }
302 let state = unsafe { &mut *(user_data as *mut WindowState<P, T, Init, Step, Inverse, ValueFn, Final>) };
303 let api = unsafe { &*state.api };
304 let context = Context { api, ctx };
305 let cell = unsafe { get_agg_cell::<P, T>(api, ctx, true) };
306 if cell.is_null() {
307 context.result_error("sqlite window no memory");
308 return;
309 }
310 let cell = unsafe { &mut *cell };
311 if !cell.initialized {
312 cell.value.write((state.init)());
313 cell.initialized = true;
314 }
315 let args = args_from_raw(api, argc, argv);
316 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
317 let value = unsafe { &mut *cell.value.as_mut_ptr() };
318 (state.step)(&context, value, args.as_slice())
319 }));
320 match out {
321 Ok(Ok(())) => {}
322 Ok(Err(err)) => set_error(&context, &err),
323 Err(_) => context.result_error("panic in sqlite window step"),
324 }
325}
326
327extern "C" fn window_inverse_trampoline<P, T, Init, Step, Inverse, ValueFn, Final>(
328 ctx: *mut P::Context,
329 argc: i32,
330 argv: *mut *mut P::Value,
331) where
332 P: Sqlite3Api,
333 T: Send + 'static,
334 Init: Fn() -> T + Send + 'static,
335 Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
336 Inverse: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
337 ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
338 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
339{
340 let ctx = match NonNull::new(ctx) {
341 Some(ctx) => ctx,
342 None => return,
343 };
344 let user_data = unsafe { P::user_data(ctx) };
345 if user_data.is_null() {
346 return;
347 }
348 let state = unsafe { &mut *(user_data as *mut WindowState<P, T, Init, Step, Inverse, ValueFn, Final>) };
349 let api = unsafe { &*state.api };
350 let context = Context { api, ctx };
351 let cell = unsafe { get_agg_cell::<P, T>(api, ctx, true) };
352 if cell.is_null() {
353 context.result_error("sqlite window no memory");
354 return;
355 }
356 let cell = unsafe { &mut *cell };
357 if !cell.initialized {
358 cell.value.write((state.init)());
359 cell.initialized = true;
360 }
361 let args = args_from_raw(api, argc, argv);
362 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
363 let value = unsafe { &mut *cell.value.as_mut_ptr() };
364 (state.inverse)(&context, value, args.as_slice())
365 }));
366 match out {
367 Ok(Ok(())) => {}
368 Ok(Err(err)) => set_error(&context, &err),
369 Err(_) => context.result_error("panic in sqlite window inverse"),
370 }
371}
372
373extern "C" fn window_value_trampoline<P, T, Init, Step, Inverse, ValueFn, Final>(
374 ctx: *mut P::Context,
375) where
376 P: Sqlite3Api,
377 T: Send + 'static,
378 Init: Fn() -> T + Send + 'static,
379 Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
380 Inverse: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
381 ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
382 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
383{
384 let ctx = match NonNull::new(ctx) {
385 Some(ctx) => ctx,
386 None => return,
387 };
388 let user_data = unsafe { P::user_data(ctx) };
389 if user_data.is_null() {
390 return;
391 }
392 let state = unsafe { &mut *(user_data as *mut WindowState<P, T, Init, Step, Inverse, ValueFn, Final>) };
393 let api = unsafe { &*state.api };
394 let context = Context { api, ctx };
395 let cell = unsafe { get_agg_cell::<P, T>(api, ctx, false) };
396 if cell.is_null() {
397 context.result_null();
398 return;
399 }
400 let cell = unsafe { &mut *cell };
401 if !cell.initialized {
402 context.result_null();
403 return;
404 }
405 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
406 let value = unsafe { &mut *cell.value.as_mut_ptr() };
407 (state.value_fn)(&context, value)
408 }));
409 match out {
410 Ok(Ok(result)) => context.result_value(result),
411 Ok(Err(err)) => set_error(&context, &err),
412 Err(_) => context.result_error("panic in sqlite window value"),
413 }
414}
415
416extern "C" fn window_final_trampoline<P, T, Init, Step, Inverse, ValueFn, Final>(
417 ctx: *mut P::Context,
418) where
419 P: Sqlite3Api,
420 T: Send + 'static,
421 Init: Fn() -> T + Send + 'static,
422 Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
423 Inverse: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
424 ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
425 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
426{
427 let ctx = match NonNull::new(ctx) {
428 Some(ctx) => ctx,
429 None => return,
430 };
431 let user_data = unsafe { P::user_data(ctx) };
432 if user_data.is_null() {
433 return;
434 }
435 let state = unsafe { &mut *(user_data as *mut WindowState<P, T, Init, Step, Inverse, ValueFn, Final>) };
436 let api = unsafe { &*state.api };
437 let context = Context { api, ctx };
438 let cell = unsafe { get_agg_cell::<P, T>(api, ctx, false) };
439 if cell.is_null() {
440 context.result_null();
441 return;
442 }
443 let cell = unsafe { &mut *cell };
444 if !cell.initialized {
445 context.result_null();
446 return;
447 }
448 let value = unsafe { cell.value.assume_init_read() };
449 cell.initialized = false;
450 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
451 (state.final_fn)(&context, value)
452 }));
453 match out {
454 Ok(Ok(result)) => context.result_value(result),
455 Ok(Err(err)) => set_error(&context, &err),
456 Err(_) => context.result_error("panic in sqlite window final"),
457 }
458}
459
460extern "C" fn drop_boxed<T>(ptr: *mut c_void) {
461 if !ptr.is_null() {
462 unsafe { drop(Box::from_raw(ptr as *mut T)) };
463 }
464}
465
466impl<'p, P: Sqlite3Api> Connection<'p, P> {
467 pub fn create_scalar_function<F>(
469 &self,
470 name: &str,
471 n_args: i32,
472 func: F,
473 ) -> Result<()>
474 where
475 F: for<'a> FnMut(&Context<'a, P>, &[ValueRef<'a>]) -> Result<Value> + Send + 'static,
476 {
477 if !self.api.feature_set().contains(FeatureSet::CREATE_FUNCTION_V2) {
478 return Err(Error::feature_unavailable("create_function_v2 unsupported"));
479 }
480 let state = Box::new(ScalarState { api: self.api as *const P, func });
481 let user_data = Box::into_raw(state) as *mut c_void;
482 unsafe {
483 self.api.create_function_v2(
484 self.db,
485 name,
486 n_args,
487 FunctionFlags::empty(),
488 Some(scalar_trampoline::<P, F>),
489 None,
490 None,
491 user_data,
492 Some(drop_boxed::<ScalarState<P, F>>),
493 )
494 }
495 }
496
497 pub fn create_aggregate_function<T, Init, Step, Final>(
499 &self,
500 name: &str,
501 n_args: i32,
502 init: Init,
503 step: Step,
504 final_fn: Final,
505 ) -> Result<()>
506 where
507 T: Send + 'static,
508 Init: Fn() -> T + Send + 'static,
509 Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
510 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
511 {
512 if !self.api.feature_set().contains(FeatureSet::CREATE_FUNCTION_V2) {
513 return Err(Error::feature_unavailable("create_function_v2 unsupported"));
514 }
515 let state = Box::new(AggregateState::<P, T, Init, Step, Final> {
516 api: self.api as *const P,
517 init,
518 step,
519 final_fn,
520 _marker: core::marker::PhantomData,
521 });
522 let user_data = Box::into_raw(state) as *mut c_void;
523 unsafe {
524 self.api.create_function_v2(
525 self.db,
526 name,
527 n_args,
528 FunctionFlags::empty(),
529 None,
530 Some(aggregate_step_trampoline::<P, T, Init, Step, Final>),
531 Some(aggregate_final_trampoline::<P, T, Init, Step, Final>),
532 user_data,
533 Some(drop_boxed::<AggregateState<P, T, Init, Step, Final>>),
534 )
535 }
536 }
537
538 #[allow(clippy::too_many_arguments)]
540 pub fn create_window_function<T, Init, Step, Inverse, ValueFn, Final>(
541 &self,
542 name: &str,
543 n_args: i32,
544 init: Init,
545 step: Step,
546 inverse: Inverse,
547 value_fn: ValueFn,
548 final_fn: Final,
549 ) -> Result<()>
550 where
551 T: Send + 'static,
552 Init: Fn() -> T + Send + 'static,
553 Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
554 Inverse: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
555 ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
556 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
557 {
558 if !self.api.feature_set().contains(FeatureSet::WINDOW_FUNCTIONS) {
559 return Err(Error::feature_unavailable("window functions unsupported"));
560 }
561 let state = Box::new(WindowState::<P, T, Init, Step, Inverse, ValueFn, Final> {
562 api: self.api as *const P,
563 init,
564 step,
565 inverse,
566 value_fn,
567 final_fn,
568 _marker: core::marker::PhantomData,
569 });
570 let user_data = Box::into_raw(state) as *mut c_void;
571 unsafe {
572 self.api.create_window_function(
573 self.db,
574 name,
575 n_args,
576 FunctionFlags::empty(),
577 Some(window_step_trampoline::<P, T, Init, Step, Inverse, ValueFn, Final>),
578 Some(window_final_trampoline::<P, T, Init, Step, Inverse, ValueFn, Final>),
579 Some(window_value_trampoline::<P, T, Init, Step, Inverse, ValueFn, Final>),
580 Some(window_inverse_trampoline::<P, T, Init, Step, Inverse, ValueFn, Final>),
581 user_data,
582 Some(drop_boxed::<WindowState<P, T, Init, Step, Inverse, ValueFn, Final>>),
583 )
584 }
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use super::ArgBuffer;
591 use crate::value::ValueRef;
592
593 #[test]
594 fn arg_buffer_inline() {
595 let mut buf = ArgBuffer::new(2);
596 buf.push(ValueRef::Integer(1));
597 buf.push(ValueRef::Integer(2));
598 assert_eq!(buf.as_slice(), &[ValueRef::Integer(1), ValueRef::Integer(2)]);
599 }
600}