uni_query/query/executor/
plugin_adapter.rs1use std::sync::{Arc, OnceLock};
18
19use arrow_array::{Array, BooleanArray, Float64Array, Int64Array, LargeBinaryArray, StringArray};
20use arrow_schema::DataType;
21use datafusion::logical_expr::{ColumnarValue, Volatility};
22use uni_common::Value;
23use uni_plugin::FnError;
24use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling, ScalarPluginFn};
25
26use uni_query_functions::custom_functions::CustomScalarFn;
27
28pub struct ValueRowFn {
40 name: String,
41 signature: OnceLock<FnSignature>,
42 inner: CustomScalarFn,
43}
44
45impl ValueRowFn {
46 #[must_use]
48 pub fn new(name: impl Into<String>, inner: CustomScalarFn) -> Self {
49 Self {
50 name: name.into(),
51 signature: OnceLock::new(),
52 inner,
53 }
54 }
55}
56
57impl std::fmt::Debug for ValueRowFn {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("ValueRowFn")
60 .field("name", &self.name)
61 .finish_non_exhaustive()
62 }
63}
64
65impl ScalarPluginFn for ValueRowFn {
66 fn signature(&self) -> &FnSignature {
67 self.signature.get_or_init(|| FnSignature {
71 args: vec![ArgType::Variadic(Box::new(ArgType::CypherValue))],
74 returns: ArgType::CypherValue,
75 volatility: Volatility::Volatile,
76 null_handling: NullHandling::UserHandled,
77 })
78 }
79
80 fn invoke(&self, args: &[ColumnarValue], rows: usize) -> Result<ColumnarValue, FnError> {
81 let materialized: Vec<Vec<Value>> = args
83 .iter()
84 .map(|c| columnar_to_values(c, rows))
85 .collect::<Result<Vec<_>, _>>()?;
86
87 let mut out_values: Vec<Value> = Vec::with_capacity(rows);
88 for row in 0..rows {
89 let mut row_args: Vec<Value> = Vec::with_capacity(materialized.len());
90 for col in &materialized {
91 row_args.push(col[row].clone());
92 }
93 let v = (self.inner)(&row_args).map_err(|e| {
94 FnError::new(
95 0x1000,
96 format!("legacy scalar fn `{}` failed: {e}", self.name),
97 )
98 })?;
99 out_values.push(v);
100 }
101
102 values_to_large_binary(&out_values)
104 }
105}
106
107fn columnar_to_values(c: &ColumnarValue, rows: usize) -> Result<Vec<Value>, FnError> {
109 match c {
110 ColumnarValue::Scalar(s) => {
111 let v = scalar_to_value(s);
112 Ok(vec![v; rows])
113 }
114 ColumnarValue::Array(arr) => array_to_values(arr.as_ref()),
115 }
116}
117
118fn scalar_to_value(s: &datafusion::scalar::ScalarValue) -> Value {
119 use datafusion::scalar::ScalarValue;
120 match s {
121 ScalarValue::Null => Value::Null,
122 ScalarValue::Boolean(Some(b)) => Value::Bool(*b),
123 ScalarValue::Boolean(None) => Value::Null,
124 ScalarValue::Int64(Some(i)) => Value::Int(*i),
125 ScalarValue::Int64(None) => Value::Null,
126 ScalarValue::Float64(Some(f)) => Value::Float(*f),
127 ScalarValue::Float64(None) => Value::Null,
128 ScalarValue::Utf8(Some(s)) => Value::String(s.clone()),
129 ScalarValue::Utf8(None) => Value::Null,
130 ScalarValue::LargeBinary(Some(bytes)) => decode_cypher_value(bytes).unwrap_or(Value::Null),
131 ScalarValue::LargeBinary(None) => Value::Null,
132 _ => Value::String(s.to_string()),
136 }
137}
138
139fn array_to_values(arr: &dyn Array) -> Result<Vec<Value>, FnError> {
140 let n = arr.len();
141 let mut out = Vec::with_capacity(n);
142
143 match arr.data_type() {
144 DataType::Boolean => {
145 let a = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
146 FnError::new(FnError::CODE_TYPE_COERCION, "expected BooleanArray")
147 })?;
148 for i in 0..n {
149 out.push(if a.is_null(i) {
150 Value::Null
151 } else {
152 Value::Bool(a.value(i))
153 });
154 }
155 }
156 DataType::Int64 => {
157 let a = arr
158 .as_any()
159 .downcast_ref::<Int64Array>()
160 .ok_or_else(|| FnError::new(FnError::CODE_TYPE_COERCION, "expected Int64Array"))?;
161 for i in 0..n {
162 out.push(if a.is_null(i) {
163 Value::Null
164 } else {
165 Value::Int(a.value(i))
166 });
167 }
168 }
169 DataType::Float64 => {
170 let a = arr.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
171 FnError::new(FnError::CODE_TYPE_COERCION, "expected Float64Array")
172 })?;
173 for i in 0..n {
174 out.push(if a.is_null(i) {
175 Value::Null
176 } else {
177 Value::Float(a.value(i))
178 });
179 }
180 }
181 DataType::Utf8 => {
182 let a = arr
183 .as_any()
184 .downcast_ref::<StringArray>()
185 .ok_or_else(|| FnError::new(FnError::CODE_TYPE_COERCION, "expected StringArray"))?;
186 for i in 0..n {
187 out.push(if a.is_null(i) {
188 Value::Null
189 } else {
190 Value::String(a.value(i).to_owned())
191 });
192 }
193 }
194 DataType::LargeBinary => {
195 let a = arr
196 .as_any()
197 .downcast_ref::<LargeBinaryArray>()
198 .ok_or_else(|| {
199 FnError::new(FnError::CODE_TYPE_COERCION, "expected LargeBinaryArray")
200 })?;
201 for i in 0..n {
202 out.push(if a.is_null(i) {
203 Value::Null
204 } else {
205 decode_cypher_value(a.value(i)).unwrap_or(Value::Null)
206 });
207 }
208 }
209 other => {
210 return Err(FnError::new(
211 FnError::CODE_TYPE_COERCION,
212 format!("unsupported arrow type in legacy adapter: {other:?}"),
213 ));
214 }
215 }
216
217 Ok(out)
218}
219
220fn values_to_large_binary(values: &[Value]) -> Result<ColumnarValue, FnError> {
221 let mut builder = arrow_array::builder::LargeBinaryBuilder::with_capacity(values.len(), 0);
222 for v in values {
223 match v {
224 Value::Null => builder.append_null(),
225 _ => {
226 let bytes = encode_cypher_value(v)?;
227 builder.append_value(&bytes);
228 }
229 }
230 }
231 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
232}
233
234fn encode_cypher_value(v: &Value) -> Result<Vec<u8>, FnError> {
235 Ok(uni_common::cypher_value_codec::encode(v))
243}
244
245fn decode_cypher_value(bytes: &[u8]) -> Option<Value> {
246 uni_common::cypher_value_codec::decode(bytes).ok()
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use uni_common::Value;
253
254 #[test]
255 fn value_row_fn_invokes_closure_for_each_row() {
256 let closure: CustomScalarFn = Arc::new(|args: &[Value]| {
257 match args.first() {
259 Some(Value::Int(i)) => Ok(Value::Int(i * 2)),
260 _ => Ok(Value::Null),
261 }
262 });
263 let f = ValueRowFn::new("double", closure);
264 let input =
265 ColumnarValue::Array(Arc::new(Int64Array::from(vec![1_i64, 2, 3])) as Arc<dyn Array>);
266 let out = f.invoke(&[input], 3).expect("invoke");
267 let arr = match out {
269 ColumnarValue::Array(a) => a,
270 _ => panic!("expected array output"),
271 };
272 let lb = arr
273 .as_any()
274 .downcast_ref::<LargeBinaryArray>()
275 .expect("LargeBinaryArray");
276 let vs: Vec<Value> = (0..lb.len())
277 .map(|i| decode_cypher_value(lb.value(i)).unwrap())
278 .collect();
279 assert_eq!(vs, vec![Value::Int(2), Value::Int(4), Value::Int(6)]);
280 }
281
282 #[test]
283 fn value_row_fn_handles_nulls() {
284 let closure: CustomScalarFn =
285 Arc::new(|args: &[Value]| Ok(args.first().cloned().unwrap_or(Value::Null)));
286 let f = ValueRowFn::new("identity", closure);
287 let input = ColumnarValue::Array(
288 Arc::new(Int64Array::from(vec![Some(1), None, Some(3)])) as Arc<dyn Array>
289 );
290 let out = f.invoke(&[input], 3).expect("invoke");
291 let arr = match out {
292 ColumnarValue::Array(a) => a,
293 _ => panic!(),
294 };
295 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
296 assert!(!lb.is_null(0));
297 assert!(lb.is_null(1));
298 assert!(!lb.is_null(2));
299 }
300}