vortex_array/expr/stats/
mod.rs1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use enum_iterator::Sequence;
9use enum_iterator::all;
10use num_enum::IntoPrimitive;
11use num_enum::TryFromPrimitive;
12
13use crate::dtype::DType;
14use crate::dtype::Nullability::NonNullable;
15
16mod bound;
17mod precision;
18mod provider;
19mod stat_bound;
20
21pub use bound::*;
22pub use precision::*;
23pub use provider::*;
24pub use stat_bound::*;
25
26use crate::aggregate_fn;
27use crate::aggregate_fn::AggregateFnRef;
28use crate::aggregate_fn::AggregateFnVTable;
29use crate::aggregate_fn::AggregateFnVTableExt;
30use crate::aggregate_fn::EmptyOptions;
31
32#[derive(
33 Debug,
34 Clone,
35 Copy,
36 PartialEq,
37 Eq,
38 PartialOrd,
39 Ord,
40 Hash,
41 Sequence,
42 IntoPrimitive,
43 TryFromPrimitive,
44)]
45#[repr(u8)]
46pub enum Stat {
47 IsConstant = 0,
50 IsSorted = 1,
53 IsStrictSorted = 2,
56 Max = 3,
58 Min = 4,
60 Sum = 5,
62 NullCount = 6,
64 UncompressedSizeInBytes = 7,
66 NaNCount = 8,
68}
69
70pub struct Max;
73
74pub struct Min;
75
76pub struct Sum;
77
78pub struct IsConstant;
79
80pub struct IsSorted;
81
82pub struct IsStrictSorted;
83
84pub struct NullCount;
85
86pub struct UncompressedSizeInBytes;
87
88pub struct NaNCount;
89
90impl StatType<bool> for IsConstant {
91 type Bound = Precision<bool>;
92
93 const STAT: Stat = Stat::IsConstant;
94}
95
96impl StatType<bool> for IsSorted {
97 type Bound = Precision<bool>;
98
99 const STAT: Stat = Stat::IsSorted;
100}
101
102impl StatType<bool> for IsStrictSorted {
103 type Bound = Precision<bool>;
104
105 const STAT: Stat = Stat::IsStrictSorted;
106}
107
108impl<T: PartialOrd + Clone> StatType<T> for NullCount {
109 type Bound = UpperBound<T>;
110
111 const STAT: Stat = Stat::NullCount;
112}
113
114impl<T: PartialOrd + Clone> StatType<T> for UncompressedSizeInBytes {
115 type Bound = UpperBound<T>;
116
117 const STAT: Stat = Stat::UncompressedSizeInBytes;
118}
119
120impl<T: PartialOrd + Clone + Debug> StatType<T> for Max {
121 type Bound = UpperBound<T>;
122
123 const STAT: Stat = Stat::Max;
124}
125
126impl<T: PartialOrd + Clone + Debug> StatType<T> for Min {
127 type Bound = LowerBound<T>;
128
129 const STAT: Stat = Stat::Min;
130}
131
132impl<T: PartialOrd + Clone + Debug> StatType<T> for Sum {
133 type Bound = Precision<T>;
134
135 const STAT: Stat = Stat::Sum;
136}
137
138impl<T: PartialOrd + Clone> StatType<T> for NaNCount {
139 type Bound = UpperBound<T>;
140
141 const STAT: Stat = Stat::NaNCount;
142}
143
144impl Stat {
145 pub fn is_commutative(&self) -> bool {
148 match self {
150 Self::IsConstant
151 | Self::Max
152 | Self::Min
153 | Self::NullCount
154 | Self::Sum
155 | Self::NaNCount
156 | Self::UncompressedSizeInBytes => true,
157 Self::IsSorted | Self::IsStrictSorted => false,
158 }
159 }
160
161 pub fn has_same_dtype_as_array(&self) -> bool {
163 matches!(self, Stat::Min | Stat::Max)
164 }
165
166 pub fn dtype(&self, data_type: &DType) -> Option<DType> {
168 Some(match self {
169 Self::IsConstant => DType::Bool(NonNullable),
170 Self::IsSorted => DType::Bool(NonNullable),
171 Self::IsStrictSorted => DType::Bool(NonNullable),
172 Self::Max if matches!(data_type, DType::Null) => return None,
173 Self::Max => data_type.clone(),
174 Self::Min if matches!(data_type, DType::Null) => return None,
175 Self::Min => data_type.clone(),
176 Self::NullCount => {
177 return aggregate_fn::fns::null_count::NullCount
178 .return_dtype(&EmptyOptions, data_type);
179 }
180 Self::UncompressedSizeInBytes => {
181 return aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes
182 .return_dtype(&EmptyOptions, data_type);
183 }
184 Self::NaNCount => {
185 return aggregate_fn::fns::nan_count::NanCount
186 .return_dtype(&EmptyOptions, data_type);
187 }
188 Self::Sum => {
189 return aggregate_fn::fns::sum::Sum.return_dtype(&EmptyOptions, data_type);
190 }
191 })
192 }
193
194 pub fn aggregate_fn(&self) -> Option<AggregateFnRef> {
196 Some(match self {
197 Self::Sum => aggregate_fn::fns::sum::Sum.bind(EmptyOptions),
198 Self::NullCount => aggregate_fn::fns::null_count::NullCount.bind(EmptyOptions),
199 Self::NaNCount => aggregate_fn::fns::nan_count::NanCount.bind(EmptyOptions),
200 Self::UncompressedSizeInBytes => {
201 aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes
202 .bind(EmptyOptions)
203 }
204 Self::IsConstant | Self::IsSorted | Self::IsStrictSorted | Self::Max | Self::Min => {
205 return None;
206 }
207 })
208 }
209
210 pub fn from_aggregate_fn(aggregate_fn: &AggregateFnRef) -> Option<Self> {
212 if aggregate_fn.is::<aggregate_fn::fns::sum::Sum>() {
213 return Some(Self::Sum);
214 }
215 if aggregate_fn.is::<aggregate_fn::fns::nan_count::NanCount>() {
216 return Some(Self::NaNCount);
217 }
218 if aggregate_fn.is::<aggregate_fn::fns::null_count::NullCount>() {
219 return Some(Self::NullCount);
220 }
221 if aggregate_fn
222 .is::<aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes>()
223 {
224 return Some(Self::UncompressedSizeInBytes);
225 }
226 None
227 }
228
229 pub fn name(&self) -> &str {
230 match self {
231 Self::IsConstant => "is_constant",
232 Self::IsSorted => "is_sorted",
233 Self::IsStrictSorted => "is_strict_sorted",
234 Self::Max => "max",
235 Self::Min => "min",
236 Self::NullCount => "null_count",
237 Self::UncompressedSizeInBytes => "uncompressed_size_in_bytes",
238 Self::Sum => "sum",
239 Self::NaNCount => "nan_count",
240 }
241 }
242
243 pub fn all() -> impl Iterator<Item = Stat> {
244 all::<Self>()
245 }
246}
247
248impl Display for Stat {
249 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
250 write!(f, "{}", self.name())
251 }
252}
253
254#[cfg(test)]
255mod test {
256 use enum_iterator::all;
257
258 use crate::LEGACY_SESSION;
259 use crate::VortexSessionExecute;
260 use crate::arrays::PrimitiveArray;
261 use crate::expr::stats::Stat;
262
263 #[test]
264 fn min_of_nulls_is_not_panic() {
265 let min = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None])
266 .statistics()
267 .compute_as::<i64>(Stat::Min, &mut LEGACY_SESSION.create_execution_ctx());
268
269 assert_eq!(min, None);
270 }
271
272 #[test]
273 fn has_same_dtype_as_array() {
274 assert!(Stat::Min.has_same_dtype_as_array());
275 assert!(Stat::Max.has_same_dtype_as_array());
276 for stat in all::<Stat>().filter(|s| !matches!(s, Stat::Min | Stat::Max)) {
277 assert!(!stat.has_same_dtype_as_array());
278 }
279 }
280}