tasm_lib/mmr/
calculate_new_peaks_from_leaf_mutation.rs

1use triton_vm::prelude::*;
2
3use super::leaf_index_to_mt_index_and_peak_index::MmrLeafIndexToMtIndexAndPeakIndex;
4use crate::arithmetic::u32::is_odd::IsOdd;
5use crate::arithmetic::u64::div2::Div2;
6use crate::list::get::Get;
7use crate::list::set::Set;
8use crate::prelude::*;
9
10/// Calculate new MMR peaks from a leaf mutation using Merkle tree indices walk up the tree
11#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
12pub struct MmrCalculateNewPeaksFromLeafMutationMtIndices;
13
14impl BasicSnippet for MmrCalculateNewPeaksFromLeafMutationMtIndices {
15    fn inputs(&self) -> Vec<(DataType, String)> {
16        let list_type = DataType::List(Box::new(DataType::Digest));
17        vec![
18            (list_type.clone(), "*auth_path".to_string()),
19            (DataType::U64, "leaf_index".to_string()),
20            (list_type, "*peaks".to_string()),
21            (DataType::Digest, "digest".to_string()),
22            (DataType::U64, "leaf_count".to_string()),
23        ]
24    }
25
26    fn outputs(&self) -> Vec<(DataType, String)> {
27        let list_type = DataType::List(Box::new(DataType::Digest));
28        vec![
29            (list_type, "*auth_path".to_string()),
30            (DataType::U64, "leaf_index".to_string()),
31        ]
32    }
33
34    fn entrypoint(&self) -> String {
35        "tasmlib_mmr_calculate_new_peaks_from_leaf_mutation".into()
36    }
37
38    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
39        let leaf_index_to_mt_index = library.import(Box::new(MmrLeafIndexToMtIndexAndPeakIndex));
40        let u32_is_odd = library.import(Box::new(IsOdd));
41        let get = library.import(Box::new(Get::new(DataType::Digest)));
42        let set = library.import(Box::new(Set::new(DataType::Digest)));
43        let div_2 = library.import(Box::new(Div2));
44
45        let entrypoint = self.entrypoint();
46        let while_loop = format!("{entrypoint}_while");
47        let swap_digests = format!("{entrypoint}_swap_digests");
48
49        triton_asm!(
50            // BEFORE: _ *auth_path leaf_index_hi leaf_index_lo *peaks [digest (leaf_digest)] leaf_count_hi leaf_count_lo
51            // AFTER:  _ *auth_path leaf_index_hi leaf_index_lo
52            {entrypoint}:
53                dup 9 dup 9
54                call {leaf_index_to_mt_index}
55                // stack: _ *auth_path leaf_index_hi leaf_index_lo *peaks [digest (leaf_digest)] mt_index_hi mt_index_lo peak_index
56
57                push 0
58                // stack: _ *auth_path leaf_index_hi leaf_index_lo *peaks [digest (leaf_digest)] mt_index_hi mt_index_lo peak_index i
59
60                swap 8 swap 4 swap 1 swap 7 swap 3 swap 6 swap 2 swap 5 swap 1
61                // stack: _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (leaf_digest)]
62                // rename: leaf_digest -> acc_hash
63
64                call {while_loop}
65                // _ _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (acc_hash)]
66
67                dup 9 dup 8
68                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest [digest (acc_hash)] *peaks peak_index
69
70                call {set}
71                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo
72
73                pop 5
74
75                return
76
77            // Note that this while loop is the same as one in `verify_from_memory`
78            // INVARIANT: _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (leaf_digest)]
79            {while_loop}:
80                dup 6 dup 6 push 0 push 1 {&DataType::U64.compare()}
81                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (leaf_digest)] (mt_index == 1)
82
83                skiz return
84                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (leaf_digest)]
85
86                // declare `ap_element = auth_path[i]`
87                dup 12 dup 9 call {get}
88                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (leaf_digest)] [digest (ap_element)]
89
90                dup 10 call {u32_is_odd} push 0 eq
91                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (acc_hash)] [digest (ap_element)] (mt_index % 2 == 0)
92
93                skiz call {swap_digests}
94                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (right_node)] [digest (left_node)]
95
96                hash
97                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (new_acc_hash)]
98
99                // i -> i + 1
100                swap 8 push 1 add swap 8
101                // _ *auth_path leaf_index_hi leaf_index_lo *peaks (i + 1) peak_index mt_index_hi mt_index_lo [digest (new_acc_hash)]
102
103                // mt_index -> mt_index / 2
104                swap 6 swap 1 swap 5
105                // _ *auth_path [digest (leaf_digest)] *peaks peak_index acc_hash_0 acc_hash_1 (i + 1) acc_hash_4 acc_hash_3 acc_hash_2 mt_index_hi mt_index_lo
106
107                call {div_2}
108                // _ *auth_path [digest (leaf_digest)] *peaks peak_index acc_hash_0 acc_hash_1 (i + 1) acc_hash_4 acc_hash_3 acc_hash_2 (mt_index / 2)_hi (mt_index / 2)_lo
109
110                swap 5 swap 1 swap 6
111                // _ *auth_path [digest (leaf_digest)] *peaks (mt_index / 2)_hi (mt_index / 2)_lo peak_index (i + 1) acc_hash_4 acc_hash_3 acc_hash_2 acc_hash_1 acc_hash_0
112
113                recurse
114
115            // purpose: swap the two digests `i` (node with `acc_hash`) is left child
116            {swap_digests}:
117                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (acc_hash)] [digest (ap_element)]
118                pick 9
119                pick 9
120                pick 9
121                pick 9
122                pick 9
123                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (ap_element)] [digest (acc_hash)]
124
125                return
126
127        )
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use itertools::Itertools;
134
135    use super::*;
136    use crate::empty_stack;
137    use crate::mmr::MAX_MMR_HEIGHT;
138    use crate::test_prelude::*;
139    use crate::twenty_first::math::other::random_elements;
140    use crate::twenty_first::prelude::*;
141    use crate::twenty_first::util_types::mmr;
142    use crate::twenty_first::util_types::mmr::mmr_accumulator::MmrAccumulator;
143    use crate::twenty_first::util_types::mmr::mmr_accumulator::util::mmra_with_mps;
144    use crate::twenty_first::util_types::mmr::mmr_trait::LeafMutation;
145
146    // These consts are an improvement to the previous situation.
147    // I'll be the first to admit that this is not pretty, either.
148    // At least they highlight that there is an issue that should be resolved.
149    const AUTH_PATH_POINTER: BFieldElement =
150        BFieldElement::new((MAX_MMR_HEIGHT * Digest::LEN + 2) as u64);
151    const PEAKS_POINTER: BFieldElement = BFieldElement::new(1);
152
153    impl MmrCalculateNewPeaksFromLeafMutationMtIndices {
154        fn prepare_state_with_mmra(
155            &self,
156            start_mmr: &mut MmrAccumulator,
157            leaf_index: u64,
158            new_leaf: Digest,
159            auth_path: Vec<Digest>,
160        ) -> FunctionInitialState {
161            let mut stack = empty_stack();
162            stack.push(AUTH_PATH_POINTER);
163            push_encodable(&mut stack, &leaf_index);
164            stack.push(PEAKS_POINTER);
165            push_encodable(&mut stack, &new_leaf);
166            push_encodable(&mut stack, &start_mmr.num_leafs());
167
168            let mut memory = HashMap::default();
169            encode_to_memory(&mut memory, PEAKS_POINTER, &start_mmr.peaks());
170            encode_to_memory(&mut memory, AUTH_PATH_POINTER, &auth_path);
171
172            FunctionInitialState { stack, memory }
173        }
174    }
175
176    impl Function for MmrCalculateNewPeaksFromLeafMutationMtIndices {
177        fn rust_shadow(
178            &self,
179            stack: &mut Vec<BFieldElement>,
180            memory: &mut HashMap<BFieldElement, BFieldElement>,
181        ) {
182            let leaf_count = pop_encodable(stack);
183            let new_leaf = pop_encodable(stack);
184            let peaks_pointer = stack.pop().unwrap();
185            let leaf_index = pop_encodable(stack);
186            let auth_path_pointer = stack.pop().unwrap();
187
188            let peaks = *Vec::decode_from_memory(memory, peaks_pointer).unwrap();
189            let auth_path = *Vec::decode_from_memory(memory, auth_path_pointer).unwrap();
190            let mmr_mp = MmrMembershipProof::new(auth_path);
191            let new_peaks = mmr::shared_basic::calculate_new_peaks_from_leaf_mutation(
192                &peaks, leaf_count, new_leaf, leaf_index, &mmr_mp,
193            );
194            encode_to_memory(memory, peaks_pointer, &new_peaks);
195
196            stack.push(auth_path_pointer);
197            push_encodable(stack, &leaf_index);
198        }
199
200        fn pseudorandom_initial_state(
201            &self,
202            seed: [u8; 32],
203            bench_case: Option<BenchmarkCase>,
204        ) -> FunctionInitialState {
205            let mut rng = StdRng::from_seed(seed);
206            let (leaf_index, num_leafs) = match bench_case {
207                Some(BenchmarkCase::CommonCase) => ((1 << 31) - 32, 1 << 31),
208                Some(BenchmarkCase::WorstCase) => ((1 << 62) - 63, 1 << 62),
209                None => {
210                    let num_leafs = rng.random_range(1..=1 << 62);
211                    let leaf_index = rng.random_range(0..num_leafs);
212                    (leaf_index, num_leafs)
213                }
214            };
215
216            let leaf = rng.random();
217            let (mut mmra, mps) = mmra_with_mps(num_leafs, vec![(leaf_index, leaf)]);
218            let auth_path = mps[0].clone();
219            let new_leaf = rng.random();
220
221            self.prepare_state_with_mmra(
222                &mut mmra,
223                leaf_index,
224                new_leaf,
225                auth_path.authentication_path,
226            )
227        }
228    }
229
230    #[test]
231    fn rust_shadow() {
232        ShadowedFunction::new(MmrCalculateNewPeaksFromLeafMutationMtIndices).test();
233    }
234
235    #[test]
236    fn mmra_leaf_mutate_test_single() {
237        let digest0 = Tip5::hash(&BFieldElement::new(4545));
238        let digest1 = Tip5::hash(&BFieldElement::new(12345));
239        let mut mmr = MmrAccumulator::new_from_leafs(vec![]);
240        mmr.append(digest0);
241        let expected_final_mmra = MmrAccumulator::new_from_leafs(vec![digest1]);
242        let mutated_index = 0;
243        prop_calculate_new_peaks_from_leaf_mutation(
244            &mut mmr,
245            digest1,
246            mutated_index,
247            expected_final_mmra,
248            vec![],
249        );
250    }
251
252    fn mmra_leaf_mutate_test_n_leafs(leaf_count: usize) {
253        let init_leaf_digests: Vec<Digest> = random_elements(leaf_count);
254        let new_leaf: Digest = rand::rng().random();
255
256        let (mmra, mps) = mmra_with_mps(
257            leaf_count as u64,
258            init_leaf_digests
259                .iter()
260                .clone()
261                .enumerate()
262                .map(|(i, &d)| (i as u64, d))
263                .collect_vec(),
264        );
265
266        for mutated_index in 0..leaf_count {
267            let auth_path = mps[mutated_index].authentication_path.clone();
268            let mut final_digests = init_leaf_digests.clone();
269            final_digests[mutated_index] = new_leaf;
270            let expected_final_mmra = MmrAccumulator::new_from_leafs(final_digests);
271            prop_calculate_new_peaks_from_leaf_mutation(
272                &mut mmra.clone(),
273                new_leaf,
274                mutated_index as u64,
275                expected_final_mmra,
276                auth_path,
277            );
278        }
279    }
280
281    #[test]
282    fn mmra_leaf_mutate_test_many_leaf_sizes() {
283        for leaf_count in 1..30 {
284            mmra_leaf_mutate_test_n_leafs(leaf_count);
285        }
286    }
287
288    #[test]
289    fn mmra_leaf_mutate_test_other_leaf_sizes() {
290        for leaf_count in [127, 128] {
291            mmra_leaf_mutate_test_n_leafs(leaf_count);
292        }
293    }
294
295    #[test]
296    fn mmra_leaf_mutate_big() {
297        for log_sizes in [15u64, 20, 25, 32, 35, 40, 45, 50, 55, 60, 62, 63] {
298            println!("log_sizes = {log_sizes}");
299            let init_peak_digests: Vec<Digest> = random_elements(log_sizes as usize);
300            let new_leaf: Digest = rand::rng().random();
301            let mut init_mmr =
302                MmrAccumulator::init(init_peak_digests.clone(), (1u64 << log_sizes) - 1);
303
304            let mut final_peaks = init_peak_digests.clone();
305            final_peaks[log_sizes as usize - 1] = new_leaf;
306            let expected_final_mmra = MmrAccumulator::init(final_peaks, (1u64 << log_sizes) - 1);
307            prop_calculate_new_peaks_from_leaf_mutation(
308                &mut init_mmr,
309                new_leaf,
310                (1u64 << log_sizes) - 2,
311                expected_final_mmra,
312                vec![],
313            );
314        }
315    }
316
317    #[test]
318    fn mmra_leaf_mutate_advanced() {
319        for log_size in [31, 63] {
320            println!("log_sizes = {log_size}");
321            let init_peak_digests: Vec<Digest> = random_elements(log_size as usize);
322            let new_leaf: Digest = rand::rng().random();
323            let before_insertion_mmr =
324                MmrAccumulator::init(init_peak_digests.clone(), (1u64 << log_size) - 1);
325
326            // Insert a leaf such that a very long (log_size long) auth path is returned
327            let mut init_mmr = before_insertion_mmr.clone();
328            let mp = init_mmr.append(rand::rng().random());
329
330            let mut final_mmr = init_mmr.clone();
331            let leaf_mutation =
332                LeafMutation::new(before_insertion_mmr.num_leafs(), new_leaf, mp.clone());
333            final_mmr.mutate_leaf(leaf_mutation);
334
335            // Mutate the last element for which we just acquired an authentication path
336            prop_calculate_new_peaks_from_leaf_mutation(
337                &mut init_mmr,
338                new_leaf,
339                (1u64 << log_size) - 1,
340                final_mmr,
341                mp.authentication_path,
342            );
343        }
344    }
345
346    fn prop_calculate_new_peaks_from_leaf_mutation(
347        start_mmr: &mut MmrAccumulator,
348        new_leaf: Digest,
349        new_leaf_index: u64,
350        expected_mmr: MmrAccumulator,
351        auth_path: Vec<Digest>,
352    ) {
353        let mmr_new_peaks = MmrCalculateNewPeaksFromLeafMutationMtIndices;
354        let init_exec_state =
355            mmr_new_peaks.prepare_state_with_mmra(start_mmr, new_leaf_index, new_leaf, auth_path);
356
357        // AFTER: _ *auth_path leaf_index_hi leaf_index_lo
358        let mut expected_final_stack = empty_stack();
359        expected_final_stack.push(AUTH_PATH_POINTER);
360        expected_final_stack.push(BFieldElement::new(new_leaf_index >> 32));
361        expected_final_stack.push(BFieldElement::new(new_leaf_index & u32::MAX as u64));
362
363        let vm_output = test_rust_equivalence_given_complete_state(
364            &ShadowedFunction::new(MmrCalculateNewPeaksFromLeafMutationMtIndices),
365            &init_exec_state.stack,
366            &[],
367            &NonDeterminism::default().with_ram(init_exec_state.memory),
368            &None,
369            Some(&expected_final_stack),
370        );
371
372        // Find produced MMR
373        let final_memory = vm_output.ram;
374        let produced_peaks = *Vec::decode_from_memory(&final_memory, PEAKS_POINTER).unwrap();
375        let produced_mmr = MmrAccumulator::init(produced_peaks, start_mmr.num_leafs());
376
377        // Verify that both code paths produce the same MMR
378        assert_eq!(expected_mmr, produced_mmr);
379
380        // Verify that auth paths is still value
381        let auth_path = *Vec::decode_from_memory(&final_memory, AUTH_PATH_POINTER).unwrap();
382        let mmr_mp = MmrMembershipProof::new(auth_path);
383        assert!(
384            mmr_mp.verify(
385                new_leaf_index,
386                new_leaf,
387                &produced_mmr.peaks(),
388                produced_mmr.num_leafs(),
389            ),
390            "TASM-produced authentication path must be valid"
391        );
392
393        // Extra checks because paranoia
394        let mut expected_final_mmra_double_check = start_mmr.to_accumulator();
395        expected_final_mmra_double_check.mutate_leaf(LeafMutation::new(
396            new_leaf_index,
397            new_leaf,
398            mmr_mp.clone(),
399        ));
400        assert_eq!(expected_final_mmra_double_check, produced_mmr);
401        assert!(mmr_mp.verify(
402            new_leaf_index,
403            new_leaf,
404            &expected_final_mmra_double_check.peaks(),
405            expected_final_mmra_double_check.num_leafs()
406        ));
407    }
408}
409
410#[cfg(test)]
411mod benches {
412    use super::*;
413    use crate::test_prelude::*;
414
415    #[test]
416    fn benchmark() {
417        ShadowedFunction::new(MmrCalculateNewPeaksFromLeafMutationMtIndices).bench();
418    }
419}