Skip to main content

vortex_btrblocks/compressor/integer/
stats.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use num_traits::PrimInt;
7use rustc_hash::FxBuildHasher;
8use vortex_array::ToCanonical;
9use vortex_array::arrays::NativeValue;
10use vortex_array::arrays::PrimitiveArray;
11use vortex_array::arrays::PrimitiveVTable;
12use vortex_array::expr::stats::Stat;
13use vortex_array::scalar::PValue;
14use vortex_array::scalar::Scalar;
15use vortex_buffer::BitBuffer;
16use vortex_dtype::IntegerPType;
17use vortex_dtype::match_each_integer_ptype;
18use vortex_error::VortexError;
19use vortex_error::VortexExpect;
20use vortex_error::VortexResult;
21use vortex_mask::AllOr;
22use vortex_utils::aliases::hash_map::HashMap;
23
24use crate::CompressorStats;
25use crate::GenerateStatsOptions;
26use crate::compressor::rle::RLEStats;
27use crate::sample::sample;
28
29#[derive(Clone, Debug)]
30pub struct TypedStats<T> {
31    pub min: T,
32    pub max: T,
33    pub top_value: T,
34    pub top_count: u32,
35    pub distinct_values: HashMap<NativeValue<T>, u32, FxBuildHasher>,
36}
37
38/// Type-erased container for one of the [TypedStats] variants.
39///
40/// Building the `TypedStats` is considerably faster and cheaper than building a type-erased
41/// set of stats. We then perform a variety of access methods on them.
42#[derive(Clone, Debug)]
43pub enum ErasedStats {
44    U8(TypedStats<u8>),
45    U16(TypedStats<u16>),
46    U32(TypedStats<u32>),
47    U64(TypedStats<u64>),
48    I8(TypedStats<i8>),
49    I16(TypedStats<i16>),
50    I32(TypedStats<i32>),
51    I64(TypedStats<i64>),
52}
53
54impl ErasedStats {
55    pub fn min_is_zero(&self) -> bool {
56        match &self {
57            ErasedStats::U8(x) => x.min == 0,
58            ErasedStats::U16(x) => x.min == 0,
59            ErasedStats::U32(x) => x.min == 0,
60            ErasedStats::U64(x) => x.min == 0,
61            ErasedStats::I8(x) => x.min == 0,
62            ErasedStats::I16(x) => x.min == 0,
63            ErasedStats::I32(x) => x.min == 0,
64            ErasedStats::I64(x) => x.min == 0,
65        }
66    }
67
68    pub fn min_is_negative(&self) -> bool {
69        match &self {
70            ErasedStats::U8(_)
71            | ErasedStats::U16(_)
72            | ErasedStats::U32(_)
73            | ErasedStats::U64(_) => false,
74            ErasedStats::I8(x) => x.min < 0,
75            ErasedStats::I16(x) => x.min < 0,
76            ErasedStats::I32(x) => x.min < 0,
77            ErasedStats::I64(x) => x.min < 0,
78        }
79    }
80
81    // Difference between max and min.
82    pub fn max_minus_min(&self) -> u64 {
83        match &self {
84            ErasedStats::U8(x) => (x.max - x.min) as u64,
85            ErasedStats::U16(x) => (x.max - x.min) as u64,
86            ErasedStats::U32(x) => (x.max - x.min) as u64,
87            ErasedStats::U64(x) => x.max - x.min,
88            ErasedStats::I8(x) => (x.max as i16 - x.min as i16) as u64,
89            ErasedStats::I16(x) => (x.max as i32 - x.min as i32) as u64,
90            ErasedStats::I32(x) => (x.max as i64 - x.min as i64) as u64,
91            ErasedStats::I64(x) => u64::try_from(x.max as i128 - x.min as i128)
92                .vortex_expect("max minus min result bigger than u64"),
93        }
94    }
95
96    /// Get the most commonly occurring value and its count
97    pub fn top_value_and_count(&self) -> (PValue, u32) {
98        match &self {
99            ErasedStats::U8(x) => (x.top_value.into(), x.top_count),
100            ErasedStats::U16(x) => (x.top_value.into(), x.top_count),
101            ErasedStats::U32(x) => (x.top_value.into(), x.top_count),
102            ErasedStats::U64(x) => (x.top_value.into(), x.top_count),
103            ErasedStats::I8(x) => (x.top_value.into(), x.top_count),
104            ErasedStats::I16(x) => (x.top_value.into(), x.top_count),
105            ErasedStats::I32(x) => (x.top_value.into(), x.top_count),
106            ErasedStats::I64(x) => (x.top_value.into(), x.top_count),
107        }
108    }
109}
110
111macro_rules! impl_from_typed {
112    ($T:ty, $variant:path) => {
113        impl From<TypedStats<$T>> for ErasedStats {
114            fn from(typed: TypedStats<$T>) -> Self {
115                $variant(typed)
116            }
117        }
118    };
119}
120
121impl_from_typed!(u8, ErasedStats::U8);
122impl_from_typed!(u16, ErasedStats::U16);
123impl_from_typed!(u32, ErasedStats::U32);
124impl_from_typed!(u64, ErasedStats::U64);
125impl_from_typed!(i8, ErasedStats::I8);
126impl_from_typed!(i16, ErasedStats::I16);
127impl_from_typed!(i32, ErasedStats::I32);
128impl_from_typed!(i64, ErasedStats::I64);
129
130/// Array of integers and relevant stats for compression.
131#[derive(Clone, Debug)]
132pub struct IntegerStats {
133    pub(super) src: PrimitiveArray,
134    // cache for validity.false_count()
135    pub(super) null_count: u32,
136    // cache for validity.true_count()
137    pub(super) value_count: u32,
138    pub(super) average_run_length: u32,
139    pub(super) distinct_values_count: u32,
140    pub(crate) typed: ErasedStats,
141}
142
143impl IntegerStats {
144    fn generate_opts_fallible(
145        input: &PrimitiveArray,
146        opts: GenerateStatsOptions,
147    ) -> VortexResult<Self> {
148        match_each_integer_ptype!(input.ptype(), |T| {
149            typed_int_stats::<T>(input, opts.count_distinct_values)
150        })
151    }
152}
153
154impl CompressorStats for IntegerStats {
155    type ArrayVTable = PrimitiveVTable;
156
157    fn generate_opts(input: &PrimitiveArray, opts: GenerateStatsOptions) -> Self {
158        Self::generate_opts_fallible(input, opts)
159            .vortex_expect("IntegerStats::generate_opts should not fail")
160    }
161
162    fn source(&self) -> &PrimitiveArray {
163        &self.src
164    }
165
166    fn sample_opts(&self, sample_size: u32, sample_count: u32, opts: GenerateStatsOptions) -> Self {
167        let sampled = sample(self.src.as_ref(), sample_size, sample_count).to_primitive();
168
169        Self::generate_opts(&sampled, opts)
170    }
171}
172
173impl RLEStats for IntegerStats {
174    fn value_count(&self) -> u32 {
175        self.value_count
176    }
177
178    fn average_run_length(&self) -> u32 {
179        self.average_run_length
180    }
181
182    fn source(&self) -> &PrimitiveArray {
183        &self.src
184    }
185}
186
187fn typed_int_stats<T>(
188    array: &PrimitiveArray,
189    count_distinct_values: bool,
190) -> VortexResult<IntegerStats>
191where
192    T: IntegerPType + PrimInt + for<'a> TryFrom<&'a Scalar, Error = VortexError>,
193    TypedStats<T>: Into<ErasedStats>,
194    NativeValue<T>: Eq + Hash,
195{
196    // Special case: empty array
197    if array.is_empty() {
198        return Ok(IntegerStats {
199            src: array.clone(),
200            null_count: 0,
201            value_count: 0,
202            average_run_length: 0,
203            distinct_values_count: 0,
204            typed: TypedStats {
205                min: T::max_value(),
206                max: T::min_value(),
207                top_value: T::default(),
208                top_count: 0,
209                distinct_values: HashMap::with_hasher(FxBuildHasher),
210            }
211            .into(),
212        });
213    } else if array.all_invalid()? {
214        return Ok(IntegerStats {
215            src: array.clone(),
216            null_count: u32::try_from(array.len())?,
217            value_count: 0,
218            average_run_length: 0,
219            distinct_values_count: 0,
220            typed: TypedStats {
221                min: T::max_value(),
222                max: T::min_value(),
223                top_value: T::default(),
224                top_count: 0,
225                distinct_values: HashMap::with_hasher(FxBuildHasher),
226            }
227            .into(),
228        });
229    }
230
231    let validity = array.validity_mask()?;
232    let null_count = validity.false_count();
233    let value_count = validity.true_count();
234
235    // Initialize loop state
236    let head_idx = validity
237        .first()
238        .vortex_expect("All null masks have been handled before");
239    let buffer = array.to_buffer::<T>();
240    let head = buffer[head_idx];
241
242    let mut loop_state = LoopState {
243        distinct_values: if count_distinct_values {
244            HashMap::with_capacity_and_hasher(array.len() / 2, FxBuildHasher)
245        } else {
246            HashMap::with_hasher(FxBuildHasher)
247        },
248        prev: head,
249        runs: 1,
250    };
251
252    let sliced = buffer.slice(head_idx..array.len());
253    let mut chunks = sliced.as_slice().chunks_exact(64);
254    match validity.bit_buffer() {
255        AllOr::All => {
256            for chunk in &mut chunks {
257                inner_loop_nonnull(
258                    chunk.try_into().ok().vortex_expect("chunk size must be 64"),
259                    count_distinct_values,
260                    &mut loop_state,
261                )
262            }
263            let remainder = chunks.remainder();
264            inner_loop_naive(
265                remainder,
266                count_distinct_values,
267                &BitBuffer::new_set(remainder.len()),
268                &mut loop_state,
269            );
270        }
271        AllOr::None => unreachable!("All invalid arrays have been handled before"),
272        AllOr::Some(v) => {
273            let mask = v.slice(head_idx..array.len());
274            let mut offset = 0;
275            for chunk in &mut chunks {
276                let validity = mask.slice(offset..(offset + 64));
277                offset += 64;
278
279                match validity.true_count() {
280                    // All nulls -> no stats to update
281                    0 => continue,
282                    // Inner loop for when validity check can be elided
283                    64 => inner_loop_nonnull(
284                        chunk.try_into().ok().vortex_expect("chunk size must be 64"),
285                        count_distinct_values,
286                        &mut loop_state,
287                    ),
288                    // Inner loop for when we need to check validity
289                    _ => inner_loop_nullable(
290                        chunk.try_into().ok().vortex_expect("chunk size must be 64"),
291                        count_distinct_values,
292                        &validity,
293                        &mut loop_state,
294                    ),
295                }
296            }
297            // Final iteration, run naive loop
298            let remainder = chunks.remainder();
299            inner_loop_naive(
300                remainder,
301                count_distinct_values,
302                &mask.slice(offset..(offset + remainder.len())),
303                &mut loop_state,
304            );
305        }
306    }
307
308    let (top_value, top_count) = if count_distinct_values {
309        let (&top_value, &top_count) = loop_state
310            .distinct_values
311            .iter()
312            .max_by_key(|&(_, &count)| count)
313            .vortex_expect("non-empty");
314        (top_value.0, top_count)
315    } else {
316        (T::default(), 0)
317    };
318
319    let runs = loop_state.runs;
320    let distinct_values_count = if count_distinct_values {
321        u32::try_from(loop_state.distinct_values.len())?
322    } else {
323        u32::MAX
324    };
325
326    let min = array
327        .statistics()
328        .compute_as::<T>(Stat::Min)
329        .vortex_expect("min should be computed");
330
331    let max = array
332        .statistics()
333        .compute_as::<T>(Stat::Max)
334        .vortex_expect("max should be computed");
335
336    let typed = TypedStats {
337        min,
338        max,
339        distinct_values: loop_state.distinct_values,
340        top_value,
341        top_count,
342    };
343
344    let null_count = u32::try_from(null_count)?;
345    let value_count = u32::try_from(value_count)?;
346
347    Ok(IntegerStats {
348        src: array.clone(),
349        null_count,
350        value_count,
351        average_run_length: value_count / runs,
352        distinct_values_count,
353        typed: typed.into(),
354    })
355}
356
357struct LoopState<T> {
358    prev: T,
359    runs: u32,
360    distinct_values: HashMap<NativeValue<T>, u32, FxBuildHasher>,
361}
362
363#[inline(always)]
364fn inner_loop_nonnull<T: IntegerPType>(
365    values: &[T; 64],
366    count_distinct_values: bool,
367    state: &mut LoopState<T>,
368) where
369    NativeValue<T>: Eq + Hash,
370{
371    for &value in values {
372        if count_distinct_values {
373            *state.distinct_values.entry(NativeValue(value)).or_insert(0) += 1;
374        }
375
376        if value != state.prev {
377            state.prev = value;
378            state.runs += 1;
379        }
380    }
381}
382
383#[inline(always)]
384fn inner_loop_nullable<T: IntegerPType>(
385    values: &[T; 64],
386    count_distinct_values: bool,
387    is_valid: &BitBuffer,
388    state: &mut LoopState<T>,
389) where
390    NativeValue<T>: Eq + Hash,
391{
392    for (idx, &value) in values.iter().enumerate() {
393        if is_valid.value(idx) {
394            if count_distinct_values {
395                *state.distinct_values.entry(NativeValue(value)).or_insert(0) += 1;
396            }
397
398            if value != state.prev {
399                state.prev = value;
400                state.runs += 1;
401            }
402        }
403    }
404}
405
406#[inline(always)]
407fn inner_loop_naive<T: IntegerPType>(
408    values: &[T],
409    count_distinct_values: bool,
410    is_valid: &BitBuffer,
411    state: &mut LoopState<T>,
412) where
413    NativeValue<T>: Eq + Hash,
414{
415    for (idx, &value) in values.iter().enumerate() {
416        if is_valid.value(idx) {
417            if count_distinct_values {
418                *state.distinct_values.entry(NativeValue(value)).or_insert(0) += 1;
419            }
420
421            if value != state.prev {
422                state.prev = value;
423                state.runs += 1;
424            }
425        }
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use std::iter;
432
433    use vortex_array::arrays::PrimitiveArray;
434    use vortex_array::validity::Validity;
435    use vortex_buffer::BitBuffer;
436    use vortex_buffer::Buffer;
437    use vortex_buffer::buffer;
438    use vortex_error::VortexResult;
439
440    use super::IntegerStats;
441    use super::typed_int_stats;
442    use crate::CompressorStats;
443
444    #[test]
445    fn test_naive_count_distinct_values() -> VortexResult<()> {
446        let array = PrimitiveArray::new(buffer![217u8, 0], Validity::NonNullable);
447        let stats = typed_int_stats::<u8>(&array, true)?;
448        assert_eq!(stats.distinct_values_count, 2);
449        Ok(())
450    }
451
452    #[test]
453    fn test_naive_count_distinct_values_nullable() -> VortexResult<()> {
454        let array = PrimitiveArray::new(
455            buffer![217u8, 0],
456            Validity::from(BitBuffer::from(vec![true, false])),
457        );
458        let stats = typed_int_stats::<u8>(&array, true)?;
459        assert_eq!(stats.distinct_values_count, 1);
460        Ok(())
461    }
462
463    #[test]
464    fn test_count_distinct_values() -> VortexResult<()> {
465        let array = PrimitiveArray::new((0..128u8).collect::<Buffer<u8>>(), Validity::NonNullable);
466        let stats = typed_int_stats::<u8>(&array, true)?;
467        assert_eq!(stats.distinct_values_count, 128);
468        Ok(())
469    }
470
471    #[test]
472    fn test_count_distinct_values_nullable() -> VortexResult<()> {
473        let array = PrimitiveArray::new(
474            (0..128u8).collect::<Buffer<u8>>(),
475            Validity::from(BitBuffer::from_iter(
476                iter::repeat_n(vec![true, false], 64).flatten(),
477            )),
478        );
479        let stats = typed_int_stats::<u8>(&array, true)?;
480        assert_eq!(stats.distinct_values_count, 64);
481        Ok(())
482    }
483
484    #[test]
485    fn test_integer_stats_leading_nulls() {
486        let ints = PrimitiveArray::new(buffer![0, 1, 2], Validity::from_iter([false, true, true]));
487
488        let stats = IntegerStats::generate(&ints);
489
490        assert_eq!(stats.value_count, 2);
491        assert_eq!(stats.null_count, 1);
492        assert_eq!(stats.average_run_length, 1);
493        assert_eq!(stats.distinct_values_count, 2);
494    }
495}