1use std::cmp::PartialEq;
23use std::marker::PhantomData;
24
25#[derive(Debug, Clone)]
26pub struct Tree<T, M> {
27 nodes: Vec<T>,
28 leaf_size: usize,
29 phanton: PhantomData<M>,
30}
31
32#[derive(Debug, Clone)]
33pub struct ProofNode<T> {
34 pub is_right: bool,
35 pub hash: T,
36}
37
38#[derive(Debug, Clone)]
39pub struct Proof<T>(pub Vec<ProofNode<T>>);
40
41impl<T, M> Tree<T, M>
42where
43 T: Default + Clone + PartialEq,
44 M: Fn(&T, &T) -> T,
45{
46 pub fn from_hashes(input: Vec<T>, merge: M) -> Self {
67 let leaf_size = input.len();
68
69 let nodes = match leaf_size {
70 0 => vec![],
71
72 1 => input,
74
75 _ => {
76 let nodes_number = get_number_of_nodes(leaf_size);
77 let mut nodes = vec![T::default(); nodes_number];
78
79 let depth = get_depth(leaf_size);
80
81 let first_input_node_index = nodes_number - leaf_size;
82 let first_index_in_lowest_row = (1 << depth) - 1;
83 let nodes_number_of_lowest_row = nodes_number - first_index_in_lowest_row;
84
85 nodes[first_input_node_index..nodes_number].clone_from_slice(&input);
87
88 let max_nodes_number_of_lowest_row = 1 << depth;
89 if max_nodes_number_of_lowest_row == leaf_size {
91 calc_tree_at_row(&mut nodes, depth, 0, &merge)
93 } else {
94 calc_tree_at_row(&mut nodes, depth, nodes_number_of_lowest_row >> 1, &merge);
95 }
96
97 for i in 1..depth {
99 let row_index = depth - i;
100 calc_tree_at_row(&mut nodes, row_index, 0, &merge);
101 }
102
103 nodes
104 }
105 };
106
107 Self {
108 nodes,
109 leaf_size,
110 phanton: PhantomData,
111 }
112 }
113
114 pub fn get_root_hash(&self) -> Option<&T> {
115 self.nodes.get(0)
116 }
117
118 pub fn get_proof_by_input_index(&self, input_index: usize) -> Option<Proof<T>> {
119 get_proof_indexes(input_index, self.leaf_size).map(|indexes| {
120 Proof::<T>(
121 indexes
122 .into_iter()
123 .map(|i| ProofNode::<T> {
124 is_right: (i & 1) == 0,
125 hash: self.nodes[i].clone(),
126 })
127 .collect(),
128 )
129 })
130 }
131}
132
133impl<T> Proof<T>
134where
135 T: Default + Clone + PartialEq,
136{
137 pub fn verify<M>(&self, root: &T, data: T, merge: M) -> bool
138 where
139 M: Fn(&T, &T) -> T,
140 {
141 &self.0.iter().fold(data, |h, ref x| {
142 if x.is_right {
143 merge(&h, &x.hash)
144 } else {
145 merge(&x.hash, &h)
146 }
147 }) == root
148 }
149}
150
151fn calc_tree_at_row<T, M>(nodes: &mut Vec<T>, row_index: usize, break_cnt: usize, merge: M)
154where
155 M: Fn(&T, &T) -> T,
156{
157 let index_update = (1 << (row_index - 1)) - 1;
159 let size_max = 1 << (row_index - 1);
160 let size_update = if break_cnt > 0 && break_cnt < size_max {
161 break_cnt
162 } else {
163 size_max
164 };
165 for i in 0..size_update {
166 let j = index_update + i;
167 nodes[j] = merge(&nodes[j * 2 + 1], &nodes[j * 2 + 2]);
168 }
169}
170
171#[inline]
172fn get_depth(m: usize) -> usize {
173 let mut x: usize = 1;
174 let mut y: usize = 0;
175 while x < m {
176 x <<= 1;
177 y += 1;
178 }
179 y
180}
181
182#[inline]
183fn get_number_of_nodes(m: usize) -> usize {
184 if m == 0 {
199 1
200 } else {
201 m * 2 - 1
202 }
203}
204
205#[inline]
206fn get_index_of_brother_and_father(index: usize) -> (usize, usize) {
207 let math_index = index + 1;
209 let math_index_for_brother = (math_index & ((!0) - 1)) + ((!math_index) & 1);
211 let math_index_for_father = math_index >> 1;
212 (math_index_for_brother - 1, math_index_for_father - 1)
214}
215
216#[inline]
217fn get_proof_indexes(input_index: usize, leaf_size: usize) -> Option<Vec<usize>> {
218 if input_index == 0 && leaf_size < 2 {
219 Some(vec![])
220 } else if leaf_size != 0 && input_index < leaf_size {
221 let mut ret = Vec::new();
222 let nodes_number = get_number_of_nodes(leaf_size);
223 let mut index = nodes_number - leaf_size + input_index;
224 while index > 0 {
225 let (brother_index, parent_index) = get_index_of_brother_and_father(index);
226 ret.push(brother_index);
227 index = parent_index;
228 }
229 Some(ret)
230 } else {
231 None
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 #[derive(Default, Clone, PartialEq, Debug)]
238 struct Node(Vec<u32>);
239
240 fn merge(left: &Node, right: &Node) -> Node {
241 let mut root: Vec<u32> = vec![];
242 root.extend_from_slice(&left.0);
243 root.extend_from_slice(&right.0);
244 Node(root)
245 }
246
247 #[test]
248 fn test_depth() {
249 let check = vec![
250 (0, 0),
251 (1, 0),
252 (2, 1),
253 (3, 2),
254 (4, 2),
255 (5, 3),
256 (8, 3),
257 (9, 4),
258 (16, 4),
259 (17, 5),
260 ];
261 for (x, y) in check {
262 assert_eq!(y, super::get_depth(x));
263 }
264 }
265
266 #[test]
267 fn test_number_of_nodes() {
268 let check = vec![
269 (0, 1),
270 (1, 1),
271 (2, 3),
272 (3, 5),
273 (4, 7),
274 (5, 9),
275 (8, 15),
276 (9, 17),
277 (16, 31),
278 (20, 39),
279 ];
280 for (x, y) in check {
281 assert_eq!(y, super::get_number_of_nodes(x));
282 }
283 }
284
285 #[test]
286 fn test_index_of_brother_and_father() {
287 let check = vec![
288 (1, (2, 0)),
289 (2, (1, 0)),
290 (11, (12, 5)),
291 (12, (11, 5)),
292 (21, (22, 10)),
293 (22, (21, 10)),
294 (31, (32, 15)),
295 (32, (31, 15)),
296 ];
297 for (x, y) in check {
298 assert_eq!(y, super::get_index_of_brother_and_father(x));
299 }
300 }
301
302 #[test]
303 fn test_proof_indexes() {
304 let check = vec![
305 ((1, 0), None),
306 ((1, 1), None),
307 ((2, 1), None),
308 ((2, 2), None),
309 ((0, 0), Some(vec![])),
310 ((0, 1), Some(vec![])),
311 ((0, 11), Some(vec![9, 3, 2])),
312 ((10, 11), Some(vec![19, 10, 3, 2])),
313 ((9, 11), Some(vec![20, 10, 3, 2])),
314 ((8, 11), Some(vec![17, 7, 4, 2])),
315 ];
316 for ((x1, x2), y) in check {
317 assert_eq!(y, super::get_proof_indexes(x1, x2));
318 }
319 }
320
321 #[test]
322 fn test_proof() {
323 let inputs = vec![
324 vec![Node(vec![1u32])],
325 (1u32..26u32).map(|i| Node(vec![i])).collect(),
326 ];
327 for input in inputs {
328 let tree = super::Tree::from_hashes(input.clone(), merge);
329 let root_hash = tree.get_root_hash().unwrap().clone();
330 let leaf_size = input.len();
331 let loop_size = if leaf_size == 0 { 1 } else { leaf_size };
332 for (index, item) in input.into_iter().enumerate().take(loop_size) {
333 let proof = tree
334 .get_proof_by_input_index(index)
335 .expect("proof is not none");
336 assert!(proof.verify(&root_hash, item, merge));
337 }
338 }
339 }
340
341 #[test]
342 fn test_root() {
343 assert_root(&(0u32..12u32).collect::<Vec<u32>>());
344 assert_root(&(0u32..11u32).collect::<Vec<u32>>());
345 assert_root(&[1u32, 5u32, 100u32, 4u32, 7u32, 9u32, 11u32]);
346 assert_root(&(0u32..27u32).collect::<Vec<u32>>());
347 }
348
349 fn assert_root(raw: &[u32]) {
350 let leaves: Vec<Node> = raw.iter().map(|i| Node(vec![*i])).collect();
351 let leaves_len = leaves.len();
352 let tree = super::Tree::from_hashes(leaves, merge);
353 let root = tree.get_root_hash().unwrap();
354 let depth = super::get_depth(leaves_len);
355 let nodes_number = super::get_number_of_nodes(leaves_len);
356 let last_row_number = nodes_number - 2usize.pow(depth as u32) + 1;
357 let first_part_index = leaves_len - last_row_number;
358 let mut first_part = raw[first_part_index..leaves_len]
359 .iter()
360 .cloned()
361 .map(|i| i)
362 .collect::<Vec<u32>>();
363 let second_part = raw[0..first_part_index]
364 .iter()
365 .cloned()
366 .map(|i| i)
367 .collect::<Vec<u32>>();
368 first_part.extend_from_slice(&second_part);
369 assert_eq!(root, &Node(first_part));
370 }
371}