tasm_lib/verifier/fri/
number_of_rounds.rs1use triton_vm::prelude::*;
2
3use crate::field;
4use crate::prelude::*;
5use crate::verifier::fri::verify::FriVerify;
6
7#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
8pub struct NumberOfRounds;
9
10impl BasicSnippet for NumberOfRounds {
11 fn parameters(&self) -> Vec<(DataType, String)> {
12 vec![(DataType::VoidPointer, "*fri".to_string())]
13 }
14
15 fn return_values(&self) -> Vec<(DataType, String)> {
16 vec![(DataType::U32, "num_rounds".to_string())]
17 }
18
19 fn entrypoint(&self) -> String {
20 "tasmlib_verifier_fri_number_of_rounds".to_string()
21 }
22
23 fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> {
24 let entrypoint = self.entrypoint();
25 let domain_length = field!(FriVerify::domain_length);
26 let expansion_factor = field!(FriVerify::expansion_factor);
27 let num_collinearity_checks = field!(FriVerify::num_collinearity_checks);
28
29 triton_asm! {
30 {entrypoint}:
33 dup 0 {&domain_length} read_mem 1 pop 1 hint domain_length = stack[0]
37
38 dup 1 {&expansion_factor} read_mem 1 pop 1 hint expansion_factor = stack[0]
41
42 swap 1 div_mod pop 1 log_2_floor hint max_num_rounds = stack[0]
45
46 dup 1 {&num_collinearity_checks}
47 read_mem 1 pop 1 hint num_collinearity_checks = stack[0]
49
50 log_2_floor push 1 add dup 1 dup 1 lt swap 2 push -1 mul add mul push -1 mul hint num_rounds = stack[0]
57
58 swap 1 pop 1 return
60 }
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use triton_vm::arithmetic_domain::ArithmeticDomain;
67 use triton_vm::fri::Fri;
68 use twenty_first::math::traits::PrimitiveRootOfUnity;
69
70 use super::*;
71 use crate::memory::dyn_malloc::DYN_MALLOC_ADDRESS;
72 use crate::test_prelude::*;
73
74 impl Function for NumberOfRounds {
75 fn rust_shadow(
76 &self,
77 stack: &mut Vec<BFieldElement>,
78 memory: &mut HashMap<BFieldElement, BFieldElement>,
79 ) -> Result<(), RustShadowError> {
80 let fri_verify = FriVerify::decode_from_memory(
81 memory,
82 stack.pop().ok_or(RustShadowError::StackUnderflow)?,
83 )
84 .map_err(|_| RustShadowError::DecodingError)?;
85 stack.push(BFieldElement::new(fri_verify.num_rounds() as u64));
86 Ok(())
87 }
88
89 fn pseudorandom_initial_state(
90 &self,
91 seed: [u8; 32],
92 bench_case: Option<BenchmarkCase>,
93 ) -> FunctionInitialState {
94 let mut rng = StdRng::from_seed(seed);
95 let rate_entropy = rng.random_range(1..16);
96 let num_colinearity_checks = f64::ceil(160.0 / (rate_entropy as f64)) as usize;
97 let domain_length = match bench_case {
98 Some(BenchmarkCase::CommonCase) => 1 << 17,
99 Some(BenchmarkCase::WorstCase) => 1 << 22,
100 None => 1 << rng.random_range(rate_entropy..=22),
101 };
102 let fri_verify = FriVerify {
103 expansion_factor: 1 << rate_entropy,
104 num_collinearity_checks: num_colinearity_checks as u32,
105 domain_length,
106 domain_offset: BFieldElement::new(7),
107 domain_generator: BFieldElement::primitive_root_of_unity(domain_length.into())
108 .unwrap(),
109 };
110 let mut stack = self.init_stack_for_isolated_run();
111 let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
112
113 let address = DYN_MALLOC_ADDRESS;
114 encode_to_memory(&mut memory, address, &fri_verify);
115 stack.push(address);
116
117 FunctionInitialState { stack, memory }
118 }
119 }
120
121 #[macro_rules_attr::apply(test)]
122 fn test_shadow() {
123 ShadowedFunction::new(NumberOfRounds {}).test()
124 }
125
126 #[macro_rules_attr::apply(test)]
127 fn shadow_agrees_with_canon() {
128 let mut rng = rand::rng();
129 let num_trials = 50;
130 for _ in 0..num_trials {
131 let rate_entropy = rng.random_range(1..16);
132 let expansion_factor = 1 << rate_entropy;
133 let num_colinearity_checks = f64::ceil(160.0 / (rate_entropy as f64)) as usize;
134 let domain_length = 1 << rng.random_range(rate_entropy..=22);
135 let domain_offset = BFieldElement::new(7);
136 let fri_verify = FriVerify {
137 expansion_factor: 1 << rate_entropy,
138 num_collinearity_checks: num_colinearity_checks as u32,
139 domain_length,
140 domain_offset,
141 domain_generator: BFieldElement::primitive_root_of_unity(domain_length.into())
142 .unwrap(),
143 };
144
145 let arithmetic_domain = ArithmeticDomain::of_length(domain_length.try_into().unwrap())
146 .unwrap()
147 .with_offset(domain_offset);
148 let fri =
149 Fri::new(arithmetic_domain, expansion_factor, num_colinearity_checks).unwrap();
150
151 assert_eq!(fri.num_rounds(), fri_verify.num_rounds());
152 }
153 }
154}