1use std::hash::Hash;
5
6use num_traits::PrimInt;
7use rustc_hash::FxBuildHasher;
8use vortex_array::ToCanonical;
9use vortex_array::arrays::{NativeValue, PrimitiveArray, PrimitiveVTable};
10use vortex_array::stats::Stat;
11use vortex_buffer::BitBuffer;
12use vortex_dtype::{IntegerPType, match_each_integer_ptype};
13use vortex_error::{VortexError, VortexExpect, VortexUnwrap};
14use vortex_mask::AllOr;
15use vortex_scalar::{PValue, Scalar};
16use vortex_utils::aliases::hash_map::HashMap;
17
18use crate::rle::RLEStats;
19use crate::sample::sample;
20use crate::{CompressorStats, GenerateStatsOptions};
21
22#[derive(Clone, Debug)]
23pub struct TypedStats<T> {
24 pub min: T,
25 pub max: T,
26 pub top_value: T,
27 pub top_count: u32,
28 pub distinct_values: HashMap<NativeValue<T>, u32, FxBuildHasher>,
29}
30
31#[derive(Clone, Debug)]
36pub enum ErasedStats {
37 U8(TypedStats<u8>),
38 U16(TypedStats<u16>),
39 U32(TypedStats<u32>),
40 U64(TypedStats<u64>),
41 I8(TypedStats<i8>),
42 I16(TypedStats<i16>),
43 I32(TypedStats<i32>),
44 I64(TypedStats<i64>),
45}
46
47impl ErasedStats {
48 pub fn min_is_zero(&self) -> bool {
49 match &self {
50 ErasedStats::U8(x) => x.min == 0,
51 ErasedStats::U16(x) => x.min == 0,
52 ErasedStats::U32(x) => x.min == 0,
53 ErasedStats::U64(x) => x.min == 0,
54 ErasedStats::I8(x) => x.min == 0,
55 ErasedStats::I16(x) => x.min == 0,
56 ErasedStats::I32(x) => x.min == 0,
57 ErasedStats::I64(x) => x.min == 0,
58 }
59 }
60
61 pub fn min_is_negative(&self) -> bool {
62 match &self {
63 ErasedStats::U8(_)
64 | ErasedStats::U16(_)
65 | ErasedStats::U32(_)
66 | ErasedStats::U64(_) => false,
67 ErasedStats::I8(x) => x.min < 0,
68 ErasedStats::I16(x) => x.min < 0,
69 ErasedStats::I32(x) => x.min < 0,
70 ErasedStats::I64(x) => x.min < 0,
71 }
72 }
73
74 pub fn max_minus_min(&self) -> u64 {
76 match &self {
77 ErasedStats::U8(x) => (x.max - x.min) as u64,
78 ErasedStats::U16(x) => (x.max - x.min) as u64,
79 ErasedStats::U32(x) => (x.max - x.min) as u64,
80 ErasedStats::U64(x) => x.max - x.min,
81 ErasedStats::I8(x) => (x.max as i16 - x.min as i16) as u64,
82 ErasedStats::I16(x) => (x.max as i32 - x.min as i32) as u64,
83 ErasedStats::I32(x) => (x.max as i64 - x.min as i64) as u64,
84 ErasedStats::I64(x) => u64::try_from(x.max as i128 - x.min as i128)
85 .vortex_expect("max minus min result bigger than u64"),
86 }
87 }
88
89 pub fn top_value_and_count(&self) -> (PValue, u32) {
91 match &self {
92 ErasedStats::U8(x) => (x.top_value.into(), x.top_count),
93 ErasedStats::U16(x) => (x.top_value.into(), x.top_count),
94 ErasedStats::U32(x) => (x.top_value.into(), x.top_count),
95 ErasedStats::U64(x) => (x.top_value.into(), x.top_count),
96 ErasedStats::I8(x) => (x.top_value.into(), x.top_count),
97 ErasedStats::I16(x) => (x.top_value.into(), x.top_count),
98 ErasedStats::I32(x) => (x.top_value.into(), x.top_count),
99 ErasedStats::I64(x) => (x.top_value.into(), x.top_count),
100 }
101 }
102}
103
104macro_rules! impl_from_typed {
105 ($T:ty, $variant:path) => {
106 impl From<TypedStats<$T>> for ErasedStats {
107 fn from(typed: TypedStats<$T>) -> Self {
108 $variant(typed)
109 }
110 }
111 };
112}
113
114impl_from_typed!(u8, ErasedStats::U8);
115impl_from_typed!(u16, ErasedStats::U16);
116impl_from_typed!(u32, ErasedStats::U32);
117impl_from_typed!(u64, ErasedStats::U64);
118impl_from_typed!(i8, ErasedStats::I8);
119impl_from_typed!(i16, ErasedStats::I16);
120impl_from_typed!(i32, ErasedStats::I32);
121impl_from_typed!(i64, ErasedStats::I64);
122
123#[derive(Clone, Debug)]
125pub struct IntegerStats {
126 pub(super) src: PrimitiveArray,
127 pub(super) null_count: u32,
129 pub(super) value_count: u32,
131 pub(super) average_run_length: u32,
132 pub(super) distinct_values_count: u32,
133 pub(crate) typed: ErasedStats,
134}
135
136impl CompressorStats for IntegerStats {
137 type ArrayVTable = PrimitiveVTable;
138
139 fn generate_opts(input: &PrimitiveArray, opts: GenerateStatsOptions) -> Self {
140 match_each_integer_ptype!(input.ptype(), |T| {
141 typed_int_stats::<T>(input, opts.count_distinct_values)
142 })
143 }
144
145 fn source(&self) -> &PrimitiveArray {
146 &self.src
147 }
148
149 fn sample_opts(&self, sample_size: u32, sample_count: u32, opts: GenerateStatsOptions) -> Self {
150 let sampled = sample(self.src.as_ref(), sample_size, sample_count).to_primitive();
151
152 Self::generate_opts(&sampled, opts)
153 }
154}
155
156impl RLEStats for IntegerStats {
157 fn value_count(&self) -> u32 {
158 self.value_count
159 }
160
161 fn average_run_length(&self) -> u32 {
162 self.average_run_length
163 }
164
165 fn source(&self) -> &PrimitiveArray {
166 &self.src
167 }
168}
169
170fn typed_int_stats<T>(array: &PrimitiveArray, count_distinct_values: bool) -> IntegerStats
171where
172 T: IntegerPType + PrimInt + for<'a> TryFrom<&'a Scalar, Error = VortexError>,
173 TypedStats<T>: Into<ErasedStats>,
174 NativeValue<T>: Eq + Hash,
175{
176 if array.is_empty() {
178 return IntegerStats {
179 src: array.clone(),
180 null_count: 0,
181 value_count: 0,
182 average_run_length: 0,
183 distinct_values_count: 0,
184 typed: TypedStats {
185 min: T::max_value(),
186 max: T::min_value(),
187 top_value: T::default(),
188 top_count: 0,
189 distinct_values: HashMap::with_hasher(FxBuildHasher),
190 }
191 .into(),
192 };
193 } else if array.all_invalid() {
194 return IntegerStats {
195 src: array.clone(),
196 null_count: array.len().try_into().vortex_expect("null_count"),
197 value_count: 0,
198 average_run_length: 0,
199 distinct_values_count: 0,
200 typed: TypedStats {
201 min: T::max_value(),
202 max: T::min_value(),
203 top_value: T::default(),
204 top_count: 0,
205 distinct_values: HashMap::with_hasher(FxBuildHasher),
206 }
207 .into(),
208 };
209 }
210
211 let validity = array.validity_mask();
212 let null_count = validity.false_count();
213 let value_count = validity.true_count();
214
215 let head_idx = validity
217 .first()
218 .vortex_expect("All null masks have been handled before");
219 let buffer = array.buffer::<T>();
220 let head = buffer[head_idx];
221
222 let mut loop_state = LoopState {
223 distinct_values: if count_distinct_values {
224 HashMap::with_capacity_and_hasher(array.len() / 2, FxBuildHasher)
225 } else {
226 HashMap::with_hasher(FxBuildHasher)
227 },
228 prev: head,
229 runs: 1,
230 };
231
232 let sliced = buffer.slice(head_idx..array.len());
233 let mut chunks = sliced.as_slice().chunks_exact(64);
234 match validity.bit_buffer() {
235 AllOr::All => {
236 for chunk in &mut chunks {
237 inner_loop_nonnull(
238 chunk.try_into().vortex_unwrap(),
239 count_distinct_values,
240 &mut loop_state,
241 )
242 }
243 let remainder = chunks.remainder();
244 inner_loop_naive(
245 remainder,
246 count_distinct_values,
247 &BitBuffer::new_set(remainder.len()),
248 &mut loop_state,
249 );
250 }
251 AllOr::None => unreachable!("All invalid arrays have been handled before"),
252 AllOr::Some(v) => {
253 let mask = v.slice(head_idx..array.len());
254 let mut offset = 0;
255 for chunk in &mut chunks {
256 let validity = mask.slice(offset..(offset + 64));
257 offset += 64;
258
259 match validity.true_count() {
260 0 => continue,
262 64 => inner_loop_nonnull(
264 chunk.try_into().vortex_unwrap(),
265 count_distinct_values,
266 &mut loop_state,
267 ),
268 _ => inner_loop_nullable(
270 chunk.try_into().vortex_unwrap(),
271 count_distinct_values,
272 &validity,
273 &mut loop_state,
274 ),
275 }
276 }
277 let remainder = chunks.remainder();
279 inner_loop_naive(
280 remainder,
281 count_distinct_values,
282 &mask.slice(offset..(offset + remainder.len())),
283 &mut loop_state,
284 );
285 }
286 }
287
288 let (top_value, top_count) = if count_distinct_values {
289 let (&top_value, &top_count) = loop_state
290 .distinct_values
291 .iter()
292 .max_by_key(|&(_, &count)| count)
293 .vortex_expect("non-empty");
294 (top_value.0, top_count)
295 } else {
296 (T::default(), 0)
297 };
298
299 let runs = loop_state.runs;
300 let distinct_values_count = if count_distinct_values {
301 loop_state.distinct_values.len().try_into().vortex_unwrap()
302 } else {
303 u32::MAX
304 };
305
306 let min = array
307 .statistics()
308 .compute_as::<T>(Stat::Min)
309 .vortex_expect("min should be computed");
310
311 let max = array
312 .statistics()
313 .compute_as::<T>(Stat::Max)
314 .vortex_expect("max should be computed");
315
316 let typed = TypedStats {
317 min,
318 max,
319 distinct_values: loop_state.distinct_values,
320 top_value,
321 top_count,
322 };
323
324 let null_count = null_count
325 .try_into()
326 .vortex_expect("null_count must fit in u32");
327 let value_count = value_count
328 .try_into()
329 .vortex_expect("value_count must fit in u32");
330
331 IntegerStats {
332 src: array.clone(),
333 null_count,
334 value_count,
335 average_run_length: value_count / runs,
336 distinct_values_count,
337 typed: typed.into(),
338 }
339}
340
341struct LoopState<T> {
342 prev: T,
343 runs: u32,
344 distinct_values: HashMap<NativeValue<T>, u32, FxBuildHasher>,
345}
346
347#[inline(always)]
348fn inner_loop_nonnull<T: IntegerPType>(
349 values: &[T; 64],
350 count_distinct_values: bool,
351 state: &mut LoopState<T>,
352) where
353 NativeValue<T>: Eq + Hash,
354{
355 for &value in values {
356 if count_distinct_values {
357 *state.distinct_values.entry(NativeValue(value)).or_insert(0) += 1;
358 }
359
360 if value != state.prev {
361 state.prev = value;
362 state.runs += 1;
363 }
364 }
365}
366
367#[inline(always)]
368fn inner_loop_nullable<T: IntegerPType>(
369 values: &[T; 64],
370 count_distinct_values: bool,
371 is_valid: &BitBuffer,
372 state: &mut LoopState<T>,
373) where
374 NativeValue<T>: Eq + Hash,
375{
376 for (idx, &value) in values.iter().enumerate() {
377 if is_valid.value(idx) {
378 if count_distinct_values {
379 *state.distinct_values.entry(NativeValue(value)).or_insert(0) += 1;
380 }
381
382 if value != state.prev {
383 state.prev = value;
384 state.runs += 1;
385 }
386 }
387 }
388}
389
390#[inline(always)]
391fn inner_loop_naive<T: IntegerPType>(
392 values: &[T],
393 count_distinct_values: bool,
394 is_valid: &BitBuffer,
395 state: &mut LoopState<T>,
396) where
397 NativeValue<T>: Eq + Hash,
398{
399 for (idx, &value) in values.iter().enumerate() {
400 if is_valid.value(idx) {
401 if count_distinct_values {
402 *state.distinct_values.entry(NativeValue(value)).or_insert(0) += 1;
403 }
404
405 if value != state.prev {
406 state.prev = value;
407 state.runs += 1;
408 }
409 }
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use std::iter;
416
417 use vortex_array::arrays::PrimitiveArray;
418 use vortex_array::validity::Validity;
419 use vortex_buffer::{BitBuffer, Buffer, buffer};
420
421 use crate::CompressorStats;
422 use crate::integer::IntegerStats;
423 use crate::integer::stats::typed_int_stats;
424
425 #[test]
426 fn test_naive_count_distinct_values() {
427 let array = PrimitiveArray::new(buffer![217u8, 0], Validity::NonNullable);
428 let stats = typed_int_stats::<u8>(&array, true);
429 assert_eq!(stats.distinct_values_count, 2);
430 }
431
432 #[test]
433 fn test_naive_count_distinct_values_nullable() {
434 let array = PrimitiveArray::new(
435 buffer![217u8, 0],
436 Validity::from(BitBuffer::from(vec![true, false])),
437 );
438 let stats = typed_int_stats::<u8>(&array, true);
439 assert_eq!(stats.distinct_values_count, 1);
440 }
441
442 #[test]
443 fn test_count_distinct_values() {
444 let array = PrimitiveArray::new((0..128u8).collect::<Buffer<u8>>(), Validity::NonNullable);
445 let stats = typed_int_stats::<u8>(&array, true);
446 assert_eq!(stats.distinct_values_count, 128);
447 }
448
449 #[test]
450 fn test_count_distinct_values_nullable() {
451 let array = PrimitiveArray::new(
452 (0..128u8).collect::<Buffer<u8>>(),
453 Validity::from(BitBuffer::from_iter(
454 iter::repeat_n(vec![true, false], 64).flatten(),
455 )),
456 );
457 let stats = typed_int_stats::<u8>(&array, true);
458 assert_eq!(stats.distinct_values_count, 64);
459 }
460
461 #[test]
462 fn test_integer_stats_leading_nulls() {
463 let ints = PrimitiveArray::new(buffer![0, 1, 2], Validity::from_iter([false, true, true]));
464
465 let stats = IntegerStats::generate(&ints);
466
467 assert_eq!(stats.value_count, 2);
468 assert_eq!(stats.null_count, 1);
469 assert_eq!(stats.average_run_length, 1);
470 assert_eq!(stats.distinct_values_count, 2);
471 }
472}