tasm_lib/hashing/merkle_root.rs
1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::prelude::*;
6use crate::traits::basic_snippet::Reviewer;
7use crate::traits::basic_snippet::SignOffFingerprint;
8
9/// Compute the Merkle root of a slice of `Digest`s. Corresponds to
10/// `MerkleTree::`[`sequential_new`][new]`(leafs).`[`root`][root]`()`.
11///
12/// ### Behavior
13///
14/// ```text
15/// BEFORE: _ *leafs
16/// AFTER: _ [root: Digest]
17/// ```
18///
19/// ### Preconditions
20///
21/// - `*leafs` points to a list of Digests
22/// - the length of the pointed-to list is greater than 0
23/// - the length of the pointed-to list is a power of 2
24/// - the length of the pointed-to list is a u32
25///
26/// ### Postconditions
27///
28/// None.
29///
30/// [new]: twenty_first::prelude::MerkleTree::sequential_new
31/// [root]: twenty_first::prelude::MerkleTree::root
32#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
33pub struct MerkleRoot;
34
35impl MerkleRoot {
36 pub const NUM_LEAFS_NOT_POWER_OF_2_ERROR_ID: i128 = 431;
37}
38
39impl BasicSnippet for MerkleRoot {
40 fn inputs(&self) -> Vec<(DataType, String)> {
41 vec![(
42 DataType::List(Box::new(DataType::Digest)),
43 "*leafs".to_string(),
44 )]
45 }
46
47 fn outputs(&self) -> Vec<(DataType, String)> {
48 vec![(DataType::Digest, "root".to_string())]
49 }
50
51 fn entrypoint(&self) -> String {
52 "tasmlib_hashing_merkle_root".to_string()
53 }
54
55 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
56 let dyn_malloc = library.import(Box::new(DynMalloc));
57
58 let entrypoint = self.entrypoint();
59 let calculate_parent_digests = format!("{entrypoint}_calculate_parent_digests");
60 let next_layer_loop = format!("{entrypoint}_next_layer_loop");
61
62 triton_asm!(
63 {entrypoint}:
64 // _ *leafs
65
66 read_mem 1
67 addi 1
68 // _ leafs_len *leafs
69
70 /* assert the number of leafs is some power of 2 */
71 dup 1
72 pop_count
73 push 1
74 eq
75 assert error_id {Self::NUM_LEAFS_NOT_POWER_OF_2_ERROR_ID}
76
77 call {dyn_malloc}
78 // _ leafs_len *leafs *parent_level
79
80 /* adjust `*parent_level` to point to last element, first word */
81 dup 2
82 addi -1
83 push {Digest::LEN}
84 mul
85 add
86 // _ leafs_len *leafs (*parent_level + (leafs_len - 1) * Digest::LEN)
87 // _ leafs_len *leafs *parent_level'
88
89 /* adjust `*leafs` to point to last element, last word */
90 pick 1
91 dup 2
92 push {Digest::LEN}
93 mul
94 add
95 // _ leafs_len *parent_level' (*leafs + leafs_len * Digest::LEN)
96 // _ leafs_len *parent_level' *leafs'
97
98 call {next_layer_loop}
99 // _ 1 *address (*root + Digest::LEN)
100
101 place 2
102 pop 2
103 // _ (*root + Digest::LEN - 1)
104
105 read_mem {Digest::LEN}
106 // _ [root: Digest] (*root - 1)
107
108 pop 1
109 // _ [root: Digest]
110
111 return
112
113 // INVARIANT: _ current_len *next_level[last]_first_word *current_level[last]_last_word
114 {next_layer_loop}:
115 // _ current_len *next_level *current_level
116
117 /* end loop if `current_len == 1` */
118 dup 2
119 push 1
120 eq
121 skiz
122 return
123 // _ current_len *next_level *current_level
124
125 /* update `current_len` */
126 pick 2
127 push {bfe!(2).inverse()}
128 hint one_half = stack[0]
129 mul
130 place 2
131 // _ (current_len/2) *next_level *current_level
132
133 /* set up termination condition for parent calculation loop:
134 * `*next_level - current_len / 2 * Digest::LEN`
135 */
136 dup 1
137 dup 3
138 push {-(Digest::LEN as isize)}
139 mul
140 add
141 // _ (current_len/2) *next_level *current_level *next_level_stop
142 // _ (current_len/2) *next_level *current_elem *next_elem_stop
143
144 dup 2
145 push 0
146 push 0
147 push 0
148 push 0
149 pick 6
150 // _ (current_len/2) *next_level *next_elem_stop *next_level 0 0 0 0 *current_elem
151
152 call {calculate_parent_digests}
153 pop 5
154 pop 1
155 // _ (current_len/2) *next_level *next_elem_stop
156
157 /* Update `*current_level` based on `*next_level` */
158 pick 1
159 // _ (current_len/2) *next_elem_stop *next_level
160
161 addi {Digest::LEN - 1}
162 // _ (current_len/2) *next_level' *current_level'
163
164 recurse
165
166 // Populate the `*next` digest list
167 // INVARIANT: _ *next_elem_stop *next_elem 0 0 0 0 *curr_elem
168 {calculate_parent_digests}:
169 read_mem {Digest::LEN}
170 read_mem {Digest::LEN}
171 // _ *next_elem_stop *next_elem 0 0 0 0 [right] [left] (*curr_elem[n] - 10)
172 // _ *next_elem_stop *next_elem 0 0 0 0 [right] [left] *curr_elem[n - 2]
173 // _ *next_elem_stop *next_elem 0 0 0 0 [right] [left] *curr_elem'
174
175 place 10
176 // _ *next_elem_stop *next_elem 0 0 0 0 *curr_elem' [right] [left]
177
178 hash
179 // _ *next_elem_stop *next_elem 0 0 0 0 *curr_elem' [parent_digest]
180
181 pick 10
182 // _ *next_elem_stop 0 0 0 0 *curr_elem' [parent_digest] *next_elem
183
184 write_mem {Digest::LEN}
185 // _ *next_elem_stop 0 0 0 0 *curr_elem' (*next_elem + 5)
186
187 addi -10
188 // _ *next_elem_stop 0 0 0 0 *curr_elem' (*next_elem - 5)
189 // _ *next_elem_stop 0 0 0 0 *curr_elem' *next_elem[n-1]
190 // _ *next_elem_stop 0 0 0 0 *curr_elem' *next_elem'
191
192 place 5
193 // _ *next_elem_stop *next_elem' 0 0 0 0 *curr_elem'
194
195 recurse_or_return
196 )
197 }
198
199 fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
200 let mut sign_offs = HashMap::new();
201 sign_offs.insert(Reviewer("ferdinand"), 0x1c30ac983fdca9da.into());
202 sign_offs
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use proptest::collection::vec;
209 use twenty_first::util_types::merkle_tree::MerkleTree;
210
211 use super::*;
212 use crate::rust_shadowing_helper_functions::dyn_malloc::dynamic_allocator;
213 use crate::test_prelude::*;
214
215 impl MerkleRoot {
216 fn init_state(
217 &self,
218 leafs: Vec<Digest>,
219 digests_pointer: BFieldElement,
220 ) -> FunctionInitialState {
221 let mut memory = HashMap::new();
222 encode_to_memory(&mut memory, digests_pointer, &leafs);
223 let mut stack = self.init_stack_for_isolated_run();
224 stack.push(digests_pointer);
225
226 FunctionInitialState { stack, memory }
227 }
228 }
229
230 impl Function for MerkleRoot {
231 fn rust_shadow(
232 &self,
233 stack: &mut Vec<BFieldElement>,
234 memory: &mut HashMap<BFieldElement, BFieldElement>,
235 ) {
236 let leafs_pointer = stack.pop().unwrap();
237 let leafs = *Vec::decode_from_memory(memory, leafs_pointer).unwrap();
238 let mt = MerkleTree::par_new(&leafs).unwrap();
239
240 // mimic snippet: write internal nodes to memory, skipping (dummy) node 0
241 let tree_pointer = dynamic_allocator(memory);
242 let num_internal_nodes = leafs.len();
243
244 for node_index in 1..num_internal_nodes {
245 let node = mt.node(node_index).unwrap();
246 let node_address = tree_pointer + bfe!(node_index * Digest::LEN);
247 encode_to_memory(memory, node_address, &node);
248 }
249
250 stack.extend(mt.root().reversed().values());
251 }
252
253 fn pseudorandom_initial_state(
254 &self,
255 seed: [u8; 32],
256 bench_case: Option<BenchmarkCase>,
257 ) -> FunctionInitialState {
258 let mut rng = StdRng::from_seed(seed);
259 let num_leafs = match bench_case {
260 Some(BenchmarkCase::CommonCase) => 512,
261 Some(BenchmarkCase::WorstCase) => 1024,
262 None => 1 << rng.random_range(0..=8),
263 };
264 let leafs = (0..num_leafs).map(|_| rng.random()).collect_vec();
265 let digests_pointer = rng.random();
266
267 self.init_state(leafs, digests_pointer)
268 }
269
270 fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
271 let height_0 = self.init_state(vec![Digest::default()], bfe!(0));
272 let height_1 = self.init_state(vec![Digest::default(), Digest::default()], bfe!(0));
273
274 vec![height_0, height_1]
275 }
276 }
277
278 #[test]
279 fn rust_shadow() {
280 ShadowedFunction::new(MerkleRoot).test();
281 }
282
283 #[test]
284 fn computing_root_of_tree_of_height_0_crashes_vm() {
285 test_assertion_failure(
286 &ShadowedFunction::new(MerkleRoot),
287 MerkleRoot.init_state(vec![], bfe!(0)).into(),
288 &[MerkleRoot::NUM_LEAFS_NOT_POWER_OF_2_ERROR_ID],
289 );
290 }
291
292 #[proptest(cases = 100)]
293 fn computing_root_of_tree_of_height_not_power_of_2_crashes_vm(
294 #[strategy(vec(arb(), 0..2048))]
295 #[filter(!#leafs.len().is_power_of_two())]
296 leafs: Vec<Digest>,
297 #[strategy(arb())] address: BFieldElement,
298 ) {
299 test_assertion_failure(
300 &ShadowedFunction::new(MerkleRoot),
301 MerkleRoot.init_state(leafs, address).into(),
302 &[MerkleRoot::NUM_LEAFS_NOT_POWER_OF_2_ERROR_ID],
303 );
304 }
305}
306
307#[cfg(test)]
308mod benches {
309 use super::*;
310 use crate::test_prelude::*;
311
312 #[test]
313 fn benchmark() {
314 ShadowedFunction::new(MerkleRoot).bench();
315 }
316}