pub use num_traits::{AsPrimitive, PrimInt, Signed};
use std::fmt::{Binary, Debug, LowerHex};
use super::math::*;
macro_rules! impl_fixed_point_func_unary {
    ($func_name: ident) => {
        #[allow(dead_code)]
        pub fn $func_name(&self) -> Self {
            Self::from_raw($func_name(self.as_raw()))
        }
    };
}
macro_rules! impl_fixed_point_func_binary {
    ($func_name: ident) => {
        pub fn $func_name(&self, b: Self) -> Self {
            Self::from_raw($func_name(self.as_raw(), b.as_raw()))
        }
    };
}
pub type Q0_31 = FixedPoint<i32, 0>;
pub type Q1_30 = FixedPoint<i32, 1>;
pub type Q2_29 = FixedPoint<i32, 2>;
pub type Q5_26 = FixedPoint<i32, 5>;
#[derive(PartialEq, Eq,PartialOrd, Copy, Clone)]
pub struct FixedPoint<T: PrimInt, const INTEGER_BITS: usize>(T);
impl<T, const INTEGER_BITS: usize> FixedPoint<T, INTEGER_BITS>
where
    T: PrimInt,
{
    pub fn from_raw(x: T) -> Self {
        Self(x)
    }
    pub fn one() -> Self {
        if INTEGER_BITS == 0 {
            Self(T::max_value())
        } else {
            Self(T::one() << Self::fractional_bits())
        }
    }
    pub fn fractional_bits() -> usize {
        if Self::is_signed() {
            std::mem::size_of::<T>() * 8 - 1 - INTEGER_BITS
        } else {
            std::mem::size_of::<T>() * 8 - INTEGER_BITS
        }
    }
    #[allow(dead_code)]
    pub fn zero() -> Self {
        Self(T::zero())
    }
    pub fn as_raw(&self) -> T {
        self.0
    }
    pub fn is_signed() -> bool {
        is_signed::<T>()
    }
}
impl<T: 'static, const INTEGER_BITS: usize> FixedPoint<T, INTEGER_BITS>
where
    T: PrimInt + Debug,
    usize: AsPrimitive<T>,
{
    pub fn constant_pot(exponent: isize) -> Self {
        let offset = (Self::fractional_bits() as isize + exponent) as usize;
        assert!(offset < 31);
        Self(1_usize.as_() << offset)
    }
}
impl FixedPoint<i32, 0> {
    impl_fixed_point_func_unary!(exp_on_interval_between_negative_one_quarter_and_0_excl);
    impl_fixed_point_func_unary!(one_over_one_plus_x_for_x_in_0_1);
}
impl FixedPoint<i32, 5> {
    #[allow(dead_code)]
    pub fn exp_on_negative_values(&self) -> FixedPoint<i32, 0> {
        FixedPoint::<i32, 0>::from_raw(exp_on_negative_values(self.as_raw()))
    }
}
impl<const INTEGER_BITS: usize> FixedPoint<i32, INTEGER_BITS> {
    impl_fixed_point_func_unary!(mask_if_non_zero);
    impl_fixed_point_func_unary!(mask_if_zero);
    impl_fixed_point_func_binary!(rounding_half_sum);
    pub fn saturating_rounding_multiply_by_pot(&self, exponent: i32) -> Self {
        Self::from_raw(saturating_rounding_multiply_by_pot(self.as_raw(), exponent))
    }
    #[allow(dead_code)]
    pub fn rounding_divide_by_pot(&self, exponent: i32) -> Self {
        Self::from_raw(rounding_divide_by_pot(self.as_raw(), exponent))
    }
    pub fn select_using_mask(mask: i32, a: Self, b: Self) -> Self {
        Self::from_raw(select_using_mask(mask, a.as_raw(), b.as_raw()))
    }
    pub fn rescale<const DST_INTEGER_BITS: usize>(&self) -> FixedPoint<i32, DST_INTEGER_BITS> {
        FixedPoint::<i32, DST_INTEGER_BITS>::from_raw(rescale(
            self.as_raw(),
            INTEGER_BITS,
            DST_INTEGER_BITS,
        ))
    }
    #[allow(dead_code)]
    pub fn get_reciprocal(&self) -> (FixedPoint<i32, 0>, usize) {
        let (raw_res, num_bits_over_units) = get_reciprocal(self.as_raw(), INTEGER_BITS);
        (FixedPoint::<i32, 0>::from_raw(raw_res), num_bits_over_units)
    }
}
impl<T, const INTEGER_BITS: usize> Debug for FixedPoint<T, INTEGER_BITS>
where
    T: AsPrimitive<f32> + PrimInt + LowerHex + Debug + Binary,
    f32: AsPrimitive<T>,
{
    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
        write!(fmt, "{:032b}({:?})({})", self.0, self.0, self.as_f32())
    }
}
impl<T, const INTEGER_BITS: usize> FixedPoint<T, INTEGER_BITS>
where
    T: AsPrimitive<f32> + PrimInt,
{
    pub fn as_f32(&self) -> f32 {
        self.0.as_() / 2_f32.powi(Self::fractional_bits() as i32)
    }
}
impl<T, const INTEGER_BITS: usize> FixedPoint<T, INTEGER_BITS>
where
    T: AsPrimitive<f32> + PrimInt,
    f32: AsPrimitive<T>,
{
    #[allow(dead_code)]
    pub fn from_f32(x: f32) -> Self {
        Self::from_raw(
            f32::min(
                f32::max(
                    f32::round(x * 2f32.powi(Self::fractional_bits().as_())),
                    T::min_value().as_(),
                ),
                T::max_value().as_(),
            )
            .as_(),
        )
    }
}
impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::Add for FixedPoint<T, INTEGER_BITS> {
    type Output = FixedPoint<T, INTEGER_BITS>;
    fn add(self, rhs: Self) -> Self::Output {
        Self::from_raw(self.0 + rhs.0)
    }
}
impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::Sub for FixedPoint<T, INTEGER_BITS> {
    type Output = FixedPoint<T, INTEGER_BITS>;
    fn sub(self, rhs: Self) -> Self::Output {
        Self::from_raw(self.0 - rhs.0)
    }
}
impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::Shl<usize> for FixedPoint<T, INTEGER_BITS> {
    type Output = FixedPoint<T, INTEGER_BITS>;
    fn shl(self, rhs: usize) -> Self::Output {
        Self::from_raw(self.0 << rhs)
    }
}
impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::Shr<usize> for FixedPoint<T, INTEGER_BITS> {
    type Output = FixedPoint<T, INTEGER_BITS>;
    fn shr(self, rhs: usize) -> Self::Output {
        Self::from_raw(self.0 >> rhs)
    }
}
impl<T: PrimInt, const INTEGER_BITS: usize> std::ops::BitAnd for FixedPoint<T, INTEGER_BITS> {
    type Output = FixedPoint<T, INTEGER_BITS>;
    fn bitand(self, rhs: Self) -> Self::Output {
        Self::from_raw(self.0 & rhs.0)
    }
}
macro_rules! impl_mul {
    ($T: ty, $LHS_INTEGER_BITS: literal, $RHS_INTEGER_BITS: literal, $OUT_INTEGER_BITS: literal) => {
        impl std::ops::Mul<FixedPoint<$T, $RHS_INTEGER_BITS>>
            for FixedPoint<$T, $LHS_INTEGER_BITS>
        {
            type Output = FixedPoint<$T, $OUT_INTEGER_BITS>;
            fn mul(self, rhs: FixedPoint<$T, $RHS_INTEGER_BITS>) -> Self::Output {
                Self::Output::from_raw(saturating_rounding_doubling_high_mul(self.0, rhs.0))
            }
        }
    };
}
impl_mul!(i32, 0, 0, 0);
impl_mul!(i32, 0, 2, 2);
impl_mul!(i32, 2, 0, 2);
impl_mul!(i32, 2, 2, 4);
impl_mul!(i32, 5, 5, 10);
#[cfg(test)]
mod test {
    use super::*;
    use approx::assert_abs_diff_eq;
    pub type Q10_21 = FixedPoint<i32, 10>;
    pub type Q12_19 = FixedPoint<i32, 12>;
    pub type Q26_5 = FixedPoint<i32, 26>;
    type Q0_7 = FixedPoint<i8, 0>;
    #[test]
    fn test_to_f32() {
        let x = Q26_5::from_raw(32);
        assert_eq!(x.as_f32(), 1.0);
    }
    #[test]
    fn test_to_f32_1() {
        let x = Q0_7::from_raw(32);
        assert_eq!(x.as_f32(), 0.25);
    }
    #[test]
    fn test_one() {
        let x = Q26_5::one();
        assert_eq!(x, Q26_5::from_raw(32));
    }
    #[test]
    fn test_one_limit() {
        let x = Q0_31::one();
        assert_eq!(x, Q0_31::from_raw(i32::MAX));
    }
    #[test]
    fn test_mul_1() {
        let a = Q5_26::from_f32(8.0); let b = Q5_26::from_f32(3.0); let product = a * b;
        let expected = Q10_21::from_f32(24.0);
        assert_eq!(product, expected);
    }
    #[test]
    fn test_add() {
        let a = Q5_26::from_f32(16.0);
        let b = Q5_26::from_f32(5.0);
        let sum = a + b;
        let expected = Q5_26::from_f32(21.0);
        assert_eq!(sum, expected);
    }
    #[test]
    fn test_one_over_one_plus_x_for_x_in_0_1() {
        let a = Q0_31::from_f32(0.75);
        let expected_res = Q0_31::from_f32(1.0 / 1.75);
        let res = a.one_over_one_plus_x_for_x_in_0_1();
        assert_eq!(res.as_f32(), expected_res.as_f32());
    }
    #[test]
    fn test_one_over_one_plus_x_for_x_in_0_1_1() {
        let a = Q0_31::from_f32(0.0);
        let expected_res = Q0_31::from_f32(1.0 / 1.0);
        let res = a.one_over_one_plus_x_for_x_in_0_1();
        assert_eq!(res.as_f32(), expected_res.as_f32());
    }
    #[test]
    fn test_get_reciprocal_1() {
        let a = Q5_26::from_f32(4.5);
        let expected_res = Q0_31::from_f32(1.0 / 4.5);
        let (shifted_res, num_bits_over_unit) = a.get_reciprocal();
        let res = shifted_res.rounding_divide_by_pot(num_bits_over_unit as i32);
        assert_eq!(res.as_f32(), expected_res.as_f32());
        assert_eq!(num_bits_over_unit, 2);
    }
    #[test]
    fn test_get_reciprocal_2() {
        let a = Q5_26::from_f32(4.5);
        let expected_res = Q0_31::from_f32(1.0 / 4.5);
        let (shifted_res, num_bits_over_unit) = a.get_reciprocal();
        let res = shifted_res.rounding_divide_by_pot(num_bits_over_unit as i32);
        assert_eq!(res.as_f32(), expected_res.as_f32());
        assert_eq!(num_bits_over_unit, 2);
    }
    #[test]
    fn test_get_reciprocal_3() {
        let a = Q12_19::from_f32(2.0);
        let expected_res = Q0_31::from_f32(1.0 / 2.0);
        let (shifted_res, num_bits_over_unit) = a.get_reciprocal();
        let res = shifted_res.rounding_divide_by_pot(num_bits_over_unit as i32);
        assert_eq!(res.as_f32(), expected_res.as_f32());
        assert_eq!(num_bits_over_unit, 1);
    }
    #[test]
    fn test_rescale_1() {
        let a = Q0_31::from_f32(0.75);
        let expeted_res = Q12_19::from_f32(0.75);
        let res = a.rescale::<12>();
        assert_eq!(res, expeted_res);
    }
    #[test]
    fn test_exp_on_interval_between_negative_one_quarter_and_0_excl() {
        let a = Q0_31::from_f32(-0.125);
        let expected_res = Q0_31::from_f32((-0.125_f32).exp());
        let res = a.exp_on_interval_between_negative_one_quarter_and_0_excl();
        assert_eq!(res.as_f32(), expected_res.as_f32());
    }
    #[test]
    fn test_exp_on_negative_values_1() {
        let a = Q5_26::from_f32(-0.125);
        let expected_res = Q0_31::from_f32((-0.125_f32).exp());
        let res = a.exp_on_negative_values();
        assert_abs_diff_eq!(res.as_f32(), expected_res.as_f32(), epsilon = 0.00001);
    }
    #[test]
    fn test_exp_on_negative_values_2() {
        let a = Q5_26::from_f32(0.0);
        let expected_res = Q0_31::from_f32((0_f32).exp());
        let res = a.exp_on_negative_values();
        assert_abs_diff_eq!(res.as_f32(), expected_res.as_f32(), epsilon = 0.00001);
    }
    #[test]
    fn test_exp_on_negative_values_3() {
        let a = Q5_26::from_f32(-0.25);
        let expected_res = Q0_31::from_f32((-0.25_f32).exp());
        let res = a.exp_on_negative_values();
        assert_abs_diff_eq!(res.as_f32(), expected_res.as_f32(), epsilon = 0.00001);
    }
    #[test]
    fn test_exp_on_negative_values_4() {
        let a = Q5_26::from_f32(-1.1875_f32);
        let expected_res = Q0_31::from_f32((-1.1875_f32).exp());
        let res = a.exp_on_negative_values();
        assert_abs_diff_eq!(res.as_f32(), expected_res.as_f32(), epsilon = 0.00001);
    }
}