vortex_array/aggregate_fn/
accumulator.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_error::vortex_err;
7
8use crate::ArrayRef;
9use crate::Columnar;
10use crate::ExecutionCtx;
11use crate::aggregate_fn::AggregateFn;
12use crate::aggregate_fn::AggregateFnRef;
13use crate::aggregate_fn::AggregateFnVTable;
14use crate::aggregate_fn::session::AggregateFnSessionExt;
15use crate::columnar::AnyColumnar;
16use crate::dtype::DType;
17use crate::executor::max_iterations;
18use crate::expr::stats::Precision;
19use crate::expr::stats::Stat;
20use crate::expr::stats::StatsProvider;
21use crate::scalar::Scalar;
22
23pub type AccumulatorRef = Box<dyn DynAccumulator>;
25
26pub struct Accumulator<V: AggregateFnVTable> {
28 vtable: V,
30 aggregate_fn: AggregateFnRef,
32 dtype: DType,
34 return_dtype: DType,
36 partial_dtype: DType,
38 partial: V::Partial,
40}
41
42impl<V: AggregateFnVTable> Accumulator<V> {
43 pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
44 let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
45 vortex_err!(
46 "Aggregate function {} cannot be applied to dtype {}",
47 vtable.id(),
48 dtype
49 )
50 })?;
51 let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
52 vortex_err!(
53 "Aggregate function {} cannot be applied to dtype {}",
54 vtable.id(),
55 dtype
56 )
57 })?;
58 let partial = vtable.empty_partial(&options, &dtype)?;
59 let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
60
61 Ok(Self {
62 vtable,
63 aggregate_fn,
64 dtype,
65 return_dtype,
66 partial_dtype,
67 partial,
68 })
69 }
70}
71
72pub trait DynAccumulator: 'static + Send {
75 fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
77
78 fn combine_partials(&mut self, other: Scalar) -> VortexResult<()>;
83
84 fn is_saturated(&self) -> bool;
86
87 fn reset(&mut self);
89
90 fn partial_scalar(&self) -> VortexResult<Scalar>;
94
95 fn final_scalar(&self) -> VortexResult<Scalar>;
97
98 fn flush(&mut self) -> VortexResult<Scalar>;
102
103 fn finish(&mut self) -> VortexResult<Scalar>;
107}
108
109impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
110 fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
111 if self.is_saturated() {
112 return Ok(());
113 }
114
115 vortex_ensure!(
116 batch.dtype() == &self.dtype,
117 "Input DType mismatch: expected {}, got {}",
118 self.dtype,
119 batch.dtype()
120 );
121
122 if let Some(stat) = Stat::from_aggregate_fn(&self.aggregate_fn)
125 && let Precision::Exact(partial) = batch.statistics().get(stat)
126 {
127 let partial = if partial.dtype() == &self.partial_dtype {
128 partial
129 } else {
130 vortex_ensure!(
131 partial.dtype().eq_ignore_nullability(&self.partial_dtype),
132 "Aggregate {} read legacy stat {} with dtype {}, expected {}",
133 self.aggregate_fn,
134 stat,
135 partial.dtype(),
136 self.partial_dtype,
137 );
138 partial.cast(&self.partial_dtype)?
139 };
140 self.vtable.combine_partials(&mut self.partial, partial)?;
141 return Ok(());
142 }
143
144 let session = ctx.session().clone();
145
146 {
152 let kernel = session
153 .aggregate_fns()
154 .find_aggregate_kernel(batch.encoding_id(), self.aggregate_fn.id());
155 if let Some(kernel) = kernel
156 && let Some(result) = kernel.aggregate(&self.aggregate_fn, batch, ctx)?
157 {
158 vortex_ensure!(
159 result.dtype() == &self.partial_dtype,
160 "Aggregate kernel returned {}, expected {}",
161 result.dtype(),
162 self.partial_dtype,
163 );
164 self.vtable.combine_partials(&mut self.partial, result)?;
165 return Ok(());
166 }
167 }
168
169 if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? {
171 return Ok(());
172 }
173
174 let mut batch = batch.clone();
180 for _ in 0..max_iterations() {
181 if batch.is::<AnyColumnar>() {
182 break;
183 }
184
185 if let Some(kernel) = session
186 .aggregate_fns()
187 .find_aggregate_kernel(batch.encoding_id(), self.aggregate_fn.id())
188 && let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch, ctx)?
189 {
190 vortex_ensure!(
191 result.dtype() == &self.partial_dtype,
192 "Aggregate kernel returned {}, expected {}",
193 result.dtype(),
194 self.partial_dtype,
195 );
196 self.vtable.combine_partials(&mut self.partial, result)?;
197 return Ok(());
198 }
199
200 batch = batch.execute(ctx)?;
201 }
202
203 let columnar = batch.execute::<Columnar>(ctx)?;
205
206 self.vtable.accumulate(&mut self.partial, &columnar, ctx)
207 }
208
209 fn combine_partials(&mut self, other: Scalar) -> VortexResult<()> {
210 self.vtable.combine_partials(&mut self.partial, other)
211 }
212
213 fn is_saturated(&self) -> bool {
214 self.vtable.is_saturated(&self.partial)
215 }
216
217 fn reset(&mut self) {
218 self.vtable.reset(&mut self.partial);
219 }
220
221 fn partial_scalar(&self) -> VortexResult<Scalar> {
222 let partial = self.vtable.to_scalar(&self.partial)?;
223
224 #[cfg(debug_assertions)]
225 {
226 vortex_ensure!(
227 partial.dtype() == &self.partial_dtype,
228 "Aggregate returned incorrect DType on partial_scalar: expected {}, got {}",
229 self.partial_dtype,
230 partial.dtype(),
231 );
232 }
233
234 Ok(partial)
235 }
236
237 fn final_scalar(&self) -> VortexResult<Scalar> {
238 let result = self.vtable.finalize_scalar(&self.partial)?;
239
240 vortex_ensure!(
241 result.dtype() == &self.return_dtype,
242 "Aggregate returned incorrect DType on final_scalar: expected {}, got {}",
243 self.return_dtype,
244 result.dtype(),
245 );
246
247 Ok(result)
248 }
249
250 fn flush(&mut self) -> VortexResult<Scalar> {
251 let partial = self.partial_scalar()?;
252 self.reset();
253 Ok(partial)
254 }
255
256 fn finish(&mut self) -> VortexResult<Scalar> {
257 let result = self.final_scalar()?;
258 self.reset();
259 Ok(result)
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use vortex_buffer::buffer;
266 use vortex_error::VortexResult;
267 use vortex_session::SessionExt;
268 use vortex_session::VortexSession;
269
270 use crate::ArrayRef;
271 use crate::ExecutionCtx;
272 use crate::IntoArray;
273 use crate::VortexSessionExecute;
274 use crate::aggregate_fn::Accumulator;
275 use crate::aggregate_fn::AggregateFnRef;
276 use crate::aggregate_fn::AggregateFnVTable;
277 use crate::aggregate_fn::DynAccumulator;
278 use crate::aggregate_fn::NumericalAggregateOpts;
279 use crate::aggregate_fn::combined::Combined;
280 use crate::aggregate_fn::combined::PairOptions;
281 use crate::aggregate_fn::fns::mean::Mean;
282 use crate::aggregate_fn::fns::sum::Sum;
283 use crate::aggregate_fn::kernels::DynAggregateKernel;
284 use crate::aggregate_fn::session::AggregateFnSession;
285 use crate::array::VTable;
286 use crate::arrays::Dict;
287 use crate::arrays::DictArray;
288 use crate::dtype::DType;
289 use crate::dtype::Nullability;
290 use crate::dtype::PType;
291 use crate::scalar::Scalar;
292
293 #[derive(Debug)]
297 struct SentinelMeanPartialKernel;
298 impl DynAggregateKernel for SentinelMeanPartialKernel {
299 fn aggregate(
300 &self,
301 _aggregate_fn: &AggregateFnRef,
302 _batch: &ArrayRef,
303 _ctx: &mut ExecutionCtx,
304 ) -> VortexResult<Option<Scalar>> {
305 Ok(Some(sentinel_partial()))
306 }
307 }
308
309 #[derive(Debug)]
311 struct DeclineKernel;
312 impl DynAggregateKernel for DeclineKernel {
313 fn aggregate(
314 &self,
315 _aggregate_fn: &AggregateFnRef,
316 _batch: &ArrayRef,
317 _ctx: &mut ExecutionCtx,
318 ) -> VortexResult<Option<Scalar>> {
319 Ok(None)
320 }
321 }
322
323 #[derive(Debug)]
326 struct SentinelSumPartialKernel;
327 impl DynAggregateKernel for SentinelSumPartialKernel {
328 fn aggregate(
329 &self,
330 _aggregate_fn: &AggregateFnRef,
331 _batch: &ArrayRef,
332 _ctx: &mut ExecutionCtx,
333 ) -> VortexResult<Option<Scalar>> {
334 Ok(Some(Scalar::primitive(42.0f64, Nullability::Nullable)))
335 }
336 }
337
338 fn fresh_session() -> VortexSession {
339 crate::array_session()
340 }
341
342 fn dict_of_seven() -> ArrayRef {
343 DictArray::try_new(buffer![0u32].into_array(), buffer![7.0f64].into_array())
344 .expect("valid dictionary")
345 .into_array()
346 }
347
348 fn mean_f64_accumulator() -> VortexResult<Accumulator<Combined<Mean>>> {
349 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
350 Accumulator::try_new(
351 Mean::combined(),
352 PairOptions(
353 NumericalAggregateOpts::default(),
354 NumericalAggregateOpts::default(),
355 ),
356 dtype,
357 )
358 }
359
360 fn sentinel_partial() -> Scalar {
361 let acc = mean_f64_accumulator().expect("build accumulator");
362 let sum = Scalar::primitive(42.0f64, Nullability::Nullable);
363 let count = Scalar::primitive(1u64, Nullability::NonNullable);
364 Scalar::struct_(acc.partial_dtype, vec![sum, count])
365 }
366
367 #[test]
370 fn combined_kernel_fires() -> VortexResult<()> {
371 static KERNEL: SentinelMeanPartialKernel = SentinelMeanPartialKernel;
372 let session = fresh_session();
373 session
374 .get::<AggregateFnSession>()
375 .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
376 let mut ctx = session.create_execution_ctx();
377
378 let mut acc = mean_f64_accumulator()?;
379 acc.accumulate(&dict_of_seven(), &mut ctx)?;
380 let partial = acc.flush()?;
381
382 let s = partial.as_struct();
383 assert_eq!(
384 s.field("sum").unwrap().as_primitive().as_::<f64>(),
385 Some(42.0)
386 );
387 assert_eq!(
388 s.field("count").unwrap().as_primitive().as_::<u64>(),
389 Some(1)
390 );
391 Ok(())
392 }
393
394 #[test]
397 fn fallback_when_kernel_declines() -> VortexResult<()> {
398 static KERNEL: DeclineKernel = DeclineKernel;
399 let session = fresh_session();
400 session
401 .get::<AggregateFnSession>()
402 .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
403 let mut ctx = session.create_execution_ctx();
404
405 let mut acc = mean_f64_accumulator()?;
406 acc.accumulate(&dict_of_seven(), &mut ctx)?;
407 let partial = acc.flush()?;
408
409 let s = partial.as_struct();
410 assert_eq!(
411 s.field("sum").unwrap().as_primitive().as_::<f64>(),
412 Some(7.0)
413 );
414 assert_eq!(
415 s.field("count").unwrap().as_primitive().as_::<u64>(),
416 Some(1)
417 );
418 Ok(())
419 }
420
421 #[test]
425 fn child_kernel_fires_through_combined() -> VortexResult<()> {
426 static KERNEL: SentinelSumPartialKernel = SentinelSumPartialKernel;
427 let session = fresh_session();
428 session
429 .get::<AggregateFnSession>()
430 .register_aggregate_kernel(Dict.id(), Some(Sum.id()), &KERNEL);
431 let mut ctx = session.create_execution_ctx();
432
433 let mut acc = mean_f64_accumulator()?;
434 acc.accumulate(&dict_of_seven(), &mut ctx)?;
435 let partial = acc.flush()?;
436
437 let s = partial.as_struct();
438 assert_eq!(
442 s.field("sum").unwrap().as_primitive().as_::<f64>(),
443 Some(42.0)
444 );
445 assert_eq!(
446 s.field("count").unwrap().as_primitive().as_::<u64>(),
447 Some(1)
448 );
449 Ok(())
450 }
451}