vortex_array/aggregate_fn/
accumulator.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_error::vortex_err;
7
8use crate::ArrayRef;
9use crate::Columnar;
10use crate::ExecutionCtx;
11use crate::aggregate_fn::AggregateFn;
12use crate::aggregate_fn::AggregateFnRef;
13use crate::aggregate_fn::AggregateFnVTable;
14use crate::aggregate_fn::session::AggregateFnSessionExt;
15use crate::columnar::AnyColumnar;
16use crate::dtype::DType;
17use crate::executor::max_iterations;
18use crate::expr::stats::Precision;
19use crate::expr::stats::Stat;
20use crate::expr::stats::StatsProvider;
21use crate::scalar::Scalar;
22
23pub type AccumulatorRef = Box<dyn DynAccumulator>;
25
26pub struct Accumulator<V: AggregateFnVTable> {
28 vtable: V,
30 aggregate_fn: AggregateFnRef,
32 dtype: DType,
34 return_dtype: DType,
36 partial_dtype: DType,
38 partial: V::Partial,
40}
41
42impl<V: AggregateFnVTable> Accumulator<V> {
43 pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
44 let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
45 vortex_err!(
46 "Aggregate function {} cannot be applied to dtype {}",
47 vtable.id(),
48 dtype
49 )
50 })?;
51 let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
52 vortex_err!(
53 "Aggregate function {} cannot be applied to dtype {}",
54 vtable.id(),
55 dtype
56 )
57 })?;
58 let partial = vtable.empty_partial(&options, &dtype)?;
59 let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
60
61 Ok(Self {
62 vtable,
63 aggregate_fn,
64 dtype,
65 return_dtype,
66 partial_dtype,
67 partial,
68 })
69 }
70}
71
72pub trait DynAccumulator: 'static + Send {
75 fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
77
78 fn combine_partials(&mut self, other: Scalar) -> VortexResult<()>;
83
84 fn is_saturated(&self) -> bool;
86
87 fn reset(&mut self);
89
90 fn partial_scalar(&self) -> VortexResult<Scalar>;
94
95 fn final_scalar(&self) -> VortexResult<Scalar>;
97
98 fn flush(&mut self) -> VortexResult<Scalar>;
102
103 fn finish(&mut self) -> VortexResult<Scalar>;
107}
108
109impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
110 fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
111 if self.is_saturated() {
112 return Ok(());
113 }
114
115 vortex_ensure!(
116 batch.dtype() == &self.dtype,
117 "Input DType mismatch: expected {}, got {}",
118 self.dtype,
119 batch.dtype()
120 );
121
122 if let Some(stat) = Stat::from_aggregate_fn(&self.aggregate_fn)
125 && let Precision::Exact(partial) = batch.statistics().get(stat)
126 {
127 let partial = if partial.dtype() == &self.partial_dtype {
128 partial
129 } else {
130 vortex_ensure!(
131 partial.dtype().eq_ignore_nullability(&self.partial_dtype),
132 "Aggregate {} read legacy stat {} with dtype {}, expected {}",
133 self.aggregate_fn,
134 stat,
135 partial.dtype(),
136 self.partial_dtype,
137 );
138 partial.cast(&self.partial_dtype)?
139 };
140 self.vtable.combine_partials(&mut self.partial, partial)?;
141 return Ok(());
142 }
143
144 let session = ctx.session().clone();
145 let kernels = &session.aggregate_fns().kernels;
146
147 {
153 let kernels_r = kernels.read();
154 let batch_id = batch.encoding_id();
155 let kernel = kernels_r
156 .get(&(batch_id, Some(self.aggregate_fn.id())))
157 .or_else(|| kernels_r.get(&(batch_id, None)))
158 .copied();
159 drop(kernels_r);
160 if let Some(kernel) = kernel
161 && let Some(result) = kernel.aggregate(&self.aggregate_fn, batch, ctx)?
162 {
163 vortex_ensure!(
164 result.dtype() == &self.partial_dtype,
165 "Aggregate kernel returned {}, expected {}",
166 result.dtype(),
167 self.partial_dtype,
168 );
169 self.vtable.combine_partials(&mut self.partial, result)?;
170 return Ok(());
171 }
172 }
173
174 if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? {
176 return Ok(());
177 }
178
179 let mut batch = batch.clone();
185 for _ in 0..max_iterations() {
186 if batch.is::<AnyColumnar>() {
187 break;
188 }
189
190 let kernels_r = kernels.read();
191 let batch_id = batch.encoding_id();
192 let kernel = kernels_r
193 .get(&(batch_id, Some(self.aggregate_fn.id())))
194 .or_else(|| kernels_r.get(&(batch_id, None)))
195 .copied();
196 drop(kernels_r);
197 if let Some(kernel) = kernel
198 && let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch, ctx)?
199 {
200 vortex_ensure!(
201 result.dtype() == &self.partial_dtype,
202 "Aggregate kernel returned {}, expected {}",
203 result.dtype(),
204 self.partial_dtype,
205 );
206 self.vtable.combine_partials(&mut self.partial, result)?;
207 return Ok(());
208 }
209
210 batch = batch.execute(ctx)?;
211 }
212
213 let columnar = batch.execute::<Columnar>(ctx)?;
215
216 self.vtable.accumulate(&mut self.partial, &columnar, ctx)
217 }
218
219 fn combine_partials(&mut self, other: Scalar) -> VortexResult<()> {
220 self.vtable.combine_partials(&mut self.partial, other)
221 }
222
223 fn is_saturated(&self) -> bool {
224 self.vtable.is_saturated(&self.partial)
225 }
226
227 fn reset(&mut self) {
228 self.vtable.reset(&mut self.partial);
229 }
230
231 fn partial_scalar(&self) -> VortexResult<Scalar> {
232 let partial = self.vtable.to_scalar(&self.partial)?;
233
234 #[cfg(debug_assertions)]
235 {
236 vortex_ensure!(
237 partial.dtype() == &self.partial_dtype,
238 "Aggregate returned incorrect DType on partial_scalar: expected {}, got {}",
239 self.partial_dtype,
240 partial.dtype(),
241 );
242 }
243
244 Ok(partial)
245 }
246
247 fn final_scalar(&self) -> VortexResult<Scalar> {
248 let result = self.vtable.finalize_scalar(&self.partial)?;
249
250 vortex_ensure!(
251 result.dtype() == &self.return_dtype,
252 "Aggregate returned incorrect DType on final_scalar: expected {}, got {}",
253 self.return_dtype,
254 result.dtype(),
255 );
256
257 Ok(result)
258 }
259
260 fn flush(&mut self) -> VortexResult<Scalar> {
261 let partial = self.partial_scalar()?;
262 self.reset();
263 Ok(partial)
264 }
265
266 fn finish(&mut self) -> VortexResult<Scalar> {
267 let result = self.final_scalar()?;
268 self.reset();
269 Ok(result)
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use vortex_buffer::buffer;
276 use vortex_error::VortexResult;
277 use vortex_session::SessionExt;
278 use vortex_session::VortexSession;
279
280 use crate::ArrayRef;
281 use crate::ExecutionCtx;
282 use crate::IntoArray;
283 use crate::VortexSessionExecute;
284 use crate::aggregate_fn::Accumulator;
285 use crate::aggregate_fn::AggregateFnRef;
286 use crate::aggregate_fn::AggregateFnVTable;
287 use crate::aggregate_fn::DynAccumulator;
288 use crate::aggregate_fn::EmptyOptions;
289 use crate::aggregate_fn::combined::Combined;
290 use crate::aggregate_fn::combined::PairOptions;
291 use crate::aggregate_fn::fns::mean::Mean;
292 use crate::aggregate_fn::fns::sum::Sum;
293 use crate::aggregate_fn::kernels::DynAggregateKernel;
294 use crate::aggregate_fn::session::AggregateFnSession;
295 use crate::array::VTable;
296 use crate::arrays::Dict;
297 use crate::arrays::DictArray;
298 use crate::dtype::DType;
299 use crate::dtype::Nullability;
300 use crate::dtype::PType;
301 use crate::scalar::Scalar;
302 use crate::session::ArraySession;
303
304 #[derive(Debug)]
308 struct SentinelMeanPartialKernel;
309 impl DynAggregateKernel for SentinelMeanPartialKernel {
310 fn aggregate(
311 &self,
312 _aggregate_fn: &AggregateFnRef,
313 _batch: &ArrayRef,
314 _ctx: &mut ExecutionCtx,
315 ) -> VortexResult<Option<Scalar>> {
316 Ok(Some(sentinel_partial()))
317 }
318 }
319
320 #[derive(Debug)]
322 struct DeclineKernel;
323 impl DynAggregateKernel for DeclineKernel {
324 fn aggregate(
325 &self,
326 _aggregate_fn: &AggregateFnRef,
327 _batch: &ArrayRef,
328 _ctx: &mut ExecutionCtx,
329 ) -> VortexResult<Option<Scalar>> {
330 Ok(None)
331 }
332 }
333
334 #[derive(Debug)]
337 struct SentinelSumPartialKernel;
338 impl DynAggregateKernel for SentinelSumPartialKernel {
339 fn aggregate(
340 &self,
341 _aggregate_fn: &AggregateFnRef,
342 _batch: &ArrayRef,
343 _ctx: &mut ExecutionCtx,
344 ) -> VortexResult<Option<Scalar>> {
345 Ok(Some(Scalar::primitive(42.0f64, Nullability::Nullable)))
346 }
347 }
348
349 fn fresh_session() -> VortexSession {
350 VortexSession::empty().with::<ArraySession>()
351 }
352
353 fn dict_of_seven() -> ArrayRef {
354 DictArray::try_new(buffer![0u32].into_array(), buffer![7.0f64].into_array())
355 .expect("valid dictionary")
356 .into_array()
357 }
358
359 fn mean_f64_accumulator() -> VortexResult<Accumulator<Combined<Mean>>> {
360 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
361 Accumulator::try_new(
362 Mean::combined(),
363 PairOptions(EmptyOptions, EmptyOptions),
364 dtype,
365 )
366 }
367
368 fn sentinel_partial() -> Scalar {
369 let acc = mean_f64_accumulator().expect("build accumulator");
370 let sum = Scalar::primitive(42.0f64, Nullability::Nullable);
371 let count = Scalar::primitive(1u64, Nullability::NonNullable);
372 Scalar::struct_(acc.partial_dtype, vec![sum, count])
373 }
374
375 #[test]
378 fn combined_kernel_fires() -> VortexResult<()> {
379 static KERNEL: SentinelMeanPartialKernel = SentinelMeanPartialKernel;
380 let session = fresh_session();
381 session
382 .get::<AggregateFnSession>()
383 .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
384 let mut ctx = session.create_execution_ctx();
385
386 let mut acc = mean_f64_accumulator()?;
387 acc.accumulate(&dict_of_seven(), &mut ctx)?;
388 let partial = acc.flush()?;
389
390 let s = partial.as_struct();
391 assert_eq!(
392 s.field("sum").unwrap().as_primitive().as_::<f64>(),
393 Some(42.0)
394 );
395 assert_eq!(
396 s.field("count").unwrap().as_primitive().as_::<u64>(),
397 Some(1)
398 );
399 Ok(())
400 }
401
402 #[test]
405 fn fallback_when_kernel_declines() -> VortexResult<()> {
406 static KERNEL: DeclineKernel = DeclineKernel;
407 let session = fresh_session();
408 session
409 .get::<AggregateFnSession>()
410 .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
411 let mut ctx = session.create_execution_ctx();
412
413 let mut acc = mean_f64_accumulator()?;
414 acc.accumulate(&dict_of_seven(), &mut ctx)?;
415 let partial = acc.flush()?;
416
417 let s = partial.as_struct();
418 assert_eq!(
419 s.field("sum").unwrap().as_primitive().as_::<f64>(),
420 Some(7.0)
421 );
422 assert_eq!(
423 s.field("count").unwrap().as_primitive().as_::<u64>(),
424 Some(1)
425 );
426 Ok(())
427 }
428
429 #[test]
433 fn child_kernel_fires_through_combined() -> VortexResult<()> {
434 static KERNEL: SentinelSumPartialKernel = SentinelSumPartialKernel;
435 let session = fresh_session();
436 session
437 .get::<AggregateFnSession>()
438 .register_aggregate_kernel(Dict.id(), Some(Sum.id()), &KERNEL);
439 let mut ctx = session.create_execution_ctx();
440
441 let mut acc = mean_f64_accumulator()?;
442 acc.accumulate(&dict_of_seven(), &mut ctx)?;
443 let partial = acc.flush()?;
444
445 let s = partial.as_struct();
446 assert_eq!(
450 s.field("sum").unwrap().as_primitive().as_::<f64>(),
451 Some(42.0)
452 );
453 assert_eq!(
454 s.field("count").unwrap().as_primitive().as_::<u64>(),
455 Some(1)
456 );
457 Ok(())
458 }
459}