spl_concurrent_merkle_tree/
concurrent_merkle_tree.rs

1use {
2    crate::{
3        changelog::ChangeLog,
4        error::ConcurrentMerkleTreeError,
5        hash::{fill_in_proof, hash_to_parent, recompute},
6        node::{empty_node, empty_node_cached, Node, EMPTY},
7        path::Path,
8    },
9    bytemuck::{Pod, Zeroable},
10    log_compute, solana_logging,
11};
12
13/// Enforce constraints on max depth and buffer size
14#[inline(always)]
15fn check_bounds(max_depth: usize, max_buffer_size: usize) {
16    // We cannot allow a tree depth greater than 30 because of the bit math
17    // required to update `ChangeLog`s
18    assert!(max_depth < 31);
19    // This will return true if MAX_BUFFER_SIZE is a power of 2 or if it is 0
20    assert!(max_buffer_size & (max_buffer_size - 1) == 0);
21}
22
23fn check_leaf_index(leaf_index: u32, max_depth: usize) -> Result<(), ConcurrentMerkleTreeError> {
24    if leaf_index >= (1 << max_depth) {
25        return Err(ConcurrentMerkleTreeError::LeafIndexOutOfBounds);
26    }
27    Ok(())
28}
29
30/// Conurrent Merkle Tree is a Merkle Tree that allows
31/// multiple tree operations targeted for the same tree root to succeed.
32///
33/// In a normal merkle tree, only the first tree operation will succeed because
34/// the following operations will have proofs for the unmodified tree state.
35/// ConcurrentMerkleTree avoids this by storing a buffer of modified nodes
36/// (`change_logs`) which allows it to implement fast-forwarding of concurrent
37/// merkle tree operations.
38///
39/// As long as the concurrent merkle tree operations
40/// have proofs that are valid for a previous state of the tree that can be
41/// found in the stored buffer, that tree operation's proof can be
42/// fast-forwarded and the tree operation can be applied.
43///
44/// There are two primitive operations for Concurrent Merkle Trees:
45/// [set_leaf](ConcurrentMerkleTree:set_leaf) and
46/// [append](ConcurrentMerkleTree::append). Setting a leaf value requires
47/// passing a proof to perform that tree operation, but appending does not
48/// require a proof.
49///
50/// An additional key property of ConcurrentMerkleTree is support for
51/// [append](ConcurrentMerkleTree::append) operations, which do not require any
52/// proofs to be passed. This is accomplished by keeping track of the
53/// proof to the rightmost leaf in the tree (`rightmost_proof`).
54///
55/// The `ConcurrentMerkleTree` is a generic struct that may be interacted with
56/// using macros. Those macros may wrap up the construction and both mutable and
57/// immutable calls to the `ConcurrentMerkleTree` struct. If the macro contains
58/// a big match statement over different sizes of a tree and buffer, it might
59/// create a huge stack footprint. This in turn might lead to a stack overflow
60/// given the max stack offset of just 4kb. In order to minimize the stack frame
61/// size, the arguments for the `ConcurrentMerkleTree` methods that contain the
62/// proofs are passed as references to structs.
63#[repr(C)]
64#[derive(Copy, Clone)]
65pub struct ConcurrentMerkleTree<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize> {
66    pub sequence_number: u64,
67    /// Index of most recent root & changes
68    pub active_index: u64,
69    /// Number of active changes we are tracking
70    pub buffer_size: u64,
71    /// Proof for respective root
72    pub change_logs: [ChangeLog<MAX_DEPTH>; MAX_BUFFER_SIZE],
73    pub rightmost_proof: Path<MAX_DEPTH>,
74}
75
76unsafe impl<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize> Zeroable
77    for ConcurrentMerkleTree<MAX_DEPTH, MAX_BUFFER_SIZE>
78{
79}
80unsafe impl<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize> Pod
81    for ConcurrentMerkleTree<MAX_DEPTH, MAX_BUFFER_SIZE>
82{
83}
84
85impl<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize> Default
86    for ConcurrentMerkleTree<MAX_DEPTH, MAX_BUFFER_SIZE>
87{
88    fn default() -> Self {
89        Self {
90            sequence_number: 0,
91            active_index: 0,
92            buffer_size: 0,
93            change_logs: [ChangeLog::<MAX_DEPTH>::default(); MAX_BUFFER_SIZE],
94            rightmost_proof: Path::<MAX_DEPTH>::default(),
95        }
96    }
97}
98
99/// Arguments structure for initializing a tree with a root.
100pub struct InitializeWithRootArgs {
101    pub root: Node,
102    pub rightmost_leaf: Node,
103    pub proof_vec: Vec<Node>,
104    pub index: u32,
105}
106
107/// Arguments structure for setting a leaf in the tree.
108pub struct SetLeafArgs {
109    pub current_root: Node,
110    pub previous_leaf: Node,
111    pub new_leaf: Node,
112    pub proof_vec: Vec<Node>,
113    pub index: u32,
114}
115
116/// Arguments structure for filling an empty leaf or appending a new leaf to the
117/// tree.
118pub struct FillEmptyOrAppendArgs {
119    pub current_root: Node,
120    pub leaf: Node,
121    pub proof_vec: Vec<Node>,
122    pub index: u32,
123}
124
125/// Arguments structure for proving a leaf in the tree.
126pub struct ProveLeafArgs {
127    pub current_root: Node,
128    pub leaf: Node,
129    pub proof_vec: Vec<Node>,
130    pub index: u32,
131}
132
133impl<const MAX_DEPTH: usize, const MAX_BUFFER_SIZE: usize>
134    ConcurrentMerkleTree<MAX_DEPTH, MAX_BUFFER_SIZE>
135{
136    pub fn new() -> Self {
137        Self::default()
138    }
139
140    pub fn is_initialized(&self) -> bool {
141        !(self.buffer_size == 0 && self.sequence_number == 0 && self.active_index == 0)
142    }
143
144    /// This is the trustless initialization method that should be used in most
145    /// cases.
146    pub fn initialize(&mut self) -> Result<Node, ConcurrentMerkleTreeError> {
147        check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
148        if self.is_initialized() {
149            return Err(ConcurrentMerkleTreeError::TreeAlreadyInitialized);
150        }
151        let mut rightmost_proof = Path::default();
152        let empty_node_cache = [Node::default(); MAX_DEPTH];
153        for (i, node) in rightmost_proof.proof.iter_mut().enumerate() {
154            *node = empty_node_cached::<MAX_DEPTH>(i as u32, &empty_node_cache);
155        }
156        let mut path = [Node::default(); MAX_DEPTH];
157        for (i, node) in path.iter_mut().enumerate() {
158            *node = empty_node_cached::<MAX_DEPTH>(i as u32, &empty_node_cache);
159        }
160        self.change_logs[0].root = empty_node(MAX_DEPTH as u32);
161        self.change_logs[0].path = path;
162        self.sequence_number = 0;
163        self.active_index = 0;
164        self.buffer_size = 1;
165        self.rightmost_proof = rightmost_proof;
166        Ok(self.change_logs[0].root)
167    }
168
169    /// This is a trustful initialization method that assumes the root contains
170    /// the expected leaves.
171    ///
172    /// At the time of this crate's publishing, there is no supported way to
173    /// efficiently verify a pre-initialized root on-chain. Using this
174    /// method before having a method for on-chain verification will prevent
175    /// other applications from indexing the leaf data stored in this tree.
176    pub fn initialize_with_root(
177        &mut self,
178        args: &InitializeWithRootArgs,
179    ) -> Result<Node, ConcurrentMerkleTreeError> {
180        check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
181        check_leaf_index(args.index, MAX_DEPTH)?;
182
183        if self.is_initialized() {
184            return Err(ConcurrentMerkleTreeError::TreeAlreadyInitialized);
185        }
186        let mut proof: [Node; MAX_DEPTH] = [Node::default(); MAX_DEPTH];
187        proof.copy_from_slice(&args.proof_vec);
188        let rightmost_proof = Path {
189            proof,
190            index: args.index + 1,
191            leaf: args.rightmost_leaf,
192            _padding: 0,
193        };
194        self.change_logs[0].root = args.root;
195        self.sequence_number = 1;
196        self.active_index = 0;
197        self.buffer_size = 1;
198        self.rightmost_proof = rightmost_proof;
199        if args.root != recompute(args.rightmost_leaf, &proof, args.index) {
200            solana_logging!("Proof failed to verify");
201            return Err(ConcurrentMerkleTreeError::InvalidProof);
202        }
203        Ok(args.root)
204    }
205
206    /// Errors if one of the leaves of the current merkle tree is non-EMPTY
207    pub fn prove_tree_is_empty(&self) -> Result<(), ConcurrentMerkleTreeError> {
208        if !self.is_initialized() {
209            return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
210        }
211        let empty_node_cache = [EMPTY; MAX_DEPTH];
212        if self.get_root() != empty_node_cached::<MAX_DEPTH>(MAX_DEPTH as u32, &empty_node_cache) {
213            return Err(ConcurrentMerkleTreeError::TreeNonEmpty);
214        }
215        Ok(())
216    }
217
218    /// Returns the current root of the merkle tree
219    pub fn get_root(&self) -> [u8; 32] {
220        self.get_change_log().root
221    }
222
223    /// Returns the most recent changelog
224    pub fn get_change_log(&self) -> Box<ChangeLog<MAX_DEPTH>> {
225        if !self.is_initialized() {
226            solana_logging!("Tree is not initialized, returning default change log");
227            return Box::<ChangeLog<MAX_DEPTH>>::default();
228        }
229        Box::new(self.change_logs[self.active_index as usize])
230    }
231
232    /// This method will fail if the leaf cannot be proven
233    /// to exist in the current tree root.
234    ///
235    /// This method will attempts to prove the leaf first
236    /// using the proof nodes provided. However if this fails,
237    /// then a proof will be constructed by inferring a proof
238    /// from the changelog buffer.
239    ///
240    /// Note: this is *not* the same as verifying that a (proof, leaf)
241    /// combination is valid for the given root. That functionality
242    /// is provided by `check_valid_proof`.
243    pub fn prove_leaf(&self, args: &ProveLeafArgs) -> Result<(), ConcurrentMerkleTreeError> {
244        check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
245        check_leaf_index(args.index, MAX_DEPTH)?;
246        if !self.is_initialized() {
247            return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
248        }
249
250        if args.index > self.rightmost_proof.index {
251            solana_logging!(
252                "Received an index larger than the rightmost index {} > {}",
253                args.index,
254                self.rightmost_proof.index
255            );
256            Err(ConcurrentMerkleTreeError::LeafIndexOutOfBounds)
257        } else {
258            let mut proof: [Node; MAX_DEPTH] = [Node::default(); MAX_DEPTH];
259            fill_in_proof::<MAX_DEPTH>(&args.proof_vec, &mut proof);
260            let valid_root =
261                self.check_valid_leaf(args.current_root, args.leaf, &mut proof, args.index, true)?;
262            if !valid_root {
263                solana_logging!("Proof failed to verify");
264                return Err(ConcurrentMerkleTreeError::InvalidProof);
265            }
266            Ok(())
267        }
268    }
269
270    /// Only used to initialize right most path for a completely empty tree.
271    #[inline(always)]
272    fn initialize_tree_from_append(
273        &mut self,
274        leaf: Node,
275        mut proof: [Node; MAX_DEPTH],
276    ) -> Result<Node, ConcurrentMerkleTreeError> {
277        let old_root = recompute(EMPTY, &proof, 0);
278        if old_root == empty_node(MAX_DEPTH as u32) {
279            self.try_apply_proof(old_root, EMPTY, leaf, &mut proof, 0, false)
280        } else {
281            Err(ConcurrentMerkleTreeError::TreeAlreadyInitialized)
282        }
283    }
284
285    /// Appending a non-empty Node will always succeed .
286    pub fn append(&mut self, mut node: Node) -> Result<Node, ConcurrentMerkleTreeError> {
287        check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
288        if !self.is_initialized() {
289            return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
290        }
291        if node == EMPTY {
292            return Err(ConcurrentMerkleTreeError::CannotAppendEmptyNode);
293        }
294        if self.rightmost_proof.index >= 1 << MAX_DEPTH {
295            return Err(ConcurrentMerkleTreeError::TreeFull);
296        }
297        if self.rightmost_proof.index == 0 {
298            return self.initialize_tree_from_append(node, self.rightmost_proof.proof);
299        }
300        let leaf = node;
301        let intersection = self.rightmost_proof.index.trailing_zeros() as usize;
302        let mut change_list = [EMPTY; MAX_DEPTH];
303        let mut intersection_node = self.rightmost_proof.leaf;
304        let empty_node_cache = [Node::default(); MAX_DEPTH];
305
306        for (i, cl_item) in change_list.iter_mut().enumerate().take(MAX_DEPTH) {
307            *cl_item = node;
308            match i {
309                i if i < intersection => {
310                    // Compute proof to the appended node from empty nodes
311                    let sibling = empty_node_cached::<MAX_DEPTH>(i as u32, &empty_node_cache);
312                    hash_to_parent(
313                        &mut intersection_node,
314                        &self.rightmost_proof.proof[i],
315                        ((self.rightmost_proof.index - 1) >> i) & 1 == 0,
316                    );
317                    hash_to_parent(&mut node, &sibling, true);
318                    self.rightmost_proof.proof[i] = sibling;
319                }
320                i if i == intersection => {
321                    // Compute the where the new node intersects the main tree
322                    hash_to_parent(&mut node, &intersection_node, false);
323                    self.rightmost_proof.proof[intersection] = intersection_node;
324                }
325                _ => {
326                    // Update the change list path up to the root
327                    hash_to_parent(
328                        &mut node,
329                        &self.rightmost_proof.proof[i],
330                        ((self.rightmost_proof.index - 1) >> i) & 1 == 0,
331                    );
332                }
333            }
334        }
335
336        self.update_internal_counters();
337        self.change_logs[self.active_index as usize] =
338            ChangeLog::<MAX_DEPTH>::new(node, change_list, self.rightmost_proof.index);
339        self.rightmost_proof.index += 1;
340        self.rightmost_proof.leaf = leaf;
341        Ok(node)
342    }
343
344    /// Convenience function for `set_leaf`
345    ///
346    /// This method will `set_leaf` if the leaf at `index` is an empty node,
347    /// otherwise it will `append` the new leaf.
348    pub fn fill_empty_or_append(
349        &mut self,
350        args: &FillEmptyOrAppendArgs,
351    ) -> Result<Node, ConcurrentMerkleTreeError> {
352        check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
353        check_leaf_index(args.index, MAX_DEPTH)?;
354        if !self.is_initialized() {
355            return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
356        }
357
358        let mut proof: [Node; MAX_DEPTH] = [Node::default(); MAX_DEPTH];
359        fill_in_proof::<MAX_DEPTH>(&args.proof_vec, &mut proof);
360
361        log_compute!();
362        match self.try_apply_proof(
363            args.current_root,
364            EMPTY,
365            args.leaf,
366            &mut proof,
367            args.index,
368            false,
369        ) {
370            Ok(new_root) => Ok(new_root),
371            Err(error) => match error {
372                ConcurrentMerkleTreeError::LeafContentsModified => self.append(args.leaf),
373                _ => Err(error),
374            },
375        }
376    }
377
378    /// This method will update the leaf at `index`.
379    ///
380    /// However if the proof cannot be verified, this method will fail.
381    pub fn set_leaf(&mut self, args: &SetLeafArgs) -> Result<Node, ConcurrentMerkleTreeError> {
382        check_bounds(MAX_DEPTH, MAX_BUFFER_SIZE);
383        check_leaf_index(args.index, MAX_DEPTH)?;
384        if !self.is_initialized() {
385            return Err(ConcurrentMerkleTreeError::TreeNotInitialized);
386        }
387
388        if args.index > self.rightmost_proof.index {
389            Err(ConcurrentMerkleTreeError::LeafIndexOutOfBounds)
390        } else {
391            let mut proof: [Node; MAX_DEPTH] = [Node::default(); MAX_DEPTH];
392            fill_in_proof::<MAX_DEPTH>(&args.proof_vec, &mut proof);
393
394            log_compute!();
395            self.try_apply_proof(
396                args.current_root,
397                args.previous_leaf,
398                args.new_leaf,
399                &mut proof,
400                args.index,
401                true,
402            )
403        }
404    }
405
406    /// Returns the Current Seq of the tree, the seq is the monotonic counter of
407    /// the tree operations that is incremented every time a mutable
408    /// operation is performed on the tree.
409    pub fn get_seq(&self) -> u64 {
410        self.sequence_number
411    }
412
413    /// Modifies the `proof` for leaf at `leaf_index`
414    /// in place by fast-forwarding the given `proof` through the
415    /// `changelog`s, starting at index `changelog_buffer_index`
416    /// Returns false if the leaf was updated in the change log
417    #[inline(always)]
418    fn fast_forward_proof(
419        &self,
420        leaf: &mut Node,
421        proof: &mut [Node; MAX_DEPTH],
422        leaf_index: u32,
423        mut changelog_buffer_index: u64,
424        use_full_buffer: bool,
425    ) -> bool {
426        solana_logging!(
427            "Fast-forwarding proof, starting index {}",
428            changelog_buffer_index
429        );
430        let mask: usize = MAX_BUFFER_SIZE - 1;
431
432        let mut updated_leaf = *leaf;
433        log_compute!();
434        // Modifies proof by iterating through the change log
435        loop {
436            // If use_full_buffer is false, this loop will terminate if the initial value of
437            // changelog_buffer_index is the active index
438            if !use_full_buffer && changelog_buffer_index == self.active_index {
439                break;
440            }
441            changelog_buffer_index = (changelog_buffer_index + 1) & mask as u64;
442            self.change_logs[changelog_buffer_index as usize].update_proof_or_leaf(
443                leaf_index,
444                proof,
445                &mut updated_leaf,
446            );
447            // If use_full_buffer is true, this loop will do 1 full pass of the change logs
448            if use_full_buffer && changelog_buffer_index == self.active_index {
449                break;
450            }
451        }
452        log_compute!();
453        let proof_leaf_unchanged = updated_leaf == *leaf;
454        *leaf = updated_leaf;
455        proof_leaf_unchanged
456    }
457
458    #[inline(always)]
459    fn find_root_in_changelog(&self, current_root: Node) -> Option<u64> {
460        let mask: usize = MAX_BUFFER_SIZE - 1;
461        for i in 0..self.buffer_size {
462            let j = self.active_index.wrapping_sub(i) & mask as u64;
463            if self.change_logs[j as usize].root == current_root {
464                return Some(j);
465            }
466        }
467        None
468    }
469
470    #[inline(always)]
471    fn check_valid_leaf(
472        &self,
473        current_root: Node,
474        leaf: Node,
475        proof: &mut [Node; MAX_DEPTH],
476        leaf_index: u32,
477        allow_inferred_proof: bool,
478    ) -> Result<bool, ConcurrentMerkleTreeError> {
479        let mask: usize = MAX_BUFFER_SIZE - 1;
480        let (changelog_index, use_full_buffer) = match self.find_root_in_changelog(current_root) {
481            Some(matching_changelog_index) => (matching_changelog_index, false),
482            None => {
483                if allow_inferred_proof {
484                    solana_logging!("Failed to find root in change log -> replaying full buffer");
485                    (
486                        self.active_index.wrapping_sub(self.buffer_size - 1) & mask as u64,
487                        true,
488                    )
489                } else {
490                    return Err(ConcurrentMerkleTreeError::RootNotFound);
491                }
492            }
493        };
494        let mut updatable_leaf_node = leaf;
495        let proof_leaf_unchanged = self.fast_forward_proof(
496            &mut updatable_leaf_node,
497            proof,
498            leaf_index,
499            changelog_index,
500            use_full_buffer,
501        );
502        if !proof_leaf_unchanged {
503            return Err(ConcurrentMerkleTreeError::LeafContentsModified);
504        }
505        Ok(self.check_valid_proof(updatable_leaf_node, proof, leaf_index))
506    }
507
508    /// Checks that the proof provided is valid for the current root.
509    pub fn check_valid_proof(
510        &self,
511        leaf: Node,
512        proof: &[Node; MAX_DEPTH],
513        leaf_index: u32,
514    ) -> bool {
515        if !self.is_initialized() {
516            solana_logging!("Tree is not initialized, returning false");
517            return false;
518        }
519        if check_leaf_index(leaf_index, MAX_DEPTH).is_err() {
520            solana_logging!("Leaf index out of bounds for max_depth");
521            return false;
522        }
523        recompute(leaf, proof, leaf_index) == self.get_root()
524    }
525
526    /// Note: Enabling `allow_inferred_proof` will fast forward the given proof
527    /// from the beginning of the buffer in the case that the supplied root is
528    /// not in the buffer.
529    #[inline(always)]
530    fn try_apply_proof(
531        &mut self,
532        current_root: Node,
533        leaf: Node,
534        new_leaf: Node,
535        proof: &mut [Node; MAX_DEPTH],
536        leaf_index: u32,
537        allow_inferred_proof: bool,
538    ) -> Result<Node, ConcurrentMerkleTreeError> {
539        solana_logging!("Active Index: {}", self.active_index);
540        solana_logging!("Rightmost Index: {}", self.rightmost_proof.index);
541        solana_logging!("Buffer Size: {}", self.buffer_size);
542        solana_logging!("Leaf Index: {}", leaf_index);
543        let valid_root =
544            self.check_valid_leaf(current_root, leaf, proof, leaf_index, allow_inferred_proof)?;
545        if !valid_root {
546            return Err(ConcurrentMerkleTreeError::InvalidProof);
547        }
548        self.update_internal_counters();
549        Ok(self.update_buffers_from_proof(new_leaf, proof, leaf_index))
550    }
551
552    /// Implements circular addition for changelog buffer index
553    fn update_internal_counters(&mut self) {
554        let mask: usize = MAX_BUFFER_SIZE - 1;
555        self.active_index += 1;
556        self.active_index &= mask as u64;
557        if self.buffer_size < MAX_BUFFER_SIZE as u64 {
558            self.buffer_size += 1;
559        }
560        self.sequence_number = self.sequence_number.saturating_add(1);
561    }
562
563    /// Creates a new root from a proof that is valid for the root at
564    /// `self.active_index`
565    fn update_buffers_from_proof(&mut self, start: Node, proof: &[Node], index: u32) -> Node {
566        let change_log = &mut self.change_logs[self.active_index as usize];
567        // Also updates change_log's current root
568        let root = change_log.replace_and_recompute_path(index, start, proof);
569        // Update rightmost path if possible
570        if self.rightmost_proof.index < (1 << MAX_DEPTH) {
571            if index < self.rightmost_proof.index {
572                change_log.update_proof_or_leaf(
573                    self.rightmost_proof.index - 1,
574                    &mut self.rightmost_proof.proof,
575                    &mut self.rightmost_proof.leaf,
576                );
577            } else {
578                assert!(index == self.rightmost_proof.index);
579                solana_logging!("Appending rightmost leaf");
580                self.rightmost_proof.proof.copy_from_slice(proof);
581                self.rightmost_proof.index = index + 1;
582                self.rightmost_proof.leaf = change_log.get_leaf();
583            }
584        }
585        root
586    }
587}