vortex_array/stats/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Traits and utilities to compute and access array statistics.
5
6use std::fmt::{Debug, Display, Formatter};
7use std::hash::Hash;
8
9use arrow_buffer::bit_iterator::BitIterator;
10use arrow_buffer::{BooleanBufferBuilder, MutableBuffer};
11use enum_iterator::{Sequence, all, last};
12use log::debug;
13use num_enum::{IntoPrimitive, TryFromPrimitive};
14pub use stats_set::*;
15use vortex_dtype::Nullability::{NonNullable, Nullable};
16use vortex_dtype::{DECIMAL256_MAX_PRECISION, DType, DecimalDType, PType};
17
18mod array;
19mod bound;
20pub mod flatbuffers;
21mod precision;
22mod provider;
23mod stat_bound;
24mod stats_set;
25
26pub use array::*;
27pub use bound::{LowerBound, UpperBound};
28pub use precision::Precision;
29pub use provider::*;
30pub use stat_bound::*;
31use vortex_error::VortexExpect;
32
33/// Statistics that are used for pruning files (i.e., we want to ensure they are computed when compressing/writing).
34/// Sum is included for boolean arrays.
35pub const PRUNING_STATS: &[Stat] = &[
36    Stat::Min,
37    Stat::Max,
38    Stat::Sum,
39    Stat::NullCount,
40    Stat::NaNCount,
41];
42
43#[derive(
44    Debug,
45    Clone,
46    Copy,
47    PartialEq,
48    Eq,
49    PartialOrd,
50    Ord,
51    Hash,
52    Sequence,
53    IntoPrimitive,
54    TryFromPrimitive,
55)]
56#[repr(u8)]
57pub enum Stat {
58    /// Whether all values are the same (nulls are not equal to other non-null values,
59    /// so this is true iff all values are null or all values are the same non-null value)
60    IsConstant = 0,
61    /// Whether the non-null values in the array are sorted (i.e., we skip nulls)
62    IsSorted = 1,
63    /// Whether the non-null values in the array are strictly sorted (i.e., sorted with no duplicates)
64    IsStrictSorted = 2,
65    /// The maximum value in the array (ignoring nulls, unless all values are null)
66    Max = 3,
67    /// The minimum value in the array (ignoring nulls, unless all values are null)
68    Min = 4,
69    /// The sum of the non-null values of the array.
70    Sum = 5,
71    /// The number of null values in the array
72    NullCount = 6,
73    /// The uncompressed size of the array in bytes
74    UncompressedSizeInBytes = 7,
75    /// The number of NaN values in the array
76    NaNCount = 8,
77}
78
79/// These structs allow the extraction of the bound from the `Precision` value.
80/// They tie together the Stat and the StatBound, which allows the bound to be extracted.
81pub struct Max;
82pub struct Min;
83pub struct Sum;
84pub struct IsConstant;
85pub struct IsSorted;
86pub struct IsStrictSorted;
87pub struct NullCount;
88pub struct UncompressedSizeInBytes;
89pub struct NaNCount;
90
91impl StatType<bool> for IsConstant {
92    type Bound = Precision<bool>;
93
94    const STAT: Stat = Stat::IsConstant;
95}
96
97impl StatType<bool> for IsSorted {
98    type Bound = Precision<bool>;
99
100    const STAT: Stat = Stat::IsSorted;
101}
102
103impl StatType<bool> for IsStrictSorted {
104    type Bound = Precision<bool>;
105
106    const STAT: Stat = Stat::IsStrictSorted;
107}
108
109impl<T: PartialOrd + Clone> StatType<T> for NullCount {
110    type Bound = UpperBound<T>;
111
112    const STAT: Stat = Stat::NullCount;
113}
114
115impl<T: PartialOrd + Clone> StatType<T> for UncompressedSizeInBytes {
116    type Bound = UpperBound<T>;
117
118    const STAT: Stat = Stat::UncompressedSizeInBytes;
119}
120
121impl<T: PartialOrd + Clone + Debug> StatType<T> for Max {
122    type Bound = UpperBound<T>;
123
124    const STAT: Stat = Stat::Max;
125}
126
127impl<T: PartialOrd + Clone + Debug> StatType<T> for Min {
128    type Bound = LowerBound<T>;
129
130    const STAT: Stat = Stat::Min;
131}
132
133impl<T: PartialOrd + Clone + Debug> StatType<T> for Sum {
134    type Bound = Precision<T>;
135
136    const STAT: Stat = Stat::Sum;
137}
138
139impl<T: PartialOrd + Clone> StatType<T> for NaNCount {
140    type Bound = UpperBound<T>;
141
142    const STAT: Stat = Stat::NaNCount;
143}
144
145impl Stat {
146    /// Whether the statistic is commutative (i.e., whether merging can be done independently of ordering)
147    /// e.g., min/max are commutative, but is_sorted is not
148    pub fn is_commutative(&self) -> bool {
149        // NOTE: we prefer this syntax to force a compile error if we add a new stat
150        match self {
151            Self::IsConstant
152            | Self::Max
153            | Self::Min
154            | Self::NullCount
155            | Self::Sum
156            | Self::NaNCount
157            | Self::UncompressedSizeInBytes => true,
158            Self::IsSorted | Self::IsStrictSorted => false,
159        }
160    }
161
162    /// Whether the statistic has the same dtype as the array it's computed on
163    pub fn has_same_dtype_as_array(&self) -> bool {
164        matches!(self, Stat::Min | Stat::Max)
165    }
166
167    /// Return the [`DType`] of the statistic scalar assuming the array is of the given [`DType`].
168    pub fn dtype(&self, data_type: &DType) -> Option<DType> {
169        Some(match self {
170            Self::IsConstant => DType::Bool(NonNullable),
171            Self::IsSorted => DType::Bool(NonNullable),
172            Self::IsStrictSorted => DType::Bool(NonNullable),
173            Self::Max if matches!(data_type, DType::Null) => return None,
174            Self::Max => data_type.clone(),
175            Self::Min if matches!(data_type, DType::Null) => return None,
176            Self::Min => data_type.clone(),
177            Self::NullCount => DType::Primitive(PType::U64, NonNullable),
178            Self::UncompressedSizeInBytes => DType::Primitive(PType::U64, NonNullable),
179            Self::NaNCount => {
180                // Only floating points support NaN counts.
181                if let DType::Primitive(ptype, ..) = data_type
182                    && ptype.is_float()
183                {
184                    DType::Primitive(PType::U64, NonNullable)
185                } else {
186                    return None;
187                }
188            }
189            Self::Sum => {
190                // Any array that cannot be summed has a sum DType of null.
191                // Any array that can be summed, but overflows, has a sum _value_ of null.
192                // Therefore, we make integer sum stats nullable.
193                match data_type {
194                    DType::Bool(_) => DType::Primitive(PType::U64, Nullable),
195                    DType::Primitive(ptype, _) => match ptype {
196                        PType::U8 | PType::U16 | PType::U32 | PType::U64 => {
197                            DType::Primitive(PType::U64, Nullable)
198                        }
199                        PType::I8 | PType::I16 | PType::I32 | PType::I64 => {
200                            DType::Primitive(PType::I64, Nullable)
201                        }
202                        PType::F16 | PType::F32 | PType::F64 => {
203                            // Float sums cannot overflow, but all null floats still end up as null
204                            DType::Primitive(PType::F64, Nullable)
205                        }
206                    },
207                    DType::Extension(ext_dtype) => self.dtype(ext_dtype.storage_dtype())?,
208                    DType::Decimal(decimal_dtype, nullability) => {
209                        // Both Spark and DataFusion use this heuristic.
210                        // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
211                        // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188
212                        let precision =
213                            u8::min(DECIMAL256_MAX_PRECISION, decimal_dtype.precision() + 10);
214                        DType::Decimal(
215                            DecimalDType::new(precision, decimal_dtype.scale()),
216                            *nullability,
217                        )
218                    }
219                    // Unsupported types
220                    _ => return None,
221                }
222            }
223        })
224    }
225
226    pub fn name(&self) -> &str {
227        match self {
228            Self::IsConstant => "is_constant",
229            Self::IsSorted => "is_sorted",
230            Self::IsStrictSorted => "is_strict_sorted",
231            Self::Max => "max",
232            Self::Min => "min",
233            Self::NullCount => "null_count",
234            Self::UncompressedSizeInBytes => "uncompressed_size_in_bytes",
235            Self::Sum => "sum",
236            Self::NaNCount => "nan_count",
237        }
238    }
239
240    pub fn all() -> impl Iterator<Item = Stat> {
241        all::<Self>()
242    }
243}
244
245pub fn as_stat_bitset_bytes(stats: &[Stat]) -> Vec<u8> {
246    let max_stat = u8::from(last::<Stat>().vortex_expect("last stat")) as usize + 1;
247    // TODO(ngates): use vortex-buffer::BitBuffer
248    let mut stat_bitset = BooleanBufferBuilder::new_from_buffer(
249        MutableBuffer::from_len_zeroed(max_stat.div_ceil(8)),
250        max_stat,
251    );
252    for stat in stats {
253        stat_bitset.set_bit(u8::from(*stat) as usize, true);
254    }
255
256    stat_bitset
257        .finish()
258        .into_inner()
259        .into_vec()
260        .unwrap_or_else(|b| b.to_vec())
261}
262
263pub fn stats_from_bitset_bytes(bytes: &[u8]) -> Vec<Stat> {
264    BitIterator::new(bytes, 0, bytes.len() * 8)
265        .enumerate()
266        .filter_map(|(i, b)| b.then_some(i))
267        // Filter out indices failing conversion, these are stats written by newer version of library
268        .filter_map(|i| {
269            let Ok(stat) = u8::try_from(i) else {
270                debug!("invalid stat encountered: {i}");
271                return None;
272            };
273            Stat::try_from(stat).ok()
274        })
275        .collect::<Vec<_>>()
276}
277
278impl Display for Stat {
279    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
280        write!(f, "{}", self.name())
281    }
282}
283
284#[cfg(test)]
285mod test {
286    use enum_iterator::all;
287
288    use crate::arrays::PrimitiveArray;
289    use crate::stats::Stat;
290
291    #[test]
292    fn min_of_nulls_is_not_panic() {
293        let min = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None])
294            .statistics()
295            .compute_as::<i64>(Stat::Min);
296
297        assert_eq!(min, None);
298    }
299
300    #[test]
301    fn has_same_dtype_as_array() {
302        assert!(Stat::Min.has_same_dtype_as_array());
303        assert!(Stat::Max.has_same_dtype_as_array());
304        for stat in all::<Stat>().filter(|s| !matches!(s, Stat::Min | Stat::Max)) {
305            assert!(!stat.has_same_dtype_as_array());
306        }
307    }
308}