triton_vm/
fri.rs

1use itertools::Itertools;
2use num_traits::Zero;
3use rayon::prelude::*;
4use twenty_first::math::polynomial::barycentric_evaluate;
5use twenty_first::math::traits::FiniteField;
6use twenty_first::prelude::*;
7
8use crate::arithmetic_domain::ArithmeticDomain;
9use crate::error::FriProvingError;
10use crate::error::FriSetupError;
11use crate::error::FriValidationError;
12use crate::profiler::profiler;
13use crate::proof_item::FriResponse;
14use crate::proof_item::ProofItem;
15use crate::proof_stream::ProofStream;
16
17pub(crate) type SetupResult<T> = Result<T, FriSetupError>;
18pub(crate) type ProverResult<T> = Result<T, FriProvingError>;
19pub(crate) type VerifierResult<T> = Result<T, FriValidationError>;
20
21pub type AuthenticationStructure = Vec<Digest>;
22
23#[derive(Debug, Copy, Clone)]
24pub struct Fri {
25    pub expansion_factor: usize,
26    pub num_collinearity_checks: usize,
27    pub domain: ArithmeticDomain,
28}
29
30#[derive(Debug, Eq, PartialEq)]
31struct FriProver<'stream> {
32    proof_stream: &'stream mut ProofStream,
33    rounds: Vec<ProverRound>,
34    first_round_domain: ArithmeticDomain,
35    num_rounds: usize,
36    num_collinearity_checks: usize,
37    first_round_collinearity_check_indices: Vec<usize>,
38}
39
40#[derive(Debug, Clone, Eq, PartialEq)]
41struct ProverRound {
42    domain: ArithmeticDomain,
43    codeword: Vec<XFieldElement>,
44    merkle_tree: MerkleTree,
45}
46
47impl FriProver<'_> {
48    fn commit(&mut self, codeword: &[XFieldElement]) -> ProverResult<()> {
49        self.commit_to_first_round(codeword)?;
50        for _ in 0..self.num_rounds {
51            self.commit_to_next_round()?;
52        }
53        self.send_last_codeword();
54        self.send_last_polynomial();
55        Ok(())
56    }
57
58    fn commit_to_first_round(&mut self, codeword: &[XFieldElement]) -> ProverResult<()> {
59        let first_round = ProverRound::new(self.first_round_domain, codeword)?;
60        self.commit_to_round(&first_round);
61        self.store_round(first_round);
62        Ok(())
63    }
64
65    fn commit_to_next_round(&mut self) -> ProverResult<()> {
66        let next_round = self.construct_next_round()?;
67        self.commit_to_round(&next_round);
68        self.store_round(next_round);
69        Ok(())
70    }
71
72    fn commit_to_round(&mut self, round: &ProverRound) {
73        let merkle_root = round.merkle_tree.root();
74        let proof_item = ProofItem::MerkleRoot(merkle_root);
75        self.proof_stream.enqueue(proof_item);
76    }
77
78    fn store_round(&mut self, round: ProverRound) {
79        self.rounds.push(round);
80    }
81
82    fn construct_next_round(&mut self) -> ProverResult<ProverRound> {
83        let previous_round = self.rounds.last().unwrap();
84        let folding_challenge = self.proof_stream.sample_scalars(1)[0];
85        let codeword = previous_round.split_and_fold(folding_challenge);
86        let domain = previous_round.domain.halve()?;
87        ProverRound::new(domain, &codeword)
88    }
89
90    fn send_last_codeword(&mut self) {
91        let last_codeword = self.rounds.last().unwrap().codeword.clone();
92        let proof_item = ProofItem::FriCodeword(last_codeword);
93        self.proof_stream.enqueue(proof_item);
94    }
95
96    fn send_last_polynomial(&mut self) {
97        let last_codeword = &self.rounds.last().unwrap().codeword;
98        let last_polynomial = ArithmeticDomain::of_length(last_codeword.len())
99            .unwrap()
100            .interpolate(last_codeword);
101        let proof_item = ProofItem::FriPolynomial(last_polynomial);
102        self.proof_stream.enqueue(proof_item);
103    }
104
105    fn query(&mut self) -> ProverResult<()> {
106        self.sample_first_round_collinearity_check_indices();
107
108        let initial_a_indices = self.first_round_collinearity_check_indices.clone();
109        self.authentically_reveal_codeword_of_round_at_indices(0, &initial_a_indices)?;
110
111        let num_rounds_that_have_a_next_round = self.rounds.len() - 1;
112        for round_number in 0..num_rounds_that_have_a_next_round {
113            let b_indices = self.collinearity_check_b_indices_for_round(round_number);
114            self.authentically_reveal_codeword_of_round_at_indices(round_number, &b_indices)?;
115        }
116
117        Ok(())
118    }
119
120    fn sample_first_round_collinearity_check_indices(&mut self) {
121        let indices_upper_bound = self.first_round_domain.length;
122        self.first_round_collinearity_check_indices = self
123            .proof_stream
124            .sample_indices(indices_upper_bound, self.num_collinearity_checks);
125    }
126
127    fn collinearity_check_b_indices_for_round(&self, round_number: usize) -> Vec<usize> {
128        let domain_length = self.rounds[round_number].domain.length;
129        self.first_round_collinearity_check_indices
130            .iter()
131            .map(|&a_index| (a_index + domain_length / 2) % domain_length)
132            .collect()
133    }
134
135    fn authentically_reveal_codeword_of_round_at_indices(
136        &mut self,
137        round_number: usize,
138        indices: &[usize],
139    ) -> ProverResult<()> {
140        let codeword = &self.rounds[round_number].codeword;
141        let revealed_leaves = indices.iter().map(|&i| codeword[i]).collect_vec();
142
143        let merkle_tree = &self.rounds[round_number].merkle_tree;
144        let auth_structure = merkle_tree.authentication_structure(indices)?;
145
146        let fri_response = FriResponse {
147            auth_structure,
148            revealed_leaves,
149        };
150        let proof_item = ProofItem::FriResponse(fri_response);
151        self.proof_stream.enqueue(proof_item);
152        Ok(())
153    }
154}
155
156impl ProverRound {
157    fn new(domain: ArithmeticDomain, codeword: &[XFieldElement]) -> ProverResult<Self> {
158        debug_assert_eq!(domain.length, codeword.len());
159        let merkle_tree = Self::merkle_tree_from_codeword(codeword)?;
160        let round = Self {
161            domain,
162            codeword: codeword.to_vec(),
163            merkle_tree,
164        };
165        Ok(round)
166    }
167
168    fn merkle_tree_from_codeword(codeword: &[XFieldElement]) -> ProverResult<MerkleTree> {
169        let digests: Vec<_> = codeword.par_iter().map(|&xfe| xfe.into()).collect();
170        MerkleTree::par_new(&digests).map_err(FriProvingError::MerkleTreeError)
171    }
172
173    fn split_and_fold(&self, folding_challenge: XFieldElement) -> Vec<XFieldElement> {
174        let one = xfe!(1);
175        let two_inverse = xfe!(2).inverse();
176
177        let domain_points = self.domain.domain_values();
178        let domain_point_inverses = BFieldElement::batch_inversion(domain_points);
179
180        let n = self.codeword.len();
181        (0..n / 2)
182            .into_par_iter()
183            .map(|i| {
184                let scaled_offset_inv = folding_challenge * domain_point_inverses[i];
185                let left_summand = (one + scaled_offset_inv) * self.codeword[i];
186                let right_summand = (one - scaled_offset_inv) * self.codeword[n / 2 + i];
187                (left_summand + right_summand) * two_inverse
188            })
189            .collect()
190    }
191}
192
193#[derive(Debug, Eq, PartialEq)]
194struct FriVerifier<'stream> {
195    proof_stream: &'stream mut ProofStream,
196    rounds: Vec<VerifierRound>,
197    first_round_domain: ArithmeticDomain,
198    last_round_codeword: Vec<XFieldElement>,
199    last_round_polynomial: Polynomial<'static, XFieldElement>,
200    last_round_max_degree: usize,
201    num_rounds: usize,
202    num_collinearity_checks: usize,
203    first_round_collinearity_check_indices: Vec<usize>,
204}
205
206#[derive(Debug, Clone, Eq, PartialEq)]
207struct VerifierRound {
208    domain: ArithmeticDomain,
209    partial_codeword_a: Vec<XFieldElement>,
210    partial_codeword_b: Vec<XFieldElement>,
211    merkle_root: Digest,
212    folding_challenge: Option<XFieldElement>,
213}
214
215impl FriVerifier<'_> {
216    fn initialize(&mut self) -> VerifierResult<()> {
217        let domain = self.first_round_domain;
218        let first_round = self.construct_round_with_domain(domain)?;
219        self.rounds.push(first_round);
220
221        for _ in 0..self.num_rounds {
222            let previous_round = self.rounds.last().unwrap();
223            let domain = previous_round.domain.halve()?;
224            let next_round = self.construct_round_with_domain(domain)?;
225            self.rounds.push(next_round);
226        }
227
228        self.last_round_codeword = self.proof_stream.dequeue()?.try_into_fri_codeword()?;
229        self.last_round_polynomial = self.proof_stream.dequeue()?.try_into_fri_polynomial()?;
230        Ok(())
231    }
232
233    fn construct_round_with_domain(
234        &mut self,
235        domain: ArithmeticDomain,
236    ) -> VerifierResult<VerifierRound> {
237        let merkle_root = self.proof_stream.dequeue()?.try_into_merkle_root()?;
238        let folding_challenge = self
239            .need_more_folding_challenges()
240            .then(|| self.proof_stream.sample_scalars(1)[0]);
241
242        let verifier_round = VerifierRound {
243            domain,
244            partial_codeword_a: vec![],
245            partial_codeword_b: vec![],
246            merkle_root,
247            folding_challenge,
248        };
249        Ok(verifier_round)
250    }
251
252    fn need_more_folding_challenges(&self) -> bool {
253        if self.num_rounds == 0 {
254            return false;
255        }
256
257        let num_initialized_rounds = self.rounds.len();
258        let num_rounds_that_have_a_next_round = self.num_rounds - 1;
259        num_initialized_rounds <= num_rounds_that_have_a_next_round
260    }
261
262    fn compute_last_round_folded_partial_codeword(&mut self) -> VerifierResult<()> {
263        self.sample_first_round_collinearity_check_indices();
264        self.receive_authentic_partially_revealed_codewords()?;
265        self.successively_fold_partial_codeword_of_each_round();
266        Ok(())
267    }
268
269    fn sample_first_round_collinearity_check_indices(&mut self) {
270        let upper_bound = self.first_round_domain.length;
271        self.first_round_collinearity_check_indices = self
272            .proof_stream
273            .sample_indices(upper_bound, self.num_collinearity_checks);
274    }
275
276    fn receive_authentic_partially_revealed_codewords(&mut self) -> VerifierResult<()> {
277        let auth_structure = self.receive_partial_codeword_a_for_first_round()?;
278        self.authenticate_partial_codeword_a_for_first_round(auth_structure)?;
279
280        let num_rounds_that_have_a_next_round = self.rounds.len() - 1;
281        for round_number in 0..num_rounds_that_have_a_next_round {
282            let auth_structure = self.receive_partial_codeword_b_for_round(round_number)?;
283            self.authenticate_partial_codeword_b_for_round(round_number, auth_structure)?;
284        }
285        Ok(())
286    }
287
288    fn receive_partial_codeword_a_for_first_round(
289        &mut self,
290    ) -> VerifierResult<AuthenticationStructure> {
291        let fri_response = self.proof_stream.dequeue()?.try_into_fri_response()?;
292        let FriResponse {
293            auth_structure,
294            revealed_leaves,
295        } = fri_response;
296
297        self.assert_enough_leaves_were_received(&revealed_leaves)?;
298        self.rounds[0].partial_codeword_a = revealed_leaves;
299        Ok(auth_structure)
300    }
301
302    fn receive_partial_codeword_b_for_round(
303        &mut self,
304        round_number: usize,
305    ) -> VerifierResult<AuthenticationStructure> {
306        let fri_response = self.proof_stream.dequeue()?.try_into_fri_response()?;
307        let FriResponse {
308            auth_structure,
309            revealed_leaves,
310        } = fri_response;
311
312        self.assert_enough_leaves_were_received(&revealed_leaves)?;
313        self.rounds[round_number].partial_codeword_b = revealed_leaves;
314        Ok(auth_structure)
315    }
316
317    fn assert_enough_leaves_were_received(&self, leaves: &[XFieldElement]) -> VerifierResult<()> {
318        match self.num_collinearity_checks == leaves.len() {
319            true => Ok(()),
320            false => Err(FriValidationError::IncorrectNumberOfRevealedLeaves),
321        }
322    }
323
324    fn authenticate_partial_codeword_a_for_first_round(
325        &self,
326        authentication_structure: AuthenticationStructure,
327    ) -> VerifierResult<()> {
328        let round = &self.rounds[0];
329        let revealed_leaves = &round.partial_codeword_a;
330        let revealed_digests = codeword_as_digests(revealed_leaves);
331
332        let leaf_indices = self.collinearity_check_a_indices_for_round(0);
333        let indexed_leafs = leaf_indices.into_iter().zip_eq(revealed_digests).collect();
334
335        let inclusion_proof = MerkleTreeInclusionProof {
336            tree_height: round.merkle_tree_height(),
337            indexed_leafs,
338            authentication_structure,
339        };
340        match inclusion_proof.verify(round.merkle_root) {
341            true => Ok(()),
342            false => Err(FriValidationError::BadMerkleAuthenticationPath),
343        }
344    }
345
346    fn authenticate_partial_codeword_b_for_round(
347        &self,
348        round_number: usize,
349        authentication_structure: AuthenticationStructure,
350    ) -> VerifierResult<()> {
351        let round = &self.rounds[round_number];
352        let revealed_leaves = &round.partial_codeword_b;
353        let revealed_digests = codeword_as_digests(revealed_leaves);
354
355        let leaf_indices = self.collinearity_check_b_indices_for_round(round_number);
356        let indexed_leafs = leaf_indices.into_iter().zip_eq(revealed_digests).collect();
357
358        let inclusion_proof = MerkleTreeInclusionProof {
359            tree_height: round.merkle_tree_height(),
360            indexed_leafs,
361            authentication_structure,
362        };
363        match inclusion_proof.verify(round.merkle_root) {
364            true => Ok(()),
365            false => Err(FriValidationError::BadMerkleAuthenticationPath),
366        }
367    }
368
369    fn successively_fold_partial_codeword_of_each_round(&mut self) {
370        let num_rounds_that_have_a_next_round = self.rounds.len() - 1;
371        for round_number in 0..num_rounds_that_have_a_next_round {
372            let folded_partial_codeword = self.fold_partial_codeword_of_round(round_number);
373            let next_round = &mut self.rounds[round_number + 1];
374            next_round.partial_codeword_a = folded_partial_codeword;
375        }
376    }
377
378    fn fold_partial_codeword_of_round(&self, round_number: usize) -> Vec<XFieldElement> {
379        let round = &self.rounds[round_number];
380        let a_indices = self.collinearity_check_a_indices_for_round(round_number);
381        let b_indices = self.collinearity_check_b_indices_for_round(round_number);
382        let partial_codeword_a = &round.partial_codeword_a;
383        let partial_codeword_b = &round.partial_codeword_b;
384        let domain = round.domain;
385        let folding_challenge = round.folding_challenge.unwrap();
386
387        (0..self.num_collinearity_checks)
388            .map(|i| {
389                let point_a_x = domain.domain_value(a_indices[i] as u32).lift();
390                let point_b_x = domain.domain_value(b_indices[i] as u32).lift();
391                let point_a = (point_a_x, partial_codeword_a[i]);
392                let point_b = (point_b_x, partial_codeword_b[i]);
393                Polynomial::get_colinear_y(point_a, point_b, folding_challenge)
394            })
395            .collect()
396    }
397
398    fn collinearity_check_a_indices_for_round(&self, round_number: usize) -> Vec<usize> {
399        let domain_length = self.rounds[round_number].domain.length;
400        let a_offset = 0;
401        self.collinearity_check_indices_with_offset_and_modulus(a_offset, domain_length)
402    }
403
404    fn collinearity_check_b_indices_for_round(&self, round_number: usize) -> Vec<usize> {
405        let domain_length = self.rounds[round_number].domain.length;
406        let b_offset = domain_length / 2;
407        self.collinearity_check_indices_with_offset_and_modulus(b_offset, domain_length)
408    }
409
410    fn collinearity_check_indices_with_offset_and_modulus(
411        &self,
412        offset: usize,
413        modulus: usize,
414    ) -> Vec<usize> {
415        self.first_round_collinearity_check_indices
416            .iter()
417            .map(|&i| (i + offset) % modulus)
418            .collect()
419    }
420
421    fn authenticate_last_round_codeword(&mut self) -> VerifierResult<()> {
422        self.assert_last_round_codeword_matches_last_round_commitment()?;
423        self.assert_last_round_codeword_agrees_with_last_round_folded_codeword()?;
424        self.assert_last_round_codeword_corresponds_to_low_degree_polynomial()
425    }
426
427    fn assert_last_round_codeword_matches_last_round_commitment(&self) -> VerifierResult<()> {
428        match self.last_round_merkle_root() == self.last_round_codeword_merkle_root()? {
429            true => Ok(()),
430            false => Err(FriValidationError::BadMerkleRootForLastCodeword),
431        }
432    }
433
434    fn last_round_codeword_merkle_root(&self) -> VerifierResult<Digest> {
435        let codeword_digests = codeword_as_digests(&self.last_round_codeword);
436        let merkle_tree = MerkleTree::sequential_new(&codeword_digests)
437            .map_err(FriValidationError::MerkleTreeError)?;
438
439        Ok(merkle_tree.root())
440    }
441
442    fn last_round_merkle_root(&self) -> Digest {
443        self.rounds.last().unwrap().merkle_root
444    }
445
446    fn assert_last_round_codeword_agrees_with_last_round_folded_codeword(
447        &self,
448    ) -> VerifierResult<()> {
449        let partial_folded_codeword = self.folded_last_round_codeword_at_indices_a();
450        let partial_received_codeword = self.received_last_round_codeword_at_indices_a();
451        match partial_received_codeword == partial_folded_codeword {
452            true => Ok(()),
453            false => Err(FriValidationError::LastCodewordMismatch),
454        }
455    }
456
457    fn folded_last_round_codeword_at_indices_a(&self) -> &[XFieldElement] {
458        &self.rounds.last().unwrap().partial_codeword_a
459    }
460
461    fn received_last_round_codeword_at_indices_a(&self) -> Vec<XFieldElement> {
462        let last_round_number = self.rounds.len() - 1;
463        let last_round_indices_a = self.collinearity_check_a_indices_for_round(last_round_number);
464        last_round_indices_a
465            .iter()
466            .map(|&last_round_index_a| self.last_round_codeword[last_round_index_a])
467            .collect()
468    }
469
470    fn assert_last_round_codeword_corresponds_to_low_degree_polynomial(
471        &mut self,
472    ) -> VerifierResult<()> {
473        if self.last_round_polynomial.degree() > self.last_round_max_degree.try_into().unwrap() {
474            return Err(FriValidationError::LastRoundPolynomialHasTooHighDegree);
475        }
476
477        let indeterminate = self.proof_stream.sample_scalars(1)[0];
478        let horner_evaluation = self
479            .last_round_polynomial
480            .evaluate_in_same_field(indeterminate);
481        let barycentric_evaluation = barycentric_evaluate(&self.last_round_codeword, indeterminate);
482        if horner_evaluation != barycentric_evaluation {
483            return Err(FriValidationError::LastRoundPolynomialEvaluationMismatch);
484        }
485
486        Ok(())
487    }
488
489    fn first_round_partially_revealed_codeword(&self) -> Vec<(usize, XFieldElement)> {
490        self.collinearity_check_a_indices_for_round(0)
491            .into_iter()
492            .zip_eq(self.rounds[0].partial_codeword_a.clone())
493            .collect()
494    }
495}
496
497impl VerifierRound {
498    fn merkle_tree_height(&self) -> u32 {
499        self.domain.length.ilog2()
500    }
501}
502
503impl Fri {
504    pub fn new(
505        domain: ArithmeticDomain,
506        expansion_factor: usize,
507        num_collinearity_checks: usize,
508    ) -> SetupResult<Self> {
509        match expansion_factor {
510            ef if ef <= 1 => return Err(FriSetupError::ExpansionFactorTooSmall),
511            ef if !ef.is_power_of_two() => return Err(FriSetupError::ExpansionFactorUnsupported),
512            ef if ef > domain.length => return Err(FriSetupError::ExpansionFactorMismatch),
513            _ => (),
514        };
515
516        Ok(Self {
517            expansion_factor,
518            num_collinearity_checks,
519            domain,
520        })
521    }
522
523    /// Create a FRI proof and return a-indices of revealed elements of round 0.
524    pub fn prove(
525        &self,
526        codeword: &[XFieldElement],
527        proof_stream: &mut ProofStream,
528    ) -> ProverResult<Vec<usize>> {
529        let mut prover = self.prover(proof_stream);
530
531        prover.commit(codeword)?;
532        prover.query()?;
533
534        // Sample one XFieldElement from Fiat-Shamir and then throw it away.
535        // This scalar is the indeterminate for the low degree test using the
536        // barycentric evaluation formula. This indeterminate is used only by
537        // the verifier, but it is important to modify the sponge state the same
538        // way.
539        prover.proof_stream.sample_scalars(1);
540
541        Ok(prover.first_round_collinearity_check_indices)
542    }
543
544    fn prover<'stream>(
545        &'stream self,
546        proof_stream: &'stream mut ProofStream,
547    ) -> FriProver<'stream> {
548        FriProver {
549            proof_stream,
550            rounds: vec![],
551            first_round_domain: self.domain,
552            num_rounds: self.num_rounds(),
553            num_collinearity_checks: self.num_collinearity_checks,
554            first_round_collinearity_check_indices: vec![],
555        }
556    }
557
558    /// Verify low-degreeness of the polynomial on the proof stream.
559    /// Returns the indices and revealed elements of the codeword at the top
560    /// level of the FRI proof.
561    pub fn verify(
562        &self,
563        proof_stream: &mut ProofStream,
564    ) -> VerifierResult<Vec<(usize, XFieldElement)>> {
565        profiler!(start "init");
566        let mut verifier = self.verifier(proof_stream);
567        verifier.initialize()?;
568        profiler!(stop "init");
569
570        profiler!(start "fold all rounds");
571        verifier.compute_last_round_folded_partial_codeword()?;
572        profiler!(stop "fold all rounds");
573
574        profiler!(start "authenticate last round codeword");
575        verifier.authenticate_last_round_codeword()?;
576        profiler!(stop "authenticate last round codeword");
577
578        Ok(verifier.first_round_partially_revealed_codeword())
579    }
580
581    fn verifier<'stream>(
582        &'stream self,
583        proof_stream: &'stream mut ProofStream,
584    ) -> FriVerifier<'stream> {
585        FriVerifier {
586            proof_stream,
587            rounds: vec![],
588            first_round_domain: self.domain,
589            last_round_codeword: vec![],
590            last_round_polynomial: Polynomial::zero(),
591            last_round_max_degree: self.last_round_max_degree(),
592            num_rounds: self.num_rounds(),
593            num_collinearity_checks: self.num_collinearity_checks,
594            first_round_collinearity_check_indices: vec![],
595        }
596    }
597
598    pub fn num_rounds(&self) -> usize {
599        let first_round_code_dimension = self.first_round_max_degree() + 1;
600        let max_num_rounds = first_round_code_dimension.next_power_of_two().ilog2();
601
602        // Skip rounds for which Merkle tree verification cost exceeds
603        // arithmetic cost, because more than half the codeword's locations are
604        // queried.
605        let num_rounds_checking_all_locations = self.num_collinearity_checks.ilog2();
606        let num_rounds_checking_most_locations = num_rounds_checking_all_locations + 1;
607
608        let num_rounds = max_num_rounds.saturating_sub(num_rounds_checking_most_locations);
609        num_rounds.try_into().unwrap()
610    }
611
612    pub fn last_round_max_degree(&self) -> usize {
613        self.first_round_max_degree() >> self.num_rounds()
614    }
615
616    pub fn first_round_max_degree(&self) -> usize {
617        assert!(self.domain.length >= self.expansion_factor);
618        (self.domain.length / self.expansion_factor) - 1
619    }
620}
621
622fn codeword_as_digests(codeword: &[XFieldElement]) -> Vec<Digest> {
623    codeword.iter().map(|&xfe| xfe.into()).collect()
624}
625
626#[cfg(test)]
627#[cfg_attr(coverage_nightly, coverage(off))]
628mod tests {
629    use std::cmp::max;
630    use std::cmp::min;
631
632    use assert2::assert;
633    use assert2::let_assert;
634    use itertools::Itertools;
635    use proptest::prelude::*;
636    use proptest_arbitrary_interop::arb;
637    use rand::prelude::*;
638    use test_strategy::proptest;
639
640    use crate::error::FriValidationError;
641    use crate::shared_tests::arbitrary_polynomial;
642    use crate::shared_tests::arbitrary_polynomial_of_degree;
643
644    use super::*;
645
646    /// A type alias exclusive to this test module.
647    type XfePoly = Polynomial<'static, XFieldElement>;
648
649    prop_compose! {
650        fn arbitrary_fri_supporting_degree(min_supported_degree: i64)(
651            log_2_expansion_factor in 1_usize..=8
652        )(
653            log_2_expansion_factor in Just(log_2_expansion_factor),
654            log_2_domain_length in log_2_expansion_factor..=18,
655            num_collinearity_checks in 1_usize..=320,
656            offset in arb(),
657        ) -> Fri {
658            let expansion_factor = (1 << log_2_expansion_factor) as usize;
659            let sampled_domain_length = (1 << log_2_domain_length) as usize;
660
661            let min_domain_length = match min_supported_degree {
662                d if d <= -1 => 0,
663                _ => (min_supported_degree as u64 + 1).next_power_of_two() as usize,
664            };
665            let min_expanded_domain_length = min_domain_length * expansion_factor;
666            let domain_length = max(sampled_domain_length, min_expanded_domain_length);
667
668            let maybe_domain = ArithmeticDomain::of_length(domain_length);
669            let fri_domain = maybe_domain.unwrap().with_offset(offset);
670
671            Fri::new(fri_domain, expansion_factor, num_collinearity_checks).unwrap()
672        }
673    }
674
675    impl Arbitrary for Fri {
676        type Parameters = ();
677
678        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
679            arbitrary_fri_supporting_degree(-1).boxed()
680        }
681
682        type Strategy = BoxedStrategy<Self>;
683    }
684
685    #[proptest]
686    fn sample_indices(fri: Fri, #[strategy(arb())] initial_absorb: [BFieldElement; tip5::RATE]) {
687        let mut sponge = Tip5::init();
688        sponge.absorb(initial_absorb);
689
690        // todo: Figure out by how much to oversample for the given parameters.
691        let oversampling_summand = 1 << 13;
692        let num_indices_to_sample = fri.num_collinearity_checks + oversampling_summand;
693        let indices = sponge.sample_indices(fri.domain.length as u32, num_indices_to_sample);
694        let num_unique_indices = indices.iter().unique().count();
695
696        let required_unique_indices = min(fri.domain.length, fri.num_collinearity_checks);
697        prop_assert!(num_unique_indices >= required_unique_indices);
698    }
699
700    #[proptest]
701    fn num_rounds_are_reasonable(fri: Fri) {
702        let expected_last_round_max_degree = fri.first_round_max_degree() >> fri.num_rounds();
703        prop_assert_eq!(expected_last_round_max_degree, fri.last_round_max_degree());
704        if fri.num_rounds() > 0 {
705            prop_assert!(fri.num_collinearity_checks <= expected_last_round_max_degree);
706            prop_assert!(expected_last_round_max_degree < 2 * fri.num_collinearity_checks);
707        }
708    }
709
710    #[proptest(cases = 20)]
711    fn prove_and_verify_low_degree_of_twice_cubing_plus_one(
712        #[strategy(arbitrary_fri_supporting_degree(3))] fri: Fri,
713    ) {
714        let coefficients = [1, 0, 0, 2].map(|c| c.into()).to_vec();
715        let polynomial = Polynomial::new(coefficients);
716        let codeword = fri.domain.evaluate(&polynomial);
717
718        let mut proof_stream = ProofStream::new();
719        fri.prove(&codeword, &mut proof_stream).unwrap();
720
721        let mut proof_stream = prepare_proof_stream_for_verification(proof_stream);
722        let verdict = fri.verify(&mut proof_stream);
723        prop_assert!(verdict.is_ok());
724    }
725
726    #[proptest(cases = 50)]
727    fn prove_and_verify_low_degree_polynomial(
728        fri: Fri,
729        #[strategy(-1_i64..=#fri.first_round_max_degree() as i64)] _degree: i64,
730        #[strategy(arbitrary_polynomial_of_degree(#_degree))] polynomial: XfePoly,
731    ) {
732        debug_assert!(polynomial.degree() <= fri.first_round_max_degree() as isize);
733        let codeword = fri.domain.evaluate(&polynomial);
734        let mut proof_stream = ProofStream::new();
735        fri.prove(&codeword, &mut proof_stream).unwrap();
736
737        let mut proof_stream = prepare_proof_stream_for_verification(proof_stream);
738        let verdict = fri.verify(&mut proof_stream);
739        prop_assert!(verdict.is_ok());
740    }
741
742    #[proptest(cases = 50)]
743    fn prove_and_fail_to_verify_high_degree_polynomial(
744        fri: Fri,
745        #[strategy(Just((1 + #fri.first_round_max_degree()) as i64))] _too_high_degree: i64,
746        #[strategy(#_too_high_degree..2 * #_too_high_degree)] _degree: i64,
747        #[strategy(arbitrary_polynomial_of_degree(#_degree))] polynomial: XfePoly,
748    ) {
749        debug_assert!(polynomial.degree() > fri.first_round_max_degree() as isize);
750        let codeword = fri.domain.evaluate(&polynomial);
751        let mut proof_stream = ProofStream::new();
752        fri.prove(&codeword, &mut proof_stream).unwrap();
753
754        let mut proof_stream = prepare_proof_stream_for_verification(proof_stream);
755        let verdict = fri.verify(&mut proof_stream);
756        prop_assert!(verdict.is_err());
757    }
758
759    #[test]
760    fn smallest_possible_fri_has_no_rounds() {
761        assert_eq!(0, smallest_fri().num_rounds());
762    }
763
764    #[test]
765    fn smallest_possible_fri_can_only_verify_constant_polynomials() {
766        assert_eq!(0, smallest_fri().first_round_max_degree());
767    }
768
769    fn smallest_fri() -> Fri {
770        let domain = ArithmeticDomain::of_length(2).unwrap();
771        let expansion_factor = 2;
772        let num_collinearity_checks = 1;
773        Fri::new(domain, expansion_factor, num_collinearity_checks).unwrap()
774    }
775
776    #[test]
777    fn too_small_expansion_factor_is_rejected() {
778        let domain = ArithmeticDomain::of_length(2).unwrap();
779        let expansion_factor = 1;
780        let num_collinearity_checks = 1;
781        let err = Fri::new(domain, expansion_factor, num_collinearity_checks).unwrap_err();
782        assert_eq!(FriSetupError::ExpansionFactorTooSmall, err);
783    }
784
785    #[proptest]
786    fn expansion_factor_not_a_power_of_two_is_rejected(
787        #[strategy(2_usize..(1 << 32))]
788        #[filter(!#expansion_factor.is_power_of_two())]
789        expansion_factor: usize,
790    ) {
791        let largest_supported_domain_size = 1 << 32;
792        let domain = ArithmeticDomain::of_length(largest_supported_domain_size).unwrap();
793        let num_collinearity_checks = 1;
794        let err = Fri::new(domain, expansion_factor, num_collinearity_checks).unwrap_err();
795        prop_assert_eq!(FriSetupError::ExpansionFactorUnsupported, err);
796    }
797
798    #[proptest]
799    fn domain_size_smaller_than_expansion_factor_is_rejected(
800        #[strategy(1_usize..32)] log_2_expansion_factor: usize,
801        #[strategy(..#log_2_expansion_factor)] log_2_domain_length: usize,
802    ) {
803        let expansion_factor = 1 << log_2_expansion_factor;
804        let domain_length = 1 << log_2_domain_length;
805        let domain = ArithmeticDomain::of_length(domain_length).unwrap();
806        let num_collinearity_checks = 1;
807        let err = Fri::new(domain, expansion_factor, num_collinearity_checks).unwrap_err();
808        prop_assert_eq!(FriSetupError::ExpansionFactorMismatch, err);
809    }
810
811    // todo: add test fuzzing proof_stream
812
813    #[proptest(cases = 50)]
814    fn serialization(
815        fri: Fri,
816        #[strategy(-1_i64..=#fri.first_round_max_degree() as i64)] _degree: i64,
817        #[strategy(arbitrary_polynomial_of_degree(#_degree))] polynomial: XfePoly,
818    ) {
819        let codeword = fri.domain.evaluate(&polynomial);
820        let mut prover_proof_stream = ProofStream::new();
821        fri.prove(&codeword, &mut prover_proof_stream).unwrap();
822
823        let proof = (&prover_proof_stream).into();
824        let verifier_proof_stream = ProofStream::try_from(&proof).unwrap();
825
826        let prover_items = prover_proof_stream.items.iter();
827        let verifier_items = verifier_proof_stream.items.iter();
828        for (prover_item, verifier_item) in prover_items.zip_eq(verifier_items) {
829            use ProofItem as PI;
830            match (prover_item, verifier_item) {
831                (PI::MerkleRoot(p), PI::MerkleRoot(v)) => prop_assert_eq!(p, v),
832                (PI::FriResponse(p), PI::FriResponse(v)) => prop_assert_eq!(p, v),
833                (PI::FriCodeword(p), PI::FriCodeword(v)) => prop_assert_eq!(p, v),
834                (PI::FriPolynomial(p), PI::FriPolynomial(v)) => prop_assert_eq!(p, v),
835                _ => panic!("Unknown items.\nProver: {prover_item:?}\nVerifier: {verifier_item:?}"),
836            }
837        }
838    }
839
840    #[proptest(cases = 50)]
841    fn last_round_codeword_unequal_to_last_round_commitment_results_in_validation_failure(
842        fri: Fri,
843        #[strategy(arbitrary_polynomial())] polynomial: XfePoly,
844        rng_seed: u64,
845    ) {
846        let codeword = fri.domain.evaluate(&polynomial);
847        let mut proof_stream = ProofStream::new();
848        fri.prove(&codeword, &mut proof_stream).unwrap();
849
850        let proof_stream = prepare_proof_stream_for_verification(proof_stream);
851        let mut proof_stream =
852            modify_last_round_codeword_in_proof_stream_using_seed(proof_stream, rng_seed);
853
854        let verdict = fri.verify(&mut proof_stream);
855        let err = verdict.unwrap_err();
856        let FriValidationError::BadMerkleRootForLastCodeword = err else {
857            return Err(TestCaseError::Fail("validation must fail".into()));
858        };
859    }
860
861    #[must_use]
862    fn prepare_proof_stream_for_verification(mut proof_stream: ProofStream) -> ProofStream {
863        proof_stream.items_index = 0;
864        proof_stream.sponge = Tip5::init();
865        proof_stream
866    }
867
868    #[must_use]
869    fn modify_last_round_codeword_in_proof_stream_using_seed(
870        mut proof_stream: ProofStream,
871        seed: u64,
872    ) -> ProofStream {
873        let mut proof_items = proof_stream.items.iter_mut();
874        let last_round_codeword = proof_items.find_map(fri_codeword_filter()).unwrap();
875
876        let mut rng = StdRng::seed_from_u64(seed);
877        let modification_index = rng.random_range(0..last_round_codeword.len());
878        let replacement_element = rng.random();
879
880        last_round_codeword[modification_index] = replacement_element;
881        proof_stream
882    }
883
884    fn fri_codeword_filter() -> fn(&mut ProofItem) -> Option<&mut Vec<XFieldElement>> {
885        |proof_item| match proof_item {
886            ProofItem::FriCodeword(codeword) => Some(codeword),
887            _ => None,
888        }
889    }
890
891    #[proptest(cases = 50)]
892    fn revealing_wrong_number_of_leaves_results_in_validation_failure(
893        fri: Fri,
894        #[strategy(arbitrary_polynomial())] polynomial: XfePoly,
895        rng_seed: u64,
896    ) {
897        let codeword = fri.domain.evaluate(&polynomial);
898        let mut proof_stream = ProofStream::new();
899        fri.prove(&codeword, &mut proof_stream).unwrap();
900
901        let proof_stream = prepare_proof_stream_for_verification(proof_stream);
902        let mut proof_stream =
903            change_size_of_some_fri_response_in_proof_stream_using_seed(proof_stream, rng_seed);
904
905        let verdict = fri.verify(&mut proof_stream);
906        let err = verdict.unwrap_err();
907        let FriValidationError::IncorrectNumberOfRevealedLeaves = err else {
908            return Err(TestCaseError::Fail("validation must fail".into()));
909        };
910    }
911
912    #[must_use]
913    fn change_size_of_some_fri_response_in_proof_stream_using_seed(
914        mut proof_stream: ProofStream,
915        seed: u64,
916    ) -> ProofStream {
917        let proof_items = proof_stream.items.iter_mut();
918        let fri_responses = proof_items.filter_map(fri_response_filter());
919
920        let mut rng = StdRng::seed_from_u64(seed);
921        let fri_response = fri_responses.choose(&mut rng).unwrap();
922        let revealed_leaves = &mut fri_response.revealed_leaves;
923        let modification_index = rng.random_range(0..revealed_leaves.len());
924        if rng.random() {
925            revealed_leaves.remove(modification_index);
926        } else {
927            revealed_leaves.insert(modification_index, rng.random());
928        };
929
930        proof_stream
931    }
932
933    fn fri_response_filter() -> fn(&mut ProofItem) -> Option<&mut super::FriResponse> {
934        |proof_item| match proof_item {
935            ProofItem::FriResponse(fri_response) => Some(fri_response),
936            _ => None,
937        }
938    }
939
940    #[proptest(cases = 50)]
941    fn incorrect_authentication_structure_results_in_validation_failure(
942        fri: Fri,
943        #[strategy(arbitrary_polynomial())] polynomial: XfePoly,
944        rng_seed: u64,
945    ) {
946        let all_authentication_structures_are_trivial =
947            fri.num_collinearity_checks >= fri.domain.length;
948        if all_authentication_structures_are_trivial {
949            return Ok(());
950        }
951
952        let codeword = fri.domain.evaluate(&polynomial);
953        let mut proof_stream = ProofStream::new();
954        fri.prove(&codeword, &mut proof_stream).unwrap();
955
956        let proof_stream = prepare_proof_stream_for_verification(proof_stream);
957        let mut proof_stream =
958            modify_some_auth_structure_in_proof_stream_using_seed(proof_stream, rng_seed);
959
960        let verdict = fri.verify(&mut proof_stream);
961        let_assert!(Err(err) = verdict);
962        assert!(let FriValidationError::BadMerkleAuthenticationPath = err);
963    }
964
965    #[must_use]
966    fn modify_some_auth_structure_in_proof_stream_using_seed(
967        mut proof_stream: ProofStream,
968        seed: u64,
969    ) -> ProofStream {
970        let proof_items = proof_stream.items.iter_mut();
971        let auth_structures = proof_items.filter_map(non_trivial_auth_structure_filter());
972
973        let mut rng = StdRng::seed_from_u64(seed);
974        let auth_structure = auth_structures.choose(&mut rng).unwrap();
975        let modification_index = rng.random_range(0..auth_structure.len());
976        match rng.random_range(0..3) {
977            0 => _ = auth_structure.remove(modification_index),
978            1 => auth_structure.insert(modification_index, rng.random()),
979            2 => auth_structure[modification_index] = rng.random(),
980            _ => unreachable!(),
981        };
982
983        proof_stream
984    }
985
986    fn non_trivial_auth_structure_filter()
987    -> fn(&mut ProofItem) -> Option<&mut AuthenticationStructure> {
988        |proof_item| match proof_item {
989            ProofItem::FriResponse(fri_response) if fri_response.auth_structure.is_empty() => None,
990            ProofItem::FriResponse(fri_response) => Some(&mut fri_response.auth_structure),
991            _ => None,
992        }
993    }
994
995    #[proptest]
996    fn incorrect_last_round_polynomial_results_in_verification_failure(
997        fri: Fri,
998        #[strategy(arbitrary_polynomial())] fri_polynomial: XfePoly,
999        #[strategy(arbitrary_polynomial_of_degree(#fri.last_round_max_degree() as i64))]
1000        incorrect_polynomial: XfePoly,
1001    ) {
1002        let codeword = fri.domain.evaluate(&fri_polynomial);
1003        let mut proof_stream = ProofStream::new();
1004        fri.prove(&codeword, &mut proof_stream).unwrap();
1005
1006        let mut proof_stream = prepare_proof_stream_for_verification(proof_stream);
1007        proof_stream.items.iter_mut().for_each(|item| {
1008            if let ProofItem::FriPolynomial(polynomial) = item {
1009                *polynomial = incorrect_polynomial.clone();
1010            }
1011        });
1012
1013        let verdict = fri.verify(&mut proof_stream);
1014        let_assert!(Err(err) = verdict);
1015        assert!(let FriValidationError::LastRoundPolynomialEvaluationMismatch = err);
1016    }
1017
1018    #[proptest]
1019    fn codeword_corresponding_to_high_degree_polynomial_results_in_verification_failure(
1020        fri: Fri,
1021        #[strategy(Just(#fri.first_round_max_degree() as i64 + 1))] _min_fail_deg: i64,
1022        #[strategy(#_min_fail_deg..2 * #_min_fail_deg)] _degree: i64,
1023        #[strategy(arbitrary_polynomial_of_degree(#_degree))] poly: XfePoly,
1024    ) {
1025        let codeword = fri.domain.evaluate(&poly);
1026        let mut proof_stream = ProofStream::new();
1027        fri.prove(&codeword, &mut proof_stream).unwrap();
1028
1029        let mut proof_stream = prepare_proof_stream_for_verification(proof_stream);
1030        let verdict = fri.verify(&mut proof_stream);
1031        let_assert!(Err(err) = verdict);
1032        assert!(let FriValidationError::LastRoundPolynomialHasTooHighDegree = err);
1033    }
1034
1035    #[proptest]
1036    fn verifying_arbitrary_proof_does_not_panic(
1037        fri: Fri,
1038        #[strategy(arb())] mut proof_stream: ProofStream,
1039    ) {
1040        let _verdict = fri.verify(&mut proof_stream);
1041    }
1042}