vortex_array/aggregate_fn/
accumulator_grouped.rs1use arrow_buffer::ArrowNativeType;
5use vortex_buffer::Buffer;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10use vortex_error::vortex_err;
11use vortex_error::vortex_panic;
12use vortex_mask::Mask;
13
14use crate::AnyCanonical;
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::Columnar;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::aggregate_fn::Accumulator;
21use crate::aggregate_fn::AggregateFn;
22use crate::aggregate_fn::AggregateFnRef;
23use crate::aggregate_fn::AggregateFnVTable;
24use crate::aggregate_fn::DynAccumulator;
25use crate::aggregate_fn::session::AggregateFnSessionExt;
26use crate::arrays::ChunkedArray;
27use crate::arrays::FixedSizeListArray;
28use crate::arrays::ListViewArray;
29use crate::arrays::fixed_size_list::FixedSizeListArrayExt;
30use crate::arrays::listview::ListViewArrayExt;
31use crate::builders::builder_with_capacity;
32use crate::builtins::ArrayBuiltins;
33use crate::dtype::DType;
34use crate::dtype::IntegerPType;
35use crate::executor::MAX_ITERATIONS;
36use crate::match_each_integer_ptype;
37
38pub type GroupedAccumulatorRef = Box<dyn DynGroupedAccumulator>;
40
41pub struct GroupedAccumulator<V: AggregateFnVTable> {
46 vtable: V,
48 options: V::Options,
50 aggregate_fn: AggregateFnRef,
52 dtype: DType,
54 return_dtype: DType,
56 partial_dtype: DType,
58 partials: Vec<ArrayRef>,
60}
61
62impl<V: AggregateFnVTable> GroupedAccumulator<V> {
63 pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
64 let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
65 let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
66 vortex_err!(
67 "Aggregate function {} cannot be applied to dtype {}",
68 vtable.id(),
69 dtype
70 )
71 })?;
72 let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
73 vortex_err!(
74 "Aggregate function {} cannot be applied to dtype {}",
75 vtable.id(),
76 dtype
77 )
78 })?;
79
80 Ok(Self {
81 vtable,
82 options,
83 aggregate_fn,
84 dtype,
85 return_dtype,
86 partial_dtype,
87 partials: vec![],
88 })
89 }
90}
91
92pub trait DynGroupedAccumulator: 'static + Send {
95 fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;
97
98 fn flush(&mut self) -> VortexResult<ArrayRef>;
101
102 fn finish(&mut self) -> VortexResult<ArrayRef>;
105}
106
107impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
108 fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
109 let elements_dtype = match groups.dtype() {
110 DType::List(elem, _) => elem,
111 DType::FixedSizeList(elem, ..) => elem,
112 _ => vortex_bail!(
113 "Input DType mismatch: expected List or FixedSizeList, got {}",
114 groups.dtype()
115 ),
116 };
117 vortex_ensure!(
118 elements_dtype.as_ref() == &self.dtype,
119 "Input DType mismatch: expected {}, got {}",
120 self.dtype,
121 elements_dtype
122 );
123
124 let canonical = match groups.clone().execute::<Columnar>(ctx)? {
127 Columnar::Canonical(c) => c,
128 Columnar::Constant(c) => c.into_array().execute::<Canonical>(ctx)?,
129 };
130 match canonical {
131 Canonical::List(groups) => self.accumulate_list_view(&groups, ctx),
132 Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, ctx),
133 _ => vortex_panic!("We checked the DType above, so this should never happen"),
134 }
135 }
136
137 fn flush(&mut self) -> VortexResult<ArrayRef> {
138 let states = std::mem::take(&mut self.partials);
139 Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array())
140 }
141
142 fn finish(&mut self) -> VortexResult<ArrayRef> {
143 let states = self.flush()?;
144 let results = self.vtable.finalize(states)?;
145
146 vortex_ensure!(
147 results.dtype() == &self.return_dtype,
148 "Return DType mismatch: expected {}, got {}",
149 self.return_dtype,
150 results.dtype()
151 );
152
153 Ok(results)
154 }
155}
156
157impl<V: AggregateFnVTable> GroupedAccumulator<V> {
158 fn accumulate_list_view(
159 &mut self,
160 groups: &ListViewArray,
161 ctx: &mut ExecutionCtx,
162 ) -> VortexResult<()> {
163 let mut elements = groups.elements().clone();
164 let groups_validity = groups.validity()?;
165 let session = ctx.session().clone();
166 let kernels = &session.aggregate_fns().grouped_kernels;
167
168 for _ in 0..*MAX_ITERATIONS {
169 if elements.is::<AnyCanonical>() {
170 break;
171 }
172
173 let kernels_r = kernels.read();
174 if let Some(result) = kernels_r
175 .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
176 .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
177 .and_then(|kernel| {
178 let groups = unsafe {
180 ListViewArray::new_unchecked(
181 elements.clone(),
182 groups.offsets().clone(),
183 groups.sizes().clone(),
184 groups_validity.clone(),
185 )
186 };
187 kernel
188 .grouped_aggregate(&self.aggregate_fn, &groups)
189 .transpose()
190 })
191 .transpose()?
192 {
193 return self.push_result(result);
194 }
195
196 elements = elements.execute(ctx)?;
198 }
199
200 let elements = elements.execute::<Columnar>(ctx)?.into_array();
202 let offsets = groups.offsets();
203 let sizes = groups.sizes().cast(offsets.dtype().clone())?;
204 let validity = groups_validity.execute_mask(offsets.len(), ctx)?;
205
206 match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
207 let offsets = offsets.clone().execute::<Buffer<O>>(ctx)?;
208 let sizes = sizes.execute::<Buffer<O>>(ctx)?;
209 self.accumulate_list_view_typed(
210 &elements,
211 offsets.as_ref(),
212 sizes.as_ref(),
213 &validity,
214 ctx,
215 )
216 })
217 }
218
219 fn accumulate_list_view_typed<O: IntegerPType>(
220 &mut self,
221 elements: &ArrayRef,
222 offsets: &[O],
223 sizes: &[O],
224 validity: &Mask,
225 ctx: &mut ExecutionCtx,
226 ) -> VortexResult<()> {
227 let mut accumulator = Accumulator::try_new(
228 self.vtable.clone(),
229 self.options.clone(),
230 self.dtype.clone(),
231 )?;
232 let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());
233
234 for (offset, size) in offsets.iter().zip(sizes.iter()) {
235 let offset = offset.to_usize().vortex_expect("Offset value is not usize");
236 let size = size.to_usize().vortex_expect("Size value is not usize");
237
238 if validity.value(offset) {
239 let group = elements.slice(offset..offset + size)?;
240 accumulator.accumulate(&group, ctx)?;
241 states.append_scalar(&accumulator.flush()?)?;
242 } else {
243 states.append_null()
244 }
245 }
246
247 self.push_result(states.finish())
248 }
249
250 fn accumulate_fixed_size_list(
251 &mut self,
252 groups: &FixedSizeListArray,
253 ctx: &mut ExecutionCtx,
254 ) -> VortexResult<()> {
255 let mut elements = groups.elements().clone();
256 let groups_validity = groups.validity()?;
257 let session = ctx.session().clone();
258 let kernels = &session.aggregate_fns().grouped_kernels;
259
260 for _ in 0..64 {
261 if elements.is::<AnyCanonical>() {
262 break;
263 }
264
265 let kernels_r = kernels.read();
266 if let Some(result) = kernels_r
267 .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
268 .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
269 .and_then(|kernel| {
270 let groups = unsafe {
272 FixedSizeListArray::new_unchecked(
273 elements.clone(),
274 groups.list_size(),
275 groups_validity.clone(),
276 groups.len(),
277 )
278 };
279
280 kernel
281 .grouped_aggregate_fixed_size(&self.aggregate_fn, &groups)
282 .transpose()
283 })
284 .transpose()?
285 {
286 return self.push_result(result);
287 }
288
289 elements = elements.execute(ctx)?;
291 }
292
293 let elements = elements.execute::<Columnar>(ctx)?.into_array();
295 let validity = groups_validity.execute_mask(groups.len(), ctx)?;
296
297 let mut accumulator = Accumulator::try_new(
298 self.vtable.clone(),
299 self.options.clone(),
300 self.dtype.clone(),
301 )?;
302 let mut states = builder_with_capacity(&self.partial_dtype, groups.len());
303
304 let mut offset = 0;
305 let size = groups
306 .list_size()
307 .to_usize()
308 .vortex_expect("List size is not usize");
309
310 for i in 0..groups.len() {
311 if validity.value(i) {
312 let group = elements.slice(offset..offset + size)?;
313 accumulator.accumulate(&group, ctx)?;
314 states.append_scalar(&accumulator.finish()?)?;
315 } else {
316 states.append_null()
317 }
318 offset += size;
319 }
320
321 self.push_result(states.finish())
322 }
323
324 fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> {
325 vortex_ensure!(
326 state.dtype() == &self.partial_dtype,
327 "State DType mismatch: expected {}, got {}",
328 self.partial_dtype,
329 state.dtype()
330 );
331 self.partials.push(state);
332 Ok(())
333 }
334}