uni_query/query/
df_udaf_plugin.rs1use std::any::Any;
23use std::hash::{Hash, Hasher};
24use std::sync::{Arc, Mutex};
25
26use arrow::array::ArrayRef;
27use arrow::datatypes::Field;
28use arrow_schema::DataType;
29use datafusion::error::{DataFusionError, Result as DFResult};
30use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
31use datafusion::logical_expr::{
32 Accumulator as DfAccumulator, AggregateUDFImpl, Signature, TypeSignature,
33};
34use datafusion::scalar::ScalarValue;
35use uni_plugin::traits::aggregate::{AggSignature, AggregatePluginFn, PluginAccumulator};
36use uni_plugin::traits::scalar::ArgType;
37use uni_plugin::{PluginRegistry, QName};
38
39pub struct PluginAggregateUdaf {
45 qname: QName,
46 name: String,
47 registry: Arc<PluginRegistry>,
48 sig: AggSignature,
49 df_signature: Signature,
50}
51
52impl PluginAggregateUdaf {
53 #[must_use]
59 pub fn new(qname: QName, registry: Arc<PluginRegistry>, sig: AggSignature) -> Self {
60 let arity = sig.args.len();
61 let df_signature = Signature::new(TypeSignature::Any(arity), sig.volatility);
62 let name = format!("{}.{}", qname.namespace(), qname.local());
63 Self {
64 qname,
65 name,
66 registry,
67 sig,
68 df_signature,
69 }
70 }
71
72 fn fetch(&self) -> DFResult<Arc<dyn AggregatePluginFn>> {
73 self.registry
74 .aggregate(&self.qname)
75 .map(|e| Arc::clone(&e.aggregate))
76 .ok_or_else(|| {
77 DataFusionError::Execution(format!(
78 "PluginAggregateUdaf: registry entry for `{}` disappeared",
79 self.name
80 ))
81 })
82 }
83}
84
85impl std::fmt::Debug for PluginAggregateUdaf {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("PluginAggregateUdaf")
88 .field("qname", &self.qname)
89 .field("supports_partial", &self.sig.supports_partial)
90 .finish_non_exhaustive()
91 }
92}
93
94impl PartialEq for PluginAggregateUdaf {
95 fn eq(&self, other: &Self) -> bool {
96 self.qname == other.qname
97 }
98}
99
100impl Eq for PluginAggregateUdaf {}
101
102impl Hash for PluginAggregateUdaf {
103 fn hash<H: Hasher>(&self, state: &mut H) {
104 self.name.hash(state);
105 }
106}
107
108impl AggregateUDFImpl for PluginAggregateUdaf {
109 fn as_any(&self) -> &dyn Any {
110 self
111 }
112
113 fn name(&self) -> &str {
114 &self.name
115 }
116
117 fn signature(&self) -> &Signature {
118 &self.df_signature
119 }
120
121 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
122 arg_type_to_arrow(&self.sig.returns).ok_or_else(|| {
123 DataFusionError::Execution(format!(
124 "PluginAggregateUdaf `{}`: non-Arrow return type",
125 self.name
126 ))
127 })
128 }
129
130 fn accumulator(&self, _args: AccumulatorArgs<'_>) -> DFResult<Box<dyn DfAccumulator>> {
131 let agg = self.fetch()?;
132 Ok(Box::new(PluginAccumulatorAdapter {
133 inner: Mutex::new(agg.create_accumulator()),
134 }))
135 }
136
137 fn state_fields(&self, _args: StateFieldsArgs<'_>) -> DFResult<Vec<Arc<Field>>> {
138 Ok(self
139 .sig
140 .state_fields
141 .iter()
142 .map(|f| Arc::new(f.clone()))
143 .collect())
144 }
145}
146
147struct PluginAccumulatorAdapter {
155 inner: Mutex<Box<dyn PluginAccumulator>>,
156}
157
158impl std::fmt::Debug for PluginAccumulatorAdapter {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("PluginAccumulatorAdapter")
161 .finish_non_exhaustive()
162 }
163}
164
165impl PluginAccumulatorAdapter {
166 fn with_inner<F, R>(&self, f: F) -> DFResult<R>
167 where
168 F: FnOnce(&mut dyn PluginAccumulator) -> Result<R, uni_plugin::FnError>,
169 {
170 let mut guard = self
171 .inner
172 .lock()
173 .map_err(|e| DataFusionError::Execution(format!("plugin accumulator lock: {e}")))?;
174 f(guard.as_mut()).map_err(fn_err_to_df)
175 }
176}
177
178impl DfAccumulator for PluginAccumulatorAdapter {
179 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
180 self.with_inner(|acc| acc.update_batch(values))
181 }
182
183 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
184 self.with_inner(|acc| acc.merge_batch(states))
185 }
186
187 fn evaluate(&mut self) -> DFResult<ScalarValue> {
188 self.with_inner(|acc| acc.evaluate())
189 }
190
191 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
192 self.with_inner(|acc| acc.state())
193 }
194
195 fn size(&self) -> usize {
196 self.inner
197 .lock()
198 .map(|g| g.size())
199 .unwrap_or(std::mem::size_of::<Self>())
200 }
201}
202
203fn fn_err_to_df(e: uni_plugin::FnError) -> DataFusionError {
204 DataFusionError::Execution(format!("plugin aggregate: {e}"))
205}
206
207fn arg_type_to_arrow(a: &ArgType) -> Option<DataType> {
210 match a {
211 ArgType::Primitive(dt) => Some(dt.clone()),
212 ArgType::CypherValue => Some(DataType::LargeBinary),
214 ArgType::Vector { len, element } => Some(DataType::FixedSizeList(
215 Arc::new(Field::new("item", element.clone(), true)),
216 i32::try_from(*len).ok()?,
217 )),
218 ArgType::Variadic(_) => None,
219 }
220}