1use std::hash::Hash;
5
6use itertools::Itertools;
7use num_traits::Float;
8use rustc_hash::FxBuildHasher;
9use vortex_array::ToCanonical;
10use vortex_array::arrays::{NativeValue, PrimitiveArray, PrimitiveVTable};
11use vortex_dtype::half::f16;
12use vortex_dtype::{NativePType, PType};
13use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
14use vortex_mask::AllOr;
15use vortex_utils::aliases::hash_set::HashSet;
16
17use crate::rle::RLEStats;
18use crate::sample::sample;
19use crate::{CompressorStats, GenerateStatsOptions};
20
21#[derive(Debug, Clone)]
22pub struct DistinctValues<T> {
23 pub values: HashSet<NativeValue<T>, FxBuildHasher>,
24}
25
26#[derive(Debug, Clone)]
27pub enum ErasedDistinctValues {
28 F16(DistinctValues<f16>),
29 F32(DistinctValues<f32>),
30 F64(DistinctValues<f64>),
31}
32
33macro_rules! impl_from_typed {
34 ($typ:ty, $variant:path) => {
35 impl From<DistinctValues<$typ>> for ErasedDistinctValues {
36 fn from(value: DistinctValues<$typ>) -> Self {
37 $variant(value)
38 }
39 }
40 };
41}
42
43impl_from_typed!(f16, ErasedDistinctValues::F16);
44impl_from_typed!(f32, ErasedDistinctValues::F32);
45impl_from_typed!(f64, ErasedDistinctValues::F64);
46
47#[derive(Debug, Clone)]
49pub struct FloatStats {
50 pub(super) src: PrimitiveArray,
51 pub(super) null_count: u32,
53 pub(super) value_count: u32,
55 #[allow(dead_code)]
56 pub(super) average_run_length: u32,
57 pub(super) distinct_values: ErasedDistinctValues,
58 pub(super) distinct_values_count: u32,
59}
60
61impl CompressorStats for FloatStats {
62 type ArrayVTable = PrimitiveVTable;
63
64 fn generate_opts(input: &PrimitiveArray, opts: GenerateStatsOptions) -> Self {
65 match input.ptype() {
66 PType::F16 => typed_float_stats::<f16>(input, opts.count_distinct_values),
67 PType::F32 => typed_float_stats::<f32>(input, opts.count_distinct_values),
68 PType::F64 => typed_float_stats::<f64>(input, opts.count_distinct_values),
69 _ => vortex_panic!("cannot generate FloatStats from ptype {}", input.ptype()),
70 }
71 }
72
73 fn source(&self) -> &PrimitiveArray {
74 &self.src
75 }
76
77 fn sample_opts(&self, sample_size: u32, sample_count: u32, opts: GenerateStatsOptions) -> Self {
78 let sampled = sample(self.src.as_ref(), sample_size, sample_count).to_primitive();
79
80 Self::generate_opts(&sampled, opts)
81 }
82}
83
84impl RLEStats for FloatStats {
85 fn value_count(&self) -> u32 {
86 self.value_count
87 }
88
89 fn average_run_length(&self) -> u32 {
90 self.average_run_length
91 }
92
93 fn source(&self) -> &PrimitiveArray {
94 &self.src
95 }
96}
97
98fn typed_float_stats<T: NativePType + Float>(
99 array: &PrimitiveArray,
100 count_distinct_values: bool,
101) -> FloatStats
102where
103 DistinctValues<T>: Into<ErasedDistinctValues>,
104 NativeValue<T>: Hash + Eq,
105{
106 if array.is_empty() {
108 return FloatStats {
109 src: array.clone(),
110 null_count: 0,
111 value_count: 0,
112 average_run_length: 0,
113 distinct_values_count: 0,
114 distinct_values: DistinctValues {
115 values: HashSet::<NativeValue<T>, FxBuildHasher>::with_hasher(FxBuildHasher),
116 }
117 .into(),
118 };
119 } else if array.all_invalid() {
120 return FloatStats {
121 src: array.clone(),
122 null_count: array.len().try_into().vortex_expect("null_count"),
123 value_count: 0,
124 average_run_length: 0,
125 distinct_values_count: 0,
126 distinct_values: DistinctValues {
127 values: HashSet::<NativeValue<T>, FxBuildHasher>::with_hasher(FxBuildHasher),
128 }
129 .into(),
130 };
131 }
132
133 let null_count = array
134 .statistics()
135 .compute_null_count()
136 .vortex_expect("null count");
137 let value_count = array.len() - null_count;
138
139 let mut distinct_values = if count_distinct_values {
142 HashSet::with_capacity_and_hasher(array.len() / 2, FxBuildHasher)
143 } else {
144 HashSet::with_hasher(FxBuildHasher)
145 };
146
147 let validity = array.validity_mask();
148
149 let mut runs = 1;
150 let head_idx = validity
151 .first()
152 .vortex_expect("All null masks have been handled before");
153 let buff = array.buffer::<T>();
154 let mut prev = buff[head_idx];
155
156 let first_valid_buff = buff.slice(head_idx..array.len());
157 match validity.bit_buffer() {
158 AllOr::All => {
159 for value in first_valid_buff {
160 if count_distinct_values {
161 distinct_values.insert(NativeValue(value));
162 }
163
164 if value != prev {
165 prev = value;
166 runs += 1;
167 }
168 }
169 }
170 AllOr::None => unreachable!("All invalid arrays have been handled earlier"),
171 AllOr::Some(v) => {
172 for (&value, valid) in first_valid_buff
173 .iter()
174 .zip_eq(v.slice(head_idx..array.len()).iter())
175 {
176 if valid {
177 if count_distinct_values {
178 distinct_values.insert(NativeValue(value));
179 }
180
181 if value != prev {
182 prev = value;
183 runs += 1;
184 }
185 }
186 }
187 }
188 }
189
190 let null_count = null_count
191 .try_into()
192 .vortex_expect("null_count must fit in u32");
193 let value_count = value_count
194 .try_into()
195 .vortex_expect("null_count must fit in u32");
196 let distinct_values_count = if count_distinct_values {
197 distinct_values.len().try_into().vortex_unwrap()
198 } else {
199 u32::MAX
200 };
201
202 FloatStats {
203 null_count,
204 value_count,
205 distinct_values_count,
206 src: array.clone(),
207 average_run_length: value_count / runs,
208 distinct_values: DistinctValues {
209 values: distinct_values,
210 }
211 .into(),
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use vortex_array::arrays::PrimitiveArray;
218 use vortex_array::validity::Validity;
219 use vortex_array::{IntoArray, ToCanonical};
220 use vortex_buffer::buffer;
221
222 use crate::CompressorStats;
223 use crate::float::stats::FloatStats;
224
225 #[test]
226 fn test_float_stats() {
227 let floats = buffer![0.0f32, 1.0f32, 2.0f32].into_array();
228 let floats = floats.to_primitive();
229
230 let stats = FloatStats::generate(&floats);
231
232 assert_eq!(stats.value_count, 3);
233 assert_eq!(stats.null_count, 0);
234 assert_eq!(stats.average_run_length, 1);
235 assert_eq!(stats.distinct_values_count, 3);
236 }
237
238 #[test]
239 fn test_float_stats_leading_nulls() {
240 let floats = PrimitiveArray::new(
241 buffer![0.0f32, 1.0f32, 2.0f32],
242 Validity::from_iter([false, true, true]),
243 );
244
245 let stats = FloatStats::generate(&floats);
246
247 assert_eq!(stats.value_count, 2);
248 assert_eq!(stats.null_count, 1);
249 assert_eq!(stats.average_run_length, 1);
250 assert_eq!(stats.distinct_values_count, 2);
251 }
252}