use crate::fft::{multicore::Worker, SparsePolynomial};
use snarkvm_fields::{batch_inversion, FieldParameters, PrimeField};
use snarkvm_utilities::{errors::SerializationError, serialize::*};
use rand::Rng;
use std::fmt;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[cfg(feature = "parallel")]
const LOG_ROOTS_OF_UNITY_PARALLEL_SIZE: usize = 7;
#[cfg(feature = "parallel")]
fn log2(number: usize) -> usize {
(number as f64).log2() as usize
}
#[derive(Copy, Clone, Hash, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
pub struct EvaluationDomain<F: PrimeField> {
pub size: u64,
pub log_size_of_group: u32,
pub size_as_field_element: F,
pub size_inv: F,
pub group_gen: F,
pub group_gen_inv: F,
pub generator_inv: F,
}
impl<F: PrimeField> fmt::Debug for EvaluationDomain<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Multiplicative subgroup of size {}", self.size)
}
}
impl<F: PrimeField> EvaluationDomain<F> {
fn calculate_chunk_size(size: usize) -> usize {
match size / rayon::current_num_threads() {
0 => 1,
chunk_size => chunk_size,
}
}
pub fn sample_element_outside_domain<R: Rng>(&self, rng: &mut R) -> F {
let mut t = F::rand(rng);
while self.evaluate_vanishing_polynomial(t).is_zero() {
t = F::rand(rng);
}
t
}
pub fn new(num_coeffs: usize) -> Option<Self> {
let size = num_coeffs.next_power_of_two() as u64;
let log_size_of_group = size.trailing_zeros();
if log_size_of_group >= F::Parameters::TWO_ADICITY {
return None;
}
let mut group_gen = F::root_of_unity();
for _ in log_size_of_group..F::Parameters::TWO_ADICITY {
group_gen.square_in_place();
}
let size_as_bigint = F::BigInteger::from(size);
let size_as_field_element = F::from_repr(size_as_bigint)?;
let size_inv = size_as_field_element.inverse()?;
Some(EvaluationDomain {
size,
log_size_of_group,
size_as_field_element,
size_inv,
group_gen,
group_gen_inv: group_gen.inverse()?,
generator_inv: F::multiplicative_generator().inverse()?,
})
}
pub fn compute_size_of_domain(num_coeffs: usize) -> Option<usize> {
let size = num_coeffs.next_power_of_two();
if size.trailing_zeros() < F::Parameters::TWO_ADICITY {
Some(size)
} else {
None
}
}
pub fn size(&self) -> usize {
self.size as usize
}
pub fn fft(&self, coeffs: &[F]) -> Vec<F> {
let mut coeffs = coeffs.to_vec();
self.fft_in_place(&mut coeffs);
coeffs
}
pub fn fft_in_place(&self, coeffs: &mut Vec<F>) {
coeffs.resize(self.size(), F::zero());
best_fft(coeffs, &Worker::new(), self.group_gen, self.log_size_of_group)
}
pub fn ifft(&self, evals: &[F]) -> Vec<F> {
let mut evals = evals.to_vec();
self.ifft_in_place(&mut evals);
evals
}
#[inline]
pub fn ifft_in_place(&self, evals: &mut Vec<F>) {
evals.resize(self.size(), F::zero());
best_fft(evals, &Worker::new(), self.group_gen_inv, self.log_size_of_group);
cfg_iter_mut!(evals).for_each(|val| *val *= &self.size_inv);
}
pub fn coset_fft(&self, coeffs: &[F]) -> Vec<F> {
let mut coeffs = coeffs.to_vec();
self.coset_fft_in_place(&mut coeffs);
coeffs
}
pub fn coset_fft_in_place(&self, coeffs: &mut Vec<F>) {
Self::distribute_powers(coeffs, F::multiplicative_generator());
self.fft_in_place(coeffs);
}
pub fn coset_ifft(&self, evals: &[F]) -> Vec<F> {
let mut evals = evals.to_vec();
self.coset_ifft_in_place(&mut evals);
evals
}
pub fn coset_ifft_in_place(&self, evals: &mut Vec<F>) {
self.ifft_in_place(evals);
Self::distribute_powers(evals, self.generator_inv);
}
fn distribute_powers(coeffs: &mut Vec<F>, g: F) {
Worker::new().scope(coeffs.len(), |scope, chunk| {
for (i, v) in coeffs.chunks_mut(chunk).enumerate() {
scope.spawn(move |_| {
let mut u = g.pow(&[(i * chunk) as u64]);
for v in v.iter_mut() {
*v *= &u;
u *= &g;
}
});
}
});
}
pub fn evaluate_all_lagrange_coefficients(&self, tau: F) -> Vec<F> {
let size = self.size as usize;
let t_size = tau.pow(&[self.size]);
let one = F::one();
if t_size.is_one() {
let mut u = vec![F::zero(); size];
let mut omega_i = one;
for x in u.iter_mut().take(size) {
if omega_i == tau {
*x = one;
break;
}
omega_i *= &self.group_gen;
}
u
} else {
let mut l = (t_size - &one) * &self.size_inv;
let mut r = one;
let mut u = vec![F::zero(); size];
let mut ls = vec![F::zero(); size];
for i in 0..size {
u[i] = tau - &r;
ls[i] = l;
l *= &self.group_gen;
r *= &self.group_gen;
}
batch_inversion(u.as_mut_slice());
cfg_iter_mut!(u).zip(ls).for_each(|(tau_minus_r, l)| {
*tau_minus_r = l * tau_minus_r;
});
u
}
}
pub fn vanishing_polynomial(&self) -> SparsePolynomial<F> {
let coeffs = vec![(0, -F::one()), (self.size(), F::one())];
SparsePolynomial::from_coefficients_vec(coeffs)
}
pub fn evaluate_vanishing_polynomial(&self, tau: F) -> F {
tau.pow(&[self.size]) - &F::one()
}
pub fn elements(&self) -> Elements<F> {
Elements {
cur_elem: F::one(),
cur_pow: 0,
domain: *self,
}
}
pub fn divide_by_vanishing_poly_on_coset_in_place(&self, evals: &mut [F]) {
let i = self
.evaluate_vanishing_polynomial(F::multiplicative_generator())
.inverse()
.unwrap();
Worker::new().scope(evals.len(), |scope, chunk| {
for evals in evals.chunks_mut(chunk) {
scope.spawn(move |_| evals.iter_mut().for_each(|eval| *eval *= &i));
}
});
}
pub fn reindex_by_subdomain(&self, other: Self, index: usize) -> usize {
assert!(self.size() >= other.size());
let period = self.size() / other.size();
if index < other.size() {
index * period
} else {
let i = index - other.size();
let x = period - 1;
i + (i / x) + 1
}
}
#[must_use]
pub fn mul_polynomials_in_evaluation_domain(&self, self_evals: &[F], other_evals: &[F]) -> Vec<F> {
assert_eq!(self_evals.len(), other_evals.len());
let mut result = self_evals.to_vec();
let chunk_size = Self::calculate_chunk_size(self.size());
cfg_chunks_mut!(result, chunk_size)
.zip(cfg_chunks!(other_evals, chunk_size))
.for_each(|(a, b)| {
for (a, b) in a.iter_mut().zip(b) {
*a *= b;
}
});
result
}
#[cfg(not(feature = "parallel"))]
pub fn roots_of_unity(&self, root: F) -> Vec<F> {
Self::compute_powers_serial((self.size as usize) / 2, root)
}
#[cfg(feature = "parallel")]
pub fn roots_of_unity(&self, root: F) -> Vec<F> {
let log_size = log2(self.size as usize);
if log_size <= LOG_ROOTS_OF_UNITY_PARALLEL_SIZE {
Self::compute_powers_serial((self.size as usize) / 2, root)
} else {
let mut tmp = root;
let log_powers: Vec<F> = (0..(log_size - 1))
.map(|_| {
let old_value = tmp;
tmp.square_in_place();
old_value
})
.collect();
let mut powers = vec![F::zero(); 1 << (log_size - 1)];
Self::roots_of_unity_recursive(&mut powers, &log_powers);
powers
}
}
#[cfg(feature = "parallel")]
fn roots_of_unity_recursive(powers: &mut [F], log_powers: &[F]) {
assert_eq!(powers.len(), 1 << log_powers.len());
if log_powers.len() <= LOG_ROOTS_OF_UNITY_PARALLEL_SIZE {
powers[0] = F::one();
for i in 1..powers.len() {
powers[i] = powers[i - 1] * &log_powers[0];
}
return;
}
let (lr_low, lr_high) = log_powers.split_at((1 + log_powers.len()) / 2);
let mut scr_low = vec![F::default(); 1 << lr_low.len()];
let mut scr_high = vec![F::default(); 1 << lr_high.len()];
rayon::join(
|| Self::roots_of_unity_recursive(&mut scr_low, lr_low),
|| Self::roots_of_unity_recursive(&mut scr_high, lr_high),
);
powers
.par_chunks_mut(scr_low.len())
.zip(&scr_high)
.for_each(|(power_chunk, scr_high)| {
for (power, scr_low) in power_chunk.iter_mut().zip(&scr_low) {
*power = *scr_high * scr_low;
}
});
}
fn compute_powers_serial(size: usize, root: F) -> Vec<F> {
let mut value = F::one();
(0..size)
.map(|_| {
let old_value = value;
value *= &root;
old_value
})
.collect()
}
}
#[allow(unused_variables)]
#[cfg(not(feature = "parallel"))]
fn best_fft<F: PrimeField>(a: &mut [F], worker: &Worker, omega: F, log_n: u32) {
serial_radix2_fft(a, omega, log_n);
}
#[cfg(feature = "parallel")]
fn best_fft<F: PrimeField>(a: &mut [F], worker: &Worker, omega: F, log_n: u32) {
let log_cpus = worker.log_num_cpus();
if log_n <= log_cpus {
serial_radix2_fft(a, omega, log_n);
} else {
parallel_radix2_fft(a, worker, omega, log_n, log_cpus);
}
}
#[allow(clippy::many_single_char_names)]
pub(crate) fn serial_radix2_fft<F: PrimeField>(a: &mut [F], omega: F, log_n: u32) {
#[inline]
fn bitreverse(mut n: u32, l: u32) -> u32 {
let mut r = 0;
for _ in 0..l {
r = (r << 1) | (n & 1);
n >>= 1;
}
r
}
let n = a.len() as u32;
assert_eq!(n, 1 << log_n);
for k in 0..n {
let rk = bitreverse(k, log_n);
if k < rk {
a.swap(rk as usize, k as usize);
}
}
let mut m = 1;
for _ in 0..log_n {
let w_m = omega.pow(&[(n / (2 * m)) as u64]);
let mut k = 0;
while k < n {
let mut w = F::one();
for j in 0..m {
let mut t = a[(k + j + m) as usize];
t *= &w;
let mut tmp = a[(k + j) as usize];
tmp -= &t;
a[(k + j + m) as usize] = tmp;
a[(k + j) as usize] += &t;
w.mul_assign(&w_m);
}
k += 2 * m;
}
m *= 2;
}
}
#[cfg(feature = "parallel")]
pub(crate) fn parallel_radix2_fft<F: PrimeField>(a: &mut [F], worker: &Worker, omega: F, log_n: u32, log_cpus: u32) {
assert!(log_n >= log_cpus);
let num_cpus = 1 << log_cpus;
let log_new_n = log_n - log_cpus;
let mut tmp = vec![vec![F::zero(); 1 << log_new_n]; num_cpus];
let new_omega = omega.pow(&[num_cpus as u64]);
worker.scope(0, |scope, _| {
let a = &*a;
for (j, tmp) in tmp.iter_mut().enumerate() {
scope.spawn(move |_| {
let omega_j = omega.pow(&[j as u64]);
let omega_step = omega.pow(&[(j as u64) << log_new_n]);
let mut elt = F::one();
for (i, x) in tmp.iter_mut().enumerate().take(1 << log_new_n) {
for s in 0..num_cpus {
let idx = (i + (s << log_new_n)) % (1 << log_n);
let mut t = a[idx];
t *= &elt;
*x += &t;
elt *= &omega_step;
}
elt *= &omega_j;
}
serial_radix2_fft(tmp, new_omega, log_new_n);
});
}
});
worker.scope(a.len(), |scope, chunk| {
let tmp = &tmp;
for (idx, a) in a.chunks_mut(chunk).enumerate() {
scope.spawn(move |_| {
let mut idx = idx * chunk;
let mask = (1 << log_cpus) - 1;
for a in a {
*a = tmp[idx & mask][idx >> log_cpus];
idx += 1;
}
});
}
});
}
pub struct Elements<F: PrimeField> {
cur_elem: F,
cur_pow: u64,
domain: EvaluationDomain<F>,
}
impl<F: PrimeField> Iterator for Elements<F> {
type Item = F;
fn next(&mut self) -> Option<F> {
if self.cur_pow == self.domain.size {
None
} else {
let cur_elem = self.cur_elem;
self.cur_elem *= &self.domain.group_gen;
self.cur_pow += 1;
Some(cur_elem)
}
}
}
#[cfg(test)]
mod tests {
use crate::fft::{DensePolynomial, EvaluationDomain};
use snarkvm_curves::bls12_377::Fr;
use snarkvm_fields::{Field, One, PrimeField, Zero};
use snarkvm_utilities::UniformRand;
use rand::{thread_rng, Rng};
#[test]
fn vanishing_polynomial_evaluation() {
let rng = &mut thread_rng();
for coeffs in 0..10 {
let domain = EvaluationDomain::<Fr>::new(coeffs).unwrap();
let z = domain.vanishing_polynomial();
for _ in 0..100 {
let point = rng.gen();
assert_eq!(z.evaluate(point), domain.evaluate_vanishing_polynomial(point))
}
}
}
#[test]
fn vanishing_polynomial_vanishes_on_domain() {
for coeffs in 0..1000 {
let domain = EvaluationDomain::<Fr>::new(coeffs).unwrap();
let z = domain.vanishing_polynomial();
for point in domain.elements() {
assert!(z.evaluate(point).is_zero())
}
}
}
#[test]
fn size_of_elements() {
for coeffs in 1..10 {
let size = 1 << coeffs;
let domain = EvaluationDomain::<Fr>::new(size).unwrap();
let domain_size = domain.size();
assert_eq!(domain_size, domain.elements().count());
}
}
#[test]
fn elements_contents() {
for coeffs in 1..10 {
let size = 1 << coeffs;
let domain = EvaluationDomain::<Fr>::new(size).unwrap();
for (i, element) in domain.elements().enumerate() {
assert_eq!(element, domain.group_gen.pow([i as u64]));
}
}
}
#[test]
fn non_systematic_lagrange_coefficients_test() {
for domain_dimension in 1..10 {
let domain_size = 1 << domain_dimension;
let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
let random_point = Fr::rand(&mut thread_rng());
let lagrange_coefficients = domain.evaluate_all_lagrange_coefficients(random_point);
let random_polynomial = DensePolynomial::<Fr>::rand(domain_size - 1, &mut thread_rng());
let polynomial_evaluations = domain.fft(random_polynomial.coeffs());
let actual_evaluations = random_polynomial.evaluate(random_point);
let mut interpolated_evaluation = Fr::zero();
for i in 0..domain_size {
interpolated_evaluation += &(lagrange_coefficients[i] * &polynomial_evaluations[i]);
}
assert_eq!(actual_evaluations, interpolated_evaluation);
}
}
#[test]
fn systematic_lagrange_coefficients_test() {
for domain_dimension in 1..5 {
let domain_size = 1 << domain_dimension;
let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
let all_domain_elements: Vec<Fr> = domain.elements().collect();
for (i, domain_element) in all_domain_elements.iter().enumerate().take(domain_size) {
let lagrange_coefficients = domain.evaluate_all_lagrange_coefficients(*domain_element);
for (j, lagrange_coefficient) in lagrange_coefficients.iter().enumerate().take(domain_size) {
if i == j {
assert_eq!(*lagrange_coefficient, Fr::one());
} else {
assert_eq!(*lagrange_coefficient, Fr::zero());
}
}
}
}
}
#[test]
fn test_roots_of_unity() {
let max_degree = 10;
for log_domain_size in 0..max_degree {
let domain_size = 1 << log_domain_size;
let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
let actual_roots = domain.roots_of_unity(domain.group_gen);
for &value in &actual_roots {
assert!(domain.evaluate_vanishing_polynomial(value).is_zero());
}
let expected_roots_elements = domain.elements();
for (expected, &actual) in expected_roots_elements.zip(&actual_roots) {
assert_eq!(expected, actual);
}
assert_eq!(actual_roots.len(), domain_size / 2);
}
}
#[test]
fn test_fft_correctness() {
let log_degree = 5;
let degree = 1 << log_degree;
let random_polynomial = DensePolynomial::<Fr>::rand(degree - 1, &mut thread_rng());
for log_domain_size in log_degree..(log_degree + 2) {
let domain_size = 1 << log_domain_size;
let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
let polynomial_evaluations = domain.fft(&random_polynomial.coeffs);
let polynomial_coset_evaluations = domain.coset_fft(&random_polynomial.coeffs);
for (i, x) in domain.elements().enumerate() {
let coset_x = Fr::multiplicative_generator() * &x;
assert_eq!(polynomial_evaluations[i], random_polynomial.evaluate(x));
assert_eq!(polynomial_coset_evaluations[i], random_polynomial.evaluate(coset_x));
}
let randon_polynomial_from_subgroup =
DensePolynomial::from_coefficients_vec(domain.ifft(&polynomial_evaluations));
let random_polynomial_from_coset =
DensePolynomial::from_coefficients_vec(domain.coset_ifft(&polynomial_coset_evaluations));
assert_eq!(
random_polynomial, randon_polynomial_from_subgroup,
"degree = {}, domain size = {}",
degree, domain_size
);
assert_eq!(
random_polynomial, random_polynomial_from_coset,
"degree = {}, domain size = {}",
degree, domain_size
);
}
}
}