1use std::ops::BitAnd;
5
6use itertools::Itertools;
7use num_traits::ToPrimitive;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12use vortex_error::vortex_panic;
13use vortex_mask::AllOr;
14
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::Columnar;
18use crate::ExecutionCtx;
19use crate::aggregate_fn::AggregateFnId;
20use crate::aggregate_fn::AggregateFnVTable;
21use crate::aggregate_fn::EmptyOptions;
22use crate::arrays::BoolArray;
23use crate::arrays::ConstantArray;
24use crate::arrays::DecimalArray;
25use crate::arrays::PrimitiveArray;
26use crate::dtype::DType;
27use crate::dtype::Nullability;
28use crate::dtype::PType;
29use crate::expr::stats::Stat;
30use crate::match_each_decimal_value_type;
31use crate::match_each_native_ptype;
32use crate::scalar::DecimalValue;
33use crate::scalar::Scalar;
34
35#[derive(Clone, Debug)]
36pub struct Sum;
37
38impl AggregateFnVTable for Sum {
39 type Options = EmptyOptions;
40 type Partial = SumPartial;
41
42 fn id(&self) -> AggregateFnId {
43 AggregateFnId::new_ref("vortex.sum")
44 }
45
46 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
47 Stat::Sum
48 .dtype(input_dtype)
49 .ok_or_else(|| vortex_err!("Cannot sum {}", input_dtype))
50 }
51
52 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
53 self.return_dtype(options, input_dtype)
54 }
55
56 fn empty_partial(
57 &self,
58 _options: &Self::Options,
59 input_dtype: &DType,
60 ) -> VortexResult<Self::Partial> {
61 let return_dtype = Stat::Sum
62 .dtype(input_dtype)
63 .ok_or_else(|| vortex_err!("Cannot sum {}", input_dtype))?;
64
65 let initial = make_zero_state(&return_dtype);
66
67 Ok(SumPartial {
68 return_dtype,
69 current: Some(initial),
70 })
71 }
72
73 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
74 if other.is_null() {
75 partial.current = None;
77 return Ok(());
78 }
79 let Some(ref mut inner) = partial.current else {
80 return Ok(());
81 };
82 let saturated = match inner {
83 SumState::Unsigned(acc) => {
84 let val = other
85 .as_primitive()
86 .typed_value::<u64>()
87 .vortex_expect("checked non-null");
88 checked_add_u64(acc, val)
89 }
90 SumState::Signed(acc) => {
91 let val = other
92 .as_primitive()
93 .typed_value::<i64>()
94 .vortex_expect("checked non-null");
95 checked_add_i64(acc, val)
96 }
97 SumState::Float(acc) => {
98 let val = other
99 .as_primitive()
100 .typed_value::<f64>()
101 .vortex_expect("checked non-null");
102 *acc += val;
103 false
104 }
105 SumState::Decimal(acc) => {
106 let val = other
107 .as_decimal()
108 .decimal_value()
109 .vortex_expect("checked non-null");
110 match acc.checked_add(&val) {
111 Some(r) => {
112 *acc = r;
113 false
114 }
115 None => true,
116 }
117 }
118 };
119 if saturated {
120 partial.current = None;
121 }
122 Ok(())
123 }
124
125 fn flush(&self, partial: &mut Self::Partial) -> VortexResult<Scalar> {
126 let result = match &partial.current {
127 None => Scalar::null(partial.return_dtype.as_nullable()),
128 Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable),
129 Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable),
130 Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable),
131 Some(SumState::Decimal(v)) => {
132 let decimal_dtype = *partial
133 .return_dtype
134 .as_decimal_opt()
135 .vortex_expect("return dtype must be decimal");
136 Scalar::decimal(*v, decimal_dtype, Nullability::Nullable)
137 }
138 };
139
140 partial.current = Some(make_zero_state(&partial.return_dtype));
142
143 Ok(result)
144 }
145
146 #[inline]
147 fn is_saturated(&self, partial: &Self::Partial) -> bool {
148 partial.current.is_none()
149 }
150
151 fn accumulate(
152 &self,
153 partial: &mut Self::Partial,
154 batch: &Columnar,
155 _ctx: &mut ExecutionCtx,
156 ) -> VortexResult<()> {
157 let mut inner = match partial.current.take() {
158 Some(inner) => inner,
159 None => return Ok(()),
160 };
161
162 let result = match batch {
163 Columnar::Canonical(c) => match c {
164 Canonical::Primitive(p) => accumulate_primitive(&mut inner, p),
165 Canonical::Bool(b) => accumulate_bool(&mut inner, b),
166 Canonical::Decimal(d) => accumulate_decimal(&mut inner, d),
167 _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()),
168 },
169 Columnar::Constant(c) => accumulate_constant(&mut inner, c),
170 };
171
172 match result {
173 Ok(false) => partial.current = Some(inner),
174 Ok(true) => {} Err(e) => {
176 partial.current = Some(inner);
177 return Err(e);
178 }
179 }
180 Ok(())
181 }
182
183 fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
184 Ok(partials)
185 }
186
187 fn finalize_scalar(&self, partial: Scalar) -> VortexResult<Scalar> {
188 Ok(partial)
189 }
190}
191
192pub struct SumPartial {
195 return_dtype: DType,
196 current: Option<SumState>,
198}
199
200pub enum SumState {
205 Unsigned(u64),
206 Signed(i64),
207 Float(f64),
208 Decimal(DecimalValue),
209}
210
211fn make_zero_state(return_dtype: &DType) -> SumState {
212 match return_dtype {
213 DType::Primitive(ptype, _) => match ptype {
214 PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0),
215 PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0),
216 PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0),
217 },
218 DType::Decimal(decimal, _) => SumState::Decimal(DecimalValue::zero(decimal)),
219 _ => vortex_panic!("Unsupported sum type"),
220 }
221}
222
223#[inline(always)]
225fn checked_add_u64(acc: &mut u64, val: u64) -> bool {
226 match acc.checked_add(val) {
227 Some(r) => {
228 *acc = r;
229 false
230 }
231 None => true,
232 }
233}
234
235#[inline(always)]
237fn checked_add_i64(acc: &mut i64, val: i64) -> bool {
238 match acc.checked_add(val) {
239 Some(r) => {
240 *acc = r;
241 false
242 }
243 None => true,
244 }
245}
246
247fn accumulate_primitive(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult<bool> {
248 let mask = p.validity_mask()?;
249 match mask.bit_buffer() {
250 AllOr::None => Ok(false),
251 AllOr::All => accumulate_primitive_all(inner, p),
252 AllOr::Some(validity) => accumulate_primitive_valid(inner, p, validity),
253 }
254}
255
256fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult<bool> {
257 match inner {
258 SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(),
259 unsigned: |T| {
260 for &v in p.as_slice::<T>() {
261 if checked_add_u64(acc, v.to_u64().vortex_expect("unsigned to u64")) {
262 return Ok(true);
263 }
264 }
265 Ok(false)
266 },
267 signed: |_T| { vortex_panic!("unsigned sum state with signed input") },
268 floating: |_T| { vortex_panic!("unsigned sum state with float input") }
269 ),
270 SumState::Signed(acc) => match_each_native_ptype!(p.ptype(),
271 unsigned: |_T| { vortex_panic!("signed sum state with unsigned input") },
272 signed: |T| {
273 for &v in p.as_slice::<T>() {
274 if checked_add_i64(acc, v.to_i64().vortex_expect("signed to i64")) {
275 return Ok(true);
276 }
277 }
278 Ok(false)
279 },
280 floating: |_T| { vortex_panic!("signed sum state with float input") }
281 ),
282 SumState::Float(acc) => match_each_native_ptype!(p.ptype(),
283 unsigned: |_T| { vortex_panic!("float sum state with unsigned input") },
284 signed: |_T| { vortex_panic!("float sum state with signed input") },
285 floating: |T| {
286 for &v in p.as_slice::<T>() {
287 *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64");
288 }
289 Ok(false)
290 }
291 ),
292 SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"),
293 }
294}
295
296fn accumulate_primitive_valid(
297 inner: &mut SumState,
298 p: &PrimitiveArray,
299 validity: &vortex_buffer::BitBuffer,
300) -> VortexResult<bool> {
301 match inner {
302 SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(),
303 unsigned: |T| {
304 for (&v, valid) in p.as_slice::<T>().iter().zip_eq(validity.iter()) {
305 if valid && checked_add_u64(acc, v.to_u64().vortex_expect("unsigned to u64")) {
306 return Ok(true);
307 }
308 }
309 Ok(false)
310 },
311 signed: |_T| { vortex_panic!("unsigned sum state with signed input") },
312 floating: |_T| { vortex_panic!("unsigned sum state with float input") }
313 ),
314 SumState::Signed(acc) => match_each_native_ptype!(p.ptype(),
315 unsigned: |_T| { vortex_panic!("signed sum state with unsigned input") },
316 signed: |T| {
317 for (&v, valid) in p.as_slice::<T>().iter().zip_eq(validity.iter()) {
318 if valid && checked_add_i64(acc, v.to_i64().vortex_expect("signed to i64")) {
319 return Ok(true);
320 }
321 }
322 Ok(false)
323 },
324 floating: |_T| { vortex_panic!("signed sum state with float input") }
325 ),
326 SumState::Float(acc) => match_each_native_ptype!(p.ptype(),
327 unsigned: |_T| { vortex_panic!("float sum state with unsigned input") },
328 signed: |_T| { vortex_panic!("float sum state with signed input") },
329 floating: |T| {
330 for (&v, valid) in p.as_slice::<T>().iter().zip_eq(validity.iter()) {
331 if valid {
332 *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64");
333 }
334 }
335 Ok(false)
336 }
337 ),
338 SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"),
339 }
340}
341
342fn accumulate_bool(inner: &mut SumState, b: &BoolArray) -> VortexResult<bool> {
343 let SumState::Unsigned(acc) = inner else {
344 vortex_panic!("expected unsigned sum state for bool input");
345 };
346
347 let mask = b.validity_mask()?;
348 let true_count = match mask.bit_buffer() {
349 AllOr::None => return Ok(false),
350 AllOr::All => b.to_bit_buffer().true_count() as u64,
351 AllOr::Some(validity) => b.to_bit_buffer().bitand(validity).true_count() as u64,
352 };
353
354 Ok(checked_add_u64(acc, true_count))
355}
356
357fn accumulate_constant(inner: &mut SumState, c: &ConstantArray) -> VortexResult<bool> {
361 let scalar = c.scalar();
362 if scalar.is_null() || c.is_empty() {
363 return Ok(false);
364 }
365 let len = c.len();
366
367 match scalar.dtype() {
368 DType::Bool(_) => {
369 let SumState::Unsigned(acc) = inner else {
370 vortex_panic!("expected unsigned sum state for bool input");
371 };
372 let val = scalar
373 .as_bool()
374 .value()
375 .ok_or_else(|| vortex_err!("Expected non-null bool scalar for sum"))?;
376 if val {
377 Ok(checked_add_u64(acc, len as u64))
378 } else {
379 Ok(false)
380 }
381 }
382 DType::Primitive(..) => {
383 let pvalue = scalar
384 .as_primitive()
385 .pvalue()
386 .ok_or_else(|| vortex_err!("Expected non-null primitive scalar for sum"))?;
387 match inner {
388 SumState::Unsigned(acc) => {
389 let val = pvalue.cast::<u64>()?;
390 match val.checked_mul(len as u64) {
391 Some(product) => Ok(checked_add_u64(acc, product)),
392 None => Ok(true),
393 }
394 }
395 SumState::Signed(acc) => {
396 let val = pvalue.cast::<i64>()?;
397 match i64::try_from(len).ok().and_then(|l| val.checked_mul(l)) {
398 Some(product) => Ok(checked_add_i64(acc, product)),
399 None => Ok(true),
400 }
401 }
402 SumState::Float(acc) => {
403 let val = pvalue.cast::<f64>()?;
404 *acc += val * len as f64;
405 Ok(false)
406 }
407 SumState::Decimal(_) => {
408 vortex_panic!("decimal sum state with primitive input")
409 }
410 }
411 }
412 DType::Decimal(..) => {
413 let SumState::Decimal(acc) = inner else {
414 vortex_panic!("expected decimal sum state for decimal input");
415 };
416 let val = scalar
417 .as_decimal()
418 .decimal_value()
419 .ok_or_else(|| vortex_err!("Expected non-null decimal scalar for sum"))?;
420 let len_decimal = DecimalValue::from(len as i128);
421 match val.checked_mul(&len_decimal) {
422 Some(product) => match acc.checked_add(&product) {
423 Some(r) => {
424 *acc = r;
425 Ok(false)
426 }
427 None => Ok(true),
428 },
429 None => Ok(true),
430 }
431 }
432 _ => vortex_bail!("Unsupported constant type for sum: {}", scalar.dtype()),
433 }
434}
435
436fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> VortexResult<bool> {
439 let SumState::Decimal(acc) = inner else {
440 vortex_panic!("expected decimal sum state for decimal input");
441 };
442
443 let mask = d.validity_mask()?;
444 match mask.bit_buffer() {
445 AllOr::None => Ok(false),
446 AllOr::All => match_each_decimal_value_type!(d.values_type(), |T| {
447 for &v in d.buffer::<T>().iter() {
448 match acc.checked_add(&DecimalValue::from(v)) {
449 Some(r) => *acc = r,
450 None => return Ok(true),
451 }
452 }
453 Ok(false)
454 }),
455 AllOr::Some(validity) => match_each_decimal_value_type!(d.values_type(), |T| {
456 for (&v, valid) in d.buffer::<T>().iter().zip_eq(validity.iter()) {
457 if valid {
458 match acc.checked_add(&DecimalValue::from(v)) {
459 Some(r) => *acc = r,
460 None => return Ok(true),
461 }
462 }
463 }
464 Ok(false)
465 }),
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use vortex_buffer::buffer;
472 use vortex_error::VortexResult;
473 use vortex_session::VortexSession;
474
475 use crate::ArrayRef;
476 use crate::IntoArray;
477 use crate::aggregate_fn::Accumulator;
478 use crate::aggregate_fn::AggregateFnVTable;
479 use crate::aggregate_fn::DynAccumulator;
480 use crate::aggregate_fn::DynGroupedAccumulator;
481 use crate::aggregate_fn::EmptyOptions;
482 use crate::aggregate_fn::GroupedAccumulator;
483 use crate::aggregate_fn::fns::sum::Sum;
484 use crate::arrays::BoolArray;
485 use crate::arrays::FixedSizeListArray;
486 use crate::arrays::PrimitiveArray;
487 use crate::assert_arrays_eq;
488 use crate::dtype::DType;
489 use crate::dtype::Nullability;
490 use crate::dtype::PType;
491 use crate::scalar::Scalar;
492 use crate::validity::Validity;
493
494 fn session() -> VortexSession {
495 VortexSession::empty()
496 }
497
498 fn run_sum(batch: &ArrayRef) -> VortexResult<Scalar> {
499 let mut acc = Accumulator::try_new(Sum, EmptyOptions, batch.dtype().clone(), session())?;
500 acc.accumulate(batch)?;
501 acc.finish()
502 }
503
504 #[test]
507 fn sum_i32() -> VortexResult<()> {
508 let arr = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
509 let result = run_sum(&arr)?;
510 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(10));
511 Ok(())
512 }
513
514 #[test]
515 fn sum_u8() -> VortexResult<()> {
516 let arr = PrimitiveArray::new(buffer![10u8, 20, 30], Validity::NonNullable).into_array();
517 let result = run_sum(&arr)?;
518 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(60));
519 Ok(())
520 }
521
522 #[test]
523 fn sum_f64() -> VortexResult<()> {
524 let arr =
525 PrimitiveArray::new(buffer![1.5f64, 2.5, 3.0], Validity::NonNullable).into_array();
526 let result = run_sum(&arr)?;
527 assert_eq!(result.as_primitive().typed_value::<f64>(), Some(7.0));
528 Ok(())
529 }
530
531 #[test]
532 fn sum_with_nulls() -> VortexResult<()> {
533 let arr = PrimitiveArray::from_option_iter([Some(2i32), None, Some(4)]).into_array();
534 let result = run_sum(&arr)?;
535 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(6));
536 Ok(())
537 }
538
539 #[test]
540 fn sum_all_null() -> VortexResult<()> {
541 let arr = PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array();
543 let result = run_sum(&arr)?;
544 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(0));
545 Ok(())
546 }
547
548 #[test]
551 fn sum_empty_produces_zero() -> VortexResult<()> {
552 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
553 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
554 let result = acc.finish()?;
555 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(0));
556 Ok(())
557 }
558
559 #[test]
560 fn sum_empty_f64_produces_zero() -> VortexResult<()> {
561 let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
562 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
563 let result = acc.finish()?;
564 assert_eq!(result.as_primitive().typed_value::<f64>(), Some(0.0));
565 Ok(())
566 }
567
568 #[test]
571 fn sum_multi_batch() -> VortexResult<()> {
572 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
573 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
574
575 let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
576 acc.accumulate(&batch1)?;
577
578 let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
579 acc.accumulate(&batch2)?;
580
581 let result = acc.finish()?;
582 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(48));
583 Ok(())
584 }
585
586 #[test]
587 fn sum_finish_resets_state() -> VortexResult<()> {
588 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
589 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
590
591 let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
592 acc.accumulate(&batch1)?;
593 let result1 = acc.finish()?;
594 assert_eq!(result1.as_primitive().typed_value::<i64>(), Some(30));
595
596 let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
597 acc.accumulate(&batch2)?;
598 let result2 = acc.finish()?;
599 assert_eq!(result2.as_primitive().typed_value::<i64>(), Some(18));
600 Ok(())
601 }
602
603 #[test]
606 fn sum_state_merge() -> VortexResult<()> {
607 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
608 let mut state = Sum.empty_partial(&EmptyOptions, &dtype)?;
609
610 let scalar1 = Scalar::primitive(100i64, Nullability::Nullable);
611 Sum.combine_partials(&mut state, scalar1)?;
612
613 let scalar2 = Scalar::primitive(50i64, Nullability::Nullable);
614 Sum.combine_partials(&mut state, scalar2)?;
615
616 let result = Sum.flush(&mut state)?;
617 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(150));
618 Ok(())
619 }
620
621 #[test]
624 fn sum_checked_overflow() -> VortexResult<()> {
625 let arr = PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array();
626 let result = run_sum(&arr)?;
627 assert!(result.is_null());
628 Ok(())
629 }
630
631 #[test]
632 fn sum_checked_overflow_is_saturated() -> VortexResult<()> {
633 let dtype = DType::Primitive(PType::I64, Nullability::NonNullable);
634 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
635 assert!(!acc.is_saturated());
636
637 let batch =
638 PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array();
639 acc.accumulate(&batch)?;
640 assert!(acc.is_saturated());
641
642 drop(acc.finish()?);
644 assert!(!acc.is_saturated());
645 Ok(())
646 }
647
648 #[test]
651 fn sum_bool_all_true() -> VortexResult<()> {
652 let arr: BoolArray = [true, true, true].into_iter().collect();
653 let result = run_sum(&arr.into_array())?;
654 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
655 Ok(())
656 }
657
658 #[test]
659 fn sum_bool_mixed() -> VortexResult<()> {
660 let arr: BoolArray = [true, false, true, false, true].into_iter().collect();
661 let result = run_sum(&arr.into_array())?;
662 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
663 Ok(())
664 }
665
666 #[test]
667 fn sum_bool_all_false() -> VortexResult<()> {
668 let arr: BoolArray = [false, false, false].into_iter().collect();
669 let result = run_sum(&arr.into_array())?;
670 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
671 Ok(())
672 }
673
674 #[test]
675 fn sum_bool_with_nulls() -> VortexResult<()> {
676 let arr = BoolArray::from_iter([Some(true), None, Some(true), Some(false)]);
677 let result = run_sum(&arr.into_array())?;
678 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(2));
679 Ok(())
680 }
681
682 #[test]
683 fn sum_bool_all_null() -> VortexResult<()> {
684 let arr = BoolArray::from_iter([None::<bool>, None, None]);
686 let result = run_sum(&arr.into_array())?;
687 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
688 Ok(())
689 }
690
691 #[test]
692 fn sum_bool_empty_produces_zero() -> VortexResult<()> {
693 let dtype = DType::Bool(Nullability::NonNullable);
694 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
695 let result = acc.finish()?;
696 assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
697 Ok(())
698 }
699
700 #[test]
701 fn sum_bool_finish_resets_state() -> VortexResult<()> {
702 let dtype = DType::Bool(Nullability::NonNullable);
703 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
704
705 let batch1: BoolArray = [true, true, false].into_iter().collect();
706 acc.accumulate(&batch1.into_array())?;
707 let result1 = acc.finish()?;
708 assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(2));
709
710 let batch2: BoolArray = [false, true].into_iter().collect();
711 acc.accumulate(&batch2.into_array())?;
712 let result2 = acc.finish()?;
713 assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(1));
714 Ok(())
715 }
716
717 #[test]
718 fn sum_bool_return_dtype() -> VortexResult<()> {
719 let dtype = Sum.return_dtype(&EmptyOptions, &DType::Bool(Nullability::NonNullable))?;
720 assert_eq!(dtype, DType::Primitive(PType::U64, Nullability::Nullable));
721 Ok(())
722 }
723
724 fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult<ArrayRef> {
727 let mut acc =
728 GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone(), session())?;
729 acc.accumulate_list(groups)?;
730 acc.finish()
731 }
732
733 #[test]
734 fn grouped_sum_fixed_size_list() -> VortexResult<()> {
735 let elements =
737 PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array();
738 let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
739
740 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
741 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
742
743 let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array();
744 assert_arrays_eq!(&result, &expected);
745 Ok(())
746 }
747
748 #[test]
749 fn grouped_sum_with_null_elements() -> VortexResult<()> {
750 let elements =
752 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)])
753 .into_array();
754 let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
755
756 let elem_dtype = DType::Primitive(PType::I32, Nullability::Nullable);
757 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
758
759 let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array();
760 assert_arrays_eq!(&result, &expected);
761 Ok(())
762 }
763
764 #[test]
765 fn grouped_sum_with_null_group() -> VortexResult<()> {
766 let elements =
768 PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable)
769 .into_array();
770 let validity = Validity::from_iter([true, false, true]);
771 let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?;
772
773 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
774 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
775
776 let expected =
777 PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array();
778 assert_arrays_eq!(&result, &expected);
779 Ok(())
780 }
781
782 #[test]
783 fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> {
784 let elements =
786 PrimitiveArray::from_option_iter([None::<i32>, None, Some(3), Some(4)]).into_array();
787 let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?;
788
789 let elem_dtype = DType::Primitive(PType::I32, Nullability::Nullable);
790 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
791
792 let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array();
793 assert_arrays_eq!(&result, &expected);
794 Ok(())
795 }
796
797 #[test]
798 fn grouped_sum_bool() -> VortexResult<()> {
799 let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect();
801 let groups =
802 FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?;
803
804 let elem_dtype = DType::Bool(Nullability::NonNullable);
805 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
806
807 let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array();
808 assert_arrays_eq!(&result, &expected);
809 Ok(())
810 }
811
812 #[test]
813 fn grouped_sum_finish_resets() -> VortexResult<()> {
814 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
815 let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype, session())?;
816
817 let elements1 =
819 PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
820 let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?;
821 acc.accumulate_list(&groups1.into_array())?;
822 let result1 = acc.finish()?;
823
824 let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array();
825 assert_arrays_eq!(&result1, &expected1);
826
827 let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
829 let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?;
830 acc.accumulate_list(&groups2.into_array())?;
831 let result2 = acc.finish()?;
832
833 let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array();
834 assert_arrays_eq!(&result2, &expected2);
835 Ok(())
836 }
837}