Skip to main content

uni_query/query/executor/
plugin_adapter.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Adapter bridging legacy `CustomScalarFn` closures to `uni-plugin`'s
5//! `ScalarPluginFn` trait.
6//!
7//! M2's facade keeps the public `CustomFunctionRegistry::register` API
8//! intact while routing registrations into a shadow `PluginRegistry`. This
9//! adapter is the bridge: it wraps a `Fn(&[Value]) -> Result<Value>`
10//! closure into a type implementing `ScalarPluginFn` so it can live in the
11//! plugin registry.
12//!
13//! As subsequent M2 commits migrate built-ins to native Arrow signatures
14//! (`ArgType::Primitive`), this row-per-call adapter remains as the slow
15//! path for legacy registrations declaring `ArgType::CypherValue`.
16
17use std::sync::{Arc, OnceLock};
18
19use arrow_array::{Array, BooleanArray, Float64Array, Int64Array, LargeBinaryArray, StringArray};
20use arrow_schema::DataType;
21use datafusion::logical_expr::{ColumnarValue, Volatility};
22use uni_common::Value;
23use uni_plugin::FnError;
24use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling, ScalarPluginFn};
25
26use uni_query_functions::custom_functions::CustomScalarFn;
27
28/// `ScalarPluginFn` impl that wraps a legacy `Fn(&[Value]) -> Result<Value>`
29/// closure.
30///
31/// Used by `CustomFunctionRegistry` to populate its shadow `PluginRegistry`.
32/// Iterates rows, converts each row's columns to `Value`s, invokes the
33/// closure, and collects results into a `LargeBinary` column (the legacy
34/// CypherValue transport).
35///
36/// This is the *slow path* — primitive-typed UDFs will go through a
37/// `NativeArrowUdf` (M2 follow-up) that skips the per-row `Value`
38/// round-trip entirely.
39pub struct ValueRowFn {
40    name: String,
41    signature: OnceLock<FnSignature>,
42    inner: CustomScalarFn,
43}
44
45impl ValueRowFn {
46    /// Wrap a legacy closure into a plugin-compatible scalar function.
47    #[must_use]
48    pub fn new(name: impl Into<String>, inner: CustomScalarFn) -> Self {
49        Self {
50            name: name.into(),
51            signature: OnceLock::new(),
52            inner,
53        }
54    }
55}
56
57impl std::fmt::Debug for ValueRowFn {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("ValueRowFn")
60            .field("name", &self.name)
61            .finish_non_exhaustive()
62    }
63}
64
65impl ScalarPluginFn for ValueRowFn {
66    fn signature(&self) -> &FnSignature {
67        // The legacy registry has no signature metadata; we synthesize a
68        // generic `CypherValue → CypherValue` signature that goes through
69        // the LargeBinary transport on the DataFusion side.
70        self.signature.get_or_init(|| FnSignature {
71            // Variadic CypherValue input (the legacy closure shape never
72            // declared arities).
73            args: vec![ArgType::Variadic(Box::new(ArgType::CypherValue))],
74            returns: ArgType::CypherValue,
75            volatility: Volatility::Volatile,
76            null_handling: NullHandling::UserHandled,
77        })
78    }
79
80    fn invoke(&self, args: &[ColumnarValue], rows: usize) -> Result<ColumnarValue, FnError> {
81        // Materialize each ColumnarValue into a row-major Vec<Vec<Value>>.
82        let materialized: Vec<Vec<Value>> = args
83            .iter()
84            .map(|c| columnar_to_values(c, rows))
85            .collect::<Result<Vec<_>, _>>()?;
86
87        let mut out_values: Vec<Value> = Vec::with_capacity(rows);
88        for row in 0..rows {
89            let mut row_args: Vec<Value> = Vec::with_capacity(materialized.len());
90            for col in &materialized {
91                row_args.push(col[row].clone());
92            }
93            let v = (self.inner)(&row_args).map_err(|e| {
94                FnError::new(
95                    0x1000,
96                    format!("legacy scalar fn `{}` failed: {e}", self.name),
97                )
98            })?;
99            out_values.push(v);
100        }
101
102        // Serialize as LargeBinary (the legacy CypherValue transport).
103        values_to_large_binary(&out_values)
104    }
105}
106
107/// Convert a [`ColumnarValue`] to a `Vec<Value>` of length `rows`.
108fn columnar_to_values(c: &ColumnarValue, rows: usize) -> Result<Vec<Value>, FnError> {
109    match c {
110        ColumnarValue::Scalar(s) => {
111            let v = scalar_to_value(s);
112            Ok(vec![v; rows])
113        }
114        ColumnarValue::Array(arr) => array_to_values(arr.as_ref()),
115    }
116}
117
118fn scalar_to_value(s: &datafusion::scalar::ScalarValue) -> Value {
119    use datafusion::scalar::ScalarValue;
120    match s {
121        ScalarValue::Null => Value::Null,
122        ScalarValue::Boolean(Some(b)) => Value::Bool(*b),
123        ScalarValue::Boolean(None) => Value::Null,
124        ScalarValue::Int64(Some(i)) => Value::Int(*i),
125        ScalarValue::Int64(None) => Value::Null,
126        ScalarValue::Float64(Some(f)) => Value::Float(*f),
127        ScalarValue::Float64(None) => Value::Null,
128        ScalarValue::Utf8(Some(s)) => Value::String(s.clone()),
129        ScalarValue::Utf8(None) => Value::Null,
130        ScalarValue::LargeBinary(Some(bytes)) => decode_cypher_value(bytes).unwrap_or(Value::Null),
131        ScalarValue::LargeBinary(None) => Value::Null,
132        // Other types: fall back to displaying as a String so the closure
133        // sees something coherent. A future commit narrows this once the
134        // legacy adapter is purely a transitional code path.
135        _ => Value::String(s.to_string()),
136    }
137}
138
139fn array_to_values(arr: &dyn Array) -> Result<Vec<Value>, FnError> {
140    let n = arr.len();
141    let mut out = Vec::with_capacity(n);
142
143    match arr.data_type() {
144        DataType::Boolean => {
145            let a = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
146                FnError::new(FnError::CODE_TYPE_COERCION, "expected BooleanArray")
147            })?;
148            for i in 0..n {
149                out.push(if a.is_null(i) {
150                    Value::Null
151                } else {
152                    Value::Bool(a.value(i))
153                });
154            }
155        }
156        DataType::Int64 => {
157            let a = arr
158                .as_any()
159                .downcast_ref::<Int64Array>()
160                .ok_or_else(|| FnError::new(FnError::CODE_TYPE_COERCION, "expected Int64Array"))?;
161            for i in 0..n {
162                out.push(if a.is_null(i) {
163                    Value::Null
164                } else {
165                    Value::Int(a.value(i))
166                });
167            }
168        }
169        DataType::Float64 => {
170            let a = arr.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
171                FnError::new(FnError::CODE_TYPE_COERCION, "expected Float64Array")
172            })?;
173            for i in 0..n {
174                out.push(if a.is_null(i) {
175                    Value::Null
176                } else {
177                    Value::Float(a.value(i))
178                });
179            }
180        }
181        DataType::Utf8 => {
182            let a = arr
183                .as_any()
184                .downcast_ref::<StringArray>()
185                .ok_or_else(|| FnError::new(FnError::CODE_TYPE_COERCION, "expected StringArray"))?;
186            for i in 0..n {
187                out.push(if a.is_null(i) {
188                    Value::Null
189                } else {
190                    Value::String(a.value(i).to_owned())
191                });
192            }
193        }
194        DataType::LargeBinary => {
195            let a = arr
196                .as_any()
197                .downcast_ref::<LargeBinaryArray>()
198                .ok_or_else(|| {
199                    FnError::new(FnError::CODE_TYPE_COERCION, "expected LargeBinaryArray")
200                })?;
201            for i in 0..n {
202                out.push(if a.is_null(i) {
203                    Value::Null
204                } else {
205                    decode_cypher_value(a.value(i)).unwrap_or(Value::Null)
206                });
207            }
208        }
209        other => {
210            return Err(FnError::new(
211                FnError::CODE_TYPE_COERCION,
212                format!("unsupported arrow type in legacy adapter: {other:?}"),
213            ));
214        }
215    }
216
217    Ok(out)
218}
219
220fn values_to_large_binary(values: &[Value]) -> Result<ColumnarValue, FnError> {
221    let mut builder = arrow_array::builder::LargeBinaryBuilder::with_capacity(values.len(), 0);
222    for v in values {
223        match v {
224            Value::Null => builder.append_null(),
225            _ => {
226                let bytes = encode_cypher_value(v)?;
227                builder.append_value(&bytes);
228            }
229        }
230    }
231    Ok(ColumnarValue::Array(Arc::new(builder.finish())))
232}
233
234fn encode_cypher_value(v: &Value) -> Result<Vec<u8>, FnError> {
235    // Use the canonical tagged codec — the same encoding every other
236    // consumer in `uni-query` reads via `cypher_value_codec::decode`
237    // (`scan.rs`, `apply.rs`, `df_expr.rs`, `similar_to_expr.rs`).
238    // Previously this was `serde_json::to_vec(v)` which produced raw
239    // textual bytes that downstream readers misinterpreted as tag
240    // bytes (e.g. for `Value::Int(42)` the first byte was ASCII '4' =
241    // 0x34 = 52, surfacing as "unknown CypherValue tag: 52").
242    Ok(uni_common::cypher_value_codec::encode(v))
243}
244
245fn decode_cypher_value(bytes: &[u8]) -> Option<Value> {
246    uni_common::cypher_value_codec::decode(bytes).ok()
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use uni_common::Value;
253
254    #[test]
255    fn value_row_fn_invokes_closure_for_each_row() {
256        let closure: CustomScalarFn = Arc::new(|args: &[Value]| {
257            // double the first int
258            match args.first() {
259                Some(Value::Int(i)) => Ok(Value::Int(i * 2)),
260                _ => Ok(Value::Null),
261            }
262        });
263        let f = ValueRowFn::new("double", closure);
264        let input =
265            ColumnarValue::Array(Arc::new(Int64Array::from(vec![1_i64, 2, 3])) as Arc<dyn Array>);
266        let out = f.invoke(&[input], 3).expect("invoke");
267        // Output is LargeBinary; decode each value.
268        let arr = match out {
269            ColumnarValue::Array(a) => a,
270            _ => panic!("expected array output"),
271        };
272        let lb = arr
273            .as_any()
274            .downcast_ref::<LargeBinaryArray>()
275            .expect("LargeBinaryArray");
276        let vs: Vec<Value> = (0..lb.len())
277            .map(|i| decode_cypher_value(lb.value(i)).unwrap())
278            .collect();
279        assert_eq!(vs, vec![Value::Int(2), Value::Int(4), Value::Int(6)]);
280    }
281
282    #[test]
283    fn value_row_fn_handles_nulls() {
284        let closure: CustomScalarFn =
285            Arc::new(|args: &[Value]| Ok(args.first().cloned().unwrap_or(Value::Null)));
286        let f = ValueRowFn::new("identity", closure);
287        let input = ColumnarValue::Array(
288            Arc::new(Int64Array::from(vec![Some(1), None, Some(3)])) as Arc<dyn Array>
289        );
290        let out = f.invoke(&[input], 3).expect("invoke");
291        let arr = match out {
292            ColumnarValue::Array(a) => a,
293            _ => panic!(),
294        };
295        let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
296        assert!(!lb.is_null(0));
297        assert!(lb.is_null(1));
298        assert!(!lb.is_null(2));
299    }
300}