Skip to main content

snarkvm_console_collections/merkle_tree/
mod.rs

1// Copyright (c) 2019-2026 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16mod helpers;
17pub use helpers::*;
18
19mod path;
20pub use path::*;
21
22#[cfg(test)]
23mod tests;
24
25use snarkvm_console_types::prelude::*;
26
27use aleo_std::prelude::*;
28
29#[cfg(feature = "locktick")]
30use locktick::parking_lot::Mutex;
31#[cfg(not(feature = "locktick"))]
32use parking_lot::Mutex;
33use serde::{Deserialize, Serialize};
34use std::{collections::BTreeMap, mem};
35
36#[cfg(not(feature = "serial"))]
37use rayon::prelude::*;
38
39/// A binary Merkle tree constructed with a leaf-digest hash function and a
40/// two-to-one compressing hash function.
41///
42/// If the number of leaves is less than `2**DEPTH`, the leaf layer is first
43/// padded to the next power of 2 with the empty-hash value `e` returned by the
44/// implementation of `PathHash::hash_empty()` for `PH`, then a balanced binary
45/// tree is built. In concrete terms, at most one `e` leaf is added: the rest
46/// are only virtual in that instead nodes with the value `PH::hash_children(e,
47/// e)` are added to the next level, which is indeed full of size equal to a
48/// power of 2.
49///
50/// Padding levels are then added as needed to reach the full `DEPTH`, each of
51/// which is constructed by hashing the root of the previous level together with
52/// `e`.
53#[derive(Deserialize, Serialize)]
54#[serde(bound = "E: Serialize + DeserializeOwned, LH: Serialize + DeserializeOwned, PH: Serialize + DeserializeOwned")]
55pub struct MerkleTree<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>>, const DEPTH: u8> {
56    /// The leaf hasher for the Merkle tree.
57    leaf_hasher: LH,
58    /// The path hasher for the Merkle tree.
59    path_hasher: PH,
60    /// The computed root of the full Merkle tree.
61    root: PH::Hash,
62    /// The internal hashes, from root to hashed leaves, of the full Merkle tree.
63    tree: Vec<PH::Hash>,
64    /// The canonical empty hash.
65    empty_hash: Field<E>,
66    /// The number of hashed leaves in the tree.
67    number_of_leaves: usize,
68    /// An optimization: the previous tree allocation reused in prepare_append.
69    #[serde(skip)]
70    preserved_tree_allocation: Mutex<Option<Vec<PH::Hash>>>,
71}
72
73impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>>, const DEPTH: u8> Clone
74    for MerkleTree<E, LH, PH, DEPTH>
75{
76    fn clone(&self) -> Self {
77        Self {
78            leaf_hasher: self.leaf_hasher.clone(),
79            path_hasher: self.path_hasher.clone(),
80            root: self.root,
81            tree: self.tree.clone(),
82            empty_hash: self.empty_hash,
83            number_of_leaves: self.number_of_leaves,
84            preserved_tree_allocation: Default::default(),
85        }
86    }
87}
88
89impl<E: Environment, LH: LeafHash<Hash = PH::Hash>, PH: PathHash<Hash = Field<E>>, const DEPTH: u8>
90    MerkleTree<E, LH, PH, DEPTH>
91{
92    #[inline]
93    /// Initializes a new Merkle tree with the given leaves.
94    pub fn new(leaf_hasher: &LH, path_hasher: &PH, leaves: &[LH::Leaf]) -> Result<Self> {
95        let timer = timer!("MerkleTree::new");
96
97        // Ensure the Merkle tree depth is greater than 0.
98        ensure!(DEPTH > 0, "Merkle tree depth must be greater than 0");
99        // Ensure the Merkle tree depth is less than or equal to 64.
100        ensure!(DEPTH <= 64u8, "Merkle tree depth must be less than or equal to 64");
101
102        // Compute the maximum number of leaves.
103        let max_leaves = match leaves.len().checked_next_power_of_two() {
104            Some(num_leaves) => num_leaves,
105            None => bail!("Integer overflow when computing the maximum number of leaves in the Merkle tree"),
106        };
107
108        // Compute the number of nodes.
109        let num_nodes = max_leaves - 1;
110        // Compute the tree size as the maximum number of leaves plus the number of nodes.
111        let tree_size = max_leaves + num_nodes;
112        // Compute the number of levels in the Merkle tree (i.e. log2(tree_size)).
113        let tree_depth = tree_depth::<DEPTH>(tree_size)?;
114        // Compute the number of padded levels.
115        let padding_depth = DEPTH - tree_depth;
116
117        // Compute the empty hash.
118        let empty_hash = path_hasher.hash_empty()?;
119
120        // Calculate the size of the tree which excludes leafless nodes.
121        // The minimum tree size is either a single root node or the calculated number of nodes plus
122        // the supplied leaves; if the number of leaves is odd, an empty hash is added for padding.
123        let minimum_tree_size =
124            std::cmp::max(1, num_nodes + leaves.len() + if leaves.len() > 1 { leaves.len() % 2 } else { 0 });
125
126        // Initialize the Merkle tree.
127        let mut tree = vec![empty_hash; minimum_tree_size];
128
129        // Compute and store each leaf hash.
130        tree[num_nodes..num_nodes + leaves.len()].copy_from_slice(&leaf_hasher.hash_leaves(leaves)?);
131        lap!(timer, "Hashed {} leaves", leaves.len());
132
133        // Compute and store the hashes for each level, iterating from the penultimate level to the root level.
134        let mut start_index = num_nodes;
135        // Compute the start index of the current level.
136        while let Some(start) = parent(start_index) {
137            // Compute the end index of the current level.
138            let end = left_child(start);
139            // Construct the children for each node in the current level; the leaves are padded, which means
140            // that there either are 2 children, or there are none, at which point we may stop iterating.
141            let tuples = (start..end)
142                .take_while(|&i| tree.get(left_child(i)).is_some())
143                .map(|i| (tree[left_child(i)], tree[right_child(i)]))
144                .collect::<Vec<_>>();
145            // Compute and store the hashes for each node in the current level.
146            let num_full_nodes = tuples.len();
147            tree[start..][..num_full_nodes].copy_from_slice(&path_hasher.hash_all_children(&tuples)?);
148            // Use the precomputed empty node hash for every empty node, if there are any.
149            if start + num_full_nodes < end {
150                let empty_node_hash = path_hasher.hash_children(&empty_hash, &empty_hash)?;
151                for node in tree.iter_mut().take(end).skip(start + num_full_nodes) {
152                    *node = empty_node_hash;
153                }
154            }
155            // Update the start index for the next level.
156            start_index = start;
157        }
158        lap!(timer, "Hashed {} levels", tree_depth);
159
160        // Compute the root hash, by iterating from the root level up to `DEPTH`.
161        let mut root_hash = tree[0];
162        for _ in 0..padding_depth {
163            // Update the root hash, by hashing the current root hash with the empty hash.
164            root_hash = path_hasher.hash_children(&root_hash, &empty_hash)?;
165        }
166        lap!(timer, "Hashed {} padding levels", padding_depth);
167
168        finish!(timer);
169
170        Ok(Self {
171            leaf_hasher: leaf_hasher.clone(),
172            path_hasher: path_hasher.clone(),
173            root: root_hash,
174            tree,
175            empty_hash,
176            number_of_leaves: leaves.len(),
177            preserved_tree_allocation: Default::default(),
178        })
179    }
180
181    #[inline]
182    /// Returns a new Merkle tree with the given new leaves appended to it.
183    pub fn prepare_append(&self, new_leaves: &[LH::Leaf]) -> Result<Self> {
184        let timer = timer!("MerkleTree::prepare_append");
185
186        // Compute the maximum number of leaves.
187        let max_leaves = match (self.number_of_leaves + new_leaves.len()).checked_next_power_of_two() {
188            Some(num_leaves) => num_leaves,
189            None => bail!("Integer overflow when computing the maximum number of leaves in the Merkle tree"),
190        };
191        // Compute the number of nodes.
192        let num_nodes = max_leaves - 1;
193        // Compute the tree size as the maximum number of leaves plus the number of nodes.
194        let tree_size = num_nodes + max_leaves;
195        // Compute the number of levels in the Merkle tree (i.e. log2(tree_size)).
196        let tree_depth = tree_depth::<DEPTH>(tree_size)?;
197        // Compute the number of padded levels.
198        let padding_depth = DEPTH - tree_depth;
199
200        // Reuse the previous Merkle tree, or initialize it if missing.
201        // All the (inner) nodes are rewritten, so their current values are irrelevant.
202        // The slowest part is populating the values, but large allocations are also slow.
203        let mut tree = self.preserved_tree_allocation.lock().take().unwrap_or_else(|| vec![self.empty_hash; num_nodes]);
204        // The number of nodes in the preserved allocation is too small if the depth increases.
205        // This is basically a noop if there are sufficient nodes already.
206        tree.resize(num_nodes, self.empty_hash);
207
208        // Extend the new Merkle tree with the existing leaf hashes.
209        tree.extend(self.leaf_hashes()?);
210        // Extend the new Merkle tree with the new leaf hashes.
211        tree.extend(&self.leaf_hasher.hash_leaves(new_leaves)?);
212
213        // Calculate the size of the tree which excludes leafless nodes.
214        let new_number_of_leaves = self.number_of_leaves + new_leaves.len();
215        let minimum_tree_size = std::cmp::max(
216            1,
217            num_nodes + new_number_of_leaves + if new_number_of_leaves > 1 { new_number_of_leaves % 2 } else { 0 },
218        );
219
220        // Resize the new Merkle tree with empty hashes to pad up to `tree_size`.
221        tree.resize(minimum_tree_size, self.empty_hash);
222        lap!(timer, "Hashed {} new leaves", new_leaves.len());
223
224        // Initialize a start index to track the starting index of the current level.
225        let start_index = num_nodes;
226        // Initialize a middle index to separate the precomputed indices from the new indices that need to be computed.
227        let middle_index = num_nodes + self.number_of_leaves;
228        // Initialize a precompute index to track the starting index of each precomputed level.
229        let start_precompute_index = match self.number_of_leaves.checked_next_power_of_two() {
230            Some(num_leaves) => num_leaves - 1,
231            None => bail!("Integer overflow when computing the Merkle tree precompute index"),
232        };
233        // Initialize a precompute index to track the middle index of each precomputed level.
234        let middle_precompute_index = match num_nodes == start_precompute_index {
235            // If the old tree and new tree are of the same size, then we can copy over the right half of the old tree.
236            true => Some(start_precompute_index + self.number_of_leaves + new_leaves.len() + 1),
237            // Otherwise, we need to compute the right half of the new tree.
238            false => None,
239        };
240
241        // Compute and store the hashes for each level, iterating from the penultimate level to the root level.
242        self.compute_updated_tree(
243            &mut tree,
244            start_index,
245            middle_index,
246            start_precompute_index,
247            middle_precompute_index,
248        )?;
249
250        // Compute the root hash, by iterating from the root level up to `DEPTH`.
251        let mut root_hash = tree[0];
252        for _ in 0..padding_depth {
253            // Update the root hash, by hashing the current root hash with the empty hash.
254            root_hash = self.path_hasher.hash_children(&root_hash, &self.empty_hash)?;
255        }
256        lap!(timer, "Hashed {} padding levels", padding_depth);
257
258        finish!(timer);
259
260        Ok(Self {
261            leaf_hasher: self.leaf_hasher.clone(),
262            path_hasher: self.path_hasher.clone(),
263            root: root_hash,
264            tree,
265            empty_hash: self.empty_hash,
266            number_of_leaves: self.number_of_leaves + new_leaves.len(),
267            preserved_tree_allocation: Default::default(), // Placeholder; will be updated at the callsite using Self::preserve_tree_allocation
268        })
269    }
270
271    #[inline]
272    /// Updates the Merkle tree with the given new leaves appended to it.
273    pub fn append(&mut self, new_leaves: &[LH::Leaf]) -> Result<()> {
274        let timer = timer!("MerkleTree::append");
275
276        // Compute the updated Merkle tree with the new leaves.
277        let updated_tree = self.prepare_append(new_leaves)?;
278        // Update the tree at the very end, so the original tree is not altered in case of failure.
279        *self = updated_tree;
280
281        finish!(timer);
282        Ok(())
283    }
284
285    #[inline]
286    /// Updates the Merkle tree at the location of the given leaf index with the new leaf.
287    pub fn update(&mut self, leaf_index: usize, new_leaf: &LH::Leaf) -> Result<()> {
288        let timer = timer!("MerkleTree::update");
289
290        // Compute the updated Merkle tree with the new leaves.
291        let updated_tree = self.prepare_update(leaf_index, new_leaf)?;
292        // Update the tree at the very end, so the original tree is not altered in case of failure.
293        *self = updated_tree;
294
295        finish!(timer);
296        Ok(())
297    }
298
299    #[inline]
300    /// Returns a new Merkle tree with updates at the location of the given leaf index with the new leaf.
301    pub fn prepare_update(&self, leaf_index: usize, new_leaf: &LH::Leaf) -> Result<Self> {
302        let timer = timer!("MerkleTree::prepare_update");
303
304        // Check that the leaf index is within the bounds of the Merkle tree.
305        ensure!(
306            leaf_index < self.number_of_leaves,
307            "Leaf index must be less than the number of leaves in the Merkle tree {leaf_index} , {}",
308            self.number_of_leaves
309        );
310
311        // Allocate a vector to store the path hashes.
312        let mut path_hashes = Vec::with_capacity(DEPTH as usize);
313
314        // Compute and add the new leaf hash to the path hashes.
315        path_hashes.push(self.leaf_hasher.hash_leaf(new_leaf)?);
316        lap!(timer, "Hashed 1 new leaf");
317
318        // Compute the start index (on the left) for the leaf hashes level in the Merkle tree.
319        let start = match self.number_of_leaves.checked_next_power_of_two() {
320            Some(num_leaves) => num_leaves - 1,
321            None => bail!("Integer overflow when computing the Merkle tree start index"),
322        };
323
324        // Compute the new hashes for the path from the leaf to the root.
325        let mut index = start + leaf_index;
326        while let Some(parent) = parent(index) {
327            // Get the left and right child hashes of the parent.
328            let (left, right) = match is_left_child(index) {
329                true => (path_hashes.last().unwrap(), &self.tree[right_child(parent)]),
330                false => (&self.tree[left_child(parent)], path_hashes.last().unwrap()),
331            };
332            // Compute and add the new parent hash to the path hashes.
333            path_hashes.push(self.path_hasher.hash_children(left, right)?);
334            // Update the index to the parent.
335            index = parent;
336        }
337
338        // Compute the number of levels in the Merkle tree (i.e. log2(tree_size)).
339        let tree_depth = tree_depth::<DEPTH>(self.tree.len())?;
340        // Compute the padding depth.
341        let padding_depth = DEPTH - tree_depth;
342
343        // Update the root hash.
344        // This unwrap is safe, as the path hashes vector is guaranteed to have at least one element.
345        let mut root_hash = *path_hashes.last().unwrap();
346        for _ in 0..padding_depth {
347            // Update the root hash, by hashing the current root hash with the empty hash.
348            root_hash = self.path_hasher.hash_children(&root_hash, &self.empty_hash)?;
349        }
350        lap!(timer, "Hashed {} padding levels", padding_depth);
351
352        // Initialize the Merkle tree.
353        let mut tree = Vec::with_capacity(self.tree.len());
354        // Extend the new Merkle tree with the existing leaf hashes.
355        tree.extend(&self.tree);
356
357        // Update the rest of the tree with the new path hashes.
358        let mut index = Some(start + leaf_index);
359        for path_hash in path_hashes {
360            tree[index.unwrap()] = path_hash;
361            index = parent(index.unwrap());
362        }
363
364        finish!(timer);
365
366        Ok(Self {
367            leaf_hasher: self.leaf_hasher.clone(),
368            path_hasher: self.path_hasher.clone(),
369            root: root_hash,
370            tree,
371            empty_hash: self.empty_hash,
372            number_of_leaves: self.number_of_leaves,
373            preserved_tree_allocation: Default::default(),
374        })
375    }
376
377    #[inline]
378    /// Updates the Merkle tree at the location of the given leaf indices with the new leaves.
379    pub fn update_many(&mut self, updates: &BTreeMap<usize, LH::Leaf>) -> Result<()> {
380        let timer = timer!("MerkleTree::update_many");
381
382        // Check that there are updates to perform.
383        ensure!(!updates.is_empty(), "There must be at least one leaf to update in the Merkle tree");
384
385        // Check that the latest leaf index is less than number of leaves in the Merkle tree.
386        // Note: This unwrap is safe since updates is guaranteed to be non-empty.
387        ensure!(
388            *updates.last_key_value().unwrap().0 < self.number_of_leaves,
389            "Leaf index must be less than the number of leaves in the Merkle tree"
390        );
391
392        // Compute the start index (on the left) for the leaf hashes level in the Merkle tree.
393        let start = match self.number_of_leaves.checked_next_power_of_two() {
394            Some(num_leaves) => num_leaves - 1,
395            None => bail!("Integer overflow when computing the Merkle tree start index"),
396        };
397
398        // A helper to compute the leaf hash.
399        let hash_update = |(leaf_index, leaf): &(&usize, &LH::Leaf)| {
400            self.leaf_hasher.hash_leaf(leaf).map(|hash| (start + **leaf_index, hash))
401        };
402
403        // Hash the leaves and add them to the updated hashes.
404        let leaf_hashes: Vec<(usize, LH::Hash)> = match updates.len() {
405            0..=100 => updates.iter().map(|update| hash_update(&update)).collect::<Result<Vec<_>>>()?,
406            _ => cfg_iter!(updates).map(|update| hash_update(&update)).collect::<Result<Vec<_>>>()?,
407        };
408        lap!(timer, "Hashed {} new leaves", leaf_hashes.len());
409
410        // Store the updated hashes by level.
411        let mut updated_hashes = Vec::new();
412        updated_hashes.push(leaf_hashes);
413
414        // A helper function to compute the path hashes for a given level.
415        type Update<PH> = (usize, (<PH as PathHash>::Hash, <PH as PathHash>::Hash));
416        let compute_path_hashes = |inputs: &[Update<PH>]| match inputs.len() {
417            0..=100 => inputs
418                .iter()
419                .map(|(index, (left, right))| self.path_hasher.hash_children(left, right).map(|hash| (*index, hash)))
420                .collect::<Result<Vec<_>>>(),
421            _ => cfg_iter!(inputs)
422                .map(|(index, (left, right))| self.path_hasher.hash_children(left, right).map(|hash| (*index, hash)))
423                .collect::<Result<Vec<_>>>(),
424        };
425
426        // Compute the depth of the tree. This corresponds to the number of levels of hashes in the tree.
427        let tree_depth = tree_depth::<DEPTH>(self.tree.len())?;
428        // Allocate a vector to store the inputs to the path hasher.
429        let mut inputs = Vec::with_capacity(updated_hashes[0].len());
430        // For each level in the tree, compute the path hashes.
431        // In the first iteration, we compute the path hashes for the updated leaf hashes.
432        // In the subsequent iterations, we compute the path hashes for the updated path hashes, until we reach the root.
433        for level in 0..tree_depth as usize {
434            let mut current = 0;
435            while current < updated_hashes[level].len() {
436                let (current_leaf_index, current_leaf_hash) = updated_hashes[level][current];
437                // Get the sibling of the current leaf.
438                let sibling_leaf_index = match sibling(current_leaf_index) {
439                    Some(sibling_index) => sibling_index,
440                    // If there is no sibling, then we have reached the root.
441                    None => break,
442                };
443                // Check if the sibling hash is the next hash in the vector.
444                let sibling_is_next_hash = match current + 1 < updated_hashes[level].len() {
445                    true => updated_hashes[level][current + 1].0 == sibling_leaf_index,
446                    false => false,
447                };
448                // Get the sibling hash.
449                // Note: This algorithm assumes that the sibling hash is either the next hash in the vector,
450                // or in the original Merkle tree. Consequently, updates need to be provided in sequential order.
451                // This is enforced by the type of `updates: `BTreeMap<usize, LH::Leaf>`.
452                // If this assumption is violated, then the algorithm will compute incorrect path hashes in the Merkle tree.
453                let sibling_leaf_hash = match sibling_is_next_hash {
454                    true => updated_hashes[level][current + 1].1,
455                    false => self.tree[sibling_leaf_index],
456                };
457                // Order the current and sibling hashes.
458                let (left, right) = match is_left_child(current_leaf_index) {
459                    true => (current_leaf_hash, sibling_leaf_hash),
460                    false => (sibling_leaf_hash, current_leaf_hash),
461                };
462                // Compute the parent index.
463                // Note that this unwrap is safe, since we check that the `current_leaf_index` is not the root.
464                let parent_index = parent(current_leaf_index).unwrap();
465                // Add the parent hash to the updated hashes.
466                inputs.push((parent_index, (left, right)));
467                // Update the current index.
468                match sibling_is_next_hash {
469                    true => current += 2,
470                    false => current += 1,
471                }
472            }
473            // Compute the path hashes for the current level.
474            let path_hashes = compute_path_hashes(&inputs)?;
475            // Add the path hashes to the updated hashes.
476            updated_hashes.push(path_hashes);
477            // Clear the inputs.
478            inputs.clear();
479        }
480
481        // Compute the padding depth.
482        let padding_depth = DEPTH - tree_depth;
483
484        // Update the root hash.
485        // This unwrap is safe, as the updated hashes is guaranteed to have at least one element.
486        let mut root_hash = updated_hashes.last().unwrap()[0].1;
487        for _ in 0..padding_depth {
488            // Update the root hash, by hashing the current root hash with the empty hash.
489            root_hash = self.path_hasher.hash_children(&root_hash, &self.empty_hash)?;
490        }
491        lap!(timer, "Hashed {} padding levels", padding_depth);
492
493        // Update the root hash.
494        self.root = root_hash;
495
496        // Update the rest of the tree with the updated hashes.
497        for (index, hash) in updated_hashes.into_iter().flatten() {
498            self.tree[index] = hash;
499        }
500
501        finish!(timer);
502        Ok(())
503    }
504
505    #[inline]
506    /// Returns a new Merkle tree with the last 'n' leaves removed from it.
507    pub fn prepare_remove_last_n(&self, n: usize) -> Result<Self> {
508        let timer = timer!("MerkleTree::prepare_remove_last_n");
509
510        ensure!(n > 0, "Cannot remove zero leaves from the Merkle tree");
511
512        // Determine the updated number of leaves, after removing the last 'n' leaves.
513        let updated_number_of_leaves = self.number_of_leaves.checked_sub(n).ok_or_else(|| {
514            anyhow!("Failed to remove '{n}' leaves from the Merkle tree, as it only contains {}", self.number_of_leaves)
515        })?;
516
517        // Compute the maximum number of leaves.
518        let max_leaves = match (updated_number_of_leaves).checked_next_power_of_two() {
519            Some(num_leaves) => num_leaves,
520            None => bail!("Integer overflow when computing the maximum number of leaves in the Merkle tree"),
521        };
522        // Compute the number of nodes.
523        let num_nodes = max_leaves - 1;
524        // Compute the tree size as the maximum number of leaves plus the number of nodes.
525        let tree_size = num_nodes + max_leaves;
526        // Compute the number of levels in the Merkle tree (i.e. log2(tree_size)).
527        let tree_depth = tree_depth::<DEPTH>(tree_size)?;
528        // Compute the number of padded levels.
529        let padding_depth = DEPTH - tree_depth;
530
531        // Calculate the size of the tree which excludes leafless nodes.
532        let minimum_tree_size = std::cmp::max(
533            1,
534            num_nodes
535                + updated_number_of_leaves
536                + if updated_number_of_leaves > 1 { updated_number_of_leaves % 2 } else { 0 },
537        );
538
539        // Initialize the Merkle tree.
540        let mut tree = vec![self.empty_hash; num_nodes];
541        // Extend the new Merkle tree with the existing leaf hashes, excluding the last 'n' leaves.
542        tree.extend(&self.leaf_hashes()?[..updated_number_of_leaves]);
543        // Resize the new Merkle tree with empty hashes to pad up to `tree_size`.
544        tree.resize(minimum_tree_size, self.empty_hash);
545        lap!(timer, "Resizing to {} leaves", updated_number_of_leaves);
546
547        // Initialize a start index to track the starting index of the current level.
548        let start_index = num_nodes;
549        // Initialize a middle index to separate the precomputed indices from the new indices that need to be computed.
550        let middle_index = num_nodes + updated_number_of_leaves;
551        // Initialize a precompute index to track the starting index of each precomputed level.
552        let start_precompute_index = match self.number_of_leaves.checked_next_power_of_two() {
553            Some(num_leaves) => num_leaves - 1,
554            None => bail!("Integer overflow when computing the Merkle tree precompute index"),
555        };
556        // Initialize a precompute index to track the middle index of each precomputed level.
557        let middle_precompute_index = match num_nodes == start_precompute_index {
558            // If the old tree and new tree are of the same size, then we can copy over the right half of the old tree.
559            true => Some(start_precompute_index + self.number_of_leaves + 1),
560            // true => None,
561            // Otherwise, do nothing, since shrinking the tree is already free.
562            false => None,
563        };
564
565        // Compute and store the hashes for each level, iterating from the penultimate level to the root level.
566        self.compute_updated_tree(
567            &mut tree,
568            start_index,
569            middle_index,
570            start_precompute_index,
571            middle_precompute_index,
572        )?;
573
574        // Compute the root hash, by iterating from the root level up to `DEPTH`.
575        let mut root_hash = tree[0];
576        for _ in 0..padding_depth {
577            // Update the root hash, by hashing the current root hash with the empty hash.
578            root_hash = self.path_hasher.hash_children(&root_hash, &self.empty_hash)?;
579        }
580        lap!(timer, "Hashed {} padding levels", padding_depth);
581
582        finish!(timer);
583
584        Ok(Self {
585            leaf_hasher: self.leaf_hasher.clone(),
586            path_hasher: self.path_hasher.clone(),
587            root: root_hash,
588            tree,
589            empty_hash: self.empty_hash,
590            number_of_leaves: updated_number_of_leaves,
591            preserved_tree_allocation: Default::default(),
592        })
593    }
594
595    #[inline]
596    /// Updates the Merkle tree with the last 'n' leaves removed from it.
597    pub fn remove_last_n(&mut self, n: usize) -> Result<()> {
598        let timer = timer!("MerkleTree::remove_last_n");
599
600        // Compute the updated Merkle tree with the last 'n' leaves removed.
601        let updated_tree = self.prepare_remove_last_n(n)?;
602        // Update the tree at the very end, so the original tree is not altered in case of failure.
603        *self = updated_tree;
604
605        finish!(timer);
606        Ok(())
607    }
608
609    #[inline]
610    /// Returns the Merkle path for the given leaf index and leaf.
611    pub fn prove(&self, leaf_index: usize, leaf: &LH::Leaf) -> Result<MerklePath<E, DEPTH>> {
612        // Ensure the leaf index is valid.
613        ensure!(leaf_index < self.number_of_leaves, "The given Merkle leaf index is out of bounds");
614
615        // Compute the leaf hash.
616        let leaf_hash = self.leaf_hasher.hash_leaf(leaf)?;
617
618        // Compute the start index (on the left) for the leaf hashes level in the Merkle tree.
619        let start = match self.number_of_leaves.checked_next_power_of_two() {
620            Some(num_leaves) => num_leaves - 1,
621            None => bail!("Integer overflow when computing the Merkle tree start index"),
622        };
623
624        // Compute the absolute index of the leaf in the Merkle tree.
625        let mut index = start + leaf_index;
626        // Ensure the leaf index is valid.
627        ensure!(index < self.tree.len(), "The given Merkle leaf index is out of bounds");
628        // Ensure the leaf hash matches the one in the tree.
629        ensure!(self.tree[index] == leaf_hash, "The given Merkle leaf does not match the one in the Merkle tree");
630
631        // Initialize a vector for the Merkle path.
632        let mut path = Vec::with_capacity(DEPTH as usize);
633
634        // Iterate from the leaf hash to the root level, storing the sibling hashes along the path.
635        for _ in 0..DEPTH {
636            // Compute the index of the sibling hash, if it exists.
637            if let Some(sibling) = sibling(index) {
638                // Append the sibling hash to the path.
639                path.push(self.tree[sibling]);
640                // Compute the index of the parent hash, if it exists.
641                match parent(index) {
642                    // Update the index to the parent index.
643                    Some(parent) => index = parent,
644                    // If the parent does not exist, the path is complete.
645                    None => break,
646                }
647            }
648        }
649
650        // If the Merkle path length is not equal to `DEPTH`, pad the path with the empty hash.
651        path.resize(DEPTH as usize, self.empty_hash);
652
653        // Return the Merkle path.
654        MerklePath::try_from((U64::new(leaf_index as u64), path))
655    }
656
657    /// Returns `true` if the given Merkle path is valid for the given root and leaf.
658    pub fn verify(&self, path: &MerklePath<E, DEPTH>, root: &PH::Hash, leaf: &LH::Leaf) -> bool {
659        path.verify(&self.leaf_hasher, &self.path_hasher, root, leaf)
660    }
661
662    /// Returns the Merkle root of the tree.
663    pub const fn root(&self) -> &PH::Hash {
664        &self.root
665    }
666
667    /// Returns the Merkle tree (excluding the hashes of the leaves).
668    pub fn tree(&self) -> &[PH::Hash] {
669        &self.tree
670    }
671
672    /// Returns the empty hash.
673    pub const fn empty_hash(&self) -> &PH::Hash {
674        &self.empty_hash
675    }
676
677    /// Returns the leaf hashes from the Merkle tree.
678    pub fn leaf_hashes(&self) -> Result<&[LH::Hash]> {
679        // Compute the start index (on the left) for the leaf hashes level in the Merkle tree.
680        let start = match self.number_of_leaves.checked_next_power_of_two() {
681            Some(num_leaves) => num_leaves - 1,
682            None => bail!("Integer overflow when computing the Merkle tree start index"),
683        };
684        // Compute the end index (on the right) for the leaf hashes level in the Merkle tree.
685        let end = start + self.number_of_leaves;
686        // Return the leaf hashes.
687        Ok(&self.tree[start..end])
688    }
689
690    /// Returns the number of leaves in the Merkle tree.
691    pub const fn number_of_leaves(&self) -> usize {
692        self.number_of_leaves
693    }
694
695    /// Compute and store the hashes for each level, iterating from the penultimate level to the root level.
696    ///
697    /// ```ignore
698    ///  start_index      middle_index                              end_index
699    ///  start_precompute_index         middle_precompute_index     end_index
700    /// ```
701    #[inline]
702    fn compute_updated_tree(
703        &self,
704        tree: &mut [Field<E>],
705        mut start_index: usize,
706        mut middle_index: usize,
707        mut start_precompute_index: usize,
708        mut middle_precompute_index: Option<usize>,
709    ) -> Result<()> {
710        // Initialize a timer for the while loop.
711        let timer = timer!("MerkleTree::compute_updated_tree");
712
713        // Compute and store the hashes for each level, iterating from the penultimate level to the root level.
714        let empty_hash = self.path_hasher.hash_empty()?;
715        while let (Some(start), Some(middle)) = (parent(start_index), parent(middle_index)) {
716            // Compute the end index of the current level.
717            let end = left_child(start);
718
719            // If the current level has precomputed indices, copy them instead of recomputing them.
720            if let Some(start_precompute) = parent(start_precompute_index) {
721                // Compute the end index of the precomputed level.
722                let end_precompute = start_precompute + (middle - start);
723                // Copy the hashes for each node in the current level.
724                tree[start..middle].copy_from_slice(&self.tree[start_precompute..end_precompute]);
725                // Update the precompute index for the next level.
726                start_precompute_index = start_precompute;
727            } else {
728                // Ensure the start index is equal to the middle index, as all precomputed indices have been processed.
729                ensure!(start == middle, "Failed to process all left precomputed indices in the Merkle tree");
730            }
731            lap!(timer, "Precompute (Left): {start} -> {middle}");
732
733            // If the current level has precomputed indices, copy them instead of recomputing them.
734            // Note: This logic works because the old tree and new tree are the same power of two.
735            if let Some(middle_precompute) = middle_precompute_index {
736                if let Some(middle_precompute) = parent(middle_precompute) {
737                    // Construct the children for the new indices in the current level.
738                    let tuples = (middle..middle_precompute)
739                        .map(|i| {
740                            (
741                                tree.get(left_child(i)).copied().unwrap_or(empty_hash),
742                                tree.get(right_child(i)).copied().unwrap_or(empty_hash),
743                            )
744                        })
745                        .collect::<Vec<_>>();
746                    // Process the indices that need to be computed for the current level.
747                    // If any level requires computing more than 100 nodes, borrow the tree for performance.
748                    match tuples.len() >= 100 {
749                        // Option 1: Borrow the tree to compute and store the hashes for the new indices in the current level.
750                        true => cfg_iter_mut!(tree[middle..middle_precompute]).zip_eq(cfg_iter!(tuples)).try_for_each(
751                            |(node, (left, right))| {
752                                *node = self.path_hasher.hash_children(left, right)?;
753                                Ok::<_, Error>(())
754                            },
755                        )?,
756                        // Option 2: Compute and store the hashes for the new indices in the current level.
757                        false => tree[middle..middle_precompute].iter_mut().zip_eq(&tuples).try_for_each(
758                            |(node, (left, right))| {
759                                *node = self.path_hasher.hash_children(left, right)?;
760                                Ok::<_, Error>(())
761                            },
762                        )?,
763                    }
764                    lap!(timer, "Compute: {middle} -> {middle_precompute}");
765
766                    // Copy the hashes for each node in the current level.
767                    tree[middle_precompute..end].copy_from_slice(&self.tree[middle_precompute..end]);
768                    // Update the precompute index for the next level.
769                    middle_precompute_index = Some(middle_precompute + 1);
770                    lap!(timer, "Precompute (Right): {middle_precompute} -> {end}");
771                } else {
772                    // Ensure the middle precompute index is equal to the end index, as all precomputed indices have been processed.
773                    ensure!(
774                        middle_precompute == end,
775                        "Failed to process all right precomputed indices in the Merkle tree"
776                    );
777                }
778            } else {
779                // Construct the children for the new indices in the current level.
780                let tuples = (middle..end)
781                    .map(|i| {
782                        (
783                            tree.get(left_child(i)).copied().unwrap_or(empty_hash),
784                            tree.get(right_child(i)).copied().unwrap_or(empty_hash),
785                        )
786                    })
787                    .collect::<Vec<_>>();
788                // Process the indices that need to be computed for the current level.
789                // If any level requires computing more than 100 nodes, borrow the tree for performance.
790                match tuples.len() >= 100 {
791                    // Option 1: Borrow the tree to compute and store the hashes for the new indices in the current level.
792                    true => cfg_iter_mut!(tree[middle..end]).zip_eq(cfg_iter!(tuples)).try_for_each(
793                        |(node, (left, right))| {
794                            *node = self.path_hasher.hash_children(left, right)?;
795                            Ok::<_, Error>(())
796                        },
797                    )?,
798                    // Option 2: Compute and store the hashes for the new indices in the current level.
799                    false => tree[middle..end].iter_mut().zip_eq(&tuples).try_for_each(|(node, (left, right))| {
800                        *node = self.path_hasher.hash_children(left, right)?;
801                        Ok::<_, Error>(())
802                    })?,
803                }
804                lap!(timer, "Compute: {middle} -> {end}");
805            }
806
807            // Update the start index for the next level.
808            start_index = start;
809            // Update the middle index for the next level.
810            middle_index = middle;
811        }
812
813        // End the timer for the while loop.
814        finish!(timer);
815
816        Ok(())
817    }
818
819    /// Save the previous tree in order to reuse its allocation later on.
820    pub fn preserve_tree_allocation(&self, previous: &mut Self) {
821        *self.preserved_tree_allocation.lock() = Some(mem::take(&mut previous.tree));
822    }
823}
824
825/// Returns the depth of the tree, given the size of the tree.
826#[inline]
827fn tree_depth<const DEPTH: u8>(tree_size: usize) -> Result<u8> {
828    let tree_size = u64::try_from(tree_size)?;
829    // Since we only allow tree sizes up to u64::MAX, the maximum possible depth is 63.
830    let tree_depth = u8::try_from(tree_size.checked_ilog2().unwrap_or(0))?;
831    // Ensure the tree depth is within the depth bound.
832    match tree_depth <= DEPTH {
833        // Return the tree depth.
834        true => Ok(tree_depth),
835        false => bail!("Merkle tree cannot exceed depth {DEPTH}: attempted to reach depth {tree_depth}"),
836    }
837}
838
839/// Returns the index of the left child, given an index.
840#[inline]
841const fn left_child(index: usize) -> usize {
842    2 * index + 1
843}
844
845/// Returns the index of the right child, given an index.
846#[inline]
847const fn right_child(index: usize) -> usize {
848    2 * index + 2
849}
850
851/// Returns the index of the sibling, given an index.
852#[inline]
853const fn sibling(index: usize) -> Option<usize> {
854    if is_root(index) {
855        None
856    } else if is_left_child(index) {
857        Some(index + 1)
858    } else {
859        Some(index - 1)
860    }
861}
862
863/// Returns true iff the index represents the root.
864#[inline]
865const fn is_root(index: usize) -> bool {
866    index == 0
867}
868
869/// Returns true iff the given index represents a left child.
870#[inline]
871const fn is_left_child(index: usize) -> bool {
872    index % 2 == 1
873}
874
875/// Returns the index of the parent, given the index of a child.
876#[inline]
877const fn parent(index: usize) -> Option<usize> {
878    if index > 0 { Some((index - 1) >> 1) } else { None }
879}