vortex_array/arrays/primitive/
stats.rs

1use std::cmp::Ordering;
2
3use arrow_buffer::buffer::BooleanBuffer;
4use vortex_dtype::{NativePType, match_each_native_ptype};
5use vortex_error::{VortexError, VortexResult};
6use vortex_mask::Mask;
7use vortex_scalar::ScalarValue;
8
9use crate::Array;
10use crate::arrays::PrimitiveEncoding;
11use crate::arrays::primitive::PrimitiveArray;
12use crate::compute::min_max;
13use crate::stats::{Precision, Stat, StatsSet};
14use crate::variants::PrimitiveArrayTrait;
15use crate::vtable::StatisticsVTable;
16
17trait PStatsType:
18    NativePType + Into<ScalarValue> + for<'a> TryFrom<&'a ScalarValue, Error = VortexError>
19{
20}
21
22impl<T> PStatsType for T where
23    T: NativePType + Into<ScalarValue> + for<'a> TryFrom<&'a ScalarValue, Error = VortexError>
24{
25}
26
27impl StatisticsVTable<&PrimitiveArray> for PrimitiveEncoding {
28    fn compute_statistics(&self, array: &PrimitiveArray, stat: Stat) -> VortexResult<StatsSet> {
29        if stat == Stat::Max || stat == Stat::Min {
30            min_max(array)?;
31            return Ok(array.statistics().to_owned());
32        }
33
34        match_each_native_ptype!(array.ptype(), |$P| {
35            self.compute_stats_with_validity::<$P>(array, stat)
36        })
37    }
38}
39
40impl PrimitiveEncoding {
41    #[inline]
42    fn compute_stats_with_validity<P: NativePType + PStatsType>(
43        &self,
44        array: &PrimitiveArray,
45        stat: Stat,
46    ) -> VortexResult<StatsSet> {
47        match array.validity_mask()? {
48            Mask::AllTrue(_) => self.compute_statistics(array.as_slice::<P>(), stat),
49            Mask::AllFalse(len) => Ok(StatsSet::nulls(len)),
50            Mask::Values(v) => self.compute_statistics(
51                &NullableValues(array.as_slice::<P>(), v.boolean_buffer()),
52                stat,
53            ),
54        }
55    }
56}
57
58impl<T: PStatsType + PartialEq> StatisticsVTable<&[T]> for PrimitiveEncoding {
59    fn compute_statistics(&self, array: &[T], stat: Stat) -> VortexResult<StatsSet> {
60        if array.is_empty() {
61            return Ok(StatsSet::default());
62        }
63
64        Ok(match stat {
65            Stat::NullCount => StatsSet::of(Stat::NullCount, Precision::exact(0u64)),
66            Stat::IsSorted => compute_is_sorted(array.iter().copied()),
67            Stat::IsStrictSorted => compute_is_strict_sorted(array.iter().copied()),
68            _ => unreachable!("already handled above"),
69        })
70    }
71}
72
73struct NullableValues<'a, T: PStatsType>(&'a [T], &'a BooleanBuffer);
74
75impl<T: PStatsType> StatisticsVTable<&NullableValues<'_, T>> for PrimitiveEncoding {
76    fn compute_statistics(
77        &self,
78        nulls: &NullableValues<'_, T>,
79        stat: Stat,
80    ) -> VortexResult<StatsSet> {
81        let values = nulls.0;
82        if values.is_empty() {
83            return Ok(StatsSet::default());
84        }
85
86        let null_count = values.len() - nulls.1.count_set_bits();
87        if null_count == 0 {
88            // no nulls, use the fast path on the values
89            return self.compute_statistics(values, stat);
90        } else if null_count == values.len() {
91            // all nulls!
92            return Ok(StatsSet::nulls(values.len()));
93        }
94
95        let mut stats = StatsSet::new_unchecked(vec![
96            (Stat::NullCount, Precision::exact(null_count)),
97            (Stat::IsConstant, Precision::exact(false)),
98        ]);
99        // we know that there is at least one null, but not all nulls, so it's not constant
100        if stat == Stat::IsConstant {
101            return Ok(stats);
102        }
103
104        let set_indices = nulls.1.set_indices();
105        if stat == Stat::IsSorted {
106            stats.extend(compute_is_sorted(set_indices.map(|next| values[next])));
107        } else if stat == Stat::IsStrictSorted {
108            stats.extend(compute_is_strict_sorted(
109                set_indices.map(|next| values[next]),
110            ));
111        }
112
113        Ok(stats)
114    }
115}
116
117fn compute_is_sorted<T: PStatsType>(mut iter: impl Iterator<Item = T>) -> StatsSet {
118    let mut sorted = true;
119    let Some(mut prev) = iter.next() else {
120        return StatsSet::default();
121    };
122    for next in iter {
123        if matches!(next.total_compare(prev), Ordering::Less) {
124            sorted = false;
125            break;
126        }
127        prev = next;
128    }
129
130    if sorted {
131        StatsSet::of(Stat::IsSorted, Precision::exact(true))
132    } else {
133        StatsSet::new_unchecked(vec![
134            (Stat::IsSorted, Precision::exact(false)),
135            (Stat::IsStrictSorted, Precision::exact(false)),
136        ])
137    }
138}
139
140fn compute_is_strict_sorted<T: PStatsType>(mut iter: impl Iterator<Item = T>) -> StatsSet {
141    let mut strict_sorted = true;
142    let Some(mut prev) = iter.next() else {
143        return StatsSet::default();
144    };
145
146    for next in iter {
147        if !matches!(prev.total_compare(next), Ordering::Less) {
148            strict_sorted = false;
149            break;
150        }
151        prev = next;
152    }
153
154    if strict_sorted {
155        StatsSet::new_unchecked(vec![
156            (Stat::IsSorted, Precision::exact(true)),
157            (Stat::IsStrictSorted, Precision::exact(true)),
158        ])
159    } else {
160        StatsSet::of(Stat::IsStrictSorted, Precision::exact(false))
161    }
162}
163
164#[cfg(test)]
165mod test {
166    use crate::array::Array;
167    use crate::arrays::primitive::PrimitiveArray;
168    use crate::stats::Stat;
169
170    #[test]
171    fn stats() {
172        let arr = PrimitiveArray::from_iter([1, 2, 3, 4, 5]);
173        let min: i32 = arr.statistics().compute_min().unwrap();
174        let max: i32 = arr.statistics().compute_max().unwrap();
175        let is_sorted = arr.statistics().compute_is_sorted().unwrap();
176        let is_strict_sorted = arr.statistics().compute_is_strict_sorted().unwrap();
177        let is_constant = arr.statistics().compute_is_constant().unwrap();
178        assert_eq!(min, 1);
179        assert_eq!(max, 5);
180        assert!(is_sorted);
181        assert!(is_strict_sorted);
182        assert!(!is_constant);
183    }
184
185    #[test]
186    fn stats_u8() {
187        let arr = PrimitiveArray::from_iter([1u8, 2, 3, 4, 5]);
188        let min: u8 = arr.statistics().compute_min().unwrap();
189        let max: u8 = arr.statistics().compute_max().unwrap();
190        assert_eq!(min, 1);
191        assert_eq!(max, 5);
192    }
193
194    #[test]
195    fn nullable_stats_u8() {
196        let arr = PrimitiveArray::from_option_iter([None, None, Some(1i32), Some(2), None]);
197        let min: i32 = arr.statistics().compute_min().unwrap();
198        let max: i32 = arr.statistics().compute_max().unwrap();
199        let null_count: usize = arr.statistics().compute_null_count().unwrap();
200        let is_strict_sorted: bool = arr.statistics().compute_is_strict_sorted().unwrap();
201        assert_eq!(min, 1);
202        assert_eq!(max, 2);
203        assert_eq!(null_count, 3);
204        assert!(is_strict_sorted);
205    }
206
207    #[test]
208    fn all_null() {
209        let arr = PrimitiveArray::from_option_iter([Option::<i32>::None, None, None]);
210        let arr_stats = arr.statistics();
211        let min = arr_stats.compute_stat(Stat::Min).unwrap();
212        let max = arr_stats.compute_stat(Stat::Max).unwrap();
213        assert!(min.is_none());
214        assert!(max.is_none());
215    }
216}