1use crate::{
2 error::{Error, Result},
3 merge::{merge, MergeValue},
4 traits::Hasher,
5 vec::Vec,
6 H256, MAX_STACK_SIZE,
7};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct MerkleProof {
11 leaves_bitmap: Vec<H256>,
13 merkle_path: Vec<MergeValue>,
15}
16
17impl MerkleProof {
18 pub fn new(leaves_bitmap: Vec<H256>, merkle_path: Vec<MergeValue>) -> Self {
22 MerkleProof {
23 leaves_bitmap,
24 merkle_path,
25 }
26 }
27
28 pub fn take(self) -> (Vec<H256>, Vec<MergeValue>) {
30 let MerkleProof {
31 leaves_bitmap,
32 merkle_path,
33 } = self;
34 (leaves_bitmap, merkle_path)
35 }
36
37 pub fn leaves_count(&self) -> usize {
39 self.leaves_bitmap.len()
40 }
41
42 pub fn leaves_bitmap(&self) -> &Vec<H256> {
44 &self.leaves_bitmap
45 }
46
47 pub fn merkle_path(&self) -> &Vec<MergeValue> {
49 &self.merkle_path
50 }
51
52 pub fn compile(self, mut leaves_keys: Vec<H256>) -> Result<CompiledMerkleProof> {
53 if leaves_keys.is_empty() {
54 return Err(Error::EmptyKeys);
55 } else if leaves_keys.len() != self.leaves_count() {
56 return Err(Error::IncorrectNumberOfLeaves {
57 expected: self.leaves_count(),
58 actual: leaves_keys.len(),
59 });
60 }
61 leaves_keys.sort_unstable();
63
64 let (leaves_bitmap, merkle_path) = self.take();
65
66 let mut proof: Vec<u8> = Vec::with_capacity(merkle_path.len() * 33 + leaves_keys.len());
67 let mut stack_fork_height = [0u8; MAX_STACK_SIZE]; let mut stack_top = 0;
69 let mut leaf_index = 0;
70 let mut merkle_path_index = 0;
71 while leaf_index < leaves_keys.len() {
72 let leaf_key = leaves_keys[leaf_index];
73 let fork_height = if leaf_index + 1 < leaves_keys.len() {
74 leaf_key.fork_height(&leaves_keys[leaf_index + 1])
75 } else {
76 core::u8::MAX
77 };
78 proof.push(0x4C);
79 let mut zero_count = 0u16;
80 for height in 0..=fork_height {
81 if height == fork_height && leaf_index + 1 < leaves_keys.len() {
82 break;
84 }
85 let (op_code_opt, sibling_data_opt): (_, Option<Vec<u8>>) =
86 if stack_top > 0 && stack_fork_height[stack_top - 1] == height {
87 stack_top -= 1;
88 (Some(0x48), None)
89 } else if leaves_bitmap[leaf_index].get_bit(height) {
90 if merkle_path_index >= merkle_path.len() {
91 return Err(Error::CorruptedProof);
92 }
93 let node = &merkle_path[merkle_path_index];
94 merkle_path_index += 1;
95 match node {
96 MergeValue::Value(v) => (Some(0x50), Some(v.as_slice().to_vec())),
97 MergeValue::MergeWithZero {
98 base_node,
99 zero_bits,
100 zero_count,
101 } => {
102 let mut buffer = crate::vec![*zero_count];
103 buffer.extend_from_slice(base_node.as_slice());
104 buffer.extend_from_slice(zero_bits.as_slice());
105 (Some(0x51), Some(buffer))
106 }
107 #[cfg(feature = "trie")]
108 _ => unreachable!(),
109 }
110 } else {
111 zero_count += 1;
112 if zero_count > 256 {
113 return Err(Error::CorruptedProof);
114 }
115 (None, None)
116 };
117 if let Some(op_code) = op_code_opt {
118 if zero_count > 0 {
119 let n = if zero_count == 256 {
120 0
121 } else {
122 zero_count as u8
123 };
124 proof.push(0x4F);
125 proof.push(n);
126 zero_count = 0;
127 }
128 proof.push(op_code);
129 }
130 if let Some(data) = sibling_data_opt {
131 proof.extend(&data);
132 }
133 }
134 if zero_count > 0 {
135 let n = if zero_count == 256 {
136 0
137 } else {
138 zero_count as u8
139 };
140 proof.push(0x4F);
141 proof.push(n);
142 }
143 debug_assert!(stack_top < MAX_STACK_SIZE);
144 stack_fork_height[stack_top] = fork_height;
145 stack_top += 1;
146 leaf_index += 1;
147 }
148
149 if stack_top != 1 {
150 return Err(Error::CorruptedProof);
151 }
152 if leaf_index != leaves_keys.len() {
153 return Err(Error::CorruptedProof);
154 }
155 if merkle_path_index != merkle_path.len() {
156 return Err(Error::CorruptedProof);
157 }
158 Ok(CompiledMerkleProof(proof))
159 }
160
161 pub fn compute_root<H: Hasher + Default>(self, leaves: Vec<(H256, H256)>) -> Result<H256> {
167 self.compile(leaves.iter().map(|(key, _value)| *key).collect())?
168 .compute_root::<H>(leaves)
169 }
170
171 pub fn verify<H: Hasher + Default>(
174 self,
175 root: &H256,
176 leaves: Vec<(H256, H256)>,
177 ) -> Result<bool> {
178 let calculated_root = self.compute_root::<H>(leaves)?;
179 Ok(&calculated_root == root)
180 }
181}
182
183#[derive(Debug, Clone)]
185pub struct CompiledMerkleProof(pub Vec<u8>);
186
187enum OpCodeContext<'a> {
189 L {
190 key: &'a H256,
191 },
192 P {
193 key: &'a H256,
194 height: u8,
195 program_index: usize,
196 },
197 Q {
198 key: &'a H256,
199 height: u8,
200 program_index: usize,
201 },
202 H {
203 key_a: &'a H256,
204 key_b: &'a H256,
205 height: u8,
206 value_a: &'a MergeValue,
207 value_b: &'a MergeValue,
208 },
209 O {
210 key: &'a H256,
211 height: u8,
212 n: u8,
213 },
214}
215
216impl CompiledMerkleProof {
217 fn compute_root_inner<H: Hasher + Default, F: FnMut(OpCodeContext) -> Result<()>>(
218 &self,
219 mut leaves: Vec<(H256, H256)>,
220 mut callback: F,
221 ) -> Result<H256> {
222 leaves.sort_unstable_by_key(|(k, _v)| *k);
223 let mut program_index = 0;
224 let mut leaf_index = 0;
225 let mut stack: Vec<(u16, H256, MergeValue)> = Vec::new();
226 while program_index < self.0.len() {
227 let code = self.0[program_index];
228 program_index += 1;
229 match code {
230 0x4C => {
232 if leaf_index >= leaves.len() {
233 return Err(Error::CorruptedStack);
234 }
235 let (k, v) = leaves[leaf_index];
236 callback(OpCodeContext::L { key: &k })?;
237 stack.push((0, k, MergeValue::from_h256(v)));
238 leaf_index += 1;
239 }
240 0x50 => {
242 if stack.is_empty() {
243 return Err(Error::CorruptedStack);
244 }
245 if program_index + 32 > self.0.len() {
246 return Err(Error::CorruptedProof);
247 }
248 let mut data = [0u8; 32];
249 data.copy_from_slice(&self.0[program_index..program_index + 32]);
250 program_index += 32;
251 let sibling_node = MergeValue::from_h256(H256::from(data));
252 let (height_u16, key, value) = stack.pop().unwrap();
253 if height_u16 > 255 {
254 return Err(Error::CorruptedProof);
255 }
256 let height = height_u16 as u8;
257 let parent_key = key.parent_path(height);
258 callback(OpCodeContext::P {
259 key: &key,
260 height,
261 program_index,
262 })?;
263 let parent = if key.get_bit(height) {
264 merge::<H>(height, &parent_key, &sibling_node, &value)
265 } else {
266 merge::<H>(height, &parent_key, &value, &sibling_node)
267 };
268 stack.push((height_u16 + 1, parent_key, parent));
269 }
270 0x51 => {
274 if stack.is_empty() {
275 return Err(Error::CorruptedStack);
276 }
277 if program_index + 65 > self.0.len() {
278 return Err(Error::CorruptedProof);
279 }
280 let zero_count = self.0[program_index];
281 let base_node = {
282 let mut data = [0u8; 32];
283 data.copy_from_slice(&self.0[program_index + 1..program_index + 33]);
284 H256::from(data)
285 };
286 let zero_bits = {
287 let mut data = [0u8; 32];
288 data.copy_from_slice(&self.0[program_index + 33..program_index + 65]);
289 H256::from(data)
290 };
291 program_index += 65;
292 let sibling_node = MergeValue::MergeWithZero {
293 base_node,
294 zero_bits,
295 zero_count,
296 };
297 let (height_u16, key, value) = stack.pop().unwrap();
298 if height_u16 > 255 {
299 return Err(Error::CorruptedProof);
300 }
301 let height = height_u16 as u8;
302 let parent_key = key.parent_path(height);
303 callback(OpCodeContext::Q {
304 key: &key,
305 height,
306 program_index,
307 })?;
308 let parent = if key.get_bit(height) {
309 merge::<H>(height, &parent_key, &sibling_node, &value)
310 } else {
311 merge::<H>(height, &parent_key, &value, &sibling_node)
312 };
313 stack.push((height_u16 + 1, parent_key, parent));
314 }
315 0x48 => {
317 if stack.len() < 2 {
318 return Err(Error::CorruptedStack);
319 }
320 let (height_b, key_b, value_b) = stack.pop().unwrap();
321 let (height_a, key_a, value_a) = stack.pop().unwrap();
322 if height_a != height_b {
323 return Err(Error::CorruptedProof);
324 }
325 if height_a > 255 {
326 return Err(Error::CorruptedProof);
327 }
328 let height_u16 = height_a;
329 let height = height_u16 as u8;
330 let parent_key_a = key_a.parent_path(height);
331 let parent_key_b = key_b.parent_path(height);
332 if parent_key_a != parent_key_b {
333 return Err(Error::CorruptedProof);
334 }
335 callback(OpCodeContext::H {
336 key_a: &key_a,
337 key_b: &key_b,
338 height,
339 value_a: &value_a,
340 value_b: &value_b,
341 })?;
342 let parent = if key_a.get_bit(height) {
343 merge::<H>(height, &parent_key_a, &value_b, &value_a)
344 } else {
345 merge::<H>(height, &parent_key_a, &value_a, &value_b)
346 };
347 stack.push((height_u16 + 1, parent_key_a, parent));
348 }
349 0x4F => {
351 if stack.is_empty() {
352 return Err(Error::CorruptedStack);
353 }
354 if program_index >= self.0.len() {
355 return Err(Error::CorruptedProof);
356 }
357 let n = self.0[program_index];
358 program_index += 1;
359 let zero_count: u16 = if n == 0 { 256 } else { n as u16 };
360 let (base_height, key, mut value) = stack.pop().unwrap();
361 if base_height > 255 {
362 return Err(Error::CorruptedProof);
363 }
364 callback(OpCodeContext::O {
365 key: &key,
366 height: base_height as u8,
367 n,
368 })?;
369 let mut parent_key = key;
370 let mut height_u16 = base_height;
371 for idx in 0..zero_count {
372 if base_height + idx > 255 {
373 return Err(Error::CorruptedProof);
374 }
375 height_u16 = base_height + idx;
376 let height = height_u16 as u8;
377 parent_key = key.parent_path(height);
378 value = if key.get_bit(height) {
379 merge::<H>(height, &parent_key, &MergeValue::zero(), &value)
380 } else {
381 merge::<H>(height, &parent_key, &value, &MergeValue::zero())
382 };
383 }
384 stack.push((height_u16 + 1, parent_key, value));
385 }
386 _ => return Err(Error::InvalidCode(code)),
387 }
388 debug_assert!(stack.len() <= MAX_STACK_SIZE);
389 }
390 if stack.len() != 1 {
391 return Err(Error::CorruptedStack);
392 }
393 if stack[0].0 != 256 {
394 return Err(Error::CorruptedProof);
395 }
396 if leaf_index != leaves.len() {
397 return Err(Error::CorruptedProof);
398 }
399 Ok(stack[0].2.hash::<H>())
400 }
401
402 pub fn extract_proof<H: Hasher + Default>(
407 &self,
408 all_leaves: Vec<(H256, H256, bool)>,
409 ) -> Result<CompiledMerkleProof> {
410 let mut leaves = Vec::with_capacity(all_leaves.len());
411 let mut sub_keys = Vec::new();
412 for (key, value, included) in all_leaves {
413 leaves.push((key, value));
414 if included {
415 sub_keys.push(key);
416 }
417 }
418
419 fn match_any_sub_key(key: &H256, height: u8, sub_keys: &[H256]) -> bool {
420 sub_keys.iter().any(|sub_key| {
421 if height == 0 {
422 key == sub_key
423 } else {
424 key == &sub_key.parent_path(height - 1)
425 }
426 })
427 }
428
429 let mut sub_proof = Vec::default();
430 let mut is_last_merge_zero = false;
431 let mut callback = |ctx: OpCodeContext| {
432 match ctx {
433 OpCodeContext::L { key } => {
434 if sub_keys.contains(key) {
435 sub_proof.push(0x4C);
436 is_last_merge_zero = false;
437 }
438 }
439 OpCodeContext::P {
440 key,
441 height,
442 program_index,
443 } => {
444 if match_any_sub_key(key, height, &sub_keys) {
445 sub_proof.push(0x50);
446 sub_proof.extend(&self.0[program_index - 32..program_index]);
447 is_last_merge_zero = false;
448 }
449 }
450 OpCodeContext::Q {
451 key,
452 height,
453 program_index,
454 } => {
455 if match_any_sub_key(key, height, &sub_keys) {
456 sub_proof.push(0x51);
457 sub_proof.extend(&self.0[program_index - 65..program_index]);
458 is_last_merge_zero = false;
459 }
460 }
461 OpCodeContext::H {
462 key_a,
463 key_b,
464 height,
465 value_a,
466 value_b,
467 } => {
468 let key_a_included = match_any_sub_key(key_a, height, &sub_keys);
469 let key_b_included = match_any_sub_key(key_b, height, &sub_keys);
470 if key_a_included && key_b_included {
471 sub_proof.push(0x48);
472 is_last_merge_zero = false;
473 } else if key_a_included || key_b_included {
474 let sibling_value = if key_a_included { &value_b } else { &value_a };
475 match sibling_value {
476 MergeValue::Value(hash) => {
477 if hash.is_zero() {
478 if is_last_merge_zero {
479 let last_n = *sub_proof.last().unwrap();
480 if last_n == 0 {
481 return Err(Error::CorruptedProof);
482 }
483 *sub_proof.last_mut().unwrap() = last_n.wrapping_add(1);
484 } else {
485 sub_proof.push(0x4F);
486 sub_proof.push(1);
487 is_last_merge_zero = true;
488 }
489 } else {
490 sub_proof.push(0x50);
491 sub_proof.extend(hash.as_slice());
492 is_last_merge_zero = false;
493 }
494 }
495 MergeValue::MergeWithZero {
496 base_node,
497 zero_bits,
498 zero_count,
499 } => {
500 sub_proof.push(0x51);
501 sub_proof.push(*zero_count);
502 sub_proof.extend(base_node.as_slice());
503 sub_proof.extend(zero_bits.as_slice());
504 is_last_merge_zero = false;
505 }
506 #[cfg(feature = "trie")]
507 _ => {}
508 };
509 }
510 }
511 OpCodeContext::O { key, height, n } => {
512 if match_any_sub_key(key, height, &sub_keys) {
513 if is_last_merge_zero {
514 let last_n = *sub_proof.last().unwrap();
515 if last_n == 0 || (last_n as u16 + n as u16) > 256 {
516 return Err(Error::CorruptedProof);
517 }
518 *sub_proof.last_mut().unwrap() = last_n.wrapping_add(n);
519 } else {
520 sub_proof.push(0x4F);
521 sub_proof.push(n);
522 is_last_merge_zero = true;
523 }
524 }
525 }
526 }
527 Ok(())
528 };
529 self.compute_root_inner::<H, _>(leaves, &mut callback)?;
530 Ok(CompiledMerkleProof(sub_proof))
531 }
532
533 pub fn compute_root<H: Hasher + Default>(&self, leaves: Vec<(H256, H256)>) -> Result<H256> {
534 self.compute_root_inner::<H, _>(leaves, |_| Ok(()))
535 }
536
537 pub fn verify<H: Hasher + Default>(
538 &self,
539 root: &H256,
540 leaves: Vec<(H256, H256)>,
541 ) -> Result<bool> {
542 let calculated_root = self.compute_root::<H>(leaves)?;
543 Ok(&calculated_root == root)
544 }
545}
546
547impl From<CompiledMerkleProof> for Vec<u8> {
548 fn from(proof: CompiledMerkleProof) -> Vec<u8> {
549 proof.0
550 }
551}