solana_merkle_tree/
merkle_tree.rs1use solana_program::hash::{hashv, Hash};
2
3const LEAF_PREFIX: &[u8] = &[0];
7const INTERMEDIATE_PREFIX: &[u8] = &[1];
8
9macro_rules! hash_leaf {
10 {$d:ident} => {
11 hashv(&[LEAF_PREFIX, $d])
12 }
13}
14
15macro_rules! hash_intermediate {
16 {$l:ident, $r:ident} => {
17 hashv(&[INTERMEDIATE_PREFIX, $l.as_ref(), $r.as_ref()])
18 }
19}
20
21#[derive(Debug)]
22pub struct MerkleTree {
23 leaf_count: usize,
24 nodes: Vec<Hash>,
25}
26
27#[derive(Debug, PartialEq, Eq)]
28pub struct ProofEntry<'a>(&'a Hash, Option<&'a Hash>, Option<&'a Hash>);
29
30impl<'a> ProofEntry<'a> {
31 pub fn new(
32 target: &'a Hash,
33 left_sibling: Option<&'a Hash>,
34 right_sibling: Option<&'a Hash>,
35 ) -> Self {
36 assert!(left_sibling.is_none() ^ right_sibling.is_none());
37 Self(target, left_sibling, right_sibling)
38 }
39}
40
41#[derive(Debug, Default, PartialEq, Eq)]
42pub struct Proof<'a>(Vec<ProofEntry<'a>>);
43
44impl<'a> Proof<'a> {
45 pub fn push(&mut self, entry: ProofEntry<'a>) {
46 self.0.push(entry)
47 }
48
49 pub fn verify(&self, candidate: Hash) -> bool {
50 let result = self.0.iter().try_fold(candidate, |candidate, pe| {
51 let lsib = pe.1.unwrap_or(&candidate);
52 let rsib = pe.2.unwrap_or(&candidate);
53 let hash = hash_intermediate!(lsib, rsib);
54
55 if hash == *pe.0 {
56 Some(hash)
57 } else {
58 None
59 }
60 });
61 matches!(result, Some(_))
62 }
63}
64
65impl MerkleTree {
66 #[inline]
67 fn next_level_len(level_len: usize) -> usize {
68 if level_len == 1 {
69 0
70 } else {
71 (level_len + 1) / 2
72 }
73 }
74
75 fn calculate_vec_capacity(leaf_count: usize) -> usize {
76 if leaf_count > 0 {
92 fast_math::log2_raw(leaf_count as f32) as usize + 2 * leaf_count + 1
93 } else {
94 0
95 }
96 }
97
98 pub fn new<T: AsRef<[u8]>>(items: &[T]) -> Self {
99 let cap = MerkleTree::calculate_vec_capacity(items.len());
100 let mut mt = MerkleTree {
101 leaf_count: items.len(),
102 nodes: Vec::with_capacity(cap),
103 };
104
105 for item in items {
106 let item = item.as_ref();
107 let hash = hash_leaf!(item);
108 mt.nodes.push(hash);
109 }
110
111 let mut level_len = MerkleTree::next_level_len(items.len());
112 let mut level_start = items.len();
113 let mut prev_level_len = items.len();
114 let mut prev_level_start = 0;
115 while level_len > 0 {
116 for i in 0..level_len {
117 let prev_level_idx = 2 * i;
118 let lsib = &mt.nodes[prev_level_start + prev_level_idx];
119 let rsib = if prev_level_idx + 1 < prev_level_len {
120 &mt.nodes[prev_level_start + prev_level_idx + 1]
121 } else {
122 &mt.nodes[prev_level_start + prev_level_idx]
124 };
125
126 let hash = hash_intermediate!(lsib, rsib);
127 mt.nodes.push(hash);
128 }
129 prev_level_start = level_start;
130 prev_level_len = level_len;
131 level_start += level_len;
132 level_len = MerkleTree::next_level_len(level_len);
133 }
134
135 mt
136 }
137
138 pub fn get_root(&self) -> Option<&Hash> {
139 self.nodes.iter().last()
140 }
141
142 pub fn find_path(&self, index: usize) -> Option<Proof> {
143 if index >= self.leaf_count {
144 return None;
145 }
146
147 let mut level_len = self.leaf_count;
148 let mut level_start = 0;
149 let mut path = Proof::default();
150 let mut node_index = index;
151 let mut lsib = None;
152 let mut rsib = None;
153 while level_len > 0 {
154 let level = &self.nodes[level_start..(level_start + level_len)];
155
156 let target = &level[node_index];
157 if lsib.is_some() || rsib.is_some() {
158 path.push(ProofEntry::new(target, lsib, rsib));
159 }
160 if node_index % 2 == 0 {
161 lsib = None;
162 rsib = if node_index + 1 < level.len() {
163 Some(&level[node_index + 1])
164 } else {
165 Some(&level[node_index])
166 };
167 } else {
168 lsib = Some(&level[node_index - 1]);
169 rsib = None;
170 }
171 node_index /= 2;
172
173 level_start += level_len;
174 level_len = MerkleTree::next_level_len(level_len);
175 }
176 Some(path)
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 const TEST: &[&[u8]] = &[
185 b"my", b"very", b"eager", b"mother", b"just", b"served", b"us", b"nine", b"pizzas",
186 b"make", b"prime",
187 ];
188 const BAD: &[&[u8]] = &[b"bad", b"missing", b"false"];
189
190 #[test]
191 fn test_tree_from_empty() {
192 let mt = MerkleTree::new::<[u8; 0]>(&[]);
193 assert_eq!(mt.get_root(), None);
194 }
195
196 #[test]
197 fn test_tree_from_one() {
198 let input = b"test";
199 let mt = MerkleTree::new(&[input]);
200 let expected = hash_leaf!(input);
201 assert_eq!(mt.get_root(), Some(&expected));
202 }
203
204 #[test]
205 fn test_tree_from_many() {
206 let mt = MerkleTree::new(TEST);
207 let bytes = hex::decode("b40c847546fdceea166f927fc46c5ca33c3638236a36275c1346d3dffb84e1bc")
211 .unwrap();
212 let expected = Hash::new(&bytes);
213 assert_eq!(mt.get_root(), Some(&expected));
214 }
215
216 #[test]
217 fn test_path_creation() {
218 let mt = MerkleTree::new(TEST);
219 for (i, _s) in TEST.iter().enumerate() {
220 let _path = mt.find_path(i).unwrap();
221 }
222 }
223
224 #[test]
225 fn test_path_creation_bad_index() {
226 let mt = MerkleTree::new(TEST);
227 assert_eq!(mt.find_path(TEST.len()), None);
228 }
229
230 #[test]
231 fn test_path_verify_good() {
232 let mt = MerkleTree::new(TEST);
233 for (i, s) in TEST.iter().enumerate() {
234 let hash = hash_leaf!(s);
235 let path = mt.find_path(i).unwrap();
236 assert!(path.verify(hash));
237 }
238 }
239
240 #[test]
241 fn test_path_verify_bad() {
242 let mt = MerkleTree::new(TEST);
243 for (i, s) in BAD.iter().enumerate() {
244 let hash = hash_leaf!(s);
245 let path = mt.find_path(i).unwrap();
246 assert!(!path.verify(hash));
247 }
248 }
249
250 #[test]
251 fn test_proof_entry_instantiation_lsib_set() {
252 ProofEntry::new(&Hash::default(), Some(&Hash::default()), None);
253 }
254
255 #[test]
256 fn test_proof_entry_instantiation_rsib_set() {
257 ProofEntry::new(&Hash::default(), None, Some(&Hash::default()));
258 }
259
260 #[test]
261 fn test_nodes_capacity_compute() {
262 let iteration_count = |mut leaf_count: usize| -> usize {
263 let mut capacity = 0;
264 while leaf_count > 0 {
265 capacity += leaf_count;
266 leaf_count = MerkleTree::next_level_len(leaf_count);
267 }
268 capacity
269 };
270
271 for leaf_count in 0..65536 {
273 let math_count = MerkleTree::calculate_vec_capacity(leaf_count);
274 let iter_count = iteration_count(leaf_count);
275 assert!(math_count >= iter_count);
276 }
277 }
278
279 #[test]
280 #[should_panic]
281 fn test_proof_entry_instantiation_both_clear() {
282 ProofEntry::new(&Hash::default(), None, None);
283 }
284
285 #[test]
286 #[should_panic]
287 fn test_proof_entry_instantiation_both_set() {
288 ProofEntry::new(
289 &Hash::default(),
290 Some(&Hash::default()),
291 Some(&Hash::default()),
292 );
293 }
294}