Skip to main content

vortex_array/aggregate_fn/
accumulator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_error::vortex_err;
7
8use crate::AnyCanonical;
9use crate::ArrayRef;
10use crate::Columnar;
11use crate::DynArray;
12use crate::ExecutionCtx;
13use crate::aggregate_fn::AggregateFn;
14use crate::aggregate_fn::AggregateFnRef;
15use crate::aggregate_fn::AggregateFnVTable;
16use crate::aggregate_fn::session::AggregateFnSessionExt;
17use crate::dtype::DType;
18use crate::executor::MAX_ITERATIONS;
19use crate::scalar::Scalar;
20
21/// Reference-counted type-erased accumulator.
22pub type AccumulatorRef = Box<dyn DynAccumulator>;
23
24/// An accumulator used for computing aggregates over an entire stream of arrays.
25pub struct Accumulator<V: AggregateFnVTable> {
26    /// The vtable of the aggregate function.
27    vtable: V,
28    /// Type-erased aggregate function used for kernel dispatch.
29    aggregate_fn: AggregateFnRef,
30    /// The DType of the input.
31    dtype: DType,
32    /// The DType of the aggregate.
33    return_dtype: DType,
34    /// The DType of the accumulator state.
35    partial_dtype: DType,
36    /// The partial state of the accumulator, updated after each accumulate/merge call.
37    partial: V::Partial,
38}
39
40impl<V: AggregateFnVTable> Accumulator<V> {
41    pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
42        let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
43            vortex_err!(
44                "Aggregate function {} cannot be applied to dtype {}",
45                vtable.id(),
46                dtype
47            )
48        })?;
49        let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
50            vortex_err!(
51                "Aggregate function {} cannot be applied to dtype {}",
52                vtable.id(),
53                dtype
54            )
55        })?;
56        let partial = vtable.empty_partial(&options, &dtype)?;
57        let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
58
59        Ok(Self {
60            vtable,
61            aggregate_fn,
62            dtype,
63            return_dtype,
64            partial_dtype,
65            partial,
66        })
67    }
68}
69
70/// A trait object for type-erased accumulators, used for dynamic dispatch when the aggregate
71/// function is not known at compile time.
72pub trait DynAccumulator: 'static + Send {
73    /// Accumulate a new array into the accumulator's state.
74    fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
75
76    /// Whether the accumulator's result is fully determined.
77    fn is_saturated(&self) -> bool;
78
79    /// Flush the accumulation state and return the partial aggregate result as a scalar.
80    ///
81    /// Resets the accumulator state back to the initial state.
82    fn flush(&mut self) -> VortexResult<Scalar>;
83
84    /// Finish the accumulation and return the final aggregate result as a scalar.
85    ///
86    /// Resets the accumulator state back to the initial state.
87    fn finish(&mut self) -> VortexResult<Scalar>;
88}
89
90impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
91    fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
92        if self.is_saturated() {
93            return Ok(());
94        }
95
96        vortex_ensure!(
97            batch.dtype() == &self.dtype,
98            "Input DType mismatch: expected {}, got {}",
99            self.dtype,
100            batch.dtype()
101        );
102
103        let session = ctx.session().clone();
104        let kernels = &session.aggregate_fns().kernels;
105
106        let mut batch = batch.clone();
107        for _ in 0..*MAX_ITERATIONS {
108            if batch.is::<AnyCanonical>() {
109                break;
110            }
111
112            let kernels_r = kernels.read();
113            let batch_id = batch.encoding_id();
114            if let Some(result) = kernels_r
115                .get(&(batch_id.clone(), Some(self.aggregate_fn.id())))
116                .or_else(|| kernels_r.get(&(batch_id, None)))
117                .and_then(|kernel| {
118                    kernel
119                        .aggregate(&self.aggregate_fn, &batch, ctx)
120                        .transpose()
121                })
122                .transpose()?
123            {
124                vortex_ensure!(
125                    result.dtype() == &self.partial_dtype,
126                    "Aggregate kernel returned {}, expected {}",
127                    result.dtype(),
128                    self.partial_dtype,
129                );
130                self.vtable.combine_partials(&mut self.partial, result)?;
131                return Ok(());
132            }
133
134            // Execute one step and try again
135            batch = batch.execute(ctx)?;
136        }
137
138        // Otherwise, execute the batch until it is columnar and accumulate it into the state.
139        let columnar = batch.execute::<Columnar>(ctx)?;
140
141        self.vtable.accumulate(&mut self.partial, &columnar, ctx)
142    }
143
144    fn is_saturated(&self) -> bool {
145        self.vtable.is_saturated(&self.partial)
146    }
147
148    fn flush(&mut self) -> VortexResult<Scalar> {
149        let partial = self.vtable.to_scalar(&self.partial)?;
150        self.vtable.reset(&mut self.partial);
151
152        #[cfg(debug_assertions)]
153        {
154            vortex_ensure!(
155                partial.dtype() == &self.partial_dtype,
156                "Aggregate kernel returned incorrect DType on flush: expected {}, got {}",
157                self.partial_dtype,
158                partial.dtype(),
159            );
160        }
161
162        Ok(partial)
163    }
164
165    fn finish(&mut self) -> VortexResult<Scalar> {
166        let result = self.vtable.finalize_scalar(&self.partial)?;
167        self.vtable.reset(&mut self.partial);
168
169        vortex_ensure!(
170            result.dtype() == &self.return_dtype,
171            "Aggregate kernel returned incorrect DType on finalize: expected {}, got {}",
172            self.return_dtype,
173            result.dtype(),
174        );
175
176        Ok(result)
177    }
178}