static_merkel_tree/
lib.rs

1// CITA
2// Copyright 2016-2019 Cryptape Technologies LLC.
3
4// This program is free software: you can redistribute it
5// and/or modify it under the terms of the GNU General Public
6// License as published by the Free Software Foundation,
7// either version 3 of the License, or (at your option) any
8// later version.
9
10// This program is distributed in the hope that it will be
11// useful, but WITHOUT ANY WARRANTY; without even the implied
12// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
13// PURPOSE. See the GNU General Public License for more details.
14
15// You should have received a copy of the GNU General Public License
16// along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
18//! Generetes complete merkle tree root.
19//!
20//! This module should be used to generate complete merkle tree root hash.
21
22use std::cmp::PartialEq;
23use std::marker::PhantomData;
24
25#[derive(Debug, Clone)]
26pub struct Tree<T, M> {
27    nodes: Vec<T>,
28    leaf_size: usize,
29    phanton: PhantomData<M>,
30}
31
32#[derive(Debug, Clone)]
33pub struct ProofNode<T> {
34    pub is_right: bool,
35    pub hash: T,
36}
37
38#[derive(Debug, Clone)]
39pub struct Proof<T>(pub Vec<ProofNode<T>>);
40
41impl<T, M> Tree<T, M>
42where
43    T: Default + Clone + PartialEq,
44    M: Fn(&T, &T) -> T,
45{
46    // For example, there is 11 hashes [A..K] are used to generate a merkle tree:
47    //
48    //      F_G H_I J_K
49    //       |   |   |
50    //       7___8   9___A   B___C   D___E
51    //         |       |       |       |
52    //         3_______4       5_______6
53    //             |               |
54    //             1_______________2
55    //                     |
56    //                     0
57    //
58    // The algorithm is:
59    //
60    //      1. Create a vec:    [_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]
61    //      2. Insert A..K:     [_, _, _, _, _, _, _, _, _, _, A, B, C, D, E, F, G, H, I, J, K]
62    //      3. Update for 7..9: [_, _, _, _, _, _, _, 7, 8, 9, A, B, C, D, E, F, G, H, I, J, K]
63    //      4. Update for 3..6: [_, _, _, 3, 4, 5, 6, 7, 8, 9, A, B, C, D, E, F, G, H, I, J, K]
64    //      5. Update for 1..2: [_, 1, 2, 3, 4, 5, 6, 7, 8, 9, A, B, C, D, E, F, G, H, I, J, K]
65    //      6. Update for 0:    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, A, B, C, D, E, F, G, H, I, J, K]
66    pub fn from_hashes(input: Vec<T>, merge: M) -> Self {
67        let leaf_size = input.len();
68
69        let nodes = match leaf_size {
70            0 => vec![],
71
72            // If only one input.
73            1 => input,
74
75            _ => {
76                let nodes_number = get_number_of_nodes(leaf_size);
77                let mut nodes = vec![T::default(); nodes_number];
78
79                let depth = get_depth(leaf_size);
80
81                let first_input_node_index = nodes_number - leaf_size;
82                let first_index_in_lowest_row = (1 << depth) - 1;
83                let nodes_number_of_lowest_row = nodes_number - first_index_in_lowest_row;
84
85                // Insert the input into the merkle tree.
86                nodes[first_input_node_index..nodes_number].clone_from_slice(&input);
87
88                let max_nodes_number_of_lowest_row = 1 << depth;
89                // Calc hash for the lowest row
90                if max_nodes_number_of_lowest_row == leaf_size {
91                    // The lowest row is full.
92                    calc_tree_at_row(&mut nodes, depth, 0, &merge)
93                } else {
94                    calc_tree_at_row(&mut nodes, depth, nodes_number_of_lowest_row >> 1, &merge);
95                }
96
97                // From the second row from the bottom to the second row from the top.
98                for i in 1..depth {
99                    let row_index = depth - i;
100                    calc_tree_at_row(&mut nodes, row_index, 0, &merge);
101                }
102
103                nodes
104            }
105        };
106
107        Self {
108            nodes,
109            leaf_size,
110            phanton: PhantomData,
111        }
112    }
113
114    pub fn get_root_hash(&self) -> Option<&T> {
115        self.nodes.get(0)
116    }
117
118    pub fn get_proof_by_input_index(&self, input_index: usize) -> Option<Proof<T>> {
119        get_proof_indexes(input_index, self.leaf_size).map(|indexes| {
120            Proof::<T>(
121                indexes
122                    .into_iter()
123                    .map(|i| ProofNode::<T> {
124                        is_right: (i & 1) == 0,
125                        hash: self.nodes[i].clone(),
126                    })
127                    .collect(),
128            )
129        })
130    }
131}
132
133impl<T> Proof<T>
134where
135    T: Default + Clone + PartialEq,
136{
137    pub fn verify<M>(&self, root: &T, data: T, merge: M) -> bool
138    where
139        M: Fn(&T, &T) -> T,
140    {
141        &self.0.iter().fold(data, |h, ref x| {
142            if x.is_right {
143                merge(&h, &x.hash)
144            } else {
145                merge(&x.hash, &h)
146            }
147        }) == root
148    }
149}
150
151// Calc hash for one row in merkle tree.
152// If break_cnt > 0, just calc break_cnt hash for that row.
153fn calc_tree_at_row<T, M>(nodes: &mut Vec<T>, row_index: usize, break_cnt: usize, merge: M)
154where
155    M: Fn(&T, &T) -> T,
156{
157    // The first index in the row which above the row_index row.
158    let index_update = (1 << (row_index - 1)) - 1;
159    let size_max = 1 << (row_index - 1);
160    let size_update = if break_cnt > 0 && break_cnt < size_max {
161        break_cnt
162    } else {
163        size_max
164    };
165    for i in 0..size_update {
166        let j = index_update + i;
167        nodes[j] = merge(&nodes[j * 2 + 1], &nodes[j * 2 + 2]);
168    }
169}
170
171#[inline]
172fn get_depth(m: usize) -> usize {
173    let mut x: usize = 1;
174    let mut y: usize = 0;
175    while x < m {
176        x <<= 1;
177        y += 1;
178    }
179    y
180}
181
182#[inline]
183fn get_number_of_nodes(m: usize) -> usize {
184    // The depth for m:
185    //      depth = get_depth(m);
186    // The second row from the bottom (math index):
187    //      y = depth;
188    // Number of nodes for the second row from the bottom:
189    //      z = 2 ^ (y - 1);
190    // Number of nodes above the lowest row:
191    //      n1 = 2 ^ y - 1;
192    // Number of nodes in the lowest row:
193    //      n2 = (m - z) * 2;
194    // Returns:
195    //      n1 + n2
196    //      = (2 ^ y - 1) + ((m - z) * 2)
197    //      = m * 2  - 1
198    if m == 0 {
199        1
200    } else {
201        m * 2 - 1
202    }
203}
204
205#[inline]
206fn get_index_of_brother_and_father(index: usize) -> (usize, usize) {
207    // Convert computer index (start from 0) to math index (start from 1).
208    let math_index = index + 1;
209    // The only one difference between brother math indexes in binary tree is the last bit.
210    let math_index_for_brother = (math_index & ((!0) - 1)) + ((!math_index) & 1);
211    let math_index_for_father = math_index >> 1;
212    // Convert back to computer index.
213    (math_index_for_brother - 1, math_index_for_father - 1)
214}
215
216#[inline]
217fn get_proof_indexes(input_index: usize, leaf_size: usize) -> Option<Vec<usize>> {
218    if input_index == 0 && leaf_size < 2 {
219        Some(vec![])
220    } else if leaf_size != 0 && input_index < leaf_size {
221        let mut ret = Vec::new();
222        let nodes_number = get_number_of_nodes(leaf_size);
223        let mut index = nodes_number - leaf_size + input_index;
224        while index > 0 {
225            let (brother_index, parent_index) = get_index_of_brother_and_father(index);
226            ret.push(brother_index);
227            index = parent_index;
228        }
229        Some(ret)
230    } else {
231        None
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    #[derive(Default, Clone, PartialEq, Debug)]
238    struct Node(Vec<u32>);
239
240    fn merge(left: &Node, right: &Node) -> Node {
241        let mut root: Vec<u32> = vec![];
242        root.extend_from_slice(&left.0);
243        root.extend_from_slice(&right.0);
244        Node(root)
245    }
246
247    #[test]
248    fn test_depth() {
249        let check = vec![
250            (0, 0),
251            (1, 0),
252            (2, 1),
253            (3, 2),
254            (4, 2),
255            (5, 3),
256            (8, 3),
257            (9, 4),
258            (16, 4),
259            (17, 5),
260        ];
261        for (x, y) in check {
262            assert_eq!(y, super::get_depth(x));
263        }
264    }
265
266    #[test]
267    fn test_number_of_nodes() {
268        let check = vec![
269            (0, 1),
270            (1, 1),
271            (2, 3),
272            (3, 5),
273            (4, 7),
274            (5, 9),
275            (8, 15),
276            (9, 17),
277            (16, 31),
278            (20, 39),
279        ];
280        for (x, y) in check {
281            assert_eq!(y, super::get_number_of_nodes(x));
282        }
283    }
284
285    #[test]
286    fn test_index_of_brother_and_father() {
287        let check = vec![
288            (1, (2, 0)),
289            (2, (1, 0)),
290            (11, (12, 5)),
291            (12, (11, 5)),
292            (21, (22, 10)),
293            (22, (21, 10)),
294            (31, (32, 15)),
295            (32, (31, 15)),
296        ];
297        for (x, y) in check {
298            assert_eq!(y, super::get_index_of_brother_and_father(x));
299        }
300    }
301
302    #[test]
303    fn test_proof_indexes() {
304        let check = vec![
305            ((1, 0), None),
306            ((1, 1), None),
307            ((2, 1), None),
308            ((2, 2), None),
309            ((0, 0), Some(vec![])),
310            ((0, 1), Some(vec![])),
311            ((0, 11), Some(vec![9, 3, 2])),
312            ((10, 11), Some(vec![19, 10, 3, 2])),
313            ((9, 11), Some(vec![20, 10, 3, 2])),
314            ((8, 11), Some(vec![17, 7, 4, 2])),
315        ];
316        for ((x1, x2), y) in check {
317            assert_eq!(y, super::get_proof_indexes(x1, x2));
318        }
319    }
320
321    #[test]
322    fn test_proof() {
323        let inputs = vec![
324            vec![Node(vec![1u32])],
325            (1u32..26u32).map(|i| Node(vec![i])).collect(),
326        ];
327        for input in inputs {
328            let tree = super::Tree::from_hashes(input.clone(), merge);
329            let root_hash = tree.get_root_hash().unwrap().clone();
330            let leaf_size = input.len();
331            let loop_size = if leaf_size == 0 { 1 } else { leaf_size };
332            for (index, item) in input.into_iter().enumerate().take(loop_size) {
333                let proof = tree
334                    .get_proof_by_input_index(index)
335                    .expect("proof is not none");
336                assert!(proof.verify(&root_hash, item, merge));
337            }
338        }
339    }
340
341    #[test]
342    fn test_root() {
343        assert_root(&(0u32..12u32).collect::<Vec<u32>>());
344        assert_root(&(0u32..11u32).collect::<Vec<u32>>());
345        assert_root(&[1u32, 5u32, 100u32, 4u32, 7u32, 9u32, 11u32]);
346        assert_root(&(0u32..27u32).collect::<Vec<u32>>());
347    }
348
349    fn assert_root(raw: &[u32]) {
350        let leaves: Vec<Node> = raw.iter().map(|i| Node(vec![*i])).collect();
351        let leaves_len = leaves.len();
352        let tree = super::Tree::from_hashes(leaves, merge);
353        let root = tree.get_root_hash().unwrap();
354        let depth = super::get_depth(leaves_len);
355        let nodes_number = super::get_number_of_nodes(leaves_len);
356        let last_row_number = nodes_number - 2usize.pow(depth as u32) + 1;
357        let first_part_index = leaves_len - last_row_number;
358        let mut first_part = raw[first_part_index..leaves_len]
359            .iter()
360            .cloned()
361            .map(|i| i)
362            .collect::<Vec<u32>>();
363        let second_part = raw[0..first_part_index]
364            .iter()
365            .cloned()
366            .map(|i| i)
367            .collect::<Vec<u32>>();
368        first_part.extend_from_slice(&second_part);
369        assert_eq!(root, &Node(first_part));
370    }
371}