1use std::ops::BitAnd;
5
6use itertools::Itertools;
7use num_traits::ToPrimitive;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12use vortex_error::vortex_panic;
13use vortex_mask::AllOr;
14
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::ExecutionCtx;
18use crate::aggregate_fn::AggregateFnId;
19use crate::aggregate_fn::AggregateFnVTable;
20use crate::aggregate_fn::EmptyOptions;
21use crate::arrays::BoolArray;
22use crate::arrays::DecimalArray;
23use crate::arrays::PrimitiveArray;
24use crate::dtype::DType;
25use crate::dtype::Nullability;
26use crate::dtype::PType;
27use crate::expr::stats::Stat;
28use crate::match_each_decimal_value_type;
29use crate::match_each_native_ptype;
30use crate::scalar::DecimalValue;
31use crate::scalar::Scalar;
32
33#[derive(Clone, Debug)]
34pub struct Sum;
35
36impl AggregateFnVTable for Sum {
37 type Options = EmptyOptions;
38 type Partial = SumPartial;
39
40 fn id(&self) -> AggregateFnId {
41 AggregateFnId::new_ref("vortex.sum")
42 }
43
44 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
45 Stat::Sum
46 .dtype(input_dtype)
47 .ok_or_else(|| vortex_err!("Cannot sum {}", input_dtype))
48 }
49
50 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
51 self.return_dtype(options, input_dtype)
52 }
53
54 fn empty_partial(
55 &self,
56 _options: &Self::Options,
57 input_dtype: &DType,
58 ) -> VortexResult<Self::Partial> {
59 let return_dtype = Stat::Sum
60 .dtype(input_dtype)
61 .ok_or_else(|| vortex_err!("Cannot sum {}", input_dtype))?;
62
63 let initial = make_zero_state(&return_dtype);
64
65 Ok(SumPartial {
66 return_dtype,
67 current: Some(initial),
68 })
69 }
70
71 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
72 if other.is_null() {
73 partial.current = None;
75 return Ok(());
76 }
77 let Some(ref mut inner) = partial.current else {
78 return Ok(());
79 };
80 let saturated = match inner {
81 SumState::Unsigned(acc) => {
82 let val = other
83 .as_primitive()
84 .typed_value::<u64>()
85 .vortex_expect("checked non-null");
86 checked_add_u64(acc, val)
87 }
88 SumState::Signed(acc) => {
89 let val = other
90 .as_primitive()
91 .typed_value::<i64>()
92 .vortex_expect("checked non-null");
93 checked_add_i64(acc, val)
94 }
95 SumState::Float(acc) => {
96 let val = other
97 .as_primitive()
98 .typed_value::<f64>()
99 .vortex_expect("checked non-null");
100 *acc += val;
101 false
102 }
103 SumState::Decimal(acc) => {
104 let val = other
105 .as_decimal()
106 .decimal_value()
107 .vortex_expect("checked non-null");
108 match acc.checked_add(&val) {
109 Some(r) => {
110 *acc = r;
111 false
112 }
113 None => true,
114 }
115 }
116 };
117 if saturated {
118 partial.current = None;
119 }
120 Ok(())
121 }
122
123 fn flush(&self, partial: &mut Self::Partial) -> VortexResult<Scalar> {
124 let result = match &partial.current {
125 None => Scalar::null(partial.return_dtype.as_nullable()),
126 Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable),
127 Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable),
128 Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable),
129 Some(SumState::Decimal(v)) => {
130 let decimal_dtype = *partial
131 .return_dtype
132 .as_decimal_opt()
133 .vortex_expect("return dtype must be decimal");
134 Scalar::decimal(*v, decimal_dtype, Nullability::Nullable)
135 }
136 };
137
138 partial.current = Some(make_zero_state(&partial.return_dtype));
140
141 Ok(result)
142 }
143
144 #[inline]
145 fn is_saturated(&self, partial: &Self::Partial) -> bool {
146 partial.current.is_none()
147 }
148
149 fn accumulate(
150 &self,
151 partial: &mut Self::Partial,
152 batch: &Canonical,
153 _ctx: &mut ExecutionCtx,
154 ) -> VortexResult<()> {
155 let mut inner = match partial.current.take() {
156 Some(inner) => inner,
157 None => return Ok(()),
158 };
159
160 let result = match batch {
161 Canonical::Primitive(p) => accumulate_primitive(&mut inner, p),
162 Canonical::Bool(b) => accumulate_bool(&mut inner, b),
163 Canonical::Decimal(d) => accumulate_decimal(&mut inner, d),
164 _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()),
165 };
166
167 match result {
168 Ok(false) => partial.current = Some(inner),
169 Ok(true) => {} Err(e) => {
171 partial.current = Some(inner);
172 return Err(e);
173 }
174 }
175 Ok(())
176 }
177
178 fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
179 Ok(partials)
180 }
181
182 fn finalize_scalar(&self, partial: Scalar) -> VortexResult<Scalar> {
183 Ok(partial)
184 }
185}
186
187pub struct SumPartial {
190 return_dtype: DType,
191 current: Option<SumState>,
193}
194
195pub enum SumState {
200 Unsigned(u64),
201 Signed(i64),
202 Float(f64),
203 Decimal(DecimalValue),
204}
205
206fn make_zero_state(return_dtype: &DType) -> SumState {
207 match return_dtype {
208 DType::Primitive(ptype, _) => match ptype {
209 PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0),
210 PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0),
211 PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0),
212 },
213 DType::Decimal(decimal, _) => SumState::Decimal(DecimalValue::zero(decimal)),
214 _ => vortex_panic!("Unsupported sum type"),
215 }
216}
217
218#[inline(always)]
220fn checked_add_u64(acc: &mut u64, val: u64) -> bool {
221 match acc.checked_add(val) {
222 Some(r) => {
223 *acc = r;
224 false
225 }
226 None => true,
227 }
228}
229
230#[inline(always)]
232fn checked_add_i64(acc: &mut i64, val: i64) -> bool {
233 match acc.checked_add(val) {
234 Some(r) => {
235 *acc = r;
236 false
237 }
238 None => true,
239 }
240}
241
242fn accumulate_primitive(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult<bool> {
243 let mask = p.validity_mask()?;
244 match mask.bit_buffer() {
245 AllOr::None => Ok(false),
246 AllOr::All => accumulate_primitive_all(inner, p),
247 AllOr::Some(validity) => accumulate_primitive_valid(inner, p, validity),
248 }
249}
250
251fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult<bool> {
252 match inner {
253 SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(),
254 unsigned: |T| {
255 for &v in p.as_slice::<T>() {
256 if checked_add_u64(acc, v.to_u64().vortex_expect("unsigned to u64")) {
257 return Ok(true);
258 }
259 }
260 Ok(false)
261 },
262 signed: |_T| { vortex_panic!("unsigned sum state with signed input") },
263 floating: |_T| { vortex_panic!("unsigned sum state with float input") }
264 ),
265 SumState::Signed(acc) => match_each_native_ptype!(p.ptype(),
266 unsigned: |_T| { vortex_panic!("signed sum state with unsigned input") },
267 signed: |T| {
268 for &v in p.as_slice::<T>() {
269 if checked_add_i64(acc, v.to_i64().vortex_expect("signed to i64")) {
270 return Ok(true);
271 }
272 }
273 Ok(false)
274 },
275 floating: |_T| { vortex_panic!("signed sum state with float input") }
276 ),
277 SumState::Float(acc) => match_each_native_ptype!(p.ptype(),
278 unsigned: |_T| { vortex_panic!("float sum state with unsigned input") },
279 signed: |_T| { vortex_panic!("float sum state with signed input") },
280 floating: |T| {
281 for &v in p.as_slice::<T>() {
282 *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64");
283 }
284 Ok(false)
285 }
286 ),
287 SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"),
288 }
289}
290
291fn accumulate_primitive_valid(
292 inner: &mut SumState,
293 p: &PrimitiveArray,
294 validity: &vortex_buffer::BitBuffer,
295) -> VortexResult<bool> {
296 match inner {
297 SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(),
298 unsigned: |T| {
299 for (&v, valid) in p.as_slice::<T>().iter().zip_eq(validity.iter()) {
300 if valid && checked_add_u64(acc, v.to_u64().vortex_expect("unsigned to u64")) {
301 return Ok(true);
302 }
303 }
304 Ok(false)
305 },
306 signed: |_T| { vortex_panic!("unsigned sum state with signed input") },
307 floating: |_T| { vortex_panic!("unsigned sum state with float input") }
308 ),
309 SumState::Signed(acc) => match_each_native_ptype!(p.ptype(),
310 unsigned: |_T| { vortex_panic!("signed sum state with unsigned input") },
311 signed: |T| {
312 for (&v, valid) in p.as_slice::<T>().iter().zip_eq(validity.iter()) {
313 if valid && checked_add_i64(acc, v.to_i64().vortex_expect("signed to i64")) {
314 return Ok(true);
315 }
316 }
317 Ok(false)
318 },
319 floating: |_T| { vortex_panic!("signed sum state with float input") }
320 ),
321 SumState::Float(acc) => match_each_native_ptype!(p.ptype(),
322 unsigned: |_T| { vortex_panic!("float sum state with unsigned input") },
323 signed: |_T| { vortex_panic!("float sum state with signed input") },
324 floating: |T| {
325 for (&v, valid) in p.as_slice::<T>().iter().zip_eq(validity.iter()) {
326 if valid {
327 *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64");
328 }
329 }
330 Ok(false)
331 }
332 ),
333 SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"),
334 }
335}
336
337fn accumulate_bool(inner: &mut SumState, b: &BoolArray) -> VortexResult<bool> {
338 let SumState::Unsigned(acc) = inner else {
339 vortex_panic!("expected unsigned sum state for bool input");
340 };
341
342 let mask = b.validity_mask()?;
343 let true_count = match mask.bit_buffer() {
344 AllOr::None => return Ok(false),
345 AllOr::All => b.to_bit_buffer().true_count() as u64,
346 AllOr::Some(validity) => b.to_bit_buffer().bitand(validity).true_count() as u64,
347 };
348
349 Ok(checked_add_u64(acc, true_count))
350}
351
352fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> VortexResult<bool> {
355 let SumState::Decimal(acc) = inner else {
356 vortex_panic!("expected decimal sum state for decimal input");
357 };
358
359 let mask = d.validity_mask()?;
360 match mask.bit_buffer() {
361 AllOr::None => Ok(false),
362 AllOr::All => match_each_decimal_value_type!(d.values_type(), |T| {
363 for &v in d.buffer::<T>().iter() {
364 match acc.checked_add(&DecimalValue::from(v)) {
365 Some(r) => *acc = r,
366 None => return Ok(true),
367 }
368 }
369 Ok(false)
370 }),
371 AllOr::Some(validity) => match_each_decimal_value_type!(d.values_type(), |T| {
372 for (&v, valid) in d.buffer::<T>().iter().zip_eq(validity.iter()) {
373 if valid {
374 match acc.checked_add(&DecimalValue::from(v)) {
375 Some(r) => *acc = r,
376 None => return Ok(true),
377 }
378 }
379 }
380 Ok(false)
381 }),
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use vortex_buffer::buffer;
388 use vortex_error::VortexResult;
389 use vortex_session::VortexSession;
390
391 use crate::ArrayRef;
392 use crate::IntoArray;
393 use crate::aggregate_fn::Accumulator;
394 use crate::aggregate_fn::AggregateFnVTable;
395 use crate::aggregate_fn::DynAccumulator;
396 use crate::aggregate_fn::DynGroupedAccumulator;
397 use crate::aggregate_fn::EmptyOptions;
398 use crate::aggregate_fn::GroupedAccumulator;
399 use crate::aggregate_fn::fns::sum::Sum;
400 use crate::arrays::BoolArray;
401 use crate::arrays::FixedSizeListArray;
402 use crate::arrays::PrimitiveArray;
403 use crate::assert_arrays_eq;
404 use crate::dtype::DType;
405 use crate::dtype::Nullability;
406 use crate::dtype::PType;
407 use crate::scalar::Scalar;
408 use crate::validity::Validity;
409
410 fn session() -> VortexSession {
411 VortexSession::empty()
412 }
413
414 fn run_sum(batch: &ArrayRef) -> VortexResult<Scalar> {
415 let mut acc = Accumulator::try_new(Sum, EmptyOptions, batch.dtype().clone(), session())?;
416 acc.accumulate(batch)?;
417 acc.finish()
418 }
419
420 #[test]
423 fn sum_i32() -> VortexResult<()> {
424 let arr = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
425 let result = run_sum(&arr)?;
426 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(10));
427 Ok(())
428 }
429
430 #[test]
431 fn sum_u8() -> VortexResult<()> {
432 let arr = PrimitiveArray::new(buffer![10u8, 20, 30], Validity::NonNullable).into_array();
433 let result = run_sum(&arr)?;
434 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(60));
435 Ok(())
436 }
437
438 #[test]
439 fn sum_f64() -> VortexResult<()> {
440 let arr =
441 PrimitiveArray::new(buffer![1.5f64, 2.5, 3.0], Validity::NonNullable).into_array();
442 let result = run_sum(&arr)?;
443 assert_eq!(result.as_primitive().typed_value::<f64>(), Some(7.0));
444 Ok(())
445 }
446
447 #[test]
448 fn sum_with_nulls() -> VortexResult<()> {
449 let arr = PrimitiveArray::from_option_iter([Some(2i32), None, Some(4)]).into_array();
450 let result = run_sum(&arr)?;
451 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(6));
452 Ok(())
453 }
454
455 #[test]
456 fn sum_all_null() -> VortexResult<()> {
457 let arr = PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array();
459 let result = run_sum(&arr)?;
460 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(0));
461 Ok(())
462 }
463
464 #[test]
467 fn sum_empty_produces_zero() -> VortexResult<()> {
468 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
469 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
470 let result = acc.finish()?;
471 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(0));
472 Ok(())
473 }
474
475 #[test]
476 fn sum_empty_f64_produces_zero() -> VortexResult<()> {
477 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
478 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
479 let result = acc.finish()?;
480 assert_eq!(result.as_primitive().typed_value::<f64>(), Some(0.0));
481 Ok(())
482 }
483
484 #[test]
487 fn sum_multi_batch() -> VortexResult<()> {
488 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
489 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
490
491 let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
492 acc.accumulate(&batch1)?;
493
494 let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
495 acc.accumulate(&batch2)?;
496
497 let result = acc.finish()?;
498 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(48));
499 Ok(())
500 }
501
502 #[test]
503 fn sum_finish_resets_state() -> VortexResult<()> {
504 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
505 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
506
507 let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
508 acc.accumulate(&batch1)?;
509 let result1 = acc.finish()?;
510 assert_eq!(result1.as_primitive().typed_value::<i64>(), Some(30));
511
512 let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
513 acc.accumulate(&batch2)?;
514 let result2 = acc.finish()?;
515 assert_eq!(result2.as_primitive().typed_value::<i64>(), Some(18));
516 Ok(())
517 }
518
519 #[test]
522 fn sum_state_merge() -> VortexResult<()> {
523 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
524 let mut state = Sum.empty_partial(&EmptyOptions, &dtype)?;
525
526 let scalar1 = Scalar::primitive(100i64, Nullability::Nullable);
527 Sum.combine_partials(&mut state, scalar1)?;
528
529 let scalar2 = Scalar::primitive(50i64, Nullability::Nullable);
530 Sum.combine_partials(&mut state, scalar2)?;
531
532 let result = Sum.flush(&mut state)?;
533 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(150));
534 Ok(())
535 }
536
537 #[test]
540 fn sum_checked_overflow() -> VortexResult<()> {
541 let arr = PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array();
542 let result = run_sum(&arr)?;
543 assert!(result.is_null());
544 Ok(())
545 }
546
547 #[test]
548 fn sum_checked_overflow_is_saturated() -> VortexResult<()> {
549 let dtype = DType::Primitive(PType::I64, Nullability::NonNullable);
550 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
551 assert!(!acc.is_saturated());
552
553 let batch =
554 PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array();
555 acc.accumulate(&batch)?;
556 assert!(acc.is_saturated());
557
558 drop(acc.finish()?);
560 assert!(!acc.is_saturated());
561 Ok(())
562 }
563
564 #[test]
567 fn sum_bool_all_true() -> VortexResult<()> {
568 let arr: BoolArray = [true, true, true].into_iter().collect();
569 let result = run_sum(&arr.into_array())?;
570 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
571 Ok(())
572 }
573
574 #[test]
575 fn sum_bool_mixed() -> VortexResult<()> {
576 let arr: BoolArray = [true, false, true, false, true].into_iter().collect();
577 let result = run_sum(&arr.into_array())?;
578 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
579 Ok(())
580 }
581
582 #[test]
583 fn sum_bool_all_false() -> VortexResult<()> {
584 let arr: BoolArray = [false, false, false].into_iter().collect();
585 let result = run_sum(&arr.into_array())?;
586 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
587 Ok(())
588 }
589
590 #[test]
591 fn sum_bool_with_nulls() -> VortexResult<()> {
592 let arr = BoolArray::from_iter([Some(true), None, Some(true), Some(false)]);
593 let result = run_sum(&arr.into_array())?;
594 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(2));
595 Ok(())
596 }
597
598 #[test]
599 fn sum_bool_all_null() -> VortexResult<()> {
600 let arr = BoolArray::from_iter([None::<bool>, None, None]);
602 let result = run_sum(&arr.into_array())?;
603 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
604 Ok(())
605 }
606
607 #[test]
608 fn sum_bool_empty_produces_zero() -> VortexResult<()> {
609 let dtype = DType::Bool(Nullability::NonNullable);
610 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
611 let result = acc.finish()?;
612 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
613 Ok(())
614 }
615
616 #[test]
617 fn sum_bool_finish_resets_state() -> VortexResult<()> {
618 let dtype = DType::Bool(Nullability::NonNullable);
619 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
620
621 let batch1: BoolArray = [true, true, false].into_iter().collect();
622 acc.accumulate(&batch1.into_array())?;
623 let result1 = acc.finish()?;
624 assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(2));
625
626 let batch2: BoolArray = [false, true].into_iter().collect();
627 acc.accumulate(&batch2.into_array())?;
628 let result2 = acc.finish()?;
629 assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(1));
630 Ok(())
631 }
632
633 #[test]
634 fn sum_bool_return_dtype() -> VortexResult<()> {
635 let dtype = Sum.return_dtype(&EmptyOptions, &DType::Bool(Nullability::NonNullable))?;
636 assert_eq!(dtype, DType::Primitive(PType::U64, Nullability::Nullable));
637 Ok(())
638 }
639
640 fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult<ArrayRef> {
643 let mut acc =
644 GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone(), session())?;
645 acc.accumulate_list(groups)?;
646 acc.finish()
647 }
648
649 #[test]
650 fn grouped_sum_fixed_size_list() -> VortexResult<()> {
651 let elements =
653 PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array();
654 let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
655
656 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
657 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
658
659 let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array();
660 assert_arrays_eq!(&result, &expected);
661 Ok(())
662 }
663
664 #[test]
665 fn grouped_sum_with_null_elements() -> VortexResult<()> {
666 let elements =
668 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)])
669 .into_array();
670 let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
671
672 let elem_dtype = DType::Primitive(PType::I32, Nullability::Nullable);
673 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
674
675 let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array();
676 assert_arrays_eq!(&result, &expected);
677 Ok(())
678 }
679
680 #[test]
681 fn grouped_sum_with_null_group() -> VortexResult<()> {
682 let elements =
684 PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable)
685 .into_array();
686 let validity = Validity::from_iter([true, false, true]);
687 let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?;
688
689 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
690 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
691
692 let expected =
693 PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array();
694 assert_arrays_eq!(&result, &expected);
695 Ok(())
696 }
697
698 #[test]
699 fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> {
700 let elements =
702 PrimitiveArray::from_option_iter([None::<i32>, None, Some(3), Some(4)]).into_array();
703 let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?;
704
705 let elem_dtype = DType::Primitive(PType::I32, Nullability::Nullable);
706 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
707
708 let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array();
709 assert_arrays_eq!(&result, &expected);
710 Ok(())
711 }
712
713 #[test]
714 fn grouped_sum_bool() -> VortexResult<()> {
715 let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect();
717 let groups =
718 FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?;
719
720 let elem_dtype = DType::Bool(Nullability::NonNullable);
721 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
722
723 let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array();
724 assert_arrays_eq!(&result, &expected);
725 Ok(())
726 }
727
728 #[test]
729 fn grouped_sum_finish_resets() -> VortexResult<()> {
730 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
731 let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype, session())?;
732
733 let elements1 =
735 PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
736 let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?;
737 acc.accumulate_list(&groups1.into_array())?;
738 let result1 = acc.finish()?;
739
740 let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array();
741 assert_arrays_eq!(&result1, &expected1);
742
743 let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
745 let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?;
746 acc.accumulate_list(&groups2.into_array())?;
747 let result2 = acc.finish()?;
748
749 let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array();
750 assert_arrays_eq!(&result2, &expected2);
751 Ok(())
752 }
753}