smtree/tree.rs
1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6//! This module provides definitions of the tree node and the paddable sparse Merkle tree,
7//! together with methods of tree generation/update, Merkle proof generation, and random sampling.
8
9use std::fmt::Debug;
10
11use crate::pad_secret::{Secret, ALL_ZEROS_SECRET};
12use crate::utils::tree_index_from_u64;
13use crate::{
14 error::{DecodingError, TreeError},
15 index::{TreeIndex, MAX_HEIGHT},
16 traits::{Mergeable, Paddable, ProofExtractable, Serializable},
17 utils::{log_2, Nil},
18};
19
20/// The direction of a child node, either left or right.
21#[derive(Debug, Clone, PartialEq, Copy)]
22pub enum ChildDir {
23 Left,
24 Right,
25}
26
27/// The type of a tree node:
28/// an internal node has child nodes;
29/// a padding node has padding value and no child node;
30/// a leaf node has real value and no child node.
31#[derive(Debug, Clone, PartialEq)]
32pub enum NodeType {
33 /// An internal node has child nodes.
34 Internal,
35 /// A padding node has padding value and no child node.
36 Padding,
37 /// A leaf node has real value and no child node.
38 Leaf,
39}
40
41impl Default for NodeType {
42 /// The default NodeType is [NodeType::Internal](../tree/enum.NodeType.html#variant.Internal)
43 fn default() -> NodeType {
44 NodeType::Internal
45 }
46}
47
48/// A node in the SMT, consisting of the links to its parent, child nodes, value and node type.
49#[derive(Debug, Clone, Default)]
50pub struct TreeNode<V> {
51 // The reference to its parent/left child/right child.
52 // Being ```None``` for non-existing node.
53 parent: Option<usize>,
54 lch: Option<usize>,
55 rch: Option<usize>,
56
57 value: V,
58 // The value of the tree node.
59 node_type: NodeType, // The type of the node.
60}
61
62impl<V: Clone + Default + Mergeable + Paddable> TreeNode<V> {
63 /// The constructor.
64 pub fn new(node_type: NodeType) -> TreeNode<V> {
65 TreeNode {
66 parent: None,
67 lch: None,
68 rch: None,
69 value: V::default(),
70 node_type,
71 }
72 }
73
74 /// Returns the reference to the left child of the tree node.
75 ///
76 /// If the child node doesn't exist, return ```None```.
77 pub fn get_lch(&self) -> Option<usize> {
78 self.lch
79 }
80
81 /// Returns the reference to the right child of the tree node.
82 ///
83 /// If the child node doesn't exist, return ```None```.
84 pub fn get_rch(&self) -> Option<usize> {
85 self.rch
86 }
87
88 /// Returns the reference to the child in the input direction of the tree node.
89 ///
90 /// If the child node doesn't exist, return ```None```.
91 pub fn get_child_by_dir(&self, dir: ChildDir) -> Option<usize> {
92 match dir {
93 ChildDir::Left => self.lch,
94 ChildDir::Right => self.rch,
95 }
96 }
97
98 /// Returns the reference to the parent of the tree node.
99 ///
100 /// If the parent node doesn't exist, return ```None```.
101 pub fn get_parent(&self) -> Option<usize> {
102 self.parent
103 }
104
105 /// Returns the node type.
106 pub fn get_node_type(&self) -> &NodeType {
107 &self.node_type
108 }
109
110 /// Returns the value of the tree node.
111 pub fn get_value(&self) -> &V {
112 &self.value
113 }
114
115 /// Set the reference to the parent node as the input.
116 pub fn set_parent(&mut self, idx: usize) {
117 self.parent = Some(idx);
118 }
119
120 /// Set the reference to the left child as the input.
121 pub fn set_lch(&mut self, idx: usize) {
122 self.lch = Some(idx);
123 }
124
125 /// Set the reference to the right child as the input.
126 pub fn set_rch(&mut self, idx: usize) {
127 self.rch = Some(idx);
128 }
129
130 /// Set the value of the tree node as the input.
131 pub fn set_value(&mut self, val: V) {
132 self.value = val;
133 }
134
135 /// Set the tree node type as the input.
136 pub fn set_node_type(&mut self, x: NodeType) {
137 self.node_type = x;
138 }
139}
140
141/// Paddable sparse Merkle tree.
142#[derive(Default, Debug)]
143pub struct SparseMerkleTree<P> {
144 height: usize,
145 // The height of the SMT.
146 root: usize,
147 // The reference to the root of the SMT.
148 nodes: Vec<TreeNode<P>>, // The values of tree nodes.
149}
150
151impl<P: Clone + Default + Mergeable + Paddable + ProofExtractable> SparseMerkleTree<P>
152where
153 <P as ProofExtractable>::ProofNode: Clone + Default + Eq + Mergeable + Serializable,
154{
155 /// The constructor.
156 ///
157 /// Panics if the input height exceeds [MAX_HEIGHT](../index/constant.MAX_HEIGHT.html).
158 pub fn new(height: usize) -> SparseMerkleTree<P> {
159 if height > MAX_HEIGHT {
160 panic!("{}", DecodingError::ExceedMaxHeight);
161 }
162 let mut root_node = TreeNode::<P>::new(NodeType::Padding);
163 root_node.set_value(P::padding(&TreeIndex::zero(0), &ALL_ZEROS_SECRET));
164 SparseMerkleTree {
165 height,
166 root: 0,
167 nodes: vec![root_node],
168 }
169 }
170
171 /// A simple Merkle tree constructor, where all items are added next to each other from left to
172 /// right. Note that zero padding secret is used and the height depends on the input list size.
173 /// Use this helper constructor only when simulating a plain Merkle tree.
174 pub fn new_merkle_tree(list: &[P]) -> SparseMerkleTree<P> {
175 let height = log_2(list.len() as u32) as usize;
176 let mut smtree = Self::new(height);
177 smtree.build_merkle_tree_zero_padding(list);
178 smtree
179 }
180
181 /// Returns the height of the SMT.
182 pub fn get_height(&self) -> usize {
183 self.height
184 }
185
186 /// Returns the number of nodes in the SMT.
187 pub fn get_nodes_num(&self) -> usize {
188 self.nodes.len()
189 }
190
191 /// Returns the tree node by reference.
192 ///
193 /// Panics if the reference is out of range.
194 pub fn get_node_by_ref(&self, link: usize) -> &TreeNode<P> {
195 if link > self.nodes.len() {
196 panic!("Input reference out of range");
197 }
198 &self.nodes[link]
199 }
200
201 /// Returns the tree node by references.
202 ///
203 /// Panics if the reference is out of range.
204 pub fn get_node_raw_by_refs(&self, list: &[usize]) -> Vec<&P> {
205 let mut vec = Vec::new();
206 for link in list {
207 vec.push(self.get_node_by_ref(*link).get_value());
208 }
209 vec
210 }
211
212 /// Returns the tree node by references.
213 ///
214 /// Panics if the reference is out of range.
215 pub fn get_node_proof_by_refs(&self, list: &[usize]) -> Vec<P::ProofNode> {
216 let mut vec = Vec::new();
217 for link in list {
218 vec.push(self.get_node_by_ref(*link).get_value().get_proof_node());
219 }
220 vec
221 }
222
223 /// Returns the reference to the root ndoe.
224 pub fn get_root_ref(&self) -> usize {
225 self.root
226 }
227
228 /// Returns the raw data of the root.
229 pub fn get_root_raw(&self) -> &P {
230 self.get_node_by_ref(self.root).get_value()
231 }
232
233 /// Returns the data of the root that is visible in the Merkle proof.
234 pub fn get_root(&self) -> <P as ProofExtractable>::ProofNode {
235 self.get_root_raw().get_proof_node()
236 }
237
238 // Returns the ref and tree index of the ancestor that is closest to the input index in the tree.
239 // Panics if the height of the input index doesn't match with that of the tree.
240 pub fn get_closest_ancestor_ref_index(&self, idx: &TreeIndex) -> (usize, TreeIndex) {
241 // Panics if the the height of the input index doesn't match with the tree height.
242 if idx.get_height() != self.height {
243 panic!("{}", TreeError::HeightNotMatch);
244 }
245
246 let mut ancestor = self.root;
247 let mut ancestor_idx = *idx;
248 // Navigate by the tree index from the root node to the queried node.
249 for i in 0..self.height {
250 if idx.get_bit(i) == 0 {
251 // The queried index is in the left sub-tree.
252 if self.nodes[ancestor].get_lch().is_none() {
253 // Terminates at current bit if there is no child node to follow along.
254 ancestor_idx = ancestor_idx.get_prefix(i);
255 break;
256 }
257 ancestor = self.nodes[ancestor].get_lch().unwrap();
258 } else {
259 // The queried index is in the right sub-tree.
260 if self.nodes[ancestor].get_rch().is_none() {
261 // Terminates at current bit if there is no child node to follow along.
262 ancestor_idx = ancestor_idx.get_prefix(i);
263 break;
264 }
265 ancestor = self.nodes[ancestor].get_rch().unwrap();
266 }
267 }
268 (ancestor, ancestor_idx)
269 }
270
271 /// Returns the tree node of a queried tree index.
272 ///
273 /// Panics if the the height of the input index doesn't match with the tree height.
274 ///
275 /// If the node doesn't exist, return ```None```.
276 pub fn get_leaf_by_index(&self, idx: &TreeIndex) -> Option<&TreeNode<P>> {
277 let (node, node_idx) = self.get_closest_ancestor_ref_index(idx);
278 if node_idx.get_height() < self.height {
279 None
280 } else {
281 Some(&self.nodes[node])
282 }
283 }
284
285 /// Returns the index-reference pairs of all tree nodes in a BFS order.
286 pub fn get_index_ref_pairs(&self) -> Vec<(TreeIndex, usize)> {
287 // Run a BFS to go through all tree nodes and
288 // generate the tree index for each node in the meanwhile.
289 // The first node in the vector is the root.
290 let mut vec: Vec<(TreeIndex, usize)> = vec![(TreeIndex::zero(0), self.root)];
291 let mut head: usize = 0;
292 while head < vec.len() {
293 // If there is a left child, add it to the vector.
294 if let Some(x) = self.nodes[vec[head].1].get_lch() {
295 vec.push((vec[head].0.get_lch_index(), x));
296 }
297 // If there is a right child, add it to the vector.
298 if let Some(x) = self.nodes[vec[head].1].get_rch() {
299 vec.push((vec[head].0.get_rch_index(), x));
300 }
301 // Move on to the next node in the vector.
302 head += 1;
303 }
304 vec
305 }
306
307 /// Returns the index-node pairs of all tree nodes.
308 pub fn get_index_node_pairs(&self) -> Vec<(TreeIndex, &TreeNode<P>)> {
309 let mut vec: Vec<(TreeIndex, &TreeNode<P>)> = Vec::new();
310 let index_ref = self.get_index_ref_pairs();
311 for (index, refer) in index_ref {
312 vec.push((index, &self.nodes[refer]));
313 }
314 vec
315 }
316
317 // Returns the index-node pairs of the input node type.
318 fn get_nodes_of_type(&self, _node_type: NodeType) -> Vec<(TreeIndex, &TreeNode<P>)> {
319 let mut vec: Vec<(TreeIndex, &TreeNode<P>)> = Vec::new();
320 let nodes = self.get_index_node_pairs();
321 for (key, value) in nodes.iter() {
322 if _node_type == *value.get_node_type() {
323 vec.push((*key, value));
324 }
325 }
326 vec
327 }
328
329 /// Returns the index-node pairs of all leaf nodes.
330 pub fn get_leaves(&self) -> Vec<(TreeIndex, &TreeNode<P>)> {
331 self.get_nodes_of_type(NodeType::Leaf)
332 }
333
334 /// Returns the index-node pairs of all padding nodes.
335 pub fn get_paddings(&self) -> Vec<(TreeIndex, &TreeNode<P>)> {
336 self.get_nodes_of_type(NodeType::Padding)
337 }
338
339 /// Returns the index-node pairs of all internal nodes.
340 pub fn get_internals(&self) -> Vec<(TreeIndex, &TreeNode<P>)> {
341 self.get_nodes_of_type(NodeType::Internal)
342 }
343
344 /// Add a new child to the input parent node.
345 fn add_child(&mut self, parent: usize, dir: ChildDir) {
346 let mut node: TreeNode<P> = TreeNode::new(NodeType::Internal);
347 node.set_parent(parent); // Link the parent to the child node.
348 self.nodes.push(node);
349 let len = self.nodes.len();
350
351 // Link the child to the parent node.
352 match dir {
353 ChildDir::Left => {
354 self.nodes[parent].set_lch(len - 1);
355 }
356 ChildDir::Right => {
357 self.nodes[parent].set_rch(len - 1);
358 }
359 }
360 }
361
362 /// Add a left child to the input parent node.
363 fn add_lch(&mut self, parent: usize) {
364 self.add_child(parent, ChildDir::Left);
365 }
366
367 /// Add a right child to the input parent node.
368 fn add_rch(&mut self, parent: usize) {
369 self.add_child(parent, ChildDir::Right);
370 }
371
372 /// Add a new node in the node list with the input node type and value,
373 /// and return the reference to the new node.
374 fn add_node(&mut self, node_type: NodeType) -> usize {
375 let node = TreeNode::new(node_type);
376 self.nodes.push(node);
377 self.nodes.len() - 1
378 }
379
380 /// Set references to child nodes and the value as the merging result of two child nodes.
381 fn set_children(&mut self, parent: &mut TreeNode<P>, lref: usize, rref: usize) {
382 parent.set_lch(lref);
383 parent.set_rch(rref);
384
385 let lch = self.nodes[lref].get_value();
386 let rch = self.nodes[rref].get_value();
387 let value = Mergeable::merge(lch, rch);
388 parent.set_value(value);
389 }
390
391 /// Check if the tree indexes in the list are all valid and sorted.
392 ///
393 /// If the height of some index doesn't match with the height of the tree,
394 /// return [TreeError::HeightNotMatch](../error/enum.TreeError.html#variant.HeightNotMatch).
395 ///
396 /// If the indexes are not in order,
397 /// return [TreeError::IndexNotSorted](../error/enum.TreeError.html#variant.IndexNotSorted).
398 ///
399 /// If there are duplicated indexes in the list,
400 /// return [TreeError::IndexDuplicated](../error/enum.TreeError.html#variant.IndexDuplicated).
401 pub fn check_index_list_validity(&self, list: &[(TreeIndex, P)]) -> Option<TreeError> {
402 // Check validity of the input list.
403 for (i, item) in list.iter().enumerate() {
404 // Panic if any index in the list doesn't match with the height of the SMT.
405 if item.0.get_height() != self.height {
406 return Some(TreeError::HeightNotMatch);
407 }
408 // Panic if two consecutive indexes after sorting are the same.
409 if i > 0 {
410 if item.0 < list[i - 1].0 {
411 return Some(TreeError::IndexNotSorted);
412 }
413 if item.0 == list[i - 1].0 {
414 return Some(TreeError::IndexDuplicated);
415 }
416 }
417 }
418 None
419 }
420
421 /// Construct SMT from the input list of sorted index-value pairs, index being the sorting key.
422 ///
423 /// If the height of some index in the input list doesn't match with the height of the tree,
424 /// return [TreeError::HeightNotMatch](../error/enum.TreeError.html#variant.HeightNotMatch).
425 ///
426 /// If the indexes in the input list are not in order,
427 /// return [TreeError::IndexNotSorted](../error/enum.TreeError.html#variant.IndexNotSorted).
428 ///
429 /// If there are duplicated indexes in the list,
430 /// return [TreeError::IndexDuplicated](../error/enum.TreeError.html#variant.IndexDuplicated).
431 pub fn construct_smt_nodes(
432 &mut self,
433 list: &[(TreeIndex, P)],
434 secret: &Secret,
435 ) -> Option<TreeError> {
436 // Check the validity of the input list.
437 if let Some(x) = self.check_index_list_validity(list) {
438 return Some(x);
439 }
440
441 // If the input list is empty, no change to the tree.
442 if list.is_empty() {
443 return None;
444 }
445 // If the input list is not empty, pop out the original padding root node.
446 self.nodes.pop();
447
448 let mut layer: Vec<(TreeIndex, usize)> = Vec::new();
449 for (i, item) in list.iter().enumerate() {
450 layer.push((item.0, i));
451 }
452
453 // Clear the node list.
454 self.nodes.clear();
455
456 // Build the tree layer by layer.
457 for i in (0..self.height).rev() {
458 let mut upper: Vec<(TreeIndex, usize)> = Vec::new(); // The upper layer to be constructed.
459
460 // Build the upper layer starting from the left-most tree index of the current highest existing layer.
461 let mut head = 0;
462 let length = layer.len();
463 while head < length {
464 // Get the index and instance of the current child node.
465 let node_idx = &layer[head].0;
466 let node_link: usize; // Reference to the current node.
467 if i == self.height - 1 {
468 // If the current layer is the leaf layer, the node hasn't been added to the tree.
469 // Add the node and refer to it, the last node in the node vector.
470 node_link = self.add_node(NodeType::Leaf);
471 self.nodes[node_link].set_value(list[layer[head].1].1.clone());
472 } else {
473 // If the current layer is above the leaf layer, the node is already in the list,
474 // and the reference is the second element of the ```(TreeIndex, usize)``` pair.
475 node_link = layer[head].1;
476 }
477
478 // Get the index and instance of the parent node,
479 // which is to be added to the upper layer.
480 let parent_idx = node_idx.get_parent_index();
481 let mut parent = TreeNode::new(NodeType::Internal);
482
483 // Get the index and instance of the sibling node,
484 // which is to be merged with the current node to get the value of the current node.
485 let sibling_idx = node_idx.get_sibling_index();
486 let sibling_link: usize; // Reference to the sibling node.
487 if node_idx.get_last_bit() == 0 {
488 // When the current node is the left child of its parent,
489 // its sibling either is the next node in the sorted list,
490 // or doesn't exist yet.
491 if head < length - 1 && layer[head + 1].0 == sibling_idx {
492 // When the sibling is the next node in the list,
493 // retrieve the node reference, and move the pointer to the next node.
494 if i == self.height - 1 {
495 // If the current layer is the leaf layer, the node hasn't been added to the tree.
496 // Add the node and refer to it, the last node in the node vector.
497 sibling_link = self.add_node(NodeType::Leaf);
498 self.nodes[sibling_link].set_value(list[layer[head + 1].1].1.clone());
499 } else {
500 // If the current layer is above the leaf layer, the node is already in the list,
501 // and the reference is the second element of the (TreeIndex, usize) pair.
502 sibling_link = layer[head + 1].1;
503 }
504 head += 1; // Move the pointer to the next node.
505 } else {
506 // When the sibling doesn't exist, generate a new padding node.
507 sibling_link = self.add_node(NodeType::Padding);
508 self.nodes[sibling_link].set_value(Paddable::padding(&sibling_idx, secret));
509 }
510 self.set_children(&mut parent, node_link, sibling_link);
511 } else {
512 // When the current node is the right node of its parent,
513 // its sibling doesn't exist yet, so need to generate a new padding node.
514 sibling_link = self.add_node(NodeType::Padding);
515 self.nodes[sibling_link].set_value(Paddable::padding(&sibling_idx, secret));
516 self.set_children(&mut parent, sibling_link, node_link);
517 }
518
519 self.nodes.push(parent); // Add the parent node to the node list.
520 // Link the child nodes to the parent.
521 let len = self.nodes.len();
522 self.nodes[node_link].set_parent(len - 1);
523 self.nodes[sibling_link].set_parent(len - 1);
524 upper.push((parent_idx, len - 1)); // Add the new parent node to the upper layer for generating the next layer.
525
526 head += 1; // Done with the current node, move the pointer to the next node.
527 }
528 layer.clear();
529 layer = upper; // Continue to generate the upper layer.
530 }
531 self.root = self.nodes.len() - 1; // The root is the last node added to the tree.
532 None
533 }
534
535 /// Build SMT from the input list of sorted index-value pairs, index being the sorting key.
536 ///
537 /// Panics if the input list is not valid.
538 pub fn build(&mut self, list: &[(TreeIndex, P)], secret: &Secret) {
539 if let Some(x) = self.construct_smt_nodes(list, secret) {
540 panic!("{}", x);
541 }
542 }
543
544 /// Build simple Merkle tree from the input list with zero padding secret.
545 ///
546 /// Panics if the input list is not valid.
547 fn build_merkle_tree_zero_padding(&mut self, list: &[P]) {
548 let tree_list: Vec<(TreeIndex, P)> = list
549 .iter()
550 .enumerate()
551 .map(|(index, p)| (tree_index_from_u64(self.height, index as u64), p.clone()))
552 .collect();
553 if let Some(x) = self.construct_smt_nodes(&tree_list, &ALL_ZEROS_SECRET) {
554 panic!("{}", x);
555 }
556 }
557
558 /// Retrieve the path from the root to the input leaf node.
559 /// If there is any node on the path or its sibling not existing yet, add it to the tree.
560 fn retrieve_path(&mut self, key: &TreeIndex) -> Vec<usize> {
561 let mut vec: Vec<usize> = Vec::new();
562
563 // Start from the index of the root.
564 let mut node_idx = TreeIndex::zero(0);
565 let mut node: usize = self.root;
566 vec.push(node); // Add the root to the path.
567
568 for i in 0..self.height {
569 // Add the left child if not exist.
570 if self.nodes[node].get_lch().is_none() {
571 self.add_lch(node);
572 }
573 // Add the right child if not exist.
574 if self.nodes[node].get_rch().is_none() {
575 self.add_rch(node);
576 }
577
578 // Move on to the next node in the path.
579 if key.get_bit(i) == 0 {
580 // Go to the left child.
581 node = self.nodes[node].get_lch().unwrap();
582 node_idx = node_idx.get_lch_index();
583 } else {
584 // Go to the right child.
585 node = self.nodes[node].get_rch().unwrap();
586 node_idx = node_idx.get_rch_index();
587 }
588 vec.push(node);
589 }
590 vec
591 }
592
593 /// Update the tree by modifying the leaf node of a certain tree index.
594 ///
595 /// Panics if the height of the input index doesn't match with that of the tree.
596 pub fn update(&mut self, key: &TreeIndex, value: P, secret: &Secret) {
597 // Panic if the height of the input tree index doesn't match with that of the tree.
598 if key.get_height() != self.height {
599 panic!("{}", TreeError::HeightNotMatch)
600 }
601
602 let vec = self.retrieve_path(key); // Retrieve the path from the root to the input leaf node.
603
604 // Update the leaf node.
605 let len = vec.len();
606 self.nodes[vec[len - 1]].set_node_type(NodeType::Leaf);
607 self.nodes[vec[len - 1]].set_value(value);
608
609 assert_eq!(len - 1, self.height); // Make sure the length of the path matches with the tree height.
610
611 // Merge nodes to update parent nodes along the path from the leaf to the root.
612 let mut idx = *key; // The node index starting from the leaf node.
613 for i in (0..len - 1).rev() {
614 let parent = vec[i]; // The link to the parent node.
615 self.nodes[parent].set_node_type(NodeType::Internal);
616
617 let sibling: usize;
618 let sibling_idx: TreeIndex;
619
620 // Get the link to and the index of the sibling node.
621 if idx.get_last_bit() == 0 {
622 sibling = self.nodes[parent].get_rch().unwrap();
623 } else {
624 sibling = self.nodes[parent].get_lch().unwrap();
625 }
626 sibling_idx = idx.get_sibling_index();
627
628 // Adjust the node type of the sibling node.
629 match *self.nodes[sibling].get_node_type() {
630 NodeType::Leaf => (),
631 _ => {
632 // If the sibling node has no child, it is a padding node.
633 if self.nodes[sibling].get_lch().is_none()
634 && self.nodes[sibling].get_rch().is_none()
635 {
636 self.nodes[sibling].set_node_type(NodeType::Padding);
637 self.nodes[sibling].set_value(Paddable::padding(&sibling_idx, secret));
638 }
639 }
640 }
641
642 // Merge the two child nodes and set the value of the parent node.
643 let new_value = Mergeable::merge(
644 self.nodes[self.nodes[parent].get_lch().unwrap()].get_value(),
645 self.nodes[self.nodes[parent].get_rch().unwrap()].get_value(),
646 );
647 self.nodes[parent].set_value(new_value);
648
649 idx = idx.get_parent_index(); // Move on to the node at the upper layer.
650 }
651 }
652
653 /// Returns the references to the input leaf node and siblings of nodes long the Merkle path from the root to the leaf.
654 /// The result is a list of references ```[leaf, sibling, ..., sibling]```.
655 ///
656 /// If the input leaf node doesn't exist, return ```None```.
657 ///
658 /// Panics if the height of the input index is different from the height of the tree.
659 pub fn get_merkle_path_ref(&self, idx: &TreeIndex) -> Option<Vec<usize>> {
660 // Panics if the height of the input index is different from the height of the tree.
661 if idx.get_height() != self.height {
662 panic!("{}", TreeError::HeightNotMatch);
663 }
664
665 let mut siblings = Vec::new();
666 let mut node = self.root;
667 // Add references to sibling nodes along the path from the root to the input node.
668 for i in 0..self.height {
669 if idx.get_bit(i) == 0 {
670 // Add the reference to the right child to the sibling list and move on to the left child.
671 self.nodes[node].get_lch()?;
672 siblings.push(self.nodes[node].get_rch().unwrap());
673 node = self.nodes[node].get_lch().unwrap();
674 } else {
675 // Add the reference to the left child to the sibling list and move on to the right child.
676 self.nodes[node].get_rch()?;
677 siblings.push(self.nodes[node].get_lch().unwrap());
678 node = self.nodes[node].get_rch().unwrap();
679 }
680 }
681 let mut path = vec![node];
682 path.append(&mut siblings);
683 Some(path) // Some([leaf, sibling, ..., sibling])
684 }
685
686 /// Returns the references to the input leaves and siblings of nodes long the batched Merkle paths from the root to the leaves.
687 /// The result is a list of references ```[leaf, ..., leaf, sibling, ..., sibling]```.
688 ///
689 /// If the root or some input leaf node doesn't exist, return ```None```.
690 ///
691 /// If the input list is empty, return an empty vector.
692 ///
693 /// Panics if the input list is not valid.
694 pub fn get_merkle_path_ref_batch(&self, list: &[TreeIndex]) -> Option<Vec<usize>> {
695 // If the input list is empty, return an empty vector.
696 if list.is_empty() {
697 return Some(Vec::new());
698 }
699
700 // Construct an SMT from the input list of indexes with void value.
701 // Panics if the input list is invalid for constructing an SMT.
702 let mut proof_tree: SparseMerkleTree<Nil> = SparseMerkleTree::new(self.height);
703 let mut list_for_building: Vec<(TreeIndex, Nil)> = Vec::new();
704 for index in list {
705 list_for_building.push((*index, Nil));
706 }
707 if let Some(x) = proof_tree.construct_smt_nodes(&list_for_building, &ALL_ZEROS_SECRET) {
708 panic!("{}", x);
709 }
710
711 // Extract values of leaves and siblings in the batched Merkle proof from the original SMT
712 // in the BFS order of all nodes in proof_tree.
713 let mut leaves: Vec<usize> = Vec::new();
714 let mut siblings: Vec<usize> = Vec::new();
715 let vec = proof_tree.get_index_ref_pairs(); // Get the index-ref pair in BFS order.
716 let mut smt_refs = vec![0usize; vec.len()]; // Map from nodes in proof_tree to nodes in self.
717 smt_refs[vec[0].1] = self.root;
718 for (_idx, proof_ref) in vec {
719 let smt_ref = smt_refs[proof_ref];
720 match &proof_tree.nodes[proof_ref].node_type {
721 // The padding node in proof_tree is a sibling node in the batched proof.
722 NodeType::Padding => {
723 siblings.push(smt_ref);
724 }
725 // The leaf node in proof_tree in also a leaf node in the batched proof.
726 NodeType::Leaf => {
727 leaves.push(smt_ref);
728 }
729 NodeType::Internal => {}
730 }
731 // Map the left child of current node in proof_tree to that of the referenced node in the original SMT.
732 if let Some(x) = proof_tree.nodes[proof_ref].get_lch() {
733 self.nodes[smt_ref].get_lch()?;
734 smt_refs[x] = self.nodes[smt_ref].get_lch().unwrap();
735 }
736 // Map the right child of current node in proof_tree to that of the referenced node in the original SMT.
737 if let Some(x) = proof_tree.nodes[proof_ref].get_rch() {
738 self.nodes[smt_ref].get_rch()?;
739 smt_refs[x] = self.nodes[smt_ref].get_rch().unwrap();
740 }
741 }
742 leaves.append(&mut siblings);
743 Some(leaves) // Some([leaf, ..., leaf, sibling, ..., sibling])
744 }
745
746 /// Returns the tree index of closest left/right (depending on input direction) node in the tree.
747 pub fn get_closest_index_by_dir(
748 &self,
749 ancestor_ref: usize,
750 ancestor_idx: TreeIndex,
751 dir: ChildDir,
752 ) -> Option<TreeIndex> {
753 let mut closest_ref = ancestor_ref;
754 let mut closest_idx = ancestor_idx;
755
756 // Find the node of which the subtree contains the closest node.
757 while closest_ref != self.root {
758 let parent_ref = self.nodes[closest_ref].get_parent().unwrap();
759 if self.nodes[parent_ref].get_child_by_dir(dir).is_none()
760 || closest_ref == self.nodes[parent_ref].get_child_by_dir(dir).unwrap()
761 || *self.nodes[self.nodes[parent_ref].get_child_by_dir(dir).unwrap()]
762 .get_node_type()
763 == NodeType::Padding
764 {
765 // When the parent node doesn't have a non-padding dir child or the current node itself is the left child,
766 // go up to the upper level.
767 closest_ref = parent_ref;
768 closest_idx = closest_idx.get_prefix(closest_idx.get_height() - 1);
769 } else {
770 // The sibling of the current node is a dir-child of its parent, thus its subtree contains the target node.
771 closest_ref = self.nodes[parent_ref].get_child_by_dir(dir).unwrap();
772 closest_idx = closest_idx.get_sibling_index();
773 break;
774 }
775 }
776 if closest_idx.get_height() == 0 {
777 // The closest left/right node doesn't exist in the tree.
778 return None;
779 }
780
781 let mut opp_dir = ChildDir::Left;
782 if dir == ChildDir::Left {
783 opp_dir = ChildDir::Right;
784 }
785
786 // Retrieve the opp_dir most node in the subtree, which is our target.
787 while *self.nodes[closest_ref].get_node_type() == NodeType::Internal {
788 if *self.nodes[self.nodes[closest_ref].get_child_by_dir(opp_dir).unwrap()]
789 .get_node_type()
790 == NodeType::Padding
791 {
792 closest_ref = self.nodes[closest_ref].get_child_by_dir(dir).unwrap();
793 closest_idx = closest_idx.get_child_index_by_dir(dir);
794 } else {
795 closest_ref = self.nodes[closest_ref].get_child_by_dir(opp_dir).unwrap();
796 closest_idx = closest_idx.get_child_index_by_dir(opp_dir);
797 }
798 }
799 Some(closest_idx)
800 }
801
802 /// Returns the index-reference pairs to necessary padding nodes to prove that
803 /// the input index is the left/right (depending on the input direction) most real leaf in the tree.
804 /// Note that the reference is the offset from the end of the sibling list.
805 pub fn get_padding_proof_by_dir_index_ref_pairs(
806 idx: &TreeIndex,
807 dir: ChildDir,
808 ) -> Vec<(TreeIndex, usize)> {
809 let mut opp_dir = ChildDir::Right;
810 let mut dir_bit = 0;
811 if dir == ChildDir::Right {
812 opp_dir = ChildDir::Left;
813 dir_bit = 1;
814 }
815
816 // Along the path from the leaf node to the root,
817 // any sibling that is an opp_dir child of its parent,
818 // it must be a padding node and should be part of proof.
819 let mut refs: Vec<(TreeIndex, usize)> = Vec::new();
820 for i in (0..idx.get_height()).rev() {
821 if idx.get_bit(i) == dir_bit {
822 refs.push((
823 idx.get_prefix(i).get_child_index_by_dir(opp_dir),
824 idx.get_height() - 1 - i,
825 ));
826 }
827 }
828 refs
829 }
830
831 /// Returns the index-reference pairs to necessary padding nodes to prove that
832 /// there are no other real leaf nodes between the input indexes in the tree.
833 /// Note that the reference is the offset from the end of the sibling list.
834 ///
835 /// Panics if the input indexes don't have the same height or not in the right order.
836 pub fn get_padding_proof_batch_index_ref_pairs(
837 left_idx: &TreeIndex,
838 right_idx: &TreeIndex,
839 ) -> Vec<(TreeIndex, usize)> {
840 // Panics if the heights of two indexes don't match.
841 if left_idx.get_height() != right_idx.get_height() {
842 panic!("{}", TreeError::HeightNotMatch);
843 }
844 // Panics if the two indexes are not in the right order.
845 if left_idx >= right_idx {
846 panic!("{}", TreeError::IndexNotSorted);
847 }
848
849 // Check all siblings in the batched Merkle proof of the two input indexes.
850 // If any sibling or the subtree of the sibling is between the two input indexes,
851 // they must be padding nodes and should be included in the padding node proof.
852 let mut refs: Vec<(TreeIndex, usize)> = Vec::new();
853 let mut cur_ref = 0usize;
854 let mut index: [TreeIndex; 2] = [*left_idx, *right_idx];
855 let mut parent: [TreeIndex; 2] =
856 [left_idx.get_parent_index(), right_idx.get_parent_index()];
857 while parent[0] != parent[1] {
858 // There won't be such padding nodes in above the common ancestor of two input indexes.
859 for dir_bit in (0..2).rev() {
860 if index[dir_bit].get_last_bit() == dir_bit as u8 {
861 // If the current index or the subtree of the index is between the two input indexes,
862 // add it to the reference of padding node proof.
863 // Not that the reference is the offset from the end of the sibling list in the Merkle proof.
864 refs.push((index[dir_bit].get_sibling_index(), cur_ref));
865 }
866 index[dir_bit] = parent[dir_bit];
867 parent[dir_bit] = parent[dir_bit].get_parent_index();
868 cur_ref += 1;
869 }
870 }
871 refs
872 }
873}