Skip to main content

duckdb/vtab/
mod.rs

1// #![warn(unsafe_op_in_unsafe_fn)]
2
3use std::ffi::c_void;
4
5use crate::{error::Error, inner_connection::InnerConnection, Connection, Result};
6
7use super::ffi;
8
9mod function;
10mod value;
11
12/// The duckdb Arrow table function interface
13#[cfg(feature = "vtab-arrow")]
14pub mod arrow;
15#[cfg(feature = "vtab-arrow")]
16pub use self::arrow::{
17    arrow_arraydata_to_query_params, arrow_ffi_to_query_params, arrow_recordbatch_to_query_params,
18    record_batch_to_duckdb_data_chunk, to_duckdb_logical_type, to_duckdb_type_id,
19};
20#[cfg(feature = "vtab-excel")]
21mod excel;
22
23pub use function::{BindInfo, InitInfo, TableFunction, TableFunctionInfo};
24pub use value::Value;
25
26use crate::core::{DataChunkHandle, LogicalTypeHandle};
27use ffi::{duckdb_bind_info, duckdb_data_chunk, duckdb_function_info, duckdb_init_info};
28
29/// Given a raw pointer to a box, free the box and the data contained within it.
30///
31/// # Safety
32/// The pointer must be a valid pointer to a `Box<T>` created by `Box::into_raw`.
33unsafe extern "C" fn drop_boxed<T>(v: *mut c_void) {
34    drop(unsafe { Box::from_raw(v.cast::<T>()) });
35}
36
37/// Duckdb table function trait
38///
39/// See to the HelloVTab example for more details
40/// <https://duckdb.org/docs/api/c/table_functions>
41pub trait VTab: Sized {
42    /// The data type of the init data.
43    ///
44    /// The init data tracks the state of the table function and is global across threads.
45    ///
46    /// The init data is shared across threads so must be `Send + Sync`.
47    type InitData: Sized + Send + Sync;
48
49    /// The data type of the bind data.
50    ///
51    /// The bind data is shared across threads so must be `Send + Sync`.
52    type BindData: Sized + Send + Sync;
53
54    /// Bind data to the table function
55    ///
56    /// This function is used for determining the return type of a table producing function and returning bind data
57    fn bind(bind: &BindInfo) -> Result<Self::BindData, Box<dyn std::error::Error>>;
58
59    /// Initialize the table function
60    fn init(init: &InitInfo) -> Result<Self::InitData, Box<dyn std::error::Error>>;
61
62    /// Generate rows from the table function.
63    ///
64    /// The implementation should populate the `output` parameter with the rows to be returned.
65    ///
66    /// When the table function is done, the implementation should set the length of the output to 0.
67    fn func(func: &TableFunctionInfo<Self>, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>>;
68
69    /// Does the table function support pushdown
70    /// default is false
71    fn supports_pushdown() -> bool {
72        false
73    }
74    /// The parameters of the table function
75    /// default is None
76    fn parameters() -> Option<Vec<LogicalTypeHandle>> {
77        None
78    }
79    /// The named parameters of the table function
80    /// default is None
81    fn named_parameters() -> Option<Vec<(String, LogicalTypeHandle)>> {
82        None
83    }
84}
85
86unsafe extern "C" fn func<T>(info: duckdb_function_info, output: duckdb_data_chunk)
87where
88    T: VTab,
89{
90    let info = TableFunctionInfo::<T>::from(info);
91    let mut data_chunk_handle = DataChunkHandle::new_unowned(output);
92    let result = T::func(&info, &mut data_chunk_handle);
93    if let Err(e) = result {
94        info.set_error(&e.to_string());
95    }
96}
97
98unsafe extern "C" fn init<T>(info: duckdb_init_info)
99where
100    T: VTab,
101{
102    let info = InitInfo::from(info);
103    match T::init(&info) {
104        Ok(init_data) => {
105            info.set_init_data(
106                Box::into_raw(Box::new(init_data)) as *mut c_void,
107                Some(drop_boxed::<T::InitData>),
108            );
109        }
110        Err(e) => {
111            info.set_error(&e.to_string());
112        }
113    }
114}
115
116unsafe extern "C" fn bind<T>(info: duckdb_bind_info)
117where
118    T: VTab,
119{
120    let info = BindInfo::from(info);
121    match T::bind(&info) {
122        Ok(bind_data) => {
123            info.set_bind_data(
124                Box::into_raw(Box::new(bind_data)) as *mut c_void,
125                Some(drop_boxed::<T::BindData>),
126            );
127        }
128        Err(e) => {
129            info.set_error(&e.to_string());
130        }
131    }
132}
133
134impl Connection {
135    /// Register the given TableFunction with the current db
136    #[inline]
137    pub fn register_table_function<T: VTab>(&self, name: &str) -> Result<()> {
138        let table_function = TableFunction::default();
139        table_function
140            .set_name(name)
141            .supports_pushdown(T::supports_pushdown())
142            .set_bind(Some(bind::<T>))
143            .set_init(Some(init::<T>))
144            .set_function(Some(func::<T>));
145        for ty in T::parameters().unwrap_or_default() {
146            table_function.add_parameter(&ty);
147        }
148        for (name, ty) in T::named_parameters().unwrap_or_default() {
149            table_function.add_named_parameter(&name, &ty);
150        }
151        self.db.borrow_mut().register_table_function(table_function)
152    }
153
154    /// Register the given TableFunction with custom extra info.
155    ///
156    /// This allows you to pass extra info that can be accessed during bind, init, and execution
157    /// via `BindInfo::get_extra_info`, `InitInfo::get_extra_info`, or `TableFunctionInfo::get_extra_info`.
158    ///
159    /// The extra info is cloned once during registration and stored in DuckDB's catalog.
160    #[inline]
161    pub fn register_table_function_with_extra_info<T: VTab, E>(&self, name: &str, extra_info: &E) -> Result<()>
162    where
163        E: Clone + Send + Sync + 'static,
164    {
165        let table_function = TableFunction::default();
166        table_function
167            .set_name(name)
168            .supports_pushdown(T::supports_pushdown())
169            .set_bind(Some(bind::<T>))
170            .set_init(Some(init::<T>))
171            .set_function(Some(func::<T>))
172            .set_extra_info(extra_info.clone());
173        for ty in T::parameters().unwrap_or_default() {
174            table_function.add_parameter(&ty);
175        }
176        for (name, ty) in T::named_parameters().unwrap_or_default() {
177            table_function.add_named_parameter(&name, &ty);
178        }
179        self.db.borrow_mut().register_table_function(table_function)
180    }
181}
182
183impl InnerConnection {
184    /// Register the given TableFunction with the current db
185    pub fn register_table_function(&mut self, table_function: TableFunction) -> Result<()> {
186        unsafe {
187            let rc = ffi::duckdb_register_table_function(self.con, table_function.ptr);
188            if rc != ffi::DuckDBSuccess {
189                return Err(Error::DuckDBFailure(ffi::Error::new(rc), None));
190            }
191        }
192        Ok(())
193    }
194}
195
196#[cfg(test)]
197mod test {
198    use super::*;
199    use crate::core::{Inserter, LogicalTypeId};
200    use std::{
201        error::Error,
202        ffi::CString,
203        sync::atomic::{AtomicBool, Ordering},
204    };
205
206    struct HelloBindData {
207        name: String,
208    }
209
210    struct HelloInitData {
211        done: AtomicBool,
212    }
213
214    struct HelloVTab;
215
216    impl VTab for HelloVTab {
217        type InitData = HelloInitData;
218        type BindData = HelloBindData;
219
220        fn bind(bind: &BindInfo) -> Result<Self::BindData, Box<dyn std::error::Error>> {
221            bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar));
222            let name = bind.get_parameter(0).to_string();
223            Ok(HelloBindData { name })
224        }
225
226        fn init(_: &InitInfo) -> Result<Self::InitData, Box<dyn std::error::Error>> {
227            Ok(HelloInitData {
228                done: AtomicBool::new(false),
229            })
230        }
231
232        fn func(
233            func: &TableFunctionInfo<Self>,
234            output: &mut DataChunkHandle,
235        ) -> Result<(), Box<dyn std::error::Error>> {
236            let init_data = func.get_init_data();
237            let bind_data = func.get_bind_data();
238
239            if init_data.done.swap(true, Ordering::Relaxed) {
240                output.set_len(0);
241            } else {
242                let vector = output.flat_vector(0);
243                let result = CString::new(format!("Hello {}", bind_data.name))?;
244                vector.insert(0, result);
245                output.set_len(1);
246            }
247            Ok(())
248        }
249
250        fn parameters() -> Option<Vec<LogicalTypeHandle>> {
251            Some(vec![LogicalTypeHandle::from(LogicalTypeId::Varchar)])
252        }
253    }
254
255    struct HelloWithNamedVTab {}
256    impl VTab for HelloWithNamedVTab {
257        type InitData = HelloInitData;
258        type BindData = HelloBindData;
259
260        fn bind(bind: &BindInfo) -> Result<Self::BindData, Box<dyn Error>> {
261            bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar));
262            let name = bind.get_named_parameter("name").unwrap().to_string();
263            assert!(bind.get_named_parameter("unknown_name").is_none());
264            Ok(HelloBindData { name })
265        }
266
267        fn init(init_info: &InitInfo) -> Result<Self::InitData, Box<dyn Error>> {
268            HelloVTab::init(init_info)
269        }
270
271        fn func(func: &TableFunctionInfo<Self>, output: &mut DataChunkHandle) -> Result<(), Box<dyn Error>> {
272            let init_data = func.get_init_data();
273            let bind_data = func.get_bind_data();
274
275            if init_data.done.swap(true, Ordering::Relaxed) {
276                output.set_len(0);
277            } else {
278                let vector = output.flat_vector(0);
279                let result = CString::new(format!("Hello {}", bind_data.name))?;
280                vector.insert(0, result);
281                output.set_len(1);
282            }
283            Ok(())
284        }
285
286        fn named_parameters() -> Option<Vec<(String, LogicalTypeHandle)>> {
287            Some(vec![(
288                "name".to_string(),
289                LogicalTypeHandle::from(LogicalTypeId::Varchar),
290            )])
291        }
292    }
293
294    #[test]
295    fn test_table_function() -> Result<(), Box<dyn Error>> {
296        let conn = Connection::open_in_memory()?;
297        conn.register_table_function::<HelloVTab>("hello")?;
298
299        let val = conn.query_row("select * from hello('duckdb')", [], |row| <(String,)>::try_from(row))?;
300        assert_eq!(val, ("Hello duckdb".to_string(),));
301
302        Ok(())
303    }
304
305    #[test]
306    fn test_named_table_function() -> Result<(), Box<dyn Error>> {
307        let conn = Connection::open_in_memory()?;
308        conn.register_table_function::<HelloWithNamedVTab>("hello_named")?;
309
310        let val = conn.query_row("select * from hello_named(name = 'duckdb')", [], |row| {
311            <(String,)>::try_from(row)
312        })?;
313        assert_eq!(val, ("Hello duckdb".to_string(),));
314
315        Ok(())
316    }
317
318    // Test table function with extra info
319    struct PrefixVTab;
320
321    impl VTab for PrefixVTab {
322        type InitData = HelloInitData;
323        type BindData = HelloBindData;
324
325        fn bind(bind: &BindInfo) -> Result<Self::BindData, Box<dyn Error>> {
326            bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar));
327            let name = bind.get_parameter(0).to_string();
328            Ok(HelloBindData { name })
329        }
330
331        fn init(_: &InitInfo) -> Result<Self::InitData, Box<dyn Error>> {
332            Ok(HelloInitData {
333                done: AtomicBool::new(false),
334            })
335        }
336
337        fn func(func: &TableFunctionInfo<Self>, output: &mut DataChunkHandle) -> Result<(), Box<dyn Error>> {
338            let init_data = func.get_init_data();
339            let bind_data = func.get_bind_data();
340            let prefix = unsafe { &*func.get_extra_info::<String>() };
341
342            if init_data.done.swap(true, Ordering::Relaxed) {
343                output.set_len(0);
344            } else {
345                let vector = output.flat_vector(0);
346                let result = CString::new(format!("{prefix} {}", bind_data.name))?;
347                vector.insert(0, result);
348                output.set_len(1);
349            }
350            Ok(())
351        }
352
353        fn parameters() -> Option<Vec<LogicalTypeHandle>> {
354            Some(vec![LogicalTypeHandle::from(LogicalTypeId::Varchar)])
355        }
356    }
357
358    #[test]
359    fn test_table_function_with_extra_info() -> Result<(), Box<dyn Error>> {
360        let conn = Connection::open_in_memory()?;
361        conn.register_table_function_with_extra_info::<PrefixVTab, _>("greet", &"Howdy".to_string())?;
362
363        let val = conn.query_row("select * from greet('partner')", [], |row| <(String,)>::try_from(row))?;
364        assert_eq!(val, ("Howdy partner".to_string(),));
365
366        Ok(())
367    }
368}