1mod bool;
5mod constant;
6mod decimal;
7mod grouped;
8mod primitive;
9pub(crate) use grouped::PrimitiveGroupedSumEncodingKernel;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15use vortex_session::VortexSession;
16use vortex_session::registry::CachedId;
17
18use self::bool::accumulate_bool;
19use self::constant::multiply_constant;
20use self::decimal::accumulate_decimal;
21use self::primitive::accumulate_primitive;
22use crate::ArrayRef;
23use crate::Canonical;
24use crate::Columnar;
25use crate::ExecutionCtx;
26use crate::aggregate_fn::Accumulator;
27use crate::aggregate_fn::AggregateFnId;
28use crate::aggregate_fn::AggregateFnVTable;
29use crate::aggregate_fn::DynAccumulator;
30use crate::aggregate_fn::NumericalAggregateOpts;
31use crate::dtype::DType;
32use crate::dtype::DecimalDType;
33use crate::dtype::MAX_PRECISION;
34use crate::dtype::Nullability;
35use crate::dtype::PType;
36use crate::expr::stats::Precision;
37use crate::expr::stats::Stat;
38use crate::expr::stats::StatsProvider;
39use crate::expr::stats::StatsProviderExt;
40use crate::scalar::DecimalValue;
41use crate::scalar::Scalar;
42
43pub fn sum(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
47 if let Precision::Exact(sum_scalar) = array.statistics().get(Stat::Sum) {
49 return Ok(sum_scalar);
50 }
51
52 let mut acc = Accumulator::try_new(
55 Sum,
56 NumericalAggregateOpts::default(),
57 array.dtype().clone(),
58 )?;
59 acc.accumulate(array, ctx)?;
60 let result = acc.finish()?;
61
62 if let Some(val) = result.value().cloned() {
64 array.statistics().set(Stat::Sum, Precision::Exact(val));
65 }
66
67 Ok(result)
68}
69
70#[derive(Clone, Debug)]
78pub struct Sum;
79
80impl AggregateFnVTable for Sum {
81 type Options = NumericalAggregateOpts;
82 type Partial = SumPartial;
83
84 fn id(&self) -> AggregateFnId {
85 static ID: CachedId = CachedId::new("vortex.sum");
86 *ID
87 }
88
89 fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
90 Ok(Some(options.serialize()))
91 }
92
93 fn deserialize(
94 &self,
95 metadata: &[u8],
96 _session: &VortexSession,
97 ) -> VortexResult<Self::Options> {
98 NumericalAggregateOpts::deserialize(metadata)
99 }
100
101 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
102 use Nullability::Nullable;
105
106 Some(match input_dtype {
107 DType::Bool(_) => DType::Primitive(PType::U64, Nullable),
108 DType::Primitive(ptype, _) => match ptype {
109 PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
110 DType::Primitive(PType::U64, Nullable)
111 }
112 PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
113 DType::Primitive(PType::I64, Nullable)
114 }
115 PType::F16 | PType::F32 | PType::F64 => {
116 DType::Primitive(PType::F64, Nullable)
118 }
119 },
120 DType::Decimal(decimal_dtype, _) => {
121 let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10);
125 DType::Decimal(
126 DecimalDType::new(precision, decimal_dtype.scale()),
127 Nullable,
128 )
129 }
130 _ => return None,
132 })
133 }
134
135 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
136 self.return_dtype(options, input_dtype)
137 }
138
139 fn empty_partial(
140 &self,
141 options: &Self::Options,
142 input_dtype: &DType,
143 ) -> VortexResult<Self::Partial> {
144 let return_dtype = self
145 .return_dtype(options, input_dtype)
146 .ok_or_else(|| vortex_err!("Unsupported sum dtype: {}", input_dtype))?;
147 let initial = make_zero_state(&return_dtype);
148
149 Ok(SumPartial {
150 return_dtype,
151 current: Some(initial),
152 skip_nans: options.skip_nans,
153 })
154 }
155
156 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
157 if other.is_null() {
158 partial.current = None;
160 return Ok(());
161 }
162 let Some(ref mut inner) = partial.current else {
163 return Ok(());
164 };
165 let saturated = match inner {
166 SumState::Unsigned(acc) => {
167 let val = other
168 .as_primitive()
169 .typed_value::<u64>()
170 .vortex_expect("checked non-null");
171 checked_add_u64(acc, val)
172 }
173 SumState::Signed(acc) => {
174 let val = other
175 .as_primitive()
176 .typed_value::<i64>()
177 .vortex_expect("checked non-null");
178 checked_add_i64(acc, val)
179 }
180 SumState::Float(acc) => {
181 let val = other
182 .as_primitive()
183 .typed_value::<f64>()
184 .vortex_expect("checked non-null");
185 *acc += val;
186 false
187 }
188 SumState::Decimal { value, dtype } => {
189 let val = other
190 .as_decimal()
191 .decimal_value()
192 .vortex_expect("checked non-null");
193 match value.checked_add(&val) {
194 Some(r) => {
195 *value = r;
196 !value.fits_in_precision(*dtype)
197 }
198 None => true,
199 }
200 }
201 };
202 if saturated {
203 partial.current = None;
204 }
205 Ok(())
206 }
207
208 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
209 Ok(match &partial.current {
210 None => Scalar::null(partial.return_dtype.as_nullable()),
211 Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable),
212 Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable),
213 Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable),
214 Some(SumState::Decimal { value, .. }) => {
215 let decimal_dtype = *partial
216 .return_dtype
217 .as_decimal_opt()
218 .vortex_expect("return dtype must be decimal");
219 Scalar::decimal(*value, decimal_dtype, Nullability::Nullable)
220 }
221 })
222 }
223
224 fn reset(&self, partial: &mut Self::Partial) {
225 partial.current = Some(make_zero_state(&partial.return_dtype));
226 }
227
228 #[inline]
229 fn is_saturated(&self, partial: &Self::Partial) -> bool {
230 match partial.current.as_ref() {
231 None => true,
232 Some(SumState::Float(v)) => v.is_nan(),
233 Some(_) => false,
234 }
235 }
236
237 fn try_accumulate(
238 &self,
239 partial: &mut Self::Partial,
240 batch: &ArrayRef,
241 _ctx: &mut ExecutionCtx,
242 ) -> VortexResult<bool> {
243 if partial.skip_nans || !matches!(partial.current, Some(SumState::Float(_))) {
246 return Ok(false);
247 }
248 match batch.statistics().get_as::<u64>(Stat::NaNCount) {
249 Precision::Exact(0) => {
250 if let Precision::Exact(sum) = batch.statistics().get(Stat::Sum) {
253 let sum = if sum.dtype() == &partial.return_dtype {
254 sum
255 } else {
256 sum.cast(&partial.return_dtype)?
257 };
258 self.combine_partials(partial, sum)?;
259 return Ok(true);
260 }
261 Ok(false)
262 }
263 Precision::Exact(_) => {
264 if let Some(SumState::Float(acc)) = partial.current.as_mut() {
266 *acc = f64::NAN;
267 }
268 Ok(true)
269 }
270 _ => Ok(false),
271 }
272 }
273
274 fn accumulate(
275 &self,
276 partial: &mut Self::Partial,
277 batch: &Columnar,
278 ctx: &mut ExecutionCtx,
279 ) -> VortexResult<()> {
280 if let Columnar::Constant(c) = batch {
282 if partial.skip_nans && c.scalar().as_primitive_opt().is_some_and(|p| p.is_nan()) {
284 return Ok(());
285 }
286 if let Some(product) = multiply_constant(c.scalar(), c.len(), &partial.return_dtype)? {
287 self.combine_partials(partial, product)?;
288 }
289 return Ok(());
290 }
291
292 let skip_nans = partial.skip_nans;
293 let mut inner = match partial.current.take() {
294 Some(inner) => inner,
295 None => return Ok(()),
296 };
297
298 let result = match batch {
299 Columnar::Canonical(c) => match c {
300 Canonical::Primitive(p) => accumulate_primitive(&mut inner, p, ctx, skip_nans),
301 Canonical::Bool(b) => accumulate_bool(&mut inner, b, ctx),
302 Canonical::Decimal(d) => accumulate_decimal(&mut inner, d, ctx),
303 _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()),
304 },
305 Columnar::Constant(_) => unreachable!(),
306 };
307
308 match result {
309 Ok(false) => partial.current = Some(inner),
310 Ok(true) => {} Err(e) => {
312 partial.current = Some(inner);
313 return Err(e);
314 }
315 }
316 Ok(())
317 }
318
319 fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
320 Ok(partials)
321 }
322
323 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
324 self.to_scalar(partial)
325 }
326}
327
328pub struct SumPartial {
331 return_dtype: DType,
332 current: Option<SumState>,
334 skip_nans: bool,
336}
337
338pub enum SumState {
342 Unsigned(u64),
343 Signed(i64),
344 Float(f64),
345 Decimal {
346 value: DecimalValue,
347 dtype: DecimalDType,
348 },
349}
350
351fn make_zero_state(return_dtype: &DType) -> SumState {
352 match return_dtype {
353 DType::Primitive(ptype, _) => match ptype {
354 PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0),
355 PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0),
356 PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0),
357 },
358 DType::Decimal(decimal, _) => SumState::Decimal {
359 value: DecimalValue::zero(decimal),
360 dtype: *decimal,
361 },
362 _ => vortex_panic!("Unsupported sum type"),
363 }
364}
365
366#[inline(always)]
368fn checked_add_u64(acc: &mut u64, val: u64) -> bool {
369 match acc.checked_add(val) {
370 Some(r) => {
371 *acc = r;
372 false
373 }
374 None => true,
375 }
376}
377
378#[inline(always)]
380fn checked_add_i64(acc: &mut i64, val: i64) -> bool {
381 match acc.checked_add(val) {
382 Some(r) => {
383 *acc = r;
384 false
385 }
386 None => true,
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use num_traits::CheckedAdd;
393 use vortex_buffer::buffer;
394 use vortex_error::VortexExpect;
395 use vortex_error::VortexResult;
396
397 use crate::ArrayRef;
398 use crate::IntoArray;
399 use crate::VortexSessionExecute;
400 use crate::aggregate_fn::Accumulator;
401 use crate::aggregate_fn::AggregateFnVTable;
402 use crate::aggregate_fn::DynAccumulator;
403 use crate::aggregate_fn::DynGroupedAccumulator;
404 use crate::aggregate_fn::GroupedAccumulator;
405 use crate::aggregate_fn::NumericalAggregateOpts;
406 use crate::aggregate_fn::fns::sum::Sum;
407 use crate::aggregate_fn::fns::sum::sum;
408 use crate::array_session;
409 use crate::arrays::BoolArray;
410 use crate::arrays::ChunkedArray;
411 use crate::arrays::ConstantArray;
412 use crate::arrays::DecimalArray;
413 use crate::arrays::FixedSizeListArray;
414 use crate::arrays::ListViewArray;
415 use crate::arrays::PrimitiveArray;
416 use crate::assert_arrays_eq;
417 use crate::dtype::DType;
418 use crate::dtype::DecimalDType;
419 use crate::dtype::Nullability;
420 use crate::dtype::Nullability::Nullable;
421 use crate::dtype::PType;
422 use crate::dtype::i256;
423 use crate::expr::stats::Precision;
424 use crate::expr::stats::Stat;
425 use crate::expr::stats::StatsProvider;
426 use crate::scalar::DecimalValue;
427 use crate::scalar::NumericOperator;
428 use crate::scalar::Scalar;
429 use crate::validity::Validity;
430
431 fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult<Scalar> {
433 let mut ctx = array_session().create_execution_ctx();
434 if accumulator.is_null() {
435 return Ok(accumulator.clone());
436 }
437 if accumulator.is_zero() == Some(true) {
438 return sum(array, &mut ctx);
439 }
440
441 let sum_dtype = Stat::Sum.dtype(array.dtype()).ok_or_else(|| {
442 vortex_error::vortex_err!("Sum not supported for dtype: {}", array.dtype())
443 })?;
444
445 if !matches!(&sum_dtype, DType::Primitive(p, _) if p.is_float())
447 && let Precision::Exact(sum_scalar) = array.statistics().get(Stat::Sum)
448 {
449 return add_scalars(&sum_dtype, &sum_scalar, accumulator);
450 }
451
452 let array_sum = sum(array, &mut ctx)?;
454
455 add_scalars(&sum_dtype, &array_sum, accumulator)
457 }
458
459 fn add_scalars(sum_dtype: &DType, lhs: &Scalar, rhs: &Scalar) -> VortexResult<Scalar> {
461 if lhs.is_null() || rhs.is_null() {
462 return Ok(Scalar::null(sum_dtype.as_nullable()));
463 }
464
465 Ok(match sum_dtype {
466 DType::Primitive(ptype, _) if ptype.is_float() => {
467 let lhs_val = f64::try_from(lhs)?;
468 let rhs_val = f64::try_from(rhs)?;
469 Scalar::primitive(lhs_val + rhs_val, Nullable)
470 }
471 DType::Primitive(..) => lhs
472 .as_primitive()
473 .checked_add(&rhs.as_primitive())
474 .map(Scalar::from)
475 .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())),
476 DType::Decimal(..) => lhs
477 .as_decimal()
478 .checked_binary_numeric(&rhs.as_decimal(), NumericOperator::Add)
479 .map(Scalar::from)
480 .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())),
481 _ => unreachable!("Sum will always be a decimal or a primitive dtype"),
482 })
483 }
484
485 #[test]
488 fn sum_multi_batch() -> VortexResult<()> {
489 let mut ctx = array_session().create_execution_ctx();
490 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
491 let mut acc = Accumulator::try_new(Sum, NumericalAggregateOpts::default(), dtype)?;
492
493 let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
494 acc.accumulate(&batch1, &mut ctx)?;
495
496 let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
497 acc.accumulate(&batch2, &mut ctx)?;
498
499 let result = acc.finish()?;
500 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(48));
501 Ok(())
502 }
503
504 #[test]
505 fn sum_finish_resets_state() -> VortexResult<()> {
506 let mut ctx = array_session().create_execution_ctx();
507 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
508 let mut acc = Accumulator::try_new(Sum, NumericalAggregateOpts::default(), dtype)?;
509
510 let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
511 acc.accumulate(&batch1, &mut ctx)?;
512 let result1 = acc.finish()?;
513 assert_eq!(result1.as_primitive().typed_value::<i64>(), Some(30));
514
515 let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
516 acc.accumulate(&batch2, &mut ctx)?;
517 let result2 = acc.finish()?;
518 assert_eq!(result2.as_primitive().typed_value::<i64>(), Some(18));
519 Ok(())
520 }
521
522 #[test]
525 fn sum_state_merge() -> VortexResult<()> {
526 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
527 let mut state = Sum.empty_partial(&NumericalAggregateOpts::default(), &dtype)?;
528
529 let scalar1 = Scalar::primitive(100i64, Nullable);
530 Sum.combine_partials(&mut state, scalar1)?;
531
532 let scalar2 = Scalar::primitive(50i64, Nullable);
533 Sum.combine_partials(&mut state, scalar2)?;
534
535 let result = Sum.to_scalar(&state)?;
536 Sum.reset(&mut state);
537 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(150));
538 Ok(())
539 }
540
541 #[test]
544 fn sum_stats() -> VortexResult<()> {
545 let array = ChunkedArray::try_new(
546 vec![
547 PrimitiveArray::from_iter([1, 1, 1]).into_array(),
548 PrimitiveArray::from_iter([2, 2, 2]).into_array(),
549 ],
550 DType::Primitive(PType::I32, Nullability::NonNullable),
551 )
552 .vortex_expect("operation should succeed in test");
553 let array = array.into_array();
554 sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullable))?;
556
557 let sum_without_acc = sum(&array, &mut array_session().create_execution_ctx())?;
558 assert_eq!(sum_without_acc, Scalar::primitive(9i64, Nullable));
559 Ok(())
560 }
561
562 #[test]
565 fn sum_constant_float_non_multiply() -> VortexResult<()> {
566 let acc = -2048669276050936500000000000f64;
567 let array = ConstantArray::new(6.1811675e16f64, 25);
568 let result = sum_with_accumulator(&array.into_array(), &Scalar::primitive(acc, Nullable))
569 .vortex_expect("operation should succeed in test");
570 assert_eq!(
571 f64::try_from(&result).vortex_expect("operation should succeed in test"),
572 -2048669274505644600000000000f64
573 );
574 Ok(())
575 }
576
577 fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult<ArrayRef> {
580 let mut acc = GroupedAccumulator::try_new(
581 Sum,
582 NumericalAggregateOpts::default(),
583 elem_dtype.clone(),
584 )?;
585 acc.accumulate_list(groups, &mut array_session().create_execution_ctx())?;
586 acc.finish()
587 }
588
589 #[test]
590 fn grouped_sum_fixed_size_list() -> VortexResult<()> {
591 let mut ctx = array_session().create_execution_ctx();
592 let elements =
593 PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array();
594 let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
595
596 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
597 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
598
599 let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array();
600 assert_arrays_eq!(&result, &expected, &mut ctx);
601 Ok(())
602 }
603
604 #[test]
605 fn grouped_sum_with_null_elements() -> VortexResult<()> {
606 let mut ctx = array_session().create_execution_ctx();
607 let elements =
608 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)])
609 .into_array();
610 let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
611
612 let elem_dtype = DType::Primitive(PType::I32, Nullable);
613 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
614
615 let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array();
616 assert_arrays_eq!(&result, &expected, &mut ctx);
617 Ok(())
618 }
619
620 #[test]
621 fn grouped_sum_with_null_group() -> VortexResult<()> {
622 let mut ctx = array_session().create_execution_ctx();
623 let elements =
624 PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable)
625 .into_array();
626 let validity = Validity::from_iter([true, false, true]);
627 let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?;
628
629 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
630 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
631
632 let expected =
633 PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array();
634 assert_arrays_eq!(&result, &expected, &mut ctx);
635 Ok(())
636 }
637
638 #[test]
639 fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> {
640 let mut ctx = array_session().create_execution_ctx();
641 let elements =
642 PrimitiveArray::from_option_iter([None::<i32>, None, Some(3), Some(4)]).into_array();
643 let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?;
644
645 let elem_dtype = DType::Primitive(PType::I32, Nullable);
646 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
647
648 let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array();
649 assert_arrays_eq!(&result, &expected, &mut ctx);
650 Ok(())
651 }
652
653 #[test]
654 fn grouped_sum_bool() -> VortexResult<()> {
655 let mut ctx = array_session().create_execution_ctx();
656 let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect();
657 let groups =
658 FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?;
659
660 let elem_dtype = DType::Bool(Nullability::NonNullable);
661 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
662
663 let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array();
664 assert_arrays_eq!(&result, &expected, &mut ctx);
665 Ok(())
666 }
667
668 #[test]
669 fn grouped_sum_finish_resets() -> VortexResult<()> {
670 let mut ctx = array_session().create_execution_ctx();
671 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
672 let mut acc =
673 GroupedAccumulator::try_new(Sum, NumericalAggregateOpts::default(), elem_dtype)?;
674
675 let elements1 =
676 PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
677 let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?;
678 acc.accumulate_list(&groups1.into_array(), &mut ctx)?;
679 let result1 = acc.finish()?;
680
681 let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array();
682 assert_arrays_eq!(&result1, &expected1, &mut ctx);
683
684 let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
685 let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?;
686 acc.accumulate_list(&groups2.into_array(), &mut ctx)?;
687 let result2 = acc.finish()?;
688
689 let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array();
690 assert_arrays_eq!(&result2, &expected2, &mut ctx);
691 Ok(())
692 }
693
694 #[test]
695 fn grouped_sum_listview_out_of_order_offsets_with_null_group() -> VortexResult<()> {
696 let mut ctx = array_session().create_execution_ctx();
697 let elements =
698 PrimitiveArray::new(buffer![100i32, 200, 300], Validity::NonNullable).into_array();
699 let offsets = PrimitiveArray::new(buffer![2i32, 0, 1], Validity::NonNullable).into_array();
700 let sizes = PrimitiveArray::new(buffer![1i32, 1, 1], Validity::NonNullable).into_array();
701 let validity = Validity::from_iter([true, false, true]);
702 let groups = ListViewArray::try_new(elements, offsets, sizes, validity)?.into_array();
703
704 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
705 let result = run_grouped_sum(&groups, &elem_dtype)?;
706
707 let expected =
709 PrimitiveArray::from_option_iter([Some(300i64), None, Some(200i64)]).into_array();
710 assert_arrays_eq!(&result, &expected, &mut ctx);
711 Ok(())
712 }
713
714 #[test]
717 fn sum_chunked_floats_with_nulls() -> VortexResult<()> {
718 let chunk1 =
719 PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]);
720 let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]);
721 let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]);
722 let dtype = chunk1.dtype().clone();
723 let chunked = ChunkedArray::try_new(
724 vec![
725 chunk1.into_array(),
726 chunk2.into_array(),
727 chunk3.into_array(),
728 ],
729 dtype,
730 )?;
731
732 let result = sum(
733 &chunked.into_array(),
734 &mut array_session().create_execution_ctx(),
735 )?;
736 assert_eq!(result.as_primitive().as_::<f64>(), Some(20.8));
737 Ok(())
738 }
739
740 #[test]
741 fn sum_chunked_floats_all_nulls_is_zero() -> VortexResult<()> {
742 let chunk1 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None, None]);
743 let chunk2 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None]);
744 let dtype = chunk1.dtype().clone();
745 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
746 let result = sum(
747 &chunked.into_array(),
748 &mut array_session().create_execution_ctx(),
749 )?;
750 assert_eq!(result, Scalar::primitive(0f64, Nullable));
751 Ok(())
752 }
753
754 #[test]
755 fn sum_chunked_floats_empty_chunks() -> VortexResult<()> {
756 let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]);
757 let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullable), 0);
758 let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]);
759 let dtype = chunk1.dtype().clone();
760 let chunked = ChunkedArray::try_new(
761 vec![
762 chunk1.into_array(),
763 chunk2.into_array(),
764 chunk3.into_array(),
765 ],
766 dtype,
767 )?;
768
769 let result = sum(
770 &chunked.into_array(),
771 &mut array_session().create_execution_ctx(),
772 )?;
773 assert_eq!(result.as_primitive().as_::<f64>(), Some(36.0));
774 Ok(())
775 }
776
777 #[test]
778 fn sum_chunked_int_almost_all_null() -> VortexResult<()> {
779 let chunk1 = PrimitiveArray::from_option_iter::<u32, _>(vec![Some(1)]);
780 let chunk2 = PrimitiveArray::from_option_iter::<u32, _>(vec![None]);
781 let dtype = chunk1.dtype().clone();
782 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
783
784 let result = sum(
785 &chunked.into_array(),
786 &mut array_session().create_execution_ctx(),
787 )?;
788 assert_eq!(result.as_primitive().as_::<u64>(), Some(1));
789 Ok(())
790 }
791
792 #[test]
793 fn sum_chunked_decimals() -> VortexResult<()> {
794 let decimal_dtype = DecimalDType::new(10, 2);
795 let chunk1 = DecimalArray::new(
796 buffer![100i32, 100i32, 100i32, 100i32, 100i32],
797 decimal_dtype,
798 Validity::AllValid,
799 );
800 let chunk2 = DecimalArray::new(
801 buffer![200i32, 200i32, 200i32],
802 decimal_dtype,
803 Validity::AllValid,
804 );
805 let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid);
806 let dtype = chunk1.dtype().clone();
807 let chunked = ChunkedArray::try_new(
808 vec![
809 chunk1.into_array(),
810 chunk2.into_array(),
811 chunk3.into_array(),
812 ],
813 dtype,
814 )?;
815
816 let result = sum(
817 &chunked.into_array(),
818 &mut array_session().create_execution_ctx(),
819 )?;
820 let decimal_result = result.as_decimal();
821 assert_eq!(
822 decimal_result.decimal_value(),
823 Some(DecimalValue::I256(i256::from_i128(1700)))
824 );
825 Ok(())
826 }
827
828 #[test]
829 fn sum_chunked_decimals_with_nulls() -> VortexResult<()> {
830 let decimal_dtype = DecimalDType::new(10, 2);
831 let chunk1 = DecimalArray::new(
832 buffer![100i32, 100i32, 100i32],
833 decimal_dtype,
834 Validity::AllValid,
835 );
836 let chunk2 = DecimalArray::new(
837 buffer![0i32, 0i32],
838 decimal_dtype,
839 Validity::from_iter([false, false]),
840 );
841 let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid);
842 let dtype = chunk1.dtype().clone();
843 let chunked = ChunkedArray::try_new(
844 vec![
845 chunk1.into_array(),
846 chunk2.into_array(),
847 chunk3.into_array(),
848 ],
849 dtype,
850 )?;
851
852 let result = sum(
853 &chunked.into_array(),
854 &mut array_session().create_execution_ctx(),
855 )?;
856 let decimal_result = result.as_decimal();
857 assert_eq!(
858 decimal_result.decimal_value(),
859 Some(DecimalValue::I256(i256::from_i128(700)))
860 );
861 Ok(())
862 }
863
864 #[test]
865 fn sum_chunked_decimals_large() -> VortexResult<()> {
866 let decimal_dtype = DecimalDType::new(3, 0);
867 let chunk1 = ConstantArray::new(
868 Scalar::decimal(
869 DecimalValue::I16(500),
870 decimal_dtype,
871 Nullability::NonNullable,
872 ),
873 1,
874 );
875 let chunk2 = ConstantArray::new(
876 Scalar::decimal(
877 DecimalValue::I16(600),
878 decimal_dtype,
879 Nullability::NonNullable,
880 ),
881 1,
882 );
883 let dtype = chunk1.dtype().clone();
884 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
885
886 let result = sum(
887 &chunked.into_array(),
888 &mut array_session().create_execution_ctx(),
889 )?;
890 let decimal_result = result.as_decimal();
891 assert_eq!(
892 decimal_result.decimal_value(),
893 Some(DecimalValue::I256(i256::from_i128(1100)))
894 );
895 assert_eq!(
896 result.dtype(),
897 &DType::Decimal(DecimalDType::new(13, 0), Nullable)
898 );
899 Ok(())
900 }
901}