Skip to main content

vex_core/
merkle.rs

1//! Provides cryptographic verification of context packet hierarchies.
2
3// Removed unused import
4
5use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7use std::fmt;
8
9/// A SHA-256 hash (32 bytes)
10#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub struct Hash(pub [u8; 32]);
12
13impl Hash {
14    /// Create a hash from raw bytes
15    pub fn from_bytes(bytes: [u8; 32]) -> Self {
16        Self(bytes)
17    }
18
19    /// Hash arbitrary data (Leaf node domain separation: 0x00)
20    pub fn digest(data: &[u8]) -> Self {
21        let mut hasher = Sha256::new();
22        hasher.update([0x00]); // Leaf prefix
23        hasher.update(data);
24        Self(hasher.finalize().into())
25    }
26
27    /// Combine two hashes (Internal node domain separation: 0x01)
28    pub fn combine(left: &Hash, right: &Hash) -> Self {
29        let mut hasher = Sha256::new();
30        hasher.update([0x01]); // Internal prefix
31        hasher.update(left.0);
32        hasher.update(right.0);
33        Self(hasher.finalize().into())
34    }
35
36    /// Get hex representation
37    pub fn to_hex(&self) -> String {
38        hex::encode(self.0)
39    }
40}
41
42impl fmt::Debug for Hash {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        write!(f, "Hash({})", &self.to_hex()[..16])
45    }
46}
47
48impl fmt::Display for Hash {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        write!(f, "{}", &self.to_hex()[..16])
51    }
52}
53
54/// A node in the Merkle tree
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub enum MerkleNode {
57    /// Leaf node containing actual data hash
58    Leaf { hash: Hash, data_id: String },
59    /// Internal node combining two child hashes
60    Internal {
61        hash: Hash,
62        left: Box<MerkleNode>,
63        right: Box<MerkleNode>,
64    },
65}
66
67impl MerkleNode {
68    /// Get the hash of this node
69    pub fn hash(&self) -> &Hash {
70        match self {
71            Self::Leaf { hash, .. } => hash,
72            Self::Internal { hash, .. } => hash,
73        }
74    }
75}
76
77/// Direction indicator for Merkle proof steps
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum ProofDirection {
80    /// Sibling hash is on the left
81    Left,
82    /// Sibling hash is on the right
83    Right,
84}
85
86/// A single step in a Merkle inclusion proof
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct ProofStep {
89    /// The sibling hash at this level
90    pub sibling_hash: Hash,
91    /// Whether the sibling is on the left or right
92    pub direction: ProofDirection,
93}
94
95/// A Merkle inclusion proof (RFC 6962 compatible)
96/// Allows proving that a leaf is part of a tree without revealing other leaves
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct MerkleProof {
99    /// The leaf hash being proven
100    pub leaf_hash: Hash,
101    /// The leaf's data ID
102    pub leaf_id: String,
103    /// The path from leaf to root (bottom to top)
104    pub path: Vec<ProofStep>,
105    /// The expected root hash
106    pub root_hash: Hash,
107}
108
109impl MerkleProof {
110    /// Verify this proof against a root hash
111    pub fn verify(&self, expected_root: &Hash) -> bool {
112        if &self.root_hash != expected_root {
113            return false;
114        }
115
116        let mut current_hash = self.leaf_hash.clone();
117
118        for step in &self.path {
119            current_hash = match step.direction {
120                ProofDirection::Left => Hash::combine(&step.sibling_hash, &current_hash),
121                ProofDirection::Right => Hash::combine(&current_hash, &step.sibling_hash),
122            };
123        }
124
125        &current_hash == expected_root
126    }
127
128    /// Export proof as a compact JSON string for transmission
129    pub fn to_json(&self) -> Result<String, serde_json::Error> {
130        serde_json::to_string(self)
131    }
132
133    /// Default maximum size for proof JSON (1 MB)
134    /// This limits the number of proof steps to prevent DoS
135    pub const MAX_PROOF_JSON_SIZE: usize = 1024 * 1024;
136
137    /// Import proof from JSON string with size limit (MEDIUM-2 fix)
138    ///
139    /// # Arguments
140    /// * `json` - The JSON string to parse
141    /// * `max_size` - Maximum allowed size in bytes (prevents DoS)
142    ///
143    /// # Errors
144    /// Returns error if JSON exceeds max_size or is invalid
145    pub fn from_json_with_limit(json: &str, max_size: usize) -> Result<Self, String> {
146        if json.len() > max_size {
147            return Err(format!(
148                "Proof JSON too large: {} bytes exceeds limit of {} bytes",
149                json.len(),
150                max_size
151            ));
152        }
153        serde_json::from_str(json).map_err(|e| e.to_string())
154    }
155
156    /// Import proof from JSON string with default 1MB limit
157    ///
158    /// For custom limits, use `from_json_with_limit()`.
159    pub fn from_json(json: &str) -> Result<Self, String> {
160        Self::from_json_with_limit(json, Self::MAX_PROOF_JSON_SIZE)
161    }
162}
163
164/// A Merkle tree for verifying context packet integrity
165#[derive(Debug, Clone)]
166pub struct MerkleTree {
167    root: Option<MerkleNode>,
168    leaf_count: usize,
169}
170
171impl MerkleTree {
172    /// Create an empty Merkle tree
173    pub fn new() -> Self {
174        Self {
175            root: None,
176            leaf_count: 0,
177        }
178    }
179
180    /// Build a Merkle tree from a list of (id, hash) pairs (zero-copy construction)
181    pub fn from_leaves(leaves: Vec<(String, Hash)>) -> Self {
182        if leaves.is_empty() {
183            return Self::new();
184        }
185
186        let leaf_count = leaves.len();
187        let mut nodes: Vec<MerkleNode> = leaves
188            .into_iter()
189            .map(|(data_id, hash)| MerkleNode::Leaf { hash, data_id })
190            .collect();
191
192        // Build tree bottom-up using move semantics (no cloning)
193        while nodes.len() > 1 {
194            let mut next_level = Vec::with_capacity(nodes.len().div_ceil(2));
195            let mut iter = nodes.into_iter();
196
197            while let Some(left_node) = iter.next() {
198                if let Some(right_node) = iter.next() {
199                    let combined_hash = Hash::combine(left_node.hash(), right_node.hash());
200                    next_level.push(MerkleNode::Internal {
201                        hash: combined_hash,
202                        left: Box::new(left_node),
203                        right: Box::new(right_node),
204                    });
205                } else {
206                    // Odd number of nodes, carry the last one up
207                    next_level.push(left_node);
208                }
209            }
210
211            nodes = next_level;
212        }
213
214        Self {
215            root: nodes.into_iter().next(),
216            leaf_count,
217        }
218    }
219
220    /// Get the root hash (None if tree is empty)
221    pub fn root_hash(&self) -> Option<&Hash> {
222        self.root.as_ref().map(|n| n.hash())
223    }
224
225    /// Get the number of leaves
226    pub fn len(&self) -> usize {
227        self.leaf_count
228    }
229
230    /// Check if tree is empty
231    pub fn is_empty(&self) -> bool {
232        self.leaf_count == 0
233    }
234
235    /// Verify that a hash is part of this tree (zero-copy traversal)
236    pub fn contains(&self, target_hash: &Hash) -> bool {
237        match &self.root {
238            None => false,
239            Some(node) => Self::contains_node(node, target_hash),
240        }
241    }
242
243    /// Recursive helper that takes a reference - no cloning needed
244    fn contains_node(node: &MerkleNode, target: &Hash) -> bool {
245        match node {
246            MerkleNode::Leaf { hash, .. } => hash == target,
247            MerkleNode::Internal { hash, left, right } => {
248                hash == target
249                    || Self::contains_node(left, target)
250                    || Self::contains_node(right, target)
251            }
252        }
253    }
254
255    /// Optimized search using iterative traversal (prevents stack overflow for deep trees)
256    pub fn contains_iterative(&self, target_hash: &Hash) -> bool {
257        let mut stack = Vec::new();
258        if let Some(root) = &self.root {
259            stack.push(root);
260        }
261
262        while let Some(node) = stack.pop() {
263            match node {
264                MerkleNode::Leaf { hash, .. } => {
265                    if hash == target_hash {
266                        return true;
267                    }
268                }
269                MerkleNode::Internal { hash, left, right } => {
270                    if hash == target_hash {
271                        return true;
272                    }
273                    stack.push(right);
274                    stack.push(left);
275                }
276            }
277        }
278        false
279    }
280
281    /// AlgoSwitch Select - picks the best search strategy based on tree size
282    #[cfg(feature = "algoswitch")]
283    pub fn contains_optimized(&self, target_hash: &Hash) -> bool {
284        // For small trees, recursive is overhead-free (no stack allocation)
285        // For large trees, iterative is safer and often faster
286        if self.leaf_count < 128 {
287            self.contains(target_hash)
288        } else {
289            self.contains_iterative(target_hash)
290        }
291    }
292
293    /// Generate an inclusion proof for a leaf by its hash
294    /// Returns None if the hash is not found in the tree
295    pub fn get_proof_by_hash(&self, target_hash: &Hash) -> Option<MerkleProof> {
296        let root = self.root.as_ref()?;
297        let root_hash = root.hash().clone();
298
299        let mut path = Vec::new();
300        let (leaf_hash, leaf_id) = Self::find_path_to_hash(root, target_hash, &mut path)?;
301
302        Some(MerkleProof {
303            leaf_hash,
304            leaf_id,
305            path,
306            root_hash,
307        })
308    }
309
310    /// Helper: Find path from root to target hash, collecting sibling hashes
311    fn find_path_to_hash(
312        node: &MerkleNode,
313        target: &Hash,
314        path: &mut Vec<ProofStep>,
315    ) -> Option<(Hash, String)> {
316        match node {
317            MerkleNode::Leaf { hash, data_id } => {
318                if hash == target {
319                    Some((hash.clone(), data_id.clone()))
320                } else {
321                    None
322                }
323            }
324            MerkleNode::Internal { left, right, .. } => {
325                // Try left subtree first
326                if let Some(result) = Self::find_path_to_hash(left, target, path) {
327                    // Target is in left subtree, sibling is on the right
328                    path.push(ProofStep {
329                        sibling_hash: right.hash().clone(),
330                        direction: ProofDirection::Right,
331                    });
332                    return Some(result);
333                }
334
335                // Try right subtree
336                if let Some(result) = Self::find_path_to_hash(right, target, path) {
337                    // Target is in right subtree, sibling is on the left
338                    path.push(ProofStep {
339                        sibling_hash: left.hash().clone(),
340                        direction: ProofDirection::Left,
341                    });
342                    return Some(result);
343                }
344
345                None
346            }
347        }
348    }
349
350    /// Verify a proof against this tree's root
351    pub fn verify_proof(&self, proof: &MerkleProof) -> bool {
352        match self.root_hash() {
353            Some(root) => proof.verify(root),
354            None => false,
355        }
356    }
357}
358
359impl Default for MerkleTree {
360    fn default() -> Self {
361        Self::new()
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn test_hash_combine() {
371        let h1 = Hash::digest(b"hello");
372        let h2 = Hash::digest(b"world");
373        let combined = Hash::combine(&h1, &h2);
374
375        // Combining same hashes should give same result
376        let combined2 = Hash::combine(&h1, &h2);
377        assert_eq!(combined, combined2);
378    }
379
380    #[test]
381    fn test_merkle_tree() {
382        let leaves = vec![
383            ("a".to_string(), Hash::digest(b"data_a")),
384            ("b".to_string(), Hash::digest(b"data_b")),
385            ("c".to_string(), Hash::digest(b"data_c")),
386            ("d".to_string(), Hash::digest(b"data_d")),
387        ];
388
389        let tree = MerkleTree::from_leaves(leaves.clone());
390        assert_eq!(tree.len(), 4);
391        assert!(tree.root_hash().is_some());
392
393        // Should find all original hashes
394        for (_, hash) in &leaves {
395            assert!(tree.contains(hash));
396        }
397    }
398
399    #[test]
400    fn test_merkle_proof_generation() {
401        let leaves = vec![
402            ("event_1".to_string(), Hash::digest(b"data_1")),
403            ("event_2".to_string(), Hash::digest(b"data_2")),
404            ("event_3".to_string(), Hash::digest(b"data_3")),
405            ("event_4".to_string(), Hash::digest(b"data_4")),
406        ];
407
408        let tree = MerkleTree::from_leaves(leaves.clone());
409        let root = tree.root_hash().unwrap();
410
411        // Generate and verify proof for each leaf
412        for (id, hash) in &leaves {
413            let proof = tree.get_proof_by_hash(hash).expect("Should find leaf");
414            assert_eq!(&proof.leaf_id, id);
415            assert_eq!(&proof.leaf_hash, hash);
416            assert!(proof.verify(root), "Proof should verify against root");
417        }
418    }
419
420    #[test]
421    fn test_merkle_proof_serialization() {
422        let leaves = vec![
423            ("a".to_string(), Hash::digest(b"data_a")),
424            ("b".to_string(), Hash::digest(b"data_b")),
425        ];
426
427        let tree = MerkleTree::from_leaves(leaves.clone());
428        let proof = tree.get_proof_by_hash(&leaves[0].1).unwrap();
429
430        // Serialize to JSON
431        let json = proof.to_json().expect("Should serialize");
432        assert!(json.contains("leaf_hash"));
433        assert!(json.contains("path"));
434
435        // Deserialize and verify
436        let restored = MerkleProof::from_json(&json).expect("Should deserialize");
437        assert_eq!(proof.leaf_id, restored.leaf_id);
438        assert!(restored.verify(tree.root_hash().unwrap()));
439    }
440
441    #[test]
442    fn test_merkle_proof_not_found() {
443        let leaves = vec![("a".to_string(), Hash::digest(b"data_a"))];
444        let tree = MerkleTree::from_leaves(leaves);
445
446        let fake_hash = Hash::digest(b"not_in_tree");
447        assert!(tree.get_proof_by_hash(&fake_hash).is_none());
448    }
449
450    #[test]
451    fn test_merkle_proof_odd_leaves() {
452        // Odd number of leaves - tests edge case in tree construction
453        let leaves = vec![
454            ("a".to_string(), Hash::digest(b"data_a")),
455            ("b".to_string(), Hash::digest(b"data_b")),
456            ("c".to_string(), Hash::digest(b"data_c")),
457        ];
458
459        let tree = MerkleTree::from_leaves(leaves.clone());
460        let root = tree.root_hash().unwrap();
461
462        // All proofs should still work
463        for (_, hash) in &leaves {
464            let proof = tree.get_proof_by_hash(hash).expect("Should find leaf");
465            assert!(proof.verify(root), "Proof should verify for odd tree");
466        }
467    }
468
469    #[test]
470    fn test_merkle_proof_tamper_detection() {
471        let leaves = vec![
472            ("a".to_string(), Hash::digest(b"data_a")),
473            ("b".to_string(), Hash::digest(b"data_b")),
474        ];
475
476        let tree = MerkleTree::from_leaves(leaves.clone());
477        let mut proof = tree.get_proof_by_hash(&leaves[0].1).unwrap();
478
479        // Tamper with the leaf hash
480        proof.leaf_hash = Hash::digest(b"tampered");
481
482        // Should fail verification
483        assert!(
484            !proof.verify(tree.root_hash().unwrap()),
485            "Tampered proof should fail"
486        );
487    }
488}