Skip to main content

vortex_array/aggregate_fn/
accumulator_grouped.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::ArrowNativeType;
5use vortex_buffer::Buffer;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10use vortex_error::vortex_err;
11use vortex_error::vortex_panic;
12use vortex_mask::Mask;
13
14use crate::AnyCanonical;
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::Columnar;
18use crate::DynArray;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::aggregate_fn::Accumulator;
22use crate::aggregate_fn::AggregateFn;
23use crate::aggregate_fn::AggregateFnRef;
24use crate::aggregate_fn::AggregateFnVTable;
25use crate::aggregate_fn::DynAccumulator;
26use crate::aggregate_fn::session::AggregateFnSessionExt;
27use crate::arrays::ChunkedArray;
28use crate::arrays::FixedSizeListArray;
29use crate::arrays::ListViewArray;
30use crate::builders::builder_with_capacity;
31use crate::builtins::ArrayBuiltins;
32use crate::dtype::DType;
33use crate::dtype::IntegerPType;
34use crate::executor::MAX_ITERATIONS;
35use crate::match_each_integer_ptype;
36use crate::vtable::ValidityHelper;
37
38/// Reference-counted type-erased grouped accumulator.
39pub type GroupedAccumulatorRef = Box<dyn DynGroupedAccumulator>;
40
41/// An accumulator used for computing grouped aggregates.
42///
43/// Note that the groups must be processed in order, and the accumulator does not support random
44/// access to groups.
45pub struct GroupedAccumulator<V: AggregateFnVTable> {
46    /// The vtable of the aggregate function.
47    vtable: V,
48    /// The options of the aggregate function.
49    options: V::Options,
50    /// Type-erased aggregate function used for kernel dispatch.
51    aggregate_fn: AggregateFnRef,
52    /// The DType of the input.
53    dtype: DType,
54    /// The DType of the aggregate.
55    return_dtype: DType,
56    /// The DType of the partial accumulator state.
57    partial_dtype: DType,
58    /// The accumulated state for prior batches of groups.
59    partials: Vec<ArrayRef>,
60}
61
62impl<V: AggregateFnVTable> GroupedAccumulator<V> {
63    pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
64        let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
65        let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
66            vortex_err!(
67                "Aggregate function {} cannot be applied to dtype {}",
68                vtable.id(),
69                dtype
70            )
71        })?;
72        let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
73            vortex_err!(
74                "Aggregate function {} cannot be applied to dtype {}",
75                vtable.id(),
76                dtype
77            )
78        })?;
79
80        Ok(Self {
81            vtable,
82            options,
83            aggregate_fn,
84            dtype,
85            return_dtype,
86            partial_dtype,
87            partials: vec![],
88        })
89    }
90}
91
92/// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the aggregate
93/// function is not known at compile time.
94pub trait DynGroupedAccumulator: 'static + Send {
95    /// Accumulate a list of groups into the accumulator.
96    fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
97
98    /// Finish the accumulation and return the partial aggregate results for all groups.
99    /// Resets the accumulator state for the next round of accumulation.
100    fn flush(&mut self) -> VortexResult<ArrayRef>;
101
102    /// Finish the accumulation and return the final aggregate results for all groups.
103    /// Resets the accumulator state for the next round of accumulation.
104    fn finish(&mut self) -> VortexResult<ArrayRef>;
105}
106
107impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
108    fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
109        let elements_dtype = match groups.dtype() {
110            DType::List(elem, _) => elem,
111            DType::FixedSizeList(elem, ..) => elem,
112            _ => vortex_bail!(
113                "Input DType mismatch: expected List or FixedSizeList, got {}",
114                groups.dtype()
115            ),
116        };
117        vortex_ensure!(
118            elements_dtype.as_ref() == &self.dtype,
119            "Input DType mismatch: expected {}, got {}",
120            self.dtype,
121            elements_dtype
122        );
123
124        // We first execute the groups until it is a ListView or FixedSizeList, since we only
125        // dispatch the aggregate kernel over the elements of these arrays.
126        let canonical = match groups.clone().execute::<Columnar>(ctx)? {
127            Columnar::Canonical(c) => c,
128            Columnar::Constant(c) => c.into_array().execute::<Canonical>(ctx)?,
129        };
130        match canonical {
131            Canonical::List(groups) => self.accumulate_list_view(&groups, ctx),
132            Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, ctx),
133            _ => vortex_panic!("We checked the DType above, so this should never happen"),
134        }
135    }
136
137    fn flush(&mut self) -> VortexResult<ArrayRef> {
138        let states = std::mem::take(&mut self.partials);
139        Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array())
140    }
141
142    fn finish(&mut self) -> VortexResult<ArrayRef> {
143        let states = self.flush()?;
144        let results = self.vtable.finalize(states)?;
145
146        vortex_ensure!(
147            results.dtype() == &self.return_dtype,
148            "Return DType mismatch: expected {}, got {}",
149            self.return_dtype,
150            results.dtype()
151        );
152
153        Ok(results)
154    }
155}
156
157impl<V: AggregateFnVTable> GroupedAccumulator<V> {
158    fn accumulate_list_view(
159        &mut self,
160        groups: &ListViewArray,
161        ctx: &mut ExecutionCtx,
162    ) -> VortexResult<()> {
163        let mut elements = groups.elements().clone();
164        let session = ctx.session().clone();
165        let kernels = &session.aggregate_fns().grouped_kernels;
166
167        for _ in 0..*MAX_ITERATIONS {
168            if elements.is::<AnyCanonical>() {
169                break;
170            }
171
172            let kernels_r = kernels.read();
173            if let Some(result) = kernels_r
174                .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
175                .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
176                .and_then(|kernel| {
177                    // SAFETY: we assume that elements execution is safe
178                    let groups = unsafe {
179                        ListViewArray::new_unchecked(
180                            elements.clone(),
181                            groups.offsets().clone(),
182                            groups.sizes().clone(),
183                            groups.validity().clone(),
184                        )
185                    };
186                    kernel
187                        .grouped_aggregate(&self.aggregate_fn, &groups)
188                        .transpose()
189                })
190                .transpose()?
191            {
192                return self.push_result(result);
193            }
194
195            // Execute one step and try again
196            elements = elements.execute(ctx)?;
197        }
198
199        // Otherwise, we iterate the offsets and sizes and accumulate each group one by one.
200        let elements = elements.execute::<Columnar>(ctx)?.into_array();
201        let offsets = groups.offsets();
202        let sizes = groups.sizes().cast(offsets.dtype().clone())?;
203        let validity = groups.validity().execute_mask(offsets.len(), ctx)?;
204
205        match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
206            let offsets = offsets.clone().execute::<Buffer<O>>(ctx)?;
207            let sizes = sizes.execute::<Buffer<O>>(ctx)?;
208            self.accumulate_list_view_typed(
209                &elements,
210                offsets.as_ref(),
211                sizes.as_ref(),
212                &validity,
213                ctx,
214            )
215        })
216    }
217
218    fn accumulate_list_view_typed<O: IntegerPType>(
219        &mut self,
220        elements: &ArrayRef,
221        offsets: &[O],
222        sizes: &[O],
223        validity: &Mask,
224        ctx: &mut ExecutionCtx,
225    ) -> VortexResult<()> {
226        let mut accumulator = Accumulator::try_new(
227            self.vtable.clone(),
228            self.options.clone(),
229            self.dtype.clone(),
230        )?;
231        let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());
232
233        for (offset, size) in offsets.iter().zip(sizes.iter()) {
234            let offset = offset.to_usize().vortex_expect("Offset value is not usize");
235            let size = size.to_usize().vortex_expect("Size value is not usize");
236
237            if validity.value(offset) {
238                let group = elements.slice(offset..offset + size)?;
239                accumulator.accumulate(&group, ctx)?;
240                states.append_scalar(&accumulator.finish()?)?;
241            } else {
242                states.append_null()
243            }
244        }
245
246        self.push_result(states.finish())
247    }
248
249    fn accumulate_fixed_size_list(
250        &mut self,
251        groups: &FixedSizeListArray,
252        ctx: &mut ExecutionCtx,
253    ) -> VortexResult<()> {
254        let mut elements = groups.elements().clone();
255        let session = ctx.session().clone();
256        let kernels = &session.aggregate_fns().grouped_kernels;
257
258        for _ in 0..64 {
259            if elements.is::<AnyCanonical>() {
260                break;
261            }
262
263            let kernels_r = kernels.read();
264            if let Some(result) = kernels_r
265                .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
266                .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
267                .and_then(|kernel| {
268                    // SAFETY: we assume that elements execution is safe
269                    let groups = unsafe {
270                        FixedSizeListArray::new_unchecked(
271                            elements.clone(),
272                            groups.list_size(),
273                            groups.validity().clone(),
274                            groups.len(),
275                        )
276                    };
277
278                    kernel
279                        .grouped_aggregate_fixed_size(&self.aggregate_fn, &groups)
280                        .transpose()
281                })
282                .transpose()?
283            {
284                return self.push_result(result);
285            }
286
287            // Execute one step and try again
288            elements = elements.execute(ctx)?;
289        }
290
291        // Otherwise, we iterate the offsets and sizes and accumulate each group one by one.
292        let elements = elements.execute::<Columnar>(ctx)?.into_array();
293        let validity = groups.validity().execute_mask(groups.len(), ctx)?;
294
295        let mut accumulator = Accumulator::try_new(
296            self.vtable.clone(),
297            self.options.clone(),
298            self.dtype.clone(),
299        )?;
300        let mut states = builder_with_capacity(&self.partial_dtype, groups.len());
301
302        let mut offset = 0;
303        let size = groups
304            .list_size()
305            .to_usize()
306            .vortex_expect("List size is not usize");
307
308        for i in 0..groups.len() {
309            if validity.value(i) {
310                let group = elements.slice(offset..offset + size)?;
311                accumulator.accumulate(&group, ctx)?;
312                states.append_scalar(&accumulator.finish()?)?;
313            } else {
314                states.append_null()
315            }
316            offset += size;
317        }
318
319        self.push_result(states.finish())
320    }
321
322    fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> {
323        vortex_ensure!(
324            state.dtype() == &self.partial_dtype,
325            "State DType mismatch: expected {}, got {}",
326            self.partial_dtype,
327            state.dtype()
328        );
329        self.partials.push(state);
330        Ok(())
331    }
332}