1use std::ops::Mul;
2use std::ops::MulAssign;
3
4use num_traits::ConstOne;
5use num_traits::One;
6use num_traits::Zero;
7use rayon::prelude::*;
8use twenty_first::math::traits::FiniteField;
9use twenty_first::math::traits::PrimitiveRootOfUnity;
10use twenty_first::prelude::*;
11
12use crate::error::ArithmeticDomainError;
13
14type Result<T> = std::result::Result<T, ArithmeticDomainError>;
15
16#[derive(Debug, Copy, Clone, Eq, PartialEq)]
17pub struct ArithmeticDomain {
18 pub offset: BFieldElement,
19 pub generator: BFieldElement,
20 pub length: usize,
21}
22
23impl ArithmeticDomain {
24 pub fn of_length(length: usize) -> Result<Self> {
31 let domain = Self {
32 offset: BFieldElement::ONE,
33 generator: Self::generator_for_length(length as u64)?,
34 length,
35 };
36 Ok(domain)
37 }
38
39 #[must_use]
41 pub fn with_offset(mut self, offset: BFieldElement) -> Self {
42 self.offset = offset;
43 self
44 }
45
46 pub fn generator_for_length(domain_length: u64) -> Result<BFieldElement> {
52 let error = ArithmeticDomainError::PrimitiveRootNotSupported(domain_length);
53 BFieldElement::primitive_root_of_unity(domain_length).ok_or(error)
54 }
55
56 pub fn evaluate<FF>(&self, polynomial: &Polynomial<FF>) -> Vec<FF>
57 where
58 FF: FiniteField
59 + MulAssign<BFieldElement>
60 + Mul<BFieldElement, Output = FF>
61 + From<BFieldElement>
62 + 'static,
63 {
64 let (offset, length) = (self.offset, self.length);
65 let evaluate_from = |chunk| Polynomial::from(chunk).fast_coset_evaluate(offset, length);
66
67 let mut indexed_chunks = (0..).zip(polynomial.coefficients().chunks(length));
69
70 let mut values = indexed_chunks.next().map_or_else(
72 || vec![FF::ZERO; length],
73 |(_, first_chunk)| evaluate_from(first_chunk),
74 );
75 for (chunk_index, chunk) in indexed_chunks {
76 let coefficient_index = chunk_index * u64::try_from(length).unwrap();
77 let scaled_offset = offset.mod_pow(coefficient_index);
78 values
79 .par_iter_mut()
80 .zip(evaluate_from(chunk))
81 .for_each(|(value, evaluation)| *value += evaluation * scaled_offset);
82 }
83
84 values
85 }
86
87 pub fn interpolate<FF>(&self, values: &[FF]) -> Polynomial<'static, FF>
91 where
92 FF: FiniteField + MulAssign<BFieldElement> + Mul<BFieldElement, Output = FF>,
93 {
94 debug_assert_eq!(self.length, values.len());
96
97 Polynomial::fast_coset_interpolate::<BFieldElement>(self.offset, values)
99 }
100
101 pub fn low_degree_extension<FF>(&self, codeword: &[FF], target_domain: Self) -> Vec<FF>
102 where
103 FF: FiniteField
104 + MulAssign<BFieldElement>
105 + Mul<BFieldElement, Output = FF>
106 + From<BFieldElement>
107 + 'static,
108 {
109 target_domain.evaluate(&self.interpolate(codeword))
110 }
111
112 pub fn domain_value(&self, n: u32) -> BFieldElement {
114 self.generator.mod_pow_u32(n) * self.offset
115 }
116
117 pub fn domain_values(&self) -> Vec<BFieldElement> {
118 let mut accumulator = bfe!(1);
119 let mut domain_values = Vec::with_capacity(self.length);
120
121 for _ in 0..self.length {
122 domain_values.push(accumulator * self.offset);
123 accumulator *= self.generator;
124 }
125 assert!(
126 accumulator.is_one(),
127 "length must be the order of the generator"
128 );
129 domain_values
130 }
131
132 pub fn zerofier(&self) -> Polynomial<BFieldElement> {
135 if self.offset.is_zero() {
136 return Polynomial::x_to_the(1);
137 }
138
139 Polynomial::x_to_the(self.length)
140 - Polynomial::from_constant(self.offset.mod_pow(self.length as u64))
141 }
142
143 pub fn mul_zerofier_with<FF>(&self, polynomial: Polynomial<FF>) -> Polynomial<'static, FF>
147 where
148 FF: FiniteField + Mul<BFieldElement, Output = FF>,
149 {
150 polynomial.clone().shift_coefficients(self.length)
152 - polynomial.scalar_mul(self.offset.mod_pow(self.length as u64))
153 }
154
155 pub(crate) fn halve(&self) -> Result<Self> {
156 if self.length < 2 {
157 return Err(ArithmeticDomainError::TooSmallForHalving(self.length));
158 }
159 let domain = Self {
160 offset: self.offset.square(),
161 generator: self.generator.square(),
162 length: self.length / 2,
163 };
164 Ok(domain)
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use assert2::let_assert;
171 use itertools::Itertools;
172 use proptest::collection::vec;
173 use proptest::prelude::*;
174 use proptest_arbitrary_interop::arb;
175 use test_strategy::proptest;
176
177 use crate::shared_tests::arbitrary_polynomial;
178 use crate::shared_tests::arbitrary_polynomial_of_degree;
179
180 use super::*;
181
182 prop_compose! {
183 fn arbitrary_domain()(
184 length in (0_usize..17).prop_map(|x| 1 << x),
185 )(
186 domain in arbitrary_domain_of_length(length),
187 ) -> ArithmeticDomain {
188 domain
189 }
190 }
191
192 prop_compose! {
193 fn arbitrary_halveable_domain()(
194 length in (2_usize..17).prop_map(|x| 1 << x),
195 )(
196 domain in arbitrary_domain_of_length(length),
197 ) -> ArithmeticDomain {
198 domain
199 }
200 }
201
202 prop_compose! {
203 fn arbitrary_domain_of_length(length: usize)(
204 offset in arb(),
205 ) -> ArithmeticDomain {
206 ArithmeticDomain::of_length(length).unwrap().with_offset(offset)
207 }
208 }
209
210 #[proptest]
211 fn evaluate_empty_polynomial(
212 #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
213 #[strategy(arbitrary_polynomial_of_degree(-1))] poly: Polynomial<'static, XFieldElement>,
214 ) {
215 domain.evaluate(&poly);
216 }
217
218 #[proptest]
219 fn evaluate_constant_polynomial(
220 #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
221 #[strategy(arbitrary_polynomial_of_degree(0))] poly: Polynomial<'static, XFieldElement>,
222 ) {
223 domain.evaluate(&poly);
224 }
225
226 #[proptest]
227 fn evaluate_linear_polynomial(
228 #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
229 #[strategy(arbitrary_polynomial_of_degree(1))] poly: Polynomial<'static, XFieldElement>,
230 ) {
231 domain.evaluate(&poly);
232 }
233
234 #[proptest]
235 fn evaluate_polynomial(
236 #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
237 #[strategy(arbitrary_polynomial())] polynomial: Polynomial<'static, XFieldElement>,
238 ) {
239 domain.evaluate(&polynomial);
240 }
241
242 #[test]
243 fn domain_values() {
244 let poly = Polynomial::<BFieldElement>::x_to_the(3);
245 let x_cubed_coefficients = poly.clone().into_coefficients();
246
247 for order in [4, 8, 32] {
248 let generator = BFieldElement::primitive_root_of_unity(order).unwrap();
249 let offset = BFieldElement::generator();
250 let b_domain = ArithmeticDomain::of_length(order as usize)
251 .unwrap()
252 .with_offset(offset);
253
254 let expected_b_values = (0..order)
255 .map(|i| offset * generator.mod_pow(i))
256 .collect_vec();
257 let actual_b_values_1 = b_domain.domain_values();
258 let actual_b_values_2 = (0..order as u32)
259 .map(|i| b_domain.domain_value(i))
260 .collect_vec();
261 assert_eq!(expected_b_values, actual_b_values_1);
262 assert_eq!(expected_b_values, actual_b_values_2);
263
264 let values = b_domain.evaluate(&poly);
265 assert_ne!(values, x_cubed_coefficients);
266
267 let interpolant = b_domain.interpolate(&values);
268 assert_eq!(poly, interpolant);
269
270 for i in 0..order {
272 let indeterminate = b_domain.domain_value(i as u32);
273 let evaluation: BFieldElement = poly.evaluate(indeterminate);
274 assert_eq!(evaluation, values[i as usize]);
275 }
276 }
277 }
278
279 #[test]
280 fn low_degree_extension() {
281 let short_domain_len = 32;
282 let long_domain_len = 128;
283 let unit_distance = long_domain_len / short_domain_len;
284
285 let short_domain = ArithmeticDomain::of_length(short_domain_len).unwrap();
286 let long_domain = ArithmeticDomain::of_length(long_domain_len).unwrap();
287
288 let polynomial = Polynomial::new(bfe_vec![1, 2, 3, 4]);
289 let short_codeword = short_domain.evaluate(&polynomial);
290 let long_codeword = short_domain.low_degree_extension(&short_codeword, long_domain);
291
292 assert_eq!(long_codeword.len(), long_domain_len);
293
294 let long_codeword_sub_view = long_codeword
295 .into_iter()
296 .step_by(unit_distance)
297 .collect_vec();
298 assert_eq!(short_codeword, long_codeword_sub_view);
299 }
300
301 #[proptest]
302 fn halving_domain_squares_all_points(
303 #[strategy(arbitrary_halveable_domain())] domain: ArithmeticDomain,
304 ) {
305 let half_domain = domain.halve()?;
306 prop_assert_eq!(domain.length / 2, half_domain.length);
307
308 let domain_points = domain.domain_values();
309 let half_domain_points = half_domain.domain_values();
310
311 for (domain_point, halved_domain_point) in domain_points
312 .into_iter()
313 .zip(half_domain_points.into_iter())
314 {
315 prop_assert_eq!(domain_point.square(), halved_domain_point);
316 }
317 }
318
319 #[test]
320 fn too_small_domains_cannot_be_halved() {
321 for i in [0, 1] {
322 let domain = ArithmeticDomain::of_length(i).unwrap();
323 let_assert!(Err(err) = domain.halve());
324 assert!(ArithmeticDomainError::TooSmallForHalving(i) == err);
325 }
326 }
327
328 #[proptest]
329 fn can_evaluate_polynomial_larger_than_domain(
330 #[strategy(1_usize..10)] _log_domain_length: usize,
331 #[strategy(1_usize..5)] _expansion_factor: usize,
332 #[strategy(Just(1 << #_log_domain_length))] domain_length: usize,
333 #[strategy(vec(arb(),#domain_length*#_expansion_factor))] coefficients: Vec<BFieldElement>,
334 #[strategy(arb())] offset: BFieldElement,
335 ) {
336 let domain = ArithmeticDomain::of_length(domain_length)
337 .unwrap()
338 .with_offset(offset);
339 let polynomial = Polynomial::new(coefficients);
340
341 let values0 = domain.evaluate(&polynomial);
342 let values1 = polynomial.batch_evaluate(&domain.domain_values());
343 assert_eq!(values0, values1);
344 }
345
346 #[proptest]
347 fn zerofier_is_actually_zerofier(#[strategy(arbitrary_domain())] domain: ArithmeticDomain) {
348 let actual_zerofier = Polynomial::zerofier(&domain.domain_values());
349 prop_assert_eq!(actual_zerofier, domain.zerofier());
350 }
351
352 #[proptest]
353 fn multiplication_with_zerofier_is_identical_to_method_mul_with_zerofier(
354 #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
355 #[strategy(arbitrary_polynomial())] polynomial: Polynomial<'static, XFieldElement>,
356 ) {
357 let mul = domain.zerofier() * polynomial.clone();
358 let mul_with = domain.mul_zerofier_with(polynomial);
359 prop_assert_eq!(mul, mul_with);
360 }
361}