1use digest::{Digest, FixedOutputReset, Output, typenum::Unsigned};
42
43pub const INNER_NODE_PREFIX: u8 = 0x01;
45
46#[derive(Debug, Clone, Default)]
50pub struct MerkleTree<D: Digest> {
51 nodes: Box<[Output<D>]>,
53 len: usize,
54}
55
56#[derive(Debug, Clone)]
58pub struct UnhashedMerkleTree<D: Digest> {
59 buffer: Vec<Output<D>>,
60 len: usize,
61}
62
63impl<D: Digest + FixedOutputReset> MerkleTree<D>
64where
65 Output<D>: Copy,
66{
67 pub fn new(data: &[Output<D>]) -> Self {
69 Self::new_unhashed(data).finalize()
70 }
71
72 pub fn new_unhashed(data: &[Output<D>]) -> UnhashedMerkleTree<D> {
74 let raw_len = data.len();
75 assert_ne!(raw_len, 0, "Cannot create Merkle tree with zero leaves");
76
77 let len = raw_len.next_power_of_two();
78 let mut nodes = Vec::<Output<D>>::with_capacity(2 * len);
79
80 unsafe {
81 let maybe_uninit = nodes.spare_capacity_mut();
82
83 maybe_uninit
86 .get_unchecked_mut(0)
87 .write(Output::<D>::default());
88
89 let dst = maybe_uninit.get_unchecked_mut(len..).as_mut_ptr().cast();
91 let src = data.as_ptr();
92 std::ptr::copy_nonoverlapping(src, dst, raw_len);
97
98 for e in maybe_uninit.get_unchecked_mut(len + raw_len..) {
100 e.write(Output::<D>::default());
101 }
102 }
103
104 UnhashedMerkleTree { buffer: nodes, len }
105 }
106
107 #[inline]
109 pub fn root(&self) -> &Output<D> {
110 unsafe { self.nodes.get_unchecked(1) }
112 }
113
114 #[inline]
116 pub fn leaves(&self) -> &[Output<D>] {
117 unsafe { self.nodes.get_unchecked(self.len..self.len + self.len) }
118 }
119
120 #[inline]
122 pub fn contains(&self, leaf: &Output<D>) -> bool {
123 self.leaves().contains(leaf)
124 }
125
126 pub fn get_proof_iter(&self, leaf: &Output<D>) -> Option<SiblingIter<'_, D>> {
128 let leaf_index_in_slice = self.leaves().iter().position(|a| a == leaf)?;
129 Some(SiblingIter {
130 nodes: &self.nodes,
131 current: self.len + leaf_index_in_slice,
132 })
133 }
134
135 #[inline]
137 pub fn to_raw_bytes(&self) -> Vec<u8> {
138 self.nodes
139 .iter()
140 .flat_map(|node| node.as_slice())
141 .copied()
142 .collect()
143 }
144
145 #[inline]
153 pub fn from_raw_bytes(bytes: &[u8]) -> Self {
154 assert!(
155 bytes.len().is_multiple_of(D::OutputSize::USIZE),
156 "Invalid raw bytes length"
157 );
158 let len = bytes.len() / D::OutputSize::USIZE;
159 assert!(len.is_multiple_of(2));
160 let mut nodes: Vec<Output<D>> = Vec::with_capacity(len);
161 for chunk in bytes.chunks_exact(D::OutputSize::USIZE) {
162 let node = Output::<D>::from_slice(chunk);
163 nodes.push(*node);
164 }
165 assert_eq!(nodes[0], Output::<D>::default());
166 let len = nodes.len() / 2;
167 Self {
168 nodes: nodes.to_vec().into_boxed_slice(),
169 len,
170 }
171 }
172}
173
174impl<D: Digest + FixedOutputReset> UnhashedMerkleTree<D>
175where
176 Output<D>: Copy,
177{
178 pub fn finalize(self) -> MerkleTree<D> {
180 let mut nodes = self.buffer;
181 let len = self.len;
182 unsafe {
183 let maybe_uninit = nodes.spare_capacity_mut();
184
185 let mut hasher = D::new();
187 for i in (1..len).rev() {
188 let left = maybe_uninit.get_unchecked(2 * i).assume_init_ref();
190 let right = maybe_uninit.get_unchecked(2 * i + 1).assume_init_ref();
191
192 Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
193 Digest::update(&mut hasher, left);
194 Digest::update(&mut hasher, right);
195 let parent_hash = hasher.finalize_reset();
196
197 maybe_uninit.get_unchecked_mut(i).write(parent_hash);
198 }
199
200 nodes.set_len(2 * len);
202 }
203 MerkleTree {
204 nodes: nodes.into_boxed_slice(),
205 len,
206 }
207 }
208}
209
210#[derive(Debug, Clone)]
212pub struct SiblingIter<'a, D: Digest> {
213 nodes: &'a [Output<D>],
214 current: usize,
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
219pub enum NodePosition {
220 Left,
222 Right,
224}
225
226impl<'a, D: Digest> Iterator for SiblingIter<'a, D> {
227 type Item = (NodePosition, &'a Output<D>);
229
230 fn next(&mut self) -> Option<Self::Item> {
231 if self.current <= 1 {
232 return None;
233 }
234 let side = if (self.current & 1) == 0 {
235 NodePosition::Left
236 } else {
237 NodePosition::Right
238 };
239 let sibling_index = self.current ^ 1;
240 let sibling = unsafe { self.nodes.get_unchecked(sibling_index) };
241 self.current >>= 1;
242 Some((side, sibling))
243 }
244
245 fn size_hint(&self) -> (usize, Option<usize>) {
246 let exact = self.current.ilog2() as usize;
247 (exact, Some(exact))
248 }
249}
250
251impl<D: Digest> ExactSizeIterator for SiblingIter<'_, D> {
252 fn len(&self) -> usize {
253 self.current.ilog2() as usize
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use alloy_primitives::{B256, U256};
261 use alloy_sol_types::SolValue;
262 use sha2::Sha256;
263 use sha3::Keccak256;
264
265 #[test]
266 fn basic() {
267 test_merkle_tree::<Sha256>();
268 test_merkle_tree::<Keccak256>();
269 }
270
271 #[test]
272 fn proof() {
273 test_proof::<Sha256>();
274 test_proof::<Keccak256>();
275 }
276
277 fn test_merkle_tree<D: Digest + FixedOutputReset>()
278 where
279 Output<D>: Copy,
280 {
281 let leaves = vec![
282 D::digest(b"leaf1"),
283 D::digest(b"leaf2"),
284 D::digest(b"leaf3"),
285 D::digest(b"leaf4"),
286 ];
287
288 let tree = MerkleTree::<D>::new(&leaves);
289
290 let mut hasher = D::new();
292 Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
293 Digest::update(&mut hasher, leaves[0]);
294 Digest::update(&mut hasher, leaves[1]);
295 let left_hash = hasher.finalize_reset();
296
297 Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
298 Digest::update(&mut hasher, leaves[2]);
299 Digest::update(&mut hasher, leaves[3]);
300 let right_hash = hasher.finalize_reset();
301
302 Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
303 Digest::update(&mut hasher, left_hash);
304 Digest::update(&mut hasher, right_hash);
305 let expected_root = hasher.finalize();
306
307 assert_eq!(tree.root().as_slice(), expected_root.as_slice());
308 }
309
310 fn test_proof<D: Digest + FixedOutputReset>()
311 where
312 Output<D>: Copy,
313 {
314 let leaves = vec![
315 D::digest(b"apple"),
316 D::digest(b"banana"),
317 D::digest(b"cherry"),
318 D::digest(b"date"),
319 ];
320
321 let tree = MerkleTree::<D>::new(&leaves);
322
323 for leaf in &leaves {
324 let iter = tree
325 .get_proof_iter(leaf)
326 .expect("Leaf should be in the tree");
327 let mut current_hash = *leaf;
328
329 let mut hasher = D::new();
330 for (side, sibling_hash) in iter {
331 match side {
332 NodePosition::Left => {
333 Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
334 Digest::update(&mut hasher, current_hash);
335 Digest::update(&mut hasher, sibling_hash);
336 }
337 NodePosition::Right => {
338 Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
339 Digest::update(&mut hasher, sibling_hash);
340 Digest::update(&mut hasher, current_hash);
341 }
342 }
343 current_hash = hasher.finalize_reset();
344 }
345
346 assert_eq!(current_hash.as_slice(), tree.root().as_slice());
347 }
348 }
349
350 #[ignore]
351 #[test]
352 fn generate_sol_test() {
353 let mut leaves = Vec::with_capacity(1024);
354 for i in 0..1024 {
355 let mut hasher = Keccak256::new();
356 let value = U256::from(i).abi_encode_packed();
357 hasher.update(&value);
358 leaves.push(hasher.finalize());
359 }
360
361 for i in 0..=10u32 {
362 let tree = MerkleTree::<Keccak256>::new(&leaves[..2usize.pow(i)]);
363 let root = B256::from_slice(tree.root());
364 println!("bytes32({root}),");
365 }
366 }
367}