tasm_lib/verifier/fri/
number_of_rounds.rs

1use triton_vm::prelude::*;
2
3use crate::field;
4use crate::prelude::*;
5use crate::verifier::fri::verify::FriVerify;
6
7#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
8pub struct NumberOfRounds;
9
10impl BasicSnippet for NumberOfRounds {
11    fn inputs(&self) -> Vec<(DataType, String)> {
12        vec![(DataType::VoidPointer, "*fri".to_string())]
13    }
14
15    fn outputs(&self) -> Vec<(DataType, String)> {
16        vec![(DataType::U32, "num_rounds".to_string())]
17    }
18
19    fn entrypoint(&self) -> String {
20        "tasmlib_verifier_fri_number_of_rounds".to_string()
21    }
22
23    fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> {
24        let entrypoint = self.entrypoint();
25        let domain_length = field!(FriVerify::domain_length);
26        let expansion_factor = field!(FriVerify::expansion_factor);
27        let num_collinearity_checks = field!(FriVerify::num_collinearity_checks);
28
29        triton_asm! {
30            // BEFORE: _ *fri_verify
31            // AFTER:  _  num_rounds
32            {entrypoint}:
33                // calculate number of rounds
34                dup 0 {&domain_length}      // _ *fri_verify *domain_length
35                read_mem 1 pop 1            // _ *fri_verify domain_length
36                hint domain_length = stack[0]
37
38                dup 1 {&expansion_factor}   // _ *fri_verify domain_length *expansion_factor
39                read_mem 1 pop 1            // _ *fri_verify domain_length expansion_factor
40                hint expansion_factor = stack[0]
41
42                swap 1 div_mod pop 1        // _ *fri_verify first_round_code_dimension
43                log_2_floor                 // _ *fri_verify max_num_rounds
44                hint max_num_rounds = stack[0]
45
46                dup 1 {&num_collinearity_checks}
47                read_mem 1 pop 1            // _ *fri_verify max_num_rounds num_collinearity_checks
48                hint num_collinearity_checks = stack[0]
49
50                log_2_floor push 1 add      // _ *fri_verify max_num_rounds num_rounds_checking_most_locations
51
52                dup 1 dup 1 lt              // _ *fri_verify max_num_rounds num_rounds_checking_most_locations (num_rounds_checking_most_locations<max_num_rounds)
53                swap 2 push -1 mul add      // _ *fri_verify (num_rounds_checking_most_locations<max_num_rounds) num_rounds_checking_most_locations-max_num_rounds
54                mul push -1 mul             // _ *fri_verify if(num_rounds_checking_most_locations<max_num_rounds){max_num_rounds-num_rounds_checking_most_locations}else{0}
55                                            // _ *fri_verify num_rounds
56                hint num_rounds = stack[0]
57
58                swap 1 pop 1                // _ num_rounds
59                return
60        }
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use triton_vm::arithmetic_domain::ArithmeticDomain;
67    use triton_vm::fri::Fri;
68    use twenty_first::math::traits::PrimitiveRootOfUnity;
69
70    use super::*;
71    use crate::memory::dyn_malloc::DYN_MALLOC_ADDRESS;
72    use crate::test_prelude::*;
73
74    impl Function for NumberOfRounds {
75        fn rust_shadow(
76            &self,
77            stack: &mut Vec<BFieldElement>,
78            memory: &mut HashMap<BFieldElement, BFieldElement>,
79        ) {
80            let fri_verify = FriVerify::decode_from_memory(memory, stack.pop().unwrap()).unwrap();
81            stack.push(BFieldElement::new(fri_verify.num_rounds() as u64));
82        }
83
84        fn pseudorandom_initial_state(
85            &self,
86            seed: [u8; 32],
87            bench_case: Option<BenchmarkCase>,
88        ) -> FunctionInitialState {
89            let mut rng = StdRng::from_seed(seed);
90            let rate_entropy = rng.random_range(1..16);
91            let num_colinearity_checks = f64::ceil(160.0 / (rate_entropy as f64)) as usize;
92            let domain_length = match bench_case {
93                Some(BenchmarkCase::CommonCase) => 1 << 17,
94                Some(BenchmarkCase::WorstCase) => 1 << 22,
95                None => 1 << rng.random_range(rate_entropy..=22),
96            };
97            let fri_verify = FriVerify {
98                expansion_factor: 1 << rate_entropy,
99                num_collinearity_checks: num_colinearity_checks as u32,
100                domain_length,
101                domain_offset: BFieldElement::new(7),
102                domain_generator: BFieldElement::primitive_root_of_unity(domain_length.into())
103                    .unwrap(),
104            };
105            let mut stack = self.init_stack_for_isolated_run();
106            let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
107
108            let address = DYN_MALLOC_ADDRESS;
109            encode_to_memory(&mut memory, address, &fri_verify);
110            stack.push(address);
111
112            FunctionInitialState { stack, memory }
113        }
114    }
115
116    #[test]
117    fn test_shadow() {
118        ShadowedFunction::new(NumberOfRounds {}).test()
119    }
120
121    #[test]
122    fn shadow_agrees_with_canon() {
123        let mut rng = rand::rng();
124        let num_trials = 50;
125        for _ in 0..num_trials {
126            let rate_entropy = rng.random_range(1..16);
127            let expansion_factor = 1 << rate_entropy;
128            let num_colinearity_checks = f64::ceil(160.0 / (rate_entropy as f64)) as usize;
129            let domain_length = 1 << rng.random_range(rate_entropy..=22);
130            let domain_offset = BFieldElement::new(7);
131            let fri_verify = FriVerify {
132                expansion_factor: 1 << rate_entropy,
133                num_collinearity_checks: num_colinearity_checks as u32,
134                domain_length,
135                domain_offset,
136                domain_generator: BFieldElement::primitive_root_of_unity(domain_length.into())
137                    .unwrap(),
138            };
139
140            let arithmetic_domain = ArithmeticDomain::of_length(domain_length.try_into().unwrap())
141                .unwrap()
142                .with_offset(domain_offset);
143            let fri =
144                Fri::new(arithmetic_domain, expansion_factor, num_colinearity_checks).unwrap();
145
146            assert_eq!(fri.num_rounds(), fri_verify.num_rounds());
147        }
148    }
149}