Skip to main content

vortex_array/expr/stats/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use enum_iterator::Sequence;
9use enum_iterator::all;
10use num_enum::IntoPrimitive;
11use num_enum::TryFromPrimitive;
12
13use crate::dtype::DType;
14use crate::dtype::Nullability::NonNullable;
15
16mod bound;
17mod precision;
18mod provider;
19mod stat_bound;
20
21pub use bound::*;
22pub use precision::*;
23pub use provider::*;
24pub use stat_bound::*;
25
26use crate::aggregate_fn;
27use crate::aggregate_fn::AggregateFnRef;
28use crate::aggregate_fn::AggregateFnVTable;
29use crate::aggregate_fn::AggregateFnVTableExt;
30use crate::aggregate_fn::EmptyOptions;
31
32#[derive(
33    Debug,
34    Clone,
35    Copy,
36    PartialEq,
37    Eq,
38    PartialOrd,
39    Ord,
40    Hash,
41    Sequence,
42    IntoPrimitive,
43    TryFromPrimitive,
44)]
45#[repr(u8)]
46pub enum Stat {
47    /// Whether all values are the same (nulls are not equal to other non-null values,
48    /// so this is true iff all values are null or all values are the same non-null value)
49    IsConstant = 0,
50    /// Whether the non-null values in the array are sorted in ascending order (i.e., we skip nulls)
51    /// This may later be extended to support descending order, but for now we only support ascending order.
52    IsSorted = 1,
53    /// Whether the non-null values in the array are strictly sorted in ascending order (i.e., sorted with no duplicates)
54    /// This may later be extended to support descending order, but for now we only support ascending order.
55    IsStrictSorted = 2,
56    /// The maximum value in the array (ignoring nulls, unless all values are null)
57    Max = 3,
58    /// The minimum value in the array (ignoring nulls, unless all values are null)
59    Min = 4,
60    /// The sum of the non-null values of the array.
61    Sum = 5,
62    /// The number of null values in the array
63    NullCount = 6,
64    /// The uncompressed size of the array in bytes
65    UncompressedSizeInBytes = 7,
66    /// The number of NaN values in the array
67    NaNCount = 8,
68}
69
70/// These structs allow the extraction of the bound from the `Precision` value.
71/// They tie together the Stat and the StatBound, which allows the bound to be extracted.
72pub struct Max;
73
74pub struct Min;
75
76pub struct Sum;
77
78pub struct IsConstant;
79
80pub struct IsSorted;
81
82pub struct IsStrictSorted;
83
84pub struct NullCount;
85
86pub struct UncompressedSizeInBytes;
87
88pub struct NaNCount;
89
90impl StatType<bool> for IsConstant {
91    type Bound = Precision<bool>;
92
93    const STAT: Stat = Stat::IsConstant;
94}
95
96impl StatType<bool> for IsSorted {
97    type Bound = Precision<bool>;
98
99    const STAT: Stat = Stat::IsSorted;
100}
101
102impl StatType<bool> for IsStrictSorted {
103    type Bound = Precision<bool>;
104
105    const STAT: Stat = Stat::IsStrictSorted;
106}
107
108impl<T: PartialOrd + Clone> StatType<T> for NullCount {
109    type Bound = UpperBound<T>;
110
111    const STAT: Stat = Stat::NullCount;
112}
113
114impl<T: PartialOrd + Clone> StatType<T> for UncompressedSizeInBytes {
115    type Bound = UpperBound<T>;
116
117    const STAT: Stat = Stat::UncompressedSizeInBytes;
118}
119
120impl<T: PartialOrd + Clone + Debug> StatType<T> for Max {
121    type Bound = UpperBound<T>;
122
123    const STAT: Stat = Stat::Max;
124}
125
126impl<T: PartialOrd + Clone + Debug> StatType<T> for Min {
127    type Bound = LowerBound<T>;
128
129    const STAT: Stat = Stat::Min;
130}
131
132impl<T: PartialOrd + Clone + Debug> StatType<T> for Sum {
133    type Bound = Precision<T>;
134
135    const STAT: Stat = Stat::Sum;
136}
137
138impl<T: PartialOrd + Clone> StatType<T> for NaNCount {
139    type Bound = UpperBound<T>;
140
141    const STAT: Stat = Stat::NaNCount;
142}
143
144impl Stat {
145    /// Whether the statistic is commutative (i.e., whether merging can be done independently of ordering)
146    /// e.g., min/max are commutative, but is_sorted is not
147    pub fn is_commutative(&self) -> bool {
148        // NOTE: we prefer this syntax to force a compile error if we add a new stat
149        match self {
150            Self::IsConstant
151            | Self::Max
152            | Self::Min
153            | Self::NullCount
154            | Self::Sum
155            | Self::NaNCount
156            | Self::UncompressedSizeInBytes => true,
157            Self::IsSorted | Self::IsStrictSorted => false,
158        }
159    }
160
161    /// Whether the statistic has the same dtype as the array it's computed on
162    pub fn has_same_dtype_as_array(&self) -> bool {
163        matches!(self, Stat::Min | Stat::Max)
164    }
165
166    /// Return the [`DType`] of the statistic scalar assuming the array is of the given [`DType`].
167    pub fn dtype(&self, data_type: &DType) -> Option<DType> {
168        Some(match self {
169            Self::IsConstant => DType::Bool(NonNullable),
170            Self::IsSorted => DType::Bool(NonNullable),
171            Self::IsStrictSorted => DType::Bool(NonNullable),
172            Self::Max if matches!(data_type, DType::Null) => return None,
173            Self::Max => data_type.clone(),
174            Self::Min if matches!(data_type, DType::Null) => return None,
175            Self::Min => data_type.clone(),
176            Self::NullCount => {
177                return aggregate_fn::fns::null_count::NullCount
178                    .return_dtype(&EmptyOptions, data_type);
179            }
180            Self::UncompressedSizeInBytes => {
181                return aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes
182                    .return_dtype(&EmptyOptions, data_type);
183            }
184            Self::NaNCount => {
185                return aggregate_fn::fns::nan_count::NanCount
186                    .return_dtype(&EmptyOptions, data_type);
187            }
188            Self::Sum => {
189                return aggregate_fn::fns::sum::Sum.return_dtype(&EmptyOptions, data_type);
190            }
191        })
192    }
193
194    /// Return the built-in aggregate function corresponding to this statistic, if one exists.
195    pub fn aggregate_fn(&self) -> Option<AggregateFnRef> {
196        Some(match self {
197            Self::Max => aggregate_fn::fns::max::Max.bind(EmptyOptions),
198            Self::Min => aggregate_fn::fns::min::Min.bind(EmptyOptions),
199            Self::Sum => aggregate_fn::fns::sum::Sum.bind(EmptyOptions),
200            Self::NullCount => aggregate_fn::fns::null_count::NullCount.bind(EmptyOptions),
201            Self::NaNCount => aggregate_fn::fns::nan_count::NanCount.bind(EmptyOptions),
202            Self::UncompressedSizeInBytes => {
203                aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes
204                    .bind(EmptyOptions)
205            }
206            Self::IsConstant | Self::IsSorted | Self::IsStrictSorted => return None,
207        })
208    }
209
210    /// Return the statistic represented by `aggregate_fn`, if it has a legacy stat slot.
211    pub fn from_aggregate_fn(aggregate_fn: &AggregateFnRef) -> Option<Self> {
212        if aggregate_fn.is::<aggregate_fn::fns::sum::Sum>() {
213            return Some(Self::Sum);
214        }
215        if aggregate_fn.is::<aggregate_fn::fns::nan_count::NanCount>() {
216            return Some(Self::NaNCount);
217        }
218        if aggregate_fn.is::<aggregate_fn::fns::null_count::NullCount>() {
219            return Some(Self::NullCount);
220        }
221        if aggregate_fn.is::<aggregate_fn::fns::min::Min>() {
222            return Some(Self::Min);
223        }
224        if aggregate_fn.is::<aggregate_fn::fns::max::Max>() {
225            return Some(Self::Max);
226        }
227        if aggregate_fn
228            .is::<aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes>()
229        {
230            return Some(Self::UncompressedSizeInBytes);
231        }
232        None
233    }
234
235    pub fn name(&self) -> &str {
236        match self {
237            Self::IsConstant => "is_constant",
238            Self::IsSorted => "is_sorted",
239            Self::IsStrictSorted => "is_strict_sorted",
240            Self::Max => "max",
241            Self::Min => "min",
242            Self::NullCount => "null_count",
243            Self::UncompressedSizeInBytes => "uncompressed_size_in_bytes",
244            Self::Sum => "sum",
245            Self::NaNCount => "nan_count",
246        }
247    }
248
249    pub fn all() -> impl Iterator<Item = Stat> {
250        all::<Self>()
251    }
252}
253
254impl Display for Stat {
255    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256        write!(f, "{}", self.name())
257    }
258}
259
260#[cfg(test)]
261mod test {
262    use enum_iterator::all;
263
264    use crate::LEGACY_SESSION;
265    use crate::VortexSessionExecute;
266    use crate::arrays::PrimitiveArray;
267    use crate::expr::stats::Stat;
268
269    #[test]
270    fn min_of_nulls_is_not_panic() {
271        let min = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None])
272            .statistics()
273            .compute_as::<i64>(Stat::Min, &mut LEGACY_SESSION.create_execution_ctx());
274
275        assert_eq!(min, None);
276    }
277
278    #[test]
279    fn has_same_dtype_as_array() {
280        assert!(Stat::Min.has_same_dtype_as_array());
281        assert!(Stat::Max.has_same_dtype_as_array());
282        for stat in all::<Stat>().filter(|s| !matches!(s, Stat::Min | Stat::Max)) {
283            assert!(!stat.has_same_dtype_as_array());
284        }
285    }
286}