Skip to main content

duckdb/vscalar/
mod.rs

1use std::ffi::CString;
2
3use function::{ScalarFunction, ScalarFunctionSet};
4use libduckdb_sys::{
5    duckdb_data_chunk, duckdb_function_info, duckdb_scalar_function_get_extra_info, duckdb_scalar_function_set_error,
6    duckdb_vector,
7};
8
9use crate::{
10    core::{DataChunkHandle, LogicalTypeHandle},
11    inner_connection::InnerConnection,
12    vtab::arrow::WritableVector,
13    Connection,
14};
15mod function;
16
17/// The duckdb Arrow scalar function interface
18#[cfg(feature = "vscalar-arrow")]
19pub mod arrow;
20
21#[cfg(feature = "vscalar-arrow")]
22pub use arrow::{ArrowFunctionSignature, ArrowScalarParams, VArrowScalar};
23
24/// Duckdb scalar function trait
25pub trait VScalar: Sized {
26    /// State set at registration time. Persists for the lifetime of the catalog entry.
27    /// Shared across worker threads and invocations — must not be modified during execution.
28    /// Must be `'static` as it is stored in DuckDB and may outlive the current stack frame.
29    type State: Sized + Send + Sync + 'static;
30    /// The actual function
31    ///
32    /// # Safety
33    ///
34    /// This function is unsafe because it:
35    ///
36    /// - Dereferences multiple raw pointers (`func`).
37    ///
38    unsafe fn invoke(
39        state: &Self::State,
40        input: &mut DataChunkHandle,
41        output: &mut dyn WritableVector,
42    ) -> Result<(), Box<dyn std::error::Error>>;
43
44    /// The possible signatures of the scalar function.
45    /// These will result in DuckDB scalar function overloads.
46    /// The invoke method should be able to handle all of these signatures.
47    fn signatures() -> Vec<ScalarFunctionSignature>;
48
49    /// Whether the scalar function is volatile.
50    ///
51    /// Volatile functions are re-evaluated for each row, even if they have no parameters.
52    /// This is useful for functions that generate random or unique values, such as random
53    /// number generators, UUID generators, or fake data generators.
54    ///
55    /// By default, DuckDB optimizes zero-argument scalar functions as constants, evaluating
56    /// them only once. Returning true from this method prevents this optimization.
57    ///
58    /// # Default
59    /// Returns `false` by default, meaning the function is not volatile.
60    fn volatile() -> bool {
61        false
62    }
63}
64
65/// Duckdb scalar function parameters
66pub enum ScalarParams {
67    /// Exact parameters
68    Exact(Vec<LogicalTypeHandle>),
69    /// Variadic parameters
70    Variadic(LogicalTypeHandle),
71}
72
73/// Duckdb scalar function signature
74pub struct ScalarFunctionSignature {
75    parameters: Option<ScalarParams>,
76    return_type: LogicalTypeHandle,
77}
78
79impl ScalarFunctionSignature {
80    /// Create an exact function signature
81    pub fn exact(params: Vec<LogicalTypeHandle>, return_type: LogicalTypeHandle) -> Self {
82        Self {
83            parameters: Some(ScalarParams::Exact(params)),
84            return_type,
85        }
86    }
87
88    /// Create a variadic function signature
89    pub fn variadic(param: LogicalTypeHandle, return_type: LogicalTypeHandle) -> Self {
90        Self {
91            parameters: Some(ScalarParams::Variadic(param)),
92            return_type,
93        }
94    }
95}
96
97impl ScalarFunctionSignature {
98    pub(crate) fn register_with_scalar(&self, f: &ScalarFunction) {
99        f.set_return_type(&self.return_type);
100
101        match &self.parameters {
102            Some(ScalarParams::Exact(params)) => {
103                for param in params.iter() {
104                    f.add_parameter(param);
105                }
106            }
107            Some(ScalarParams::Variadic(param)) => {
108                f.add_variadic_parameter(param);
109            }
110            None => {
111                // do nothing
112            }
113        }
114    }
115}
116
117/// An interface to store and retrieve data during the function execution stage
118#[derive(Debug)]
119struct ScalarFunctionInfo(duckdb_function_info);
120
121impl From<duckdb_function_info> for ScalarFunctionInfo {
122    fn from(ptr: duckdb_function_info) -> Self {
123        Self(ptr)
124    }
125}
126
127impl ScalarFunctionInfo {
128    pub unsafe fn get_extra_info<T>(&self) -> &T {
129        &*(duckdb_scalar_function_get_extra_info(self.0).cast())
130    }
131
132    pub unsafe fn set_error(&self, error: &str) {
133        let c_str = CString::new(error).unwrap();
134        duckdb_scalar_function_set_error(self.0, c_str.as_ptr());
135    }
136}
137
138unsafe extern "C" fn scalar_func<T>(info: duckdb_function_info, input: duckdb_data_chunk, mut output: duckdb_vector)
139where
140    T: VScalar,
141{
142    let info = ScalarFunctionInfo::from(info);
143    let mut input = DataChunkHandle::new_unowned(input);
144    let result = T::invoke(info.get_extra_info(), &mut input, &mut output);
145    if let Err(e) = result {
146        info.set_error(&e.to_string());
147    }
148}
149
150impl Connection {
151    /// Register the given ScalarFunction with default state.
152    #[inline]
153    pub fn register_scalar_function<S: VScalar>(&self, name: &str) -> crate::Result<()>
154    where
155        S::State: Default,
156    {
157        let set = ScalarFunctionSet::new(name);
158        for signature in S::signatures() {
159            let scalar_function = ScalarFunction::new(name)?;
160            signature.register_with_scalar(&scalar_function);
161            scalar_function.set_function(Some(scalar_func::<S>));
162            if S::volatile() {
163                scalar_function.set_volatile();
164            }
165            scalar_function.set_extra_info(S::State::default());
166            set.add_function(scalar_function)?;
167        }
168        self.db.borrow_mut().register_scalar_function_set(set)
169    }
170
171    /// Register the given ScalarFunction with custom state.
172    ///
173    /// The state is cloned once per function signature (overload) and stored in DuckDB's catalog.
174    #[inline]
175    pub fn register_scalar_function_with_state<S: VScalar>(&self, name: &str, state: &S::State) -> crate::Result<()>
176    where
177        S::State: Clone,
178    {
179        let set = ScalarFunctionSet::new(name);
180        for signature in S::signatures() {
181            let scalar_function = ScalarFunction::new(name)?;
182            signature.register_with_scalar(&scalar_function);
183            scalar_function.set_function(Some(scalar_func::<S>));
184            if S::volatile() {
185                scalar_function.set_volatile();
186            }
187            scalar_function.set_extra_info(state.clone());
188            set.add_function(scalar_function)?;
189        }
190        self.db.borrow_mut().register_scalar_function_set(set)
191    }
192}
193
194impl InnerConnection {
195    /// Register the given ScalarFunction with the current db
196    pub fn register_scalar_function_set(&mut self, f: ScalarFunctionSet) -> crate::Result<()> {
197        f.register_with_connection(self.con)
198    }
199}
200
201#[cfg(test)]
202mod test {
203    use std::error::Error;
204
205    use arrow::array::Array;
206    use libduckdb_sys::duckdb_string_t;
207
208    use crate::{
209        core::{DataChunkHandle, Inserter, LogicalTypeHandle, LogicalTypeId},
210        types::DuckString,
211        vtab::arrow::WritableVector,
212        Connection,
213    };
214
215    use super::{ScalarFunctionSignature, VScalar};
216
217    struct ErrorScalar {}
218
219    impl VScalar for ErrorScalar {
220        type State = ();
221
222        unsafe fn invoke(
223            _: &Self::State,
224            input: &mut DataChunkHandle,
225            _: &mut dyn WritableVector,
226        ) -> Result<(), Box<dyn std::error::Error>> {
227            let mut msg = input.flat_vector(0).as_slice_with_len::<duckdb_string_t>(input.len())[0];
228            let string = DuckString::new(&mut msg).as_str();
229            Err(format!("Error: {string}").into())
230        }
231
232        fn signatures() -> Vec<ScalarFunctionSignature> {
233            vec![ScalarFunctionSignature::exact(
234                vec![LogicalTypeId::Varchar.into()],
235                LogicalTypeId::Varchar.into(),
236            )]
237        }
238    }
239
240    #[derive(Debug, Clone)]
241    struct TestState {
242        multiplier: usize,
243        prefix: String,
244    }
245
246    impl Default for TestState {
247        fn default() -> Self {
248            Self {
249                multiplier: 3,
250                prefix: "default".to_string(),
251            }
252        }
253    }
254
255    struct EchoScalar {}
256
257    impl VScalar for EchoScalar {
258        type State = TestState;
259
260        unsafe fn invoke(
261            state: &Self::State,
262            input: &mut DataChunkHandle,
263            output: &mut dyn WritableVector,
264        ) -> Result<(), Box<dyn std::error::Error>> {
265            let values = input.flat_vector(0);
266            let values = values.as_slice_with_len::<duckdb_string_t>(input.len());
267            let strings = values
268                .iter()
269                .map(|ptr| DuckString::new(&mut { *ptr }).as_str().to_string())
270                .take(input.len());
271            let output = output.flat_vector();
272
273            for s in strings {
274                let res = format!("{}: {}", state.prefix, s.repeat(state.multiplier));
275                output.insert(0, res.as_str());
276            }
277            Ok(())
278        }
279
280        fn signatures() -> Vec<ScalarFunctionSignature> {
281            vec![ScalarFunctionSignature::exact(
282                vec![LogicalTypeId::Varchar.into()],
283                LogicalTypeId::Varchar.into(),
284            )]
285        }
286    }
287
288    struct Repeat {}
289
290    impl VScalar for Repeat {
291        type State = ();
292
293        unsafe fn invoke(
294            _: &Self::State,
295            input: &mut DataChunkHandle,
296            output: &mut dyn WritableVector,
297        ) -> Result<(), Box<dyn std::error::Error>> {
298            let output = output.flat_vector();
299            let counts = input.flat_vector(1);
300            let values = input.flat_vector(0);
301            let values = values.as_slice_with_len::<duckdb_string_t>(input.len());
302            let strings = values
303                .iter()
304                .map(|ptr| DuckString::new(&mut { *ptr }).as_str().to_string());
305            let counts = counts.as_slice_with_len::<i32>(input.len());
306            for (count, value) in counts.iter().zip(strings).take(input.len()) {
307                output.insert(0, value.repeat((*count) as usize).as_str());
308            }
309
310            Ok(())
311        }
312
313        fn signatures() -> Vec<ScalarFunctionSignature> {
314            vec![ScalarFunctionSignature::exact(
315                vec![
316                    LogicalTypeHandle::from(LogicalTypeId::Varchar),
317                    LogicalTypeHandle::from(LogicalTypeId::Integer),
318                ],
319                LogicalTypeHandle::from(LogicalTypeId::Varchar),
320            )]
321        }
322    }
323
324    #[test]
325    fn test_scalar() -> Result<(), Box<dyn Error>> {
326        let conn = Connection::open_in_memory()?;
327
328        // Test with default state
329        {
330            conn.register_scalar_function::<EchoScalar>("echo")?;
331
332            let mut stmt = conn.prepare("select echo('x')")?;
333            let mut rows = stmt.query([])?;
334
335            while let Some(row) = rows.next()? {
336                let res: String = row.get(0)?;
337                assert_eq!(res, "default: xxx");
338            }
339        }
340
341        // Test with custom state
342        {
343            conn.register_scalar_function_with_state::<EchoScalar>(
344                "echo2",
345                &TestState {
346                    multiplier: 5,
347                    prefix: "custom".to_string(),
348                },
349            )?;
350
351            let mut stmt = conn.prepare("select echo2('y')")?;
352            let mut rows = stmt.query([])?;
353
354            while let Some(row) = rows.next()? {
355                let res: String = row.get(0)?;
356                assert_eq!(res, "custom: yyyyy");
357            }
358        }
359
360        Ok(())
361    }
362
363    #[test]
364    fn test_scalar_error() -> Result<(), Box<dyn Error>> {
365        let conn = Connection::open_in_memory()?;
366        conn.register_scalar_function::<ErrorScalar>("error_udf")?;
367
368        let mut stmt = conn.prepare("select error_udf('blurg') as hello")?;
369        if let Err(err) = stmt.query([]) {
370            assert!(err.to_string().contains("Error: blurg"));
371        } else {
372            panic!("Expected an error");
373        }
374
375        Ok(())
376    }
377
378    #[test]
379    fn test_repeat_scalar() -> Result<(), Box<dyn Error>> {
380        let conn = Connection::open_in_memory()?;
381        conn.register_scalar_function::<Repeat>("nobie_repeat")?;
382
383        let batches = conn
384            .prepare("select nobie_repeat('Ho ho ho 🎅🎄', 3) as message from range(5)")?
385            .query_arrow([])?
386            .collect::<Vec<_>>();
387
388        for batch in batches.iter() {
389            let array = batch.column(0);
390            let array = array.as_any().downcast_ref::<::arrow::array::StringArray>().unwrap();
391            for i in 0..array.len() {
392                assert_eq!(array.value(i), "Ho ho ho 🎅🎄Ho ho ho 🎅🎄Ho ho ho 🎅🎄");
393            }
394        }
395
396        Ok(())
397    }
398
399    // Counters for testing volatile functions
400    use std::sync::atomic::{AtomicU64, Ordering};
401    static VOLATILE_COUNTER: AtomicU64 = AtomicU64::new(0);
402    static NON_VOLATILE_COUNTER: AtomicU64 = AtomicU64::new(0);
403
404    struct CounterScalar {}
405
406    impl VScalar for CounterScalar {
407        type State = ();
408
409        unsafe fn invoke(
410            _: &Self::State,
411            input: &mut DataChunkHandle,
412            output: &mut dyn WritableVector,
413        ) -> Result<(), Box<dyn std::error::Error>> {
414            let len = input.len();
415            let mut output_vec = output.flat_vector();
416            let data = output_vec.as_mut_slice::<i64>();
417
418            for item in data.iter_mut().take(len) {
419                *item = NON_VOLATILE_COUNTER.fetch_add(1, Ordering::SeqCst) as i64;
420            }
421            Ok(())
422        }
423
424        fn signatures() -> Vec<ScalarFunctionSignature> {
425            vec![ScalarFunctionSignature::exact(
426                vec![],
427                LogicalTypeHandle::from(LogicalTypeId::Bigint),
428            )]
429        }
430    }
431
432    struct VolatileCounterScalar {}
433
434    impl VScalar for VolatileCounterScalar {
435        type State = ();
436
437        unsafe fn invoke(
438            _: &Self::State,
439            input: &mut DataChunkHandle,
440            output: &mut dyn WritableVector,
441        ) -> Result<(), Box<dyn std::error::Error>> {
442            let len = input.len();
443            let mut output_vec = output.flat_vector();
444            let data = output_vec.as_mut_slice::<i64>();
445
446            for item in data.iter_mut().take(len) {
447                *item = VOLATILE_COUNTER.fetch_add(1, Ordering::SeqCst) as i64;
448            }
449            Ok(())
450        }
451
452        fn signatures() -> Vec<ScalarFunctionSignature> {
453            vec![ScalarFunctionSignature::exact(
454                vec![],
455                LogicalTypeHandle::from(LogicalTypeId::Bigint),
456            )]
457        }
458
459        fn volatile() -> bool {
460            true
461        }
462    }
463
464    #[test]
465    fn test_volatile_scalar() -> Result<(), Box<dyn Error>> {
466        let conn = Connection::open_in_memory()?;
467
468        VOLATILE_COUNTER.store(0, Ordering::SeqCst);
469        conn.register_scalar_function::<VolatileCounterScalar>("volatile_counter")?;
470
471        let values: Vec<i64> = conn
472            .prepare("SELECT volatile_counter() FROM generate_series(1, 5)")?
473            .query_map([], |row| row.get(0))?
474            .collect::<Result<_, _>>()?;
475
476        assert_eq!(values, [0, 1, 2, 3, 4]);
477
478        Ok(())
479    }
480
481    #[test]
482    fn test_non_volatile_scalar() -> Result<(), Box<dyn Error>> {
483        let conn = Connection::open_in_memory()?;
484
485        NON_VOLATILE_COUNTER.store(0, Ordering::SeqCst);
486        conn.register_scalar_function::<CounterScalar>("non_volatile_counter")?;
487
488        // Constant folding should make every row identical
489        let distinct_count: i64 = conn
490            .prepare("SELECT COUNT(DISTINCT non_volatile_counter()) FROM generate_series(1, 5)")?
491            .query_row([], |row| row.get(0))?;
492
493        assert_eq!(distinct_count, 1);
494
495        Ok(())
496    }
497}