Skip to main content

tasm_lib/verifier/fri/
number_of_rounds.rs

1use triton_vm::prelude::*;
2
3use crate::field;
4use crate::prelude::*;
5use crate::verifier::fri::verify::FriVerify;
6
7#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
8pub struct NumberOfRounds;
9
10impl BasicSnippet for NumberOfRounds {
11    fn parameters(&self) -> Vec<(DataType, String)> {
12        vec![(DataType::VoidPointer, "*fri".to_string())]
13    }
14
15    fn return_values(&self) -> Vec<(DataType, String)> {
16        vec![(DataType::U32, "num_rounds".to_string())]
17    }
18
19    fn entrypoint(&self) -> String {
20        "tasmlib_verifier_fri_number_of_rounds".to_string()
21    }
22
23    fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> {
24        let entrypoint = self.entrypoint();
25        let domain_length = field!(FriVerify::domain_length);
26        let expansion_factor = field!(FriVerify::expansion_factor);
27        let num_collinearity_checks = field!(FriVerify::num_collinearity_checks);
28
29        triton_asm! {
30            // BEFORE: _ *fri_verify
31            // AFTER:  _  num_rounds
32            {entrypoint}:
33                // calculate number of rounds
34                dup 0 {&domain_length}      // _ *fri_verify *domain_length
35                read_mem 1 pop 1            // _ *fri_verify domain_length
36                hint domain_length = stack[0]
37
38                dup 1 {&expansion_factor}   // _ *fri_verify domain_length *expansion_factor
39                read_mem 1 pop 1            // _ *fri_verify domain_length expansion_factor
40                hint expansion_factor = stack[0]
41
42                swap 1 div_mod pop 1        // _ *fri_verify first_round_code_dimension
43                log_2_floor                 // _ *fri_verify max_num_rounds
44                hint max_num_rounds = stack[0]
45
46                dup 1 {&num_collinearity_checks}
47                read_mem 1 pop 1            // _ *fri_verify max_num_rounds num_collinearity_checks
48                hint num_collinearity_checks = stack[0]
49
50                log_2_floor push 1 add      // _ *fri_verify max_num_rounds num_rounds_checking_most_locations
51
52                dup 1 dup 1 lt              // _ *fri_verify max_num_rounds num_rounds_checking_most_locations (num_rounds_checking_most_locations<max_num_rounds)
53                swap 2 push -1 mul add      // _ *fri_verify (num_rounds_checking_most_locations<max_num_rounds) num_rounds_checking_most_locations-max_num_rounds
54                mul push -1 mul             // _ *fri_verify if(num_rounds_checking_most_locations<max_num_rounds){max_num_rounds-num_rounds_checking_most_locations}else{0}
55                                            // _ *fri_verify num_rounds
56                hint num_rounds = stack[0]
57
58                swap 1 pop 1                // _ num_rounds
59                return
60        }
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use triton_vm::arithmetic_domain::ArithmeticDomain;
67    use triton_vm::fri::Fri;
68    use twenty_first::math::traits::PrimitiveRootOfUnity;
69
70    use super::*;
71    use crate::memory::dyn_malloc::DYN_MALLOC_ADDRESS;
72    use crate::test_prelude::*;
73
74    impl Function for NumberOfRounds {
75        fn rust_shadow(
76            &self,
77            stack: &mut Vec<BFieldElement>,
78            memory: &mut HashMap<BFieldElement, BFieldElement>,
79        ) -> Result<(), RustShadowError> {
80            let fri_verify = FriVerify::decode_from_memory(
81                memory,
82                stack.pop().ok_or(RustShadowError::StackUnderflow)?,
83            )
84            .map_err(|_| RustShadowError::DecodingError)?;
85            stack.push(BFieldElement::new(fri_verify.num_rounds() as u64));
86            Ok(())
87        }
88
89        fn pseudorandom_initial_state(
90            &self,
91            seed: [u8; 32],
92            bench_case: Option<BenchmarkCase>,
93        ) -> FunctionInitialState {
94            let mut rng = StdRng::from_seed(seed);
95            let rate_entropy = rng.random_range(1..16);
96            let num_colinearity_checks = f64::ceil(160.0 / (rate_entropy as f64)) as usize;
97            let domain_length = match bench_case {
98                Some(BenchmarkCase::CommonCase) => 1 << 17,
99                Some(BenchmarkCase::WorstCase) => 1 << 22,
100                None => 1 << rng.random_range(rate_entropy..=22),
101            };
102            let fri_verify = FriVerify {
103                expansion_factor: 1 << rate_entropy,
104                num_collinearity_checks: num_colinearity_checks as u32,
105                domain_length,
106                domain_offset: BFieldElement::new(7),
107                domain_generator: BFieldElement::primitive_root_of_unity(domain_length.into())
108                    .unwrap(),
109            };
110            let mut stack = self.init_stack_for_isolated_run();
111            let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
112
113            let address = DYN_MALLOC_ADDRESS;
114            encode_to_memory(&mut memory, address, &fri_verify);
115            stack.push(address);
116
117            FunctionInitialState { stack, memory }
118        }
119    }
120
121    #[macro_rules_attr::apply(test)]
122    fn test_shadow() {
123        ShadowedFunction::new(NumberOfRounds {}).test()
124    }
125
126    #[macro_rules_attr::apply(test)]
127    fn shadow_agrees_with_canon() {
128        let mut rng = rand::rng();
129        let num_trials = 50;
130        for _ in 0..num_trials {
131            let rate_entropy = rng.random_range(1..16);
132            let expansion_factor = 1 << rate_entropy;
133            let num_colinearity_checks = f64::ceil(160.0 / (rate_entropy as f64)) as usize;
134            let domain_length = 1 << rng.random_range(rate_entropy..=22);
135            let domain_offset = BFieldElement::new(7);
136            let fri_verify = FriVerify {
137                expansion_factor: 1 << rate_entropy,
138                num_collinearity_checks: num_colinearity_checks as u32,
139                domain_length,
140                domain_offset,
141                domain_generator: BFieldElement::primitive_root_of_unity(domain_length.into())
142                    .unwrap(),
143            };
144
145            let arithmetic_domain = ArithmeticDomain::of_length(domain_length.try_into().unwrap())
146                .unwrap()
147                .with_offset(domain_offset);
148            let fri =
149                Fri::new(arithmetic_domain, expansion_factor, num_colinearity_checks).unwrap();
150
151            assert_eq!(fri.num_rounds(), fri_verify.num_rounds());
152        }
153    }
154}