tasm_lib/verifier/fri/
number_of_rounds.rs1use triton_vm::prelude::*;
2
3use crate::field;
4use crate::prelude::*;
5use crate::verifier::fri::verify::FriVerify;
6
7#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
8pub struct NumberOfRounds;
9
10impl BasicSnippet for NumberOfRounds {
11 fn inputs(&self) -> Vec<(DataType, String)> {
12 vec![(DataType::VoidPointer, "*fri".to_string())]
13 }
14
15 fn outputs(&self) -> Vec<(DataType, String)> {
16 vec![(DataType::U32, "num_rounds".to_string())]
17 }
18
19 fn entrypoint(&self) -> String {
20 "tasmlib_verifier_fri_number_of_rounds".to_string()
21 }
22
23 fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> {
24 let entrypoint = self.entrypoint();
25 let domain_length = field!(FriVerify::domain_length);
26 let expansion_factor = field!(FriVerify::expansion_factor);
27 let num_collinearity_checks = field!(FriVerify::num_collinearity_checks);
28
29 triton_asm! {
30 {entrypoint}:
33 dup 0 {&domain_length} read_mem 1 pop 1 hint domain_length = stack[0]
37
38 dup 1 {&expansion_factor} read_mem 1 pop 1 hint expansion_factor = stack[0]
41
42 swap 1 div_mod pop 1 log_2_floor hint max_num_rounds = stack[0]
45
46 dup 1 {&num_collinearity_checks}
47 read_mem 1 pop 1 hint num_collinearity_checks = stack[0]
49
50 log_2_floor push 1 add dup 1 dup 1 lt swap 2 push -1 mul add mul push -1 mul hint num_rounds = stack[0]
57
58 swap 1 pop 1 return
60 }
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use triton_vm::arithmetic_domain::ArithmeticDomain;
67 use triton_vm::fri::Fri;
68 use twenty_first::math::traits::PrimitiveRootOfUnity;
69
70 use super::*;
71 use crate::memory::dyn_malloc::DYN_MALLOC_ADDRESS;
72 use crate::test_prelude::*;
73
74 impl Function for NumberOfRounds {
75 fn rust_shadow(
76 &self,
77 stack: &mut Vec<BFieldElement>,
78 memory: &mut HashMap<BFieldElement, BFieldElement>,
79 ) {
80 let fri_verify = FriVerify::decode_from_memory(memory, stack.pop().unwrap()).unwrap();
81 stack.push(BFieldElement::new(fri_verify.num_rounds() as u64));
82 }
83
84 fn pseudorandom_initial_state(
85 &self,
86 seed: [u8; 32],
87 bench_case: Option<BenchmarkCase>,
88 ) -> FunctionInitialState {
89 let mut rng = StdRng::from_seed(seed);
90 let rate_entropy = rng.random_range(1..16);
91 let num_colinearity_checks = f64::ceil(160.0 / (rate_entropy as f64)) as usize;
92 let domain_length = match bench_case {
93 Some(BenchmarkCase::CommonCase) => 1 << 17,
94 Some(BenchmarkCase::WorstCase) => 1 << 22,
95 None => 1 << rng.random_range(rate_entropy..=22),
96 };
97 let fri_verify = FriVerify {
98 expansion_factor: 1 << rate_entropy,
99 num_collinearity_checks: num_colinearity_checks as u32,
100 domain_length,
101 domain_offset: BFieldElement::new(7),
102 domain_generator: BFieldElement::primitive_root_of_unity(domain_length.into())
103 .unwrap(),
104 };
105 let mut stack = self.init_stack_for_isolated_run();
106 let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
107
108 let address = DYN_MALLOC_ADDRESS;
109 encode_to_memory(&mut memory, address, &fri_verify);
110 stack.push(address);
111
112 FunctionInitialState { stack, memory }
113 }
114 }
115
116 #[test]
117 fn test_shadow() {
118 ShadowedFunction::new(NumberOfRounds {}).test()
119 }
120
121 #[test]
122 fn shadow_agrees_with_canon() {
123 let mut rng = rand::rng();
124 let num_trials = 50;
125 for _ in 0..num_trials {
126 let rate_entropy = rng.random_range(1..16);
127 let expansion_factor = 1 << rate_entropy;
128 let num_colinearity_checks = f64::ceil(160.0 / (rate_entropy as f64)) as usize;
129 let domain_length = 1 << rng.random_range(rate_entropy..=22);
130 let domain_offset = BFieldElement::new(7);
131 let fri_verify = FriVerify {
132 expansion_factor: 1 << rate_entropy,
133 num_collinearity_checks: num_colinearity_checks as u32,
134 domain_length,
135 domain_offset,
136 domain_generator: BFieldElement::primitive_root_of_unity(domain_length.into())
137 .unwrap(),
138 };
139
140 let arithmetic_domain = ArithmeticDomain::of_length(domain_length.try_into().unwrap())
141 .unwrap()
142 .with_offset(domain_offset);
143 let fri =
144 Fri::new(arithmetic_domain, expansion_factor, num_colinearity_checks).unwrap();
145
146 assert_eq!(fri.num_rounds(), fri_verify.num_rounds());
147 }
148 }
149}