Skip to main content

vortex_array/aggregate_fn/fns/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::BitAnd;
5
6use itertools::Itertools;
7use num_traits::ToPrimitive;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12use vortex_error::vortex_panic;
13use vortex_mask::AllOr;
14
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::Columnar;
18use crate::ExecutionCtx;
19use crate::aggregate_fn::AggregateFnId;
20use crate::aggregate_fn::AggregateFnVTable;
21use crate::aggregate_fn::EmptyOptions;
22use crate::arrays::BoolArray;
23use crate::arrays::ConstantArray;
24use crate::arrays::DecimalArray;
25use crate::arrays::PrimitiveArray;
26use crate::dtype::DType;
27use crate::dtype::Nullability;
28use crate::dtype::PType;
29use crate::expr::stats::Stat;
30use crate::match_each_decimal_value_type;
31use crate::match_each_native_ptype;
32use crate::scalar::DecimalValue;
33use crate::scalar::Scalar;
34
35#[derive(Clone, Debug)]
36pub struct Sum;
37
38impl AggregateFnVTable for Sum {
39    type Options = EmptyOptions;
40    type Partial = SumPartial;
41
42    fn id(&self) -> AggregateFnId {
43        AggregateFnId::new_ref("vortex.sum")
44    }
45
46    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
47        Stat::Sum
48            .dtype(input_dtype)
49            .ok_or_else(|| vortex_err!("Cannot sum {}", input_dtype))
50    }
51
52    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
53        self.return_dtype(options, input_dtype)
54    }
55
56    fn empty_partial(
57        &self,
58        _options: &Self::Options,
59        input_dtype: &DType,
60    ) -> VortexResult<Self::Partial> {
61        let return_dtype = Stat::Sum
62            .dtype(input_dtype)
63            .ok_or_else(|| vortex_err!("Cannot sum {}", input_dtype))?;
64
65        let initial = make_zero_state(&return_dtype);
66
67        Ok(SumPartial {
68            return_dtype,
69            current: Some(initial),
70        })
71    }
72
73    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
74        if other.is_null() {
75            // A null partial means the sub-accumulator saturated (overflow).
76            partial.current = None;
77            return Ok(());
78        }
79        let Some(ref mut inner) = partial.current else {
80            return Ok(());
81        };
82        let saturated = match inner {
83            SumState::Unsigned(acc) => {
84                let val = other
85                    .as_primitive()
86                    .typed_value::<u64>()
87                    .vortex_expect("checked non-null");
88                checked_add_u64(acc, val)
89            }
90            SumState::Signed(acc) => {
91                let val = other
92                    .as_primitive()
93                    .typed_value::<i64>()
94                    .vortex_expect("checked non-null");
95                checked_add_i64(acc, val)
96            }
97            SumState::Float(acc) => {
98                let val = other
99                    .as_primitive()
100                    .typed_value::<f64>()
101                    .vortex_expect("checked non-null");
102                *acc += val;
103                false
104            }
105            SumState::Decimal(acc) => {
106                let val = other
107                    .as_decimal()
108                    .decimal_value()
109                    .vortex_expect("checked non-null");
110                match acc.checked_add(&val) {
111                    Some(r) => {
112                        *acc = r;
113                        false
114                    }
115                    None => true,
116                }
117            }
118        };
119        if saturated {
120            partial.current = None;
121        }
122        Ok(())
123    }
124
125    fn flush(&self, partial: &mut Self::Partial) -> VortexResult<Scalar> {
126        let result = match &partial.current {
127            None => Scalar::null(partial.return_dtype.as_nullable()),
128            Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable),
129            Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable),
130            Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable),
131            Some(SumState::Decimal(v)) => {
132                let decimal_dtype = *partial
133                    .return_dtype
134                    .as_decimal_opt()
135                    .vortex_expect("return dtype must be decimal");
136                Scalar::decimal(*v, decimal_dtype, Nullability::Nullable)
137            }
138        };
139
140        // Reset the state
141        partial.current = Some(make_zero_state(&partial.return_dtype));
142
143        Ok(result)
144    }
145
146    #[inline]
147    fn is_saturated(&self, partial: &Self::Partial) -> bool {
148        partial.current.is_none()
149    }
150
151    fn accumulate(
152        &self,
153        partial: &mut Self::Partial,
154        batch: &Columnar,
155        _ctx: &mut ExecutionCtx,
156    ) -> VortexResult<()> {
157        let mut inner = match partial.current.take() {
158            Some(inner) => inner,
159            None => return Ok(()),
160        };
161
162        let result = match batch {
163            Columnar::Canonical(c) => match c {
164                Canonical::Primitive(p) => accumulate_primitive(&mut inner, p),
165                Canonical::Bool(b) => accumulate_bool(&mut inner, b),
166                Canonical::Decimal(d) => accumulate_decimal(&mut inner, d),
167                _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()),
168            },
169            Columnar::Constant(c) => accumulate_constant(&mut inner, c),
170        };
171
172        match result {
173            Ok(false) => partial.current = Some(inner),
174            Ok(true) => {} // saturated: current stays None
175            Err(e) => {
176                partial.current = Some(inner);
177                return Err(e);
178            }
179        }
180        Ok(())
181    }
182
183    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
184        Ok(partials)
185    }
186
187    fn finalize_scalar(&self, partial: Scalar) -> VortexResult<Scalar> {
188        Ok(partial)
189    }
190}
191
192/// The group state for a sum aggregate, containing the accumulated value and configuration
193/// needed for reset/result without external context.
194pub struct SumPartial {
195    return_dtype: DType,
196    /// The current accumulated state, or `None` if saturated (checked overflow).
197    current: Option<SumState>,
198}
199
200/// The accumulated sum value.
201///
202// TODO(ngates): instead of an enum, we should use a Box<dyn State> to avoid dispatcher over the
203//  input type every time? Perhaps?
204pub enum SumState {
205    Unsigned(u64),
206    Signed(i64),
207    Float(f64),
208    Decimal(DecimalValue),
209}
210
211fn make_zero_state(return_dtype: &DType) -> SumState {
212    match return_dtype {
213        DType::Primitive(ptype, _) => match ptype {
214            PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0),
215            PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0),
216            PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0),
217        },
218        DType::Decimal(decimal, _) => SumState::Decimal(DecimalValue::zero(decimal)),
219        _ => vortex_panic!("Unsupported sum type"),
220    }
221}
222
223/// Checked add for u64, returning true if overflow occurred.
224#[inline(always)]
225fn checked_add_u64(acc: &mut u64, val: u64) -> bool {
226    match acc.checked_add(val) {
227        Some(r) => {
228            *acc = r;
229            false
230        }
231        None => true,
232    }
233}
234
235/// Checked add for i64, returning true if overflow occurred.
236#[inline(always)]
237fn checked_add_i64(acc: &mut i64, val: i64) -> bool {
238    match acc.checked_add(val) {
239        Some(r) => {
240            *acc = r;
241            false
242        }
243        None => true,
244    }
245}
246
247fn accumulate_primitive(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult<bool> {
248    let mask = p.validity_mask()?;
249    match mask.bit_buffer() {
250        AllOr::None => Ok(false),
251        AllOr::All => accumulate_primitive_all(inner, p),
252        AllOr::Some(validity) => accumulate_primitive_valid(inner, p, validity),
253    }
254}
255
256fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult<bool> {
257    match inner {
258        SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(),
259            unsigned: |T| {
260                for &v in p.as_slice::<T>() {
261                    if checked_add_u64(acc, v.to_u64().vortex_expect("unsigned to u64")) {
262                        return Ok(true);
263                    }
264                }
265                Ok(false)
266            },
267            signed: |_T| { vortex_panic!("unsigned sum state with signed input") },
268            floating: |_T| { vortex_panic!("unsigned sum state with float input") }
269        ),
270        SumState::Signed(acc) => match_each_native_ptype!(p.ptype(),
271            unsigned: |_T| { vortex_panic!("signed sum state with unsigned input") },
272            signed: |T| {
273                for &v in p.as_slice::<T>() {
274                    if checked_add_i64(acc, v.to_i64().vortex_expect("signed to i64")) {
275                        return Ok(true);
276                    }
277                }
278                Ok(false)
279            },
280            floating: |_T| { vortex_panic!("signed sum state with float input") }
281        ),
282        SumState::Float(acc) => match_each_native_ptype!(p.ptype(),
283            unsigned: |_T| { vortex_panic!("float sum state with unsigned input") },
284            signed: |_T| { vortex_panic!("float sum state with signed input") },
285            floating: |T| {
286                for &v in p.as_slice::<T>() {
287                    *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64");
288                }
289                Ok(false)
290            }
291        ),
292        SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"),
293    }
294}
295
296fn accumulate_primitive_valid(
297    inner: &mut SumState,
298    p: &PrimitiveArray,
299    validity: &vortex_buffer::BitBuffer,
300) -> VortexResult<bool> {
301    match inner {
302        SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(),
303            unsigned: |T| {
304                for (&v, valid) in p.as_slice::<T>().iter().zip_eq(validity.iter()) {
305                    if valid && checked_add_u64(acc, v.to_u64().vortex_expect("unsigned to u64")) {
306                        return Ok(true);
307                    }
308                }
309                Ok(false)
310            },
311            signed: |_T| { vortex_panic!("unsigned sum state with signed input") },
312            floating: |_T| { vortex_panic!("unsigned sum state with float input") }
313        ),
314        SumState::Signed(acc) => match_each_native_ptype!(p.ptype(),
315            unsigned: |_T| { vortex_panic!("signed sum state with unsigned input") },
316            signed: |T| {
317                for (&v, valid) in p.as_slice::<T>().iter().zip_eq(validity.iter()) {
318                    if valid && checked_add_i64(acc, v.to_i64().vortex_expect("signed to i64")) {
319                        return Ok(true);
320                    }
321                }
322                Ok(false)
323            },
324            floating: |_T| { vortex_panic!("signed sum state with float input") }
325        ),
326        SumState::Float(acc) => match_each_native_ptype!(p.ptype(),
327            unsigned: |_T| { vortex_panic!("float sum state with unsigned input") },
328            signed: |_T| { vortex_panic!("float sum state with signed input") },
329            floating: |T| {
330                for (&v, valid) in p.as_slice::<T>().iter().zip_eq(validity.iter()) {
331                    if valid {
332                        *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64");
333                    }
334                }
335                Ok(false)
336            }
337        ),
338        SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"),
339    }
340}
341
342fn accumulate_bool(inner: &mut SumState, b: &BoolArray) -> VortexResult<bool> {
343    let SumState::Unsigned(acc) = inner else {
344        vortex_panic!("expected unsigned sum state for bool input");
345    };
346
347    let mask = b.validity_mask()?;
348    let true_count = match mask.bit_buffer() {
349        AllOr::None => return Ok(false),
350        AllOr::All => b.to_bit_buffer().true_count() as u64,
351        AllOr::Some(validity) => b.to_bit_buffer().bitand(validity).true_count() as u64,
352    };
353
354    Ok(checked_add_u64(acc, true_count))
355}
356
357/// Accumulate a constant array into the sum state.
358/// Computes `scalar * len` and adds to the accumulator.
359/// Returns Ok(true) if saturated (overflow), Ok(false) if not.
360fn accumulate_constant(inner: &mut SumState, c: &ConstantArray) -> VortexResult<bool> {
361    let scalar = c.scalar();
362    if scalar.is_null() || c.is_empty() {
363        return Ok(false);
364    }
365    let len = c.len();
366
367    match scalar.dtype() {
368        DType::Bool(_) => {
369            let SumState::Unsigned(acc) = inner else {
370                vortex_panic!("expected unsigned sum state for bool input");
371            };
372            let val = scalar
373                .as_bool()
374                .value()
375                .ok_or_else(|| vortex_err!("Expected non-null bool scalar for sum"))?;
376            if val {
377                Ok(checked_add_u64(acc, len as u64))
378            } else {
379                Ok(false)
380            }
381        }
382        DType::Primitive(..) => {
383            let pvalue = scalar
384                .as_primitive()
385                .pvalue()
386                .ok_or_else(|| vortex_err!("Expected non-null primitive scalar for sum"))?;
387            match inner {
388                SumState::Unsigned(acc) => {
389                    let val = pvalue.cast::<u64>()?;
390                    match val.checked_mul(len as u64) {
391                        Some(product) => Ok(checked_add_u64(acc, product)),
392                        None => Ok(true),
393                    }
394                }
395                SumState::Signed(acc) => {
396                    let val = pvalue.cast::<i64>()?;
397                    match i64::try_from(len).ok().and_then(|l| val.checked_mul(l)) {
398                        Some(product) => Ok(checked_add_i64(acc, product)),
399                        None => Ok(true),
400                    }
401                }
402                SumState::Float(acc) => {
403                    let val = pvalue.cast::<f64>()?;
404                    *acc += val * len as f64;
405                    Ok(false)
406                }
407                SumState::Decimal(_) => {
408                    vortex_panic!("decimal sum state with primitive input")
409                }
410            }
411        }
412        DType::Decimal(..) => {
413            let SumState::Decimal(acc) = inner else {
414                vortex_panic!("expected decimal sum state for decimal input");
415            };
416            let val = scalar
417                .as_decimal()
418                .decimal_value()
419                .ok_or_else(|| vortex_err!("Expected non-null decimal scalar for sum"))?;
420            let len_decimal = DecimalValue::from(len as i128);
421            match val.checked_mul(&len_decimal) {
422                Some(product) => match acc.checked_add(&product) {
423                    Some(r) => {
424                        *acc = r;
425                        Ok(false)
426                    }
427                    None => Ok(true),
428                },
429                None => Ok(true),
430            }
431        }
432        _ => vortex_bail!("Unsupported constant type for sum: {}", scalar.dtype()),
433    }
434}
435
436/// Accumulate a decimal array into the sum state.
437/// Returns Ok(true) if saturated (overflow), Ok(false) if not.
438fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> VortexResult<bool> {
439    let SumState::Decimal(acc) = inner else {
440        vortex_panic!("expected decimal sum state for decimal input");
441    };
442
443    let mask = d.validity_mask()?;
444    match mask.bit_buffer() {
445        AllOr::None => Ok(false),
446        AllOr::All => match_each_decimal_value_type!(d.values_type(), |T| {
447            for &v in d.buffer::<T>().iter() {
448                match acc.checked_add(&DecimalValue::from(v)) {
449                    Some(r) => *acc = r,
450                    None => return Ok(true),
451                }
452            }
453            Ok(false)
454        }),
455        AllOr::Some(validity) => match_each_decimal_value_type!(d.values_type(), |T| {
456            for (&v, valid) in d.buffer::<T>().iter().zip_eq(validity.iter()) {
457                if valid {
458                    match acc.checked_add(&DecimalValue::from(v)) {
459                        Some(r) => *acc = r,
460                        None => return Ok(true),
461                    }
462                }
463            }
464            Ok(false)
465        }),
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use vortex_buffer::buffer;
472    use vortex_error::VortexResult;
473    use vortex_session::VortexSession;
474
475    use crate::ArrayRef;
476    use crate::IntoArray;
477    use crate::aggregate_fn::Accumulator;
478    use crate::aggregate_fn::AggregateFnVTable;
479    use crate::aggregate_fn::DynAccumulator;
480    use crate::aggregate_fn::DynGroupedAccumulator;
481    use crate::aggregate_fn::EmptyOptions;
482    use crate::aggregate_fn::GroupedAccumulator;
483    use crate::aggregate_fn::fns::sum::Sum;
484    use crate::arrays::BoolArray;
485    use crate::arrays::FixedSizeListArray;
486    use crate::arrays::PrimitiveArray;
487    use crate::assert_arrays_eq;
488    use crate::dtype::DType;
489    use crate::dtype::Nullability;
490    use crate::dtype::PType;
491    use crate::scalar::Scalar;
492    use crate::validity::Validity;
493
494    fn session() -> VortexSession {
495        VortexSession::empty()
496    }
497
498    fn run_sum(batch: &ArrayRef) -> VortexResult<Scalar> {
499        let mut acc = Accumulator::try_new(Sum, EmptyOptions, batch.dtype().clone(), session())?;
500        acc.accumulate(batch)?;
501        acc.finish()
502    }
503
504    // Primitive sum tests
505
506    #[test]
507    fn sum_i32() -> VortexResult<()> {
508        let arr = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
509        let result = run_sum(&arr)?;
510        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(10));
511        Ok(())
512    }
513
514    #[test]
515    fn sum_u8() -> VortexResult<()> {
516        let arr = PrimitiveArray::new(buffer![10u8, 20, 30], Validity::NonNullable).into_array();
517        let result = run_sum(&arr)?;
518        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(60));
519        Ok(())
520    }
521
522    #[test]
523    fn sum_f64() -> VortexResult<()> {
524        let arr =
525            PrimitiveArray::new(buffer![1.5f64, 2.5, 3.0], Validity::NonNullable).into_array();
526        let result = run_sum(&arr)?;
527        assert_eq!(result.as_primitive().typed_value::<f64>(), Some(7.0));
528        Ok(())
529    }
530
531    #[test]
532    fn sum_with_nulls() -> VortexResult<()> {
533        let arr = PrimitiveArray::from_option_iter([Some(2i32), None, Some(4)]).into_array();
534        let result = run_sum(&arr)?;
535        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(6));
536        Ok(())
537    }
538
539    #[test]
540    fn sum_all_null() -> VortexResult<()> {
541        // Arrow semantics: sum of all nulls is zero (identity element)
542        let arr = PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array();
543        let result = run_sum(&arr)?;
544        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(0));
545        Ok(())
546    }
547
548    // Empty accumulator tests
549
550    #[test]
551    fn sum_empty_produces_zero() -> VortexResult<()> {
552        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
553        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
554        let result = acc.finish()?;
555        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(0));
556        Ok(())
557    }
558
559    #[test]
560    fn sum_empty_f64_produces_zero() -> VortexResult<()> {
561        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
562        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
563        let result = acc.finish()?;
564        assert_eq!(result.as_primitive().typed_value::<f64>(), Some(0.0));
565        Ok(())
566    }
567
568    // Multi-batch and reset tests
569
570    #[test]
571    fn sum_multi_batch() -> VortexResult<()> {
572        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
573        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
574
575        let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
576        acc.accumulate(&batch1)?;
577
578        let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
579        acc.accumulate(&batch2)?;
580
581        let result = acc.finish()?;
582        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(48));
583        Ok(())
584    }
585
586    #[test]
587    fn sum_finish_resets_state() -> VortexResult<()> {
588        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
589        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
590
591        let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
592        acc.accumulate(&batch1)?;
593        let result1 = acc.finish()?;
594        assert_eq!(result1.as_primitive().typed_value::<i64>(), Some(30));
595
596        let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
597        acc.accumulate(&batch2)?;
598        let result2 = acc.finish()?;
599        assert_eq!(result2.as_primitive().typed_value::<i64>(), Some(18));
600        Ok(())
601    }
602
603    // State merge tests (vtable-level)
604
605    #[test]
606    fn sum_state_merge() -> VortexResult<()> {
607        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
608        let mut state = Sum.empty_partial(&EmptyOptions, &dtype)?;
609
610        let scalar1 = Scalar::primitive(100i64, Nullability::Nullable);
611        Sum.combine_partials(&mut state, scalar1)?;
612
613        let scalar2 = Scalar::primitive(50i64, Nullability::Nullable);
614        Sum.combine_partials(&mut state, scalar2)?;
615
616        let result = Sum.flush(&mut state)?;
617        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(150));
618        Ok(())
619    }
620
621    // Overflow tests
622
623    #[test]
624    fn sum_checked_overflow() -> VortexResult<()> {
625        let arr = PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array();
626        let result = run_sum(&arr)?;
627        assert!(result.is_null());
628        Ok(())
629    }
630
631    #[test]
632    fn sum_checked_overflow_is_saturated() -> VortexResult<()> {
633        let dtype = DType::Primitive(PType::I64, Nullability::NonNullable);
634        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
635        assert!(!acc.is_saturated());
636
637        let batch =
638            PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array();
639        acc.accumulate(&batch)?;
640        assert!(acc.is_saturated());
641
642        // finish resets state, clearing saturation
643        drop(acc.finish()?);
644        assert!(!acc.is_saturated());
645        Ok(())
646    }
647
648    // Boolean sum tests
649
650    #[test]
651    fn sum_bool_all_true() -> VortexResult<()> {
652        let arr: BoolArray = [true, true, true].into_iter().collect();
653        let result = run_sum(&arr.into_array())?;
654        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
655        Ok(())
656    }
657
658    #[test]
659    fn sum_bool_mixed() -> VortexResult<()> {
660        let arr: BoolArray = [true, false, true, false, true].into_iter().collect();
661        let result = run_sum(&arr.into_array())?;
662        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
663        Ok(())
664    }
665
666    #[test]
667    fn sum_bool_all_false() -> VortexResult<()> {
668        let arr: BoolArray = [false, false, false].into_iter().collect();
669        let result = run_sum(&arr.into_array())?;
670        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
671        Ok(())
672    }
673
674    #[test]
675    fn sum_bool_with_nulls() -> VortexResult<()> {
676        let arr = BoolArray::from_iter([Some(true), None, Some(true), Some(false)]);
677        let result = run_sum(&arr.into_array())?;
678        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(2));
679        Ok(())
680    }
681
682    #[test]
683    fn sum_bool_all_null() -> VortexResult<()> {
684        // Arrow semantics: sum of all nulls is zero (identity element)
685        let arr = BoolArray::from_iter([None::<bool>, None, None]);
686        let result = run_sum(&arr.into_array())?;
687        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
688        Ok(())
689    }
690
691    #[test]
692    fn sum_bool_empty_produces_zero() -> VortexResult<()> {
693        let dtype = DType::Bool(Nullability::NonNullable);
694        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
695        let result = acc.finish()?;
696        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
697        Ok(())
698    }
699
700    #[test]
701    fn sum_bool_finish_resets_state() -> VortexResult<()> {
702        let dtype = DType::Bool(Nullability::NonNullable);
703        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
704
705        let batch1: BoolArray = [true, true, false].into_iter().collect();
706        acc.accumulate(&batch1.into_array())?;
707        let result1 = acc.finish()?;
708        assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(2));
709
710        let batch2: BoolArray = [false, true].into_iter().collect();
711        acc.accumulate(&batch2.into_array())?;
712        let result2 = acc.finish()?;
713        assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(1));
714        Ok(())
715    }
716
717    #[test]
718    fn sum_bool_return_dtype() -> VortexResult<()> {
719        let dtype = Sum.return_dtype(&EmptyOptions, &DType::Bool(Nullability::NonNullable))?;
720        assert_eq!(dtype, DType::Primitive(PType::U64, Nullability::Nullable));
721        Ok(())
722    }
723
724    // Grouped sum tests
725
726    fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult<ArrayRef> {
727        let mut acc =
728            GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone(), session())?;
729        acc.accumulate_list(groups)?;
730        acc.finish()
731    }
732
733    #[test]
734    fn grouped_sum_fixed_size_list() -> VortexResult<()> {
735        // Groups: [[1,2,3], [4,5,6]] -> sums [6, 15]
736        let elements =
737            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array();
738        let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
739
740        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
741        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
742
743        let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array();
744        assert_arrays_eq!(&result, &expected);
745        Ok(())
746    }
747
748    #[test]
749    fn grouped_sum_with_null_elements() -> VortexResult<()> {
750        // Groups: [[Some(1), None, Some(3)], [None, Some(5), Some(6)]] -> sums [4, 11]
751        let elements =
752            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)])
753                .into_array();
754        let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
755
756        let elem_dtype = DType::Primitive(PType::I32, Nullability::Nullable);
757        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
758
759        let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array();
760        assert_arrays_eq!(&result, &expected);
761        Ok(())
762    }
763
764    #[test]
765    fn grouped_sum_with_null_group() -> VortexResult<()> {
766        // Groups: [[1,2,3], null, [7,8,9]] -> sums [6, null, 24]
767        let elements =
768            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable)
769                .into_array();
770        let validity = Validity::from_iter([true, false, true]);
771        let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?;
772
773        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
774        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
775
776        let expected =
777            PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array();
778        assert_arrays_eq!(&result, &expected);
779        Ok(())
780    }
781
782    #[test]
783    fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> {
784        // Groups: [[None, None], [Some(3), Some(4)]] -> sums [0, 7] (Arrow semantics)
785        let elements =
786            PrimitiveArray::from_option_iter([None::<i32>, None, Some(3), Some(4)]).into_array();
787        let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?;
788
789        let elem_dtype = DType::Primitive(PType::I32, Nullability::Nullable);
790        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
791
792        let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array();
793        assert_arrays_eq!(&result, &expected);
794        Ok(())
795    }
796
797    #[test]
798    fn grouped_sum_bool() -> VortexResult<()> {
799        // Groups: [[true, false, true], [true, true, true]] -> sums [2, 3]
800        let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect();
801        let groups =
802            FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?;
803
804        let elem_dtype = DType::Bool(Nullability::NonNullable);
805        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
806
807        let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array();
808        assert_arrays_eq!(&result, &expected);
809        Ok(())
810    }
811
812    #[test]
813    fn grouped_sum_finish_resets() -> VortexResult<()> {
814        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
815        let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype, session())?;
816
817        // First batch: [[1, 2], [3, 4]]
818        let elements1 =
819            PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
820        let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?;
821        acc.accumulate_list(&groups1.into_array())?;
822        let result1 = acc.finish()?;
823
824        let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array();
825        assert_arrays_eq!(&result1, &expected1);
826
827        // Second batch after reset: [[10, 20]]
828        let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
829        let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?;
830        acc.accumulate_list(&groups2.into_array())?;
831        let result2 = acc.finish()?;
832
833        let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array();
834        assert_arrays_eq!(&result2, &expected2);
835        Ok(())
836    }
837}