1#![allow(clippy::not_unsafe_ptr_arg_deref)]
4use std::{
5 ffi::CString,
6 os::raw::{c_int, c_void},
7 slice,
8};
9
10use crate::{
11 api,
12 constants::{SQLITE_INTERNAL, SQLITE_OKAY},
13 errors::{Error, ErrorKind, Result},
14 ext::sqlite3ext_create_function_v2,
15};
16use sqlite3ext_sys::{sqlite3, sqlite3_context, sqlite3_user_data, sqlite3_value};
17
18use bitflags::bitflags;
19
20use sqlite3ext_sys::{
21 SQLITE_DETERMINISTIC, SQLITE_DIRECTONLY, SQLITE_INNOCUOUS, SQLITE_SUBTYPE, SQLITE_UTF16,
22 SQLITE_UTF16BE, SQLITE_UTF16LE, SQLITE_UTF8,
23};
24
25bitflags! {
26 pub struct FunctionFlags: i32 {
31 const UTF8 = SQLITE_UTF8 as i32;
32 const UTF16LE = SQLITE_UTF16LE as i32;
33 const UTF16BE = SQLITE_UTF16BE as i32;
34 const UTF16 = SQLITE_UTF16 as i32;
35
36 const DETERMINISTIC = SQLITE_DETERMINISTIC as i32;
40 const DIRECTONLY = SQLITE_DIRECTONLY as i32;
41 const SUBTYPE = SQLITE_SUBTYPE as i32;
42 const INNOCUOUS = SQLITE_INNOCUOUS as i32;
43 }
44}
45
46pub fn define_scalar_function<F>(
58 db: *mut sqlite3,
59 name: &str,
60 num_args: c_int,
61 x_func: F,
62 func_flags: FunctionFlags,
63) -> Result<()>
64where
65 F: Fn(*mut sqlite3_context, &[*mut sqlite3_value]) -> Result<()>,
70{
71 let function_pointer: *mut F = Box::into_raw(Box::new(x_func));
72
73 unsafe extern "C" fn x_func_wrapper<F>(
74 context: *mut sqlite3_context,
75 argc: c_int,
76 argv: *mut *mut sqlite3_value,
77 ) where
78 F: Fn(*mut sqlite3_context, &[*mut sqlite3_value]) -> Result<()>,
79 {
80 let boxed_function: *mut F = sqlite3_user_data(context).cast::<F>();
81 let args = slice::from_raw_parts(argv, argc as usize);
83 match (*boxed_function)(context, args) {
84 Ok(()) => (),
85 Err(e) => {
86 if api::result_error(context, &e.result_error_message()).is_err() {
87 api::result_error_code(context, SQLITE_INTERNAL);
88 }
89 }
90 }
91 }
92 let cname = CString::new(name)?;
93 let result = unsafe {
94 sqlite3ext_create_function_v2(
95 db,
96 cname.as_ptr(),
97 num_args,
98 func_flags.bits,
99 function_pointer.cast::<c_void>(),
100 Some(x_func_wrapper::<F>),
101 None,
102 None,
103 None,
104 )
105 };
106
107 if result != SQLITE_OKAY {
108 Err(Error::new(ErrorKind::DefineScalarFunction(result)))
109 } else {
110 Ok(())
111 }
112}
113
114pub fn delete_scalar_function(
115 db: *mut sqlite3,
116 name: &str,
117 num_args: c_int,
118 func_flags: FunctionFlags,
119) -> Result<()> {
120 let cname = CString::new(name)?;
121 let result = unsafe {
122 sqlite3ext_create_function_v2(
123 db,
124 cname.as_ptr(),
125 num_args,
126 func_flags.bits,
127 std::ptr::null_mut(),
128 None,
129 None,
130 None,
131 None,
132 )
133 };
134
135 if result != SQLITE_OKAY {
136 println!("failed with {result}");
137 Err(Error::new(ErrorKind::DefineScalarFunction(result)))
138 } else {
139 Ok(())
140 }
141}
142
143pub fn define_scalar_function_with_aux<F, T>(
148 db: *mut sqlite3,
149 name: &str,
150 num_args: c_int,
151 x_func: F,
152 func_flags: FunctionFlags,
153 aux: T,
154) -> Result<()>
155where
156 F: Fn(*mut sqlite3_context, &[*mut sqlite3_value], &T) -> Result<()>,
157{
158 let function_pointer: *mut F = Box::into_raw(Box::new(x_func));
159 let aux_pointer: *mut T = Box::into_raw(Box::new(aux));
160 let app_pointer = Box::into_raw(Box::new((function_pointer, aux_pointer)));
161
162 unsafe extern "C" fn x_func_wrapper<F, T>(
163 context: *mut sqlite3_context,
164 argc: c_int,
165 argv: *mut *mut sqlite3_value,
166 ) where
167 F: Fn(*mut sqlite3_context, &[*mut sqlite3_value], &T) -> Result<()>,
168 {
169 let x = sqlite3_user_data(context).cast::<(*mut F, *mut T)>();
170 let boxed_function = (*x).0;
171 let aux = (*x).1;
172 let args = slice::from_raw_parts(argv, argc as usize);
174 let b = Box::from_raw(aux);
175 match (*boxed_function)(context, args, &*b) {
176 Ok(()) => (),
177 Err(e) => {
178 if api::result_error(context, &e.result_error_message()).is_err() {
179 api::result_error_code(context, SQLITE_INTERNAL);
180 }
181 }
182 }
183 Box::into_raw(b);
184 }
185 let cname = CString::new(name)?;
186
187 let result = unsafe {
188 sqlite3ext_create_function_v2(
189 db,
190 cname.as_ptr(),
191 num_args,
192 func_flags.bits,
193 app_pointer.cast::<c_void>(),
194 Some(x_func_wrapper::<F, T>),
195 None,
196 None,
197 None,
198 )
199 };
200
201 if result != SQLITE_OKAY {
202 Err(Error::new(ErrorKind::DefineScalarFunction(result)))
203 } else {
204 Ok(())
205 }
206}