vortex_btrblocks/integer/
stats.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use num_traits::PrimInt;
7use rustc_hash::FxBuildHasher;
8use vortex_array::ToCanonical;
9use vortex_array::arrays::{NativeValue, PrimitiveArray, PrimitiveVTable};
10use vortex_array::stats::Stat;
11use vortex_buffer::BitBuffer;
12use vortex_dtype::{IntegerPType, match_each_integer_ptype};
13use vortex_error::{VortexError, VortexExpect, VortexUnwrap};
14use vortex_mask::AllOr;
15use vortex_scalar::{PValue, Scalar};
16use vortex_utils::aliases::hash_map::HashMap;
17
18use crate::rle::RLEStats;
19use crate::sample::sample;
20use crate::{CompressorStats, GenerateStatsOptions};
21
22#[derive(Clone, Debug)]
23pub struct TypedStats<T> {
24    pub min: T,
25    pub max: T,
26    pub top_value: T,
27    pub top_count: u32,
28    pub distinct_values: HashMap<NativeValue<T>, u32, FxBuildHasher>,
29}
30
31/// Type-erased container for one of the [TypedStats] variants.
32///
33/// Building the `TypedStats` is considerably faster and cheaper than building a type-erased
34/// set of stats. We then perform a variety of access methods on them.
35#[derive(Clone, Debug)]
36pub enum ErasedStats {
37    U8(TypedStats<u8>),
38    U16(TypedStats<u16>),
39    U32(TypedStats<u32>),
40    U64(TypedStats<u64>),
41    I8(TypedStats<i8>),
42    I16(TypedStats<i16>),
43    I32(TypedStats<i32>),
44    I64(TypedStats<i64>),
45}
46
47impl ErasedStats {
48    pub fn min_is_zero(&self) -> bool {
49        match &self {
50            ErasedStats::U8(x) => x.min == 0,
51            ErasedStats::U16(x) => x.min == 0,
52            ErasedStats::U32(x) => x.min == 0,
53            ErasedStats::U64(x) => x.min == 0,
54            ErasedStats::I8(x) => x.min == 0,
55            ErasedStats::I16(x) => x.min == 0,
56            ErasedStats::I32(x) => x.min == 0,
57            ErasedStats::I64(x) => x.min == 0,
58        }
59    }
60
61    pub fn min_is_negative(&self) -> bool {
62        match &self {
63            ErasedStats::U8(_)
64            | ErasedStats::U16(_)
65            | ErasedStats::U32(_)
66            | ErasedStats::U64(_) => false,
67            ErasedStats::I8(x) => x.min < 0,
68            ErasedStats::I16(x) => x.min < 0,
69            ErasedStats::I32(x) => x.min < 0,
70            ErasedStats::I64(x) => x.min < 0,
71        }
72    }
73
74    // Difference between max and min.
75    pub fn max_minus_min(&self) -> u64 {
76        match &self {
77            ErasedStats::U8(x) => (x.max - x.min) as u64,
78            ErasedStats::U16(x) => (x.max - x.min) as u64,
79            ErasedStats::U32(x) => (x.max - x.min) as u64,
80            ErasedStats::U64(x) => x.max - x.min,
81            ErasedStats::I8(x) => (x.max as i16 - x.min as i16) as u64,
82            ErasedStats::I16(x) => (x.max as i32 - x.min as i32) as u64,
83            ErasedStats::I32(x) => (x.max as i64 - x.min as i64) as u64,
84            ErasedStats::I64(x) => u64::try_from(x.max as i128 - x.min as i128)
85                .vortex_expect("max minus min result bigger than u64"),
86        }
87    }
88
89    /// Get the most commonly occurring value and its count
90    pub fn top_value_and_count(&self) -> (PValue, u32) {
91        match &self {
92            ErasedStats::U8(x) => (x.top_value.into(), x.top_count),
93            ErasedStats::U16(x) => (x.top_value.into(), x.top_count),
94            ErasedStats::U32(x) => (x.top_value.into(), x.top_count),
95            ErasedStats::U64(x) => (x.top_value.into(), x.top_count),
96            ErasedStats::I8(x) => (x.top_value.into(), x.top_count),
97            ErasedStats::I16(x) => (x.top_value.into(), x.top_count),
98            ErasedStats::I32(x) => (x.top_value.into(), x.top_count),
99            ErasedStats::I64(x) => (x.top_value.into(), x.top_count),
100        }
101    }
102}
103
104macro_rules! impl_from_typed {
105    ($T:ty, $variant:path) => {
106        impl From<TypedStats<$T>> for ErasedStats {
107            fn from(typed: TypedStats<$T>) -> Self {
108                $variant(typed)
109            }
110        }
111    };
112}
113
114impl_from_typed!(u8, ErasedStats::U8);
115impl_from_typed!(u16, ErasedStats::U16);
116impl_from_typed!(u32, ErasedStats::U32);
117impl_from_typed!(u64, ErasedStats::U64);
118impl_from_typed!(i8, ErasedStats::I8);
119impl_from_typed!(i16, ErasedStats::I16);
120impl_from_typed!(i32, ErasedStats::I32);
121impl_from_typed!(i64, ErasedStats::I64);
122
123/// Array of integers and relevant stats for compression.
124#[derive(Clone, Debug)]
125pub struct IntegerStats {
126    pub(super) src: PrimitiveArray,
127    // cache for validity.false_count()
128    pub(super) null_count: u32,
129    // cache for validity.true_count()
130    pub(super) value_count: u32,
131    pub(super) average_run_length: u32,
132    pub(super) distinct_values_count: u32,
133    pub(crate) typed: ErasedStats,
134}
135
136impl CompressorStats for IntegerStats {
137    type ArrayVTable = PrimitiveVTable;
138
139    fn generate_opts(input: &PrimitiveArray, opts: GenerateStatsOptions) -> Self {
140        match_each_integer_ptype!(input.ptype(), |T| {
141            typed_int_stats::<T>(input, opts.count_distinct_values)
142        })
143    }
144
145    fn source(&self) -> &PrimitiveArray {
146        &self.src
147    }
148
149    fn sample_opts(&self, sample_size: u32, sample_count: u32, opts: GenerateStatsOptions) -> Self {
150        let sampled = sample(self.src.as_ref(), sample_size, sample_count).to_primitive();
151
152        Self::generate_opts(&sampled, opts)
153    }
154}
155
156impl RLEStats for IntegerStats {
157    fn value_count(&self) -> u32 {
158        self.value_count
159    }
160
161    fn average_run_length(&self) -> u32 {
162        self.average_run_length
163    }
164
165    fn source(&self) -> &PrimitiveArray {
166        &self.src
167    }
168}
169
170fn typed_int_stats<T>(array: &PrimitiveArray, count_distinct_values: bool) -> IntegerStats
171where
172    T: IntegerPType + PrimInt + for<'a> TryFrom<&'a Scalar, Error = VortexError>,
173    TypedStats<T>: Into<ErasedStats>,
174    NativeValue<T>: Eq + Hash,
175{
176    // Special case: empty array
177    if array.is_empty() {
178        return IntegerStats {
179            src: array.clone(),
180            null_count: 0,
181            value_count: 0,
182            average_run_length: 0,
183            distinct_values_count: 0,
184            typed: TypedStats {
185                min: T::max_value(),
186                max: T::min_value(),
187                top_value: T::default(),
188                top_count: 0,
189                distinct_values: HashMap::with_hasher(FxBuildHasher),
190            }
191            .into(),
192        };
193    } else if array.all_invalid() {
194        return IntegerStats {
195            src: array.clone(),
196            null_count: array.len().try_into().vortex_expect("null_count"),
197            value_count: 0,
198            average_run_length: 0,
199            distinct_values_count: 0,
200            typed: TypedStats {
201                min: T::max_value(),
202                max: T::min_value(),
203                top_value: T::default(),
204                top_count: 0,
205                distinct_values: HashMap::with_hasher(FxBuildHasher),
206            }
207            .into(),
208        };
209    }
210
211    let validity = array.validity_mask();
212    let null_count = validity.false_count();
213    let value_count = validity.true_count();
214
215    // Initialize loop state
216    let head_idx = validity
217        .first()
218        .vortex_expect("All null masks have been handled before");
219    let buffer = array.buffer::<T>();
220    let head = buffer[head_idx];
221
222    let mut loop_state = LoopState {
223        distinct_values: if count_distinct_values {
224            HashMap::with_capacity_and_hasher(array.len() / 2, FxBuildHasher)
225        } else {
226            HashMap::with_hasher(FxBuildHasher)
227        },
228        prev: head,
229        runs: 1,
230    };
231
232    let sliced = buffer.slice(head_idx..array.len());
233    let mut chunks = sliced.as_slice().chunks_exact(64);
234    match validity.bit_buffer() {
235        AllOr::All => {
236            for chunk in &mut chunks {
237                inner_loop_nonnull(
238                    chunk.try_into().vortex_unwrap(),
239                    count_distinct_values,
240                    &mut loop_state,
241                )
242            }
243            let remainder = chunks.remainder();
244            inner_loop_naive(
245                remainder,
246                count_distinct_values,
247                &BitBuffer::new_set(remainder.len()),
248                &mut loop_state,
249            );
250        }
251        AllOr::None => unreachable!("All invalid arrays have been handled before"),
252        AllOr::Some(v) => {
253            let mask = v.slice(head_idx..array.len());
254            let mut offset = 0;
255            for chunk in &mut chunks {
256                let validity = mask.slice(offset..(offset + 64));
257                offset += 64;
258
259                match validity.true_count() {
260                    // All nulls -> no stats to update
261                    0 => continue,
262                    // Inner loop for when validity check can be elided
263                    64 => inner_loop_nonnull(
264                        chunk.try_into().vortex_unwrap(),
265                        count_distinct_values,
266                        &mut loop_state,
267                    ),
268                    // Inner loop for when we need to check validity
269                    _ => inner_loop_nullable(
270                        chunk.try_into().vortex_unwrap(),
271                        count_distinct_values,
272                        &validity,
273                        &mut loop_state,
274                    ),
275                }
276            }
277            // Final iteration, run naive loop
278            let remainder = chunks.remainder();
279            inner_loop_naive(
280                remainder,
281                count_distinct_values,
282                &mask.slice(offset..(offset + remainder.len())),
283                &mut loop_state,
284            );
285        }
286    }
287
288    let (top_value, top_count) = if count_distinct_values {
289        let (&top_value, &top_count) = loop_state
290            .distinct_values
291            .iter()
292            .max_by_key(|&(_, &count)| count)
293            .vortex_expect("non-empty");
294        (top_value.0, top_count)
295    } else {
296        (T::default(), 0)
297    };
298
299    let runs = loop_state.runs;
300    let distinct_values_count = if count_distinct_values {
301        loop_state.distinct_values.len().try_into().vortex_unwrap()
302    } else {
303        u32::MAX
304    };
305
306    let min = array
307        .statistics()
308        .compute_as::<T>(Stat::Min)
309        .vortex_expect("min should be computed");
310
311    let max = array
312        .statistics()
313        .compute_as::<T>(Stat::Max)
314        .vortex_expect("max should be computed");
315
316    let typed = TypedStats {
317        min,
318        max,
319        distinct_values: loop_state.distinct_values,
320        top_value,
321        top_count,
322    };
323
324    let null_count = null_count
325        .try_into()
326        .vortex_expect("null_count must fit in u32");
327    let value_count = value_count
328        .try_into()
329        .vortex_expect("value_count must fit in u32");
330
331    IntegerStats {
332        src: array.clone(),
333        null_count,
334        value_count,
335        average_run_length: value_count / runs,
336        distinct_values_count,
337        typed: typed.into(),
338    }
339}
340
341struct LoopState<T> {
342    prev: T,
343    runs: u32,
344    distinct_values: HashMap<NativeValue<T>, u32, FxBuildHasher>,
345}
346
347#[inline(always)]
348fn inner_loop_nonnull<T: IntegerPType>(
349    values: &[T; 64],
350    count_distinct_values: bool,
351    state: &mut LoopState<T>,
352) where
353    NativeValue<T>: Eq + Hash,
354{
355    for &value in values {
356        if count_distinct_values {
357            *state.distinct_values.entry(NativeValue(value)).or_insert(0) += 1;
358        }
359
360        if value != state.prev {
361            state.prev = value;
362            state.runs += 1;
363        }
364    }
365}
366
367#[inline(always)]
368fn inner_loop_nullable<T: IntegerPType>(
369    values: &[T; 64],
370    count_distinct_values: bool,
371    is_valid: &BitBuffer,
372    state: &mut LoopState<T>,
373) where
374    NativeValue<T>: Eq + Hash,
375{
376    for (idx, &value) in values.iter().enumerate() {
377        if is_valid.value(idx) {
378            if count_distinct_values {
379                *state.distinct_values.entry(NativeValue(value)).or_insert(0) += 1;
380            }
381
382            if value != state.prev {
383                state.prev = value;
384                state.runs += 1;
385            }
386        }
387    }
388}
389
390#[inline(always)]
391fn inner_loop_naive<T: IntegerPType>(
392    values: &[T],
393    count_distinct_values: bool,
394    is_valid: &BitBuffer,
395    state: &mut LoopState<T>,
396) where
397    NativeValue<T>: Eq + Hash,
398{
399    for (idx, &value) in values.iter().enumerate() {
400        if is_valid.value(idx) {
401            if count_distinct_values {
402                *state.distinct_values.entry(NativeValue(value)).or_insert(0) += 1;
403            }
404
405            if value != state.prev {
406                state.prev = value;
407                state.runs += 1;
408            }
409        }
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use std::iter;
416
417    use vortex_array::arrays::PrimitiveArray;
418    use vortex_array::validity::Validity;
419    use vortex_buffer::{BitBuffer, Buffer, buffer};
420
421    use crate::CompressorStats;
422    use crate::integer::IntegerStats;
423    use crate::integer::stats::typed_int_stats;
424
425    #[test]
426    fn test_naive_count_distinct_values() {
427        let array = PrimitiveArray::new(buffer![217u8, 0], Validity::NonNullable);
428        let stats = typed_int_stats::<u8>(&array, true);
429        assert_eq!(stats.distinct_values_count, 2);
430    }
431
432    #[test]
433    fn test_naive_count_distinct_values_nullable() {
434        let array = PrimitiveArray::new(
435            buffer![217u8, 0],
436            Validity::from(BitBuffer::from(vec![true, false])),
437        );
438        let stats = typed_int_stats::<u8>(&array, true);
439        assert_eq!(stats.distinct_values_count, 1);
440    }
441
442    #[test]
443    fn test_count_distinct_values() {
444        let array = PrimitiveArray::new((0..128u8).collect::<Buffer<u8>>(), Validity::NonNullable);
445        let stats = typed_int_stats::<u8>(&array, true);
446        assert_eq!(stats.distinct_values_count, 128);
447    }
448
449    #[test]
450    fn test_count_distinct_values_nullable() {
451        let array = PrimitiveArray::new(
452            (0..128u8).collect::<Buffer<u8>>(),
453            Validity::from(BitBuffer::from_iter(
454                iter::repeat_n(vec![true, false], 64).flatten(),
455            )),
456        );
457        let stats = typed_int_stats::<u8>(&array, true);
458        assert_eq!(stats.distinct_values_count, 64);
459    }
460
461    #[test]
462    fn test_integer_stats_leading_nulls() {
463        let ints = PrimitiveArray::new(buffer![0, 1, 2], Validity::from_iter([false, true, true]));
464
465        let stats = IntegerStats::generate(&ints);
466
467        assert_eq!(stats.value_count, 2);
468        assert_eq!(stats.null_count, 1);
469        assert_eq!(stats.average_run_length, 1);
470        assert_eq!(stats.distinct_values_count, 2);
471    }
472}