Skip to main content

tasm_lib/hashing/
merkle_root_from_xfes.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4use twenty_first::math::x_field_element::EXTENSION_DEGREE;
5
6use crate::hashing::merkle_root::MerkleRoot;
7use crate::prelude::*;
8use crate::traits::basic_snippet::Reviewer;
9use crate::traits::basic_snippet::SignOffFingerprint;
10
11/// Calculate a Merkle root from a list of extension-field elements.
12///
13/// ### Behavior
14///
15/// ```text
16/// BEFORE: _ *leafs
17/// AFTER:  _ [root: Digest]
18/// ```
19///
20/// ### Preconditions
21///
22/// - `*leafs` points to a list of [`XFieldElement`]s
23/// - the length of the pointed-to list is greater than 1
24/// - the length of the pointed-to list is a power of 2
25/// - the length of the pointed-to list is a u32
26///
27/// ### Postconditions
28///
29/// None.
30#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
31pub struct MerkleRootFromXfes;
32
33impl MerkleRootFromXfes {
34    pub const NUM_ELEMENTS_NOT_POWER_OF_2_ERROR_ID: i128 = 90;
35}
36
37impl BasicSnippet for MerkleRootFromXfes {
38    fn parameters(&self) -> Vec<(DataType, String)> {
39        let list_type = DataType::List(Box::new(DataType::Xfe));
40        vec![(list_type, "*leafs".to_string())]
41    }
42
43    fn return_values(&self) -> Vec<(DataType, String)> {
44        vec![(DataType::Digest, "root".to_string())]
45    }
46
47    fn entrypoint(&self) -> String {
48        "tasmlib_hashing_merkle_root_from_xfes".to_string()
49    }
50
51    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
52        let dyn_malloc = library.import(Box::new(DynMalloc));
53        let merkle_root = library.import(Box::new(MerkleRoot));
54
55        let entrypoint = self.entrypoint();
56        let list_len_is_1 = format!("{entrypoint}_list_len_is_1");
57        let build_1st_layer = format!("{entrypoint}_build_1st_layer");
58
59        triton_asm!(
60            // BEFORE: _ *leafs
61            {entrypoint}:
62                read_mem 1
63                addi 1
64                pick 1
65                // _ *xfes len
66
67                /* assert the number of elements is some power of 2 */
68                dup 0
69                pop_count
70                push 1
71                eq
72                assert error_id {Self::NUM_ELEMENTS_NOT_POWER_OF_2_ERROR_ID}
73
74                /* special case: list length is 1 */
75                push 0      hint return_early: bool = stack[0]
76                dup 1
77                push 1
78                eq
79                skiz call {list_len_is_1}
80                skiz return
81                // _ *xfes len
82
83                /* Strategy: Construct the 1st parent layer and store it as a list in memory. */
84                push 2
85                dup 1
86                div_mod
87                pop 1
88
89                dup 0
90                call {dyn_malloc}
91                // _ *xfes len (len / 2) (len / 2) *parent_nodes
92
93                write_mem 1
94                // _ *xfes len (len / 2) *parent_nodes[0]
95
96                pick 1
97                // _ *xfes len *parent_nodes[0] (len / 2)
98
99                addi -1
100                // _ *xfes len *parent_nodes[0] (len / 2 - 1)
101
102                push {Digest::LEN}
103                mul
104                // _ *xfes len *parent_nodes[0] parent_offset_last_element
105
106                dup 1
107                add
108                // _ *xfes len *parent_nodes[0] *parent_nodes[last]
109
110                place 2
111                // _ *xfes *parent_nodes[last] len *parent_nodes[0]
112
113                addi {-(Digest::LEN as isize)}
114                // _ *xfes *parent_nodes[last] len (*parent_nodes - 4)
115
116                place 3
117                // _ (*parent_nodes - 4) *xfes *parent_nodes[last] len
118
119                push {EXTENSION_DEGREE}
120                mul
121                // _ (*parent_nodes - 4) *xfes *parent_nodes[last] (lenĀ·3)
122
123                pick 2
124                add
125                // _ (*parent_nodes - 4) *parent_nodes[last] *xfes[last]_last_word
126
127                push 0
128                push 0
129                push 0
130                push 0
131                pick 4
132                // _ (*parent_nodes - 4) *parent_nodes[last] 0 0 0 0 *xfes[last]_last_word
133
134                call {build_1st_layer}
135                // _ (*parent_nodes - 4) *parent_digests[n] 0 0 0 0 *xfes[2*n]_last_word
136
137                pop 5
138                pop 1
139                // _ (*parent_nodes - 4)
140
141                addi {Digest::LEN - 1}
142                // _ *parent_digests
143
144                call {merkle_root}
145                // _ [merkle_root]
146
147                return
148
149            // BEFORE: _ *xfes 1 0
150            // AFTER:  _ [0 0 xfes[0]] 1
151            {list_len_is_1}:
152                            hint filler = stack[0]
153                            hint return_early: bool = stack[1]
154                push 0      hint filler = stack[0]
155                // _ *xfes 1 0 0
156
157                pick 3
158                addi {EXTENSION_DEGREE}
159                read_mem {EXTENSION_DEGREE}
160                            hint root: Digest = stack[1..6]
161                pop 1
162                // _ 1 [0 0 xfes[0]]
163
164                pick 5
165                return
166
167
168            // INVARIANT: _ (*parent_nodes - 4) *parent_digests[n] 0 0 0 0 *xfes[2*n]_last_word
169            {build_1st_layer}:
170                push 0
171                push 0
172                pick 2
173                read_mem {EXTENSION_DEGREE}
174                // _ (*parent_nodes - 4) *parent_digests[n] 0 0 0 0 [0 0 right_xfe] *xfes[2*n-1]
175
176                push 0
177                push 0
178                pick 2
179                read_mem {EXTENSION_DEGREE}
180                // _ (*parent_nodes - 4) *parent_digests[n] 0 0 0 0 [0 0 right_xfe] [0 0 left_xfe] *xfes[2*n-2]
181
182                place 10
183                // _ (*parent_nodes - 4) *parent_digests[n] 0 0 0 0 *xfes[2*n-2] [0 0 right_xfe] [0 0 left_xfe]
184
185                hash
186                // _ (*parent_nodes - 4) *parent_digests[n] 0 0 0 0 *xfes[2*n-2] [parent_digest]
187
188                pick 10
189                write_mem {Digest::LEN}
190                // _ (*parent_nodes - 4) 0 0 0 0 *xfes[2*n-2] *parent_digests[n+1]
191
192                addi -10
193                // _ (*parent_nodes - 4) 0 0 0 0 *xfes[2*n-2] *parent_digests[n-1]
194
195                place 5
196                // _ (*parent_nodes - 4) *parent_digests[n-1] 0 0 0 0 *xfes[2*n-2]
197
198                recurse_or_return
199        )
200    }
201
202    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
203        let mut sign_offs = HashMap::new();
204        sign_offs.insert(Reviewer("ferdinand"), 0x57f2e812e29b71b8.into());
205
206        sign_offs
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use proptest::collection::vec;
213    use twenty_first::util_types::merkle_tree::MerkleTree;
214
215    use super::*;
216    use crate::rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator;
217    use crate::rust_shadowing_helper_functions::list::list_new;
218    use crate::rust_shadowing_helper_functions::list::list_push;
219    use crate::test_helpers::test_assertion_failure;
220    use crate::test_prelude::*;
221
222    impl MerkleRootFromXfes {
223        fn init_state(
224            &self,
225            leafs: Vec<XFieldElement>,
226            leaf_pointer: BFieldElement,
227        ) -> FunctionInitialState {
228            let mut memory = HashMap::new();
229            encode_to_memory(&mut memory, leaf_pointer, &leafs);
230
231            let mut stack = self.init_stack_for_isolated_run();
232            stack.push(leaf_pointer);
233
234            FunctionInitialState { stack, memory }
235        }
236    }
237
238    impl Function for MerkleRootFromXfes {
239        fn rust_shadow(
240            &self,
241            stack: &mut Vec<BFieldElement>,
242            memory: &mut HashMap<BFieldElement, BFieldElement>,
243        ) -> Result<(), RustShadowError> {
244            let leafs_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
245            let leafs = *Vec::<XFieldElement>::decode_from_memory(memory, leafs_pointer)
246                .map_err(|_| RustShadowError::DecodingError)?;
247            let leafs = leafs.into_iter().map(Digest::from).collect_vec();
248            let mt = MerkleTree::par_new(&leafs).map_err(|_| RustShadowError::Other)?;
249
250            if leafs.len() == 1 {
251                stack.extend(mt.root().reversed().values());
252                return Ok(());
253            }
254
255            // Write entire Merkle tree to memory, because that's what the VM does
256            let first_layer_pointer = dynamic_allocator(memory);
257            list_new(first_layer_pointer, memory);
258            for node_count in 0..(leafs.len() >> 1) {
259                let node_index = node_count + (1 << (mt.height() - 1));
260                let node = mt.node(node_index).ok_or(RustShadowError::Other)?;
261                list_push(first_layer_pointer, node.values().to_vec(), memory)?
262            }
263
264            let rest_of_tree_pointer = dynamic_allocator(memory);
265            for layer in 2..=mt.height() {
266                for node_count in 0..(leafs.len() >> layer) {
267                    let node_index = node_count + (1 << (mt.height() - layer));
268                    let node = mt.node(node_index).ok_or(RustShadowError::Other)?;
269                    let pointer = rest_of_tree_pointer + bfe!(node_index * Digest::LEN);
270                    encode_to_memory(memory, pointer, &node);
271                }
272            }
273            stack.extend(mt.root().reversed().values());
274
275            Ok(())
276        }
277
278        fn pseudorandom_initial_state(
279            &self,
280            seed: [u8; 32],
281            bench_case: Option<BenchmarkCase>,
282        ) -> FunctionInitialState {
283            let mut rng = StdRng::from_seed(seed);
284            let num_leafs = match bench_case {
285                Some(BenchmarkCase::CommonCase) => 1 << 9,
286                Some(BenchmarkCase::WorstCase) => 1 << 10,
287                None => 1 << rng.random_range(1..=10),
288            };
289            let list_pointer = rng.random();
290            let leafs = (0..num_leafs).map(|_| rng.random()).collect_vec();
291
292            self.init_state(leafs, list_pointer)
293        }
294
295        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
296            [1, 2, 4, 8]
297                .map(|len| self.init_state(xfe_vec![1; len], bfe!(0)))
298                .to_vec()
299        }
300    }
301
302    #[macro_rules_attr::apply(test)]
303    fn rust_shadow() {
304        ShadowedFunction::new(MerkleRootFromXfes).test();
305    }
306
307    #[macro_rules_attr::apply(proptest(cases = 100))]
308    fn cannot_handle_input_list_of_length_not_pow2(
309        #[strategy(vec(arb(), 0..2048))]
310        #[filter(!#leafs.len().is_power_of_two())]
311        leafs: Vec<XFieldElement>,
312        #[strategy(arb())] address: BFieldElement,
313    ) {
314        test_assertion_failure(
315            &ShadowedFunction::new(MerkleRootFromXfes),
316            MerkleRootFromXfes.init_state(leafs, address).into(),
317            &[MerkleRootFromXfes::NUM_ELEMENTS_NOT_POWER_OF_2_ERROR_ID],
318        );
319    }
320}
321
322#[cfg(test)]
323mod benches {
324    use super::*;
325    use crate::test_prelude::*;
326
327    #[macro_rules_attr::apply(test)]
328    fn benchmark() {
329        ShadowedFunction::new(MerkleRootFromXfes).bench();
330    }
331}