vortex_array/expr/stats/
mod.rs1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use enum_iterator::Sequence;
9use enum_iterator::all;
10use num_enum::IntoPrimitive;
11use num_enum::TryFromPrimitive;
12use vortex_dtype::DType;
13use vortex_dtype::DecimalDType;
14use vortex_dtype::MAX_PRECISION;
15use vortex_dtype::Nullability::NonNullable;
16use vortex_dtype::Nullability::Nullable;
17use vortex_dtype::PType;
18
19mod bound;
20mod precision;
21mod provider;
22mod stat_bound;
23
24pub use bound::*;
25pub use precision::*;
26pub use provider::*;
27pub use stat_bound::*;
28
29#[derive(
30 Debug,
31 Clone,
32 Copy,
33 PartialEq,
34 Eq,
35 PartialOrd,
36 Ord,
37 Hash,
38 Sequence,
39 IntoPrimitive,
40 TryFromPrimitive,
41)]
42#[repr(u8)]
43pub enum Stat {
44 IsConstant = 0,
47 IsSorted = 1,
50 IsStrictSorted = 2,
53 Max = 3,
55 Min = 4,
57 Sum = 5,
59 NullCount = 6,
61 UncompressedSizeInBytes = 7,
63 NaNCount = 8,
65}
66
67pub struct Max;
70
71pub struct Min;
72
73pub struct Sum;
74
75pub struct IsConstant;
76
77pub struct IsSorted;
78
79pub struct IsStrictSorted;
80
81pub struct NullCount;
82
83pub struct UncompressedSizeInBytes;
84
85pub struct NaNCount;
86
87impl StatType<bool> for IsConstant {
88 type Bound = Precision<bool>;
89
90 const STAT: Stat = Stat::IsConstant;
91}
92
93impl StatType<bool> for IsSorted {
94 type Bound = Precision<bool>;
95
96 const STAT: Stat = Stat::IsSorted;
97}
98
99impl StatType<bool> for IsStrictSorted {
100 type Bound = Precision<bool>;
101
102 const STAT: Stat = Stat::IsStrictSorted;
103}
104
105impl<T: PartialOrd + Clone> StatType<T> for NullCount {
106 type Bound = UpperBound<T>;
107
108 const STAT: Stat = Stat::NullCount;
109}
110
111impl<T: PartialOrd + Clone> StatType<T> for UncompressedSizeInBytes {
112 type Bound = UpperBound<T>;
113
114 const STAT: Stat = Stat::UncompressedSizeInBytes;
115}
116
117impl<T: PartialOrd + Clone + Debug> StatType<T> for Max {
118 type Bound = UpperBound<T>;
119
120 const STAT: Stat = Stat::Max;
121}
122
123impl<T: PartialOrd + Clone + Debug> StatType<T> for Min {
124 type Bound = LowerBound<T>;
125
126 const STAT: Stat = Stat::Min;
127}
128
129impl<T: PartialOrd + Clone + Debug> StatType<T> for Sum {
130 type Bound = Precision<T>;
131
132 const STAT: Stat = Stat::Sum;
133}
134
135impl<T: PartialOrd + Clone> StatType<T> for NaNCount {
136 type Bound = UpperBound<T>;
137
138 const STAT: Stat = Stat::NaNCount;
139}
140
141impl Stat {
142 pub fn is_commutative(&self) -> bool {
145 match self {
147 Self::IsConstant
148 | Self::Max
149 | Self::Min
150 | Self::NullCount
151 | Self::Sum
152 | Self::NaNCount
153 | Self::UncompressedSizeInBytes => true,
154 Self::IsSorted | Self::IsStrictSorted => false,
155 }
156 }
157
158 pub fn has_same_dtype_as_array(&self) -> bool {
160 matches!(self, Stat::Min | Stat::Max)
161 }
162
163 pub fn dtype(&self, data_type: &DType) -> Option<DType> {
165 Some(match self {
166 Self::IsConstant => DType::Bool(NonNullable),
167 Self::IsSorted => DType::Bool(NonNullable),
168 Self::IsStrictSorted => DType::Bool(NonNullable),
169 Self::Max if matches!(data_type, DType::Null) => return None,
170 Self::Max => data_type.clone(),
171 Self::Min if matches!(data_type, DType::Null) => return None,
172 Self::Min => data_type.clone(),
173 Self::NullCount => DType::Primitive(PType::U64, NonNullable),
174 Self::UncompressedSizeInBytes => DType::Primitive(PType::U64, NonNullable),
175 Self::NaNCount => {
176 if let DType::Primitive(ptype, ..) = data_type
178 && ptype.is_float()
179 {
180 DType::Primitive(PType::U64, NonNullable)
181 } else {
182 return None;
183 }
184 }
185 Self::Sum => {
186 match data_type {
190 DType::Bool(_) => DType::Primitive(PType::U64, Nullable),
191 DType::Primitive(ptype, _) => match ptype {
192 PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
193 DType::Primitive(PType::U64, Nullable)
194 }
195 PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
196 DType::Primitive(PType::I64, Nullable)
197 }
198 PType::F16 | PType::F32 | PType::F64 => {
199 DType::Primitive(PType::F64, Nullable)
201 }
202 },
203 DType::Extension(ext_dtype) => self.dtype(ext_dtype.storage_dtype())?,
204 DType::Decimal(decimal_dtype, _) => {
205 let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10);
209 DType::Decimal(
210 DecimalDType::new(precision, decimal_dtype.scale()),
211 Nullable,
212 )
213 }
214 _ => return None,
216 }
217 }
218 })
219 }
220
221 pub fn name(&self) -> &str {
222 match self {
223 Self::IsConstant => "is_constant",
224 Self::IsSorted => "is_sorted",
225 Self::IsStrictSorted => "is_strict_sorted",
226 Self::Max => "max",
227 Self::Min => "min",
228 Self::NullCount => "null_count",
229 Self::UncompressedSizeInBytes => "uncompressed_size_in_bytes",
230 Self::Sum => "sum",
231 Self::NaNCount => "nan_count",
232 }
233 }
234
235 pub fn all() -> impl Iterator<Item = Stat> {
236 all::<Self>()
237 }
238}
239
240impl Display for Stat {
241 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
242 write!(f, "{}", self.name())
243 }
244}
245
246#[cfg(test)]
247mod test {
248 use enum_iterator::all;
249
250 use crate::arrays::PrimitiveArray;
251 use crate::expr::stats::Stat;
252
253 #[test]
254 fn min_of_nulls_is_not_panic() {
255 let min = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None])
256 .statistics()
257 .compute_as::<i64>(Stat::Min);
258
259 assert_eq!(min, None);
260 }
261
262 #[test]
263 fn has_same_dtype_as_array() {
264 assert!(Stat::Min.has_same_dtype_as_array());
265 assert!(Stat::Max.has_same_dtype_as_array());
266 for stat in all::<Stat>().filter(|s| !matches!(s, Stat::Min | Stat::Max)) {
267 assert!(!stat.has_same_dtype_as_array());
268 }
269 }
270}