solana_merkle_tree/
merkle_tree.rs

1use {solana_hash::Hash, solana_sha256_hasher::hashv};
2
3// We need to discern between leaf and intermediate nodes to prevent trivial second
4// pre-image attacks.
5// https://flawed.net.nz/2018/02/21/attacking-merkle-trees-with-a-second-preimage-attack
6const LEAF_PREFIX: &[u8] = &[0];
7const INTERMEDIATE_PREFIX: &[u8] = &[1];
8
9macro_rules! hash_leaf {
10    {$d:ident} => {
11        hashv(&[LEAF_PREFIX, $d])
12    }
13}
14
15macro_rules! hash_intermediate {
16    {$l:ident, $r:ident} => {
17        hashv(&[INTERMEDIATE_PREFIX, $l.as_ref(), $r.as_ref()])
18    }
19}
20
21#[derive(Debug)]
22pub struct MerkleTree {
23    leaf_count: usize,
24    nodes: Vec<Hash>,
25}
26
27#[derive(Debug, PartialEq, Eq)]
28pub struct ProofEntry<'a>(&'a Hash, Option<&'a Hash>, Option<&'a Hash>);
29
30impl<'a> ProofEntry<'a> {
31    pub fn new(
32        target: &'a Hash,
33        left_sibling: Option<&'a Hash>,
34        right_sibling: Option<&'a Hash>,
35    ) -> Self {
36        assert!(left_sibling.is_none() ^ right_sibling.is_none());
37        Self(target, left_sibling, right_sibling)
38    }
39}
40
41#[derive(Debug, Default, PartialEq, Eq)]
42pub struct Proof<'a>(Vec<ProofEntry<'a>>);
43
44impl<'a> Proof<'a> {
45    pub fn push(&mut self, entry: ProofEntry<'a>) {
46        self.0.push(entry)
47    }
48
49    pub fn verify(&self, candidate: Hash) -> bool {
50        let result = self.0.iter().try_fold(candidate, |candidate, pe| {
51            let lsib = pe.1.unwrap_or(&candidate);
52            let rsib = pe.2.unwrap_or(&candidate);
53            let hash = hash_intermediate!(lsib, rsib);
54
55            if hash == *pe.0 {
56                Some(hash)
57            } else {
58                None
59            }
60        });
61        result.is_some()
62    }
63}
64
65impl MerkleTree {
66    #[inline]
67    fn next_level_len(level_len: usize) -> usize {
68        if level_len == 1 {
69            0
70        } else {
71            level_len.div_ceil(2)
72        }
73    }
74
75    fn calculate_vec_capacity(leaf_count: usize) -> usize {
76        // the most nodes consuming case is when n-1 is full balanced binary tree
77        // then n will cause the previous tree add a left only path to the root
78        // this cause the total nodes number increased by tree height, we use this
79        // condition as the max nodes consuming case.
80        // n is current leaf nodes number
81        // assuming n-1 is a full balanced binary tree, n-1 tree nodes number will be
82        // 2(n-1) - 1, n tree height is closed to log2(n) + 1
83        // so the max nodes number is 2(n-1) - 1 + log2(n) + 1, finally we can use
84        // 2n + log2(n+1) as a safe capacity value.
85        // test results:
86        // 8192 leaf nodes(full balanced):
87        // computed cap is 16398, actually using is 16383
88        // 8193 leaf nodes:(full balanced plus 1 leaf):
89        // computed cap is 16400, actually using is 16398
90        // about performance: current used fast_math log2 code is constant algo time
91        if leaf_count > 0 {
92            fast_math::log2_raw(leaf_count as f32) as usize + 2 * leaf_count + 1
93        } else {
94            0
95        }
96    }
97
98    pub fn new<T: AsRef<[u8]>>(items: &[T]) -> Self {
99        let cap = MerkleTree::calculate_vec_capacity(items.len());
100        let mut mt = MerkleTree {
101            leaf_count: items.len(),
102            nodes: Vec::with_capacity(cap),
103        };
104
105        for item in items {
106            let item = item.as_ref();
107            let hash = hash_leaf!(item);
108            mt.nodes.push(hash);
109        }
110
111        let mut level_len = MerkleTree::next_level_len(items.len());
112        let mut level_start = items.len();
113        let mut prev_level_len = items.len();
114        let mut prev_level_start = 0;
115        while level_len > 0 {
116            for i in 0..level_len {
117                let prev_level_idx = 2 * i;
118                let lsib = &mt.nodes[prev_level_start + prev_level_idx];
119                let rsib = if prev_level_idx + 1 < prev_level_len {
120                    &mt.nodes[prev_level_start + prev_level_idx + 1]
121                } else {
122                    // Duplicate last entry if the level length is odd
123                    &mt.nodes[prev_level_start + prev_level_idx]
124                };
125
126                let hash = hash_intermediate!(lsib, rsib);
127                mt.nodes.push(hash);
128            }
129            prev_level_start = level_start;
130            prev_level_len = level_len;
131            level_start += level_len;
132            level_len = MerkleTree::next_level_len(level_len);
133        }
134
135        mt
136    }
137
138    pub fn get_root(&self) -> Option<&Hash> {
139        self.nodes.iter().last()
140    }
141
142    pub fn find_path(&self, index: usize) -> Option<Proof> {
143        if index >= self.leaf_count {
144            return None;
145        }
146
147        let mut level_len = self.leaf_count;
148        let mut level_start = 0;
149        let mut path = Proof::default();
150        let mut node_index = index;
151        let mut lsib = None;
152        let mut rsib = None;
153        while level_len > 0 {
154            let level = &self.nodes[level_start..(level_start + level_len)];
155
156            let target = &level[node_index];
157            if lsib.is_some() || rsib.is_some() {
158                path.push(ProofEntry::new(target, lsib, rsib));
159            }
160            if node_index % 2 == 0 {
161                lsib = None;
162                rsib = if node_index + 1 < level.len() {
163                    Some(&level[node_index + 1])
164                } else {
165                    Some(&level[node_index])
166                };
167            } else {
168                lsib = Some(&level[node_index - 1]);
169                rsib = None;
170            }
171            node_index /= 2;
172
173            level_start += level_len;
174            level_len = MerkleTree::next_level_len(level_len);
175        }
176        Some(path)
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use {super::*, solana_hash::HASH_BYTES};
183
184    const TEST: &[&[u8]] = &[
185        b"my", b"very", b"eager", b"mother", b"just", b"served", b"us", b"nine", b"pizzas",
186        b"make", b"prime",
187    ];
188    const BAD: &[&[u8]] = &[b"bad", b"missing", b"false"];
189
190    #[test]
191    fn test_tree_from_empty() {
192        let mt = MerkleTree::new::<[u8; 0]>(&[]);
193        assert_eq!(mt.get_root(), None);
194    }
195
196    #[test]
197    fn test_tree_from_one() {
198        let input = b"test";
199        let mt = MerkleTree::new(&[input]);
200        let expected = hash_leaf!(input);
201        assert_eq!(mt.get_root(), Some(&expected));
202    }
203
204    #[test]
205    fn test_tree_from_many() {
206        let mt = MerkleTree::new(TEST);
207        // This golden hash will need to be updated whenever the contents of `TEST` change in any
208        // way, including addition, removal and reordering or any of the tree calculation algo
209        // changes
210        let bytes = hex::decode("b40c847546fdceea166f927fc46c5ca33c3638236a36275c1346d3dffb84e1bc")
211            .unwrap();
212        let expected = <[u8; HASH_BYTES]>::try_from(bytes)
213            .map(Hash::new_from_array)
214            .unwrap();
215        assert_eq!(mt.get_root(), Some(&expected));
216    }
217
218    #[test]
219    fn test_path_creation() {
220        let mt = MerkleTree::new(TEST);
221        for (i, _s) in TEST.iter().enumerate() {
222            let _path = mt.find_path(i).unwrap();
223        }
224    }
225
226    #[test]
227    fn test_path_creation_bad_index() {
228        let mt = MerkleTree::new(TEST);
229        assert_eq!(mt.find_path(TEST.len()), None);
230    }
231
232    #[test]
233    fn test_path_verify_good() {
234        let mt = MerkleTree::new(TEST);
235        for (i, s) in TEST.iter().enumerate() {
236            let hash = hash_leaf!(s);
237            let path = mt.find_path(i).unwrap();
238            assert!(path.verify(hash));
239        }
240    }
241
242    #[test]
243    fn test_path_verify_bad() {
244        let mt = MerkleTree::new(TEST);
245        for (i, s) in BAD.iter().enumerate() {
246            let hash = hash_leaf!(s);
247            let path = mt.find_path(i).unwrap();
248            assert!(!path.verify(hash));
249        }
250    }
251
252    #[test]
253    fn test_proof_entry_instantiation_lsib_set() {
254        ProofEntry::new(&Hash::default(), Some(&Hash::default()), None);
255    }
256
257    #[test]
258    fn test_proof_entry_instantiation_rsib_set() {
259        ProofEntry::new(&Hash::default(), None, Some(&Hash::default()));
260    }
261
262    #[test]
263    fn test_nodes_capacity_compute() {
264        let iteration_count = |mut leaf_count: usize| -> usize {
265            let mut capacity = 0;
266            while leaf_count > 0 {
267                capacity += leaf_count;
268                leaf_count = MerkleTree::next_level_len(leaf_count);
269            }
270            capacity
271        };
272
273        // test max 64k leaf nodes compute
274        for leaf_count in 0..65536 {
275            let math_count = MerkleTree::calculate_vec_capacity(leaf_count);
276            let iter_count = iteration_count(leaf_count);
277            assert!(math_count >= iter_count);
278        }
279    }
280
281    #[test]
282    #[should_panic]
283    fn test_proof_entry_instantiation_both_clear() {
284        ProofEntry::new(&Hash::default(), None, None);
285    }
286
287    #[test]
288    #[should_panic]
289    fn test_proof_entry_instantiation_both_set() {
290        ProofEntry::new(
291            &Hash::default(),
292            Some(&Hash::default()),
293            Some(&Hash::default()),
294        );
295    }
296}