Skip to main content

uni_query/query/
df_udfs_plugin.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Plugin-registry scalar-UDF integration for DataFusion.
5//!
6//! This module bridges the `uni-plugin` plugin system into DataFusion's
7//! scalar-UDF surface. It lives in `uni-query` (not the dependency-light
8//! `uni-query-functions` leaf crate) because it depends on `uni-plugin`
9//! and `tokio` task-locals.
10//!
11//! Responsibilities:
12//!
13//! - The `SESSION_PLUGIN_REGISTRY` tokio task-local and its scope/read
14//!   helpers ([`scoped_with_session_plugin_registry`],
15//!   [`current_session_plugin_registry`]).
16//! - Re-export of the principal task-local helpers from
17//!   `uni_plugin::host::principal` so callers keep their existing
18//!   `uni_query::scoped_with_principal` / `current_principal` paths.
19//! - [`scoped_with_session_context`] combining both scopes.
20//! - [`register_plugin_scalar_udfs`] / [`register_plugin_scalar_udfs_pair`]
21//!   and the `PluginScalarUdf` DataFusion adapter (private).
22
23use std::any::Any;
24use std::hash::{Hash, Hasher};
25use std::sync::Arc;
26
27use arrow::datatypes::DataType;
28use datafusion::error::Result as DFResult;
29use datafusion::logical_expr::{
30    ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
31};
32use datafusion::prelude::SessionContext;
33use uni_query_functions::custom_functions::{CustomFunctionRegistry, LEGACY_USER_PLUGIN_ID};
34
35use crate::query::executor::plugin_adapter::ValueRowFn;
36
37tokio::task_local! {
38    /// Tokio task-local carrying the current `Session`'s
39    /// **session-local** plugin registry across the per-query
40    /// executor scope. Set by host-crate session execute paths via
41    /// [`scoped_with_session_plugin_registry`]; read at the UDF
42    /// registration site (`register_plugin_scalar_udfs_pair`) and at
43    /// the procedure / Locy-aggregate dual-consult helpers.
44    ///
45    /// Propagates across `.await` points within the same task tree;
46    /// does NOT propagate across `tokio::spawn` (which is fine — the
47    /// per-query executor runs everything in the same task).
48    pub static SESSION_PLUGIN_REGISTRY:
49        std::sync::Arc<uni_plugin::PluginRegistry>;
50}
51
52/// Run `fut` inside a scope that exposes `registry` as the current
53/// session-local plugin registry. Returns the future's output.
54///
55/// Use this at every uni-db host-crate boundary where a `Session`
56/// dispatches into the executor.
57pub fn scoped_with_session_plugin_registry<F: std::future::Future>(
58    registry: std::sync::Arc<uni_plugin::PluginRegistry>,
59    fut: F,
60) -> tokio::task::futures::TaskLocalFuture<std::sync::Arc<uni_plugin::PluginRegistry>, F> {
61    SESSION_PLUGIN_REGISTRY.scope(registry, fut)
62}
63
64/// Borrow the current session-local plugin registry, if any. Returns
65/// `None` when the call is not inside a
66/// [`scoped_with_session_plugin_registry`] scope (e.g., a query
67/// against `Uni` directly with no Session in flight, or a unit test
68/// invoking the executor outside the host crate).
69#[must_use]
70pub fn current_session_plugin_registry() -> Option<std::sync::Arc<uni_plugin::PluginRegistry>> {
71    SESSION_PLUGIN_REGISTRY.try_with(|r| r.clone()).ok()
72}
73
74// §1.2 / Phase 5: the principal task-local + scope helpers moved to
75// `uni_plugin::host::principal`. Re-exported here so external callers
76// (`uni::api::{session,transaction}`, downstream embedders) keep their
77// existing `uni_query::scoped_with_principal` / `current_principal`
78// paths.
79pub use uni_plugin::host::principal::{
80    CURRENT_PRINCIPAL, current_principal, maybe_scope_with_principal, scoped_with_principal,
81};
82
83/// Run `fut` inside both [`scoped_with_session_plugin_registry`] and
84/// the principal task-local scope in a single call.
85///
86/// `principal` is optional — when `None`, only the plugin-registry
87/// scope is installed and [`current_principal`] returns `None` inside
88/// `fut`. This matches the legacy behavior for sessions that haven't
89/// authenticated.
90pub async fn scoped_with_session_context<F: std::future::Future>(
91    registry: std::sync::Arc<uni_plugin::PluginRegistry>,
92    principal: Option<std::sync::Arc<uni_plugin::traits::connector::Principal>>,
93    fut: F,
94) -> F::Output {
95    scoped_with_session_plugin_registry(registry, maybe_scope_with_principal(principal, fut)).await
96}
97
98/// Two-registry variant of [`register_plugin_scalar_udfs`] — registers
99/// the instance registry's scalars first, then the session registry's
100/// (if present) on top. DataFusion's `register_udf` is last-write-wins
101/// by registered name, so session entries shadow instance entries
102/// without any explicit ordering logic.
103///
104/// This is the resolution path used per-query when a `Session` carries
105/// a session-local plugin registry. See `proposal §5.4.2` for the
106/// session-scope contract and the M8.6 follow-up plan.
107///
108/// # Errors
109///
110/// Returns an error if any UDF registration fails.
111pub fn register_plugin_scalar_udfs_pair(
112    ctx: &SessionContext,
113    instance: &uni_plugin::PluginRegistry,
114    session: Option<&uni_plugin::PluginRegistry>,
115) -> DFResult<()> {
116    register_plugin_scalar_udfs(ctx, instance)?;
117    if let Some(session_reg) = session {
118        register_plugin_scalar_udfs(ctx, session_reg)?;
119    }
120    Ok(())
121}
122
123/// Register every scalar function in a `PluginRegistry` as a DataFusion UDF.
124///
125/// M2's plugin-path DataFusion adapter — iterates
126/// [`uni_plugin::PluginRegistry`] directly.
127///
128/// Registers each scalar as both lowercase and uppercase local-name
129/// variants so Cypher's case-insensitive function-name match resolves.
130/// The qname's namespace is preserved (Cypher syntax uses dotted names for
131/// qualified callable references).
132///
133/// # Errors
134///
135/// Returns an error if any UDF registration fails.
136pub fn register_plugin_scalar_udfs(
137    ctx: &SessionContext,
138    plugin_registry: &uni_plugin::PluginRegistry,
139) -> DFResult<()> {
140    for (qname, entry) in plugin_registry.iter_scalars() {
141        let local = qname.local();
142        let lower_local = local.to_lowercase();
143        let upper_local = local.to_uppercase();
144
145        // Local-name registrations — what Cypher's case-insensitive
146        // lookup hits.
147        if lower_local != upper_local {
148            ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
149                lower_local.clone(),
150                Arc::clone(&entry),
151            )));
152        }
153        ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
154            upper_local,
155            Arc::clone(&entry),
156        )));
157
158        // Also register under the fully-qualified name (`namespace.local`)
159        // so dotted-name dispatch works.
160        ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
161            qname.to_string(),
162            Arc::clone(&entry),
163        )));
164    }
165    Ok(())
166}
167
168/// Build a shadow [`uni_plugin::PluginRegistry`] from the pure
169/// [`CustomFunctionRegistry`] leaf type.
170///
171/// `CustomFunctionRegistry` lives in the dependency-light
172/// `uni-query-functions` crate and stores only `(name, fn)` pairs. The
173/// plugin-framework dispatch path (`register_plugin_scalar_udfs`) consumes a
174/// `PluginRegistry`, so we mirror every legacy registration into one under
175/// the reserved [`LEGACY_USER_PLUGIN_ID`] here, where the `uni-plugin`
176/// dependency is available.
177///
178/// Each entry is wrapped in a [`ValueRowFn`] adapter and given a permissive
179/// `CypherValue` signature — the actual coercion happens at the DataFusion
180/// adapter ([`PluginScalarUdf`]) site.
181fn plugin_registry_for_custom_functions(
182    registry: &CustomFunctionRegistry,
183) -> uni_plugin::PluginRegistry {
184    use datafusion::logical_expr::Volatility;
185    use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling};
186    use uni_plugin::{Capability, CapabilitySet, PluginId, PluginRegistrar, PluginRegistry, QName};
187
188    let pr = PluginRegistry::new();
189    let plugin_id = PluginId::new(LEGACY_USER_PLUGIN_ID);
190    let caps = CapabilitySet::from_iter_of([Capability::ScalarFn]);
191
192    for (name, func) in registry.iter() {
193        // `CustomFunctionRegistry` already uppercases names on registration,
194        // but normalize defensively so the qname matches the plugin id's
195        // namespace expectations.
196        let upper = name.to_uppercase();
197        let mut r = PluginRegistrar::new(plugin_id.clone(), &caps, &pr);
198        let qname = QName::new(LEGACY_USER_PLUGIN_ID, &upper);
199        let adapter = Arc::new(ValueRowFn::new(upper.clone(), Arc::clone(func)));
200        let sig = FnSignature {
201            args: vec![ArgType::Variadic(Box::new(ArgType::CypherValue))],
202            returns: ArgType::CypherValue,
203            volatility: Volatility::Volatile,
204            null_handling: NullHandling::UserHandled,
205        };
206        if let Err(e) = r.scalar_fn(qname, sig, adapter) {
207            tracing::warn!(error = ?e, fn_name = %upper, "shadow registration failed");
208            continue;
209        }
210        if let Err(e) = r.commit_to_registry() {
211            tracing::warn!(error = ?e, fn_name = %upper, "shadow commit failed");
212        }
213    }
214    pr
215}
216
217/// Register the legacy [`CustomFunctionRegistry`] entries as DataFusion
218/// scalar UDFs by mirroring them through the plugin-framework adapter.
219///
220/// This is the instance-scope, legacy `CustomFunctionRegistry` shadow path
221/// (e.g. `db.register_function()` entries + apoc-core mirrors).
222///
223/// # Errors
224///
225/// Returns an error if any UDF registration fails.
226pub fn register_custom_functions_as_plugin_scalars(
227    ctx: &SessionContext,
228    registry: &CustomFunctionRegistry,
229) -> DFResult<()> {
230    let shadow = plugin_registry_for_custom_functions(registry);
231    register_plugin_scalar_udfs(ctx, &shadow)
232}
233
234/// DataFusion adapter wrapping a [`uni_plugin::registry::ScalarEntry`].
235///
236/// Inspects the plugin's `signature.returns` at construction time to pick
237/// the DataFusion return type:
238///
239/// - `ArgType::Primitive(T)` → declares `T` directly to DataFusion. The
240///   plugin's `invoke()` returns Arrow data in `T`'s native type, no
241///   LargeBinary round-trip. This is the ≥ 20% perf target path for
242///   primitively-typed UDFs.
243/// - `ArgType::CypherValue` → declares `LargeBinary` (legacy transport).
244/// - `ArgType::Vector { .. }` / `Variadic(..)` → `LargeBinary` for now.
245///
246/// The same adapter is used for both the local-name and qualified-name
247/// registrations.
248struct PluginScalarUdf {
249    name: String,
250    entry: Arc<uni_plugin::registry::ScalarEntry>,
251    signature: Signature,
252    return_type: DataType,
253}
254
255impl PluginScalarUdf {
256    fn new(name: String, entry: Arc<uni_plugin::registry::ScalarEntry>) -> Self {
257        // Derive volatility from the plugin's declared volatility, falling
258        // back to Volatile if signature inspection fails. (The plugin's
259        // FnSignature is the canonical source of truth.)
260        let volatility = entry.signature.volatility;
261        let return_type = derive_return_type(&entry);
262        Self {
263            signature: Signature::new(TypeSignature::VariadicAny, volatility),
264            name,
265            entry,
266            return_type,
267        }
268    }
269}
270
271/// Derive the DataFusion return type from the plugin's declared signature.
272fn derive_return_type(entry: &uni_plugin::registry::ScalarEntry) -> DataType {
273    use uni_plugin::traits::scalar::ArgType;
274    match &entry.signature.returns {
275        ArgType::Primitive(t) => t.clone(),
276        // CypherValue + Vector + Variadic stay on the LargeBinary path
277        // (the latter two are uncommon for return types; CypherValue is
278        // explicit opt-in to the legacy transport).
279        _ => DataType::LargeBinary,
280    }
281}
282
283impl std::fmt::Debug for PluginScalarUdf {
284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        f.debug_struct("PluginScalarUdf")
286            .field("name", &self.name)
287            .finish()
288    }
289}
290
291/// Resolve the declared [`ArgType`] for positional argument `i`, transparently
292/// unwrapping a trailing `Variadic(..)` so arguments beyond the fixed prefix
293/// inherit the variadic element type.
294fn declared_arg_type(
295    args: &[uni_plugin::traits::scalar::ArgType],
296    i: usize,
297) -> Option<&uni_plugin::traits::scalar::ArgType> {
298    use uni_plugin::traits::scalar::ArgType;
299    let raw = match args.get(i) {
300        Some(a) => Some(a),
301        // Past the fixed args: only a trailing variadic keeps matching.
302        None => match args.last() {
303            Some(v @ ArgType::Variadic(_)) => Some(v),
304            _ => None,
305        },
306    };
307    match raw {
308        Some(ArgType::Variadic(inner)) => Some(inner.as_ref()),
309        other => other,
310    }
311}
312
313/// Coerce one plugin-scalar argument column from the schemaless `LargeBinary`
314/// CypherValue transport into the primitive Arrow type the manifest declares
315/// (`Int64`/`Float64`) for that argument.
316///
317/// A raw integer/float node or edge property reaches expression evaluation as a
318/// `LargeBinary` variant column (schemaless storage). A plugin scalar that
319/// declares `Primitive(Int64)` downcasts its argument to `Int64Array` and fails
320/// on `LargeBinary` — previously forcing callers to wrap the property in
321/// `toInteger(...)`/`toFloat(...)`. This performs that coercion automatically
322/// (REQ-4). A value that genuinely cannot be coerced (e.g. a string where an
323/// integer is declared) yields a precise error naming the argument, the declared
324/// type, and the explicit-coercion hint, instead of an opaque downcast failure.
325///
326/// Columns whose declared type is not a numeric primitive, or that already
327/// arrive as a non-`LargeBinary` (i.e. natively typed) array, pass through
328/// untouched.
329fn coerce_plugin_scalar_arg(
330    col: ColumnarValue,
331    declared: Option<&uni_plugin::traits::scalar::ArgType>,
332    rows: usize,
333    arg_idx: usize,
334    fn_name: &str,
335) -> DFResult<ColumnarValue> {
336    use arrow::array::{Array, ArrayRef, Float64Array, Int64Array, LargeBinaryArray};
337    use uni_common::Value;
338    use uni_plugin::traits::scalar::ArgType;
339
340    let target = match declared {
341        Some(ArgType::Primitive(t @ (DataType::Int64 | DataType::Float64))) => t.clone(),
342        _ => return Ok(col),
343    };
344
345    let array = col.to_array(rows)?;
346    // Already natively typed (or some other non-variant transport): nothing to do.
347    if array.data_type() != &DataType::LargeBinary {
348        return Ok(col);
349    }
350    let lb = array
351        .as_any()
352        .downcast_ref::<LargeBinaryArray>()
353        .expect("data_type checked to be LargeBinary");
354
355    let non_numeric_err = |row: usize, got: &Value| {
356        let hint = if target == DataType::Int64 {
357            "toInteger(...)"
358        } else {
359            "toFloat(...)"
360        };
361        datafusion::error::DataFusionError::Execution(format!(
362            "plugin fn `{fn_name}`: argument {} declares {target} but row {row} carried a \
363             non-numeric value ({got:?}); wrap the property with {hint}",
364            arg_idx + 1,
365        ))
366    };
367
368    let decoded =
369        |row: usize| uni_store::storage::arrow_convert::arrow_to_value(array.as_ref(), row, None);
370
371    let out: ArrayRef = match target {
372        DataType::Int64 => {
373            let mut b = Int64Array::builder(array.len());
374            for row in 0..array.len() {
375                if lb.is_null(row) || lb.value(row).is_empty() {
376                    b.append_null();
377                    continue;
378                }
379                match decoded(row) {
380                    Value::Int(i) => b.append_value(i),
381                    Value::Float(f) => b.append_value(f as i64),
382                    Value::Null => b.append_null(),
383                    other => return Err(non_numeric_err(row, &other)),
384                }
385            }
386            Arc::new(b.finish())
387        }
388        DataType::Float64 => {
389            let mut b = Float64Array::builder(array.len());
390            for row in 0..array.len() {
391                if lb.is_null(row) || lb.value(row).is_empty() {
392                    b.append_null();
393                    continue;
394                }
395                match decoded(row) {
396                    Value::Float(f) => b.append_value(f),
397                    Value::Int(i) => b.append_value(i as f64),
398                    Value::Null => b.append_null(),
399                    other => return Err(non_numeric_err(row, &other)),
400                }
401            }
402            Arc::new(b.finish())
403        }
404        _ => unreachable!("target restricted to Int64/Float64 above"),
405    };
406    Ok(ColumnarValue::Array(out))
407}
408
409impl PartialEq for PluginScalarUdf {
410    fn eq(&self, other: &Self) -> bool {
411        self.signature == other.signature
412    }
413}
414
415impl Eq for PluginScalarUdf {}
416
417impl Hash for PluginScalarUdf {
418    fn hash<H: Hasher>(&self, state: &mut H) {
419        self.name().hash(state);
420    }
421}
422
423impl ScalarUDFImpl for PluginScalarUdf {
424    fn as_any(&self) -> &dyn Any {
425        self
426    }
427
428    fn name(&self) -> &str {
429        &self.name
430    }
431
432    fn signature(&self) -> &Signature {
433        &self.signature
434    }
435
436    fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
437        Ok(self.return_type.clone())
438    }
439
440    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
441        let entry = Arc::clone(&self.entry);
442        let rows = args.number_rows;
443        // Auto-coerce raw schemaless (LargeBinary) numeric property args to the
444        // primitive type the plugin's manifest declares, so a property can be
445        // passed without an explicit toInteger()/toFloat() wrapper (REQ-4).
446        let declared = &entry.signature.args;
447        let cols = args
448            .args
449            .into_iter()
450            .enumerate()
451            .map(|(i, col)| {
452                coerce_plugin_scalar_arg(col, declared_arg_type(declared, i), rows, i, &self.name)
453            })
454            .collect::<DFResult<Vec<_>>>()?;
455        entry.function.invoke(&cols, rows).map_err(|e| {
456            datafusion::error::DataFusionError::Execution(format!(
457                "plugin `{}` fn `{}` failed: {e}",
458                entry.plugin, self.name
459            ))
460        })
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    // `SessionContext::udf` is a `FunctionRegistry` trait method; bring the
468    // trait into scope so the registration assertions resolve. `Volatility`
469    // is used only by the in-test plugin fixtures.
470    use datafusion::execution::FunctionRegistry;
471    use datafusion::logical_expr::Volatility;
472
473    #[test]
474    fn test_register_plugin_scalars_routes_through_plugin_registry() {
475        // M2 facade test: register a scalar via the legacy
476        // CustomFunctionRegistry, then verify that calling
477        // `register_plugin_scalar_udfs` against its `plugin_registry()`
478        // exposes the same fn through DataFusion under both case-folds and
479        // the fully-qualified namespace.
480        use uni_common::Value;
481        use uni_query_functions::custom_functions::{CustomFunctionRegistry, CustomScalarFn};
482
483        let mut reg = CustomFunctionRegistry::new();
484        let f: CustomScalarFn =
485            Arc::new(|_args: &[Value]| Ok(Value::String("plugin-path".to_owned())));
486        reg.register("MYFN".to_owned(), f);
487
488        let ctx = SessionContext::new();
489        register_custom_functions_as_plugin_scalars(&ctx, &reg).unwrap();
490
491        // Local-name lowercase form (Cypher case-insensitive dispatch).
492        assert!(ctx.udf("myfn").is_ok());
493        // Uppercase local name.
494        assert!(ctx.udf("MYFN").is_ok());
495        // Fully-qualified namespace form.
496        let qname = format!("{LEGACY_USER_PLUGIN_ID}.MYFN");
497        assert!(ctx.udf(&qname).is_ok());
498    }
499
500    #[test]
501    fn test_native_arrow_udf_declares_primitive_return_type() {
502        // M2 fast path: a plugin declaring `ArgType::Primitive(Float64)` as
503        // its return type should produce a DataFusion UDF whose
504        // `return_type` is `Float64`, not `LargeBinary`. This eliminates
505        // the per-row LargeBinary round-trip.
506        use std::sync::OnceLock;
507        use uni_plugin::FnError;
508        use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling, ScalarPluginFn};
509        use uni_plugin::{
510            Capability, CapabilitySet, PluginId, PluginRegistrar, PluginRegistry, QName,
511        };
512
513        struct DoubleIt;
514        impl ScalarPluginFn for DoubleIt {
515            fn signature(&self) -> &FnSignature {
516                static S: OnceLock<FnSignature> = OnceLock::new();
517                S.get_or_init(|| FnSignature {
518                    args: vec![ArgType::Primitive(DataType::Float64)],
519                    returns: ArgType::Primitive(DataType::Float64),
520                    volatility: Volatility::Immutable,
521                    null_handling: NullHandling::PropagateNulls,
522                })
523            }
524            fn invoke(
525                &self,
526                args: &[ColumnarValue],
527                _rows: usize,
528            ) -> Result<ColumnarValue, FnError> {
529                Ok(args.first().cloned().unwrap())
530            }
531        }
532
533        let pr = PluginRegistry::new();
534        let caps = CapabilitySet::from_iter_of([Capability::ScalarFn]);
535        let mut r = PluginRegistrar::new(PluginId::new("test.fast"), &caps, &pr);
536        r.scalar_fn(
537            QName::new("test.fast", "double"),
538            FnSignature {
539                args: vec![ArgType::Primitive(DataType::Float64)],
540                returns: ArgType::Primitive(DataType::Float64),
541                volatility: Volatility::Immutable,
542                null_handling: NullHandling::PropagateNulls,
543            },
544            Arc::new(DoubleIt),
545        )
546        .unwrap();
547        r.commit_to_registry().unwrap();
548
549        let ctx = SessionContext::new();
550        register_plugin_scalar_udfs(&ctx, &pr).unwrap();
551
552        // Resolve the UDF and ask DataFusion for its return type.
553        let udf = ctx.udf("double").expect("udf registered");
554        let rt = udf.return_type(&[DataType::Float64]).unwrap();
555        assert_eq!(
556            rt,
557            DataType::Float64,
558            "primitive-typed plugin should declare Float64 directly, not LargeBinary"
559        );
560    }
561}