semaphore_rs_trees/imt/
mod.rs

1//! Implements basic binary Merkle trees
2
3use std::fmt::Debug;
4use std::iter::{once, repeat_n, successors};
5
6use bytemuck::Pod;
7use derive_where::derive_where;
8use semaphore_rs_hasher::Hasher;
9
10use crate::proof::{Branch, InclusionProof};
11
12/// Merkle tree with all leaf and intermediate hashes stored
13#[derive_where(Clone; <H as Hasher>::Hash: Clone)]
14#[derive_where(PartialEq; <H as Hasher>::Hash: PartialEq)]
15#[derive_where(Eq; <H as Hasher>::Hash: Eq)]
16#[derive_where(Debug; <H as Hasher>::Hash: Debug)]
17pub struct MerkleTree<H>
18where
19    H: Hasher,
20{
21    /// Depth of the tree, # of layers including leaf layer
22    depth: usize,
23
24    /// Hash value of empty subtrees of given depth, starting at leaf level
25    empty: Vec<H::Hash>,
26
27    /// Hash values of tree nodes and leaves, breadth first order
28    nodes: Vec<H::Hash>,
29}
30
31/// For a given node index, return the parent node index
32/// Returns None if there is no parent (root node)
33const fn parent(index: usize) -> Option<usize> {
34    if index <= 1 {
35        None
36    } else {
37        Some(index >> 1)
38    }
39}
40
41/// For a given node index, return index of the first (left) child.
42const fn left_child(index: usize) -> usize {
43    index << 1
44}
45
46const fn depth(index: usize) -> usize {
47    // `n.next_power_of_two()` will return `n` iff `n` is a power of two.
48    // The extra offset corrects this.
49    if index <= 1 {
50        return 0;
51    }
52
53    index.ilog2() as usize
54}
55
56impl<H> MerkleTree<H>
57where
58    H: Hasher,
59    <H as Hasher>::Hash: Clone + Copy + Pod + Eq + Debug,
60{
61    /// Creates a new `MerkleTree`
62    /// * `depth` - The depth of the tree, including the root. This is 1 greater
63    ///   than the `treeLevels` argument to the Semaphore contract.
64    pub fn new(depth: usize, initial_leaf: H::Hash) -> Self {
65        // Compute empty node values, leaf to root
66        let empty = successors(Some(initial_leaf), |prev| Some(H::hash_node(prev, prev)))
67            .take(depth + 1)
68            .collect::<Vec<_>>();
69
70        // Compute node values
71        let first_node = std::iter::once(initial_leaf);
72        let nodes = empty
73            .iter()
74            .rev()
75            .enumerate()
76            .flat_map(|(depth, hash)| repeat_n(hash, 1 << depth))
77            .cloned();
78
79        let nodes = first_node.chain(nodes).collect();
80
81        Self {
82            depth,
83            empty,
84            nodes,
85        }
86    }
87
88    #[must_use]
89    pub fn num_leaves(&self) -> usize {
90        1 << self.depth
91    }
92
93    #[must_use]
94    pub fn root(&self) -> H::Hash {
95        self.nodes[1]
96    }
97
98    pub fn set(&mut self, leaf: usize, hash: H::Hash) {
99        self.set_range(leaf, once(hash));
100    }
101
102    pub fn set_range<I: IntoIterator<Item = H::Hash>>(&mut self, start: usize, hashes: I) {
103        let index = self.num_leaves() + start;
104
105        let mut count = 0;
106        // TODO: Error/panic when hashes is longer than available leafs
107        for (leaf, hash) in self.nodes[index..].iter_mut().zip(hashes) {
108            *leaf = hash;
109            count += 1;
110        }
111
112        if count != 0 {
113            self.update_nodes(index, index + (count - 1));
114        }
115    }
116
117    fn update_nodes(&mut self, start: usize, end: usize) {
118        debug_assert_eq!(depth(start), depth(end));
119        if let (Some(start), Some(end)) = (parent(start), parent(end)) {
120            for parent in start..=end {
121                let child = left_child(parent);
122                self.nodes[parent] = H::hash_node(&self.nodes[child], &self.nodes[child + 1]);
123            }
124            self.update_nodes(start, end);
125        }
126    }
127
128    #[must_use]
129    pub fn proof(&self, leaf: usize) -> Option<InclusionProof<H>> {
130        if leaf >= self.num_leaves() {
131            return None;
132        }
133        let mut index = self.num_leaves() + leaf;
134        let mut path = Vec::with_capacity(self.depth);
135        while let Some(parent) = parent(index) {
136            // Add proof for node at index to parent
137            path.push(match index & 1 {
138                1 => Branch::Right(self.nodes[index - 1]),
139                0 => Branch::Left(self.nodes[index + 1]),
140                _ => unreachable!(),
141            });
142            index = parent;
143        }
144        Some(InclusionProof(path))
145    }
146
147    #[must_use]
148    pub fn verify(&self, hash: H::Hash, proof: &InclusionProof<H>) -> bool {
149        proof.root(hash) == self.root()
150    }
151
152    #[must_use]
153    pub fn leaves(&self) -> &[H::Hash] {
154        &self.nodes[(self.num_leaves() - 1)..]
155    }
156}
157
158impl<H: Hasher> InclusionProof<H> {
159    /// Compute the leaf index for this proof
160    #[must_use]
161    pub fn leaf_index(&self) -> usize {
162        self.0.iter().rev().fold(0, |index, branch| match branch {
163            Branch::Left(_) => index << 1,
164            Branch::Right(_) => (index << 1) + 1,
165        })
166    }
167
168    /// Compute the Merkle root given a leaf hash
169    #[must_use]
170    pub fn root(&self, hash: H::Hash) -> H::Hash {
171        self.0.iter().fold(hash, |hash, branch| match branch {
172            Branch::Left(sibling) => H::hash_node(&hash, sibling),
173            Branch::Right(sibling) => H::hash_node(sibling, &hash),
174        })
175    }
176}
177
178#[cfg(test)]
179pub mod test {
180    use hex_literal::hex;
181    use ruint::aliases::U256;
182    use semaphore_rs_keccak::keccak::Keccak256;
183    use semaphore_rs_poseidon::Poseidon;
184    use test_case::test_case;
185
186    use super::*;
187
188    #[test_case(0 => None)]
189    #[test_case(1 => None)]
190    #[test_case(2 => Some(1))]
191    #[test_case(3 => Some(1))]
192    #[test_case(4 => Some(2))]
193    #[test_case(5 => Some(2))]
194    #[test_case(6 => Some(3))]
195    #[test_case(27 => Some(13))]
196    fn parent_of(index: usize) -> Option<usize> {
197        parent(index)
198    }
199
200    #[test_case(0 => 0 ; "Nonsense case")]
201    #[test_case(1 => 2)]
202    #[test_case(2 => 4)]
203    #[test_case(3 => 6)]
204    fn left_child_of(index: usize) -> usize {
205        left_child(index)
206    }
207
208    #[test_case(0 => 0)]
209    #[test_case(1 => 0)]
210    #[test_case(2 => 1)]
211    #[test_case(3 => 1)]
212    #[test_case(6 => 2)]
213    fn depth_of(index: usize) -> usize {
214        depth(index)
215    }
216
217    #[test_case(2 => hex!("b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30"))]
218    fn empty_keccak(depth: usize) -> [u8; 32] {
219        let tree = MerkleTree::<Keccak256>::new(depth, [0; 32]);
220
221        tree.root()
222    }
223
224    #[test]
225    fn simple_poseidon() {
226        let mut tree = MerkleTree::<Poseidon>::new(10, U256::ZERO);
227
228        let expected_root = ruint::uint!(
229            12413880268183407374852357075976609371175688755676981206018884971008854919922_U256
230        );
231        assert_eq!(tree.root(), expected_root);
232
233        tree.set(0, ruint::uint!(1_U256));
234
235        let expected_root = ruint::uint!(
236            467068234150758165281816522946040748310650451788100792957402532717155514893_U256
237        );
238        assert_eq!(tree.root(), expected_root);
239    }
240}