Skip to main content

tasm_lib/verifier/fri/
verify_fri_authentication_paths.rs

1use triton_vm::prelude::*;
2use twenty_first::math::x_field_element::EXTENSION_DEGREE;
3
4use crate::prelude::*;
5
6/// Verify Merkle authentication paths in a FRI context.
7///
8/// Verify a batch of Merkle membership claims in a FRI context where only the
9/// a-indices are known and the b-indices must be calculated on the fly. This
10/// snippet can be used for both a and b-indices. For a-indices the
11/// `xor_bit_mask` value must be set to the domain length, and for b indices,
12/// `xor_bit_mask` must be set to 3/2 times the domain length. The
13/// `xor_bit_mask` is used to convert a leaf index into a Merkle tree node
14/// index.
15///
16/// Behavior: crashes the VM if just one of the authentication paths is
17/// invalid. Goes into an infinite loop if a node index value is initialized to
18/// 0 or 1 through wrong domain-length values. Also cannot handle empty lists,
19/// so this snippet must verify at least one authentication path.
20#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
21pub struct VerifyFriAuthenticationPaths;
22
23impl BasicSnippet for VerifyFriAuthenticationPaths {
24    fn parameters(&self) -> Vec<(DataType, String)> {
25        vec![
26            (DataType::U32, "dom_len_minus_one".to_owned()),
27            (DataType::U32, "xor_bitflag".to_owned()),
28            (
29                DataType::List(Box::new(DataType::Xfe)),
30                "*values_last_word".to_owned(),
31            ),
32            (
33                DataType::List(Box::new(DataType::U32)),
34                "*a_indices".to_owned(),
35            ),
36            (
37                DataType::List(Box::new(DataType::U32)),
38                "*a_indices_last_word".to_owned(),
39            ),
40            (DataType::Digest, "root".to_string()),
41        ]
42    }
43
44    fn return_values(&self) -> Vec<(DataType, String)> {
45        vec![]
46    }
47
48    fn entrypoint(&self) -> String {
49        "tasmlib_verifier_fri_verify_fri_authentication_paths".into()
50    }
51
52    fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> {
53        let entrypoint = self.entrypoint();
54        let main_loop = format!("{entrypoint}_main_loop");
55
56        let loop_over_auth_paths_label = format!("{entrypoint}_loop_over_auth_path_elements");
57        let loop_over_auth_paths_code = triton_asm!(
58            {loop_over_auth_paths_label}:
59                merkle_step                         // move up one level in the Merkle tree
60                recurse_or_return                   // break loop if node_index is 1
61        );
62
63        triton_asm!(
64            // BEFORE: _ dom_len_minus_one xor_bitflag *values_last_word *idx_end_cond *a_indices_last_word [root]
65            // AFTER : _
66
67            {entrypoint}:
68                call {main_loop}
69                // _ dom_len_minus_one xor_bitflag *values_last_word *a_indices *a_indices_last_word [root]
70
71                /* Cleanup stack */
72                pop 5
73                pop 5
74                // _
75
76                return
77
78
79            // Invariant: _ dom_len_minus_one xor_bitflag *value[n]_last_word *a_indices *a_indices[n] [root]
80            {main_loop}:
81                // _ dom_len_minus_one xor_bitflag *value[n] *a_indices *a_indices[n] [root]
82
83                push 1
84                // _ dom_len_minus_one xor_bitflag *value *a_indices *a_indices[n] [root] 1
85
86                pick 6
87                read_mem 1
88                place 7
89                // _ dom_len_minus_one xor_bitflag *value *a_indices *a_indices[n]' [root] 1 ia_0[n]
90
91                dup 11
92                and
93                dup 10
94                xor
95                // _ dom_len_minus_one xor_bitflag *value *a_indices *a_indices[n]' [root] 1 ((ia_0[n] & dom_len_minus_one) ^ xor_bitflag)
96                // _ dom_len_minus_one xor_bitflag *value *a_indices *a_indices[n]' [root] 1 (i_r[n] + dom_len)
97                // _ dom_len_minus_one xor_bitflag *value *a_indices *a_indices[n]' [root] 1 node_index_i_r[n]
98
99                push 0
100                push 0
101                // _ dom_len_minus_one xor_bitflag *value *a_indices *a_indices[n]' [root] 1 i_r[n] 0 0
102
103                pick 11
104                read_mem {EXTENSION_DEGREE}
105                place 14
106                // _ dom_len_minus_one xor_bitflag *value' *a_indices *a_indices[n]' [root] 1 i_r[n] 0 0 [xfe]
107
108                call {loop_over_auth_paths_label}
109                // _ dom_len_minus_one xor_bitflag *value' *a_indices *a_indices[n]' [root] 1 1 [calculated_root]
110                // _ dom_len_minus_one xor_bitflag *value' *a_indices *a_indices[n]' [root] 1 1 cr4 cr3 cr2 cr1 cr0
111
112                pick 5 pick 6
113                pop 2
114                // _ dom_len_minus_one xor_bitflag *value' *a_indices *a_indices[n]' [root] cr4 cr3 cr2 cr1 cr0
115                // _ dom_len_minus_one xor_bitflag *value' *a_indices *a_indices[n]' [root] [calculated_root]
116
117                assert_vector
118                    error_id 30
119                // _ dom_len_minus_one xor_bitflag *value *a_indices *a_indices[n]' [root]
120
121                recurse_or_return
122
123            {&loop_over_auth_paths_code}
124        )
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use rand::distr::Distribution;
131    use rand::distr::StandardUniform;
132    use strum::EnumIter;
133    use strum::IntoEnumIterator;
134    use twenty_first::prelude::*;
135
136    use super::*;
137    use crate::U32_TO_USIZE_ERR;
138    use crate::rust_shadowing_helper_functions;
139    use crate::test_prelude::*;
140
141    #[derive(Clone, Debug, EnumIter, Copy)]
142    enum IndexType {
143        A,
144        B,
145    }
146
147    impl Distribution<IndexType> for StandardUniform {
148        fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> IndexType {
149            if rng.random() {
150                IndexType::A
151            } else {
152                IndexType::B
153            }
154        }
155    }
156
157    impl Algorithm for VerifyFriAuthenticationPaths {
158        fn rust_shadow(
159            &self,
160            stack: &mut Vec<BFieldElement>,
161            memory: &mut HashMap<BFieldElement, BFieldElement>,
162            nondeterminism: &NonDeterminism,
163        ) -> Result<(), RustShadowError> {
164            let root = pop_encodable(stack)?;
165            let idx_last_elem = pop_encodable(stack)?;
166            let idx_end_condition = pop_encodable(stack)?;
167            let leaf_last_element_pointer = pop_encodable(stack)?;
168            let xor_bitflag = pop_encodable::<u32>(stack)?;
169            let dom_len_minus_one = pop_encodable::<u32>(stack)?;
170
171            let dom_len = dom_len_minus_one + 1;
172            let tree_height = dom_len.ilog2();
173
174            let mut auth_path_counter = 0;
175            let mut idx_element_pointer = idx_last_elem;
176            let mut leaf_pointer = leaf_last_element_pointer;
177            while idx_element_pointer != idx_end_condition {
178                let auth_path_len = usize::try_from(tree_height).expect(U32_TO_USIZE_ERR);
179                let auth_path_start = auth_path_counter * auth_path_len;
180                let auth_path_end = auth_path_start + auth_path_len;
181                let authentication_path =
182                    nondeterminism.digests[auth_path_start..auth_path_end].to_vec();
183
184                let leaf_index_a_round_0: u32 = memory
185                    .get(&idx_element_pointer)
186                    .map(|x| x.value())
187                    .unwrap_or_default()
188                    .try_into()
189                    .map_err(|_| RustShadowError::U64ToU32Error)?;
190                let node_index = (leaf_index_a_round_0 & dom_len_minus_one) ^ xor_bitflag;
191                let leaf_index = node_index ^ dom_len;
192
193                let read_word_from_mem =
194                    |pointer: BFieldElement| memory.get(&pointer).copied().unwrap_or_default();
195                let leaf = XFieldElement::new([
196                    read_word_from_mem(leaf_pointer - bfe!(2)),
197                    read_word_from_mem(leaf_pointer - bfe!(1)),
198                    read_word_from_mem(leaf_pointer),
199                ]);
200                let inclusion_proof = MerkleTreeInclusionProof {
201                    tree_height,
202                    indexed_leafs: vec![(leaf_index as usize, leaf.into())],
203                    authentication_structure: authentication_path,
204                };
205                if !inclusion_proof.verify(root) {
206                    return Err(RustShadowError::InvalidProof);
207                }
208
209                idx_element_pointer.decrement();
210                auth_path_counter += 1;
211                leaf_pointer -= bfe!(EXTENSION_DEGREE as u64);
212            }
213            Ok(())
214        }
215
216        fn pseudorandom_initial_state(
217            &self,
218            seed: [u8; 32],
219            bench_case: Option<BenchmarkCase>,
220        ) -> AlgorithmInitialState {
221            let mut rng = StdRng::from_seed(seed);
222
223            // determine sizes
224            let (height, num_indices) = match bench_case {
225                Some(BenchmarkCase::CommonCase) => (10, 80),
226                Some(BenchmarkCase::WorstCase) => (20, 80),
227                None => (rng.random_range(6..=15), rng.random_range(2..10) as usize),
228            };
229
230            let index_type = rng.random();
231
232            self.prepare_state(&mut rng, height, num_indices, index_type)
233        }
234
235        fn corner_case_initial_states(&self) -> Vec<AlgorithmInitialState> {
236            let mut rng = StdRng::from_seed([42u8; 32]);
237
238            let mut test_cases = vec![];
239            for index_type in IndexType::iter() {
240                test_cases.push(self.prepare_state(&mut rng, 1, 1, index_type));
241                test_cases.push(self.prepare_state(&mut rng, 1, 1, index_type));
242                test_cases.push(self.prepare_state(&mut rng, 1, 1, index_type));
243                test_cases.push(self.prepare_state(&mut rng, 1, 1, index_type));
244                test_cases.push(self.prepare_state(&mut rng, 1, 1, index_type));
245                test_cases.push(self.prepare_state(&mut rng, 1, 2, index_type));
246                test_cases.push(self.prepare_state(&mut rng, 2, 1, index_type));
247                test_cases.push(self.prepare_state(&mut rng, 2, 2, index_type));
248                test_cases.push(self.prepare_state(&mut rng, 2, 3, index_type));
249                test_cases.push(self.prepare_state(&mut rng, 2, 4, index_type));
250            }
251
252            test_cases
253        }
254    }
255
256    impl VerifyFriAuthenticationPaths {
257        fn prepare_state(
258            &self,
259            rng: &mut StdRng,
260            height: u32,
261            num_indices: usize,
262            index_type: IndexType,
263        ) -> AlgorithmInitialState {
264            // generate data structure
265            let dom_len = 1 << height;
266            let dom_len_minus_one = dom_len - 1;
267            let dom_len_half: u32 = dom_len / 2;
268
269            let xfe_leafs = (0..dom_len)
270                .map(|_| rng.random::<XFieldElement>())
271                .collect_vec();
272            let leafs_as_digest: Vec<Digest> =
273                xfe_leafs.iter().map(|&xfe| xfe.into()).collect_vec();
274            let tree = MerkleTree::par_new(&leafs_as_digest).unwrap();
275            let root = tree.root();
276
277            let a_indices = (0..num_indices)
278                .map(|_| rng.random_range(0..dom_len) as usize)
279                .collect_vec();
280
281            // TODO: Generalize for other values than round=0
282            let indices_revealed = match index_type {
283                IndexType::A => a_indices.clone(),
284                IndexType::B => a_indices
285                    .clone()
286                    .into_iter()
287                    .map(|x| (x + dom_len as usize / 2) & dom_len_minus_one as usize)
288                    .collect_vec(),
289            };
290            let opened_leafs = indices_revealed.iter().map(|i| xfe_leafs[*i]).collect_vec();
291            let authentication_paths = indices_revealed
292                .iter()
293                .rev()
294                .map(|i| tree.authentication_structure(&[*i]).unwrap())
295                .collect_vec();
296            let a_indices: Vec<u32> = a_indices.into_iter().map(|idx| idx as u32).collect_vec();
297
298            // prepare memory + stack + nondeterminism
299            let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::default();
300
301            let a_indices_pointer = BFieldElement::new(rng.next_u64() % (1 << 20));
302            rust_shadowing_helper_functions::list::list_insert(
303                a_indices_pointer,
304                a_indices,
305                &mut memory,
306            );
307
308            let leaf_pointer = BFieldElement::new(rng.next_u64() % (1 << 20) + (1 << 32));
309            rust_shadowing_helper_functions::list::list_insert(
310                leaf_pointer,
311                opened_leafs,
312                &mut memory,
313            );
314
315            let a_indices_last_word = a_indices_pointer + bfe!(num_indices as u64);
316            let leaf_pointer_last_word =
317                leaf_pointer + bfe!((EXTENSION_DEGREE * num_indices) as u64);
318            let dom_len_minus_one: u32 = dom_len - 1;
319            let xor_bitflag: u32 = match index_type {
320                IndexType::A => dom_len,
321                IndexType::B => dom_len_half + dom_len,
322            };
323
324            let mut stack = self.init_stack_for_isolated_run();
325            stack.push(bfe!(dom_len_minus_one));
326            stack.push(bfe!(xor_bitflag));
327            stack.push(leaf_pointer_last_word);
328            stack.push(a_indices_pointer);
329            stack.push(a_indices_last_word);
330            stack.push(root.0[4]);
331            stack.push(root.0[3]);
332            stack.push(root.0[2]);
333            stack.push(root.0[1]);
334            stack.push(root.0[0]);
335            let nondeterminism = NonDeterminism::default()
336                .with_digests(authentication_paths.into_iter().flatten().collect_vec())
337                .with_ram(memory);
338
339            AlgorithmInitialState {
340                stack,
341                nondeterminism,
342            }
343        }
344    }
345
346    #[macro_rules_attr::apply(test)]
347    fn test() {
348        ShadowedAlgorithm::new(VerifyFriAuthenticationPaths).test();
349    }
350
351    #[macro_rules_attr::apply(proptest)]
352    fn fri_authentication_fails_if_root_is_disturbed_slightly(
353        seed: [u8; 32],
354        #[strategy(0_usize..5)] perturbation_index: usize,
355        #[filter(#perturbation != 0)] perturbation: i8,
356    ) {
357        let mut initial_state = VerifyFriAuthenticationPaths.pseudorandom_initial_state(seed, None);
358        let top_of_stack = initial_state.stack.len() - 1;
359        initial_state.stack[top_of_stack - perturbation_index] += bfe!(perturbation);
360
361        test_assertion_failure(
362            &ShadowedAlgorithm::new(VerifyFriAuthenticationPaths),
363            initial_state.into(),
364            &[30],
365        );
366    }
367
368    #[macro_rules_attr::apply(proptest)]
369    fn fri_authentication_fails_if_xor_bitflag_is_disturbed_slightly(seed: [u8; 32]) {
370        let mut initial_state = VerifyFriAuthenticationPaths.pseudorandom_initial_state(seed, None);
371        let top_of_stack = initial_state.stack.len() - 1;
372        let xor_bitflag = initial_state.stack.get_mut(top_of_stack - 8).unwrap();
373        *xor_bitflag *= bfe!(2); // todo: generalize this perturbation
374        prop_assume!(u32::try_from(*xor_bitflag).is_ok());
375
376        test_assertion_failure(
377            &ShadowedAlgorithm::new(VerifyFriAuthenticationPaths),
378            initial_state.into(),
379            &[30],
380        );
381    }
382
383    #[macro_rules_attr::apply(proptest)]
384    fn fri_authentication_fails_if_authentication_path_is_disturbed_slightly(
385        seed: [u8; 32],
386        digest_index: usize,
387        #[strategy(0_usize..5)] perturbation_index: usize,
388        #[filter(#perturbation != 0)] perturbation: i8,
389    ) {
390        let mut initial_state = VerifyFriAuthenticationPaths.pseudorandom_initial_state(seed, None);
391        let auth_paths = &mut initial_state.nondeterminism.digests;
392        let digest_index = digest_index % auth_paths.len();
393        let Digest(ref mut auth_path_element_innards) = auth_paths[digest_index];
394        auth_path_element_innards[perturbation_index] += bfe!(perturbation);
395
396        test_assertion_failure(
397            &ShadowedAlgorithm::new(VerifyFriAuthenticationPaths),
398            initial_state.into(),
399            &[30],
400        );
401    }
402
403    #[macro_rules_attr::apply(proptest)]
404    fn fri_authentication_fails_if_a_index_is_disturbed_slightly(
405        seed: [u8; 32],
406        perturbation_index: usize,
407        #[filter(#perturbation != 0)] perturbation: i8,
408    ) {
409        let perturbation = bfe!(perturbation);
410
411        let mut initial_state = VerifyFriAuthenticationPaths.pseudorandom_initial_state(seed, None);
412        let top_of_stack = initial_state.stack.len() - 1;
413
414        let a_indices_pointer = initial_state.stack[top_of_stack - 6];
415        let a_indices_len = initial_state.nondeterminism.ram[&a_indices_pointer].value() as usize;
416        let perturbation_index = bfe!(perturbation_index % a_indices_len);
417
418        let perturbation_pointer = a_indices_pointer + bfe!(1) + perturbation_index;
419        let ram = &mut initial_state.nondeterminism.ram;
420        let a_index = ram.get_mut(&perturbation_pointer).unwrap();
421
422        let old_a_index = a_index.value();
423        *a_index += perturbation;
424        let new_a_index = a_index.value();
425        prop_assume!(u32::try_from(*a_index).is_ok());
426
427        // ensure meaningful perturbation
428        let dom_len_minus_one = initial_state.stack[top_of_stack - 9].value();
429        let xor_bitflag = initial_state.stack[top_of_stack - 8].value();
430        let node_index = |i| (i & dom_len_minus_one) ^ xor_bitflag;
431        prop_assume!(node_index(old_a_index) != node_index(new_a_index));
432
433        test_assertion_failure(
434            &ShadowedAlgorithm::new(VerifyFriAuthenticationPaths),
435            initial_state.into(),
436            &[30],
437        );
438    }
439}
440
441#[cfg(test)]
442mod benches {
443    use super::*;
444    use crate::test_prelude::*;
445
446    #[macro_rules_attr::apply(test)]
447    fn benchmark() {
448        ShadowedAlgorithm::new(VerifyFriAuthenticationPaths).bench();
449    }
450}