semaphore_rs_trees/imt/
mod.rs1use std::fmt::Debug;
4use std::iter::{once, repeat_n, successors};
5
6use bytemuck::Pod;
7use derive_where::derive_where;
8use semaphore_rs_hasher::Hasher;
9
10use crate::proof::{Branch, InclusionProof};
11
12#[derive_where(Clone; <H as Hasher>::Hash: Clone)]
14#[derive_where(PartialEq; <H as Hasher>::Hash: PartialEq)]
15#[derive_where(Eq; <H as Hasher>::Hash: Eq)]
16#[derive_where(Debug; <H as Hasher>::Hash: Debug)]
17pub struct MerkleTree<H>
18where
19 H: Hasher,
20{
21 depth: usize,
23
24 empty: Vec<H::Hash>,
26
27 nodes: Vec<H::Hash>,
29}
30
31const fn parent(index: usize) -> Option<usize> {
34 if index <= 1 {
35 None
36 } else {
37 Some(index >> 1)
38 }
39}
40
41const fn left_child(index: usize) -> usize {
43 index << 1
44}
45
46const fn depth(index: usize) -> usize {
47 if index <= 1 {
50 return 0;
51 }
52
53 index.ilog2() as usize
54}
55
56impl<H> MerkleTree<H>
57where
58 H: Hasher,
59 <H as Hasher>::Hash: Clone + Copy + Pod + Eq + Debug,
60{
61 pub fn new(depth: usize, initial_leaf: H::Hash) -> Self {
65 let empty = successors(Some(initial_leaf), |prev| Some(H::hash_node(prev, prev)))
67 .take(depth + 1)
68 .collect::<Vec<_>>();
69
70 let first_node = std::iter::once(initial_leaf);
72 let nodes = empty
73 .iter()
74 .rev()
75 .enumerate()
76 .flat_map(|(depth, hash)| repeat_n(hash, 1 << depth))
77 .cloned();
78
79 let nodes = first_node.chain(nodes).collect();
80
81 Self {
82 depth,
83 empty,
84 nodes,
85 }
86 }
87
88 #[must_use]
89 pub fn num_leaves(&self) -> usize {
90 1 << self.depth
91 }
92
93 #[must_use]
94 pub fn root(&self) -> H::Hash {
95 self.nodes[1]
96 }
97
98 pub fn set(&mut self, leaf: usize, hash: H::Hash) {
99 self.set_range(leaf, once(hash));
100 }
101
102 pub fn set_range<I: IntoIterator<Item = H::Hash>>(&mut self, start: usize, hashes: I) {
103 let index = self.num_leaves() + start;
104
105 let mut count = 0;
106 for (leaf, hash) in self.nodes[index..].iter_mut().zip(hashes) {
108 *leaf = hash;
109 count += 1;
110 }
111
112 if count != 0 {
113 self.update_nodes(index, index + (count - 1));
114 }
115 }
116
117 fn update_nodes(&mut self, start: usize, end: usize) {
118 debug_assert_eq!(depth(start), depth(end));
119 if let (Some(start), Some(end)) = (parent(start), parent(end)) {
120 for parent in start..=end {
121 let child = left_child(parent);
122 self.nodes[parent] = H::hash_node(&self.nodes[child], &self.nodes[child + 1]);
123 }
124 self.update_nodes(start, end);
125 }
126 }
127
128 #[must_use]
129 pub fn proof(&self, leaf: usize) -> Option<InclusionProof<H>> {
130 if leaf >= self.num_leaves() {
131 return None;
132 }
133 let mut index = self.num_leaves() + leaf;
134 let mut path = Vec::with_capacity(self.depth);
135 while let Some(parent) = parent(index) {
136 path.push(match index & 1 {
138 1 => Branch::Right(self.nodes[index - 1]),
139 0 => Branch::Left(self.nodes[index + 1]),
140 _ => unreachable!(),
141 });
142 index = parent;
143 }
144 Some(InclusionProof(path))
145 }
146
147 #[must_use]
148 pub fn verify(&self, hash: H::Hash, proof: &InclusionProof<H>) -> bool {
149 proof.root(hash) == self.root()
150 }
151
152 #[must_use]
153 pub fn leaves(&self) -> &[H::Hash] {
154 &self.nodes[(self.num_leaves() - 1)..]
155 }
156}
157
158impl<H: Hasher> InclusionProof<H> {
159 #[must_use]
161 pub fn leaf_index(&self) -> usize {
162 self.0.iter().rev().fold(0, |index, branch| match branch {
163 Branch::Left(_) => index << 1,
164 Branch::Right(_) => (index << 1) + 1,
165 })
166 }
167
168 #[must_use]
170 pub fn root(&self, hash: H::Hash) -> H::Hash {
171 self.0.iter().fold(hash, |hash, branch| match branch {
172 Branch::Left(sibling) => H::hash_node(&hash, sibling),
173 Branch::Right(sibling) => H::hash_node(sibling, &hash),
174 })
175 }
176}
177
178#[cfg(test)]
179pub mod test {
180 use hex_literal::hex;
181 use ruint::aliases::U256;
182 use semaphore_rs_keccak::keccak::Keccak256;
183 use semaphore_rs_poseidon::Poseidon;
184 use test_case::test_case;
185
186 use super::*;
187
188 #[test_case(0 => None)]
189 #[test_case(1 => None)]
190 #[test_case(2 => Some(1))]
191 #[test_case(3 => Some(1))]
192 #[test_case(4 => Some(2))]
193 #[test_case(5 => Some(2))]
194 #[test_case(6 => Some(3))]
195 #[test_case(27 => Some(13))]
196 fn parent_of(index: usize) -> Option<usize> {
197 parent(index)
198 }
199
200 #[test_case(0 => 0 ; "Nonsense case")]
201 #[test_case(1 => 2)]
202 #[test_case(2 => 4)]
203 #[test_case(3 => 6)]
204 fn left_child_of(index: usize) -> usize {
205 left_child(index)
206 }
207
208 #[test_case(0 => 0)]
209 #[test_case(1 => 0)]
210 #[test_case(2 => 1)]
211 #[test_case(3 => 1)]
212 #[test_case(6 => 2)]
213 fn depth_of(index: usize) -> usize {
214 depth(index)
215 }
216
217 #[test_case(2 => hex!("b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30"))]
218 fn empty_keccak(depth: usize) -> [u8; 32] {
219 let tree = MerkleTree::<Keccak256>::new(depth, [0; 32]);
220
221 tree.root()
222 }
223
224 #[test]
225 fn simple_poseidon() {
226 let mut tree = MerkleTree::<Poseidon>::new(10, U256::ZERO);
227
228 let expected_root = ruint::uint!(
229 12413880268183407374852357075976609371175688755676981206018884971008854919922_U256
230 );
231 assert_eq!(tree.root(), expected_root);
232
233 tree.set(0, ruint::uint!(1_U256));
234
235 let expected_root = ruint::uint!(
236 467068234150758165281816522946040748310650451788100792957402532717155514893_U256
237 );
238 assert_eq!(tree.root(), expected_root);
239 }
240}