1mod bool;
5mod constant;
6mod decimal;
7mod primitive;
8
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_error::vortex_panic;
14
15use self::bool::accumulate_bool;
16use self::constant::multiply_constant;
17use self::decimal::accumulate_decimal;
18use self::primitive::accumulate_primitive;
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::Columnar;
22use crate::ExecutionCtx;
23use crate::aggregate_fn::Accumulator;
24use crate::aggregate_fn::AggregateFnId;
25use crate::aggregate_fn::AggregateFnVTable;
26use crate::aggregate_fn::DynAccumulator;
27use crate::aggregate_fn::EmptyOptions;
28use crate::dtype::DType;
29use crate::dtype::DecimalDType;
30use crate::dtype::MAX_PRECISION;
31use crate::dtype::Nullability;
32use crate::dtype::PType;
33use crate::expr::stats::Precision;
34use crate::expr::stats::Stat;
35use crate::expr::stats::StatsProvider;
36use crate::scalar::DecimalValue;
37use crate::scalar::Scalar;
38
39pub fn sum(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
43 if let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) {
45 return Ok(sum_scalar);
46 }
47
48 let mut acc = Accumulator::try_new(Sum, EmptyOptions, array.dtype().clone())?;
51 acc.accumulate(array, ctx)?;
52 let result = acc.finish()?;
53
54 if let Some(val) = result.value().cloned() {
56 array.statistics().set(Stat::Sum, Precision::Exact(val));
57 }
58
59 Ok(result)
60}
61
62#[derive(Clone, Debug)]
67pub struct Sum;
68
69impl AggregateFnVTable for Sum {
70 type Options = EmptyOptions;
71 type Partial = SumPartial;
72
73 fn id(&self) -> AggregateFnId {
74 AggregateFnId::new_ref("vortex.sum")
75 }
76
77 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
78 unimplemented!("Sum is not yet serializable");
79 }
80
81 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
82 use Nullability::Nullable;
85
86 Some(match input_dtype {
87 DType::Bool(_) => DType::Primitive(PType::U64, Nullable),
88 DType::Primitive(ptype, _) => match ptype {
89 PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
90 DType::Primitive(PType::U64, Nullable)
91 }
92 PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
93 DType::Primitive(PType::I64, Nullable)
94 }
95 PType::F16 | PType::F32 | PType::F64 => {
96 DType::Primitive(PType::F64, Nullable)
98 }
99 },
100 DType::Decimal(decimal_dtype, _) => {
101 let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10);
105 DType::Decimal(
106 DecimalDType::new(precision, decimal_dtype.scale()),
107 Nullable,
108 )
109 }
110 _ => return None,
112 })
113 }
114
115 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
116 self.return_dtype(options, input_dtype)
117 }
118
119 fn empty_partial(
120 &self,
121 options: &Self::Options,
122 input_dtype: &DType,
123 ) -> VortexResult<Self::Partial> {
124 let return_dtype = self
125 .return_dtype(options, input_dtype)
126 .ok_or_else(|| vortex_err!("Unsupported sum dtype: {}", input_dtype))?;
127 let initial = make_zero_state(&return_dtype);
128
129 Ok(SumPartial {
130 return_dtype,
131 current: Some(initial),
132 })
133 }
134
135 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
136 if other.is_null() {
137 partial.current = None;
139 return Ok(());
140 }
141 let Some(ref mut inner) = partial.current else {
142 return Ok(());
143 };
144 let saturated = match inner {
145 SumState::Unsigned(acc) => {
146 let val = other
147 .as_primitive()
148 .typed_value::<u64>()
149 .vortex_expect("checked non-null");
150 checked_add_u64(acc, val)
151 }
152 SumState::Signed(acc) => {
153 let val = other
154 .as_primitive()
155 .typed_value::<i64>()
156 .vortex_expect("checked non-null");
157 checked_add_i64(acc, val)
158 }
159 SumState::Float(acc) => {
160 let val = other
161 .as_primitive()
162 .typed_value::<f64>()
163 .vortex_expect("checked non-null");
164 *acc += val;
165 false
166 }
167 SumState::Decimal { value, dtype } => {
168 let val = other
169 .as_decimal()
170 .decimal_value()
171 .vortex_expect("checked non-null");
172 match value.checked_add(&val) {
173 Some(r) => {
174 *value = r;
175 !value.fits_in_precision(*dtype)
176 }
177 None => true,
178 }
179 }
180 };
181 if saturated {
182 partial.current = None;
183 }
184 Ok(())
185 }
186
187 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
188 Ok(match &partial.current {
189 None => Scalar::null(partial.return_dtype.as_nullable()),
190 Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable),
191 Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable),
192 Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable),
193 Some(SumState::Decimal { value, .. }) => {
194 let decimal_dtype = *partial
195 .return_dtype
196 .as_decimal_opt()
197 .vortex_expect("return dtype must be decimal");
198 Scalar::decimal(*value, decimal_dtype, Nullability::Nullable)
199 }
200 })
201 }
202
203 fn reset(&self, partial: &mut Self::Partial) {
204 partial.current = Some(make_zero_state(&partial.return_dtype));
205 }
206
207 #[inline]
208 fn is_saturated(&self, partial: &Self::Partial) -> bool {
209 match partial.current.as_ref() {
210 None => true,
211 Some(SumState::Float(v)) => v.is_nan(),
212 Some(_) => false,
213 }
214 }
215
216 fn accumulate(
217 &self,
218 partial: &mut Self::Partial,
219 batch: &Columnar,
220 _ctx: &mut ExecutionCtx,
221 ) -> VortexResult<()> {
222 if let Columnar::Constant(c) = batch {
224 if let Some(product) = multiply_constant(c.scalar(), c.len(), &partial.return_dtype)? {
225 self.combine_partials(partial, product)?;
226 }
227 return Ok(());
228 }
229
230 let mut inner = match partial.current.take() {
231 Some(inner) => inner,
232 None => return Ok(()),
233 };
234
235 let result = match batch {
236 Columnar::Canonical(c) => match c {
237 Canonical::Primitive(p) => accumulate_primitive(&mut inner, p),
238 Canonical::Bool(b) => accumulate_bool(&mut inner, b),
239 Canonical::Decimal(d) => accumulate_decimal(&mut inner, d),
240 _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()),
241 },
242 Columnar::Constant(_) => unreachable!(),
243 };
244
245 match result {
246 Ok(false) => partial.current = Some(inner),
247 Ok(true) => {} Err(e) => {
249 partial.current = Some(inner);
250 return Err(e);
251 }
252 }
253 Ok(())
254 }
255
256 fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
257 Ok(partials)
258 }
259
260 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
261 self.to_scalar(partial)
262 }
263}
264
265pub struct SumPartial {
268 return_dtype: DType,
269 current: Option<SumState>,
271}
272
273pub enum SumState {
278 Unsigned(u64),
279 Signed(i64),
280 Float(f64),
281 Decimal {
282 value: DecimalValue,
283 dtype: DecimalDType,
284 },
285}
286
287fn make_zero_state(return_dtype: &DType) -> SumState {
288 match return_dtype {
289 DType::Primitive(ptype, _) => match ptype {
290 PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0),
291 PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0),
292 PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0),
293 },
294 DType::Decimal(decimal, _) => SumState::Decimal {
295 value: DecimalValue::zero(decimal),
296 dtype: *decimal,
297 },
298 _ => vortex_panic!("Unsupported sum type"),
299 }
300}
301
302#[inline(always)]
304fn checked_add_u64(acc: &mut u64, val: u64) -> bool {
305 match acc.checked_add(val) {
306 Some(r) => {
307 *acc = r;
308 false
309 }
310 None => true,
311 }
312}
313
314#[inline(always)]
316fn checked_add_i64(acc: &mut i64, val: i64) -> bool {
317 match acc.checked_add(val) {
318 Some(r) => {
319 *acc = r;
320 false
321 }
322 None => true,
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use num_traits::CheckedAdd;
329 use vortex_buffer::buffer;
330 use vortex_error::VortexExpect;
331 use vortex_error::VortexResult;
332
333 use crate::ArrayRef;
334 use crate::IntoArray;
335 use crate::LEGACY_SESSION;
336 use crate::VortexSessionExecute;
337 use crate::aggregate_fn::Accumulator;
338 use crate::aggregate_fn::AggregateFnVTable;
339 use crate::aggregate_fn::DynAccumulator;
340 use crate::aggregate_fn::DynGroupedAccumulator;
341 use crate::aggregate_fn::EmptyOptions;
342 use crate::aggregate_fn::GroupedAccumulator;
343 use crate::aggregate_fn::fns::sum::Sum;
344 use crate::aggregate_fn::fns::sum::sum;
345 use crate::arrays::BoolArray;
346 use crate::arrays::ChunkedArray;
347 use crate::arrays::ConstantArray;
348 use crate::arrays::DecimalArray;
349 use crate::arrays::FixedSizeListArray;
350 use crate::arrays::PrimitiveArray;
351 use crate::assert_arrays_eq;
352 use crate::dtype::DType;
353 use crate::dtype::DecimalDType;
354 use crate::dtype::Nullability;
355 use crate::dtype::Nullability::Nullable;
356 use crate::dtype::PType;
357 use crate::dtype::i256;
358 use crate::expr::stats::Precision;
359 use crate::expr::stats::Stat;
360 use crate::expr::stats::StatsProvider;
361 use crate::scalar::DecimalValue;
362 use crate::scalar::NumericOperator;
363 use crate::scalar::Scalar;
364 use crate::validity::Validity;
365
366 fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult<Scalar> {
368 let mut ctx = LEGACY_SESSION.create_execution_ctx();
369 if accumulator.is_null() {
370 return Ok(accumulator.clone());
371 }
372 if accumulator.is_zero() == Some(true) {
373 return sum(array, &mut ctx);
374 }
375
376 let sum_dtype = Stat::Sum.dtype(array.dtype()).ok_or_else(|| {
377 vortex_error::vortex_err!("Sum not supported for dtype: {}", array.dtype())
378 })?;
379
380 if !matches!(&sum_dtype, DType::Primitive(p, _) if p.is_float())
382 && let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum)
383 {
384 return add_scalars(&sum_dtype, &sum_scalar, accumulator);
385 }
386
387 let array_sum = sum(array, &mut ctx)?;
389
390 add_scalars(&sum_dtype, &array_sum, accumulator)
392 }
393
394 fn add_scalars(sum_dtype: &DType, lhs: &Scalar, rhs: &Scalar) -> VortexResult<Scalar> {
396 if lhs.is_null() || rhs.is_null() {
397 return Ok(Scalar::null(sum_dtype.as_nullable()));
398 }
399
400 Ok(match sum_dtype {
401 DType::Primitive(ptype, _) if ptype.is_float() => {
402 let lhs_val = f64::try_from(lhs)?;
403 let rhs_val = f64::try_from(rhs)?;
404 Scalar::primitive(lhs_val + rhs_val, Nullable)
405 }
406 DType::Primitive(..) => lhs
407 .as_primitive()
408 .checked_add(&rhs.as_primitive())
409 .map(Scalar::from)
410 .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())),
411 DType::Decimal(..) => lhs
412 .as_decimal()
413 .checked_binary_numeric(&rhs.as_decimal(), NumericOperator::Add)
414 .map(Scalar::from)
415 .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())),
416 _ => unreachable!("Sum will always be a decimal or a primitive dtype"),
417 })
418 }
419
420 #[test]
423 fn sum_multi_batch() -> VortexResult<()> {
424 let mut ctx = LEGACY_SESSION.create_execution_ctx();
425 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
426 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?;
427
428 let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
429 acc.accumulate(&batch1, &mut ctx)?;
430
431 let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
432 acc.accumulate(&batch2, &mut ctx)?;
433
434 let result = acc.finish()?;
435 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(48));
436 Ok(())
437 }
438
439 #[test]
440 fn sum_finish_resets_state() -> VortexResult<()> {
441 let mut ctx = LEGACY_SESSION.create_execution_ctx();
442 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
443 let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?;
444
445 let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
446 acc.accumulate(&batch1, &mut ctx)?;
447 let result1 = acc.finish()?;
448 assert_eq!(result1.as_primitive().typed_value::<i64>(), Some(30));
449
450 let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
451 acc.accumulate(&batch2, &mut ctx)?;
452 let result2 = acc.finish()?;
453 assert_eq!(result2.as_primitive().typed_value::<i64>(), Some(18));
454 Ok(())
455 }
456
457 #[test]
460 fn sum_state_merge() -> VortexResult<()> {
461 let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
462 let mut state = Sum.empty_partial(&EmptyOptions, &dtype)?;
463
464 let scalar1 = Scalar::primitive(100i64, Nullable);
465 Sum.combine_partials(&mut state, scalar1)?;
466
467 let scalar2 = Scalar::primitive(50i64, Nullable);
468 Sum.combine_partials(&mut state, scalar2)?;
469
470 let result = Sum.to_scalar(&state)?;
471 Sum.reset(&mut state);
472 assert_eq!(result.as_primitive().typed_value::<i64>(), Some(150));
473 Ok(())
474 }
475
476 #[test]
479 fn sum_stats() -> VortexResult<()> {
480 let array = ChunkedArray::try_new(
481 vec![
482 PrimitiveArray::from_iter([1, 1, 1]).into_array(),
483 PrimitiveArray::from_iter([2, 2, 2]).into_array(),
484 ],
485 DType::Primitive(PType::I32, Nullability::NonNullable),
486 )
487 .vortex_expect("operation should succeed in test");
488 let array = array.into_array();
489 sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullable))?;
491
492 let sum_without_acc = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?;
493 assert_eq!(sum_without_acc, Scalar::primitive(9i64, Nullable));
494 Ok(())
495 }
496
497 #[test]
500 fn sum_constant_float_non_multiply() -> VortexResult<()> {
501 let acc = -2048669276050936500000000000f64;
502 let array = ConstantArray::new(6.1811675e16f64, 25);
503 let result = sum_with_accumulator(&array.into_array(), &Scalar::primitive(acc, Nullable))
504 .vortex_expect("operation should succeed in test");
505 assert_eq!(
506 f64::try_from(&result).vortex_expect("operation should succeed in test"),
507 -2048669274505644600000000000f64
508 );
509 Ok(())
510 }
511
512 fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult<ArrayRef> {
515 let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?;
516 acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?;
517 acc.finish()
518 }
519
520 #[test]
521 fn grouped_sum_fixed_size_list() -> VortexResult<()> {
522 let elements =
523 PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array();
524 let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
525
526 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
527 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
528
529 let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array();
530 assert_arrays_eq!(&result, &expected);
531 Ok(())
532 }
533
534 #[test]
535 fn grouped_sum_with_null_elements() -> VortexResult<()> {
536 let elements =
537 PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)])
538 .into_array();
539 let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
540
541 let elem_dtype = DType::Primitive(PType::I32, Nullable);
542 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
543
544 let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array();
545 assert_arrays_eq!(&result, &expected);
546 Ok(())
547 }
548
549 #[test]
550 fn grouped_sum_with_null_group() -> VortexResult<()> {
551 let elements =
552 PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable)
553 .into_array();
554 let validity = Validity::from_iter([true, false, true]);
555 let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?;
556
557 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
558 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
559
560 let expected =
561 PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array();
562 assert_arrays_eq!(&result, &expected);
563 Ok(())
564 }
565
566 #[test]
567 fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> {
568 let elements =
569 PrimitiveArray::from_option_iter([None::<i32>, None, Some(3), Some(4)]).into_array();
570 let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?;
571
572 let elem_dtype = DType::Primitive(PType::I32, Nullable);
573 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
574
575 let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array();
576 assert_arrays_eq!(&result, &expected);
577 Ok(())
578 }
579
580 #[test]
581 fn grouped_sum_bool() -> VortexResult<()> {
582 let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect();
583 let groups =
584 FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?;
585
586 let elem_dtype = DType::Bool(Nullability::NonNullable);
587 let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
588
589 let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array();
590 assert_arrays_eq!(&result, &expected);
591 Ok(())
592 }
593
594 #[test]
595 fn grouped_sum_finish_resets() -> VortexResult<()> {
596 let mut ctx = LEGACY_SESSION.create_execution_ctx();
597 let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
598 let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype)?;
599
600 let elements1 =
601 PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
602 let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?;
603 acc.accumulate_list(&groups1.into_array(), &mut ctx)?;
604 let result1 = acc.finish()?;
605
606 let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array();
607 assert_arrays_eq!(&result1, &expected1);
608
609 let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
610 let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?;
611 acc.accumulate_list(&groups2.into_array(), &mut ctx)?;
612 let result2 = acc.finish()?;
613
614 let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array();
615 assert_arrays_eq!(&result2, &expected2);
616 Ok(())
617 }
618
619 #[test]
622 fn sum_chunked_floats_with_nulls() -> VortexResult<()> {
623 let chunk1 =
624 PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]);
625 let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]);
626 let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]);
627 let dtype = chunk1.dtype().clone();
628 let chunked = ChunkedArray::try_new(
629 vec![
630 chunk1.into_array(),
631 chunk2.into_array(),
632 chunk3.into_array(),
633 ],
634 dtype,
635 )?;
636
637 let result = sum(
638 &chunked.into_array(),
639 &mut LEGACY_SESSION.create_execution_ctx(),
640 )?;
641 assert_eq!(result.as_primitive().as_::<f64>(), Some(20.8));
642 Ok(())
643 }
644
645 #[test]
646 fn sum_chunked_floats_all_nulls_is_zero() -> VortexResult<()> {
647 let chunk1 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None, None]);
648 let chunk2 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None]);
649 let dtype = chunk1.dtype().clone();
650 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
651 let result = sum(
652 &chunked.into_array(),
653 &mut LEGACY_SESSION.create_execution_ctx(),
654 )?;
655 assert_eq!(result, Scalar::primitive(0f64, Nullable));
656 Ok(())
657 }
658
659 #[test]
660 fn sum_chunked_floats_empty_chunks() -> VortexResult<()> {
661 let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]);
662 let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullable), 0);
663 let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]);
664 let dtype = chunk1.dtype().clone();
665 let chunked = ChunkedArray::try_new(
666 vec![
667 chunk1.into_array(),
668 chunk2.into_array(),
669 chunk3.into_array(),
670 ],
671 dtype,
672 )?;
673
674 let result = sum(
675 &chunked.into_array(),
676 &mut LEGACY_SESSION.create_execution_ctx(),
677 )?;
678 assert_eq!(result.as_primitive().as_::<f64>(), Some(36.0));
679 Ok(())
680 }
681
682 #[test]
683 fn sum_chunked_int_almost_all_null() -> VortexResult<()> {
684 let chunk1 = PrimitiveArray::from_option_iter::<u32, _>(vec![Some(1)]);
685 let chunk2 = PrimitiveArray::from_option_iter::<u32, _>(vec![None]);
686 let dtype = chunk1.dtype().clone();
687 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
688
689 let result = sum(
690 &chunked.into_array(),
691 &mut LEGACY_SESSION.create_execution_ctx(),
692 )?;
693 assert_eq!(result.as_primitive().as_::<u64>(), Some(1));
694 Ok(())
695 }
696
697 #[test]
698 fn sum_chunked_decimals() -> VortexResult<()> {
699 let decimal_dtype = DecimalDType::new(10, 2);
700 let chunk1 = DecimalArray::new(
701 buffer![100i32, 100i32, 100i32, 100i32, 100i32],
702 decimal_dtype,
703 Validity::AllValid,
704 );
705 let chunk2 = DecimalArray::new(
706 buffer![200i32, 200i32, 200i32],
707 decimal_dtype,
708 Validity::AllValid,
709 );
710 let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid);
711 let dtype = chunk1.dtype().clone();
712 let chunked = ChunkedArray::try_new(
713 vec![
714 chunk1.into_array(),
715 chunk2.into_array(),
716 chunk3.into_array(),
717 ],
718 dtype,
719 )?;
720
721 let result = sum(
722 &chunked.into_array(),
723 &mut LEGACY_SESSION.create_execution_ctx(),
724 )?;
725 let decimal_result = result.as_decimal();
726 assert_eq!(
727 decimal_result.decimal_value(),
728 Some(DecimalValue::I256(i256::from_i128(1700)))
729 );
730 Ok(())
731 }
732
733 #[test]
734 fn sum_chunked_decimals_with_nulls() -> VortexResult<()> {
735 let decimal_dtype = DecimalDType::new(10, 2);
736 let chunk1 = DecimalArray::new(
737 buffer![100i32, 100i32, 100i32],
738 decimal_dtype,
739 Validity::AllValid,
740 );
741 let chunk2 = DecimalArray::new(
742 buffer![0i32, 0i32],
743 decimal_dtype,
744 Validity::from_iter([false, false]),
745 );
746 let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid);
747 let dtype = chunk1.dtype().clone();
748 let chunked = ChunkedArray::try_new(
749 vec![
750 chunk1.into_array(),
751 chunk2.into_array(),
752 chunk3.into_array(),
753 ],
754 dtype,
755 )?;
756
757 let result = sum(
758 &chunked.into_array(),
759 &mut LEGACY_SESSION.create_execution_ctx(),
760 )?;
761 let decimal_result = result.as_decimal();
762 assert_eq!(
763 decimal_result.decimal_value(),
764 Some(DecimalValue::I256(i256::from_i128(700)))
765 );
766 Ok(())
767 }
768
769 #[test]
770 fn sum_chunked_decimals_large() -> VortexResult<()> {
771 let decimal_dtype = DecimalDType::new(3, 0);
772 let chunk1 = ConstantArray::new(
773 Scalar::decimal(
774 DecimalValue::I16(500),
775 decimal_dtype,
776 Nullability::NonNullable,
777 ),
778 1,
779 );
780 let chunk2 = ConstantArray::new(
781 Scalar::decimal(
782 DecimalValue::I16(600),
783 decimal_dtype,
784 Nullability::NonNullable,
785 ),
786 1,
787 );
788 let dtype = chunk1.dtype().clone();
789 let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
790
791 let result = sum(
792 &chunked.into_array(),
793 &mut LEGACY_SESSION.create_execution_ctx(),
794 )?;
795 let decimal_result = result.as_decimal();
796 assert_eq!(
797 decimal_result.decimal_value(),
798 Some(DecimalValue::I256(i256::from_i128(1100)))
799 );
800 assert_eq!(
801 result.dtype(),
802 &DType::Decimal(DecimalDType::new(13, 0), Nullable)
803 );
804 Ok(())
805 }
806}