1mod bool;
5mod constant;
6mod decimal;
7mod grouped;
8mod primitive;
9pub(crate) use grouped::PrimitiveGroupedSumEncodingKernel;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15
16use self::bool::accumulate_bool;
17use self::constant::multiply_constant;
18use self::decimal::accumulate_decimal;
19use self::primitive::accumulate_primitive;
20use crate::ArrayRef;
21use crate::Canonical;
22use crate::Columnar;
23use crate::ExecutionCtx;
24use crate::aggregate_fn::Accumulator;
25use crate::aggregate_fn::AggregateFnId;
26use crate::aggregate_fn::AggregateFnVTable;
27use crate::aggregate_fn::DynAccumulator;
28use crate::aggregate_fn::EmptyOptions;
29use crate::dtype::DType;
30use crate::dtype::DecimalDType;
31use crate::dtype::MAX_PRECISION;
32use crate::dtype::Nullability;
33use crate::dtype::PType;
34use crate::expr::stats::Precision;
35use crate::expr::stats::Stat;
36use crate::expr::stats::StatsProvider;
37use crate::scalar::DecimalValue;
38use crate::scalar::Scalar;
39
40pub fn sum(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
44 if let Precision::Exact(sum_scalar) = array.statistics().get(Stat::Sum) {
46 return Ok(sum_scalar);
47 }
48
49 let mut acc = Accumulator::try_new(Sum, EmptyOptions, array.dtype().clone())?;
52 acc.accumulate(array, ctx)?;
53 let result = acc.finish()?;
54
55 if let Some(val) = result.value().cloned() {
57 array.statistics().set(Stat::Sum, Precision::Exact(val));
58 }
59
60 Ok(result)
61}
62
63#[derive(Clone, Debug)]
68pub struct Sum;
69
70impl AggregateFnVTable for Sum {
71 type Options = EmptyOptions;
72 type Partial = SumPartial;
73
74 fn id(&self) -> AggregateFnId {
75 AggregateFnId::new("vortex.sum")
76 }
77
78 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
79 unimplemented!("Sum is not yet serializable");
80 }
81
82 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
83 use Nullability::Nullable;
86
87 Some(match input_dtype {
88 DType::Bool(_) => DType::Primitive(PType::U64, Nullable),
89 DType::Primitive(ptype, _) => match ptype {
90 PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
91 DType::Primitive(PType::U64, Nullable)
92 }
93 PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
94 DType::Primitive(PType::I64, Nullable)
95 }
96 PType::F16 | PType::F32 | PType::F64 => {
97 DType::Primitive(PType::F64, Nullable)
99 }
100 },
101 DType::Decimal(decimal_dtype, _) => {
102 let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10);
106 DType::Decimal(
107 DecimalDType::new(precision, decimal_dtype.scale()),
108 Nullable,
109 )
110 }
111 _ => return None,
113 })
114 }
115
116 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
117 self.return_dtype(options, input_dtype)
118 }
119
120 fn empty_partial(
121 &self,
122 options: &Self::Options,
123 input_dtype: &DType,
124 ) -> VortexResult<Self::Partial> {
125 let return_dtype = self
126 .return_dtype(options, input_dtype)
127 .ok_or_else(|| vortex_err!("Unsupported sum dtype: {}", input_dtype))?;
128 let initial = make_zero_state(&return_dtype);
129
130 Ok(SumPartial {
131 return_dtype,
132 current: Some(initial),
133 })
134 }
135
136 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
137 if other.is_null() {
138 partial.current = None;
140 return Ok(());
141 }
142 let Some(ref mut inner) = partial.current else {
143 return Ok(());
144 };
145 let saturated = match inner {
146 SumState::Unsigned(acc) => {
147 let val = other
148 .as_primitive()
149 .typed_value::<u64>()
150 .vortex_expect("checked non-null");
151 checked_add_u64(acc, val)
152 }
153 SumState::Signed(acc) => {
154 let val = other
155 .as_primitive()
156 .typed_value::<i64>()
157 .vortex_expect("checked non-null");
158 checked_add_i64(acc, val)
159 }
160 SumState::Float(acc) => {
161 let val = other
162 .as_primitive()
163 .typed_value::<f64>()
164 .vortex_expect("checked non-null");
165 *acc += val;
166 false
167 }
168 SumState::Decimal { value, dtype } => {
169 let val = other
170 .as_decimal()
171 .decimal_value()
172 .vortex_expect("checked non-null");
173 match value.checked_add(&val) {
174 Some(r) => {
175 *value = r;
176 !value.fits_in_precision(*dtype)
177 }
178 None => true,
179 }
180 }
181 };
182 if saturated {
183 partial.current = None;
184 }
185 Ok(())
186 }
187
188 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
189 Ok(match &partial.current {
190 None => Scalar::null(partial.return_dtype.as_nullable()),
191 Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable),
192 Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable),
193 Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable),
194 Some(SumState::Decimal { value, .. }) => {
195 let decimal_dtype = *partial
196 .return_dtype
197 .as_decimal_opt()
198 .vortex_expect("return dtype must be decimal");
199 Scalar::decimal(*value, decimal_dtype, Nullability::Nullable)
200 }
201 })
202 }
203
204 fn reset(&self, partial: &mut Self::Partial) {
205 partial.current = Some(make_zero_state(&partial.return_dtype));
206 }
207
208 #[inline]
209 fn is_saturated(&self, partial: &Self::Partial) -> bool {
210 match partial.current.as_ref() {
211 None => true,
212 Some(SumState::Float(v)) => v.is_nan(),
213 Some(_) => false,
214 }
215 }
216
217 fn accumulate(
218 &self,
219 partial: &mut Self::Partial,
220 batch: &Columnar,
221 ctx: &mut ExecutionCtx,
222 ) -> VortexResult<()> {
223 if let Columnar::Constant(c) = batch {
225 if let Some(product) = multiply_constant(c.scalar(), c.len(), &partial.return_dtype)? {
226 self.combine_partials(partial, product)?;
227 }
228 return Ok(());
229 }
230
231 let mut inner = match partial.current.take() {
232 Some(inner) => inner,
233 None => return Ok(()),
234 };
235
236 let result = match batch {
237 Columnar::Canonical(c) => match c {
238 Canonical::Primitive(p) => accumulate_primitive(&mut inner, p, ctx),
239 Canonical::Bool(b) => accumulate_bool(&mut inner, b, ctx),
240 Canonical::Decimal(d) => accumulate_decimal(&mut inner, d, ctx),
241 _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()),
242 },
243 Columnar::Constant(_) => unreachable!(),
244 };
245
246 match result {
247 Ok(false) => partial.current = Some(inner),
248 Ok(true) => {} Err(e) => {
250 partial.current = Some(inner);
251 return Err(e);
252 }
253 }
254 Ok(())
255 }
256
257 fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
258 Ok(partials)
259 }
260
261 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
262 self.to_scalar(partial)
263 }
264}
265
266pub struct SumPartial {
269 return_dtype: DType,
270 current: Option<SumState>,
272}
273
274pub enum SumState {
279 Unsigned(u64),
280 Signed(i64),
281 Float(f64),
282 Decimal {
283 value: DecimalValue,
284 dtype: DecimalDType,
285 },
286}
287
288fn make_zero_state(return_dtype: &DType) -> SumState {
289 match return_dtype {
290 DType::Primitive(ptype, _) => match ptype {
291 PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0),
292 PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0),
293 PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0),
294 },
295 DType::Decimal(decimal, _) => SumState::Decimal {
296 value: DecimalValue::zero(decimal),
297 dtype: *decimal,
298 },
299 _ => vortex_panic!("Unsupported sum type"),
300 }
301}
302
303#[inline(always)]
305fn checked_add_u64(acc: &mut u64, val: u64) -> bool {
306 match acc.checked_add(val) {
307 Some(r) => {
308 *acc = r;
309 false
310 }
311 None => true,
312 }
313}
314
315#[inline(always)]
317fn checked_add_i64(acc: &mut i64, val: i64) -> bool {
318 match acc.checked_add(val) {
319 Some(r) => {
320 *acc = r;
321 false
322 }
323 None => true,
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use num_traits::CheckedAdd;
330 use vortex_buffer::buffer;
331 use vortex_error::VortexExpect;
332 use vortex_error::VortexResult;
333
334 use crate::ArrayRef;
335 use crate::IntoArray;
336 use crate::LEGACY_SESSION;
337 use crate::VortexSessionExecute;
338 use crate::aggregate_fn::Accumulator;
339 use crate::aggregate_fn::AggregateFnVTable;
340 use crate::aggregate_fn::DynAccumulator;
341 use crate::aggregate_fn::DynGroupedAccumulator;
342 use crate::aggregate_fn::EmptyOptions;
343 use crate::aggregate_fn::GroupedAccumulator;
344 use crate::aggregate_fn::fns::sum::Sum;
345 use crate::aggregate_fn::fns::sum::sum;
346 use crate::arrays::BoolArray;
347 use crate::arrays::ChunkedArray;
348 use crate::arrays::ConstantArray;
349 use crate::arrays::DecimalArray;
350 use crate::arrays::FixedSizeListArray;
351 use crate::arrays::ListViewArray;
352 use crate::arrays::PrimitiveArray;
353 use crate::assert_arrays_eq;
354 use crate::dtype::DType;
355 use crate::dtype::DecimalDType;
356 use crate::dtype::Nullability;
357 use crate::dtype::Nullability::Nullable;
358 use crate::dtype::PType;
359 use crate::dtype::i256;
360 use crate::expr::stats::Precision;
361 use crate::expr::stats::Stat;
362 use crate::expr::stats::StatsProvider;
363 use crate::scalar::DecimalValue;
364 use crate::scalar::NumericOperator;
365 use crate::scalar::Scalar;
366 use crate::validity::Validity;
367
368 fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult<Scalar> {
370 let mut ctx = LEGACY_SESSION.create_execution_ctx();
371 if accumulator.is_null() {
372 return Ok(accumulator.clone());
373 }
374 if accumulator.is_zero() == Some(true) {
375 return sum(array, &mut ctx);
376 }
377
378 let sum_dtype = Stat::Sum.dtype(array.dtype()).ok_or_else(|| {
379 vortex_error::vortex_err!("Sum not supported for dtype: {}", array.dtype())
380 })?;
381
382 if !matches!(&sum_dtype, DType::Primitive(p, _) if p.is_float())
384 && let Precision::Exact(sum_scalar) = array.statistics().get(Stat::Sum)
385 {
386 return add_scalars(&sum_dtype, &sum_scalar, accumulator);
387 }
388
389 let array_sum = sum(array, &mut ctx)?;
391
392 add_scalars(&sum_dtype, &array_sum, accumulator)
394 }
395
396 fn add_scalars(sum_dtype: &DType, lhs: &Scalar, rhs: &Scalar) -> VortexResult<Scalar> {
398 if lhs.is_null() || rhs.is_null() {
399 return Ok(Scalar::null(sum_dtype.as_nullable()));
400 }
401
402 Ok(match sum_dtype {
403 DType::Primitive(ptype, _) if ptype.is_float() => {
404 let lhs_val = f64::try_from(lhs)?;
405 let rhs_val = f64::try_from(rhs)?;
406 Scalar::primitive(lhs_val + rhs_val, Nullable)
407 }
408 DType::Primitive(..) => lhs
409 .as_primitive()
410 .checked_add(&rhs.as_primitive())
411 .map(Scalar::from)
412 .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())),
413 DType::Decimal(..) => lhs
414 .as_decimal()
415 .checked_binary_numeric(&rhs.as_decimal(), NumericOperator::Add)
416 .map(Scalar::from)
417 .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())),
418 _ => unreachable!("Sum will always be a decimal or a primitive dtype"),
419 })
420 }
421
422 #[test]
425 fn sum_multi_batch() -> VortexResult<()> {
426 let mut ctx = LEGACY_SESSION.create_execution_ctx();
427 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
428 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?;
429
430 let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
431 acc.accumulate(&batch1, &mut ctx)?;
432
433 let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
434 acc.accumulate(&batch2, &mut ctx)?;
435
436 let result = acc.finish()?;
437 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(48));
438 Ok(())
439 }
440
441 #[test]
442 fn sum_finish_resets_state() -> VortexResult<()> {
443 let mut ctx = LEGACY_SESSION.create_execution_ctx();
444 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
445 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?;
446
447 let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
448 acc.accumulate(&batch1, &mut ctx)?;
449 let result1 = acc.finish()?;
450 assert_eq!(result1.as_primitive().typed_value::<i64>(), Some(30));
451
452 let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
453 acc.accumulate(&batch2, &mut ctx)?;
454 let result2 = acc.finish()?;
455 assert_eq!(result2.as_primitive().typed_value::<i64>(), Some(18));
456 Ok(())
457 }
458
459 #[test]
462 fn sum_state_merge() -> VortexResult<()> {
463 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
464 let mut state = Sum.empty_partial(&EmptyOptions, &dtype)?;
465
466 let scalar1 = Scalar::primitive(100i64, Nullable);
467 Sum.combine_partials(&mut state, scalar1)?;
468
469 let scalar2 = Scalar::primitive(50i64, Nullable);
470 Sum.combine_partials(&mut state, scalar2)?;
471
472 let result = Sum.to_scalar(&state)?;
473 Sum.reset(&mut state);
474 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(150));
475 Ok(())
476 }
477
478 #[test]
481 fn sum_stats() -> VortexResult<()> {
482 let array = ChunkedArray::try_new(
483 vec![
484 PrimitiveArray::from_iter([1, 1, 1]).into_array(),
485 PrimitiveArray::from_iter([2, 2, 2]).into_array(),
486 ],
487 DType::Primitive(PType::I32, Nullability::NonNullable),
488 )
489 .vortex_expect("operation should succeed in test");
490 let array = array.into_array();
491 sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullable))?;
493
494 let sum_without_acc = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?;
495 assert_eq!(sum_without_acc, Scalar::primitive(9i64, Nullable));
496 Ok(())
497 }
498
499 #[test]
502 fn sum_constant_float_non_multiply() -> VortexResult<()> {
503 let acc = -2048669276050936500000000000f64;
504 let array = ConstantArray::new(6.1811675e16f64, 25);
505 let result = sum_with_accumulator(&array.into_array(), &Scalar::primitive(acc, Nullable))
506 .vortex_expect("operation should succeed in test");
507 assert_eq!(
508 f64::try_from(&result).vortex_expect("operation should succeed in test"),
509 -2048669274505644600000000000f64
510 );
511 Ok(())
512 }
513
514 fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult<ArrayRef> {
517 let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?;
518 acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?;
519 acc.finish()
520 }
521
522 #[test]
523 fn grouped_sum_fixed_size_list() -> VortexResult<()> {
524 let elements =
525 PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array();
526 let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
527
528 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
529 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
530
531 let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array();
532 assert_arrays_eq!(&result, &expected);
533 Ok(())
534 }
535
536 #[test]
537 fn grouped_sum_with_null_elements() -> VortexResult<()> {
538 let elements =
539 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)])
540 .into_array();
541 let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
542
543 let elem_dtype = DType::Primitive(PType::I32, Nullable);
544 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
545
546 let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array();
547 assert_arrays_eq!(&result, &expected);
548 Ok(())
549 }
550
551 #[test]
552 fn grouped_sum_with_null_group() -> VortexResult<()> {
553 let elements =
554 PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable)
555 .into_array();
556 let validity = Validity::from_iter([true, false, true]);
557 let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?;
558
559 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
560 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
561
562 let expected =
563 PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array();
564 assert_arrays_eq!(&result, &expected);
565 Ok(())
566 }
567
568 #[test]
569 fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> {
570 let elements =
571 PrimitiveArray::from_option_iter([None::<i32>, None, Some(3), Some(4)]).into_array();
572 let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?;
573
574 let elem_dtype = DType::Primitive(PType::I32, Nullable);
575 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
576
577 let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array();
578 assert_arrays_eq!(&result, &expected);
579 Ok(())
580 }
581
582 #[test]
583 fn grouped_sum_bool() -> VortexResult<()> {
584 let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect();
585 let groups =
586 FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?;
587
588 let elem_dtype = DType::Bool(Nullability::NonNullable);
589 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
590
591 let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array();
592 assert_arrays_eq!(&result, &expected);
593 Ok(())
594 }
595
596 #[test]
597 fn grouped_sum_finish_resets() -> VortexResult<()> {
598 let mut ctx = LEGACY_SESSION.create_execution_ctx();
599 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
600 let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype)?;
601
602 let elements1 =
603 PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
604 let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?;
605 acc.accumulate_list(&groups1.into_array(), &mut ctx)?;
606 let result1 = acc.finish()?;
607
608 let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array();
609 assert_arrays_eq!(&result1, &expected1);
610
611 let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
612 let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?;
613 acc.accumulate_list(&groups2.into_array(), &mut ctx)?;
614 let result2 = acc.finish()?;
615
616 let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array();
617 assert_arrays_eq!(&result2, &expected2);
618 Ok(())
619 }
620
621 #[test]
622 fn grouped_sum_listview_out_of_order_offsets_with_null_group() -> VortexResult<()> {
623 let elements =
624 PrimitiveArray::new(buffer![100i32, 200, 300], Validity::NonNullable).into_array();
625 let offsets = PrimitiveArray::new(buffer![2i32, 0, 1], Validity::NonNullable).into_array();
626 let sizes = PrimitiveArray::new(buffer![1i32, 1, 1], Validity::NonNullable).into_array();
627 let validity = Validity::from_iter([true, false, true]);
628 let groups = ListViewArray::try_new(elements, offsets, sizes, validity)?.into_array();
629
630 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
631 let result = run_grouped_sum(&groups, &elem_dtype)?;
632
633 let expected =
635 PrimitiveArray::from_option_iter([Some(300i64), None, Some(200i64)]).into_array();
636 assert_arrays_eq!(&result, &expected);
637 Ok(())
638 }
639
640 #[test]
643 fn sum_chunked_floats_with_nulls() -> VortexResult<()> {
644 let chunk1 =
645 PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]);
646 let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]);
647 let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]);
648 let dtype = chunk1.dtype().clone();
649 let chunked = ChunkedArray::try_new(
650 vec![
651 chunk1.into_array(),
652 chunk2.into_array(),
653 chunk3.into_array(),
654 ],
655 dtype,
656 )?;
657
658 let result = sum(
659 &chunked.into_array(),
660 &mut LEGACY_SESSION.create_execution_ctx(),
661 )?;
662 assert_eq!(result.as_primitive().as_::<f64>(), Some(20.8));
663 Ok(())
664 }
665
666 #[test]
667 fn sum_chunked_floats_all_nulls_is_zero() -> VortexResult<()> {
668 let chunk1 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None, None]);
669 let chunk2 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None]);
670 let dtype = chunk1.dtype().clone();
671 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
672 let result = sum(
673 &chunked.into_array(),
674 &mut LEGACY_SESSION.create_execution_ctx(),
675 )?;
676 assert_eq!(result, Scalar::primitive(0f64, Nullable));
677 Ok(())
678 }
679
680 #[test]
681 fn sum_chunked_floats_empty_chunks() -> VortexResult<()> {
682 let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]);
683 let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullable), 0);
684 let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]);
685 let dtype = chunk1.dtype().clone();
686 let chunked = ChunkedArray::try_new(
687 vec![
688 chunk1.into_array(),
689 chunk2.into_array(),
690 chunk3.into_array(),
691 ],
692 dtype,
693 )?;
694
695 let result = sum(
696 &chunked.into_array(),
697 &mut LEGACY_SESSION.create_execution_ctx(),
698 )?;
699 assert_eq!(result.as_primitive().as_::<f64>(), Some(36.0));
700 Ok(())
701 }
702
703 #[test]
704 fn sum_chunked_int_almost_all_null() -> VortexResult<()> {
705 let chunk1 = PrimitiveArray::from_option_iter::<u32, _>(vec![Some(1)]);
706 let chunk2 = PrimitiveArray::from_option_iter::<u32, _>(vec![None]);
707 let dtype = chunk1.dtype().clone();
708 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
709
710 let result = sum(
711 &chunked.into_array(),
712 &mut LEGACY_SESSION.create_execution_ctx(),
713 )?;
714 assert_eq!(result.as_primitive().as_::<u64>(), Some(1));
715 Ok(())
716 }
717
718 #[test]
719 fn sum_chunked_decimals() -> VortexResult<()> {
720 let decimal_dtype = DecimalDType::new(10, 2);
721 let chunk1 = DecimalArray::new(
722 buffer![100i32, 100i32, 100i32, 100i32, 100i32],
723 decimal_dtype,
724 Validity::AllValid,
725 );
726 let chunk2 = DecimalArray::new(
727 buffer![200i32, 200i32, 200i32],
728 decimal_dtype,
729 Validity::AllValid,
730 );
731 let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid);
732 let dtype = chunk1.dtype().clone();
733 let chunked = ChunkedArray::try_new(
734 vec![
735 chunk1.into_array(),
736 chunk2.into_array(),
737 chunk3.into_array(),
738 ],
739 dtype,
740 )?;
741
742 let result = sum(
743 &chunked.into_array(),
744 &mut LEGACY_SESSION.create_execution_ctx(),
745 )?;
746 let decimal_result = result.as_decimal();
747 assert_eq!(
748 decimal_result.decimal_value(),
749 Some(DecimalValue::I256(i256::from_i128(1700)))
750 );
751 Ok(())
752 }
753
754 #[test]
755 fn sum_chunked_decimals_with_nulls() -> VortexResult<()> {
756 let decimal_dtype = DecimalDType::new(10, 2);
757 let chunk1 = DecimalArray::new(
758 buffer![100i32, 100i32, 100i32],
759 decimal_dtype,
760 Validity::AllValid,
761 );
762 let chunk2 = DecimalArray::new(
763 buffer![0i32, 0i32],
764 decimal_dtype,
765 Validity::from_iter([false, false]),
766 );
767 let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid);
768 let dtype = chunk1.dtype().clone();
769 let chunked = ChunkedArray::try_new(
770 vec![
771 chunk1.into_array(),
772 chunk2.into_array(),
773 chunk3.into_array(),
774 ],
775 dtype,
776 )?;
777
778 let result = sum(
779 &chunked.into_array(),
780 &mut LEGACY_SESSION.create_execution_ctx(),
781 )?;
782 let decimal_result = result.as_decimal();
783 assert_eq!(
784 decimal_result.decimal_value(),
785 Some(DecimalValue::I256(i256::from_i128(700)))
786 );
787 Ok(())
788 }
789
790 #[test]
791 fn sum_chunked_decimals_large() -> VortexResult<()> {
792 let decimal_dtype = DecimalDType::new(3, 0);
793 let chunk1 = ConstantArray::new(
794 Scalar::decimal(
795 DecimalValue::I16(500),
796 decimal_dtype,
797 Nullability::NonNullable,
798 ),
799 1,
800 );
801 let chunk2 = ConstantArray::new(
802 Scalar::decimal(
803 DecimalValue::I16(600),
804 decimal_dtype,
805 Nullability::NonNullable,
806 ),
807 1,
808 );
809 let dtype = chunk1.dtype().clone();
810 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
811
812 let result = sum(
813 &chunked.into_array(),
814 &mut LEGACY_SESSION.create_execution_ctx(),
815 )?;
816 let decimal_result = result.as_decimal();
817 assert_eq!(
818 decimal_result.decimal_value(),
819 Some(DecimalValue::I256(i256::from_i128(1100)))
820 );
821 assert_eq!(
822 result.dtype(),
823 &DType::Decimal(DecimalDType::new(13, 0), Nullable)
824 );
825 Ok(())
826 }
827}