sparse_merkle_tree/
merkle_proof.rs

1use crate::{
2    error::{Error, Result},
3    merge::{merge, MergeValue},
4    traits::Hasher,
5    vec::Vec,
6    H256, MAX_STACK_SIZE,
7};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct MerkleProof {
11    // leaf bitmap, bitmap.get_bit(height) is true means there need a non zero sibling in this height
12    leaves_bitmap: Vec<H256>,
13    // needed sibling node hash
14    merkle_path: Vec<MergeValue>,
15}
16
17impl MerkleProof {
18    /// Create MerkleProof
19    /// leaves_bitmap: leaf bitmap, bitmap.get_bit(height) is true means there need a non zero sibling in this height
20    /// proof: needed sibling node hash
21    pub fn new(leaves_bitmap: Vec<H256>, merkle_path: Vec<MergeValue>) -> Self {
22        MerkleProof {
23            leaves_bitmap,
24            merkle_path,
25        }
26    }
27
28    /// Destruct the structure, useful for serialization
29    pub fn take(self) -> (Vec<H256>, Vec<MergeValue>) {
30        let MerkleProof {
31            leaves_bitmap,
32            merkle_path,
33        } = self;
34        (leaves_bitmap, merkle_path)
35    }
36
37    /// number of leaves required by this merkle proof
38    pub fn leaves_count(&self) -> usize {
39        self.leaves_bitmap.len()
40    }
41
42    /// return the inner leaves_bitmap vector
43    pub fn leaves_bitmap(&self) -> &Vec<H256> {
44        &self.leaves_bitmap
45    }
46
47    /// return sibling node hashes
48    pub fn merkle_path(&self) -> &Vec<MergeValue> {
49        &self.merkle_path
50    }
51
52    pub fn compile(self, mut leaves_keys: Vec<H256>) -> Result<CompiledMerkleProof> {
53        if leaves_keys.is_empty() {
54            return Err(Error::EmptyKeys);
55        } else if leaves_keys.len() != self.leaves_count() {
56            return Err(Error::IncorrectNumberOfLeaves {
57                expected: self.leaves_count(),
58                actual: leaves_keys.len(),
59            });
60        }
61        // sort leaves keys
62        leaves_keys.sort_unstable();
63
64        let (leaves_bitmap, merkle_path) = self.take();
65
66        let mut proof: Vec<u8> = Vec::with_capacity(merkle_path.len() * 33 + leaves_keys.len());
67        let mut stack_fork_height = [0u8; MAX_STACK_SIZE]; // store fork height
68        let mut stack_top = 0;
69        let mut leaf_index = 0;
70        let mut merkle_path_index = 0;
71        while leaf_index < leaves_keys.len() {
72            let leaf_key = leaves_keys[leaf_index];
73            let fork_height = if leaf_index + 1 < leaves_keys.len() {
74                leaf_key.fork_height(&leaves_keys[leaf_index + 1])
75            } else {
76                core::u8::MAX
77            };
78            proof.push(0x4C);
79            let mut zero_count = 0u16;
80            for height in 0..=fork_height {
81                if height == fork_height && leaf_index + 1 < leaves_keys.len() {
82                    // If it's not final round, we don't need to merge to root (height=255)
83                    break;
84                }
85                let (op_code_opt, sibling_data_opt): (_, Option<Vec<u8>>) =
86                    if stack_top > 0 && stack_fork_height[stack_top - 1] == height {
87                        stack_top -= 1;
88                        (Some(0x48), None)
89                    } else if leaves_bitmap[leaf_index].get_bit(height) {
90                        if merkle_path_index >= merkle_path.len() {
91                            return Err(Error::CorruptedProof);
92                        }
93                        let node = &merkle_path[merkle_path_index];
94                        merkle_path_index += 1;
95                        match node {
96                            MergeValue::Value(v) => (Some(0x50), Some(v.as_slice().to_vec())),
97                            MergeValue::MergeWithZero {
98                                base_node,
99                                zero_bits,
100                                zero_count,
101                            } => {
102                                let mut buffer = crate::vec![*zero_count];
103                                buffer.extend_from_slice(base_node.as_slice());
104                                buffer.extend_from_slice(zero_bits.as_slice());
105                                (Some(0x51), Some(buffer))
106                            }
107                            #[cfg(feature = "trie")]
108                            _ => unreachable!(),
109                        }
110                    } else {
111                        zero_count += 1;
112                        if zero_count > 256 {
113                            return Err(Error::CorruptedProof);
114                        }
115                        (None, None)
116                    };
117                if let Some(op_code) = op_code_opt {
118                    if zero_count > 0 {
119                        let n = if zero_count == 256 {
120                            0
121                        } else {
122                            zero_count as u8
123                        };
124                        proof.push(0x4F);
125                        proof.push(n);
126                        zero_count = 0;
127                    }
128                    proof.push(op_code);
129                }
130                if let Some(data) = sibling_data_opt {
131                    proof.extend(&data);
132                }
133            }
134            if zero_count > 0 {
135                let n = if zero_count == 256 {
136                    0
137                } else {
138                    zero_count as u8
139                };
140                proof.push(0x4F);
141                proof.push(n);
142            }
143            debug_assert!(stack_top < MAX_STACK_SIZE);
144            stack_fork_height[stack_top] = fork_height;
145            stack_top += 1;
146            leaf_index += 1;
147        }
148
149        if stack_top != 1 {
150            return Err(Error::CorruptedProof);
151        }
152        if leaf_index != leaves_keys.len() {
153            return Err(Error::CorruptedProof);
154        }
155        if merkle_path_index != merkle_path.len() {
156            return Err(Error::CorruptedProof);
157        }
158        Ok(CompiledMerkleProof(proof))
159    }
160
161    /// Compute root from proof
162    /// leaves: a vector of (key, value)
163    ///
164    /// return EmptyProof error when proof is empty
165    /// return CorruptedProof error when proof is invalid
166    pub fn compute_root<H: Hasher + Default>(self, leaves: Vec<(H256, H256)>) -> Result<H256> {
167        self.compile(leaves.iter().map(|(key, _value)| *key).collect())?
168            .compute_root::<H>(leaves)
169    }
170
171    /// Verify merkle proof
172    /// see compute_root_from_proof
173    pub fn verify<H: Hasher + Default>(
174        self,
175        root: &H256,
176        leaves: Vec<(H256, H256)>,
177    ) -> Result<bool> {
178        let calculated_root = self.compute_root::<H>(leaves)?;
179        Ok(&calculated_root == root)
180    }
181}
182
183/// An structure optimized for verify merkle proof
184#[derive(Debug, Clone)]
185pub struct CompiledMerkleProof(pub Vec<u8>);
186
187// A op code context passing to the callback function
188enum OpCodeContext<'a> {
189    L {
190        key: &'a H256,
191    },
192    P {
193        key: &'a H256,
194        height: u8,
195        program_index: usize,
196    },
197    Q {
198        key: &'a H256,
199        height: u8,
200        program_index: usize,
201    },
202    H {
203        key_a: &'a H256,
204        key_b: &'a H256,
205        height: u8,
206        value_a: &'a MergeValue,
207        value_b: &'a MergeValue,
208    },
209    O {
210        key: &'a H256,
211        height: u8,
212        n: u8,
213    },
214}
215
216impl CompiledMerkleProof {
217    fn compute_root_inner<H: Hasher + Default, F: FnMut(OpCodeContext) -> Result<()>>(
218        &self,
219        mut leaves: Vec<(H256, H256)>,
220        mut callback: F,
221    ) -> Result<H256> {
222        leaves.sort_unstable_by_key(|(k, _v)| *k);
223        let mut program_index = 0;
224        let mut leaf_index = 0;
225        let mut stack: Vec<(u16, H256, MergeValue)> = Vec::new();
226        while program_index < self.0.len() {
227            let code = self.0[program_index];
228            program_index += 1;
229            match code {
230                // L : push leaf value
231                0x4C => {
232                    if leaf_index >= leaves.len() {
233                        return Err(Error::CorruptedStack);
234                    }
235                    let (k, v) = leaves[leaf_index];
236                    callback(OpCodeContext::L { key: &k })?;
237                    stack.push((0, k, MergeValue::from_h256(v)));
238                    leaf_index += 1;
239                }
240                // P : hash stack top item with sibling node in proof
241                0x50 => {
242                    if stack.is_empty() {
243                        return Err(Error::CorruptedStack);
244                    }
245                    if program_index + 32 > self.0.len() {
246                        return Err(Error::CorruptedProof);
247                    }
248                    let mut data = [0u8; 32];
249                    data.copy_from_slice(&self.0[program_index..program_index + 32]);
250                    program_index += 32;
251                    let sibling_node = MergeValue::from_h256(H256::from(data));
252                    let (height_u16, key, value) = stack.pop().unwrap();
253                    if height_u16 > 255 {
254                        return Err(Error::CorruptedProof);
255                    }
256                    let height = height_u16 as u8;
257                    let parent_key = key.parent_path(height);
258                    callback(OpCodeContext::P {
259                        key: &key,
260                        height,
261                        program_index,
262                    })?;
263                    let parent = if key.get_bit(height) {
264                        merge::<H>(height, &parent_key, &sibling_node, &value)
265                    } else {
266                        merge::<H>(height, &parent_key, &value, &sibling_node)
267                    };
268                    stack.push((height_u16 + 1, parent_key, parent));
269                }
270                // Q : hash stack top item with sibling node in proof,
271                // this is similar to P except that proof comes in using
272                // MergeWithZero format.
273                0x51 => {
274                    if stack.is_empty() {
275                        return Err(Error::CorruptedStack);
276                    }
277                    if program_index + 65 > self.0.len() {
278                        return Err(Error::CorruptedProof);
279                    }
280                    let zero_count = self.0[program_index];
281                    let base_node = {
282                        let mut data = [0u8; 32];
283                        data.copy_from_slice(&self.0[program_index + 1..program_index + 33]);
284                        H256::from(data)
285                    };
286                    let zero_bits = {
287                        let mut data = [0u8; 32];
288                        data.copy_from_slice(&self.0[program_index + 33..program_index + 65]);
289                        H256::from(data)
290                    };
291                    program_index += 65;
292                    let sibling_node = MergeValue::MergeWithZero {
293                        base_node,
294                        zero_bits,
295                        zero_count,
296                    };
297                    let (height_u16, key, value) = stack.pop().unwrap();
298                    if height_u16 > 255 {
299                        return Err(Error::CorruptedProof);
300                    }
301                    let height = height_u16 as u8;
302                    let parent_key = key.parent_path(height);
303                    callback(OpCodeContext::Q {
304                        key: &key,
305                        height,
306                        program_index,
307                    })?;
308                    let parent = if key.get_bit(height) {
309                        merge::<H>(height, &parent_key, &sibling_node, &value)
310                    } else {
311                        merge::<H>(height, &parent_key, &value, &sibling_node)
312                    };
313                    stack.push((height_u16 + 1, parent_key, parent));
314                }
315                // H : pop 2 items in stack hash them then push the result
316                0x48 => {
317                    if stack.len() < 2 {
318                        return Err(Error::CorruptedStack);
319                    }
320                    let (height_b, key_b, value_b) = stack.pop().unwrap();
321                    let (height_a, key_a, value_a) = stack.pop().unwrap();
322                    if height_a != height_b {
323                        return Err(Error::CorruptedProof);
324                    }
325                    if height_a > 255 {
326                        return Err(Error::CorruptedProof);
327                    }
328                    let height_u16 = height_a;
329                    let height = height_u16 as u8;
330                    let parent_key_a = key_a.parent_path(height);
331                    let parent_key_b = key_b.parent_path(height);
332                    if parent_key_a != parent_key_b {
333                        return Err(Error::CorruptedProof);
334                    }
335                    callback(OpCodeContext::H {
336                        key_a: &key_a,
337                        key_b: &key_b,
338                        height,
339                        value_a: &value_a,
340                        value_b: &value_b,
341                    })?;
342                    let parent = if key_a.get_bit(height) {
343                        merge::<H>(height, &parent_key_a, &value_b, &value_a)
344                    } else {
345                        merge::<H>(height, &parent_key_a, &value_a, &value_b)
346                    };
347                    stack.push((height_u16 + 1, parent_key_a, parent));
348                }
349                // O : hash stack top item with n zero values
350                0x4F => {
351                    if stack.is_empty() {
352                        return Err(Error::CorruptedStack);
353                    }
354                    if program_index >= self.0.len() {
355                        return Err(Error::CorruptedProof);
356                    }
357                    let n = self.0[program_index];
358                    program_index += 1;
359                    let zero_count: u16 = if n == 0 { 256 } else { n as u16 };
360                    let (base_height, key, mut value) = stack.pop().unwrap();
361                    if base_height > 255 {
362                        return Err(Error::CorruptedProof);
363                    }
364                    callback(OpCodeContext::O {
365                        key: &key,
366                        height: base_height as u8,
367                        n,
368                    })?;
369                    let mut parent_key = key;
370                    let mut height_u16 = base_height;
371                    for idx in 0..zero_count {
372                        if base_height + idx > 255 {
373                            return Err(Error::CorruptedProof);
374                        }
375                        height_u16 = base_height + idx;
376                        let height = height_u16 as u8;
377                        parent_key = key.parent_path(height);
378                        value = if key.get_bit(height) {
379                            merge::<H>(height, &parent_key, &MergeValue::zero(), &value)
380                        } else {
381                            merge::<H>(height, &parent_key, &value, &MergeValue::zero())
382                        };
383                    }
384                    stack.push((height_u16 + 1, parent_key, value));
385                }
386                _ => return Err(Error::InvalidCode(code)),
387            }
388            debug_assert!(stack.len() <= MAX_STACK_SIZE);
389        }
390        if stack.len() != 1 {
391            return Err(Error::CorruptedStack);
392        }
393        if stack[0].0 != 256 {
394            return Err(Error::CorruptedProof);
395        }
396        if leaf_index != leaves.len() {
397            return Err(Error::CorruptedProof);
398        }
399        Ok(stack[0].2.hash::<H>())
400    }
401
402    /// Extract sub compiled proof for certain sub leaves from current compiled proof.
403    ///
404    /// The argument must include all leaves. The 3rd item of every tuple
405    /// indicate if the sub key is selected.
406    pub fn extract_proof<H: Hasher + Default>(
407        &self,
408        all_leaves: Vec<(H256, H256, bool)>,
409    ) -> Result<CompiledMerkleProof> {
410        let mut leaves = Vec::with_capacity(all_leaves.len());
411        let mut sub_keys = Vec::new();
412        for (key, value, included) in all_leaves {
413            leaves.push((key, value));
414            if included {
415                sub_keys.push(key);
416            }
417        }
418
419        fn match_any_sub_key(key: &H256, height: u8, sub_keys: &[H256]) -> bool {
420            sub_keys.iter().any(|sub_key| {
421                if height == 0 {
422                    key == sub_key
423                } else {
424                    key == &sub_key.parent_path(height - 1)
425                }
426            })
427        }
428
429        let mut sub_proof = Vec::default();
430        let mut is_last_merge_zero = false;
431        let mut callback = |ctx: OpCodeContext| {
432            match ctx {
433                OpCodeContext::L { key } => {
434                    if sub_keys.contains(key) {
435                        sub_proof.push(0x4C);
436                        is_last_merge_zero = false;
437                    }
438                }
439                OpCodeContext::P {
440                    key,
441                    height,
442                    program_index,
443                } => {
444                    if match_any_sub_key(key, height, &sub_keys) {
445                        sub_proof.push(0x50);
446                        sub_proof.extend(&self.0[program_index - 32..program_index]);
447                        is_last_merge_zero = false;
448                    }
449                }
450                OpCodeContext::Q {
451                    key,
452                    height,
453                    program_index,
454                } => {
455                    if match_any_sub_key(key, height, &sub_keys) {
456                        sub_proof.push(0x51);
457                        sub_proof.extend(&self.0[program_index - 65..program_index]);
458                        is_last_merge_zero = false;
459                    }
460                }
461                OpCodeContext::H {
462                    key_a,
463                    key_b,
464                    height,
465                    value_a,
466                    value_b,
467                } => {
468                    let key_a_included = match_any_sub_key(key_a, height, &sub_keys);
469                    let key_b_included = match_any_sub_key(key_b, height, &sub_keys);
470                    if key_a_included && key_b_included {
471                        sub_proof.push(0x48);
472                        is_last_merge_zero = false;
473                    } else if key_a_included || key_b_included {
474                        let sibling_value = if key_a_included { &value_b } else { &value_a };
475                        match sibling_value {
476                            MergeValue::Value(hash) => {
477                                if hash.is_zero() {
478                                    if is_last_merge_zero {
479                                        let last_n = *sub_proof.last().unwrap();
480                                        if last_n == 0 {
481                                            return Err(Error::CorruptedProof);
482                                        }
483                                        *sub_proof.last_mut().unwrap() = last_n.wrapping_add(1);
484                                    } else {
485                                        sub_proof.push(0x4F);
486                                        sub_proof.push(1);
487                                        is_last_merge_zero = true;
488                                    }
489                                } else {
490                                    sub_proof.push(0x50);
491                                    sub_proof.extend(hash.as_slice());
492                                    is_last_merge_zero = false;
493                                }
494                            }
495                            MergeValue::MergeWithZero {
496                                base_node,
497                                zero_bits,
498                                zero_count,
499                            } => {
500                                sub_proof.push(0x51);
501                                sub_proof.push(*zero_count);
502                                sub_proof.extend(base_node.as_slice());
503                                sub_proof.extend(zero_bits.as_slice());
504                                is_last_merge_zero = false;
505                            }
506                            #[cfg(feature = "trie")]
507                            _ => {}
508                        };
509                    }
510                }
511                OpCodeContext::O { key, height, n } => {
512                    if match_any_sub_key(key, height, &sub_keys) {
513                        if is_last_merge_zero {
514                            let last_n = *sub_proof.last().unwrap();
515                            if last_n == 0 || (last_n as u16 + n as u16) > 256 {
516                                return Err(Error::CorruptedProof);
517                            }
518                            *sub_proof.last_mut().unwrap() = last_n.wrapping_add(n);
519                        } else {
520                            sub_proof.push(0x4F);
521                            sub_proof.push(n);
522                            is_last_merge_zero = true;
523                        }
524                    }
525                }
526            }
527            Ok(())
528        };
529        self.compute_root_inner::<H, _>(leaves, &mut callback)?;
530        Ok(CompiledMerkleProof(sub_proof))
531    }
532
533    pub fn compute_root<H: Hasher + Default>(&self, leaves: Vec<(H256, H256)>) -> Result<H256> {
534        self.compute_root_inner::<H, _>(leaves, |_| Ok(()))
535    }
536
537    pub fn verify<H: Hasher + Default>(
538        &self,
539        root: &H256,
540        leaves: Vec<(H256, H256)>,
541    ) -> Result<bool> {
542        let calculated_root = self.compute_root::<H>(leaves)?;
543        Ok(&calculated_root == root)
544    }
545}
546
547impl From<CompiledMerkleProof> for Vec<u8> {
548    fn from(proof: CompiledMerkleProof) -> Vec<u8> {
549        proof.0
550    }
551}