vortex_array/aggregate_fn/
accumulator_grouped.rs1use arrow_buffer::ArrowNativeType;
5use vortex_buffer::Buffer;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10use vortex_error::vortex_err;
11use vortex_error::vortex_panic;
12use vortex_mask::Mask;
13
14use crate::AnyCanonical;
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::Columnar;
18use crate::DynArray;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::aggregate_fn::Accumulator;
22use crate::aggregate_fn::AggregateFn;
23use crate::aggregate_fn::AggregateFnRef;
24use crate::aggregate_fn::AggregateFnVTable;
25use crate::aggregate_fn::DynAccumulator;
26use crate::aggregate_fn::session::AggregateFnSessionExt;
27use crate::arrays::ChunkedArray;
28use crate::arrays::FixedSizeListArray;
29use crate::arrays::ListViewArray;
30use crate::builders::builder_with_capacity;
31use crate::builtins::ArrayBuiltins;
32use crate::dtype::DType;
33use crate::dtype::IntegerPType;
34use crate::executor::MAX_ITERATIONS;
35use crate::match_each_integer_ptype;
36use crate::vtable::ValidityHelper;
37
38pub type GroupedAccumulatorRef = Box<dyn DynGroupedAccumulator>;
40
41pub struct GroupedAccumulator<V: AggregateFnVTable> {
46 vtable: V,
48 options: V::Options,
50 aggregate_fn: AggregateFnRef,
52 dtype: DType,
54 return_dtype: DType,
56 partial_dtype: DType,
58 partials: Vec<ArrayRef>,
60}
61
62impl<V: AggregateFnVTable> GroupedAccumulator<V> {
63 pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
64 let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
65 let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
66 vortex_err!(
67 "Aggregate function {} cannot be applied to dtype {}",
68 vtable.id(),
69 dtype
70 )
71 })?;
72 let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
73 vortex_err!(
74 "Aggregate function {} cannot be applied to dtype {}",
75 vtable.id(),
76 dtype
77 )
78 })?;
79
80 Ok(Self {
81 vtable,
82 options,
83 aggregate_fn,
84 dtype,
85 return_dtype,
86 partial_dtype,
87 partials: vec![],
88 })
89 }
90}
91
92pub trait DynGroupedAccumulator: 'static + Send {
95 fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
97
98 fn flush(&mut self) -> VortexResult<ArrayRef>;
101
102 fn finish(&mut self) -> VortexResult<ArrayRef>;
105}
106
107impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
108 fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
109 let elements_dtype = match groups.dtype() {
110 DType::List(elem, _) => elem,
111 DType::FixedSizeList(elem, ..) => elem,
112 _ => vortex_bail!(
113 "Input DType mismatch: expected List or FixedSizeList, got {}",
114 groups.dtype()
115 ),
116 };
117 vortex_ensure!(
118 elements_dtype.as_ref() == &self.dtype,
119 "Input DType mismatch: expected {}, got {}",
120 self.dtype,
121 elements_dtype
122 );
123
124 let canonical = match groups.clone().execute::<Columnar>(ctx)? {
127 Columnar::Canonical(c) => c,
128 Columnar::Constant(c) => c.into_array().execute::<Canonical>(ctx)?,
129 };
130 match canonical {
131 Canonical::List(groups) => self.accumulate_list_view(&groups, ctx),
132 Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, ctx),
133 _ => vortex_panic!("We checked the DType above, so this should never happen"),
134 }
135 }
136
137 fn flush(&mut self) -> VortexResult<ArrayRef> {
138 let states = std::mem::take(&mut self.partials);
139 Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array())
140 }
141
142 fn finish(&mut self) -> VortexResult<ArrayRef> {
143 let states = self.flush()?;
144 let results = self.vtable.finalize(states)?;
145
146 vortex_ensure!(
147 results.dtype() == &self.return_dtype,
148 "Return DType mismatch: expected {}, got {}",
149 self.return_dtype,
150 results.dtype()
151 );
152
153 Ok(results)
154 }
155}
156
157impl<V: AggregateFnVTable> GroupedAccumulator<V> {
158 fn accumulate_list_view(
159 &mut self,
160 groups: &ListViewArray,
161 ctx: &mut ExecutionCtx,
162 ) -> VortexResult<()> {
163 let mut elements = groups.elements().clone();
164 let session = ctx.session().clone();
165 let kernels = &session.aggregate_fns().grouped_kernels;
166
167 for _ in 0..*MAX_ITERATIONS {
168 if elements.is::<AnyCanonical>() {
169 break;
170 }
171
172 let kernels_r = kernels.read();
173 if let Some(result) = kernels_r
174 .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
175 .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
176 .and_then(|kernel| {
177 let groups = unsafe {
179 ListViewArray::new_unchecked(
180 elements.clone(),
181 groups.offsets().clone(),
182 groups.sizes().clone(),
183 groups.validity().clone(),
184 )
185 };
186 kernel
187 .grouped_aggregate(&self.aggregate_fn, &groups)
188 .transpose()
189 })
190 .transpose()?
191 {
192 return self.push_result(result);
193 }
194
195 elements = elements.execute(ctx)?;
197 }
198
199 let elements = elements.execute::<Columnar>(ctx)?.into_array();
201 let offsets = groups.offsets();
202 let sizes = groups.sizes().cast(offsets.dtype().clone())?;
203 let validity = groups.validity().execute_mask(offsets.len(), ctx)?;
204
205 match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
206 let offsets = offsets.clone().execute::<Buffer<O>>(ctx)?;
207 let sizes = sizes.execute::<Buffer<O>>(ctx)?;
208 self.accumulate_list_view_typed(
209 &elements,
210 offsets.as_ref(),
211 sizes.as_ref(),
212 &validity,
213 ctx,
214 )
215 })
216 }
217
218 fn accumulate_list_view_typed<O: IntegerPType>(
219 &mut self,
220 elements: &ArrayRef,
221 offsets: &[O],
222 sizes: &[O],
223 validity: &Mask,
224 ctx: &mut ExecutionCtx,
225 ) -> VortexResult<()> {
226 let mut accumulator = Accumulator::try_new(
227 self.vtable.clone(),
228 self.options.clone(),
229 self.dtype.clone(),
230 )?;
231 let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());
232
233 for (offset, size) in offsets.iter().zip(sizes.iter()) {
234 let offset = offset.to_usize().vortex_expect("Offset value is not usize");
235 let size = size.to_usize().vortex_expect("Size value is not usize");
236
237 if validity.value(offset) {
238 let group = elements.slice(offset..offset + size)?;
239 accumulator.accumulate(&group, ctx)?;
240 states.append_scalar(&accumulator.finish()?)?;
241 } else {
242 states.append_null()
243 }
244 }
245
246 self.push_result(states.finish())
247 }
248
249 fn accumulate_fixed_size_list(
250 &mut self,
251 groups: &FixedSizeListArray,
252 ctx: &mut ExecutionCtx,
253 ) -> VortexResult<()> {
254 let mut elements = groups.elements().clone();
255 let session = ctx.session().clone();
256 let kernels = &session.aggregate_fns().grouped_kernels;
257
258 for _ in 0..64 {
259 if elements.is::<AnyCanonical>() {
260 break;
261 }
262
263 let kernels_r = kernels.read();
264 if let Some(result) = kernels_r
265 .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
266 .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
267 .and_then(|kernel| {
268 let groups = unsafe {
270 FixedSizeListArray::new_unchecked(
271 elements.clone(),
272 groups.list_size(),
273 groups.validity().clone(),
274 groups.len(),
275 )
276 };
277
278 kernel
279 .grouped_aggregate_fixed_size(&self.aggregate_fn, &groups)
280 .transpose()
281 })
282 .transpose()?
283 {
284 return self.push_result(result);
285 }
286
287 elements = elements.execute(ctx)?;
289 }
290
291 let elements = elements.execute::<Columnar>(ctx)?.into_array();
293 let validity = groups.validity().execute_mask(groups.len(), ctx)?;
294
295 let mut accumulator = Accumulator::try_new(
296 self.vtable.clone(),
297 self.options.clone(),
298 self.dtype.clone(),
299 )?;
300 let mut states = builder_with_capacity(&self.partial_dtype, groups.len());
301
302 let mut offset = 0;
303 let size = groups
304 .list_size()
305 .to_usize()
306 .vortex_expect("List size is not usize");
307
308 for i in 0..groups.len() {
309 if validity.value(i) {
310 let group = elements.slice(offset..offset + size)?;
311 accumulator.accumulate(&group, ctx)?;
312 states.append_scalar(&accumulator.finish()?)?;
313 } else {
314 states.append_null()
315 }
316 offset += size;
317 }
318
319 self.push_result(states.finish())
320 }
321
322 fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> {
323 vortex_ensure!(
324 state.dtype() == &self.partial_dtype,
325 "State DType mismatch: expected {}, got {}",
326 self.partial_dtype,
327 state.dtype()
328 );
329 self.partials.push(state);
330 Ok(())
331 }
332}