1use {
2 crate::{
3 changelog::ChangeLog,
4 error::ConcurrentMerkleTreeError,
5 hash::{fill_in_proof, hash_to_parent, recompute},
6 node::{empty_node, empty_node_cached, Node, EMPTY},
7 path::Path,
8 },
9 bytemuck::{Pod, Zeroable},
10 log_compute, solana_logging,
11};
12
13#[inline(always)]
15fn check_bounds(max_depth: usize, max_buffer_size: usize) {
16 assert!(max_depth < 31);
19 assert!(max_buffer_size & (max_buffer_size - 1) == 0);
21}
22
23fn check_leaf_index(leaf_index: u32, max_depth: usize) -> Result<(), ConcurrentMerkleTreeError> {
24 if leaf_index >= (1 << max_depth) {
25 return Err(ConcurrentMerkleTreeError::LeafIndexOutOfBounds);
26 }
27 Ok(())
28}
29
30#[repr(C)]
64#[derive(Copy, Clone)]
65pub struct ConcurrentMerkleTree<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize> {
66 pub sequence_number: u64,
67 pub active_index: u64,
69 pub buffer_size: u64,
71 pub change_logs: [ChangeLog<MAX_DEPTH>; MAX_BUFFER_SIZE],
73 pub rightmost_proof: Path<MAX_DEPTH>,
74}
75
76unsafe impl<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize> Zeroable
77 for ConcurrentMerkleTree<MAX_DEPTH, MAX_BUFFER_SIZE>
78{
79}
80unsafe impl<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize> Pod
81 for ConcurrentMerkleTree<MAX_DEPTH, MAX_BUFFER_SIZE>
82{
83}
84
85impl<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize> Default
86 for ConcurrentMerkleTree<MAX_DEPTH, MAX_BUFFER_SIZE>
87{
88 fn default() -> Self {
89 Self {
90 sequence_number: 0,
91 active_index: 0,
92 buffer_size: 0,
93 change_logs: [ChangeLog::<MAX_DEPTH>::default(); MAX_BUFFER_SIZE],
94 rightmost_proof: Path::<MAX_DEPTH>::default(),
95 }
96 }
97}
98
99pub struct InitializeWithRootArgs {
101 pub root: Node,
102 pub rightmost_leaf: Node,
103 pub proof_vec: Vec<Node>,
104 pub index: u32,
105}
106
107pub struct SetLeafArgs {
109 pub current_root: Node,
110 pub previous_leaf: Node,
111 pub new_leaf: Node,
112 pub proof_vec: Vec<Node>,
113 pub index: u32,
114}
115
116pub struct FillEmptyOrAppendArgs {
119 pub current_root: Node,
120 pub leaf: Node,
121 pub proof_vec: Vec<Node>,
122 pub index: u32,
123}
124
125pub struct ProveLeafArgs {
127 pub current_root: Node,
128 pub leaf: Node,
129 pub proof_vec: Vec<Node>,
130 pub index: u32,
131}
132
133impl<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize>
134 ConcurrentMerkleTree<MAX_DEPTH, MAX_BUFFER_SIZE>
135{
136 pub fn new() -> Self {
137 Self::default()
138 }
139
140 pub fn is_initialized(&self) -> bool {
141 !(self.buffer_size == 0 && self.sequence_number == 0 && self.active_index == 0)
142 }
143
144 pub fn initialize(&mut self) -> Result<Node, ConcurrentMerkleTreeError> {
147 check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
148 if self.is_initialized() {
149 return Err(ConcurrentMerkleTreeError::TreeAlreadyInitialized);
150 }
151 let mut rightmost_proof = Path::default();
152 let empty_node_cache = [Node::default(); MAX_DEPTH];
153 for (i, node) in rightmost_proof.proof.iter_mut().enumerate() {
154 *node = empty_node_cached::<MAX_DEPTH>(i as u32, &empty_node_cache);
155 }
156 let mut path = [Node::default(); MAX_DEPTH];
157 for (i, node) in path.iter_mut().enumerate() {
158 *node = empty_node_cached::<MAX_DEPTH>(i as u32, &empty_node_cache);
159 }
160 self.change_logs[0].root = empty_node(MAX_DEPTH as u32);
161 self.change_logs[0].path = path;
162 self.sequence_number = 0;
163 self.active_index = 0;
164 self.buffer_size = 1;
165 self.rightmost_proof = rightmost_proof;
166 Ok(self.change_logs[0].root)
167 }
168
169 pub fn initialize_with_root(
177 &mut self,
178 args: &InitializeWithRootArgs,
179 ) -> Result<Node, ConcurrentMerkleTreeError> {
180 check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
181 check_leaf_index(args.index, MAX_DEPTH)?;
182
183 if self.is_initialized() {
184 return Err(ConcurrentMerkleTreeError::TreeAlreadyInitialized);
185 }
186 let mut proof: [Node; MAX_DEPTH] = [Node::default(); MAX_DEPTH];
187 proof.copy_from_slice(&args.proof_vec);
188 let rightmost_proof = Path {
189 proof,
190 index: args.index + 1,
191 leaf: args.rightmost_leaf,
192 _padding: 0,
193 };
194 self.change_logs[0].root = args.root;
195 self.sequence_number = 1;
196 self.active_index = 0;
197 self.buffer_size = 1;
198 self.rightmost_proof = rightmost_proof;
199 if args.root != recompute(args.rightmost_leaf, &proof, args.index) {
200 solana_logging!("Proof failed to verify");
201 return Err(ConcurrentMerkleTreeError::InvalidProof);
202 }
203 Ok(args.root)
204 }
205
206 pub fn prove_tree_is_empty(&self) -> Result<(), ConcurrentMerkleTreeError> {
208 if !self.is_initialized() {
209 return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
210 }
211 let empty_node_cache = [EMPTY; MAX_DEPTH];
212 if self.get_root() != empty_node_cached::<MAX_DEPTH>(MAX_DEPTH as u32, &empty_node_cache) {
213 return Err(ConcurrentMerkleTreeError::TreeNonEmpty);
214 }
215 Ok(())
216 }
217
218 pub fn get_root(&self) -> [u8; 32] {
220 self.get_change_log().root
221 }
222
223 pub fn get_change_log(&self) -> Box<ChangeLog<MAX_DEPTH>> {
225 if !self.is_initialized() {
226 solana_logging!("Tree is not initialized, returning default change log");
227 return Box::<ChangeLog<MAX_DEPTH>>::default();
228 }
229 Box::new(self.change_logs[self.active_index as usize])
230 }
231
232 pub fn prove_leaf(&self, args: &ProveLeafArgs) -> Result<(), ConcurrentMerkleTreeError> {
244 check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
245 check_leaf_index(args.index, MAX_DEPTH)?;
246 if !self.is_initialized() {
247 return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
248 }
249
250 if args.index > self.rightmost_proof.index {
251 solana_logging!(
252 "Received an index larger than the rightmost index {} > {}",
253 args.index,
254 self.rightmost_proof.index
255 );
256 Err(ConcurrentMerkleTreeError::LeafIndexOutOfBounds)
257 } else {
258 let mut proof: [Node; MAX_DEPTH] = [Node::default(); MAX_DEPTH];
259 fill_in_proof::<MAX_DEPTH>(&args.proof_vec, &mut proof);
260 let valid_root =
261 self.check_valid_leaf(args.current_root, args.leaf, &mut proof, args.index, true)?;
262 if !valid_root {
263 solana_logging!("Proof failed to verify");
264 return Err(ConcurrentMerkleTreeError::InvalidProof);
265 }
266 Ok(())
267 }
268 }
269
270 #[inline(always)]
272 fn initialize_tree_from_append(
273 &mut self,
274 leaf: Node,
275 mut proof: [Node; MAX_DEPTH],
276 ) -> Result<Node, ConcurrentMerkleTreeError> {
277 let old_root = recompute(EMPTY, &proof, 0);
278 if old_root == empty_node(MAX_DEPTH as u32) {
279 self.try_apply_proof(old_root, EMPTY, leaf, &mut proof, 0, false)
280 } else {
281 Err(ConcurrentMerkleTreeError::TreeAlreadyInitialized)
282 }
283 }
284
285 pub fn append(&mut self, mut node: Node) -> Result<Node, ConcurrentMerkleTreeError> {
287 check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
288 if !self.is_initialized() {
289 return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
290 }
291 if node == EMPTY {
292 return Err(ConcurrentMerkleTreeError::CannotAppendEmptyNode);
293 }
294 if self.rightmost_proof.index >= 1 << MAX_DEPTH {
295 return Err(ConcurrentMerkleTreeError::TreeFull);
296 }
297 if self.rightmost_proof.index == 0 {
298 return self.initialize_tree_from_append(node, self.rightmost_proof.proof);
299 }
300 let leaf = node;
301 let intersection = self.rightmost_proof.index.trailing_zeros() as usize;
302 let mut change_list = [EMPTY; MAX_DEPTH];
303 let mut intersection_node = self.rightmost_proof.leaf;
304 let empty_node_cache = [Node::default(); MAX_DEPTH];
305
306 for (i, cl_item) in change_list.iter_mut().enumerate().take(MAX_DEPTH) {
307 *cl_item = node;
308 match i {
309 i if i < intersection => {
310 let sibling = empty_node_cached::<MAX_DEPTH>(i as u32, &empty_node_cache);
312 hash_to_parent(
313 &mut intersection_node,
314 &self.rightmost_proof.proof[i],
315 ((self.rightmost_proof.index - 1) >> i) & 1 == 0,
316 );
317 hash_to_parent(&mut node, &sibling, true);
318 self.rightmost_proof.proof[i] = sibling;
319 }
320 i if i == intersection => {
321 hash_to_parent(&mut node, &intersection_node, false);
323 self.rightmost_proof.proof[intersection] = intersection_node;
324 }
325 _ => {
326 hash_to_parent(
328 &mut node,
329 &self.rightmost_proof.proof[i],
330 ((self.rightmost_proof.index - 1) >> i) & 1 == 0,
331 );
332 }
333 }
334 }
335
336 self.update_internal_counters();
337 self.change_logs[self.active_index as usize] =
338 ChangeLog::<MAX_DEPTH>::new(node, change_list, self.rightmost_proof.index);
339 self.rightmost_proof.index += 1;
340 self.rightmost_proof.leaf = leaf;
341 Ok(node)
342 }
343
344 pub fn fill_empty_or_append(
349 &mut self,
350 args: &FillEmptyOrAppendArgs,
351 ) -> Result<Node, ConcurrentMerkleTreeError> {
352 check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
353 check_leaf_index(args.index, MAX_DEPTH)?;
354 if !self.is_initialized() {
355 return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
356 }
357
358 let mut proof: [Node; MAX_DEPTH] = [Node::default(); MAX_DEPTH];
359 fill_in_proof::<MAX_DEPTH>(&args.proof_vec, &mut proof);
360
361 log_compute!();
362 match self.try_apply_proof(
363 args.current_root,
364 EMPTY,
365 args.leaf,
366 &mut proof,
367 args.index,
368 false,
369 ) {
370 Ok(new_root) => Ok(new_root),
371 Err(error) => match error {
372 ConcurrentMerkleTreeError::LeafContentsModified => self.append(args.leaf),
373 _ => Err(error),
374 },
375 }
376 }
377
378 pub fn set_leaf(&mut self, args: &SetLeafArgs) -> Result<Node, ConcurrentMerkleTreeError> {
382 check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
383 check_leaf_index(args.index, MAX_DEPTH)?;
384 if !self.is_initialized() {
385 return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
386 }
387
388 if args.index > self.rightmost_proof.index {
389 Err(ConcurrentMerkleTreeError::LeafIndexOutOfBounds)
390 } else {
391 let mut proof: [Node; MAX_DEPTH] = [Node::default(); MAX_DEPTH];
392 fill_in_proof::<MAX_DEPTH>(&args.proof_vec, &mut proof);
393
394 log_compute!();
395 self.try_apply_proof(
396 args.current_root,
397 args.previous_leaf,
398 args.new_leaf,
399 &mut proof,
400 args.index,
401 true,
402 )
403 }
404 }
405
406 pub fn get_seq(&self) -> u64 {
410 self.sequence_number
411 }
412
413 #[inline(always)]
418 fn fast_forward_proof(
419 &self,
420 leaf: &mut Node,
421 proof: &mut [Node; MAX_DEPTH],
422 leaf_index: u32,
423 mut changelog_buffer_index: u64,
424 use_full_buffer: bool,
425 ) -> bool {
426 solana_logging!(
427 "Fast-forwarding proof, starting index {}",
428 changelog_buffer_index
429 );
430 let mask: usize = MAX_BUFFER_SIZE - 1;
431
432 let mut updated_leaf = *leaf;
433 log_compute!();
434 loop {
436 if !use_full_buffer && changelog_buffer_index == self.active_index {
439 break;
440 }
441 changelog_buffer_index = (changelog_buffer_index + 1) & mask as u64;
442 self.change_logs[changelog_buffer_index as usize].update_proof_or_leaf(
443 leaf_index,
444 proof,
445 &mut updated_leaf,
446 );
447 if use_full_buffer && changelog_buffer_index == self.active_index {
449 break;
450 }
451 }
452 log_compute!();
453 let proof_leaf_unchanged = updated_leaf == *leaf;
454 *leaf = updated_leaf;
455 proof_leaf_unchanged
456 }
457
458 #[inline(always)]
459 fn find_root_in_changelog(&self, current_root: Node) -> Option<u64> {
460 let mask: usize = MAX_BUFFER_SIZE - 1;
461 for i in 0..self.buffer_size {
462 let j = self.active_index.wrapping_sub(i) & mask as u64;
463 if self.change_logs[j as usize].root == current_root {
464 return Some(j);
465 }
466 }
467 None
468 }
469
470 #[inline(always)]
471 fn check_valid_leaf(
472 &self,
473 current_root: Node,
474 leaf: Node,
475 proof: &mut [Node; MAX_DEPTH],
476 leaf_index: u32,
477 allow_inferred_proof: bool,
478 ) -> Result<bool, ConcurrentMerkleTreeError> {
479 let mask: usize = MAX_BUFFER_SIZE - 1;
480 let (changelog_index, use_full_buffer) = match self.find_root_in_changelog(current_root) {
481 Some(matching_changelog_index) => (matching_changelog_index, false),
482 None => {
483 if allow_inferred_proof {
484 solana_logging!("Failed to find root in change log -> replaying full buffer");
485 (
486 self.active_index.wrapping_sub(self.buffer_size - 1) & mask as u64,
487 true,
488 )
489 } else {
490 return Err(ConcurrentMerkleTreeError::RootNotFound);
491 }
492 }
493 };
494 let mut updatable_leaf_node = leaf;
495 let proof_leaf_unchanged = self.fast_forward_proof(
496 &mut updatable_leaf_node,
497 proof,
498 leaf_index,
499 changelog_index,
500 use_full_buffer,
501 );
502 if !proof_leaf_unchanged {
503 return Err(ConcurrentMerkleTreeError::LeafContentsModified);
504 }
505 Ok(self.check_valid_proof(updatable_leaf_node, proof, leaf_index))
506 }
507
508 pub fn check_valid_proof(
510 &self,
511 leaf: Node,
512 proof: &[Node; MAX_DEPTH],
513 leaf_index: u32,
514 ) -> bool {
515 if !self.is_initialized() {
516 solana_logging!("Tree is not initialized, returning false");
517 return false;
518 }
519 if check_leaf_index(leaf_index, MAX_DEPTH).is_err() {
520 solana_logging!("Leaf index out of bounds for max_depth");
521 return false;
522 }
523 recompute(leaf, proof, leaf_index) == self.get_root()
524 }
525
526 #[inline(always)]
530 fn try_apply_proof(
531 &mut self,
532 current_root: Node,
533 leaf: Node,
534 new_leaf: Node,
535 proof: &mut [Node; MAX_DEPTH],
536 leaf_index: u32,
537 allow_inferred_proof: bool,
538 ) -> Result<Node, ConcurrentMerkleTreeError> {
539 solana_logging!("Active Index: {}", self.active_index);
540 solana_logging!("Rightmost Index: {}", self.rightmost_proof.index);
541 solana_logging!("Buffer Size: {}", self.buffer_size);
542 solana_logging!("Leaf Index: {}", leaf_index);
543 let valid_root =
544 self.check_valid_leaf(current_root, leaf, proof, leaf_index, allow_inferred_proof)?;
545 if !valid_root {
546 return Err(ConcurrentMerkleTreeError::InvalidProof);
547 }
548 self.update_internal_counters();
549 Ok(self.update_buffers_from_proof(new_leaf, proof, leaf_index))
550 }
551
552 fn update_internal_counters(&mut self) {
554 let mask: usize = MAX_BUFFER_SIZE - 1;
555 self.active_index += 1;
556 self.active_index &= mask as u64;
557 if self.buffer_size < MAX_BUFFER_SIZE as u64 {
558 self.buffer_size += 1;
559 }
560 self.sequence_number = self.sequence_number.saturating_add(1);
561 }
562
563 fn update_buffers_from_proof(&mut self, start: Node, proof: &[Node], index: u32) -> Node {
566 let change_log = &mut self.change_logs[self.active_index as usize];
567 let root = change_log.replace_and_recompute_path(index, start, proof);
569 if self.rightmost_proof.index < (1 << MAX_DEPTH) {
571 if index < self.rightmost_proof.index {
572 change_log.update_proof_or_leaf(
573 self.rightmost_proof.index - 1,
574 &mut self.rightmost_proof.proof,
575 &mut self.rightmost_proof.leaf,
576 );
577 } else {
578 assert!(index == self.rightmost_proof.index);
579 solana_logging!("Appending rightmost leaf");
580 self.rightmost_proof.proof.copy_from_slice(proof);
581 self.rightmost_proof.index = index + 1;
582 self.rightmost_proof.leaf = change_log.get_leaf();
583 }
584 }
585 root
586 }
587}