vortex_array/aggregate_fn/
accumulator.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_error::vortex_err;
7
8use crate::ArrayRef;
9use crate::Columnar;
10use crate::ExecutionCtx;
11use crate::aggregate_fn::AggregateFn;
12use crate::aggregate_fn::AggregateFnRef;
13use crate::aggregate_fn::AggregateFnVTable;
14use crate::aggregate_fn::session::AggregateFnSessionExt;
15use crate::columnar::AnyColumnar;
16use crate::dtype::DType;
17use crate::executor::max_iterations;
18use crate::expr::stats::Precision;
19use crate::expr::stats::Stat;
20use crate::expr::stats::StatsProvider;
21use crate::scalar::Scalar;
22
23pub type AccumulatorRef = Box<dyn DynAccumulator>;
25
26pub struct Accumulator<V: AggregateFnVTable> {
28 vtable: V,
30 aggregate_fn: AggregateFnRef,
32 dtype: DType,
34 return_dtype: DType,
36 partial_dtype: DType,
38 partial: V::Partial,
40}
41
42impl<V: AggregateFnVTable> Accumulator<V> {
43 pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
44 let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
45 vortex_err!(
46 "Aggregate function {} cannot be applied to dtype {}",
47 vtable.id(),
48 dtype
49 )
50 })?;
51 let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
52 vortex_err!(
53 "Aggregate function {} cannot be applied to dtype {}",
54 vtable.id(),
55 dtype
56 )
57 })?;
58 let partial = vtable.empty_partial(&options, &dtype)?;
59 let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
60
61 Ok(Self {
62 vtable,
63 aggregate_fn,
64 dtype,
65 return_dtype,
66 partial_dtype,
67 partial,
68 })
69 }
70}
71
72pub trait DynAccumulator: 'static + Send {
75 fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
77
78 fn combine_partials(&mut self, other: Scalar) -> VortexResult<()>;
83
84 fn is_saturated(&self) -> bool;
86
87 fn reset(&mut self);
89
90 fn partial_scalar(&self) -> VortexResult<Scalar>;
94
95 fn final_scalar(&self) -> VortexResult<Scalar>;
97
98 fn flush(&mut self) -> VortexResult<Scalar>;
102
103 fn finish(&mut self) -> VortexResult<Scalar>;
107}
108
109impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
110 fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
111 if self.is_saturated() {
112 return Ok(());
113 }
114
115 vortex_ensure!(
116 batch.dtype() == &self.dtype,
117 "Input DType mismatch: expected {}, got {}",
118 self.dtype,
119 batch.dtype()
120 );
121
122 if let Some(stat) = Stat::from_aggregate_fn(&self.aggregate_fn)
125 && let Some(Precision::Exact(partial)) = batch.statistics().get(stat)
126 {
127 vortex_ensure!(
128 partial.dtype() == &self.partial_dtype,
129 "Aggregate {} read legacy stat {} with dtype {}, expected {}",
130 self.aggregate_fn,
131 stat,
132 partial.dtype(),
133 self.partial_dtype,
134 );
135 self.vtable.combine_partials(&mut self.partial, partial)?;
136 return Ok(());
137 }
138
139 let session = ctx.session().clone();
140 let kernels = &session.aggregate_fns().kernels;
141
142 {
148 let kernels_r = kernels.read();
149 let batch_id = batch.encoding_id();
150 let kernel = kernels_r
151 .get(&(batch_id, Some(self.aggregate_fn.id())))
152 .or_else(|| kernels_r.get(&(batch_id, None)))
153 .copied();
154 drop(kernels_r);
155 if let Some(kernel) = kernel
156 && let Some(result) = kernel.aggregate(&self.aggregate_fn, batch, ctx)?
157 {
158 vortex_ensure!(
159 result.dtype() == &self.partial_dtype,
160 "Aggregate kernel returned {}, expected {}",
161 result.dtype(),
162 self.partial_dtype,
163 );
164 self.vtable.combine_partials(&mut self.partial, result)?;
165 return Ok(());
166 }
167 }
168
169 if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? {
171 return Ok(());
172 }
173
174 let mut batch = batch.clone();
180 for _ in 0..max_iterations() {
181 if batch.is::<AnyColumnar>() {
182 break;
183 }
184
185 let kernels_r = kernels.read();
186 let batch_id = batch.encoding_id();
187 let kernel = kernels_r
188 .get(&(batch_id, Some(self.aggregate_fn.id())))
189 .or_else(|| kernels_r.get(&(batch_id, None)))
190 .copied();
191 drop(kernels_r);
192 if let Some(kernel) = kernel
193 && let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch, ctx)?
194 {
195 vortex_ensure!(
196 result.dtype() == &self.partial_dtype,
197 "Aggregate kernel returned {}, expected {}",
198 result.dtype(),
199 self.partial_dtype,
200 );
201 self.vtable.combine_partials(&mut self.partial, result)?;
202 return Ok(());
203 }
204
205 batch = batch.execute(ctx)?;
206 }
207
208 let columnar = batch.execute::<Columnar>(ctx)?;
210
211 self.vtable.accumulate(&mut self.partial, &columnar, ctx)
212 }
213
214 fn combine_partials(&mut self, other: Scalar) -> VortexResult<()> {
215 self.vtable.combine_partials(&mut self.partial, other)
216 }
217
218 fn is_saturated(&self) -> bool {
219 self.vtable.is_saturated(&self.partial)
220 }
221
222 fn reset(&mut self) {
223 self.vtable.reset(&mut self.partial);
224 }
225
226 fn partial_scalar(&self) -> VortexResult<Scalar> {
227 let partial = self.vtable.to_scalar(&self.partial)?;
228
229 #[cfg(debug_assertions)]
230 {
231 vortex_ensure!(
232 partial.dtype() == &self.partial_dtype,
233 "Aggregate returned incorrect DType on partial_scalar: expected {}, got {}",
234 self.partial_dtype,
235 partial.dtype(),
236 );
237 }
238
239 Ok(partial)
240 }
241
242 fn final_scalar(&self) -> VortexResult<Scalar> {
243 let result = self.vtable.finalize_scalar(&self.partial)?;
244
245 vortex_ensure!(
246 result.dtype() == &self.return_dtype,
247 "Aggregate returned incorrect DType on final_scalar: expected {}, got {}",
248 self.return_dtype,
249 result.dtype(),
250 );
251
252 Ok(result)
253 }
254
255 fn flush(&mut self) -> VortexResult<Scalar> {
256 let partial = self.partial_scalar()?;
257 self.reset();
258 Ok(partial)
259 }
260
261 fn finish(&mut self) -> VortexResult<Scalar> {
262 let result = self.final_scalar()?;
263 self.reset();
264 Ok(result)
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use vortex_buffer::buffer;
271 use vortex_error::VortexResult;
272 use vortex_session::SessionExt;
273 use vortex_session::VortexSession;
274
275 use crate::ArrayRef;
276 use crate::ExecutionCtx;
277 use crate::IntoArray;
278 use crate::VortexSessionExecute;
279 use crate::aggregate_fn::Accumulator;
280 use crate::aggregate_fn::AggregateFnRef;
281 use crate::aggregate_fn::AggregateFnVTable;
282 use crate::aggregate_fn::DynAccumulator;
283 use crate::aggregate_fn::EmptyOptions;
284 use crate::aggregate_fn::combined::Combined;
285 use crate::aggregate_fn::combined::PairOptions;
286 use crate::aggregate_fn::fns::mean::Mean;
287 use crate::aggregate_fn::fns::sum::Sum;
288 use crate::aggregate_fn::kernels::DynAggregateKernel;
289 use crate::aggregate_fn::session::AggregateFnSession;
290 use crate::array::VTable;
291 use crate::arrays::Dict;
292 use crate::arrays::DictArray;
293 use crate::dtype::DType;
294 use crate::dtype::Nullability;
295 use crate::dtype::PType;
296 use crate::scalar::Scalar;
297 use crate::session::ArraySession;
298
299 #[derive(Debug)]
303 struct SentinelMeanPartialKernel;
304 impl DynAggregateKernel for SentinelMeanPartialKernel {
305 fn aggregate(
306 &self,
307 _aggregate_fn: &AggregateFnRef,
308 _batch: &ArrayRef,
309 _ctx: &mut ExecutionCtx,
310 ) -> VortexResult<Option<Scalar>> {
311 Ok(Some(sentinel_partial()))
312 }
313 }
314
315 #[derive(Debug)]
317 struct DeclineKernel;
318 impl DynAggregateKernel for DeclineKernel {
319 fn aggregate(
320 &self,
321 _aggregate_fn: &AggregateFnRef,
322 _batch: &ArrayRef,
323 _ctx: &mut ExecutionCtx,
324 ) -> VortexResult<Option<Scalar>> {
325 Ok(None)
326 }
327 }
328
329 #[derive(Debug)]
332 struct SentinelSumPartialKernel;
333 impl DynAggregateKernel for SentinelSumPartialKernel {
334 fn aggregate(
335 &self,
336 _aggregate_fn: &AggregateFnRef,
337 _batch: &ArrayRef,
338 _ctx: &mut ExecutionCtx,
339 ) -> VortexResult<Option<Scalar>> {
340 Ok(Some(Scalar::primitive(42.0f64, Nullability::Nullable)))
341 }
342 }
343
344 fn fresh_session() -> VortexSession {
345 VortexSession::empty().with::<ArraySession>()
346 }
347
348 fn dict_of_seven() -> ArrayRef {
349 DictArray::try_new(buffer![0u32].into_array(), buffer![7.0f64].into_array())
350 .expect("valid dictionary")
351 .into_array()
352 }
353
354 fn mean_f64_accumulator() -> VortexResult<Accumulator<Combined<Mean>>> {
355 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
356 Accumulator::try_new(
357 Mean::combined(),
358 PairOptions(EmptyOptions, EmptyOptions),
359 dtype,
360 )
361 }
362
363 fn sentinel_partial() -> Scalar {
364 let acc = mean_f64_accumulator().expect("build accumulator");
365 let sum = Scalar::primitive(42.0f64, Nullability::Nullable);
366 let count = Scalar::primitive(1u64, Nullability::NonNullable);
367 Scalar::struct_(acc.partial_dtype, vec![sum, count])
368 }
369
370 #[test]
373 fn combined_kernel_fires() -> VortexResult<()> {
374 static KERNEL: SentinelMeanPartialKernel = SentinelMeanPartialKernel;
375 let session = fresh_session();
376 session
377 .get::<AggregateFnSession>()
378 .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
379 let mut ctx = session.create_execution_ctx();
380
381 let mut acc = mean_f64_accumulator()?;
382 acc.accumulate(&dict_of_seven(), &mut ctx)?;
383 let partial = acc.flush()?;
384
385 let s = partial.as_struct();
386 assert_eq!(
387 s.field("sum").unwrap().as_primitive().as_::<f64>(),
388 Some(42.0)
389 );
390 assert_eq!(
391 s.field("count").unwrap().as_primitive().as_::<u64>(),
392 Some(1)
393 );
394 Ok(())
395 }
396
397 #[test]
400 fn fallback_when_kernel_declines() -> VortexResult<()> {
401 static KERNEL: DeclineKernel = DeclineKernel;
402 let session = fresh_session();
403 session
404 .get::<AggregateFnSession>()
405 .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
406 let mut ctx = session.create_execution_ctx();
407
408 let mut acc = mean_f64_accumulator()?;
409 acc.accumulate(&dict_of_seven(), &mut ctx)?;
410 let partial = acc.flush()?;
411
412 let s = partial.as_struct();
413 assert_eq!(
414 s.field("sum").unwrap().as_primitive().as_::<f64>(),
415 Some(7.0)
416 );
417 assert_eq!(
418 s.field("count").unwrap().as_primitive().as_::<u64>(),
419 Some(1)
420 );
421 Ok(())
422 }
423
424 #[test]
428 fn child_kernel_fires_through_combined() -> VortexResult<()> {
429 static KERNEL: SentinelSumPartialKernel = SentinelSumPartialKernel;
430 let session = fresh_session();
431 session
432 .get::<AggregateFnSession>()
433 .register_aggregate_kernel(Dict.id(), Some(Sum.id()), &KERNEL);
434 let mut ctx = session.create_execution_ctx();
435
436 let mut acc = mean_f64_accumulator()?;
437 acc.accumulate(&dict_of_seven(), &mut ctx)?;
438 let partial = acc.flush()?;
439
440 let s = partial.as_struct();
441 assert_eq!(
445 s.field("sum").unwrap().as_primitive().as_::<f64>(),
446 Some(42.0)
447 );
448 assert_eq!(
449 s.field("count").unwrap().as_primitive().as_::<u64>(),
450 Some(1)
451 );
452 Ok(())
453 }
454}