rust_tvtf_api/
lib.rs

1use crate::arg::{Arg, ArgType, NamedArg};
2use anyhow::Context;
3use arg::Args;
4use arrow::array::RecordBatch;
5use arrow::ffi_stream::FFI_ArrowArrayStream;
6use arrow_utils::{DynamicArrowArrayStreamReader, VecRecordBatchReader};
7use derive_builder::Builder;
8use serde::Serialize;
9use std::borrow::Cow;
10use std::ffi::c_char;
11use std::ptr::null_mut;
12use std::sync::Arc;
13
14pub mod arg;
15mod arrow_utils;
16
17/// # SAFETY
18///
19/// This function is unsafe because it dereferences raw pointers.
20///
21/// For `arguments`, it expects a null-terminated UTF-8 string, it may be `nullptr`.
22///
23/// For `named_arguments`, it expects a null-terminated UTF-8 string, it may be `nullptr`.
24///
25/// For `timezone`, it expects a null-terminated UTF-8 string, it must be valid.
26pub unsafe fn create_raw(
27    registry: &FunctionRegistry,
28    arguments: *const i8,
29    named_arguments: *const i8,
30    timezone: *const i8,
31) -> anyhow::Result<Box<dyn TableFunction>> {
32    let arguments = if arguments.is_null() {
33        None
34    } else {
35        unsafe {
36            Some(std::str::from_utf8_unchecked(
37                std::ffi::CStr::from_ptr(arguments as *const c_char).to_bytes(),
38            ))
39        }
40    };
41    let named_arguments = if named_arguments.is_null() {
42        None
43    } else {
44        unsafe {
45            Some(std::str::from_utf8_unchecked(
46                std::ffi::CStr::from_ptr(named_arguments as *const c_char).to_bytes(),
47            ))
48        }
49    };
50    let timezone = unsafe {
51        std::str::from_utf8_unchecked(
52            std::ffi::CStr::from_ptr(timezone as *const c_char).to_bytes(),
53        )
54    };
55    create(registry, arguments, named_arguments, timezone)
56}
57
58pub fn create(
59    registry: &FunctionRegistry,
60    arguments: Option<&str>,
61    named_arguments: Option<&str>,
62    timezone: &str,
63) -> anyhow::Result<Box<dyn TableFunction>> {
64    let create_closure = &(registry.init);
65    let arguments = if let Some(arg) = arguments {
66        serde_json::from_str(arg).context("Failed to parse arguments from JSON")?
67    } else {
68        None
69    };
70    let named_arguments: Vec<NamedArg> = if let Some(arg) = named_arguments {
71        serde_json::from_str(arg).context("Failed to parse named arguments from JSON")?
72    } else {
73        vec![]
74    };
75    let ctx = FunctionContext {
76        arguments,
77        named_arguments: named_arguments
78            .into_iter()
79            .map(|named| (named.name, named.arg))
80            .collect(),
81        // TODO: parse as value instead of string
82        timezone: String::from(timezone),
83    };
84    create_closure(ctx)
85}
86
87type TableFunctionInitialize =
88    Arc<dyn Fn(FunctionContext) -> anyhow::Result<Box<dyn TableFunction>>>;
89
90#[derive(Builder, Clone)]
91pub struct FunctionRegistry {
92    #[builder(setter(into))]
93    name: &'static str,
94    init: TableFunctionInitialize,
95    #[builder(setter(strip_option, each(name = "signature", into)))]
96    signatures: Option<Vec<Signature>>,
97    #[builder(default = false)]
98    require_ordered: bool,
99}
100
101impl std::fmt::Debug for FunctionRegistry {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        f.debug_struct("FunctionRegistry")
104            .field("name", &self.name)
105            .field("init", &Arc::as_ptr(&self.init))
106            .field("signatures", &self.signatures)
107            .field("require_ordered", &self.require_ordered)
108            .finish()
109    }
110}
111
112impl FunctionRegistry {
113    pub fn name(&self) -> &'static str {
114        self.name
115    }
116
117    pub fn require_ordered(&self) -> bool {
118        self.require_ordered
119    }
120
121    pub fn signatures(&self) -> anyhow::Result<String> {
122        serde_json::to_string(&self.signatures).context("Failed to get signatures")
123    }
124
125    pub fn builder() -> FunctionRegistryBuilder {
126        FunctionRegistryBuilder::default()
127    }
128}
129
130#[derive(Clone, Debug, Serialize, Builder)]
131pub struct Signature {
132    #[builder(setter(each(name = "parameter", into)))]
133    pub(crate) parameters: Vec<Parameter>,
134}
135
136impl Signature {
137    pub fn builder() -> SignatureBuilder {
138        SignatureBuilder::default()
139    }
140
141    pub fn empty() -> Signature {
142        Signature { parameters: vec![] }
143    }
144}
145
146#[derive(Clone, Debug, Serialize)]
147pub struct Parameter {
148    pub(crate) name: Option<String>,
149    pub(crate) default: Option<Arg>,
150    pub(crate) arg_type: ArgType,
151}
152
153impl From<ArgType> for Parameter {
154    fn from(value: ArgType) -> Self {
155        Parameter {
156            name: None,
157            default: None,
158            arg_type: value,
159        }
160    }
161}
162
163impl<NAME, ARG> From<(Option<NAME>, ArgType, Option<ARG>)> for Parameter
164where
165    NAME: Into<Cow<'static, str>>,
166    ARG: Into<Arg>,
167{
168    fn from((name, arg_type, default): (Option<NAME>, ArgType, Option<ARG>)) -> Self {
169        Parameter {
170            name: name.map(|x| x.into().into_owned()),
171            default: default.map(|x| x.into()),
172            arg_type,
173        }
174    }
175}
176
177impl<P> From<Vec<P>> for Signature
178where
179    P: Into<Parameter>,
180{
181    fn from(value: Vec<P>) -> Self {
182        Signature {
183            parameters: value.into_iter().map(|x| x.into()).collect(),
184        }
185    }
186}
187
188pub struct FunctionContext {
189    pub arguments: Option<Args>,
190    pub named_arguments: Vec<(String, Arg)>,
191    pub timezone: String,
192}
193
194pub trait TableFunction {
195    fn process(&mut self, input: RecordBatch) -> anyhow::Result<Option<RecordBatch>>;
196
197    fn finalize(&mut self) -> anyhow::Result<Option<RecordBatch>> {
198        Ok(None)
199    }
200}
201
202/// Wrapper over `process` method, the type of `input_stream` is `*mut FFI_ArrowArrayStream`.
203/// Due to the limitation of zngur, use `i64` here as `void*` and casting to `*mut FFI_ArrowArrayStream`.
204///
205/// The input_stream should contain 0 or 1 RecordBatch
206///
207/// Returns may be `nullptr`. Otherwise, returns `*mut FFI_ArrowArrayStream`
208///
209/// # SAFETY
210///
211/// This function is unsafe because it dereferences a raw pointer and
212/// expects the caller to ensure that the pointer is valid and
213/// points to a `Box<dyn TableFunction>`.
214pub unsafe fn process_raw(
215    func: &mut Box<dyn TableFunction>,
216    input_stream: i64,
217) -> anyhow::Result<i64> {
218    let mut stream_reader: DynamicArrowArrayStreamReader = unsafe {
219        DynamicArrowArrayStreamReader::from_raw(
220            input_stream as *mut arrow::ffi_stream::FFI_ArrowArrayStream,
221        )
222        .expect("Failed to construct DynamicArrowArrayStreamReader")
223    };
224    if let Some(record_batch) = stream_reader.next() {
225        let record_batch = record_batch.expect("cannot iterate over record batch");
226        let Some(output) = func.process(record_batch)? else {
227            return Ok(null_mut::<FFI_ArrowArrayStream>() as i64);
228        };
229        let boxed = Box::new(FFI_ArrowArrayStream::new(VecRecordBatchReader::new(vec![
230            output,
231        ])));
232        return Ok(Box::into_raw(boxed) as i64);
233    }
234
235    Ok(null_mut::<FFI_ArrowArrayStream>() as i64)
236}
237
238/// Wrapper over `finalize` method.
239///
240/// Returns may be `nullptr`. Otherwise, returns `i64` as `*mut FFI_ArrowArrayStream`.
241///
242/// # SAFETY
243///
244/// This function is unsafe because it dereferences a raw pointer and
245/// expects the caller to ensure that the pointer is valid and
246/// points to a `Box<dyn TableFunction>`.
247pub unsafe fn finalize_raw(func: &mut Box<dyn TableFunction>) -> anyhow::Result<i64> {
248    let Some(output) = func.finalize()? else {
249        return Ok(null_mut::<FFI_ArrowArrayStream>() as i64);
250    };
251    let boxed = Box::new(FFI_ArrowArrayStream::new(VecRecordBatchReader::new(vec![
252        output,
253    ])));
254    Ok(Box::into_raw(boxed) as i64)
255}
256
257pub fn anyhow_error_to_string(error: &anyhow::Error) -> String {
258    format!("{error:?}")
259}