vortex_btrblocks/float/
stats.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use itertools::Itertools;
7use num_traits::Float;
8use rustc_hash::FxBuildHasher;
9use vortex_array::ToCanonical;
10use vortex_array::arrays::NativeValue;
11use vortex_array::arrays::PrimitiveArray;
12use vortex_array::arrays::PrimitiveVTable;
13use vortex_dtype::NativePType;
14use vortex_dtype::PType;
15use vortex_dtype::half::f16;
16use vortex_error::VortexExpect;
17use vortex_error::VortexUnwrap;
18use vortex_error::vortex_panic;
19use vortex_mask::AllOr;
20use vortex_utils::aliases::hash_set::HashSet;
21
22use crate::CompressorStats;
23use crate::GenerateStatsOptions;
24use crate::rle::RLEStats;
25use crate::sample::sample;
26
27#[derive(Debug, Clone)]
28pub struct DistinctValues<T> {
29    pub values: HashSet<NativeValue<T>, FxBuildHasher>,
30}
31
32#[derive(Debug, Clone)]
33pub enum ErasedDistinctValues {
34    F16(DistinctValues<f16>),
35    F32(DistinctValues<f32>),
36    F64(DistinctValues<f64>),
37}
38
39macro_rules! impl_from_typed {
40    ($typ:ty, $variant:path) => {
41        impl From<DistinctValues<$typ>> for ErasedDistinctValues {
42            fn from(value: DistinctValues<$typ>) -> Self {
43                $variant(value)
44            }
45        }
46    };
47}
48
49impl_from_typed!(f16, ErasedDistinctValues::F16);
50impl_from_typed!(f32, ErasedDistinctValues::F32);
51impl_from_typed!(f64, ErasedDistinctValues::F64);
52
53/// Array of floating-point numbers and relevant stats for compression.
54#[derive(Debug, Clone)]
55pub struct FloatStats {
56    pub(super) src: PrimitiveArray,
57    // cache for validity.false_count()
58    pub(super) null_count: u32,
59    // cache for validity.true_count()
60    pub(super) value_count: u32,
61    #[allow(dead_code)]
62    pub(super) average_run_length: u32,
63    pub(super) distinct_values: ErasedDistinctValues,
64    pub(super) distinct_values_count: u32,
65}
66
67impl CompressorStats for FloatStats {
68    type ArrayVTable = PrimitiveVTable;
69
70    fn generate_opts(input: &PrimitiveArray, opts: GenerateStatsOptions) -> Self {
71        match input.ptype() {
72            PType::F16 => typed_float_stats::<f16>(input, opts.count_distinct_values),
73            PType::F32 => typed_float_stats::<f32>(input, opts.count_distinct_values),
74            PType::F64 => typed_float_stats::<f64>(input, opts.count_distinct_values),
75            _ => vortex_panic!("cannot generate FloatStats from ptype {}", input.ptype()),
76        }
77    }
78
79    fn source(&self) -> &PrimitiveArray {
80        &self.src
81    }
82
83    fn sample_opts(&self, sample_size: u32, sample_count: u32, opts: GenerateStatsOptions) -> Self {
84        let sampled = sample(self.src.as_ref(), sample_size, sample_count).to_primitive();
85
86        Self::generate_opts(&sampled, opts)
87    }
88}
89
90impl RLEStats for FloatStats {
91    fn value_count(&self) -> u32 {
92        self.value_count
93    }
94
95    fn average_run_length(&self) -> u32 {
96        self.average_run_length
97    }
98
99    fn source(&self) -> &PrimitiveArray {
100        &self.src
101    }
102}
103
104fn typed_float_stats<T: NativePType + Float>(
105    array: &PrimitiveArray,
106    count_distinct_values: bool,
107) -> FloatStats
108where
109    DistinctValues<T>: Into<ErasedDistinctValues>,
110    NativeValue<T>: Hash + Eq,
111{
112    // Special case: empty array
113    if array.is_empty() {
114        return FloatStats {
115            src: array.clone(),
116            null_count: 0,
117            value_count: 0,
118            average_run_length: 0,
119            distinct_values_count: 0,
120            distinct_values: DistinctValues {
121                values: HashSet::<NativeValue<T>, FxBuildHasher>::with_hasher(FxBuildHasher),
122            }
123            .into(),
124        };
125    } else if array.all_invalid() {
126        return FloatStats {
127            src: array.clone(),
128            null_count: array.len().try_into().vortex_expect("null_count"),
129            value_count: 0,
130            average_run_length: 0,
131            distinct_values_count: 0,
132            distinct_values: DistinctValues {
133                values: HashSet::<NativeValue<T>, FxBuildHasher>::with_hasher(FxBuildHasher),
134            }
135            .into(),
136        };
137    }
138
139    let null_count = array
140        .statistics()
141        .compute_null_count()
142        .vortex_expect("null count");
143    let value_count = array.len() - null_count;
144
145    // Keep a HashMap of T, then convert the keys into PValue afterward since value is
146    // so much more efficient to hash and search for.
147    let mut distinct_values = if count_distinct_values {
148        HashSet::with_capacity_and_hasher(array.len() / 2, FxBuildHasher)
149    } else {
150        HashSet::with_hasher(FxBuildHasher)
151    };
152
153    let validity = array.validity_mask();
154
155    let mut runs = 1;
156    let head_idx = validity
157        .first()
158        .vortex_expect("All null masks have been handled before");
159    let buff = array.buffer::<T>();
160    let mut prev = buff[head_idx];
161
162    let first_valid_buff = buff.slice(head_idx..array.len());
163    match validity.bit_buffer() {
164        AllOr::All => {
165            for value in first_valid_buff {
166                if count_distinct_values {
167                    distinct_values.insert(NativeValue(value));
168                }
169
170                if value != prev {
171                    prev = value;
172                    runs += 1;
173                }
174            }
175        }
176        AllOr::None => unreachable!("All invalid arrays have been handled earlier"),
177        AllOr::Some(v) => {
178            for (&value, valid) in first_valid_buff
179                .iter()
180                .zip_eq(v.slice(head_idx..array.len()).iter())
181            {
182                if valid {
183                    if count_distinct_values {
184                        distinct_values.insert(NativeValue(value));
185                    }
186
187                    if value != prev {
188                        prev = value;
189                        runs += 1;
190                    }
191                }
192            }
193        }
194    }
195
196    let null_count = null_count
197        .try_into()
198        .vortex_expect("null_count must fit in u32");
199    let value_count = value_count
200        .try_into()
201        .vortex_expect("null_count must fit in u32");
202    let distinct_values_count = if count_distinct_values {
203        distinct_values.len().try_into().vortex_unwrap()
204    } else {
205        u32::MAX
206    };
207
208    FloatStats {
209        null_count,
210        value_count,
211        distinct_values_count,
212        src: array.clone(),
213        average_run_length: value_count / runs,
214        distinct_values: DistinctValues {
215            values: distinct_values,
216        }
217        .into(),
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use vortex_array::IntoArray;
224    use vortex_array::ToCanonical;
225    use vortex_array::arrays::PrimitiveArray;
226    use vortex_array::validity::Validity;
227    use vortex_buffer::buffer;
228
229    use crate::CompressorStats;
230    use crate::float::stats::FloatStats;
231
232    #[test]
233    fn test_float_stats() {
234        let floats = buffer![0.0f32, 1.0f32, 2.0f32].into_array();
235        let floats = floats.to_primitive();
236
237        let stats = FloatStats::generate(&floats);
238
239        assert_eq!(stats.value_count, 3);
240        assert_eq!(stats.null_count, 0);
241        assert_eq!(stats.average_run_length, 1);
242        assert_eq!(stats.distinct_values_count, 3);
243    }
244
245    #[test]
246    fn test_float_stats_leading_nulls() {
247        let floats = PrimitiveArray::new(
248            buffer![0.0f32, 1.0f32, 2.0f32],
249            Validity::from_iter([false, true, true]),
250        );
251
252        let stats = FloatStats::generate(&floats);
253
254        assert_eq!(stats.value_count, 2);
255        assert_eq!(stats.null_count, 1);
256        assert_eq!(stats.average_run_length, 1);
257        assert_eq!(stats.distinct_values_count, 2);
258    }
259}