Skip to main content

tasm_lib/verifier/fri/
derive_from_stark.rs

1use triton_vm::prelude::*;
2
3use crate::arithmetic::bfe::primitive_root_of_unity::PrimitiveRootOfUnity;
4use crate::arithmetic::u32::next_power_of_two::NextPowerOfTwo;
5use crate::prelude::*;
6use crate::verifier::fri::verify::fri_verify_type;
7
8/// Mimics Triton-VM's FRI parameter-derivation method, but doesn't allow for a FRI-domain length
9/// of 2^32 bc the domain length is stored in a single word/a `u32`.
10#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
11pub struct DeriveFriFromStark {
12    pub stark: Stark,
13}
14
15impl DeriveFriFromStark {
16    fn derive_fri_field_values(&self, library: &mut Library) -> Vec<LabelledInstruction> {
17        let next_power_of_two = library.import(Box::new(NextPowerOfTwo));
18        let domain_generator = library.import(Box::new(PrimitiveRootOfUnity));
19
20        let num_trace_randomizers = self.stark.num_trace_randomizers;
21        let fri_expansion_factor = self.stark.fri_expansion_factor;
22        let interpolant_codeword_length_code = triton_asm!(
23            // _ padded_height
24
25            push {num_trace_randomizers}
26            add
27            // _ (padded_height + num_trace_randomizers)
28
29            call {next_power_of_two}
30            // _ next_pow2(padded_height + num_trace_randomizers)
31            // _ interpolant_codeword_length
32        );
33        let fri_domain_length = triton_asm!(
34            // _ interpolant_codeword_length
35            push {fri_expansion_factor}
36            mul
37            // _ (interpolant_codeword_length * fri_expansion_factor)
38            // _ fri_domain_length
39        );
40
41        let domain_offset = BFieldElement::generator();
42        let num_collinearity_checks = self.stark.num_collinearity_checks;
43        let expansion_factor = self.stark.fri_expansion_factor;
44        triton_asm!(
45            // _ padded_height
46
47            {&interpolant_codeword_length_code}
48            {&fri_domain_length}
49            // _ fri_domain_length
50
51            push {num_collinearity_checks}
52            // _ fri_domain_length num_collinearity_checks
53
54            push {expansion_factor}
55            // _ fri_domain_length num_collinearity_checks expansion_factor
56
57            swap 2
58            // _ expansion_factor num_collinearity_checks fri_domain_length
59
60            push {domain_offset}
61            // _ expansion_factor num_collinearity_checks fri_domain_length domain_offset
62
63            dup 1
64            split
65            call {domain_generator}
66            // _ expansion_factor num_collinearity_checks fri_domain_length domain_offset domain_generator
67        )
68    }
69}
70
71impl BasicSnippet for DeriveFriFromStark {
72    fn parameters(&self) -> Vec<(DataType, String)> {
73        vec![(DataType::U32, "padded_height".to_owned())]
74    }
75
76    fn return_values(&self) -> Vec<(DataType, String)> {
77        vec![(
78            DataType::StructRef(fri_verify_type()),
79            "*fri_verify".to_owned(),
80        )]
81    }
82
83    fn entrypoint(&self) -> String {
84        "tasmlib_verifier_fri_derive_from_stark".to_owned()
85    }
86
87    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
88        let entrypoint = self.entrypoint();
89        let derive_fri_field_values = self.derive_fri_field_values(library);
90        let dyn_malloc = library.import(Box::new(DynMalloc));
91
92        triton_asm!(
93            {entrypoint}:
94                // _ padded_height
95
96                {&derive_fri_field_values}
97                // _ fri_domain_length domain_offset domain_generator num_collinearity_checks expansion_factor
98
99                call {dyn_malloc}
100                // _ fri_domain_length domain_offset domain_generator num_collinearity_checks expansion_factor *fri_verify
101
102                write_mem 5
103                // _ (*fri_verify + 5)
104
105                push -5
106                add
107                // _ *fri_verify
108
109                return
110        )
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::U32_TO_USIZE_ERR;
118    use crate::rust_shadowing_helper_functions;
119    use crate::test_prelude::*;
120    use crate::verifier::fri::verify::FriVerify;
121
122    #[macro_rules_attr::apply(test)]
123    fn fri_param_derivation_default_stark_pbt() {
124        ShadowedFunction::new(DeriveFriFromStark {
125            stark: Stark::default(),
126        })
127        .test();
128    }
129
130    #[macro_rules_attr::apply(proptest(cases = 10))]
131    fn fri_param_derivation_pbt_pbt(#[strategy(arb())] stark: Stark) {
132        ShadowedFunction::new(DeriveFriFromStark { stark }).test();
133    }
134
135    impl Function for DeriveFriFromStark {
136        fn rust_shadow(
137            &self,
138            stack: &mut Vec<BFieldElement>,
139            memory: &mut HashMap<BFieldElement, BFieldElement>,
140        ) -> Result<(), RustShadowError> {
141            let padded_height: u32 = stack
142                .pop()
143                .ok_or(RustShadowError::StackUnderflow)?
144                .try_into()
145                .map_err(|_| RustShadowError::U64ToU32Error)?;
146            let fri_from_tvm = self
147                .stark
148                .fri(padded_height.try_into().expect(U32_TO_USIZE_ERR))
149                .map_err(|_| RustShadowError::Other)?;
150            let local_fri: FriVerify = fri_from_tvm.into();
151            let fri_pointer =
152                rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator(memory);
153            encode_to_memory(memory, fri_pointer, &local_fri);
154            stack.push(fri_pointer);
155
156            Ok(())
157        }
158
159        fn pseudorandom_initial_state(
160            &self,
161            seed: [u8; 32],
162            bench_case: Option<BenchmarkCase>,
163        ) -> FunctionInitialState {
164            // Due to an arithmetic-overflow bug in Triton VM v3.0.0, derivation
165            // of a FRI instance using values that are too close to `usize::MAX`
166            // (and what “too close” means depends on the FRI expansion factor)
167            // results in a `panic!`, not an `Err`. The workaround is not pretty
168            // but should be temporary.
169
170            #[cfg(target_pointer_width = "32")]
171            const WORST_CASE_BENCH_SIZE: u32 = 21;
172            #[cfg(target_pointer_width = "64")]
173            const WORST_CASE_BENCH_SIZE: u32 = 23;
174
175            #[cfg(target_pointer_width = "32")]
176            const MAX_BENCH_SIZE: u32 = WORST_CASE_BENCH_SIZE;
177            #[cfg(target_pointer_width = "64")]
178            const MAX_BENCH_SIZE: u32 = 25;
179
180            let padded_height: u32 = match bench_case {
181                Some(BenchmarkCase::CommonCase) => 2u32.pow(21),
182                Some(BenchmarkCase::WorstCase) => 2u32.pow(WORST_CASE_BENCH_SIZE),
183                None => {
184                    let mut rng = StdRng::from_seed(seed);
185                    let mut padded_height = 2u32.pow(rng.random_range(8..=MAX_BENCH_SIZE));
186
187                    // Don't test parameters that result in too big FRI domains, i.e. larger
188                    // than 2^32. Note that this also excludes 2^32 as domain length because
189                    // the type used to hold this value is a `u32` in this repo. I think such a
190                    // large FRI domain is unfeasible anyway, so I'm reasonably comfortable
191                    // excluding that option.
192                    while self.stark.fri(padded_height as usize * 2).is_err() {
193                        padded_height /= 2;
194                    }
195
196                    assert!(padded_height >= 2u32.pow(8));
197
198                    padded_height
199                }
200            };
201
202            FunctionInitialState {
203                stack: [
204                    self.init_stack_for_isolated_run(),
205                    vec![padded_height.into()],
206                ]
207                .concat(),
208                memory: HashMap::default(),
209            }
210        }
211    }
212}