vortex_array/compute/
sum.rs

1use vortex_dtype::PType;
2use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic};
3use vortex_scalar::Scalar;
4
5use crate::Array;
6use crate::encoding::Encoding;
7use crate::stats::{Precision, Stat, StatsProvider};
8
9pub trait SumFn<A> {
10    /// # Preconditions
11    ///
12    /// * The array's DType is summable
13    /// * The array is not all-null
14    fn sum(&self, array: A) -> VortexResult<Scalar>;
15}
16
17impl<E: Encoding> SumFn<&dyn Array> for E
18where
19    E: for<'a> SumFn<&'a E::Array>,
20{
21    fn sum(&self, array: &dyn Array) -> VortexResult<Scalar> {
22        let array_ref = array
23            .as_any()
24            .downcast_ref::<E::Array>()
25            .vortex_expect("Failed to downcast array");
26        SumFn::sum(self, array_ref)
27    }
28}
29
30/// Sum an array.
31///
32/// If the sum overflows, a null scalar will be returned.
33/// If the sum is not supported for the array's dtype, an error will be raised.
34/// If the array is all-invalid, the sum will be zero.
35pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
36    // Compute the expected dtype of the sum.
37    let sum_dtype = Stat::Sum
38        .dtype(array.dtype())
39        .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))?;
40
41    // Short-circuit using array statistics.
42    if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
43        return Ok(Scalar::new(sum_dtype, sum));
44    }
45
46    // If the array is constant, we can compute the sum directly.
47    if let Some(mut constant) = array.as_constant() {
48        if constant.is_null() {
49            // An all-null constant array has a sum of 0.
50            return if PType::try_from(&sum_dtype)
51                .vortex_expect("must be primitive")
52                .is_float()
53            {
54                Ok(Scalar::new(sum_dtype, 0.0.into()))
55            } else {
56                Ok(Scalar::new(sum_dtype, 0.into()))
57            };
58        }
59
60        // If it's an extension array, then unwrap it into the storage scalar.
61        if let Some(extension) = constant.as_extension_opt() {
62            constant = extension.storage();
63        }
64
65        // If it's a boolean array, then the true count is the sum, which is the length.
66        if let Some(bool) = constant.as_bool_opt() {
67            return if bool.value().vortex_expect("already checked for null value") {
68                // Constant true
69                Ok(Scalar::new(sum_dtype, array.len().into()))
70            } else {
71                // Constant false
72                Ok(Scalar::new(sum_dtype, 0.into()))
73            };
74        }
75
76        // If it's a primitive array, then the sum is the constant value times the length.
77        if let Some(primitive) = constant.as_primitive_opt() {
78            match primitive.ptype() {
79                PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
80                    let value = primitive
81                        .pvalue()
82                        .vortex_expect("already checked for null value")
83                        .as_u64()
84                        .vortex_expect("Failed to cast constant value to u64");
85
86                    // Overflow results in a null sum.
87                    let sum = value.checked_mul(array.len() as u64);
88
89                    return Ok(Scalar::new(sum_dtype, sum.into()));
90                }
91                PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
92                    let value = primitive
93                        .pvalue()
94                        .vortex_expect("already checked for null value")
95                        .as_i64()
96                        .vortex_expect("Failed to cast constant value to i64");
97
98                    // Overflow results in a null sum.
99                    let sum = value.checked_mul(array.len() as i64);
100
101                    return Ok(Scalar::new(sum_dtype, sum.into()));
102                }
103                PType::F16 | PType::F32 | PType::F64 => {
104                    let value = primitive
105                        .pvalue()
106                        .vortex_expect("already checked for null value")
107                        .as_f64()
108                        .vortex_expect("Failed to cast constant value to f64");
109
110                    let sum = value * (array.len() as f64);
111
112                    return Ok(Scalar::new(sum_dtype, sum.into()));
113                }
114            }
115        }
116
117        // For the unsupported types, we should have exited earlier.
118        unreachable!("Unsupported sum constant: {}", constant.dtype());
119    }
120
121    // Try to use the sum function from the vtable.
122    let sum = if let Some(f) = array.vtable().sum_fn() {
123        f.sum(array)?
124    } else {
125        // Otherwise, canonicalize and try again.
126        log::debug!("No sum implementation found for {}", array.encoding());
127
128        let array = array.to_canonical()?;
129        if let Some(f) = array.as_ref().vtable().sum_fn() {
130            f.sum(array.as_ref())?
131        } else {
132            vortex_bail!(
133                "No sum function for canonical array: {}",
134                array.as_ref().encoding(),
135            )
136        }
137    };
138
139    if sum.dtype() != &sum_dtype {
140        vortex_panic!(
141            "Sum function of {} returned scalar with wrong dtype: {:?}",
142            array.encoding(),
143            sum.dtype()
144        );
145    }
146
147    // Update the statistics with the computed sum.
148    array
149        .statistics()
150        .set(Stat::Sum, Precision::Exact(sum.value().clone()));
151
152    Ok(sum)
153}
154
155#[cfg(test)]
156mod test {
157    use crate::arrays::{BoolArray, PrimitiveArray};
158    use crate::compute::sum;
159
160    #[test]
161    fn sum_all_invalid() {
162        let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
163        let result = sum(&array).unwrap();
164        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(0));
165    }
166
167    #[test]
168    fn sum_all_invalid_float() {
169        let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
170        let result = sum(&array).unwrap();
171        assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(0.0));
172    }
173
174    #[test]
175    fn sum_constant() {
176        let array = PrimitiveArray::from_iter([1, 1, 1, 1]);
177        let result = sum(&array).unwrap();
178        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(4));
179    }
180
181    #[test]
182    fn sum_constant_float() {
183        let array = PrimitiveArray::from_iter([1., 1., 1., 1.]);
184        let result = sum(&array).unwrap();
185        assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(4.));
186    }
187
188    #[test]
189    fn sum_boolean() {
190        let array = BoolArray::from_iter([true, false, false, true]);
191        let result = sum(&array).unwrap();
192        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(2));
193    }
194}