vortex_array/compute/
sum.rs

1use vortex_dtype::PType;
2use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic};
3use vortex_scalar::Scalar;
4
5use crate::Array;
6use crate::encoding::Encoding;
7use crate::stats::{Precision, Stat, StatsProvider};
8
9pub trait SumFn<A> {
10    /// # Preconditions
11    ///
12    /// * The array's DType is summable
13    /// * The array is not all-null
14    fn sum(&self, array: A) -> VortexResult<Scalar>;
15}
16
17impl<E: Encoding> SumFn<&dyn Array> for E
18where
19    E: for<'a> SumFn<&'a E::Array>,
20{
21    fn sum(&self, array: &dyn Array) -> VortexResult<Scalar> {
22        let array_ref = array
23            .as_any()
24            .downcast_ref::<E::Array>()
25            .vortex_expect("Failed to downcast array");
26        SumFn::sum(self, array_ref)
27    }
28}
29
30/// Sum an array.
31///
32/// If the sum overflows, a null scalar will be returned.
33/// If the sum is not supported for the array's dtype, an error will be raised.
34/// If the array is all-invalid, the sum will be zero.
35pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
36    // Compute the expected dtype of the sum.
37    let sum_dtype = Stat::Sum
38        .dtype(array.dtype())
39        .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))?;
40
41    // Short-circuit using array statistics.
42    if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
43        return Ok(Scalar::new(sum_dtype, sum));
44    }
45
46    if array.is_empty() {
47        return if sum_dtype.is_float() {
48            Ok(Scalar::new(sum_dtype, 0.0.into()))
49        } else {
50            Ok(Scalar::new(sum_dtype, 0.into()))
51        };
52    }
53
54    // If the array is constant, we can compute the sum directly.
55    if let Some(mut constant) = array.as_constant() {
56        if constant.is_null() {
57            // An all-null constant array has a sum of 0.
58            return if sum_dtype.is_float() {
59                Ok(Scalar::new(sum_dtype, 0.0.into()))
60            } else {
61                Ok(Scalar::new(sum_dtype, 0.into()))
62            };
63        }
64
65        // If it's an extension array, then unwrap it into the storage scalar.
66        if let Some(extension) = constant.as_extension_opt() {
67            constant = extension.storage();
68        }
69
70        // If it's a boolean array, then the true count is the sum, which is the length.
71        if let Some(bool) = constant.as_bool_opt() {
72            return if bool.value().vortex_expect("already checked for null value") {
73                // Constant true
74                Ok(Scalar::new(sum_dtype, array.len().into()))
75            } else {
76                // Constant false
77                Ok(Scalar::new(sum_dtype, 0.into()))
78            };
79        }
80
81        // If it's a primitive array, then the sum is the constant value times the length.
82        if let Some(primitive) = constant.as_primitive_opt() {
83            match primitive.ptype() {
84                PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
85                    let value = primitive
86                        .pvalue()
87                        .vortex_expect("already checked for null value")
88                        .as_u64()
89                        .vortex_expect("Failed to cast constant value to u64");
90
91                    // Overflow results in a null sum.
92                    let sum = value.checked_mul(array.len() as u64);
93
94                    return Ok(Scalar::new(sum_dtype, sum.into()));
95                }
96                PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
97                    let value = primitive
98                        .pvalue()
99                        .vortex_expect("already checked for null value")
100                        .as_i64()
101                        .vortex_expect("Failed to cast constant value to i64");
102
103                    // Overflow results in a null sum.
104                    let sum = value.checked_mul(array.len() as i64);
105
106                    return Ok(Scalar::new(sum_dtype, sum.into()));
107                }
108                PType::F16 | PType::F32 | PType::F64 => {
109                    let value = primitive
110                        .pvalue()
111                        .vortex_expect("already checked for null value")
112                        .as_f64()
113                        .vortex_expect("Failed to cast constant value to f64");
114
115                    let sum = value * (array.len() as f64);
116
117                    return Ok(Scalar::new(sum_dtype, sum.into()));
118                }
119            }
120        }
121
122        // For the unsupported types, we should have exited earlier.
123        unreachable!("Unsupported sum constant: {}", constant.dtype());
124    }
125
126    // Try to use the sum function from the vtable.
127    let sum = if let Some(f) = array.vtable().sum_fn() {
128        f.sum(array)?
129    } else {
130        // Otherwise, canonicalize and try again.
131        log::debug!("No sum implementation found for {}", array.encoding());
132
133        let array = array.to_canonical()?;
134        if let Some(f) = array.as_ref().vtable().sum_fn() {
135            f.sum(array.as_ref())?
136        } else {
137            vortex_bail!(
138                "No sum function for canonical array: {}",
139                array.as_ref().encoding(),
140            )
141        }
142    };
143
144    if sum.dtype() != &sum_dtype {
145        vortex_panic!(
146            "Sum function of {} returned scalar with wrong dtype: {:?}",
147            array.encoding(),
148            sum.dtype()
149        );
150    }
151
152    // Update the statistics with the computed sum.
153    array
154        .statistics()
155        .set(Stat::Sum, Precision::Exact(sum.value().clone()));
156
157    Ok(sum)
158}
159
160#[cfg(test)]
161mod test {
162    use crate::arrays::{BoolArray, PrimitiveArray};
163    use crate::compute::sum;
164
165    #[test]
166    fn sum_all_invalid() {
167        let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
168        let result = sum(&array).unwrap();
169        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(0));
170    }
171
172    #[test]
173    fn sum_all_invalid_float() {
174        let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
175        let result = sum(&array).unwrap();
176        assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(0.0));
177    }
178
179    #[test]
180    fn sum_constant() {
181        let array = PrimitiveArray::from_iter([1, 1, 1, 1]);
182        let result = sum(&array).unwrap();
183        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(4));
184    }
185
186    #[test]
187    fn sum_constant_float() {
188        let array = PrimitiveArray::from_iter([1., 1., 1., 1.]);
189        let result = sum(&array).unwrap();
190        assert_eq!(result.as_primitive().as_::<f32>().unwrap(), Some(4.));
191    }
192
193    #[test]
194    fn sum_boolean() {
195        let array = BoolArray::from_iter([true, false, false, true]);
196        let result = sum(&array).unwrap();
197        assert_eq!(result.as_primitive().as_::<i32>().unwrap(), Some(2));
198    }
199}