1use std::any::Any;
56use std::marker::PhantomData;
57use std::ops::Deref;
58use std::os::raw::{c_int, c_void};
59use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
60use std::ptr;
61use std::slice;
62use std::sync::Arc;
63
64use crate::ffi;
65use crate::ffi::sqlite3_context;
66use crate::ffi::sqlite3_value;
67
68use crate::context::set_result;
69use crate::types::{FromSql, FromSqlError, ToSql, ValueRef};
70
71use crate::{str_to_cstring, Connection, Error, InnerConnection, Result};
72
73unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
74 #[cfg(feature = "modern_sqlite")]
79 fn constraint_error_code() -> i32 {
80 ffi::SQLITE_CONSTRAINT_FUNCTION
81 }
82 #[cfg(not(feature = "modern_sqlite"))]
83 fn constraint_error_code() -> i32 {
84 ffi::SQLITE_CONSTRAINT
85 }
86
87 if let Error::SqliteFailure(ref err, ref s) = *err {
88 ffi::sqlite3_result_error_code(ctx, err.extended_code);
89 if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
90 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
91 }
92 } else {
93 ffi::sqlite3_result_error_code(ctx, constraint_error_code());
94 if let Ok(cstr) = str_to_cstring(&err.to_string()) {
95 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
96 }
97 }
98}
99
100unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
101 drop(Box::from_raw(p.cast::<T>()));
102}
103
104pub struct Context<'a> {
107 ctx: *mut sqlite3_context,
108 args: &'a [*mut sqlite3_value],
109}
110
111impl Context<'_> {
112 #[inline]
114 #[must_use]
115 pub fn len(&self) -> usize {
116 self.args.len()
117 }
118
119 #[inline]
121 #[must_use]
122 pub fn is_empty(&self) -> bool {
123 self.args.is_empty()
124 }
125
126 pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> {
136 let arg = self.args[idx];
137 let value = unsafe { ValueRef::from_value(arg) };
138 FromSql::column_result(value).map_err(|err| match err {
139 FromSqlError::InvalidType => {
140 Error::InvalidFunctionParameterType(idx, value.data_type())
141 }
142 FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
143 FromSqlError::Other(err) => {
144 Error::FromSqlConversionFailure(idx, value.data_type(), err)
145 }
146 FromSqlError::InvalidBlobSize { .. } => {
147 Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
148 }
149 })
150 }
151
152 #[inline]
159 #[must_use]
160 pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
161 let arg = self.args[idx];
162 unsafe { ValueRef::from_value(arg) }
163 }
164
165 #[cfg(feature = "modern_sqlite")] #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))]
173 pub fn get_subtype(&self, idx: usize) -> std::os::raw::c_uint {
174 let arg = self.args[idx];
175 unsafe { ffi::sqlite3_value_subtype(arg) }
176 }
177
178 pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
186 where
187 T: Send + Sync + 'static,
188 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
189 F: FnOnce(ValueRef<'_>) -> Result<T, E>,
190 {
191 if let Some(v) = self.get_aux(arg)? {
192 Ok(v)
193 } else {
194 let vr = self.get_raw(arg as usize);
195 self.set_aux(
196 arg,
197 func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
198 )
199 }
200 }
201
202 pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
206 let orig: Arc<T> = Arc::new(value);
207 let inner: AuxInner = orig.clone();
208 let outer = Box::new(inner);
209 let raw: *mut AuxInner = Box::into_raw(outer);
210 unsafe {
211 ffi::sqlite3_set_auxdata(
212 self.ctx,
213 arg,
214 raw.cast(),
215 Some(free_boxed_value::<AuxInner>),
216 );
217 };
218 Ok(orig)
219 }
220
221 pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
226 let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
227 if p.is_null() {
228 Ok(None)
229 } else {
230 let v: AuxInner = AuxInner::clone(unsafe { &*p });
231 v.downcast::<T>()
232 .map(Some)
233 .map_err(|_| Error::GetAuxWrongType)
234 }
235 }
236
237 pub unsafe fn get_connection(&self) -> Result<ConnectionRef<'_>> {
244 let handle = ffi::sqlite3_context_db_handle(self.ctx);
245 Ok(ConnectionRef {
246 conn: Connection::from_handle(handle)?,
247 phantom: PhantomData,
248 })
249 }
250
251 #[cfg(feature = "modern_sqlite")] #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))]
254 pub fn set_result_subtype(&self, sub_type: std::os::raw::c_uint) {
255 unsafe { ffi::sqlite3_result_subtype(self.ctx, sub_type) };
256 }
257}
258
259pub struct ConnectionRef<'ctx> {
261 conn: Connection,
264 phantom: PhantomData<&'ctx Context<'ctx>>,
265}
266
267impl Deref for ConnectionRef<'_> {
268 type Target = Connection;
269
270 #[inline]
271 fn deref(&self) -> &Connection {
272 &self.conn
273 }
274}
275
276type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
277
278pub trait Aggregate<A, T>
284where
285 A: RefUnwindSafe + UnwindSafe,
286 T: ToSql,
287{
288 fn init(&self, _: &mut Context<'_>) -> Result<A>;
293
294 fn step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
297
298 fn finalize(&self, _: &mut Context<'_>, _: Option<A>) -> Result<T>;
308}
309
310#[cfg(feature = "window")]
313#[cfg_attr(docsrs, doc(cfg(feature = "window")))]
314pub trait WindowAggregate<A, T>: Aggregate<A, T>
315where
316 A: RefUnwindSafe + UnwindSafe,
317 T: ToSql,
318{
319 fn value(&self, _: Option<&A>) -> Result<T>;
322
323 fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
325}
326
327bitflags::bitflags! {
328 #[repr(C)]
332 pub struct FunctionFlags: ::std::os::raw::c_int {
333 const SQLITE_UTF8 = ffi::SQLITE_UTF8;
335 const SQLITE_UTF16LE = ffi::SQLITE_UTF16LE;
337 const SQLITE_UTF16BE = ffi::SQLITE_UTF16BE;
339 const SQLITE_UTF16 = ffi::SQLITE_UTF16;
341 const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC; const SQLITE_DIRECTONLY = 0x0000_0008_0000; const SQLITE_SUBTYPE = 0x0000_0010_0000; const SQLITE_INNOCUOUS = 0x0000_0020_0000; }
350}
351
352impl Default for FunctionFlags {
353 #[inline]
354 fn default() -> FunctionFlags {
355 FunctionFlags::SQLITE_UTF8
356 }
357}
358
359impl Connection {
360 #[inline]
398 pub fn create_scalar_function<F, T>(
399 &self,
400 fn_name: &str,
401 n_arg: c_int,
402 flags: FunctionFlags,
403 x_func: F,
404 ) -> Result<()>
405 where
406 F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
407 T: ToSql,
408 {
409 self.db
410 .borrow_mut()
411 .create_scalar_function(fn_name, n_arg, flags, x_func)
412 }
413
414 #[inline]
421 pub fn create_aggregate_function<A, D, T>(
422 &self,
423 fn_name: &str,
424 n_arg: c_int,
425 flags: FunctionFlags,
426 aggr: D,
427 ) -> Result<()>
428 where
429 A: RefUnwindSafe + UnwindSafe,
430 D: Aggregate<A, T> + 'static,
431 T: ToSql,
432 {
433 self.db
434 .borrow_mut()
435 .create_aggregate_function(fn_name, n_arg, flags, aggr)
436 }
437
438 #[cfg(feature = "window")]
444 #[cfg_attr(docsrs, doc(cfg(feature = "window")))]
445 #[inline]
446 pub fn create_window_function<A, W, T>(
447 &self,
448 fn_name: &str,
449 n_arg: c_int,
450 flags: FunctionFlags,
451 aggr: W,
452 ) -> Result<()>
453 where
454 A: RefUnwindSafe + UnwindSafe,
455 W: WindowAggregate<A, T> + 'static,
456 T: ToSql,
457 {
458 self.db
459 .borrow_mut()
460 .create_window_function(fn_name, n_arg, flags, aggr)
461 }
462
463 #[inline]
474 pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> {
475 self.db.borrow_mut().remove_function(fn_name, n_arg)
476 }
477}
478
479impl InnerConnection {
480 fn create_scalar_function<F, T>(
481 &mut self,
482 fn_name: &str,
483 n_arg: c_int,
484 flags: FunctionFlags,
485 x_func: F,
486 ) -> Result<()>
487 where
488 F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
489 T: ToSql,
490 {
491 unsafe extern "C" fn call_boxed_closure<F, T>(
492 ctx: *mut sqlite3_context,
493 argc: c_int,
494 argv: *mut *mut sqlite3_value,
495 ) where
496 F: FnMut(&Context<'_>) -> Result<T>,
497 T: ToSql,
498 {
499 let r = catch_unwind(|| {
500 let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::<F>();
501 assert!(!boxed_f.is_null(), "Internal error - null function pointer");
502 let ctx = Context {
503 ctx,
504 args: slice::from_raw_parts(argv, argc as usize),
505 };
506 (*boxed_f)(&ctx)
507 });
508 let t = match r {
509 Err(_) => {
510 report_error(ctx, &Error::UnwindingPanic);
511 return;
512 }
513 Ok(r) => r,
514 };
515 let t = t.as_ref().map(|t| ToSql::to_sql(t));
516
517 match t {
518 Ok(Ok(ref value)) => set_result(ctx, value),
519 Ok(Err(err)) => report_error(ctx, &err),
520 Err(err) => report_error(ctx, err),
521 }
522 }
523
524 let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
525 let c_name = str_to_cstring(fn_name)?;
526 let r = unsafe {
527 ffi::sqlite3_create_function_v2(
528 self.db(),
529 c_name.as_ptr(),
530 n_arg,
531 flags.bits(),
532 boxed_f.cast::<c_void>(),
533 Some(call_boxed_closure::<F, T>),
534 None,
535 None,
536 Some(free_boxed_value::<F>),
537 )
538 };
539 self.decode_result(r)
540 }
541
542 fn create_aggregate_function<A, D, T>(
543 &mut self,
544 fn_name: &str,
545 n_arg: c_int,
546 flags: FunctionFlags,
547 aggr: D,
548 ) -> Result<()>
549 where
550 A: RefUnwindSafe + UnwindSafe,
551 D: Aggregate<A, T> + 'static,
552 T: ToSql,
553 {
554 let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
555 let c_name = str_to_cstring(fn_name)?;
556 let r = unsafe {
557 ffi::sqlite3_create_function_v2(
558 self.db(),
559 c_name.as_ptr(),
560 n_arg,
561 flags.bits(),
562 boxed_aggr.cast::<c_void>(),
563 None,
564 Some(call_boxed_step::<A, D, T>),
565 Some(call_boxed_final::<A, D, T>),
566 Some(free_boxed_value::<D>),
567 )
568 };
569 self.decode_result(r)
570 }
571
572 #[cfg(feature = "window")]
573 fn create_window_function<A, W, T>(
574 &mut self,
575 fn_name: &str,
576 n_arg: c_int,
577 flags: FunctionFlags,
578 aggr: W,
579 ) -> Result<()>
580 where
581 A: RefUnwindSafe + UnwindSafe,
582 W: WindowAggregate<A, T> + 'static,
583 T: ToSql,
584 {
585 let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
586 let c_name = str_to_cstring(fn_name)?;
587 let r = unsafe {
588 ffi::sqlite3_create_window_function(
589 self.db(),
590 c_name.as_ptr(),
591 n_arg,
592 flags.bits(),
593 boxed_aggr.cast::<c_void>(),
594 Some(call_boxed_step::<A, W, T>),
595 Some(call_boxed_final::<A, W, T>),
596 Some(call_boxed_value::<A, W, T>),
597 Some(call_boxed_inverse::<A, W, T>),
598 Some(free_boxed_value::<W>),
599 )
600 };
601 self.decode_result(r)
602 }
603
604 fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
605 let c_name = str_to_cstring(fn_name)?;
606 let r = unsafe {
607 ffi::sqlite3_create_function_v2(
608 self.db(),
609 c_name.as_ptr(),
610 n_arg,
611 ffi::SQLITE_UTF8,
612 ptr::null_mut(),
613 None,
614 None,
615 None,
616 None,
617 )
618 };
619 self.decode_result(r)
620 }
621}
622
623unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
624 let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
625 if pac.is_null() {
626 return None;
627 }
628 Some(pac)
629}
630
631unsafe extern "C" fn call_boxed_step<A, D, T>(
632 ctx: *mut sqlite3_context,
633 argc: c_int,
634 argv: *mut *mut sqlite3_value,
635) where
636 A: RefUnwindSafe + UnwindSafe,
637 D: Aggregate<A, T>,
638 T: ToSql,
639{
640 let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) {
641 pac
642 } else {
643 ffi::sqlite3_result_error_nomem(ctx);
644 return;
645 };
646
647 let r = catch_unwind(|| {
648 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
649 assert!(
650 !boxed_aggr.is_null(),
651 "Internal error - null aggregate pointer"
652 );
653 let mut ctx = Context {
654 ctx,
655 args: slice::from_raw_parts(argv, argc as usize),
656 };
657
658 if (*pac as *mut A).is_null() {
659 *pac = Box::into_raw(Box::new((*boxed_aggr).init(&mut ctx)?));
660 }
661
662 (*boxed_aggr).step(&mut ctx, &mut **pac)
663 });
664 let r = match r {
665 Err(_) => {
666 report_error(ctx, &Error::UnwindingPanic);
667 return;
668 }
669 Ok(r) => r,
670 };
671 match r {
672 Ok(_) => {}
673 Err(err) => report_error(ctx, &err),
674 };
675}
676
677#[cfg(feature = "window")]
678unsafe extern "C" fn call_boxed_inverse<A, W, T>(
679 ctx: *mut sqlite3_context,
680 argc: c_int,
681 argv: *mut *mut sqlite3_value,
682) where
683 A: RefUnwindSafe + UnwindSafe,
684 W: WindowAggregate<A, T>,
685 T: ToSql,
686{
687 let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) {
688 pac
689 } else {
690 ffi::sqlite3_result_error_nomem(ctx);
691 return;
692 };
693
694 let r = catch_unwind(|| {
695 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
696 assert!(
697 !boxed_aggr.is_null(),
698 "Internal error - null aggregate pointer"
699 );
700 let mut ctx = Context {
701 ctx,
702 args: slice::from_raw_parts(argv, argc as usize),
703 };
704 (*boxed_aggr).inverse(&mut ctx, &mut **pac)
705 });
706 let r = match r {
707 Err(_) => {
708 report_error(ctx, &Error::UnwindingPanic);
709 return;
710 }
711 Ok(r) => r,
712 };
713 match r {
714 Ok(_) => {}
715 Err(err) => report_error(ctx, &err),
716 };
717}
718
719unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
720where
721 A: RefUnwindSafe + UnwindSafe,
722 D: Aggregate<A, T>,
723 T: ToSql,
724{
725 let a: Option<A> = match aggregate_context(ctx, 0) {
728 Some(pac) => {
729 if (*pac as *mut A).is_null() {
730 None
731 } else {
732 let a = Box::from_raw(*pac);
733 Some(*a)
734 }
735 }
736 None => None,
737 };
738
739 let r = catch_unwind(|| {
740 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
741 assert!(
742 !boxed_aggr.is_null(),
743 "Internal error - null aggregate pointer"
744 );
745 let mut ctx = Context { ctx, args: &mut [] };
746 (*boxed_aggr).finalize(&mut ctx, a)
747 });
748 let t = match r {
749 Err(_) => {
750 report_error(ctx, &Error::UnwindingPanic);
751 return;
752 }
753 Ok(r) => r,
754 };
755 let t = t.as_ref().map(|t| ToSql::to_sql(t));
756 match t {
757 Ok(Ok(ref value)) => set_result(ctx, value),
758 Ok(Err(err)) => report_error(ctx, &err),
759 Err(err) => report_error(ctx, err),
760 }
761}
762
763#[cfg(feature = "window")]
764unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
765where
766 A: RefUnwindSafe + UnwindSafe,
767 W: WindowAggregate<A, T>,
768 T: ToSql,
769{
770 let a: Option<&A> = match aggregate_context(ctx, 0) {
773 Some(pac) => {
774 if (*pac as *mut A).is_null() {
775 None
776 } else {
777 let a = &**pac;
778 Some(a)
779 }
780 }
781 None => None,
782 };
783
784 let r = catch_unwind(|| {
785 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
786 assert!(
787 !boxed_aggr.is_null(),
788 "Internal error - null aggregate pointer"
789 );
790 (*boxed_aggr).value(a)
791 });
792 let t = match r {
793 Err(_) => {
794 report_error(ctx, &Error::UnwindingPanic);
795 return;
796 }
797 Ok(r) => r,
798 };
799 let t = t.as_ref().map(|t| ToSql::to_sql(t));
800 match t {
801 Ok(Ok(ref value)) => set_result(ctx, value),
802 Ok(Err(err)) => report_error(ctx, &err),
803 Err(err) => report_error(ctx, err),
804 }
805}
806
807#[cfg(test)]
808mod test {
809 use regex::Regex;
810 use std::os::raw::c_double;
811
812 #[cfg(feature = "window")]
813 use crate::functions::WindowAggregate;
814 use crate::functions::{Aggregate, Context, FunctionFlags};
815 use crate::{Connection, Error, Result};
816
817 fn half(ctx: &Context<'_>) -> Result<c_double> {
818 assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
819 let value = ctx.get::<c_double>(0)?;
820 Ok(value / 2f64)
821 }
822
823 #[test]
824 fn test_function_half() -> Result<()> {
825 let db = Connection::open_in_memory()?;
826 db.create_scalar_function(
827 "half",
828 1,
829 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
830 half,
831 )?;
832 let result: Result<f64> = db.query_row("SELECT half(6)", [], |r| r.get(0));
833
834 assert!((3f64 - result?).abs() < f64::EPSILON);
835 Ok(())
836 }
837
838 #[test]
839 fn test_remove_function() -> Result<()> {
840 let db = Connection::open_in_memory()?;
841 db.create_scalar_function(
842 "half",
843 1,
844 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
845 half,
846 )?;
847 let result: Result<f64> = db.query_row("SELECT half(6)", [], |r| r.get(0));
848 assert!((3f64 - result?).abs() < f64::EPSILON);
849
850 db.remove_function("half", 1)?;
851 let result: Result<f64> = db.query_row("SELECT half(6)", [], |r| r.get(0));
852 assert!(result.is_err());
853 Ok(())
854 }
855
856 fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> {
860 assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
861 type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
862 let regexp: std::sync::Arc<Regex> = ctx
863 .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
864 Ok(Regex::new(vr.as_str()?)?)
865 })?;
866
867 let is_match = {
868 let text = ctx
869 .get_raw(1)
870 .as_str()
871 .map_err(|e| Error::UserFunctionError(e.into()))?;
872
873 regexp.is_match(text)
874 };
875
876 Ok(is_match)
877 }
878
879 #[test]
880 fn test_function_regexp_with_auxilliary() -> Result<()> {
881 let db = Connection::open_in_memory()?;
882 db.execute_batch(
883 "BEGIN;
884 CREATE TABLE foo (x string);
885 INSERT INTO foo VALUES ('lisa');
886 INSERT INTO foo VALUES ('lXsi');
887 INSERT INTO foo VALUES ('lisX');
888 END;",
889 )?;
890 db.create_scalar_function(
891 "regexp",
892 2,
893 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
894 regexp_with_auxilliary,
895 )?;
896
897 let result: Result<bool> =
898 db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", [], |r| r.get(0));
899
900 assert!(result?);
901
902 let result: Result<i64> = db.query_row(
903 "SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1",
904 [],
905 |r| r.get(0),
906 );
907
908 assert_eq!(2, result?);
909 Ok(())
910 }
911
912 #[test]
913 fn test_varargs_function() -> Result<()> {
914 let db = Connection::open_in_memory()?;
915 db.create_scalar_function(
916 "my_concat",
917 -1,
918 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
919 |ctx| {
920 let mut ret = String::new();
921
922 for idx in 0..ctx.len() {
923 let s = ctx.get::<String>(idx)?;
924 ret.push_str(&s);
925 }
926
927 Ok(ret)
928 },
929 )?;
930
931 for &(expected, query) in &[
932 ("", "SELECT my_concat()"),
933 ("onetwo", "SELECT my_concat('one', 'two')"),
934 ("abc", "SELECT my_concat('a', 'b', 'c')"),
935 ] {
936 let result: String = db.query_row(query, [], |r| r.get(0))?;
937 assert_eq!(expected, result);
938 }
939 Ok(())
940 }
941
942 #[test]
943 fn test_get_aux_type_checking() -> Result<()> {
944 let db = Connection::open_in_memory()?;
945 db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
946 if !ctx.get::<bool>(1)? {
947 ctx.set_aux::<i64>(0, 100)?;
948 } else {
949 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
950 assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100);
951 }
952 Ok(true)
953 })?;
954
955 let res: bool = db.query_row(
956 "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)",
957 [],
958 |r| r.get(0),
959 )?;
960 assert!(res);
962 Ok(())
963 }
964
965 struct Sum;
966 struct Count;
967
968 impl Aggregate<i64, Option<i64>> for Sum {
969 fn init(&self, _: &mut Context<'_>) -> Result<i64> {
970 Ok(0)
971 }
972
973 fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
974 *sum += ctx.get::<i64>(0)?;
975 Ok(())
976 }
977
978 fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> {
979 Ok(sum)
980 }
981 }
982
983 impl Aggregate<i64, i64> for Count {
984 fn init(&self, _: &mut Context<'_>) -> Result<i64> {
985 Ok(0)
986 }
987
988 fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
989 *sum += 1;
990 Ok(())
991 }
992
993 fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> {
994 Ok(sum.unwrap_or(0))
995 }
996 }
997
998 #[test]
999 fn test_sum() -> Result<()> {
1000 let db = Connection::open_in_memory()?;
1001 db.create_aggregate_function(
1002 "my_sum",
1003 1,
1004 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1005 Sum,
1006 )?;
1007
1008 let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1010 let result: Option<i64> = db.query_row(no_result, [], |r| r.get(0))?;
1011 assert!(result.is_none());
1012
1013 let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1014 let result: i64 = db.query_row(single_sum, [], |r| r.get(0))?;
1015 assert_eq!(4, result);
1016
1017 let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
1018 2, 1)";
1019 let result: (i64, i64) = db.query_row(dual_sum, [], |r| Ok((r.get(0)?, r.get(1)?)))?;
1020 assert_eq!((4, 2), result);
1021 Ok(())
1022 }
1023
1024 #[test]
1025 fn test_count() -> Result<()> {
1026 let db = Connection::open_in_memory()?;
1027 db.create_aggregate_function(
1028 "my_count",
1029 -1,
1030 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1031 Count,
1032 )?;
1033
1034 let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1036 let result: i64 = db.query_row(no_result, [], |r| r.get(0))?;
1037 assert_eq!(result, 0);
1038
1039 let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1040 let result: i64 = db.query_row(single_sum, [], |r| r.get(0))?;
1041 assert_eq!(2, result);
1042 Ok(())
1043 }
1044
1045 #[cfg(feature = "window")]
1046 impl WindowAggregate<i64, Option<i64>> for Sum {
1047 fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1048 *sum -= ctx.get::<i64>(0)?;
1049 Ok(())
1050 }
1051
1052 fn value(&self, sum: Option<&i64>) -> Result<Option<i64>> {
1053 Ok(sum.copied())
1054 }
1055 }
1056
1057 #[test]
1058 #[cfg(feature = "window")]
1059 fn test_window() -> Result<()> {
1060 use fallible_iterator::FallibleIterator;
1061
1062 let db = Connection::open_in_memory()?;
1063 db.create_window_function(
1064 "sumint",
1065 1,
1066 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1067 Sum,
1068 )?;
1069 db.execute_batch(
1070 "CREATE TABLE t3(x, y);
1071 INSERT INTO t3 VALUES('a', 4),
1072 ('b', 5),
1073 ('c', 3),
1074 ('d', 8),
1075 ('e', 1);",
1076 )?;
1077
1078 let mut stmt = db.prepare(
1079 "SELECT x, sumint(y) OVER (
1080 ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1081 ) AS sum_y
1082 FROM t3 ORDER BY x;",
1083 )?;
1084
1085 let results: Vec<(String, i64)> = stmt
1086 .query([])?
1087 .map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
1088 .collect()?;
1089 let expected = vec![
1090 ("a".to_owned(), 9),
1091 ("b".to_owned(), 12),
1092 ("c".to_owned(), 16),
1093 ("d".to_owned(), 12),
1094 ("e".to_owned(), 9),
1095 ];
1096 assert_eq!(expected, results);
1097 Ok(())
1098 }
1099}