Skip to main content

vortex_array/expr/stats/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use enum_iterator::Sequence;
9use enum_iterator::all;
10use num_enum::IntoPrimitive;
11use num_enum::TryFromPrimitive;
12
13use crate::dtype::DType;
14use crate::dtype::Nullability::NonNullable;
15
16mod bound;
17mod precision;
18mod provider;
19mod stat_bound;
20
21pub use bound::*;
22pub use precision::*;
23pub use provider::*;
24pub use stat_bound::*;
25
26use crate::aggregate_fn;
27use crate::aggregate_fn::AggregateFnRef;
28use crate::aggregate_fn::AggregateFnVTable;
29use crate::aggregate_fn::AggregateFnVTableExt;
30use crate::aggregate_fn::EmptyOptions;
31
32#[derive(
33    Debug,
34    Clone,
35    Copy,
36    PartialEq,
37    Eq,
38    PartialOrd,
39    Ord,
40    Hash,
41    Sequence,
42    IntoPrimitive,
43    TryFromPrimitive,
44)]
45#[repr(u8)]
46pub enum Stat {
47    /// Whether all values are the same (nulls are not equal to other non-null values,
48    /// so this is true iff all values are null or all values are the same non-null value)
49    IsConstant = 0,
50    /// Whether the non-null values in the array are sorted in ascending order (i.e., we skip nulls)
51    /// This may later be extended to support descending order, but for now we only support ascending order.
52    IsSorted = 1,
53    /// Whether the non-null values in the array are strictly sorted in ascending order (i.e., sorted with no duplicates)
54    /// This may later be extended to support descending order, but for now we only support ascending order.
55    IsStrictSorted = 2,
56    /// The maximum value in the array (ignoring nulls, unless all values are null)
57    Max = 3,
58    /// The minimum value in the array (ignoring nulls, unless all values are null)
59    Min = 4,
60    /// The sum of the non-null values of the array.
61    Sum = 5,
62    /// The number of null values in the array
63    NullCount = 6,
64    /// The uncompressed size of the array in bytes
65    UncompressedSizeInBytes = 7,
66    /// The number of NaN values in the array
67    NaNCount = 8,
68}
69
70/// These structs allow the extraction of the bound from the `Precision` value.
71/// They tie together the Stat and the StatBound, which allows the bound to be extracted.
72pub struct Max;
73
74pub struct Min;
75
76pub struct Sum;
77
78pub struct IsConstant;
79
80pub struct IsSorted;
81
82pub struct IsStrictSorted;
83
84pub struct NullCount;
85
86pub struct UncompressedSizeInBytes;
87
88pub struct NaNCount;
89
90impl StatType<bool> for IsConstant {
91    type Bound = Precision<bool>;
92
93    const STAT: Stat = Stat::IsConstant;
94}
95
96impl StatType<bool> for IsSorted {
97    type Bound = Precision<bool>;
98
99    const STAT: Stat = Stat::IsSorted;
100}
101
102impl StatType<bool> for IsStrictSorted {
103    type Bound = Precision<bool>;
104
105    const STAT: Stat = Stat::IsStrictSorted;
106}
107
108impl<T: PartialOrd + Clone> StatType<T> for NullCount {
109    type Bound = UpperBound<T>;
110
111    const STAT: Stat = Stat::NullCount;
112}
113
114impl<T: PartialOrd + Clone> StatType<T> for UncompressedSizeInBytes {
115    type Bound = UpperBound<T>;
116
117    const STAT: Stat = Stat::UncompressedSizeInBytes;
118}
119
120impl<T: PartialOrd + Clone + Debug> StatType<T> for Max {
121    type Bound = UpperBound<T>;
122
123    const STAT: Stat = Stat::Max;
124}
125
126impl<T: PartialOrd + Clone + Debug> StatType<T> for Min {
127    type Bound = LowerBound<T>;
128
129    const STAT: Stat = Stat::Min;
130}
131
132impl<T: PartialOrd + Clone + Debug> StatType<T> for Sum {
133    type Bound = Precision<T>;
134
135    const STAT: Stat = Stat::Sum;
136}
137
138impl<T: PartialOrd + Clone> StatType<T> for NaNCount {
139    type Bound = UpperBound<T>;
140
141    const STAT: Stat = Stat::NaNCount;
142}
143
144impl Stat {
145    /// Whether the statistic is commutative (i.e., whether merging can be done independently of ordering)
146    /// e.g., min/max are commutative, but is_sorted is not
147    pub fn is_commutative(&self) -> bool {
148        // NOTE: we prefer this syntax to force a compile error if we add a new stat
149        match self {
150            Self::IsConstant
151            | Self::Max
152            | Self::Min
153            | Self::NullCount
154            | Self::Sum
155            | Self::NaNCount
156            | Self::UncompressedSizeInBytes => true,
157            Self::IsSorted | Self::IsStrictSorted => false,
158        }
159    }
160
161    /// Whether the statistic has the same dtype as the array it's computed on
162    pub fn has_same_dtype_as_array(&self) -> bool {
163        matches!(self, Stat::Min | Stat::Max)
164    }
165
166    /// Return the [`DType`] of the statistic scalar assuming the array is of the given [`DType`].
167    pub fn dtype(&self, data_type: &DType) -> Option<DType> {
168        Some(match self {
169            Self::IsConstant => DType::Bool(NonNullable),
170            Self::IsSorted => DType::Bool(NonNullable),
171            Self::IsStrictSorted => DType::Bool(NonNullable),
172            Self::Max if matches!(data_type, DType::Null) => return None,
173            Self::Max => data_type.clone(),
174            Self::Min if matches!(data_type, DType::Null) => return None,
175            Self::Min => data_type.clone(),
176            Self::NullCount => {
177                return aggregate_fn::fns::null_count::NullCount
178                    .return_dtype(&EmptyOptions, data_type);
179            }
180            Self::UncompressedSizeInBytes => {
181                return aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes
182                    .return_dtype(&EmptyOptions, data_type);
183            }
184            Self::NaNCount => {
185                return aggregate_fn::fns::nan_count::NanCount
186                    .return_dtype(&EmptyOptions, data_type);
187            }
188            Self::Sum => {
189                return aggregate_fn::fns::sum::Sum.return_dtype(&EmptyOptions, data_type);
190            }
191        })
192    }
193
194    /// Return the built-in aggregate function corresponding to this statistic, if one exists.
195    pub fn aggregate_fn(&self) -> Option<AggregateFnRef> {
196        Some(match self {
197            Self::Sum => aggregate_fn::fns::sum::Sum.bind(EmptyOptions),
198            Self::NullCount => aggregate_fn::fns::null_count::NullCount.bind(EmptyOptions),
199            Self::NaNCount => aggregate_fn::fns::nan_count::NanCount.bind(EmptyOptions),
200            Self::UncompressedSizeInBytes => {
201                aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes
202                    .bind(EmptyOptions)
203            }
204            Self::IsConstant | Self::IsSorted | Self::IsStrictSorted | Self::Max | Self::Min => {
205                return None;
206            }
207        })
208    }
209
210    /// Return the statistic represented by `aggregate_fn`, if it has a legacy stat slot.
211    pub fn from_aggregate_fn(aggregate_fn: &AggregateFnRef) -> Option<Self> {
212        if aggregate_fn.is::<aggregate_fn::fns::sum::Sum>() {
213            return Some(Self::Sum);
214        }
215        if aggregate_fn.is::<aggregate_fn::fns::nan_count::NanCount>() {
216            return Some(Self::NaNCount);
217        }
218        if aggregate_fn.is::<aggregate_fn::fns::null_count::NullCount>() {
219            return Some(Self::NullCount);
220        }
221        if aggregate_fn
222            .is::<aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes>()
223        {
224            return Some(Self::UncompressedSizeInBytes);
225        }
226        None
227    }
228
229    pub fn name(&self) -> &str {
230        match self {
231            Self::IsConstant => "is_constant",
232            Self::IsSorted => "is_sorted",
233            Self::IsStrictSorted => "is_strict_sorted",
234            Self::Max => "max",
235            Self::Min => "min",
236            Self::NullCount => "null_count",
237            Self::UncompressedSizeInBytes => "uncompressed_size_in_bytes",
238            Self::Sum => "sum",
239            Self::NaNCount => "nan_count",
240        }
241    }
242
243    pub fn all() -> impl Iterator<Item = Stat> {
244        all::<Self>()
245    }
246}
247
248impl Display for Stat {
249    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
250        write!(f, "{}", self.name())
251    }
252}
253
254#[cfg(test)]
255mod test {
256    use enum_iterator::all;
257
258    use crate::LEGACY_SESSION;
259    use crate::VortexSessionExecute;
260    use crate::arrays::PrimitiveArray;
261    use crate::expr::stats::Stat;
262
263    #[test]
264    fn min_of_nulls_is_not_panic() {
265        let min = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None])
266            .statistics()
267            .compute_as::<i64>(Stat::Min, &mut LEGACY_SESSION.create_execution_ctx());
268
269        assert_eq!(min, None);
270    }
271
272    #[test]
273    fn has_same_dtype_as_array() {
274        assert!(Stat::Min.has_same_dtype_as_array());
275        assert!(Stat::Max.has_same_dtype_as_array());
276        for stat in all::<Stat>().filter(|s| !matches!(s, Stat::Min | Stat::Max)) {
277            assert!(!stat.has_same_dtype_as_array());
278        }
279    }
280}