Skip to main content

vortex_array/expr/stats/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use enum_iterator::Sequence;
9use enum_iterator::all;
10use num_enum::IntoPrimitive;
11use num_enum::TryFromPrimitive;
12
13use crate::dtype::DType;
14use crate::dtype::Nullability::NonNullable;
15use crate::dtype::PType;
16
17mod bound;
18mod precision;
19mod provider;
20mod stat_bound;
21
22pub use bound::*;
23pub use precision::*;
24pub use provider::*;
25pub use stat_bound::*;
26
27use crate::aggregate_fn;
28use crate::aggregate_fn::AggregateFnVTable;
29use crate::aggregate_fn::EmptyOptions;
30
31#[derive(
32    Debug,
33    Clone,
34    Copy,
35    PartialEq,
36    Eq,
37    PartialOrd,
38    Ord,
39    Hash,
40    Sequence,
41    IntoPrimitive,
42    TryFromPrimitive,
43)]
44#[repr(u8)]
45pub enum Stat {
46    /// Whether all values are the same (nulls are not equal to other non-null values,
47    /// so this is true iff all values are null or all values are the same non-null value)
48    IsConstant = 0,
49    /// Whether the non-null values in the array are sorted in ascending order (i.e., we skip nulls)
50    /// This may later be extended to support descending order, but for now we only support ascending order.
51    IsSorted = 1,
52    /// Whether the non-null values in the array are strictly sorted in ascending order (i.e., sorted with no duplicates)
53    /// This may later be extended to support descending order, but for now we only support ascending order.
54    IsStrictSorted = 2,
55    /// The maximum value in the array (ignoring nulls, unless all values are null)
56    Max = 3,
57    /// The minimum value in the array (ignoring nulls, unless all values are null)
58    Min = 4,
59    /// The sum of the non-null values of the array.
60    Sum = 5,
61    /// The number of null values in the array
62    NullCount = 6,
63    /// The uncompressed size of the array in bytes
64    UncompressedSizeInBytes = 7,
65    /// The number of NaN values in the array
66    NaNCount = 8,
67}
68
69/// These structs allow the extraction of the bound from the `Precision` value.
70/// They tie together the Stat and the StatBound, which allows the bound to be extracted.
71pub struct Max;
72
73pub struct Min;
74
75pub struct Sum;
76
77pub struct IsConstant;
78
79pub struct IsSorted;
80
81pub struct IsStrictSorted;
82
83pub struct NullCount;
84
85pub struct UncompressedSizeInBytes;
86
87pub struct NaNCount;
88
89impl StatType<bool> for IsConstant {
90    type Bound = Precision<bool>;
91
92    const STAT: Stat = Stat::IsConstant;
93}
94
95impl StatType<bool> for IsSorted {
96    type Bound = Precision<bool>;
97
98    const STAT: Stat = Stat::IsSorted;
99}
100
101impl StatType<bool> for IsStrictSorted {
102    type Bound = Precision<bool>;
103
104    const STAT: Stat = Stat::IsStrictSorted;
105}
106
107impl<T: PartialOrd + Clone> StatType<T> for NullCount {
108    type Bound = UpperBound<T>;
109
110    const STAT: Stat = Stat::NullCount;
111}
112
113impl<T: PartialOrd + Clone> StatType<T> for UncompressedSizeInBytes {
114    type Bound = UpperBound<T>;
115
116    const STAT: Stat = Stat::UncompressedSizeInBytes;
117}
118
119impl<T: PartialOrd + Clone + Debug> StatType<T> for Max {
120    type Bound = UpperBound<T>;
121
122    const STAT: Stat = Stat::Max;
123}
124
125impl<T: PartialOrd + Clone + Debug> StatType<T> for Min {
126    type Bound = LowerBound<T>;
127
128    const STAT: Stat = Stat::Min;
129}
130
131impl<T: PartialOrd + Clone + Debug> StatType<T> for Sum {
132    type Bound = Precision<T>;
133
134    const STAT: Stat = Stat::Sum;
135}
136
137impl<T: PartialOrd + Clone> StatType<T> for NaNCount {
138    type Bound = UpperBound<T>;
139
140    const STAT: Stat = Stat::NaNCount;
141}
142
143impl Stat {
144    /// Whether the statistic is commutative (i.e., whether merging can be done independently of ordering)
145    /// e.g., min/max are commutative, but is_sorted is not
146    pub fn is_commutative(&self) -> bool {
147        // NOTE: we prefer this syntax to force a compile error if we add a new stat
148        match self {
149            Self::IsConstant
150            | Self::Max
151            | Self::Min
152            | Self::NullCount
153            | Self::Sum
154            | Self::NaNCount
155            | Self::UncompressedSizeInBytes => true,
156            Self::IsSorted | Self::IsStrictSorted => false,
157        }
158    }
159
160    /// Whether the statistic has the same dtype as the array it's computed on
161    pub fn has_same_dtype_as_array(&self) -> bool {
162        matches!(self, Stat::Min | Stat::Max)
163    }
164
165    /// Return the [`DType`] of the statistic scalar assuming the array is of the given [`DType`].
166    pub fn dtype(&self, data_type: &DType) -> Option<DType> {
167        Some(match self {
168            Self::IsConstant => DType::Bool(NonNullable),
169            Self::IsSorted => DType::Bool(NonNullable),
170            Self::IsStrictSorted => DType::Bool(NonNullable),
171            Self::Max if matches!(data_type, DType::Null) => return None,
172            Self::Max => data_type.clone(),
173            Self::Min if matches!(data_type, DType::Null) => return None,
174            Self::Min => data_type.clone(),
175            Self::NullCount => DType::Primitive(PType::U64, NonNullable),
176            Self::UncompressedSizeInBytes => DType::Primitive(PType::U64, NonNullable),
177            Self::NaNCount => {
178                return aggregate_fn::fns::nan_count::NanCount
179                    .return_dtype(&EmptyOptions, data_type);
180            }
181            Self::Sum => {
182                return aggregate_fn::fns::sum::Sum.return_dtype(&EmptyOptions, data_type);
183            }
184        })
185    }
186
187    pub fn name(&self) -> &str {
188        match self {
189            Self::IsConstant => "is_constant",
190            Self::IsSorted => "is_sorted",
191            Self::IsStrictSorted => "is_strict_sorted",
192            Self::Max => "max",
193            Self::Min => "min",
194            Self::NullCount => "null_count",
195            Self::UncompressedSizeInBytes => "uncompressed_size_in_bytes",
196            Self::Sum => "sum",
197            Self::NaNCount => "nan_count",
198        }
199    }
200
201    pub fn all() -> impl Iterator<Item = Stat> {
202        all::<Self>()
203    }
204}
205
206impl Display for Stat {
207    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
208        write!(f, "{}", self.name())
209    }
210}
211
212#[cfg(test)]
213mod test {
214    use enum_iterator::all;
215
216    use crate::arrays::PrimitiveArray;
217    use crate::expr::stats::Stat;
218
219    #[test]
220    fn min_of_nulls_is_not_panic() {
221        let min = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None])
222            .statistics()
223            .compute_as::<i64>(Stat::Min);
224
225        assert_eq!(min, None);
226    }
227
228    #[test]
229    fn has_same_dtype_as_array() {
230        assert!(Stat::Min.has_same_dtype_as_array());
231        assert!(Stat::Max.has_same_dtype_as_array());
232        for stat in all::<Stat>().filter(|s| !matches!(s, Stat::Min | Stat::Max)) {
233            assert!(!stat.has_same_dtype_as_array());
234        }
235    }
236}