stealth_lib/merkle/tree.rs
1//! Merkle tree data structure.
2//!
3//! A sparse Merkle tree implementation using MiMC hash, designed for
4//! zero-knowledge proof applications.
5
6use crate::error::{Error, Result};
7use crate::hash::MimcHasher;
8use crate::merkle::proof::MerkleProof;
9use crate::merkle::ROOT_HISTORY_SIZE;
10
11#[cfg(feature = "std")]
12use std::collections::HashMap;
13
14#[cfg(not(feature = "std"))]
15extern crate alloc;
16#[cfg(not(feature = "std"))]
17use alloc::collections::BTreeMap as HashMap;
18#[cfg(not(feature = "std"))]
19use alloc::vec::Vec;
20
21/// A Merkle tree with MiMC hash function.
22///
23/// This implementation is optimized for ZK-circuit compatibility and includes
24/// features like root history for handling concurrent on-chain insertions.
25///
26/// # Example
27///
28/// ```
29/// use stealth_lib::MerkleTree;
30///
31/// // Create a new tree with 20 levels
32/// let mut tree = MerkleTree::new(20).unwrap();
33///
34/// // Insert leaves
35/// let index = tree.insert(12345).unwrap();
36/// assert_eq!(index, 0);
37///
38/// // Get the current root
39/// let root = tree.root().unwrap();
40/// println!("Root: {}", root);
41/// ```
42///
43/// # Capacity
44///
45/// A tree with `n` levels can hold `2^n` leaves. The maximum supported
46/// depth is 255 levels, though practical trees typically use 20-32 levels.
47#[derive(Debug, Clone)]
48pub struct MerkleTree {
49 /// Number of levels in the tree (excluding root).
50 levels: u8,
51 /// Pre-computed subtree hashes for empty positions.
52 filled_subtrees: HashMap<u8, u128>,
53 /// Circular buffer of recent root hashes.
54 roots: HashMap<u8, u128>,
55 /// Index into the roots circular buffer.
56 current_root_index: u8,
57 /// Index for the next leaf to be inserted.
58 next_index: u32,
59 /// Hash function used for the tree.
60 hasher: MimcHasher,
61 /// Leaves inserted into the tree (for proof generation).
62 leaves: Vec<u128>,
63}
64
65impl MerkleTree {
66 /// Creates a new empty Merkle tree with the specified number of levels.
67 ///
68 /// # Arguments
69 ///
70 /// * `levels` - The depth of the tree. The tree can hold `2^levels` leaves.
71 ///
72 /// # Returns
73 ///
74 /// A new `MerkleTree` or an error if the configuration is invalid.
75 ///
76 /// # Errors
77 ///
78 /// Returns [`Error::InvalidTreeConfig`] if `levels` is 0 or greater than 32.
79 ///
80 /// # Example
81 ///
82 /// ```
83 /// use stealth_lib::MerkleTree;
84 ///
85 /// let tree = MerkleTree::new(20).unwrap();
86 /// assert_eq!(tree.levels(), 20);
87 /// assert_eq!(tree.capacity(), 1 << 20);
88 /// ```
89 pub fn new(levels: u8) -> Result<Self> {
90 if levels == 0 {
91 return Err(Error::InvalidTreeConfig(
92 "Tree must have at least 1 level".to_string(),
93 ));
94 }
95 if levels > 32 {
96 return Err(Error::InvalidTreeConfig(
97 "Tree depth cannot exceed 32 levels".to_string(),
98 ));
99 }
100
101 let hasher = MimcHasher::default();
102 let mut instance = MerkleTree {
103 levels,
104 filled_subtrees: HashMap::new(),
105 roots: HashMap::new(),
106 current_root_index: 0,
107 next_index: 0,
108 hasher,
109 leaves: Vec::new(),
110 };
111
112 // Initialize filled_subtrees with zero hashes
113 for i in 0..levels {
114 instance.filled_subtrees.insert(i, instance.zeros(i));
115 }
116
117 // Initialize root with the empty tree root
118 instance.roots.insert(0, instance.zeros(levels - 1));
119
120 Ok(instance)
121 }
122
123 /// Creates a new Merkle tree with a custom hasher.
124 ///
125 /// # Arguments
126 ///
127 /// * `levels` - The depth of the tree
128 /// * `hasher` - Custom MiMC hasher configuration
129 ///
130 /// # Example
131 ///
132 /// ```
133 /// use stealth_lib::{MerkleTree, hash::MimcHasher};
134 ///
135 /// let hasher = MimcHasher::default();
136 /// let tree = MerkleTree::with_hasher(20, hasher).unwrap();
137 /// ```
138 pub fn with_hasher(levels: u8, hasher: MimcHasher) -> Result<Self> {
139 if levels == 0 {
140 return Err(Error::InvalidTreeConfig(
141 "Tree must have at least 1 level".to_string(),
142 ));
143 }
144 if levels > 32 {
145 return Err(Error::InvalidTreeConfig(
146 "Tree depth cannot exceed 32 levels".to_string(),
147 ));
148 }
149
150 let mut instance = MerkleTree {
151 levels,
152 filled_subtrees: HashMap::new(),
153 roots: HashMap::new(),
154 current_root_index: 0,
155 next_index: 0,
156 hasher,
157 leaves: Vec::new(),
158 };
159
160 for i in 0..levels {
161 instance.filled_subtrees.insert(i, instance.zeros(i));
162 }
163
164 instance.roots.insert(0, instance.zeros(levels - 1));
165
166 Ok(instance)
167 }
168
169 /// Returns the number of levels in the tree.
170 #[inline]
171 pub fn levels(&self) -> u8 {
172 self.levels
173 }
174
175 /// Returns the maximum capacity of the tree.
176 ///
177 /// This is `2^levels`.
178 #[inline]
179 pub fn capacity(&self) -> usize {
180 1usize << self.levels
181 }
182
183 /// Returns the current number of leaves in the tree.
184 #[inline]
185 pub fn len(&self) -> u32 {
186 self.next_index
187 }
188
189 /// Returns true if the tree is empty.
190 #[inline]
191 pub fn is_empty(&self) -> bool {
192 self.next_index == 0
193 }
194
195 /// Returns a reference to the hasher used by this tree.
196 #[inline]
197 pub fn hasher(&self) -> &MimcHasher {
198 &self.hasher
199 }
200
201 /// Returns the current root hash of the tree.
202 ///
203 /// Returns `None` only if the tree is in an invalid state (should not happen
204 /// under normal usage).
205 ///
206 /// # Example
207 ///
208 /// ```
209 /// use stealth_lib::MerkleTree;
210 ///
211 /// let tree = MerkleTree::new(20).unwrap();
212 /// let root = tree.root().unwrap();
213 /// println!("Empty tree root: {}", root);
214 /// ```
215 pub fn root(&self) -> Option<u128> {
216 self.roots.get(&self.current_root_index).copied()
217 }
218
219 /// Hashes two child nodes to produce a parent node.
220 ///
221 /// Uses the MiMC sponge construction for ZK-circuit compatibility.
222 fn hash_left_right(&self, left: u128, right: u128) -> u128 {
223 let field_size = self.hasher.field_prime();
224 let c = 0_u128;
225
226 let mut r = left;
227 r = self.hasher.mimc_sponge(r, c, field_size);
228 r = r.wrapping_add(right).wrapping_rem(field_size);
229 r = self.hasher.mimc_sponge(r, c, field_size);
230
231 r
232 }
233
234 /// Inserts a new leaf into the tree.
235 ///
236 /// # Arguments
237 ///
238 /// * `leaf` - The leaf value to insert
239 ///
240 /// # Returns
241 ///
242 /// The index of the inserted leaf, or an error if the tree is full.
243 ///
244 /// # Errors
245 ///
246 /// Returns [`Error::TreeFull`] if the tree has reached its maximum capacity.
247 ///
248 /// # Example
249 ///
250 /// ```
251 /// use stealth_lib::MerkleTree;
252 ///
253 /// let mut tree = MerkleTree::new(20).unwrap();
254 /// let index = tree.insert(12345).unwrap();
255 /// assert_eq!(index, 0);
256 ///
257 /// let index = tree.insert(67890).unwrap();
258 /// assert_eq!(index, 1);
259 /// ```
260 pub fn insert(&mut self, leaf: u128) -> Result<u32> {
261 let capacity = self.capacity();
262 if (self.next_index as usize) >= capacity {
263 return Err(Error::TreeFull {
264 capacity,
265 attempted_index: self.next_index as usize,
266 });
267 }
268
269 let inserted_index = self.next_index;
270 let mut current_index = self.next_index;
271 let mut current_level_hash = leaf;
272
273 // Store the leaf for proof generation
274 self.leaves.push(leaf);
275
276 // Update the tree path from leaf to root
277 for i in 0..self.levels {
278 let (left, right) = if current_index % 2 == 0 {
279 // This is a left child
280 self.filled_subtrees.insert(i, current_level_hash);
281 (current_level_hash, self.zeros(i))
282 } else {
283 // This is a right child
284 let left = self
285 .filled_subtrees
286 .get(&i)
287 .copied()
288 .unwrap_or_else(|| self.zeros(i));
289 (left, current_level_hash)
290 };
291
292 current_level_hash = self.hash_left_right(left, right);
293 current_index /= 2;
294 }
295
296 // Update root history
297 let new_root_index = (self.current_root_index + 1) % ROOT_HISTORY_SIZE;
298 self.current_root_index = new_root_index;
299 self.roots.insert(new_root_index, current_level_hash);
300 self.next_index = inserted_index + 1;
301
302 Ok(inserted_index)
303 }
304
305 /// Checks if a root hash is in the recent root history.
306 ///
307 /// The tree maintains a circular buffer of recent roots to handle
308 /// concurrent insertions in on-chain applications.
309 ///
310 /// # Arguments
311 ///
312 /// * `root` - The root hash to check
313 ///
314 /// # Returns
315 ///
316 /// `true` if the root is in the history, `false` otherwise.
317 ///
318 /// # Example
319 ///
320 /// ```
321 /// use stealth_lib::MerkleTree;
322 ///
323 /// let mut tree = MerkleTree::new(20).unwrap();
324 /// let root_before = tree.root().unwrap();
325 /// tree.insert(12345).unwrap();
326 /// let root_after = tree.root().unwrap();
327 ///
328 /// // Both roots are in history
329 /// assert!(tree.is_known_root(root_before));
330 /// assert!(tree.is_known_root(root_after));
331 ///
332 /// // Random value is not
333 /// assert!(!tree.is_known_root(99999));
334 /// ```
335 pub fn is_known_root(&self, root: u128) -> bool {
336 if root == 0 {
337 return false;
338 }
339
340 let mut i = self.current_root_index;
341 loop {
342 if let Some(&stored_root) = self.roots.get(&i) {
343 if stored_root == root {
344 return true;
345 }
346 }
347
348 i = if i == 0 {
349 ROOT_HISTORY_SIZE - 1
350 } else {
351 i - 1
352 };
353
354 if i == self.current_root_index {
355 break;
356 }
357 }
358
359 false
360 }
361
362 /// Returns the last (current) root hash.
363 ///
364 /// # Panics
365 ///
366 /// Panics if the tree is in an invalid state (should not happen under normal usage).
367 /// Prefer using [`root`](Self::root) for fallible access.
368 #[deprecated(since = "1.0.0", note = "Use root() instead")]
369 pub fn get_last_root(&self) -> u128 {
370 self.root().expect("Tree in invalid state: no root")
371 }
372
373 /// Computes the zero hash at a given level.
374 ///
375 /// Zero hashes represent empty subtrees at each level.
376 /// This uses the same formula as the original Tornado Cash implementation:
377 /// `zeros(0) = 0`, `zeros(i) = mimc_sponge(zeros(i-1), 0, p)`.
378 ///
379 /// Note: This is NOT the same as `hash_left_right(zeros(i-1), zeros(i-1))`.
380 /// The formula is chosen for compatibility with existing ZK circuits.
381 pub fn zeros(&self, level: u8) -> u128 {
382 let mut result = 0u128;
383 for _ in 0..level {
384 result = self.hasher.mimc_sponge(result, 0, self.hasher.field_prime());
385 }
386 result
387 }
388
389 /// Generates a Merkle proof for the leaf at the given index.
390 ///
391 /// # Arguments
392 ///
393 /// * `leaf_index` - The index of the leaf to prove
394 ///
395 /// # Returns
396 ///
397 /// A [`MerkleProof`] that can be used to verify inclusion.
398 ///
399 /// # Errors
400 ///
401 /// Returns [`Error::LeafIndexOutOfBounds`] if the index is invalid.
402 ///
403 /// # Example
404 ///
405 /// ```
406 /// use stealth_lib::MerkleTree;
407 ///
408 /// let mut tree = MerkleTree::new(20).unwrap();
409 /// tree.insert(12345).unwrap();
410 /// tree.insert(67890).unwrap();
411 ///
412 /// let proof = tree.prove(0).unwrap();
413 /// let root = tree.root().unwrap();
414 /// assert!(proof.verify(root, &tree.hasher()));
415 /// ```
416 pub fn prove(&self, leaf_index: u32) -> Result<MerkleProof> {
417 if leaf_index >= self.next_index {
418 return Err(Error::LeafIndexOutOfBounds {
419 index: leaf_index,
420 tree_size: self.next_index,
421 });
422 }
423
424 let leaf = self.leaves[leaf_index as usize];
425 let mut path = Vec::with_capacity(self.levels as usize);
426 let mut indices = Vec::with_capacity(self.levels as usize);
427 let mut current_index = leaf_index;
428
429 for level in 0..self.levels {
430 let is_right = current_index % 2 == 1;
431 indices.push(is_right);
432
433 // Get sibling
434 let sibling_index = if is_right {
435 current_index - 1
436 } else {
437 current_index + 1
438 };
439
440 let sibling = self.get_node_at(level, sibling_index);
441 path.push(sibling);
442
443 current_index /= 2;
444 }
445
446 Ok(MerkleProof {
447 leaf,
448 leaf_index,
449 path,
450 indices,
451 })
452 }
453
454 /// Gets the hash value of a node at a specific level and index.
455 ///
456 /// For levels below the current tree depth, this reconstructs the hash.
457 /// Empty positions return the zero hash for that level.
458 fn get_node_at(&self, level: u8, index: u32) -> u128 {
459 if level == 0 {
460 // Leaf level
461 if (index as usize) < self.leaves.len() {
462 return self.leaves[index as usize];
463 } else {
464 return 0; // zeros(0) = 0
465 }
466 }
467
468 // Check if this subtree is completely empty
469 // A subtree at (level, index) covers leaf indices from
470 // index * 2^level to (index+1) * 2^level - 1
471 let leaves_per_subtree = 1u32 << level;
472 let subtree_start = index * leaves_per_subtree;
473
474 // If all leaves in this subtree would be beyond our current tree size,
475 // return the precomputed zero value
476 if subtree_start >= self.next_index {
477 return self.zeros(level);
478 }
479
480 // Otherwise compute by combining children
481 let left_index = index * 2;
482 let right_index = left_index + 1;
483
484 let left = self.get_node_at(level - 1, left_index);
485 let right = self.get_node_at(level - 1, right_index);
486
487 self.hash_left_right(left, right)
488 }
489}
490
491#[cfg(feature = "borsh")]
492mod borsh_impl {
493 // Note: Full borsh implementation would go here
494 // For now, we document that this is available under the feature flag
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn test_new_tree() {
503 let tree = MerkleTree::new(20).unwrap();
504 assert_eq!(tree.levels(), 20);
505 assert_eq!(tree.capacity(), 1 << 20);
506 assert_eq!(tree.len(), 0);
507 assert!(tree.is_empty());
508 }
509
510 #[test]
511 fn test_new_tree_invalid_levels() {
512 assert!(MerkleTree::new(0).is_err());
513 assert!(MerkleTree::new(33).is_err());
514 }
515
516 #[test]
517 fn test_insert_single() {
518 let mut tree = MerkleTree::new(20).unwrap();
519 let index = tree.insert(12345).unwrap();
520 assert_eq!(index, 0);
521 assert_eq!(tree.len(), 1);
522 assert!(!tree.is_empty());
523 }
524
525 #[test]
526 fn test_insert_multiple() {
527 let mut tree = MerkleTree::new(20).unwrap();
528 for i in 0..10 {
529 let index = tree.insert(i as u128).unwrap();
530 assert_eq!(index, i);
531 }
532 assert_eq!(tree.len(), 10);
533 }
534
535 #[test]
536 fn test_tree_full() {
537 let mut tree = MerkleTree::new(2).unwrap(); // Can hold 4 leaves
538 for i in 0..4 {
539 tree.insert(i as u128).unwrap();
540 }
541 let result = tree.insert(100);
542 assert!(matches!(result, Err(Error::TreeFull { .. })));
543 }
544
545 #[test]
546 fn test_root_changes_on_insert() {
547 let mut tree = MerkleTree::new(20).unwrap();
548 let root1 = tree.root().unwrap();
549 tree.insert(12345).unwrap();
550 let root2 = tree.root().unwrap();
551 assert_ne!(root1, root2);
552 }
553
554 #[test]
555 fn test_is_known_root() {
556 let mut tree = MerkleTree::new(20).unwrap();
557 let root1 = tree.root().unwrap();
558 tree.insert(12345).unwrap();
559 let root2 = tree.root().unwrap();
560
561 assert!(tree.is_known_root(root1));
562 assert!(tree.is_known_root(root2));
563 assert!(!tree.is_known_root(99999));
564 assert!(!tree.is_known_root(0));
565 }
566
567 #[test]
568 fn test_zeros_computation() {
569 let tree = MerkleTree::new(10).unwrap();
570 let zero0 = tree.zeros(0);
571 let zero1 = tree.zeros(1);
572 assert_eq!(zero0, 0);
573 assert_ne!(zero1, 0);
574 }
575
576 #[test]
577 fn test_deterministic_roots() {
578 let mut tree1 = MerkleTree::new(10).unwrap();
579 let mut tree2 = MerkleTree::new(10).unwrap();
580
581 tree1.insert(123).unwrap();
582 tree1.insert(456).unwrap();
583
584 tree2.insert(123).unwrap();
585 tree2.insert(456).unwrap();
586
587 assert_eq!(tree1.root(), tree2.root());
588 }
589
590 #[test]
591 fn test_prove_valid_index() {
592 let mut tree = MerkleTree::new(10).unwrap();
593 tree.insert(12345).unwrap();
594 tree.insert(67890).unwrap();
595
596 let proof = tree.prove(0).unwrap();
597 assert_eq!(proof.leaf, 12345);
598 assert_eq!(proof.leaf_index, 0);
599 assert_eq!(proof.path.len(), 10);
600 }
601
602 #[test]
603 fn test_prove_invalid_index() {
604 let mut tree = MerkleTree::new(10).unwrap();
605 tree.insert(12345).unwrap();
606
607 let result = tree.prove(1);
608 assert!(matches!(result, Err(Error::LeafIndexOutOfBounds { .. })));
609 }
610
611 #[test]
612 fn test_proof_verifies() {
613 let mut tree = MerkleTree::new(10).unwrap();
614 tree.insert(12345).unwrap();
615 tree.insert(67890).unwrap();
616 tree.insert(11111).unwrap();
617
618 let root = tree.root().unwrap();
619
620 for i in 0..3 {
621 let proof = tree.prove(i).unwrap();
622 assert!(proof.verify(root, &tree.hasher()), "Proof failed for leaf {}", i);
623 }
624 }
625}