Skip to main content

vortex_array/aggregate_fn/
accumulator_grouped.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::ArrowNativeType;
5use vortex_buffer::Buffer;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12use vortex_session::VortexSession;
13
14use crate::AnyCanonical;
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::Columnar;
18use crate::DynArray;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::VortexSessionExecute;
22use crate::aggregate_fn::Accumulator;
23use crate::aggregate_fn::AggregateFn;
24use crate::aggregate_fn::AggregateFnRef;
25use crate::aggregate_fn::AggregateFnVTable;
26use crate::aggregate_fn::DynAccumulator;
27use crate::aggregate_fn::session::AggregateFnSessionExt;
28use crate::arrays::ChunkedArray;
29use crate::arrays::FixedSizeListArray;
30use crate::arrays::ListViewArray;
31use crate::builders::builder_with_capacity;
32use crate::builtins::ArrayBuiltins;
33use crate::dtype::DType;
34use crate::dtype::IntegerPType;
35use crate::executor::MAX_ITERATIONS;
36use crate::match_each_integer_ptype;
37use crate::vtable::ValidityHelper;
38
39/// Reference-counted type-erased grouped accumulator.
40pub type GroupedAccumulatorRef = Box<dyn DynGroupedAccumulator>;
41
42/// An accumulator used for computing grouped aggregates.
43///
44/// Note that the groups must be processed in order, and the accumulator does not support random
45/// access to groups.
46pub struct GroupedAccumulator<V: AggregateFnVTable> {
47    /// The vtable of the aggregate function.
48    vtable: V,
49    /// The options of the aggregate function.
50    options: V::Options,
51    /// Type-erased aggregate function used for kernel dispatch.
52    aggregate_fn: AggregateFnRef,
53    /// The DType of the input.
54    dtype: DType,
55    /// The DType of the aggregate.
56    return_dtype: DType,
57    /// The DType of the partial accumulator state.
58    partial_dtype: DType,
59    /// The accumulated state for prior batches of groups.
60    partials: Vec<ArrayRef>,
61    /// A session used to lookup custom aggregate kernels.
62    session: VortexSession,
63}
64
65impl<V: AggregateFnVTable> GroupedAccumulator<V> {
66    pub fn try_new(
67        vtable: V,
68        options: V::Options,
69        dtype: DType,
70        session: VortexSession,
71    ) -> VortexResult<Self> {
72        let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
73        let return_dtype = vtable.return_dtype(&options, &dtype)?;
74        let partial_dtype = vtable.partial_dtype(&options, &dtype)?;
75
76        Ok(Self {
77            vtable,
78            options,
79            aggregate_fn,
80            dtype,
81            return_dtype,
82            partial_dtype,
83            partials: vec![],
84            session,
85        })
86    }
87}
88
89/// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the aggregate
90/// function is not known at compile time.
91pub trait DynGroupedAccumulator: 'static + Send {
92    /// Accumulate a list of groups into the accumulator.
93    fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()>;
94
95    /// Finish the accumulation and return the partial aggregate results for all groups.
96    /// Resets the accumulator state for the next round of accumulation.
97    fn flush(&mut self) -> VortexResult<ArrayRef>;
98
99    /// Finish the accumulation and return the final aggregate results for all groups.
100    /// Resets the accumulator state for the next round of accumulation.
101    fn finish(&mut self) -> VortexResult<ArrayRef>;
102}
103
104impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
105    fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()> {
106        let elements_dtype = match groups.dtype() {
107            DType::List(elem, _) => elem,
108            DType::FixedSizeList(elem, ..) => elem,
109            _ => vortex_bail!(
110                "Input DType mismatch: expected List or FixedSizeList, got {}",
111                groups.dtype()
112            ),
113        };
114        vortex_ensure!(
115            elements_dtype.as_ref() == &self.dtype,
116            "Input DType mismatch: expected {}, got {}",
117            self.dtype,
118            elements_dtype
119        );
120
121        let mut ctx = self.session.create_execution_ctx();
122
123        // We first execute the groups until it is a ListView or FixedSizeList, since we only
124        // dispatch the aggregate kernel over the elements of these arrays.
125        let canonical = match groups.clone().execute::<Columnar>(&mut ctx)? {
126            Columnar::Canonical(c) => c,
127            Columnar::Constant(c) => c.into_array().execute::<Canonical>(&mut ctx)?,
128        };
129        match canonical {
130            Canonical::List(groups) => self.accumulate_list_view(&groups, &mut ctx),
131            Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, &mut ctx),
132            _ => vortex_panic!("We checked the DType above, so this should never happen"),
133        }
134    }
135
136    fn flush(&mut self) -> VortexResult<ArrayRef> {
137        let states = std::mem::take(&mut self.partials);
138        Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array())
139    }
140
141    fn finish(&mut self) -> VortexResult<ArrayRef> {
142        let states = self.flush()?;
143        let results = self.vtable.finalize(states)?;
144
145        vortex_ensure!(
146            results.dtype() == &self.return_dtype,
147            "Return DType mismatch: expected {}, got {}",
148            self.return_dtype,
149            results.dtype()
150        );
151
152        Ok(results)
153    }
154}
155
156impl<V: AggregateFnVTable> GroupedAccumulator<V> {
157    fn accumulate_list_view(
158        &mut self,
159        groups: &ListViewArray,
160        ctx: &mut ExecutionCtx,
161    ) -> VortexResult<()> {
162        let mut elements = groups.elements().clone();
163        let session = self.session.clone();
164
165        let kernels = &session.aggregate_fns().grouped_kernels;
166
167        for _ in 0..*MAX_ITERATIONS {
168            if elements.is::<AnyCanonical>() {
169                break;
170            }
171
172            let kernels_r = kernels.read();
173            if let Some(result) = kernels_r
174                .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
175                .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
176                .and_then(|kernel| {
177                    // SAFETY: we assume that elements execution is safe
178                    let groups = unsafe {
179                        ListViewArray::new_unchecked(
180                            elements.clone(),
181                            groups.offsets().clone(),
182                            groups.sizes().clone(),
183                            groups.validity().clone(),
184                        )
185                    };
186                    kernel
187                        .grouped_aggregate(&self.aggregate_fn, &groups)
188                        .transpose()
189                })
190                .transpose()?
191            {
192                return self.push_result(result);
193            }
194
195            // Execute one step and try again
196            elements = elements.execute(ctx)?;
197        }
198
199        // Otherwise, we iterate the offsets and sizes and accumulate each group one by one.
200        let elements = elements.execute::<Columnar>(ctx)?.into_array();
201        let offsets = groups.offsets();
202        let sizes = groups.sizes().cast(offsets.dtype().clone())?;
203        let validity = groups.validity().to_mask(offsets.len());
204
205        match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
206            let offsets = offsets.clone().execute::<Buffer<O>>(ctx)?;
207            let sizes = sizes.execute::<Buffer<O>>(ctx)?;
208            self.accumulate_list_view_typed(&elements, offsets.as_ref(), sizes.as_ref(), &validity)
209        })
210    }
211
212    fn accumulate_list_view_typed<O: IntegerPType>(
213        &mut self,
214        elements: &ArrayRef,
215        offsets: &[O],
216        sizes: &[O],
217        validity: &Mask,
218    ) -> VortexResult<()> {
219        let mut accumulator = Accumulator::try_new(
220            self.vtable.clone(),
221            self.options.clone(),
222            self.dtype.clone(),
223            self.session.clone(),
224        )?;
225        let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());
226
227        for (offset, size) in offsets.iter().zip(sizes.iter()) {
228            let offset = offset.to_usize().vortex_expect("Offset value is not usize");
229            let size = size.to_usize().vortex_expect("Size value is not usize");
230
231            if validity.value(offset) {
232                let group = elements.slice(offset..offset + size)?;
233                accumulator.accumulate(&group)?;
234                states.append_scalar(&accumulator.finish()?)?;
235            } else {
236                states.append_null()
237            }
238        }
239
240        self.push_result(states.finish())
241    }
242
243    fn accumulate_fixed_size_list(
244        &mut self,
245        groups: &FixedSizeListArray,
246        ctx: &mut ExecutionCtx,
247    ) -> VortexResult<()> {
248        let mut elements = groups.elements().clone();
249
250        let session = self.session.clone();
251        let kernels = &session.aggregate_fns().grouped_kernels;
252
253        for _ in 0..64 {
254            if elements.is::<AnyCanonical>() {
255                break;
256            }
257
258            let kernels_r = kernels.read();
259            if let Some(result) = kernels_r
260                .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
261                .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
262                .and_then(|kernel| {
263                    // SAFETY: we assume that elements execution is safe
264                    let groups = unsafe {
265                        FixedSizeListArray::new_unchecked(
266                            elements.clone(),
267                            groups.list_size(),
268                            groups.validity().clone(),
269                            groups.len(),
270                        )
271                    };
272
273                    kernel
274                        .grouped_aggregate_fixed_size(&self.aggregate_fn, &groups)
275                        .transpose()
276                })
277                .transpose()?
278            {
279                return self.push_result(result);
280            }
281
282            // Execute one step and try again
283            elements = elements.execute(ctx)?;
284        }
285
286        // Otherwise, we iterate the offsets and sizes and accumulate each group one by one.
287        let elements = elements.execute::<Columnar>(ctx)?.into_array();
288        let validity = groups.validity().to_mask(groups.len());
289
290        let mut accumulator = Accumulator::try_new(
291            self.vtable.clone(),
292            self.options.clone(),
293            self.dtype.clone(),
294            self.session.clone(),
295        )?;
296        let mut states = builder_with_capacity(&self.partial_dtype, groups.len());
297
298        let mut offset = 0;
299        let size = groups
300            .list_size()
301            .to_usize()
302            .vortex_expect("List size is not usize");
303
304        for i in 0..groups.len() {
305            if validity.value(i) {
306                let group = elements.slice(offset..offset + size)?;
307                accumulator.accumulate(&group)?;
308                states.append_scalar(&accumulator.finish()?)?;
309            } else {
310                states.append_null()
311            }
312            offset += size;
313        }
314
315        self.push_result(states.finish())
316    }
317
318    fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> {
319        vortex_ensure!(
320            state.dtype() == &self.partial_dtype,
321            "State DType mismatch: expected {}, got {}",
322            self.partial_dtype,
323            state.dtype()
324        );
325        self.partials.push(state);
326        Ok(())
327    }
328}