uni_query/query/df_udfs_plugin.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Plugin-registry scalar-UDF integration for DataFusion.
5//!
6//! This module bridges the `uni-plugin` plugin system into DataFusion's
7//! scalar-UDF surface. It lives in `uni-query` (not the dependency-light
8//! `uni-query-functions` leaf crate) because it depends on `uni-plugin`
9//! and `tokio` task-locals.
10//!
11//! Responsibilities:
12//!
13//! - The `SESSION_PLUGIN_REGISTRY` tokio task-local and its scope/read
14//! helpers ([`scoped_with_session_plugin_registry`],
15//! [`current_session_plugin_registry`]).
16//! - Re-export of the principal task-local helpers from
17//! `uni_plugin::host::principal` so callers keep their existing
18//! `uni_query::scoped_with_principal` / `current_principal` paths.
19//! - [`scoped_with_session_context`] combining both scopes.
20//! - [`register_plugin_scalar_udfs`] / [`register_plugin_scalar_udfs_pair`]
21//! and the `PluginScalarUdf` DataFusion adapter (private).
22
23use std::any::Any;
24use std::hash::{Hash, Hasher};
25use std::sync::Arc;
26
27use arrow::datatypes::DataType;
28use datafusion::error::Result as DFResult;
29use datafusion::logical_expr::{
30 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
31};
32use datafusion::prelude::SessionContext;
33use uni_query_functions::custom_functions::{CustomFunctionRegistry, LEGACY_USER_PLUGIN_ID};
34
35use crate::query::executor::plugin_adapter::ValueRowFn;
36
37tokio::task_local! {
38 /// Tokio task-local carrying the current `Session`'s
39 /// **session-local** plugin registry across the per-query
40 /// executor scope. Set by host-crate session execute paths via
41 /// [`scoped_with_session_plugin_registry`]; read at the UDF
42 /// registration site (`register_plugin_scalar_udfs_pair`) and at
43 /// the procedure / Locy-aggregate dual-consult helpers.
44 ///
45 /// Propagates across `.await` points within the same task tree;
46 /// does NOT propagate across `tokio::spawn` (which is fine — the
47 /// per-query executor runs everything in the same task).
48 pub static SESSION_PLUGIN_REGISTRY:
49 std::sync::Arc<uni_plugin::PluginRegistry>;
50}
51
52/// Run `fut` inside a scope that exposes `registry` as the current
53/// session-local plugin registry. Returns the future's output.
54///
55/// Use this at every uni-db host-crate boundary where a `Session`
56/// dispatches into the executor.
57pub fn scoped_with_session_plugin_registry<F: std::future::Future>(
58 registry: std::sync::Arc<uni_plugin::PluginRegistry>,
59 fut: F,
60) -> tokio::task::futures::TaskLocalFuture<std::sync::Arc<uni_plugin::PluginRegistry>, F> {
61 SESSION_PLUGIN_REGISTRY.scope(registry, fut)
62}
63
64/// Borrow the current session-local plugin registry, if any. Returns
65/// `None` when the call is not inside a
66/// [`scoped_with_session_plugin_registry`] scope (e.g., a query
67/// against `Uni` directly with no Session in flight, or a unit test
68/// invoking the executor outside the host crate).
69#[must_use]
70pub fn current_session_plugin_registry() -> Option<std::sync::Arc<uni_plugin::PluginRegistry>> {
71 SESSION_PLUGIN_REGISTRY.try_with(|r| r.clone()).ok()
72}
73
74// §1.2 / Phase 5: the principal task-local + scope helpers moved to
75// `uni_plugin::host::principal`. Re-exported here so external callers
76// (`uni::api::{session,transaction}`, downstream embedders) keep their
77// existing `uni_query::scoped_with_principal` / `current_principal`
78// paths.
79pub use uni_plugin::host::principal::{
80 CURRENT_PRINCIPAL, current_principal, maybe_scope_with_principal, scoped_with_principal,
81};
82
83/// Run `fut` inside both [`scoped_with_session_plugin_registry`] and
84/// the principal task-local scope in a single call.
85///
86/// `principal` is optional — when `None`, only the plugin-registry
87/// scope is installed and [`current_principal`] returns `None` inside
88/// `fut`. This matches the legacy behavior for sessions that haven't
89/// authenticated.
90pub async fn scoped_with_session_context<F: std::future::Future>(
91 registry: std::sync::Arc<uni_plugin::PluginRegistry>,
92 principal: Option<std::sync::Arc<uni_plugin::traits::connector::Principal>>,
93 fut: F,
94) -> F::Output {
95 scoped_with_session_plugin_registry(registry, maybe_scope_with_principal(principal, fut)).await
96}
97
98/// Two-registry variant of [`register_plugin_scalar_udfs`] — registers
99/// the instance registry's scalars first, then the session registry's
100/// (if present) on top. DataFusion's `register_udf` is last-write-wins
101/// by registered name, so session entries shadow instance entries
102/// without any explicit ordering logic.
103///
104/// This is the resolution path used per-query when a `Session` carries
105/// a session-local plugin registry. See `proposal §5.4.2` for the
106/// session-scope contract and the M8.6 follow-up plan.
107///
108/// # Errors
109///
110/// Returns an error if any UDF registration fails.
111pub fn register_plugin_scalar_udfs_pair(
112 ctx: &SessionContext,
113 instance: &uni_plugin::PluginRegistry,
114 session: Option<&uni_plugin::PluginRegistry>,
115) -> DFResult<()> {
116 register_plugin_scalar_udfs(ctx, instance)?;
117 if let Some(session_reg) = session {
118 register_plugin_scalar_udfs(ctx, session_reg)?;
119 }
120 Ok(())
121}
122
123/// Register every scalar function in a `PluginRegistry` as a DataFusion UDF.
124///
125/// M2's plugin-path DataFusion adapter — iterates
126/// [`uni_plugin::PluginRegistry`] directly.
127///
128/// Registers each scalar as both lowercase and uppercase local-name
129/// variants so Cypher's case-insensitive function-name match resolves.
130/// The qname's namespace is preserved (Cypher syntax uses dotted names for
131/// qualified callable references).
132///
133/// # Errors
134///
135/// Returns an error if any UDF registration fails.
136pub fn register_plugin_scalar_udfs(
137 ctx: &SessionContext,
138 plugin_registry: &uni_plugin::PluginRegistry,
139) -> DFResult<()> {
140 for (qname, entry) in plugin_registry.iter_scalars() {
141 let local = qname.local();
142 let lower_local = local.to_lowercase();
143 let upper_local = local.to_uppercase();
144
145 // Local-name registrations — what Cypher's case-insensitive
146 // lookup hits.
147 if lower_local != upper_local {
148 ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
149 lower_local.clone(),
150 Arc::clone(&entry),
151 )));
152 }
153 ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
154 upper_local,
155 Arc::clone(&entry),
156 )));
157
158 // Also register under the fully-qualified name (`namespace.local`)
159 // so dotted-name dispatch works.
160 ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
161 qname.to_string(),
162 Arc::clone(&entry),
163 )));
164 }
165 Ok(())
166}
167
168/// Build a shadow [`uni_plugin::PluginRegistry`] from the pure
169/// [`CustomFunctionRegistry`] leaf type.
170///
171/// `CustomFunctionRegistry` lives in the dependency-light
172/// `uni-query-functions` crate and stores only `(name, fn)` pairs. The
173/// plugin-framework dispatch path (`register_plugin_scalar_udfs`) consumes a
174/// `PluginRegistry`, so we mirror every legacy registration into one under
175/// the reserved [`LEGACY_USER_PLUGIN_ID`] here, where the `uni-plugin`
176/// dependency is available.
177///
178/// Each entry is wrapped in a [`ValueRowFn`] adapter and given a permissive
179/// `CypherValue` signature — the actual coercion happens at the DataFusion
180/// adapter ([`PluginScalarUdf`]) site.
181fn plugin_registry_for_custom_functions(
182 registry: &CustomFunctionRegistry,
183) -> uni_plugin::PluginRegistry {
184 use datafusion::logical_expr::Volatility;
185 use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling};
186 use uni_plugin::{Capability, CapabilitySet, PluginId, PluginRegistrar, PluginRegistry, QName};
187
188 let pr = PluginRegistry::new();
189 let plugin_id = PluginId::new(LEGACY_USER_PLUGIN_ID);
190 let caps = CapabilitySet::from_iter_of([Capability::ScalarFn]);
191
192 for (name, func) in registry.iter() {
193 // `CustomFunctionRegistry` already uppercases names on registration,
194 // but normalize defensively so the qname matches the plugin id's
195 // namespace expectations.
196 let upper = name.to_uppercase();
197 let mut r = PluginRegistrar::new(plugin_id.clone(), &caps, &pr);
198 let qname = QName::new(LEGACY_USER_PLUGIN_ID, &upper);
199 let adapter = Arc::new(ValueRowFn::new(upper.clone(), Arc::clone(func)));
200 let sig = FnSignature {
201 args: vec![ArgType::Variadic(Box::new(ArgType::CypherValue))],
202 returns: ArgType::CypherValue,
203 volatility: Volatility::Volatile,
204 null_handling: NullHandling::UserHandled,
205 };
206 if let Err(e) = r.scalar_fn(qname, sig, adapter) {
207 tracing::warn!(error = ?e, fn_name = %upper, "shadow registration failed");
208 continue;
209 }
210 if let Err(e) = r.commit_to_registry() {
211 tracing::warn!(error = ?e, fn_name = %upper, "shadow commit failed");
212 }
213 }
214 pr
215}
216
217/// Register the legacy [`CustomFunctionRegistry`] entries as DataFusion
218/// scalar UDFs by mirroring them through the plugin-framework adapter.
219///
220/// This is the instance-scope, legacy `CustomFunctionRegistry` shadow path
221/// (e.g. `db.register_function()` entries + apoc-core mirrors).
222///
223/// # Errors
224///
225/// Returns an error if any UDF registration fails.
226pub fn register_custom_functions_as_plugin_scalars(
227 ctx: &SessionContext,
228 registry: &CustomFunctionRegistry,
229) -> DFResult<()> {
230 let shadow = plugin_registry_for_custom_functions(registry);
231 register_plugin_scalar_udfs(ctx, &shadow)
232}
233
234/// DataFusion adapter wrapping a [`uni_plugin::registry::ScalarEntry`].
235///
236/// Inspects the plugin's `signature.returns` at construction time to pick
237/// the DataFusion return type:
238///
239/// - `ArgType::Primitive(T)` → declares `T` directly to DataFusion. The
240/// plugin's `invoke()` returns Arrow data in `T`'s native type, no
241/// LargeBinary round-trip. This is the ≥ 20% perf target path for
242/// primitively-typed UDFs.
243/// - `ArgType::CypherValue` → declares `LargeBinary` (legacy transport).
244/// - `ArgType::Vector { .. }` / `Variadic(..)` → `LargeBinary` for now.
245///
246/// The same adapter is used for both the local-name and qualified-name
247/// registrations.
248struct PluginScalarUdf {
249 name: String,
250 entry: Arc<uni_plugin::registry::ScalarEntry>,
251 signature: Signature,
252 return_type: DataType,
253}
254
255impl PluginScalarUdf {
256 fn new(name: String, entry: Arc<uni_plugin::registry::ScalarEntry>) -> Self {
257 // Derive volatility from the plugin's declared volatility, falling
258 // back to Volatile if signature inspection fails. (The plugin's
259 // FnSignature is the canonical source of truth.)
260 let volatility = entry.signature.volatility;
261 let return_type = derive_return_type(&entry);
262 Self {
263 signature: Signature::new(TypeSignature::VariadicAny, volatility),
264 name,
265 entry,
266 return_type,
267 }
268 }
269}
270
271/// Derive the DataFusion return type from the plugin's declared signature.
272fn derive_return_type(entry: &uni_plugin::registry::ScalarEntry) -> DataType {
273 use uni_plugin::traits::scalar::ArgType;
274 match &entry.signature.returns {
275 ArgType::Primitive(t) => t.clone(),
276 // CypherValue + Vector + Variadic stay on the LargeBinary path
277 // (the latter two are uncommon for return types; CypherValue is
278 // explicit opt-in to the legacy transport).
279 _ => DataType::LargeBinary,
280 }
281}
282
283impl std::fmt::Debug for PluginScalarUdf {
284 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 f.debug_struct("PluginScalarUdf")
286 .field("name", &self.name)
287 .finish()
288 }
289}
290
291impl PartialEq for PluginScalarUdf {
292 fn eq(&self, other: &Self) -> bool {
293 self.signature == other.signature
294 }
295}
296
297impl Eq for PluginScalarUdf {}
298
299impl Hash for PluginScalarUdf {
300 fn hash<H: Hasher>(&self, state: &mut H) {
301 self.name().hash(state);
302 }
303}
304
305impl ScalarUDFImpl for PluginScalarUdf {
306 fn as_any(&self) -> &dyn Any {
307 self
308 }
309
310 fn name(&self) -> &str {
311 &self.name
312 }
313
314 fn signature(&self) -> &Signature {
315 &self.signature
316 }
317
318 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
319 Ok(self.return_type.clone())
320 }
321
322 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
323 let entry = Arc::clone(&self.entry);
324 let rows = args.number_rows;
325 let cols = args.args;
326 entry.function.invoke(&cols, rows).map_err(|e| {
327 datafusion::error::DataFusionError::Execution(format!(
328 "plugin `{}` fn `{}` failed: {e}",
329 entry.plugin, self.name
330 ))
331 })
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 // `SessionContext::udf` is a `FunctionRegistry` trait method; bring the
339 // trait into scope so the registration assertions resolve. `Volatility`
340 // is used only by the in-test plugin fixtures.
341 use datafusion::execution::FunctionRegistry;
342 use datafusion::logical_expr::Volatility;
343
344 #[test]
345 fn test_register_plugin_scalars_routes_through_plugin_registry() {
346 // M2 facade test: register a scalar via the legacy
347 // CustomFunctionRegistry, then verify that calling
348 // `register_plugin_scalar_udfs` against its `plugin_registry()`
349 // exposes the same fn through DataFusion under both case-folds and
350 // the fully-qualified namespace.
351 use uni_common::Value;
352 use uni_query_functions::custom_functions::{CustomFunctionRegistry, CustomScalarFn};
353
354 let mut reg = CustomFunctionRegistry::new();
355 let f: CustomScalarFn =
356 Arc::new(|_args: &[Value]| Ok(Value::String("plugin-path".to_owned())));
357 reg.register("MYFN".to_owned(), f);
358
359 let ctx = SessionContext::new();
360 register_custom_functions_as_plugin_scalars(&ctx, ®).unwrap();
361
362 // Local-name lowercase form (Cypher case-insensitive dispatch).
363 assert!(ctx.udf("myfn").is_ok());
364 // Uppercase local name.
365 assert!(ctx.udf("MYFN").is_ok());
366 // Fully-qualified namespace form.
367 let qname = format!("{LEGACY_USER_PLUGIN_ID}.MYFN");
368 assert!(ctx.udf(&qname).is_ok());
369 }
370
371 #[test]
372 fn test_native_arrow_udf_declares_primitive_return_type() {
373 // M2 fast path: a plugin declaring `ArgType::Primitive(Float64)` as
374 // its return type should produce a DataFusion UDF whose
375 // `return_type` is `Float64`, not `LargeBinary`. This eliminates
376 // the per-row LargeBinary round-trip.
377 use std::sync::OnceLock;
378 use uni_plugin::FnError;
379 use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling, ScalarPluginFn};
380 use uni_plugin::{
381 Capability, CapabilitySet, PluginId, PluginRegistrar, PluginRegistry, QName,
382 };
383
384 struct DoubleIt;
385 impl ScalarPluginFn for DoubleIt {
386 fn signature(&self) -> &FnSignature {
387 static S: OnceLock<FnSignature> = OnceLock::new();
388 S.get_or_init(|| FnSignature {
389 args: vec![ArgType::Primitive(DataType::Float64)],
390 returns: ArgType::Primitive(DataType::Float64),
391 volatility: Volatility::Immutable,
392 null_handling: NullHandling::PropagateNulls,
393 })
394 }
395 fn invoke(
396 &self,
397 args: &[ColumnarValue],
398 _rows: usize,
399 ) -> Result<ColumnarValue, FnError> {
400 Ok(args.first().cloned().unwrap())
401 }
402 }
403
404 let pr = PluginRegistry::new();
405 let caps = CapabilitySet::from_iter_of([Capability::ScalarFn]);
406 let mut r = PluginRegistrar::new(PluginId::new("test.fast"), &caps, &pr);
407 r.scalar_fn(
408 QName::new("test.fast", "double"),
409 FnSignature {
410 args: vec![ArgType::Primitive(DataType::Float64)],
411 returns: ArgType::Primitive(DataType::Float64),
412 volatility: Volatility::Immutable,
413 null_handling: NullHandling::PropagateNulls,
414 },
415 Arc::new(DoubleIt),
416 )
417 .unwrap();
418 r.commit_to_registry().unwrap();
419
420 let ctx = SessionContext::new();
421 register_plugin_scalar_udfs(&ctx, &pr).unwrap();
422
423 // Resolve the UDF and ask DataFusion for its return type.
424 let udf = ctx.udf("double").expect("udf registered");
425 let rt = udf.return_type(&[DataType::Float64]).unwrap();
426 assert_eq!(
427 rt,
428 DataType::Float64,
429 "primitive-typed plugin should declare Float64 directly, not LargeBinary"
430 );
431 }
432}