vortex_array/aggregate_fn/
accumulator.rs1use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_error::vortex_err;
7
8use crate::ArrayRef;
9use crate::Columnar;
10use crate::ExecutionCtx;
11use crate::aggregate_fn::AggregateFn;
12use crate::aggregate_fn::AggregateFnRef;
13use crate::aggregate_fn::AggregateFnVTable;
14use crate::aggregate_fn::session::AggregateFnSessionExt;
15use crate::columnar::AnyColumnar;
16use crate::dtype::DType;
17use crate::executor::max_iterations;
18use crate::expr::stats::Precision;
19use crate::expr::stats::Stat;
20use crate::expr::stats::StatsProvider;
21use crate::scalar::Scalar;
22
23pub type AccumulatorRef = Box<dyn DynAccumulator>;
25
26pub struct Accumulator<V: AggregateFnVTable> {
28 vtable: V,
30 aggregate_fn: AggregateFnRef,
32 dtype: DType,
34 return_dtype: DType,
36 partial_dtype: DType,
38 partial: V::Partial,
40}
41
42impl<V: AggregateFnVTable> Accumulator<V> {
43 pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
44 let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
45 vortex_err!(
46 "Aggregate function {} cannot be applied to dtype {}",
47 vtable.id(),
48 dtype
49 )
50 })?;
51 let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
52 vortex_err!(
53 "Aggregate function {} cannot be applied to dtype {}",
54 vtable.id(),
55 dtype
56 )
57 })?;
58 let partial = vtable.empty_partial(&options, &dtype)?;
59 let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
60
61 Ok(Self {
62 vtable,
63 aggregate_fn,
64 dtype,
65 return_dtype,
66 partial_dtype,
67 partial,
68 })
69 }
70}
71
72pub trait DynAccumulator: 'static + Send {
75 fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
77
78 fn combine_partials(&mut self, other: Scalar) -> VortexResult<()>;
83
84 fn is_saturated(&self) -> bool;
86
87 fn reset(&mut self);
89
90 fn partial_scalar(&self) -> VortexResult<Scalar>;
94
95 fn final_scalar(&self) -> VortexResult<Scalar>;
97
98 fn flush(&mut self) -> VortexResult<Scalar>;
102
103 fn finish(&mut self) -> VortexResult<Scalar>;
107}
108
109impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
110 fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
111 if self.is_saturated() {
112 return Ok(());
113 }
114
115 vortex_ensure!(
116 batch.dtype() == &self.dtype,
117 "Input DType mismatch: expected {}, got {}",
118 self.dtype,
119 batch.dtype()
120 );
121
122 if let Some(stat) = Stat::from_aggregate_fn(&self.aggregate_fn)
125 && let Precision::Exact(partial) = batch.statistics().get(stat)
126 {
127 let partial = if partial.dtype() == &self.partial_dtype {
128 partial
129 } else {
130 vortex_ensure!(
131 partial.dtype().eq_ignore_nullability(&self.partial_dtype),
132 "Aggregate {} read legacy stat {} with dtype {}, expected {}",
133 self.aggregate_fn,
134 stat,
135 partial.dtype(),
136 self.partial_dtype,
137 );
138 partial.cast(&self.partial_dtype)?
139 };
140 self.vtable.combine_partials(&mut self.partial, partial)?;
141 return Ok(());
142 }
143
144 let session = ctx.session().clone();
145
146 {
152 let kernel = session
153 .aggregate_fns()
154 .find_aggregate_kernel(batch.encoding_id(), self.aggregate_fn.id());
155 if let Some(kernel) = kernel
156 && let Some(result) = kernel.aggregate(&self.aggregate_fn, batch, ctx)?
157 {
158 vortex_ensure!(
159 result.dtype() == &self.partial_dtype,
160 "Aggregate kernel returned {}, expected {}",
161 result.dtype(),
162 self.partial_dtype,
163 );
164 self.vtable.combine_partials(&mut self.partial, result)?;
165 return Ok(());
166 }
167 }
168
169 if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? {
171 return Ok(());
172 }
173
174 let mut batch = batch.clone();
180 for _ in 0..max_iterations() {
181 if batch.is::<AnyColumnar>() {
182 break;
183 }
184
185 if let Some(kernel) = session
186 .aggregate_fns()
187 .find_aggregate_kernel(batch.encoding_id(), self.aggregate_fn.id())
188 && let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch, ctx)?
189 {
190 vortex_ensure!(
191 result.dtype() == &self.partial_dtype,
192 "Aggregate kernel returned {}, expected {}",
193 result.dtype(),
194 self.partial_dtype,
195 );
196 self.vtable.combine_partials(&mut self.partial, result)?;
197 return Ok(());
198 }
199
200 batch = batch.execute(ctx)?;
201 }
202
203 let columnar = batch.execute::<Columnar>(ctx)?;
205
206 self.vtable.accumulate(&mut self.partial, &columnar, ctx)
207 }
208
209 fn combine_partials(&mut self, other: Scalar) -> VortexResult<()> {
210 self.vtable.combine_partials(&mut self.partial, other)
211 }
212
213 fn is_saturated(&self) -> bool {
214 self.vtable.is_saturated(&self.partial)
215 }
216
217 fn reset(&mut self) {
218 self.vtable.reset(&mut self.partial);
219 }
220
221 fn partial_scalar(&self) -> VortexResult<Scalar> {
222 let partial = self.vtable.to_scalar(&self.partial)?;
223
224 #[cfg(debug_assertions)]
225 {
226 vortex_ensure!(
227 partial.dtype() == &self.partial_dtype,
228 "Aggregate returned incorrect DType on partial_scalar: expected {}, got {}",
229 self.partial_dtype,
230 partial.dtype(),
231 );
232 }
233
234 Ok(partial)
235 }
236
237 fn final_scalar(&self) -> VortexResult<Scalar> {
238 let result = self.vtable.finalize_scalar(&self.partial)?;
239
240 vortex_ensure!(
241 result.dtype() == &self.return_dtype,
242 "Aggregate returned incorrect DType on final_scalar: expected {}, got {}",
243 self.return_dtype,
244 result.dtype(),
245 );
246
247 Ok(result)
248 }
249
250 fn flush(&mut self) -> VortexResult<Scalar> {
251 let partial = self.partial_scalar()?;
252 self.reset();
253 Ok(partial)
254 }
255
256 fn finish(&mut self) -> VortexResult<Scalar> {
257 let result = self.final_scalar()?;
258 self.reset();
259 Ok(result)
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use vortex_buffer::buffer;
266 use vortex_error::VortexResult;
267 use vortex_session::SessionExt;
268 use vortex_session::VortexSession;
269
270 use crate::ArrayRef;
271 use crate::ExecutionCtx;
272 use crate::IntoArray;
273 use crate::VortexSessionExecute;
274 use crate::aggregate_fn::Accumulator;
275 use crate::aggregate_fn::AggregateFnRef;
276 use crate::aggregate_fn::AggregateFnVTable;
277 use crate::aggregate_fn::DynAccumulator;
278 use crate::aggregate_fn::EmptyOptions;
279 use crate::aggregate_fn::combined::Combined;
280 use crate::aggregate_fn::combined::PairOptions;
281 use crate::aggregate_fn::fns::mean::Mean;
282 use crate::aggregate_fn::fns::sum::Sum;
283 use crate::aggregate_fn::kernels::DynAggregateKernel;
284 use crate::aggregate_fn::session::AggregateFnSession;
285 use crate::array::VTable;
286 use crate::arrays::Dict;
287 use crate::arrays::DictArray;
288 use crate::dtype::DType;
289 use crate::dtype::Nullability;
290 use crate::dtype::PType;
291 use crate::scalar::Scalar;
292 use crate::session::ArraySession;
293
294 #[derive(Debug)]
298 struct SentinelMeanPartialKernel;
299 impl DynAggregateKernel for SentinelMeanPartialKernel {
300 fn aggregate(
301 &self,
302 _aggregate_fn: &AggregateFnRef,
303 _batch: &ArrayRef,
304 _ctx: &mut ExecutionCtx,
305 ) -> VortexResult<Option<Scalar>> {
306 Ok(Some(sentinel_partial()))
307 }
308 }
309
310 #[derive(Debug)]
312 struct DeclineKernel;
313 impl DynAggregateKernel for DeclineKernel {
314 fn aggregate(
315 &self,
316 _aggregate_fn: &AggregateFnRef,
317 _batch: &ArrayRef,
318 _ctx: &mut ExecutionCtx,
319 ) -> VortexResult<Option<Scalar>> {
320 Ok(None)
321 }
322 }
323
324 #[derive(Debug)]
327 struct SentinelSumPartialKernel;
328 impl DynAggregateKernel for SentinelSumPartialKernel {
329 fn aggregate(
330 &self,
331 _aggregate_fn: &AggregateFnRef,
332 _batch: &ArrayRef,
333 _ctx: &mut ExecutionCtx,
334 ) -> VortexResult<Option<Scalar>> {
335 Ok(Some(Scalar::primitive(42.0f64, Nullability::Nullable)))
336 }
337 }
338
339 fn fresh_session() -> VortexSession {
340 VortexSession::empty().with::<ArraySession>()
341 }
342
343 fn dict_of_seven() -> ArrayRef {
344 DictArray::try_new(buffer![0u32].into_array(), buffer![7.0f64].into_array())
345 .expect("valid dictionary")
346 .into_array()
347 }
348
349 fn mean_f64_accumulator() -> VortexResult<Accumulator<Combined<Mean>>> {
350 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
351 Accumulator::try_new(
352 Mean::combined(),
353 PairOptions(EmptyOptions, EmptyOptions),
354 dtype,
355 )
356 }
357
358 fn sentinel_partial() -> Scalar {
359 let acc = mean_f64_accumulator().expect("build accumulator");
360 let sum = Scalar::primitive(42.0f64, Nullability::Nullable);
361 let count = Scalar::primitive(1u64, Nullability::NonNullable);
362 Scalar::struct_(acc.partial_dtype, vec![sum, count])
363 }
364
365 #[test]
368 fn combined_kernel_fires() -> VortexResult<()> {
369 static KERNEL: SentinelMeanPartialKernel = SentinelMeanPartialKernel;
370 let session = fresh_session();
371 session
372 .get::<AggregateFnSession>()
373 .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
374 let mut ctx = session.create_execution_ctx();
375
376 let mut acc = mean_f64_accumulator()?;
377 acc.accumulate(&dict_of_seven(), &mut ctx)?;
378 let partial = acc.flush()?;
379
380 let s = partial.as_struct();
381 assert_eq!(
382 s.field("sum").unwrap().as_primitive().as_::<f64>(),
383 Some(42.0)
384 );
385 assert_eq!(
386 s.field("count").unwrap().as_primitive().as_::<u64>(),
387 Some(1)
388 );
389 Ok(())
390 }
391
392 #[test]
395 fn fallback_when_kernel_declines() -> VortexResult<()> {
396 static KERNEL: DeclineKernel = DeclineKernel;
397 let session = fresh_session();
398 session
399 .get::<AggregateFnSession>()
400 .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
401 let mut ctx = session.create_execution_ctx();
402
403 let mut acc = mean_f64_accumulator()?;
404 acc.accumulate(&dict_of_seven(), &mut ctx)?;
405 let partial = acc.flush()?;
406
407 let s = partial.as_struct();
408 assert_eq!(
409 s.field("sum").unwrap().as_primitive().as_::<f64>(),
410 Some(7.0)
411 );
412 assert_eq!(
413 s.field("count").unwrap().as_primitive().as_::<u64>(),
414 Some(1)
415 );
416 Ok(())
417 }
418
419 #[test]
423 fn child_kernel_fires_through_combined() -> VortexResult<()> {
424 static KERNEL: SentinelSumPartialKernel = SentinelSumPartialKernel;
425 let session = fresh_session();
426 session
427 .get::<AggregateFnSession>()
428 .register_aggregate_kernel(Dict.id(), Some(Sum.id()), &KERNEL);
429 let mut ctx = session.create_execution_ctx();
430
431 let mut acc = mean_f64_accumulator()?;
432 acc.accumulate(&dict_of_seven(), &mut ctx)?;
433 let partial = acc.flush()?;
434
435 let s = partial.as_struct();
436 assert_eq!(
440 s.field("sum").unwrap().as_primitive().as_::<f64>(),
441 Some(42.0)
442 );
443 assert_eq!(
444 s.field("count").unwrap().as_primitive().as_::<u64>(),
445 Some(1)
446 );
447 Ok(())
448 }
449}