1#![feature(maybe_uninit_fill)]
2#![feature(likely_unlikely)]
3use bytemuck::Pod;
6use digest::{Digest, FixedOutputReset, Output};
7use std::hint::unlikely;
8
9pub const INNER_NODE_PREFIX: u8 = 0x01;
11
12#[derive(Debug, Clone, Default)]
16pub struct MerkleTree<D: Digest> {
17 nodes: Box<[Output<D>]>,
19 len: usize,
20}
21
22#[derive(Debug, Clone)]
24pub struct UnhashedMerkleTree<D: Digest> {
25 buffer: Vec<Output<D>>,
26 len: usize,
27}
28
29impl<D: Digest + FixedOutputReset> MerkleTree<D>
30where
31 Output<D>: Pod + Copy,
32{
33 pub fn new(data: &[Output<D>]) -> Self {
35 Self::new_unhashed(data).finalize()
36 }
37
38 pub fn new_unhashed(data: &[Output<D>]) -> UnhashedMerkleTree<D> {
40 let raw_len = data.len();
41 assert_ne!(raw_len, 0, "Cannot create Merkle tree with zero leaves");
42
43 let len = raw_len.next_power_of_two();
44 let mut nodes = Vec::<Output<D>>::with_capacity(2 * len);
45
46 unsafe {
47 let maybe_uninit = nodes.spare_capacity_mut();
48
49 maybe_uninit
52 .get_unchecked_mut(0)
53 .write(Output::<D>::default());
54
55 let dst = maybe_uninit.get_unchecked_mut(len..).as_mut_ptr().cast();
57 let src = data.as_ptr();
58 std::ptr::copy_nonoverlapping(src, dst, raw_len);
63
64 maybe_uninit
66 .get_unchecked_mut(len + raw_len..)
67 .write_filled(Output::<D>::default());
68 }
69
70 UnhashedMerkleTree { buffer: nodes, len }
71 }
72
73 #[inline]
75 pub fn root(&self) -> &Output<D> {
76 unsafe { self.nodes.get_unchecked(1) }
78 }
79
80 #[inline]
82 pub fn leaves(&self) -> &[Output<D>] {
83 unsafe { self.nodes.get_unchecked(self.len..self.len + self.len) }
84 }
85
86 #[inline]
88 pub fn contains(&self, leaf: &Output<D>) -> bool {
89 self.leaves().contains(leaf)
90 }
91
92 pub fn get_proof_iter(&self, leaf: &Output<D>) -> Option<SiblingIter<'_, D>> {
94 let leaf_index_in_slice = self.leaves().iter().position(|a| a == leaf)?;
95 Some(SiblingIter {
96 nodes: &self.nodes,
97 current: self.len + leaf_index_in_slice,
98 })
99 }
100
101 #[inline]
103 pub fn as_raw_bytes(&self) -> &[u8] {
104 bytemuck::cast_slice(&self.nodes)
105 }
106
107 #[inline]
115 pub fn from_raw_bytes(bytes: &[u8]) -> Self {
116 let nodes: &[Output<D>] = bytemuck::cast_slice(bytes);
117 assert!(nodes.len().is_multiple_of(2));
118 assert_eq!(nodes[0], Output::<D>::default());
119 let len = nodes.len() / 2;
120 Self {
121 nodes: nodes.to_vec().into_boxed_slice(),
122 len,
123 }
124 }
125}
126
127impl<D: Digest + FixedOutputReset> UnhashedMerkleTree<D>
128where
129 Output<D>: Pod + Copy,
130{
131 pub fn finalize(self) -> MerkleTree<D> {
133 let mut nodes = self.buffer;
134 let len = self.len;
135 unsafe {
136 let maybe_uninit = nodes.spare_capacity_mut();
137
138 let mut hasher = D::new();
140 for i in (1..len).rev() {
141 let left = maybe_uninit.get_unchecked(2 * i).assume_init_ref();
143 let right = maybe_uninit.get_unchecked(2 * i + 1).assume_init_ref();
144
145 Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
146 Digest::update(&mut hasher, left);
147 Digest::update(&mut hasher, right);
148 let parent_hash = hasher.finalize_reset();
149
150 maybe_uninit.get_unchecked_mut(i).write(parent_hash);
151 }
152
153 nodes.set_len(2 * len);
155 }
156 MerkleTree {
157 nodes: nodes.into_boxed_slice(),
158 len,
159 }
160 }
161}
162
163#[derive(Debug, Clone)]
165pub struct SiblingIter<'a, D: Digest> {
166 nodes: &'a [Output<D>],
167 current: usize,
168}
169
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
172pub enum NodePosition {
173 Left,
175 Right,
177}
178
179impl<'a, D: Digest> Iterator for SiblingIter<'a, D> {
180 type Item = (NodePosition, &'a Output<D>);
182
183 fn next(&mut self) -> Option<Self::Item> {
184 if unlikely(self.current <= 1) {
185 return None;
186 }
187 let side = if (self.current & 1) == 0 {
188 NodePosition::Left
189 } else {
190 NodePosition::Right
191 };
192 let sibling_index = self.current ^ 1;
193 let sibling = unsafe { self.nodes.get_unchecked(sibling_index) };
194 self.current >>= 1;
195 Some((side, sibling))
196 }
197
198 fn size_hint(&self) -> (usize, Option<usize>) {
199 let exact = self.current.ilog2() as usize;
200 (exact, Some(exact))
201 }
202}
203
204impl<D: Digest> ExactSizeIterator for SiblingIter<'_, D> {
205 fn len(&self) -> usize {
206 self.current.ilog2() as usize
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use alloy_primitives::{B256, U256};
214 use alloy_sol_types::SolValue;
215 use sha2::Sha256;
216 use sha3::Keccak256;
217
218 #[test]
219 fn basic() {
220 test_merkle_tree::<Sha256>();
221 test_merkle_tree::<Keccak256>();
222 }
223
224 #[test]
225 fn proof() {
226 test_proof::<Sha256>();
227 test_proof::<Keccak256>();
228 }
229
230 fn test_merkle_tree<D: Digest + FixedOutputReset>()
231 where
232 Output<D>: Pod + Copy,
233 {
234 let leaves = vec![
235 D::digest(b"leaf1"),
236 D::digest(b"leaf2"),
237 D::digest(b"leaf3"),
238 D::digest(b"leaf4"),
239 ];
240
241 let tree = MerkleTree::<D>::new(&leaves);
242
243 let mut hasher = D::new();
245 Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
246 Digest::update(&mut hasher, leaves[0]);
247 Digest::update(&mut hasher, leaves[1]);
248 let left_hash = hasher.finalize_reset();
249
250 Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
251 Digest::update(&mut hasher, leaves[2]);
252 Digest::update(&mut hasher, leaves[3]);
253 let right_hash = hasher.finalize_reset();
254
255 Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
256 Digest::update(&mut hasher, left_hash);
257 Digest::update(&mut hasher, right_hash);
258 let expected_root = hasher.finalize();
259
260 assert_eq!(tree.root().as_slice(), expected_root.as_slice());
261 }
262
263 fn test_proof<D: Digest + FixedOutputReset>()
264 where
265 Output<D>: Pod + Copy,
266 {
267 let leaves = vec![
268 D::digest(b"apple"),
269 D::digest(b"banana"),
270 D::digest(b"cherry"),
271 D::digest(b"date"),
272 ];
273
274 let tree = MerkleTree::<D>::new(&leaves);
275
276 for leaf in &leaves {
277 let iter = tree
278 .get_proof_iter(leaf)
279 .expect("Leaf should be in the tree");
280 let mut current_hash = *leaf;
281
282 let mut hasher = D::new();
283 for (side, sibling_hash) in iter {
284 match side {
285 NodePosition::Left => {
286 Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
287 Digest::update(&mut hasher, current_hash);
288 Digest::update(&mut hasher, sibling_hash);
289 }
290 NodePosition::Right => {
291 Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
292 Digest::update(&mut hasher, sibling_hash);
293 Digest::update(&mut hasher, current_hash);
294 }
295 }
296 current_hash = hasher.finalize_reset();
297 }
298
299 assert_eq!(current_hash.as_slice(), tree.root().as_slice());
300 }
301 }
302
303 #[ignore]
304 #[test]
305 fn generate_sol_test() {
306 let mut leaves = Vec::with_capacity(1024);
307 for i in 0..1024 {
308 let mut hasher = Keccak256::new();
309 let value = U256::from(i).abi_encode_packed();
310 hasher.update(&value);
311 leaves.push(hasher.finalize());
312 }
313
314 for i in 0..=10u32 {
315 let tree = MerkleTree::<Keccak256>::new(&leaves[..2usize.pow(i)]);
316 let root = B256::new(tree.root().0);
317 println!("bytes32({root}),");
318 }
319 }
320}