sqlx_core_oldapi/sqlite/connection/
function.rs1use std::ffi::{c_char, CString};
2use std::os::raw::{c_int, c_void};
3use std::sync::Arc;
4
5use libsqlite3_sys::{
6 sqlite3_context, sqlite3_create_function_v2, sqlite3_result_blob, sqlite3_result_double,
7 sqlite3_result_error, sqlite3_result_int, sqlite3_result_int64, sqlite3_result_null,
8 sqlite3_result_text, sqlite3_user_data, sqlite3_value, sqlite3_value_type,
9 SQLITE_DETERMINISTIC, SQLITE_DIRECTONLY, SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8,
10};
11
12use crate::decode::Decode;
13use crate::encode::{Encode, IsNull};
14use crate::error::{BoxDynError, Error};
15use crate::sqlite::type_info::DataType;
16use crate::sqlite::Sqlite;
17use crate::sqlite::SqliteArgumentValue;
18use crate::sqlite::SqliteTypeInfo;
19use crate::sqlite::SqliteValue;
20use crate::sqlite::{connection::handle::ConnectionHandle, SqliteError};
21use crate::value::Value;
22
23pub trait SqliteCallable: Send + Sync {
24 unsafe fn call_boxed_closure(
25 &self,
26 ctx: *mut sqlite3_context,
27 argc: c_int,
28 argv: *mut *mut sqlite3_value,
29 );
30 fn arg_count(&self) -> i32;
32}
33
34pub struct SqliteFunctionCtx {
35 ctx: *mut sqlite3_context,
36 argument_values: Vec<SqliteValue>,
37}
38
39impl SqliteFunctionCtx {
40 unsafe fn new(ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value) -> Self {
44 let count = usize::try_from(argc).expect("invalid argument count");
45 let argument_values = (0..count)
46 .map(|i| {
47 let raw = *argv.add(i);
48 let data_type_code = sqlite3_value_type(raw);
49 let value_type_info = SqliteTypeInfo(DataType::from_code(data_type_code));
50 SqliteValue::new(raw, value_type_info)
51 })
52 .collect::<Vec<_>>();
53 Self {
54 ctx,
55 argument_values,
56 }
57 }
58
59 pub fn get_arg<'q, T: Decode<'q, Sqlite>>(&'q self, index: usize) -> T {
62 self.try_get_arg::<T>(index)
63 .expect("invalid argument index")
64 }
65
66 pub fn try_get_arg<'q, T: Decode<'q, Sqlite>>(
69 &'q self,
70 index: usize,
71 ) -> Result<T, BoxDynError> {
72 if let Some(value) = self.argument_values.get(index) {
73 let value_ref = value.as_ref();
74 T::decode(value_ref)
75 } else {
76 Err("invalid argument index".into())
77 }
78 }
79
80 pub fn set_result<'q, R: Encode<'q, Sqlite>>(&self, result: R) {
81 unsafe {
82 let mut arg_buffer: Vec<SqliteArgumentValue<'q>> = Vec::with_capacity(1);
83 if let IsNull::Yes = result.encode(&mut arg_buffer) {
84 sqlite3_result_null(self.ctx);
85 } else {
86 let arg = arg_buffer.pop().unwrap();
87 match arg {
88 SqliteArgumentValue::Null => {
89 sqlite3_result_null(self.ctx);
90 }
91 SqliteArgumentValue::Text(text) => {
92 sqlite3_result_text(
93 self.ctx,
94 text.as_ptr() as *const c_char,
95 text.len() as c_int,
96 SQLITE_TRANSIENT(),
97 );
98 }
99 SqliteArgumentValue::Blob(blob) => {
100 sqlite3_result_blob(
101 self.ctx,
102 blob.as_ptr() as *const c_void,
103 blob.len() as c_int,
104 SQLITE_TRANSIENT(),
105 );
106 }
107 SqliteArgumentValue::Double(double) => {
108 sqlite3_result_double(self.ctx, double);
109 }
110 SqliteArgumentValue::Int(int) => {
111 sqlite3_result_int(self.ctx, int);
112 }
113 SqliteArgumentValue::Int64(int64) => {
114 sqlite3_result_int64(self.ctx, int64);
115 }
116 }
117 }
118 }
119 }
120
121 pub fn set_error(&self, error_str: &str) {
122 let error_str = CString::new(error_str).expect("invalid error string");
123 unsafe {
124 sqlite3_result_error(
125 self.ctx,
126 error_str.as_ptr(),
127 error_str.as_bytes().len() as c_int,
128 );
129 }
130 }
131}
132
133impl<F: Fn(&SqliteFunctionCtx) + Send + Sync> SqliteCallable for F {
134 unsafe fn call_boxed_closure(
135 &self,
136 ctx: *mut sqlite3_context,
137 argc: c_int,
138 argv: *mut *mut sqlite3_value,
139 ) {
140 let ctx = SqliteFunctionCtx::new(ctx, argc, argv);
141 (*self)(&ctx);
142 }
143 fn arg_count(&self) -> i32 {
144 -1
145 }
146}
147
148#[derive(Clone)]
149pub struct Function {
150 name: CString,
151 func: Arc<dyn SqliteCallable>,
152 pub deterministic: bool,
154 pub direct_only: bool,
156 call:
157 unsafe extern "C" fn(ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value),
158}
159
160impl std::fmt::Debug for Function {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 f.debug_struct("Function")
163 .field("name", &self.name)
164 .field("deterministic", &self.deterministic)
165 .finish_non_exhaustive()
166 }
167}
168
169impl Function {
170 pub fn new<N, F>(name: N, func: F) -> Self
171 where
172 N: Into<Vec<u8>>,
173 F: SqliteCallable + Send + Sync + 'static,
174 {
175 Function {
176 name: CString::new(name).expect("invalid function name"),
177 func: Arc::new(func),
178 deterministic: false,
179 direct_only: false,
180 call: call_boxed_closure::<F>,
181 }
182 }
183
184 pub(crate) fn create(&self, handle: &mut ConnectionHandle) -> Result<(), Error> {
185 let raw_f = Arc::into_raw(Arc::clone(&self.func));
186 let r = unsafe {
187 sqlite3_create_function_v2(
188 handle.as_ptr(),
189 self.name.as_ptr(),
190 self.func.arg_count(), self.sqlite_flags(),
192 raw_f as *mut c_void,
193 Some(self.call),
194 None, None, None, )
198 };
199
200 if r == SQLITE_OK {
201 Ok(())
202 } else {
203 Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr()))))
204 }
205 }
206
207 fn sqlite_flags(&self) -> c_int {
208 let mut flags = SQLITE_UTF8;
209 if self.deterministic {
210 flags |= SQLITE_DETERMINISTIC;
211 }
212 if self.direct_only {
213 flags |= SQLITE_DIRECTONLY;
214 }
215 flags
216 }
217
218 pub fn deterministic(mut self) -> Self {
219 self.deterministic = true;
220 self
221 }
222
223 pub fn direct_only(mut self) -> Self {
224 self.direct_only = true;
225 self
226 }
227}
228
229unsafe extern "C" fn call_boxed_closure<F: SqliteCallable>(
230 ctx: *mut sqlite3_context,
231 argc: c_int,
232 argv: *mut *mut sqlite3_value,
233) {
234 let data = sqlite3_user_data(ctx);
235 let boxed_f: *mut F = data as *mut F;
236 debug_assert!(!boxed_f.is_null());
237 let expected_argc = (*boxed_f).arg_count();
238 debug_assert!(expected_argc == -1 || argc == expected_argc);
239 (*boxed_f).call_boxed_closure(ctx, argc, argv);
240}