Skip to main content

vortex_array/aggregate_fn/
accumulator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_error::vortex_err;
7
8use crate::ArrayRef;
9use crate::Columnar;
10use crate::ExecutionCtx;
11use crate::aggregate_fn::AggregateFn;
12use crate::aggregate_fn::AggregateFnRef;
13use crate::aggregate_fn::AggregateFnVTable;
14use crate::aggregate_fn::session::AggregateFnSessionExt;
15use crate::columnar::AnyColumnar;
16use crate::dtype::DType;
17use crate::executor::max_iterations;
18use crate::expr::stats::Precision;
19use crate::expr::stats::Stat;
20use crate::expr::stats::StatsProvider;
21use crate::scalar::Scalar;
22
23/// Reference-counted type-erased accumulator.
24pub type AccumulatorRef = Box<dyn DynAccumulator>;
25
26/// An accumulator used for computing aggregates over an entire stream of arrays.
27pub struct Accumulator<V: AggregateFnVTable> {
28    /// The vtable of the aggregate function.
29    vtable: V,
30    /// Type-erased aggregate function used for kernel dispatch.
31    aggregate_fn: AggregateFnRef,
32    /// The DType of the input.
33    dtype: DType,
34    /// The DType of the aggregate.
35    return_dtype: DType,
36    /// The DType of the accumulator state.
37    partial_dtype: DType,
38    /// The partial state of the accumulator, updated after each accumulate/merge call.
39    partial: V::Partial,
40}
41
42impl<V: AggregateFnVTable> Accumulator<V> {
43    pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
44        let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
45            vortex_err!(
46                "Aggregate function {} cannot be applied to dtype {}",
47                vtable.id(),
48                dtype
49            )
50        })?;
51        let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
52            vortex_err!(
53                "Aggregate function {} cannot be applied to dtype {}",
54                vtable.id(),
55                dtype
56            )
57        })?;
58        let partial = vtable.empty_partial(&options, &dtype)?;
59        let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
60
61        Ok(Self {
62            vtable,
63            aggregate_fn,
64            dtype,
65            return_dtype,
66            partial_dtype,
67            partial,
68        })
69    }
70}
71
72/// A trait object for type-erased accumulators, used for dynamic dispatch when the aggregate
73/// function is not known at compile time.
74pub trait DynAccumulator: 'static + Send {
75    /// Accumulate a new array into the accumulator's state.
76    fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
77
78    /// Fold an external partial-state scalar into this accumulator's state.
79    ///
80    /// The scalar must have the dtype reported by the vtable's `partial_dtype` for the
81    /// options and input dtype used to construct this accumulator.
82    fn combine_partials(&mut self, other: Scalar) -> VortexResult<()>;
83
84    /// Whether the accumulator's result is fully determined.
85    fn is_saturated(&self) -> bool;
86
87    /// Reset the accumulator's state to the empty group.
88    fn reset(&mut self);
89
90    /// Read the current partial state as a scalar without resetting it.
91    ///
92    /// The returned scalar has the dtype reported by the vtable's `partial_dtype`.
93    fn partial_scalar(&self) -> VortexResult<Scalar>;
94
95    /// Compute the final aggregate result as a scalar without resetting state.
96    fn final_scalar(&self) -> VortexResult<Scalar>;
97
98    /// Flush the accumulation state and return the partial aggregate result as a scalar.
99    ///
100    /// Resets the accumulator state back to the initial state.
101    fn flush(&mut self) -> VortexResult<Scalar>;
102
103    /// Finish the accumulation and return the final aggregate result as a scalar.
104    ///
105    /// Resets the accumulator state back to the initial state.
106    fn finish(&mut self) -> VortexResult<Scalar>;
107}
108
109impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
110    fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
111        if self.is_saturated() {
112            return Ok(());
113        }
114
115        vortex_ensure!(
116            batch.dtype() == &self.dtype,
117            "Input DType mismatch: expected {}, got {}",
118            self.dtype,
119            batch.dtype()
120        );
121
122        // 0. Legacy stats bridge: if this aggregate is still cached under a legacy Stat slot,
123        //    consume that exact stat before kernel dispatch or decode.
124        if let Some(stat) = Stat::from_aggregate_fn(&self.aggregate_fn)
125            && let Precision::Exact(partial) = batch.statistics().get(stat)
126        {
127            let partial = if partial.dtype() == &self.partial_dtype {
128                partial
129            } else {
130                vortex_ensure!(
131                    partial.dtype().eq_ignore_nullability(&self.partial_dtype),
132                    "Aggregate {} read legacy stat {} with dtype {}, expected {}",
133                    self.aggregate_fn,
134                    stat,
135                    partial.dtype(),
136                    self.partial_dtype,
137                );
138                partial.cast(&self.partial_dtype)?
139            };
140            self.vtable.combine_partials(&mut self.partial, partial)?;
141            return Ok(());
142        }
143
144        let session = ctx.session().clone();
145        let kernels = &session.aggregate_fns().kernels;
146
147        // 1. Kernel registry first: a registered `(encoding, aggregate_fn)` kernel is strictly
148        //    more specific than the vtable's `try_accumulate` short-circuit. Checking the
149        //    registry first gives kernels for `Combined<V>` aggregates a chance to fire —
150        //    `Combined::try_accumulate` always returns true, so a later kernel check would be
151        //    unreachable.
152        {
153            let kernels_r = kernels.read();
154            let batch_id = batch.encoding_id();
155            let kernel = kernels_r
156                .get(&(batch_id, Some(self.aggregate_fn.id())))
157                .or_else(|| kernels_r.get(&(batch_id, None)))
158                .copied();
159            drop(kernels_r);
160            if let Some(kernel) = kernel
161                && let Some(result) = kernel.aggregate(&self.aggregate_fn, batch, ctx)?
162            {
163                vortex_ensure!(
164                    result.dtype() == &self.partial_dtype,
165                    "Aggregate kernel returned {}, expected {}",
166                    result.dtype(),
167                    self.partial_dtype,
168                );
169                self.vtable.combine_partials(&mut self.partial, result)?;
170                return Ok(());
171            }
172        }
173
174        // 2. Allow the vtable to short-circuit on the raw array before decompression.
175        if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? {
176            return Ok(());
177        }
178
179        // 3. Iteratively check the registry against each intermediate encoding, executing one
180        //    step between checks. Mirrors the loop in `GroupedAccumulator::accumulate_list_view`.
181        //    Iteration 0 re-checks the initial encoding — a redundant HashMap miss, the price of
182        //    keeping the loop body uniform. Terminates on `AnyColumnar` (Canonical or Constant)
183        //    since the vtable's `accumulate(&Columnar)` handles both cases directly.
184        let mut batch = batch.clone();
185        for _ in 0..max_iterations() {
186            if batch.is::<AnyColumnar>() {
187                break;
188            }
189
190            let kernels_r = kernels.read();
191            let batch_id = batch.encoding_id();
192            let kernel = kernels_r
193                .get(&(batch_id, Some(self.aggregate_fn.id())))
194                .or_else(|| kernels_r.get(&(batch_id, None)))
195                .copied();
196            drop(kernels_r);
197            if let Some(kernel) = kernel
198                && let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch, ctx)?
199            {
200                vortex_ensure!(
201                    result.dtype() == &self.partial_dtype,
202                    "Aggregate kernel returned {}, expected {}",
203                    result.dtype(),
204                    self.partial_dtype,
205                );
206                self.vtable.combine_partials(&mut self.partial, result)?;
207                return Ok(());
208            }
209
210            batch = batch.execute(ctx)?;
211        }
212
213        // 4. Otherwise, execute the batch until it is columnar and accumulate it into the state.
214        let columnar = batch.execute::<Columnar>(ctx)?;
215
216        self.vtable.accumulate(&mut self.partial, &columnar, ctx)
217    }
218
219    fn combine_partials(&mut self, other: Scalar) -> VortexResult<()> {
220        self.vtable.combine_partials(&mut self.partial, other)
221    }
222
223    fn is_saturated(&self) -> bool {
224        self.vtable.is_saturated(&self.partial)
225    }
226
227    fn reset(&mut self) {
228        self.vtable.reset(&mut self.partial);
229    }
230
231    fn partial_scalar(&self) -> VortexResult<Scalar> {
232        let partial = self.vtable.to_scalar(&self.partial)?;
233
234        #[cfg(debug_assertions)]
235        {
236            vortex_ensure!(
237                partial.dtype() == &self.partial_dtype,
238                "Aggregate returned incorrect DType on partial_scalar: expected {}, got {}",
239                self.partial_dtype,
240                partial.dtype(),
241            );
242        }
243
244        Ok(partial)
245    }
246
247    fn final_scalar(&self) -> VortexResult<Scalar> {
248        let result = self.vtable.finalize_scalar(&self.partial)?;
249
250        vortex_ensure!(
251            result.dtype() == &self.return_dtype,
252            "Aggregate returned incorrect DType on final_scalar: expected {}, got {}",
253            self.return_dtype,
254            result.dtype(),
255        );
256
257        Ok(result)
258    }
259
260    fn flush(&mut self) -> VortexResult<Scalar> {
261        let partial = self.partial_scalar()?;
262        self.reset();
263        Ok(partial)
264    }
265
266    fn finish(&mut self) -> VortexResult<Scalar> {
267        let result = self.final_scalar()?;
268        self.reset();
269        Ok(result)
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use vortex_buffer::buffer;
276    use vortex_error::VortexResult;
277    use vortex_session::SessionExt;
278    use vortex_session::VortexSession;
279
280    use crate::ArrayRef;
281    use crate::ExecutionCtx;
282    use crate::IntoArray;
283    use crate::VortexSessionExecute;
284    use crate::aggregate_fn::Accumulator;
285    use crate::aggregate_fn::AggregateFnRef;
286    use crate::aggregate_fn::AggregateFnVTable;
287    use crate::aggregate_fn::DynAccumulator;
288    use crate::aggregate_fn::EmptyOptions;
289    use crate::aggregate_fn::combined::Combined;
290    use crate::aggregate_fn::combined::PairOptions;
291    use crate::aggregate_fn::fns::mean::Mean;
292    use crate::aggregate_fn::fns::sum::Sum;
293    use crate::aggregate_fn::kernels::DynAggregateKernel;
294    use crate::aggregate_fn::session::AggregateFnSession;
295    use crate::array::VTable;
296    use crate::arrays::Dict;
297    use crate::arrays::DictArray;
298    use crate::dtype::DType;
299    use crate::dtype::Nullability;
300    use crate::dtype::PType;
301    use crate::scalar::Scalar;
302    use crate::session::ArraySession;
303
304    /// Mean partial sentinel `{sum: 42.0, count: 1}` — distinguishable from the
305    /// natural fan-out result `{sum: 7.0, count: 1}` that `Combined::try_accumulate`
306    /// would produce for `dict_of_seven()`.
307    #[derive(Debug)]
308    struct SentinelMeanPartialKernel;
309    impl DynAggregateKernel for SentinelMeanPartialKernel {
310        fn aggregate(
311            &self,
312            _aggregate_fn: &AggregateFnRef,
313            _batch: &ArrayRef,
314            _ctx: &mut ExecutionCtx,
315        ) -> VortexResult<Option<Scalar>> {
316            Ok(Some(sentinel_partial()))
317        }
318    }
319
320    /// Returns `Ok(None)` => kernel declined, dispatch falls through.
321    #[derive(Debug)]
322    struct DeclineKernel;
323    impl DynAggregateKernel for DeclineKernel {
324        fn aggregate(
325            &self,
326            _aggregate_fn: &AggregateFnRef,
327            _batch: &ArrayRef,
328            _ctx: &mut ExecutionCtx,
329        ) -> VortexResult<Option<Scalar>> {
330            Ok(None)
331        }
332    }
333
334    /// Sum partial sentinel `42.0` — distinguishable from the natural Sum of
335    /// `dict_of_seven()` which is `7.0`.
336    #[derive(Debug)]
337    struct SentinelSumPartialKernel;
338    impl DynAggregateKernel for SentinelSumPartialKernel {
339        fn aggregate(
340            &self,
341            _aggregate_fn: &AggregateFnRef,
342            _batch: &ArrayRef,
343            _ctx: &mut ExecutionCtx,
344        ) -> VortexResult<Option<Scalar>> {
345            Ok(Some(Scalar::primitive(42.0f64, Nullability::Nullable)))
346        }
347    }
348
349    fn fresh_session() -> VortexSession {
350        VortexSession::empty().with::<ArraySession>()
351    }
352
353    fn dict_of_seven() -> ArrayRef {
354        DictArray::try_new(buffer![0u32].into_array(), buffer![7.0f64].into_array())
355            .expect("valid dictionary")
356            .into_array()
357    }
358
359    fn mean_f64_accumulator() -> VortexResult<Accumulator<Combined<Mean>>> {
360        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
361        Accumulator::try_new(
362            Mean::combined(),
363            PairOptions(EmptyOptions, EmptyOptions),
364            dtype,
365        )
366    }
367
368    fn sentinel_partial() -> Scalar {
369        let acc = mean_f64_accumulator().expect("build accumulator");
370        let sum = Scalar::primitive(42.0f64, Nullability::Nullable);
371        let count = Scalar::primitive(1u64, Nullability::NonNullable);
372        Scalar::struct_(acc.partial_dtype, vec![sum, count])
373    }
374
375    /// Kernel registered for `(Dict, Combined<Mean>)` fires in preference to
376    /// `Combined::try_accumulate`'s fan-out path — proves the dispatch reorder.
377    #[test]
378    fn combined_kernel_fires() -> VortexResult<()> {
379        static KERNEL: SentinelMeanPartialKernel = SentinelMeanPartialKernel;
380        let session = fresh_session();
381        session
382            .get::<AggregateFnSession>()
383            .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
384        let mut ctx = session.create_execution_ctx();
385
386        let mut acc = mean_f64_accumulator()?;
387        acc.accumulate(&dict_of_seven(), &mut ctx)?;
388        let partial = acc.flush()?;
389
390        let s = partial.as_struct();
391        assert_eq!(
392            s.field("sum").unwrap().as_primitive().as_::<f64>(),
393            Some(42.0)
394        );
395        assert_eq!(
396            s.field("count").unwrap().as_primitive().as_::<u64>(),
397            Some(1)
398        );
399        Ok(())
400    }
401
402    /// Kernel returns `Ok(None)` => dispatch falls through to `Combined::try_accumulate`'s
403    /// natural fan-out. The natural partial is `{sum: 7.0, count: 1}`.
404    #[test]
405    fn fallback_when_kernel_declines() -> VortexResult<()> {
406        static KERNEL: DeclineKernel = DeclineKernel;
407        let session = fresh_session();
408        session
409            .get::<AggregateFnSession>()
410            .register_aggregate_kernel(Dict.id(), Some(Mean::combined().id()), &KERNEL);
411        let mut ctx = session.create_execution_ctx();
412
413        let mut acc = mean_f64_accumulator()?;
414        acc.accumulate(&dict_of_seven(), &mut ctx)?;
415        let partial = acc.flush()?;
416
417        let s = partial.as_struct();
418        assert_eq!(
419            s.field("sum").unwrap().as_primitive().as_::<f64>(),
420            Some(7.0)
421        );
422        assert_eq!(
423            s.field("count").unwrap().as_primitive().as_::<u64>(),
424            Some(1)
425        );
426        Ok(())
427    }
428
429    /// A kernel registered for the inner `(Dict, Sum)` child fires when accumulating a
430    /// Dict batch through `Combined<Mean>`. This is the reusable-primitive case the
431    /// refactor enables: no `(Dict, Combined<Mean>)` kernel is needed.
432    #[test]
433    fn child_kernel_fires_through_combined() -> VortexResult<()> {
434        static KERNEL: SentinelSumPartialKernel = SentinelSumPartialKernel;
435        let session = fresh_session();
436        session
437            .get::<AggregateFnSession>()
438            .register_aggregate_kernel(Dict.id(), Some(Sum.id()), &KERNEL);
439        let mut ctx = session.create_execution_ctx();
440
441        let mut acc = mean_f64_accumulator()?;
442        acc.accumulate(&dict_of_seven(), &mut ctx)?;
443        let partial = acc.flush()?;
444
445        let s = partial.as_struct();
446        // `Sum` child returned the sentinel 42.0 — proves the (Dict, Sum) kernel fired
447        // via `Combined<Mean>`'s fan-out. `Count`'s native `try_accumulate` reads the
448        // batch's valid_count, so count is the real 1.
449        assert_eq!(
450            s.field("sum").unwrap().as_primitive().as_::<f64>(),
451            Some(42.0)
452        );
453        assert_eq!(
454            s.field("count").unwrap().as_primitive().as_::<u64>(),
455            Some(1)
456        );
457        Ok(())
458    }
459}