spl_merkle_tree_reference/
lib.rs

1#![allow(clippy::arithmetic_side_effects)]
2use {
3    solana_program::keccak::hashv,
4    std::{cell::RefCell, collections::VecDeque, iter::FromIterator, rc::Rc},
5};
6
7pub type Node = [u8; 32];
8pub const EMPTY: Node = [0; 32];
9
10/// Max number of concurrent changes to tree supported before having to
11/// regenerate proofs
12pub const MAX_SIZE: usize = 64;
13
14/// Max depth of the Merkle tree
15pub const MAX_DEPTH: usize = 14;
16
17/// Used for node parity when hashing
18pub const MASK: usize = MAX_SIZE - 1;
19
20/// Recomputes root of the Merkle tree from Node & proof
21pub fn recompute(mut leaf: Node, proof: &[Node], index: u32) -> Node {
22    for (i, s) in proof.iter().enumerate() {
23        if index >> i & 1 == 0 {
24            let res = hashv(&[&leaf, s.as_ref()]);
25            leaf.copy_from_slice(res.as_ref());
26        } else {
27            let res = hashv(&[s.as_ref(), &leaf]);
28            leaf.copy_from_slice(res.as_ref());
29        }
30    }
31    leaf
32}
33
34// Off-chain implementation to keep track of nodes
35pub struct MerkleTree {
36    pub leaf_nodes: Vec<Rc<RefCell<TreeNode>>>,
37    pub root: Node,
38}
39
40impl MerkleTree {
41    /// Calculates updated root from the passed leaves
42    pub fn new(leaves: &[Node]) -> Self {
43        let mut leaf_nodes = vec![];
44        for (i, node) in leaves.iter().enumerate() {
45            let mut tree_node = TreeNode::new_empty(0, i as u128);
46            tree_node.node = *node;
47            leaf_nodes.push(Rc::new(RefCell::new(tree_node)));
48        }
49        let root = MerkleTree::build_root(leaf_nodes.as_slice());
50        Self { leaf_nodes, root }
51    }
52
53    /// Builds root from stack of leaves
54    pub fn build_root(leaves: &[Rc<RefCell<TreeNode>>]) -> Node {
55        let mut tree = VecDeque::from_iter(leaves.iter().map(Rc::clone));
56        let mut seq_num = leaves.len() as u128;
57        while tree.len() > 1 {
58            let left = tree.pop_front().unwrap();
59            let level = left.borrow().level;
60            let right = if level != tree[0].borrow().level {
61                let node = Rc::new(RefCell::new(TreeNode::new_empty(level, seq_num)));
62                seq_num += 1;
63                node
64            } else {
65                tree.pop_front().unwrap()
66            };
67            let mut hashed_parent = EMPTY;
68
69            hashed_parent
70                .copy_from_slice(hashv(&[&left.borrow().node, &right.borrow().node]).as_ref());
71            let parent = Rc::new(RefCell::new(TreeNode::new(
72                hashed_parent,
73                left.clone(),
74                right.clone(),
75                level + 1,
76                seq_num,
77            )));
78            left.borrow_mut().assign_parent(parent.clone());
79            right.borrow_mut().assign_parent(parent.clone());
80            tree.push_back(parent);
81            seq_num += 1;
82        }
83
84        let root = tree[0].borrow().node;
85        root
86    }
87
88    /// Traverses TreeNodes upwards to root from a Leaf TreeNode
89    /// hashing along the way
90    pub fn get_proof_of_leaf(&self, idx: usize) -> Vec<Node> {
91        let mut proof = vec![];
92        let mut node = self.leaf_nodes[idx].clone();
93        loop {
94            let ref_node = node.clone();
95            if ref_node.borrow().parent.is_none() {
96                break;
97            }
98            let parent = ref_node.borrow().parent.as_ref().unwrap().clone();
99            if parent.borrow().left.as_ref().unwrap().borrow().id == ref_node.borrow().id {
100                proof.push(parent.borrow().right.as_ref().unwrap().borrow().node);
101            } else {
102                proof.push(parent.borrow().left.as_ref().unwrap().borrow().node);
103            }
104            node = parent;
105        }
106        proof
107    }
108
109    /// Updates root from an updated leaf node set at index: `idx`
110    fn update_root_from_leaf(&mut self, leaf_idx: usize) {
111        let mut node = self.leaf_nodes[leaf_idx].clone();
112        loop {
113            let ref_node = node.clone();
114            if ref_node.borrow().parent.is_none() {
115                self.root = ref_node.borrow().node;
116                break;
117            }
118            let parent = ref_node.borrow().parent.as_ref().unwrap().clone();
119            let hash = if parent.borrow().left.as_ref().unwrap().borrow().id == ref_node.borrow().id
120            {
121                hashv(&[
122                    &ref_node.borrow().node,
123                    &parent.borrow().right.as_ref().unwrap().borrow().node,
124                ])
125            } else {
126                hashv(&[
127                    &parent.borrow().left.as_ref().unwrap().borrow().node,
128                    &ref_node.borrow().node,
129                ])
130            };
131            node = parent;
132            node.borrow_mut().node.copy_from_slice(hash.as_ref());
133        }
134    }
135
136    pub fn get_node(&self, idx: usize) -> Node {
137        self.leaf_nodes[idx].borrow().node
138    }
139
140    pub fn get_root(&self) -> Node {
141        self.root
142    }
143
144    pub fn add_leaf(&mut self, leaf: Node, leaf_idx: usize) {
145        self.leaf_nodes[leaf_idx].borrow_mut().node = leaf;
146        self.update_root_from_leaf(leaf_idx)
147    }
148
149    pub fn remove_leaf(&mut self, leaf_idx: usize) {
150        self.leaf_nodes[leaf_idx].borrow_mut().node = EMPTY;
151        self.update_root_from_leaf(leaf_idx)
152    }
153
154    pub fn get_leaf(&self, leaf_idx: usize) -> Node {
155        self.leaf_nodes[leaf_idx].borrow().node
156    }
157}
158
159#[derive(Clone)]
160pub struct TreeNode {
161    pub node: Node,
162    left: Option<Rc<RefCell<TreeNode>>>,
163    right: Option<Rc<RefCell<TreeNode>>>,
164    parent: Option<Rc<RefCell<TreeNode>>>,
165    level: u32,
166    /// ID needed to figure out whether we came from left or right child node
167    /// when hashing path upwards
168    id: u128,
169}
170
171impl TreeNode {
172    pub fn new(
173        node: Node,
174        left: Rc<RefCell<TreeNode>>,
175        right: Rc<RefCell<TreeNode>>,
176        level: u32,
177        id: u128,
178    ) -> Self {
179        Self {
180            node,
181            left: Some(left),
182            right: Some(right),
183            parent: None,
184            level,
185            id,
186        }
187    }
188
189    pub fn new_empty(level: u32, id: u128) -> Self {
190        Self {
191            node: empty_node(level),
192            left: None,
193            right: None,
194            parent: None,
195            level,
196            id,
197        }
198    }
199
200    /// Allows to propagate parent assignment
201    pub fn assign_parent(&mut self, parent: Rc<RefCell<TreeNode>>) {
202        self.parent = Some(parent);
203    }
204}
205
206/// Calculates hash of empty nodes up to level i
207/// TODO: cache this
208pub fn empty_node(level: u32) -> Node {
209    let mut data = EMPTY;
210    if level != 0 {
211        let lower_empty = empty_node(level - 1);
212        let hash = hashv(&[&lower_empty, &lower_empty]);
213        data.copy_from_slice(hash.as_ref());
214    }
215    data
216}