1use anyhow::Context;
2use arg::Args;
3use arrow::array::RecordBatch;
4use arrow::ffi_stream::FFI_ArrowArrayStream;
5use arrow_utils::{DynamicArrowArrayStreamReader, VecRecordBatchReader};
6use derive_builder::Builder;
7use serde::Serialize;
8use std::ffi::c_char;
9use std::ptr::null_mut;
10use std::sync::Arc;
11
12use crate::arg::ArgType;
13
14pub mod arg;
15mod arrow_utils;
16
17pub unsafe fn create_raw(
25 registry: &FunctionRegistry,
26 parameters: *const i8,
27 timezone: *const i8,
28) -> anyhow::Result<Box<dyn TableFunction>> {
29 let parameters = if parameters.is_null() {
30 None
31 } else {
32 unsafe {
33 Some(std::str::from_utf8_unchecked(
34 std::ffi::CStr::from_ptr(parameters as *const c_char).to_bytes(),
35 ))
36 }
37 };
38 let timezone = unsafe {
39 std::str::from_utf8_unchecked(
40 std::ffi::CStr::from_ptr(timezone as *const c_char).to_bytes(),
41 )
42 };
43 create(registry, parameters, timezone)
44}
45
46pub fn create(
47 registry: &FunctionRegistry,
48 parameters: Option<&str>,
49 timezone: &str,
50) -> anyhow::Result<Box<dyn TableFunction>> {
51 let create_closure = &(registry.init);
52 let parameters = if let Some(param) = parameters {
53 serde_json::from_str(param).context("serde json failed")?
54 } else {
55 None
56 };
57 let ctx = FunctionContext {
58 parameters,
59 timezone: String::from(timezone),
61 };
62 create_closure(ctx)
63}
64
65type TableFunctionInitialize =
66 Arc<dyn Fn(FunctionContext) -> anyhow::Result<Box<dyn TableFunction>>>;
67
68#[derive(Builder)]
69pub struct FunctionRegistry {
70 #[builder(setter(into))]
71 name: &'static str,
72 init: TableFunctionInitialize,
73 #[builder(setter(strip_option, each(name = "signature", into)))]
74 signatures: Option<Vec<Signature>>,
75}
76
77impl std::fmt::Debug for FunctionRegistry {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("FunctionRegistry")
80 .field("name", &self.name)
81 .field("init", &Arc::as_ptr(&self.init))
82 .field("signatures", &self.signatures)
83 .finish()
84 }
85}
86
87impl FunctionRegistry {
88 pub fn name(&self) -> &'static str {
89 self.name
90 }
91
92 pub fn signatures(&self) -> anyhow::Result<String> {
93 serde_json::to_string(&self.signatures).context("Failed to get signatures")
94 }
95
96 pub fn builder() -> FunctionRegistryBuilder {
97 FunctionRegistryBuilder::default()
98 }
99}
100
101#[derive(Clone, Debug, Serialize)]
102pub struct Signature {
103 pub args: Vec<ArgType>,
104}
105
106impl From<Vec<ArgType>> for Signature {
107 fn from(value: Vec<ArgType>) -> Self {
108 Signature { args: value }
109 }
110}
111
112pub struct FunctionContext {
113 pub parameters: Option<Args>,
114 pub timezone: String,
115}
116
117pub trait TableFunction {
118 fn process(&mut self, input: RecordBatch) -> anyhow::Result<Option<RecordBatch>>;
119
120 fn finalize(&mut self) -> anyhow::Result<Option<RecordBatch>> {
121 Ok(None)
122 }
123}
124
125pub unsafe fn process_raw(
138 func: &mut Box<dyn TableFunction>,
139 input_stream: i64,
140) -> anyhow::Result<i64> {
141 let mut stream_reader: DynamicArrowArrayStreamReader = unsafe {
142 DynamicArrowArrayStreamReader::from_raw(
143 input_stream as *mut arrow::ffi_stream::FFI_ArrowArrayStream,
144 )
145 .expect("Failed to construct DynamicArrowArrayStreamReader")
146 };
147 if let Some(record_batch) = stream_reader.next() {
148 let record_batch = record_batch.expect("cannot iterate over record batch");
149 let Some(output) = func.process(record_batch)? else {
150 return Ok(null_mut::<FFI_ArrowArrayStream>() as i64);
151 };
152 let boxed = Box::new(FFI_ArrowArrayStream::new(VecRecordBatchReader::new(vec![
153 output,
154 ])));
155 return Ok(Box::into_raw(boxed) as i64);
156 }
157
158 Ok(null_mut::<FFI_ArrowArrayStream>() as i64)
159}
160
161pub unsafe fn finalize_raw(func: &mut Box<dyn TableFunction>) -> anyhow::Result<i64> {
171 let Some(output) = func.finalize()? else {
172 return Ok(null_mut::<FFI_ArrowArrayStream>() as i64);
173 };
174 let boxed = Box::new(FFI_ArrowArrayStream::new(VecRecordBatchReader::new(vec![
175 output,
176 ])));
177 Ok(Box::into_raw(boxed) as i64)
178}