1use std::hash::Hash;
5
6use itertools::Itertools;
7use num_traits::Float;
8use rustc_hash::FxBuildHasher;
9use vortex_array::ToCanonical;
10use vortex_array::arrays::{NativeValue, PrimitiveArray, PrimitiveVTable};
11use vortex_dtype::half::f16;
12use vortex_dtype::{NativePType, PType};
13use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
14use vortex_mask::AllOr;
15use vortex_utils::aliases::hash_set::HashSet;
16
17use crate::sample::sample;
18use crate::{CompressorStats, GenerateStatsOptions};
19
20#[derive(Debug, Clone)]
21pub struct DistinctValues<T> {
22 pub values: HashSet<NativeValue<T>, FxBuildHasher>,
23}
24
25#[derive(Debug, Clone)]
26pub enum ErasedDistinctValues {
27 F16(DistinctValues<f16>),
28 F32(DistinctValues<f32>),
29 F64(DistinctValues<f64>),
30}
31
32macro_rules! impl_from_typed {
33 ($typ:ty, $variant:path) => {
34 impl From<DistinctValues<$typ>> for ErasedDistinctValues {
35 fn from(value: DistinctValues<$typ>) -> Self {
36 $variant(value)
37 }
38 }
39 };
40}
41
42impl_from_typed!(f16, ErasedDistinctValues::F16);
43impl_from_typed!(f32, ErasedDistinctValues::F32);
44impl_from_typed!(f64, ErasedDistinctValues::F64);
45
46#[derive(Debug, Clone)]
48pub struct FloatStats {
49 pub(super) src: PrimitiveArray,
50 pub(super) null_count: u32,
52 pub(super) value_count: u32,
54 #[allow(dead_code)]
55 pub(super) average_run_length: u32,
56 pub(super) distinct_values: ErasedDistinctValues,
57 pub(super) distinct_values_count: u32,
58}
59
60impl CompressorStats for FloatStats {
61 type ArrayVTable = PrimitiveVTable;
62
63 fn generate_opts(input: &PrimitiveArray, opts: GenerateStatsOptions) -> Self {
64 match input.ptype() {
65 PType::F16 => typed_float_stats::<f16>(input, opts.count_distinct_values),
66 PType::F32 => typed_float_stats::<f32>(input, opts.count_distinct_values),
67 PType::F64 => typed_float_stats::<f64>(input, opts.count_distinct_values),
68 _ => vortex_panic!("cannot generate FloatStats from ptype {}", input.ptype()),
69 }
70 }
71
72 fn source(&self) -> &PrimitiveArray {
73 &self.src
74 }
75
76 fn sample_opts(&self, sample_size: u32, sample_count: u32, opts: GenerateStatsOptions) -> Self {
77 let sampled = sample(self.src.as_ref(), sample_size, sample_count).to_primitive();
78
79 Self::generate_opts(&sampled, opts)
80 }
81}
82
83fn typed_float_stats<T: NativePType + Float>(
84 array: &PrimitiveArray,
85 count_distinct_values: bool,
86) -> FloatStats
87where
88 DistinctValues<T>: Into<ErasedDistinctValues>,
89 NativeValue<T>: Hash + Eq,
90{
91 if array.is_empty() {
93 return FloatStats {
94 src: array.clone(),
95 null_count: 0,
96 value_count: 0,
97 average_run_length: 0,
98 distinct_values_count: 0,
99 distinct_values: DistinctValues {
100 values: HashSet::<NativeValue<T>, FxBuildHasher>::with_hasher(FxBuildHasher),
101 }
102 .into(),
103 };
104 } else if array.all_invalid() {
105 return FloatStats {
106 src: array.clone(),
107 null_count: array.len().try_into().vortex_expect("null_count"),
108 value_count: 0,
109 average_run_length: 0,
110 distinct_values_count: 0,
111 distinct_values: DistinctValues {
112 values: HashSet::<NativeValue<T>, FxBuildHasher>::with_hasher(FxBuildHasher),
113 }
114 .into(),
115 };
116 }
117
118 let null_count = array
119 .statistics()
120 .compute_null_count()
121 .vortex_expect("null count");
122 let value_count = array.len() - null_count;
123
124 let mut distinct_values = if count_distinct_values {
127 HashSet::with_capacity_and_hasher(array.len() / 2, FxBuildHasher)
128 } else {
129 HashSet::with_hasher(FxBuildHasher)
130 };
131
132 let validity = array.validity_mask();
133
134 let mut runs = 1;
135 let head_idx = validity
136 .first()
137 .vortex_expect("All null masks have been handled before");
138 let buff = array.buffer::<T>();
139 let mut prev = buff[head_idx];
140
141 let first_valid_buff = buff.slice(head_idx..array.len());
142 match validity.boolean_buffer() {
143 AllOr::All => {
144 for value in first_valid_buff {
145 if count_distinct_values {
146 distinct_values.insert(NativeValue(value));
147 }
148
149 if value != prev {
150 prev = value;
151 runs += 1;
152 }
153 }
154 }
155 AllOr::None => unreachable!("All invalid arrays have been handled earlier"),
156 AllOr::Some(v) => {
157 for (&value, valid) in first_valid_buff
158 .iter()
159 .zip_eq(v.slice(head_idx, array.len() - head_idx).iter())
160 {
161 if valid {
162 if count_distinct_values {
163 distinct_values.insert(NativeValue(value));
164 }
165
166 if value != prev {
167 prev = value;
168 runs += 1;
169 }
170 }
171 }
172 }
173 }
174
175 let null_count = null_count
176 .try_into()
177 .vortex_expect("null_count must fit in u32");
178 let value_count = value_count
179 .try_into()
180 .vortex_expect("null_count must fit in u32");
181 let distinct_values_count = if count_distinct_values {
182 distinct_values.len().try_into().vortex_unwrap()
183 } else {
184 u32::MAX
185 };
186
187 FloatStats {
188 null_count,
189 value_count,
190 distinct_values_count,
191 src: array.clone(),
192 average_run_length: value_count / runs,
193 distinct_values: DistinctValues {
194 values: distinct_values,
195 }
196 .into(),
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use vortex_array::arrays::PrimitiveArray;
203 use vortex_array::validity::Validity;
204 use vortex_array::{IntoArray, ToCanonical};
205 use vortex_buffer::buffer;
206
207 use crate::CompressorStats;
208 use crate::float::stats::FloatStats;
209
210 #[test]
211 fn test_float_stats() {
212 let floats = buffer![0.0f32, 1.0f32, 2.0f32].into_array();
213 let floats = floats.to_primitive();
214
215 let stats = FloatStats::generate(&floats);
216
217 assert_eq!(stats.value_count, 3);
218 assert_eq!(stats.null_count, 0);
219 assert_eq!(stats.average_run_length, 1);
220 assert_eq!(stats.distinct_values_count, 3);
221 }
222
223 #[test]
224 fn test_float_stats_leading_nulls() {
225 let floats = PrimitiveArray::new(
226 buffer![0.0f32, 1.0f32, 2.0f32],
227 Validity::from_iter([false, true, true]),
228 );
229
230 let stats = FloatStats::generate(&floats);
231
232 assert_eq!(stats.value_count, 2);
233 assert_eq!(stats.null_count, 1);
234 assert_eq!(stats.average_run_length, 1);
235 assert_eq!(stats.distinct_values_count, 2);
236 }
237}