Skip to main content

uni_plugin_custom/
scalar.rs

1// Rust guideline compliant
2//! `DeclaredScalarFn` — a [`ScalarPluginFn`] that evaluates a parsed
3//! Cypher expression body row-by-row.
4//!
5//! Constructed by the `uni.plugin.declareFunction` procedure with a
6//! pre-parsed [`Expr`] body and a list of declared argument names. On
7//! every invocation, each row's input columns are decoded into
8//! `uni_common::Value`, bound to the declared parameter names, fed
9//! through the [`crate::eval::eval_expr`] interpreter, and re-encoded
10//! into the output Arrow column.
11
12use std::sync::Arc;
13
14use arrow_array::ArrayRef;
15use arrow_array::builder::{BooleanBuilder, Float64Builder, Int64Builder, StringBuilder};
16use arrow_schema::DataType;
17use datafusion::logical_expr::{ColumnarValue, Volatility};
18use uni_common::Value;
19use uni_cypher::ast::Expr;
20use uni_plugin::FnError;
21use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling, ScalarPluginFn};
22
23use crate::decode::{array_value_at, eval_err_to_fn, stringify};
24use crate::eval::eval_expr;
25
26/// A scalar function declared from Cypher via
27/// `uni.plugin.declareFunction`.
28///
29/// Holds a parsed [`Expr`] body, the declared argument names (in
30/// positional order — same order as the columns passed to
31/// [`ScalarPluginFn::invoke`]), and a precomputed [`FnSignature`].
32pub struct DeclaredScalarFn {
33    body: Expr,
34    arg_names: Vec<String>,
35    signature: FnSignature,
36}
37
38impl std::fmt::Debug for DeclaredScalarFn {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("DeclaredScalarFn")
41            .field("arg_names", &self.arg_names)
42            .field("return_type", &self.signature.returns)
43            .finish_non_exhaustive()
44    }
45}
46
47impl DeclaredScalarFn {
48    /// Construct a declared scalar function.
49    ///
50    /// `arg_names` must align with the positional arguments — element
51    /// 0 of `arg_names` corresponds to column 0 in the invocation's
52    /// `args` slice.
53    #[must_use]
54    pub fn new(body: Expr, arg_names: Vec<String>, signature: FnSignature) -> Self {
55        Self {
56            body,
57            arg_names,
58            signature,
59        }
60    }
61
62    /// Construct a default [`FnSignature`] given an Arrow return type
63    /// and a list of `(name, type)` pairs for arguments.
64    #[must_use]
65    pub fn build_signature(returns: DataType, args: &[(String, DataType)]) -> FnSignature {
66        FnSignature {
67            args: args
68                .iter()
69                .map(|(_, t)| ArgType::Primitive(t.clone()))
70                .collect(),
71            returns: ArgType::Primitive(returns),
72            volatility: Volatility::Volatile,
73            null_handling: NullHandling::UserHandled,
74        }
75    }
76}
77
78impl ScalarPluginFn for DeclaredScalarFn {
79    fn signature(&self) -> &FnSignature {
80        &self.signature
81    }
82
83    fn invoke(&self, args: &[ColumnarValue], rows: usize) -> Result<ColumnarValue, FnError> {
84        if args.len() != self.arg_names.len() {
85            return Err(FnError::new(
86                FnError::CODE_TYPE_COERCION,
87                format!(
88                    "declared scalar fn expected {} args, got {}",
89                    self.arg_names.len(),
90                    args.len()
91                ),
92            ));
93        }
94        let row_count = rows.max(1);
95        let columns: Vec<ArrayRef> = args
96            .iter()
97            .map(|cv| columnar_to_array(cv, row_count))
98            .collect::<Result<_, _>>()?;
99
100        let return_dt = match &self.signature.returns {
101            ArgType::Primitive(dt) => dt.clone(),
102            other => {
103                return Err(FnError::new(
104                    FnError::CODE_TYPE_COERCION,
105                    format!("declared fn return type not supported: {other:?}"),
106                ));
107            }
108        };
109
110        let out = build_output(&return_dt, row_count, |row| {
111            let mut bindings = std::collections::HashMap::with_capacity(columns.len());
112            for (i, col) in columns.iter().enumerate() {
113                bindings.insert(self.arg_names[i].clone(), array_value_at(col, row)?);
114            }
115            eval_expr(&self.body, &bindings).map_err(eval_err_to_fn)
116        })?;
117
118        Ok(ColumnarValue::Array(out))
119    }
120}
121
122fn columnar_to_array(cv: &ColumnarValue, rows: usize) -> Result<ArrayRef, FnError> {
123    match cv {
124        ColumnarValue::Array(a) => Ok(Arc::clone(a)),
125        ColumnarValue::Scalar(s) => s
126            .to_array_of_size(rows)
127            .map_err(|e| FnError::new(FnError::CODE_TYPE_COERCION, format!("scalar→array: {e}"))),
128    }
129}
130
131fn build_output(
132    dt: &DataType,
133    rows: usize,
134    mut row_value: impl FnMut(usize) -> Result<Value, FnError>,
135) -> Result<ArrayRef, FnError> {
136    match dt {
137        DataType::Utf8 => {
138            let mut b = StringBuilder::with_capacity(rows, rows * 8);
139            for row in 0..rows {
140                match row_value(row)? {
141                    Value::Null => b.append_null(),
142                    Value::String(s) => b.append_value(s),
143                    other => b.append_value(stringify(&other)),
144                }
145            }
146            Ok(Arc::new(b.finish()))
147        }
148        DataType::Int64 => {
149            let mut b = Int64Builder::with_capacity(rows);
150            for row in 0..rows {
151                match row_value(row)? {
152                    Value::Null => b.append_null(),
153                    Value::Int(i) => b.append_value(i),
154                    Value::Float(f) => b.append_value(f as i64),
155                    other => {
156                        return Err(FnError::new(
157                            FnError::CODE_TYPE_COERCION,
158                            format!("expected Int64, got {other:?}"),
159                        ));
160                    }
161                }
162            }
163            Ok(Arc::new(b.finish()))
164        }
165        DataType::Float64 => {
166            let mut b = Float64Builder::with_capacity(rows);
167            for row in 0..rows {
168                match row_value(row)? {
169                    Value::Null => b.append_null(),
170                    Value::Int(i) => b.append_value(i as f64),
171                    Value::Float(f) => b.append_value(f),
172                    other => {
173                        return Err(FnError::new(
174                            FnError::CODE_TYPE_COERCION,
175                            format!("expected Float64, got {other:?}"),
176                        ));
177                    }
178                }
179            }
180            Ok(Arc::new(b.finish()))
181        }
182        DataType::Boolean => {
183            let mut b = BooleanBuilder::with_capacity(rows);
184            for row in 0..rows {
185                match row_value(row)? {
186                    Value::Null => b.append_null(),
187                    Value::Bool(v) => b.append_value(v),
188                    other => {
189                        return Err(FnError::new(
190                            FnError::CODE_TYPE_COERCION,
191                            format!("expected Boolean, got {other:?}"),
192                        ));
193                    }
194                }
195            }
196            Ok(Arc::new(b.finish()))
197        }
198        other => Err(FnError::new(
199            FnError::CODE_TYPE_COERCION,
200            format!("declared fn return type {other:?} not supported"),
201        )),
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use arrow_array::{Array, StringArray};
209    use datafusion::scalar::ScalarValue;
210    use uni_cypher::parse_expression;
211
212    fn fn_string(body: &str, arg_names: &[&str]) -> DeclaredScalarFn {
213        let body = parse_expression(body).unwrap();
214        let arg_names: Vec<String> = arg_names.iter().map(|s| (*s).to_owned()).collect();
215        let sig_args: Vec<(String, DataType)> = arg_names
216            .iter()
217            .map(|n| (n.clone(), DataType::Utf8))
218            .collect();
219        let sig = DeclaredScalarFn::build_signature(DataType::Utf8, &sig_args);
220        DeclaredScalarFn::new(body, arg_names, sig)
221    }
222
223    #[test]
224    fn invoke_string_concat_via_scalars() {
225        let f = fn_string("$first + ' ' + $last", &["first", "last"]);
226        let args = vec![
227            ColumnarValue::Scalar(ScalarValue::Utf8(Some("Ada".to_owned()))),
228            ColumnarValue::Scalar(ScalarValue::Utf8(Some("Lovelace".to_owned()))),
229        ];
230        let out = f.invoke(&args, 1).unwrap();
231        let arr = match out {
232            ColumnarValue::Array(a) => a,
233            ColumnarValue::Scalar(_) => panic!("expected array"),
234        };
235        let s = arr.as_any().downcast_ref::<StringArray>().unwrap();
236        assert_eq!(s.value(0), "Ada Lovelace");
237    }
238
239    #[test]
240    fn invoke_arity_mismatch() {
241        let f = fn_string("$first + ' ' + $last", &["first", "last"]);
242        let args = vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(
243            "a".to_owned(),
244        )))];
245        let err = f.invoke(&args, 1).unwrap_err();
246        assert_eq!(err.code, FnError::CODE_TYPE_COERCION);
247    }
248}