vortex_btrblocks/integer/
stats.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use num_traits::PrimInt;
7use rustc_hash::FxBuildHasher;
8use vortex_array::ToCanonical;
9use vortex_array::arrays::NativeValue;
10use vortex_array::arrays::PrimitiveArray;
11use vortex_array::arrays::PrimitiveVTable;
12use vortex_array::expr::stats::Stat;
13use vortex_buffer::BitBuffer;
14use vortex_dtype::IntegerPType;
15use vortex_dtype::match_each_integer_ptype;
16use vortex_error::VortexError;
17use vortex_error::VortexExpect;
18use vortex_mask::AllOr;
19use vortex_scalar::PValue;
20use vortex_scalar::Scalar;
21use vortex_utils::aliases::hash_map::HashMap;
22
23use crate::CompressorStats;
24use crate::GenerateStatsOptions;
25use crate::rle::RLEStats;
26use crate::sample::sample;
27
28#[derive(Clone, Debug)]
29pub struct TypedStats<T> {
30    pub min: T,
31    pub max: T,
32    pub top_value: T,
33    pub top_count: u32,
34    pub distinct_values: HashMap<NativeValue<T>, u32, FxBuildHasher>,
35}
36
37/// Type-erased container for one of the [TypedStats] variants.
38///
39/// Building the `TypedStats` is considerably faster and cheaper than building a type-erased
40/// set of stats. We then perform a variety of access methods on them.
41#[derive(Clone, Debug)]
42pub enum ErasedStats {
43    U8(TypedStats<u8>),
44    U16(TypedStats<u16>),
45    U32(TypedStats<u32>),
46    U64(TypedStats<u64>),
47    I8(TypedStats<i8>),
48    I16(TypedStats<i16>),
49    I32(TypedStats<i32>),
50    I64(TypedStats<i64>),
51}
52
53impl ErasedStats {
54    pub fn min_is_zero(&self) -> bool {
55        match &self {
56            ErasedStats::U8(x) => x.min == 0,
57            ErasedStats::U16(x) => x.min == 0,
58            ErasedStats::U32(x) => x.min == 0,
59            ErasedStats::U64(x) => x.min == 0,
60            ErasedStats::I8(x) => x.min == 0,
61            ErasedStats::I16(x) => x.min == 0,
62            ErasedStats::I32(x) => x.min == 0,
63            ErasedStats::I64(x) => x.min == 0,
64        }
65    }
66
67    pub fn min_is_negative(&self) -> bool {
68        match &self {
69            ErasedStats::U8(_)
70            | ErasedStats::U16(_)
71            | ErasedStats::U32(_)
72            | ErasedStats::U64(_) => false,
73            ErasedStats::I8(x) => x.min < 0,
74            ErasedStats::I16(x) => x.min < 0,
75            ErasedStats::I32(x) => x.min < 0,
76            ErasedStats::I64(x) => x.min < 0,
77        }
78    }
79
80    // Difference between max and min.
81    pub fn max_minus_min(&self) -> u64 {
82        match &self {
83            ErasedStats::U8(x) => (x.max - x.min) as u64,
84            ErasedStats::U16(x) => (x.max - x.min) as u64,
85            ErasedStats::U32(x) => (x.max - x.min) as u64,
86            ErasedStats::U64(x) => x.max - x.min,
87            ErasedStats::I8(x) => (x.max as i16 - x.min as i16) as u64,
88            ErasedStats::I16(x) => (x.max as i32 - x.min as i32) as u64,
89            ErasedStats::I32(x) => (x.max as i64 - x.min as i64) as u64,
90            ErasedStats::I64(x) => u64::try_from(x.max as i128 - x.min as i128)
91                .vortex_expect("max minus min result bigger than u64"),
92        }
93    }
94
95    /// Get the most commonly occurring value and its count
96    pub fn top_value_and_count(&self) -> (PValue, u32) {
97        match &self {
98            ErasedStats::U8(x) => (x.top_value.into(), x.top_count),
99            ErasedStats::U16(x) => (x.top_value.into(), x.top_count),
100            ErasedStats::U32(x) => (x.top_value.into(), x.top_count),
101            ErasedStats::U64(x) => (x.top_value.into(), x.top_count),
102            ErasedStats::I8(x) => (x.top_value.into(), x.top_count),
103            ErasedStats::I16(x) => (x.top_value.into(), x.top_count),
104            ErasedStats::I32(x) => (x.top_value.into(), x.top_count),
105            ErasedStats::I64(x) => (x.top_value.into(), x.top_count),
106        }
107    }
108}
109
110macro_rules! impl_from_typed {
111    ($T:ty, $variant:path) => {
112        impl From<TypedStats<$T>> for ErasedStats {
113            fn from(typed: TypedStats<$T>) -> Self {
114                $variant(typed)
115            }
116        }
117    };
118}
119
120impl_from_typed!(u8, ErasedStats::U8);
121impl_from_typed!(u16, ErasedStats::U16);
122impl_from_typed!(u32, ErasedStats::U32);
123impl_from_typed!(u64, ErasedStats::U64);
124impl_from_typed!(i8, ErasedStats::I8);
125impl_from_typed!(i16, ErasedStats::I16);
126impl_from_typed!(i32, ErasedStats::I32);
127impl_from_typed!(i64, ErasedStats::I64);
128
129/// Array of integers and relevant stats for compression.
130#[derive(Clone, Debug)]
131pub struct IntegerStats {
132    pub(super) src: PrimitiveArray,
133    // cache for validity.false_count()
134    pub(super) null_count: u32,
135    // cache for validity.true_count()
136    pub(super) value_count: u32,
137    pub(super) average_run_length: u32,
138    pub(super) distinct_values_count: u32,
139    pub(crate) typed: ErasedStats,
140}
141
142impl CompressorStats for IntegerStats {
143    type ArrayVTable = PrimitiveVTable;
144
145    fn generate_opts(input: &PrimitiveArray, opts: GenerateStatsOptions) -> Self {
146        match_each_integer_ptype!(input.ptype(), |T| {
147            typed_int_stats::<T>(input, opts.count_distinct_values)
148        })
149    }
150
151    fn source(&self) -> &PrimitiveArray {
152        &self.src
153    }
154
155    fn sample_opts(&self, sample_size: u32, sample_count: u32, opts: GenerateStatsOptions) -> Self {
156        let sampled = sample(self.src.as_ref(), sample_size, sample_count).to_primitive();
157
158        Self::generate_opts(&sampled, opts)
159    }
160}
161
162impl RLEStats for IntegerStats {
163    fn value_count(&self) -> u32 {
164        self.value_count
165    }
166
167    fn average_run_length(&self) -> u32 {
168        self.average_run_length
169    }
170
171    fn source(&self) -> &PrimitiveArray {
172        &self.src
173    }
174}
175
176fn typed_int_stats<T>(array: &PrimitiveArray, count_distinct_values: bool) -> IntegerStats
177where
178    T: IntegerPType + PrimInt + for<'a> TryFrom<&'a Scalar, Error = VortexError>,
179    TypedStats<T>: Into<ErasedStats>,
180    NativeValue<T>: Eq + Hash,
181{
182    // Special case: empty array
183    if array.is_empty() {
184        return IntegerStats {
185            src: array.clone(),
186            null_count: 0,
187            value_count: 0,
188            average_run_length: 0,
189            distinct_values_count: 0,
190            typed: TypedStats {
191                min: T::max_value(),
192                max: T::min_value(),
193                top_value: T::default(),
194                top_count: 0,
195                distinct_values: HashMap::with_hasher(FxBuildHasher),
196            }
197            .into(),
198        };
199    } else if array.all_invalid() {
200        return IntegerStats {
201            src: array.clone(),
202            null_count: array.len().try_into().vortex_expect("null_count"),
203            value_count: 0,
204            average_run_length: 0,
205            distinct_values_count: 0,
206            typed: TypedStats {
207                min: T::max_value(),
208                max: T::min_value(),
209                top_value: T::default(),
210                top_count: 0,
211                distinct_values: HashMap::with_hasher(FxBuildHasher),
212            }
213            .into(),
214        };
215    }
216
217    let validity = array.validity_mask();
218    let null_count = validity.false_count();
219    let value_count = validity.true_count();
220
221    // Initialize loop state
222    let head_idx = validity
223        .first()
224        .vortex_expect("All null masks have been handled before");
225    let buffer = array.buffer::<T>();
226    let head = buffer[head_idx];
227
228    let mut loop_state = LoopState {
229        distinct_values: if count_distinct_values {
230            HashMap::with_capacity_and_hasher(array.len() / 2, FxBuildHasher)
231        } else {
232            HashMap::with_hasher(FxBuildHasher)
233        },
234        prev: head,
235        runs: 1,
236    };
237
238    let sliced = buffer.slice(head_idx..array.len());
239    let mut chunks = sliced.as_slice().chunks_exact(64);
240    match validity.bit_buffer() {
241        AllOr::All => {
242            for chunk in &mut chunks {
243                inner_loop_nonnull(
244                    chunk.try_into().vortex_expect("chunk size must be 64"),
245                    count_distinct_values,
246                    &mut loop_state,
247                )
248            }
249            let remainder = chunks.remainder();
250            inner_loop_naive(
251                remainder,
252                count_distinct_values,
253                &BitBuffer::new_set(remainder.len()),
254                &mut loop_state,
255            );
256        }
257        AllOr::None => unreachable!("All invalid arrays have been handled before"),
258        AllOr::Some(v) => {
259            let mask = v.slice(head_idx..array.len());
260            let mut offset = 0;
261            for chunk in &mut chunks {
262                let validity = mask.slice(offset..(offset + 64));
263                offset += 64;
264
265                match validity.true_count() {
266                    // All nulls -> no stats to update
267                    0 => continue,
268                    // Inner loop for when validity check can be elided
269                    64 => inner_loop_nonnull(
270                        chunk.try_into().vortex_expect("chunk size must be 64"),
271                        count_distinct_values,
272                        &mut loop_state,
273                    ),
274                    // Inner loop for when we need to check validity
275                    _ => inner_loop_nullable(
276                        chunk.try_into().vortex_expect("chunk size must be 64"),
277                        count_distinct_values,
278                        &validity,
279                        &mut loop_state,
280                    ),
281                }
282            }
283            // Final iteration, run naive loop
284            let remainder = chunks.remainder();
285            inner_loop_naive(
286                remainder,
287                count_distinct_values,
288                &mask.slice(offset..(offset + remainder.len())),
289                &mut loop_state,
290            );
291        }
292    }
293
294    let (top_value, top_count) = if count_distinct_values {
295        let (&top_value, &top_count) = loop_state
296            .distinct_values
297            .iter()
298            .max_by_key(|&(_, &count)| count)
299            .vortex_expect("non-empty");
300        (top_value.0, top_count)
301    } else {
302        (T::default(), 0)
303    };
304
305    let runs = loop_state.runs;
306    let distinct_values_count = if count_distinct_values {
307        loop_state
308            .distinct_values
309            .len()
310            .try_into()
311            .vortex_expect("distinct values count must fit in u32")
312    } else {
313        u32::MAX
314    };
315
316    let min = array
317        .statistics()
318        .compute_as::<T>(Stat::Min)
319        .vortex_expect("min should be computed");
320
321    let max = array
322        .statistics()
323        .compute_as::<T>(Stat::Max)
324        .vortex_expect("max should be computed");
325
326    let typed = TypedStats {
327        min,
328        max,
329        distinct_values: loop_state.distinct_values,
330        top_value,
331        top_count,
332    };
333
334    let null_count = null_count
335        .try_into()
336        .vortex_expect("null_count must fit in u32");
337    let value_count = value_count
338        .try_into()
339        .vortex_expect("value_count must fit in u32");
340
341    IntegerStats {
342        src: array.clone(),
343        null_count,
344        value_count,
345        average_run_length: value_count / runs,
346        distinct_values_count,
347        typed: typed.into(),
348    }
349}
350
351struct LoopState<T> {
352    prev: T,
353    runs: u32,
354    distinct_values: HashMap<NativeValue<T>, u32, FxBuildHasher>,
355}
356
357#[inline(always)]
358fn inner_loop_nonnull<T: IntegerPType>(
359    values: &[T; 64],
360    count_distinct_values: bool,
361    state: &mut LoopState<T>,
362) where
363    NativeValue<T>: Eq + Hash,
364{
365    for &value in values {
366        if count_distinct_values {
367            *state.distinct_values.entry(NativeValue(value)).or_insert(0) += 1;
368        }
369
370        if value != state.prev {
371            state.prev = value;
372            state.runs += 1;
373        }
374    }
375}
376
377#[inline(always)]
378fn inner_loop_nullable<T: IntegerPType>(
379    values: &[T; 64],
380    count_distinct_values: bool,
381    is_valid: &BitBuffer,
382    state: &mut LoopState<T>,
383) where
384    NativeValue<T>: Eq + Hash,
385{
386    for (idx, &value) in values.iter().enumerate() {
387        if is_valid.value(idx) {
388            if count_distinct_values {
389                *state.distinct_values.entry(NativeValue(value)).or_insert(0) += 1;
390            }
391
392            if value != state.prev {
393                state.prev = value;
394                state.runs += 1;
395            }
396        }
397    }
398}
399
400#[inline(always)]
401fn inner_loop_naive<T: IntegerPType>(
402    values: &[T],
403    count_distinct_values: bool,
404    is_valid: &BitBuffer,
405    state: &mut LoopState<T>,
406) where
407    NativeValue<T>: Eq + Hash,
408{
409    for (idx, &value) in values.iter().enumerate() {
410        if is_valid.value(idx) {
411            if count_distinct_values {
412                *state.distinct_values.entry(NativeValue(value)).or_insert(0) += 1;
413            }
414
415            if value != state.prev {
416                state.prev = value;
417                state.runs += 1;
418            }
419        }
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use std::iter;
426
427    use vortex_array::arrays::PrimitiveArray;
428    use vortex_array::validity::Validity;
429    use vortex_buffer::BitBuffer;
430    use vortex_buffer::Buffer;
431    use vortex_buffer::buffer;
432
433    use crate::CompressorStats;
434    use crate::integer::IntegerStats;
435    use crate::integer::stats::typed_int_stats;
436
437    #[test]
438    fn test_naive_count_distinct_values() {
439        let array = PrimitiveArray::new(buffer![217u8, 0], Validity::NonNullable);
440        let stats = typed_int_stats::<u8>(&array, true);
441        assert_eq!(stats.distinct_values_count, 2);
442    }
443
444    #[test]
445    fn test_naive_count_distinct_values_nullable() {
446        let array = PrimitiveArray::new(
447            buffer![217u8, 0],
448            Validity::from(BitBuffer::from(vec![true, false])),
449        );
450        let stats = typed_int_stats::<u8>(&array, true);
451        assert_eq!(stats.distinct_values_count, 1);
452    }
453
454    #[test]
455    fn test_count_distinct_values() {
456        let array = PrimitiveArray::new((0..128u8).collect::<Buffer<u8>>(), Validity::NonNullable);
457        let stats = typed_int_stats::<u8>(&array, true);
458        assert_eq!(stats.distinct_values_count, 128);
459    }
460
461    #[test]
462    fn test_count_distinct_values_nullable() {
463        let array = PrimitiveArray::new(
464            (0..128u8).collect::<Buffer<u8>>(),
465            Validity::from(BitBuffer::from_iter(
466                iter::repeat_n(vec![true, false], 64).flatten(),
467            )),
468        );
469        let stats = typed_int_stats::<u8>(&array, true);
470        assert_eq!(stats.distinct_values_count, 64);
471    }
472
473    #[test]
474    fn test_integer_stats_leading_nulls() {
475        let ints = PrimitiveArray::new(buffer![0, 1, 2], Validity::from_iter([false, true, true]));
476
477        let stats = IntegerStats::generate(&ints);
478
479        assert_eq!(stats.value_count, 2);
480        assert_eq!(stats.null_count, 1);
481        assert_eq!(stats.average_run_length, 1);
482        assert_eq!(stats.distinct_values_count, 2);
483    }
484}