rust_tvtf_api/
lib.rs

1use crate::arg::{Arg, ArgType, NamedArg};
2use anyhow::Context;
3use arg::Args;
4use arrow::array::RecordBatch;
5use arrow::ffi_stream::FFI_ArrowArrayStream;
6use arrow_utils::{DynamicArrowArrayStreamReader, VecRecordBatchReader};
7use derive_builder::Builder;
8use serde::Serialize;
9use std::borrow::Cow;
10use std::ffi::c_char;
11use std::ptr::null_mut;
12use std::sync::Arc;
13
14pub mod arg;
15mod arrow_utils;
16
17/// # SAFETY
18///
19/// This function is unsafe because it dereferences raw pointers.
20///
21/// For `arguments`, it expects a null-terminated UTF-8 string, it may be `nullptr`.
22///
23/// For `named_arguments`, it expects a null-terminated UTF-8 string, it may be `nullptr`.
24///
25/// For `timezone`, it expects a null-terminated UTF-8 string, it must be valid.
26pub unsafe fn create_raw(
27    registry: &FunctionRegistry,
28    arguments: *const i8,
29    named_arguments: *const i8,
30    timezone: *const i8,
31) -> anyhow::Result<Box<dyn TableFunction>> {
32    let arguments = if arguments.is_null() {
33        None
34    } else {
35        unsafe {
36            Some(std::str::from_utf8_unchecked(
37                std::ffi::CStr::from_ptr(arguments as *const c_char).to_bytes(),
38            ))
39        }
40    };
41    let named_arguments = if named_arguments.is_null() {
42        None
43    } else {
44        unsafe {
45            Some(std::str::from_utf8_unchecked(
46                std::ffi::CStr::from_ptr(named_arguments as *const c_char).to_bytes(),
47            ))
48        }
49    };
50    let timezone = unsafe {
51        std::str::from_utf8_unchecked(
52            std::ffi::CStr::from_ptr(timezone as *const c_char).to_bytes(),
53        )
54    };
55    create(registry, arguments, named_arguments, timezone)
56}
57
58pub fn create(
59    registry: &FunctionRegistry,
60    arguments: Option<&str>,
61    named_arguments: Option<&str>,
62    timezone: &str,
63) -> anyhow::Result<Box<dyn TableFunction>> {
64    let create_closure = &(registry.init);
65    let arguments = if let Some(arg) = arguments {
66        serde_json::from_str(arg).context("Failed to parse arguments from JSON")?
67    } else {
68        None
69    };
70    let named_arguments: Vec<NamedArg> = if let Some(arg) = named_arguments {
71        serde_json::from_str(arg).context("Failed to parse named arguments from JSON")?
72    } else {
73        vec![]
74    };
75    let ctx = FunctionContext {
76        arguments,
77        named_arguments: named_arguments
78            .into_iter()
79            .map(|named| (named.name, named.arg))
80            .collect(),
81        // TODO: parse as value instead of string
82        timezone: String::from(timezone),
83    };
84    create_closure(ctx)
85}
86
87type TableFunctionInitialize =
88    Arc<dyn Fn(FunctionContext) -> anyhow::Result<Box<dyn TableFunction>>>;
89
90#[derive(Builder)]
91pub struct FunctionRegistry {
92    #[builder(setter(into))]
93    name: &'static str,
94    init: TableFunctionInitialize,
95    #[builder(setter(strip_option, each(name = "signature", into)))]
96    signatures: Option<Vec<Signature>>,
97}
98
99impl std::fmt::Debug for FunctionRegistry {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("FunctionRegistry")
102            .field("name", &self.name)
103            .field("init", &Arc::as_ptr(&self.init))
104            .field("signatures", &self.signatures)
105            .finish()
106    }
107}
108
109impl FunctionRegistry {
110    pub fn name(&self) -> &'static str {
111        self.name
112    }
113
114    pub fn signatures(&self) -> anyhow::Result<String> {
115        serde_json::to_string(&self.signatures).context("Failed to get signatures")
116    }
117
118    pub fn builder() -> FunctionRegistryBuilder {
119        FunctionRegistryBuilder::default()
120    }
121}
122
123#[derive(Clone, Debug, Serialize, Builder)]
124pub struct Signature {
125    #[builder(setter(each(name = "parameter", into)))]
126    pub(crate) parameters: Vec<Parameter>,
127}
128
129impl Signature {
130    pub fn builder() -> SignatureBuilder {
131        SignatureBuilder::default()
132    }
133
134    pub fn empty() -> Signature {
135        Signature { parameters: vec![] }
136    }
137}
138
139#[derive(Clone, Debug, Serialize)]
140pub struct Parameter {
141    pub(crate) name: Option<String>,
142    pub(crate) default: Option<Arg>,
143    pub(crate) arg_type: ArgType,
144}
145
146impl From<ArgType> for Parameter {
147    fn from(value: ArgType) -> Self {
148        Parameter {
149            name: None,
150            default: None,
151            arg_type: value,
152        }
153    }
154}
155
156impl<NAME, ARG> From<(Option<NAME>, ArgType, Option<ARG>)> for Parameter
157where
158    NAME: Into<Cow<'static, str>>,
159    ARG: Into<Arg>,
160{
161    fn from((name, arg_type, default): (Option<NAME>, ArgType, Option<ARG>)) -> Self {
162        Parameter {
163            name: name.map(|x| x.into().into_owned()),
164            default: default.map(|x| x.into()),
165            arg_type,
166        }
167    }
168}
169
170impl<P> From<Vec<P>> for Signature
171where
172    P: Into<Parameter>,
173{
174    fn from(value: Vec<P>) -> Self {
175        Signature {
176            parameters: value.into_iter().map(|x| x.into()).collect(),
177        }
178    }
179}
180
181pub struct FunctionContext {
182    pub arguments: Option<Args>,
183    pub named_arguments: Vec<(String, Arg)>,
184    pub timezone: String,
185}
186
187pub trait TableFunction {
188    fn process(&mut self, input: RecordBatch) -> anyhow::Result<Option<RecordBatch>>;
189
190    fn finalize(&mut self) -> anyhow::Result<Option<RecordBatch>> {
191        Ok(None)
192    }
193}
194
195/// Wrapper over `process` method, the type of `input_stream` is `*mut FFI_ArrowArrayStream`.
196/// Due to the limitation of zngur, use `i64` here as `void*` and casting to `*mut FFI_ArrowArrayStream`.
197///
198/// The input_stream should contain 0 or 1 RecordBatch
199///
200/// Returns may be `nullptr`. Otherwise, returns `*mut FFI_ArrowArrayStream`
201///
202/// # SAFETY
203///
204/// This function is unsafe because it dereferences a raw pointer and
205/// expects the caller to ensure that the pointer is valid and
206/// points to a `Box<dyn TableFunction>`.
207pub unsafe fn process_raw(
208    func: &mut Box<dyn TableFunction>,
209    input_stream: i64,
210) -> anyhow::Result<i64> {
211    let mut stream_reader: DynamicArrowArrayStreamReader = unsafe {
212        DynamicArrowArrayStreamReader::from_raw(
213            input_stream as *mut arrow::ffi_stream::FFI_ArrowArrayStream,
214        )
215        .expect("Failed to construct DynamicArrowArrayStreamReader")
216    };
217    if let Some(record_batch) = stream_reader.next() {
218        let record_batch = record_batch.expect("cannot iterate over record batch");
219        let Some(output) = func.process(record_batch)? else {
220            return Ok(null_mut::<FFI_ArrowArrayStream>() as i64);
221        };
222        let boxed = Box::new(FFI_ArrowArrayStream::new(VecRecordBatchReader::new(vec![
223            output,
224        ])));
225        return Ok(Box::into_raw(boxed) as i64);
226    }
227
228    Ok(null_mut::<FFI_ArrowArrayStream>() as i64)
229}
230
231/// Wrapper over `finalize` method.
232///
233/// Returns may be `nullptr`. Otherwise, returns `i64` as `*mut FFI_ArrowArrayStream`.
234///
235/// # SAFETY
236///
237/// This function is unsafe because it dereferences a raw pointer and
238/// expects the caller to ensure that the pointer is valid and
239/// points to a `Box<dyn TableFunction>`.
240pub unsafe fn finalize_raw(func: &mut Box<dyn TableFunction>) -> anyhow::Result<i64> {
241    let Some(output) = func.finalize()? else {
242        return Ok(null_mut::<FFI_ArrowArrayStream>() as i64);
243    };
244    let boxed = Box::new(FFI_ArrowArrayStream::new(VecRecordBatchReader::new(vec![
245        output,
246    ])));
247    Ok(Box::into_raw(boxed) as i64)
248}