Skip to main content

uni_query/query/
df_udfs_plugin.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Plugin-registry scalar-UDF integration for DataFusion.
5//!
6//! This module bridges the `uni-plugin` plugin system into DataFusion's
7//! scalar-UDF surface. It lives in `uni-query` (not the dependency-light
8//! `uni-query-functions` leaf crate) because it depends on `uni-plugin`
9//! and `tokio` task-locals.
10//!
11//! Responsibilities:
12//!
13//! - The `SESSION_PLUGIN_REGISTRY` tokio task-local and its scope/read
14//!   helpers ([`scoped_with_session_plugin_registry`],
15//!   [`current_session_plugin_registry`]).
16//! - Re-export of the principal task-local helpers from
17//!   `uni_plugin::host::principal` so callers keep their existing
18//!   `uni_query::scoped_with_principal` / `current_principal` paths.
19//! - [`scoped_with_session_context`] combining both scopes.
20//! - [`register_plugin_scalar_udfs`] / [`register_plugin_scalar_udfs_pair`]
21//!   and the `PluginScalarUdf` DataFusion adapter (private).
22
23use std::any::Any;
24use std::hash::{Hash, Hasher};
25use std::sync::Arc;
26
27use arrow::datatypes::DataType;
28use datafusion::error::Result as DFResult;
29use datafusion::logical_expr::{
30    ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
31};
32use datafusion::prelude::SessionContext;
33use uni_query_functions::custom_functions::{CustomFunctionRegistry, LEGACY_USER_PLUGIN_ID};
34
35use crate::query::executor::plugin_adapter::ValueRowFn;
36
37tokio::task_local! {
38    /// Tokio task-local carrying the current `Session`'s
39    /// **session-local** plugin registry across the per-query
40    /// executor scope. Set by host-crate session execute paths via
41    /// [`scoped_with_session_plugin_registry`]; read at the UDF
42    /// registration site (`register_plugin_scalar_udfs_pair`) and at
43    /// the procedure / Locy-aggregate dual-consult helpers.
44    ///
45    /// Propagates across `.await` points within the same task tree;
46    /// does NOT propagate across `tokio::spawn` (which is fine — the
47    /// per-query executor runs everything in the same task).
48    pub static SESSION_PLUGIN_REGISTRY:
49        std::sync::Arc<uni_plugin::PluginRegistry>;
50}
51
52/// Run `fut` inside a scope that exposes `registry` as the current
53/// session-local plugin registry. Returns the future's output.
54///
55/// Use this at every uni-db host-crate boundary where a `Session`
56/// dispatches into the executor.
57pub fn scoped_with_session_plugin_registry<F: std::future::Future>(
58    registry: std::sync::Arc<uni_plugin::PluginRegistry>,
59    fut: F,
60) -> tokio::task::futures::TaskLocalFuture<std::sync::Arc<uni_plugin::PluginRegistry>, F> {
61    SESSION_PLUGIN_REGISTRY.scope(registry, fut)
62}
63
64/// Borrow the current session-local plugin registry, if any. Returns
65/// `None` when the call is not inside a
66/// [`scoped_with_session_plugin_registry`] scope (e.g., a query
67/// against `Uni` directly with no Session in flight, or a unit test
68/// invoking the executor outside the host crate).
69#[must_use]
70pub fn current_session_plugin_registry() -> Option<std::sync::Arc<uni_plugin::PluginRegistry>> {
71    SESSION_PLUGIN_REGISTRY.try_with(|r| r.clone()).ok()
72}
73
74// §1.2 / Phase 5: the principal task-local + scope helpers moved to
75// `uni_plugin::host::principal`. Re-exported here so external callers
76// (`uni::api::{session,transaction}`, downstream embedders) keep their
77// existing `uni_query::scoped_with_principal` / `current_principal`
78// paths.
79pub use uni_plugin::host::principal::{
80    CURRENT_PRINCIPAL, current_principal, maybe_scope_with_principal, scoped_with_principal,
81};
82
83/// Run `fut` inside both [`scoped_with_session_plugin_registry`] and
84/// the principal task-local scope in a single call.
85///
86/// `principal` is optional — when `None`, only the plugin-registry
87/// scope is installed and [`current_principal`] returns `None` inside
88/// `fut`. This matches the legacy behavior for sessions that haven't
89/// authenticated.
90pub async fn scoped_with_session_context<F: std::future::Future>(
91    registry: std::sync::Arc<uni_plugin::PluginRegistry>,
92    principal: Option<std::sync::Arc<uni_plugin::traits::connector::Principal>>,
93    fut: F,
94) -> F::Output {
95    scoped_with_session_plugin_registry(registry, maybe_scope_with_principal(principal, fut)).await
96}
97
98/// Two-registry variant of [`register_plugin_scalar_udfs`] — registers
99/// the instance registry's scalars first, then the session registry's
100/// (if present) on top. DataFusion's `register_udf` is last-write-wins
101/// by registered name, so session entries shadow instance entries
102/// without any explicit ordering logic.
103///
104/// This is the resolution path used per-query when a `Session` carries
105/// a session-local plugin registry. See `proposal §5.4.2` for the
106/// session-scope contract and the M8.6 follow-up plan.
107///
108/// # Errors
109///
110/// Returns an error if any UDF registration fails.
111pub fn register_plugin_scalar_udfs_pair(
112    ctx: &SessionContext,
113    instance: &uni_plugin::PluginRegistry,
114    session: Option<&uni_plugin::PluginRegistry>,
115) -> DFResult<()> {
116    register_plugin_scalar_udfs(ctx, instance)?;
117    if let Some(session_reg) = session {
118        register_plugin_scalar_udfs(ctx, session_reg)?;
119    }
120    Ok(())
121}
122
123/// Register every scalar function in a `PluginRegistry` as a DataFusion UDF.
124///
125/// M2's plugin-path DataFusion adapter — iterates
126/// [`uni_plugin::PluginRegistry`] directly.
127///
128/// Registers each scalar as both lowercase and uppercase local-name
129/// variants so Cypher's case-insensitive function-name match resolves.
130/// The qname's namespace is preserved (Cypher syntax uses dotted names for
131/// qualified callable references).
132///
133/// # Errors
134///
135/// Returns an error if any UDF registration fails.
136pub fn register_plugin_scalar_udfs(
137    ctx: &SessionContext,
138    plugin_registry: &uni_plugin::PluginRegistry,
139) -> DFResult<()> {
140    for (qname, entry) in plugin_registry.iter_scalars() {
141        let local = qname.local();
142        let lower_local = local.to_lowercase();
143        let upper_local = local.to_uppercase();
144
145        // Local-name registrations — what Cypher's case-insensitive
146        // lookup hits.
147        if lower_local != upper_local {
148            ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
149                lower_local.clone(),
150                Arc::clone(&entry),
151            )));
152        }
153        ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
154            upper_local,
155            Arc::clone(&entry),
156        )));
157
158        // Also register under the fully-qualified name (`namespace.local`)
159        // so dotted-name dispatch works.
160        ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
161            qname.to_string(),
162            Arc::clone(&entry),
163        )));
164    }
165    Ok(())
166}
167
168/// Build a shadow [`uni_plugin::PluginRegistry`] from the pure
169/// [`CustomFunctionRegistry`] leaf type.
170///
171/// `CustomFunctionRegistry` lives in the dependency-light
172/// `uni-query-functions` crate and stores only `(name, fn)` pairs. The
173/// plugin-framework dispatch path (`register_plugin_scalar_udfs`) consumes a
174/// `PluginRegistry`, so we mirror every legacy registration into one under
175/// the reserved [`LEGACY_USER_PLUGIN_ID`] here, where the `uni-plugin`
176/// dependency is available.
177///
178/// Each entry is wrapped in a [`ValueRowFn`] adapter and given a permissive
179/// `CypherValue` signature — the actual coercion happens at the DataFusion
180/// adapter ([`PluginScalarUdf`]) site.
181fn plugin_registry_for_custom_functions(
182    registry: &CustomFunctionRegistry,
183) -> uni_plugin::PluginRegistry {
184    use datafusion::logical_expr::Volatility;
185    use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling};
186    use uni_plugin::{Capability, CapabilitySet, PluginId, PluginRegistrar, PluginRegistry, QName};
187
188    let pr = PluginRegistry::new();
189    let plugin_id = PluginId::new(LEGACY_USER_PLUGIN_ID);
190    let caps = CapabilitySet::from_iter_of([Capability::ScalarFn]);
191
192    for (name, func) in registry.iter() {
193        // `CustomFunctionRegistry` already uppercases names on registration,
194        // but normalize defensively so the qname matches the plugin id's
195        // namespace expectations.
196        let upper = name.to_uppercase();
197        let mut r = PluginRegistrar::new(plugin_id.clone(), &caps, &pr);
198        let qname = QName::new(LEGACY_USER_PLUGIN_ID, &upper);
199        let adapter = Arc::new(ValueRowFn::new(upper.clone(), Arc::clone(func)));
200        let sig = FnSignature {
201            args: vec![ArgType::Variadic(Box::new(ArgType::CypherValue))],
202            returns: ArgType::CypherValue,
203            volatility: Volatility::Volatile,
204            null_handling: NullHandling::UserHandled,
205        };
206        if let Err(e) = r.scalar_fn(qname, sig, adapter) {
207            tracing::warn!(error = ?e, fn_name = %upper, "shadow registration failed");
208            continue;
209        }
210        if let Err(e) = r.commit_to_registry() {
211            tracing::warn!(error = ?e, fn_name = %upper, "shadow commit failed");
212        }
213    }
214    pr
215}
216
217/// Register the legacy [`CustomFunctionRegistry`] entries as DataFusion
218/// scalar UDFs by mirroring them through the plugin-framework adapter.
219///
220/// This is the instance-scope, legacy `CustomFunctionRegistry` shadow path
221/// (e.g. `db.register_function()` entries + apoc-core mirrors).
222///
223/// # Errors
224///
225/// Returns an error if any UDF registration fails.
226pub fn register_custom_functions_as_plugin_scalars(
227    ctx: &SessionContext,
228    registry: &CustomFunctionRegistry,
229) -> DFResult<()> {
230    let shadow = plugin_registry_for_custom_functions(registry);
231    register_plugin_scalar_udfs(ctx, &shadow)
232}
233
234/// DataFusion adapter wrapping a [`uni_plugin::registry::ScalarEntry`].
235///
236/// Inspects the plugin's `signature.returns` at construction time to pick
237/// the DataFusion return type:
238///
239/// - `ArgType::Primitive(T)` → declares `T` directly to DataFusion. The
240///   plugin's `invoke()` returns Arrow data in `T`'s native type, no
241///   LargeBinary round-trip. This is the ≥ 20% perf target path for
242///   primitively-typed UDFs.
243/// - `ArgType::CypherValue` → declares `LargeBinary` (legacy transport).
244/// - `ArgType::Vector { .. }` / `Variadic(..)` → `LargeBinary` for now.
245///
246/// The same adapter is used for both the local-name and qualified-name
247/// registrations.
248struct PluginScalarUdf {
249    name: String,
250    entry: Arc<uni_plugin::registry::ScalarEntry>,
251    signature: Signature,
252    return_type: DataType,
253}
254
255impl PluginScalarUdf {
256    fn new(name: String, entry: Arc<uni_plugin::registry::ScalarEntry>) -> Self {
257        // Derive volatility from the plugin's declared volatility, falling
258        // back to Volatile if signature inspection fails. (The plugin's
259        // FnSignature is the canonical source of truth.)
260        let volatility = entry.signature.volatility;
261        let return_type = derive_return_type(&entry);
262        Self {
263            signature: Signature::new(TypeSignature::VariadicAny, volatility),
264            name,
265            entry,
266            return_type,
267        }
268    }
269}
270
271/// Derive the DataFusion return type from the plugin's declared signature.
272fn derive_return_type(entry: &uni_plugin::registry::ScalarEntry) -> DataType {
273    use uni_plugin::traits::scalar::ArgType;
274    match &entry.signature.returns {
275        ArgType::Primitive(t) => t.clone(),
276        // CypherValue + Vector + Variadic stay on the LargeBinary path
277        // (the latter two are uncommon for return types; CypherValue is
278        // explicit opt-in to the legacy transport).
279        _ => DataType::LargeBinary,
280    }
281}
282
283impl std::fmt::Debug for PluginScalarUdf {
284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        f.debug_struct("PluginScalarUdf")
286            .field("name", &self.name)
287            .finish()
288    }
289}
290
291impl PartialEq for PluginScalarUdf {
292    fn eq(&self, other: &Self) -> bool {
293        self.signature == other.signature
294    }
295}
296
297impl Eq for PluginScalarUdf {}
298
299impl Hash for PluginScalarUdf {
300    fn hash<H: Hasher>(&self, state: &mut H) {
301        self.name().hash(state);
302    }
303}
304
305impl ScalarUDFImpl for PluginScalarUdf {
306    fn as_any(&self) -> &dyn Any {
307        self
308    }
309
310    fn name(&self) -> &str {
311        &self.name
312    }
313
314    fn signature(&self) -> &Signature {
315        &self.signature
316    }
317
318    fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
319        Ok(self.return_type.clone())
320    }
321
322    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
323        let entry = Arc::clone(&self.entry);
324        let rows = args.number_rows;
325        let cols = args.args;
326        entry.function.invoke(&cols, rows).map_err(|e| {
327            datafusion::error::DataFusionError::Execution(format!(
328                "plugin `{}` fn `{}` failed: {e}",
329                entry.plugin, self.name
330            ))
331        })
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    // `SessionContext::udf` is a `FunctionRegistry` trait method; bring the
339    // trait into scope so the registration assertions resolve. `Volatility`
340    // is used only by the in-test plugin fixtures.
341    use datafusion::execution::FunctionRegistry;
342    use datafusion::logical_expr::Volatility;
343
344    #[test]
345    fn test_register_plugin_scalars_routes_through_plugin_registry() {
346        // M2 facade test: register a scalar via the legacy
347        // CustomFunctionRegistry, then verify that calling
348        // `register_plugin_scalar_udfs` against its `plugin_registry()`
349        // exposes the same fn through DataFusion under both case-folds and
350        // the fully-qualified namespace.
351        use uni_common::Value;
352        use uni_query_functions::custom_functions::{CustomFunctionRegistry, CustomScalarFn};
353
354        let mut reg = CustomFunctionRegistry::new();
355        let f: CustomScalarFn =
356            Arc::new(|_args: &[Value]| Ok(Value::String("plugin-path".to_owned())));
357        reg.register("MYFN".to_owned(), f);
358
359        let ctx = SessionContext::new();
360        register_custom_functions_as_plugin_scalars(&ctx, &reg).unwrap();
361
362        // Local-name lowercase form (Cypher case-insensitive dispatch).
363        assert!(ctx.udf("myfn").is_ok());
364        // Uppercase local name.
365        assert!(ctx.udf("MYFN").is_ok());
366        // Fully-qualified namespace form.
367        let qname = format!("{LEGACY_USER_PLUGIN_ID}.MYFN");
368        assert!(ctx.udf(&qname).is_ok());
369    }
370
371    #[test]
372    fn test_native_arrow_udf_declares_primitive_return_type() {
373        // M2 fast path: a plugin declaring `ArgType::Primitive(Float64)` as
374        // its return type should produce a DataFusion UDF whose
375        // `return_type` is `Float64`, not `LargeBinary`. This eliminates
376        // the per-row LargeBinary round-trip.
377        use std::sync::OnceLock;
378        use uni_plugin::FnError;
379        use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling, ScalarPluginFn};
380        use uni_plugin::{
381            Capability, CapabilitySet, PluginId, PluginRegistrar, PluginRegistry, QName,
382        };
383
384        struct DoubleIt;
385        impl ScalarPluginFn for DoubleIt {
386            fn signature(&self) -> &FnSignature {
387                static S: OnceLock<FnSignature> = OnceLock::new();
388                S.get_or_init(|| FnSignature {
389                    args: vec![ArgType::Primitive(DataType::Float64)],
390                    returns: ArgType::Primitive(DataType::Float64),
391                    volatility: Volatility::Immutable,
392                    null_handling: NullHandling::PropagateNulls,
393                })
394            }
395            fn invoke(
396                &self,
397                args: &[ColumnarValue],
398                _rows: usize,
399            ) -> Result<ColumnarValue, FnError> {
400                Ok(args.first().cloned().unwrap())
401            }
402        }
403
404        let pr = PluginRegistry::new();
405        let caps = CapabilitySet::from_iter_of([Capability::ScalarFn]);
406        let mut r = PluginRegistrar::new(PluginId::new("test.fast"), &caps, &pr);
407        r.scalar_fn(
408            QName::new("test.fast", "double"),
409            FnSignature {
410                args: vec![ArgType::Primitive(DataType::Float64)],
411                returns: ArgType::Primitive(DataType::Float64),
412                volatility: Volatility::Immutable,
413                null_handling: NullHandling::PropagateNulls,
414            },
415            Arc::new(DoubleIt),
416        )
417        .unwrap();
418        r.commit_to_registry().unwrap();
419
420        let ctx = SessionContext::new();
421        register_plugin_scalar_udfs(&ctx, &pr).unwrap();
422
423        // Resolve the UDF and ask DataFusion for its return type.
424        let udf = ctx.udf("double").expect("udf registered");
425        let rt = udf.return_type(&[DataType::Float64]).unwrap();
426        assert_eq!(
427            rt,
428            DataType::Float64,
429            "primitive-typed plugin should declare Float64 directly, not LargeBinary"
430        );
431    }
432}