vortex_array/expr/stats/
mod.rs1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use enum_iterator::Sequence;
9use enum_iterator::all;
10use num_enum::IntoPrimitive;
11use num_enum::TryFromPrimitive;
12
13use crate::dtype::DType;
14use crate::dtype::DecimalDType;
15use crate::dtype::MAX_PRECISION;
16use crate::dtype::Nullability::NonNullable;
17use crate::dtype::Nullability::Nullable;
18use crate::dtype::PType;
19
20mod bound;
21mod precision;
22mod provider;
23mod stat_bound;
24
25pub use bound::*;
26pub use precision::*;
27pub use provider::*;
28pub use stat_bound::*;
29
30#[derive(
31 Debug,
32 Clone,
33 Copy,
34 PartialEq,
35 Eq,
36 PartialOrd,
37 Ord,
38 Hash,
39 Sequence,
40 IntoPrimitive,
41 TryFromPrimitive,
42)]
43#[repr(u8)]
44pub enum Stat {
45 IsConstant = 0,
48 IsSorted = 1,
51 IsStrictSorted = 2,
54 Max = 3,
56 Min = 4,
58 Sum = 5,
60 NullCount = 6,
62 UncompressedSizeInBytes = 7,
64 NaNCount = 8,
66}
67
68pub struct Max;
71
72pub struct Min;
73
74pub struct Sum;
75
76pub struct IsConstant;
77
78pub struct IsSorted;
79
80pub struct IsStrictSorted;
81
82pub struct NullCount;
83
84pub struct UncompressedSizeInBytes;
85
86pub struct NaNCount;
87
88impl StatType<bool> for IsConstant {
89 type Bound = Precision<bool>;
90
91 const STAT: Stat = Stat::IsConstant;
92}
93
94impl StatType<bool> for IsSorted {
95 type Bound = Precision<bool>;
96
97 const STAT: Stat = Stat::IsSorted;
98}
99
100impl StatType<bool> for IsStrictSorted {
101 type Bound = Precision<bool>;
102
103 const STAT: Stat = Stat::IsStrictSorted;
104}
105
106impl<T: PartialOrd + Clone> StatType<T> for NullCount {
107 type Bound = UpperBound<T>;
108
109 const STAT: Stat = Stat::NullCount;
110}
111
112impl<T: PartialOrd + Clone> StatType<T> for UncompressedSizeInBytes {
113 type Bound = UpperBound<T>;
114
115 const STAT: Stat = Stat::UncompressedSizeInBytes;
116}
117
118impl<T: PartialOrd + Clone + Debug> StatType<T> for Max {
119 type Bound = UpperBound<T>;
120
121 const STAT: Stat = Stat::Max;
122}
123
124impl<T: PartialOrd + Clone + Debug> StatType<T> for Min {
125 type Bound = LowerBound<T>;
126
127 const STAT: Stat = Stat::Min;
128}
129
130impl<T: PartialOrd + Clone + Debug> StatType<T> for Sum {
131 type Bound = Precision<T>;
132
133 const STAT: Stat = Stat::Sum;
134}
135
136impl<T: PartialOrd + Clone> StatType<T> for NaNCount {
137 type Bound = UpperBound<T>;
138
139 const STAT: Stat = Stat::NaNCount;
140}
141
142impl Stat {
143 pub fn is_commutative(&self) -> bool {
146 match self {
148 Self::IsConstant
149 | Self::Max
150 | Self::Min
151 | Self::NullCount
152 | Self::Sum
153 | Self::NaNCount
154 | Self::UncompressedSizeInBytes => true,
155 Self::IsSorted | Self::IsStrictSorted => false,
156 }
157 }
158
159 pub fn has_same_dtype_as_array(&self) -> bool {
161 matches!(self, Stat::Min | Stat::Max)
162 }
163
164 pub fn dtype(&self, data_type: &DType) -> Option<DType> {
166 Some(match self {
167 Self::IsConstant => DType::Bool(NonNullable),
168 Self::IsSorted => DType::Bool(NonNullable),
169 Self::IsStrictSorted => DType::Bool(NonNullable),
170 Self::Max if matches!(data_type, DType::Null) => return None,
171 Self::Max => data_type.clone(),
172 Self::Min if matches!(data_type, DType::Null) => return None,
173 Self::Min => data_type.clone(),
174 Self::NullCount => DType::Primitive(PType::U64, NonNullable),
175 Self::UncompressedSizeInBytes => DType::Primitive(PType::U64, NonNullable),
176 Self::NaNCount => {
177 if let DType::Primitive(ptype, ..) = data_type
179 && ptype.is_float()
180 {
181 DType::Primitive(PType::U64, NonNullable)
182 } else {
183 return None;
184 }
185 }
186 Self::Sum => {
187 match data_type {
191 DType::Bool(_) => DType::Primitive(PType::U64, Nullable),
192 DType::Primitive(ptype, _) => match ptype {
193 PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
194 DType::Primitive(PType::U64, Nullable)
195 }
196 PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
197 DType::Primitive(PType::I64, Nullable)
198 }
199 PType::F16 | PType::F32 | PType::F64 => {
200 DType::Primitive(PType::F64, Nullable)
202 }
203 },
204 DType::Extension(ext_dtype) => self.dtype(ext_dtype.storage_dtype())?,
205 DType::Decimal(decimal_dtype, _) => {
206 let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10);
210 DType::Decimal(
211 DecimalDType::new(precision, decimal_dtype.scale()),
212 Nullable,
213 )
214 }
215 _ => return None,
217 }
218 }
219 })
220 }
221
222 pub fn name(&self) -> &str {
223 match self {
224 Self::IsConstant => "is_constant",
225 Self::IsSorted => "is_sorted",
226 Self::IsStrictSorted => "is_strict_sorted",
227 Self::Max => "max",
228 Self::Min => "min",
229 Self::NullCount => "null_count",
230 Self::UncompressedSizeInBytes => "uncompressed_size_in_bytes",
231 Self::Sum => "sum",
232 Self::NaNCount => "nan_count",
233 }
234 }
235
236 pub fn all() -> impl Iterator<Item = Stat> {
237 all::<Self>()
238 }
239}
240
241impl Display for Stat {
242 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
243 write!(f, "{}", self.name())
244 }
245}
246
247#[cfg(test)]
248mod test {
249 use enum_iterator::all;
250
251 use crate::arrays::PrimitiveArray;
252 use crate::expr::stats::Stat;
253
254 #[test]
255 fn min_of_nulls_is_not_panic() {
256 let min = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None])
257 .statistics()
258 .compute_as::<i64>(Stat::Min);
259
260 assert_eq!(min, None);
261 }
262
263 #[test]
264 fn has_same_dtype_as_array() {
265 assert!(Stat::Min.has_same_dtype_as_array());
266 assert!(Stat::Max.has_same_dtype_as_array());
267 for stat in all::<Stat>().filter(|s| !matches!(s, Stat::Min | Stat::Max)) {
268 assert!(!stat.has_same_dtype_as_array());
269 }
270 }
271}