sqlite3_ext/function/mod.rs
1//! Create application-defined functions.
2//!
3//! The functionality in this module is primarily exposed through
4//! [Connection::create_scalar_function] and [Connection::create_aggregate_function].
5use super::{ffi, sqlite3_match_version, types::*, value::*, Connection, RiskLevel};
6pub use context::*;
7use std::{cmp::Ordering, ffi::CString, ptr::null_mut};
8
9mod context;
10mod stubs;
11mod test;
12
13/// Constructor for aggregate functions.
14///
15/// Aggregate functions are instantiated using user data provided when the function is
16/// registered. There is a blanket implementation for types implementing [Default] for cases
17/// where user data is not required.
18pub trait FromUserData<T> {
19 /// Construct a new instance based on the provided user data.
20 fn from_user_data(data: &T) -> Self;
21}
22
23/// Implement an application-defined aggregate function which cannot be used as a window
24/// function.
25///
26/// In general, there is no reason to implement this trait instead of [AggregateFunction],
27/// because the latter provides a blanket implementation of the former.
28pub trait LegacyAggregateFunction<UserData>: FromUserData<UserData> {
29 /// Assign the default value of the aggregate function to the context using
30 /// [Context::set_result].
31 ///
32 /// This method is called when the aggregate function is invoked over an empty set of
33 /// rows. The default implementation is equivalent to
34 /// `Self::from_user_data(user_data).value(context)`.
35 fn default_value(user_data: &UserData, context: &Context) -> Result<()>
36 where
37 Self: Sized,
38 {
39 Self::from_user_data(user_data).value(context)
40 }
41
42 /// Add a new row to the aggregate.
43 fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
44
45 /// Assign the current value of the aggregate function to the context using
46 /// [Context::set_result]. If no result is set, SQL NULL is returned. If the function returns
47 /// an Err value, the SQL statement will fail, even if a result had been set before the
48 /// failure.
49 fn value(&self, context: &Context) -> Result<()>;
50}
51
52/// Implement an application-defined aggregate window function.
53///
54/// The function can be registered with a database connection using
55/// [Connection::create_aggregate_function].
56pub trait AggregateFunction<UserData>: FromUserData<UserData> {
57 /// Assign the default value of the aggregate function to the context using
58 /// [Context::set_result].
59 ///
60 /// This method is called when the aggregate function is invoked over an empty set of
61 /// rows. The default implementation is equivalent to
62 /// `Self::from_user_data(user_data).value(context)`.
63 fn default_value(user_data: &UserData, context: &Context) -> Result<()>
64 where
65 Self: Sized,
66 {
67 Self::from_user_data(user_data).value(context)
68 }
69
70 /// Add a new row to the aggregate.
71 fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
72
73 /// Assign the current value of the aggregate function to the context using
74 /// [Context::set_result]. If no result is set, SQL NULL is returned. If the function returns
75 /// an Err value, the SQL statement will fail, even if a result had been set before the
76 /// failure.
77 fn value(&self, context: &Context) -> Result<()>;
78
79 /// Remove the oldest presently aggregated row.
80 ///
81 /// The args are the same that were passed to [AggregateFunction::step] when this row
82 /// was added.
83 fn inverse(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
84}
85
86impl<U, F: Default> FromUserData<U> for F {
87 fn from_user_data(_: &U) -> F {
88 F::default()
89 }
90}
91
92impl<U, T: AggregateFunction<U>> LegacyAggregateFunction<U> for T {
93 fn default_value(user_data: &U, context: &Context) -> Result<()> {
94 <T as AggregateFunction<U>>::default_value(user_data, context)
95 }
96
97 fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()> {
98 <T as AggregateFunction<U>>::step(self, context, args)
99 }
100
101 fn value(&self, context: &Context) -> Result<()> {
102 <T as AggregateFunction<U>>::value(self, context)
103 }
104}
105
106#[derive(Debug, Clone)]
107pub struct FunctionOptions {
108 n_args: i32,
109 flags: i32,
110}
111
112impl Default for FunctionOptions {
113 fn default() -> Self {
114 FunctionOptions::default()
115 }
116}
117
118impl FunctionOptions {
119 pub const fn default() -> Self {
120 FunctionOptions {
121 n_args: -1,
122 flags: 0,
123 }
124 }
125
126 /// Set the number of parameters accepted by this function. Multiple functions may be
127 /// provided under the same name with different n_args values; the implementation will
128 /// be chosen by SQLite based on the number of parameters at the call site. The value
129 /// may also be -1, which means that the function accepts any number of parameters.
130 /// Functions which take a specific number of parameters take precedence over functions
131 /// which take any number.
132 ///
133 /// # Panics
134 ///
135 /// This function panics if n_args is outside the range -1..128. This limitation is
136 /// imposed by SQLite.
137 pub const fn set_n_args(mut self, n_args: i32) -> Self {
138 assert!(n_args >= -1 && n_args < 128, "n_args invalid");
139 self.n_args = n_args;
140 self
141 }
142
143 /// Enable or disable the deterministic flag. This flag indicates that the function is
144 /// pure. It must have no side effects and the value must be determined solely its the
145 /// parameters.
146 ///
147 /// The SQLite query planner is able to perform additional optimizations on
148 /// deterministic functions, so use of this flag is recommended where possible.
149 pub const fn set_deterministic(mut self, val: bool) -> Self {
150 if val {
151 self.flags |= ffi::SQLITE_DETERMINISTIC;
152 } else {
153 self.flags &= !ffi::SQLITE_DETERMINISTIC;
154 }
155 self
156 }
157
158 /// Set the level of risk for this function. See the [RiskLevel] enum for details about
159 /// what the individual options mean.
160 ///
161 /// Requires SQLite 3.31.0. On earlier versions of SQLite, this function is a harmless no-op.
162 pub const fn set_risk_level(
163 #[cfg_attr(not(modern_sqlite), allow(unused_mut))] mut self,
164 level: RiskLevel,
165 ) -> Self {
166 let _ = level;
167 #[cfg(modern_sqlite)]
168 {
169 self.flags |= match level {
170 RiskLevel::Innocuous => ffi::SQLITE_INNOCUOUS,
171 RiskLevel::DirectOnly => ffi::SQLITE_DIRECTONLY,
172 };
173 self.flags &= match level {
174 RiskLevel::Innocuous => !ffi::SQLITE_DIRECTONLY,
175 RiskLevel::DirectOnly => !ffi::SQLITE_INNOCUOUS,
176 };
177 }
178 self
179 }
180}
181
182impl Connection {
183 /// Create a stub function that always fails.
184 ///
185 /// This API makes sure a global version of a function with a particular name and
186 /// number of parameters exists. If no such function exists before this API is called,
187 /// a new function is created. The implementation of the new function always causes an
188 /// exception to be thrown. So the new function is not good for anything by itself. Its
189 /// only purpose is to be a placeholder function that can be overloaded by a virtual
190 /// table.
191 ///
192 /// For more information, see [vtab::FindFunctionVTab](super::vtab::FindFunctionVTab).
193 pub fn create_overloaded_function(&self, name: &str, opts: &FunctionOptions) -> Result<()> {
194 let guard = self.lock();
195 let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
196 unsafe {
197 Error::from_sqlite_desc(
198 ffi::sqlite3_overload_function(self.as_mut_ptr(), name.as_ptr() as _, opts.n_args),
199 guard,
200 )
201 }
202 }
203
204 /// Create a new scalar function. The function will be invoked with a [Context] and an array of
205 /// [ValueRef] objects. The function is required to set its output using [Context::set_result].
206 /// If no result is set, SQL NULL is returned. If the function returns an Err value, the SQL
207 /// statement will fail, even if a result had been set before the failure.
208 ///
209 /// # Compatibility
210 ///
211 /// On versions of SQLite earlier than 3.7.3, this function will leak the function and
212 /// all bound variables. This is because these versions of SQLite did not provide the
213 /// ability to specify a destructor function.
214 pub fn create_scalar_function<F>(
215 &self,
216 name: &str,
217 opts: &FunctionOptions,
218 func: F,
219 ) -> Result<()>
220 where
221 F: Fn(&Context, &mut [&mut ValueRef]) -> Result<()> + 'static,
222 {
223 let guard = self.lock();
224 let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
225 let func = Box::new(func);
226 unsafe {
227 Error::from_sqlite_desc(
228 sqlite3_match_version! {
229 3_007_003 => ffi::sqlite3_create_function_v2(
230 self.as_mut_ptr(),
231 name.as_ptr() as _,
232 opts.n_args,
233 opts.flags,
234 Box::into_raw(func) as _,
235 Some(stubs::call_scalar::<F>),
236 None,
237 None,
238 Some(ffi::drop_boxed::<F>),
239 ),
240 _ => ffi::sqlite3_create_function(
241 self.as_mut_ptr(),
242 name.as_ptr() as _,
243 opts.n_args,
244 opts.flags,
245 Box::into_raw(func) as _,
246 Some(stubs::call_scalar::<F>),
247 None,
248 None,
249 ),
250 },
251 guard,
252 )
253 }
254 }
255
256 /// Create a new aggregate function which cannot be used as a window function.
257 ///
258 /// In general, you should use
259 /// [create_aggregate_function](Connection::create_aggregate_function) instead, which
260 /// provides all of the same features as legacy aggregate functions but also support
261 /// WINDOW.
262 ///
263 /// # Compatibility
264 ///
265 /// On versions of SQLite earlier than 3.7.3, this function will leak the user data.
266 /// This is because these versions of SQLite did not provide the ability to specify a
267 /// destructor function.
268 pub fn create_legacy_aggregate_function<U, F: LegacyAggregateFunction<U>>(
269 &self,
270 name: &str,
271 opts: &FunctionOptions,
272 user_data: U,
273 ) -> Result<()> {
274 let guard = self.lock();
275 let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
276 let user_data = Box::new(user_data);
277 unsafe {
278 Error::from_sqlite_desc(
279 sqlite3_match_version! {
280 3_007_003 => ffi::sqlite3_create_function_v2(
281 self.as_mut_ptr(),
282 name.as_ptr() as _,
283 opts.n_args,
284 opts.flags,
285 Box::into_raw(user_data) as _,
286 None,
287 Some(stubs::aggregate_step::<U, F>),
288 Some(stubs::aggregate_final::<U, F>),
289 Some(ffi::drop_boxed::<U>),
290 ),
291 _ => ffi::sqlite3_create_function(
292 self.as_mut_ptr(),
293 name.as_ptr() as _,
294 opts.n_args,
295 opts.flags,
296 Box::into_raw(user_data) as _,
297 None,
298 Some(stubs::aggregate_step::<U, F>),
299 Some(stubs::aggregate_final::<U, F>),
300 ),
301 },
302 guard,
303 )
304 }
305 }
306
307 /// Create a new aggregate function.
308 ///
309 /// # Compatibility
310 ///
311 /// Window functions require SQLite 3.25.0. On earlier versions of SQLite, this
312 /// function will automatically fall back to
313 /// [create_legacy_aggregate_function](Connection::create_legacy_aggregate_function).
314 pub fn create_aggregate_function<U, F: AggregateFunction<U>>(
315 &self,
316 name: &str,
317 opts: &FunctionOptions,
318 user_data: U,
319 ) -> Result<()> {
320 sqlite3_match_version! {
321 3_025_000 => {
322 let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
323 let user_data = Box::new(user_data);
324 let guard = self.lock();
325 unsafe {
326 Error::from_sqlite_desc(ffi::sqlite3_create_window_function(
327 self.as_mut_ptr(),
328 name.as_ptr() as _,
329 opts.n_args,
330 opts.flags,
331 Box::into_raw(user_data) as _,
332 Some(stubs::aggregate_step::<U, F>),
333 Some(stubs::aggregate_final::<U, F>),
334 Some(stubs::aggregate_value::<U, F>),
335 Some(stubs::aggregate_inverse::<U, F>),
336 Some(ffi::drop_boxed::<U>),
337 ), guard)
338 }
339 },
340 _ => self.create_legacy_aggregate_function::<U, F>(name, opts, user_data),
341 }
342 }
343
344 /// Remove an application-defined scalar or aggregate function. The name and n_args
345 /// parameters must match the values used when the function was created.
346 pub fn remove_function(&self, name: &str, n_args: i32) -> Result<()> {
347 let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
348 let guard = self.lock();
349 unsafe {
350 Error::from_sqlite_desc(
351 ffi::sqlite3_create_function(
352 self.as_mut_ptr(),
353 name.as_ptr() as _,
354 n_args,
355 0,
356 null_mut(),
357 None,
358 None,
359 None,
360 ),
361 guard,
362 )
363 }
364 }
365
366 /// Register a new collating sequence.
367 pub fn create_collation<F: Fn(&str, &str) -> Ordering>(
368 &self,
369 name: &str,
370 func: F,
371 ) -> Result<()> {
372 let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
373 let func = Box::into_raw(Box::new(func));
374 let guard = self.lock();
375 unsafe {
376 let rc = ffi::sqlite3_create_collation_v2(
377 self.as_mut_ptr(),
378 name.as_ptr() as _,
379 ffi::SQLITE_UTF8,
380 func as _,
381 Some(stubs::compare::<F>),
382 Some(ffi::drop_boxed::<F>),
383 );
384 if rc != ffi::SQLITE_OK {
385 // The xDestroy callback is not called if the
386 // sqlite3_create_collation_v2() function fails.
387 drop(Box::from_raw(func));
388 }
389 Error::from_sqlite_desc(rc, guard)
390 }
391 }
392
393 /// Register a callback for when SQLite needs a collation sequence. The function will
394 /// be invoked when a collation sequence is needed, and
395 /// [create_collation](Connection::create_collation) can be used to provide the needed
396 /// sequence.
397 ///
398 /// Note: the provided function and any captured variables will be leaked. SQLite does
399 /// not provide any facilities for cleaning up this data.
400 pub fn set_collation_needed_func<F: Fn(&str)>(&self, func: F) -> Result<()> {
401 let func = Box::new(func);
402 let guard = self.lock();
403 unsafe {
404 Error::from_sqlite_desc(
405 ffi::sqlite3_collation_needed(
406 self.as_mut_ptr(),
407 Box::into_raw(func) as _,
408 Some(stubs::collation_needed::<F>),
409 ),
410 guard,
411 )
412 }
413 }
414}