vortex_array/aggregate_fn/
accumulator_grouped.rs1use arrow_buffer::ArrowNativeType;
5use vortex_buffer::Buffer;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10use vortex_error::vortex_err;
11use vortex_error::vortex_panic;
12use vortex_mask::Mask;
13
14use crate::ArrayRef;
15use crate::Canonical;
16use crate::Columnar;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::aggregate_fn::Accumulator;
20use crate::aggregate_fn::AggregateFn;
21use crate::aggregate_fn::AggregateFnRef;
22use crate::aggregate_fn::AggregateFnVTable;
23use crate::aggregate_fn::DynAccumulator;
24use crate::aggregate_fn::session::AggregateFnSessionExt;
25use crate::arrays::ChunkedArray;
26use crate::arrays::FixedSizeListArray;
27use crate::arrays::ListViewArray;
28use crate::arrays::fixed_size_list::FixedSizeListArrayExt;
29use crate::arrays::listview::ListViewArrayExt;
30use crate::builders::builder_with_capacity;
31use crate::builtins::ArrayBuiltins;
32use crate::columnar::AnyColumnar;
33use crate::dtype::DType;
34use crate::executor::max_iterations;
35use crate::match_each_integer_ptype;
36
37pub type GroupedAccumulatorRef = Box<dyn DynGroupedAccumulator>;
39
40pub enum GroupedArray {
46 ListView(ListViewArray),
48 FixedSizeList(FixedSizeListArray),
50}
51
52impl From<ListViewArray> for GroupedArray {
53 fn from(groups: ListViewArray) -> Self {
54 Self::ListView(groups)
55 }
56}
57
58impl From<FixedSizeListArray> for GroupedArray {
59 fn from(groups: FixedSizeListArray) -> Self {
60 Self::FixedSizeList(groups)
61 }
62}
63
64impl GroupedArray {
65 pub fn elements(&self) -> &ArrayRef {
67 match self {
68 Self::ListView(groups) => groups.elements(),
69 Self::FixedSizeList(groups) => groups.elements(),
70 }
71 }
72
73 pub fn group_ranges(&self, ctx: &mut ExecutionCtx) -> VortexResult<GroupRanges> {
75 match self {
76 Self::ListView(groups) => list_view_group_ranges(groups, ctx),
77 Self::FixedSizeList(groups) => Ok(fixed_size_list_group_ranges(groups)),
78 }
79 }
80
81 pub fn group_validity(&self, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
83 match self {
84 Self::ListView(groups) => groups.validity()?.execute_mask(groups.len(), ctx),
85 Self::FixedSizeList(groups) => groups.validity()?.execute_mask(groups.len(), ctx),
86 }
87 }
88
89 pub fn len(&self) -> usize {
91 match self {
92 Self::ListView(groups) => groups.len(),
93 Self::FixedSizeList(groups) => groups.len(),
94 }
95 }
96
97 pub fn is_empty(&self) -> bool {
99 self.len() == 0
100 }
101
102 pub fn all_groups_valid(&self, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
104 Ok(self.group_validity(ctx)?.all_true())
105 }
106
107 unsafe fn with_elements_unchecked(&self, elements: ArrayRef) -> VortexResult<Self> {
108 Ok(match self {
109 Self::ListView(groups) => unsafe {
110 ListViewArray::new_unchecked(
111 elements,
112 groups.offsets().clone(),
113 groups.sizes().clone(),
114 groups.validity()?,
115 )
116 }
117 .into(),
118 Self::FixedSizeList(groups) => unsafe {
119 FixedSizeListArray::new_unchecked(
120 elements,
121 groups.list_size(),
122 groups.validity()?,
123 groups.len(),
124 )
125 }
126 .into(),
127 })
128 }
129}
130
131pub enum GroupRanges {
133 ListView {
135 ranges: Vec<(usize, usize)>,
137 },
138 FixedSizeList {
140 len: usize,
142 size: usize,
144 },
145}
146
147impl GroupRanges {
148 pub fn len(&self) -> usize {
150 match self {
151 Self::ListView { ranges } => ranges.len(),
152 Self::FixedSizeList { len, .. } => *len,
153 }
154 }
155
156 pub fn is_empty(&self) -> bool {
158 self.len() == 0
159 }
160
161 fn range(&self, index: usize) -> (usize, usize) {
163 match self {
164 Self::ListView { ranges } => ranges[index],
165 Self::FixedSizeList { len, size } => {
166 assert!(index < *len, "range index out of bounds");
167 (index * size, *size)
168 }
169 }
170 }
171
172 pub fn iter(&self) -> impl Iterator<Item = (usize, usize)> + '_ {
174 (0..self.len()).map(|index| self.range(index))
175 }
176}
177
178pub struct GroupedAccumulator<V: AggregateFnVTable> {
183 vtable: V,
185 options: V::Options,
187 aggregate_fn: AggregateFnRef,
189 dtype: DType,
191 return_dtype: DType,
193 partial_dtype: DType,
195 partials: Vec<ArrayRef>,
197}
198
199impl<V: AggregateFnVTable> GroupedAccumulator<V> {
200 pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
201 let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
202 let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
203 vortex_err!(
204 "Aggregate function {} cannot be applied to dtype {}",
205 vtable.id(),
206 dtype
207 )
208 })?;
209 let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
210 vortex_err!(
211 "Aggregate function {} cannot be applied to dtype {}",
212 vtable.id(),
213 dtype
214 )
215 })?;
216
217 Ok(Self {
218 vtable,
219 options,
220 aggregate_fn,
221 dtype,
222 return_dtype,
223 partial_dtype,
224 partials: vec![],
225 })
226 }
227}
228
229pub trait DynGroupedAccumulator: 'static + Send {
232 fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
234
235 fn flush(&mut self) -> VortexResult<ArrayRef>;
238
239 fn finish(&mut self) -> VortexResult<ArrayRef>;
242}
243
244impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
245 fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
246 let elements_dtype = match groups.dtype() {
247 DType::List(elem, _) => elem,
248 DType::FixedSizeList(elem, ..) => elem,
249 _ => vortex_bail!(
250 "Input DType mismatch: expected List or FixedSizeList, got {}",
251 groups.dtype()
252 ),
253 };
254 vortex_ensure!(
255 elements_dtype.as_ref() == &self.dtype,
256 "Input DType mismatch: expected {}, got {}",
257 self.dtype,
258 elements_dtype
259 );
260
261 let canonical = match groups.clone().execute::<Columnar>(ctx)? {
264 Columnar::Canonical(c) => c,
265 Columnar::Constant(c) => c.into_array().execute::<Canonical>(ctx)?,
266 };
267 match canonical {
268 Canonical::List(groups) => self.accumulate_grouped_array(groups.into(), ctx),
269 Canonical::FixedSizeList(groups) => self.accumulate_grouped_array(groups.into(), ctx),
270 _ => vortex_panic!("We checked the DType above, so this should never happen"),
271 }
272 }
273
274 fn flush(&mut self) -> VortexResult<ArrayRef> {
275 let states = std::mem::take(&mut self.partials);
276 Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array())
277 }
278
279 fn finish(&mut self) -> VortexResult<ArrayRef> {
280 let states = self.flush()?;
281 let results = self.vtable.finalize(states)?;
282
283 vortex_ensure!(
284 results.dtype() == &self.return_dtype,
285 "Return DType mismatch: expected {}, got {}",
286 self.return_dtype,
287 results.dtype()
288 );
289
290 Ok(results)
291 }
292}
293
294impl<V: AggregateFnVTable> GroupedAccumulator<V> {
295 fn accumulate_grouped_array(
296 &mut self,
297 groups: GroupedArray,
298 ctx: &mut ExecutionCtx,
299 ) -> VortexResult<()> {
300 let mut elements = groups.elements().clone();
301 let session = ctx.session().clone();
302
303 for _ in 0..max_iterations() {
304 if let Some(kernel) = session
306 .aggregate_fns()
307 .find_grouped_encoding_kernel(elements.encoding_id(), self.aggregate_fn.id())
308 {
309 let kernel_groups = unsafe { groups.with_elements_unchecked(elements.clone())? };
311 if let Some(result) =
312 kernel.grouped_aggregate(&self.aggregate_fn, &kernel_groups, ctx)?
313 {
314 return self.push_result(result);
315 }
316 }
317
318 if let Some(kernel) = session
320 .aggregate_fns()
321 .find_grouped_kernel(self.aggregate_fn.id())
322 {
323 let kernel_groups = unsafe { groups.with_elements_unchecked(elements.clone())? };
326 if let Some(result) =
327 kernel.grouped_aggregate(&self.aggregate_fn, &kernel_groups, ctx)?
328 {
329 return self.push_result(result);
330 }
331 }
332
333 if elements.is::<AnyColumnar>() {
334 break;
335 }
336
337 elements = elements.execute(ctx)?;
339 }
340
341 let elements = elements.execute::<Columnar>(ctx)?.into_array();
342 let grouped = unsafe { groups.with_elements_unchecked(elements)? };
345
346 self.accumulate_grouped_fallback(&grouped, ctx)
348 }
349
350 fn accumulate_grouped_fallback(
351 &mut self,
352 grouped: &GroupedArray,
353 ctx: &mut ExecutionCtx,
354 ) -> VortexResult<()> {
355 let mut accumulator = Accumulator::try_new(
356 self.vtable.clone(),
357 self.options.clone(),
358 self.dtype.clone(),
359 )?;
360 let mut states = builder_with_capacity(&self.partial_dtype, grouped.len());
361 let group_ranges = grouped.group_ranges(ctx)?;
362 let group_validity = grouped.group_validity(ctx)?;
363
364 for ((offset, size), valid) in group_ranges.iter().zip(group_validity.iter()) {
365 if valid {
366 let group = grouped.elements().slice(offset..offset + size)?;
367 accumulator.accumulate(&group, ctx)?;
368 states.append_scalar(&accumulator.flush()?)?;
369 } else {
370 states.append_null()
371 }
372 }
373
374 self.push_result(states.finish())
375 }
376
377 fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> {
378 vortex_ensure!(
379 state.dtype() == &self.partial_dtype,
380 "State DType mismatch: expected {}, got {}",
381 self.partial_dtype,
382 state.dtype()
383 );
384 self.partials.push(state);
385 Ok(())
386 }
387}
388fn list_view_group_ranges(
389 groups: &ListViewArray,
390 ctx: &mut ExecutionCtx,
391) -> VortexResult<GroupRanges> {
392 let offsets = groups.offsets();
393 let sizes = groups.sizes().cast(offsets.dtype().clone())?;
394
395 let ranges = match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
396 let offsets = offsets.clone().execute::<Buffer<O>>(ctx)?;
397 let sizes = sizes.execute::<Buffer<O>>(ctx)?;
398 offsets
399 .as_ref()
400 .iter()
401 .zip(sizes.as_ref().iter())
402 .map(|(offset, size)| {
403 (
404 offset.to_usize().vortex_expect("Offset value is not usize"),
405 size.to_usize().vortex_expect("Size value is not usize"),
406 )
407 })
408 .collect::<Vec<_>>()
409 });
410
411 Ok(GroupRanges::ListView { ranges })
412}
413
414fn fixed_size_list_group_ranges(groups: &FixedSizeListArray) -> GroupRanges {
415 GroupRanges::FixedSizeList {
416 len: groups.len(),
417 size: groups.list_size() as usize,
418 }
419}