1use std::any::Any;
24use std::hash::{Hash, Hasher};
25use std::sync::Arc;
26
27use arrow::datatypes::DataType;
28use datafusion::error::Result as DFResult;
29use datafusion::logical_expr::{
30 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
31};
32use datafusion::prelude::SessionContext;
33use uni_query_functions::custom_functions::{CustomFunctionRegistry, LEGACY_USER_PLUGIN_ID};
34
35use crate::query::executor::plugin_adapter::ValueRowFn;
36
37tokio::task_local! {
38 pub static SESSION_PLUGIN_REGISTRY:
49 std::sync::Arc<uni_plugin::PluginRegistry>;
50}
51
52pub fn scoped_with_session_plugin_registry<F: std::future::Future>(
58 registry: std::sync::Arc<uni_plugin::PluginRegistry>,
59 fut: F,
60) -> tokio::task::futures::TaskLocalFuture<std::sync::Arc<uni_plugin::PluginRegistry>, F> {
61 SESSION_PLUGIN_REGISTRY.scope(registry, fut)
62}
63
64#[must_use]
70pub fn current_session_plugin_registry() -> Option<std::sync::Arc<uni_plugin::PluginRegistry>> {
71 SESSION_PLUGIN_REGISTRY.try_with(|r| r.clone()).ok()
72}
73
74pub use uni_plugin::host::principal::{
80 CURRENT_PRINCIPAL, current_principal, maybe_scope_with_principal, scoped_with_principal,
81};
82
83pub async fn scoped_with_session_context<F: std::future::Future>(
91 registry: std::sync::Arc<uni_plugin::PluginRegistry>,
92 principal: Option<std::sync::Arc<uni_plugin::traits::connector::Principal>>,
93 fut: F,
94) -> F::Output {
95 scoped_with_session_plugin_registry(registry, maybe_scope_with_principal(principal, fut)).await
96}
97
98pub fn register_plugin_scalar_udfs_pair(
112 ctx: &SessionContext,
113 instance: &uni_plugin::PluginRegistry,
114 session: Option<&uni_plugin::PluginRegistry>,
115) -> DFResult<()> {
116 register_plugin_scalar_udfs(ctx, instance)?;
117 if let Some(session_reg) = session {
118 register_plugin_scalar_udfs(ctx, session_reg)?;
119 }
120 Ok(())
121}
122
123pub fn register_plugin_scalar_udfs(
137 ctx: &SessionContext,
138 plugin_registry: &uni_plugin::PluginRegistry,
139) -> DFResult<()> {
140 for (qname, entry) in plugin_registry.iter_scalars() {
141 let local = qname.local();
142 let lower_local = local.to_lowercase();
143 let upper_local = local.to_uppercase();
144
145 if lower_local != upper_local {
148 ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
149 lower_local.clone(),
150 Arc::clone(&entry),
151 )));
152 }
153 ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
154 upper_local,
155 Arc::clone(&entry),
156 )));
157
158 ctx.register_udf(ScalarUDF::new_from_impl(PluginScalarUdf::new(
161 qname.to_string(),
162 Arc::clone(&entry),
163 )));
164 }
165 Ok(())
166}
167
168fn plugin_registry_for_custom_functions(
182 registry: &CustomFunctionRegistry,
183) -> uni_plugin::PluginRegistry {
184 use datafusion::logical_expr::Volatility;
185 use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling};
186 use uni_plugin::{Capability, CapabilitySet, PluginId, PluginRegistrar, PluginRegistry, QName};
187
188 let pr = PluginRegistry::new();
189 let plugin_id = PluginId::new(LEGACY_USER_PLUGIN_ID);
190 let caps = CapabilitySet::from_iter_of([Capability::ScalarFn]);
191
192 for (name, func) in registry.iter() {
193 let upper = name.to_uppercase();
197 let mut r = PluginRegistrar::new(plugin_id.clone(), &caps, &pr);
198 let qname = QName::new(LEGACY_USER_PLUGIN_ID, &upper);
199 let adapter = Arc::new(ValueRowFn::new(upper.clone(), Arc::clone(func)));
200 let sig = FnSignature {
201 args: vec![ArgType::Variadic(Box::new(ArgType::CypherValue))],
202 returns: ArgType::CypherValue,
203 volatility: Volatility::Volatile,
204 null_handling: NullHandling::UserHandled,
205 };
206 if let Err(e) = r.scalar_fn(qname, sig, adapter) {
207 tracing::warn!(error = ?e, fn_name = %upper, "shadow registration failed");
208 continue;
209 }
210 if let Err(e) = r.commit_to_registry() {
211 tracing::warn!(error = ?e, fn_name = %upper, "shadow commit failed");
212 }
213 }
214 pr
215}
216
217pub fn register_custom_functions_as_plugin_scalars(
227 ctx: &SessionContext,
228 registry: &CustomFunctionRegistry,
229) -> DFResult<()> {
230 let shadow = plugin_registry_for_custom_functions(registry);
231 register_plugin_scalar_udfs(ctx, &shadow)
232}
233
234struct PluginScalarUdf {
249 name: String,
250 entry: Arc<uni_plugin::registry::ScalarEntry>,
251 signature: Signature,
252 return_type: DataType,
253}
254
255impl PluginScalarUdf {
256 fn new(name: String, entry: Arc<uni_plugin::registry::ScalarEntry>) -> Self {
257 let volatility = entry.signature.volatility;
261 let return_type = derive_return_type(&entry);
262 Self {
263 signature: Signature::new(TypeSignature::VariadicAny, volatility),
264 name,
265 entry,
266 return_type,
267 }
268 }
269}
270
271fn derive_return_type(entry: &uni_plugin::registry::ScalarEntry) -> DataType {
273 use uni_plugin::traits::scalar::ArgType;
274 match &entry.signature.returns {
275 ArgType::Primitive(t) => t.clone(),
276 _ => DataType::LargeBinary,
280 }
281}
282
283impl std::fmt::Debug for PluginScalarUdf {
284 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 f.debug_struct("PluginScalarUdf")
286 .field("name", &self.name)
287 .finish()
288 }
289}
290
291fn declared_arg_type(
295 args: &[uni_plugin::traits::scalar::ArgType],
296 i: usize,
297) -> Option<&uni_plugin::traits::scalar::ArgType> {
298 use uni_plugin::traits::scalar::ArgType;
299 let raw = match args.get(i) {
300 Some(a) => Some(a),
301 None => match args.last() {
303 Some(v @ ArgType::Variadic(_)) => Some(v),
304 _ => None,
305 },
306 };
307 match raw {
308 Some(ArgType::Variadic(inner)) => Some(inner.as_ref()),
309 other => other,
310 }
311}
312
313fn coerce_plugin_scalar_arg(
330 col: ColumnarValue,
331 declared: Option<&uni_plugin::traits::scalar::ArgType>,
332 rows: usize,
333 arg_idx: usize,
334 fn_name: &str,
335) -> DFResult<ColumnarValue> {
336 use arrow::array::{Array, ArrayRef, Float64Array, Int64Array, LargeBinaryArray};
337 use uni_common::Value;
338 use uni_plugin::traits::scalar::ArgType;
339
340 let target = match declared {
341 Some(ArgType::Primitive(t @ (DataType::Int64 | DataType::Float64))) => t.clone(),
342 _ => return Ok(col),
343 };
344
345 let array = col.to_array(rows)?;
346 if array.data_type() != &DataType::LargeBinary {
348 return Ok(col);
349 }
350 let lb = array
351 .as_any()
352 .downcast_ref::<LargeBinaryArray>()
353 .expect("data_type checked to be LargeBinary");
354
355 let non_numeric_err = |row: usize, got: &Value| {
356 let hint = if target == DataType::Int64 {
357 "toInteger(...)"
358 } else {
359 "toFloat(...)"
360 };
361 datafusion::error::DataFusionError::Execution(format!(
362 "plugin fn `{fn_name}`: argument {} declares {target} but row {row} carried a \
363 non-numeric value ({got:?}); wrap the property with {hint}",
364 arg_idx + 1,
365 ))
366 };
367
368 let decoded =
369 |row: usize| uni_store::storage::arrow_convert::arrow_to_value(array.as_ref(), row, None);
370
371 let out: ArrayRef = match target {
372 DataType::Int64 => {
373 let mut b = Int64Array::builder(array.len());
374 for row in 0..array.len() {
375 if lb.is_null(row) || lb.value(row).is_empty() {
376 b.append_null();
377 continue;
378 }
379 match decoded(row) {
380 Value::Int(i) => b.append_value(i),
381 Value::Float(f) => b.append_value(f as i64),
382 Value::Null => b.append_null(),
383 other => return Err(non_numeric_err(row, &other)),
384 }
385 }
386 Arc::new(b.finish())
387 }
388 DataType::Float64 => {
389 let mut b = Float64Array::builder(array.len());
390 for row in 0..array.len() {
391 if lb.is_null(row) || lb.value(row).is_empty() {
392 b.append_null();
393 continue;
394 }
395 match decoded(row) {
396 Value::Float(f) => b.append_value(f),
397 Value::Int(i) => b.append_value(i as f64),
398 Value::Null => b.append_null(),
399 other => return Err(non_numeric_err(row, &other)),
400 }
401 }
402 Arc::new(b.finish())
403 }
404 _ => unreachable!("target restricted to Int64/Float64 above"),
405 };
406 Ok(ColumnarValue::Array(out))
407}
408
409impl PartialEq for PluginScalarUdf {
410 fn eq(&self, other: &Self) -> bool {
411 self.signature == other.signature
412 }
413}
414
415impl Eq for PluginScalarUdf {}
416
417impl Hash for PluginScalarUdf {
418 fn hash<H: Hasher>(&self, state: &mut H) {
419 self.name().hash(state);
420 }
421}
422
423impl ScalarUDFImpl for PluginScalarUdf {
424 fn as_any(&self) -> &dyn Any {
425 self
426 }
427
428 fn name(&self) -> &str {
429 &self.name
430 }
431
432 fn signature(&self) -> &Signature {
433 &self.signature
434 }
435
436 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
437 Ok(self.return_type.clone())
438 }
439
440 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
441 let entry = Arc::clone(&self.entry);
442 let rows = args.number_rows;
443 let declared = &entry.signature.args;
447 let cols = args
448 .args
449 .into_iter()
450 .enumerate()
451 .map(|(i, col)| {
452 coerce_plugin_scalar_arg(col, declared_arg_type(declared, i), rows, i, &self.name)
453 })
454 .collect::<DFResult<Vec<_>>>()?;
455 entry.function.invoke(&cols, rows).map_err(|e| {
456 datafusion::error::DataFusionError::Execution(format!(
457 "plugin `{}` fn `{}` failed: {e}",
458 entry.plugin, self.name
459 ))
460 })
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use datafusion::execution::FunctionRegistry;
471 use datafusion::logical_expr::Volatility;
472
473 #[test]
474 fn test_register_plugin_scalars_routes_through_plugin_registry() {
475 use uni_common::Value;
481 use uni_query_functions::custom_functions::{CustomFunctionRegistry, CustomScalarFn};
482
483 let mut reg = CustomFunctionRegistry::new();
484 let f: CustomScalarFn =
485 Arc::new(|_args: &[Value]| Ok(Value::String("plugin-path".to_owned())));
486 reg.register("MYFN".to_owned(), f);
487
488 let ctx = SessionContext::new();
489 register_custom_functions_as_plugin_scalars(&ctx, ®).unwrap();
490
491 assert!(ctx.udf("myfn").is_ok());
493 assert!(ctx.udf("MYFN").is_ok());
495 let qname = format!("{LEGACY_USER_PLUGIN_ID}.MYFN");
497 assert!(ctx.udf(&qname).is_ok());
498 }
499
500 #[test]
501 fn test_native_arrow_udf_declares_primitive_return_type() {
502 use std::sync::OnceLock;
507 use uni_plugin::FnError;
508 use uni_plugin::traits::scalar::{ArgType, FnSignature, NullHandling, ScalarPluginFn};
509 use uni_plugin::{
510 Capability, CapabilitySet, PluginId, PluginRegistrar, PluginRegistry, QName,
511 };
512
513 struct DoubleIt;
514 impl ScalarPluginFn for DoubleIt {
515 fn signature(&self) -> &FnSignature {
516 static S: OnceLock<FnSignature> = OnceLock::new();
517 S.get_or_init(|| FnSignature {
518 args: vec![ArgType::Primitive(DataType::Float64)],
519 returns: ArgType::Primitive(DataType::Float64),
520 volatility: Volatility::Immutable,
521 null_handling: NullHandling::PropagateNulls,
522 })
523 }
524 fn invoke(
525 &self,
526 args: &[ColumnarValue],
527 _rows: usize,
528 ) -> Result<ColumnarValue, FnError> {
529 Ok(args.first().cloned().unwrap())
530 }
531 }
532
533 let pr = PluginRegistry::new();
534 let caps = CapabilitySet::from_iter_of([Capability::ScalarFn]);
535 let mut r = PluginRegistrar::new(PluginId::new("test.fast"), &caps, &pr);
536 r.scalar_fn(
537 QName::new("test.fast", "double"),
538 FnSignature {
539 args: vec![ArgType::Primitive(DataType::Float64)],
540 returns: ArgType::Primitive(DataType::Float64),
541 volatility: Volatility::Immutable,
542 null_handling: NullHandling::PropagateNulls,
543 },
544 Arc::new(DoubleIt),
545 )
546 .unwrap();
547 r.commit_to_registry().unwrap();
548
549 let ctx = SessionContext::new();
550 register_plugin_scalar_udfs(&ctx, &pr).unwrap();
551
552 let udf = ctx.udf("double").expect("udf registered");
554 let rt = udf.return_type(&[DataType::Float64]).unwrap();
555 assert_eq!(
556 rt,
557 DataType::Float64,
558 "primitive-typed plugin should declare Float64 directly, not LargeBinary"
559 );
560 }
561}