Skip to main content

vortex_array/aggregate_fn/
accumulator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_error::vortex_err;
7
8use crate::ArrayRef;
9use crate::Columnar;
10use crate::ExecutionCtx;
11use crate::aggregate_fn::AggregateFn;
12use crate::aggregate_fn::AggregateFnRef;
13use crate::aggregate_fn::AggregateFnVTable;
14use crate::aggregate_fn::session::AggregateFnSessionExt;
15use crate::columnar::AnyColumnar;
16use crate::dtype::DType;
17use crate::executor::max_iterations;
18use crate::expr::stats::Precision;
19use crate::expr::stats::Stat;
20use crate::expr::stats::StatsProvider;
21use crate::scalar::Scalar;
22
23/// Reference-counted type-erased accumulator.
24pub type AccumulatorRef = Box<dyn DynAccumulator>;
25
26/// An accumulator used for computing aggregates over an entire stream of arrays.
27pub struct Accumulator<V: AggregateFnVTable> {
28    /// The vtable of the aggregate function.
29    vtable: V,
30    /// Type-erased aggregate function used for kernel dispatch.
31    aggregate_fn: AggregateFnRef,
32    /// The DType of the input.
33    dtype: DType,
34    /// The DType of the aggregate.
35    return_dtype: DType,
36    /// The DType of the accumulator state.
37    partial_dtype: DType,
38    /// The partial state of the accumulator, updated after each accumulate/merge call.
39    partial: V::Partial,
40}
41
42impl<V: AggregateFnVTable> Accumulator<V> {
43    pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
44        let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
45            vortex_err!(
46                "Aggregate function {} cannot be applied to dtype {}",
47                vtable.id(),
48                dtype
49            )
50        })?;
51        let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
52            vortex_err!(
53                "Aggregate function {} cannot be applied to dtype {}",
54                vtable.id(),
55                dtype
56            )
57        })?;
58        let partial = vtable.empty_partial(&options, &dtype)?;
59        let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
60
61        Ok(Self {
62            vtable,
63            aggregate_fn,
64            dtype,
65            return_dtype,
66            partial_dtype,
67            partial,
68        })
69    }
70}
71
72/// A trait object for type-erased accumulators, used for dynamic dispatch when the aggregate
73/// function is not known at compile time.
74pub trait DynAccumulator: 'static + Send {
75    /// Accumulate a new array into the accumulator's state.
76    fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
77
78    /// Fold an external partial-state scalar into this accumulator's state.
79    ///
80    /// The scalar must have the dtype reported by the vtable's `partial_dtype` for the
81    /// options and input dtype used to construct this accumulator.
82    fn combine_partials(&mut self, other: Scalar) -> VortexResult<()>;
83
84    /// Whether the accumulator's result is fully determined.
85    fn is_saturated(&self) -> bool;
86
87    /// Reset the accumulator's state to the empty group.
88    fn reset(&mut self);
89
90    /// Read the current partial state as a scalar without resetting it.
91    ///
92    /// The returned scalar has the dtype reported by the vtable's `partial_dtype`.
93    fn partial_scalar(&self) -> VortexResult<Scalar>;
94
95    /// Compute the final aggregate result as a scalar without resetting state.
96    fn final_scalar(&self) -> VortexResult<Scalar>;
97
98    /// Flush the accumulation state and return the partial aggregate result as a scalar.
99    ///
100    /// Resets the accumulator state back to the initial state.
101    fn flush(&mut self) -> VortexResult<Scalar>;
102
103    /// Finish the accumulation and return the final aggregate result as a scalar.
104    ///
105    /// Resets the accumulator state back to the initial state.
106    fn finish(&mut self) -> VortexResult<Scalar>;
107}
108
109impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
110    fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
111        if self.is_saturated() {
112            return Ok(());
113        }
114
115        vortex_ensure!(
116            batch.dtype() == &self.dtype,
117            "Input DType mismatch: expected {}, got {}",
118            self.dtype,
119            batch.dtype()
120        );
121
122        // 0. Legacy stats bridge: if this aggregate is still cached under a legacy Stat slot,
123        //    consume that exact stat before kernel dispatch or decode.
124        if let Some(stat) = Stat::from_aggregate_fn(&self.aggregate_fn)
125            && let Some(Precision::Exact(partial)) = batch.statistics().get(stat)
126        {
127            vortex_ensure!(
128                partial.dtype() == &self.partial_dtype,
129                "Aggregate {} read legacy stat {} with dtype {}, expected {}",
130                self.aggregate_fn,
131                stat,
132                partial.dtype(),
133                self.partial_dtype,
134            );
135            self.vtable.combine_partials(&mut self.partial, partial)?;
136            return Ok(());
137        }
138
139        let session = ctx.session().clone();
140        let kernels = &session.aggregate_fns().kernels;
141
142        // 1. Kernel registry first: a registered `(encoding, aggregate_fn)` kernel is strictly
143        //    more specific than the vtable's `try_accumulate` short-circuit. Checking the
144        //    registry first gives kernels for `Combined<V>` aggregates a chance to fire —
145        //    `Combined::try_accumulate` always returns true, so a later kernel check would be
146        //    unreachable.
147        {
148            let kernels_r = kernels.read();
149            let batch_id = batch.encoding_id();
150            let kernel = kernels_r
151                .get(&(batch_id, Some(self.aggregate_fn.id())))
152                .or_else(|| kernels_r.get(&(batch_id, None)))
153                .copied();
154            drop(kernels_r);
155            if let Some(kernel) = kernel
156                && let Some(result) = kernel.aggregate(&self.aggregate_fn, batch, ctx)?
157            {
158                vortex_ensure!(
159                    result.dtype() == &self.partial_dtype,
160                    "Aggregate kernel returned {}, expected {}",
161                    result.dtype(),
162                    self.partial_dtype,
163                );
164                self.vtable.combine_partials(&mut self.partial, result)?;
165                return Ok(());
166            }
167        }
168
169        // 2. Allow the vtable to short-circuit on the raw array before decompression.
170        if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? {
171            return Ok(());
172        }
173
174        // 3. Iteratively check the registry against each intermediate encoding, executing one
175        //    step between checks. Mirrors the loop in `GroupedAccumulator::accumulate_list_view`.
176        //    Iteration 0 re-checks the initial encoding — a redundant HashMap miss, the price of
177        //    keeping the loop body uniform. Terminates on `AnyColumnar` (Canonical or Constant)
178        //    since the vtable's `accumulate(&Columnar)` handles both cases directly.
179        let mut batch = batch.clone();
180        for _ in 0..max_iterations() {
181            if batch.is::<AnyColumnar>() {
182                break;
183            }
184
185            let kernels_r = kernels.read();
186            let batch_id = batch.encoding_id();
187            let kernel = kernels_r
188                .get(&(batch_id, Some(self.aggregate_fn.id())))
189                .or_else(|| kernels_r.get(&(batch_id, None)))
190                .copied();
191            drop(kernels_r);
192            if let Some(kernel) = kernel
193                && let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch, ctx)?
194            {
195                vortex_ensure!(
196                    result.dtype() == &self.partial_dtype,
197                    "Aggregate kernel returned {}, expected {}",
198                    result.dtype(),
199                    self.partial_dtype,
200                );
201                self.vtable.combine_partials(&mut self.partial, result)?;
202                return Ok(());
203            }
204
205            batch = batch.execute(ctx)?;
206        }
207
208        // 4. Otherwise, execute the batch until it is columnar and accumulate it into the state.
209        let columnar = batch.execute::<Columnar>(ctx)?;
210
211        self.vtable.accumulate(&mut self.partial, &columnar, ctx)
212    }
213
214    fn combine_partials(&mut self, other: Scalar) -> VortexResult<()> {
215        self.vtable.combine_partials(&mut self.partial, other)
216    }
217
218    fn is_saturated(&self) -> bool {
219        self.vtable.is_saturated(&self.partial)
220    }
221
222    fn reset(&mut self) {
223        self.vtable.reset(&mut self.partial);
224    }
225
226    fn partial_scalar(&self) -> VortexResult<Scalar> {
227        let partial = self.vtable.to_scalar(&self.partial)?;
228
229        #[cfg(debug_assertions)]
230        {
231            vortex_ensure!(
232                partial.dtype() == &self.partial_dtype,
233                "Aggregate returned incorrect DType on partial_scalar: expected {}, got {}",
234                self.partial_dtype,
235                partial.dtype(),
236            );
237        }
238
239        Ok(partial)
240    }
241
242    fn final_scalar(&self) -> VortexResult<Scalar> {
243        let result = self.vtable.finalize_scalar(&self.partial)?;
244
245        vortex_ensure!(
246            result.dtype() == &self.return_dtype,
247            "Aggregate returned incorrect DType on final_scalar: expected {}, got {}",
248            self.return_dtype,
249            result.dtype(),
250        );
251
252        Ok(result)
253    }
254
255    fn flush(&mut self) -> VortexResult<Scalar> {
256        let partial = self.partial_scalar()?;
257        self.reset();
258        Ok(partial)
259    }
260
261    fn finish(&mut self) -> VortexResult<Scalar> {
262        let result = self.final_scalar()?;
263        self.reset();
264        Ok(result)
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use vortex_buffer::buffer;
271    use vortex_error::VortexResult;
272    use vortex_session::SessionExt;
273    use vortex_session::VortexSession;
274
275    use crate::ArrayRef;
276    use crate::ExecutionCtx;
277    use crate::IntoArray;
278    use crate::VortexSessionExecute;
279    use crate::aggregate_fn::Accumulator;
280    use crate::aggregate_fn::AggregateFnRef;
281    use crate::aggregate_fn::AggregateFnVTable;
282    use crate::aggregate_fn::DynAccumulator;
283    use crate::aggregate_fn::EmptyOptions;
284    use crate::aggregate_fn::combined::Combined;
285    use crate::aggregate_fn::combined::PairOptions;
286    use crate::aggregate_fn::fns::mean::Mean;
287    use crate::aggregate_fn::fns::sum::Sum;
288    use crate::aggregate_fn::kernels::DynAggregateKernel;
289    use crate::aggregate_fn::session::AggregateFnSession;
290    use crate::array::VTable;
291    use crate::arrays::Dict;
292    use crate::arrays::DictArray;
293    use crate::dtype::DType;
294    use crate::dtype::Nullability;
295    use crate::dtype::PType;
296    use crate::scalar::Scalar;
297    use crate::session::ArraySession;
298
299    /// Mean partial sentinel `{sum: 42.0, count: 1}` — distinguishable from the
300    /// natural fan-out result `{sum: 7.0, count: 1}` that `Combined::try_accumulate`
301    /// would produce for `dict_of_seven()`.
302    #[derive(Debug)]
303    struct SentinelMeanPartialKernel;
304    impl DynAggregateKernel for SentinelMeanPartialKernel {
305        fn aggregate(
306            &self,
307            _aggregate_fn: &AggregateFnRef,
308            _batch: &ArrayRef,
309            _ctx: &mut ExecutionCtx,
310        ) -> VortexResult<Option<Scalar>> {
311            Ok(Some(sentinel_partial()))
312        }
313    }
314
315    /// Returns `Ok(None)` => kernel declined, dispatch falls through.
316    #[derive(Debug)]
317    struct DeclineKernel;
318    impl DynAggregateKernel for DeclineKernel {
319        fn aggregate(
320            &self,
321            _aggregate_fn: &AggregateFnRef,
322            _batch: &ArrayRef,
323            _ctx: &mut ExecutionCtx,
324        ) -> VortexResult<Option<Scalar>> {
325            Ok(None)
326        }
327    }
328
329    /// Sum partial sentinel `42.0` — distinguishable from the natural Sum of
330    /// `dict_of_seven()` which is `7.0`.
331    #[derive(Debug)]
332    struct SentinelSumPartialKernel;
333    impl DynAggregateKernel for SentinelSumPartialKernel {
334        fn aggregate(
335            &self,
336            _aggregate_fn: &AggregateFnRef,
337            _batch: &ArrayRef,
338            _ctx: &mut ExecutionCtx,
339        ) -> VortexResult<Option<Scalar>> {
340            Ok(Some(Scalar::primitive(42.0f64, Nullability::Nullable)))
341        }
342    }
343
344    fn fresh_session() -> VortexSession {
345        VortexSession::empty().with::<ArraySession>()
346    }
347
348    fn dict_of_seven() -> ArrayRef {
349        DictArray::try_new(buffer![0u32].into_array(), buffer![7.0f64].into_array())
350            .expect("valid dictionary")
351            .into_array()
352    }
353
354    fn mean_f64_accumulator() -> VortexResult<Accumulator<Combined<Mean>>> {
355        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
356        Accumulator::try_new(
357            Mean::combined(),
358            PairOptions(EmptyOptions, EmptyOptions),
359            dtype,
360        )
361    }
362
363    fn sentinel_partial() -> Scalar {
364        let acc = mean_f64_accumulator().expect("build accumulator");
365        let sum = Scalar::primitive(42.0f64, Nullability::Nullable);
366        let count = Scalar::primitive(1u64, Nullability::NonNullable);
367        Scalar::struct_(acc.partial_dtype, vec![sum, count])
368    }
369
370    /// Kernel registered for `(Dict, Combined<Mean>)` fires in preference to
371    /// `Combined::try_accumulate`'s fan-out path — proves the dispatch reorder.
372    #[test]
373    fn combined_kernel_fires() -> VortexResult<()> {
374        static KERNEL: SentinelMeanPartialKernel = SentinelMeanPartialKernel;
375        let session = fresh_session();
376        session
377            .get::<AggregateFnSession>()
378            .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
379        let mut ctx = session.create_execution_ctx();
380
381        let mut acc = mean_f64_accumulator()?;
382        acc.accumulate(&dict_of_seven(), &mut ctx)?;
383        let partial = acc.flush()?;
384
385        let s = partial.as_struct();
386        assert_eq!(
387            s.field("sum").unwrap().as_primitive().as_::<f64>(),
388            Some(42.0)
389        );
390        assert_eq!(
391            s.field("count").unwrap().as_primitive().as_::<u64>(),
392            Some(1)
393        );
394        Ok(())
395    }
396
397    /// Kernel returns `Ok(None)` => dispatch falls through to `Combined::try_accumulate`'s
398    /// natural fan-out. The natural partial is `{sum: 7.0, count: 1}`.
399    #[test]
400    fn fallback_when_kernel_declines() -> VortexResult<()> {
401        static KERNEL: DeclineKernel = DeclineKernel;
402        let session = fresh_session();
403        session
404            .get::<AggregateFnSession>()
405            .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
406        let mut ctx = session.create_execution_ctx();
407
408        let mut acc = mean_f64_accumulator()?;
409        acc.accumulate(&dict_of_seven(), &mut ctx)?;
410        let partial = acc.flush()?;
411
412        let s = partial.as_struct();
413        assert_eq!(
414            s.field("sum").unwrap().as_primitive().as_::<f64>(),
415            Some(7.0)
416        );
417        assert_eq!(
418            s.field("count").unwrap().as_primitive().as_::<u64>(),
419            Some(1)
420        );
421        Ok(())
422    }
423
424    /// A kernel registered for the inner `(Dict, Sum)` child fires when accumulating a
425    /// Dict batch through `Combined<Mean>`. This is the reusable-primitive case the
426    /// refactor enables: no `(Dict, Combined<Mean>)` kernel is needed.
427    #[test]
428    fn child_kernel_fires_through_combined() -> VortexResult<()> {
429        static KERNEL: SentinelSumPartialKernel = SentinelSumPartialKernel;
430        let session = fresh_session();
431        session
432            .get::<AggregateFnSession>()
433            .register_aggregate_kernel(Dict.id(), Some(Sum.id()), &KERNEL);
434        let mut ctx = session.create_execution_ctx();
435
436        let mut acc = mean_f64_accumulator()?;
437        acc.accumulate(&dict_of_seven(), &mut ctx)?;
438        let partial = acc.flush()?;
439
440        let s = partial.as_struct();
441        // `Sum` child returned the sentinel 42.0 — proves the (Dict, Sum) kernel fired
442        // via `Combined<Mean>`'s fan-out. `Count`'s native `try_accumulate` reads the
443        // batch's valid_count, so count is the real 1.
444        assert_eq!(
445            s.field("sum").unwrap().as_primitive().as_::<f64>(),
446            Some(42.0)
447        );
448        assert_eq!(
449            s.field("count").unwrap().as_primitive().as_::<u64>(),
450            Some(1)
451        );
452        Ok(())
453    }
454}