rust_tvtf_api/
lib.rs

1use anyhow::Context;
2use arg::Args;
3use arrow::array::RecordBatch;
4use arrow::ffi_stream::FFI_ArrowArrayStream;
5use arrow_utils::{DynamicArrowArrayStreamReader, VecRecordBatchReader};
6use derive_builder::Builder;
7use serde::Serialize;
8use std::ffi::c_char;
9use std::ptr::null_mut;
10use std::sync::Arc;
11
12use crate::arg::ArgType;
13
14pub mod arg;
15mod arrow_utils;
16
17/// # SAFETY
18///
19/// This function is unsafe because it dereferences raw pointers.
20///
21/// For `parameters`, it expects a null-terminated UTF-8 string, it may be `nullptr`.
22///
23/// For `timezone`, it expects a null-terminated UTF-8 string, it must be valid.
24pub unsafe fn create_raw(
25    registry: &FunctionRegistry,
26    parameters: *const i8,
27    timezone: *const i8,
28) -> anyhow::Result<Box<dyn TableFunction>> {
29    let parameters = if parameters.is_null() {
30        None
31    } else {
32        unsafe {
33            Some(std::str::from_utf8_unchecked(
34                std::ffi::CStr::from_ptr(parameters as *const c_char).to_bytes(),
35            ))
36        }
37    };
38    let timezone = unsafe {
39        std::str::from_utf8_unchecked(
40            std::ffi::CStr::from_ptr(timezone as *const c_char).to_bytes(),
41        )
42    };
43    create(registry, parameters, timezone)
44}
45
46pub fn create(
47    registry: &FunctionRegistry,
48    parameters: Option<&str>,
49    timezone: &str,
50) -> anyhow::Result<Box<dyn TableFunction>> {
51    let create_closure = &(registry.init);
52    let parameters = if let Some(param) = parameters {
53        serde_json::from_str(param).context("serde json failed")?
54    } else {
55        None
56    };
57    let ctx = FunctionContext {
58        parameters,
59        // TODO: parse as value instead of string
60        timezone: String::from(timezone),
61    };
62    create_closure(ctx)
63}
64
65type TableFunctionInitialize =
66    Arc<dyn Fn(FunctionContext) -> anyhow::Result<Box<dyn TableFunction>>>;
67
68#[derive(Builder)]
69pub struct FunctionRegistry {
70    #[builder(setter(into))]
71    name: &'static str,
72    init: TableFunctionInitialize,
73    #[builder(setter(strip_option, each(name = "signature", into)))]
74    signatures: Option<Vec<Signature>>,
75}
76
77impl std::fmt::Debug for FunctionRegistry {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("FunctionRegistry")
80            .field("name", &self.name)
81            .field("init", &Arc::as_ptr(&self.init))
82            .field("signatures", &self.signatures)
83            .finish()
84    }
85}
86
87impl FunctionRegistry {
88    pub fn name(&self) -> &'static str {
89        self.name
90    }
91
92    pub fn signatures(&self) -> anyhow::Result<String> {
93        serde_json::to_string(&self.signatures).context("Failed to get signatures")
94    }
95
96    pub fn builder() -> FunctionRegistryBuilder {
97        FunctionRegistryBuilder::default()
98    }
99}
100
101#[derive(Clone, Debug, Serialize)]
102pub struct Signature {
103    pub args: Vec<ArgType>,
104}
105
106impl From<Vec<ArgType>> for Signature {
107    fn from(value: Vec<ArgType>) -> Self {
108        Signature { args: value }
109    }
110}
111
112pub struct FunctionContext {
113    pub parameters: Option<Args>,
114    pub timezone: String,
115}
116
117pub trait TableFunction {
118    fn process(&mut self, input: RecordBatch) -> anyhow::Result<Option<RecordBatch>>;
119
120    fn finalize(&mut self) -> anyhow::Result<Option<RecordBatch>> {
121        Ok(None)
122    }
123}
124
125/// Wrapper over `process` method, the type of `input_stream` is `*mut FFI_ArrowArrayStream`.
126/// Due to the limitation of zngur, use `i64` here as `void*` and casting to `*mut FFI_ArrowArrayStream`.
127///
128/// The input_stream should contain 0 or 1 RecordBatch
129///
130/// Returns may be `nullptr`. Otherwise, returns `*mut FFI_ArrowArrayStream`
131///
132/// # SAFETY
133///
134/// This function is unsafe because it dereferences a raw pointer and
135/// expects the caller to ensure that the pointer is valid and
136/// points to a `Box<dyn TableFunction>`.
137pub unsafe fn process_raw(
138    func: &mut Box<dyn TableFunction>,
139    input_stream: i64,
140) -> anyhow::Result<i64> {
141    let mut stream_reader: DynamicArrowArrayStreamReader = unsafe {
142        DynamicArrowArrayStreamReader::from_raw(
143            input_stream as *mut arrow::ffi_stream::FFI_ArrowArrayStream,
144        )
145        .expect("Failed to construct DynamicArrowArrayStreamReader")
146    };
147    if let Some(record_batch) = stream_reader.next() {
148        let record_batch = record_batch.expect("cannot iterate over record batch");
149        let Some(output) = func.process(record_batch)? else {
150            return Ok(null_mut::<FFI_ArrowArrayStream>() as i64);
151        };
152        let boxed = Box::new(FFI_ArrowArrayStream::new(VecRecordBatchReader::new(vec![
153            output,
154        ])));
155        return Ok(Box::into_raw(boxed) as i64);
156    }
157
158    Ok(null_mut::<FFI_ArrowArrayStream>() as i64)
159}
160
161/// Wrapper over `finalize` method.
162///
163/// Returns may be `nullptr`. Otherwise, returns `i64` as `*mut FFI_ArrowArrayStream`.
164///
165/// # SAFETY
166///
167/// This function is unsafe because it dereferences a raw pointer and
168/// expects the caller to ensure that the pointer is valid and
169/// points to a `Box<dyn TableFunction>`.
170pub unsafe fn finalize_raw(func: &mut Box<dyn TableFunction>) -> anyhow::Result<i64> {
171    let Some(output) = func.finalize()? else {
172        return Ok(null_mut::<FFI_ArrowArrayStream>() as i64);
173    };
174    let boxed = Box::new(FFI_ArrowArrayStream::new(VecRecordBatchReader::new(vec![
175        output,
176    ])));
177    Ok(Box::into_raw(boxed) as i64)
178}