Skip to main content

vortex_array/aggregate_fn/fns/sum/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod bool;
5mod constant;
6mod decimal;
7mod grouped;
8mod primitive;
9pub(crate) use grouped::PrimitiveGroupedSumEncodingKernel;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_error::vortex_panic;
15use vortex_session::VortexSession;
16use vortex_session::registry::CachedId;
17
18use self::bool::accumulate_bool;
19use self::constant::multiply_constant;
20use self::decimal::accumulate_decimal;
21use self::primitive::accumulate_primitive;
22use crate::ArrayRef;
23use crate::Canonical;
24use crate::Columnar;
25use crate::ExecutionCtx;
26use crate::aggregate_fn::Accumulator;
27use crate::aggregate_fn::AggregateFnId;
28use crate::aggregate_fn::AggregateFnVTable;
29use crate::aggregate_fn::DynAccumulator;
30use crate::aggregate_fn::NumericalAggregateOpts;
31use crate::dtype::DType;
32use crate::dtype::DecimalDType;
33use crate::dtype::MAX_PRECISION;
34use crate::dtype::Nullability;
35use crate::dtype::PType;
36use crate::expr::stats::Precision;
37use crate::expr::stats::Stat;
38use crate::expr::stats::StatsProvider;
39use crate::expr::stats::StatsProviderExt;
40use crate::scalar::DecimalValue;
41use crate::scalar::Scalar;
42
43/// Return the sum of an array.
44///
45/// See [`Sum`] for details.
46pub fn sum(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
47    // Short-circuit using cached array statistics.
48    if let Precision::Exact(sum_scalar) = array.statistics().get(Stat::Sum) {
49        return Ok(sum_scalar);
50    }
51
52    // Compute using Accumulator<Sum>.
53    // TODO(ngates): we may want to wrap this three-step dance up into an extension crate maybe.
54    let mut acc = Accumulator::try_new(
55        Sum,
56        NumericalAggregateOpts::default(),
57        array.dtype().clone(),
58    )?;
59    acc.accumulate(array, ctx)?;
60    let result = acc.finish()?;
61
62    // Cache the computed sum as a statistic (only if non-null, i.e. no overflow).
63    if let Some(val) = result.value().cloned() {
64        array.statistics().set(Stat::Sum, Precision::Exact(val));
65    }
66
67    Ok(result)
68}
69
70/// Sum an array, starting from zero.
71///
72/// If the sum overflows, a null scalar will be returned.
73/// If the array is all-invalid, the sum will be zero.
74///
75/// NaN handling for float inputs is controlled by [`NumericalAggregateOpts`]: with `skip_nans` (the
76/// default) NaN values contribute nothing, otherwise any NaN value poisons the sum to NaN.
77#[derive(Clone, Debug)]
78pub struct Sum;
79
80impl AggregateFnVTable for Sum {
81    type Options = NumericalAggregateOpts;
82    type Partial = SumPartial;
83
84    fn id(&self) -> AggregateFnId {
85        static ID: CachedId = CachedId::new("vortex.sum");
86        *ID
87    }
88
89    fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
90        Ok(Some(options.serialize()))
91    }
92
93    fn deserialize(
94        &self,
95        metadata: &[u8],
96        _session: &VortexSession,
97    ) -> VortexResult<Self::Options> {
98        NumericalAggregateOpts::deserialize(metadata)
99    }
100
101    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
102        // When a sum overflows, we return a sum _value_ of null. Therefore, we all return dtypes
103        // are nullable.
104        use Nullability::Nullable;
105
106        Some(match input_dtype {
107            DType::Bool(_) => DType::Primitive(PType::U64, Nullable),
108            DType::Primitive(ptype, _) => match ptype {
109                PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
110                    DType::Primitive(PType::U64, Nullable)
111                }
112                PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
113                    DType::Primitive(PType::I64, Nullable)
114                }
115                PType::F16 | PType::F32 | PType::F64 => {
116                    // Float sums cannot overflow, but all null floats still end up as null
117                    DType::Primitive(PType::F64, Nullable)
118                }
119            },
120            DType::Decimal(decimal_dtype, _) => {
121                // Both Spark and DataFusion use this heuristic.
122                // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
123                // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188
124                let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10);
125                DType::Decimal(
126                    DecimalDType::new(precision, decimal_dtype.scale()),
127                    Nullable,
128                )
129            }
130            // Unsupported types
131            _ => return None,
132        })
133    }
134
135    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
136        self.return_dtype(options, input_dtype)
137    }
138
139    fn empty_partial(
140        &self,
141        options: &Self::Options,
142        input_dtype: &DType,
143    ) -> VortexResult<Self::Partial> {
144        let return_dtype = self
145            .return_dtype(options, input_dtype)
146            .ok_or_else(|| vortex_err!("Unsupported sum dtype: {}", input_dtype))?;
147        let initial = make_zero_state(&return_dtype);
148
149        Ok(SumPartial {
150            return_dtype,
151            current: Some(initial),
152            skip_nans: options.skip_nans,
153        })
154    }
155
156    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
157        if other.is_null() {
158            // A null partial means the sub-accumulator saturated (overflow).
159            partial.current = None;
160            return Ok(());
161        }
162        let Some(ref mut inner) = partial.current else {
163            return Ok(());
164        };
165        let saturated = match inner {
166            SumState::Unsigned(acc) => {
167                let val = other
168                    .as_primitive()
169                    .typed_value::<u64>()
170                    .vortex_expect("checked non-null");
171                checked_add_u64(acc, val)
172            }
173            SumState::Signed(acc) => {
174                let val = other
175                    .as_primitive()
176                    .typed_value::<i64>()
177                    .vortex_expect("checked non-null");
178                checked_add_i64(acc, val)
179            }
180            SumState::Float(acc) => {
181                let val = other
182                    .as_primitive()
183                    .typed_value::<f64>()
184                    .vortex_expect("checked non-null");
185                *acc += val;
186                false
187            }
188            SumState::Decimal { value, dtype } => {
189                let val = other
190                    .as_decimal()
191                    .decimal_value()
192                    .vortex_expect("checked non-null");
193                match value.checked_add(&val) {
194                    Some(r) => {
195                        *value = r;
196                        !value.fits_in_precision(*dtype)
197                    }
198                    None => true,
199                }
200            }
201        };
202        if saturated {
203            partial.current = None;
204        }
205        Ok(())
206    }
207
208    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
209        Ok(match &partial.current {
210            None => Scalar::null(partial.return_dtype.as_nullable()),
211            Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable),
212            Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable),
213            Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable),
214            Some(SumState::Decimal { value, .. }) => {
215                let decimal_dtype = *partial
216                    .return_dtype
217                    .as_decimal_opt()
218                    .vortex_expect("return dtype must be decimal");
219                Scalar::decimal(*value, decimal_dtype, Nullability::Nullable)
220            }
221        })
222    }
223
224    fn reset(&self, partial: &mut Self::Partial) {
225        partial.current = Some(make_zero_state(&partial.return_dtype));
226    }
227
228    #[inline]
229    fn is_saturated(&self, partial: &Self::Partial) -> bool {
230        match partial.current.as_ref() {
231            None => true,
232            Some(SumState::Float(v)) => v.is_nan(),
233            Some(_) => false,
234        }
235    }
236
237    fn try_accumulate(
238        &self,
239        partial: &mut Self::Partial,
240        batch: &ArrayRef,
241        _ctx: &mut ExecutionCtx,
242    ) -> VortexResult<bool> {
243        // NaN-aware shortcircuits only apply to NaN-including float sums; everything else takes
244        // the default dispatch path.
245        if partial.skip_nans || !matches!(partial.current, Some(SumState::Float(_))) {
246            return Ok(false);
247        }
248        match batch.statistics().get_as::<u64>(Stat::NaNCount) {
249            Precision::Exact(0) => {
250                // NaN-free batch: the cached NaN-skipping sum (if any) equals the
251                // NaN-including sum.
252                if let Precision::Exact(sum) = batch.statistics().get(Stat::Sum) {
253                    let sum = if sum.dtype() == &partial.return_dtype {
254                        sum
255                    } else {
256                        sum.cast(&partial.return_dtype)?
257                    };
258                    self.combine_partials(partial, sum)?;
259                    return Ok(true);
260                }
261                Ok(false)
262            }
263            Precision::Exact(_) => {
264                // At least one NaN value: the sum is NaN without scanning the batch.
265                if let Some(SumState::Float(acc)) = partial.current.as_mut() {
266                    *acc = f64::NAN;
267                }
268                Ok(true)
269            }
270            _ => Ok(false),
271        }
272    }
273
274    fn accumulate(
275        &self,
276        partial: &mut Self::Partial,
277        batch: &Columnar,
278        ctx: &mut ExecutionCtx,
279    ) -> VortexResult<()> {
280        // Constants compute scalar * len and combine via combine_partials.
281        if let Columnar::Constant(c) = batch {
282            // NaN constants are treated as missing when skipping NaNs.
283            if partial.skip_nans && c.scalar().as_primitive_opt().is_some_and(|p| p.is_nan()) {
284                return Ok(());
285            }
286            if let Some(product) = multiply_constant(c.scalar(), c.len(), &partial.return_dtype)? {
287                self.combine_partials(partial, product)?;
288            }
289            return Ok(());
290        }
291
292        let skip_nans = partial.skip_nans;
293        let mut inner = match partial.current.take() {
294            Some(inner) => inner,
295            None => return Ok(()),
296        };
297
298        let result = match batch {
299            Columnar::Canonical(c) => match c {
300                Canonical::Primitive(p) => accumulate_primitive(&mut inner, p, ctx, skip_nans),
301                Canonical::Bool(b) => accumulate_bool(&mut inner, b, ctx),
302                Canonical::Decimal(d) => accumulate_decimal(&mut inner, d, ctx),
303                _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()),
304            },
305            Columnar::Constant(_) => unreachable!(),
306        };
307
308        match result {
309            Ok(false) => partial.current = Some(inner),
310            Ok(true) => {} // saturated: current stays None
311            Err(e) => {
312                partial.current = Some(inner);
313                return Err(e);
314            }
315        }
316        Ok(())
317    }
318
319    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
320        Ok(partials)
321    }
322
323    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
324        self.to_scalar(partial)
325    }
326}
327
328/// The group state for a sum aggregate, containing the accumulated value and configuration
329/// needed for reset/result without external context.
330pub struct SumPartial {
331    return_dtype: DType,
332    /// The current accumulated state, or `None` if saturated (checked overflow).
333    current: Option<SumState>,
334    /// Whether NaN values in float inputs are skipped.
335    skip_nans: bool,
336}
337
338/// The accumulated sum value.
339// TODO(ngates): instead of an enum, we should use a Box<dyn State> to avoid dispatcher over the
340//  input type every time? Perhaps?
341pub enum SumState {
342    Unsigned(u64),
343    Signed(i64),
344    Float(f64),
345    Decimal {
346        value: DecimalValue,
347        dtype: DecimalDType,
348    },
349}
350
351fn make_zero_state(return_dtype: &DType) -> SumState {
352    match return_dtype {
353        DType::Primitive(ptype, _) => match ptype {
354            PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0),
355            PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0),
356            PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0),
357        },
358        DType::Decimal(decimal, _) => SumState::Decimal {
359            value: DecimalValue::zero(decimal),
360            dtype: *decimal,
361        },
362        _ => vortex_panic!("Unsupported sum type"),
363    }
364}
365
366/// Checked add for u64, returning true if overflow occurred.
367#[inline(always)]
368fn checked_add_u64(acc: &mut u64, val: u64) -> bool {
369    match acc.checked_add(val) {
370        Some(r) => {
371            *acc = r;
372            false
373        }
374        None => true,
375    }
376}
377
378/// Checked add for i64, returning true if overflow occurred.
379#[inline(always)]
380fn checked_add_i64(acc: &mut i64, val: i64) -> bool {
381    match acc.checked_add(val) {
382        Some(r) => {
383            *acc = r;
384            false
385        }
386        None => true,
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use num_traits::CheckedAdd;
393    use vortex_buffer::buffer;
394    use vortex_error::VortexExpect;
395    use vortex_error::VortexResult;
396
397    use crate::ArrayRef;
398    use crate::IntoArray;
399    use crate::VortexSessionExecute;
400    use crate::aggregate_fn::Accumulator;
401    use crate::aggregate_fn::AggregateFnVTable;
402    use crate::aggregate_fn::DynAccumulator;
403    use crate::aggregate_fn::DynGroupedAccumulator;
404    use crate::aggregate_fn::GroupedAccumulator;
405    use crate::aggregate_fn::NumericalAggregateOpts;
406    use crate::aggregate_fn::fns::sum::Sum;
407    use crate::aggregate_fn::fns::sum::sum;
408    use crate::array_session;
409    use crate::arrays::BoolArray;
410    use crate::arrays::ChunkedArray;
411    use crate::arrays::ConstantArray;
412    use crate::arrays::DecimalArray;
413    use crate::arrays::FixedSizeListArray;
414    use crate::arrays::ListViewArray;
415    use crate::arrays::PrimitiveArray;
416    use crate::assert_arrays_eq;
417    use crate::dtype::DType;
418    use crate::dtype::DecimalDType;
419    use crate::dtype::Nullability;
420    use crate::dtype::Nullability::Nullable;
421    use crate::dtype::PType;
422    use crate::dtype::i256;
423    use crate::expr::stats::Precision;
424    use crate::expr::stats::Stat;
425    use crate::expr::stats::StatsProvider;
426    use crate::scalar::DecimalValue;
427    use crate::scalar::NumericOperator;
428    use crate::scalar::Scalar;
429    use crate::validity::Validity;
430
431    /// Sum an array with an initial value (test-only helper).
432    fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult<Scalar> {
433        let mut ctx = array_session().create_execution_ctx();
434        if accumulator.is_null() {
435            return Ok(accumulator.clone());
436        }
437        if accumulator.is_zero() == Some(true) {
438            return sum(array, &mut ctx);
439        }
440
441        let sum_dtype = Stat::Sum.dtype(array.dtype()).ok_or_else(|| {
442            vortex_error::vortex_err!("Sum not supported for dtype: {}", array.dtype())
443        })?;
444
445        // For non-float types, try statistics short-circuit with accumulator.
446        if !matches!(&sum_dtype, DType::Primitive(p, _) if p.is_float())
447            && let Precision::Exact(sum_scalar) = array.statistics().get(Stat::Sum)
448        {
449            return add_scalars(&sum_dtype, &sum_scalar, accumulator);
450        }
451
452        // Compute array sum from zero (also caches stats).
453        let array_sum = sum(array, &mut ctx)?;
454
455        // Combine with the accumulator.
456        add_scalars(&sum_dtype, &array_sum, accumulator)
457    }
458
459    /// Add two sum scalars with overflow checking.
460    fn add_scalars(sum_dtype: &DType, lhs: &Scalar, rhs: &Scalar) -> VortexResult<Scalar> {
461        if lhs.is_null() || rhs.is_null() {
462            return Ok(Scalar::null(sum_dtype.as_nullable()));
463        }
464
465        Ok(match sum_dtype {
466            DType::Primitive(ptype, _) if ptype.is_float() => {
467                let lhs_val = f64::try_from(lhs)?;
468                let rhs_val = f64::try_from(rhs)?;
469                Scalar::primitive(lhs_val + rhs_val, Nullable)
470            }
471            DType::Primitive(..) => lhs
472                .as_primitive()
473                .checked_add(&rhs.as_primitive())
474                .map(Scalar::from)
475                .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())),
476            DType::Decimal(..) => lhs
477                .as_decimal()
478                .checked_binary_numeric(&rhs.as_decimal(), NumericOperator::Add)
479                .map(Scalar::from)
480                .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())),
481            _ => unreachable!("Sum will always be a decimal or a primitive dtype"),
482        })
483    }
484
485    // Multi-batch and reset tests
486
487    #[test]
488    fn sum_multi_batch() -> VortexResult<()> {
489        let mut ctx = array_session().create_execution_ctx();
490        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
491        let mut acc = Accumulator::try_new(Sum, NumericalAggregateOpts::default(), dtype)?;
492
493        let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
494        acc.accumulate(&batch1, &mut ctx)?;
495
496        let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
497        acc.accumulate(&batch2, &mut ctx)?;
498
499        let result = acc.finish()?;
500        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(48));
501        Ok(())
502    }
503
504    #[test]
505    fn sum_finish_resets_state() -> VortexResult<()> {
506        let mut ctx = array_session().create_execution_ctx();
507        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
508        let mut acc = Accumulator::try_new(Sum, NumericalAggregateOpts::default(), dtype)?;
509
510        let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
511        acc.accumulate(&batch1, &mut ctx)?;
512        let result1 = acc.finish()?;
513        assert_eq!(result1.as_primitive().typed_value::<i64>(), Some(30));
514
515        let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
516        acc.accumulate(&batch2, &mut ctx)?;
517        let result2 = acc.finish()?;
518        assert_eq!(result2.as_primitive().typed_value::<i64>(), Some(18));
519        Ok(())
520    }
521
522    // State merge tests (vtable-level)
523
524    #[test]
525    fn sum_state_merge() -> VortexResult<()> {
526        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
527        let mut state = Sum.empty_partial(&NumericalAggregateOpts::default(), &dtype)?;
528
529        let scalar1 = Scalar::primitive(100i64, Nullable);
530        Sum.combine_partials(&mut state, scalar1)?;
531
532        let scalar2 = Scalar::primitive(50i64, Nullable);
533        Sum.combine_partials(&mut state, scalar2)?;
534
535        let result = Sum.to_scalar(&state)?;
536        Sum.reset(&mut state);
537        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(150));
538        Ok(())
539    }
540
541    // Stats caching test
542
543    #[test]
544    fn sum_stats() -> VortexResult<()> {
545        let array = ChunkedArray::try_new(
546            vec![
547                PrimitiveArray::from_iter([1, 1, 1]).into_array(),
548                PrimitiveArray::from_iter([2, 2, 2]).into_array(),
549            ],
550            DType::Primitive(PType::I32, Nullability::NonNullable),
551        )
552        .vortex_expect("operation should succeed in test");
553        let array = array.into_array();
554        // compute sum with accumulator to populate stats
555        sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullable))?;
556
557        let sum_without_acc = sum(&array, &mut array_session().create_execution_ctx())?;
558        assert_eq!(sum_without_acc, Scalar::primitive(9i64, Nullable));
559        Ok(())
560    }
561
562    // Constant float non-multiply test
563
564    #[test]
565    fn sum_constant_float_non_multiply() -> VortexResult<()> {
566        let acc = -2048669276050936500000000000f64;
567        let array = ConstantArray::new(6.1811675e16f64, 25);
568        let result = sum_with_accumulator(&array.into_array(), &Scalar::primitive(acc, Nullable))
569            .vortex_expect("operation should succeed in test");
570        assert_eq!(
571            f64::try_from(&result).vortex_expect("operation should succeed in test"),
572            -2048669274505644600000000000f64
573        );
574        Ok(())
575    }
576
577    // Grouped sum tests
578
579    fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult<ArrayRef> {
580        let mut acc = GroupedAccumulator::try_new(
581            Sum,
582            NumericalAggregateOpts::default(),
583            elem_dtype.clone(),
584        )?;
585        acc.accumulate_list(groups, &mut array_session().create_execution_ctx())?;
586        acc.finish()
587    }
588
589    #[test]
590    fn grouped_sum_fixed_size_list() -> VortexResult<()> {
591        let mut ctx = array_session().create_execution_ctx();
592        let elements =
593            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array();
594        let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
595
596        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
597        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
598
599        let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array();
600        assert_arrays_eq!(&result, &expected, &mut ctx);
601        Ok(())
602    }
603
604    #[test]
605    fn grouped_sum_with_null_elements() -> VortexResult<()> {
606        let mut ctx = array_session().create_execution_ctx();
607        let elements =
608            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)])
609                .into_array();
610        let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
611
612        let elem_dtype = DType::Primitive(PType::I32, Nullable);
613        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
614
615        let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array();
616        assert_arrays_eq!(&result, &expected, &mut ctx);
617        Ok(())
618    }
619
620    #[test]
621    fn grouped_sum_with_null_group() -> VortexResult<()> {
622        let mut ctx = array_session().create_execution_ctx();
623        let elements =
624            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable)
625                .into_array();
626        let validity = Validity::from_iter([true, false, true]);
627        let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?;
628
629        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
630        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
631
632        let expected =
633            PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array();
634        assert_arrays_eq!(&result, &expected, &mut ctx);
635        Ok(())
636    }
637
638    #[test]
639    fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> {
640        let mut ctx = array_session().create_execution_ctx();
641        let elements =
642            PrimitiveArray::from_option_iter([None::<i32>, None, Some(3), Some(4)]).into_array();
643        let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?;
644
645        let elem_dtype = DType::Primitive(PType::I32, Nullable);
646        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
647
648        let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array();
649        assert_arrays_eq!(&result, &expected, &mut ctx);
650        Ok(())
651    }
652
653    #[test]
654    fn grouped_sum_bool() -> VortexResult<()> {
655        let mut ctx = array_session().create_execution_ctx();
656        let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect();
657        let groups =
658            FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?;
659
660        let elem_dtype = DType::Bool(Nullability::NonNullable);
661        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
662
663        let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array();
664        assert_arrays_eq!(&result, &expected, &mut ctx);
665        Ok(())
666    }
667
668    #[test]
669    fn grouped_sum_finish_resets() -> VortexResult<()> {
670        let mut ctx = array_session().create_execution_ctx();
671        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
672        let mut acc =
673            GroupedAccumulator::try_new(Sum, NumericalAggregateOpts::default(), elem_dtype)?;
674
675        let elements1 =
676            PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
677        let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?;
678        acc.accumulate_list(&groups1.into_array(), &mut ctx)?;
679        let result1 = acc.finish()?;
680
681        let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array();
682        assert_arrays_eq!(&result1, &expected1, &mut ctx);
683
684        let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
685        let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?;
686        acc.accumulate_list(&groups2.into_array(), &mut ctx)?;
687        let result2 = acc.finish()?;
688
689        let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array();
690        assert_arrays_eq!(&result2, &expected2, &mut ctx);
691        Ok(())
692    }
693
694    #[test]
695    fn grouped_sum_listview_out_of_order_offsets_with_null_group() -> VortexResult<()> {
696        let mut ctx = array_session().create_execution_ctx();
697        let elements =
698            PrimitiveArray::new(buffer![100i32, 200, 300], Validity::NonNullable).into_array();
699        let offsets = PrimitiveArray::new(buffer![2i32, 0, 1], Validity::NonNullable).into_array();
700        let sizes = PrimitiveArray::new(buffer![1i32, 1, 1], Validity::NonNullable).into_array();
701        let validity = Validity::from_iter([true, false, true]);
702        let groups = ListViewArray::try_new(elements, offsets, sizes, validity)?.into_array();
703
704        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
705        let result = run_grouped_sum(&groups, &elem_dtype)?;
706
707        // group 0 -> elements[2..3] = 300; group 1 -> null; group 2 -> elements[1..2] = 200.
708        let expected =
709            PrimitiveArray::from_option_iter([Some(300i64), None, Some(200i64)]).into_array();
710        assert_arrays_eq!(&result, &expected, &mut ctx);
711        Ok(())
712    }
713
714    // Chunked array tests
715
716    #[test]
717    fn sum_chunked_floats_with_nulls() -> VortexResult<()> {
718        let chunk1 =
719            PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]);
720        let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]);
721        let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]);
722        let dtype = chunk1.dtype().clone();
723        let chunked = ChunkedArray::try_new(
724            vec![
725                chunk1.into_array(),
726                chunk2.into_array(),
727                chunk3.into_array(),
728            ],
729            dtype,
730        )?;
731
732        let result = sum(
733            &chunked.into_array(),
734            &mut array_session().create_execution_ctx(),
735        )?;
736        assert_eq!(result.as_primitive().as_::<f64>(), Some(20.8));
737        Ok(())
738    }
739
740    #[test]
741    fn sum_chunked_floats_all_nulls_is_zero() -> VortexResult<()> {
742        let chunk1 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None, None]);
743        let chunk2 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None]);
744        let dtype = chunk1.dtype().clone();
745        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
746        let result = sum(
747            &chunked.into_array(),
748            &mut array_session().create_execution_ctx(),
749        )?;
750        assert_eq!(result, Scalar::primitive(0f64, Nullable));
751        Ok(())
752    }
753
754    #[test]
755    fn sum_chunked_floats_empty_chunks() -> VortexResult<()> {
756        let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]);
757        let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullable), 0);
758        let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]);
759        let dtype = chunk1.dtype().clone();
760        let chunked = ChunkedArray::try_new(
761            vec![
762                chunk1.into_array(),
763                chunk2.into_array(),
764                chunk3.into_array(),
765            ],
766            dtype,
767        )?;
768
769        let result = sum(
770            &chunked.into_array(),
771            &mut array_session().create_execution_ctx(),
772        )?;
773        assert_eq!(result.as_primitive().as_::<f64>(), Some(36.0));
774        Ok(())
775    }
776
777    #[test]
778    fn sum_chunked_int_almost_all_null() -> VortexResult<()> {
779        let chunk1 = PrimitiveArray::from_option_iter::<u32, _>(vec![Some(1)]);
780        let chunk2 = PrimitiveArray::from_option_iter::<u32, _>(vec![None]);
781        let dtype = chunk1.dtype().clone();
782        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
783
784        let result = sum(
785            &chunked.into_array(),
786            &mut array_session().create_execution_ctx(),
787        )?;
788        assert_eq!(result.as_primitive().as_::<u64>(), Some(1));
789        Ok(())
790    }
791
792    #[test]
793    fn sum_chunked_decimals() -> VortexResult<()> {
794        let decimal_dtype = DecimalDType::new(10, 2);
795        let chunk1 = DecimalArray::new(
796            buffer![100i32, 100i32, 100i32, 100i32, 100i32],
797            decimal_dtype,
798            Validity::AllValid,
799        );
800        let chunk2 = DecimalArray::new(
801            buffer![200i32, 200i32, 200i32],
802            decimal_dtype,
803            Validity::AllValid,
804        );
805        let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid);
806        let dtype = chunk1.dtype().clone();
807        let chunked = ChunkedArray::try_new(
808            vec![
809                chunk1.into_array(),
810                chunk2.into_array(),
811                chunk3.into_array(),
812            ],
813            dtype,
814        )?;
815
816        let result = sum(
817            &chunked.into_array(),
818            &mut array_session().create_execution_ctx(),
819        )?;
820        let decimal_result = result.as_decimal();
821        assert_eq!(
822            decimal_result.decimal_value(),
823            Some(DecimalValue::I256(i256::from_i128(1700)))
824        );
825        Ok(())
826    }
827
828    #[test]
829    fn sum_chunked_decimals_with_nulls() -> VortexResult<()> {
830        let decimal_dtype = DecimalDType::new(10, 2);
831        let chunk1 = DecimalArray::new(
832            buffer![100i32, 100i32, 100i32],
833            decimal_dtype,
834            Validity::AllValid,
835        );
836        let chunk2 = DecimalArray::new(
837            buffer![0i32, 0i32],
838            decimal_dtype,
839            Validity::from_iter([false, false]),
840        );
841        let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid);
842        let dtype = chunk1.dtype().clone();
843        let chunked = ChunkedArray::try_new(
844            vec![
845                chunk1.into_array(),
846                chunk2.into_array(),
847                chunk3.into_array(),
848            ],
849            dtype,
850        )?;
851
852        let result = sum(
853            &chunked.into_array(),
854            &mut array_session().create_execution_ctx(),
855        )?;
856        let decimal_result = result.as_decimal();
857        assert_eq!(
858            decimal_result.decimal_value(),
859            Some(DecimalValue::I256(i256::from_i128(700)))
860        );
861        Ok(())
862    }
863
864    #[test]
865    fn sum_chunked_decimals_large() -> VortexResult<()> {
866        let decimal_dtype = DecimalDType::new(3, 0);
867        let chunk1 = ConstantArray::new(
868            Scalar::decimal(
869                DecimalValue::I16(500),
870                decimal_dtype,
871                Nullability::NonNullable,
872            ),
873            1,
874        );
875        let chunk2 = ConstantArray::new(
876            Scalar::decimal(
877                DecimalValue::I16(600),
878                decimal_dtype,
879                Nullability::NonNullable,
880            ),
881            1,
882        );
883        let dtype = chunk1.dtype().clone();
884        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
885
886        let result = sum(
887            &chunked.into_array(),
888            &mut array_session().create_execution_ctx(),
889        )?;
890        let decimal_result = result.as_decimal();
891        assert_eq!(
892            decimal_result.decimal_value(),
893            Some(DecimalValue::I256(i256::from_i128(1100)))
894        );
895        assert_eq!(
896            result.dtype(),
897            &DType::Decimal(DecimalDType::new(13, 0), Nullable)
898        );
899        Ok(())
900    }
901}