Skip to main content

vortex_compressor/stats/
float.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Float compression statistics.
5
6use std::hash::Hash;
7
8use itertools::Itertools;
9use num_traits::Float;
10use rustc_hash::FxBuildHasher;
11use vortex_array::LEGACY_SESSION;
12use vortex_array::VortexSessionExecute;
13use vortex_array::arrays::PrimitiveArray;
14use vortex_array::arrays::primitive::NativeValue;
15use vortex_array::dtype::NativePType;
16use vortex_array::dtype::PType;
17use vortex_array::dtype::half::f16;
18use vortex_error::VortexExpect;
19use vortex_error::VortexResult;
20use vortex_error::vortex_err;
21use vortex_error::vortex_panic;
22use vortex_mask::AllOr;
23use vortex_utils::aliases::hash_set::HashSet;
24
25use super::GenerateStatsOptions;
26
27/// Information about the distinct values in a float array.
28#[derive(Debug, Clone)]
29pub struct DistinctInfo<T> {
30    /// The set of distinct float values.
31    distinct_values: HashSet<NativeValue<T>, FxBuildHasher>,
32    /// The count of unique values. This _must_ be non-zero.
33    distinct_count: u32,
34}
35
36impl<T> DistinctInfo<T> {
37    /// Returns a reference to the distinct values set.
38    pub fn distinct_values(&self) -> &HashSet<NativeValue<T>, FxBuildHasher> {
39        &self.distinct_values
40    }
41}
42
43/// Typed statistics for a specific float type.
44#[derive(Debug, Clone)]
45pub struct TypedStats<T> {
46    /// Distinct value information, or `None` if not computed.
47    distinct: Option<DistinctInfo<T>>,
48}
49
50impl<T> TypedStats<T> {
51    /// Returns the distinct value information, if computed.
52    pub fn distinct(&self) -> Option<&DistinctInfo<T>> {
53        self.distinct.as_ref()
54    }
55}
56
57/// Type-erased container for one of the [`TypedStats`] variants.
58#[derive(Debug, Clone)]
59pub enum ErasedStats {
60    /// Stats for `f16` arrays.
61    F16(TypedStats<f16>),
62    /// Stats for `f32` arrays.
63    F32(TypedStats<f32>),
64    /// Stats for `f64` arrays.
65    F64(TypedStats<f64>),
66}
67
68impl ErasedStats {
69    /// Get the count of distinct values, if we have computed it already.
70    fn distinct_count(&self) -> Option<u32> {
71        match self {
72            ErasedStats::F16(x) => x.distinct.as_ref().map(|d| d.distinct_count),
73            ErasedStats::F32(x) => x.distinct.as_ref().map(|d| d.distinct_count),
74            ErasedStats::F64(x) => x.distinct.as_ref().map(|d| d.distinct_count),
75        }
76    }
77}
78
79/// Implements `From<TypedStats<$T>>` for [`ErasedStats`].
80macro_rules! impl_from_typed {
81    ($T:ty, $variant:path) => {
82        impl From<TypedStats<$T>> for ErasedStats {
83            fn from(typed: TypedStats<$T>) -> Self {
84                $variant(typed)
85            }
86        }
87    };
88}
89
90impl_from_typed!(f16, ErasedStats::F16);
91impl_from_typed!(f32, ErasedStats::F32);
92impl_from_typed!(f64, ErasedStats::F64);
93
94/// Array of floating-point numbers and relevant stats for compression.
95#[derive(Debug, Clone)]
96pub struct FloatStats {
97    /// Cache for `validity.false_count()`.
98    null_count: u32,
99    /// Cache for `validity.true_count()`.
100    value_count: u32,
101    /// The average run length.
102    average_run_length: u32,
103    /// Type-erased typed statistics.
104    erased: ErasedStats,
105}
106
107impl FloatStats {
108    /// Generates stats, returning an error on failure.
109    fn generate_opts_fallible(
110        input: &PrimitiveArray,
111        opts: GenerateStatsOptions,
112    ) -> VortexResult<Self> {
113        match input.ptype() {
114            PType::F16 => typed_float_stats::<f16>(input, opts.count_distinct_values),
115            PType::F32 => typed_float_stats::<f32>(input, opts.count_distinct_values),
116            PType::F64 => typed_float_stats::<f64>(input, opts.count_distinct_values),
117            _ => vortex_panic!("cannot generate FloatStats from ptype {}", input.ptype()),
118        }
119    }
120
121    /// Get the count of distinct values, if we have computed it already.
122    pub fn distinct_count(&self) -> Option<u32> {
123        self.erased.distinct_count()
124    }
125}
126
127impl FloatStats {
128    /// Generates stats with default options.
129    pub fn generate(input: &PrimitiveArray) -> Self {
130        Self::generate_opts(input, GenerateStatsOptions::default())
131    }
132
133    /// Generates stats with provided options.
134    pub fn generate_opts(input: &PrimitiveArray, opts: GenerateStatsOptions) -> Self {
135        Self::generate_opts_fallible(input, opts)
136            .vortex_expect("FloatStats::generate_opts should not fail")
137    }
138
139    /// Returns the number of null values.
140    pub fn null_count(&self) -> u32 {
141        self.null_count
142    }
143
144    /// Returns the number of non-null values.
145    pub fn value_count(&self) -> u32 {
146        self.value_count
147    }
148
149    /// Returns the average run length.
150    pub fn average_run_length(&self) -> u32 {
151        self.average_run_length
152    }
153
154    /// Returns the type-erased typed statistics.
155    pub fn erased(&self) -> &ErasedStats {
156        &self.erased
157    }
158}
159
160/// Computes typed float statistics for a specific float type.
161fn typed_float_stats<T: NativePType + Float>(
162    array: &PrimitiveArray,
163    count_distinct_values: bool,
164) -> VortexResult<FloatStats>
165where
166    NativeValue<T>: Hash + Eq,
167    TypedStats<T>: Into<ErasedStats>,
168{
169    // Special case: empty array.
170    if array.is_empty() {
171        return Ok(FloatStats {
172            null_count: 0,
173            value_count: 0,
174            average_run_length: 0,
175            erased: TypedStats { distinct: None }.into(),
176        });
177    }
178
179    let mut ctx = LEGACY_SESSION.create_execution_ctx();
180    if array.all_invalid(&mut ctx)? {
181        return Ok(FloatStats {
182            null_count: u32::try_from(array.len())?,
183            value_count: 0,
184            average_run_length: 0,
185            erased: TypedStats {
186                distinct: Some(DistinctInfo {
187                    distinct_values: HashSet::with_capacity_and_hasher(0, FxBuildHasher),
188                    distinct_count: 0,
189                }),
190            }
191            .into(),
192        });
193    }
194
195    let null_count = array
196        .statistics()
197        .compute_null_count(&mut ctx)
198        .ok_or_else(|| vortex_err!("Failed to compute null_count"))?;
199    let value_count = array.len() - null_count;
200
201    // Keep a HashMap of T, then convert the keys into PValue afterward since value is
202    // so much more efficient to hash and search for.
203    let mut distinct_values = if count_distinct_values {
204        HashSet::with_capacity_and_hasher(array.len() / 2, FxBuildHasher)
205    } else {
206        HashSet::with_hasher(FxBuildHasher)
207    };
208
209    let validity = array.as_ref().validity()?.to_mask(
210        array.as_ref().len(),
211        &mut LEGACY_SESSION.create_execution_ctx(),
212    )?;
213
214    let mut runs = 1;
215    let head_idx = validity
216        .first()
217        .vortex_expect("All null masks have been handled before");
218    let buff = array.to_buffer::<T>();
219    let mut prev = buff[head_idx];
220
221    let first_valid_buff = buff.slice(head_idx..array.len());
222    match validity.bit_buffer() {
223        AllOr::All => {
224            for value in first_valid_buff {
225                if count_distinct_values {
226                    distinct_values.insert(NativeValue(value));
227                }
228
229                if value != prev {
230                    prev = value;
231                    runs += 1;
232                }
233            }
234        }
235        AllOr::None => unreachable!("All invalid arrays have been handled earlier"),
236        AllOr::Some(v) => {
237            for (&value, valid) in first_valid_buff
238                .iter()
239                .zip_eq(v.slice(head_idx..array.len()).iter())
240            {
241                if valid {
242                    if count_distinct_values {
243                        distinct_values.insert(NativeValue(value));
244                    }
245
246                    if value != prev {
247                        prev = value;
248                        runs += 1;
249                    }
250                }
251            }
252        }
253    }
254
255    let null_count = u32::try_from(null_count)?;
256    let value_count = u32::try_from(value_count)?;
257
258    let distinct = count_distinct_values.then(|| DistinctInfo {
259        distinct_count: u32::try_from(distinct_values.len())
260            .vortex_expect("more than u32::MAX distinct values"),
261        distinct_values,
262    });
263
264    Ok(FloatStats {
265        null_count,
266        value_count,
267        average_run_length: value_count / runs,
268        erased: TypedStats { distinct }.into(),
269    })
270}
271
272#[cfg(test)]
273mod tests {
274    use vortex_array::IntoArray;
275    use vortex_array::ToCanonical;
276    use vortex_array::arrays::PrimitiveArray;
277    use vortex_array::validity::Validity;
278    use vortex_buffer::buffer;
279
280    use super::FloatStats;
281
282    #[test]
283    fn test_float_stats() {
284        let floats = buffer![0.0f32, 1.0f32, 2.0f32].into_array();
285        let floats = floats.to_primitive();
286
287        let stats = FloatStats::generate_opts(
288            &floats,
289            crate::stats::GenerateStatsOptions {
290                count_distinct_values: true,
291            },
292        );
293
294        assert_eq!(stats.value_count, 3);
295        assert_eq!(stats.null_count, 0);
296        assert_eq!(stats.average_run_length, 1);
297        assert_eq!(stats.distinct_count().unwrap(), 3);
298    }
299
300    #[test]
301    fn test_float_stats_leading_nulls() {
302        let floats = PrimitiveArray::new(
303            buffer![0.0f32, 1.0f32, 2.0f32],
304            Validity::from_iter([false, true, true]),
305        );
306
307        let stats = FloatStats::generate_opts(
308            &floats,
309            crate::stats::GenerateStatsOptions {
310                count_distinct_values: true,
311            },
312        );
313
314        assert_eq!(stats.value_count, 2);
315        assert_eq!(stats.null_count, 1);
316        assert_eq!(stats.average_run_length, 1);
317        assert_eq!(stats.distinct_count().unwrap(), 2);
318    }
319}