tasm_lib/verifier/fri/
derive_from_stark.rs1use triton_vm::prelude::*;
2
3use crate::arithmetic::bfe::primitive_root_of_unity::PrimitiveRootOfUnity;
4use crate::arithmetic::u32::next_power_of_two::NextPowerOfTwo;
5use crate::prelude::*;
6use crate::verifier::fri::verify::fri_verify_type;
7
8#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
11pub struct DeriveFriFromStark {
12 pub stark: Stark,
13}
14
15impl DeriveFriFromStark {
16 fn derive_fri_field_values(&self, library: &mut Library) -> Vec<LabelledInstruction> {
17 let next_power_of_two = library.import(Box::new(NextPowerOfTwo));
18 let domain_generator = library.import(Box::new(PrimitiveRootOfUnity));
19
20 let num_trace_randomizers = self.stark.num_trace_randomizers;
21 let fri_expansion_factor = self.stark.fri_expansion_factor;
22 let interpolant_codeword_length_code = triton_asm!(
23 push {num_trace_randomizers}
26 add
27 call {next_power_of_two}
30 );
33 let fri_domain_length = triton_asm!(
34 push {fri_expansion_factor}
36 mul
37 );
40
41 let domain_offset = BFieldElement::generator();
42 let num_collinearity_checks = self.stark.num_collinearity_checks;
43 let expansion_factor = self.stark.fri_expansion_factor;
44 triton_asm!(
45 {&interpolant_codeword_length_code}
48 {&fri_domain_length}
49 push {num_collinearity_checks}
52 push {expansion_factor}
55 swap 2
58 push {domain_offset}
61 dup 1
64 split
65 call {domain_generator}
66 )
68 }
69}
70
71impl BasicSnippet for DeriveFriFromStark {
72 fn parameters(&self) -> Vec<(DataType, String)> {
73 vec![(DataType::U32, "padded_height".to_owned())]
74 }
75
76 fn return_values(&self) -> Vec<(DataType, String)> {
77 vec![(
78 DataType::StructRef(fri_verify_type()),
79 "*fri_verify".to_owned(),
80 )]
81 }
82
83 fn entrypoint(&self) -> String {
84 "tasmlib_verifier_fri_derive_from_stark".to_owned()
85 }
86
87 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
88 let entrypoint = self.entrypoint();
89 let derive_fri_field_values = self.derive_fri_field_values(library);
90 let dyn_malloc = library.import(Box::new(DynMalloc));
91
92 triton_asm!(
93 {entrypoint}:
94 {&derive_fri_field_values}
97 call {dyn_malloc}
100 write_mem 5
103 push -5
106 add
107 return
110 )
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::U32_TO_USIZE_ERR;
118 use crate::rust_shadowing_helper_functions;
119 use crate::test_prelude::*;
120 use crate::verifier::fri::verify::FriVerify;
121
122 #[macro_rules_attr::apply(test)]
123 fn fri_param_derivation_default_stark_pbt() {
124 ShadowedFunction::new(DeriveFriFromStark {
125 stark: Stark::default(),
126 })
127 .test();
128 }
129
130 #[macro_rules_attr::apply(proptest(cases = 10))]
131 fn fri_param_derivation_pbt_pbt(#[strategy(arb())] stark: Stark) {
132 ShadowedFunction::new(DeriveFriFromStark { stark }).test();
133 }
134
135 impl Function for DeriveFriFromStark {
136 fn rust_shadow(
137 &self,
138 stack: &mut Vec<BFieldElement>,
139 memory: &mut HashMap<BFieldElement, BFieldElement>,
140 ) -> Result<(), RustShadowError> {
141 let padded_height: u32 = stack
142 .pop()
143 .ok_or(RustShadowError::StackUnderflow)?
144 .try_into()
145 .map_err(|_| RustShadowError::U64ToU32Error)?;
146 let fri_from_tvm = self
147 .stark
148 .fri(padded_height.try_into().expect(U32_TO_USIZE_ERR))
149 .map_err(|_| RustShadowError::Other)?;
150 let local_fri: FriVerify = fri_from_tvm.into();
151 let fri_pointer =
152 rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator(memory);
153 encode_to_memory(memory, fri_pointer, &local_fri);
154 stack.push(fri_pointer);
155
156 Ok(())
157 }
158
159 fn pseudorandom_initial_state(
160 &self,
161 seed: [u8; 32],
162 bench_case: Option<BenchmarkCase>,
163 ) -> FunctionInitialState {
164 #[cfg(target_pointer_width = "32")]
171 const WORST_CASE_BENCH_SIZE: u32 = 21;
172 #[cfg(target_pointer_width = "64")]
173 const WORST_CASE_BENCH_SIZE: u32 = 23;
174
175 #[cfg(target_pointer_width = "32")]
176 const MAX_BENCH_SIZE: u32 = WORST_CASE_BENCH_SIZE;
177 #[cfg(target_pointer_width = "64")]
178 const MAX_BENCH_SIZE: u32 = 25;
179
180 let padded_height: u32 = match bench_case {
181 Some(BenchmarkCase::CommonCase) => 2u32.pow(21),
182 Some(BenchmarkCase::WorstCase) => 2u32.pow(WORST_CASE_BENCH_SIZE),
183 None => {
184 let mut rng = StdRng::from_seed(seed);
185 let mut padded_height = 2u32.pow(rng.random_range(8..=MAX_BENCH_SIZE));
186
187 while self.stark.fri(padded_height as usize * 2).is_err() {
193 padded_height /= 2;
194 }
195
196 assert!(padded_height >= 2u32.pow(8));
197
198 padded_height
199 }
200 };
201
202 FunctionInitialState {
203 stack: [
204 self.init_stack_for_isolated_run(),
205 vec![padded_height.into()],
206 ]
207 .concat(),
208 memory: HashMap::default(),
209 }
210 }
211 }
212}