Skip to main content

vortex_array/aggregate_fn/fns/sum/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod bool;
5mod constant;
6mod decimal;
7mod grouped;
8mod primitive;
9pub(crate) use grouped::PrimitiveGroupedSumEncodingKernel;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15
16use self::bool::accumulate_bool;
17use self::constant::multiply_constant;
18use self::decimal::accumulate_decimal;
19use self::primitive::accumulate_primitive;
20use crate::ArrayRef;
21use crate::Canonical;
22use crate::Columnar;
23use crate::ExecutionCtx;
24use crate::aggregate_fn::Accumulator;
25use crate::aggregate_fn::AggregateFnId;
26use crate::aggregate_fn::AggregateFnVTable;
27use crate::aggregate_fn::DynAccumulator;
28use crate::aggregate_fn::EmptyOptions;
29use crate::dtype::DType;
30use crate::dtype::DecimalDType;
31use crate::dtype::MAX_PRECISION;
32use crate::dtype::Nullability;
33use crate::dtype::PType;
34use crate::expr::stats::Precision;
35use crate::expr::stats::Stat;
36use crate::expr::stats::StatsProvider;
37use crate::scalar::DecimalValue;
38use crate::scalar::Scalar;
39
40/// Return the sum of an array.
41///
42/// See [`Sum`] for details.
43pub fn sum(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
44    // Short-circuit using cached array statistics.
45    if let Precision::Exact(sum_scalar) = array.statistics().get(Stat::Sum) {
46        return Ok(sum_scalar);
47    }
48
49    // Compute using Accumulator<Sum>.
50    // TODO(ngates): we may want to wrap this three-step dance up into an extension crate maybe.
51    let mut acc = Accumulator::try_new(Sum, EmptyOptions, array.dtype().clone())?;
52    acc.accumulate(array, ctx)?;
53    let result = acc.finish()?;
54
55    // Cache the computed sum as a statistic (only if non-null, i.e. no overflow).
56    if let Some(val) = result.value().cloned() {
57        array.statistics().set(Stat::Sum, Precision::Exact(val));
58    }
59
60    Ok(result)
61}
62
63/// Sum an array, starting from zero.
64///
65/// If the sum overflows, a null scalar will be returned.
66/// If the array is all-invalid, the sum will be zero.
67#[derive(Clone, Debug)]
68pub struct Sum;
69
70impl AggregateFnVTable for Sum {
71    type Options = EmptyOptions;
72    type Partial = SumPartial;
73
74    fn id(&self) -> AggregateFnId {
75        AggregateFnId::new("vortex.sum")
76    }
77
78    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
79        unimplemented!("Sum is not yet serializable");
80    }
81
82    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
83        // When a sum overflows, we return a sum _value_ of null. Therefore, we all return dtypes
84        // are nullable.
85        use Nullability::Nullable;
86
87        Some(match input_dtype {
88            DType::Bool(_) => DType::Primitive(PType::U64, Nullable),
89            DType::Primitive(ptype, _) => match ptype {
90                PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
91                    DType::Primitive(PType::U64, Nullable)
92                }
93                PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
94                    DType::Primitive(PType::I64, Nullable)
95                }
96                PType::F16 | PType::F32 | PType::F64 => {
97                    // Float sums cannot overflow, but all null floats still end up as null
98                    DType::Primitive(PType::F64, Nullable)
99                }
100            },
101            DType::Decimal(decimal_dtype, _) => {
102                // Both Spark and DataFusion use this heuristic.
103                // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
104                // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188
105                let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10);
106                DType::Decimal(
107                    DecimalDType::new(precision, decimal_dtype.scale()),
108                    Nullable,
109                )
110            }
111            // Unsupported types
112            _ => return None,
113        })
114    }
115
116    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
117        self.return_dtype(options, input_dtype)
118    }
119
120    fn empty_partial(
121        &self,
122        options: &Self::Options,
123        input_dtype: &DType,
124    ) -> VortexResult<Self::Partial> {
125        let return_dtype = self
126            .return_dtype(options, input_dtype)
127            .ok_or_else(|| vortex_err!("Unsupported sum dtype: {}", input_dtype))?;
128        let initial = make_zero_state(&return_dtype);
129
130        Ok(SumPartial {
131            return_dtype,
132            current: Some(initial),
133        })
134    }
135
136    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
137        if other.is_null() {
138            // A null partial means the sub-accumulator saturated (overflow).
139            partial.current = None;
140            return Ok(());
141        }
142        let Some(ref mut inner) = partial.current else {
143            return Ok(());
144        };
145        let saturated = match inner {
146            SumState::Unsigned(acc) => {
147                let val = other
148                    .as_primitive()
149                    .typed_value::<u64>()
150                    .vortex_expect("checked non-null");
151                checked_add_u64(acc, val)
152            }
153            SumState::Signed(acc) => {
154                let val = other
155                    .as_primitive()
156                    .typed_value::<i64>()
157                    .vortex_expect("checked non-null");
158                checked_add_i64(acc, val)
159            }
160            SumState::Float(acc) => {
161                let val = other
162                    .as_primitive()
163                    .typed_value::<f64>()
164                    .vortex_expect("checked non-null");
165                *acc += val;
166                false
167            }
168            SumState::Decimal { value, dtype } => {
169                let val = other
170                    .as_decimal()
171                    .decimal_value()
172                    .vortex_expect("checked non-null");
173                match value.checked_add(&val) {
174                    Some(r) => {
175                        *value = r;
176                        !value.fits_in_precision(*dtype)
177                    }
178                    None => true,
179                }
180            }
181        };
182        if saturated {
183            partial.current = None;
184        }
185        Ok(())
186    }
187
188    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
189        Ok(match &partial.current {
190            None => Scalar::null(partial.return_dtype.as_nullable()),
191            Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable),
192            Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable),
193            Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable),
194            Some(SumState::Decimal { value, .. }) => {
195                let decimal_dtype = *partial
196                    .return_dtype
197                    .as_decimal_opt()
198                    .vortex_expect("return dtype must be decimal");
199                Scalar::decimal(*value, decimal_dtype, Nullability::Nullable)
200            }
201        })
202    }
203
204    fn reset(&self, partial: &mut Self::Partial) {
205        partial.current = Some(make_zero_state(&partial.return_dtype));
206    }
207
208    #[inline]
209    fn is_saturated(&self, partial: &Self::Partial) -> bool {
210        match partial.current.as_ref() {
211            None => true,
212            Some(SumState::Float(v)) => v.is_nan(),
213            Some(_) => false,
214        }
215    }
216
217    fn accumulate(
218        &self,
219        partial: &mut Self::Partial,
220        batch: &Columnar,
221        ctx: &mut ExecutionCtx,
222    ) -> VortexResult<()> {
223        // Constants compute scalar * len and combine via combine_partials.
224        if let Columnar::Constant(c) = batch {
225            if let Some(product) = multiply_constant(c.scalar(), c.len(), &partial.return_dtype)? {
226                self.combine_partials(partial, product)?;
227            }
228            return Ok(());
229        }
230
231        let mut inner = match partial.current.take() {
232            Some(inner) => inner,
233            None => return Ok(()),
234        };
235
236        let result = match batch {
237            Columnar::Canonical(c) => match c {
238                Canonical::Primitive(p) => accumulate_primitive(&mut inner, p, ctx),
239                Canonical::Bool(b) => accumulate_bool(&mut inner, b, ctx),
240                Canonical::Decimal(d) => accumulate_decimal(&mut inner, d, ctx),
241                _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()),
242            },
243            Columnar::Constant(_) => unreachable!(),
244        };
245
246        match result {
247            Ok(false) => partial.current = Some(inner),
248            Ok(true) => {} // saturated: current stays None
249            Err(e) => {
250                partial.current = Some(inner);
251                return Err(e);
252            }
253        }
254        Ok(())
255    }
256
257    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
258        Ok(partials)
259    }
260
261    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
262        self.to_scalar(partial)
263    }
264}
265
266/// The group state for a sum aggregate, containing the accumulated value and configuration
267/// needed for reset/result without external context.
268pub struct SumPartial {
269    return_dtype: DType,
270    /// The current accumulated state, or `None` if saturated (checked overflow).
271    current: Option<SumState>,
272}
273
274/// The accumulated sum value.
275///
276// TODO(ngates): instead of an enum, we should use a Box<dyn State> to avoid dispatcher over the
277//  input type every time? Perhaps?
278pub enum SumState {
279    Unsigned(u64),
280    Signed(i64),
281    Float(f64),
282    Decimal {
283        value: DecimalValue,
284        dtype: DecimalDType,
285    },
286}
287
288fn make_zero_state(return_dtype: &DType) -> SumState {
289    match return_dtype {
290        DType::Primitive(ptype, _) => match ptype {
291            PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0),
292            PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0),
293            PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0),
294        },
295        DType::Decimal(decimal, _) => SumState::Decimal {
296            value: DecimalValue::zero(decimal),
297            dtype: *decimal,
298        },
299        _ => vortex_panic!("Unsupported sum type"),
300    }
301}
302
303/// Checked add for u64, returning true if overflow occurred.
304#[inline(always)]
305fn checked_add_u64(acc: &mut u64, val: u64) -> bool {
306    match acc.checked_add(val) {
307        Some(r) => {
308            *acc = r;
309            false
310        }
311        None => true,
312    }
313}
314
315/// Checked add for i64, returning true if overflow occurred.
316#[inline(always)]
317fn checked_add_i64(acc: &mut i64, val: i64) -> bool {
318    match acc.checked_add(val) {
319        Some(r) => {
320            *acc = r;
321            false
322        }
323        None => true,
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use num_traits::CheckedAdd;
330    use vortex_buffer::buffer;
331    use vortex_error::VortexExpect;
332    use vortex_error::VortexResult;
333
334    use crate::ArrayRef;
335    use crate::IntoArray;
336    use crate::LEGACY_SESSION;
337    use crate::VortexSessionExecute;
338    use crate::aggregate_fn::Accumulator;
339    use crate::aggregate_fn::AggregateFnVTable;
340    use crate::aggregate_fn::DynAccumulator;
341    use crate::aggregate_fn::DynGroupedAccumulator;
342    use crate::aggregate_fn::EmptyOptions;
343    use crate::aggregate_fn::GroupedAccumulator;
344    use crate::aggregate_fn::fns::sum::Sum;
345    use crate::aggregate_fn::fns::sum::sum;
346    use crate::arrays::BoolArray;
347    use crate::arrays::ChunkedArray;
348    use crate::arrays::ConstantArray;
349    use crate::arrays::DecimalArray;
350    use crate::arrays::FixedSizeListArray;
351    use crate::arrays::ListViewArray;
352    use crate::arrays::PrimitiveArray;
353    use crate::assert_arrays_eq;
354    use crate::dtype::DType;
355    use crate::dtype::DecimalDType;
356    use crate::dtype::Nullability;
357    use crate::dtype::Nullability::Nullable;
358    use crate::dtype::PType;
359    use crate::dtype::i256;
360    use crate::expr::stats::Precision;
361    use crate::expr::stats::Stat;
362    use crate::expr::stats::StatsProvider;
363    use crate::scalar::DecimalValue;
364    use crate::scalar::NumericOperator;
365    use crate::scalar::Scalar;
366    use crate::validity::Validity;
367
368    /// Sum an array with an initial value (test-only helper).
369    fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult<Scalar> {
370        let mut ctx = LEGACY_SESSION.create_execution_ctx();
371        if accumulator.is_null() {
372            return Ok(accumulator.clone());
373        }
374        if accumulator.is_zero() == Some(true) {
375            return sum(array, &mut ctx);
376        }
377
378        let sum_dtype = Stat::Sum.dtype(array.dtype()).ok_or_else(|| {
379            vortex_error::vortex_err!("Sum not supported for dtype: {}", array.dtype())
380        })?;
381
382        // For non-float types, try statistics short-circuit with accumulator.
383        if !matches!(&sum_dtype, DType::Primitive(p, _) if p.is_float())
384            && let Precision::Exact(sum_scalar) = array.statistics().get(Stat::Sum)
385        {
386            return add_scalars(&sum_dtype, &sum_scalar, accumulator);
387        }
388
389        // Compute array sum from zero (also caches stats).
390        let array_sum = sum(array, &mut ctx)?;
391
392        // Combine with the accumulator.
393        add_scalars(&sum_dtype, &array_sum, accumulator)
394    }
395
396    /// Add two sum scalars with overflow checking.
397    fn add_scalars(sum_dtype: &DType, lhs: &Scalar, rhs: &Scalar) -> VortexResult<Scalar> {
398        if lhs.is_null() || rhs.is_null() {
399            return Ok(Scalar::null(sum_dtype.as_nullable()));
400        }
401
402        Ok(match sum_dtype {
403            DType::Primitive(ptype, _) if ptype.is_float() => {
404                let lhs_val = f64::try_from(lhs)?;
405                let rhs_val = f64::try_from(rhs)?;
406                Scalar::primitive(lhs_val + rhs_val, Nullable)
407            }
408            DType::Primitive(..) => lhs
409                .as_primitive()
410                .checked_add(&rhs.as_primitive())
411                .map(Scalar::from)
412                .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())),
413            DType::Decimal(..) => lhs
414                .as_decimal()
415                .checked_binary_numeric(&rhs.as_decimal(), NumericOperator::Add)
416                .map(Scalar::from)
417                .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())),
418            _ => unreachable!("Sum will always be a decimal or a primitive dtype"),
419        })
420    }
421
422    // Multi-batch and reset tests
423
424    #[test]
425    fn sum_multi_batch() -> VortexResult<()> {
426        let mut ctx = LEGACY_SESSION.create_execution_ctx();
427        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
428        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?;
429
430        let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
431        acc.accumulate(&batch1, &mut ctx)?;
432
433        let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
434        acc.accumulate(&batch2, &mut ctx)?;
435
436        let result = acc.finish()?;
437        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(48));
438        Ok(())
439    }
440
441    #[test]
442    fn sum_finish_resets_state() -> VortexResult<()> {
443        let mut ctx = LEGACY_SESSION.create_execution_ctx();
444        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
445        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?;
446
447        let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
448        acc.accumulate(&batch1, &mut ctx)?;
449        let result1 = acc.finish()?;
450        assert_eq!(result1.as_primitive().typed_value::<i64>(), Some(30));
451
452        let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
453        acc.accumulate(&batch2, &mut ctx)?;
454        let result2 = acc.finish()?;
455        assert_eq!(result2.as_primitive().typed_value::<i64>(), Some(18));
456        Ok(())
457    }
458
459    // State merge tests (vtable-level)
460
461    #[test]
462    fn sum_state_merge() -> VortexResult<()> {
463        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
464        let mut state = Sum.empty_partial(&EmptyOptions, &dtype)?;
465
466        let scalar1 = Scalar::primitive(100i64, Nullable);
467        Sum.combine_partials(&mut state, scalar1)?;
468
469        let scalar2 = Scalar::primitive(50i64, Nullable);
470        Sum.combine_partials(&mut state, scalar2)?;
471
472        let result = Sum.to_scalar(&state)?;
473        Sum.reset(&mut state);
474        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(150));
475        Ok(())
476    }
477
478    // Stats caching test
479
480    #[test]
481    fn sum_stats() -> VortexResult<()> {
482        let array = ChunkedArray::try_new(
483            vec![
484                PrimitiveArray::from_iter([1, 1, 1]).into_array(),
485                PrimitiveArray::from_iter([2, 2, 2]).into_array(),
486            ],
487            DType::Primitive(PType::I32, Nullability::NonNullable),
488        )
489        .vortex_expect("operation should succeed in test");
490        let array = array.into_array();
491        // compute sum with accumulator to populate stats
492        sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullable))?;
493
494        let sum_without_acc = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?;
495        assert_eq!(sum_without_acc, Scalar::primitive(9i64, Nullable));
496        Ok(())
497    }
498
499    // Constant float non-multiply test
500
501    #[test]
502    fn sum_constant_float_non_multiply() -> VortexResult<()> {
503        let acc = -2048669276050936500000000000f64;
504        let array = ConstantArray::new(6.1811675e16f64, 25);
505        let result = sum_with_accumulator(&array.into_array(), &Scalar::primitive(acc, Nullable))
506            .vortex_expect("operation should succeed in test");
507        assert_eq!(
508            f64::try_from(&result).vortex_expect("operation should succeed in test"),
509            -2048669274505644600000000000f64
510        );
511        Ok(())
512    }
513
514    // Grouped sum tests
515
516    fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult<ArrayRef> {
517        let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?;
518        acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?;
519        acc.finish()
520    }
521
522    #[test]
523    fn grouped_sum_fixed_size_list() -> VortexResult<()> {
524        let elements =
525            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array();
526        let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
527
528        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
529        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
530
531        let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array();
532        assert_arrays_eq!(&result, &expected);
533        Ok(())
534    }
535
536    #[test]
537    fn grouped_sum_with_null_elements() -> VortexResult<()> {
538        let elements =
539            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)])
540                .into_array();
541        let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
542
543        let elem_dtype = DType::Primitive(PType::I32, Nullable);
544        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
545
546        let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array();
547        assert_arrays_eq!(&result, &expected);
548        Ok(())
549    }
550
551    #[test]
552    fn grouped_sum_with_null_group() -> VortexResult<()> {
553        let elements =
554            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable)
555                .into_array();
556        let validity = Validity::from_iter([true, false, true]);
557        let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?;
558
559        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
560        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
561
562        let expected =
563            PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array();
564        assert_arrays_eq!(&result, &expected);
565        Ok(())
566    }
567
568    #[test]
569    fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> {
570        let elements =
571            PrimitiveArray::from_option_iter([None::<i32>, None, Some(3), Some(4)]).into_array();
572        let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?;
573
574        let elem_dtype = DType::Primitive(PType::I32, Nullable);
575        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
576
577        let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array();
578        assert_arrays_eq!(&result, &expected);
579        Ok(())
580    }
581
582    #[test]
583    fn grouped_sum_bool() -> VortexResult<()> {
584        let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect();
585        let groups =
586            FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?;
587
588        let elem_dtype = DType::Bool(Nullability::NonNullable);
589        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
590
591        let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array();
592        assert_arrays_eq!(&result, &expected);
593        Ok(())
594    }
595
596    #[test]
597    fn grouped_sum_finish_resets() -> VortexResult<()> {
598        let mut ctx = LEGACY_SESSION.create_execution_ctx();
599        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
600        let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype)?;
601
602        let elements1 =
603            PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
604        let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?;
605        acc.accumulate_list(&groups1.into_array(), &mut ctx)?;
606        let result1 = acc.finish()?;
607
608        let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array();
609        assert_arrays_eq!(&result1, &expected1);
610
611        let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
612        let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?;
613        acc.accumulate_list(&groups2.into_array(), &mut ctx)?;
614        let result2 = acc.finish()?;
615
616        let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array();
617        assert_arrays_eq!(&result2, &expected2);
618        Ok(())
619    }
620
621    #[test]
622    fn grouped_sum_listview_out_of_order_offsets_with_null_group() -> VortexResult<()> {
623        let elements =
624            PrimitiveArray::new(buffer![100i32, 200, 300], Validity::NonNullable).into_array();
625        let offsets = PrimitiveArray::new(buffer![2i32, 0, 1], Validity::NonNullable).into_array();
626        let sizes = PrimitiveArray::new(buffer![1i32, 1, 1], Validity::NonNullable).into_array();
627        let validity = Validity::from_iter([true, false, true]);
628        let groups = ListViewArray::try_new(elements, offsets, sizes, validity)?.into_array();
629
630        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
631        let result = run_grouped_sum(&groups, &elem_dtype)?;
632
633        // group 0 -> elements[2..3] = 300; group 1 -> null; group 2 -> elements[1..2] = 200.
634        let expected =
635            PrimitiveArray::from_option_iter([Some(300i64), None, Some(200i64)]).into_array();
636        assert_arrays_eq!(&result, &expected);
637        Ok(())
638    }
639
640    // Chunked array tests
641
642    #[test]
643    fn sum_chunked_floats_with_nulls() -> VortexResult<()> {
644        let chunk1 =
645            PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]);
646        let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]);
647        let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]);
648        let dtype = chunk1.dtype().clone();
649        let chunked = ChunkedArray::try_new(
650            vec![
651                chunk1.into_array(),
652                chunk2.into_array(),
653                chunk3.into_array(),
654            ],
655            dtype,
656        )?;
657
658        let result = sum(
659            &chunked.into_array(),
660            &mut LEGACY_SESSION.create_execution_ctx(),
661        )?;
662        assert_eq!(result.as_primitive().as_::<f64>(), Some(20.8));
663        Ok(())
664    }
665
666    #[test]
667    fn sum_chunked_floats_all_nulls_is_zero() -> VortexResult<()> {
668        let chunk1 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None, None]);
669        let chunk2 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None]);
670        let dtype = chunk1.dtype().clone();
671        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
672        let result = sum(
673            &chunked.into_array(),
674            &mut LEGACY_SESSION.create_execution_ctx(),
675        )?;
676        assert_eq!(result, Scalar::primitive(0f64, Nullable));
677        Ok(())
678    }
679
680    #[test]
681    fn sum_chunked_floats_empty_chunks() -> VortexResult<()> {
682        let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]);
683        let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullable), 0);
684        let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]);
685        let dtype = chunk1.dtype().clone();
686        let chunked = ChunkedArray::try_new(
687            vec![
688                chunk1.into_array(),
689                chunk2.into_array(),
690                chunk3.into_array(),
691            ],
692            dtype,
693        )?;
694
695        let result = sum(
696            &chunked.into_array(),
697            &mut LEGACY_SESSION.create_execution_ctx(),
698        )?;
699        assert_eq!(result.as_primitive().as_::<f64>(), Some(36.0));
700        Ok(())
701    }
702
703    #[test]
704    fn sum_chunked_int_almost_all_null() -> VortexResult<()> {
705        let chunk1 = PrimitiveArray::from_option_iter::<u32, _>(vec![Some(1)]);
706        let chunk2 = PrimitiveArray::from_option_iter::<u32, _>(vec![None]);
707        let dtype = chunk1.dtype().clone();
708        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
709
710        let result = sum(
711            &chunked.into_array(),
712            &mut LEGACY_SESSION.create_execution_ctx(),
713        )?;
714        assert_eq!(result.as_primitive().as_::<u64>(), Some(1));
715        Ok(())
716    }
717
718    #[test]
719    fn sum_chunked_decimals() -> VortexResult<()> {
720        let decimal_dtype = DecimalDType::new(10, 2);
721        let chunk1 = DecimalArray::new(
722            buffer![100i32, 100i32, 100i32, 100i32, 100i32],
723            decimal_dtype,
724            Validity::AllValid,
725        );
726        let chunk2 = DecimalArray::new(
727            buffer![200i32, 200i32, 200i32],
728            decimal_dtype,
729            Validity::AllValid,
730        );
731        let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid);
732        let dtype = chunk1.dtype().clone();
733        let chunked = ChunkedArray::try_new(
734            vec![
735                chunk1.into_array(),
736                chunk2.into_array(),
737                chunk3.into_array(),
738            ],
739            dtype,
740        )?;
741
742        let result = sum(
743            &chunked.into_array(),
744            &mut LEGACY_SESSION.create_execution_ctx(),
745        )?;
746        let decimal_result = result.as_decimal();
747        assert_eq!(
748            decimal_result.decimal_value(),
749            Some(DecimalValue::I256(i256::from_i128(1700)))
750        );
751        Ok(())
752    }
753
754    #[test]
755    fn sum_chunked_decimals_with_nulls() -> VortexResult<()> {
756        let decimal_dtype = DecimalDType::new(10, 2);
757        let chunk1 = DecimalArray::new(
758            buffer![100i32, 100i32, 100i32],
759            decimal_dtype,
760            Validity::AllValid,
761        );
762        let chunk2 = DecimalArray::new(
763            buffer![0i32, 0i32],
764            decimal_dtype,
765            Validity::from_iter([false, false]),
766        );
767        let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid);
768        let dtype = chunk1.dtype().clone();
769        let chunked = ChunkedArray::try_new(
770            vec![
771                chunk1.into_array(),
772                chunk2.into_array(),
773                chunk3.into_array(),
774            ],
775            dtype,
776        )?;
777
778        let result = sum(
779            &chunked.into_array(),
780            &mut LEGACY_SESSION.create_execution_ctx(),
781        )?;
782        let decimal_result = result.as_decimal();
783        assert_eq!(
784            decimal_result.decimal_value(),
785            Some(DecimalValue::I256(i256::from_i128(700)))
786        );
787        Ok(())
788    }
789
790    #[test]
791    fn sum_chunked_decimals_large() -> VortexResult<()> {
792        let decimal_dtype = DecimalDType::new(3, 0);
793        let chunk1 = ConstantArray::new(
794            Scalar::decimal(
795                DecimalValue::I16(500),
796                decimal_dtype,
797                Nullability::NonNullable,
798            ),
799            1,
800        );
801        let chunk2 = ConstantArray::new(
802            Scalar::decimal(
803                DecimalValue::I16(600),
804                decimal_dtype,
805                Nullability::NonNullable,
806            ),
807            1,
808        );
809        let dtype = chunk1.dtype().clone();
810        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
811
812        let result = sum(
813            &chunked.into_array(),
814            &mut LEGACY_SESSION.create_execution_ctx(),
815        )?;
816        let decimal_result = result.as_decimal();
817        assert_eq!(
818            decimal_result.decimal_value(),
819            Some(DecimalValue::I256(i256::from_i128(1100)))
820        );
821        assert_eq!(
822            result.dtype(),
823            &DType::Decimal(DecimalDType::new(13, 0), Nullable)
824        );
825        Ok(())
826    }
827}