tasm_lib/hashing/
merkle_root_from_xfes.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4use twenty_first::math::x_field_element::EXTENSION_DEGREE;
5
6use crate::hashing::merkle_root::MerkleRoot;
7use crate::prelude::*;
8use crate::traits::basic_snippet::Reviewer;
9use crate::traits::basic_snippet::SignOffFingerprint;
10
11/// Calculate a Merkle root from a list of extension-field elements.
12///
13/// ### Behavior
14///
15/// ```text
16/// BEFORE: _ *leafs
17/// AFTER:  _ [root: Digest]
18/// ```
19///
20/// ### Preconditions
21///
22/// - `*leafs` points to a list of [`XFieldElement`]s
23/// - the length of the pointed-to list is greater than 1
24/// - the length of the pointed-to list is a power of 2
25/// - the length of the pointed-to list is a u32
26///
27/// ### Postconditions
28///
29/// None.
30#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
31pub struct MerkleRootFromXfes;
32
33impl MerkleRootFromXfes {
34    pub const NUM_ELEMENTS_NOT_POWER_OF_2_ERROR_ID: i128 = 90;
35}
36
37impl BasicSnippet for MerkleRootFromXfes {
38    fn inputs(&self) -> Vec<(DataType, String)> {
39        let list_type = DataType::List(Box::new(DataType::Xfe));
40        vec![(list_type, "*leafs".to_string())]
41    }
42
43    fn outputs(&self) -> Vec<(DataType, String)> {
44        vec![(DataType::Digest, "root".to_string())]
45    }
46
47    fn entrypoint(&self) -> String {
48        "tasmlib_hashing_merkle_root_from_xfes".to_string()
49    }
50
51    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
52        let dyn_malloc = library.import(Box::new(DynMalloc));
53        let merkle_root = library.import(Box::new(MerkleRoot));
54
55        let entrypoint = self.entrypoint();
56        let list_len_is_1 = format!("{entrypoint}_list_len_is_1");
57        let build_1st_layer = format!("{entrypoint}_build_1st_layer");
58
59        triton_asm!(
60            // BEFORE: _ *leafs
61            {entrypoint}:
62                read_mem 1
63                addi 1
64                pick 1
65                // _ *xfes len
66
67                /* assert the number of elements is some power of 2 */
68                dup 0
69                pop_count
70                push 1
71                eq
72                assert error_id {Self::NUM_ELEMENTS_NOT_POWER_OF_2_ERROR_ID}
73
74                /* special case: list length is 1 */
75                push 0      hint return_early: bool = stack[0]
76                dup 1
77                push 1
78                eq
79                skiz call {list_len_is_1}
80                skiz return
81                // _ *xfes len
82
83                /* Strategy: Construct the 1st parent layer and store it as a list in memory. */
84                push 2
85                dup 1
86                div_mod
87                pop 1
88
89                dup 0
90                call {dyn_malloc}
91                // _ *xfes len (len / 2) (len / 2) *parent_nodes
92
93                write_mem 1
94                // _ *xfes len (len / 2) *parent_nodes[0]
95
96                pick 1
97                // _ *xfes len *parent_nodes[0] (len / 2)
98
99                addi -1
100                // _ *xfes len *parent_nodes[0] (len / 2 - 1)
101
102                push {Digest::LEN}
103                mul
104                // _ *xfes len *parent_nodes[0] parent_offset_last_element
105
106                dup 1
107                add
108                // _ *xfes len *parent_nodes[0] *parent_nodes[last]
109
110                place 2
111                // _ *xfes *parent_nodes[last] len *parent_nodes[0]
112
113                addi {-(Digest::LEN as isize)}
114                // _ *xfes *parent_nodes[last] len (*parent_nodes - 4)
115
116                place 3
117                // _ (*parent_nodes - 4) *xfes *parent_nodes[last] len
118
119                push {EXTENSION_DEGREE}
120                mul
121                // _ (*parent_nodes - 4) *xfes *parent_nodes[last] (lenĀ·3)
122
123                pick 2
124                add
125                // _ (*parent_nodes - 4) *parent_nodes[last] *xfes[last]_last_word
126
127                push 0
128                push 0
129                push 0
130                push 0
131                pick 4
132                // _ (*parent_nodes - 4) *parent_nodes[last] 0 0 0 0 *xfes[last]_last_word
133
134                call {build_1st_layer}
135                // _ (*parent_nodes - 4) *parent_digests[n] 0 0 0 0 *xfes[2*n]_last_word
136
137                pop 5
138                pop 1
139                // _ (*parent_nodes - 4)
140
141                addi {Digest::LEN - 1}
142                // _ *parent_digests
143
144                call {merkle_root}
145                // _ [merkle_root]
146
147                return
148
149            // BEFORE: _ *xfes 1 0
150            // AFTER:  _ [0 0 xfes[0]] 1
151            {list_len_is_1}:
152                            hint filler = stack[0]
153                            hint return_early: bool = stack[1]
154                push 0      hint filler = stack[0]
155                // _ *xfes 1 0 0
156
157                pick 3
158                addi {EXTENSION_DEGREE}
159                read_mem {EXTENSION_DEGREE}
160                            hint root: Digest = stack[1..6]
161                pop 1
162                // _ 1 [0 0 xfes[0]]
163
164                pick 5
165                return
166
167
168            // INVARIANT: _ (*parent_nodes - 4) *parent_digests[n] 0 0 0 0 *xfes[2*n]_last_word
169            {build_1st_layer}:
170                push 0
171                push 0
172                pick 2
173                read_mem {EXTENSION_DEGREE}
174                // _ (*parent_nodes - 4) *parent_digests[n] 0 0 0 0 [0 0 right_xfe] *xfes[2*n-1]
175
176                push 0
177                push 0
178                pick 2
179                read_mem {EXTENSION_DEGREE}
180                // _ (*parent_nodes - 4) *parent_digests[n] 0 0 0 0 [0 0 right_xfe] [0 0 left_xfe] *xfes[2*n-2]
181
182                place 10
183                // _ (*parent_nodes - 4) *parent_digests[n] 0 0 0 0 *xfes[2*n-2] [0 0 right_xfe] [0 0 left_xfe]
184
185                hash
186                // _ (*parent_nodes - 4) *parent_digests[n] 0 0 0 0 *xfes[2*n-2] [parent_digest]
187
188                pick 10
189                write_mem {Digest::LEN}
190                // _ (*parent_nodes - 4) 0 0 0 0 *xfes[2*n-2] *parent_digests[n+1]
191
192                addi -10
193                // _ (*parent_nodes - 4) 0 0 0 0 *xfes[2*n-2] *parent_digests[n-1]
194
195                place 5
196                // _ (*parent_nodes - 4) *parent_digests[n-1] 0 0 0 0 *xfes[2*n-2]
197
198                recurse_or_return
199        )
200    }
201
202    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
203        let mut sign_offs = HashMap::new();
204        sign_offs.insert(Reviewer("ferdinand"), 0x850f6c4f5a62ccb5.into());
205
206        sign_offs
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use proptest::collection::vec;
213    use twenty_first::util_types::merkle_tree::MerkleTree;
214
215    use super::*;
216    use crate::rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator;
217    use crate::rust_shadowing_helper_functions::list::list_new;
218    use crate::rust_shadowing_helper_functions::list::list_push;
219    use crate::test_helpers::test_assertion_failure;
220    use crate::test_prelude::*;
221
222    impl MerkleRootFromXfes {
223        fn init_state(
224            &self,
225            leafs: Vec<XFieldElement>,
226            leaf_pointer: BFieldElement,
227        ) -> FunctionInitialState {
228            let mut memory = HashMap::new();
229            encode_to_memory(&mut memory, leaf_pointer, &leafs);
230
231            let mut stack = self.init_stack_for_isolated_run();
232            stack.push(leaf_pointer);
233
234            FunctionInitialState { stack, memory }
235        }
236    }
237
238    impl Function for MerkleRootFromXfes {
239        fn rust_shadow(
240            &self,
241            stack: &mut Vec<BFieldElement>,
242            memory: &mut HashMap<BFieldElement, BFieldElement>,
243        ) {
244            let leafs_pointer = stack.pop().unwrap();
245            let leafs = *Vec::<XFieldElement>::decode_from_memory(memory, leafs_pointer).unwrap();
246            let leafs = leafs.into_iter().map(Digest::from).collect_vec();
247            let mt = MerkleTree::par_new(&leafs).unwrap();
248
249            if leafs.len() == 1 {
250                stack.extend(mt.root().reversed().values());
251                return;
252            }
253
254            // Write entire Merkle tree to memory, because that's what the VM does
255            let first_layer_pointer = dynamic_allocator(memory);
256            list_new(first_layer_pointer, memory);
257            for node_count in 0..(leafs.len() >> 1) {
258                let node_index = node_count + (1 << (mt.height() - 1));
259                let node = mt.node(node_index).unwrap();
260                list_push(first_layer_pointer, node.values().to_vec(), memory)
261            }
262
263            let rest_of_tree_pointer = dynamic_allocator(memory);
264            for layer in 2..=mt.height() {
265                for node_count in 0..(leafs.len() >> layer) {
266                    let node_index = node_count + (1 << (mt.height() - layer));
267                    let node = mt.node(node_index).unwrap();
268                    let pointer = rest_of_tree_pointer + bfe!(node_index * Digest::LEN);
269                    encode_to_memory(memory, pointer, &node);
270                }
271            }
272
273            stack.extend(mt.root().reversed().values());
274        }
275
276        fn pseudorandom_initial_state(
277            &self,
278            seed: [u8; 32],
279            bench_case: Option<BenchmarkCase>,
280        ) -> FunctionInitialState {
281            let mut rng = StdRng::from_seed(seed);
282            let num_leafs = match bench_case {
283                Some(BenchmarkCase::CommonCase) => 1 << 9,
284                Some(BenchmarkCase::WorstCase) => 1 << 10,
285                None => 1 << rng.random_range(1..=10),
286            };
287            let list_pointer = rng.random();
288            let leafs = (0..num_leafs).map(|_| rng.random()).collect_vec();
289
290            self.init_state(leafs, list_pointer)
291        }
292
293        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
294            [1, 2, 4, 8]
295                .map(|len| self.init_state(xfe_vec![1; len], bfe!(0)))
296                .to_vec()
297        }
298    }
299
300    #[test]
301    fn rust_shadow() {
302        ShadowedFunction::new(MerkleRootFromXfes).test();
303    }
304
305    #[proptest(cases = 100)]
306    fn cannot_handle_input_list_of_length_not_pow2(
307        #[strategy(vec(arb(), 0..2048))]
308        #[filter(!#leafs.len().is_power_of_two())]
309        leafs: Vec<XFieldElement>,
310        #[strategy(arb())] address: BFieldElement,
311    ) {
312        test_assertion_failure(
313            &ShadowedFunction::new(MerkleRootFromXfes),
314            MerkleRootFromXfes.init_state(leafs, address).into(),
315            &[MerkleRootFromXfes::NUM_ELEMENTS_NOT_POWER_OF_2_ERROR_ID],
316        );
317    }
318}
319
320#[cfg(test)]
321mod benches {
322    use super::*;
323    use crate::test_prelude::*;
324
325    #[test]
326    fn benchmark() {
327        ShadowedFunction::new(MerkleRootFromXfes).bench();
328    }
329}