Skip to main content

uts_bmt/
lib.rs

1#![feature(maybe_uninit_fill)]
2#![feature(likely_unlikely)]
3//! High performance binary Merkle tree implementation in Rust.
4
5use bytemuck::Pod;
6use digest::{Digest, FixedOutputReset, Output};
7use std::hint::unlikely;
8
9/// Prefix byte to distinguish internal nodes from leaves when hashing.
10pub const INNER_NODE_PREFIX: u8 = 0x01;
11
12/// Flat, Fixed-Size, Read only Merkle Tree
13///
14/// Expects the length of leaves to be equal or near(less) to a power of two.
15#[derive(Debug, Clone, Default)]
16pub struct MerkleTree<D: Digest> {
17    /// Index 0 is not used, leaves start at index `len`.
18    nodes: Box<[Output<D>]>,
19    len: usize,
20}
21
22/// Merkle Tree without hashing the leaves
23#[derive(Debug, Clone)]
24pub struct UnhashedMerkleTree<D: Digest> {
25    buffer: Vec<Output<D>>,
26    len: usize,
27}
28
29impl<D: Digest + FixedOutputReset> MerkleTree<D>
30where
31    Output<D>: Pod + Copy,
32{
33    /// Constructs a new Merkle tree from the given hash leaves.
34    pub fn new(data: &[Output<D>]) -> Self {
35        Self::new_unhashed(data).finalize()
36    }
37
38    /// Constructs a new Merkle tree from the given hash leaves, without hashing internal nodes.
39    pub fn new_unhashed(data: &[Output<D>]) -> UnhashedMerkleTree<D> {
40        let raw_len = data.len();
41        assert_ne!(raw_len, 0, "Cannot create Merkle tree with zero leaves");
42
43        let len = raw_len.next_power_of_two();
44        let mut nodes = Vec::<Output<D>>::with_capacity(2 * len);
45
46        unsafe {
47            let maybe_uninit = nodes.spare_capacity_mut();
48
49            // SAFETY: tree is valid for writes, properly aligned, and at least 1 element long.
50            // index 0, we will never use it
51            maybe_uninit
52                .get_unchecked_mut(0)
53                .write(Output::<D>::default());
54
55            // SAFETY: capacity * sizeof(T) is within the allocated size of `tree`
56            let dst = maybe_uninit.get_unchecked_mut(len..).as_mut_ptr().cast();
57            let src = data.as_ptr();
58            // SAFETY:
59            // - src is valid for reads `len` elements and properly aligned
60            // - dst is valid for writes `len` elements and properly aligned
61            // - the regions do not overlap since we just allocated `tree`
62            std::ptr::copy_nonoverlapping(src, dst, raw_len);
63
64            // SAFETY: capacity + len is within the allocated size of `tree`
65            maybe_uninit
66                .get_unchecked_mut(len + raw_len..)
67                .write_filled(Output::<D>::default());
68        }
69
70        UnhashedMerkleTree { buffer: nodes, len }
71    }
72
73    /// Returns the root hash of the Merkle tree
74    #[inline]
75    pub fn root(&self) -> &Output<D> {
76        // SAFETY: index 1 is always initialized in new()
77        unsafe { self.nodes.get_unchecked(1) }
78    }
79
80    /// Returns the leaves of the Merkle tree
81    #[inline]
82    pub fn leaves(&self) -> &[Output<D>] {
83        unsafe { self.nodes.get_unchecked(self.len..self.len + self.len) }
84    }
85
86    /// Checks if the given leaf is contained in the Merkle tree
87    #[inline]
88    pub fn contains(&self, leaf: &Output<D>) -> bool {
89        self.leaves().contains(leaf)
90    }
91
92    /// Get proof for a given leaf
93    pub fn get_proof_iter(&self, leaf: &Output<D>) -> Option<SiblingIter<'_, D>> {
94        let leaf_index_in_slice = self.leaves().iter().position(|a| a == leaf)?;
95        Some(SiblingIter {
96            nodes: &self.nodes,
97            current: self.len + leaf_index_in_slice,
98        })
99    }
100
101    /// Returns the raw bytes of the Merkle tree nodes
102    #[inline]
103    pub fn as_raw_bytes(&self) -> &[u8] {
104        bytemuck::cast_slice(&self.nodes)
105    }
106
107    /// From raw bytes, reconstruct the Merkle tree
108    ///
109    /// # Panics
110    ///
111    /// - If the length of `bytes` is not a multiple of the hash output size.
112    /// - If the number of nodes implied by `bytes` is not consistent with a valid
113    ///   Merkle tree structure.
114    #[inline]
115    pub fn from_raw_bytes(bytes: &[u8]) -> Self {
116        let nodes: &[Output<D>] = bytemuck::cast_slice(bytes);
117        assert!(nodes.len().is_multiple_of(2));
118        assert_eq!(nodes[0], Output::<D>::default());
119        let len = nodes.len() / 2;
120        Self {
121            nodes: nodes.to_vec().into_boxed_slice(),
122            len,
123        }
124    }
125}
126
127impl<D: Digest + FixedOutputReset> UnhashedMerkleTree<D>
128where
129    Output<D>: Pod + Copy,
130{
131    /// Finalizes the Merkle tree by hashing internal nodes
132    pub fn finalize(self) -> MerkleTree<D> {
133        let mut nodes = self.buffer;
134        let len = self.len;
135        unsafe {
136            let maybe_uninit = nodes.spare_capacity_mut();
137
138            // Build the tree
139            let mut hasher = D::new();
140            for i in (1..len).rev() {
141                // SAFETY: in bounds due to loop range and initialization above
142                let left = maybe_uninit.get_unchecked(2 * i).assume_init_ref();
143                let right = maybe_uninit.get_unchecked(2 * i + 1).assume_init_ref();
144
145                Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
146                Digest::update(&mut hasher, left);
147                Digest::update(&mut hasher, right);
148                let parent_hash = hasher.finalize_reset();
149
150                maybe_uninit.get_unchecked_mut(i).write(parent_hash);
151            }
152
153            // SAFETY: initialized all elements.
154            nodes.set_len(2 * len);
155        }
156        MerkleTree {
157            nodes: nodes.into_boxed_slice(),
158            len,
159        }
160    }
161}
162
163/// Iterator over the sibling nodes of a leaf in the Merkle tree
164#[derive(Debug, Clone)]
165pub struct SiblingIter<'a, D: Digest> {
166    nodes: &'a [Output<D>],
167    current: usize,
168}
169
170/// Indicates current node position relative to its sibling
171#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
172pub enum NodePosition {
173    /// The sibling is a right child, `APPEND` its hash when computing the parent
174    Left,
175    /// The sibling is a left child, `PREPEND` its hash when computing the parent
176    Right,
177}
178
179impl<'a, D: Digest> Iterator for SiblingIter<'a, D> {
180    /// (Yielded Node Position, Sibling Hash)
181    type Item = (NodePosition, &'a Output<D>);
182
183    fn next(&mut self) -> Option<Self::Item> {
184        if unlikely(self.current <= 1) {
185            return None;
186        }
187        let side = if (self.current & 1) == 0 {
188            NodePosition::Left
189        } else {
190            NodePosition::Right
191        };
192        let sibling_index = self.current ^ 1;
193        let sibling = unsafe { self.nodes.get_unchecked(sibling_index) };
194        self.current >>= 1;
195        Some((side, sibling))
196    }
197
198    fn size_hint(&self) -> (usize, Option<usize>) {
199        let exact = self.current.ilog2() as usize;
200        (exact, Some(exact))
201    }
202}
203
204impl<D: Digest> ExactSizeIterator for SiblingIter<'_, D> {
205    fn len(&self) -> usize {
206        self.current.ilog2() as usize
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use alloy_primitives::{B256, U256};
214    use alloy_sol_types::SolValue;
215    use sha2::Sha256;
216    use sha3::Keccak256;
217
218    #[test]
219    fn basic() {
220        test_merkle_tree::<Sha256>();
221        test_merkle_tree::<Keccak256>();
222    }
223
224    #[test]
225    fn proof() {
226        test_proof::<Sha256>();
227        test_proof::<Keccak256>();
228    }
229
230    fn test_merkle_tree<D: Digest + FixedOutputReset>()
231    where
232        Output<D>: Pod + Copy,
233    {
234        let leaves = vec![
235            D::digest(b"leaf1"),
236            D::digest(b"leaf2"),
237            D::digest(b"leaf3"),
238            D::digest(b"leaf4"),
239        ];
240
241        let tree = MerkleTree::<D>::new(&leaves);
242
243        // Manually compute the expected root
244        let mut hasher = D::new();
245        Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
246        Digest::update(&mut hasher, leaves[0]);
247        Digest::update(&mut hasher, leaves[1]);
248        let left_hash = hasher.finalize_reset();
249
250        Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
251        Digest::update(&mut hasher, leaves[2]);
252        Digest::update(&mut hasher, leaves[3]);
253        let right_hash = hasher.finalize_reset();
254
255        Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
256        Digest::update(&mut hasher, left_hash);
257        Digest::update(&mut hasher, right_hash);
258        let expected_root = hasher.finalize();
259
260        assert_eq!(tree.root().as_slice(), expected_root.as_slice());
261    }
262
263    fn test_proof<D: Digest + FixedOutputReset>()
264    where
265        Output<D>: Pod + Copy,
266    {
267        let leaves = vec![
268            D::digest(b"apple"),
269            D::digest(b"banana"),
270            D::digest(b"cherry"),
271            D::digest(b"date"),
272        ];
273
274        let tree = MerkleTree::<D>::new(&leaves);
275
276        for leaf in &leaves {
277            let iter = tree
278                .get_proof_iter(leaf)
279                .expect("Leaf should be in the tree");
280            let mut current_hash = *leaf;
281
282            let mut hasher = D::new();
283            for (side, sibling_hash) in iter {
284                match side {
285                    NodePosition::Left => {
286                        Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
287                        Digest::update(&mut hasher, current_hash);
288                        Digest::update(&mut hasher, sibling_hash);
289                    }
290                    NodePosition::Right => {
291                        Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
292                        Digest::update(&mut hasher, sibling_hash);
293                        Digest::update(&mut hasher, current_hash);
294                    }
295                }
296                current_hash = hasher.finalize_reset();
297            }
298
299            assert_eq!(current_hash.as_slice(), tree.root().as_slice());
300        }
301    }
302
303    #[ignore]
304    #[test]
305    fn generate_sol_test() {
306        let mut leaves = Vec::with_capacity(1024);
307        for i in 0..1024 {
308            let mut hasher = Keccak256::new();
309            let value = U256::from(i).abi_encode_packed();
310            hasher.update(&value);
311            leaves.push(hasher.finalize());
312        }
313
314        for i in 0..=10u32 {
315            let tree = MerkleTree::<Keccak256>::new(&leaves[..2usize.pow(i)]);
316            let root = B256::new(tree.root().0);
317            println!("bytes32({root}),");
318        }
319    }
320}