Skip to main content

vortex_array/aggregate_fn/
accumulator_grouped.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::ArrowNativeType;
5use vortex_buffer::Buffer;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12use vortex_session::VortexSession;
13
14use crate::AnyCanonical;
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::DynArray;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::VortexSessionExecute;
21use crate::aggregate_fn::Accumulator;
22use crate::aggregate_fn::AggregateFn;
23use crate::aggregate_fn::AggregateFnRef;
24use crate::aggregate_fn::AggregateFnVTable;
25use crate::aggregate_fn::DynAccumulator;
26use crate::aggregate_fn::session::AggregateFnSessionExt;
27use crate::arrays::ChunkedArray;
28use crate::arrays::FixedSizeListArray;
29use crate::arrays::ListViewArray;
30use crate::builders::builder_with_capacity;
31use crate::builtins::ArrayBuiltins;
32use crate::dtype::DType;
33use crate::dtype::IntegerPType;
34use crate::executor::MAX_ITERATIONS;
35use crate::match_each_integer_ptype;
36use crate::vtable::ValidityHelper;
37
38/// Reference-counted type-erased grouped accumulator.
39pub type GroupedAccumulatorRef = Box<dyn DynGroupedAccumulator>;
40
41/// An accumulator used for computing grouped aggregates.
42///
43/// Note that the groups must be processed in order, and the accumulator does not support random
44/// access to groups.
45pub struct GroupedAccumulator<V: AggregateFnVTable> {
46    /// The vtable of the aggregate function.
47    vtable: V,
48    /// The options of the aggregate function.
49    options: V::Options,
50    /// Type-erased aggregate function used for kernel dispatch.
51    aggregate_fn: AggregateFnRef,
52    /// The DType of the input.
53    dtype: DType,
54    /// The DType of the aggregate.
55    return_dtype: DType,
56    /// The DType of the partial accumulator state.
57    partial_dtype: DType,
58    /// The accumulated state for prior batches of groups.
59    partials: Vec<ArrayRef>,
60    /// A session used to lookup custom aggregate kernels.
61    session: VortexSession,
62}
63
64impl<V: AggregateFnVTable> GroupedAccumulator<V> {
65    pub fn try_new(
66        vtable: V,
67        options: V::Options,
68        dtype: DType,
69        session: VortexSession,
70    ) -> VortexResult<Self> {
71        let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
72        let return_dtype = vtable.return_dtype(&options, &dtype)?;
73        let partial_dtype = vtable.partial_dtype(&options, &dtype)?;
74
75        Ok(Self {
76            vtable,
77            options,
78            aggregate_fn,
79            dtype,
80            return_dtype,
81            partial_dtype,
82            partials: vec![],
83            session,
84        })
85    }
86}
87
88/// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the aggregate
89/// function is not known at compile time.
90pub trait DynGroupedAccumulator: 'static + Send {
91    /// Accumulate a list of groups into the accumulator.
92    fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()>;
93
94    /// Finish the accumulation and return the partial aggregate results for all groups.
95    /// Resets the accumulator state for the next round of accumulation.
96    fn flush(&mut self) -> VortexResult<ArrayRef>;
97
98    /// Finish the accumulation and return the final aggregate results for all groups.
99    /// Resets the accumulator state for the next round of accumulation.
100    fn finish(&mut self) -> VortexResult<ArrayRef>;
101}
102
103impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
104    fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()> {
105        let elements_dtype = match groups.dtype() {
106            DType::List(elem, _) => elem,
107            DType::FixedSizeList(elem, ..) => elem,
108            _ => vortex_bail!(
109                "Input DType mismatch: expected List or FixedSizeList, got {}",
110                groups.dtype()
111            ),
112        };
113        vortex_ensure!(
114            elements_dtype.as_ref() == &self.dtype,
115            "Input DType mismatch: expected {}, got {}",
116            self.dtype,
117            elements_dtype
118        );
119
120        let mut ctx = self.session.create_execution_ctx();
121
122        // We first execute the groups until it is a ListView or FixedSizeList, since we only
123        // dispatch the aggregate kernel over the elements of these arrays.
124        match groups.clone().execute::<Canonical>(&mut ctx)? {
125            Canonical::List(groups) => self.accumulate_list_view(&groups, &mut ctx),
126            Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, &mut ctx),
127            _ => vortex_panic!("We checked the DType above, so this should never happen"),
128        }
129    }
130
131    fn flush(&mut self) -> VortexResult<ArrayRef> {
132        let states = std::mem::take(&mut self.partials);
133        Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array())
134    }
135
136    fn finish(&mut self) -> VortexResult<ArrayRef> {
137        let states = self.flush()?;
138        let results = self.vtable.finalize(states)?;
139
140        vortex_ensure!(
141            results.dtype() == &self.return_dtype,
142            "Return DType mismatch: expected {}, got {}",
143            self.return_dtype,
144            results.dtype()
145        );
146
147        Ok(results)
148    }
149}
150
151impl<V: AggregateFnVTable> GroupedAccumulator<V> {
152    fn accumulate_list_view(
153        &mut self,
154        groups: &ListViewArray,
155        ctx: &mut ExecutionCtx,
156    ) -> VortexResult<()> {
157        let mut elements = groups.elements().clone();
158        let session = self.session.clone();
159
160        let kernels = &session.aggregate_fns().grouped_kernels;
161
162        for _ in 0..*MAX_ITERATIONS {
163            if elements.is::<AnyCanonical>() {
164                break;
165            }
166
167            let kernels_r = kernels.read();
168            if let Some(result) = kernels_r
169                .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
170                .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
171                .and_then(|kernel| {
172                    // SAFETY: we assume that elements execution is safe
173                    let groups = unsafe {
174                        ListViewArray::new_unchecked(
175                            elements.clone(),
176                            groups.offsets().clone(),
177                            groups.sizes().clone(),
178                            groups.validity().clone(),
179                        )
180                    };
181                    kernel
182                        .grouped_aggregate(&self.aggregate_fn, &groups)
183                        .transpose()
184                })
185                .transpose()?
186            {
187                return self.push_result(result);
188            }
189
190            // Execute one step and try again
191            elements = elements.execute(ctx)?;
192        }
193
194        // Otherwise, we iterate the offsets and sizes and accumulate each group one by one.
195        let elements = elements.execute::<Canonical>(ctx)?.into_array();
196        let offsets = groups.offsets();
197        let sizes = groups.sizes().cast(offsets.dtype().clone())?;
198        let validity = groups.validity().to_mask(offsets.len());
199
200        match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
201            let offsets = offsets.clone().execute::<Buffer<O>>(ctx)?;
202            let sizes = sizes.execute::<Buffer<O>>(ctx)?;
203            self.accumulate_list_view_typed(&elements, offsets.as_ref(), sizes.as_ref(), &validity)
204        })
205    }
206
207    fn accumulate_list_view_typed<O: IntegerPType>(
208        &mut self,
209        elements: &ArrayRef,
210        offsets: &[O],
211        sizes: &[O],
212        validity: &Mask,
213    ) -> VortexResult<()> {
214        let mut accumulator = Accumulator::try_new(
215            self.vtable.clone(),
216            self.options.clone(),
217            self.dtype.clone(),
218            self.session.clone(),
219        )?;
220        let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());
221
222        for (offset, size) in offsets.iter().zip(sizes.iter()) {
223            let offset = offset.to_usize().vortex_expect("Offset value is not usize");
224            let size = size.to_usize().vortex_expect("Size value is not usize");
225
226            if validity.value(offset) {
227                let group = elements.slice(offset..offset + size)?;
228                accumulator.accumulate(&group)?;
229                states.append_scalar(&accumulator.finish()?)?;
230            } else {
231                states.append_null()
232            }
233        }
234
235        self.push_result(states.finish())
236    }
237
238    fn accumulate_fixed_size_list(
239        &mut self,
240        groups: &FixedSizeListArray,
241        ctx: &mut ExecutionCtx,
242    ) -> VortexResult<()> {
243        let mut elements = groups.elements().clone();
244
245        let session = self.session.clone();
246        let kernels = &session.aggregate_fns().grouped_kernels;
247
248        for _ in 0..64 {
249            if elements.is::<AnyCanonical>() {
250                break;
251            }
252
253            let kernels_r = kernels.read();
254            if let Some(result) = kernels_r
255                .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
256                .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
257                .and_then(|kernel| {
258                    // SAFETY: we assume that elements execution is safe
259                    let groups = unsafe {
260                        FixedSizeListArray::new_unchecked(
261                            elements.clone(),
262                            groups.list_size(),
263                            groups.validity().clone(),
264                            groups.len(),
265                        )
266                    };
267
268                    kernel
269                        .grouped_aggregate_fixed_size(&self.aggregate_fn, &groups)
270                        .transpose()
271                })
272                .transpose()?
273            {
274                return self.push_result(result);
275            }
276
277            // Execute one step and try again
278            elements = elements.execute(ctx)?;
279        }
280
281        // Otherwise, we iterate the offsets and sizes and accumulate each group one by one.
282        let elements = elements.execute::<Canonical>(ctx)?.into_array();
283        let validity = groups.validity().to_mask(groups.len());
284
285        let mut accumulator = Accumulator::try_new(
286            self.vtable.clone(),
287            self.options.clone(),
288            self.dtype.clone(),
289            self.session.clone(),
290        )?;
291        let mut states = builder_with_capacity(&self.partial_dtype, groups.len());
292
293        let mut offset = 0;
294        let size = groups
295            .list_size()
296            .to_usize()
297            .vortex_expect("List size is not usize");
298
299        for i in 0..groups.len() {
300            if validity.value(i) {
301                let group = elements.slice(offset..offset + size)?;
302                accumulator.accumulate(&group)?;
303                states.append_scalar(&accumulator.finish()?)?;
304            } else {
305                states.append_null()
306            }
307            offset += size;
308        }
309
310        self.push_result(states.finish())
311    }
312
313    fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> {
314        vortex_ensure!(
315            state.dtype() == &self.partial_dtype,
316            "State DType mismatch: expected {}, got {}",
317            self.partial_dtype,
318            state.dtype()
319        );
320        self.partials.push(state);
321        Ok(())
322    }
323}