triton_vm/
arithmetic_domain.rs

1use std::ops::Mul;
2use std::ops::MulAssign;
3
4use num_traits::ConstOne;
5use num_traits::One;
6use num_traits::Zero;
7use rayon::prelude::*;
8use twenty_first::math::traits::FiniteField;
9use twenty_first::math::traits::PrimitiveRootOfUnity;
10use twenty_first::prelude::*;
11
12use crate::error::ArithmeticDomainError;
13
14type Result<T> = std::result::Result<T, ArithmeticDomainError>;
15
16#[derive(Debug, Copy, Clone, Eq, PartialEq)]
17pub struct ArithmeticDomain {
18    pub offset: BFieldElement,
19    pub generator: BFieldElement,
20    pub length: usize,
21}
22
23impl ArithmeticDomain {
24    /// Create a new domain with the given length.
25    /// No offset is applied, but can be added through [`with_offset()`](Self::with_offset).
26    ///
27    /// # Errors
28    ///
29    /// Errors if the domain length is not a power of 2.
30    pub fn of_length(length: usize) -> Result<Self> {
31        let domain = Self {
32            offset: BFieldElement::ONE,
33            generator: Self::generator_for_length(length as u64)?,
34            length,
35        };
36        Ok(domain)
37    }
38
39    /// Set the offset of the domain.
40    #[must_use]
41    pub fn with_offset(mut self, offset: BFieldElement) -> Self {
42        self.offset = offset;
43        self
44    }
45
46    /// Derive a generator for a domain of the given length.
47    ///
48    /// # Errors
49    ///
50    /// Errors if the domain length is not a power of 2.
51    pub fn generator_for_length(domain_length: u64) -> Result<BFieldElement> {
52        let error = ArithmeticDomainError::PrimitiveRootNotSupported(domain_length);
53        BFieldElement::primitive_root_of_unity(domain_length).ok_or(error)
54    }
55
56    pub fn evaluate<FF>(&self, polynomial: &Polynomial<FF>) -> Vec<FF>
57    where
58        FF: FiniteField
59            + MulAssign<BFieldElement>
60            + Mul<BFieldElement, Output = FF>
61            + From<BFieldElement>
62            + 'static,
63    {
64        let (offset, length) = (self.offset, self.length);
65        let evaluate_from = |chunk| Polynomial::from(chunk).fast_coset_evaluate(offset, length);
66
67        // avoid `enumerate` to directly get index of the right type
68        let mut indexed_chunks = (0..).zip(polynomial.coefficients().chunks(length));
69
70        // only allocate a bunch of zeros if there are no chunks
71        let mut values = indexed_chunks.next().map_or_else(
72            || vec![FF::ZERO; length],
73            |(_, first_chunk)| evaluate_from(first_chunk),
74        );
75        for (chunk_index, chunk) in indexed_chunks {
76            let coefficient_index = chunk_index * u64::try_from(length).unwrap();
77            let scaled_offset = offset.mod_pow(coefficient_index);
78            values
79                .par_iter_mut()
80                .zip(evaluate_from(chunk))
81                .for_each(|(value, evaluation)| *value += evaluation * scaled_offset);
82        }
83
84        values
85    }
86
87    /// # Panics
88    ///
89    /// Panics if the length of the argument does not match the length of `self`.
90    pub fn interpolate<FF>(&self, values: &[FF]) -> Polynomial<'static, FF>
91    where
92        FF: FiniteField + MulAssign<BFieldElement> + Mul<BFieldElement, Output = FF>,
93    {
94        // required by `fast_coset_interpolate`
95        debug_assert_eq!(self.length, values.len());
96
97        // generic type made explicit to avoid performance regressions due to auto-conversion
98        Polynomial::fast_coset_interpolate::<BFieldElement>(self.offset, values)
99    }
100
101    pub fn low_degree_extension<FF>(&self, codeword: &[FF], target_domain: Self) -> Vec<FF>
102    where
103        FF: FiniteField
104            + MulAssign<BFieldElement>
105            + Mul<BFieldElement, Output = FF>
106            + From<BFieldElement>
107            + 'static,
108    {
109        target_domain.evaluate(&self.interpolate(codeword))
110    }
111
112    /// Compute the `n`th element of the domain.
113    pub fn domain_value(&self, n: u32) -> BFieldElement {
114        self.generator.mod_pow_u32(n) * self.offset
115    }
116
117    pub fn domain_values(&self) -> Vec<BFieldElement> {
118        let mut accumulator = bfe!(1);
119        let mut domain_values = Vec::with_capacity(self.length);
120
121        for _ in 0..self.length {
122            domain_values.push(accumulator * self.offset);
123            accumulator *= self.generator;
124        }
125        assert!(
126            accumulator.is_one(),
127            "length must be the order of the generator"
128        );
129        domain_values
130    }
131
132    /// A polynomial that evaluates to 0 on (and only on)
133    /// a [domain value][Self::domain_values].
134    pub fn zerofier(&self) -> Polynomial<BFieldElement> {
135        if self.offset.is_zero() {
136            return Polynomial::x_to_the(1);
137        }
138
139        Polynomial::x_to_the(self.length)
140            - Polynomial::from_constant(self.offset.mod_pow(self.length as u64))
141    }
142
143    /// [`Self::zerofier`] times the argument.
144    /// More performant than polynomial multiplication.
145    /// See [`Self::zerofier`] for details.
146    pub fn mul_zerofier_with<FF>(&self, polynomial: Polynomial<FF>) -> Polynomial<'static, FF>
147    where
148        FF: FiniteField + Mul<BFieldElement, Output = FF>,
149    {
150        // use knowledge of zerofier's shape for faster multiplication
151        polynomial.clone().shift_coefficients(self.length)
152            - polynomial.scalar_mul(self.offset.mod_pow(self.length as u64))
153    }
154
155    pub(crate) fn halve(&self) -> Result<Self> {
156        if self.length < 2 {
157            return Err(ArithmeticDomainError::TooSmallForHalving(self.length));
158        }
159        let domain = Self {
160            offset: self.offset.square(),
161            generator: self.generator.square(),
162            length: self.length / 2,
163        };
164        Ok(domain)
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use assert2::let_assert;
171    use itertools::Itertools;
172    use proptest::collection::vec;
173    use proptest::prelude::*;
174    use proptest_arbitrary_interop::arb;
175    use test_strategy::proptest;
176
177    use crate::shared_tests::arbitrary_polynomial;
178    use crate::shared_tests::arbitrary_polynomial_of_degree;
179
180    use super::*;
181
182    prop_compose! {
183        fn arbitrary_domain()(
184            length in (0_usize..17).prop_map(|x| 1 << x),
185        )(
186            domain in arbitrary_domain_of_length(length),
187        ) -> ArithmeticDomain {
188            domain
189        }
190    }
191
192    prop_compose! {
193        fn arbitrary_halveable_domain()(
194            length in (2_usize..17).prop_map(|x| 1 << x),
195        )(
196            domain in arbitrary_domain_of_length(length),
197        ) -> ArithmeticDomain {
198            domain
199        }
200    }
201
202    prop_compose! {
203        fn arbitrary_domain_of_length(length: usize)(
204            offset in arb(),
205        ) -> ArithmeticDomain {
206            ArithmeticDomain::of_length(length).unwrap().with_offset(offset)
207        }
208    }
209
210    #[proptest]
211    fn evaluate_empty_polynomial(
212        #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
213        #[strategy(arbitrary_polynomial_of_degree(-1))] poly: Polynomial<'static, XFieldElement>,
214    ) {
215        domain.evaluate(&poly);
216    }
217
218    #[proptest]
219    fn evaluate_constant_polynomial(
220        #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
221        #[strategy(arbitrary_polynomial_of_degree(0))] poly: Polynomial<'static, XFieldElement>,
222    ) {
223        domain.evaluate(&poly);
224    }
225
226    #[proptest]
227    fn evaluate_linear_polynomial(
228        #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
229        #[strategy(arbitrary_polynomial_of_degree(1))] poly: Polynomial<'static, XFieldElement>,
230    ) {
231        domain.evaluate(&poly);
232    }
233
234    #[proptest]
235    fn evaluate_polynomial(
236        #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
237        #[strategy(arbitrary_polynomial())] polynomial: Polynomial<'static, XFieldElement>,
238    ) {
239        domain.evaluate(&polynomial);
240    }
241
242    #[test]
243    fn domain_values() {
244        let poly = Polynomial::<BFieldElement>::x_to_the(3);
245        let x_cubed_coefficients = poly.clone().into_coefficients();
246
247        for order in [4, 8, 32] {
248            let generator = BFieldElement::primitive_root_of_unity(order).unwrap();
249            let offset = BFieldElement::generator();
250            let b_domain = ArithmeticDomain::of_length(order as usize)
251                .unwrap()
252                .with_offset(offset);
253
254            let expected_b_values = (0..order)
255                .map(|i| offset * generator.mod_pow(i))
256                .collect_vec();
257            let actual_b_values_1 = b_domain.domain_values();
258            let actual_b_values_2 = (0..order as u32)
259                .map(|i| b_domain.domain_value(i))
260                .collect_vec();
261            assert_eq!(expected_b_values, actual_b_values_1);
262            assert_eq!(expected_b_values, actual_b_values_2);
263
264            let values = b_domain.evaluate(&poly);
265            assert_ne!(values, x_cubed_coefficients);
266
267            let interpolant = b_domain.interpolate(&values);
268            assert_eq!(poly, interpolant);
269
270            // Verify that batch-evaluated values match a manual evaluation
271            for i in 0..order {
272                let indeterminate = b_domain.domain_value(i as u32);
273                let evaluation: BFieldElement = poly.evaluate(indeterminate);
274                assert_eq!(evaluation, values[i as usize]);
275            }
276        }
277    }
278
279    #[test]
280    fn low_degree_extension() {
281        let short_domain_len = 32;
282        let long_domain_len = 128;
283        let unit_distance = long_domain_len / short_domain_len;
284
285        let short_domain = ArithmeticDomain::of_length(short_domain_len).unwrap();
286        let long_domain = ArithmeticDomain::of_length(long_domain_len).unwrap();
287
288        let polynomial = Polynomial::new(bfe_vec![1, 2, 3, 4]);
289        let short_codeword = short_domain.evaluate(&polynomial);
290        let long_codeword = short_domain.low_degree_extension(&short_codeword, long_domain);
291
292        assert_eq!(long_codeword.len(), long_domain_len);
293
294        let long_codeword_sub_view = long_codeword
295            .into_iter()
296            .step_by(unit_distance)
297            .collect_vec();
298        assert_eq!(short_codeword, long_codeword_sub_view);
299    }
300
301    #[proptest]
302    fn halving_domain_squares_all_points(
303        #[strategy(arbitrary_halveable_domain())] domain: ArithmeticDomain,
304    ) {
305        let half_domain = domain.halve()?;
306        prop_assert_eq!(domain.length / 2, half_domain.length);
307
308        let domain_points = domain.domain_values();
309        let half_domain_points = half_domain.domain_values();
310
311        for (domain_point, halved_domain_point) in domain_points
312            .into_iter()
313            .zip(half_domain_points.into_iter())
314        {
315            prop_assert_eq!(domain_point.square(), halved_domain_point);
316        }
317    }
318
319    #[test]
320    fn too_small_domains_cannot_be_halved() {
321        for i in [0, 1] {
322            let domain = ArithmeticDomain::of_length(i).unwrap();
323            let_assert!(Err(err) = domain.halve());
324            assert!(ArithmeticDomainError::TooSmallForHalving(i) == err);
325        }
326    }
327
328    #[proptest]
329    fn can_evaluate_polynomial_larger_than_domain(
330        #[strategy(1_usize..10)] _log_domain_length: usize,
331        #[strategy(1_usize..5)] _expansion_factor: usize,
332        #[strategy(Just(1 << #_log_domain_length))] domain_length: usize,
333        #[strategy(vec(arb(),#domain_length*#_expansion_factor))] coefficients: Vec<BFieldElement>,
334        #[strategy(arb())] offset: BFieldElement,
335    ) {
336        let domain = ArithmeticDomain::of_length(domain_length)
337            .unwrap()
338            .with_offset(offset);
339        let polynomial = Polynomial::new(coefficients);
340
341        let values0 = domain.evaluate(&polynomial);
342        let values1 = polynomial.batch_evaluate(&domain.domain_values());
343        assert_eq!(values0, values1);
344    }
345
346    #[proptest]
347    fn zerofier_is_actually_zerofier(#[strategy(arbitrary_domain())] domain: ArithmeticDomain) {
348        let actual_zerofier = Polynomial::zerofier(&domain.domain_values());
349        prop_assert_eq!(actual_zerofier, domain.zerofier());
350    }
351
352    #[proptest]
353    fn multiplication_with_zerofier_is_identical_to_method_mul_with_zerofier(
354        #[strategy(arbitrary_domain())] domain: ArithmeticDomain,
355        #[strategy(arbitrary_polynomial())] polynomial: Polynomial<'static, XFieldElement>,
356    ) {
357        let mul = domain.zerofier() * polynomial.clone();
358        let mul_with = domain.mul_zerofier_with(polynomial);
359        prop_assert_eq!(mul, mul_with);
360    }
361}