sparse_merkle_tree/
tree.rs

1use crate::{
2    collections::VecDeque,
3    error::{Error, Result},
4    merge::{merge, MergeValue},
5    merkle_proof::MerkleProof,
6    traits::{Hasher, StoreReadOps, StoreWriteOps, Value},
7    vec::Vec,
8    H256, MAX_STACK_SIZE,
9};
10use core::cmp::Ordering;
11use core::marker::PhantomData;
12/// The branch key
13#[derive(Debug, Clone, Eq, PartialEq, Hash)]
14pub struct BranchKey {
15    pub height: u8,
16    pub node_key: H256,
17}
18
19impl BranchKey {
20    pub fn new(height: u8, node_key: H256) -> BranchKey {
21        BranchKey { height, node_key }
22    }
23}
24
25impl PartialOrd for BranchKey {
26    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
27        Some(self.cmp(other))
28    }
29}
30impl Ord for BranchKey {
31    fn cmp(&self, other: &Self) -> Ordering {
32        match self.height.cmp(&other.height) {
33            Ordering::Equal => self.node_key.cmp(&other.node_key),
34            ordering => ordering,
35        }
36    }
37}
38
39/// A branch in the SMT
40#[derive(Debug, Eq, PartialEq, Clone)]
41pub struct BranchNode {
42    pub left: MergeValue,
43    pub right: MergeValue,
44}
45
46impl BranchNode {
47    /// Create a new empty branch
48    pub fn new_empty() -> BranchNode {
49        BranchNode {
50            left: MergeValue::zero(),
51            right: MergeValue::zero(),
52        }
53    }
54
55    /// Determine whether a node did not store any value
56    pub fn is_empty(&self) -> bool {
57        self.left.is_zero() && self.right.is_zero()
58    }
59}
60
61/// Sparse merkle tree
62#[derive(Default, Debug)]
63pub struct SparseMerkleTree<H, V, S> {
64    store: S,
65    root: H256,
66    phantom: PhantomData<(H, V)>,
67}
68
69impl<H, V, S> SparseMerkleTree<H, V, S> {
70    /// Build a merkle tree from root and store
71    pub fn new(root: H256, store: S) -> SparseMerkleTree<H, V, S> {
72        SparseMerkleTree {
73            root,
74            store,
75            phantom: PhantomData,
76        }
77    }
78
79    /// Merkle root
80    pub fn root(&self) -> &H256 {
81        &self.root
82    }
83
84    /// Check empty of the tree
85    pub fn is_empty(&self) -> bool {
86        self.root.is_zero()
87    }
88
89    /// Destroy current tree and retake store
90    pub fn take_store(self) -> S {
91        self.store
92    }
93
94    /// Get backend store
95    pub fn store(&self) -> &S {
96        &self.store
97    }
98
99    /// Get mutable backend store
100    pub fn store_mut(&mut self) -> &mut S {
101        &mut self.store
102    }
103}
104
105impl<H: Hasher + Default, V, S: StoreReadOps<V>> SparseMerkleTree<H, V, S> {
106    /// Build a merkle tree from store, the root will be calculated automatically
107    pub fn new_with_store(store: S) -> Result<SparseMerkleTree<H, V, S>> {
108        let root_branch_key = BranchKey::new(core::u8::MAX, H256::zero());
109        store
110            .get_branch(&root_branch_key)
111            .map(|branch_node| {
112                branch_node
113                    .map(|n| {
114                        merge::<H>(core::u8::MAX, &H256::zero(), &n.left, &n.right).hash::<H>()
115                    })
116                    .unwrap_or_default()
117            })
118            .map(|root| SparseMerkleTree::new(root, store))
119    }
120}
121
122impl<H: Hasher + Default, V: Value, S: StoreReadOps<V> + StoreWriteOps<V>>
123    SparseMerkleTree<H, V, S>
124{
125    /// Update a leaf, return new merkle root
126    /// set to zero value to delete a key
127    pub fn update(&mut self, key: H256, value: V) -> Result<&H256> {
128        // compute and store new leaf
129        let node = MergeValue::from_h256(value.to_h256());
130        // notice when value is zero the leaf is deleted, so we do not need to store it
131        if !node.is_zero() {
132            self.store.insert_leaf(key, value)?;
133        } else {
134            self.store.remove_leaf(&key)?;
135        }
136
137        // recompute the tree from bottom to top
138        let mut current_key = key;
139        let mut current_node = node;
140        for height in 0..=core::u8::MAX {
141            let parent_key = current_key.parent_path(height);
142            let parent_branch_key = BranchKey::new(height, parent_key);
143            let (left, right) =
144                if let Some(parent_branch) = self.store.get_branch(&parent_branch_key)? {
145                    if current_key.is_right(height) {
146                        (parent_branch.left, current_node)
147                    } else {
148                        (current_node, parent_branch.right)
149                    }
150                } else if current_key.is_right(height) {
151                    (MergeValue::zero(), current_node)
152                } else {
153                    (current_node, MergeValue::zero())
154                };
155
156            if !left.is_zero() || !right.is_zero() {
157                // insert or update branch
158                self.store.insert_branch(
159                    parent_branch_key,
160                    BranchNode {
161                        left: left.clone(),
162                        right: right.clone(),
163                    },
164                )?;
165            } else {
166                // remove empty branch
167                self.store.remove_branch(&parent_branch_key)?;
168            }
169            // prepare for next round
170            current_key = parent_key;
171            current_node = merge::<H>(height, &parent_key, &left, &right);
172        }
173
174        self.root = current_node.hash::<H>();
175        Ok(&self.root)
176    }
177
178    /// Update multiple leaves at once
179    pub fn update_all(&mut self, mut leaves: Vec<(H256, V)>) -> Result<&H256> {
180        // Dedup(only keep the last of each key) and sort leaves
181        leaves.reverse();
182        leaves.sort_by_key(|(a, _)| *a);
183        leaves.dedup_by_key(|(a, _)| *a);
184
185        let mut nodes = leaves
186            .into_iter()
187            .map(|(k, v)| {
188                let value = MergeValue::from_h256(v.to_h256());
189                if !value.is_zero() {
190                    self.store.insert_leaf(k, v)
191                } else {
192                    self.store.remove_leaf(&k)
193                }
194                .map(|_| (k, value, 0))
195            })
196            .collect::<Result<VecDeque<(H256, MergeValue, u8)>>>()?;
197
198        while let Some((current_key, current_merge_value, height)) = nodes.pop_front() {
199            let parent_key = current_key.parent_path(height);
200            let parent_branch_key = BranchKey::new(height, parent_key);
201
202            // Test for neighbors
203            let mut right = None;
204            if !current_key.is_right(height) && !nodes.is_empty() {
205                let (neighbor_key, _, neighbor_height) = nodes.front().expect("nodes is not empty");
206                if neighbor_height.eq(&height) {
207                    let mut right_key = current_key;
208                    right_key.set_bit(height);
209                    if neighbor_key.eq(&right_key) {
210                        let (_, neighbor_value, _) = nodes.pop_front().expect("nodes is not empty");
211                        right = Some(neighbor_value);
212                    }
213                }
214            }
215
216            let (left, right) = if let Some(right_merge_value) = right {
217                (current_merge_value, right_merge_value)
218            } else {
219                // In case neighbor is not available, fetch from store
220                if let Some(parent_branch) = self.store.get_branch(&parent_branch_key)? {
221                    if current_key.is_right(height) {
222                        (parent_branch.left, current_merge_value)
223                    } else {
224                        (current_merge_value, parent_branch.right)
225                    }
226                } else if current_key.is_right(height) {
227                    (MergeValue::zero(), current_merge_value)
228                } else {
229                    (current_merge_value, MergeValue::zero())
230                }
231            };
232
233            if !left.is_zero() || !right.is_zero() {
234                self.store.insert_branch(
235                    parent_branch_key,
236                    BranchNode {
237                        left: left.clone(),
238                        right: right.clone(),
239                    },
240                )?;
241            } else {
242                self.store.remove_branch(&parent_branch_key)?;
243            }
244            if height == core::u8::MAX {
245                self.root = merge::<H>(height, &parent_key, &left, &right).hash::<H>();
246                break;
247            } else {
248                nodes.push_back((
249                    parent_key,
250                    merge::<H>(height, &parent_key, &left, &right),
251                    height + 1,
252                ));
253            }
254        }
255
256        Ok(&self.root)
257    }
258}
259
260impl<H: Hasher + Default, V: Value, S: StoreReadOps<V>> SparseMerkleTree<H, V, S> {
261    /// Get value of a leaf
262    /// return zero value if leaf not exists
263    pub fn get(&self, key: &H256) -> Result<V> {
264        if self.is_empty() {
265            return Ok(V::zero());
266        }
267        Ok(self.store.get_leaf(key)?.unwrap_or_else(V::zero))
268    }
269
270    /// Generate merkle proof
271    pub fn merkle_proof(&self, mut keys: Vec<H256>) -> Result<MerkleProof> {
272        if keys.is_empty() {
273            return Err(Error::EmptyKeys);
274        }
275
276        // sort keys
277        keys.sort_unstable();
278
279        // Collect leaf bitmaps
280        let mut leaves_bitmap: Vec<H256> = Default::default();
281        for current_key in &keys {
282            let mut bitmap = H256::zero();
283            for height in 0..=core::u8::MAX {
284                let parent_key = current_key.parent_path(height);
285                let parent_branch_key = BranchKey::new(height, parent_key);
286                if let Some(parent_branch) = self.store.get_branch(&parent_branch_key)? {
287                    let sibling = if current_key.is_right(height) {
288                        parent_branch.left
289                    } else {
290                        parent_branch.right
291                    };
292                    if !sibling.is_zero() {
293                        bitmap.set_bit(height);
294                    }
295                } else {
296                    // The key is not in the tree (support non-inclusion proof)
297                }
298            }
299            leaves_bitmap.push(bitmap);
300        }
301
302        let mut proof: Vec<MergeValue> = Default::default();
303        let mut stack_fork_height = [0u8; MAX_STACK_SIZE]; // store fork height
304        let mut stack_top = 0;
305        let mut leaf_index = 0;
306        while leaf_index < keys.len() {
307            let leaf_key = keys[leaf_index];
308            let fork_height = if leaf_index + 1 < keys.len() {
309                leaf_key.fork_height(&keys[leaf_index + 1])
310            } else {
311                core::u8::MAX
312            };
313            for height in 0..=fork_height {
314                if height == fork_height && leaf_index + 1 < keys.len() {
315                    // If it's not final round, we don't need to merge to root (height=255)
316                    break;
317                }
318                let parent_key = leaf_key.parent_path(height);
319                let is_right = leaf_key.is_right(height);
320
321                // has non-zero sibling
322                if stack_top > 0 && stack_fork_height[stack_top - 1] == height {
323                    stack_top -= 1;
324                } else if leaves_bitmap[leaf_index].get_bit(height) {
325                    let parent_branch_key = BranchKey::new(height, parent_key);
326                    if let Some(parent_branch) = self.store.get_branch(&parent_branch_key)? {
327                        let sibling = if is_right {
328                            parent_branch.left
329                        } else {
330                            parent_branch.right
331                        };
332                        if !sibling.is_zero() {
333                            proof.push(sibling);
334                        } else {
335                            unreachable!();
336                        }
337                    } else {
338                        // The key is not in the tree (support non-inclusion proof)
339                    }
340                }
341            }
342            debug_assert!(stack_top < MAX_STACK_SIZE);
343            stack_fork_height[stack_top] = fork_height;
344            stack_top += 1;
345            leaf_index += 1;
346        }
347        assert_eq!(stack_top, 1);
348        Ok(MerkleProof::new(leaves_bitmap, proof))
349    }
350}