Skip to main content

tasm_lib/mmr/
calculate_new_peaks_from_leaf_mutation.rs

1use triton_vm::prelude::*;
2
3use super::leaf_index_to_mt_index_and_peak_index::MmrLeafIndexToMtIndexAndPeakIndex;
4use crate::arithmetic::u32::is_odd::IsOdd;
5use crate::arithmetic::u64::div2::Div2;
6use crate::list::get::Get;
7use crate::list::set::Set;
8use crate::prelude::*;
9
10/// Calculate new MMR peaks from a leaf mutation using Merkle tree indices walk up the tree
11#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
12pub struct MmrCalculateNewPeaksFromLeafMutationMtIndices;
13
14impl BasicSnippet for MmrCalculateNewPeaksFromLeafMutationMtIndices {
15    fn parameters(&self) -> Vec<(DataType, String)> {
16        let list_type = DataType::List(Box::new(DataType::Digest));
17        vec![
18            (list_type.clone(), "*auth_path".to_string()),
19            (DataType::U64, "leaf_index".to_string()),
20            (list_type, "*peaks".to_string()),
21            (DataType::Digest, "digest".to_string()),
22            (DataType::U64, "leaf_count".to_string()),
23        ]
24    }
25
26    fn return_values(&self) -> Vec<(DataType, String)> {
27        let list_type = DataType::List(Box::new(DataType::Digest));
28        vec![
29            (list_type, "*auth_path".to_string()),
30            (DataType::U64, "leaf_index".to_string()),
31        ]
32    }
33
34    fn entrypoint(&self) -> String {
35        "tasmlib_mmr_calculate_new_peaks_from_leaf_mutation".into()
36    }
37
38    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
39        let leaf_index_to_mt_index = library.import(Box::new(MmrLeafIndexToMtIndexAndPeakIndex));
40        let u32_is_odd = library.import(Box::new(IsOdd));
41        let get = library.import(Box::new(Get::new(DataType::Digest)));
42        let set = library.import(Box::new(Set::new(DataType::Digest)));
43        let div_2 = library.import(Box::new(Div2));
44
45        let entrypoint = self.entrypoint();
46        let while_loop = format!("{entrypoint}_while");
47        let swap_digests = format!("{entrypoint}_swap_digests");
48
49        triton_asm!(
50            // BEFORE: _ *auth_path leaf_index_hi leaf_index_lo *peaks [digest (leaf_digest)] leaf_count_hi leaf_count_lo
51            // AFTER:  _ *auth_path leaf_index_hi leaf_index_lo
52            {entrypoint}:
53                dup 9 dup 9
54                call {leaf_index_to_mt_index}
55                // stack: _ *auth_path leaf_index_hi leaf_index_lo *peaks [digest (leaf_digest)] mt_index_hi mt_index_lo peak_index
56
57                push 0
58                // stack: _ *auth_path leaf_index_hi leaf_index_lo *peaks [digest (leaf_digest)] mt_index_hi mt_index_lo peak_index i
59
60                swap 8 swap 4 swap 1 swap 7 swap 3 swap 6 swap 2 swap 5 swap 1
61                // stack: _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (leaf_digest)]
62                // rename: leaf_digest -> acc_hash
63
64                call {while_loop}
65                // _ _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (acc_hash)]
66
67                dup 9 dup 8
68                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest [digest (acc_hash)] *peaks peak_index
69
70                call {set}
71                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo
72
73                pop 5
74
75                return
76
77            // Note that this while loop is the same as one in `verify_from_memory`
78            // INVARIANT: _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (leaf_digest)]
79            {while_loop}:
80                dup 6 dup 6 push 0 push 1 {&DataType::U64.compare()}
81                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (leaf_digest)] (mt_index == 1)
82
83                skiz return
84                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (leaf_digest)]
85
86                // declare `ap_element = auth_path[i]`
87                dup 12 dup 9 call {get}
88                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (leaf_digest)] [digest (ap_element)]
89
90                dup 10 call {u32_is_odd} push 0 eq
91                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (acc_hash)] [digest (ap_element)] (mt_index % 2 == 0)
92
93                skiz call {swap_digests}
94                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (right_node)] [digest (left_node)]
95
96                hash
97                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (new_acc_hash)]
98
99                // i -> i + 1
100                swap 8 push 1 add swap 8
101                // _ *auth_path leaf_index_hi leaf_index_lo *peaks (i + 1) peak_index mt_index_hi mt_index_lo [digest (new_acc_hash)]
102
103                // mt_index -> mt_index / 2
104                swap 6 swap 1 swap 5
105                // _ *auth_path [digest (leaf_digest)] *peaks peak_index acc_hash_0 acc_hash_1 (i + 1) acc_hash_4 acc_hash_3 acc_hash_2 mt_index_hi mt_index_lo
106
107                call {div_2}
108                // _ *auth_path [digest (leaf_digest)] *peaks peak_index acc_hash_0 acc_hash_1 (i + 1) acc_hash_4 acc_hash_3 acc_hash_2 (mt_index / 2)_hi (mt_index / 2)_lo
109
110                swap 5 swap 1 swap 6
111                // _ *auth_path [digest (leaf_digest)] *peaks (mt_index / 2)_hi (mt_index / 2)_lo peak_index (i + 1) acc_hash_4 acc_hash_3 acc_hash_2 acc_hash_1 acc_hash_0
112
113                recurse
114
115            // purpose: swap the two digests `i` (node with `acc_hash`) is left child
116            {swap_digests}:
117                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (acc_hash)] [digest (ap_element)]
118                pick 9
119                pick 9
120                pick 9
121                pick 9
122                pick 9
123                // _ *auth_path leaf_index_hi leaf_index_lo *peaks i peak_index mt_index_hi mt_index_lo [digest (ap_element)] [digest (acc_hash)]
124
125                return
126
127        )
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use itertools::Itertools;
134
135    use super::*;
136    use crate::empty_stack;
137    use crate::mmr::MAX_MMR_HEIGHT;
138    use crate::test_prelude::*;
139    use crate::twenty_first::math::other::random_elements;
140    use crate::twenty_first::prelude::*;
141    use crate::twenty_first::util_types::mmr;
142    use crate::twenty_first::util_types::mmr::mmr_accumulator::MmrAccumulator;
143    use crate::twenty_first::util_types::mmr::mmr_accumulator::util::mmra_with_mps;
144    use crate::twenty_first::util_types::mmr::mmr_trait::LeafMutation;
145
146    // These consts are an improvement to the previous situation.
147    // I'll be the first to admit that this is not pretty, either.
148    // At least they highlight that there is an issue that should be resolved.
149    const AUTH_PATH_POINTER: BFieldElement =
150        BFieldElement::new((MAX_MMR_HEIGHT * Digest::LEN + 2) as u64);
151    const PEAKS_POINTER: BFieldElement = BFieldElement::new(1);
152
153    impl MmrCalculateNewPeaksFromLeafMutationMtIndices {
154        fn prepare_state_with_mmra(
155            &self,
156            start_mmr: &mut MmrAccumulator,
157            leaf_index: u64,
158            new_leaf: Digest,
159            auth_path: Vec<Digest>,
160        ) -> FunctionInitialState {
161            let mut stack = empty_stack();
162            stack.push(AUTH_PATH_POINTER);
163            push_encodable(&mut stack, &leaf_index);
164            stack.push(PEAKS_POINTER);
165            push_encodable(&mut stack, &new_leaf);
166            push_encodable(&mut stack, &start_mmr.num_leafs());
167
168            let mut memory = HashMap::default();
169            encode_to_memory(&mut memory, PEAKS_POINTER, &start_mmr.peaks());
170            encode_to_memory(&mut memory, AUTH_PATH_POINTER, &auth_path);
171
172            FunctionInitialState { stack, memory }
173        }
174    }
175
176    impl Function for MmrCalculateNewPeaksFromLeafMutationMtIndices {
177        fn rust_shadow(
178            &self,
179            stack: &mut Vec<BFieldElement>,
180            memory: &mut HashMap<BFieldElement, BFieldElement>,
181        ) -> Result<(), RustShadowError> {
182            let leaf_count = pop_encodable(stack)?;
183            let new_leaf = pop_encodable(stack)?;
184            let peaks_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
185            let leaf_index = pop_encodable(stack)?;
186            let auth_path_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
187
188            let peaks = *Vec::decode_from_memory(memory, peaks_pointer)
189                .map_err(|_| RustShadowError::DecodingError)?;
190            let auth_path = *Vec::decode_from_memory(memory, auth_path_pointer)
191                .map_err(|_| RustShadowError::DecodingError)?;
192            let mmr_mp = MmrMembershipProof::new(auth_path);
193            let new_peaks = mmr::shared_basic::calculate_new_peaks_from_leaf_mutation(
194                &peaks, leaf_count, new_leaf, leaf_index, &mmr_mp,
195            );
196            encode_to_memory(memory, peaks_pointer, &new_peaks);
197
198            stack.push(auth_path_pointer);
199            push_encodable(stack, &leaf_index);
200            Ok(())
201        }
202
203        fn pseudorandom_initial_state(
204            &self,
205            seed: [u8; 32],
206            bench_case: Option<BenchmarkCase>,
207        ) -> FunctionInitialState {
208            let mut rng = StdRng::from_seed(seed);
209            let (leaf_index, num_leafs) = match bench_case {
210                Some(BenchmarkCase::CommonCase) => ((1 << 31) - 32, 1 << 31),
211                Some(BenchmarkCase::WorstCase) => ((1 << 62) - 63, 1 << 62),
212                None => {
213                    let num_leafs = rng.random_range(1..=1 << 62);
214                    let leaf_index = rng.random_range(0..num_leafs);
215                    (leaf_index, num_leafs)
216                }
217            };
218
219            let leaf = rng.random();
220            let (mut mmra, mps) = mmra_with_mps(num_leafs, vec![(leaf_index, leaf)]);
221            let auth_path = mps[0].clone();
222            let new_leaf = rng.random();
223
224            self.prepare_state_with_mmra(
225                &mut mmra,
226                leaf_index,
227                new_leaf,
228                auth_path.authentication_path,
229            )
230        }
231    }
232
233    #[macro_rules_attr::apply(test)]
234    fn rust_shadow() {
235        ShadowedFunction::new(MmrCalculateNewPeaksFromLeafMutationMtIndices).test();
236    }
237
238    #[macro_rules_attr::apply(test)]
239    fn mmra_leaf_mutate_test_single() {
240        let digest0 = Tip5::hash(&BFieldElement::new(4545));
241        let digest1 = Tip5::hash(&BFieldElement::new(12345));
242        let mut mmr = MmrAccumulator::new_from_leafs(vec![]);
243        mmr.append(digest0);
244        let expected_final_mmra = MmrAccumulator::new_from_leafs(vec![digest1]);
245        let mutated_index = 0;
246        prop_calculate_new_peaks_from_leaf_mutation(
247            &mut mmr,
248            digest1,
249            mutated_index,
250            expected_final_mmra,
251            vec![],
252        );
253    }
254
255    fn mmra_leaf_mutate_test_n_leafs(leaf_count: usize) {
256        let init_leaf_digests: Vec<Digest> = random_elements(leaf_count);
257        let new_leaf: Digest = rand::rng().random();
258
259        let (mmra, mps) = mmra_with_mps(
260            leaf_count as u64,
261            init_leaf_digests
262                .iter()
263                .clone()
264                .enumerate()
265                .map(|(i, &d)| (i as u64, d))
266                .collect_vec(),
267        );
268
269        for mutated_index in 0..leaf_count {
270            let auth_path = mps[mutated_index].authentication_path.clone();
271            let mut final_digests = init_leaf_digests.clone();
272            final_digests[mutated_index] = new_leaf;
273            let expected_final_mmra = MmrAccumulator::new_from_leafs(final_digests);
274            prop_calculate_new_peaks_from_leaf_mutation(
275                &mut mmra.clone(),
276                new_leaf,
277                mutated_index as u64,
278                expected_final_mmra,
279                auth_path,
280            );
281        }
282    }
283
284    #[macro_rules_attr::apply(test)]
285    fn mmra_leaf_mutate_test_many_leaf_sizes() {
286        for leaf_count in 1..30 {
287            mmra_leaf_mutate_test_n_leafs(leaf_count);
288        }
289    }
290
291    #[macro_rules_attr::apply(test)]
292    fn mmra_leaf_mutate_test_other_leaf_sizes() {
293        for leaf_count in [127, 128] {
294            mmra_leaf_mutate_test_n_leafs(leaf_count);
295        }
296    }
297
298    #[macro_rules_attr::apply(test)]
299    fn mmra_leaf_mutate_big() {
300        for log_sizes in [15u64, 20, 25, 32, 35, 40, 45, 50, 55, 60, 62, 63] {
301            println!("log_sizes = {log_sizes}");
302            let init_peak_digests: Vec<Digest> = random_elements(log_sizes as usize);
303            let new_leaf: Digest = rand::rng().random();
304            let mut init_mmr =
305                MmrAccumulator::init(init_peak_digests.clone(), (1u64 << log_sizes) - 1);
306
307            let mut final_peaks = init_peak_digests.clone();
308            final_peaks[log_sizes as usize - 1] = new_leaf;
309            let expected_final_mmra = MmrAccumulator::init(final_peaks, (1u64 << log_sizes) - 1);
310            prop_calculate_new_peaks_from_leaf_mutation(
311                &mut init_mmr,
312                new_leaf,
313                (1u64 << log_sizes) - 2,
314                expected_final_mmra,
315                vec![],
316            );
317        }
318    }
319
320    #[macro_rules_attr::apply(test)]
321    fn mmra_leaf_mutate_advanced() {
322        for log_size in [31, 63] {
323            println!("log_sizes = {log_size}");
324            let init_peak_digests: Vec<Digest> = random_elements(log_size as usize);
325            let new_leaf: Digest = rand::rng().random();
326            let before_insertion_mmr =
327                MmrAccumulator::init(init_peak_digests.clone(), (1u64 << log_size) - 1);
328
329            // Insert a leaf such that a very long (log_size long) auth path is returned
330            let mut init_mmr = before_insertion_mmr.clone();
331            let mp = init_mmr.append(rand::rng().random());
332
333            let mut final_mmr = init_mmr.clone();
334            let leaf_mutation =
335                LeafMutation::new(before_insertion_mmr.num_leafs(), new_leaf, mp.clone());
336            final_mmr.mutate_leaf(leaf_mutation);
337
338            // Mutate the last element for which we just acquired an authentication path
339            prop_calculate_new_peaks_from_leaf_mutation(
340                &mut init_mmr,
341                new_leaf,
342                (1u64 << log_size) - 1,
343                final_mmr,
344                mp.authentication_path,
345            );
346        }
347    }
348
349    fn prop_calculate_new_peaks_from_leaf_mutation(
350        start_mmr: &mut MmrAccumulator,
351        new_leaf: Digest,
352        new_leaf_index: u64,
353        expected_mmr: MmrAccumulator,
354        auth_path: Vec<Digest>,
355    ) {
356        let mmr_new_peaks = MmrCalculateNewPeaksFromLeafMutationMtIndices;
357        let init_exec_state =
358            mmr_new_peaks.prepare_state_with_mmra(start_mmr, new_leaf_index, new_leaf, auth_path);
359
360        // AFTER: _ *auth_path leaf_index_hi leaf_index_lo
361        let mut expected_final_stack = empty_stack();
362        expected_final_stack.push(AUTH_PATH_POINTER);
363        expected_final_stack.push(BFieldElement::new(new_leaf_index >> 32));
364        expected_final_stack.push(BFieldElement::new(new_leaf_index & u32::MAX as u64));
365
366        let vm_output = test_rust_equivalence_given_complete_state(
367            &ShadowedFunction::new(MmrCalculateNewPeaksFromLeafMutationMtIndices),
368            &init_exec_state.stack,
369            &[],
370            &NonDeterminism::default().with_ram(init_exec_state.memory),
371            &None,
372            Some(&expected_final_stack),
373        );
374
375        // Find produced MMR
376        let final_memory = vm_output.ram;
377        let produced_peaks = *Vec::decode_from_memory(&final_memory, PEAKS_POINTER).unwrap();
378        let produced_mmr = MmrAccumulator::init(produced_peaks, start_mmr.num_leafs());
379
380        // Verify that both code paths produce the same MMR
381        assert_eq!(expected_mmr, produced_mmr);
382
383        // Verify that auth paths is still value
384        let auth_path = *Vec::decode_from_memory(&final_memory, AUTH_PATH_POINTER).unwrap();
385        let mmr_mp = MmrMembershipProof::new(auth_path);
386        assert!(
387            mmr_mp.verify(
388                new_leaf_index,
389                new_leaf,
390                &produced_mmr.peaks(),
391                produced_mmr.num_leafs(),
392            ),
393            "TASM-produced authentication path must be valid"
394        );
395
396        // Extra checks because paranoia
397        let mut expected_final_mmra_double_check = start_mmr.to_accumulator();
398        expected_final_mmra_double_check.mutate_leaf(LeafMutation::new(
399            new_leaf_index,
400            new_leaf,
401            mmr_mp.clone(),
402        ));
403        assert_eq!(expected_final_mmra_double_check, produced_mmr);
404        assert!(mmr_mp.verify(
405            new_leaf_index,
406            new_leaf,
407            &expected_final_mmra_double_check.peaks(),
408            expected_final_mmra_double_check.num_leafs()
409        ));
410    }
411}
412
413#[cfg(test)]
414mod benches {
415    use super::*;
416    use crate::test_prelude::*;
417
418    #[macro_rules_attr::apply(test)]
419    fn benchmark() {
420        ShadowedFunction::new(MmrCalculateNewPeaksFromLeafMutationMtIndices).bench();
421    }
422}