1use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7use std::fmt;
8
9#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub struct Hash(pub [u8; 32]);
12
13impl Hash {
14 pub fn from_bytes(bytes: [u8; 32]) -> Self {
16 Self(bytes)
17 }
18
19 pub fn digest(data: &[u8]) -> Self {
21 let mut hasher = Sha256::new();
22 hasher.update([0x00]); hasher.update(data);
24 Self(hasher.finalize().into())
25 }
26
27 pub fn combine(left: &Hash, right: &Hash) -> Self {
29 let mut hasher = Sha256::new();
30 hasher.update([0x01]); hasher.update(left.0);
32 hasher.update(right.0);
33 Self(hasher.finalize().into())
34 }
35
36 pub fn to_hex(&self) -> String {
38 hex::encode(self.0)
39 }
40}
41
42impl fmt::Debug for Hash {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 write!(f, "Hash({})", &self.to_hex()[..16])
45 }
46}
47
48impl fmt::Display for Hash {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 write!(f, "{}", &self.to_hex()[..16])
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub enum MerkleNode {
57 Leaf { hash: Hash, data_id: String },
59 Internal {
61 hash: Hash,
62 left: Box<MerkleNode>,
63 right: Box<MerkleNode>,
64 },
65}
66
67impl MerkleNode {
68 pub fn hash(&self) -> &Hash {
70 match self {
71 Self::Leaf { hash, .. } => hash,
72 Self::Internal { hash, .. } => hash,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79pub enum ProofDirection {
80 Left,
82 Right,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct ProofStep {
89 pub sibling_hash: Hash,
91 pub direction: ProofDirection,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct MerkleProof {
99 pub leaf_hash: Hash,
101 pub leaf_id: String,
103 pub path: Vec<ProofStep>,
105 pub root_hash: Hash,
107}
108
109impl MerkleProof {
110 pub fn verify(&self, expected_root: &Hash) -> bool {
112 if &self.root_hash != expected_root {
113 return false;
114 }
115
116 let mut current_hash = self.leaf_hash.clone();
117
118 for step in &self.path {
119 current_hash = match step.direction {
120 ProofDirection::Left => Hash::combine(&step.sibling_hash, ¤t_hash),
121 ProofDirection::Right => Hash::combine(¤t_hash, &step.sibling_hash),
122 };
123 }
124
125 ¤t_hash == expected_root
126 }
127
128 pub fn to_json(&self) -> Result<String, serde_json::Error> {
130 serde_json::to_string(self)
131 }
132
133 pub const MAX_PROOF_JSON_SIZE: usize = 1024 * 1024;
136
137 pub fn from_json_with_limit(json: &str, max_size: usize) -> Result<Self, String> {
146 if json.len() > max_size {
147 return Err(format!(
148 "Proof JSON too large: {} bytes exceeds limit of {} bytes",
149 json.len(),
150 max_size
151 ));
152 }
153 serde_json::from_str(json).map_err(|e| e.to_string())
154 }
155
156 pub fn from_json(json: &str) -> Result<Self, String> {
160 Self::from_json_with_limit(json, Self::MAX_PROOF_JSON_SIZE)
161 }
162}
163
164#[derive(Debug, Clone)]
166pub struct MerkleTree {
167 root: Option<MerkleNode>,
168 leaf_count: usize,
169}
170
171impl MerkleTree {
172 pub fn new() -> Self {
174 Self {
175 root: None,
176 leaf_count: 0,
177 }
178 }
179
180 pub fn from_leaves(leaves: Vec<(String, Hash)>) -> Self {
182 if leaves.is_empty() {
183 return Self::new();
184 }
185
186 let leaf_count = leaves.len();
187 let mut nodes: Vec<MerkleNode> = leaves
188 .into_iter()
189 .map(|(data_id, hash)| MerkleNode::Leaf { hash, data_id })
190 .collect();
191
192 while nodes.len() > 1 {
194 let mut next_level = Vec::with_capacity(nodes.len().div_ceil(2));
195 let mut iter = nodes.into_iter();
196
197 while let Some(left_node) = iter.next() {
198 if let Some(right_node) = iter.next() {
199 let combined_hash = Hash::combine(left_node.hash(), right_node.hash());
200 next_level.push(MerkleNode::Internal {
201 hash: combined_hash,
202 left: Box::new(left_node),
203 right: Box::new(right_node),
204 });
205 } else {
206 next_level.push(left_node);
208 }
209 }
210
211 nodes = next_level;
212 }
213
214 Self {
215 root: nodes.into_iter().next(),
216 leaf_count,
217 }
218 }
219
220 pub fn root_hash(&self) -> Option<&Hash> {
222 self.root.as_ref().map(|n| n.hash())
223 }
224
225 pub fn len(&self) -> usize {
227 self.leaf_count
228 }
229
230 pub fn is_empty(&self) -> bool {
232 self.leaf_count == 0
233 }
234
235 pub fn contains(&self, target_hash: &Hash) -> bool {
237 match &self.root {
238 None => false,
239 Some(node) => Self::contains_node(node, target_hash),
240 }
241 }
242
243 fn contains_node(node: &MerkleNode, target: &Hash) -> bool {
245 match node {
246 MerkleNode::Leaf { hash, .. } => hash == target,
247 MerkleNode::Internal { hash, left, right } => {
248 hash == target
249 || Self::contains_node(left, target)
250 || Self::contains_node(right, target)
251 }
252 }
253 }
254
255 pub fn contains_iterative(&self, target_hash: &Hash) -> bool {
257 let mut stack = Vec::new();
258 if let Some(root) = &self.root {
259 stack.push(root);
260 }
261
262 while let Some(node) = stack.pop() {
263 match node {
264 MerkleNode::Leaf { hash, .. } => {
265 if hash == target_hash {
266 return true;
267 }
268 }
269 MerkleNode::Internal { hash, left, right } => {
270 if hash == target_hash {
271 return true;
272 }
273 stack.push(right);
274 stack.push(left);
275 }
276 }
277 }
278 false
279 }
280
281 #[cfg(feature = "algoswitch")]
283 pub fn contains_optimized(&self, target_hash: &Hash) -> bool {
284 if self.leaf_count < 128 {
287 self.contains(target_hash)
288 } else {
289 self.contains_iterative(target_hash)
290 }
291 }
292
293 pub fn get_proof_by_hash(&self, target_hash: &Hash) -> Option<MerkleProof> {
296 let root = self.root.as_ref()?;
297 let root_hash = root.hash().clone();
298
299 let mut path = Vec::new();
300 let (leaf_hash, leaf_id) = Self::find_path_to_hash(root, target_hash, &mut path)?;
301
302 Some(MerkleProof {
303 leaf_hash,
304 leaf_id,
305 path,
306 root_hash,
307 })
308 }
309
310 fn find_path_to_hash(
312 node: &MerkleNode,
313 target: &Hash,
314 path: &mut Vec<ProofStep>,
315 ) -> Option<(Hash, String)> {
316 match node {
317 MerkleNode::Leaf { hash, data_id } => {
318 if hash == target {
319 Some((hash.clone(), data_id.clone()))
320 } else {
321 None
322 }
323 }
324 MerkleNode::Internal { left, right, .. } => {
325 if let Some(result) = Self::find_path_to_hash(left, target, path) {
327 path.push(ProofStep {
329 sibling_hash: right.hash().clone(),
330 direction: ProofDirection::Right,
331 });
332 return Some(result);
333 }
334
335 if let Some(result) = Self::find_path_to_hash(right, target, path) {
337 path.push(ProofStep {
339 sibling_hash: left.hash().clone(),
340 direction: ProofDirection::Left,
341 });
342 return Some(result);
343 }
344
345 None
346 }
347 }
348 }
349
350 pub fn verify_proof(&self, proof: &MerkleProof) -> bool {
352 match self.root_hash() {
353 Some(root) => proof.verify(root),
354 None => false,
355 }
356 }
357}
358
359impl Default for MerkleTree {
360 fn default() -> Self {
361 Self::new()
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_hash_combine() {
371 let h1 = Hash::digest(b"hello");
372 let h2 = Hash::digest(b"world");
373 let combined = Hash::combine(&h1, &h2);
374
375 let combined2 = Hash::combine(&h1, &h2);
377 assert_eq!(combined, combined2);
378 }
379
380 #[test]
381 fn test_merkle_tree() {
382 let leaves = vec![
383 ("a".to_string(), Hash::digest(b"data_a")),
384 ("b".to_string(), Hash::digest(b"data_b")),
385 ("c".to_string(), Hash::digest(b"data_c")),
386 ("d".to_string(), Hash::digest(b"data_d")),
387 ];
388
389 let tree = MerkleTree::from_leaves(leaves.clone());
390 assert_eq!(tree.len(), 4);
391 assert!(tree.root_hash().is_some());
392
393 for (_, hash) in &leaves {
395 assert!(tree.contains(hash));
396 }
397 }
398
399 #[test]
400 fn test_merkle_proof_generation() {
401 let leaves = vec![
402 ("event_1".to_string(), Hash::digest(b"data_1")),
403 ("event_2".to_string(), Hash::digest(b"data_2")),
404 ("event_3".to_string(), Hash::digest(b"data_3")),
405 ("event_4".to_string(), Hash::digest(b"data_4")),
406 ];
407
408 let tree = MerkleTree::from_leaves(leaves.clone());
409 let root = tree.root_hash().unwrap();
410
411 for (id, hash) in &leaves {
413 let proof = tree.get_proof_by_hash(hash).expect("Should find leaf");
414 assert_eq!(&proof.leaf_id, id);
415 assert_eq!(&proof.leaf_hash, hash);
416 assert!(proof.verify(root), "Proof should verify against root");
417 }
418 }
419
420 #[test]
421 fn test_merkle_proof_serialization() {
422 let leaves = vec![
423 ("a".to_string(), Hash::digest(b"data_a")),
424 ("b".to_string(), Hash::digest(b"data_b")),
425 ];
426
427 let tree = MerkleTree::from_leaves(leaves.clone());
428 let proof = tree.get_proof_by_hash(&leaves[0].1).unwrap();
429
430 let json = proof.to_json().expect("Should serialize");
432 assert!(json.contains("leaf_hash"));
433 assert!(json.contains("path"));
434
435 let restored = MerkleProof::from_json(&json).expect("Should deserialize");
437 assert_eq!(proof.leaf_id, restored.leaf_id);
438 assert!(restored.verify(tree.root_hash().unwrap()));
439 }
440
441 #[test]
442 fn test_merkle_proof_not_found() {
443 let leaves = vec![("a".to_string(), Hash::digest(b"data_a"))];
444 let tree = MerkleTree::from_leaves(leaves);
445
446 let fake_hash = Hash::digest(b"not_in_tree");
447 assert!(tree.get_proof_by_hash(&fake_hash).is_none());
448 }
449
450 #[test]
451 fn test_merkle_proof_odd_leaves() {
452 let leaves = vec![
454 ("a".to_string(), Hash::digest(b"data_a")),
455 ("b".to_string(), Hash::digest(b"data_b")),
456 ("c".to_string(), Hash::digest(b"data_c")),
457 ];
458
459 let tree = MerkleTree::from_leaves(leaves.clone());
460 let root = tree.root_hash().unwrap();
461
462 for (_, hash) in &leaves {
464 let proof = tree.get_proof_by_hash(hash).expect("Should find leaf");
465 assert!(proof.verify(root), "Proof should verify for odd tree");
466 }
467 }
468
469 #[test]
470 fn test_merkle_proof_tamper_detection() {
471 let leaves = vec![
472 ("a".to_string(), Hash::digest(b"data_a")),
473 ("b".to_string(), Hash::digest(b"data_b")),
474 ];
475
476 let tree = MerkleTree::from_leaves(leaves.clone());
477 let mut proof = tree.get_proof_by_hash(&leaves[0].1).unwrap();
478
479 proof.leaf_hash = Hash::digest(b"tampered");
481
482 assert!(
484 !proof.verify(tree.root_hash().unwrap()),
485 "Tampered proof should fail"
486 );
487 }
488}