vortex_btrblocks/float/
stats.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use itertools::Itertools;
7use num_traits::Float;
8use rustc_hash::FxBuildHasher;
9use vortex_array::ToCanonical;
10use vortex_array::arrays::NativeValue;
11use vortex_array::arrays::PrimitiveArray;
12use vortex_array::arrays::PrimitiveVTable;
13use vortex_dtype::NativePType;
14use vortex_dtype::PType;
15use vortex_dtype::half::f16;
16use vortex_error::VortexExpect;
17use vortex_error::vortex_panic;
18use vortex_mask::AllOr;
19use vortex_utils::aliases::hash_set::HashSet;
20
21use crate::CompressorStats;
22use crate::GenerateStatsOptions;
23use crate::rle::RLEStats;
24use crate::sample::sample;
25
26#[derive(Debug, Clone)]
27pub struct DistinctValues<T> {
28    pub values: HashSet<NativeValue<T>, FxBuildHasher>,
29}
30
31#[derive(Debug, Clone)]
32pub enum ErasedDistinctValues {
33    F16(DistinctValues<f16>),
34    F32(DistinctValues<f32>),
35    F64(DistinctValues<f64>),
36}
37
38macro_rules! impl_from_typed {
39    ($typ:ty, $variant:path) => {
40        impl From<DistinctValues<$typ>> for ErasedDistinctValues {
41            fn from(value: DistinctValues<$typ>) -> Self {
42                $variant(value)
43            }
44        }
45    };
46}
47
48impl_from_typed!(f16, ErasedDistinctValues::F16);
49impl_from_typed!(f32, ErasedDistinctValues::F32);
50impl_from_typed!(f64, ErasedDistinctValues::F64);
51
52/// Array of floating-point numbers and relevant stats for compression.
53#[derive(Debug, Clone)]
54pub struct FloatStats {
55    pub(super) src: PrimitiveArray,
56    // cache for validity.false_count()
57    pub(super) null_count: u32,
58    // cache for validity.true_count()
59    pub(super) value_count: u32,
60    #[allow(dead_code)]
61    pub(super) average_run_length: u32,
62    pub(super) distinct_values: ErasedDistinctValues,
63    pub(super) distinct_values_count: u32,
64}
65
66impl CompressorStats for FloatStats {
67    type ArrayVTable = PrimitiveVTable;
68
69    fn generate_opts(input: &PrimitiveArray, opts: GenerateStatsOptions) -> Self {
70        match input.ptype() {
71            PType::F16 => typed_float_stats::<f16>(input, opts.count_distinct_values),
72            PType::F32 => typed_float_stats::<f32>(input, opts.count_distinct_values),
73            PType::F64 => typed_float_stats::<f64>(input, opts.count_distinct_values),
74            _ => vortex_panic!("cannot generate FloatStats from ptype {}", input.ptype()),
75        }
76    }
77
78    fn source(&self) -> &PrimitiveArray {
79        &self.src
80    }
81
82    fn sample_opts(&self, sample_size: u32, sample_count: u32, opts: GenerateStatsOptions) -> Self {
83        let sampled = sample(self.src.as_ref(), sample_size, sample_count).to_primitive();
84
85        Self::generate_opts(&sampled, opts)
86    }
87}
88
89impl RLEStats for FloatStats {
90    fn value_count(&self) -> u32 {
91        self.value_count
92    }
93
94    fn average_run_length(&self) -> u32 {
95        self.average_run_length
96    }
97
98    fn source(&self) -> &PrimitiveArray {
99        &self.src
100    }
101}
102
103fn typed_float_stats<T: NativePType + Float>(
104    array: &PrimitiveArray,
105    count_distinct_values: bool,
106) -> FloatStats
107where
108    DistinctValues<T>: Into<ErasedDistinctValues>,
109    NativeValue<T>: Hash + Eq,
110{
111    // Special case: empty array
112    if array.is_empty() {
113        return FloatStats {
114            src: array.clone(),
115            null_count: 0,
116            value_count: 0,
117            average_run_length: 0,
118            distinct_values_count: 0,
119            distinct_values: DistinctValues {
120                values: HashSet::<NativeValue<T>, FxBuildHasher>::with_hasher(FxBuildHasher),
121            }
122            .into(),
123        };
124    } else if array.all_invalid() {
125        return FloatStats {
126            src: array.clone(),
127            null_count: array.len().try_into().vortex_expect("null_count"),
128            value_count: 0,
129            average_run_length: 0,
130            distinct_values_count: 0,
131            distinct_values: DistinctValues {
132                values: HashSet::<NativeValue<T>, FxBuildHasher>::with_hasher(FxBuildHasher),
133            }
134            .into(),
135        };
136    }
137
138    let null_count = array
139        .statistics()
140        .compute_null_count()
141        .vortex_expect("null count");
142    let value_count = array.len() - null_count;
143
144    // Keep a HashMap of T, then convert the keys into PValue afterward since value is
145    // so much more efficient to hash and search for.
146    let mut distinct_values = if count_distinct_values {
147        HashSet::with_capacity_and_hasher(array.len() / 2, FxBuildHasher)
148    } else {
149        HashSet::with_hasher(FxBuildHasher)
150    };
151
152    let validity = array.validity_mask();
153
154    let mut runs = 1;
155    let head_idx = validity
156        .first()
157        .vortex_expect("All null masks have been handled before");
158    let buff = array.buffer::<T>();
159    let mut prev = buff[head_idx];
160
161    let first_valid_buff = buff.slice(head_idx..array.len());
162    match validity.bit_buffer() {
163        AllOr::All => {
164            for value in first_valid_buff {
165                if count_distinct_values {
166                    distinct_values.insert(NativeValue(value));
167                }
168
169                if value != prev {
170                    prev = value;
171                    runs += 1;
172                }
173            }
174        }
175        AllOr::None => unreachable!("All invalid arrays have been handled earlier"),
176        AllOr::Some(v) => {
177            for (&value, valid) in first_valid_buff
178                .iter()
179                .zip_eq(v.slice(head_idx..array.len()).iter())
180            {
181                if valid {
182                    if count_distinct_values {
183                        distinct_values.insert(NativeValue(value));
184                    }
185
186                    if value != prev {
187                        prev = value;
188                        runs += 1;
189                    }
190                }
191            }
192        }
193    }
194
195    let null_count = null_count
196        .try_into()
197        .vortex_expect("null_count must fit in u32");
198    let value_count = value_count
199        .try_into()
200        .vortex_expect("null_count must fit in u32");
201    let distinct_values_count = if count_distinct_values {
202        distinct_values
203            .len()
204            .try_into()
205            .vortex_expect("distinct values count must fit in u32")
206    } else {
207        u32::MAX
208    };
209
210    FloatStats {
211        null_count,
212        value_count,
213        distinct_values_count,
214        src: array.clone(),
215        average_run_length: value_count / runs,
216        distinct_values: DistinctValues {
217            values: distinct_values,
218        }
219        .into(),
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use vortex_array::IntoArray;
226    use vortex_array::ToCanonical;
227    use vortex_array::arrays::PrimitiveArray;
228    use vortex_array::validity::Validity;
229    use vortex_buffer::buffer;
230
231    use crate::CompressorStats;
232    use crate::float::stats::FloatStats;
233
234    #[test]
235    fn test_float_stats() {
236        let floats = buffer![0.0f32, 1.0f32, 2.0f32].into_array();
237        let floats = floats.to_primitive();
238
239        let stats = FloatStats::generate(&floats);
240
241        assert_eq!(stats.value_count, 3);
242        assert_eq!(stats.null_count, 0);
243        assert_eq!(stats.average_run_length, 1);
244        assert_eq!(stats.distinct_values_count, 3);
245    }
246
247    #[test]
248    fn test_float_stats_leading_nulls() {
249        let floats = PrimitiveArray::new(
250            buffer![0.0f32, 1.0f32, 2.0f32],
251            Validity::from_iter([false, true, true]),
252        );
253
254        let stats = FloatStats::generate(&floats);
255
256        assert_eq!(stats.value_count, 2);
257        assert_eq!(stats.null_count, 1);
258        assert_eq!(stats.average_run_length, 1);
259        assert_eq!(stats.distinct_values_count, 2);
260    }
261}