Skip to main content

stealth_lib/merkle/
tree.rs

1//! Merkle tree data structure.
2//!
3//! A sparse Merkle tree implementation using MiMC hash, designed for
4//! zero-knowledge proof applications.
5
6use crate::error::{Error, Result};
7use crate::hash::MimcHasher;
8use crate::merkle::proof::MerkleProof;
9use crate::merkle::ROOT_HISTORY_SIZE;
10
11#[cfg(feature = "std")]
12use std::collections::HashMap;
13
14#[cfg(not(feature = "std"))]
15extern crate alloc;
16#[cfg(not(feature = "std"))]
17use alloc::collections::BTreeMap as HashMap;
18#[cfg(not(feature = "std"))]
19use alloc::vec::Vec;
20
21/// A Merkle tree with MiMC hash function.
22///
23/// This implementation is optimized for ZK-circuit compatibility and includes
24/// features like root history for handling concurrent on-chain insertions.
25///
26/// # Example
27///
28/// ```
29/// use stealth_lib::MerkleTree;
30///
31/// // Create a new tree with 20 levels
32/// let mut tree = MerkleTree::new(20).unwrap();
33///
34/// // Insert leaves
35/// let index = tree.insert(12345).unwrap();
36/// assert_eq!(index, 0);
37///
38/// // Get the current root
39/// let root = tree.root().unwrap();
40/// println!("Root: {}", root);
41/// ```
42///
43/// # Capacity
44///
45/// A tree with `n` levels can hold `2^n` leaves. The maximum supported
46/// depth is 255 levels, though practical trees typically use 20-32 levels.
47#[derive(Debug, Clone)]
48pub struct MerkleTree {
49    /// Number of levels in the tree (excluding root).
50    levels: u8,
51    /// Pre-computed subtree hashes for empty positions.
52    filled_subtrees: HashMap<u8, u128>,
53    /// Circular buffer of recent root hashes.
54    roots: HashMap<u8, u128>,
55    /// Index into the roots circular buffer.
56    current_root_index: u8,
57    /// Index for the next leaf to be inserted.
58    next_index: u32,
59    /// Hash function used for the tree.
60    hasher: MimcHasher,
61    /// Leaves inserted into the tree (for proof generation).
62    leaves: Vec<u128>,
63}
64
65impl MerkleTree {
66    /// Creates a new empty Merkle tree with the specified number of levels.
67    ///
68    /// # Arguments
69    ///
70    /// * `levels` - The depth of the tree. The tree can hold `2^levels` leaves.
71    ///
72    /// # Returns
73    ///
74    /// A new `MerkleTree` or an error if the configuration is invalid.
75    ///
76    /// # Errors
77    ///
78    /// Returns [`Error::InvalidTreeConfig`] if `levels` is 0 or greater than 32.
79    ///
80    /// # Example
81    ///
82    /// ```
83    /// use stealth_lib::MerkleTree;
84    ///
85    /// let tree = MerkleTree::new(20).unwrap();
86    /// assert_eq!(tree.levels(), 20);
87    /// assert_eq!(tree.capacity(), 1 << 20);
88    /// ```
89    pub fn new(levels: u8) -> Result<Self> {
90        if levels == 0 {
91            return Err(Error::InvalidTreeConfig(
92                "Tree must have at least 1 level".to_string(),
93            ));
94        }
95        if levels > 32 {
96            return Err(Error::InvalidTreeConfig(
97                "Tree depth cannot exceed 32 levels".to_string(),
98            ));
99        }
100
101        let hasher = MimcHasher::default();
102        let mut instance = MerkleTree {
103            levels,
104            filled_subtrees: HashMap::new(),
105            roots: HashMap::new(),
106            current_root_index: 0,
107            next_index: 0,
108            hasher,
109            leaves: Vec::new(),
110        };
111
112        // Initialize filled_subtrees with zero hashes
113        for i in 0..levels {
114            instance.filled_subtrees.insert(i, instance.zeros(i));
115        }
116
117        // Initialize root with the empty tree root
118        instance.roots.insert(0, instance.zeros(levels - 1));
119
120        Ok(instance)
121    }
122
123    /// Creates a new Merkle tree with a custom hasher.
124    ///
125    /// # Arguments
126    ///
127    /// * `levels` - The depth of the tree
128    /// * `hasher` - Custom MiMC hasher configuration
129    ///
130    /// # Example
131    ///
132    /// ```
133    /// use stealth_lib::{MerkleTree, hash::MimcHasher};
134    ///
135    /// let hasher = MimcHasher::default();
136    /// let tree = MerkleTree::with_hasher(20, hasher).unwrap();
137    /// ```
138    pub fn with_hasher(levels: u8, hasher: MimcHasher) -> Result<Self> {
139        if levels == 0 {
140            return Err(Error::InvalidTreeConfig(
141                "Tree must have at least 1 level".to_string(),
142            ));
143        }
144        if levels > 32 {
145            return Err(Error::InvalidTreeConfig(
146                "Tree depth cannot exceed 32 levels".to_string(),
147            ));
148        }
149
150        let mut instance = MerkleTree {
151            levels,
152            filled_subtrees: HashMap::new(),
153            roots: HashMap::new(),
154            current_root_index: 0,
155            next_index: 0,
156            hasher,
157            leaves: Vec::new(),
158        };
159
160        for i in 0..levels {
161            instance.filled_subtrees.insert(i, instance.zeros(i));
162        }
163
164        instance.roots.insert(0, instance.zeros(levels - 1));
165
166        Ok(instance)
167    }
168
169    /// Returns the number of levels in the tree.
170    #[inline]
171    pub fn levels(&self) -> u8 {
172        self.levels
173    }
174
175    /// Returns the maximum capacity of the tree.
176    ///
177    /// This is `2^levels`.
178    #[inline]
179    pub fn capacity(&self) -> usize {
180        1usize << self.levels
181    }
182
183    /// Returns the current number of leaves in the tree.
184    #[inline]
185    pub fn len(&self) -> u32 {
186        self.next_index
187    }
188
189    /// Returns true if the tree is empty.
190    #[inline]
191    pub fn is_empty(&self) -> bool {
192        self.next_index == 0
193    }
194
195    /// Returns a reference to the hasher used by this tree.
196    #[inline]
197    pub fn hasher(&self) -> &MimcHasher {
198        &self.hasher
199    }
200
201    /// Returns the current root hash of the tree.
202    ///
203    /// Returns `None` only if the tree is in an invalid state (should not happen
204    /// under normal usage).
205    ///
206    /// # Example
207    ///
208    /// ```
209    /// use stealth_lib::MerkleTree;
210    ///
211    /// let tree = MerkleTree::new(20).unwrap();
212    /// let root = tree.root().unwrap();
213    /// println!("Empty tree root: {}", root);
214    /// ```
215    pub fn root(&self) -> Option<u128> {
216        self.roots.get(&self.current_root_index).copied()
217    }
218
219    /// Hashes two child nodes to produce a parent node.
220    ///
221    /// Uses the MiMC sponge construction for ZK-circuit compatibility.
222    fn hash_left_right(&self, left: u128, right: u128) -> u128 {
223        let field_size = self.hasher.field_prime();
224        let c = 0_u128;
225
226        let mut r = left;
227        r = self.hasher.mimc_sponge(r, c, field_size);
228        r = r.wrapping_add(right).wrapping_rem(field_size);
229        r = self.hasher.mimc_sponge(r, c, field_size);
230
231        r
232    }
233
234    /// Inserts a new leaf into the tree.
235    ///
236    /// # Arguments
237    ///
238    /// * `leaf` - The leaf value to insert
239    ///
240    /// # Returns
241    ///
242    /// The index of the inserted leaf, or an error if the tree is full.
243    ///
244    /// # Errors
245    ///
246    /// Returns [`Error::TreeFull`] if the tree has reached its maximum capacity.
247    ///
248    /// # Example
249    ///
250    /// ```
251    /// use stealth_lib::MerkleTree;
252    ///
253    /// let mut tree = MerkleTree::new(20).unwrap();
254    /// let index = tree.insert(12345).unwrap();
255    /// assert_eq!(index, 0);
256    ///
257    /// let index = tree.insert(67890).unwrap();
258    /// assert_eq!(index, 1);
259    /// ```
260    pub fn insert(&mut self, leaf: u128) -> Result<u32> {
261        let capacity = self.capacity();
262        if (self.next_index as usize) >= capacity {
263            return Err(Error::TreeFull {
264                capacity,
265                attempted_index: self.next_index as usize,
266            });
267        }
268
269        let inserted_index = self.next_index;
270        let mut current_index = self.next_index;
271        let mut current_level_hash = leaf;
272
273        // Store the leaf for proof generation
274        self.leaves.push(leaf);
275
276        // Update the tree path from leaf to root
277        for i in 0..self.levels {
278            let (left, right) = if current_index % 2 == 0 {
279                // This is a left child
280                self.filled_subtrees.insert(i, current_level_hash);
281                (current_level_hash, self.zeros(i))
282            } else {
283                // This is a right child
284                let left = self
285                    .filled_subtrees
286                    .get(&i)
287                    .copied()
288                    .unwrap_or_else(|| self.zeros(i));
289                (left, current_level_hash)
290            };
291
292            current_level_hash = self.hash_left_right(left, right);
293            current_index /= 2;
294        }
295
296        // Update root history
297        let new_root_index = (self.current_root_index + 1) % ROOT_HISTORY_SIZE;
298        self.current_root_index = new_root_index;
299        self.roots.insert(new_root_index, current_level_hash);
300        self.next_index = inserted_index + 1;
301
302        Ok(inserted_index)
303    }
304
305    /// Checks if a root hash is in the recent root history.
306    ///
307    /// The tree maintains a circular buffer of recent roots to handle
308    /// concurrent insertions in on-chain applications.
309    ///
310    /// # Arguments
311    ///
312    /// * `root` - The root hash to check
313    ///
314    /// # Returns
315    ///
316    /// `true` if the root is in the history, `false` otherwise.
317    ///
318    /// # Example
319    ///
320    /// ```
321    /// use stealth_lib::MerkleTree;
322    ///
323    /// let mut tree = MerkleTree::new(20).unwrap();
324    /// let root_before = tree.root().unwrap();
325    /// tree.insert(12345).unwrap();
326    /// let root_after = tree.root().unwrap();
327    ///
328    /// // Both roots are in history
329    /// assert!(tree.is_known_root(root_before));
330    /// assert!(tree.is_known_root(root_after));
331    ///
332    /// // Random value is not
333    /// assert!(!tree.is_known_root(99999));
334    /// ```
335    pub fn is_known_root(&self, root: u128) -> bool {
336        if root == 0 {
337            return false;
338        }
339
340        let mut i = self.current_root_index;
341        loop {
342            if let Some(&stored_root) = self.roots.get(&i) {
343                if stored_root == root {
344                    return true;
345                }
346            }
347
348            i = if i == 0 {
349                ROOT_HISTORY_SIZE - 1
350            } else {
351                i - 1
352            };
353
354            if i == self.current_root_index {
355                break;
356            }
357        }
358
359        false
360    }
361
362    /// Returns the last (current) root hash.
363    ///
364    /// # Panics
365    ///
366    /// Panics if the tree is in an invalid state (should not happen under normal usage).
367    /// Prefer using [`root`](Self::root) for fallible access.
368    #[deprecated(since = "1.0.0", note = "Use root() instead")]
369    pub fn get_last_root(&self) -> u128 {
370        self.root().expect("Tree in invalid state: no root")
371    }
372
373    /// Computes the zero hash at a given level.
374    ///
375    /// Zero hashes represent empty subtrees at each level.
376    /// This uses the same formula as the original Tornado Cash implementation:
377    /// `zeros(0) = 0`, `zeros(i) = mimc_sponge(zeros(i-1), 0, p)`.
378    ///
379    /// Note: This is NOT the same as `hash_left_right(zeros(i-1), zeros(i-1))`.
380    /// The formula is chosen for compatibility with existing ZK circuits.
381    pub fn zeros(&self, level: u8) -> u128 {
382        let mut result = 0u128;
383        for _ in 0..level {
384            result = self.hasher.mimc_sponge(result, 0, self.hasher.field_prime());
385        }
386        result
387    }
388
389    /// Generates a Merkle proof for the leaf at the given index.
390    ///
391    /// # Arguments
392    ///
393    /// * `leaf_index` - The index of the leaf to prove
394    ///
395    /// # Returns
396    ///
397    /// A [`MerkleProof`] that can be used to verify inclusion.
398    ///
399    /// # Errors
400    ///
401    /// Returns [`Error::LeafIndexOutOfBounds`] if the index is invalid.
402    ///
403    /// # Example
404    ///
405    /// ```
406    /// use stealth_lib::MerkleTree;
407    ///
408    /// let mut tree = MerkleTree::new(20).unwrap();
409    /// tree.insert(12345).unwrap();
410    /// tree.insert(67890).unwrap();
411    ///
412    /// let proof = tree.prove(0).unwrap();
413    /// let root = tree.root().unwrap();
414    /// assert!(proof.verify(root, &tree.hasher()));
415    /// ```
416    pub fn prove(&self, leaf_index: u32) -> Result<MerkleProof> {
417        if leaf_index >= self.next_index {
418            return Err(Error::LeafIndexOutOfBounds {
419                index: leaf_index,
420                tree_size: self.next_index,
421            });
422        }
423
424        let leaf = self.leaves[leaf_index as usize];
425        let mut path = Vec::with_capacity(self.levels as usize);
426        let mut indices = Vec::with_capacity(self.levels as usize);
427        let mut current_index = leaf_index;
428
429        for level in 0..self.levels {
430            let is_right = current_index % 2 == 1;
431            indices.push(is_right);
432
433            // Get sibling
434            let sibling_index = if is_right {
435                current_index - 1
436            } else {
437                current_index + 1
438            };
439
440            let sibling = self.get_node_at(level, sibling_index);
441            path.push(sibling);
442
443            current_index /= 2;
444        }
445
446        Ok(MerkleProof {
447            leaf,
448            leaf_index,
449            path,
450            indices,
451        })
452    }
453
454    /// Gets the hash value of a node at a specific level and index.
455    ///
456    /// For levels below the current tree depth, this reconstructs the hash.
457    /// Empty positions return the zero hash for that level.
458    fn get_node_at(&self, level: u8, index: u32) -> u128 {
459        if level == 0 {
460            // Leaf level
461            if (index as usize) < self.leaves.len() {
462                return self.leaves[index as usize];
463            } else {
464                return 0; // zeros(0) = 0
465            }
466        }
467
468        // Check if this subtree is completely empty
469        // A subtree at (level, index) covers leaf indices from 
470        // index * 2^level to (index+1) * 2^level - 1
471        let leaves_per_subtree = 1u32 << level;
472        let subtree_start = index * leaves_per_subtree;
473        
474        // If all leaves in this subtree would be beyond our current tree size,
475        // return the precomputed zero value
476        if subtree_start >= self.next_index {
477            return self.zeros(level);
478        }
479
480        // Otherwise compute by combining children
481        let left_index = index * 2;
482        let right_index = left_index + 1;
483
484        let left = self.get_node_at(level - 1, left_index);
485        let right = self.get_node_at(level - 1, right_index);
486
487        self.hash_left_right(left, right)
488    }
489}
490
491#[cfg(feature = "borsh")]
492mod borsh_impl {
493    // Note: Full borsh implementation would go here
494    // For now, we document that this is available under the feature flag
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500
501    #[test]
502    fn test_new_tree() {
503        let tree = MerkleTree::new(20).unwrap();
504        assert_eq!(tree.levels(), 20);
505        assert_eq!(tree.capacity(), 1 << 20);
506        assert_eq!(tree.len(), 0);
507        assert!(tree.is_empty());
508    }
509
510    #[test]
511    fn test_new_tree_invalid_levels() {
512        assert!(MerkleTree::new(0).is_err());
513        assert!(MerkleTree::new(33).is_err());
514    }
515
516    #[test]
517    fn test_insert_single() {
518        let mut tree = MerkleTree::new(20).unwrap();
519        let index = tree.insert(12345).unwrap();
520        assert_eq!(index, 0);
521        assert_eq!(tree.len(), 1);
522        assert!(!tree.is_empty());
523    }
524
525    #[test]
526    fn test_insert_multiple() {
527        let mut tree = MerkleTree::new(20).unwrap();
528        for i in 0..10 {
529            let index = tree.insert(i as u128).unwrap();
530            assert_eq!(index, i);
531        }
532        assert_eq!(tree.len(), 10);
533    }
534
535    #[test]
536    fn test_tree_full() {
537        let mut tree = MerkleTree::new(2).unwrap(); // Can hold 4 leaves
538        for i in 0..4 {
539            tree.insert(i as u128).unwrap();
540        }
541        let result = tree.insert(100);
542        assert!(matches!(result, Err(Error::TreeFull { .. })));
543    }
544
545    #[test]
546    fn test_root_changes_on_insert() {
547        let mut tree = MerkleTree::new(20).unwrap();
548        let root1 = tree.root().unwrap();
549        tree.insert(12345).unwrap();
550        let root2 = tree.root().unwrap();
551        assert_ne!(root1, root2);
552    }
553
554    #[test]
555    fn test_is_known_root() {
556        let mut tree = MerkleTree::new(20).unwrap();
557        let root1 = tree.root().unwrap();
558        tree.insert(12345).unwrap();
559        let root2 = tree.root().unwrap();
560
561        assert!(tree.is_known_root(root1));
562        assert!(tree.is_known_root(root2));
563        assert!(!tree.is_known_root(99999));
564        assert!(!tree.is_known_root(0));
565    }
566
567    #[test]
568    fn test_zeros_computation() {
569        let tree = MerkleTree::new(10).unwrap();
570        let zero0 = tree.zeros(0);
571        let zero1 = tree.zeros(1);
572        assert_eq!(zero0, 0);
573        assert_ne!(zero1, 0);
574    }
575
576    #[test]
577    fn test_deterministic_roots() {
578        let mut tree1 = MerkleTree::new(10).unwrap();
579        let mut tree2 = MerkleTree::new(10).unwrap();
580
581        tree1.insert(123).unwrap();
582        tree1.insert(456).unwrap();
583
584        tree2.insert(123).unwrap();
585        tree2.insert(456).unwrap();
586
587        assert_eq!(tree1.root(), tree2.root());
588    }
589
590    #[test]
591    fn test_prove_valid_index() {
592        let mut tree = MerkleTree::new(10).unwrap();
593        tree.insert(12345).unwrap();
594        tree.insert(67890).unwrap();
595
596        let proof = tree.prove(0).unwrap();
597        assert_eq!(proof.leaf, 12345);
598        assert_eq!(proof.leaf_index, 0);
599        assert_eq!(proof.path.len(), 10);
600    }
601
602    #[test]
603    fn test_prove_invalid_index() {
604        let mut tree = MerkleTree::new(10).unwrap();
605        tree.insert(12345).unwrap();
606
607        let result = tree.prove(1);
608        assert!(matches!(result, Err(Error::LeafIndexOutOfBounds { .. })));
609    }
610
611    #[test]
612    fn test_proof_verifies() {
613        let mut tree = MerkleTree::new(10).unwrap();
614        tree.insert(12345).unwrap();
615        tree.insert(67890).unwrap();
616        tree.insert(11111).unwrap();
617
618        let root = tree.root().unwrap();
619
620        for i in 0..3 {
621            let proof = tree.prove(i).unwrap();
622            assert!(proof.verify(root, &tree.hasher()), "Proof failed for leaf {}", i);
623        }
624    }
625}