solana_merkle_tree/
merkle_tree.rs1use {solana_hash::Hash, solana_sha256_hasher::hashv};
2
3const LEAF_PREFIX: &[u8] = &[0];
7const INTERMEDIATE_PREFIX: &[u8] = &[1];
8
9macro_rules! hash_leaf {
10 {$d:ident} => {
11 hashv(&[LEAF_PREFIX, $d])
12 }
13}
14
15macro_rules! hash_intermediate {
16 {$l:ident, $r:ident} => {
17 hashv(&[INTERMEDIATE_PREFIX, $l.as_ref(), $r.as_ref()])
18 }
19}
20
21#[derive(Debug)]
22pub struct MerkleTree {
23 leaf_count: usize,
24 nodes: Vec<Hash>,
25}
26
27#[derive(Debug, PartialEq, Eq)]
28pub struct ProofEntry<'a>(&'a Hash, Option<&'a Hash>, Option<&'a Hash>);
29
30impl<'a> ProofEntry<'a> {
31 pub fn new(
32 target: &'a Hash,
33 left_sibling: Option<&'a Hash>,
34 right_sibling: Option<&'a Hash>,
35 ) -> Self {
36 assert!(left_sibling.is_none() ^ right_sibling.is_none());
37 Self(target, left_sibling, right_sibling)
38 }
39}
40
41#[derive(Debug, Default, PartialEq, Eq)]
42pub struct Proof<'a>(Vec<ProofEntry<'a>>);
43
44impl<'a> Proof<'a> {
45 pub fn push(&mut self, entry: ProofEntry<'a>) {
46 self.0.push(entry)
47 }
48
49 pub fn verify(&self, candidate: Hash) -> bool {
50 let result = self.0.iter().try_fold(candidate, |candidate, pe| {
51 let lsib = pe.1.unwrap_or(&candidate);
52 let rsib = pe.2.unwrap_or(&candidate);
53 let hash = hash_intermediate!(lsib, rsib);
54
55 if hash == *pe.0 {
56 Some(hash)
57 } else {
58 None
59 }
60 });
61 result.is_some()
62 }
63}
64
65impl MerkleTree {
66 #[inline]
67 fn next_level_len(level_len: usize) -> usize {
68 if level_len == 1 {
69 0
70 } else {
71 level_len.div_ceil(2)
72 }
73 }
74
75 fn calculate_vec_capacity(leaf_count: usize) -> usize {
76 if leaf_count > 0 {
92 fast_math::log2_raw(leaf_count as f32) as usize + 2 * leaf_count + 1
93 } else {
94 0
95 }
96 }
97
98 pub fn new<T: AsRef<[u8]>>(items: &[T]) -> Self {
99 let cap = MerkleTree::calculate_vec_capacity(items.len());
100 let mut mt = MerkleTree {
101 leaf_count: items.len(),
102 nodes: Vec::with_capacity(cap),
103 };
104
105 for item in items {
106 let item = item.as_ref();
107 let hash = hash_leaf!(item);
108 mt.nodes.push(hash);
109 }
110
111 let mut level_len = MerkleTree::next_level_len(items.len());
112 let mut level_start = items.len();
113 let mut prev_level_len = items.len();
114 let mut prev_level_start = 0;
115 while level_len > 0 {
116 for i in 0..level_len {
117 let prev_level_idx = 2 * i;
118 let lsib = &mt.nodes[prev_level_start + prev_level_idx];
119 let rsib = if prev_level_idx + 1 < prev_level_len {
120 &mt.nodes[prev_level_start + prev_level_idx + 1]
121 } else {
122 &mt.nodes[prev_level_start + prev_level_idx]
124 };
125
126 let hash = hash_intermediate!(lsib, rsib);
127 mt.nodes.push(hash);
128 }
129 prev_level_start = level_start;
130 prev_level_len = level_len;
131 level_start += level_len;
132 level_len = MerkleTree::next_level_len(level_len);
133 }
134
135 mt
136 }
137
138 pub fn get_root(&self) -> Option<&Hash> {
139 self.nodes.iter().last()
140 }
141
142 pub fn find_path(&self, index: usize) -> Option<Proof> {
143 if index >= self.leaf_count {
144 return None;
145 }
146
147 let mut level_len = self.leaf_count;
148 let mut level_start = 0;
149 let mut path = Proof::default();
150 let mut node_index = index;
151 let mut lsib = None;
152 let mut rsib = None;
153 while level_len > 0 {
154 let level = &self.nodes[level_start..(level_start + level_len)];
155
156 let target = &level[node_index];
157 if lsib.is_some() || rsib.is_some() {
158 path.push(ProofEntry::new(target, lsib, rsib));
159 }
160 if node_index % 2 == 0 {
161 lsib = None;
162 rsib = if node_index + 1 < level.len() {
163 Some(&level[node_index + 1])
164 } else {
165 Some(&level[node_index])
166 };
167 } else {
168 lsib = Some(&level[node_index - 1]);
169 rsib = None;
170 }
171 node_index /= 2;
172
173 level_start += level_len;
174 level_len = MerkleTree::next_level_len(level_len);
175 }
176 Some(path)
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use {super::*, solana_hash::HASH_BYTES};
183
184 const TEST: &[&[u8]] = &[
185 b"my", b"very", b"eager", b"mother", b"just", b"served", b"us", b"nine", b"pizzas",
186 b"make", b"prime",
187 ];
188 const BAD: &[&[u8]] = &[b"bad", b"missing", b"false"];
189
190 #[test]
191 fn test_tree_from_empty() {
192 let mt = MerkleTree::new::<[u8; 0]>(&[]);
193 assert_eq!(mt.get_root(), None);
194 }
195
196 #[test]
197 fn test_tree_from_one() {
198 let input = b"test";
199 let mt = MerkleTree::new(&[input]);
200 let expected = hash_leaf!(input);
201 assert_eq!(mt.get_root(), Some(&expected));
202 }
203
204 #[test]
205 fn test_tree_from_many() {
206 let mt = MerkleTree::new(TEST);
207 let bytes = hex::decode("b40c847546fdceea166f927fc46c5ca33c3638236a36275c1346d3dffb84e1bc")
211 .unwrap();
212 let expected = <[u8; HASH_BYTES]>::try_from(bytes)
213 .map(Hash::new_from_array)
214 .unwrap();
215 assert_eq!(mt.get_root(), Some(&expected));
216 }
217
218 #[test]
219 fn test_path_creation() {
220 let mt = MerkleTree::new(TEST);
221 for (i, _s) in TEST.iter().enumerate() {
222 let _path = mt.find_path(i).unwrap();
223 }
224 }
225
226 #[test]
227 fn test_path_creation_bad_index() {
228 let mt = MerkleTree::new(TEST);
229 assert_eq!(mt.find_path(TEST.len()), None);
230 }
231
232 #[test]
233 fn test_path_verify_good() {
234 let mt = MerkleTree::new(TEST);
235 for (i, s) in TEST.iter().enumerate() {
236 let hash = hash_leaf!(s);
237 let path = mt.find_path(i).unwrap();
238 assert!(path.verify(hash));
239 }
240 }
241
242 #[test]
243 fn test_path_verify_bad() {
244 let mt = MerkleTree::new(TEST);
245 for (i, s) in BAD.iter().enumerate() {
246 let hash = hash_leaf!(s);
247 let path = mt.find_path(i).unwrap();
248 assert!(!path.verify(hash));
249 }
250 }
251
252 #[test]
253 fn test_proof_entry_instantiation_lsib_set() {
254 ProofEntry::new(&Hash::default(), Some(&Hash::default()), None);
255 }
256
257 #[test]
258 fn test_proof_entry_instantiation_rsib_set() {
259 ProofEntry::new(&Hash::default(), None, Some(&Hash::default()));
260 }
261
262 #[test]
263 fn test_nodes_capacity_compute() {
264 let iteration_count = |mut leaf_count: usize| -> usize {
265 let mut capacity = 0;
266 while leaf_count > 0 {
267 capacity += leaf_count;
268 leaf_count = MerkleTree::next_level_len(leaf_count);
269 }
270 capacity
271 };
272
273 for leaf_count in 0..65536 {
275 let math_count = MerkleTree::calculate_vec_capacity(leaf_count);
276 let iter_count = iteration_count(leaf_count);
277 assert!(math_count >= iter_count);
278 }
279 }
280
281 #[test]
282 #[should_panic]
283 fn test_proof_entry_instantiation_both_clear() {
284 ProofEntry::new(&Hash::default(), None, None);
285 }
286
287 #[test]
288 #[should_panic]
289 fn test_proof_entry_instantiation_both_set() {
290 ProofEntry::new(
291 &Hash::default(),
292 Some(&Hash::default()),
293 Some(&Hash::default()),
294 );
295 }
296}