uni_plugin/traits/scalar.rs
1//! Scalar plugin functions — Cypher `RETURN myfn(x)`.
2//!
3//! Scalar plugins are the bread and butter of the plugin framework: pure (or
4//! session-scoped) functions that map a row of input columns to one output
5//! value per row. They are columnar by default; a [`RowFn`] adapter provides
6//! the row-at-a-time convenience for plugin authors who don't need
7//! vectorization.
8
9use std::sync::Arc;
10
11use arrow_schema::DataType;
12use datafusion::logical_expr::{ColumnarValue, Volatility};
13
14use crate::errors::FnError;
15
16/// A scalar plugin function — `(rows of columnar input) → 1 columnar output`.
17///
18/// Implementations are expected to be `Send + Sync` because they are shared
19/// across query workers; non-thread-safe state (e.g., a mutable cache) must
20/// be wrapped behind `Mutex` or `RwLock`.
21pub trait ScalarPluginFn: Send + Sync {
22 /// The function's static signature (arg types, return type, volatility).
23 fn signature(&self) -> &FnSignature;
24
25 /// Invoke the function on a batch of inputs.
26 ///
27 /// `args[i]` is the `i`-th argument's column-or-scalar; `rows` is the
28 /// number of rows the caller expects to be produced. Implementations
29 /// should produce exactly `rows` values when returning a column.
30 ///
31 /// # Errors
32 ///
33 /// Returns an [`FnError`] for any per-call failure; this is wrapped into
34 /// `UniError::Plugin` at the call site.
35 fn invoke(&self, args: &[ColumnarValue], rows: usize) -> Result<ColumnarValue, FnError>;
36}
37
38/// Static signature of a scalar plugin function.
39#[derive(Clone, Debug)]
40pub struct FnSignature {
41 /// Argument types, in order.
42 pub args: Vec<ArgType>,
43 /// Return type.
44 pub returns: ArgType,
45 /// DataFusion volatility (drives caching and hoisting).
46 pub volatility: Volatility,
47 /// Null-handling policy.
48 pub null_handling: NullHandling,
49}
50
51impl FnSignature {
52 /// Convenience constructor for the common case: known args/returns,
53 /// derived volatility, propagate-nulls semantics.
54 #[must_use]
55 pub fn new(args: Vec<ArgType>, returns: ArgType, volatility: Volatility) -> Self {
56 Self {
57 args,
58 returns,
59 volatility,
60 null_handling: NullHandling::PropagateNulls,
61 }
62 }
63}
64
65/// How the framework handles `NULL` values in scalar-fn arguments.
66#[derive(Clone, Copy, Debug, PartialEq, Eq)]
67pub enum NullHandling {
68 /// Any `NULL` in an arg short-circuits to `NULL` output (standard Cypher).
69 PropagateNulls,
70 /// The function handles `NULL` explicitly via `Option<T>` semantics.
71 UserHandled,
72}
73
74/// Logical type of a scalar function argument or return.
75///
76/// `Primitive` arguments take the **native Arrow fast path** — no
77/// `LargeBinary` round-trip. `CypherValue` arguments go through the
78/// legacy `LargeBinary` transport for fns that genuinely need to see
79/// `Node` / `Relationship` / `Path` values.
80#[derive(Clone, Debug)]
81pub enum ArgType {
82 /// Native Arrow primitive type (`Float64`, `Int64`, `Utf8`, etc.).
83 Primitive(DataType),
84 /// Full `CypherValue` (serialized as `LargeBinary` opaque to the plugin).
85 CypherValue,
86 /// Fixed-size list of `element` with declared `len`.
87 Vector {
88 /// Number of elements per row.
89 len: usize,
90 /// Element type.
91 element: DataType,
92 },
93 /// Variadic — repeats the inner `ArgType` zero or more times.
94 Variadic(Box<ArgType>),
95}
96
97/// Row-at-a-time adapter wrapping a closure into a [`ScalarPluginFn`].
98///
99/// This is the *convenience* path for plugin authors who don't care about
100/// vectorization. The default columnar contract is preferred for hot-path
101/// UDFs; use `RowFn` for one-off ad-hoc fns where per-row author ergonomics
102/// matter more than per-row performance.
103pub struct RowFn<F> {
104 signature: FnSignature,
105 #[allow(
106 dead_code,
107 reason = "row evaluation is wired by uni-query host adapter; field held for downstream extraction"
108 )]
109 inner: Arc<F>,
110}
111
112impl<F> std::fmt::Debug for RowFn<F> {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("RowFn")
115 .field("signature", &self.signature)
116 .finish_non_exhaustive()
117 }
118}
119
120impl<F> RowFn<F> {
121 /// Wrap a row-shaped closure into a scalar plugin fn.
122 #[must_use]
123 pub fn new(signature: FnSignature, f: F) -> Self {
124 Self {
125 signature,
126 inner: Arc::new(f),
127 }
128 }
129}
130
131// Note: actual row-by-row invocation requires the Value type from
132// uni-common, which we cannot reference here without a cyclic dep.
133// The real `RowFn::invoke` implementation lives in `uni-query` where
134// `Value` is available; this struct is the type-level placeholder.
135
136impl<F> ScalarPluginFn for RowFn<F>
137where
138 F: Send + Sync + 'static,
139{
140 fn signature(&self) -> &FnSignature {
141 &self.signature
142 }
143
144 fn invoke(&self, _args: &[ColumnarValue], _rows: usize) -> Result<ColumnarValue, FnError> {
145 // RowFn::invoke is implemented by the host adapter in uni-query,
146 // which knows how to deserialize ColumnarValue → Value rows and
147 // re-serialize the result. The trait impl here exists so RowFn
148 // implements ScalarPluginFn (for type-erasure into Arc<dyn>); the
149 // actual row evaluation is replaced at the registration boundary
150 // with a closure that has access to uni-common's Value type.
151 Err(FnError::new(
152 0xDEAD,
153 "RowFn::invoke must be intercepted by the host adapter; \
154 see uni-query::custom_functions::register_row_fn",
155 ))
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn signature_constructor() {
165 let sig = FnSignature::new(
166 vec![ArgType::Primitive(DataType::Float64)],
167 ArgType::Primitive(DataType::Float64),
168 Volatility::Immutable,
169 );
170 assert_eq!(sig.args.len(), 1);
171 assert_eq!(sig.null_handling, NullHandling::PropagateNulls);
172 }
173
174 #[test]
175 fn arg_type_variants_round_trip_in_debug() {
176 let t = ArgType::Vector {
177 len: 384,
178 element: DataType::Float32,
179 };
180 let s = format!("{t:?}");
181 assert!(s.contains("Vector"));
182 assert!(s.contains("384"));
183 }
184}