Skip to main content

vortex_array/aggregate_fn/fns/
sum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::BitAnd;
5
6use itertools::Itertools;
7use num_traits::ToPrimitive;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12use vortex_error::vortex_panic;
13use vortex_mask::AllOr;
14
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::ExecutionCtx;
18use crate::aggregate_fn::AggregateFnId;
19use crate::aggregate_fn::AggregateFnVTable;
20use crate::aggregate_fn::EmptyOptions;
21use crate::arrays::BoolArray;
22use crate::arrays::DecimalArray;
23use crate::arrays::PrimitiveArray;
24use crate::dtype::DType;
25use crate::dtype::Nullability;
26use crate::dtype::PType;
27use crate::expr::stats::Stat;
28use crate::match_each_decimal_value_type;
29use crate::match_each_native_ptype;
30use crate::scalar::DecimalValue;
31use crate::scalar::Scalar;
32
33#[derive(Clone, Debug)]
34pub struct Sum;
35
36impl AggregateFnVTable for Sum {
37    type Options = EmptyOptions;
38    type Partial = SumPartial;
39
40    fn id(&self) -> AggregateFnId {
41        AggregateFnId::new_ref("vortex.sum")
42    }
43
44    fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
45        Stat::Sum
46            .dtype(input_dtype)
47            .ok_or_else(|| vortex_err!("Cannot sum {}", input_dtype))
48    }
49
50    fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
51        self.return_dtype(options, input_dtype)
52    }
53
54    fn empty_partial(
55        &self,
56        _options: &Self::Options,
57        input_dtype: &DType,
58    ) -> VortexResult<Self::Partial> {
59        let return_dtype = Stat::Sum
60            .dtype(input_dtype)
61            .ok_or_else(|| vortex_err!("Cannot sum {}", input_dtype))?;
62
63        let initial = make_zero_state(&return_dtype);
64
65        Ok(SumPartial {
66            return_dtype,
67            current: Some(initial),
68        })
69    }
70
71    fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
72        if other.is_null() {
73            // A null partial means the sub-accumulator saturated (overflow).
74            partial.current = None;
75            return Ok(());
76        }
77        let Some(ref mut inner) = partial.current else {
78            return Ok(());
79        };
80        let saturated = match inner {
81            SumState::Unsigned(acc) => {
82                let val = other
83                    .as_primitive()
84                    .typed_value::<u64>()
85                    .vortex_expect("checked non-null");
86                checked_add_u64(acc, val)
87            }
88            SumState::Signed(acc) => {
89                let val = other
90                    .as_primitive()
91                    .typed_value::<i64>()
92                    .vortex_expect("checked non-null");
93                checked_add_i64(acc, val)
94            }
95            SumState::Float(acc) => {
96                let val = other
97                    .as_primitive()
98                    .typed_value::<f64>()
99                    .vortex_expect("checked non-null");
100                *acc += val;
101                false
102            }
103            SumState::Decimal(acc) => {
104                let val = other
105                    .as_decimal()
106                    .decimal_value()
107                    .vortex_expect("checked non-null");
108                match acc.checked_add(&val) {
109                    Some(r) => {
110                        *acc = r;
111                        false
112                    }
113                    None => true,
114                }
115            }
116        };
117        if saturated {
118            partial.current = None;
119        }
120        Ok(())
121    }
122
123    fn flush(&self, partial: &mut Self::Partial) -> VortexResult<Scalar> {
124        let result = match &partial.current {
125            None => Scalar::null(partial.return_dtype.as_nullable()),
126            Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable),
127            Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable),
128            Some(SumState::Float(v)) => Scalar::primitive(*v, Nullability::Nullable),
129            Some(SumState::Decimal(v)) => {
130                let decimal_dtype = *partial
131                    .return_dtype
132                    .as_decimal_opt()
133                    .vortex_expect("return dtype must be decimal");
134                Scalar::decimal(*v, decimal_dtype, Nullability::Nullable)
135            }
136        };
137
138        // Reset the state
139        partial.current = Some(make_zero_state(&partial.return_dtype));
140
141        Ok(result)
142    }
143
144    #[inline]
145    fn is_saturated(&self, partial: &Self::Partial) -> bool {
146        partial.current.is_none()
147    }
148
149    fn accumulate(
150        &self,
151        partial: &mut Self::Partial,
152        batch: &Canonical,
153        _ctx: &mut ExecutionCtx,
154    ) -> VortexResult<()> {
155        let mut inner = match partial.current.take() {
156            Some(inner) => inner,
157            None => return Ok(()),
158        };
159
160        let result = match batch {
161            Canonical::Primitive(p) => accumulate_primitive(&mut inner, p),
162            Canonical::Bool(b) => accumulate_bool(&mut inner, b),
163            Canonical::Decimal(d) => accumulate_decimal(&mut inner, d),
164            _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()),
165        };
166
167        match result {
168            Ok(false) => partial.current = Some(inner),
169            Ok(true) => {} // saturated: current stays None
170            Err(e) => {
171                partial.current = Some(inner);
172                return Err(e);
173            }
174        }
175        Ok(())
176    }
177
178    fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
179        Ok(partials)
180    }
181
182    fn finalize_scalar(&self, partial: Scalar) -> VortexResult<Scalar> {
183        Ok(partial)
184    }
185}
186
187/// The group state for a sum aggregate, containing the accumulated value and configuration
188/// needed for reset/result without external context.
189pub struct SumPartial {
190    return_dtype: DType,
191    /// The current accumulated state, or `None` if saturated (checked overflow).
192    current: Option<SumState>,
193}
194
195/// The accumulated sum value.
196///
197// TODO(ngates): instead of an enum, we should use a Box<dyn State> to avoid dispatcher over the
198//  input type every time? Perhaps?
199pub enum SumState {
200    Unsigned(u64),
201    Signed(i64),
202    Float(f64),
203    Decimal(DecimalValue),
204}
205
206fn make_zero_state(return_dtype: &DType) -> SumState {
207    match return_dtype {
208        DType::Primitive(ptype, _) => match ptype {
209            PType::U8 | PType::U16 | PType::U32 | PType::U64 => SumState::Unsigned(0),
210            PType::I8 | PType::I16 | PType::I32 | PType::I64 => SumState::Signed(0),
211            PType::F16 | PType::F32 | PType::F64 => SumState::Float(0.0),
212        },
213        DType::Decimal(decimal, _) => SumState::Decimal(DecimalValue::zero(decimal)),
214        _ => vortex_panic!("Unsupported sum type"),
215    }
216}
217
218/// Checked add for u64, returning true if overflow occurred.
219#[inline(always)]
220fn checked_add_u64(acc: &mut u64, val: u64) -> bool {
221    match acc.checked_add(val) {
222        Some(r) => {
223            *acc = r;
224            false
225        }
226        None => true,
227    }
228}
229
230/// Checked add for i64, returning true if overflow occurred.
231#[inline(always)]
232fn checked_add_i64(acc: &mut i64, val: i64) -> bool {
233    match acc.checked_add(val) {
234        Some(r) => {
235            *acc = r;
236            false
237        }
238        None => true,
239    }
240}
241
242fn accumulate_primitive(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult<bool> {
243    let mask = p.validity_mask()?;
244    match mask.bit_buffer() {
245        AllOr::None => Ok(false),
246        AllOr::All => accumulate_primitive_all(inner, p),
247        AllOr::Some(validity) => accumulate_primitive_valid(inner, p, validity),
248    }
249}
250
251fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult<bool> {
252    match inner {
253        SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(),
254            unsigned: |T| {
255                for &v in p.as_slice::<T>() {
256                    if checked_add_u64(acc, v.to_u64().vortex_expect("unsigned to u64")) {
257                        return Ok(true);
258                    }
259                }
260                Ok(false)
261            },
262            signed: |_T| { vortex_panic!("unsigned sum state with signed input") },
263            floating: |_T| { vortex_panic!("unsigned sum state with float input") }
264        ),
265        SumState::Signed(acc) => match_each_native_ptype!(p.ptype(),
266            unsigned: |_T| { vortex_panic!("signed sum state with unsigned input") },
267            signed: |T| {
268                for &v in p.as_slice::<T>() {
269                    if checked_add_i64(acc, v.to_i64().vortex_expect("signed to i64")) {
270                        return Ok(true);
271                    }
272                }
273                Ok(false)
274            },
275            floating: |_T| { vortex_panic!("signed sum state with float input") }
276        ),
277        SumState::Float(acc) => match_each_native_ptype!(p.ptype(),
278            unsigned: |_T| { vortex_panic!("float sum state with unsigned input") },
279            signed: |_T| { vortex_panic!("float sum state with signed input") },
280            floating: |T| {
281                for &v in p.as_slice::<T>() {
282                    *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64");
283                }
284                Ok(false)
285            }
286        ),
287        SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"),
288    }
289}
290
291fn accumulate_primitive_valid(
292    inner: &mut SumState,
293    p: &PrimitiveArray,
294    validity: &vortex_buffer::BitBuffer,
295) -> VortexResult<bool> {
296    match inner {
297        SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(),
298            unsigned: |T| {
299                for (&v, valid) in p.as_slice::<T>().iter().zip_eq(validity.iter()) {
300                    if valid && checked_add_u64(acc, v.to_u64().vortex_expect("unsigned to u64")) {
301                        return Ok(true);
302                    }
303                }
304                Ok(false)
305            },
306            signed: |_T| { vortex_panic!("unsigned sum state with signed input") },
307            floating: |_T| { vortex_panic!("unsigned sum state with float input") }
308        ),
309        SumState::Signed(acc) => match_each_native_ptype!(p.ptype(),
310            unsigned: |_T| { vortex_panic!("signed sum state with unsigned input") },
311            signed: |T| {
312                for (&v, valid) in p.as_slice::<T>().iter().zip_eq(validity.iter()) {
313                    if valid && checked_add_i64(acc, v.to_i64().vortex_expect("signed to i64")) {
314                        return Ok(true);
315                    }
316                }
317                Ok(false)
318            },
319            floating: |_T| { vortex_panic!("signed sum state with float input") }
320        ),
321        SumState::Float(acc) => match_each_native_ptype!(p.ptype(),
322            unsigned: |_T| { vortex_panic!("float sum state with unsigned input") },
323            signed: |_T| { vortex_panic!("float sum state with signed input") },
324            floating: |T| {
325                for (&v, valid) in p.as_slice::<T>().iter().zip_eq(validity.iter()) {
326                    if valid {
327                        *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64");
328                    }
329                }
330                Ok(false)
331            }
332        ),
333        SumState::Decimal(_) => vortex_panic!("decimal sum state with primitive input"),
334    }
335}
336
337fn accumulate_bool(inner: &mut SumState, b: &BoolArray) -> VortexResult<bool> {
338    let SumState::Unsigned(acc) = inner else {
339        vortex_panic!("expected unsigned sum state for bool input");
340    };
341
342    let mask = b.validity_mask()?;
343    let true_count = match mask.bit_buffer() {
344        AllOr::None => return Ok(false),
345        AllOr::All => b.to_bit_buffer().true_count() as u64,
346        AllOr::Some(validity) => b.to_bit_buffer().bitand(validity).true_count() as u64,
347    };
348
349    Ok(checked_add_u64(acc, true_count))
350}
351
352/// Accumulate a decimal array into the sum state.
353/// Returns Ok(true) if saturated (overflow), Ok(false) if not.
354fn accumulate_decimal(inner: &mut SumState, d: &DecimalArray) -> VortexResult<bool> {
355    let SumState::Decimal(acc) = inner else {
356        vortex_panic!("expected decimal sum state for decimal input");
357    };
358
359    let mask = d.validity_mask()?;
360    match mask.bit_buffer() {
361        AllOr::None => Ok(false),
362        AllOr::All => match_each_decimal_value_type!(d.values_type(), |T| {
363            for &v in d.buffer::<T>().iter() {
364                match acc.checked_add(&DecimalValue::from(v)) {
365                    Some(r) => *acc = r,
366                    None => return Ok(true),
367                }
368            }
369            Ok(false)
370        }),
371        AllOr::Some(validity) => match_each_decimal_value_type!(d.values_type(), |T| {
372            for (&v, valid) in d.buffer::<T>().iter().zip_eq(validity.iter()) {
373                if valid {
374                    match acc.checked_add(&DecimalValue::from(v)) {
375                        Some(r) => *acc = r,
376                        None => return Ok(true),
377                    }
378                }
379            }
380            Ok(false)
381        }),
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use vortex_buffer::buffer;
388    use vortex_error::VortexResult;
389    use vortex_session::VortexSession;
390
391    use crate::ArrayRef;
392    use crate::IntoArray;
393    use crate::aggregate_fn::Accumulator;
394    use crate::aggregate_fn::AggregateFnVTable;
395    use crate::aggregate_fn::DynAccumulator;
396    use crate::aggregate_fn::DynGroupedAccumulator;
397    use crate::aggregate_fn::EmptyOptions;
398    use crate::aggregate_fn::GroupedAccumulator;
399    use crate::aggregate_fn::fns::sum::Sum;
400    use crate::arrays::BoolArray;
401    use crate::arrays::FixedSizeListArray;
402    use crate::arrays::PrimitiveArray;
403    use crate::assert_arrays_eq;
404    use crate::dtype::DType;
405    use crate::dtype::Nullability;
406    use crate::dtype::PType;
407    use crate::scalar::Scalar;
408    use crate::validity::Validity;
409
410    fn session() -> VortexSession {
411        VortexSession::empty()
412    }
413
414    fn run_sum(batch: &ArrayRef) -> VortexResult<Scalar> {
415        let mut acc = Accumulator::try_new(Sum, EmptyOptions, batch.dtype().clone(), session())?;
416        acc.accumulate(batch)?;
417        acc.finish()
418    }
419
420    // Primitive sum tests
421
422    #[test]
423    fn sum_i32() -> VortexResult<()> {
424        let arr = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
425        let result = run_sum(&arr)?;
426        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(10));
427        Ok(())
428    }
429
430    #[test]
431    fn sum_u8() -> VortexResult<()> {
432        let arr = PrimitiveArray::new(buffer![10u8, 20, 30], Validity::NonNullable).into_array();
433        let result = run_sum(&arr)?;
434        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(60));
435        Ok(())
436    }
437
438    #[test]
439    fn sum_f64() -> VortexResult<()> {
440        let arr =
441            PrimitiveArray::new(buffer![1.5f64, 2.5, 3.0], Validity::NonNullable).into_array();
442        let result = run_sum(&arr)?;
443        assert_eq!(result.as_primitive().typed_value::<f64>(), Some(7.0));
444        Ok(())
445    }
446
447    #[test]
448    fn sum_with_nulls() -> VortexResult<()> {
449        let arr = PrimitiveArray::from_option_iter([Some(2i32), None, Some(4)]).into_array();
450        let result = run_sum(&arr)?;
451        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(6));
452        Ok(())
453    }
454
455    #[test]
456    fn sum_all_null() -> VortexResult<()> {
457        // Arrow semantics: sum of all nulls is zero (identity element)
458        let arr = PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array();
459        let result = run_sum(&arr)?;
460        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(0));
461        Ok(())
462    }
463
464    // Empty accumulator tests
465
466    #[test]
467    fn sum_empty_produces_zero() -> VortexResult<()> {
468        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
469        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
470        let result = acc.finish()?;
471        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(0));
472        Ok(())
473    }
474
475    #[test]
476    fn sum_empty_f64_produces_zero() -> VortexResult<()> {
477        let dtype = DType::Primitive(PType::F64, Nullability::NonNullable);
478        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
479        let result = acc.finish()?;
480        assert_eq!(result.as_primitive().typed_value::<f64>(), Some(0.0));
481        Ok(())
482    }
483
484    // Multi-batch and reset tests
485
486    #[test]
487    fn sum_multi_batch() -> VortexResult<()> {
488        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
489        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
490
491        let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
492        acc.accumulate(&batch1)?;
493
494        let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
495        acc.accumulate(&batch2)?;
496
497        let result = acc.finish()?;
498        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(48));
499        Ok(())
500    }
501
502    #[test]
503    fn sum_finish_resets_state() -> VortexResult<()> {
504        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
505        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
506
507        let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
508        acc.accumulate(&batch1)?;
509        let result1 = acc.finish()?;
510        assert_eq!(result1.as_primitive().typed_value::<i64>(), Some(30));
511
512        let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
513        acc.accumulate(&batch2)?;
514        let result2 = acc.finish()?;
515        assert_eq!(result2.as_primitive().typed_value::<i64>(), Some(18));
516        Ok(())
517    }
518
519    // State merge tests (vtable-level)
520
521    #[test]
522    fn sum_state_merge() -> VortexResult<()> {
523        let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
524        let mut state = Sum.empty_partial(&EmptyOptions, &dtype)?;
525
526        let scalar1 = Scalar::primitive(100i64, Nullability::Nullable);
527        Sum.combine_partials(&mut state, scalar1)?;
528
529        let scalar2 = Scalar::primitive(50i64, Nullability::Nullable);
530        Sum.combine_partials(&mut state, scalar2)?;
531
532        let result = Sum.flush(&mut state)?;
533        assert_eq!(result.as_primitive().typed_value::<i64>(), Some(150));
534        Ok(())
535    }
536
537    // Overflow tests
538
539    #[test]
540    fn sum_checked_overflow() -> VortexResult<()> {
541        let arr = PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array();
542        let result = run_sum(&arr)?;
543        assert!(result.is_null());
544        Ok(())
545    }
546
547    #[test]
548    fn sum_checked_overflow_is_saturated() -> VortexResult<()> {
549        let dtype = DType::Primitive(PType::I64, Nullability::NonNullable);
550        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
551        assert!(!acc.is_saturated());
552
553        let batch =
554            PrimitiveArray::new(buffer![i64::MAX, 1i64], Validity::NonNullable).into_array();
555        acc.accumulate(&batch)?;
556        assert!(acc.is_saturated());
557
558        // finish resets state, clearing saturation
559        drop(acc.finish()?);
560        assert!(!acc.is_saturated());
561        Ok(())
562    }
563
564    // Boolean sum tests
565
566    #[test]
567    fn sum_bool_all_true() -> VortexResult<()> {
568        let arr: BoolArray = [true, true, true].into_iter().collect();
569        let result = run_sum(&arr.into_array())?;
570        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
571        Ok(())
572    }
573
574    #[test]
575    fn sum_bool_mixed() -> VortexResult<()> {
576        let arr: BoolArray = [true, false, true, false, true].into_iter().collect();
577        let result = run_sum(&arr.into_array())?;
578        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
579        Ok(())
580    }
581
582    #[test]
583    fn sum_bool_all_false() -> VortexResult<()> {
584        let arr: BoolArray = [false, false, false].into_iter().collect();
585        let result = run_sum(&arr.into_array())?;
586        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
587        Ok(())
588    }
589
590    #[test]
591    fn sum_bool_with_nulls() -> VortexResult<()> {
592        let arr = BoolArray::from_iter([Some(true), None, Some(true), Some(false)]);
593        let result = run_sum(&arr.into_array())?;
594        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(2));
595        Ok(())
596    }
597
598    #[test]
599    fn sum_bool_all_null() -> VortexResult<()> {
600        // Arrow semantics: sum of all nulls is zero (identity element)
601        let arr = BoolArray::from_iter([None::<bool>, None, None]);
602        let result = run_sum(&arr.into_array())?;
603        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
604        Ok(())
605    }
606
607    #[test]
608    fn sum_bool_empty_produces_zero() -> VortexResult<()> {
609        let dtype = DType::Bool(Nullability::NonNullable);
610        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
611        let result = acc.finish()?;
612        assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
613        Ok(())
614    }
615
616    #[test]
617    fn sum_bool_finish_resets_state() -> VortexResult<()> {
618        let dtype = DType::Bool(Nullability::NonNullable);
619        let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype, session())?;
620
621        let batch1: BoolArray = [true, true, false].into_iter().collect();
622        acc.accumulate(&batch1.into_array())?;
623        let result1 = acc.finish()?;
624        assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(2));
625
626        let batch2: BoolArray = [false, true].into_iter().collect();
627        acc.accumulate(&batch2.into_array())?;
628        let result2 = acc.finish()?;
629        assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(1));
630        Ok(())
631    }
632
633    #[test]
634    fn sum_bool_return_dtype() -> VortexResult<()> {
635        let dtype = Sum.return_dtype(&EmptyOptions, &DType::Bool(Nullability::NonNullable))?;
636        assert_eq!(dtype, DType::Primitive(PType::U64, Nullability::Nullable));
637        Ok(())
638    }
639
640    // Grouped sum tests
641
642    fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult<ArrayRef> {
643        let mut acc =
644            GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone(), session())?;
645        acc.accumulate_list(groups)?;
646        acc.finish()
647    }
648
649    #[test]
650    fn grouped_sum_fixed_size_list() -> VortexResult<()> {
651        // Groups: [[1,2,3], [4,5,6]] -> sums [6, 15]
652        let elements =
653            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array();
654        let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
655
656        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
657        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
658
659        let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array();
660        assert_arrays_eq!(&result, &expected);
661        Ok(())
662    }
663
664    #[test]
665    fn grouped_sum_with_null_elements() -> VortexResult<()> {
666        // Groups: [[Some(1), None, Some(3)], [None, Some(5), Some(6)]] -> sums [4, 11]
667        let elements =
668            PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)])
669                .into_array();
670        let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?;
671
672        let elem_dtype = DType::Primitive(PType::I32, Nullability::Nullable);
673        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
674
675        let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array();
676        assert_arrays_eq!(&result, &expected);
677        Ok(())
678    }
679
680    #[test]
681    fn grouped_sum_with_null_group() -> VortexResult<()> {
682        // Groups: [[1,2,3], null, [7,8,9]] -> sums [6, null, 24]
683        let elements =
684            PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable)
685                .into_array();
686        let validity = Validity::from_iter([true, false, true]);
687        let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?;
688
689        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
690        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
691
692        let expected =
693            PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array();
694        assert_arrays_eq!(&result, &expected);
695        Ok(())
696    }
697
698    #[test]
699    fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> {
700        // Groups: [[None, None], [Some(3), Some(4)]] -> sums [0, 7] (Arrow semantics)
701        let elements =
702            PrimitiveArray::from_option_iter([None::<i32>, None, Some(3), Some(4)]).into_array();
703        let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?;
704
705        let elem_dtype = DType::Primitive(PType::I32, Nullability::Nullable);
706        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
707
708        let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array();
709        assert_arrays_eq!(&result, &expected);
710        Ok(())
711    }
712
713    #[test]
714    fn grouped_sum_bool() -> VortexResult<()> {
715        // Groups: [[true, false, true], [true, true, true]] -> sums [2, 3]
716        let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect();
717        let groups =
718            FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?;
719
720        let elem_dtype = DType::Bool(Nullability::NonNullable);
721        let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?;
722
723        let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array();
724        assert_arrays_eq!(&result, &expected);
725        Ok(())
726    }
727
728    #[test]
729    fn grouped_sum_finish_resets() -> VortexResult<()> {
730        let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
731        let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype, session())?;
732
733        // First batch: [[1, 2], [3, 4]]
734        let elements1 =
735            PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
736        let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?;
737        acc.accumulate_list(&groups1.into_array())?;
738        let result1 = acc.finish()?;
739
740        let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array();
741        assert_arrays_eq!(&result1, &expected1);
742
743        // Second batch after reset: [[10, 20]]
744        let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
745        let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?;
746        acc.accumulate_list(&groups2.into_array())?;
747        let result2 = acc.finish()?;
748
749        let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array();
750        assert_arrays_eq!(&result2, &expected2);
751        Ok(())
752    }
753}