Skip to main content

vortex_array/aggregate_fn/
accumulator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_error::vortex_err;
7
8use crate::ArrayRef;
9use crate::Columnar;
10use crate::ExecutionCtx;
11use crate::aggregate_fn::AggregateFn;
12use crate::aggregate_fn::AggregateFnRef;
13use crate::aggregate_fn::AggregateFnVTable;
14use crate::aggregate_fn::session::AggregateFnSessionExt;
15use crate::columnar::AnyColumnar;
16use crate::dtype::DType;
17use crate::executor::max_iterations;
18use crate::expr::stats::Precision;
19use crate::expr::stats::Stat;
20use crate::expr::stats::StatsProvider;
21use crate::scalar::Scalar;
22
23/// Reference-counted type-erased accumulator.
24pub type AccumulatorRef = Box<dyn DynAccumulator>;
25
26/// An accumulator used for computing aggregates over an entire stream of arrays.
27pub struct Accumulator<V: AggregateFnVTable> {
28    /// The vtable of the aggregate function.
29    vtable: V,
30    /// Type-erased aggregate function used for kernel dispatch.
31    aggregate_fn: AggregateFnRef,
32    /// The DType of the input.
33    dtype: DType,
34    /// The DType of the aggregate.
35    return_dtype: DType,
36    /// The DType of the accumulator state.
37    partial_dtype: DType,
38    /// The partial state of the accumulator, updated after each accumulate/merge call.
39    partial: V::Partial,
40}
41
42impl<V: AggregateFnVTable> Accumulator<V> {
43    pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
44        let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
45            vortex_err!(
46                "Aggregate function {} cannot be applied to dtype {}",
47                vtable.id(),
48                dtype
49            )
50        })?;
51        let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
52            vortex_err!(
53                "Aggregate function {} cannot be applied to dtype {}",
54                vtable.id(),
55                dtype
56            )
57        })?;
58        let partial = vtable.empty_partial(&options, &dtype)?;
59        let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
60
61        Ok(Self {
62            vtable,
63            aggregate_fn,
64            dtype,
65            return_dtype,
66            partial_dtype,
67            partial,
68        })
69    }
70}
71
72/// A trait object for type-erased accumulators, used for dynamic dispatch when the aggregate
73/// function is not known at compile time.
74pub trait DynAccumulator: 'static + Send {
75    /// Accumulate a new array into the accumulator's state.
76    fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
77
78    /// Fold an external partial-state scalar into this accumulator's state.
79    ///
80    /// The scalar must have the dtype reported by the vtable's `partial_dtype` for the
81    /// options and input dtype used to construct this accumulator.
82    fn combine_partials(&mut self, other: Scalar) -> VortexResult<()>;
83
84    /// Whether the accumulator's result is fully determined.
85    fn is_saturated(&self) -> bool;
86
87    /// Reset the accumulator's state to the empty group.
88    fn reset(&mut self);
89
90    /// Read the current partial state as a scalar without resetting it.
91    ///
92    /// The returned scalar has the dtype reported by the vtable's `partial_dtype`.
93    fn partial_scalar(&self) -> VortexResult<Scalar>;
94
95    /// Compute the final aggregate result as a scalar without resetting state.
96    fn final_scalar(&self) -> VortexResult<Scalar>;
97
98    /// Flush the accumulation state and return the partial aggregate result as a scalar.
99    ///
100    /// Resets the accumulator state back to the initial state.
101    fn flush(&mut self) -> VortexResult<Scalar>;
102
103    /// Finish the accumulation and return the final aggregate result as a scalar.
104    ///
105    /// Resets the accumulator state back to the initial state.
106    fn finish(&mut self) -> VortexResult<Scalar>;
107}
108
109impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
110    fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
111        if self.is_saturated() {
112            return Ok(());
113        }
114
115        vortex_ensure!(
116            batch.dtype() == &self.dtype,
117            "Input DType mismatch: expected {}, got {}",
118            self.dtype,
119            batch.dtype()
120        );
121
122        // 0. Legacy stats bridge: if this aggregate is still cached under a legacy Stat slot,
123        //    consume that exact stat before kernel dispatch or decode.
124        if let Some(stat) = Stat::from_aggregate_fn(&self.aggregate_fn)
125            && let Precision::Exact(partial) = batch.statistics().get(stat)
126        {
127            let partial = if partial.dtype() == &self.partial_dtype {
128                partial
129            } else {
130                vortex_ensure!(
131                    partial.dtype().eq_ignore_nullability(&self.partial_dtype),
132                    "Aggregate {} read legacy stat {} with dtype {}, expected {}",
133                    self.aggregate_fn,
134                    stat,
135                    partial.dtype(),
136                    self.partial_dtype,
137                );
138                partial.cast(&self.partial_dtype)?
139            };
140            self.vtable.combine_partials(&mut self.partial, partial)?;
141            return Ok(());
142        }
143
144        let session = ctx.session().clone();
145
146        // 1. Kernel registry first: a registered `(encoding, aggregate_fn)` kernel is strictly
147        //    more specific than the vtable's `try_accumulate` short-circuit. Checking the
148        //    registry first gives kernels for `Combined<V>` aggregates a chance to fire —
149        //    `Combined::try_accumulate` always returns true, so a later kernel check would be
150        //    unreachable.
151        {
152            let kernel = session
153                .aggregate_fns()
154                .find_aggregate_kernel(batch.encoding_id(), self.aggregate_fn.id());
155            if let Some(kernel) = kernel
156                && let Some(result) = kernel.aggregate(&self.aggregate_fn, batch, ctx)?
157            {
158                vortex_ensure!(
159                    result.dtype() == &self.partial_dtype,
160                    "Aggregate kernel returned {}, expected {}",
161                    result.dtype(),
162                    self.partial_dtype,
163                );
164                self.vtable.combine_partials(&mut self.partial, result)?;
165                return Ok(());
166            }
167        }
168
169        // 2. Allow the vtable to short-circuit on the raw array before decompression.
170        if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? {
171            return Ok(());
172        }
173
174        // 3. Iteratively check the registry against each intermediate encoding, executing one
175        //    step between checks. Mirrors the loop in `GroupedAccumulator::accumulate_list_view`.
176        //    Iteration 0 re-checks the initial encoding — a redundant HashMap miss, the price of
177        //    keeping the loop body uniform. Terminates on `AnyColumnar` (Canonical or Constant)
178        //    since the vtable's `accumulate(&Columnar)` handles both cases directly.
179        let mut batch = batch.clone();
180        for _ in 0..max_iterations() {
181            if batch.is::<AnyColumnar>() {
182                break;
183            }
184
185            if let Some(kernel) = session
186                .aggregate_fns()
187                .find_aggregate_kernel(batch.encoding_id(), self.aggregate_fn.id())
188                && let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch, ctx)?
189            {
190                vortex_ensure!(
191                    result.dtype() == &self.partial_dtype,
192                    "Aggregate kernel returned {}, expected {}",
193                    result.dtype(),
194                    self.partial_dtype,
195                );
196                self.vtable.combine_partials(&mut self.partial, result)?;
197                return Ok(());
198            }
199
200            batch = batch.execute(ctx)?;
201        }
202
203        // 4. Otherwise, execute the batch until it is columnar and accumulate it into the state.
204        let columnar = batch.execute::<Columnar>(ctx)?;
205
206        self.vtable.accumulate(&mut self.partial, &columnar, ctx)
207    }
208
209    fn combine_partials(&mut self, other: Scalar) -> VortexResult<()> {
210        self.vtable.combine_partials(&mut self.partial, other)
211    }
212
213    fn is_saturated(&self) -> bool {
214        self.vtable.is_saturated(&self.partial)
215    }
216
217    fn reset(&mut self) {
218        self.vtable.reset(&mut self.partial);
219    }
220
221    fn partial_scalar(&self) -> VortexResult<Scalar> {
222        let partial = self.vtable.to_scalar(&self.partial)?;
223
224        #[cfg(debug_assertions)]
225        {
226            vortex_ensure!(
227                partial.dtype() == &self.partial_dtype,
228                "Aggregate returned incorrect DType on partial_scalar: expected {}, got {}",
229                self.partial_dtype,
230                partial.dtype(),
231            );
232        }
233
234        Ok(partial)
235    }
236
237    fn final_scalar(&self) -> VortexResult<Scalar> {
238        let result = self.vtable.finalize_scalar(&self.partial)?;
239
240        vortex_ensure!(
241            result.dtype() == &self.return_dtype,
242            "Aggregate returned incorrect DType on final_scalar: expected {}, got {}",
243            self.return_dtype,
244            result.dtype(),
245        );
246
247        Ok(result)
248    }
249
250    fn flush(&mut self) -> VortexResult<Scalar> {
251        let partial = self.partial_scalar()?;
252        self.reset();
253        Ok(partial)
254    }
255
256    fn finish(&mut self) -> VortexResult<Scalar> {
257        let result = self.final_scalar()?;
258        self.reset();
259        Ok(result)
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use vortex_buffer::buffer;
266    use vortex_error::VortexResult;
267    use vortex_session::SessionExt;
268    use vortex_session::VortexSession;
269
270    use crate::ArrayRef;
271    use crate::ExecutionCtx;
272    use crate::IntoArray;
273    use crate::VortexSessionExecute;
274    use crate::aggregate_fn::Accumulator;
275    use crate::aggregate_fn::AggregateFnRef;
276    use crate::aggregate_fn::AggregateFnVTable;
277    use crate::aggregate_fn::DynAccumulator;
278    use crate::aggregate_fn::EmptyOptions;
279    use crate::aggregate_fn::combined::Combined;
280    use crate::aggregate_fn::combined::PairOptions;
281    use crate::aggregate_fn::fns::mean::Mean;
282    use crate::aggregate_fn::fns::sum::Sum;
283    use crate::aggregate_fn::kernels::DynAggregateKernel;
284    use crate::aggregate_fn::session::AggregateFnSession;
285    use crate::array::VTable;
286    use crate::arrays::Dict;
287    use crate::arrays::DictArray;
288    use crate::dtype::DType;
289    use crate::dtype::Nullability;
290    use crate::dtype::PType;
291    use crate::scalar::Scalar;
292    use crate::session::ArraySession;
293
294    /// Mean partial sentinel `{sum: 42.0, count: 1}` — distinguishable from the
295    /// natural fan-out result `{sum: 7.0, count: 1}` that `Combined::try_accumulate`
296    /// would produce for `dict_of_seven()`.
297    #[derive(Debug)]
298    struct SentinelMeanPartialKernel;
299    impl DynAggregateKernel for SentinelMeanPartialKernel {
300        fn aggregate(
301            &self,
302            _aggregate_fn: &AggregateFnRef,
303            _batch: &ArrayRef,
304            _ctx: &mut ExecutionCtx,
305        ) -> VortexResult<Option<Scalar>> {
306            Ok(Some(sentinel_partial()))
307        }
308    }
309
310    /// Returns `Ok(None)` => kernel declined, dispatch falls through.
311    #[derive(Debug)]
312    struct DeclineKernel;
313    impl DynAggregateKernel for DeclineKernel {
314        fn aggregate(
315            &self,
316            _aggregate_fn: &AggregateFnRef,
317            _batch: &ArrayRef,
318            _ctx: &mut ExecutionCtx,
319        ) -> VortexResult<Option<Scalar>> {
320            Ok(None)
321        }
322    }
323
324    /// Sum partial sentinel `42.0` — distinguishable from the natural Sum of
325    /// `dict_of_seven()` which is `7.0`.
326    #[derive(Debug)]
327    struct SentinelSumPartialKernel;
328    impl DynAggregateKernel for SentinelSumPartialKernel {
329        fn aggregate(
330            &self,
331            _aggregate_fn: &AggregateFnRef,
332            _batch: &ArrayRef,
333            _ctx: &mut ExecutionCtx,
334        ) -> VortexResult<Option<Scalar>> {
335            Ok(Some(Scalar::primitive(42.0f64, Nullability::Nullable)))
336        }
337    }
338
339    fn fresh_session() -> VortexSession {
340        VortexSession::empty().with::<ArraySession>()
341    }
342
343    fn dict_of_seven() -> ArrayRef {
344        DictArray::try_new(buffer![0u32].into_array(), buffer![7.0f64].into_array())
345            .expect("valid dictionary")
346            .into_array()
347    }
348
349    fn mean_f64_accumulator() -> VortexResult<Accumulator<Combined<Mean>>> {
350        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
351        Accumulator::try_new(
352            Mean::combined(),
353            PairOptions(EmptyOptions, EmptyOptions),
354            dtype,
355        )
356    }
357
358    fn sentinel_partial() -> Scalar {
359        let acc = mean_f64_accumulator().expect("build accumulator");
360        let sum = Scalar::primitive(42.0f64, Nullability::Nullable);
361        let count = Scalar::primitive(1u64, Nullability::NonNullable);
362        Scalar::struct_(acc.partial_dtype, vec![sum, count])
363    }
364
365    /// Kernel registered for `(Dict, Combined<Mean>)` fires in preference to
366    /// `Combined::try_accumulate`'s fan-out path — proves the dispatch reorder.
367    #[test]
368    fn combined_kernel_fires() -> VortexResult<()> {
369        static KERNEL: SentinelMeanPartialKernel = SentinelMeanPartialKernel;
370        let session = fresh_session();
371        session
372            .get::<AggregateFnSession>()
373            .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
374        let mut ctx = session.create_execution_ctx();
375
376        let mut acc = mean_f64_accumulator()?;
377        acc.accumulate(&dict_of_seven(), &mut ctx)?;
378        let partial = acc.flush()?;
379
380        let s = partial.as_struct();
381        assert_eq!(
382            s.field("sum").unwrap().as_primitive().as_::<f64>(),
383            Some(42.0)
384        );
385        assert_eq!(
386            s.field("count").unwrap().as_primitive().as_::<u64>(),
387            Some(1)
388        );
389        Ok(())
390    }
391
392    /// Kernel returns `Ok(None)` => dispatch falls through to `Combined::try_accumulate`'s
393    /// natural fan-out. The natural partial is `{sum: 7.0, count: 1}`.
394    #[test]
395    fn fallback_when_kernel_declines() -> VortexResult<()> {
396        static KERNEL: DeclineKernel = DeclineKernel;
397        let session = fresh_session();
398        session
399            .get::<AggregateFnSession>()
400            .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
401        let mut ctx = session.create_execution_ctx();
402
403        let mut acc = mean_f64_accumulator()?;
404        acc.accumulate(&dict_of_seven(), &mut ctx)?;
405        let partial = acc.flush()?;
406
407        let s = partial.as_struct();
408        assert_eq!(
409            s.field("sum").unwrap().as_primitive().as_::<f64>(),
410            Some(7.0)
411        );
412        assert_eq!(
413            s.field("count").unwrap().as_primitive().as_::<u64>(),
414            Some(1)
415        );
416        Ok(())
417    }
418
419    /// A kernel registered for the inner `(Dict, Sum)` child fires when accumulating a
420    /// Dict batch through `Combined<Mean>`. This is the reusable-primitive case the
421    /// refactor enables: no `(Dict, Combined<Mean>)` kernel is needed.
422    #[test]
423    fn child_kernel_fires_through_combined() -> VortexResult<()> {
424        static KERNEL: SentinelSumPartialKernel = SentinelSumPartialKernel;
425        let session = fresh_session();
426        session
427            .get::<AggregateFnSession>()
428            .register_aggregate_kernel(Dict.id(), Some(Sum.id()), &KERNEL);
429        let mut ctx = session.create_execution_ctx();
430
431        let mut acc = mean_f64_accumulator()?;
432        acc.accumulate(&dict_of_seven(), &mut ctx)?;
433        let partial = acc.flush()?;
434
435        let s = partial.as_struct();
436        // `Sum` child returned the sentinel 42.0 — proves the (Dict, Sum) kernel fired
437        // via `Combined<Mean>`'s fan-out. `Count`'s native `try_accumulate` reads the
438        // batch's valid_count, so count is the real 1.
439        assert_eq!(
440            s.field("sum").unwrap().as_primitive().as_::<f64>(),
441            Some(42.0)
442        );
443        assert_eq!(
444            s.field("count").unwrap().as_primitive().as_::<u64>(),
445            Some(1)
446        );
447        Ok(())
448    }
449}