Skip to main content

uni_query/query/
df_udaf_plugin.rs

1// Rust guideline compliant
2//! DataFusion adapter for [`uni_plugin::traits::aggregate::AggregatePluginFn`].
3//!
4//! Bridges plugin-registered aggregates (`AggregatePluginFn`) into the
5//! DataFusion `AggregateUDFImpl` surface so the Cypher planner can
6//! invoke `RETURN myAgg(x)` against any registry entry, not just the
7//! handful of built-ins hard-coded in `df_planner.rs::translate_aggregates`.
8//!
9//! M9 ships this in support of `uni.plugin.declareAggregate` (see
10//! `uni-plugin-custom::DeclaredAggregateFn`). The adapter is generic
11//! across any `AggregatePluginFn` source — it does not assume the
12//! declared shape.
13//!
14//! # State model
15//!
16//! Plugin aggregates' `AggSignature.state_fields` declares the schema
17//! of partial state for distributed aggregation. The M9 declared
18//! aggregates ship with `state_fields: vec![]` and
19//! `supports_partial: false`; the adapter respects whatever the
20//! registry entry declares.
21
22use std::any::Any;
23use std::hash::{Hash, Hasher};
24use std::sync::{Arc, Mutex};
25
26use arrow::array::ArrayRef;
27use arrow::datatypes::Field;
28use arrow_schema::DataType;
29use datafusion::error::{DataFusionError, Result as DFResult};
30use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
31use datafusion::logical_expr::{
32    Accumulator as DfAccumulator, AggregateUDFImpl, Signature, TypeSignature,
33};
34use datafusion::scalar::ScalarValue;
35use uni_plugin::traits::aggregate::{AggSignature, AggregatePluginFn, PluginAccumulator};
36use uni_plugin::traits::scalar::ArgType;
37use uni_plugin::{PluginRegistry, QName};
38
39/// DataFusion `AggregateUDFImpl` wrapping a plugin-registered
40/// aggregate looked up by [`QName`] in the shared [`PluginRegistry`].
41///
42/// Each accumulator call re-fetches the entry so hot-reload swaps land
43/// for the next group; in-flight groups keep their accumulator.
44pub struct PluginAggregateUdaf {
45    qname: QName,
46    name: String,
47    registry: Arc<PluginRegistry>,
48    sig: AggSignature,
49    df_signature: Signature,
50}
51
52impl PluginAggregateUdaf {
53    /// Construct an adapter over the named registry entry.
54    ///
55    /// `qname` and `sig` are captured at planner time; the actual
56    /// `AggregatePluginFn` is fetched per-accumulator-construction from
57    /// `registry` so reloads pick up.
58    #[must_use]
59    pub fn new(qname: QName, registry: Arc<PluginRegistry>, sig: AggSignature) -> Self {
60        let arity = sig.args.len();
61        let df_signature = Signature::new(TypeSignature::Any(arity), sig.volatility);
62        let name = format!("{}.{}", qname.namespace(), qname.local());
63        Self {
64            qname,
65            name,
66            registry,
67            sig,
68            df_signature,
69        }
70    }
71
72    fn fetch(&self) -> DFResult<Arc<dyn AggregatePluginFn>> {
73        self.registry
74            .aggregate(&self.qname)
75            .map(|e| Arc::clone(&e.aggregate))
76            .ok_or_else(|| {
77                DataFusionError::Execution(format!(
78                    "PluginAggregateUdaf: registry entry for `{}` disappeared",
79                    self.name
80                ))
81            })
82    }
83}
84
85impl std::fmt::Debug for PluginAggregateUdaf {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.debug_struct("PluginAggregateUdaf")
88            .field("qname", &self.qname)
89            .field("supports_partial", &self.sig.supports_partial)
90            .finish_non_exhaustive()
91    }
92}
93
94impl PartialEq for PluginAggregateUdaf {
95    fn eq(&self, other: &Self) -> bool {
96        self.qname == other.qname
97    }
98}
99
100impl Eq for PluginAggregateUdaf {}
101
102impl Hash for PluginAggregateUdaf {
103    fn hash<H: Hasher>(&self, state: &mut H) {
104        self.name.hash(state);
105    }
106}
107
108impl AggregateUDFImpl for PluginAggregateUdaf {
109    fn as_any(&self) -> &dyn Any {
110        self
111    }
112
113    fn name(&self) -> &str {
114        &self.name
115    }
116
117    fn signature(&self) -> &Signature {
118        &self.df_signature
119    }
120
121    fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
122        arg_type_to_arrow(&self.sig.returns).ok_or_else(|| {
123            DataFusionError::Execution(format!(
124                "PluginAggregateUdaf `{}`: non-Arrow return type",
125                self.name
126            ))
127        })
128    }
129
130    fn accumulator(&self, _args: AccumulatorArgs<'_>) -> DFResult<Box<dyn DfAccumulator>> {
131        let agg = self.fetch()?;
132        Ok(Box::new(PluginAccumulatorAdapter {
133            inner: Mutex::new(agg.create_accumulator()),
134        }))
135    }
136
137    fn state_fields(&self, _args: StateFieldsArgs<'_>) -> DFResult<Vec<Arc<Field>>> {
138        Ok(self
139            .sig
140            .state_fields
141            .iter()
142            .map(|f| Arc::new(f.clone()))
143            .collect())
144    }
145}
146
147/// DataFusion `Accumulator` that forwards to a [`PluginAccumulator`].
148///
149/// DataFusion's [`DfAccumulator`] trait requires `Send + Sync`, while
150/// the plugin trait only requires `Send`. The `Mutex` provides the
151/// `Sync` upgrade without modifying the plugin ABI; under
152/// DataFusion's `&mut self`-only call pattern the lock is uncontended
153/// in practice.
154struct PluginAccumulatorAdapter {
155    inner: Mutex<Box<dyn PluginAccumulator>>,
156}
157
158impl std::fmt::Debug for PluginAccumulatorAdapter {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        f.debug_struct("PluginAccumulatorAdapter")
161            .finish_non_exhaustive()
162    }
163}
164
165impl PluginAccumulatorAdapter {
166    fn with_inner<F, R>(&self, f: F) -> DFResult<R>
167    where
168        F: FnOnce(&mut dyn PluginAccumulator) -> Result<R, uni_plugin::FnError>,
169    {
170        let mut guard = self
171            .inner
172            .lock()
173            .map_err(|e| DataFusionError::Execution(format!("plugin accumulator lock: {e}")))?;
174        f(guard.as_mut()).map_err(fn_err_to_df)
175    }
176}
177
178impl DfAccumulator for PluginAccumulatorAdapter {
179    fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
180        self.with_inner(|acc| acc.update_batch(values))
181    }
182
183    fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
184        self.with_inner(|acc| acc.merge_batch(states))
185    }
186
187    fn evaluate(&mut self) -> DFResult<ScalarValue> {
188        self.with_inner(|acc| acc.evaluate())
189    }
190
191    fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
192        self.with_inner(|acc| acc.state())
193    }
194
195    fn size(&self) -> usize {
196        self.inner
197            .lock()
198            .map(|g| g.size())
199            .unwrap_or(std::mem::size_of::<Self>())
200    }
201}
202
203fn fn_err_to_df(e: uni_plugin::FnError) -> DataFusionError {
204    DataFusionError::Execution(format!("plugin aggregate: {e}"))
205}
206
207/// Map a plugin [`ArgType`] to a concrete Arrow [`DataType`]. Returns
208/// `None` for non-Arrow shapes (`Variadic`).
209fn arg_type_to_arrow(a: &ArgType) -> Option<DataType> {
210    match a {
211        ArgType::Primitive(dt) => Some(dt.clone()),
212        // `CypherValue` plugins ride through `LargeBinary` opaquely.
213        ArgType::CypherValue => Some(DataType::LargeBinary),
214        ArgType::Vector { len, element } => Some(DataType::FixedSizeList(
215            Arc::new(Field::new("item", element.clone(), true)),
216            i32::try_from(*len).ok()?,
217        )),
218        ArgType::Variadic(_) => None,
219    }
220}