vortex_btrblocks/float/
stats.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use itertools::Itertools;
7use num_traits::Float;
8use rustc_hash::FxBuildHasher;
9use vortex_array::ToCanonical;
10use vortex_array::arrays::{NativeValue, PrimitiveArray, PrimitiveVTable};
11use vortex_dtype::half::f16;
12use vortex_dtype::{NativePType, PType};
13use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
14use vortex_mask::AllOr;
15use vortex_utils::aliases::hash_set::HashSet;
16
17use crate::rle::RLEStats;
18use crate::sample::sample;
19use crate::{CompressorStats, GenerateStatsOptions};
20
21#[derive(Debug, Clone)]
22pub struct DistinctValues<T> {
23    pub values: HashSet<NativeValue<T>, FxBuildHasher>,
24}
25
26#[derive(Debug, Clone)]
27pub enum ErasedDistinctValues {
28    F16(DistinctValues<f16>),
29    F32(DistinctValues<f32>),
30    F64(DistinctValues<f64>),
31}
32
33macro_rules! impl_from_typed {
34    ($typ:ty, $variant:path) => {
35        impl From<DistinctValues<$typ>> for ErasedDistinctValues {
36            fn from(value: DistinctValues<$typ>) -> Self {
37                $variant(value)
38            }
39        }
40    };
41}
42
43impl_from_typed!(f16, ErasedDistinctValues::F16);
44impl_from_typed!(f32, ErasedDistinctValues::F32);
45impl_from_typed!(f64, ErasedDistinctValues::F64);
46
47/// Array of floating-point numbers and relevant stats for compression.
48#[derive(Debug, Clone)]
49pub struct FloatStats {
50    pub(super) src: PrimitiveArray,
51    // cache for validity.false_count()
52    pub(super) null_count: u32,
53    // cache for validity.true_count()
54    pub(super) value_count: u32,
55    #[allow(dead_code)]
56    pub(super) average_run_length: u32,
57    pub(super) distinct_values: ErasedDistinctValues,
58    pub(super) distinct_values_count: u32,
59}
60
61impl CompressorStats for FloatStats {
62    type ArrayVTable = PrimitiveVTable;
63
64    fn generate_opts(input: &PrimitiveArray, opts: GenerateStatsOptions) -> Self {
65        match input.ptype() {
66            PType::F16 => typed_float_stats::<f16>(input, opts.count_distinct_values),
67            PType::F32 => typed_float_stats::<f32>(input, opts.count_distinct_values),
68            PType::F64 => typed_float_stats::<f64>(input, opts.count_distinct_values),
69            _ => vortex_panic!("cannot generate FloatStats from ptype {}", input.ptype()),
70        }
71    }
72
73    fn source(&self) -> &PrimitiveArray {
74        &self.src
75    }
76
77    fn sample_opts(&self, sample_size: u32, sample_count: u32, opts: GenerateStatsOptions) -> Self {
78        let sampled = sample(self.src.as_ref(), sample_size, sample_count).to_primitive();
79
80        Self::generate_opts(&sampled, opts)
81    }
82}
83
84impl RLEStats for FloatStats {
85    fn value_count(&self) -> u32 {
86        self.value_count
87    }
88
89    fn average_run_length(&self) -> u32 {
90        self.average_run_length
91    }
92
93    fn source(&self) -> &PrimitiveArray {
94        &self.src
95    }
96}
97
98fn typed_float_stats<T: NativePType + Float>(
99    array: &PrimitiveArray,
100    count_distinct_values: bool,
101) -> FloatStats
102where
103    DistinctValues<T>: Into<ErasedDistinctValues>,
104    NativeValue<T>: Hash + Eq,
105{
106    // Special case: empty array
107    if array.is_empty() {
108        return FloatStats {
109            src: array.clone(),
110            null_count: 0,
111            value_count: 0,
112            average_run_length: 0,
113            distinct_values_count: 0,
114            distinct_values: DistinctValues {
115                values: HashSet::<NativeValue<T>, FxBuildHasher>::with_hasher(FxBuildHasher),
116            }
117            .into(),
118        };
119    } else if array.all_invalid() {
120        return FloatStats {
121            src: array.clone(),
122            null_count: array.len().try_into().vortex_expect("null_count"),
123            value_count: 0,
124            average_run_length: 0,
125            distinct_values_count: 0,
126            distinct_values: DistinctValues {
127                values: HashSet::<NativeValue<T>, FxBuildHasher>::with_hasher(FxBuildHasher),
128            }
129            .into(),
130        };
131    }
132
133    let null_count = array
134        .statistics()
135        .compute_null_count()
136        .vortex_expect("null count");
137    let value_count = array.len() - null_count;
138
139    // Keep a HashMap of T, then convert the keys into PValue afterward since value is
140    // so much more efficient to hash and search for.
141    let mut distinct_values = if count_distinct_values {
142        HashSet::with_capacity_and_hasher(array.len() / 2, FxBuildHasher)
143    } else {
144        HashSet::with_hasher(FxBuildHasher)
145    };
146
147    let validity = array.validity_mask();
148
149    let mut runs = 1;
150    let head_idx = validity
151        .first()
152        .vortex_expect("All null masks have been handled before");
153    let buff = array.buffer::<T>();
154    let mut prev = buff[head_idx];
155
156    let first_valid_buff = buff.slice(head_idx..array.len());
157    match validity.bit_buffer() {
158        AllOr::All => {
159            for value in first_valid_buff {
160                if count_distinct_values {
161                    distinct_values.insert(NativeValue(value));
162                }
163
164                if value != prev {
165                    prev = value;
166                    runs += 1;
167                }
168            }
169        }
170        AllOr::None => unreachable!("All invalid arrays have been handled earlier"),
171        AllOr::Some(v) => {
172            for (&value, valid) in first_valid_buff
173                .iter()
174                .zip_eq(v.slice(head_idx..array.len()).iter())
175            {
176                if valid {
177                    if count_distinct_values {
178                        distinct_values.insert(NativeValue(value));
179                    }
180
181                    if value != prev {
182                        prev = value;
183                        runs += 1;
184                    }
185                }
186            }
187        }
188    }
189
190    let null_count = null_count
191        .try_into()
192        .vortex_expect("null_count must fit in u32");
193    let value_count = value_count
194        .try_into()
195        .vortex_expect("null_count must fit in u32");
196    let distinct_values_count = if count_distinct_values {
197        distinct_values.len().try_into().vortex_unwrap()
198    } else {
199        u32::MAX
200    };
201
202    FloatStats {
203        null_count,
204        value_count,
205        distinct_values_count,
206        src: array.clone(),
207        average_run_length: value_count / runs,
208        distinct_values: DistinctValues {
209            values: distinct_values,
210        }
211        .into(),
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use vortex_array::arrays::PrimitiveArray;
218    use vortex_array::validity::Validity;
219    use vortex_array::{IntoArray, ToCanonical};
220    use vortex_buffer::buffer;
221
222    use crate::CompressorStats;
223    use crate::float::stats::FloatStats;
224
225    #[test]
226    fn test_float_stats() {
227        let floats = buffer![0.0f32, 1.0f32, 2.0f32].into_array();
228        let floats = floats.to_primitive();
229
230        let stats = FloatStats::generate(&floats);
231
232        assert_eq!(stats.value_count, 3);
233        assert_eq!(stats.null_count, 0);
234        assert_eq!(stats.average_run_length, 1);
235        assert_eq!(stats.distinct_values_count, 3);
236    }
237
238    #[test]
239    fn test_float_stats_leading_nulls() {
240        let floats = PrimitiveArray::new(
241            buffer![0.0f32, 1.0f32, 2.0f32],
242            Validity::from_iter([false, true, true]),
243        );
244
245        let stats = FloatStats::generate(&floats);
246
247        assert_eq!(stats.value_count, 2);
248        assert_eq!(stats.null_count, 1);
249        assert_eq!(stats.average_run_length, 1);
250        assert_eq!(stats.distinct_values_count, 2);
251    }
252}