Skip to main content

vortex_array/aggregate_fn/
accumulator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_error::vortex_ensure;
6use vortex_session::VortexSession;
7
8use crate::AnyCanonical;
9use crate::ArrayRef;
10use crate::Canonical;
11use crate::DynArray;
12use crate::VortexSessionExecute;
13use crate::aggregate_fn::AggregateFn;
14use crate::aggregate_fn::AggregateFnRef;
15use crate::aggregate_fn::AggregateFnVTable;
16use crate::aggregate_fn::session::AggregateFnSessionExt;
17use crate::dtype::DType;
18use crate::executor::MAX_ITERATIONS;
19use crate::scalar::Scalar;
20
21/// Reference-counted type-erased accumulator.
22pub type AccumulatorRef = Box<dyn DynAccumulator>;
23
24/// An accumulator used for computing aggregates over an entire stream of arrays.
25pub struct Accumulator<V: AggregateFnVTable> {
26    /// The vtable of the aggregate function.
27    vtable: V,
28    /// Type-erased aggregate function used for kernel dispatch.
29    aggregate_fn: AggregateFnRef,
30    /// The DType of the input.
31    dtype: DType,
32    /// The DType of the aggregate.
33    return_dtype: DType,
34    /// The DType of the accumulator state.
35    partial_dtype: DType,
36    /// The partial state of the accumulator, updated after each accumulate/merge call.
37    partial: V::Partial,
38    /// A session used to lookup custom aggregate kernels.
39    session: VortexSession,
40}
41
42impl<V: AggregateFnVTable> Accumulator<V> {
43    pub fn try_new(
44        vtable: V,
45        options: V::Options,
46        dtype: DType,
47        session: VortexSession,
48    ) -> VortexResult<Self> {
49        let return_dtype = vtable.return_dtype(&options, &dtype)?;
50        let partial_dtype = vtable.partial_dtype(&options, &dtype)?;
51        let partial = vtable.empty_partial(&options, &dtype)?;
52        let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();
53
54        Ok(Self {
55            vtable,
56            aggregate_fn,
57            dtype,
58            return_dtype,
59            partial_dtype,
60            partial,
61            session,
62        })
63    }
64}
65
66/// A trait object for type-erased accumulators, used for dynamic dispatch when the aggregate
67/// function is not known at compile time.
68pub trait DynAccumulator: 'static + Send {
69    /// Accumulate a new array into the accumulator's state.
70    fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()>;
71
72    /// Whether the accumulator's result is fully determined.
73    fn is_saturated(&self) -> bool;
74
75    /// Flush the accumulation state and return the partial aggregate result as a scalar.
76    ///
77    /// Resets the accumulator state back to the initial state.
78    fn flush(&mut self) -> VortexResult<Scalar>;
79
80    /// Finish the accumulation and return the final aggregate result as a scalar.
81    ///
82    /// Resets the accumulator state back to the initial state.
83    fn finish(&mut self) -> VortexResult<Scalar>;
84}
85
86impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
87    fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> {
88        if self.is_saturated() {
89            return Ok(());
90        }
91
92        vortex_ensure!(
93            batch.dtype() == &self.dtype,
94            "Input DType mismatch: expected {}, got {}",
95            self.dtype,
96            batch.dtype()
97        );
98
99        let kernels = &self.session.aggregate_fns().kernels;
100
101        let mut ctx = self.session.create_execution_ctx();
102        let mut batch = batch.clone();
103        for _ in 0..*MAX_ITERATIONS {
104            if batch.is::<AnyCanonical>() {
105                break;
106            }
107
108            let kernels_r = kernels.read();
109            let batch_id = batch.encoding_id();
110            if let Some(result) = kernels_r
111                .get(&(batch_id.clone(), Some(self.aggregate_fn.id())))
112                .or_else(|| kernels_r.get(&(batch_id, None)))
113                .and_then(|kernel| {
114                    kernel
115                        .aggregate(&self.aggregate_fn, &batch, &mut ctx)
116                        .transpose()
117                })
118                .transpose()?
119            {
120                vortex_ensure!(
121                    result.dtype() == &self.partial_dtype,
122                    "Aggregate kernel returned {}, expected {}",
123                    result.dtype(),
124                    self.partial_dtype,
125                );
126                self.vtable.combine_partials(&mut self.partial, result)?;
127                return Ok(());
128            }
129
130            // Execute one step and try again
131            batch = batch.execute(&mut ctx)?;
132        }
133
134        // Otherwise, execute the batch until it is canonical and accumulate it into the state.
135        let canonical = batch.execute::<Canonical>(&mut ctx)?;
136
137        self.vtable
138            .accumulate(&mut self.partial, &canonical, &mut ctx)
139    }
140
141    fn is_saturated(&self) -> bool {
142        self.vtable.is_saturated(&self.partial)
143    }
144
145    fn flush(&mut self) -> VortexResult<Scalar> {
146        let partial = self.vtable.flush(&mut self.partial)?;
147
148        #[cfg(debug_assertions)]
149        {
150            vortex_ensure!(
151                partial.dtype() == &self.partial_dtype,
152                "Aggregate kernel returned incorrect DType on flush: expected {}, got {}",
153                self.partial_dtype,
154                partial.dtype(),
155            );
156        }
157
158        Ok(partial)
159    }
160
161    fn finish(&mut self) -> VortexResult<Scalar> {
162        let partial = self.flush()?;
163        let result = self.vtable.finalize_scalar(partial)?;
164
165        vortex_ensure!(
166            result.dtype() == &self.return_dtype,
167            "Aggregate kernel returned incorrect DType on finalize: expected {}, got {}",
168            self.return_dtype,
169            result.dtype(),
170        );
171
172        Ok(result)
173    }
174}