use crate::{
    cfg_chunks_mut,
    cfg_into_iter,
    cfg_iter,
    cfg_iter_mut,
    fft::{DomainCoeff, SparsePolynomial},
};
use snarkvm_fields::{batch_inversion, FftField, FftParameters, Field};
#[cfg(not(feature = "serial"))]
use snarkvm_utilities::max_available_threads;
use snarkvm_utilities::{execute_with_max_available_threads, serialize::*};
use rand::Rng;
use std::{borrow::Cow, fmt};
#[cfg(not(feature = "serial"))]
use rayon::prelude::*;
#[cfg(feature = "serial")]
use itertools::Itertools;
pub fn log2(x: usize) -> u32 {
    if x == 0 {
        0
    } else if x.is_power_of_two() {
        1usize.leading_zeros() - x.leading_zeros()
    } else {
        0usize.leading_zeros() - x.leading_zeros()
    }
}
#[allow(unused)]
#[cfg(not(feature = "serial"))]
const MIN_PARALLEL_CHUNK_SIZE: usize = 1 << 7;
#[derive(Copy, Clone, Hash, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
pub struct EvaluationDomain<F: FftField> {
    pub size: u64,
    pub log_size_of_group: u32,
    pub size_as_field_element: F,
    pub size_inv: F,
    pub group_gen: F,
    pub group_gen_inv: F,
    pub generator_inv: F,
}
impl<F: FftField> fmt::Debug for EvaluationDomain<F> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "Multiplicative subgroup of size {}", self.size)
    }
}
impl<F: FftField> EvaluationDomain<F> {
    pub fn sample_element_outside_domain<R: Rng>(&self, rng: &mut R) -> F {
        let mut t = F::rand(rng);
        while self.evaluate_vanishing_polynomial(t).is_zero() {
            t = F::rand(rng);
        }
        t
    }
    pub fn new(num_coeffs: usize) -> Option<Self> {
        let size = num_coeffs.checked_next_power_of_two()? as u64;
        let log_size_of_group = size.trailing_zeros();
        if log_size_of_group > F::FftParameters::TWO_ADICITY {
            return None;
        }
        let group_gen = F::get_root_of_unity(size as usize)?;
        debug_assert_eq!(group_gen.pow([size]), F::one());
        let size_as_field_element = F::from(size);
        let size_inv = size_as_field_element.inverse()?;
        Some(EvaluationDomain {
            size,
            log_size_of_group,
            size_as_field_element,
            size_inv,
            group_gen,
            group_gen_inv: group_gen.inverse()?,
            generator_inv: F::multiplicative_generator().inverse()?,
        })
    }
    pub fn compute_size_of_domain(num_coeffs: usize) -> Option<usize> {
        let size = num_coeffs.checked_next_power_of_two()?;
        if size.trailing_zeros() <= F::FftParameters::TWO_ADICITY { Some(size) } else { None }
    }
    pub fn size(&self) -> usize {
        self.size as usize
    }
    pub fn fft<T: DomainCoeff<F>>(&self, coeffs: &[T]) -> Vec<T> {
        let mut coeffs = coeffs.to_vec();
        self.fft_in_place(&mut coeffs);
        coeffs
    }
    pub fn fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
        execute_with_max_available_threads(|| {
            coeffs.resize(self.size(), T::zero());
            self.in_order_fft_in_place(&mut *coeffs);
        });
    }
    pub fn ifft<T: DomainCoeff<F>>(&self, evals: &[T]) -> Vec<T> {
        let mut evals = evals.to_vec();
        self.ifft_in_place(&mut evals);
        evals
    }
    #[inline]
    pub fn ifft_in_place<T: DomainCoeff<F>>(&self, evals: &mut Vec<T>) {
        execute_with_max_available_threads(|| {
            evals.resize(self.size(), T::zero());
            self.in_order_ifft_in_place(&mut *evals);
        });
    }
    pub fn coset_fft<T: DomainCoeff<F>>(&self, coeffs: &[T]) -> Vec<T> {
        let mut coeffs = coeffs.to_vec();
        self.coset_fft_in_place(&mut coeffs);
        coeffs
    }
    pub fn coset_fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
        execute_with_max_available_threads(|| {
            Self::distribute_powers(coeffs, F::multiplicative_generator());
            self.fft_in_place(coeffs);
        });
    }
    pub fn coset_ifft<T: DomainCoeff<F>>(&self, evals: &[T]) -> Vec<T> {
        let mut evals = evals.to_vec();
        self.coset_ifft_in_place(&mut evals);
        evals
    }
    pub fn coset_ifft_in_place<T: DomainCoeff<F>>(&self, evals: &mut Vec<T>) {
        execute_with_max_available_threads(|| {
            evals.resize(self.size(), T::zero());
            self.in_order_coset_ifft_in_place(&mut *evals);
        });
    }
    fn distribute_powers<T: DomainCoeff<F>>(coeffs: &mut [T], g: F) {
        Self::distribute_powers_and_mul_by_const(coeffs, g, F::one());
    }
    #[cfg(feature = "serial")]
    fn distribute_powers_and_mul_by_const<T: DomainCoeff<F>>(coeffs: &mut [T], g: F, c: F) {
        let mut pow = c;
        coeffs.iter_mut().for_each(|coeff| {
            *coeff *= pow;
            pow *= &g
        })
    }
    #[cfg(not(feature = "serial"))]
    fn distribute_powers_and_mul_by_const<T: DomainCoeff<F>>(coeffs: &mut [T], g: F, c: F) {
        let min_parallel_chunk_size = 1024;
        let num_cpus_available = max_available_threads();
        let num_elem_per_thread = core::cmp::max(coeffs.len() / num_cpus_available, min_parallel_chunk_size);
        cfg_chunks_mut!(coeffs, num_elem_per_thread).enumerate().for_each(|(i, chunk)| {
            let offset = c * g.pow([(i * num_elem_per_thread) as u64]);
            let mut pow = offset;
            chunk.iter_mut().for_each(|coeff| {
                *coeff *= pow;
                pow *= &g
            })
        });
    }
    pub fn evaluate_all_lagrange_coefficients(&self, tau: F) -> Vec<F> {
        let size = self.size as usize;
        let t_size = tau.pow([self.size]);
        let one = F::one();
        if t_size.is_one() {
            let mut u = vec![F::zero(); size];
            let mut omega_i = one;
            for x in u.iter_mut().take(size) {
                if omega_i == tau {
                    *x = one;
                    break;
                }
                omega_i *= &self.group_gen;
            }
            u
        } else {
            let mut l = (t_size - one) * self.size_inv;
            let mut r = one;
            let mut u = vec![F::zero(); size];
            let mut ls = vec![F::zero(); size];
            for i in 0..size {
                u[i] = tau - r;
                ls[i] = l;
                l *= &self.group_gen;
                r *= &self.group_gen;
            }
            batch_inversion(u.as_mut_slice());
            cfg_iter_mut!(u).zip_eq(ls).for_each(|(tau_minus_r, l)| {
                *tau_minus_r = l * *tau_minus_r;
            });
            u
        }
    }
    pub fn vanishing_polynomial(&self) -> SparsePolynomial<F> {
        let coeffs = [(0, -F::one()), (self.size(), F::one())];
        SparsePolynomial::from_coefficients(coeffs)
    }
    pub fn evaluate_vanishing_polynomial(&self, tau: F) -> F {
        tau.pow([self.size]) - F::one()
    }
    pub fn elements(&self) -> Elements<F> {
        Elements { cur_elem: F::one(), cur_pow: 0, domain: *self }
    }
    pub fn divide_by_vanishing_poly_on_coset_in_place(&self, evals: &mut [F]) {
        let i = self.evaluate_vanishing_polynomial(F::multiplicative_generator()).inverse().unwrap();
        cfg_iter_mut!(evals).for_each(|eval| *eval *= &i);
    }
    pub fn reindex_by_subdomain(&self, other: &Self, index: usize) -> usize {
        assert!(self.size() >= other.size());
        let period = self.size() / other.size();
        if index < other.size() {
            index * period
        } else {
            let i = index - other.size();
            let x = period - 1;
            i + (i / x) + 1
        }
    }
    #[must_use]
    pub fn mul_polynomials_in_evaluation_domain(&self, self_evals: Vec<F>, other_evals: &[F]) -> Vec<F> {
        let mut result = self_evals;
        cfg_iter_mut!(result).zip_eq(other_evals).for_each(|(a, b)| *a *= b);
        result
    }
}
impl<F: FftField> EvaluationDomain<F> {
    pub fn precompute_fft(&self) -> FFTPrecomputation<F> {
        execute_with_max_available_threads(|| FFTPrecomputation {
            roots: self.roots_of_unity(self.group_gen),
            domain: *self,
        })
    }
    pub fn precompute_ifft(&self) -> IFFTPrecomputation<F> {
        execute_with_max_available_threads(|| IFFTPrecomputation {
            inverse_roots: self.roots_of_unity(self.group_gen_inv),
            domain: *self,
        })
    }
    pub(crate) fn in_order_fft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
        #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
        if self.size >= 32 && std::mem::size_of::<T>() == 32 {
            let result = snarkvm_algorithms_cuda::NTT(
                self.size as usize,
                x_s,
                snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
                snarkvm_algorithms_cuda::NTTDirection::Forward,
                snarkvm_algorithms_cuda::NTTType::Standard,
            );
            if result.is_ok() {
                return;
            }
        }
        let pc = self.precompute_fft();
        self.fft_helper_in_place_with_pc(x_s, FFTOrder::II, &pc)
    }
    pub fn in_order_fft_with_pc<T: DomainCoeff<F>>(&self, x_s: &[T], pc: &FFTPrecomputation<F>) -> Vec<T> {
        let mut x_s = x_s.to_vec();
        if self.size() != x_s.len() {
            x_s.extend(core::iter::repeat(T::zero()).take(self.size() - x_s.len()));
        }
        self.fft_helper_in_place_with_pc(&mut x_s, FFTOrder::II, pc);
        x_s
    }
    pub(crate) fn in_order_ifft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
        #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
        if self.size >= 32 && std::mem::size_of::<T>() == 32 {
            let result = snarkvm_algorithms_cuda::NTT(
                self.size as usize,
                x_s,
                snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
                snarkvm_algorithms_cuda::NTTDirection::Inverse,
                snarkvm_algorithms_cuda::NTTType::Standard,
            );
            if result.is_ok() {
                return;
            }
        }
        let pc = self.precompute_ifft();
        self.ifft_helper_in_place_with_pc(x_s, FFTOrder::II, &pc);
        cfg_iter_mut!(x_s).for_each(|val| *val *= self.size_inv);
    }
    pub(crate) fn in_order_coset_ifft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
        #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
        if self.size >= 32 && std::mem::size_of::<T>() == 32 {
            let result = snarkvm_algorithms_cuda::NTT(
                self.size as usize,
                x_s,
                snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
                snarkvm_algorithms_cuda::NTTDirection::Inverse,
                snarkvm_algorithms_cuda::NTTType::Coset,
            );
            if result.is_ok() {
                return;
            }
        }
        let pc = self.precompute_ifft();
        self.ifft_helper_in_place_with_pc(x_s, FFTOrder::II, &pc);
        let coset_shift = self.generator_inv;
        Self::distribute_powers_and_mul_by_const(x_s, coset_shift, self.size_inv);
    }
    #[allow(unused)]
    pub(crate) fn in_order_fft_in_place_with_pc<T: DomainCoeff<F>>(
        &self,
        x_s: &mut [T],
        pre_comp: &FFTPrecomputation<F>,
    ) {
        #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
        if self.size >= 32 && std::mem::size_of::<T>() == 32 {
            let result = snarkvm_algorithms_cuda::NTT(
                self.size as usize,
                x_s,
                snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
                snarkvm_algorithms_cuda::NTTDirection::Forward,
                snarkvm_algorithms_cuda::NTTType::Standard,
            );
            if result.is_ok() {
                return;
            }
        }
        self.fft_helper_in_place_with_pc(x_s, FFTOrder::II, pre_comp)
    }
    pub(crate) fn out_order_fft_in_place_with_pc<T: DomainCoeff<F>>(
        &self,
        x_s: &mut [T],
        pre_comp: &FFTPrecomputation<F>,
    ) {
        self.fft_helper_in_place_with_pc(x_s, FFTOrder::IO, pre_comp)
    }
    pub(crate) fn in_order_ifft_in_place_with_pc<T: DomainCoeff<F>>(
        &self,
        x_s: &mut [T],
        pre_comp: &IFFTPrecomputation<F>,
    ) {
        #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
        if self.size >= 32 && std::mem::size_of::<T>() == 32 {
            let result = snarkvm_algorithms_cuda::NTT(
                self.size as usize,
                x_s,
                snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
                snarkvm_algorithms_cuda::NTTDirection::Inverse,
                snarkvm_algorithms_cuda::NTTType::Standard,
            );
            if result.is_ok() {
                return;
            }
        }
        self.ifft_helper_in_place_with_pc(x_s, FFTOrder::II, pre_comp);
        cfg_iter_mut!(x_s).for_each(|val| *val *= self.size_inv);
    }
    pub(crate) fn out_order_ifft_in_place_with_pc<T: DomainCoeff<F>>(
        &self,
        x_s: &mut [T],
        pre_comp: &IFFTPrecomputation<F>,
    ) {
        self.ifft_helper_in_place_with_pc(x_s, FFTOrder::OI, pre_comp);
        cfg_iter_mut!(x_s).for_each(|val| *val *= self.size_inv);
    }
    #[allow(unused)]
    pub(crate) fn in_order_coset_ifft_in_place_with_pc<T: DomainCoeff<F>>(
        &self,
        x_s: &mut [T],
        pre_comp: &IFFTPrecomputation<F>,
    ) {
        #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
        if self.size >= 32 && std::mem::size_of::<T>() == 32 {
            let result = snarkvm_algorithms_cuda::NTT(
                self.size as usize,
                x_s,
                snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
                snarkvm_algorithms_cuda::NTTDirection::Inverse,
                snarkvm_algorithms_cuda::NTTType::Coset,
            );
            if result.is_ok() {
                return;
            }
        }
        self.ifft_helper_in_place_with_pc(x_s, FFTOrder::II, pre_comp);
        let coset_shift = self.generator_inv;
        Self::distribute_powers_and_mul_by_const(x_s, coset_shift, self.size_inv);
    }
    fn fft_helper_in_place_with_pc<T: DomainCoeff<F>>(
        &self,
        x_s: &mut [T],
        ord: FFTOrder,
        pre_comp: &FFTPrecomputation<F>,
    ) {
        use FFTOrder::*;
        let pc = pre_comp.precomputation_for_subdomain(self).unwrap();
        let log_len = log2(x_s.len());
        if ord == OI {
            self.oi_helper_with_roots(x_s, &pc.roots);
        } else {
            self.io_helper_with_roots(x_s, &pc.roots);
        }
        if ord == II {
            derange_helper(x_s, log_len);
        }
    }
    fn ifft_helper_in_place_with_pc<T: DomainCoeff<F>>(
        &self,
        x_s: &mut [T],
        ord: FFTOrder,
        pre_comp: &IFFTPrecomputation<F>,
    ) {
        use FFTOrder::*;
        let pc = pre_comp.precomputation_for_subdomain(self).unwrap();
        let log_len = log2(x_s.len());
        if ord == II {
            derange_helper(x_s, log_len);
        }
        if ord == IO {
            self.io_helper_with_roots(x_s, &pc.inverse_roots);
        } else {
            self.oi_helper_with_roots(x_s, &pc.inverse_roots);
        }
    }
    #[cfg(feature = "serial")]
    pub fn roots_of_unity(&self, root: F) -> Vec<F> {
        compute_powers_serial((self.size as usize) / 2, root)
    }
    #[cfg(not(feature = "serial"))]
    pub fn roots_of_unity(&self, root: F) -> Vec<F> {
        let log_size = log2(self.size as usize);
        if log_size <= LOG_ROOTS_OF_UNITY_PARALLEL_SIZE {
            compute_powers_serial((self.size as usize) / 2, root)
        } else {
            let mut temp = root;
            let log_powers: Vec<F> = (0..(log_size - 1))
                .map(|_| {
                    let old_value = temp;
                    temp.square_in_place();
                    old_value
                })
                .collect();
            let mut powers = vec![F::zero(); 1 << (log_size - 1)];
            Self::roots_of_unity_recursive(&mut powers, &log_powers);
            powers
        }
    }
    #[cfg(not(feature = "serial"))]
    fn roots_of_unity_recursive(out: &mut [F], log_powers: &[F]) {
        assert_eq!(out.len(), 1 << log_powers.len());
        if log_powers.len() <= LOG_ROOTS_OF_UNITY_PARALLEL_SIZE as usize {
            out[0] = F::one();
            for idx in 1..out.len() {
                out[idx] = out[idx - 1] * log_powers[0];
            }
            return;
        }
        let (lr_lo, lr_hi) = log_powers.split_at((1 + log_powers.len()) / 2);
        let mut scr_lo = vec![F::default(); 1 << lr_lo.len()];
        let mut scr_hi = vec![F::default(); 1 << lr_hi.len()];
        rayon::join(
            || Self::roots_of_unity_recursive(&mut scr_lo, lr_lo),
            || Self::roots_of_unity_recursive(&mut scr_hi, lr_hi),
        );
        out.par_chunks_mut(scr_lo.len()).zip(&scr_hi).for_each(|(out_chunk, scr_hi)| {
            for (out_elem, scr_lo) in out_chunk.iter_mut().zip(&scr_lo) {
                *out_elem = *scr_hi * scr_lo;
            }
        });
    }
    #[inline(always)]
    fn butterfly_fn_io<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
        let neg = *lo - *hi;
        *lo += *hi;
        *hi = neg;
        *hi *= *root;
    }
    #[inline(always)]
    fn butterfly_fn_oi<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
        *hi *= *root;
        let neg = *lo - *hi;
        *lo += *hi;
        *hi = neg;
    }
    #[allow(clippy::too_many_arguments)]
    fn apply_butterfly<T: DomainCoeff<F>, G: Fn(((&mut T, &mut T), &F)) + Copy + Sync + Send>(
        g: G,
        xi: &mut [T],
        roots: &[F],
        step: usize,
        chunk_size: usize,
        num_chunks: usize,
        max_threads: usize,
        gap: usize,
    ) {
        cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
            let (lo, hi) = cxi.split_at_mut(gap);
            if gap > MIN_GAP_SIZE_FOR_PARALLELISATION && num_chunks < max_threads {
                cfg_iter_mut!(lo).zip(hi).zip(cfg_iter!(roots).step_by(step)).for_each(g);
            } else {
                lo.iter_mut().zip(hi).zip(roots.iter().step_by(step)).for_each(g);
            }
        });
    }
    #[allow(clippy::unnecessary_to_owned)]
    fn io_helper_with_roots<T: DomainCoeff<F>>(&self, xi: &mut [T], roots: &[F]) {
        let mut roots = std::borrow::Cow::Borrowed(roots);
        let mut step = 1;
        let mut first = true;
        #[cfg(not(feature = "serial"))]
        let max_threads = snarkvm_utilities::parallel::max_available_threads();
        #[cfg(feature = "serial")]
        let max_threads = 1;
        let mut gap = xi.len() / 2;
        while gap > 0 {
            let chunk_size = 2 * gap;
            let num_chunks = xi.len() / chunk_size;
            if num_chunks >= MIN_NUM_CHUNKS_FOR_COMPACTION {
                if !first {
                    roots = Cow::Owned(cfg_into_iter!(roots.into_owned()).step_by(step * 2).collect());
                }
                step = 1;
                roots.to_mut().shrink_to_fit();
            } else {
                step = num_chunks;
            }
            first = false;
            Self::apply_butterfly(
                Self::butterfly_fn_io,
                xi,
                &roots[..],
                step,
                chunk_size,
                num_chunks,
                max_threads,
                gap,
            );
            gap /= 2;
        }
    }
    fn oi_helper_with_roots<T: DomainCoeff<F>>(&self, xi: &mut [T], roots_cache: &[F]) {
        let compaction_max_size =
            core::cmp::min(roots_cache.len() / 2, roots_cache.len() / MIN_NUM_CHUNKS_FOR_COMPACTION);
        let mut compacted_roots = vec![F::default(); compaction_max_size];
        #[cfg(not(feature = "serial"))]
        let max_threads = snarkvm_utilities::parallel::max_available_threads();
        #[cfg(feature = "serial")]
        let max_threads = 1;
        let mut gap = 1;
        while gap < xi.len() {
            let chunk_size = 2 * gap;
            let num_chunks = xi.len() / chunk_size;
            let (roots, step) = if num_chunks >= MIN_NUM_CHUNKS_FOR_COMPACTION && gap < xi.len() / 2 {
                cfg_iter_mut!(compacted_roots[..gap])
                    .zip(cfg_iter!(roots_cache[..(gap * num_chunks)]).step_by(num_chunks))
                    .for_each(|(a, b)| *a = *b);
                (&compacted_roots[..gap], 1)
            } else {
                (roots_cache, num_chunks)
            };
            Self::apply_butterfly(Self::butterfly_fn_oi, xi, roots, step, chunk_size, num_chunks, max_threads, gap);
            gap *= 2;
        }
    }
}
const MIN_NUM_CHUNKS_FOR_COMPACTION: usize = 1 << 7;
const MIN_GAP_SIZE_FOR_PARALLELISATION: usize = 1 << 10;
#[cfg(not(feature = "serial"))]
const LOG_ROOTS_OF_UNITY_PARALLEL_SIZE: u32 = 7;
#[inline]
pub(super) fn bitrev(a: u64, log_len: u32) -> u64 {
    a.reverse_bits() >> (64 - log_len)
}
pub(crate) fn derange<T>(xi: &mut [T]) {
    derange_helper(xi, log2(xi.len()))
}
fn derange_helper<T>(xi: &mut [T], log_len: u32) {
    for idx in 1..(xi.len() as u64 - 1) {
        let ridx = bitrev(idx, log_len);
        if idx < ridx {
            xi.swap(idx as usize, ridx as usize);
        }
    }
}
#[derive(PartialEq, Eq, Debug)]
enum FFTOrder {
    II,
    IO,
    OI,
}
pub(crate) fn compute_powers_serial<F: Field>(size: usize, root: F) -> Vec<F> {
    compute_powers_and_mul_by_const_serial(size, root, F::one())
}
pub(crate) fn compute_powers_and_mul_by_const_serial<F: Field>(size: usize, root: F, c: F) -> Vec<F> {
    let mut value = c;
    (0..size)
        .map(|_| {
            let old_value = value;
            value *= root;
            old_value
        })
        .collect()
}
#[allow(unused)]
#[cfg(not(feature = "serial"))]
pub(crate) fn compute_powers<F: Field>(size: usize, g: F) -> Vec<F> {
    if size < MIN_PARALLEL_CHUNK_SIZE {
        return compute_powers_serial(size, g);
    }
    let num_cpus_available = max_available_threads();
    let num_elem_per_thread = core::cmp::max(size / num_cpus_available, MIN_PARALLEL_CHUNK_SIZE);
    let num_cpus_used = size / num_elem_per_thread;
    let res: Vec<F> = (0..num_cpus_used)
        .into_par_iter()
        .flat_map(|i| {
            let offset = g.pow([(i * num_elem_per_thread) as u64]);
            let num_elements_to_compute = core::cmp::min(size - i * num_elem_per_thread, num_elem_per_thread);
            compute_powers_and_mul_by_const_serial(num_elements_to_compute, g, offset)
        })
        .collect();
    res
}
#[derive(Clone)]
pub struct Elements<F: FftField> {
    cur_elem: F,
    cur_pow: u64,
    domain: EvaluationDomain<F>,
}
impl<F: FftField> Iterator for Elements<F> {
    type Item = F;
    fn next(&mut self) -> Option<F> {
        if self.cur_pow == self.domain.size {
            None
        } else {
            let cur_elem = self.cur_elem;
            self.cur_elem *= &self.domain.group_gen;
            self.cur_pow += 1;
            Some(cur_elem)
        }
    }
}
#[derive(Clone, Eq, PartialEq, Debug, CanonicalDeserialize, CanonicalSerialize)]
pub struct FFTPrecomputation<F: FftField> {
    roots: Vec<F>,
    domain: EvaluationDomain<F>,
}
impl<F: FftField> FFTPrecomputation<F> {
    pub fn to_ifft_precomputation(&self) -> IFFTPrecomputation<F> {
        let mut inverse_roots = self.roots.clone();
        snarkvm_fields::batch_inversion(&mut inverse_roots);
        IFFTPrecomputation { inverse_roots, domain: self.domain }
    }
    pub fn precomputation_for_subdomain<'a>(&'a self, domain: &EvaluationDomain<F>) -> Option<Cow<'a, Self>> {
        if domain.size() == 1 {
            return Some(Cow::Owned(Self { roots: vec![], domain: *domain }));
        }
        if &self.domain == domain {
            Some(Cow::Borrowed(self))
        } else if domain.size() < self.domain.size() {
            let size_ratio = self.domain.size() / domain.size();
            let roots = self.roots.iter().step_by(size_ratio).copied().collect();
            Some(Cow::Owned(Self { roots, domain: *domain }))
        } else {
            None
        }
    }
}
#[derive(Clone, Eq, PartialEq, Debug, CanonicalSerialize, CanonicalDeserialize)]
pub struct IFFTPrecomputation<F: FftField> {
    inverse_roots: Vec<F>,
    domain: EvaluationDomain<F>,
}
impl<F: FftField> IFFTPrecomputation<F> {
    pub fn precomputation_for_subdomain<'a>(&'a self, domain: &EvaluationDomain<F>) -> Option<Cow<'a, Self>> {
        if domain.size() == 1 {
            return Some(Cow::Owned(Self { inverse_roots: vec![], domain: *domain }));
        }
        if &self.domain == domain {
            Some(Cow::Borrowed(self))
        } else if domain.size() < self.domain.size() {
            let size_ratio = self.domain.size() / domain.size();
            let inverse_roots = self.inverse_roots.iter().step_by(size_ratio).copied().collect();
            Some(Cow::Owned(Self { inverse_roots, domain: *domain }))
        } else {
            None
        }
    }
}
#[cfg(test)]
mod tests {
    #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
    use crate::fft::domain::FFTOrder;
    use crate::fft::{DensePolynomial, EvaluationDomain};
    use rand::Rng;
    use snarkvm_curves::bls12_377::Fr;
    use snarkvm_fields::{FftField, Field, One, Zero};
    use snarkvm_utilities::{TestRng, Uniform};
    #[test]
    fn vanishing_polynomial_evaluation() {
        let rng = &mut TestRng::default();
        for coeffs in 0..10 {
            let domain = EvaluationDomain::<Fr>::new(coeffs).unwrap();
            let z = domain.vanishing_polynomial();
            for _ in 0..100 {
                let point = rng.gen();
                assert_eq!(z.evaluate(point), domain.evaluate_vanishing_polynomial(point))
            }
        }
    }
    #[test]
    fn vanishing_polynomial_vanishes_on_domain() {
        for coeffs in 0..1000 {
            let domain = EvaluationDomain::<Fr>::new(coeffs).unwrap();
            let z = domain.vanishing_polynomial();
            for point in domain.elements() {
                assert!(z.evaluate(point).is_zero())
            }
        }
    }
    #[test]
    fn size_of_elements() {
        for coeffs in 1..10 {
            let size = 1 << coeffs;
            let domain = EvaluationDomain::<Fr>::new(size).unwrap();
            let domain_size = domain.size();
            assert_eq!(domain_size, domain.elements().count());
        }
    }
    #[test]
    fn elements_contents() {
        for coeffs in 1..10 {
            let size = 1 << coeffs;
            let domain = EvaluationDomain::<Fr>::new(size).unwrap();
            for (i, element) in domain.elements().enumerate() {
                assert_eq!(element, domain.group_gen.pow([i as u64]));
            }
        }
    }
    #[test]
    fn non_systematic_lagrange_coefficients_test() {
        let mut rng = TestRng::default();
        for domain_dimension in 1..10 {
            let domain_size = 1 << domain_dimension;
            let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
            let random_point = Fr::rand(&mut rng);
            let lagrange_coefficients = domain.evaluate_all_lagrange_coefficients(random_point);
            let random_polynomial = DensePolynomial::<Fr>::rand(domain_size - 1, &mut rng);
            let polynomial_evaluations = domain.fft(random_polynomial.coeffs());
            let actual_evaluations = random_polynomial.evaluate(random_point);
            let mut interpolated_evaluation = Fr::zero();
            for i in 0..domain_size {
                interpolated_evaluation += lagrange_coefficients[i] * polynomial_evaluations[i];
            }
            assert_eq!(actual_evaluations, interpolated_evaluation);
        }
    }
    #[test]
    fn systematic_lagrange_coefficients_test() {
        for domain_dimension in 1..5 {
            let domain_size = 1 << domain_dimension;
            let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
            let all_domain_elements: Vec<Fr> = domain.elements().collect();
            for (i, domain_element) in all_domain_elements.iter().enumerate().take(domain_size) {
                let lagrange_coefficients = domain.evaluate_all_lagrange_coefficients(*domain_element);
                for (j, lagrange_coefficient) in lagrange_coefficients.iter().enumerate().take(domain_size) {
                    if i == j {
                        assert_eq!(*lagrange_coefficient, Fr::one());
                    } else {
                        assert_eq!(*lagrange_coefficient, Fr::zero());
                    }
                }
            }
        }
    }
    #[test]
    fn test_roots_of_unity() {
        let max_degree = 10;
        for log_domain_size in 0..max_degree {
            let domain_size = 1 << log_domain_size;
            let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
            let actual_roots = domain.roots_of_unity(domain.group_gen);
            for &value in &actual_roots {
                assert!(domain.evaluate_vanishing_polynomial(value).is_zero());
            }
            let expected_roots_elements = domain.elements();
            for (expected, &actual) in expected_roots_elements.zip(&actual_roots) {
                assert_eq!(expected, actual);
            }
            assert_eq!(actual_roots.len(), domain_size / 2);
        }
    }
    #[test]
    fn test_fft_correctness() {
        let mut rng = TestRng::default();
        let log_degree = 5;
        let degree = 1 << log_degree;
        let random_polynomial = DensePolynomial::<Fr>::rand(degree - 1, &mut rng);
        for log_domain_size in log_degree..(log_degree + 2) {
            let domain_size = 1 << log_domain_size;
            let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
            let polynomial_evaluations = domain.fft(&random_polynomial.coeffs);
            let polynomial_coset_evaluations = domain.coset_fft(&random_polynomial.coeffs);
            for (i, x) in domain.elements().enumerate() {
                let coset_x = Fr::multiplicative_generator() * x;
                assert_eq!(polynomial_evaluations[i], random_polynomial.evaluate(x));
                assert_eq!(polynomial_coset_evaluations[i], random_polynomial.evaluate(coset_x));
            }
            let randon_polynomial_from_subgroup =
                DensePolynomial::from_coefficients_vec(domain.ifft(&polynomial_evaluations));
            let random_polynomial_from_coset =
                DensePolynomial::from_coefficients_vec(domain.coset_ifft(&polynomial_coset_evaluations));
            assert_eq!(
                random_polynomial, randon_polynomial_from_subgroup,
                "degree = {degree}, domain size = {domain_size}"
            );
            assert_eq!(
                random_polynomial, random_polynomial_from_coset,
                "degree = {degree}, domain size = {domain_size}"
            );
        }
    }
    #[test]
    fn test_fft_precomputation() {
        for i in 1..10 {
            let big_domain = EvaluationDomain::<Fr>::new(i).unwrap();
            let pc = big_domain.precompute_fft();
            for j in 1..i {
                let small_domain = EvaluationDomain::<Fr>::new(j).unwrap();
                let small_pc = small_domain.precompute_fft();
                assert_eq!(pc.precomputation_for_subdomain(&small_domain).unwrap().as_ref(), &small_pc);
            }
        }
    }
    #[test]
    fn test_ifft_precomputation() {
        for i in 1..10 {
            let big_domain = EvaluationDomain::<Fr>::new(i).unwrap();
            let pc = big_domain.precompute_ifft();
            for j in 1..i {
                let small_domain = EvaluationDomain::<Fr>::new(j).unwrap();
                let small_pc = small_domain.precompute_ifft();
                assert_eq!(pc.precomputation_for_subdomain(&small_domain).unwrap().as_ref(), &small_pc);
            }
        }
    }
    #[test]
    fn test_ifft_precomputation_from_fft() {
        for i in 1..10 {
            let domain = EvaluationDomain::<Fr>::new(i).unwrap();
            let pc = domain.precompute_ifft();
            let fft_pc = domain.precompute_fft();
            assert_eq!(pc, fft_pc.to_ifft_precomputation())
        }
    }
    #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
    #[test]
    fn test_fft_correctness_cuda() {
        let mut rng = TestRng::default();
        for log_domain in 2..20 {
            println!("Testing domain size {log_domain}");
            let domain_size = 1 << log_domain;
            let random_polynomial = DensePolynomial::<Fr>::rand(domain_size - 1, &mut rng);
            let mut polynomial_evaluations = random_polynomial.coeffs.clone();
            let mut polynomial_evaluations_cuda = random_polynomial.coeffs.clone();
            let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
            let pc = domain.precompute_fft();
            domain.fft_helper_in_place_with_pc(&mut polynomial_evaluations, FFTOrder::II, &pc);
            if snarkvm_algorithms_cuda::NTT::<Fr>(
                domain_size,
                &mut polynomial_evaluations_cuda,
                snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
                snarkvm_algorithms_cuda::NTTDirection::Forward,
                snarkvm_algorithms_cuda::NTTType::Standard,
            )
            .is_err()
            {
                println!("cuda error!");
            }
            assert_eq!(polynomial_evaluations, polynomial_evaluations_cuda, "domain size = {domain_size}");
            if snarkvm_algorithms_cuda::NTT::<Fr>(
                domain_size,
                &mut polynomial_evaluations_cuda,
                snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
                snarkvm_algorithms_cuda::NTTDirection::Inverse,
                snarkvm_algorithms_cuda::NTTType::Standard,
            )
            .is_err()
            {
                println!("cuda error!");
            }
            assert_eq!(random_polynomial.coeffs, polynomial_evaluations_cuda, "domain size = {domain_size}");
            polynomial_evaluations = random_polynomial.coeffs.clone();
            let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
            let pc = domain.precompute_fft();
            EvaluationDomain::<Fr>::distribute_powers(&mut polynomial_evaluations, Fr::multiplicative_generator());
            domain.fft_helper_in_place_with_pc(&mut polynomial_evaluations, FFTOrder::II, &pc);
            if snarkvm_algorithms_cuda::NTT::<Fr>(
                domain_size,
                &mut polynomial_evaluations_cuda,
                snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
                snarkvm_algorithms_cuda::NTTDirection::Forward,
                snarkvm_algorithms_cuda::NTTType::Coset,
            )
            .is_err()
            {
                println!("cuda error!");
            }
            assert_eq!(polynomial_evaluations, polynomial_evaluations_cuda, "domain size = {domain_size}");
            if snarkvm_algorithms_cuda::NTT::<Fr>(
                domain_size,
                &mut polynomial_evaluations_cuda,
                snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
                snarkvm_algorithms_cuda::NTTDirection::Inverse,
                snarkvm_algorithms_cuda::NTTType::Coset,
            )
            .is_err()
            {
                println!("cuda error!");
            }
            assert_eq!(random_polynomial.coeffs, polynomial_evaluations_cuda, "domain size = {domain_size}");
        }
    }
}