use std::collections::HashMap;
use std::mem::size_of;
use arrow_buffer::buffer::BooleanBuffer;
use num_traits::PrimInt;
use vortex_dtype::half::f16;
use vortex_dtype::{match_each_native_ptype, NativePType};
use vortex_error::VortexResult;
use vortex_scalar::Scalar;
use crate::array::primitive::PrimitiveArray;
use crate::stats::{ArrayStatisticsCompute, Stat, StatsSet};
use crate::validity::{ArrayValidity, LogicalValidity};
use crate::{ArrayDType, IntoArrayVariant};
trait PStatsType: NativePType + Into<Scalar> + BitWidth {}
impl<T: NativePType + Into<Scalar> + BitWidth> PStatsType for T {}
impl ArrayStatisticsCompute for PrimitiveArray {
fn compute_statistics(&self, stat: Stat) -> VortexResult<StatsSet> {
match_each_native_ptype!(self.ptype(), |$P| {
match self.logical_validity() {
LogicalValidity::AllValid(_) => self.maybe_null_slice::<$P>().compute_statistics(stat),
LogicalValidity::AllInvalid(v) => Ok(StatsSet::nulls(v, self.dtype())),
LogicalValidity::Array(a) => NullableValues(
self.maybe_null_slice::<$P>(),
&a.clone().into_bool()?.boolean_buffer(),
)
.compute_statistics(stat),
}
})
}
}
impl<T: PStatsType> ArrayStatisticsCompute for &[T] {
fn compute_statistics(&self, _stat: Stat) -> VortexResult<StatsSet> {
if self.is_empty() {
return Ok(StatsSet::new());
}
let mut stats = StatsAccumulator::new(self[0]);
self.iter().skip(1).for_each(|next| stats.next(*next));
Ok(stats.into_map())
}
}
struct NullableValues<'a, T: PStatsType>(&'a [T], &'a BooleanBuffer);
impl<'a, T: PStatsType> ArrayStatisticsCompute for NullableValues<'a, T> {
fn compute_statistics(&self, _stat: Stat) -> VortexResult<StatsSet> {
let values = self.0;
if values.is_empty() {
return Ok(StatsSet::new());
}
let first_non_null_idx = self
.1
.iter()
.enumerate()
.skip_while(|(_, valid)| !*valid)
.map(|(idx, _)| idx)
.next()
.expect("Must be at least one non-null value");
let mut stats = StatsAccumulator::new_with_leading_nulls(
values[first_non_null_idx],
first_non_null_idx,
);
values
.iter()
.zip(self.1.iter())
.skip(first_non_null_idx + 1)
.map(|(next, valid)| valid.then_some(*next))
.for_each(|next| stats.nullable_next(next));
Ok(stats.into_map())
}
}
trait BitWidth {
fn bit_width(self) -> u32;
fn trailing_zeros(self) -> u32;
}
macro_rules! int_bit_width {
($T:ty) => {
impl BitWidth for $T {
fn bit_width(self) -> u32 {
Self::BITS - PrimInt::leading_zeros(self)
}
fn trailing_zeros(self) -> u32 {
PrimInt::trailing_zeros(self)
}
}
};
}
int_bit_width!(u8);
int_bit_width!(u16);
int_bit_width!(u32);
int_bit_width!(u64);
int_bit_width!(i8);
int_bit_width!(i16);
int_bit_width!(i32);
int_bit_width!(i64);
macro_rules! float_bit_width {
($T:ty) => {
impl BitWidth for $T {
fn bit_width(self) -> u32 {
(size_of::<Self>() * 8) as u32
}
fn trailing_zeros(self) -> u32 {
0
}
}
};
}
float_bit_width!(f16);
float_bit_width!(f32);
float_bit_width!(f64);
struct StatsAccumulator<T: PStatsType> {
prev: T,
min: T,
max: T,
is_sorted: bool,
is_strict_sorted: bool,
run_count: usize,
null_count: usize,
bit_widths: Vec<usize>,
trailing_zeros: Vec<usize>,
}
impl<T: PStatsType> StatsAccumulator<T> {
fn new(first_value: T) -> Self {
let mut stats = Self {
prev: first_value,
min: first_value,
max: first_value,
is_sorted: true,
is_strict_sorted: true,
run_count: 1,
null_count: 0,
bit_widths: vec![0; size_of::<T>() * 8 + 1],
trailing_zeros: vec![0; size_of::<T>() * 8 + 1],
};
stats.bit_widths[first_value.bit_width() as usize] += 1;
stats.trailing_zeros[first_value.trailing_zeros() as usize] += 1;
stats
}
fn new_with_leading_nulls(first_value: T, leading_null_count: usize) -> Self {
let mut stats = Self::new(first_value);
stats.null_count += leading_null_count;
stats.bit_widths[0] += leading_null_count;
stats.trailing_zeros[T::PTYPE.bit_width()] += leading_null_count;
stats
}
pub fn nullable_next(&mut self, next: Option<T>) {
match next {
Some(n) => self.next(n),
None => {
self.bit_widths[0] += 1;
self.trailing_zeros[T::PTYPE.bit_width()] += 1;
self.null_count += 1;
}
}
}
pub fn next(&mut self, next: T) {
self.bit_widths[next.bit_width() as usize] += 1;
self.trailing_zeros[next.trailing_zeros() as usize] += 1;
if self.prev == next {
self.is_strict_sorted = false;
} else {
if next < self.prev {
self.is_sorted = false;
}
self.run_count += 1;
}
if next < self.min {
self.min = next;
} else if next > self.max {
self.max = next;
}
self.prev = next;
}
pub fn into_map(self) -> StatsSet {
StatsSet::from(HashMap::from([
(Stat::Min, self.min.into()),
(Stat::Max, self.max.into()),
(Stat::NullCount, self.null_count.into()),
(Stat::IsConstant, (self.min == self.max).into()),
(Stat::BitWidthFreq, self.bit_widths.into()),
(Stat::TrailingZeroFreq, self.trailing_zeros.into()),
(Stat::IsSorted, self.is_sorted.into()),
(
Stat::IsStrictSorted,
(self.is_sorted && self.is_strict_sorted).into(),
),
(Stat::RunCount, self.run_count.into()),
]))
}
}
#[cfg(test)]
mod test {
use vortex_dtype::{DType, Nullability, PType};
use vortex_scalar::Scalar;
use crate::array::primitive::PrimitiveArray;
use crate::stats::{ArrayStatistics, Stat};
#[test]
fn stats() {
let arr = PrimitiveArray::from(vec![1, 2, 3, 4, 5]);
let min: i32 = arr.statistics().compute_min().unwrap();
let max: i32 = arr.statistics().compute_max().unwrap();
let is_sorted = arr.statistics().compute_is_sorted().unwrap();
let is_strict_sorted = arr.statistics().compute_is_strict_sorted().unwrap();
let is_constant = arr.statistics().compute_is_constant().unwrap();
let bit_width_freq = arr.statistics().compute_bit_width_freq().unwrap();
let trailing_zeros_freq = arr.statistics().compute_trailing_zero_freq().unwrap();
let run_count = arr.statistics().compute_run_count().unwrap();
assert_eq!(min, 1);
assert_eq!(max, 5);
assert!(is_sorted);
assert!(is_strict_sorted);
assert!(!is_constant);
assert_eq!(
bit_width_freq,
vec![
0usize, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
]
);
assert_eq!(
trailing_zeros_freq,
vec![
3usize, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
]
);
assert_eq!(run_count, 5);
}
#[test]
fn stats_u8() {
let arr = PrimitiveArray::from(vec![1u8, 2, 3, 4, 5]);
let min: u8 = arr.statistics().compute_min().unwrap();
let max: u8 = arr.statistics().compute_max().unwrap();
assert_eq!(min, 1);
assert_eq!(max, 5);
}
#[test]
fn nullable_stats_u8() {
let arr = PrimitiveArray::from_nullable_vec(vec![None, None, Some(1i32), Some(2), None]);
let min: i32 = arr.statistics().compute_min().unwrap();
let max: i32 = arr.statistics().compute_max().unwrap();
let null_count: usize = arr.statistics().compute_null_count().unwrap();
let is_strict_sorted: bool = arr.statistics().compute_is_strict_sorted().unwrap();
assert_eq!(min, 1);
assert_eq!(max, 2);
assert_eq!(null_count, 3);
assert!(is_strict_sorted);
}
#[test]
fn all_null() {
let arr = PrimitiveArray::from_nullable_vec(vec![Option::<i32>::None, None, None]);
let min: Option<Scalar> = arr.statistics().compute(Stat::Min);
let max: Option<Scalar> = arr.statistics().compute(Stat::Max);
let null_i32 = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
assert_eq!(min, Some(null_i32.clone()));
assert_eq!(max, Some(null_i32));
}
}