Skip to main content

uts_bmt/
lib.rs

1//! High performance binary Merkle tree implementation in Rust.
2
3// MIT License
4//
5// Copyright (c) 2025 UTS Contributors
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in all
15// copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23// SOFTWARE.
24//
25// Apache License, Version 2.0
26//
27// Copyright (c) 2025 UTS Contributors
28//
29// Licensed under the Apache License, Version 2.0 (the "License");
30// you may not use this file except in compliance with the License.
31// You may obtain a copy of the License at
32//
33//     http://www.apache.org/licenses/LICENSE-2.0
34//
35// Unless required by applicable law or agreed to in writing, software
36// distributed under the License is distributed on an "AS IS" BASIS,
37// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
38// See the License for the specific language governing permissions and
39// limitations under the License.
40
41use digest::{Digest, FixedOutputReset, Output, typenum::Unsigned};
42
43/// Prefix byte to distinguish internal nodes from leaves when hashing.
44pub const INNER_NODE_PREFIX: u8 = 0x01;
45
46/// Flat, Fixed-Size, Read only Merkle Tree
47///
48/// Expects the length of leaves to be equal or near(less) to a power of two.
49#[derive(Debug, Clone, Default)]
50pub struct MerkleTree<D: Digest> {
51    /// Index 0 is not used, leaves start at index `len`.
52    nodes: Box<[Output<D>]>,
53    len: usize,
54}
55
56/// Merkle Tree without hashing the leaves
57#[derive(Debug, Clone)]
58pub struct UnhashedMerkleTree<D: Digest> {
59    buffer: Vec<Output<D>>,
60    len: usize,
61}
62
63impl<D: Digest + FixedOutputReset> MerkleTree<D>
64where
65    Output<D>: Copy,
66{
67    /// Constructs a new Merkle tree from the given hash leaves.
68    pub fn new(data: &[Output<D>]) -> Self {
69        Self::new_unhashed(data).finalize()
70    }
71
72    /// Constructs a new Merkle tree from the given hash leaves, without hashing internal nodes.
73    pub fn new_unhashed(data: &[Output<D>]) -> UnhashedMerkleTree<D> {
74        let raw_len = data.len();
75        assert_ne!(raw_len, 0, "Cannot create Merkle tree with zero leaves");
76
77        let len = raw_len.next_power_of_two();
78        let mut nodes = Vec::<Output<D>>::with_capacity(2 * len);
79
80        unsafe {
81            let maybe_uninit = nodes.spare_capacity_mut();
82
83            // SAFETY: tree is valid for writes, properly aligned, and at least 1 element long.
84            // index 0, we will never use it
85            maybe_uninit
86                .get_unchecked_mut(0)
87                .write(Output::<D>::default());
88
89            // SAFETY: capacity * sizeof(T) is within the allocated size of `tree`
90            let dst = maybe_uninit.get_unchecked_mut(len..).as_mut_ptr().cast();
91            let src = data.as_ptr();
92            // SAFETY:
93            // - src is valid for reads `len` elements and properly aligned
94            // - dst is valid for writes `len` elements and properly aligned
95            // - the regions do not overlap since we just allocated `tree`
96            std::ptr::copy_nonoverlapping(src, dst, raw_len);
97
98            // SAFETY: capacity + len is within the allocated size of `tree`
99            for e in maybe_uninit.get_unchecked_mut(len + raw_len..) {
100                e.write(Output::<D>::default());
101            }
102        }
103
104        UnhashedMerkleTree { buffer: nodes, len }
105    }
106
107    /// Returns the root hash of the Merkle tree
108    #[inline]
109    pub fn root(&self) -> &Output<D> {
110        // SAFETY: index 1 is always initialized in new()
111        unsafe { self.nodes.get_unchecked(1) }
112    }
113
114    /// Returns the leaves of the Merkle tree
115    #[inline]
116    pub fn leaves(&self) -> &[Output<D>] {
117        unsafe { self.nodes.get_unchecked(self.len..self.len + self.len) }
118    }
119
120    /// Checks if the given leaf is contained in the Merkle tree
121    #[inline]
122    pub fn contains(&self, leaf: &Output<D>) -> bool {
123        self.leaves().contains(leaf)
124    }
125
126    /// Get proof for a given leaf
127    pub fn get_proof_iter(&self, leaf: &Output<D>) -> Option<SiblingIter<'_, D>> {
128        let leaf_index_in_slice = self.leaves().iter().position(|a| a == leaf)?;
129        Some(SiblingIter {
130            nodes: &self.nodes,
131            current: self.len + leaf_index_in_slice,
132        })
133    }
134
135    /// Returns the raw bytes of the Merkle tree nodes
136    #[inline]
137    pub fn to_raw_bytes(&self) -> Vec<u8> {
138        self.nodes
139            .iter()
140            .flat_map(|node| node.as_slice())
141            .copied()
142            .collect()
143    }
144
145    /// From raw bytes, reconstruct the Merkle tree
146    ///
147    /// # Panics
148    ///
149    /// - If the length of `bytes` is not a multiple of the hash output size.
150    /// - If the number of nodes implied by `bytes` is not consistent with a valid
151    ///   Merkle tree structure.
152    #[inline]
153    pub fn from_raw_bytes(bytes: &[u8]) -> Self {
154        assert!(
155            bytes.len().is_multiple_of(D::OutputSize::USIZE),
156            "Invalid raw bytes length"
157        );
158        let len = bytes.len() / D::OutputSize::USIZE;
159        assert!(len.is_multiple_of(2));
160        let mut nodes: Vec<Output<D>> = Vec::with_capacity(len);
161        for chunk in bytes.chunks_exact(D::OutputSize::USIZE) {
162            let node = Output::<D>::from_slice(chunk);
163            nodes.push(*node);
164        }
165        assert_eq!(nodes[0], Output::<D>::default());
166        let len = nodes.len() / 2;
167        Self {
168            nodes: nodes.to_vec().into_boxed_slice(),
169            len,
170        }
171    }
172}
173
174impl<D: Digest + FixedOutputReset> UnhashedMerkleTree<D>
175where
176    Output<D>: Copy,
177{
178    /// Finalizes the Merkle tree by hashing internal nodes
179    pub fn finalize(self) -> MerkleTree<D> {
180        let mut nodes = self.buffer;
181        let len = self.len;
182        unsafe {
183            let maybe_uninit = nodes.spare_capacity_mut();
184
185            // Build the tree
186            let mut hasher = D::new();
187            for i in (1..len).rev() {
188                // SAFETY: in bounds due to loop range and initialization above
189                let left = maybe_uninit.get_unchecked(2 * i).assume_init_ref();
190                let right = maybe_uninit.get_unchecked(2 * i + 1).assume_init_ref();
191
192                Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
193                Digest::update(&mut hasher, left);
194                Digest::update(&mut hasher, right);
195                let parent_hash = hasher.finalize_reset();
196
197                maybe_uninit.get_unchecked_mut(i).write(parent_hash);
198            }
199
200            // SAFETY: initialized all elements.
201            nodes.set_len(2 * len);
202        }
203        MerkleTree {
204            nodes: nodes.into_boxed_slice(),
205            len,
206        }
207    }
208}
209
210/// Iterator over the sibling nodes of a leaf in the Merkle tree
211#[derive(Debug, Clone)]
212pub struct SiblingIter<'a, D: Digest> {
213    nodes: &'a [Output<D>],
214    current: usize,
215}
216
217/// Indicates current node position relative to its sibling
218#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
219pub enum NodePosition {
220    /// The sibling is a right child, `APPEND` its hash when computing the parent
221    Left,
222    /// The sibling is a left child, `PREPEND` its hash when computing the parent
223    Right,
224}
225
226impl<'a, D: Digest> Iterator for SiblingIter<'a, D> {
227    /// (Yielded Node Position, Sibling Hash)
228    type Item = (NodePosition, &'a Output<D>);
229
230    fn next(&mut self) -> Option<Self::Item> {
231        if self.current <= 1 {
232            return None;
233        }
234        let side = if (self.current & 1) == 0 {
235            NodePosition::Left
236        } else {
237            NodePosition::Right
238        };
239        let sibling_index = self.current ^ 1;
240        let sibling = unsafe { self.nodes.get_unchecked(sibling_index) };
241        self.current >>= 1;
242        Some((side, sibling))
243    }
244
245    fn size_hint(&self) -> (usize, Option<usize>) {
246        let exact = self.current.ilog2() as usize;
247        (exact, Some(exact))
248    }
249}
250
251impl<D: Digest> ExactSizeIterator for SiblingIter<'_, D> {
252    fn len(&self) -> usize {
253        self.current.ilog2() as usize
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use alloy_primitives::{B256, U256};
261    use alloy_sol_types::SolValue;
262    use sha2::Sha256;
263    use sha3::Keccak256;
264
265    #[test]
266    fn basic() {
267        test_merkle_tree::<Sha256>();
268        test_merkle_tree::<Keccak256>();
269    }
270
271    #[test]
272    fn proof() {
273        test_proof::<Sha256>();
274        test_proof::<Keccak256>();
275    }
276
277    fn test_merkle_tree<D: Digest + FixedOutputReset>()
278    where
279        Output<D>: Copy,
280    {
281        let leaves = vec![
282            D::digest(b"leaf1"),
283            D::digest(b"leaf2"),
284            D::digest(b"leaf3"),
285            D::digest(b"leaf4"),
286        ];
287
288        let tree = MerkleTree::<D>::new(&leaves);
289
290        // Manually compute the expected root
291        let mut hasher = D::new();
292        Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
293        Digest::update(&mut hasher, leaves[0]);
294        Digest::update(&mut hasher, leaves[1]);
295        let left_hash = hasher.finalize_reset();
296
297        Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
298        Digest::update(&mut hasher, leaves[2]);
299        Digest::update(&mut hasher, leaves[3]);
300        let right_hash = hasher.finalize_reset();
301
302        Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
303        Digest::update(&mut hasher, left_hash);
304        Digest::update(&mut hasher, right_hash);
305        let expected_root = hasher.finalize();
306
307        assert_eq!(tree.root().as_slice(), expected_root.as_slice());
308    }
309
310    fn test_proof<D: Digest + FixedOutputReset>()
311    where
312        Output<D>: Copy,
313    {
314        let leaves = vec![
315            D::digest(b"apple"),
316            D::digest(b"banana"),
317            D::digest(b"cherry"),
318            D::digest(b"date"),
319        ];
320
321        let tree = MerkleTree::<D>::new(&leaves);
322
323        for leaf in &leaves {
324            let iter = tree
325                .get_proof_iter(leaf)
326                .expect("Leaf should be in the tree");
327            let mut current_hash = *leaf;
328
329            let mut hasher = D::new();
330            for (side, sibling_hash) in iter {
331                match side {
332                    NodePosition::Left => {
333                        Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
334                        Digest::update(&mut hasher, current_hash);
335                        Digest::update(&mut hasher, sibling_hash);
336                    }
337                    NodePosition::Right => {
338                        Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
339                        Digest::update(&mut hasher, sibling_hash);
340                        Digest::update(&mut hasher, current_hash);
341                    }
342                }
343                current_hash = hasher.finalize_reset();
344            }
345
346            assert_eq!(current_hash.as_slice(), tree.root().as_slice());
347        }
348    }
349
350    #[ignore]
351    #[test]
352    fn generate_sol_test() {
353        let mut leaves = Vec::with_capacity(1024);
354        for i in 0..1024 {
355            let mut hasher = Keccak256::new();
356            let value = U256::from(i).abi_encode_packed();
357            hasher.update(&value);
358            leaves.push(hasher.finalize());
359        }
360
361        for i in 0..=10u32 {
362            let tree = MerkleTree::<Keccak256>::new(&leaves[..2usize.pow(i)]);
363            let root = B256::from_slice(tree.root());
364            println!("bytes32({root}),");
365        }
366    }
367}