1use std::ffi::CString;
2
3use function::{ScalarFunction, ScalarFunctionSet};
4use libduckdb_sys::{
5 duckdb_data_chunk, duckdb_function_info, duckdb_scalar_function_get_extra_info, duckdb_scalar_function_set_error,
6 duckdb_vector,
7};
8
9use crate::{
10 core::{DataChunkHandle, LogicalTypeHandle},
11 inner_connection::InnerConnection,
12 vtab::arrow::WritableVector,
13 Connection,
14};
15mod function;
16
17#[cfg(feature = "vscalar-arrow")]
19pub mod arrow;
20
21#[cfg(feature = "vscalar-arrow")]
22pub use arrow::{ArrowFunctionSignature, ArrowScalarParams, VArrowScalar};
23
24pub trait VScalar: Sized {
26 type State: Sized + Send + Sync + 'static;
30 unsafe fn invoke(
39 state: &Self::State,
40 input: &mut DataChunkHandle,
41 output: &mut dyn WritableVector,
42 ) -> Result<(), Box<dyn std::error::Error>>;
43
44 fn signatures() -> Vec<ScalarFunctionSignature>;
48
49 fn volatile() -> bool {
61 false
62 }
63}
64
65pub enum ScalarParams {
67 Exact(Vec<LogicalTypeHandle>),
69 Variadic(LogicalTypeHandle),
71}
72
73pub struct ScalarFunctionSignature {
75 parameters: Option<ScalarParams>,
76 return_type: LogicalTypeHandle,
77}
78
79impl ScalarFunctionSignature {
80 pub fn exact(params: Vec<LogicalTypeHandle>, return_type: LogicalTypeHandle) -> Self {
82 Self {
83 parameters: Some(ScalarParams::Exact(params)),
84 return_type,
85 }
86 }
87
88 pub fn variadic(param: LogicalTypeHandle, return_type: LogicalTypeHandle) -> Self {
90 Self {
91 parameters: Some(ScalarParams::Variadic(param)),
92 return_type,
93 }
94 }
95}
96
97impl ScalarFunctionSignature {
98 pub(crate) fn register_with_scalar(&self, f: &ScalarFunction) {
99 f.set_return_type(&self.return_type);
100
101 match &self.parameters {
102 Some(ScalarParams::Exact(params)) => {
103 for param in params.iter() {
104 f.add_parameter(param);
105 }
106 }
107 Some(ScalarParams::Variadic(param)) => {
108 f.add_variadic_parameter(param);
109 }
110 None => {
111 }
113 }
114 }
115}
116
117#[derive(Debug)]
119struct ScalarFunctionInfo(duckdb_function_info);
120
121impl From<duckdb_function_info> for ScalarFunctionInfo {
122 fn from(ptr: duckdb_function_info) -> Self {
123 Self(ptr)
124 }
125}
126
127impl ScalarFunctionInfo {
128 pub unsafe fn get_extra_info<T>(&self) -> &T {
129 &*(duckdb_scalar_function_get_extra_info(self.0).cast())
130 }
131
132 pub unsafe fn set_error(&self, error: &str) {
133 let c_str = CString::new(error).unwrap();
134 duckdb_scalar_function_set_error(self.0, c_str.as_ptr());
135 }
136}
137
138unsafe extern "C" fn scalar_func<T>(info: duckdb_function_info, input: duckdb_data_chunk, mut output: duckdb_vector)
139where
140 T: VScalar,
141{
142 let info = ScalarFunctionInfo::from(info);
143 let mut input = DataChunkHandle::new_unowned(input);
144 let result = T::invoke(info.get_extra_info(), &mut input, &mut output);
145 if let Err(e) = result {
146 info.set_error(&e.to_string());
147 }
148}
149
150impl Connection {
151 #[inline]
153 pub fn register_scalar_function<S: VScalar>(&self, name: &str) -> crate::Result<()>
154 where
155 S::State: Default,
156 {
157 let set = ScalarFunctionSet::new(name);
158 for signature in S::signatures() {
159 let scalar_function = ScalarFunction::new(name)?;
160 signature.register_with_scalar(&scalar_function);
161 scalar_function.set_function(Some(scalar_func::<S>));
162 if S::volatile() {
163 scalar_function.set_volatile();
164 }
165 scalar_function.set_extra_info(S::State::default());
166 set.add_function(scalar_function)?;
167 }
168 self.db.borrow_mut().register_scalar_function_set(set)
169 }
170
171 #[inline]
175 pub fn register_scalar_function_with_state<S: VScalar>(&self, name: &str, state: &S::State) -> crate::Result<()>
176 where
177 S::State: Clone,
178 {
179 let set = ScalarFunctionSet::new(name);
180 for signature in S::signatures() {
181 let scalar_function = ScalarFunction::new(name)?;
182 signature.register_with_scalar(&scalar_function);
183 scalar_function.set_function(Some(scalar_func::<S>));
184 if S::volatile() {
185 scalar_function.set_volatile();
186 }
187 scalar_function.set_extra_info(state.clone());
188 set.add_function(scalar_function)?;
189 }
190 self.db.borrow_mut().register_scalar_function_set(set)
191 }
192}
193
194impl InnerConnection {
195 pub fn register_scalar_function_set(&mut self, f: ScalarFunctionSet) -> crate::Result<()> {
197 f.register_with_connection(self.con)
198 }
199}
200
201#[cfg(test)]
202mod test {
203 use std::error::Error;
204
205 use arrow::array::Array;
206 use libduckdb_sys::duckdb_string_t;
207
208 use crate::{
209 core::{DataChunkHandle, Inserter, LogicalTypeHandle, LogicalTypeId},
210 types::DuckString,
211 vtab::arrow::WritableVector,
212 Connection,
213 };
214
215 use super::{ScalarFunctionSignature, VScalar};
216
217 struct ErrorScalar {}
218
219 impl VScalar for ErrorScalar {
220 type State = ();
221
222 unsafe fn invoke(
223 _: &Self::State,
224 input: &mut DataChunkHandle,
225 _: &mut dyn WritableVector,
226 ) -> Result<(), Box<dyn std::error::Error>> {
227 let mut msg = input.flat_vector(0).as_slice_with_len::<duckdb_string_t>(input.len())[0];
228 let string = DuckString::new(&mut msg).as_str();
229 Err(format!("Error: {string}").into())
230 }
231
232 fn signatures() -> Vec<ScalarFunctionSignature> {
233 vec![ScalarFunctionSignature::exact(
234 vec![LogicalTypeId::Varchar.into()],
235 LogicalTypeId::Varchar.into(),
236 )]
237 }
238 }
239
240 #[derive(Debug, Clone)]
241 struct TestState {
242 multiplier: usize,
243 prefix: String,
244 }
245
246 impl Default for TestState {
247 fn default() -> Self {
248 Self {
249 multiplier: 3,
250 prefix: "default".to_string(),
251 }
252 }
253 }
254
255 struct EchoScalar {}
256
257 impl VScalar for EchoScalar {
258 type State = TestState;
259
260 unsafe fn invoke(
261 state: &Self::State,
262 input: &mut DataChunkHandle,
263 output: &mut dyn WritableVector,
264 ) -> Result<(), Box<dyn std::error::Error>> {
265 let values = input.flat_vector(0);
266 let values = values.as_slice_with_len::<duckdb_string_t>(input.len());
267 let strings = values
268 .iter()
269 .map(|ptr| DuckString::new(&mut { *ptr }).as_str().to_string())
270 .take(input.len());
271 let output = output.flat_vector();
272
273 for s in strings {
274 let res = format!("{}: {}", state.prefix, s.repeat(state.multiplier));
275 output.insert(0, res.as_str());
276 }
277 Ok(())
278 }
279
280 fn signatures() -> Vec<ScalarFunctionSignature> {
281 vec![ScalarFunctionSignature::exact(
282 vec![LogicalTypeId::Varchar.into()],
283 LogicalTypeId::Varchar.into(),
284 )]
285 }
286 }
287
288 struct Repeat {}
289
290 impl VScalar for Repeat {
291 type State = ();
292
293 unsafe fn invoke(
294 _: &Self::State,
295 input: &mut DataChunkHandle,
296 output: &mut dyn WritableVector,
297 ) -> Result<(), Box<dyn std::error::Error>> {
298 let output = output.flat_vector();
299 let counts = input.flat_vector(1);
300 let values = input.flat_vector(0);
301 let values = values.as_slice_with_len::<duckdb_string_t>(input.len());
302 let strings = values
303 .iter()
304 .map(|ptr| DuckString::new(&mut { *ptr }).as_str().to_string());
305 let counts = counts.as_slice_with_len::<i32>(input.len());
306 for (count, value) in counts.iter().zip(strings).take(input.len()) {
307 output.insert(0, value.repeat((*count) as usize).as_str());
308 }
309
310 Ok(())
311 }
312
313 fn signatures() -> Vec<ScalarFunctionSignature> {
314 vec![ScalarFunctionSignature::exact(
315 vec![
316 LogicalTypeHandle::from(LogicalTypeId::Varchar),
317 LogicalTypeHandle::from(LogicalTypeId::Integer),
318 ],
319 LogicalTypeHandle::from(LogicalTypeId::Varchar),
320 )]
321 }
322 }
323
324 #[test]
325 fn test_scalar() -> Result<(), Box<dyn Error>> {
326 let conn = Connection::open_in_memory()?;
327
328 {
330 conn.register_scalar_function::<EchoScalar>("echo")?;
331
332 let mut stmt = conn.prepare("select echo('x')")?;
333 let mut rows = stmt.query([])?;
334
335 while let Some(row) = rows.next()? {
336 let res: String = row.get(0)?;
337 assert_eq!(res, "default: xxx");
338 }
339 }
340
341 {
343 conn.register_scalar_function_with_state::<EchoScalar>(
344 "echo2",
345 &TestState {
346 multiplier: 5,
347 prefix: "custom".to_string(),
348 },
349 )?;
350
351 let mut stmt = conn.prepare("select echo2('y')")?;
352 let mut rows = stmt.query([])?;
353
354 while let Some(row) = rows.next()? {
355 let res: String = row.get(0)?;
356 assert_eq!(res, "custom: yyyyy");
357 }
358 }
359
360 Ok(())
361 }
362
363 #[test]
364 fn test_scalar_error() -> Result<(), Box<dyn Error>> {
365 let conn = Connection::open_in_memory()?;
366 conn.register_scalar_function::<ErrorScalar>("error_udf")?;
367
368 let mut stmt = conn.prepare("select error_udf('blurg') as hello")?;
369 if let Err(err) = stmt.query([]) {
370 assert!(err.to_string().contains("Error: blurg"));
371 } else {
372 panic!("Expected an error");
373 }
374
375 Ok(())
376 }
377
378 #[test]
379 fn test_repeat_scalar() -> Result<(), Box<dyn Error>> {
380 let conn = Connection::open_in_memory()?;
381 conn.register_scalar_function::<Repeat>("nobie_repeat")?;
382
383 let batches = conn
384 .prepare("select nobie_repeat('Ho ho ho 🎅🎄', 3) as message from range(5)")?
385 .query_arrow([])?
386 .collect::<Vec<_>>();
387
388 for batch in batches.iter() {
389 let array = batch.column(0);
390 let array = array.as_any().downcast_ref::<::arrow::array::StringArray>().unwrap();
391 for i in 0..array.len() {
392 assert_eq!(array.value(i), "Ho ho ho 🎅🎄Ho ho ho 🎅🎄Ho ho ho 🎅🎄");
393 }
394 }
395
396 Ok(())
397 }
398
399 use std::sync::atomic::{AtomicU64, Ordering};
401 static VOLATILE_COUNTER: AtomicU64 = AtomicU64::new(0);
402 static NON_VOLATILE_COUNTER: AtomicU64 = AtomicU64::new(0);
403
404 struct CounterScalar {}
405
406 impl VScalar for CounterScalar {
407 type State = ();
408
409 unsafe fn invoke(
410 _: &Self::State,
411 input: &mut DataChunkHandle,
412 output: &mut dyn WritableVector,
413 ) -> Result<(), Box<dyn std::error::Error>> {
414 let len = input.len();
415 let mut output_vec = output.flat_vector();
416 let data = output_vec.as_mut_slice::<i64>();
417
418 for item in data.iter_mut().take(len) {
419 *item = NON_VOLATILE_COUNTER.fetch_add(1, Ordering::SeqCst) as i64;
420 }
421 Ok(())
422 }
423
424 fn signatures() -> Vec<ScalarFunctionSignature> {
425 vec![ScalarFunctionSignature::exact(
426 vec![],
427 LogicalTypeHandle::from(LogicalTypeId::Bigint),
428 )]
429 }
430 }
431
432 struct VolatileCounterScalar {}
433
434 impl VScalar for VolatileCounterScalar {
435 type State = ();
436
437 unsafe fn invoke(
438 _: &Self::State,
439 input: &mut DataChunkHandle,
440 output: &mut dyn WritableVector,
441 ) -> Result<(), Box<dyn std::error::Error>> {
442 let len = input.len();
443 let mut output_vec = output.flat_vector();
444 let data = output_vec.as_mut_slice::<i64>();
445
446 for item in data.iter_mut().take(len) {
447 *item = VOLATILE_COUNTER.fetch_add(1, Ordering::SeqCst) as i64;
448 }
449 Ok(())
450 }
451
452 fn signatures() -> Vec<ScalarFunctionSignature> {
453 vec![ScalarFunctionSignature::exact(
454 vec![],
455 LogicalTypeHandle::from(LogicalTypeId::Bigint),
456 )]
457 }
458
459 fn volatile() -> bool {
460 true
461 }
462 }
463
464 #[test]
465 fn test_volatile_scalar() -> Result<(), Box<dyn Error>> {
466 let conn = Connection::open_in_memory()?;
467
468 VOLATILE_COUNTER.store(0, Ordering::SeqCst);
469 conn.register_scalar_function::<VolatileCounterScalar>("volatile_counter")?;
470
471 let values: Vec<i64> = conn
472 .prepare("SELECT volatile_counter() FROM generate_series(1, 5)")?
473 .query_map([], |row| row.get(0))?
474 .collect::<Result<_, _>>()?;
475
476 assert_eq!(values, [0, 1, 2, 3, 4]);
477
478 Ok(())
479 }
480
481 #[test]
482 fn test_non_volatile_scalar() -> Result<(), Box<dyn Error>> {
483 let conn = Connection::open_in_memory()?;
484
485 NON_VOLATILE_COUNTER.store(0, Ordering::SeqCst);
486 conn.register_scalar_function::<CounterScalar>("non_volatile_counter")?;
487
488 let distinct_count: i64 = conn
490 .prepare("SELECT COUNT(DISTINCT non_volatile_counter()) FROM generate_series(1, 5)")?
491 .query_row([], |row| row.get(0))?;
492
493 assert_eq!(distinct_count, 1);
494
495 Ok(())
496 }
497}