Skip to main content

vortex_array/aggregate_fn/
accumulator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_error::vortex_err;
7
8use crate::AnyCanonical;
9use crate::ArrayRef;
10use crate::Columnar;
11use crate::ExecutionCtx;
12use crate::aggregate_fn::AggregateFn;
13use crate::aggregate_fn::AggregateFnRef;
14use crate::aggregate_fn::AggregateFnVTable;
15use crate::aggregate_fn::session::AggregateFnSessionExt;
16use crate::dtype::DType;
17use crate::executor::MAX_ITERATIONS;
18use crate::scalar::Scalar;
19
20/// Reference-counted type-erased accumulator.
21pub type AccumulatorRef = Box<dyn DynAccumulator>;
22
23/// An accumulator used for computing aggregates over an entire stream of arrays.
24pub struct Accumulator<V: AggregateFnVTable> {
25    /// The vtable of the aggregate function.
26    vtable: V,
27    /// Type-erased aggregate function used for kernel dispatch.
28    aggregate_fn: AggregateFnRef,
29    /// The DType of the input.
30    dtype: DType,
31    /// The DType of the aggregate.
32    return_dtype: DType,
33    /// The DType of the accumulator state.
34    partial_dtype: DType,
35    /// The partial state of the accumulator, updated after each accumulate/merge call.
36    partial: V::Partial,
37}
38
39impl<V: AggregateFnVTable> Accumulator<V> {
40    pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
41        let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
42            vortex_err!(
43                "Aggregate function {} cannot be applied to dtype {}",
44                vtable.id(),
45                dtype
46            )
47        })?;
48        let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
49            vortex_err!(
50                "Aggregate function {} cannot be applied to dtype {}",
51                vtable.id(),
52                dtype
53            )
54        })?;
55        let partial = vtable.empty_partial(&options, &dtype)?;
56        let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
57
58        Ok(Self {
59            vtable,
60            aggregate_fn,
61            dtype,
62            return_dtype,
63            partial_dtype,
64            partial,
65        })
66    }
67}
68
69/// A trait object for type-erased accumulators, used for dynamic dispatch when the aggregate
70/// function is not known at compile time.
71pub trait DynAccumulator: 'static + Send {
72    /// Accumulate a new array into the accumulator's state.
73    fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
74
75    /// Whether the accumulator's result is fully determined.
76    fn is_saturated(&self) -> bool;
77
78    /// Flush the accumulation state and return the partial aggregate result as a scalar.
79    ///
80    /// Resets the accumulator state back to the initial state.
81    fn flush(&mut self) -> VortexResult<Scalar>;
82
83    /// Finish the accumulation and return the final aggregate result as a scalar.
84    ///
85    /// Resets the accumulator state back to the initial state.
86    fn finish(&mut self) -> VortexResult<Scalar>;
87}
88
89impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
90    fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
91        if self.is_saturated() {
92            return Ok(());
93        }
94
95        vortex_ensure!(
96            batch.dtype() == &self.dtype,
97            "Input DType mismatch: expected {}, got {}",
98            self.dtype,
99            batch.dtype()
100        );
101
102        // Allow the vtable to short-circuit on the raw array before decompression.
103        if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? {
104            return Ok(());
105        }
106
107        let session = ctx.session().clone();
108        let kernels = &session.aggregate_fns().kernels;
109
110        let mut batch = batch.clone();
111        for _ in 0..*MAX_ITERATIONS {
112            if batch.is::<AnyCanonical>() {
113                break;
114            }
115
116            let kernels_r = kernels.read();
117            let batch_id = batch.encoding_id();
118            if let Some(result) = kernels_r
119                .get(&(batch_id.clone(), Some(self.aggregate_fn.id())))
120                .or_else(|| kernels_r.get(&(batch_id, None)))
121                .and_then(|kernel| {
122                    kernel
123                        .aggregate(&self.aggregate_fn, &batch, ctx)
124                        .transpose()
125                })
126                .transpose()?
127            {
128                vortex_ensure!(
129                    result.dtype() == &self.partial_dtype,
130                    "Aggregate kernel returned {}, expected {}",
131                    result.dtype(),
132                    self.partial_dtype,
133                );
134                self.vtable.combine_partials(&mut self.partial, result)?;
135                return Ok(());
136            }
137
138            // Execute one step and try again
139            batch = batch.execute(ctx)?;
140        }
141
142        // Otherwise, execute the batch until it is columnar and accumulate it into the state.
143        let columnar = batch.execute::<Columnar>(ctx)?;
144
145        self.vtable.accumulate(&mut self.partial, &columnar, ctx)
146    }
147
148    fn is_saturated(&self) -> bool {
149        self.vtable.is_saturated(&self.partial)
150    }
151
152    fn flush(&mut self) -> VortexResult<Scalar> {
153        let partial = self.vtable.to_scalar(&self.partial)?;
154        self.vtable.reset(&mut self.partial);
155
156        #[cfg(debug_assertions)]
157        {
158            vortex_ensure!(
159                partial.dtype() == &self.partial_dtype,
160                "Aggregate kernel returned incorrect DType on flush: expected {}, got {}",
161                self.partial_dtype,
162                partial.dtype(),
163            );
164        }
165
166        Ok(partial)
167    }
168
169    fn finish(&mut self) -> VortexResult<Scalar> {
170        let result = self.vtable.finalize_scalar(&self.partial)?;
171        self.vtable.reset(&mut self.partial);
172
173        vortex_ensure!(
174            result.dtype() == &self.return_dtype,
175            "Aggregate kernel returned incorrect DType on finalize: expected {}, got {}",
176            self.return_dtype,
177            result.dtype(),
178        );
179
180        Ok(result)
181    }
182}