1use crate::{
2 collections::VecDeque,
3 error::{Error, Result},
4 merge::{merge, MergeValue},
5 merkle_proof::MerkleProof,
6 traits::{Hasher, StoreReadOps, StoreWriteOps, Value},
7 vec::Vec,
8 H256, MAX_STACK_SIZE,
9};
10use core::cmp::Ordering;
11use core::marker::PhantomData;
12#[derive(Debug, Clone, Eq, PartialEq, Hash)]
14pub struct BranchKey {
15 pub height: u8,
16 pub node_key: H256,
17}
18
19impl BranchKey {
20 pub fn new(height: u8, node_key: H256) -> BranchKey {
21 BranchKey { height, node_key }
22 }
23}
24
25impl PartialOrd for BranchKey {
26 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
27 Some(self.cmp(other))
28 }
29}
30impl Ord for BranchKey {
31 fn cmp(&self, other: &Self) -> Ordering {
32 match self.height.cmp(&other.height) {
33 Ordering::Equal => self.node_key.cmp(&other.node_key),
34 ordering => ordering,
35 }
36 }
37}
38
39#[derive(Debug, Eq, PartialEq, Clone)]
41pub struct BranchNode {
42 pub left: MergeValue,
43 pub right: MergeValue,
44}
45
46impl BranchNode {
47 pub fn new_empty() -> BranchNode {
49 BranchNode {
50 left: MergeValue::zero(),
51 right: MergeValue::zero(),
52 }
53 }
54
55 pub fn is_empty(&self) -> bool {
57 self.left.is_zero() && self.right.is_zero()
58 }
59}
60
61#[derive(Default, Debug)]
63pub struct SparseMerkleTree<H, V, S> {
64 store: S,
65 root: H256,
66 phantom: PhantomData<(H, V)>,
67}
68
69impl<H, V, S> SparseMerkleTree<H, V, S> {
70 pub fn new(root: H256, store: S) -> SparseMerkleTree<H, V, S> {
72 SparseMerkleTree {
73 root,
74 store,
75 phantom: PhantomData,
76 }
77 }
78
79 pub fn root(&self) -> &H256 {
81 &self.root
82 }
83
84 pub fn is_empty(&self) -> bool {
86 self.root.is_zero()
87 }
88
89 pub fn take_store(self) -> S {
91 self.store
92 }
93
94 pub fn store(&self) -> &S {
96 &self.store
97 }
98
99 pub fn store_mut(&mut self) -> &mut S {
101 &mut self.store
102 }
103}
104
105impl<H: Hasher + Default, V, S: StoreReadOps<V>> SparseMerkleTree<H, V, S> {
106 pub fn new_with_store(store: S) -> Result<SparseMerkleTree<H, V, S>> {
108 let root_branch_key = BranchKey::new(core::u8::MAX, H256::zero());
109 store
110 .get_branch(&root_branch_key)
111 .map(|branch_node| {
112 branch_node
113 .map(|n| {
114 merge::<H>(core::u8::MAX, &H256::zero(), &n.left, &n.right).hash::<H>()
115 })
116 .unwrap_or_default()
117 })
118 .map(|root| SparseMerkleTree::new(root, store))
119 }
120}
121
122impl<H: Hasher + Default, V: Value, S: StoreReadOps<V> + StoreWriteOps<V>>
123 SparseMerkleTree<H, V, S>
124{
125 pub fn update(&mut self, key: H256, value: V) -> Result<&H256> {
128 let node = MergeValue::from_h256(value.to_h256());
130 if !node.is_zero() {
132 self.store.insert_leaf(key, value)?;
133 } else {
134 self.store.remove_leaf(&key)?;
135 }
136
137 let mut current_key = key;
139 let mut current_node = node;
140 for height in 0..=core::u8::MAX {
141 let parent_key = current_key.parent_path(height);
142 let parent_branch_key = BranchKey::new(height, parent_key);
143 let (left, right) =
144 if let Some(parent_branch) = self.store.get_branch(&parent_branch_key)? {
145 if current_key.is_right(height) {
146 (parent_branch.left, current_node)
147 } else {
148 (current_node, parent_branch.right)
149 }
150 } else if current_key.is_right(height) {
151 (MergeValue::zero(), current_node)
152 } else {
153 (current_node, MergeValue::zero())
154 };
155
156 if !left.is_zero() || !right.is_zero() {
157 self.store.insert_branch(
159 parent_branch_key,
160 BranchNode {
161 left: left.clone(),
162 right: right.clone(),
163 },
164 )?;
165 } else {
166 self.store.remove_branch(&parent_branch_key)?;
168 }
169 current_key = parent_key;
171 current_node = merge::<H>(height, &parent_key, &left, &right);
172 }
173
174 self.root = current_node.hash::<H>();
175 Ok(&self.root)
176 }
177
178 pub fn update_all(&mut self, mut leaves: Vec<(H256, V)>) -> Result<&H256> {
180 leaves.reverse();
182 leaves.sort_by_key(|(a, _)| *a);
183 leaves.dedup_by_key(|(a, _)| *a);
184
185 let mut nodes = leaves
186 .into_iter()
187 .map(|(k, v)| {
188 let value = MergeValue::from_h256(v.to_h256());
189 if !value.is_zero() {
190 self.store.insert_leaf(k, v)
191 } else {
192 self.store.remove_leaf(&k)
193 }
194 .map(|_| (k, value, 0))
195 })
196 .collect::<Result<VecDeque<(H256, MergeValue, u8)>>>()?;
197
198 while let Some((current_key, current_merge_value, height)) = nodes.pop_front() {
199 let parent_key = current_key.parent_path(height);
200 let parent_branch_key = BranchKey::new(height, parent_key);
201
202 let mut right = None;
204 if !current_key.is_right(height) && !nodes.is_empty() {
205 let (neighbor_key, _, neighbor_height) = nodes.front().expect("nodes is not empty");
206 if neighbor_height.eq(&height) {
207 let mut right_key = current_key;
208 right_key.set_bit(height);
209 if neighbor_key.eq(&right_key) {
210 let (_, neighbor_value, _) = nodes.pop_front().expect("nodes is not empty");
211 right = Some(neighbor_value);
212 }
213 }
214 }
215
216 let (left, right) = if let Some(right_merge_value) = right {
217 (current_merge_value, right_merge_value)
218 } else {
219 if let Some(parent_branch) = self.store.get_branch(&parent_branch_key)? {
221 if current_key.is_right(height) {
222 (parent_branch.left, current_merge_value)
223 } else {
224 (current_merge_value, parent_branch.right)
225 }
226 } else if current_key.is_right(height) {
227 (MergeValue::zero(), current_merge_value)
228 } else {
229 (current_merge_value, MergeValue::zero())
230 }
231 };
232
233 if !left.is_zero() || !right.is_zero() {
234 self.store.insert_branch(
235 parent_branch_key,
236 BranchNode {
237 left: left.clone(),
238 right: right.clone(),
239 },
240 )?;
241 } else {
242 self.store.remove_branch(&parent_branch_key)?;
243 }
244 if height == core::u8::MAX {
245 self.root = merge::<H>(height, &parent_key, &left, &right).hash::<H>();
246 break;
247 } else {
248 nodes.push_back((
249 parent_key,
250 merge::<H>(height, &parent_key, &left, &right),
251 height + 1,
252 ));
253 }
254 }
255
256 Ok(&self.root)
257 }
258}
259
260impl<H: Hasher + Default, V: Value, S: StoreReadOps<V>> SparseMerkleTree<H, V, S> {
261 pub fn get(&self, key: &H256) -> Result<V> {
264 if self.is_empty() {
265 return Ok(V::zero());
266 }
267 Ok(self.store.get_leaf(key)?.unwrap_or_else(V::zero))
268 }
269
270 pub fn merkle_proof(&self, mut keys: Vec<H256>) -> Result<MerkleProof> {
272 if keys.is_empty() {
273 return Err(Error::EmptyKeys);
274 }
275
276 keys.sort_unstable();
278
279 let mut leaves_bitmap: Vec<H256> = Default::default();
281 for current_key in &keys {
282 let mut bitmap = H256::zero();
283 for height in 0..=core::u8::MAX {
284 let parent_key = current_key.parent_path(height);
285 let parent_branch_key = BranchKey::new(height, parent_key);
286 if let Some(parent_branch) = self.store.get_branch(&parent_branch_key)? {
287 let sibling = if current_key.is_right(height) {
288 parent_branch.left
289 } else {
290 parent_branch.right
291 };
292 if !sibling.is_zero() {
293 bitmap.set_bit(height);
294 }
295 } else {
296 }
298 }
299 leaves_bitmap.push(bitmap);
300 }
301
302 let mut proof: Vec<MergeValue> = Default::default();
303 let mut stack_fork_height = [0u8; MAX_STACK_SIZE]; let mut stack_top = 0;
305 let mut leaf_index = 0;
306 while leaf_index < keys.len() {
307 let leaf_key = keys[leaf_index];
308 let fork_height = if leaf_index + 1 < keys.len() {
309 leaf_key.fork_height(&keys[leaf_index + 1])
310 } else {
311 core::u8::MAX
312 };
313 for height in 0..=fork_height {
314 if height == fork_height && leaf_index + 1 < keys.len() {
315 break;
317 }
318 let parent_key = leaf_key.parent_path(height);
319 let is_right = leaf_key.is_right(height);
320
321 if stack_top > 0 && stack_fork_height[stack_top - 1] == height {
323 stack_top -= 1;
324 } else if leaves_bitmap[leaf_index].get_bit(height) {
325 let parent_branch_key = BranchKey::new(height, parent_key);
326 if let Some(parent_branch) = self.store.get_branch(&parent_branch_key)? {
327 let sibling = if is_right {
328 parent_branch.left
329 } else {
330 parent_branch.right
331 };
332 if !sibling.is_zero() {
333 proof.push(sibling);
334 } else {
335 unreachable!();
336 }
337 } else {
338 }
340 }
341 }
342 debug_assert!(stack_top < MAX_STACK_SIZE);
343 stack_fork_height[stack_top] = fork_height;
344 stack_top += 1;
345 leaf_index += 1;
346 }
347 assert_eq!(stack_top, 1);
348 Ok(MerkleProof::new(leaves_bitmap, proof))
349 }
350}