vortex_array/aggregate_fn/
accumulator_grouped.rs1use arrow_buffer::ArrowNativeType;
5use vortex_buffer::Buffer;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12use vortex_session::VortexSession;
13
14use crate::AnyCanonical;
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::Columnar;
18use crate::DynArray;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::VortexSessionExecute;
22use crate::aggregate_fn::Accumulator;
23use crate::aggregate_fn::AggregateFn;
24use crate::aggregate_fn::AggregateFnRef;
25use crate::aggregate_fn::AggregateFnVTable;
26use crate::aggregate_fn::DynAccumulator;
27use crate::aggregate_fn::session::AggregateFnSessionExt;
28use crate::arrays::ChunkedArray;
29use crate::arrays::FixedSizeListArray;
30use crate::arrays::ListViewArray;
31use crate::builders::builder_with_capacity;
32use crate::builtins::ArrayBuiltins;
33use crate::dtype::DType;
34use crate::dtype::IntegerPType;
35use crate::executor::MAX_ITERATIONS;
36use crate::match_each_integer_ptype;
37use crate::vtable::ValidityHelper;
38
39pub type GroupedAccumulatorRef = Box<dyn DynGroupedAccumulator>;
41
42pub struct GroupedAccumulator<V: AggregateFnVTable> {
47 vtable: V,
49 options: V::Options,
51 aggregate_fn: AggregateFnRef,
53 dtype: DType,
55 return_dtype: DType,
57 partial_dtype: DType,
59 partials: Vec<ArrayRef>,
61 session: VortexSession,
63}
64
65impl<V: AggregateFnVTable> GroupedAccumulator<V> {
66 pub fn try_new(
67 vtable: V,
68 options: V::Options,
69 dtype: DType,
70 session: VortexSession,
71 ) -> VortexResult<Self> {
72 let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
73 let return_dtype = vtable.return_dtype(&options, &dtype)?;
74 let partial_dtype = vtable.partial_dtype(&options, &dtype)?;
75
76 Ok(Self {
77 vtable,
78 options,
79 aggregate_fn,
80 dtype,
81 return_dtype,
82 partial_dtype,
83 partials: vec![],
84 session,
85 })
86 }
87}
88
89pub trait DynGroupedAccumulator: 'static + Send {
92 fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()>;
94
95 fn flush(&mut self) -> VortexResult<ArrayRef>;
98
99 fn finish(&mut self) -> VortexResult<ArrayRef>;
102}
103
104impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
105 fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()> {
106 let elements_dtype = match groups.dtype() {
107 DType::List(elem, _) => elem,
108 DType::FixedSizeList(elem, ..) => elem,
109 _ => vortex_bail!(
110 "Input DType mismatch: expected List or FixedSizeList, got {}",
111 groups.dtype()
112 ),
113 };
114 vortex_ensure!(
115 elements_dtype.as_ref() == &self.dtype,
116 "Input DType mismatch: expected {}, got {}",
117 self.dtype,
118 elements_dtype
119 );
120
121 let mut ctx = self.session.create_execution_ctx();
122
123 let canonical = match groups.clone().execute::<Columnar>(&mut ctx)? {
126 Columnar::Canonical(c) => c,
127 Columnar::Constant(c) => c.into_array().execute::<Canonical>(&mut ctx)?,
128 };
129 match canonical {
130 Canonical::List(groups) => self.accumulate_list_view(&groups, &mut ctx),
131 Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, &mut ctx),
132 _ => vortex_panic!("We checked the DType above, so this should never happen"),
133 }
134 }
135
136 fn flush(&mut self) -> VortexResult<ArrayRef> {
137 let states = std::mem::take(&mut self.partials);
138 Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array())
139 }
140
141 fn finish(&mut self) -> VortexResult<ArrayRef> {
142 let states = self.flush()?;
143 let results = self.vtable.finalize(states)?;
144
145 vortex_ensure!(
146 results.dtype() == &self.return_dtype,
147 "Return DType mismatch: expected {}, got {}",
148 self.return_dtype,
149 results.dtype()
150 );
151
152 Ok(results)
153 }
154}
155
156impl<V: AggregateFnVTable> GroupedAccumulator<V> {
157 fn accumulate_list_view(
158 &mut self,
159 groups: &ListViewArray,
160 ctx: &mut ExecutionCtx,
161 ) -> VortexResult<()> {
162 let mut elements = groups.elements().clone();
163 let session = self.session.clone();
164
165 let kernels = &session.aggregate_fns().grouped_kernels;
166
167 for _ in 0..*MAX_ITERATIONS {
168 if elements.is::<AnyCanonical>() {
169 break;
170 }
171
172 let kernels_r = kernels.read();
173 if let Some(result) = kernels_r
174 .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
175 .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
176 .and_then(|kernel| {
177 let groups = unsafe {
179 ListViewArray::new_unchecked(
180 elements.clone(),
181 groups.offsets().clone(),
182 groups.sizes().clone(),
183 groups.validity().clone(),
184 )
185 };
186 kernel
187 .grouped_aggregate(&self.aggregate_fn, &groups)
188 .transpose()
189 })
190 .transpose()?
191 {
192 return self.push_result(result);
193 }
194
195 elements = elements.execute(ctx)?;
197 }
198
199 let elements = elements.execute::<Columnar>(ctx)?.into_array();
201 let offsets = groups.offsets();
202 let sizes = groups.sizes().cast(offsets.dtype().clone())?;
203 let validity = groups.validity().to_mask(offsets.len());
204
205 match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
206 let offsets = offsets.clone().execute::<Buffer<O>>(ctx)?;
207 let sizes = sizes.execute::<Buffer<O>>(ctx)?;
208 self.accumulate_list_view_typed(&elements, offsets.as_ref(), sizes.as_ref(), &validity)
209 })
210 }
211
212 fn accumulate_list_view_typed<O: IntegerPType>(
213 &mut self,
214 elements: &ArrayRef,
215 offsets: &[O],
216 sizes: &[O],
217 validity: &Mask,
218 ) -> VortexResult<()> {
219 let mut accumulator = Accumulator::try_new(
220 self.vtable.clone(),
221 self.options.clone(),
222 self.dtype.clone(),
223 self.session.clone(),
224 )?;
225 let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());
226
227 for (offset, size) in offsets.iter().zip(sizes.iter()) {
228 let offset = offset.to_usize().vortex_expect("Offset value is not usize");
229 let size = size.to_usize().vortex_expect("Size value is not usize");
230
231 if validity.value(offset) {
232 let group = elements.slice(offset..offset + size)?;
233 accumulator.accumulate(&group)?;
234 states.append_scalar(&accumulator.finish()?)?;
235 } else {
236 states.append_null()
237 }
238 }
239
240 self.push_result(states.finish())
241 }
242
243 fn accumulate_fixed_size_list(
244 &mut self,
245 groups: &FixedSizeListArray,
246 ctx: &mut ExecutionCtx,
247 ) -> VortexResult<()> {
248 let mut elements = groups.elements().clone();
249
250 let session = self.session.clone();
251 let kernels = &session.aggregate_fns().grouped_kernels;
252
253 for _ in 0..64 {
254 if elements.is::<AnyCanonical>() {
255 break;
256 }
257
258 let kernels_r = kernels.read();
259 if let Some(result) = kernels_r
260 .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
261 .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
262 .and_then(|kernel| {
263 let groups = unsafe {
265 FixedSizeListArray::new_unchecked(
266 elements.clone(),
267 groups.list_size(),
268 groups.validity().clone(),
269 groups.len(),
270 )
271 };
272
273 kernel
274 .grouped_aggregate_fixed_size(&self.aggregate_fn, &groups)
275 .transpose()
276 })
277 .transpose()?
278 {
279 return self.push_result(result);
280 }
281
282 elements = elements.execute(ctx)?;
284 }
285
286 let elements = elements.execute::<Columnar>(ctx)?.into_array();
288 let validity = groups.validity().to_mask(groups.len());
289
290 let mut accumulator = Accumulator::try_new(
291 self.vtable.clone(),
292 self.options.clone(),
293 self.dtype.clone(),
294 self.session.clone(),
295 )?;
296 let mut states = builder_with_capacity(&self.partial_dtype, groups.len());
297
298 let mut offset = 0;
299 let size = groups
300 .list_size()
301 .to_usize()
302 .vortex_expect("List size is not usize");
303
304 for i in 0..groups.len() {
305 if validity.value(i) {
306 let group = elements.slice(offset..offset + size)?;
307 accumulator.accumulate(&group)?;
308 states.append_scalar(&accumulator.finish()?)?;
309 } else {
310 states.append_null()
311 }
312 offset += size;
313 }
314
315 self.push_result(states.finish())
316 }
317
318 fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> {
319 vortex_ensure!(
320 state.dtype() == &self.partial_dtype,
321 "State DType mismatch: expected {}, got {}",
322 self.partial_dtype,
323 state.dtype()
324 );
325 self.partials.push(state);
326 Ok(())
327 }
328}