simple_merkle_tree/
lib.rs

1//! Simple Merkle Tree implementation
2//!
3//! To be used hand-in-hand with Solidity, that's why it's using `keccak256` for hashing!
4//! Example usage
5//! ```rust
6//! use simple_merkle_tree::MerkleTree;
7//! let elements = (0..4)
8//!     .map(|el| format!("item-string-{:}", el).into_bytes())
9//!     .collect::<Vec<Vec<u8>>>();
10//!
11//! let tree = MerkleTree::new(elements.clone());
12//! let a = &elements[0];
13//! let b = &elements[1];
14//! let c = &elements[2];
15//! let d = &elements[3];
16//!
17//! let h_a = MerkleTree::hash(a); // Part of the proof
18//! let h_b = MerkleTree::hash(b);
19//! let h_c = MerkleTree::hash(c); // Part of the proof
20//! let h_d = MerkleTree::hash(d); // Part of the proof
21//! let h_ab = MerkleTree::combined_hash(&h_a, &h_b); // Part of the proof
22//! let h_cd = MerkleTree::combined_hash(&h_c, &h_d);
23//! let h_abcd = MerkleTree::combined_hash(&h_ab, &h_cd);
24//!
25//! let proof = tree.get_proof(d).unwrap();
26//! assert_eq!(proof.len(), 2);
27//!
28//! assert_eq!(
29//!     vec![hex::encode(h_c), hex::encode(h_ab),],
30//!     proof
31//!         .iter()
32//!         .map(|e| hex::encode(e))
33//!         .collect::<Vec<String>>()
34//! );
35//!
36//! let root = tree.get_root();
37//! assert_eq!(hex::encode(h_abcd), hex::encode(root));
38//! ```
39
40
41use std::fmt::Debug;
42
43pub use tiny_keccak::Hasher;
44
45pub type Buffer = Vec<u8>;
46type Hash = [u8; 32];
47
48/// Merkle tree implementation using
49pub struct MerkleTree {
50    hashed_elements: Vec<Hash>,
51}
52
53// Public interface impl
54impl MerkleTree {
55    pub fn new(elements: Vec<Buffer>) -> Self {
56        let elements = {
57            let mut elements: Vec<Buffer> = elements
58                .into_iter()
59                // Filter empty
60                .filter(|e| !e.iter().all(|e| *e == 0))
61                .collect();
62
63            // Sort
64            elements.sort();
65
66            // Deduplicate
67            let el_len = elements.len();
68            let elements = elements
69                .into_iter()
70                .fold(Vec::with_capacity(el_len), |mut acc, i| {
71                    if !acc.contains(&i) {
72                        acc.push(i);
73                    }
74                    acc
75                });
76            elements
77        };
78
79        // Construct hashes
80        let el_len = elements.len();
81        let (capacity, levels) = MerkleTree::calculate_levels(&el_len);
82
83        let vector_size = 2 * el_len - 1;
84        let mut result = vec![[0; 32]; vector_size];
85        log::debug!("Creating a vector with size {:}", vector_size);
86
87        let mut prior_elements = 0;
88        for level in 1..=levels {
89            let elem_count_in_level = el_len / level as usize;
90            let start_index = capacity - prior_elements - elem_count_in_level;
91
92            let end_index = start_index + elem_count_in_level; // non inclusive
93            prior_elements += elem_count_in_level;
94            log::trace!(
95                "start_index: {}| end_index {}| elem_count_in_level {}",
96                start_index,
97                end_index,
98                elem_count_in_level
99            );
100
101            if level == 1 {
102                for (idx, elem) in elements.iter().enumerate() {
103                    let hashed = MerkleTree::hash(&elem);
104                    log::trace!(
105                        "Setting idx {:} to {:}",
106                        start_index + idx,
107                        hex::encode(hashed)
108                    );
109                    result[start_index + idx] = hashed;
110                }
111            } else {
112                for idx in start_index..end_index {
113                    let left = (2_usize * idx) + 1;
114                    let right = (2_usize * idx) + 2;
115
116                    log::trace!("Getting child of {}| L: {}| R: {}", idx, left, right);
117                    let left = result[left];
118                    let right = result[right];
119                    let parent = MerkleTree::combined_hash(&left, &right);
120                    // log::trace!("Setting idx {:} to {:}", start_index + idx, hex::encode(parent));
121                    result[idx] = parent;
122                }
123            }
124        }
125
126        let res = Self {
127            hashed_elements: result,
128        };
129        log::debug!("Constructed merkle tree {:#?}", &res);
130        res
131    }
132
133
134    pub fn get_root(&self) -> &[u8; 32] {
135        &self.hashed_elements[0]
136    }
137
138    pub fn get_proof(&self, el: &Buffer) -> Option<Vec<&[u8; 32]>> {
139        let hashed = MerkleTree::hash(&el);
140        log::debug!("Finding proof for {:}", hex::encode(hashed));
141
142        let index = self.hashed_elements.iter().position(|e| e == &hashed);
143
144        match index {
145            Some(mut index) => {
146                let mut res = vec![];
147
148                while index > 0 {
149                    // Skip the root element
150                    let sibling = self.get_pair_element(index);
151
152                    if let Some(sibling) = sibling {
153                        log::trace!(
154                            "getting pair elem for index {:}; res {:}",
155                            index,
156                            hex::encode(sibling)
157                        );
158                        res.push(sibling);
159                    }
160
161                    index = MerkleTree::calculate_parent_idx(index);
162                    log::trace!("Parent {:}", index);
163                }
164                Some(res)
165            }
166            None => None,
167        }
168    }
169
170}
171
172// Different helpers
173impl MerkleTree {
174
175    /// Create a hash of two byte array slices
176    ///```rust
177    /// use simple_merkle_tree::MerkleTree;
178    /// let elements = (0..3)
179    ///     .map(|el| format!("item-string-{:}", el).into_bytes())
180    ///     .collect::<Vec<Vec<u8>>>();
181    ///
182    /// let a = &elements[0];
183    /// let b = &elements[1];
184    /// let h_a = MerkleTree::hash(a);
185    /// let h_b = MerkleTree::hash(b);
186    /// let h_ab = MerkleTree::combined_hash(&h_a, &h_b);
187    ///```
188    pub fn combined_hash(first: &[u8], second: &[u8]) -> [u8; 32] {
189        let mut keccak = tiny_keccak::Keccak::v256();
190        keccak.update(first);
191        keccak.update(second);
192        let mut result: [u8; 32] = Default::default();
193        keccak.finalize(&mut result);
194        result
195    }
196
197    /// Create a hash of a single byte array slice
198    ///```rust
199    /// use simple_merkle_tree::MerkleTree;
200    /// let elements = (0..3)
201    ///     .map(|el| format!("item-string-{:}", el).into_bytes())
202    ///     .collect::<Vec<Vec<u8>>>();
203    ///
204    /// let a = &elements[0];
205    ///
206    /// let h_a = MerkleTree::hash(a);
207    ///```
208    pub fn hash(data: &[u8]) -> [u8; 32] {
209        let mut keccak = tiny_keccak::Keccak::v256();
210        keccak.update(&data);
211        let mut result: [u8; 32] = Default::default();
212        keccak.finalize(&mut result);
213        result
214    }
215
216    fn get_pair_element(&self, idx: usize) -> Option<&[u8; 32]> {
217        let pair_idx = MerkleTree::calculate_sibling_idx(idx);
218
219        if pair_idx < self.hashed_elements.len() {
220            return Some(&self.hashed_elements[pair_idx]);
221        }
222        return None;
223    }
224
225    fn calculate_sibling_idx(idx: usize) -> usize {
226        if idx % 2 == 0 {
227            idx - 1
228        } else {
229            idx + 1
230        }
231    }
232
233    fn calculate_parent_idx(child_idx: usize) -> usize {
234        let child_offset = {
235            if child_idx % 2 == 0 {
236                // If is right child
237                2
238            } else {
239                // If is left child
240                1
241            }
242        };
243
244        (child_idx - child_offset) / 2
245    }
246
247
248    fn calculate_levels(el_len: &usize) -> (usize, u32) {
249        let capacity = 2 * el_len - 1;
250        let levels: u32 = ((capacity as f32).log2() + 1.) as u32;
251        (capacity, levels)
252    }
253}
254
255impl Debug for MerkleTree {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        let hashed_elements: Vec<String> = self
258            .hashed_elements
259            .iter()
260            .map(|e| hex::encode(e))
261            .collect();
262
263        f.debug_struct("MerkleTree")
264            .field("hashed_elements", &hashed_elements)
265            .finish()
266    }
267}
268
269#[cfg(test)]
270mod tests {
271
272    use super::*;
273
274    fn generate_sample_vec(items: u32) -> Vec<Vec<u8>> {
275        let elements = (0..items)
276            .map(|el| format!("item-string-{:}", el).into_bytes())
277            .collect::<Vec<Vec<u8>>>();
278        elements
279    }
280
281    #[test]
282    fn construct_tree() {
283        simple_logger::init_with_level(log::Level::Trace).unwrap();
284        let elements = generate_sample_vec(4);
285        let tree = MerkleTree::new(elements.clone());
286
287        let a = &elements[0];
288        let b = &elements[1];
289        let c = &elements[2];
290        let d = &elements[3];
291
292        let h_a = MerkleTree::hash(a);
293        let h_b = MerkleTree::hash(b);
294        let h_c = MerkleTree::hash(c);
295        let h_d = MerkleTree::hash(d);
296
297        let h_ab = MerkleTree::combined_hash(&h_a, &h_b);
298        let h_cd = MerkleTree::combined_hash(&h_c, &h_d);
299
300        let h_abcd = MerkleTree::combined_hash(&h_ab, &h_cd);
301
302        log::debug!("h_abcd = {:}", hex::encode(h_abcd));
303
304        log::debug!("h_ab = {:}", hex::encode(h_ab));
305        log::debug!("h_cd = {:}", hex::encode(h_cd));
306
307        log::debug!("h_a = {:}", hex::encode(h_a));
308        log::debug!("h_b = {:}", hex::encode(h_b));
309        log::debug!("h_c = {:}", hex::encode(h_c));
310        log::debug!("h_d = {:}", hex::encode(h_d));
311
312        {
313            let proof = tree.get_proof(d).unwrap();
314            assert_eq!(proof.len(), 2);
315
316            assert_eq!(
317                vec![hex::encode(h_c), hex::encode(h_ab),],
318                proof
319                    .iter()
320                    .map(|e| hex::encode(e))
321                    .collect::<Vec<String>>()
322            );
323        }
324
325        {
326            let proof = tree.get_proof(a).unwrap();
327            assert_eq!(proof.len(), 2);
328
329            assert_eq!(
330                vec![hex::encode(h_b), hex::encode(h_cd),],
331                proof
332                    .iter()
333                    .map(|e| hex::encode(e))
334                    .collect::<Vec<String>>()
335            );
336        }
337        {
338            let root = tree.get_root();
339            assert_eq!(hex::encode(h_abcd), hex::encode(root));
340        }
341    }
342
343    #[test]
344    fn levels_get_calculated() {
345        let elements = generate_sample_vec(4);
346        let levels = MerkleTree::calculate_levels(&elements.len());
347        assert_eq!(levels, (7, 3));
348    }
349    #[test]
350    fn levels_get_calculated_v2() {
351        let elements = generate_sample_vec(3);
352        let levels = MerkleTree::calculate_levels(&elements.len());
353        assert_eq!(levels, (5, 3));
354    }
355    #[test]
356    fn levels_get_calculated_v3() {
357        let elements = generate_sample_vec(2);
358        let levels = MerkleTree::calculate_levels(&elements.len());
359        assert_eq!(levels, (3, 2));
360    }
361}