use std::cmp::Ordering;
use std::collections::HashMap;
use std::mem::size_of;
use arrow_buffer::buffer::BooleanBuffer;
use num_traits::PrimInt;
use vortex_dtype::half::f16;
use vortex_dtype::{match_each_native_ptype, DType, NativePType, Nullability};
use vortex_error::VortexResult;
use vortex_scalar::Scalar;
use crate::array::primitive::PrimitiveArray;
use crate::stats::{ArrayStatisticsCompute, Stat, StatsSet};
use crate::validity::{ArrayValidity, LogicalValidity};
use crate::{ArrayDType, IntoArrayVariant};
trait PStatsType: NativePType + Into<Scalar> + BitWidth {}
impl<T: NativePType + Into<Scalar> + BitWidth> PStatsType for T {}
impl ArrayStatisticsCompute for PrimitiveArray {
fn compute_statistics(&self, stat: Stat) -> VortexResult<StatsSet> {
match_each_native_ptype!(self.ptype(), |$P| {
match self.logical_validity() {
LogicalValidity::AllValid(_) => self.maybe_null_slice::<$P>().compute_statistics(stat),
LogicalValidity::AllInvalid(v) => Ok(StatsSet::nulls(v, self.dtype())),
LogicalValidity::Array(a) => NullableValues(
self.maybe_null_slice::<$P>(),
&a.clone().into_bool()?.boolean_buffer(),
)
.compute_statistics(stat),
}
})
}
}
impl<T: PStatsType> ArrayStatisticsCompute for &[T] {
fn compute_statistics(&self, _stat: Stat) -> VortexResult<StatsSet> {
if self.is_empty() {
return Ok(StatsSet::new());
}
let mut stats = StatsAccumulator::new(self[0]);
self.iter().skip(1).for_each(|next| stats.next(*next));
Ok(stats.finish())
}
}
struct NullableValues<'a, T: PStatsType>(&'a [T], &'a BooleanBuffer);
impl<'a, T: PStatsType> ArrayStatisticsCompute for NullableValues<'a, T> {
fn compute_statistics(&self, _stat: Stat) -> VortexResult<StatsSet> {
let values = self.0;
if values.is_empty() {
return Ok(StatsSet::new());
}
let first_non_null_idx = self
.1
.iter()
.enumerate()
.skip_while(|(_, valid)| !*valid)
.map(|(idx, _)| idx)
.next();
if let Some(first_non_null) = first_non_null_idx {
let mut acc = StatsAccumulator::new(values[first_non_null]);
acc.n_nulls(first_non_null);
self.0
.iter()
.zip(self.1.iter())
.skip(first_non_null + 1)
.map(|(next, valid)| valid.then_some(*next))
.for_each(|next| acc.nullable_next(next));
Ok(acc.finish())
} else {
Ok(StatsSet::nulls(
self.0.len(),
&DType::Primitive(T::PTYPE, Nullability::Nullable),
))
}
}
}
trait BitWidth {
fn bit_width(self) -> u32;
fn trailing_zeros(self) -> u32;
}
macro_rules! int_bit_width {
($T:ty) => {
impl BitWidth for $T {
fn bit_width(self) -> u32 {
Self::BITS - PrimInt::leading_zeros(self)
}
fn trailing_zeros(self) -> u32 {
PrimInt::trailing_zeros(self)
}
}
};
}
int_bit_width!(u8);
int_bit_width!(u16);
int_bit_width!(u32);
int_bit_width!(u64);
int_bit_width!(i8);
int_bit_width!(i16);
int_bit_width!(i32);
int_bit_width!(i64);
macro_rules! float_bit_width {
($T:ty) => {
impl BitWidth for $T {
fn bit_width(self) -> u32 {
(size_of::<Self>() * 8) as u32
}
fn trailing_zeros(self) -> u32 {
0
}
}
};
}
float_bit_width!(f16);
float_bit_width!(f32);
float_bit_width!(f64);
struct StatsAccumulator<T: PStatsType> {
prev: T,
min: T,
max: T,
is_sorted: bool,
is_strict_sorted: bool,
run_count: usize,
null_count: usize,
nan_count: usize,
bit_widths: Vec<usize>,
trailing_zeros: Vec<usize>,
len: usize,
}
impl<T: PStatsType> StatsAccumulator<T> {
fn new(first_value: T) -> Self {
let mut stats = Self {
prev: first_value,
min: first_value,
max: first_value,
is_sorted: true,
is_strict_sorted: true,
run_count: 1,
null_count: 0,
bit_widths: vec![0; size_of::<T>() * 8 + 1],
trailing_zeros: vec![0; size_of::<T>() * 8 + 1],
len: 1,
nan_count: first_value.is_nan().then_some(1).unwrap_or_default(),
};
stats.bit_widths[first_value.bit_width() as usize] += 1;
stats.trailing_zeros[first_value.trailing_zeros() as usize] += 1;
stats
}
fn n_nulls(&mut self, n_nulls: usize) {
self.null_count += n_nulls;
self.bit_widths[0] += n_nulls;
self.trailing_zeros[T::PTYPE.bit_width()] += n_nulls;
self.len += n_nulls;
}
pub fn nullable_next(&mut self, next: Option<T>) {
match next {
Some(n) => self.next(n),
None => {
self.bit_widths[0] += 1;
self.trailing_zeros[T::PTYPE.bit_width()] += 1;
self.null_count += 1;
self.len += 1;
}
}
}
pub fn next(&mut self, next: T) {
self.bit_widths[next.bit_width() as usize] += 1;
self.trailing_zeros[next.trailing_zeros() as usize] += 1;
self.len += 1;
if next.is_nan() {
self.nan_count += 1;
}
if next.is_eq(self.prev) {
self.is_strict_sorted = false;
} else {
if matches!(next.compare(self.prev), Ordering::Less) {
self.is_sorted = false;
}
self.run_count += 1;
}
if matches!(next.compare(self.min), Ordering::Less) {
self.min = next;
} else if matches!(next.compare(self.max), Ordering::Greater) {
self.max = next;
}
self.prev = next;
}
pub fn finish(self) -> StatsSet {
let is_constant = (self.min == self.max && self.null_count == 0 && self.nan_count == 0)
|| self.null_count == self.len;
StatsSet::from(HashMap::from([
(Stat::Min, self.min.into()),
(Stat::Max, self.max.into()),
(Stat::NullCount, self.null_count.into()),
(Stat::IsConstant, is_constant.into()),
(Stat::BitWidthFreq, self.bit_widths.into()),
(Stat::TrailingZeroFreq, self.trailing_zeros.into()),
(Stat::IsSorted, self.is_sorted.into()),
(
Stat::IsStrictSorted,
(self.is_sorted && self.is_strict_sorted).into(),
),
(Stat::RunCount, self.run_count.into()),
]))
}
}
#[cfg(test)]
mod test {
use vortex_dtype::{DType, Nullability, PType};
use vortex_scalar::Scalar;
use crate::array::primitive::PrimitiveArray;
use crate::stats::{ArrayStatistics, Stat};
#[test]
fn stats() {
let arr = PrimitiveArray::from(vec![1, 2, 3, 4, 5]);
let min: i32 = arr.statistics().compute_min().unwrap();
let max: i32 = arr.statistics().compute_max().unwrap();
let is_sorted = arr.statistics().compute_is_sorted().unwrap();
let is_strict_sorted = arr.statistics().compute_is_strict_sorted().unwrap();
let is_constant = arr.statistics().compute_is_constant().unwrap();
let bit_width_freq = arr.statistics().compute_bit_width_freq().unwrap();
let trailing_zeros_freq = arr.statistics().compute_trailing_zero_freq().unwrap();
let run_count = arr.statistics().compute_run_count().unwrap();
assert_eq!(min, 1);
assert_eq!(max, 5);
assert!(is_sorted);
assert!(is_strict_sorted);
assert!(!is_constant);
assert_eq!(
bit_width_freq,
vec![
0usize, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
]
);
assert_eq!(
trailing_zeros_freq,
vec![
3usize, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
]
);
assert_eq!(run_count, 5);
}
#[test]
fn stats_u8() {
let arr = PrimitiveArray::from(vec![1u8, 2, 3, 4, 5]);
let min: u8 = arr.statistics().compute_min().unwrap();
let max: u8 = arr.statistics().compute_max().unwrap();
assert_eq!(min, 1);
assert_eq!(max, 5);
}
#[test]
fn nullable_stats_u8() {
let arr = PrimitiveArray::from_nullable_vec(vec![None, None, Some(1i32), Some(2), None]);
let min: i32 = arr.statistics().compute_min().unwrap();
let max: i32 = arr.statistics().compute_max().unwrap();
let null_count: usize = arr.statistics().compute_null_count().unwrap();
let is_strict_sorted: bool = arr.statistics().compute_is_strict_sorted().unwrap();
assert_eq!(min, 1);
assert_eq!(max, 2);
assert_eq!(null_count, 3);
assert!(is_strict_sorted);
}
#[test]
fn all_null() {
let arr = PrimitiveArray::from_nullable_vec(vec![Option::<i32>::None, None, None]);
let min: Option<Scalar> = arr.statistics().compute(Stat::Min);
let max: Option<Scalar> = arr.statistics().compute(Stat::Max);
let null_i32 = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
assert_eq!(min, Some(null_i32.clone()));
assert_eq!(max, Some(null_i32));
}
}