tasm_lib/verifier/fri/
derive_from_stark.rs

1use triton_vm::prelude::*;
2
3use crate::arithmetic::bfe::primitive_root_of_unity::PrimitiveRootOfUnity;
4use crate::arithmetic::u32::next_power_of_two::NextPowerOfTwo;
5use crate::prelude::*;
6use crate::verifier::fri::verify::fri_verify_type;
7
8/// Mimics Triton-VM's FRI parameter-derivation method, but doesn't allow for a FRI-domain length
9/// of 2^32 bc the domain length is stored in a single word/a `u32`.
10#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
11pub struct DeriveFriFromStark {
12    pub stark: Stark,
13}
14
15impl DeriveFriFromStark {
16    fn derive_fri_field_values(&self, library: &mut Library) -> Vec<LabelledInstruction> {
17        let next_power_of_two = library.import(Box::new(NextPowerOfTwo));
18        let domain_generator = library.import(Box::new(PrimitiveRootOfUnity));
19
20        let num_trace_randomizers = self.stark.num_trace_randomizers;
21        let fri_expansion_factor = self.stark.fri_expansion_factor;
22        let interpolant_codeword_length_code = triton_asm!(
23            // _ padded_height
24
25            push {num_trace_randomizers}
26            add
27            // _ (padded_height + num_trace_randomizers)
28
29            call {next_power_of_two}
30            // _ next_pow2(padded_height + num_trace_randomizers)
31            // _ interpolant_codeword_length
32        );
33        let fri_domain_length = triton_asm!(
34            // _ interpolant_codeword_length
35            push {fri_expansion_factor}
36            mul
37            // _ (interpolant_codeword_length * fri_expansion_factor)
38            // _ fri_domain_length
39        );
40
41        let domain_offset = BFieldElement::generator();
42        let num_collinearity_checks = self.stark.num_collinearity_checks;
43        let expansion_factor = self.stark.fri_expansion_factor;
44        triton_asm!(
45            // _ padded_height
46
47            {&interpolant_codeword_length_code}
48            {&fri_domain_length}
49            // _ fri_domain_length
50
51            push {num_collinearity_checks}
52            // _ fri_domain_length num_collinearity_checks
53
54            push {expansion_factor}
55            // _ fri_domain_length num_collinearity_checks expansion_factor
56
57            swap 2
58            // _ expansion_factor num_collinearity_checks fri_domain_length
59
60            push {domain_offset}
61            // _ expansion_factor num_collinearity_checks fri_domain_length domain_offset
62
63            dup 1
64            split
65            call {domain_generator}
66            // _ expansion_factor num_collinearity_checks fri_domain_length domain_offset domain_generator
67        )
68    }
69}
70
71impl BasicSnippet for DeriveFriFromStark {
72    fn inputs(&self) -> Vec<(DataType, String)> {
73        vec![(DataType::U32, "padded_height".to_owned())]
74    }
75
76    fn outputs(&self) -> Vec<(DataType, String)> {
77        vec![(
78            DataType::StructRef(fri_verify_type()),
79            "*fri_verify".to_owned(),
80        )]
81    }
82
83    fn entrypoint(&self) -> String {
84        "tasmlib_verifier_fri_derive_from_stark".to_owned()
85    }
86
87    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
88        let entrypoint = self.entrypoint();
89        let derive_fri_field_values = self.derive_fri_field_values(library);
90        let dyn_malloc = library.import(Box::new(DynMalloc));
91
92        triton_asm!(
93            {entrypoint}:
94                // _ padded_height
95
96                {&derive_fri_field_values}
97                // _ fri_domain_length domain_offset domain_generator num_collinearity_checks expansion_factor
98
99                call {dyn_malloc}
100                // _ fri_domain_length domain_offset domain_generator num_collinearity_checks expansion_factor *fri_verify
101
102                write_mem 5
103                // _ (*fri_verify + 5)
104
105                push -5
106                add
107                // _ *fri_verify
108
109                return
110        )
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::rust_shadowing_helper_functions;
118    use crate::test_prelude::*;
119    use crate::verifier::fri::verify::FriVerify;
120
121    #[test]
122    fn fri_param_derivation_default_stark_pbt() {
123        ShadowedFunction::new(DeriveFriFromStark {
124            stark: Stark::default(),
125        })
126        .test();
127    }
128
129    #[proptest(cases = 10)]
130    fn fri_param_derivation_pbt_pbt(#[strategy(arb())] stark: Stark) {
131        ShadowedFunction::new(DeriveFriFromStark { stark }).test();
132    }
133
134    impl Function for DeriveFriFromStark {
135        fn rust_shadow(
136            &self,
137            stack: &mut Vec<BFieldElement>,
138            memory: &mut HashMap<BFieldElement, BFieldElement>,
139        ) {
140            let padded_height: u32 = stack.pop().unwrap().try_into().unwrap();
141            let fri_from_tvm = self.stark.fri(padded_height.try_into().unwrap()).unwrap();
142            let local_fri: FriVerify = fri_from_tvm.into();
143            let fri_pointer =
144                rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator(memory);
145            encode_to_memory(memory, fri_pointer, &local_fri);
146            stack.push(fri_pointer)
147        }
148
149        fn pseudorandom_initial_state(
150            &self,
151            seed: [u8; 32],
152            bench_case: Option<BenchmarkCase>,
153        ) -> FunctionInitialState {
154            let padded_height: u32 = match bench_case {
155                Some(BenchmarkCase::CommonCase) => 2u32.pow(21),
156                Some(BenchmarkCase::WorstCase) => 2u32.pow(23),
157                None => {
158                    let mut rng = StdRng::from_seed(seed);
159                    let mut padded_height = 2u32.pow(rng.random_range(8..=25));
160
161                    // Don't test parameters that result in too big FRI domains, i.e. larger
162                    // than 2^32. Note that this also excludes 2^32 as domain length because
163                    // the type used to hold this value is a `u32` in this repo. I think such a
164                    // large FRI domain is unfeasible anyway, so I'm reasonably comfortable
165                    // excluding that option.
166                    while self.stark.fri(padded_height as usize * 2).is_err() {
167                        padded_height /= 2;
168                    }
169
170                    assert!(padded_height >= 2u32.pow(8));
171
172                    padded_height
173                }
174            };
175
176            FunctionInitialState {
177                stack: [
178                    self.init_stack_for_isolated_run(),
179                    vec![padded_height.into()],
180                ]
181                .concat(),
182                memory: HashMap::default(),
183            }
184        }
185    }
186}