Skip to main content

vortex_array/aggregate_fn/
accumulator_grouped.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::ArrowNativeType;
5use vortex_buffer::Buffer;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10use vortex_error::vortex_err;
11use vortex_error::vortex_panic;
12use vortex_mask::Mask;
13
14use crate::ArrayRef;
15use crate::Canonical;
16use crate::Columnar;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::aggregate_fn::Accumulator;
20use crate::aggregate_fn::AggregateFn;
21use crate::aggregate_fn::AggregateFnRef;
22use crate::aggregate_fn::AggregateFnVTable;
23use crate::aggregate_fn::DynAccumulator;
24use crate::aggregate_fn::session::AggregateFnSessionExt;
25use crate::arrays::ChunkedArray;
26use crate::arrays::FixedSizeListArray;
27use crate::arrays::ListViewArray;
28use crate::arrays::fixed_size_list::FixedSizeListArrayExt;
29use crate::arrays::listview::ListViewArrayExt;
30use crate::builders::builder_with_capacity;
31use crate::builtins::ArrayBuiltins;
32use crate::columnar::AnyColumnar;
33use crate::dtype::DType;
34use crate::executor::max_iterations;
35use crate::match_each_integer_ptype;
36
37/// Reference-counted type-erased grouped accumulator.
38pub type GroupedAccumulatorRef = Box<dyn DynGroupedAccumulator>;
39
40/// A batch of grouped values to aggregate.
41///
42/// Each outer list value is one group, and the inner element array is shared by all groups.
43/// Aggregate implementations can inspect the concrete grouped representation directly, or ask for
44/// derived ranges when their algorithm is expressed in terms of `(offset, size)` pairs.
45pub enum GroupedArray {
46    /// Groups represented as a list-view array with per-group offsets and sizes.
47    ListView(ListViewArray),
48    /// Groups represented as a fixed-size list array.
49    FixedSizeList(FixedSizeListArray),
50}
51
52impl From<ListViewArray> for GroupedArray {
53    fn from(groups: ListViewArray) -> Self {
54        Self::ListView(groups)
55    }
56}
57
58impl From<FixedSizeListArray> for GroupedArray {
59    fn from(groups: FixedSizeListArray) -> Self {
60        Self::FixedSizeList(groups)
61    }
62}
63
64impl GroupedArray {
65    /// The inner element array shared by all groups.
66    pub fn elements(&self) -> &ArrayRef {
67        match self {
68            Self::ListView(groups) => groups.elements(),
69            Self::FixedSizeList(groups) => groups.elements(),
70        }
71    }
72
73    /// Return the `(offset, size)` ranges describing each group in `elements`.
74    pub fn group_ranges(&self, ctx: &mut ExecutionCtx) -> VortexResult<GroupRanges> {
75        match self {
76            Self::ListView(groups) => list_view_group_ranges(groups, ctx),
77            Self::FixedSizeList(groups) => Ok(fixed_size_list_group_ranges(groups)),
78        }
79    }
80
81    /// Return the per-group validity mask.
82    pub fn group_validity(&self, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
83        match self {
84            Self::ListView(groups) => groups.validity()?.execute_mask(groups.len(), ctx),
85            Self::FixedSizeList(groups) => groups.validity()?.execute_mask(groups.len(), ctx),
86        }
87    }
88
89    /// The number of groups in this batch.
90    pub fn len(&self) -> usize {
91        match self {
92            Self::ListView(groups) => groups.len(),
93            Self::FixedSizeList(groups) => groups.len(),
94        }
95    }
96
97    /// Returns true when this batch contains no groups.
98    pub fn is_empty(&self) -> bool {
99        self.len() == 0
100    }
101
102    /// Returns true when every group is valid.
103    pub fn all_groups_valid(&self, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
104        Ok(self.group_validity(ctx)?.all_true())
105    }
106
107    unsafe fn with_elements_unchecked(&self, elements: ArrayRef) -> VortexResult<Self> {
108        Ok(match self {
109            Self::ListView(groups) => unsafe {
110                ListViewArray::new_unchecked(
111                    elements,
112                    groups.offsets().clone(),
113                    groups.sizes().clone(),
114                    groups.validity()?,
115                )
116            }
117            .into(),
118            Self::FixedSizeList(groups) => unsafe {
119                FixedSizeListArray::new_unchecked(
120                    elements,
121                    groups.list_size(),
122                    groups.validity()?,
123                    groups.len(),
124                )
125            }
126            .into(),
127        })
128    }
129}
130
131/// The physical ranges of a grouped array.
132pub enum GroupRanges {
133    /// Explicit ranges extracted from a list-view array.
134    ListView {
135        /// The `(offset, size)` ranges.
136        ranges: Vec<(usize, usize)>,
137    },
138    /// Uniform ranges derived from a fixed-size list array.
139    FixedSizeList {
140        /// The number of groups.
141        len: usize,
142        /// The number of elements in each group.
143        size: usize,
144    },
145}
146
147impl GroupRanges {
148    /// The number of groups described by these ranges.
149    pub fn len(&self) -> usize {
150        match self {
151            Self::ListView { ranges } => ranges.len(),
152            Self::FixedSizeList { len, .. } => *len,
153        }
154    }
155
156    /// Returns true when there are no groups.
157    pub fn is_empty(&self) -> bool {
158        self.len() == 0
159    }
160
161    /// Return the `(offset, size)` range for the group at `index`.
162    fn range(&self, index: usize) -> (usize, usize) {
163        match self {
164            Self::ListView { ranges } => ranges[index],
165            Self::FixedSizeList { len, size } => {
166                assert!(index < *len, "range index out of bounds");
167                (index * size, *size)
168            }
169        }
170    }
171
172    /// Iterate over all `(offset, size)` group ranges.
173    pub fn iter(&self) -> impl Iterator<Item = (usize, usize)> + '_ {
174        (0..self.len()).map(|index| self.range(index))
175    }
176}
177
178/// An accumulator used for computing grouped aggregates.
179///
180/// Note that the groups must be processed in order, and the accumulator does not support random
181/// access to groups.
182pub struct GroupedAccumulator<V: AggregateFnVTable> {
183    /// The vtable of the aggregate function.
184    vtable: V,
185    /// The options of the aggregate function.
186    options: V::Options,
187    /// Type-erased aggregate function used for kernel dispatch.
188    aggregate_fn: AggregateFnRef,
189    /// The DType of the input.
190    dtype: DType,
191    /// The DType of the aggregate.
192    return_dtype: DType,
193    /// The DType of the partial accumulator state.
194    partial_dtype: DType,
195    /// The accumulated state for prior batches of groups.
196    partials: Vec<ArrayRef>,
197}
198
199impl<V: AggregateFnVTable> GroupedAccumulator<V> {
200    pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
201        let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
202        let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
203            vortex_err!(
204                "Aggregate function {} cannot be applied to dtype {}",
205                vtable.id(),
206                dtype
207            )
208        })?;
209        let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
210            vortex_err!(
211                "Aggregate function {} cannot be applied to dtype {}",
212                vtable.id(),
213                dtype
214            )
215        })?;
216
217        Ok(Self {
218            vtable,
219            options,
220            aggregate_fn,
221            dtype,
222            return_dtype,
223            partial_dtype,
224            partials: vec![],
225        })
226    }
227}
228
229/// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the aggregate
230/// function is not known at compile time.
231pub trait DynGroupedAccumulator: 'static + Send {
232    /// Accumulate a list of groups into the accumulator.
233    fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
234
235    /// Finish the accumulation and return the partial aggregate results for all groups.
236    /// Resets the accumulator state for the next round of accumulation.
237    fn flush(&mut self) -> VortexResult<ArrayRef>;
238
239    /// Finish the accumulation and return the final aggregate results for all groups.
240    /// Resets the accumulator state for the next round of accumulation.
241    fn finish(&mut self) -> VortexResult<ArrayRef>;
242}
243
244impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
245    fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
246        let elements_dtype = match groups.dtype() {
247            DType::List(elem, _) => elem,
248            DType::FixedSizeList(elem, ..) => elem,
249            _ => vortex_bail!(
250                "Input DType mismatch: expected List or FixedSizeList, got {}",
251                groups.dtype()
252            ),
253        };
254        vortex_ensure!(
255            elements_dtype.as_ref() == &self.dtype,
256            "Input DType mismatch: expected {}, got {}",
257            self.dtype,
258            elements_dtype
259        );
260
261        // We first execute the groups until it is a ListView or FixedSizeList, since we only
262        // dispatch the aggregate kernel over the elements of these arrays.
263        let canonical = match groups.clone().execute::<Columnar>(ctx)? {
264            Columnar::Canonical(c) => c,
265            Columnar::Constant(c) => c.into_array().execute::<Canonical>(ctx)?,
266        };
267        match canonical {
268            Canonical::List(groups) => self.accumulate_grouped_array(groups.into(), ctx),
269            Canonical::FixedSizeList(groups) => self.accumulate_grouped_array(groups.into(), ctx),
270            _ => vortex_panic!("We checked the DType above, so this should never happen"),
271        }
272    }
273
274    fn flush(&mut self) -> VortexResult<ArrayRef> {
275        let states = std::mem::take(&mut self.partials);
276        Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array())
277    }
278
279    fn finish(&mut self) -> VortexResult<ArrayRef> {
280        let states = self.flush()?;
281        let results = self.vtable.finalize(states)?;
282
283        vortex_ensure!(
284            results.dtype() == &self.return_dtype,
285            "Return DType mismatch: expected {}, got {}",
286            self.return_dtype,
287            results.dtype()
288        );
289
290        Ok(results)
291    }
292}
293
294impl<V: AggregateFnVTable> GroupedAccumulator<V> {
295    fn accumulate_grouped_array(
296        &mut self,
297        groups: GroupedArray,
298        ctx: &mut ExecutionCtx,
299    ) -> VortexResult<()> {
300        let mut elements = groups.elements().clone();
301        let session = ctx.session().clone();
302
303        for _ in 0..max_iterations() {
304            // Try a registered grouped kernel for the current element encoding.
305            if let Some(kernel) = session
306                .aggregate_fns()
307                .find_grouped_encoding_kernel(elements.encoding_id(), self.aggregate_fn.id())
308            {
309                // SAFETY: we assume that elements execution is safe
310                let kernel_groups = unsafe { groups.with_elements_unchecked(elements.clone())? };
311                if let Some(result) =
312                    kernel.grouped_aggregate(&self.aggregate_fn, &kernel_groups, ctx)?
313                {
314                    return self.push_result(result);
315                }
316            }
317
318            // Try a grouped kernel for the current aggregate regardless of element encoding.
319            if let Some(kernel) = session
320                .aggregate_fns()
321                .find_grouped_kernel(self.aggregate_fn.id())
322            {
323                // SAFETY: we preserve the grouped shape and validity while replacing the
324                // elements with another representation of the same logical array.
325                let kernel_groups = unsafe { groups.with_elements_unchecked(elements.clone())? };
326                if let Some(result) =
327                    kernel.grouped_aggregate(&self.aggregate_fn, &kernel_groups, ctx)?
328                {
329                    return self.push_result(result);
330                }
331            }
332
333            if elements.is::<AnyColumnar>() {
334                break;
335            }
336
337            // Execute one step and try again
338            elements = elements.execute(ctx)?;
339        }
340
341        let elements = elements.execute::<Columnar>(ctx)?.into_array();
342        // SAFETY: we preserve the grouped shape and validity while replacing the elements with an
343        // executed form of the same logical array.
344        let grouped = unsafe { groups.with_elements_unchecked(elements)? };
345
346        // Otherwise, we iterate the offsets and sizes and accumulate each group one by one.
347        self.accumulate_grouped_fallback(&grouped, ctx)
348    }
349
350    fn accumulate_grouped_fallback(
351        &mut self,
352        grouped: &GroupedArray,
353        ctx: &mut ExecutionCtx,
354    ) -> VortexResult<()> {
355        let mut accumulator = Accumulator::try_new(
356            self.vtable.clone(),
357            self.options.clone(),
358            self.dtype.clone(),
359        )?;
360        let mut states = builder_with_capacity(&self.partial_dtype, grouped.len());
361        let group_ranges = grouped.group_ranges(ctx)?;
362        let group_validity = grouped.group_validity(ctx)?;
363
364        for ((offset, size), valid) in group_ranges.iter().zip(group_validity.iter()) {
365            if valid {
366                let group = grouped.elements().slice(offset..offset + size)?;
367                accumulator.accumulate(&group, ctx)?;
368                states.append_scalar(&accumulator.flush()?)?;
369            } else {
370                states.append_null()
371            }
372        }
373
374        self.push_result(states.finish())
375    }
376
377    fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> {
378        vortex_ensure!(
379            state.dtype() == &self.partial_dtype,
380            "State DType mismatch: expected {}, got {}",
381            self.partial_dtype,
382            state.dtype()
383        );
384        self.partials.push(state);
385        Ok(())
386    }
387}
388fn list_view_group_ranges(
389    groups: &ListViewArray,
390    ctx: &mut ExecutionCtx,
391) -> VortexResult<GroupRanges> {
392    let offsets = groups.offsets();
393    let sizes = groups.sizes().cast(offsets.dtype().clone())?;
394
395    let ranges = match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
396        let offsets = offsets.clone().execute::<Buffer<O>>(ctx)?;
397        let sizes = sizes.execute::<Buffer<O>>(ctx)?;
398        offsets
399            .as_ref()
400            .iter()
401            .zip(sizes.as_ref().iter())
402            .map(|(offset, size)| {
403                (
404                    offset.to_usize().vortex_expect("Offset value is not usize"),
405                    size.to_usize().vortex_expect("Size value is not usize"),
406                )
407            })
408            .collect::<Vec<_>>()
409    });
410
411    Ok(GroupRanges::ListView { ranges })
412}
413
414fn fixed_size_list_group_ranges(groups: &FixedSizeListArray) -> GroupRanges {
415    GroupRanges::FixedSizeList {
416        len: groups.len(),
417        size: groups.list_size() as usize,
418    }
419}