use std::fmt::{Display, Formatter};
use std::hash::Hash;
use enum_iterator::Sequence;
use itertools::Itertools;
pub use statsset::*;
use vortex_dtype::Nullability::NonNullable;
use vortex_dtype::{DType, NativePType};
use vortex_error::{vortex_panic, VortexError, VortexResult};
use vortex_scalar::Scalar;
use crate::Array;
pub mod flatbuffers;
mod statsset;
pub(crate) const PRUNING_STATS: &[Stat] = &[Stat::Min, Stat::Max, Stat::TrueCount, Stat::NullCount];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Sequence)]
#[non_exhaustive]
pub enum Stat {
    BitWidthFreq,
    TrailingZeroFreq,
    IsConstant,
    IsSorted,
    IsStrictSorted,
    Max,
    Min,
    RunCount,
    TrueCount,
    NullCount,
}
impl Stat {
    pub fn is_commutative(&self) -> bool {
        matches!(
            self,
            Stat::BitWidthFreq
                | Stat::TrailingZeroFreq
                | Stat::IsConstant
                | Stat::Max
                | Stat::Min
                | Stat::TrueCount
                | Stat::NullCount
        )
    }
    pub fn has_same_dtype_as_array(&self) -> bool {
        matches!(self, Stat::Min | Stat::Max)
    }
}
impl Display for Stat {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::BitWidthFreq => write!(f, "bit_width_frequency"),
            Self::TrailingZeroFreq => write!(f, "trailing_zero_frequency"),
            Self::IsConstant => write!(f, "is_constant"),
            Self::IsSorted => write!(f, "is_sorted"),
            Self::IsStrictSorted => write!(f, "is_strict_sorted"),
            Self::Max => write!(f, "max"),
            Self::Min => write!(f, "min"),
            Self::RunCount => write!(f, "run_count"),
            Self::TrueCount => write!(f, "true_count"),
            Self::NullCount => write!(f, "null_count"),
        }
    }
}
pub trait Statistics {
    fn get(&self, stat: Stat) -> Option<Scalar>;
    fn to_set(&self) -> StatsSet;
    fn set(&self, stat: Stat, value: Scalar);
    fn compute(&self, stat: Stat) -> Option<Scalar>;
    fn compute_all(&self, stats: &[Stat]) -> VortexResult<StatsSet> {
        for stat in stats {
            let _ = self.compute(*stat);
        }
        Ok(self.to_set())
    }
}
pub trait ArrayStatistics {
    fn statistics(&self) -> &dyn Statistics;
}
pub trait ArrayStatisticsCompute {
    fn compute_statistics(&self, _stat: Stat) -> VortexResult<StatsSet> {
        Ok(StatsSet::new())
    }
}
impl dyn Statistics + '_ {
    pub fn get_as<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
        &self,
        stat: Stat,
    ) -> Option<U> {
        self.get(stat)
            .map(|s| U::try_from(&s))
            .transpose()
            .unwrap_or_else(|err| {
                vortex_panic!(
                    err,
                    "Failed to cast stat {} to {}",
                    stat,
                    std::any::type_name::<U>()
                )
            })
    }
    pub fn get_as_cast<U: NativePType + for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
        &self,
        stat: Stat,
    ) -> Option<U> {
        self.get(stat)
            .filter(|s| s.is_valid())
            .map(|s| s.cast(&DType::Primitive(U::PTYPE, NonNullable)))
            .transpose()
            .and_then(|maybe| maybe.as_ref().map(U::try_from).transpose())
            .unwrap_or_else(|err| {
                vortex_panic!(err, "Failed to cast stat {} to {}", stat, U::PTYPE)
            })
    }
    pub fn compute_as<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
        &self,
        stat: Stat,
    ) -> Option<U> {
        self.compute(stat)
            .map(|s| U::try_from(&s))
            .transpose()
            .unwrap_or_else(|err| {
                vortex_panic!(
                    err,
                    "Failed to compute stat {} as {}",
                    stat,
                    std::any::type_name::<U>()
                )
            })
    }
    pub fn compute_as_cast<U: NativePType + for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
        &self,
        stat: Stat,
    ) -> Option<U> {
        self.compute(stat)
            .filter(|s| s.is_valid())
            .map(|s| s.cast(&DType::Primitive(U::PTYPE, NonNullable)))
            .transpose()
            .and_then(|maybe| maybe.as_ref().map(U::try_from).transpose())
            .unwrap_or_else(|err| {
                vortex_panic!(err, "Failed to compute stat {} as cast {}", stat, U::PTYPE)
            })
    }
    pub fn compute_min<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(&self) -> Option<U> {
        self.compute_as(Stat::Min)
    }
    pub fn compute_max<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(&self) -> Option<U> {
        self.compute_as(Stat::Max)
    }
    pub fn compute_is_strict_sorted(&self) -> Option<bool> {
        self.compute_as(Stat::IsStrictSorted)
    }
    pub fn compute_is_sorted(&self) -> Option<bool> {
        self.compute_as(Stat::IsSorted)
    }
    pub fn compute_is_constant(&self) -> Option<bool> {
        self.compute_as(Stat::IsConstant)
    }
    pub fn compute_true_count(&self) -> Option<usize> {
        self.compute_as(Stat::TrueCount)
    }
    pub fn compute_null_count(&self) -> Option<usize> {
        self.compute_as(Stat::NullCount)
    }
    pub fn compute_run_count(&self) -> Option<usize> {
        self.compute_as(Stat::RunCount)
    }
    pub fn compute_bit_width_freq(&self) -> Option<Vec<usize>> {
        self.compute_as::<Vec<usize>>(Stat::BitWidthFreq)
    }
    pub fn compute_trailing_zero_freq(&self) -> Option<Vec<usize>> {
        self.compute_as::<Vec<usize>>(Stat::TrailingZeroFreq)
    }
}
pub fn trailing_zeros(array: &Array) -> u8 {
    let tz_freq = array
        .statistics()
        .compute_trailing_zero_freq()
        .unwrap_or_else(|| vec![0]);
    tz_freq
        .iter()
        .enumerate()
        .find_or_first(|(_, &v)| v > 0)
        .map(|(i, _)| i)
        .unwrap_or(0) as u8
}
#[cfg(test)]
mod test {
    use enum_iterator::all;
    use crate::array::PrimitiveArray;
    use crate::stats::{ArrayStatistics, Stat};
    #[test]
    fn min_of_nulls_is_not_panic() {
        let min = PrimitiveArray::from_nullable_vec::<i32>(vec![None, None, None, None])
            .statistics()
            .compute_as_cast::<i64>(Stat::Min);
        assert_eq!(min, None);
    }
    #[test]
    fn commutativity() {
        assert!(Stat::BitWidthFreq.is_commutative());
        assert!(Stat::TrailingZeroFreq.is_commutative());
        assert!(Stat::IsConstant.is_commutative());
        assert!(Stat::Min.is_commutative());
        assert!(Stat::Max.is_commutative());
        assert!(Stat::TrueCount.is_commutative());
        assert!(Stat::NullCount.is_commutative());
        assert!(!Stat::IsStrictSorted.is_commutative());
        assert!(!Stat::IsSorted.is_commutative());
        assert!(!Stat::RunCount.is_commutative());
    }
    #[test]
    fn has_same_dtype_as_array() {
        assert!(Stat::Min.has_same_dtype_as_array());
        assert!(Stat::Max.has_same_dtype_as_array());
        for stat in all::<Stat>().filter(|s| !matches!(s, Stat::Min | Stat::Max)) {
            assert!(!stat.has_same_dtype_as_array());
        }
    }
}