Skip to main content

vortex_array/aggregate_fn/fns/sum/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod bool;
5mod constant;
6mod decimal;
7mod primitive;
8
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_error::vortex_panic;
14
15use self::bool::accumulate_bool;
16use self::constant::multiply_constant;
17use self::decimal::accumulate_decimal;
18use self::primitive::accumulate_primitive;
19use crate::ArrayRef;
20use crate::Canonical;
21use crate::Columnar;
22use crate::ExecutionCtx;
23use crate::aggregate_fn::Accumulator;
24use crate::aggregate_fn::AggregateFnId;
25use crate::aggregate_fn::AggregateFnVTable;
26use crate::aggregate_fn::DynAccumulator;
27use crate::aggregate_fn::EmptyOptions;
28use crate::dtype::DType;
29use crate::dtype::DecimalDType;
30use crate::dtype::MAX_PRECISION;
31use crate::dtype::Nullability;
32use crate::dtype::PType;
33use crate::expr::stats::Precision;
34use crate::expr::stats::Stat;
35use crate::expr::stats::StatsProvider;
36use crate::scalar::DecimalValue;
37use crate::scalar::Scalar;
38
39/// Return the sum of an array.
40///
41/// See [`Sum`] for details.
42pub fn sum(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
43    // Short-circuit using cached array statistics.
44    if let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) {
45        return Ok(sum_scalar);
46    }
47
48    // Compute using Accumulator<Sum>.
49    // TODO(ngates): we may want to wrap this three-step dance up into an extension crate maybe.
50    let mut acc = Accumulator::try_new(Sum, EmptyOptions, array.dtype().clone())?;
51    acc.accumulate(array, ctx)?;
52    let result = acc.finish()?;
53
54    // Cache the computed sum as a statistic (only if non-null, i.e. no overflow).
55    if let Some(val) = result.value().cloned() {
56        array.statistics().set(Stat::Sum, Precision::Exact(val));
57    }
58
59    Ok(result)
60}
61
62/// Sum an array, starting from zero.
63///
64/// If the sum overflows, a null scalar will be returned.
65/// If the array is all-invalid, the sum will be zero.
66#[derive(Clone, Debug)]
67pub struct Sum;
68
69impl AggregateFnVTable for Sum {
70    type Options = EmptyOptions;
71    type Partial = SumPartial;
72
73    fn id(&self) -> AggregateFnId {
74        AggregateFnId::new_ref("vortex.sum")
75    }
76
77    fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
78        unimplemented!("Sum is not yet serializable");
79    }
80
81    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
82        // When a sum overflows, we return a sum _value_ of null. Therefore, we all return dtypes
83        // are nullable.
84        use Nullability::Nullable;
85
86        Some(match input_dtype {
87            DType::Bool(_) => DType::Primitive(PType::U64, Nullable),
88            DType::Primitive(ptype, _) => match ptype {
89                PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
90                    DType::Primitive(PType::U64, Nullable)
91                }
92                PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
93                    DType::Primitive(PType::I64, Nullable)
94                }
95                PType::F16 | PType::F32 | PType::F64 => {
96                    // Float sums cannot overflow, but all null floats still end up as null
97                    DType::Primitive(PType::F64, Nullable)
98                }
99            },
100            DType::Decimal(decimal_dtype, _) => {
101                // Both Spark and DataFusion use this heuristic.
102                // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
103                // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188
104                let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10);
105                DType::Decimal(
106                    DecimalDType::new(precision, decimal_dtype.scale()),
107                    Nullable,
108                )
109            }
110            // Unsupported types
111            _ => return None,
112        })
113    }
114
115    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
116        self.return_dtype(options, input_dtype)
117    }
118
119    fn empty_partial(
120        &self,
121        options: &Self::Options,
122        input_dtype: &DType,
123    ) -> VortexResult<Self::Partial> {
124        let return_dtype = self
125            .return_dtype(options, input_dtype)
126            .ok_or_else(|| vortex_err!("Unsupported sum dtype: {}", input_dtype))?;
127        let initial = make_zero_state(&return_dtype);
128
129        Ok(SumPartial {
130            return_dtype,
131            current: Some(initial),
132        })
133    }
134
135    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
136        if other.is_null() {
137            // A null partial means the sub-accumulator saturated (overflow).
138            partial.current = None;
139            return Ok(());
140        }
141        let Some(ref mut inner) = partial.current else {
142            return Ok(());
143        };
144        let saturated = match inner {
145            SumState::Unsigned(acc) => {
146                let val = other
147                    .as_primitive()
148                    .typed_value::<u64>()
149                    .vortex_expect("checked non-null");
150                checked_add_u64(acc, val)
151            }
152            SumState::Signed(acc) => {
153                let val = other
154                    .as_primitive()
155                    .typed_value::<i64>()
156                    .vortex_expect("checked non-null");
157                checked_add_i64(acc, val)
158            }
159            SumState::Float(acc) => {
160                let val = other
161                    .as_primitive()
162                    .typed_value::<f64>()
163                    .vortex_expect("checked non-null");
164                *acc += val;
165                false
166            }
167            SumState::Decimal { value, dtype } => {
168                let val = other
169                    .as_decimal()
170                    .decimal_value()
171                    .vortex_expect("checked non-null");
172                match value.checked_add(&val) {
173                    Some(r) => {
174                        *value = r;
175                        !value.fits_in_precision(*dtype)
176                    }
177                    None => true,
178                }
179            }
180        };
181        if saturated {
182            partial.current = None;
183        }
184        Ok(())
185    }
186
187    fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
188        Ok(match &partial.current {
189            None => Scalar::null(partial.return_dtype.as_nullable()),
190            Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable),
191            Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable),
192            Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable),
193            Some(SumState::Decimal { value, .. }) => {
194                let decimal_dtype = *partial
195                    .return_dtype
196                    .as_decimal_opt()
197                    .vortex_expect("return dtype must be decimal");
198                Scalar::decimal(*value, decimal_dtype, Nullability::Nullable)
199            }
200        })
201    }
202
203    fn reset(&self, partial: &mut Self::Partial) {
204        partial.current = Some(make_zero_state(&partial.return_dtype));
205    }
206
207    #[inline]
208    fn is_saturated(&self, partial: &Self::Partial) -> bool {
209        match partial.current.as_ref() {
210            None => true,
211            Some(SumState::Float(v)) => v.is_nan(),
212            Some(_) => false,
213        }
214    }
215
216    fn accumulate(
217        &self,
218        partial: &mut Self::Partial,
219        batch: &Columnar,
220        _ctx: &mut ExecutionCtx,
221    ) -> VortexResult<()> {
222        // Constants compute scalar * len and combine via combine_partials.
223        if let Columnar::Constant(c) = batch {
224            if let Some(product) = multiply_constant(c.scalar(), c.len(), &partial.return_dtype)? {
225                self.combine_partials(partial, product)?;
226            }
227            return Ok(());
228        }
229
230        let mut inner = match partial.current.take() {
231            Some(inner) => inner,
232            None => return Ok(()),
233        };
234
235        let result = match batch {
236            Columnar::Canonical(c) => match c {
237                Canonical::Primitive(p) => accumulate_primitive(&mut inner, p),
238                Canonical::Bool(b) => accumulate_bool(&mut inner, b),
239                Canonical::Decimal(d) => accumulate_decimal(&mut inner, d),
240                _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()),
241            },
242            Columnar::Constant(_) => unreachable!(),
243        };
244
245        match result {
246            Ok(false) => partial.current = Some(inner),
247            Ok(true) => {} // saturated: current stays None
248            Err(e) => {
249                partial.current = Some(inner);
250                return Err(e);
251            }
252        }
253        Ok(())
254    }
255
256    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
257        Ok(partials)
258    }
259
260    fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
261        self.to_scalar(partial)
262    }
263}
264
265/// The group state for a sum aggregate, containing the accumulated value and configuration
266/// needed for reset/result without external context.
267pub struct SumPartial {
268    return_dtype: DType,
269    /// The current accumulated state, or `None` if saturated (checked overflow).
270    current: Option<SumState>,
271}
272
273/// The accumulated sum value.
274///
275// TODO(ngates): instead of an enum, we should use a Box<dyn State> to avoid dispatcher over the
276//  input type every time? Perhaps?
277pub enum SumState {
278    Unsigned(u64),
279    Signed(i64),
280    Float(f64),
281    Decimal {
282        value: DecimalValue,
283        dtype: DecimalDType,
284    },
285}
286
287fn make_zero_state(return_dtype: &DType) -> SumState {
288    match return_dtype {
289        DType::Primitive(ptype, _) => match ptype {
290            PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0),
291            PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0),
292            PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0),
293        },
294        DType::Decimal(decimal, _) => SumState::Decimal {
295            value: DecimalValue::zero(decimal),
296            dtype: *decimal,
297        },
298        _ => vortex_panic!("Unsupported sum type"),
299    }
300}
301
302/// Checked add for u64, returning true if overflow occurred.
303#[inline(always)]
304fn checked_add_u64(acc: &mut u64, val: u64) -> bool {
305    match acc.checked_add(val) {
306        Some(r) => {
307            *acc = r;
308            false
309        }
310        None => true,
311    }
312}
313
314/// Checked add for i64, returning true if overflow occurred.
315#[inline(always)]
316fn checked_add_i64(acc: &mut i64, val: i64) -> bool {
317    match acc.checked_add(val) {
318        Some(r) => {
319            *acc = r;
320            false
321        }
322        None => true,
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use num_traits::CheckedAdd;
329    use vortex_buffer::buffer;
330    use vortex_error::VortexExpect;
331    use vortex_error::VortexResult;
332
333    use crate::ArrayRef;
334    use crate::IntoArray;
335    use crate::LEGACY_SESSION;
336    use crate::VortexSessionExecute;
337    use crate::aggregate_fn::Accumulator;
338    use crate::aggregate_fn::AggregateFnVTable;
339    use crate::aggregate_fn::DynAccumulator;
340    use crate::aggregate_fn::DynGroupedAccumulator;
341    use crate::aggregate_fn::EmptyOptions;
342    use crate::aggregate_fn::GroupedAccumulator;
343    use crate::aggregate_fn::fns::sum::Sum;
344    use crate::aggregate_fn::fns::sum::sum;
345    use crate::arrays::BoolArray;
346    use crate::arrays::ChunkedArray;
347    use crate::arrays::ConstantArray;
348    use crate::arrays::DecimalArray;
349    use crate::arrays::FixedSizeListArray;
350    use crate::arrays::PrimitiveArray;
351    use crate::assert_arrays_eq;
352    use crate::dtype::DType;
353    use crate::dtype::DecimalDType;
354    use crate::dtype::Nullability;
355    use crate::dtype::Nullability::Nullable;
356    use crate::dtype::PType;
357    use crate::dtype::i256;
358    use crate::expr::stats::Precision;
359    use crate::expr::stats::Stat;
360    use crate::expr::stats::StatsProvider;
361    use crate::scalar::DecimalValue;
362    use crate::scalar::NumericOperator;
363    use crate::scalar::Scalar;
364    use crate::validity::Validity;
365
366    /// Sum an array with an initial value (test-only helper).
367    fn sum_with_accumulator(array: &ArrayRef, accumulator: &Scalar) -> VortexResult<Scalar> {
368        let mut ctx = LEGACY_SESSION.create_execution_ctx();
369        if accumulator.is_null() {
370            return Ok(accumulator.clone());
371        }
372        if accumulator.is_zero() == Some(true) {
373            return sum(array, &mut ctx);
374        }
375
376        let sum_dtype = Stat::Sum.dtype(array.dtype()).ok_or_else(|| {
377            vortex_error::vortex_err!("Sum not supported for dtype: {}", array.dtype())
378        })?;
379
380        // For non-float types, try statistics short-circuit with accumulator.
381        if !matches!(&sum_dtype, DType::Primitive(p, _) if p.is_float())
382            && let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum)
383        {
384            return add_scalars(&sum_dtype, &sum_scalar, accumulator);
385        }
386
387        // Compute array sum from zero (also caches stats).
388        let array_sum = sum(array, &mut ctx)?;
389
390        // Combine with the accumulator.
391        add_scalars(&sum_dtype, &array_sum, accumulator)
392    }
393
394    /// Add two sum scalars with overflow checking.
395    fn add_scalars(sum_dtype: &DType, lhs: &Scalar, rhs: &Scalar) -> VortexResult<Scalar> {
396        if lhs.is_null() || rhs.is_null() {
397            return Ok(Scalar::null(sum_dtype.as_nullable()));
398        }
399
400        Ok(match sum_dtype {
401            DType::Primitive(ptype, _) if ptype.is_float() => {
402                let lhs_val = f64::try_from(lhs)?;
403                let rhs_val = f64::try_from(rhs)?;
404                Scalar::primitive(lhs_val + rhs_val, Nullable)
405            }
406            DType::Primitive(..) => lhs
407                .as_primitive()
408                .checked_add(&rhs.as_primitive())
409                .map(Scalar::from)
410                .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())),
411            DType::Decimal(..) => lhs
412                .as_decimal()
413                .checked_binary_numeric(&rhs.as_decimal(), NumericOperator::Add)
414                .map(Scalar::from)
415                .unwrap_or_else(|| Scalar::null(sum_dtype.as_nullable())),
416            _ => unreachable!("Sum will always be a decimal or a primitive dtype"),
417        })
418    }
419
420    // Multi-batch and reset tests
421
422    #[test]
423    fn sum_multi_batch() -> VortexResult<()> {
424        let mut ctx = LEGACY_SESSION.create_execution_ctx();
425        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
426        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?;
427
428        let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
429        acc.accumulate(&batch1, &mut ctx)?;
430
431        let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
432        acc.accumulate(&batch2, &mut ctx)?;
433
434        let result = acc.finish()?;
435        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(48));
436        Ok(())
437    }
438
439    #[test]
440    fn sum_finish_resets_state() -> VortexResult<()> {
441        let mut ctx = LEGACY_SESSION.create_execution_ctx();
442        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
443        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?;
444
445        let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
446        acc.accumulate(&batch1, &mut ctx)?;
447        let result1 = acc.finish()?;
448        assert_eq!(result1.as_primitive().typed_value::<i64>(), Some(30));
449
450        let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
451        acc.accumulate(&batch2, &mut ctx)?;
452        let result2 = acc.finish()?;
453        assert_eq!(result2.as_primitive().typed_value::<i64>(), Some(18));
454        Ok(())
455    }
456
457    // State merge tests (vtable-level)
458
459    #[test]
460    fn sum_state_merge() -> VortexResult<()> {
461        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
462        let mut state = Sum.empty_partial(&EmptyOptions, &dtype)?;
463
464        let scalar1 = Scalar::primitive(100i64, Nullable);
465        Sum.combine_partials(&mut state, scalar1)?;
466
467        let scalar2 = Scalar::primitive(50i64, Nullable);
468        Sum.combine_partials(&mut state, scalar2)?;
469
470        let result = Sum.to_scalar(&state)?;
471        Sum.reset(&mut state);
472        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(150));
473        Ok(())
474    }
475
476    // Stats caching test
477
478    #[test]
479    fn sum_stats() -> VortexResult<()> {
480        let array = ChunkedArray::try_new(
481            vec![
482                PrimitiveArray::from_iter([1, 1, 1]).into_array(),
483                PrimitiveArray::from_iter([2, 2, 2]).into_array(),
484            ],
485            DType::Primitive(PType::I32, Nullability::NonNullable),
486        )
487        .vortex_expect("operation should succeed in test");
488        let array = array.into_array();
489        // compute sum with accumulator to populate stats
490        sum_with_accumulator(&array, &Scalar::primitive(2i64, Nullable))?;
491
492        let sum_without_acc = sum(&array, &mut LEGACY_SESSION.create_execution_ctx())?;
493        assert_eq!(sum_without_acc, Scalar::primitive(9i64, Nullable));
494        Ok(())
495    }
496
497    // Constant float non-multiply test
498
499    #[test]
500    fn sum_constant_float_non_multiply() -> VortexResult<()> {
501        let acc = -2048669276050936500000000000f64;
502        let array = ConstantArray::new(6.1811675e16f64, 25);
503        let result = sum_with_accumulator(&array.into_array(), &Scalar::primitive(acc, Nullable))
504            .vortex_expect("operation should succeed in test");
505        assert_eq!(
506            f64::try_from(&result).vortex_expect("operation should succeed in test"),
507            -2048669274505644600000000000f64
508        );
509        Ok(())
510    }
511
512    // Grouped sum tests
513
514    fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult<ArrayRef> {
515        let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?;
516        acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?;
517        acc.finish()
518    }
519
520    #[test]
521    fn grouped_sum_fixed_size_list() -> VortexResult<()> {
522        let elements =
523            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array();
524        let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
525
526        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
527        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
528
529        let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array();
530        assert_arrays_eq!(&result, &expected);
531        Ok(())
532    }
533
534    #[test]
535    fn grouped_sum_with_null_elements() -> VortexResult<()> {
536        let elements =
537            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)])
538                .into_array();
539        let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
540
541        let elem_dtype = DType::Primitive(PType::I32, Nullable);
542        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
543
544        let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array();
545        assert_arrays_eq!(&result, &expected);
546        Ok(())
547    }
548
549    #[test]
550    fn grouped_sum_with_null_group() -> VortexResult<()> {
551        let elements =
552            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable)
553                .into_array();
554        let validity = Validity::from_iter([true, false, true]);
555        let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?;
556
557        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
558        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
559
560        let expected =
561            PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array();
562        assert_arrays_eq!(&result, &expected);
563        Ok(())
564    }
565
566    #[test]
567    fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> {
568        let elements =
569            PrimitiveArray::from_option_iter([None::<i32>, None, Some(3), Some(4)]).into_array();
570        let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?;
571
572        let elem_dtype = DType::Primitive(PType::I32, Nullable);
573        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
574
575        let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array();
576        assert_arrays_eq!(&result, &expected);
577        Ok(())
578    }
579
580    #[test]
581    fn grouped_sum_bool() -> VortexResult<()> {
582        let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect();
583        let groups =
584            FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?;
585
586        let elem_dtype = DType::Bool(Nullability::NonNullable);
587        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
588
589        let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array();
590        assert_arrays_eq!(&result, &expected);
591        Ok(())
592    }
593
594    #[test]
595    fn grouped_sum_finish_resets() -> VortexResult<()> {
596        let mut ctx = LEGACY_SESSION.create_execution_ctx();
597        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
598        let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype)?;
599
600        let elements1 =
601            PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
602        let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?;
603        acc.accumulate_list(&groups1.into_array(), &mut ctx)?;
604        let result1 = acc.finish()?;
605
606        let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array();
607        assert_arrays_eq!(&result1, &expected1);
608
609        let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
610        let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?;
611        acc.accumulate_list(&groups2.into_array(), &mut ctx)?;
612        let result2 = acc.finish()?;
613
614        let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array();
615        assert_arrays_eq!(&result2, &expected2);
616        Ok(())
617    }
618
619    // Chunked array tests
620
621    #[test]
622    fn sum_chunked_floats_with_nulls() -> VortexResult<()> {
623        let chunk1 =
624            PrimitiveArray::from_option_iter(vec![Some(1.5f64), None, Some(3.2), Some(4.8)]);
625        let chunk2 = PrimitiveArray::from_option_iter(vec![Some(2.1f64), Some(5.7), None]);
626        let chunk3 = PrimitiveArray::from_option_iter(vec![None, Some(1.0f64), Some(2.5), None]);
627        let dtype = chunk1.dtype().clone();
628        let chunked = ChunkedArray::try_new(
629            vec![
630                chunk1.into_array(),
631                chunk2.into_array(),
632                chunk3.into_array(),
633            ],
634            dtype,
635        )?;
636
637        let result = sum(
638            &chunked.into_array(),
639            &mut LEGACY_SESSION.create_execution_ctx(),
640        )?;
641        assert_eq!(result.as_primitive().as_::<f64>(), Some(20.8));
642        Ok(())
643    }
644
645    #[test]
646    fn sum_chunked_floats_all_nulls_is_zero() -> VortexResult<()> {
647        let chunk1 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None, None]);
648        let chunk2 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None]);
649        let dtype = chunk1.dtype().clone();
650        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
651        let result = sum(
652            &chunked.into_array(),
653            &mut LEGACY_SESSION.create_execution_ctx(),
654        )?;
655        assert_eq!(result, Scalar::primitive(0f64, Nullable));
656        Ok(())
657    }
658
659    #[test]
660    fn sum_chunked_floats_empty_chunks() -> VortexResult<()> {
661        let chunk1 = PrimitiveArray::from_option_iter(vec![Some(10.5f64), Some(20.3)]);
662        let chunk2 = ConstantArray::new(Scalar::primitive(0f64, Nullable), 0);
663        let chunk3 = PrimitiveArray::from_option_iter(vec![Some(5.2f64)]);
664        let dtype = chunk1.dtype().clone();
665        let chunked = ChunkedArray::try_new(
666            vec![
667                chunk1.into_array(),
668                chunk2.into_array(),
669                chunk3.into_array(),
670            ],
671            dtype,
672        )?;
673
674        let result = sum(
675            &chunked.into_array(),
676            &mut LEGACY_SESSION.create_execution_ctx(),
677        )?;
678        assert_eq!(result.as_primitive().as_::<f64>(), Some(36.0));
679        Ok(())
680    }
681
682    #[test]
683    fn sum_chunked_int_almost_all_null() -> VortexResult<()> {
684        let chunk1 = PrimitiveArray::from_option_iter::<u32, _>(vec![Some(1)]);
685        let chunk2 = PrimitiveArray::from_option_iter::<u32, _>(vec![None]);
686        let dtype = chunk1.dtype().clone();
687        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
688
689        let result = sum(
690            &chunked.into_array(),
691            &mut LEGACY_SESSION.create_execution_ctx(),
692        )?;
693        assert_eq!(result.as_primitive().as_::<u64>(), Some(1));
694        Ok(())
695    }
696
697    #[test]
698    fn sum_chunked_decimals() -> VortexResult<()> {
699        let decimal_dtype = DecimalDType::new(10, 2);
700        let chunk1 = DecimalArray::new(
701            buffer![100i32, 100i32, 100i32, 100i32, 100i32],
702            decimal_dtype,
703            Validity::AllValid,
704        );
705        let chunk2 = DecimalArray::new(
706            buffer![200i32, 200i32, 200i32],
707            decimal_dtype,
708            Validity::AllValid,
709        );
710        let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid);
711        let dtype = chunk1.dtype().clone();
712        let chunked = ChunkedArray::try_new(
713            vec![
714                chunk1.into_array(),
715                chunk2.into_array(),
716                chunk3.into_array(),
717            ],
718            dtype,
719        )?;
720
721        let result = sum(
722            &chunked.into_array(),
723            &mut LEGACY_SESSION.create_execution_ctx(),
724        )?;
725        let decimal_result = result.as_decimal();
726        assert_eq!(
727            decimal_result.decimal_value(),
728            Some(DecimalValue::I256(i256::from_i128(1700)))
729        );
730        Ok(())
731    }
732
733    #[test]
734    fn sum_chunked_decimals_with_nulls() -> VortexResult<()> {
735        let decimal_dtype = DecimalDType::new(10, 2);
736        let chunk1 = DecimalArray::new(
737            buffer![100i32, 100i32, 100i32],
738            decimal_dtype,
739            Validity::AllValid,
740        );
741        let chunk2 = DecimalArray::new(
742            buffer![0i32, 0i32],
743            decimal_dtype,
744            Validity::from_iter([false, false]),
745        );
746        let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid);
747        let dtype = chunk1.dtype().clone();
748        let chunked = ChunkedArray::try_new(
749            vec![
750                chunk1.into_array(),
751                chunk2.into_array(),
752                chunk3.into_array(),
753            ],
754            dtype,
755        )?;
756
757        let result = sum(
758            &chunked.into_array(),
759            &mut LEGACY_SESSION.create_execution_ctx(),
760        )?;
761        let decimal_result = result.as_decimal();
762        assert_eq!(
763            decimal_result.decimal_value(),
764            Some(DecimalValue::I256(i256::from_i128(700)))
765        );
766        Ok(())
767    }
768
769    #[test]
770    fn sum_chunked_decimals_large() -> VortexResult<()> {
771        let decimal_dtype = DecimalDType::new(3, 0);
772        let chunk1 = ConstantArray::new(
773            Scalar::decimal(
774                DecimalValue::I16(500),
775                decimal_dtype,
776                Nullability::NonNullable,
777            ),
778            1,
779        );
780        let chunk2 = ConstantArray::new(
781            Scalar::decimal(
782                DecimalValue::I16(600),
783                decimal_dtype,
784                Nullability::NonNullable,
785            ),
786            1,
787        );
788        let dtype = chunk1.dtype().clone();
789        let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
790
791        let result = sum(
792            &chunked.into_array(),
793            &mut LEGACY_SESSION.create_execution_ctx(),
794        )?;
795        let decimal_result = result.as_decimal();
796        assert_eq!(
797            decimal_result.decimal_value(),
798            Some(DecimalValue::I256(i256::from_i128(1100)))
799        );
800        assert_eq!(
801            result.dtype(),
802            &DType::Decimal(DecimalDType::new(13, 0), Nullable)
803        );
804        Ok(())
805    }
806}