tasm_lib/hashing/
merkle_root.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::prelude::*;
6use crate::traits::basic_snippet::Reviewer;
7use crate::traits::basic_snippet::SignOffFingerprint;
8
9/// Compute the Merkle root of a slice of `Digest`s. Corresponds to
10/// `MerkleTree::`[`sequential_new`][new]`(leafs).`[`root`][root]`()`.
11///
12/// ### Behavior
13///
14/// ```text
15/// BEFORE: _ *leafs
16/// AFTER:  _ [root: Digest]
17/// ```
18///
19/// ### Preconditions
20///
21/// - `*leafs` points to a list of Digests
22/// - the length of the pointed-to list is greater than 0
23/// - the length of the pointed-to list is a power of 2
24/// - the length of the pointed-to list is a u32
25///
26/// ### Postconditions
27///
28/// None.
29///
30/// [new]: twenty_first::prelude::MerkleTree::sequential_new
31/// [root]: twenty_first::prelude::MerkleTree::root
32#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
33pub struct MerkleRoot;
34
35impl MerkleRoot {
36    pub const NUM_LEAFS_NOT_POWER_OF_2_ERROR_ID: i128 = 431;
37}
38
39impl BasicSnippet for MerkleRoot {
40    fn inputs(&self) -> Vec<(DataType, String)> {
41        vec![(
42            DataType::List(Box::new(DataType::Digest)),
43            "*leafs".to_string(),
44        )]
45    }
46
47    fn outputs(&self) -> Vec<(DataType, String)> {
48        vec![(DataType::Digest, "root".to_string())]
49    }
50
51    fn entrypoint(&self) -> String {
52        "tasmlib_hashing_merkle_root".to_string()
53    }
54
55    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
56        let dyn_malloc = library.import(Box::new(DynMalloc));
57
58        let entrypoint = self.entrypoint();
59        let calculate_parent_digests = format!("{entrypoint}_calculate_parent_digests");
60        let next_layer_loop = format!("{entrypoint}_next_layer_loop");
61
62        triton_asm!(
63            {entrypoint}:
64                // _ *leafs
65
66                read_mem 1
67                addi 1
68                // _ leafs_len *leafs
69
70                /* assert the number of leafs is some power of 2 */
71                dup 1
72                pop_count
73                push 1
74                eq
75                assert error_id {Self::NUM_LEAFS_NOT_POWER_OF_2_ERROR_ID}
76
77                call {dyn_malloc}
78                // _ leafs_len *leafs *parent_level
79
80                /* adjust `*parent_level` to point to last element, first word */
81                dup 2
82                addi -1
83                push {Digest::LEN}
84                mul
85                add
86                // _ leafs_len *leafs (*parent_level + (leafs_len - 1) * Digest::LEN)
87                // _ leafs_len *leafs *parent_level'
88
89                /* adjust `*leafs` to point to last element, last word */
90                pick 1
91                dup 2
92                push {Digest::LEN}
93                mul
94                add
95                // _ leafs_len *parent_level' (*leafs + leafs_len * Digest::LEN)
96                // _ leafs_len *parent_level' *leafs'
97
98                call {next_layer_loop}
99                // _ 1 *address (*root + Digest::LEN)
100
101                place 2
102                pop 2
103                // _ (*root + Digest::LEN - 1)
104
105                read_mem {Digest::LEN}
106                // _ [root: Digest] (*root - 1)
107
108                pop 1
109                // _ [root: Digest]
110
111               return
112
113            // INVARIANT: _ current_len *next_level[last]_first_word *current_level[last]_last_word
114            {next_layer_loop}:
115                // _ current_len *next_level *current_level
116
117                /* end loop if `current_len == 1` */
118                dup 2
119                push 1
120                eq
121                skiz
122                    return
123                // _ current_len *next_level *current_level
124
125                /* update `current_len` */
126                pick 2
127                push {bfe!(2).inverse()}
128                        hint one_half = stack[0]
129                mul
130                place 2
131                // _ (current_len/2) *next_level *current_level
132
133                /* set up termination condition for parent calculation loop:
134                 * `*next_level - current_len / 2 * Digest::LEN`
135                 */
136                dup 1
137                dup 3
138                push {-(Digest::LEN as isize)}
139                mul
140                add
141                // _ (current_len/2) *next_level *current_level *next_level_stop
142                // _ (current_len/2) *next_level *current_elem  *next_elem_stop
143
144                dup 2
145                push 0
146                push 0
147                push 0
148                push 0
149                pick 6
150                // _ (current_len/2) *next_level *next_elem_stop *next_level 0 0 0 0 *current_elem
151
152                call {calculate_parent_digests}
153                pop 5
154                pop 1
155                // _ (current_len/2) *next_level *next_elem_stop
156
157                /* Update `*current_level` based on `*next_level` */
158                pick 1
159                // _ (current_len/2) *next_elem_stop *next_level
160
161                addi {Digest::LEN - 1}
162                // _ (current_len/2) *next_level' *current_level'
163
164                recurse
165
166            // Populate the `*next` digest list
167            // INVARIANT: _ *next_elem_stop *next_elem 0 0 0 0 *curr_elem
168            {calculate_parent_digests}:
169                read_mem {Digest::LEN}
170                read_mem {Digest::LEN}
171                // _ *next_elem_stop *next_elem 0 0 0 0 [right] [left] (*curr_elem[n] - 10)
172                // _ *next_elem_stop *next_elem 0 0 0 0 [right] [left] *curr_elem[n - 2]
173                // _ *next_elem_stop *next_elem 0 0 0 0 [right] [left] *curr_elem'
174
175                place 10
176                // _ *next_elem_stop *next_elem 0 0 0 0 *curr_elem' [right] [left]
177
178                hash
179                // _ *next_elem_stop *next_elem 0 0 0 0 *curr_elem' [parent_digest]
180
181                pick 10
182                // _ *next_elem_stop 0 0 0 0 *curr_elem' [parent_digest] *next_elem
183
184                write_mem {Digest::LEN}
185                // _ *next_elem_stop 0 0 0 0 *curr_elem' (*next_elem + 5)
186
187                addi -10
188                // _ *next_elem_stop 0 0 0 0 *curr_elem' (*next_elem - 5)
189                // _ *next_elem_stop 0 0 0 0 *curr_elem' *next_elem[n-1]
190                // _ *next_elem_stop 0 0 0 0 *curr_elem' *next_elem'
191
192                place 5
193                // _ *next_elem_stop *next_elem' 0 0 0 0 *curr_elem'
194
195                recurse_or_return
196        )
197    }
198
199    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
200        let mut sign_offs = HashMap::new();
201        sign_offs.insert(Reviewer("ferdinand"), 0x1c30ac983fdca9da.into());
202        sign_offs
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use proptest::collection::vec;
209    use twenty_first::util_types::merkle_tree::MerkleTree;
210
211    use super::*;
212    use crate::rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator;
213    use crate::test_prelude::*;
214
215    impl MerkleRoot {
216        fn init_state(
217            &self,
218            leafs: Vec<Digest>,
219            digests_pointer: BFieldElement,
220        ) -> FunctionInitialState {
221            let mut memory = HashMap::new();
222            encode_to_memory(&mut memory, digests_pointer, &leafs);
223            let mut stack = self.init_stack_for_isolated_run();
224            stack.push(digests_pointer);
225
226            FunctionInitialState { stack, memory }
227        }
228    }
229
230    impl Function for MerkleRoot {
231        fn rust_shadow(
232            &self,
233            stack: &mut Vec<BFieldElement>,
234            memory: &mut HashMap<BFieldElement, BFieldElement>,
235        ) {
236            let leafs_pointer = stack.pop().unwrap();
237            let leafs = *Vec::decode_from_memory(memory, leafs_pointer).unwrap();
238            let mt = MerkleTree::par_new(&leafs).unwrap();
239
240            // mimic snippet: write internal nodes to memory, skipping (dummy) node 0
241            let tree_pointer = dynamic_allocator(memory);
242            let num_internal_nodes = leafs.len();
243
244            for node_index in 1..num_internal_nodes {
245                let node = mt.node(node_index).unwrap();
246                let node_address = tree_pointer + bfe!(node_index * Digest::LEN);
247                encode_to_memory(memory, node_address, &node);
248            }
249
250            stack.extend(mt.root().reversed().values());
251        }
252
253        fn pseudorandom_initial_state(
254            &self,
255            seed: [u8; 32],
256            bench_case: Option<BenchmarkCase>,
257        ) -> FunctionInitialState {
258            let mut rng = StdRng::from_seed(seed);
259            let num_leafs = match bench_case {
260                Some(BenchmarkCase::CommonCase) => 512,
261                Some(BenchmarkCase::WorstCase) => 1024,
262                None => 1 << rng.random_range(0..=8),
263            };
264            let leafs = (0..num_leafs).map(|_| rng.random()).collect_vec();
265            let digests_pointer = rng.random();
266
267            self.init_state(leafs, digests_pointer)
268        }
269
270        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
271            let height_0 = self.init_state(vec![Digest::default()], bfe!(0));
272            let height_1 = self.init_state(vec![Digest::default(), Digest::default()], bfe!(0));
273
274            vec![height_0, height_1]
275        }
276    }
277
278    #[test]
279    fn rust_shadow() {
280        ShadowedFunction::new(MerkleRoot).test();
281    }
282
283    #[test]
284    fn computing_root_of_tree_of_height_0_crashes_vm() {
285        test_assertion_failure(
286            &ShadowedFunction::new(MerkleRoot),
287            MerkleRoot.init_state(vec![], bfe!(0)).into(),
288            &[MerkleRoot::NUM_LEAFS_NOT_POWER_OF_2_ERROR_ID],
289        );
290    }
291
292    #[proptest(cases = 100)]
293    fn computing_root_of_tree_of_height_not_power_of_2_crashes_vm(
294        #[strategy(vec(arb(), 0..2048))]
295        #[filter(!#leafs.len().is_power_of_two())]
296        leafs: Vec<Digest>,
297        #[strategy(arb())] address: BFieldElement,
298    ) {
299        test_assertion_failure(
300            &ShadowedFunction::new(MerkleRoot),
301            MerkleRoot.init_state(leafs, address).into(),
302            &[MerkleRoot::NUM_LEAFS_NOT_POWER_OF_2_ERROR_ID],
303        );
304    }
305}
306
307#[cfg(test)]
308mod benches {
309    use super::*;
310    use crate::test_prelude::*;
311
312    #[test]
313    fn benchmark() {
314        ShadowedFunction::new(MerkleRoot).bench();
315    }
316}