vortex_array/expr/stats/
mod.rs1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use enum_iterator::Sequence;
9use enum_iterator::all;
10use num_enum::IntoPrimitive;
11use num_enum::TryFromPrimitive;
12
13use crate::dtype::DType;
14use crate::dtype::Nullability::NonNullable;
15
16mod bound;
17mod precision;
18mod provider;
19mod stat_bound;
20
21pub use bound::*;
22pub use precision::*;
23pub use provider::*;
24pub use stat_bound::*;
25
26use crate::aggregate_fn;
27use crate::aggregate_fn::AggregateFnRef;
28use crate::aggregate_fn::AggregateFnVTable;
29use crate::aggregate_fn::AggregateFnVTableExt;
30use crate::aggregate_fn::EmptyOptions;
31use crate::aggregate_fn::NumericalAggregateOpts;
32
33#[derive(
34 Debug,
35 Clone,
36 Copy,
37 PartialEq,
38 Eq,
39 PartialOrd,
40 Ord,
41 Hash,
42 Sequence,
43 IntoPrimitive,
44 TryFromPrimitive,
45)]
46#[repr(u8)]
47pub enum Stat {
48 IsConstant = 0,
51 IsSorted = 1,
54 IsStrictSorted = 2,
57 Max = 3,
59 Min = 4,
61 Sum = 5,
63 NullCount = 6,
65 UncompressedSizeInBytes = 7,
67 NaNCount = 8,
69}
70
71pub struct Max;
74
75pub struct Min;
76
77pub struct Sum;
78
79pub struct IsConstant;
80
81pub struct IsSorted;
82
83pub struct IsStrictSorted;
84
85pub struct NullCount;
86
87pub struct UncompressedSizeInBytes;
88
89pub struct NaNCount;
90
91impl StatType<bool> for IsConstant {
92 type Bound = Precision<bool>;
93
94 const STAT: Stat = Stat::IsConstant;
95}
96
97impl StatType<bool> for IsSorted {
98 type Bound = Precision<bool>;
99
100 const STAT: Stat = Stat::IsSorted;
101}
102
103impl StatType<bool> for IsStrictSorted {
104 type Bound = Precision<bool>;
105
106 const STAT: Stat = Stat::IsStrictSorted;
107}
108
109impl<T: PartialOrd + Clone> StatType<T> for NullCount {
110 type Bound = UpperBound<T>;
111
112 const STAT: Stat = Stat::NullCount;
113}
114
115impl<T: PartialOrd + Clone> StatType<T> for UncompressedSizeInBytes {
116 type Bound = UpperBound<T>;
117
118 const STAT: Stat = Stat::UncompressedSizeInBytes;
119}
120
121impl<T: PartialOrd + Clone + Debug> StatType<T> for Max {
122 type Bound = UpperBound<T>;
123
124 const STAT: Stat = Stat::Max;
125}
126
127impl<T: PartialOrd + Clone + Debug> StatType<T> for Min {
128 type Bound = LowerBound<T>;
129
130 const STAT: Stat = Stat::Min;
131}
132
133impl<T: PartialOrd + Clone + Debug> StatType<T> for Sum {
134 type Bound = Precision<T>;
135
136 const STAT: Stat = Stat::Sum;
137}
138
139impl<T: PartialOrd + Clone> StatType<T> for NaNCount {
140 type Bound = UpperBound<T>;
141
142 const STAT: Stat = Stat::NaNCount;
143}
144
145impl Stat {
146 pub fn is_commutative(&self) -> bool {
149 match self {
151 Self::IsConstant
152 | Self::Max
153 | Self::Min
154 | Self::NullCount
155 | Self::Sum
156 | Self::NaNCount
157 | Self::UncompressedSizeInBytes => true,
158 Self::IsSorted | Self::IsStrictSorted => false,
159 }
160 }
161
162 pub fn has_same_dtype_as_array(&self) -> bool {
164 matches!(self, Stat::Min | Stat::Max)
165 }
166
167 pub fn dtype(&self, data_type: &DType) -> Option<DType> {
169 Some(match self {
170 Self::IsConstant => DType::Bool(NonNullable),
171 Self::IsSorted => DType::Bool(NonNullable),
172 Self::IsStrictSorted => DType::Bool(NonNullable),
173 Self::Max if matches!(data_type, DType::Null) => return None,
174 Self::Max => data_type.clone(),
175 Self::Min if matches!(data_type, DType::Null) => return None,
176 Self::Min => data_type.clone(),
177 Self::NullCount => {
178 return aggregate_fn::fns::null_count::NullCount
179 .return_dtype(&EmptyOptions, data_type);
180 }
181 Self::UncompressedSizeInBytes => {
182 return aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes
183 .return_dtype(&EmptyOptions, data_type);
184 }
185 Self::NaNCount => {
186 return aggregate_fn::fns::nan_count::NanCount
187 .return_dtype(&EmptyOptions, data_type);
188 }
189 Self::Sum => {
190 return aggregate_fn::fns::sum::Sum
192 .return_dtype(&NumericalAggregateOpts::skip_nans(), data_type);
193 }
194 })
195 }
196
197 pub fn aggregate_fn(&self) -> Option<AggregateFnRef> {
199 Some(match self {
201 Self::Max => aggregate_fn::fns::max::Max.bind(NumericalAggregateOpts::skip_nans()),
202 Self::Min => aggregate_fn::fns::min::Min.bind(NumericalAggregateOpts::skip_nans()),
203 Self::Sum => aggregate_fn::fns::sum::Sum.bind(NumericalAggregateOpts::skip_nans()),
204 Self::NullCount => aggregate_fn::fns::null_count::NullCount.bind(EmptyOptions),
205 Self::NaNCount => aggregate_fn::fns::nan_count::NanCount.bind(EmptyOptions),
206 Self::UncompressedSizeInBytes => {
207 aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes
208 .bind(EmptyOptions)
209 }
210 Self::IsConstant | Self::IsSorted | Self::IsStrictSorted => return None,
211 })
212 }
213
214 pub fn from_aggregate_fn(aggregate_fn: &AggregateFnRef) -> Option<Self> {
219 if let Some(options) = aggregate_fn.as_opt::<aggregate_fn::fns::sum::Sum>() {
220 return options.skip_nans.then_some(Self::Sum);
221 }
222 if aggregate_fn.is::<aggregate_fn::fns::nan_count::NanCount>() {
223 return Some(Self::NaNCount);
224 }
225 if aggregate_fn.is::<aggregate_fn::fns::null_count::NullCount>() {
226 return Some(Self::NullCount);
227 }
228 if let Some(options) = aggregate_fn.as_opt::<aggregate_fn::fns::min::Min>() {
229 return options.skip_nans.then_some(Self::Min);
230 }
231 if let Some(options) = aggregate_fn.as_opt::<aggregate_fn::fns::max::Max>() {
232 return options.skip_nans.then_some(Self::Max);
233 }
234 if aggregate_fn
235 .is::<aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes>()
236 {
237 return Some(Self::UncompressedSizeInBytes);
238 }
239 None
240 }
241
242 pub fn name(&self) -> &str {
243 match self {
244 Self::IsConstant => "is_constant",
245 Self::IsSorted => "is_sorted",
246 Self::IsStrictSorted => "is_strict_sorted",
247 Self::Max => "max",
248 Self::Min => "min",
249 Self::NullCount => "null_count",
250 Self::UncompressedSizeInBytes => "uncompressed_size_in_bytes",
251 Self::Sum => "sum",
252 Self::NaNCount => "nan_count",
253 }
254 }
255
256 pub fn all() -> impl Iterator<Item = Stat> {
257 all::<Self>()
258 }
259}
260
261impl Display for Stat {
262 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
263 write!(f, "{}", self.name())
264 }
265}
266
267#[cfg(test)]
268mod test {
269 use enum_iterator::all;
270
271 use crate::VortexSessionExecute;
272 use crate::array_session;
273 use crate::arrays::PrimitiveArray;
274 use crate::expr::stats::Stat;
275
276 #[test]
277 fn min_of_nulls_is_not_panic() {
278 let min = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None])
279 .statistics()
280 .compute_as::<i64>(Stat::Min, &mut array_session().create_execution_ctx());
281
282 assert_eq!(min, None);
283 }
284
285 #[test]
286 fn has_same_dtype_as_array() {
287 assert!(Stat::Min.has_same_dtype_as_array());
288 assert!(Stat::Max.has_same_dtype_as_array());
289 for stat in all::<Stat>().filter(|s| !matches!(s, Stat::Min | Stat::Max)) {
290 assert!(!stat.has_same_dtype_as_array());
291 }
292 }
293}