Skip to main content

uni_plugin/traits/
scalar.rs

1//! Scalar plugin functions — Cypher `RETURN myfn(x)`.
2//!
3//! Scalar plugins are the bread and butter of the plugin framework: pure (or
4//! session-scoped) functions that map a row of input columns to one output
5//! value per row. They are columnar by default; a [`RowFn`] adapter provides
6//! the row-at-a-time convenience for plugin authors who don't need
7//! vectorization.
8
9use std::sync::Arc;
10
11use arrow_schema::DataType;
12use datafusion::logical_expr::{ColumnarValue, Volatility};
13
14use crate::errors::FnError;
15
16/// A scalar plugin function — `(rows of columnar input) → 1 columnar output`.
17///
18/// Implementations are expected to be `Send + Sync` because they are shared
19/// across query workers; non-thread-safe state (e.g., a mutable cache) must
20/// be wrapped behind `Mutex` or `RwLock`.
21pub trait ScalarPluginFn: Send + Sync {
22    /// The function's static signature (arg types, return type, volatility).
23    fn signature(&self) -> &FnSignature;
24
25    /// Invoke the function on a batch of inputs.
26    ///
27    /// `args[i]` is the `i`-th argument's column-or-scalar; `rows` is the
28    /// number of rows the caller expects to be produced. Implementations
29    /// should produce exactly `rows` values when returning a column.
30    ///
31    /// # Errors
32    ///
33    /// Returns an [`FnError`] for any per-call failure; this is wrapped into
34    /// `UniError::Plugin` at the call site.
35    fn invoke(&self, args: &[ColumnarValue], rows: usize) -> Result<ColumnarValue, FnError>;
36}
37
38/// Static signature of a scalar plugin function.
39#[derive(Clone, Debug)]
40pub struct FnSignature {
41    /// Argument types, in order.
42    pub args: Vec<ArgType>,
43    /// Return type.
44    pub returns: ArgType,
45    /// DataFusion volatility (drives caching and hoisting).
46    pub volatility: Volatility,
47    /// Null-handling policy.
48    pub null_handling: NullHandling,
49}
50
51impl FnSignature {
52    /// Convenience constructor for the common case: known args/returns,
53    /// derived volatility, propagate-nulls semantics.
54    #[must_use]
55    pub fn new(args: Vec<ArgType>, returns: ArgType, volatility: Volatility) -> Self {
56        Self {
57            args,
58            returns,
59            volatility,
60            null_handling: NullHandling::PropagateNulls,
61        }
62    }
63}
64
65/// How the framework handles `NULL` values in scalar-fn arguments.
66#[derive(Clone, Copy, Debug, PartialEq, Eq)]
67pub enum NullHandling {
68    /// Any `NULL` in an arg short-circuits to `NULL` output (standard Cypher).
69    PropagateNulls,
70    /// The function handles `NULL` explicitly via `Option<T>` semantics.
71    UserHandled,
72}
73
74/// Logical type of a scalar function argument or return.
75///
76/// `Primitive` arguments take the **native Arrow fast path** — no
77/// `LargeBinary` round-trip. `CypherValue` arguments go through the
78/// legacy `LargeBinary` transport for fns that genuinely need to see
79/// `Node` / `Relationship` / `Path` values.
80#[derive(Clone, Debug)]
81pub enum ArgType {
82    /// Native Arrow primitive type (`Float64`, `Int64`, `Utf8`, etc.).
83    Primitive(DataType),
84    /// Full `CypherValue` (serialized as `LargeBinary` opaque to the plugin).
85    CypherValue,
86    /// Fixed-size list of `element` with declared `len`.
87    Vector {
88        /// Number of elements per row.
89        len: usize,
90        /// Element type.
91        element: DataType,
92    },
93    /// Variadic — repeats the inner `ArgType` zero or more times.
94    Variadic(Box<ArgType>),
95}
96
97/// Row-at-a-time adapter wrapping a closure into a [`ScalarPluginFn`].
98///
99/// This is the *convenience* path for plugin authors who don't care about
100/// vectorization. The default columnar contract is preferred for hot-path
101/// UDFs; use `RowFn` for one-off ad-hoc fns where per-row author ergonomics
102/// matter more than per-row performance.
103pub struct RowFn<F> {
104    signature: FnSignature,
105    #[allow(
106        dead_code,
107        reason = "row evaluation is wired by uni-query host adapter; field held for downstream extraction"
108    )]
109    inner: Arc<F>,
110}
111
112impl<F> std::fmt::Debug for RowFn<F> {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("RowFn")
115            .field("signature", &self.signature)
116            .finish_non_exhaustive()
117    }
118}
119
120impl<F> RowFn<F> {
121    /// Wrap a row-shaped closure into a scalar plugin fn.
122    #[must_use]
123    pub fn new(signature: FnSignature, f: F) -> Self {
124        Self {
125            signature,
126            inner: Arc::new(f),
127        }
128    }
129}
130
131// Note: actual row-by-row invocation requires the Value type from
132// uni-common, which we cannot reference here without a cyclic dep.
133// The real `RowFn::invoke` implementation lives in `uni-query` where
134// `Value` is available; this struct is the type-level placeholder.
135
136impl<F> ScalarPluginFn for RowFn<F>
137where
138    F: Send + Sync + 'static,
139{
140    fn signature(&self) -> &FnSignature {
141        &self.signature
142    }
143
144    fn invoke(&self, _args: &[ColumnarValue], _rows: usize) -> Result<ColumnarValue, FnError> {
145        // RowFn::invoke is implemented by the host adapter in uni-query,
146        // which knows how to deserialize ColumnarValue → Value rows and
147        // re-serialize the result. The trait impl here exists so RowFn
148        // implements ScalarPluginFn (for type-erasure into Arc<dyn>); the
149        // actual row evaluation is replaced at the registration boundary
150        // with a closure that has access to uni-common's Value type.
151        Err(FnError::new(
152            0xDEAD,
153            "RowFn::invoke must be intercepted by the host adapter; \
154             see uni-query::custom_functions::register_row_fn",
155        ))
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn signature_constructor() {
165        let sig = FnSignature::new(
166            vec![ArgType::Primitive(DataType::Float64)],
167            ArgType::Primitive(DataType::Float64),
168            Volatility::Immutable,
169        );
170        assert_eq!(sig.args.len(), 1);
171        assert_eq!(sig.null_handling, NullHandling::PropagateNulls);
172    }
173
174    #[test]
175    fn arg_type_variants_round_trip_in_debug() {
176        let t = ArgType::Vector {
177            len: 384,
178            element: DataType::Float32,
179        };
180        let s = format!("{t:?}");
181        assert!(s.contains("Vector"));
182        assert!(s.contains("384"));
183    }
184}