Skip to main content

vortex_array/aggregate_fn/fns/sum/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod bool;
5mod constant;
6mod decimal;
7mod primitive;
8
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_error::vortex_panic;
14
15use self::bool::accumulate_bool;
16use self::constant::multiply_constant;
17use self::decimal::accumulate_decimal;
18use self::primitive::accumulate_primitive;
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::Columnar;
22use crate::ExecutionCtx;
23use crate::aggregate_fn::Accumulator;
24use crate::aggregate_fn::AggregateFnId;
25use crate::aggregate_fn::AggregateFnVTable;
26use crate::aggregate_fn::DynAccumulator;
27use crate::aggregate_fn::EmptyOptions;
28use crate::dtype::DType;
29use crate::dtype::DecimalDType;
30use crate::dtype::MAX_PRECISION;
31use crate::dtype::Nullability;
32use crate::dtype::PType;
33use crate::expr::stats::Precision;
34use crate::expr::stats::Stat;
35use crate::expr::stats::StatsProvider;
36use crate::scalar::DecimalValue;
37use crate::scalar::Scalar;
38
39/// Return the sum of an array.
40///
41/// See [`Sum`] for details.
42pub fn sum(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
43    // Short-circuit using cached array statistics.
44    if let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) {
45        return Ok(sum_scalar);
46    }
47
48    // Compute using Accumulator<Sum>.
49    // TODO(ngates): we may want to wrap this three-step dance up into an extension crate maybe.
50    let mut acc = Accumulator::try_new(Sum, EmptyOptions, array.dtype().clone())?;
51    acc.accumulate(array, ctx)?;
52    let result = acc.finish()?;
53
54    // Cache the computed sum as a statistic (only if non-null, i.e. no overflow).
55    if let Some(val) = result.value().cloned() {
56        array.statistics().set(Stat::Sum, Precision::Exact(val));
57    }
58
59    Ok(result)
60}
61
62/// Sum an array, starting from zero.
63///
64/// If the sum overflows, a null scalar will be returned.
65/// If the array is all-invalid, the sum will be zero.
66#[derive(Clone, Debug)]
67pub struct Sum;
68
69impl AggregateFnVTable for Sum {
70    type Options = EmptyOptions;
71    type Partial = SumPartial;
72
73    fn id(&self) -> AggregateFnId {
74        AggregateFnId::new_ref("vortex.sum")
75    }
76
77    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
78        Ok(Some(vec![]))
79    }
80
81    fn deserialize(
82        &self,
83        _metadata: &[u8],
84        _session: &vortex_session::VortexSession,
85    ) -> VortexResult<Self::Options> {
86        Ok(EmptyOptions)
87    }
88
89    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
90        // When a sum overflows, we return a sum _value_ of null. Therefore, we all return dtypes
91        // are nullable.
92        use Nullability::Nullable;
93
94        Some(match input_dtype {
95            DType::Bool(_) => DType::Primitive(PType::U64, Nullable),
96            DType::Primitive(ptype, _) => match ptype {
97                PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
98                    DType::Primitive(PType::U64, Nullable)
99                }
100                PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
101                    DType::Primitive(PType::I64, Nullable)
102                }
103                PType::F16 | PType::F32 | PType::F64 => {
104                    // Float sums cannot overflow, but all null floats still end up as null
105                    DType::Primitive(PType::F64, Nullable)
106                }
107            },
108            DType::Decimal(decimal_dtype, _) => {
109                // Both Spark and DataFusion use this heuristic.
110                // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
111                // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188
112                let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10);
113                DType::Decimal(
114                    DecimalDType::new(precision, decimal_dtype.scale()),
115                    Nullable,
116                )
117            }
118            // Unsupported types
119            _ => return None,
120        })
121    }
122
123    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
124        self.return_dtype(options, input_dtype)
125    }
126
127    fn empty_partial(
128        &self,
129        options: &Self::Options,
130        input_dtype: &DType,
131    ) -> VortexResult<Self::Partial> {
132        let return_dtype = self
133            .return_dtype(options, input_dtype)
134            .ok_or_else(|| vortex_err!("Unsupported sum dtype: {}", input_dtype))?;
135        let initial = make_zero_state(&return_dtype);
136
137        Ok(SumPartial {
138            return_dtype,
139            current: Some(initial),
140        })
141    }
142
143    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
144        if other.is_null() {
145            // A null partial means the sub-accumulator saturated (overflow).
146            partial.current = None;
147            return Ok(());
148        }
149        let Some(ref mut inner) = partial.current else {
150            return Ok(());
151        };
152        let saturated = match inner {
153            SumState::Unsigned(acc) => {
154                let val = other
155                    .as_primitive()
156                    .typed_value::<u64>()
157                    .vortex_expect("checked non-null");
158                checked_add_u64(acc, val)
159            }
160            SumState::Signed(acc) => {
161                let val = other
162                    .as_primitive()
163                    .typed_value::<i64>()
164                    .vortex_expect("checked non-null");
165                checked_add_i64(acc, val)
166            }
167            SumState::Float(acc) => {
168                let val = other
169                    .as_primitive()
170                    .typed_value::<f64>()
171                    .vortex_expect("checked non-null");
172                *acc += val;
173                false
174            }
175            SumState::Decimal { value, dtype } => {
176                let val = other
177                    .as_decimal()
178                    .decimal_value()
179                    .vortex_expect("checked non-null");
180                match value.checked_add(&val) {
181                    Some(r) => {
182                        *value = r;
183                        !value.fits_in_precision(*dtype)
184                    }
185                    None => true,
186                }
187            }
188        };
189        if saturated {
190            partial.current = None;
191        }
192        Ok(())
193    }
194
195    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
196        Ok(match &partial.current {
197            None => Scalar::null(partial.return_dtype.as_nullable()),
198            Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable),
199            Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable),
200            Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable),
201            Some(SumState::Decimal { value, .. }) => {
202                let decimal_dtype = *partial
203                    .return_dtype
204                    .as_decimal_opt()
205                    .vortex_expect("return dtype must be decimal");
206                Scalar::decimal(*value, decimal_dtype, Nullability::Nullable)
207            }
208        })
209    }
210
211    fn reset(&self, partial: &mut Self::Partial) {
212        partial.current = Some(make_zero_state(&partial.return_dtype));
213    }
214
215    #[inline]
216    fn is_saturated(&self, partial: &Self::Partial) -> bool {
217        match partial.current.as_ref() {
218            None => true,
219            Some(SumState::Float(v)) => v.is_nan(),
220            Some(_) => false,
221        }
222    }
223
224    fn accumulate(
225        &self,
226        partial: &mut Self::Partial,
227        batch: &Columnar,
228        _ctx: &mut ExecutionCtx,
229    ) -> VortexResult<()> {
230        // Constants compute scalar * len and combine via combine_partials.
231        if let Columnar::Constant(c) = batch {
232            if let Some(product) = multiply_constant(c.scalar(), c.len(), &partial.return_dtype)? {
233                self.combine_partials(partial, product)?;
234            }
235            return Ok(());
236        }
237
238        let mut inner = match partial.current.take() {
239            Some(inner) => inner,
240            None => return Ok(()),
241        };
242
243        let result = match batch {
244            Columnar::Canonical(c) => match c {
245                Canonical::Primitive(p) => accumulate_primitive(&mut inner, p),
246                Canonical::Bool(b) => accumulate_bool(&mut inner, b),
247                Canonical::Decimal(d) => accumulate_decimal(&mut inner, d),
248                _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()),
249            },
250            Columnar::Constant(_) => unreachable!(),
251        };
252
253        match result {
254            Ok(false) => partial.current = Some(inner),
255            Ok(true) => {} // saturated: current stays None
256            Err(e) => {
257                partial.current = Some(inner);
258                return Err(e);
259            }
260        }
261        Ok(())
262    }
263
264    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
265        Ok(partials)
266    }
267
268    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
269        self.to_scalar(partial)
270    }
271}
272
273/// The group state for a sum aggregate, containing the accumulated value and configuration
274/// needed for reset/result without external context.
275pub struct SumPartial {
276    return_dtype: DType,
277    /// The current accumulated state, or `None` if saturated (checked overflow).
278    current: Option<SumState>,
279}
280
281/// The accumulated sum value.
282///
283// TODO(ngates): instead of an enum, we should use a Box<dyn State> to avoid dispatcher over the
284//  input type every time? Perhaps?
285pub enum SumState {
286    Unsigned(u64),
287    Signed(i64),
288    Float(f64),
289    Decimal {
290        value: DecimalValue,
291        dtype: DecimalDType,
292    },
293}
294
295fn make_zero_state(return_dtype: &DType) -> SumState {
296    match return_dtype {
297        DType::Primitive(ptype, _) => match ptype {
298            PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0),
299            PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0),
300            PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0),
301        },
302        DType::Decimal(decimal, _) => SumState::Decimal {
303            value: DecimalValue::zero(decimal),
304            dtype: *decimal,
305        },
306        _ => vortex_panic!("Unsupported sum type"),
307    }
308}
309
310/// Checked add for u64, returning true if overflow occurred.
311#[inline(always)]
312fn checked_add_u64(acc: &mut u64, val: u64) -> bool {
313    match acc.checked_add(val) {
314        Some(r) => {
315            *acc = r;
316            false
317        }
318        None => true,
319    }
320}
321
322/// Checked add for i64, returning true if overflow occurred.
323#[inline(always)]
324fn checked_add_i64(acc: &mut i64, val: i64) -> bool {
325    match acc.checked_add(val) {
326        Some(r) => {
327            *acc = r;
328            false
329        }
330        None => true,
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use num_traits::CheckedAdd;
337    use vortex_buffer::buffer;
338    use vortex_error::VortexExpect;
339    use vortex_error::VortexResult;
340
341    use crate::ArrayRef;
342    use crate::DynArray;
343    use crate::IntoArray;
344    use crate::LEGACY_SESSION;
345    use crate::VortexSessionExecute;
346    use crate::aggregate_fn::Accumulator;
347    use crate::aggregate_fn::AggregateFnVTable;
348    use crate::aggregate_fn::DynAccumulator;
349    use crate::aggregate_fn::DynGroupedAccumulator;
350    use crate::aggregate_fn::EmptyOptions;
351    use crate::aggregate_fn::GroupedAccumulator;
352    use crate::aggregate_fn::fns::sum::Sum;
353    use crate::aggregate_fn::fns::sum::sum;
354    use crate::arrays::BoolArray;
355    use crate::arrays::ChunkedArray;
356    use crate::arrays::ConstantArray;
357    use crate::arrays::DecimalArray;
358    use crate::arrays::FixedSizeListArray;
359    use crate::arrays::PrimitiveArray;
360    use crate::assert_arrays_eq;
361    use crate::dtype::DType;
362    use crate::dtype::DecimalDType;
363    use crate::dtype::Nullability;
364    use crate::dtype::Nullability::Nullable;
365    use crate::dtype::PType;
366    use crate::dtype::i256;
367    use crate::expr::stats::Precision;
368    use crate::expr::stats::Stat;
369    use crate::expr::stats::StatsProvider;
370    use crate::scalar::DecimalValue;
371    use crate::scalar::NumericOperator;
372    use crate::scalar::Scalar;
373    use crate::validity::Validity;
374
375    /// Sum an array with an initial value (test-only helper).
376    fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult<Scalar> {
377        let mut ctx = LEGACY_SESSION.create_execution_ctx();
378        if accumulator.is_null() {
379            return Ok(accumulator.clone());
380        }
381        if accumulator.is_zero() == Some(true) {
382            return sum(array, &mut ctx);
383        }
384
385        let sum_dtype = Stat::Sum.dtype(array.dtype()).ok_or_else(|| {
386            vortex_error::vortex_err!("Sum not supported for dtype: {}", array.dtype())
387        })?;
388
389        // For non-float types, try statistics short-circuit with accumulator.
390        if !matches!(&sum_dtype, DType::Primitive(p, _) if p.is_float())
391            && let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum)
392        {
393            return add_scalars(&sum_dtype, &sum_scalar, accumulator);
394        }
395
396        // Compute array sum from zero (also caches stats).
397        let array_sum = sum(array, &mut ctx)?;
398
399        // Combine with the accumulator.
400        add_scalars(&sum_dtype, &array_sum, accumulator)
401    }
402
403    /// Add two sum scalars with overflow checking.
404    fn add_scalars(sum_dtype: &DType, lhs: &Scalar, rhs: &Scalar) -> VortexResult<Scalar> {
405        if lhs.is_null() || rhs.is_null() {
406            return Ok(Scalar::null(sum_dtype.as_nullable()));
407        }
408
409        Ok(match sum_dtype {
410            DType::Primitive(ptype, _) if ptype.is_float() => {
411                let lhs_val = f64::try_from(lhs)?;
412                let rhs_val = f64::try_from(rhs)?;
413                Scalar::primitive(lhs_val + rhs_val, Nullable)
414            }
415            DType::Primitive(..) => lhs
416                .as_primitive()
417                .checked_add(&rhs.as_primitive())
418                .map(Scalar::from)
419                .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())),
420            DType::Decimal(..) => lhs
421                .as_decimal()
422                .checked_binary_numeric(&rhs.as_decimal(), NumericOperator::Add)
423                .map(Scalar::from)
424                .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())),
425            _ => unreachable!("Sum will always be a decimal or a primitive dtype"),
426        })
427    }
428
429    // Multi-batch and reset tests
430
431    #[test]
432    fn sum_multi_batch() -> VortexResult<()> {
433        let mut ctx = LEGACY_SESSION.create_execution_ctx();
434        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
435        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?;
436
437        let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
438        acc.accumulate(&batch1, &mut ctx)?;
439
440        let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
441        acc.accumulate(&batch2, &mut ctx)?;
442
443        let result = acc.finish()?;
444        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(48));
445        Ok(())
446    }
447
448    #[test]
449    fn sum_finish_resets_state() -> VortexResult<()> {
450        let mut ctx = LEGACY_SESSION.create_execution_ctx();
451        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
452        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?;
453
454        let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
455        acc.accumulate(&batch1, &mut ctx)?;
456        let result1 = acc.finish()?;
457        assert_eq!(result1.as_primitive().typed_value::<i64>(), Some(30));
458
459        let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
460        acc.accumulate(&batch2, &mut ctx)?;
461        let result2 = acc.finish()?;
462        assert_eq!(result2.as_primitive().typed_value::<i64>(), Some(18));
463        Ok(())
464    }
465
466    // State merge tests (vtable-level)
467
468    #[test]
469    fn sum_state_merge() -> VortexResult<()> {
470        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
471        let mut state = Sum.empty_partial(&EmptyOptions, &dtype)?;
472
473        let scalar1 = Scalar::primitive(100i64, Nullable);
474        Sum.combine_partials(&mut state, scalar1)?;
475
476        let scalar2 = Scalar::primitive(50i64, Nullable);
477        Sum.combine_partials(&mut state, scalar2)?;
478
479        let result = Sum.to_scalar(&state)?;
480        Sum.reset(&mut state);
481        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(150));
482        Ok(())
483    }
484
485    // Stats caching test
486
487    #[test]
488    fn sum_stats() -> VortexResult<()> {
489        let array = ChunkedArray::try_new(
490            vec![
491                PrimitiveArray::from_iter([1, 1, 1]).into_array(),
492                PrimitiveArray::from_iter([2, 2, 2]).into_array(),
493            ],
494            DType::Primitive(PType::I32, Nullability::NonNullable),
495        )
496        .vortex_expect("operation should succeed in test");
497        let array = array.into_array();
498        // compute sum with accumulator to populate stats
499        sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullable))?;
500
501        let sum_without_acc = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?;
502        assert_eq!(sum_without_acc, Scalar::primitive(9i64, Nullable));
503        Ok(())
504    }
505
506    // Constant float non-multiply test
507
508    #[test]
509    fn sum_constant_float_non_multiply() -> VortexResult<()> {
510        let acc = -2048669276050936500000000000f64;
511        let array = ConstantArray::new(6.1811675e16f64, 25);
512        let result = sum_with_accumulator(&array.into_array(), &Scalar::primitive(acc, Nullable))
513            .vortex_expect("operation should succeed in test");
514        assert_eq!(
515            f64::try_from(&result).vortex_expect("operation should succeed in test"),
516            -2048669274505644600000000000f64
517        );
518        Ok(())
519    }
520
521    // Grouped sum tests
522
523    fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult<ArrayRef> {
524        let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?;
525        acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?;
526        acc.finish()
527    }
528
529    #[test]
530    fn grouped_sum_fixed_size_list() -> VortexResult<()> {
531        let elements =
532            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array();
533        let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
534
535        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
536        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
537
538        let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array();
539        assert_arrays_eq!(&result, &expected);
540        Ok(())
541    }
542
543    #[test]
544    fn grouped_sum_with_null_elements() -> VortexResult<()> {
545        let elements =
546            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)])
547                .into_array();
548        let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
549
550        let elem_dtype = DType::Primitive(PType::I32, Nullable);
551        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
552
553        let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array();
554        assert_arrays_eq!(&result, &expected);
555        Ok(())
556    }
557
558    #[test]
559    fn grouped_sum_with_null_group() -> VortexResult<()> {
560        let elements =
561            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable)
562                .into_array();
563        let validity = Validity::from_iter([true, false, true]);
564        let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?;
565
566        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
567        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
568
569        let expected =
570            PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array();
571        assert_arrays_eq!(&result, &expected);
572        Ok(())
573    }
574
575    #[test]
576    fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> {
577        let elements =
578            PrimitiveArray::from_option_iter([None::<i32>, None, Some(3), Some(4)]).into_array();
579        let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?;
580
581        let elem_dtype = DType::Primitive(PType::I32, Nullable);
582        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
583
584        let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array();
585        assert_arrays_eq!(&result, &expected);
586        Ok(())
587    }
588
589    #[test]
590    fn grouped_sum_bool() -> VortexResult<()> {
591        let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect();
592        let groups =
593            FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?;
594
595        let elem_dtype = DType::Bool(Nullability::NonNullable);
596        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
597
598        let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array();
599        assert_arrays_eq!(&result, &expected);
600        Ok(())
601    }
602
603    #[test]
604    fn grouped_sum_finish_resets() -> VortexResult<()> {
605        let mut ctx = LEGACY_SESSION.create_execution_ctx();
606        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
607        let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype)?;
608
609        let elements1 =
610            PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
611        let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?;
612        acc.accumulate_list(&groups1.into_array(), &mut ctx)?;
613        let result1 = acc.finish()?;
614
615        let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array();
616        assert_arrays_eq!(&result1, &expected1);
617
618        let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
619        let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?;
620        acc.accumulate_list(&groups2.into_array(), &mut ctx)?;
621        let result2 = acc.finish()?;
622
623        let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array();
624        assert_arrays_eq!(&result2, &expected2);
625        Ok(())
626    }
627
628    // Chunked array tests
629
630    #[test]
631    fn sum_chunked_floats_with_nulls() -> VortexResult<()> {
632        let chunk1 =
633            PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]);
634        let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]);
635        let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]);
636        let dtype = chunk1.dtype().clone();
637        let chunked = ChunkedArray::try_new(
638            vec![
639                chunk1.into_array(),
640                chunk2.into_array(),
641                chunk3.into_array(),
642            ],
643            dtype,
644        )?;
645
646        let result = sum(
647            &chunked.into_array(),
648            &mut LEGACY_SESSION.create_execution_ctx(),
649        )?;
650        assert_eq!(result.as_primitive().as_::<f64>(), Some(20.8));
651        Ok(())
652    }
653
654    #[test]
655    fn sum_chunked_floats_all_nulls_is_zero() -> VortexResult<()> {
656        let chunk1 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None, None]);
657        let chunk2 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None]);
658        let dtype = chunk1.dtype().clone();
659        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
660        let result = sum(
661            &chunked.into_array(),
662            &mut LEGACY_SESSION.create_execution_ctx(),
663        )?;
664        assert_eq!(result, Scalar::primitive(0f64, Nullable));
665        Ok(())
666    }
667
668    #[test]
669    fn sum_chunked_floats_empty_chunks() -> VortexResult<()> {
670        let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]);
671        let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullable), 0);
672        let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]);
673        let dtype = chunk1.dtype().clone();
674        let chunked = ChunkedArray::try_new(
675            vec![
676                chunk1.into_array(),
677                chunk2.into_array(),
678                chunk3.into_array(),
679            ],
680            dtype,
681        )?;
682
683        let result = sum(
684            &chunked.into_array(),
685            &mut LEGACY_SESSION.create_execution_ctx(),
686        )?;
687        assert_eq!(result.as_primitive().as_::<f64>(), Some(36.0));
688        Ok(())
689    }
690
691    #[test]
692    fn sum_chunked_int_almost_all_null() -> VortexResult<()> {
693        let chunk1 = PrimitiveArray::from_option_iter::<u32, _>(vec![Some(1)]);
694        let chunk2 = PrimitiveArray::from_option_iter::<u32, _>(vec![None]);
695        let dtype = chunk1.dtype().clone();
696        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
697
698        let result = sum(
699            &chunked.into_array(),
700            &mut LEGACY_SESSION.create_execution_ctx(),
701        )?;
702        assert_eq!(result.as_primitive().as_::<u64>(), Some(1));
703        Ok(())
704    }
705
706    #[test]
707    fn sum_chunked_decimals() -> VortexResult<()> {
708        let decimal_dtype = DecimalDType::new(10, 2);
709        let chunk1 = DecimalArray::new(
710            buffer![100i32, 100i32, 100i32, 100i32, 100i32],
711            decimal_dtype,
712            Validity::AllValid,
713        );
714        let chunk2 = DecimalArray::new(
715            buffer![200i32, 200i32, 200i32],
716            decimal_dtype,
717            Validity::AllValid,
718        );
719        let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid);
720        let dtype = chunk1.dtype().clone();
721        let chunked = ChunkedArray::try_new(
722            vec![
723                chunk1.into_array(),
724                chunk2.into_array(),
725                chunk3.into_array(),
726            ],
727            dtype,
728        )?;
729
730        let result = sum(
731            &chunked.into_array(),
732            &mut LEGACY_SESSION.create_execution_ctx(),
733        )?;
734        let decimal_result = result.as_decimal();
735        assert_eq!(
736            decimal_result.decimal_value(),
737            Some(DecimalValue::I256(i256::from_i128(1700)))
738        );
739        Ok(())
740    }
741
742    #[test]
743    fn sum_chunked_decimals_with_nulls() -> VortexResult<()> {
744        let decimal_dtype = DecimalDType::new(10, 2);
745        let chunk1 = DecimalArray::new(
746            buffer![100i32, 100i32, 100i32],
747            decimal_dtype,
748            Validity::AllValid,
749        );
750        let chunk2 = DecimalArray::new(
751            buffer![0i32, 0i32],
752            decimal_dtype,
753            Validity::from_iter([false, false]),
754        );
755        let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid);
756        let dtype = chunk1.dtype().clone();
757        let chunked = ChunkedArray::try_new(
758            vec![
759                chunk1.into_array(),
760                chunk2.into_array(),
761                chunk3.into_array(),
762            ],
763            dtype,
764        )?;
765
766        let result = sum(
767            &chunked.into_array(),
768            &mut LEGACY_SESSION.create_execution_ctx(),
769        )?;
770        let decimal_result = result.as_decimal();
771        assert_eq!(
772            decimal_result.decimal_value(),
773            Some(DecimalValue::I256(i256::from_i128(700)))
774        );
775        Ok(())
776    }
777
778    #[test]
779    fn sum_chunked_decimals_large() -> VortexResult<()> {
780        let decimal_dtype = DecimalDType::new(3, 0);
781        let chunk1 = ConstantArray::new(
782            Scalar::decimal(
783                DecimalValue::I16(500),
784                decimal_dtype,
785                Nullability::NonNullable,
786            ),
787            1,
788        );
789        let chunk2 = ConstantArray::new(
790            Scalar::decimal(
791                DecimalValue::I16(600),
792                decimal_dtype,
793                Nullability::NonNullable,
794            ),
795            1,
796        );
797        let dtype = chunk1.dtype().clone();
798        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
799
800        let result = sum(
801            &chunked.into_array(),
802            &mut LEGACY_SESSION.create_execution_ctx(),
803        )?;
804        let decimal_result = result.as_decimal();
805        assert_eq!(
806            decimal_result.decimal_value(),
807            Some(DecimalValue::I256(i256::from_i128(1100)))
808        );
809        assert_eq!(
810            result.dtype(),
811            &DType::Decimal(DecimalDType::new(13, 0), Nullable)
812        );
813        Ok(())
814    }
815}