use crate::{
changelog::ChangeLog,
error::ConcurrentMerkleTreeError,
hash::{fill_in_proof, hash_to_parent, recompute},
node::{empty_node, empty_node_cached, Node, EMPTY},
path::Path,
};
use bytemuck::{Pod, Zeroable};
use log_compute;
use solana_logging;
#[inline(always)]
fn check_bounds(max_depth: usize, max_buffer_size: usize) {
assert!(max_depth < 31);
assert!(max_buffer_size & (max_buffer_size - 1) == 0);
}
fn check_leaf_index(leaf_index: u32, max_depth: usize) -> Result<(), ConcurrentMerkleTreeError> {
if leaf_index >= (1 << max_depth) {
return Err(ConcurrentMerkleTreeError::LeafIndexOutOfBounds);
}
Ok(())
}
#[repr(C)]
#[derive(Copy, Clone)]
pub struct ConcurrentMerkleTree<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize> {
pub sequence_number: u64,
pub active_index: u64,
pub buffer_size: u64,
pub change_logs: [ChangeLog<MAX_DEPTH>; MAX_BUFFER_SIZE],
pub rightmost_proof: Path<MAX_DEPTH>,
}
unsafe impl<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize> Zeroable
for ConcurrentMerkleTree<MAX_DEPTH, MAX_BUFFER_SIZE>
{
}
unsafe impl<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize> Pod
for ConcurrentMerkleTree<MAX_DEPTH, MAX_BUFFER_SIZE>
{
}
impl<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize> Default
for ConcurrentMerkleTree<MAX_DEPTH, MAX_BUFFER_SIZE>
{
fn default() -> Self {
Self {
sequence_number: 0,
active_index: 0,
buffer_size: 0,
change_logs: [ChangeLog::<MAX_DEPTH>::default(); MAX_BUFFER_SIZE],
rightmost_proof: Path::<MAX_DEPTH>::default(),
}
}
}
impl<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize>
ConcurrentMerkleTree<MAX_DEPTH, MAX_BUFFER_SIZE>
{
pub fn new() -> Self {
Self::default()
}
pub fn is_initialized(&self) -> bool {
!(self.buffer_size == 0 && self.sequence_number == 0 && self.active_index == 0)
}
pub fn initialize(&mut self) -> Result<Node, ConcurrentMerkleTreeError> {
check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
if self.is_initialized() {
return Err(ConcurrentMerkleTreeError::TreeAlreadyInitialized);
}
let mut rightmost_proof = Path::default();
let mut empty_node_cache = Box::new([Node::default(); MAX_DEPTH]);
for (i, node) in rightmost_proof.proof.iter_mut().enumerate() {
*node = empty_node_cached::<MAX_DEPTH>(i as u32, &mut empty_node_cache);
}
let mut path = [Node::default(); MAX_DEPTH];
for (i, node) in path.iter_mut().enumerate() {
*node = empty_node_cached::<MAX_DEPTH>(i as u32, &mut empty_node_cache);
}
self.change_logs[0].root = empty_node(MAX_DEPTH as u32);
self.change_logs[0].path = path;
self.sequence_number = 0;
self.active_index = 0;
self.buffer_size = 1;
self.rightmost_proof = rightmost_proof;
Ok(self.change_logs[0].root)
}
pub fn initialize_with_root(
&mut self,
root: Node,
rightmost_leaf: Node,
proof_vec: &[Node],
index: u32,
) -> Result<Node, ConcurrentMerkleTreeError> {
check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
check_leaf_index(index, MAX_DEPTH)?;
if self.is_initialized() {
return Err(ConcurrentMerkleTreeError::TreeAlreadyInitialized);
}
let mut proof: [Node; MAX_DEPTH] = [Node::default(); MAX_DEPTH];
proof.copy_from_slice(proof_vec);
let rightmost_proof = Path {
proof,
index: index + 1,
leaf: rightmost_leaf,
_padding: 0,
};
self.change_logs[0].root = root;
self.sequence_number = 1;
self.active_index = 0;
self.buffer_size = 1;
self.rightmost_proof = rightmost_proof;
if root != recompute(rightmost_leaf, &proof, index) {
solana_logging!("Proof failed to verify");
return Err(ConcurrentMerkleTreeError::InvalidProof);
}
Ok(root)
}
pub fn prove_tree_is_empty(&self) -> Result<(), ConcurrentMerkleTreeError> {
if !self.is_initialized() {
return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
}
let mut empty_node_cache = Box::new([EMPTY; MAX_DEPTH]);
if self.get_root()
!= empty_node_cached::<MAX_DEPTH>(MAX_DEPTH as u32, &mut empty_node_cache)
{
return Err(ConcurrentMerkleTreeError::TreeNonEmpty);
}
Ok(())
}
pub fn get_root(&self) -> [u8; 32] {
self.get_change_log().root
}
pub fn get_change_log(&self) -> Box<ChangeLog<MAX_DEPTH>> {
if !self.is_initialized() {
solana_logging!("Tree is not initialized, returning default change log");
return Box::<ChangeLog<MAX_DEPTH>>::default();
}
Box::new(self.change_logs[self.active_index as usize])
}
pub fn prove_leaf(
&self,
current_root: Node,
leaf: Node,
proof_vec: &[Node],
leaf_index: u32,
) -> Result<(), ConcurrentMerkleTreeError> {
check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
check_leaf_index(leaf_index, MAX_DEPTH)?;
if !self.is_initialized() {
return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
}
if leaf_index > self.rightmost_proof.index {
solana_logging!(
"Received an index larger than the rightmost index {} > {}",
leaf_index,
self.rightmost_proof.index
);
Err(ConcurrentMerkleTreeError::LeafIndexOutOfBounds)
} else {
let mut proof: [Node; MAX_DEPTH] = [Node::default(); MAX_DEPTH];
fill_in_proof::<MAX_DEPTH>(proof_vec, &mut proof);
let valid_root =
self.check_valid_leaf(current_root, leaf, &mut proof, leaf_index, true)?;
if !valid_root {
solana_logging!("Proof failed to verify");
return Err(ConcurrentMerkleTreeError::InvalidProof);
}
Ok(())
}
}
#[inline(always)]
fn initialize_tree_from_append(
&mut self,
leaf: Node,
mut proof: [Node; MAX_DEPTH],
) -> Result<Node, ConcurrentMerkleTreeError> {
let old_root = recompute(EMPTY, &proof, 0);
if old_root == empty_node(MAX_DEPTH as u32) {
self.try_apply_proof(old_root, EMPTY, leaf, &mut proof, 0, false)
} else {
Err(ConcurrentMerkleTreeError::TreeAlreadyInitialized)
}
}
pub fn append(&mut self, mut node: Node) -> Result<Node, ConcurrentMerkleTreeError> {
check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
if !self.is_initialized() {
return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
}
if node == EMPTY {
return Err(ConcurrentMerkleTreeError::CannotAppendEmptyNode);
}
if self.rightmost_proof.index >= 1 << MAX_DEPTH {
return Err(ConcurrentMerkleTreeError::TreeFull);
}
if self.rightmost_proof.index == 0 {
return self.initialize_tree_from_append(node, self.rightmost_proof.proof);
}
let leaf = node;
let intersection = self.rightmost_proof.index.trailing_zeros() as usize;
let mut change_list = [EMPTY; MAX_DEPTH];
let mut intersection_node = self.rightmost_proof.leaf;
let mut empty_node_cache = Box::new([Node::default(); MAX_DEPTH]);
for (i, cl_item) in change_list.iter_mut().enumerate().take(MAX_DEPTH) {
*cl_item = node;
match i {
i if i < intersection => {
let sibling = empty_node_cached::<MAX_DEPTH>(i as u32, &mut empty_node_cache);
hash_to_parent(
&mut intersection_node,
&self.rightmost_proof.proof[i],
((self.rightmost_proof.index - 1) >> i) & 1 == 0,
);
hash_to_parent(&mut node, &sibling, true);
self.rightmost_proof.proof[i] = sibling;
}
i if i == intersection => {
hash_to_parent(&mut node, &intersection_node, false);
self.rightmost_proof.proof[intersection] = intersection_node;
}
_ => {
hash_to_parent(
&mut node,
&self.rightmost_proof.proof[i],
((self.rightmost_proof.index - 1) >> i) & 1 == 0,
);
}
}
}
self.update_internal_counters();
self.change_logs[self.active_index as usize] =
ChangeLog::<MAX_DEPTH>::new(node, change_list, self.rightmost_proof.index);
self.rightmost_proof.index += 1;
self.rightmost_proof.leaf = leaf;
Ok(node)
}
pub fn fill_empty_or_append(
&mut self,
current_root: Node,
leaf: Node,
proof_vec: &[Node],
index: u32,
) -> Result<Node, ConcurrentMerkleTreeError> {
check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
check_leaf_index(index, MAX_DEPTH)?;
if !self.is_initialized() {
return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
}
let mut proof: [Node; MAX_DEPTH] = [Node::default(); MAX_DEPTH];
fill_in_proof::<MAX_DEPTH>(proof_vec, &mut proof);
log_compute!();
match self.try_apply_proof(current_root, EMPTY, leaf, &mut proof, index, false) {
Ok(new_root) => Ok(new_root),
Err(error) => match error {
ConcurrentMerkleTreeError::LeafContentsModified => self.append(leaf),
_ => Err(error),
},
}
}
pub fn set_leaf(
&mut self,
current_root: Node,
previous_leaf: Node,
new_leaf: Node,
proof_vec: &[Node],
index: u32,
) -> Result<Node, ConcurrentMerkleTreeError> {
check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
check_leaf_index(index, MAX_DEPTH)?;
if !self.is_initialized() {
return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
}
if index > self.rightmost_proof.index {
Err(ConcurrentMerkleTreeError::LeafIndexOutOfBounds)
} else {
let mut proof: [Node; MAX_DEPTH] = [Node::default(); MAX_DEPTH];
fill_in_proof::<MAX_DEPTH>(proof_vec, &mut proof);
log_compute!();
self.try_apply_proof(
current_root,
previous_leaf,
new_leaf,
&mut proof,
index,
true,
)
}
}
pub fn get_seq(&self) -> u64 {
self.sequence_number
}
#[inline(always)]
fn fast_forward_proof(
&self,
leaf: &mut Node,
proof: &mut [Node; MAX_DEPTH],
leaf_index: u32,
mut changelog_buffer_index: u64,
use_full_buffer: bool,
) -> bool {
solana_logging!(
"Fast-forwarding proof, starting index {}",
changelog_buffer_index
);
let mask: usize = MAX_BUFFER_SIZE - 1;
let mut updated_leaf = *leaf;
log_compute!();
loop {
if !use_full_buffer && changelog_buffer_index == self.active_index {
break;
}
changelog_buffer_index = (changelog_buffer_index + 1) & mask as u64;
self.change_logs[changelog_buffer_index as usize].update_proof_or_leaf(
leaf_index,
proof,
&mut updated_leaf,
);
if use_full_buffer && changelog_buffer_index == self.active_index {
break;
}
}
log_compute!();
let proof_leaf_unchanged = updated_leaf == *leaf;
*leaf = updated_leaf;
proof_leaf_unchanged
}
#[inline(always)]
fn find_root_in_changelog(&self, current_root: Node) -> Option<u64> {
let mask: usize = MAX_BUFFER_SIZE - 1;
for i in 0..self.buffer_size {
let j = self.active_index.wrapping_sub(i) & mask as u64;
if self.change_logs[j as usize].root == current_root {
return Some(j);
}
}
None
}
#[inline(always)]
fn check_valid_leaf(
&self,
current_root: Node,
leaf: Node,
proof: &mut [Node; MAX_DEPTH],
leaf_index: u32,
allow_inferred_proof: bool,
) -> Result<bool, ConcurrentMerkleTreeError> {
let mask: usize = MAX_BUFFER_SIZE - 1;
let (changelog_index, use_full_buffer) = match self.find_root_in_changelog(current_root) {
Some(matching_changelog_index) => (matching_changelog_index, false),
None => {
if allow_inferred_proof {
solana_logging!("Failed to find root in change log -> replaying full buffer");
(
self.active_index.wrapping_sub(self.buffer_size - 1) & mask as u64,
true,
)
} else {
return Err(ConcurrentMerkleTreeError::RootNotFound);
}
}
};
let mut updatable_leaf_node = leaf;
let proof_leaf_unchanged = self.fast_forward_proof(
&mut updatable_leaf_node,
proof,
leaf_index,
changelog_index,
use_full_buffer,
);
if !proof_leaf_unchanged {
return Err(ConcurrentMerkleTreeError::LeafContentsModified);
}
Ok(self.check_valid_proof(updatable_leaf_node, proof, leaf_index))
}
pub fn check_valid_proof(
&self,
leaf: Node,
proof: &[Node; MAX_DEPTH],
leaf_index: u32,
) -> bool {
if !self.is_initialized() {
solana_logging!("Tree is not initialized, returning false");
return false;
}
if check_leaf_index(leaf_index, MAX_DEPTH).is_err() {
solana_logging!("Leaf index out of bounds for max_depth");
return false;
}
recompute(leaf, proof, leaf_index) == self.get_root()
}
#[inline(always)]
fn try_apply_proof(
&mut self,
current_root: Node,
leaf: Node,
new_leaf: Node,
proof: &mut [Node; MAX_DEPTH],
leaf_index: u32,
allow_inferred_proof: bool,
) -> Result<Node, ConcurrentMerkleTreeError> {
solana_logging!("Active Index: {}", self.active_index);
solana_logging!("Rightmost Index: {}", self.rightmost_proof.index);
solana_logging!("Buffer Size: {}", self.buffer_size);
solana_logging!("Leaf Index: {}", leaf_index);
let valid_root =
self.check_valid_leaf(current_root, leaf, proof, leaf_index, allow_inferred_proof)?;
if !valid_root {
return Err(ConcurrentMerkleTreeError::InvalidProof);
}
self.update_internal_counters();
Ok(self.update_buffers_from_proof(new_leaf, proof, leaf_index))
}
fn update_internal_counters(&mut self) {
let mask: usize = MAX_BUFFER_SIZE - 1;
self.active_index += 1;
self.active_index &= mask as u64;
if self.buffer_size < MAX_BUFFER_SIZE as u64 {
self.buffer_size += 1;
}
self.sequence_number = self.sequence_number.saturating_add(1);
}
fn update_buffers_from_proof(&mut self, start: Node, proof: &[Node], index: u32) -> Node {
let change_log = &mut self.change_logs[self.active_index as usize];
let root = change_log.replace_and_recompute_path(index, start, proof);
if self.rightmost_proof.index < (1 << MAX_DEPTH) {
if index < self.rightmost_proof.index {
change_log.update_proof_or_leaf(
self.rightmost_proof.index - 1,
&mut self.rightmost_proof.proof,
&mut self.rightmost_proof.leaf,
);
} else {
assert!(index == self.rightmost_proof.index);
solana_logging!("Appending rightmost leaf");
self.rightmost_proof.proof.copy_from_slice(proof);
self.rightmost_proof.index = index + 1;
self.rightmost_proof.leaf = change_log.get_leaf();
}
}
root
}
}