Skip to main content

vortex_array/expr/stats/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use enum_iterator::Sequence;
9use enum_iterator::all;
10use num_enum::IntoPrimitive;
11use num_enum::TryFromPrimitive;
12
13use crate::dtype::DType;
14use crate::dtype::Nullability::NonNullable;
15
16mod bound;
17mod precision;
18mod provider;
19mod stat_bound;
20
21pub use bound::*;
22pub use precision::*;
23pub use provider::*;
24pub use stat_bound::*;
25
26use crate::aggregate_fn;
27use crate::aggregate_fn::AggregateFnRef;
28use crate::aggregate_fn::AggregateFnVTable;
29use crate::aggregate_fn::AggregateFnVTableExt;
30use crate::aggregate_fn::EmptyOptions;
31use crate::aggregate_fn::NumericalAggregateOpts;
32
33#[derive(
34    Debug,
35    Clone,
36    Copy,
37    PartialEq,
38    Eq,
39    PartialOrd,
40    Ord,
41    Hash,
42    Sequence,
43    IntoPrimitive,
44    TryFromPrimitive,
45)]
46#[repr(u8)]
47pub enum Stat {
48    /// Whether all values are the same (nulls are not equal to other non-null values,
49    /// so this is true iff all values are null or all values are the same non-null value)
50    IsConstant = 0,
51    /// Whether the non-null values in the array are sorted in ascending order (i.e., we skip nulls)
52    /// This may later be extended to support descending order, but for now we only support ascending order.
53    IsSorted = 1,
54    /// Whether the non-null values in the array are strictly sorted in ascending order (i.e., sorted with no duplicates)
55    /// This may later be extended to support descending order, but for now we only support ascending order.
56    IsStrictSorted = 2,
57    /// The maximum value in the array (ignoring nulls, unless all values are null)
58    Max = 3,
59    /// The minimum value in the array (ignoring nulls, unless all values are null)
60    Min = 4,
61    /// The sum of the non-null values of the array.
62    Sum = 5,
63    /// The number of null values in the array
64    NullCount = 6,
65    /// The uncompressed size of the array in bytes
66    UncompressedSizeInBytes = 7,
67    /// The number of NaN values in the array
68    NaNCount = 8,
69}
70
71/// These structs allow the extraction of the bound from the `Precision` value.
72/// They tie together the Stat and the StatBound, which allows the bound to be extracted.
73pub struct Max;
74
75pub struct Min;
76
77pub struct Sum;
78
79pub struct IsConstant;
80
81pub struct IsSorted;
82
83pub struct IsStrictSorted;
84
85pub struct NullCount;
86
87pub struct UncompressedSizeInBytes;
88
89pub struct NaNCount;
90
91impl StatType<bool> for IsConstant {
92    type Bound = Precision<bool>;
93
94    const STAT: Stat = Stat::IsConstant;
95}
96
97impl StatType<bool> for IsSorted {
98    type Bound = Precision<bool>;
99
100    const STAT: Stat = Stat::IsSorted;
101}
102
103impl StatType<bool> for IsStrictSorted {
104    type Bound = Precision<bool>;
105
106    const STAT: Stat = Stat::IsStrictSorted;
107}
108
109impl<T: PartialOrd + Clone> StatType<T> for NullCount {
110    type Bound = UpperBound<T>;
111
112    const STAT: Stat = Stat::NullCount;
113}
114
115impl<T: PartialOrd + Clone> StatType<T> for UncompressedSizeInBytes {
116    type Bound = UpperBound<T>;
117
118    const STAT: Stat = Stat::UncompressedSizeInBytes;
119}
120
121impl<T: PartialOrd + Clone + Debug> StatType<T> for Max {
122    type Bound = UpperBound<T>;
123
124    const STAT: Stat = Stat::Max;
125}
126
127impl<T: PartialOrd + Clone + Debug> StatType<T> for Min {
128    type Bound = LowerBound<T>;
129
130    const STAT: Stat = Stat::Min;
131}
132
133impl<T: PartialOrd + Clone + Debug> StatType<T> for Sum {
134    type Bound = Precision<T>;
135
136    const STAT: Stat = Stat::Sum;
137}
138
139impl<T: PartialOrd + Clone> StatType<T> for NaNCount {
140    type Bound = UpperBound<T>;
141
142    const STAT: Stat = Stat::NaNCount;
143}
144
145impl Stat {
146    /// Whether the statistic is commutative (i.e., whether merging can be done independently of ordering)
147    /// e.g., min/max are commutative, but is_sorted is not
148    pub fn is_commutative(&self) -> bool {
149        // NOTE: we prefer this syntax to force a compile error if we add a new stat
150        match self {
151            Self::IsConstant
152            | Self::Max
153            | Self::Min
154            | Self::NullCount
155            | Self::Sum
156            | Self::NaNCount
157            | Self::UncompressedSizeInBytes => true,
158            Self::IsSorted | Self::IsStrictSorted => false,
159        }
160    }
161
162    /// Whether the statistic has the same dtype as the array it's computed on
163    pub fn has_same_dtype_as_array(&self) -> bool {
164        matches!(self, Stat::Min | Stat::Max)
165    }
166
167    /// Return the [`DType`] of the statistic scalar assuming the array is of the given [`DType`].
168    pub fn dtype(&self, data_type: &DType) -> Option<DType> {
169        Some(match self {
170            Self::IsConstant => DType::Bool(NonNullable),
171            Self::IsSorted => DType::Bool(NonNullable),
172            Self::IsStrictSorted => DType::Bool(NonNullable),
173            Self::Max if matches!(data_type, DType::Null) => return None,
174            Self::Max => data_type.clone(),
175            Self::Min if matches!(data_type, DType::Null) => return None,
176            Self::Min => data_type.clone(),
177            Self::NullCount => {
178                return aggregate_fn::fns::null_count::NullCount
179                    .return_dtype(&EmptyOptions, data_type);
180            }
181            Self::UncompressedSizeInBytes => {
182                return aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes
183                    .return_dtype(&EmptyOptions, data_type);
184            }
185            Self::NaNCount => {
186                return aggregate_fn::fns::nan_count::NanCount
187                    .return_dtype(&EmptyOptions, data_type);
188            }
189            Self::Sum => {
190                // Statistics follow NaN-skipping semantics; request it explicitly.
191                return aggregate_fn::fns::sum::Sum
192                    .return_dtype(&NumericalAggregateOpts::skip_nans(), data_type);
193            }
194        })
195    }
196
197    /// Return the built-in aggregate function corresponding to this statistic, if one exists.
198    pub fn aggregate_fn(&self) -> Option<AggregateFnRef> {
199        // Statistics follow NaN-skipping semantics; request it explicitly rather than the default.
200        Some(match self {
201            Self::Max => aggregate_fn::fns::max::Max.bind(NumericalAggregateOpts::skip_nans()),
202            Self::Min => aggregate_fn::fns::min::Min.bind(NumericalAggregateOpts::skip_nans()),
203            Self::Sum => aggregate_fn::fns::sum::Sum.bind(NumericalAggregateOpts::skip_nans()),
204            Self::NullCount => aggregate_fn::fns::null_count::NullCount.bind(EmptyOptions),
205            Self::NaNCount => aggregate_fn::fns::nan_count::NanCount.bind(EmptyOptions),
206            Self::UncompressedSizeInBytes => {
207                aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes
208                    .bind(EmptyOptions)
209            }
210            Self::IsConstant | Self::IsSorted | Self::IsStrictSorted => return None,
211        })
212    }
213
214    /// Return the statistic represented by `aggregate_fn`, if it has a legacy stat slot.
215    ///
216    /// Min/max/sum statistics skip NaN values, so NaN-including configurations of those
217    /// aggregates have no stat slot.
218    pub fn from_aggregate_fn(aggregate_fn: &AggregateFnRef) -> Option<Self> {
219        if let Some(options) = aggregate_fn.as_opt::<aggregate_fn::fns::sum::Sum>() {
220            return options.skip_nans.then_some(Self::Sum);
221        }
222        if aggregate_fn.is::<aggregate_fn::fns::nan_count::NanCount>() {
223            return Some(Self::NaNCount);
224        }
225        if aggregate_fn.is::<aggregate_fn::fns::null_count::NullCount>() {
226            return Some(Self::NullCount);
227        }
228        if let Some(options) = aggregate_fn.as_opt::<aggregate_fn::fns::min::Min>() {
229            return options.skip_nans.then_some(Self::Min);
230        }
231        if let Some(options) = aggregate_fn.as_opt::<aggregate_fn::fns::max::Max>() {
232            return options.skip_nans.then_some(Self::Max);
233        }
234        if aggregate_fn
235            .is::<aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes>()
236        {
237            return Some(Self::UncompressedSizeInBytes);
238        }
239        None
240    }
241
242    pub fn name(&self) -> &str {
243        match self {
244            Self::IsConstant => "is_constant",
245            Self::IsSorted => "is_sorted",
246            Self::IsStrictSorted => "is_strict_sorted",
247            Self::Max => "max",
248            Self::Min => "min",
249            Self::NullCount => "null_count",
250            Self::UncompressedSizeInBytes => "uncompressed_size_in_bytes",
251            Self::Sum => "sum",
252            Self::NaNCount => "nan_count",
253        }
254    }
255
256    pub fn all() -> impl Iterator<Item = Stat> {
257        all::<Self>()
258    }
259}
260
261impl Display for Stat {
262    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
263        write!(f, "{}", self.name())
264    }
265}
266
267#[cfg(test)]
268mod test {
269    use enum_iterator::all;
270
271    use crate::VortexSessionExecute;
272    use crate::array_session;
273    use crate::arrays::PrimitiveArray;
274    use crate::expr::stats::Stat;
275
276    #[test]
277    fn min_of_nulls_is_not_panic() {
278        let min = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None])
279            .statistics()
280            .compute_as::<i64>(Stat::Min, &mut array_session().create_execution_ctx());
281
282        assert_eq!(min, None);
283    }
284
285    #[test]
286    fn has_same_dtype_as_array() {
287        assert!(Stat::Min.has_same_dtype_as_array());
288        assert!(Stat::Max.has_same_dtype_as_array());
289        for stat in all::<Stat>().filter(|s| !matches!(s, Stat::Min | Stat::Max)) {
290            assert!(!stat.has_same_dtype_as_array());
291        }
292    }
293}