use p3_field::{AbstractField, Field};
use sp1_recursion_core::runtime::NUM_BITS;
use super::{Array, Builder, Config, DslIr, Felt, Usize, Var};
impl<C: Config> Builder<C> {
    pub fn num2bits_v(&mut self, num: Var<C::N>) -> Array<C, Var<C::N>> {
        assert!(C::N::bits() == NUM_BITS);
        let output = self.dyn_array::<Var<_>>(NUM_BITS);
        self.push(DslIr::HintBitsV(output.clone(), num));
        let sum: Var<_> = self.eval(C::N::zero());
        for i in 0..NUM_BITS {
            let bit = self.get(&output, i);
            self.assert_var_eq(bit * (bit - C::N::one()), C::N::zero());
            self.assign(sum, sum + bit * C::N::from_canonical_u32(1 << i));
        }
        self.assert_var_eq(sum, num);
        self.less_than_bb_modulus(output.clone());
        output
    }
    pub fn range_check_v(&mut self, num: Var<C::N>, num_bits: usize) {
        let bits = self.num2bits_v(num);
        self.range(num_bits, bits.len()).for_each(|i, builder| {
            let bit = builder.get(&bits, i);
            builder.assert_var_eq(bit, C::N::zero());
        });
    }
    pub fn num2bits_v_circuit(&mut self, num: Var<C::N>, bits: usize) -> Vec<Var<C::N>> {
        let mut output = Vec::new();
        for _ in 0..bits {
            output.push(self.uninit());
        }
        self.push(DslIr::CircuitNum2BitsV(num, bits, output.clone()));
        output
    }
    pub fn range_check_f(&mut self, num: Felt<C::F>, num_bits: usize) {
        let bits = self.num2bits_f(num);
        self.range(num_bits, bits.len()).for_each(|i, builder| {
            let bit = builder.get(&bits, i);
            builder.assert_var_eq(bit, C::N::zero());
        });
    }
    pub fn num2bits_f(&mut self, num: Felt<C::F>) -> Array<C, Var<C::N>> {
        let output = self.dyn_array::<Var<_>>(NUM_BITS);
        self.push(DslIr::HintBitsF(output.clone(), num));
        let sum: Felt<_> = self.eval(C::F::zero());
        for i in 0..NUM_BITS {
            let bit = self.get(&output, i);
            self.assert_var_eq(bit * (bit - C::N::one()), C::N::zero());
            self.if_eq(bit, C::N::one()).then(|builder| {
                builder.assign(sum, sum + C::F::from_canonical_u32(1 << i));
            });
        }
        self.assert_felt_eq(sum, num);
        self.less_than_bb_modulus(output.clone());
        output
    }
    pub fn num2bits_f_circuit(&mut self, num: Felt<C::F>) -> Vec<Var<C::N>> {
        let mut output = Vec::new();
        for _ in 0..NUM_BITS {
            output.push(self.uninit());
        }
        self.push(DslIr::CircuitNum2BitsF(num, output.clone()));
        let output_array = self.vec(output.clone());
        self.less_than_bb_modulus(output_array);
        output
    }
    pub fn bits2num_v(&mut self, bits: &Array<C, Var<C::N>>) -> Var<C::N> {
        let num: Var<_> = self.eval(C::N::zero());
        let power: Var<_> = self.eval(C::N::one());
        self.range(0, bits.len()).for_each(|i, builder| {
            let bit = builder.get(bits, i);
            builder.assign(num, num + bit * power);
            builder.assign(power, power * C::N::from_canonical_u32(2));
        });
        num
    }
    pub fn bits2num_v_circuit(&mut self, bits: &[Var<C::N>]) -> Var<C::N> {
        let result: Var<_> = self.eval(C::N::zero());
        for i in 0..bits.len() {
            self.assign(result, result + bits[i] * C::N::from_canonical_u32(1 << i));
        }
        result
    }
    pub fn bits2num_f(&mut self, bits: &Array<C, Var<C::N>>) -> Felt<C::F> {
        let num: Felt<_> = self.eval(C::F::zero());
        for i in 0..NUM_BITS {
            let bit = self.get(bits, i);
            self.if_eq(bit, C::N::one()).then(|builder| {
                builder.assign(num, num + C::F::from_canonical_u32(1 << i));
            });
        }
        num
    }
    pub fn reverse_bits_len(
        &mut self,
        index_bits: &Array<C, Var<C::N>>,
        bit_len: impl Into<Usize<C::N>>,
    ) -> Array<C, Var<C::N>> {
        let bit_len = bit_len.into();
        let mut result_bits = self.dyn_array::<Var<_>>(NUM_BITS);
        self.range(0, bit_len).for_each(|i, builder| {
            let index: Var<C::N> = builder.eval(bit_len - i - C::N::one());
            let entry = builder.get(index_bits, index);
            builder.set_value(&mut result_bits, i, entry);
        });
        let zero = self.eval(C::N::zero());
        self.range(bit_len, NUM_BITS).for_each(|i, builder| {
            builder.set_value(&mut result_bits, i, zero);
        });
        result_bits
    }
    pub fn reverse_bits_len_circuit(
        &mut self,
        index_bits: Vec<Var<C::N>>,
        bit_len: usize,
    ) -> Vec<Var<C::N>> {
        assert!(bit_len <= NUM_BITS);
        let mut result_bits = Vec::new();
        for i in 0..bit_len {
            let idx = bit_len - i - 1;
            result_bits.push(index_bits[idx]);
        }
        result_bits
    }
    fn less_than_bb_modulus(&mut self, num_bits: Array<C, Var<C::N>>) {
        let one: Var<_> = self.eval(C::N::one());
        let zero: Var<_> = self.eval(C::N::zero());
        let mut most_sig_4_bits = one;
        for i in (NUM_BITS - 4)..NUM_BITS {
            let bit = self.get(&num_bits, i);
            most_sig_4_bits = self.eval(bit * most_sig_4_bits);
        }
        let mut sum_least_sig_bits = zero;
        for i in 0..(NUM_BITS - 4) {
            let bit = self.get(&num_bits, i);
            sum_least_sig_bits = self.eval(bit + sum_least_sig_bits);
        }
        let check: Var<_> =
            self.eval(most_sig_4_bits * sum_least_sig_bits + (one - most_sig_4_bits) * zero);
        self.assert_var_eq(check, zero);
    }
}