1use crate::arg::{Arg, ArgType, NamedArg};
2use anyhow::Context;
3use arg::Args;
4use arrow::array::RecordBatch;
5use arrow::ffi_stream::FFI_ArrowArrayStream;
6use arrow_utils::{DynamicArrowArrayStreamReader, VecRecordBatchReader};
7use derive_builder::Builder;
8use serde::Serialize;
9use std::borrow::Cow;
10use std::ffi::c_char;
11use std::ptr::null_mut;
12use std::sync::Arc;
13
14pub mod arg;
15mod arrow_utils;
16
17pub unsafe fn create_raw(
27 registry: &FunctionRegistry,
28 arguments: *const i8,
29 named_arguments: *const i8,
30 timezone: *const i8,
31) -> anyhow::Result<Box<dyn TableFunction>> {
32 let arguments = if arguments.is_null() {
33 None
34 } else {
35 unsafe {
36 Some(std::str::from_utf8_unchecked(
37 std::ffi::CStr::from_ptr(arguments as *const c_char).to_bytes(),
38 ))
39 }
40 };
41 let named_arguments = if named_arguments.is_null() {
42 None
43 } else {
44 unsafe {
45 Some(std::str::from_utf8_unchecked(
46 std::ffi::CStr::from_ptr(named_arguments as *const c_char).to_bytes(),
47 ))
48 }
49 };
50 let timezone = unsafe {
51 std::str::from_utf8_unchecked(
52 std::ffi::CStr::from_ptr(timezone as *const c_char).to_bytes(),
53 )
54 };
55 create(registry, arguments, named_arguments, timezone)
56}
57
58pub fn create(
59 registry: &FunctionRegistry,
60 arguments: Option<&str>,
61 named_arguments: Option<&str>,
62 timezone: &str,
63) -> anyhow::Result<Box<dyn TableFunction>> {
64 let create_closure = &(registry.init);
65 let arguments = if let Some(arg) = arguments {
66 serde_json::from_str(arg).context("Failed to parse arguments from JSON")?
67 } else {
68 None
69 };
70 let named_arguments: Vec<NamedArg> = if let Some(arg) = named_arguments {
71 serde_json::from_str(arg).context("Failed to parse named arguments from JSON")?
72 } else {
73 vec![]
74 };
75 let ctx = FunctionContext {
76 arguments,
77 named_arguments: named_arguments
78 .into_iter()
79 .map(|named| (named.name, named.arg))
80 .collect(),
81 timezone: String::from(timezone),
83 };
84 create_closure(ctx)
85}
86
87type TableFunctionInitialize =
88 Arc<dyn Fn(FunctionContext) -> anyhow::Result<Box<dyn TableFunction>>>;
89
90#[derive(Builder, Clone)]
91pub struct FunctionRegistry {
92 #[builder(setter(into))]
93 name: &'static str,
94 init: TableFunctionInitialize,
95 #[builder(setter(strip_option, each(name = "signature", into)))]
96 signatures: Option<Vec<Signature>>,
97 #[builder(default = false)]
98 require_ordered: bool,
99}
100
101impl std::fmt::Debug for FunctionRegistry {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 f.debug_struct("FunctionRegistry")
104 .field("name", &self.name)
105 .field("init", &Arc::as_ptr(&self.init))
106 .field("signatures", &self.signatures)
107 .field("require_ordered", &self.require_ordered)
108 .finish()
109 }
110}
111
112impl FunctionRegistry {
113 pub fn name(&self) -> &'static str {
114 self.name
115 }
116
117 pub fn require_ordered(&self) -> bool {
118 self.require_ordered
119 }
120
121 pub fn signatures(&self) -> anyhow::Result<String> {
122 serde_json::to_string(&self.signatures).context("Failed to get signatures")
123 }
124
125 pub fn builder() -> FunctionRegistryBuilder {
126 FunctionRegistryBuilder::default()
127 }
128}
129
130#[derive(Clone, Debug, Serialize, Builder)]
131pub struct Signature {
132 #[builder(setter(each(name = "parameter", into)))]
133 pub(crate) parameters: Vec<Parameter>,
134}
135
136impl Signature {
137 pub fn builder() -> SignatureBuilder {
138 SignatureBuilder::default()
139 }
140
141 pub fn empty() -> Signature {
142 Signature { parameters: vec![] }
143 }
144}
145
146#[derive(Clone, Debug, Serialize)]
147pub struct Parameter {
148 pub(crate) name: Option<String>,
149 pub(crate) default: Option<Arg>,
150 pub(crate) arg_type: ArgType,
151}
152
153impl From<ArgType> for Parameter {
154 fn from(value: ArgType) -> Self {
155 Parameter {
156 name: None,
157 default: None,
158 arg_type: value,
159 }
160 }
161}
162
163impl<NAME, ARG> From<(Option<NAME>, ArgType, Option<ARG>)> for Parameter
164where
165 NAME: Into<Cow<'static, str>>,
166 ARG: Into<Arg>,
167{
168 fn from((name, arg_type, default): (Option<NAME>, ArgType, Option<ARG>)) -> Self {
169 Parameter {
170 name: name.map(|x| x.into().into_owned()),
171 default: default.map(|x| x.into()),
172 arg_type,
173 }
174 }
175}
176
177impl<P> From<Vec<P>> for Signature
178where
179 P: Into<Parameter>,
180{
181 fn from(value: Vec<P>) -> Self {
182 Signature {
183 parameters: value.into_iter().map(|x| x.into()).collect(),
184 }
185 }
186}
187
188pub struct FunctionContext {
189 pub arguments: Option<Args>,
190 pub named_arguments: Vec<(String, Arg)>,
191 pub timezone: String,
192}
193
194pub trait TableFunction {
195 fn process(&mut self, input: RecordBatch) -> anyhow::Result<Option<RecordBatch>>;
196
197 fn finalize(&mut self) -> anyhow::Result<Option<RecordBatch>> {
198 Ok(None)
199 }
200}
201
202pub unsafe fn process_raw(
215 func: &mut Box<dyn TableFunction>,
216 input_stream: i64,
217) -> anyhow::Result<i64> {
218 let mut stream_reader: DynamicArrowArrayStreamReader = unsafe {
219 DynamicArrowArrayStreamReader::from_raw(
220 input_stream as *mut arrow::ffi_stream::FFI_ArrowArrayStream,
221 )
222 .expect("Failed to construct DynamicArrowArrayStreamReader")
223 };
224 if let Some(record_batch) = stream_reader.next() {
225 let record_batch = record_batch.expect("cannot iterate over record batch");
226 let Some(output) = func.process(record_batch)? else {
227 return Ok(null_mut::<FFI_ArrowArrayStream>() as i64);
228 };
229 let boxed = Box::new(FFI_ArrowArrayStream::new(VecRecordBatchReader::new(vec![
230 output,
231 ])));
232 return Ok(Box::into_raw(boxed) as i64);
233 }
234
235 Ok(null_mut::<FFI_ArrowArrayStream>() as i64)
236}
237
238pub unsafe fn finalize_raw(func: &mut Box<dyn TableFunction>) -> anyhow::Result<i64> {
248 let Some(output) = func.finalize()? else {
249 return Ok(null_mut::<FFI_ArrowArrayStream>() as i64);
250 };
251 let boxed = Box::new(FFI_ArrowArrayStream::new(VecRecordBatchReader::new(vec![
252 output,
253 ])));
254 Ok(Box::into_raw(boxed) as i64)
255}
256
257pub fn anyhow_error_to_string(error: &anyhow::Error) -> String {
258 format!("{error:?}")
259}