vortex_array/aggregate_fn/
accumulator_grouped.rs1use arrow_buffer::ArrowNativeType;
5use vortex_buffer::Buffer;
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_error::vortex_ensure;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12use vortex_session::VortexSession;
13
14use crate::AnyCanonical;
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::DynArray;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::VortexSessionExecute;
21use crate::aggregate_fn::Accumulator;
22use crate::aggregate_fn::AggregateFn;
23use crate::aggregate_fn::AggregateFnRef;
24use crate::aggregate_fn::AggregateFnVTable;
25use crate::aggregate_fn::DynAccumulator;
26use crate::aggregate_fn::session::AggregateFnSessionExt;
27use crate::arrays::ChunkedArray;
28use crate::arrays::FixedSizeListArray;
29use crate::arrays::ListViewArray;
30use crate::builders::builder_with_capacity;
31use crate::builtins::ArrayBuiltins;
32use crate::dtype::DType;
33use crate::dtype::IntegerPType;
34use crate::executor::MAX_ITERATIONS;
35use crate::match_each_integer_ptype;
36use crate::vtable::ValidityHelper;
37
38pub type GroupedAccumulatorRef = Box<dyn DynGroupedAccumulator>;
40
41pub struct GroupedAccumulator<V: AggregateFnVTable> {
46 vtable: V,
48 options: V::Options,
50 aggregate_fn: AggregateFnRef,
52 dtype: DType,
54 return_dtype: DType,
56 partial_dtype: DType,
58 partials: Vec<ArrayRef>,
60 session: VortexSession,
62}
63
64impl<V: AggregateFnVTable> GroupedAccumulator<V> {
65 pub fn try_new(
66 vtable: V,
67 options: V::Options,
68 dtype: DType,
69 session: VortexSession,
70 ) -> VortexResult<Self> {
71 let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
72 let return_dtype = vtable.return_dtype(&options, &dtype)?;
73 let partial_dtype = vtable.partial_dtype(&options, &dtype)?;
74
75 Ok(Self {
76 vtable,
77 options,
78 aggregate_fn,
79 dtype,
80 return_dtype,
81 partial_dtype,
82 partials: vec![],
83 session,
84 })
85 }
86}
87
88pub trait DynGroupedAccumulator: 'static + Send {
91 fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()>;
93
94 fn flush(&mut self) -> VortexResult<ArrayRef>;
97
98 fn finish(&mut self) -> VortexResult<ArrayRef>;
101}
102
103impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
104 fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()> {
105 let elements_dtype = match groups.dtype() {
106 DType::List(elem, _) => elem,
107 DType::FixedSizeList(elem, ..) => elem,
108 _ => vortex_bail!(
109 "Input DType mismatch: expected List or FixedSizeList, got {}",
110 groups.dtype()
111 ),
112 };
113 vortex_ensure!(
114 elements_dtype.as_ref() == &self.dtype,
115 "Input DType mismatch: expected {}, got {}",
116 self.dtype,
117 elements_dtype
118 );
119
120 let mut ctx = self.session.create_execution_ctx();
121
122 match groups.clone().execute::<Canonical>(&mut ctx)? {
125 Canonical::List(groups) => self.accumulate_list_view(&groups, &mut ctx),
126 Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, &mut ctx),
127 _ => vortex_panic!("We checked the DType above, so this should never happen"),
128 }
129 }
130
131 fn flush(&mut self) -> VortexResult<ArrayRef> {
132 let states = std::mem::take(&mut self.partials);
133 Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array())
134 }
135
136 fn finish(&mut self) -> VortexResult<ArrayRef> {
137 let states = self.flush()?;
138 let results = self.vtable.finalize(states)?;
139
140 vortex_ensure!(
141 results.dtype() == &self.return_dtype,
142 "Return DType mismatch: expected {}, got {}",
143 self.return_dtype,
144 results.dtype()
145 );
146
147 Ok(results)
148 }
149}
150
151impl<V: AggregateFnVTable> GroupedAccumulator<V> {
152 fn accumulate_list_view(
153 &mut self,
154 groups: &ListViewArray,
155 ctx: &mut ExecutionCtx,
156 ) -> VortexResult<()> {
157 let mut elements = groups.elements().clone();
158 let session = self.session.clone();
159
160 let kernels = &session.aggregate_fns().grouped_kernels;
161
162 for _ in 0..*MAX_ITERATIONS {
163 if elements.is::<AnyCanonical>() {
164 break;
165 }
166
167 let kernels_r = kernels.read();
168 if let Some(result) = kernels_r
169 .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
170 .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
171 .and_then(|kernel| {
172 let groups = unsafe {
174 ListViewArray::new_unchecked(
175 elements.clone(),
176 groups.offsets().clone(),
177 groups.sizes().clone(),
178 groups.validity().clone(),
179 )
180 };
181 kernel
182 .grouped_aggregate(&self.aggregate_fn, &groups)
183 .transpose()
184 })
185 .transpose()?
186 {
187 return self.push_result(result);
188 }
189
190 elements = elements.execute(ctx)?;
192 }
193
194 let elements = elements.execute::<Canonical>(ctx)?.into_array();
196 let offsets = groups.offsets();
197 let sizes = groups.sizes().cast(offsets.dtype().clone())?;
198 let validity = groups.validity().to_mask(offsets.len());
199
200 match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
201 let offsets = offsets.clone().execute::<Buffer<O>>(ctx)?;
202 let sizes = sizes.execute::<Buffer<O>>(ctx)?;
203 self.accumulate_list_view_typed(&elements, offsets.as_ref(), sizes.as_ref(), &validity)
204 })
205 }
206
207 fn accumulate_list_view_typed<O: IntegerPType>(
208 &mut self,
209 elements: &ArrayRef,
210 offsets: &[O],
211 sizes: &[O],
212 validity: &Mask,
213 ) -> VortexResult<()> {
214 let mut accumulator = Accumulator::try_new(
215 self.vtable.clone(),
216 self.options.clone(),
217 self.dtype.clone(),
218 self.session.clone(),
219 )?;
220 let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());
221
222 for (offset, size) in offsets.iter().zip(sizes.iter()) {
223 let offset = offset.to_usize().vortex_expect("Offset value is not usize");
224 let size = size.to_usize().vortex_expect("Size value is not usize");
225
226 if validity.value(offset) {
227 let group = elements.slice(offset..offset + size)?;
228 accumulator.accumulate(&group)?;
229 states.append_scalar(&accumulator.finish()?)?;
230 } else {
231 states.append_null()
232 }
233 }
234
235 self.push_result(states.finish())
236 }
237
238 fn accumulate_fixed_size_list(
239 &mut self,
240 groups: &FixedSizeListArray,
241 ctx: &mut ExecutionCtx,
242 ) -> VortexResult<()> {
243 let mut elements = groups.elements().clone();
244
245 let session = self.session.clone();
246 let kernels = &session.aggregate_fns().grouped_kernels;
247
248 for _ in 0..64 {
249 if elements.is::<AnyCanonical>() {
250 break;
251 }
252
253 let kernels_r = kernels.read();
254 if let Some(result) = kernels_r
255 .get(&(elements.encoding_id(), Some(self.aggregate_fn.id())))
256 .or_else(|| kernels_r.get(&(elements.encoding_id(), None)))
257 .and_then(|kernel| {
258 let groups = unsafe {
260 FixedSizeListArray::new_unchecked(
261 elements.clone(),
262 groups.list_size(),
263 groups.validity().clone(),
264 groups.len(),
265 )
266 };
267
268 kernel
269 .grouped_aggregate_fixed_size(&self.aggregate_fn, &groups)
270 .transpose()
271 })
272 .transpose()?
273 {
274 return self.push_result(result);
275 }
276
277 elements = elements.execute(ctx)?;
279 }
280
281 let elements = elements.execute::<Canonical>(ctx)?.into_array();
283 let validity = groups.validity().to_mask(groups.len());
284
285 let mut accumulator = Accumulator::try_new(
286 self.vtable.clone(),
287 self.options.clone(),
288 self.dtype.clone(),
289 self.session.clone(),
290 )?;
291 let mut states = builder_with_capacity(&self.partial_dtype, groups.len());
292
293 let mut offset = 0;
294 let size = groups
295 .list_size()
296 .to_usize()
297 .vortex_expect("List size is not usize");
298
299 for i in 0..groups.len() {
300 if validity.value(i) {
301 let group = elements.slice(offset..offset + size)?;
302 accumulator.accumulate(&group)?;
303 states.append_scalar(&accumulator.finish()?)?;
304 } else {
305 states.append_null()
306 }
307 offset += size;
308 }
309
310 self.push_result(states.finish())
311 }
312
313 fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> {
314 vortex_ensure!(
315 state.dtype() == &self.partial_dtype,
316 "State DType mismatch: expected {}, got {}",
317 self.partial_dtype,
318 state.dtype()
319 );
320 self.partials.push(state);
321 Ok(())
322 }
323}