Skip to main content

vortex_array/expr/stats/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use enum_iterator::Sequence;
9use enum_iterator::all;
10use num_enum::IntoPrimitive;
11use num_enum::TryFromPrimitive;
12
13use crate::dtype::DType;
14use crate::dtype::DecimalDType;
15use crate::dtype::MAX_PRECISION;
16use crate::dtype::Nullability::NonNullable;
17use crate::dtype::Nullability::Nullable;
18use crate::dtype::PType;
19
20mod bound;
21mod precision;
22mod provider;
23mod stat_bound;
24
25pub use bound::*;
26pub use precision::*;
27pub use provider::*;
28pub use stat_bound::*;
29
30#[derive(
31    Debug,
32    Clone,
33    Copy,
34    PartialEq,
35    Eq,
36    PartialOrd,
37    Ord,
38    Hash,
39    Sequence,
40    IntoPrimitive,
41    TryFromPrimitive,
42)]
43#[repr(u8)]
44pub enum Stat {
45    /// Whether all values are the same (nulls are not equal to other non-null values,
46    /// so this is true iff all values are null or all values are the same non-null value)
47    IsConstant = 0,
48    /// Whether the non-null values in the array are sorted in ascending order (i.e., we skip nulls)
49    /// This may later be extended to support descending order, but for now we only support ascending order.
50    IsSorted = 1,
51    /// Whether the non-null values in the array are strictly sorted in ascending order (i.e., sorted with no duplicates)
52    /// This may later be extended to support descending order, but for now we only support ascending order.
53    IsStrictSorted = 2,
54    /// The maximum value in the array (ignoring nulls, unless all values are null)
55    Max = 3,
56    /// The minimum value in the array (ignoring nulls, unless all values are null)
57    Min = 4,
58    /// The sum of the non-null values of the array.
59    Sum = 5,
60    /// The number of null values in the array
61    NullCount = 6,
62    /// The uncompressed size of the array in bytes
63    UncompressedSizeInBytes = 7,
64    /// The number of NaN values in the array
65    NaNCount = 8,
66}
67
68/// These structs allow the extraction of the bound from the `Precision` value.
69/// They tie together the Stat and the StatBound, which allows the bound to be extracted.
70pub struct Max;
71
72pub struct Min;
73
74pub struct Sum;
75
76pub struct IsConstant;
77
78pub struct IsSorted;
79
80pub struct IsStrictSorted;
81
82pub struct NullCount;
83
84pub struct UncompressedSizeInBytes;
85
86pub struct NaNCount;
87
88impl StatType<bool> for IsConstant {
89    type Bound = Precision<bool>;
90
91    const STAT: Stat = Stat::IsConstant;
92}
93
94impl StatType<bool> for IsSorted {
95    type Bound = Precision<bool>;
96
97    const STAT: Stat = Stat::IsSorted;
98}
99
100impl StatType<bool> for IsStrictSorted {
101    type Bound = Precision<bool>;
102
103    const STAT: Stat = Stat::IsStrictSorted;
104}
105
106impl<T: PartialOrd + Clone> StatType<T> for NullCount {
107    type Bound = UpperBound<T>;
108
109    const STAT: Stat = Stat::NullCount;
110}
111
112impl<T: PartialOrd + Clone> StatType<T> for UncompressedSizeInBytes {
113    type Bound = UpperBound<T>;
114
115    const STAT: Stat = Stat::UncompressedSizeInBytes;
116}
117
118impl<T: PartialOrd + Clone + Debug> StatType<T> for Max {
119    type Bound = UpperBound<T>;
120
121    const STAT: Stat = Stat::Max;
122}
123
124impl<T: PartialOrd + Clone + Debug> StatType<T> for Min {
125    type Bound = LowerBound<T>;
126
127    const STAT: Stat = Stat::Min;
128}
129
130impl<T: PartialOrd + Clone + Debug> StatType<T> for Sum {
131    type Bound = Precision<T>;
132
133    const STAT: Stat = Stat::Sum;
134}
135
136impl<T: PartialOrd + Clone> StatType<T> for NaNCount {
137    type Bound = UpperBound<T>;
138
139    const STAT: Stat = Stat::NaNCount;
140}
141
142impl Stat {
143    /// Whether the statistic is commutative (i.e., whether merging can be done independently of ordering)
144    /// e.g., min/max are commutative, but is_sorted is not
145    pub fn is_commutative(&self) -> bool {
146        // NOTE: we prefer this syntax to force a compile error if we add a new stat
147        match self {
148            Self::IsConstant
149            | Self::Max
150            | Self::Min
151            | Self::NullCount
152            | Self::Sum
153            | Self::NaNCount
154            | Self::UncompressedSizeInBytes => true,
155            Self::IsSorted | Self::IsStrictSorted => false,
156        }
157    }
158
159    /// Whether the statistic has the same dtype as the array it's computed on
160    pub fn has_same_dtype_as_array(&self) -> bool {
161        matches!(self, Stat::Min | Stat::Max)
162    }
163
164    /// Return the [`DType`] of the statistic scalar assuming the array is of the given [`DType`].
165    pub fn dtype(&self, data_type: &DType) -> Option<DType> {
166        Some(match self {
167            Self::IsConstant => DType::Bool(NonNullable),
168            Self::IsSorted => DType::Bool(NonNullable),
169            Self::IsStrictSorted => DType::Bool(NonNullable),
170            Self::Max if matches!(data_type, DType::Null) => return None,
171            Self::Max => data_type.clone(),
172            Self::Min if matches!(data_type, DType::Null) => return None,
173            Self::Min => data_type.clone(),
174            Self::NullCount => DType::Primitive(PType::U64, NonNullable),
175            Self::UncompressedSizeInBytes => DType::Primitive(PType::U64, NonNullable),
176            Self::NaNCount => {
177                // Only floating points support NaN counts.
178                if let DType::Primitive(ptype, ..) = data_type
179                    && ptype.is_float()
180                {
181                    DType::Primitive(PType::U64, NonNullable)
182                } else {
183                    return None;
184                }
185            }
186            Self::Sum => {
187                // Any array that cannot be summed has a sum DType of null.
188                // Any array that can be summed, but overflows, has a sum _value_ of null.
189                // Therefore, we make integer sum stats nullable.
190                match data_type {
191                    DType::Bool(_) => DType::Primitive(PType::U64, Nullable),
192                    DType::Primitive(ptype, _) => match ptype {
193                        PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
194                            DType::Primitive(PType::U64, Nullable)
195                        }
196                        PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
197                            DType::Primitive(PType::I64, Nullable)
198                        }
199                        PType::F16 | PType::F32 | PType::F64 => {
200                            // Float sums cannot overflow, but all null floats still end up as null
201                            DType::Primitive(PType::F64, Nullable)
202                        }
203                    },
204                    DType::Extension(ext_dtype) => self.dtype(ext_dtype.storage_dtype())?,
205                    DType::Decimal(decimal_dtype, _) => {
206                        // Both Spark and DataFusion use this heuristic.
207                        // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
208                        // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188
209                        let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10);
210                        DType::Decimal(
211                            DecimalDType::new(precision, decimal_dtype.scale()),
212                            Nullable,
213                        )
214                    }
215                    // Unsupported types
216                    _ => return None,
217                }
218            }
219        })
220    }
221
222    pub fn name(&self) -> &str {
223        match self {
224            Self::IsConstant => "is_constant",
225            Self::IsSorted => "is_sorted",
226            Self::IsStrictSorted => "is_strict_sorted",
227            Self::Max => "max",
228            Self::Min => "min",
229            Self::NullCount => "null_count",
230            Self::UncompressedSizeInBytes => "uncompressed_size_in_bytes",
231            Self::Sum => "sum",
232            Self::NaNCount => "nan_count",
233        }
234    }
235
236    pub fn all() -> impl Iterator<Item = Stat> {
237        all::<Self>()
238    }
239}
240
241impl Display for Stat {
242    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
243        write!(f, "{}", self.name())
244    }
245}
246
247#[cfg(test)]
248mod test {
249    use enum_iterator::all;
250
251    use crate::arrays::PrimitiveArray;
252    use crate::expr::stats::Stat;
253
254    #[test]
255    fn min_of_nulls_is_not_panic() {
256        let min = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None])
257            .statistics()
258            .compute_as::<i64>(Stat::Min);
259
260        assert_eq!(min, None);
261    }
262
263    #[test]
264    fn has_same_dtype_as_array() {
265        assert!(Stat::Min.has_same_dtype_as_array());
266        assert!(Stat::Max.has_same_dtype_as_array());
267        for stat in all::<Stat>().filter(|s| !matches!(s, Stat::Min | Stat::Max)) {
268            assert!(!stat.has_same_dtype_as_array());
269        }
270    }
271}