Skip to main content

vortex_array/expr/stats/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use enum_iterator::Sequence;
9use enum_iterator::all;
10use num_enum::IntoPrimitive;
11use num_enum::TryFromPrimitive;
12use vortex_dtype::DType;
13use vortex_dtype::DecimalDType;
14use vortex_dtype::MAX_PRECISION;
15use vortex_dtype::Nullability::NonNullable;
16use vortex_dtype::Nullability::Nullable;
17use vortex_dtype::PType;
18
19mod bound;
20mod precision;
21mod provider;
22mod stat_bound;
23
24pub use bound::*;
25pub use precision::*;
26pub use provider::*;
27pub use stat_bound::*;
28
29#[derive(
30    Debug,
31    Clone,
32    Copy,
33    PartialEq,
34    Eq,
35    PartialOrd,
36    Ord,
37    Hash,
38    Sequence,
39    IntoPrimitive,
40    TryFromPrimitive,
41)]
42#[repr(u8)]
43pub enum Stat {
44    /// Whether all values are the same (nulls are not equal to other non-null values,
45    /// so this is true iff all values are null or all values are the same non-null value)
46    IsConstant = 0,
47    /// Whether the non-null values in the array are sorted in ascending order (i.e., we skip nulls)
48    /// This may later be extended to support descending order, but for now we only support ascending order.
49    IsSorted = 1,
50    /// Whether the non-null values in the array are strictly sorted in ascending order (i.e., sorted with no duplicates)
51    /// This may later be extended to support descending order, but for now we only support ascending order.
52    IsStrictSorted = 2,
53    /// The maximum value in the array (ignoring nulls, unless all values are null)
54    Max = 3,
55    /// The minimum value in the array (ignoring nulls, unless all values are null)
56    Min = 4,
57    /// The sum of the non-null values of the array.
58    Sum = 5,
59    /// The number of null values in the array
60    NullCount = 6,
61    /// The uncompressed size of the array in bytes
62    UncompressedSizeInBytes = 7,
63    /// The number of NaN values in the array
64    NaNCount = 8,
65}
66
67/// These structs allow the extraction of the bound from the `Precision` value.
68/// They tie together the Stat and the StatBound, which allows the bound to be extracted.
69pub struct Max;
70
71pub struct Min;
72
73pub struct Sum;
74
75pub struct IsConstant;
76
77pub struct IsSorted;
78
79pub struct IsStrictSorted;
80
81pub struct NullCount;
82
83pub struct UncompressedSizeInBytes;
84
85pub struct NaNCount;
86
87impl StatType<bool> for IsConstant {
88    type Bound = Precision<bool>;
89
90    const STAT: Stat = Stat::IsConstant;
91}
92
93impl StatType<bool> for IsSorted {
94    type Bound = Precision<bool>;
95
96    const STAT: Stat = Stat::IsSorted;
97}
98
99impl StatType<bool> for IsStrictSorted {
100    type Bound = Precision<bool>;
101
102    const STAT: Stat = Stat::IsStrictSorted;
103}
104
105impl<T: PartialOrd + Clone> StatType<T> for NullCount {
106    type Bound = UpperBound<T>;
107
108    const STAT: Stat = Stat::NullCount;
109}
110
111impl<T: PartialOrd + Clone> StatType<T> for UncompressedSizeInBytes {
112    type Bound = UpperBound<T>;
113
114    const STAT: Stat = Stat::UncompressedSizeInBytes;
115}
116
117impl<T: PartialOrd + Clone + Debug> StatType<T> for Max {
118    type Bound = UpperBound<T>;
119
120    const STAT: Stat = Stat::Max;
121}
122
123impl<T: PartialOrd + Clone + Debug> StatType<T> for Min {
124    type Bound = LowerBound<T>;
125
126    const STAT: Stat = Stat::Min;
127}
128
129impl<T: PartialOrd + Clone + Debug> StatType<T> for Sum {
130    type Bound = Precision<T>;
131
132    const STAT: Stat = Stat::Sum;
133}
134
135impl<T: PartialOrd + Clone> StatType<T> for NaNCount {
136    type Bound = UpperBound<T>;
137
138    const STAT: Stat = Stat::NaNCount;
139}
140
141impl Stat {
142    /// Whether the statistic is commutative (i.e., whether merging can be done independently of ordering)
143    /// e.g., min/max are commutative, but is_sorted is not
144    pub fn is_commutative(&self) -> bool {
145        // NOTE: we prefer this syntax to force a compile error if we add a new stat
146        match self {
147            Self::IsConstant
148            | Self::Max
149            | Self::Min
150            | Self::NullCount
151            | Self::Sum
152            | Self::NaNCount
153            | Self::UncompressedSizeInBytes => true,
154            Self::IsSorted | Self::IsStrictSorted => false,
155        }
156    }
157
158    /// Whether the statistic has the same dtype as the array it's computed on
159    pub fn has_same_dtype_as_array(&self) -> bool {
160        matches!(self, Stat::Min | Stat::Max)
161    }
162
163    /// Return the [`DType`] of the statistic scalar assuming the array is of the given [`DType`].
164    pub fn dtype(&self, data_type: &DType) -> Option<DType> {
165        Some(match self {
166            Self::IsConstant => DType::Bool(NonNullable),
167            Self::IsSorted => DType::Bool(NonNullable),
168            Self::IsStrictSorted => DType::Bool(NonNullable),
169            Self::Max if matches!(data_type, DType::Null) => return None,
170            Self::Max => data_type.clone(),
171            Self::Min if matches!(data_type, DType::Null) => return None,
172            Self::Min => data_type.clone(),
173            Self::NullCount => DType::Primitive(PType::U64, NonNullable),
174            Self::UncompressedSizeInBytes => DType::Primitive(PType::U64, NonNullable),
175            Self::NaNCount => {
176                // Only floating points support NaN counts.
177                if let DType::Primitive(ptype, ..) = data_type
178                    && ptype.is_float()
179                {
180                    DType::Primitive(PType::U64, NonNullable)
181                } else {
182                    return None;
183                }
184            }
185            Self::Sum => {
186                // Any array that cannot be summed has a sum DType of null.
187                // Any array that can be summed, but overflows, has a sum _value_ of null.
188                // Therefore, we make integer sum stats nullable.
189                match data_type {
190                    DType::Bool(_) => DType::Primitive(PType::U64, Nullable),
191                    DType::Primitive(ptype, _) => match ptype {
192                        PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
193                            DType::Primitive(PType::U64, Nullable)
194                        }
195                        PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
196                            DType::Primitive(PType::I64, Nullable)
197                        }
198                        PType::F16 | PType::F32 | PType::F64 => {
199                            // Float sums cannot overflow, but all null floats still end up as null
200                            DType::Primitive(PType::F64, Nullable)
201                        }
202                    },
203                    DType::Extension(ext_dtype) => self.dtype(ext_dtype.storage_dtype())?,
204                    DType::Decimal(decimal_dtype, _) => {
205                        // Both Spark and DataFusion use this heuristic.
206                        // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
207                        // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188
208                        let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10);
209                        DType::Decimal(
210                            DecimalDType::new(precision, decimal_dtype.scale()),
211                            Nullable,
212                        )
213                    }
214                    // Unsupported types
215                    _ => return None,
216                }
217            }
218        })
219    }
220
221    pub fn name(&self) -> &str {
222        match self {
223            Self::IsConstant => "is_constant",
224            Self::IsSorted => "is_sorted",
225            Self::IsStrictSorted => "is_strict_sorted",
226            Self::Max => "max",
227            Self::Min => "min",
228            Self::NullCount => "null_count",
229            Self::UncompressedSizeInBytes => "uncompressed_size_in_bytes",
230            Self::Sum => "sum",
231            Self::NaNCount => "nan_count",
232        }
233    }
234
235    pub fn all() -> impl Iterator<Item = Stat> {
236        all::<Self>()
237    }
238}
239
240impl Display for Stat {
241    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
242        write!(f, "{}", self.name())
243    }
244}
245
246#[cfg(test)]
247mod test {
248    use enum_iterator::all;
249
250    use crate::arrays::PrimitiveArray;
251    use crate::expr::stats::Stat;
252
253    #[test]
254    fn min_of_nulls_is_not_panic() {
255        let min = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None])
256            .statistics()
257            .compute_as::<i64>(Stat::Min);
258
259        assert_eq!(min, None);
260    }
261
262    #[test]
263    fn has_same_dtype_as_array() {
264        assert!(Stat::Min.has_same_dtype_as_array());
265        assert!(Stat::Max.has_same_dtype_as_array());
266        for stat in all::<Stat>().filter(|s| !matches!(s, Stat::Min | Stat::Max)) {
267            assert!(!stat.has_same_dtype_as_array());
268        }
269    }
270}