Skip to main content

vortex_array/aggregate_fn/
accumulator_grouped.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::ArrowNativeType;
5use vortex_buffer::Buffer;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10use vortex_error::vortex_err;
11use vortex_error::vortex_panic;
12use vortex_mask::Mask;
13
14use crate::AnyCanonical;
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::Columnar;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::aggregate_fn::Accumulator;
21use crate::aggregate_fn::AggregateFn;
22use crate::aggregate_fn::AggregateFnRef;
23use crate::aggregate_fn::AggregateFnVTable;
24use crate::aggregate_fn::DynAccumulator;
25use crate::aggregate_fn::session::AggregateFnSessionExt;
26use crate::arrays::ChunkedArray;
27use crate::arrays::FixedSizeListArray;
28use crate::arrays::ListViewArray;
29use crate::arrays::fixed_size_list::FixedSizeListArrayExt;
30use crate::arrays::listview::ListViewArrayExt;
31use crate::builders::builder_with_capacity;
32use crate::builtins::ArrayBuiltins;
33use crate::dtype::DType;
34use crate::dtype::IntegerPType;
35use crate::executor::MAX_ITERATIONS;
36use crate::match_each_integer_ptype;
37
38/// Reference-counted type-erased grouped accumulator.
39pub type GroupedAccumulatorRef = Box<dyn DynGroupedAccumulator>;
40
41/// An accumulator used for computing grouped aggregates.
42///
43/// Note that the groups must be processed in order, and the accumulator does not support random
44/// access to groups.
45pub struct GroupedAccumulator<V: AggregateFnVTable> {
46    /// The vtable of the aggregate function.
47    vtable: V,
48    /// The options of the aggregate function.
49    options: V::Options,
50    /// Type-erased aggregate function used for kernel dispatch.
51    aggregate_fn: AggregateFnRef,
52    /// The DType of the input.
53    dtype: DType,
54    /// The DType of the aggregate.
55    return_dtype: DType,
56    /// The DType of the partial accumulator state.
57    partial_dtype: DType,
58    /// The accumulated state for prior batches of groups.
59    partials: Vec<ArrayRef>,
60}
61
62impl<V: AggregateFnVTable> GroupedAccumulator<V> {
63    pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
64        let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
65        let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
66            vortex_err!(
67                "Aggregate function {} cannot be applied to dtype {}",
68                vtable.id(),
69                dtype
70            )
71        })?;
72        let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
73            vortex_err!(
74                "Aggregate function {} cannot be applied to dtype {}",
75                vtable.id(),
76                dtype
77            )
78        })?;
79
80        Ok(Self {
81            vtable,
82            options,
83            aggregate_fn,
84            dtype,
85            return_dtype,
86            partial_dtype,
87            partials: vec![],
88        })
89    }
90}
91
92/// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the aggregate
93/// function is not known at compile time.
94pub trait DynGroupedAccumulator: 'static + Send {
95    /// Accumulate a list of groups into the accumulator.
96    fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
97
98    /// Finish the accumulation and return the partial aggregate results for all groups.
99    /// Resets the accumulator state for the next round of accumulation.
100    fn flush(&mut self) -> VortexResult<ArrayRef>;
101
102    /// Finish the accumulation and return the final aggregate results for all groups.
103    /// Resets the accumulator state for the next round of accumulation.
104    fn finish(&mut self) -> VortexResult<ArrayRef>;
105}
106
107impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
108    fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
109        let elements_dtype = match groups.dtype() {
110            DType::List(elem, _) => elem,
111            DType::FixedSizeList(elem, ..) => elem,
112            _ => vortex_bail!(
113                "Input DType mismatch: expected List or FixedSizeList, got {}",
114                groups.dtype()
115            ),
116        };
117        vortex_ensure!(
118            elements_dtype.as_ref() == &self.dtype,
119            "Input DType mismatch: expected {}, got {}",
120            self.dtype,
121            elements_dtype
122        );
123
124        // We first execute the groups until it is a ListView or FixedSizeList, since we only
125        // dispatch the aggregate kernel over the elements of these arrays.
126        let canonical = match groups.clone().execute::<Columnar>(ctx)? {
127            Columnar::Canonical(c) => c,
128            Columnar::Constant(c) => c.into_array().execute::<Canonical>(ctx)?,
129        };
130        match canonical {
131            Canonical::List(groups) => self.accumulate_list_view(&groups, ctx),
132            Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, ctx),
133            _ => vortex_panic!("We checked the DType above, so this should never happen"),
134        }
135    }
136
137    fn flush(&mut self) -> VortexResult<ArrayRef> {
138        let states = std::mem::take(&mut self.partials);
139        Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array())
140    }
141
142    fn finish(&mut self) -> VortexResult<ArrayRef> {
143        let states = self.flush()?;
144        let results = self.vtable.finalize(states)?;
145
146        vortex_ensure!(
147            results.dtype() == &self.return_dtype,
148            "Return DType mismatch: expected {}, got {}",
149            self.return_dtype,
150            results.dtype()
151        );
152
153        Ok(results)
154    }
155}
156
157impl<V: AggregateFnVTable> GroupedAccumulator<V> {
158    fn accumulate_list_view(
159        &mut self,
160        groups: &ListViewArray,
161        ctx: &mut ExecutionCtx,
162    ) -> VortexResult<()> {
163        let mut elements = groups.elements().clone();
164        let groups_validity = groups.validity()?;
165        let session = ctx.session().clone();
166        let kernels = &session.aggregate_fns().grouped_kernels;
167
168        for _ in 0..*MAX_ITERATIONS {
169            if elements.is::<AnyCanonical>() {
170                break;
171            }
172
173            let kernels_r = kernels.read();
174            if let Some(result) = kernels_r
175                .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
176                .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
177                .and_then(|kernel| {
178                    // SAFETY: we assume that elements execution is safe
179                    let groups = unsafe {
180                        ListViewArray::new_unchecked(
181                            elements.clone(),
182                            groups.offsets().clone(),
183                            groups.sizes().clone(),
184                            groups_validity.clone(),
185                        )
186                    };
187                    kernel
188                        .grouped_aggregate(&self.aggregate_fn, &groups)
189                        .transpose()
190                })
191                .transpose()?
192            {
193                return self.push_result(result);
194            }
195
196            // Execute one step and try again
197            elements = elements.execute(ctx)?;
198        }
199
200        // Otherwise, we iterate the offsets and sizes and accumulate each group one by one.
201        let elements = elements.execute::<Columnar>(ctx)?.into_array();
202        let offsets = groups.offsets();
203        let sizes = groups.sizes().cast(offsets.dtype().clone())?;
204        let validity = groups_validity.execute_mask(offsets.len(), ctx)?;
205
206        match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
207            let offsets = offsets.clone().execute::<Buffer<O>>(ctx)?;
208            let sizes = sizes.execute::<Buffer<O>>(ctx)?;
209            self.accumulate_list_view_typed(
210                &elements,
211                offsets.as_ref(),
212                sizes.as_ref(),
213                &validity,
214                ctx,
215            )
216        })
217    }
218
219    fn accumulate_list_view_typed<O: IntegerPType>(
220        &mut self,
221        elements: &ArrayRef,
222        offsets: &[O],
223        sizes: &[O],
224        validity: &Mask,
225        ctx: &mut ExecutionCtx,
226    ) -> VortexResult<()> {
227        let mut accumulator = Accumulator::try_new(
228            self.vtable.clone(),
229            self.options.clone(),
230            self.dtype.clone(),
231        )?;
232        let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());
233
234        for (offset, size) in offsets.iter().zip(sizes.iter()) {
235            let offset = offset.to_usize().vortex_expect("Offset value is not usize");
236            let size = size.to_usize().vortex_expect("Size value is not usize");
237
238            if validity.value(offset) {
239                let group = elements.slice(offset..offset + size)?;
240                accumulator.accumulate(&group, ctx)?;
241                states.append_scalar(&accumulator.flush()?)?;
242            } else {
243                states.append_null()
244            }
245        }
246
247        self.push_result(states.finish())
248    }
249
250    fn accumulate_fixed_size_list(
251        &mut self,
252        groups: &FixedSizeListArray,
253        ctx: &mut ExecutionCtx,
254    ) -> VortexResult<()> {
255        let mut elements = groups.elements().clone();
256        let groups_validity = groups.validity()?;
257        let session = ctx.session().clone();
258        let kernels = &session.aggregate_fns().grouped_kernels;
259
260        for _ in 0..64 {
261            if elements.is::<AnyCanonical>() {
262                break;
263            }
264
265            let kernels_r = kernels.read();
266            if let Some(result) = kernels_r
267                .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
268                .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
269                .and_then(|kernel| {
270                    // SAFETY: we assume that elements execution is safe
271                    let groups = unsafe {
272                        FixedSizeListArray::new_unchecked(
273                            elements.clone(),
274                            groups.list_size(),
275                            groups_validity.clone(),
276                            groups.len(),
277                        )
278                    };
279
280                    kernel
281                        .grouped_aggregate_fixed_size(&self.aggregate_fn, &groups)
282                        .transpose()
283                })
284                .transpose()?
285            {
286                return self.push_result(result);
287            }
288
289            // Execute one step and try again
290            elements = elements.execute(ctx)?;
291        }
292
293        // Otherwise, we iterate the offsets and sizes and accumulate each group one by one.
294        let elements = elements.execute::<Columnar>(ctx)?.into_array();
295        let validity = groups_validity.execute_mask(groups.len(), ctx)?;
296
297        let mut accumulator = Accumulator::try_new(
298            self.vtable.clone(),
299            self.options.clone(),
300            self.dtype.clone(),
301        )?;
302        let mut states = builder_with_capacity(&self.partial_dtype, groups.len());
303
304        let mut offset = 0;
305        let size = groups
306            .list_size()
307            .to_usize()
308            .vortex_expect("List size is not usize");
309
310        for i in 0..groups.len() {
311            if validity.value(i) {
312                let group = elements.slice(offset..offset + size)?;
313                accumulator.accumulate(&group, ctx)?;
314                states.append_scalar(&accumulator.finish()?)?;
315            } else {
316                states.append_null()
317            }
318            offset += size;
319        }
320
321        self.push_result(states.finish())
322    }
323
324    fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> {
325        vortex_ensure!(
326            state.dtype() == &self.partial_dtype,
327            "State DType mismatch: expected {}, got {}",
328            self.partial_dtype,
329            state.dtype()
330        );
331        self.partials.push(state);
332        Ok(())
333    }
334}