use crate::{
collections::{BTreeMap, VecDeque},
error::{Error, Result},
merge::{hash_leaf, merge},
merkle_proof::MerkleProof,
traits::{Hasher, Store, Value},
vec::Vec,
EXPECTED_PATH_SIZE, H256,
};
use core::{cmp::max, marker::PhantomData};
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct BranchNode {
pub fork_height: u8,
pub key: H256,
pub node_type: NodeType,
}
impl BranchNode {
fn node_at(&self, height: u8) -> NodeType {
match self.node_type {
NodeType::Pair(node, sibling) => {
let is_right = self.key.get_bit(height);
if is_right {
NodeType::Pair(sibling, node)
} else {
NodeType::Pair(node, sibling)
}
}
NodeType::Single(node) => NodeType::Single(node),
}
}
fn key(&self) -> &H256 {
&self.key
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum NodeType {
Single(H256),
Pair(H256, H256),
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct LeafNode<V> {
pub key: H256,
pub value: V,
}
#[derive(Default, Debug)]
pub struct SparseMerkleTree<H, V, S> {
store: S,
root: H256,
phantom: PhantomData<(H, V)>,
}
impl<H: Hasher + Default, V: Value, S: Store<V>> SparseMerkleTree<H, V, S> {
pub fn new(root: H256, store: S) -> SparseMerkleTree<H, V, S> {
SparseMerkleTree {
root,
store,
phantom: PhantomData,
}
}
pub fn root(&self) -> &H256 {
&self.root
}
pub fn is_empty(&self) -> bool {
self.root.is_zero()
}
pub fn take_store(self) -> S {
self.store
}
pub fn store(&self) -> &S {
&self.store
}
pub fn store_mut(&mut self) -> &mut S {
&mut self.store
}
pub fn update(&mut self, key: H256, value: V) -> Result<&H256> {
let mut path = Vec::new();
if !self.is_empty() {
let mut node = self.root;
loop {
let branch_node = self
.store
.get_branch(&node)?
.ok_or_else(|| Error::MissingBranch(node))?;
let height = max(key.fork_height(branch_node.key()), branch_node.fork_height);
match branch_node.node_at(height) {
NodeType::Pair(left, right) => {
if height > branch_node.fork_height {
path.push((height, node));
break;
} else {
self.store.remove_branch(&node)?;
let is_right = key.get_bit(height);
if is_right {
node = right;
path.push((height, left));
} else {
node = left;
path.push((height, right));
}
}
}
NodeType::Single(node) => {
if &key == branch_node.key() {
self.store.remove_leaf(&node)?;
self.store.remove_branch(&node)?;
} else {
path.push((height, node));
}
break;
}
}
}
}
let mut node = hash_leaf::<H>(&key, &value.to_h256());
if !node.is_zero() {
self.store.insert_leaf(node, LeafNode { key, value })?;
self.store.insert_branch(
node,
BranchNode {
key,
fork_height: 0,
node_type: NodeType::Single(node),
},
)?;
}
for (height, sibling) in path.into_iter().rev() {
let is_right = key.get_bit(height);
let parent = if is_right {
merge::<H>(&sibling, &node)
} else {
merge::<H>(&node, &sibling)
};
if !node.is_zero() {
let branch_node = BranchNode {
key,
fork_height: height,
node_type: NodeType::Pair(node, sibling),
};
self.store.insert_branch(parent, branch_node)?;
}
node = parent;
}
self.root = node;
Ok(&self.root)
}
pub fn get(&self, key: &H256) -> Result<V> {
if self.is_empty() {
return Ok(V::zero());
}
let mut node = self.root;
loop {
let branch_node = self
.store
.get_branch(&node)?
.ok_or_else(|| Error::MissingBranch(node))?;
match branch_node.node_at(branch_node.fork_height) {
NodeType::Pair(left, right) => {
let is_right = key.get_bit(branch_node.fork_height);
node = if is_right { right } else { left };
}
NodeType::Single(node) => {
if key == branch_node.key() {
return Ok(self
.store
.get_leaf(&node)?
.ok_or_else(|| Error::MissingLeaf(node))?
.value);
} else {
return Ok(V::zero());
}
}
}
}
}
fn fetch_merkle_path(&self, key: &H256, cache: &mut BTreeMap<(u8, H256), H256>) -> Result<()> {
let mut node = self.root;
loop {
let branch_node = self
.store
.get_branch(&node)?
.ok_or_else(|| Error::MissingBranch(node))?;
let height = max(key.fork_height(branch_node.key()), branch_node.fork_height);
let is_right = key.get_bit(height);
let mut sibling_key = key.parent_path(height);
if !is_right {
sibling_key.set_bit(height);
};
match branch_node.node_at(height) {
NodeType::Pair(left, right) => {
if height > branch_node.fork_height {
cache.entry((height, sibling_key)).or_insert(node);
break;
} else {
let sibling;
if is_right {
if node == right {
break;
}
sibling = left;
node = right;
} else {
if node == left {
break;
}
sibling = right;
node = left;
}
cache.insert((height, sibling_key), sibling);
}
}
NodeType::Single(node) => {
if key != branch_node.key() {
cache.insert((height, sibling_key), node);
}
break;
}
}
}
Ok(())
}
pub fn merkle_proof(&self, mut keys: Vec<H256>) -> Result<MerkleProof> {
if keys.is_empty() {
return Err(Error::EmptyKeys);
}
keys.sort_unstable();
let mut cache: BTreeMap<(u8, H256), H256> = Default::default();
if !self.is_empty() {
for k in &keys {
self.fetch_merkle_path(k, &mut cache)?;
}
}
let mut proof: Vec<(H256, u8)> = Vec::with_capacity(EXPECTED_PATH_SIZE * keys.len());
let mut leaves_path: Vec<Vec<u8>> = Vec::with_capacity(keys.len());
leaves_path.resize_with(keys.len(), Default::default);
let keys_len = keys.len();
let mut queue: VecDeque<(H256, u8, usize)> = keys
.into_iter()
.enumerate()
.map(|(i, k)| (k, 0, i))
.collect();
while let Some((key, height, leaf_index)) = queue.pop_front() {
if queue.is_empty() && cache.is_empty() {
if leaves_path[leaf_index].is_empty() {
leaves_path[leaf_index].push(core::u8::MAX);
}
break;
}
let mut sibling_key = key.parent_path(height);
let is_right = key.get_bit(height);
if is_right {
sibling_key.clear_bit(height);
} else {
sibling_key.set_bit(height);
}
if Some((&sibling_key, &height))
== queue
.front()
.map(|(sibling_key, height, _leaf_index)| (sibling_key, height))
{
let (_sibling_key, height, leaf_index) = queue.pop_front().unwrap();
leaves_path[leaf_index].push(height);
} else {
match cache.remove(&(height, sibling_key)) {
Some(sibling) => {
proof.push((sibling, height));
}
None => {
if !is_right {
sibling_key.clear_bit(height);
}
if height == core::u8::MAX {
if leaves_path[leaf_index].is_empty() {
leaves_path[leaf_index].push(height);
}
break;
} else {
let parent_key = sibling_key;
queue.push_back((parent_key, height + 1, leaf_index));
continue;
}
}
}
}
leaves_path[leaf_index].push(height);
if height == core::u8::MAX {
break;
} else {
let parent_key = if is_right { sibling_key } else { key };
queue.push_back((parent_key, height + 1, leaf_index));
}
}
debug_assert_eq!(leaves_path.len(), keys_len);
Ok(MerkleProof::new(leaves_path, proof))
}
}