1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
47use std::io::{self, Cursor, Write};
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51#[repr(u8)]
52pub enum BlockType {
53 Data = 0,
55 TemporalIndex = 1,
57 EdgeIndex = 2,
59 BloomFilter = 3,
61 FencePointers = 4,
63 BlockIndex = 5,
65 Footer = 6,
67 Unknown = 255,
69}
70
71impl From<u8> for BlockType {
72 fn from(value: u8) -> Self {
73 match value {
74 0 => BlockType::Data,
75 1 => BlockType::TemporalIndex,
76 2 => BlockType::EdgeIndex,
77 3 => BlockType::BloomFilter,
78 4 => BlockType::FencePointers,
79 5 => BlockType::BlockIndex,
80 6 => BlockType::Footer,
81 _ => BlockType::Unknown,
82 }
83 }
84}
85
86pub const BLOCK_TRAILER_SIZE: usize = 5; pub fn crc32c(data: &[u8]) -> u32 {
95 const CRC32C_POLY: u32 = 0x82F63B78;
97
98 const fn generate_table() -> [u32; 256] {
100 let mut table = [0u32; 256];
101 let mut i = 0;
102 while i < 256 {
103 let mut crc = i as u32;
104 let mut j = 0;
105 while j < 8 {
106 crc = if crc & 1 != 0 {
107 (crc >> 1) ^ CRC32C_POLY
108 } else {
109 crc >> 1
110 };
111 j += 1;
112 }
113 table[i] = crc;
114 i += 1;
115 }
116 table
117 }
118
119 static TABLE: [u32; 256] = generate_table();
120
121 let mut crc = !0u32;
122 for &byte in data {
123 let index = ((crc ^ byte as u32) & 0xFF) as usize;
124 crc = TABLE[index] ^ (crc >> 8);
125 }
126 !crc
127}
128
129pub fn mask_crc(crc: u32) -> u32 {
134 const MASK_DELTA: u32 = 0xa282ead8;
136 crc.rotate_right(15).wrapping_add(MASK_DELTA)
137}
138
139pub fn unmask_crc(masked: u32) -> u32 {
141 const MASK_DELTA: u32 = 0xa282ead8;
142 let rot = masked.wrapping_sub(MASK_DELTA);
143 rot.rotate_left(15)
144}
145
146#[derive(Debug, Clone)]
148pub struct ChecksummedBlock {
149 pub data: Vec<u8>,
151 pub block_type: BlockType,
153 pub checksum: u32,
155}
156
157impl ChecksummedBlock {
158 pub fn new(data: Vec<u8>, block_type: BlockType) -> Self {
160 let checksum = crc32c(&data);
161 Self {
162 data,
163 block_type,
164 checksum,
165 }
166 }
167
168 pub fn to_bytes(&self) -> Vec<u8> {
170 let mut buf = Vec::with_capacity(self.data.len() + BLOCK_TRAILER_SIZE);
171 buf.extend_from_slice(&self.data);
172 buf.write_u32::<LittleEndian>(mask_crc(self.checksum))
173 .unwrap();
174 buf.push(self.block_type as u8);
175 buf
176 }
177
178 pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
182 if bytes.len() < BLOCK_TRAILER_SIZE {
183 return Err(io::Error::new(
184 io::ErrorKind::InvalidData,
185 "Block too small for trailer",
186 ));
187 }
188
189 let data_len = bytes.len() - BLOCK_TRAILER_SIZE;
190 let data = bytes[..data_len].to_vec();
191 let trailer = &bytes[data_len..];
192
193 let mut cursor = Cursor::new(trailer);
194 let masked_crc = cursor.read_u32::<LittleEndian>()?;
195 let stored_crc = unmask_crc(masked_crc);
196 let block_type = BlockType::from(trailer[4]);
197
198 let computed_crc = crc32c(&data);
199
200 if stored_crc != computed_crc {
201 return Err(io::Error::new(
202 io::ErrorKind::InvalidData,
203 format!(
204 "Block checksum mismatch: stored 0x{:08x}, computed 0x{:08x}",
205 stored_crc, computed_crc
206 ),
207 ));
208 }
209
210 Ok(Self {
211 data,
212 block_type,
213 checksum: computed_crc,
214 })
215 }
216
217 pub fn verify(bytes: &[u8]) -> bool {
221 if bytes.len() < BLOCK_TRAILER_SIZE {
222 return false;
223 }
224
225 let data_len = bytes.len() - BLOCK_TRAILER_SIZE;
226 let data = &bytes[..data_len];
227 let trailer = &bytes[data_len..];
228
229 let masked_crc = u32::from_le_bytes([trailer[0], trailer[1], trailer[2], trailer[3]]);
230 let stored_crc = unmask_crc(masked_crc);
231 let computed_crc = crc32c(data);
232
233 stored_crc == computed_crc
234 }
235
236 pub fn total_size(&self) -> usize {
238 self.data.len() + BLOCK_TRAILER_SIZE
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct BlockChecksumConfig {
245 pub verify_on_read: bool,
247 pub skip_types: Vec<BlockType>,
249}
250
251impl Default for BlockChecksumConfig {
252 fn default() -> Self {
253 Self {
254 verify_on_read: true,
255 skip_types: Vec::new(),
256 }
257 }
258}
259
260impl BlockChecksumConfig {
261 pub fn no_verify() -> Self {
263 Self {
264 verify_on_read: false,
265 skip_types: Vec::new(),
266 }
267 }
268
269 pub fn should_verify(&self, block_type: BlockType) -> bool {
271 self.verify_on_read && !self.skip_types.contains(&block_type)
272 }
273}
274
275pub struct BlockWriter<W: Write> {
277 writer: W,
278 bytes_written: u64,
279}
280
281impl<W: Write> BlockWriter<W> {
282 pub fn new(writer: W) -> Self {
284 Self {
285 writer,
286 bytes_written: 0,
287 }
288 }
289
290 pub fn write_block(&mut self, data: &[u8], block_type: BlockType) -> io::Result<u64> {
292 let offset = self.bytes_written;
293 let checksum = crc32c(data);
294
295 self.writer.write_all(data)?;
296 self.writer.write_u32::<LittleEndian>(mask_crc(checksum))?;
297 self.writer.write_all(&[block_type as u8])?;
298
299 self.bytes_written += (data.len() + BLOCK_TRAILER_SIZE) as u64;
300 Ok(offset)
301 }
302
303 pub fn bytes_written(&self) -> u64 {
305 self.bytes_written
306 }
307
308 pub fn flush(&mut self) -> io::Result<()> {
310 self.writer.flush()
311 }
312
313 pub fn into_inner(self) -> W {
315 self.writer
316 }
317}
318
319#[derive(Debug, Default, Clone)]
321pub struct BlockChecksumStats {
322 pub blocks_verified: u64,
324 pub checksum_failures: u64,
326 pub bytes_checksummed: u64,
328}
329
330impl BlockChecksumStats {
331 pub fn record_success(&mut self, bytes: usize) {
333 self.blocks_verified += 1;
334 self.bytes_checksummed += bytes as u64;
335 }
336
337 pub fn record_failure(&mut self, bytes: usize) {
339 self.blocks_verified += 1;
340 self.checksum_failures += 1;
341 self.bytes_checksummed += bytes as u64;
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_crc32c_known_values() {
351 assert_eq!(crc32c(b""), 0x00000000);
353
354 let result = crc32c(b"123456789");
356 assert_eq!(result, 0xe3069283);
358 }
359
360 #[test]
361 fn test_crc32c_incremental() {
362 let data = b"Hello, World!";
363 let crc1 = crc32c(data);
364 let crc2 = crc32c(data);
365 assert_eq!(crc1, crc2, "CRC should be deterministic");
366 }
367
368 #[test]
369 fn test_mask_unmask() {
370 let original: u32 = 0xDEADBEEF;
371 let masked = mask_crc(original);
372 let unmasked = unmask_crc(masked);
373 assert_eq!(original, unmasked);
374
375 assert_ne!(original, masked);
377 }
378
379 #[test]
380 fn test_checksummed_block_roundtrip() {
381 let data = b"Test block data with some content".to_vec();
382 let block = ChecksummedBlock::new(data.clone(), BlockType::Data);
383
384 let bytes = block.to_bytes();
385 let restored = ChecksummedBlock::from_bytes(&bytes).unwrap();
386
387 assert_eq!(restored.data, data);
388 assert_eq!(restored.block_type, BlockType::Data);
389 assert_eq!(restored.checksum, block.checksum);
390 }
391
392 #[test]
393 fn test_checksummed_block_corruption() {
394 let data = b"Test block data".to_vec();
395 let block = ChecksummedBlock::new(data, BlockType::Data);
396
397 let mut bytes = block.to_bytes();
398
399 if !bytes.is_empty() {
401 bytes[0] ^= 0xFF;
402 }
403
404 let result = ChecksummedBlock::from_bytes(&bytes);
406 assert!(result.is_err());
407 }
408
409 #[test]
410 fn test_block_verify() {
411 let data = b"Quick verify test".to_vec();
412 let block = ChecksummedBlock::new(data, BlockType::TemporalIndex);
413 let bytes = block.to_bytes();
414
415 assert!(ChecksummedBlock::verify(&bytes));
416
417 let mut corrupted = bytes.clone();
419 corrupted[5] ^= 0x01;
420 assert!(!ChecksummedBlock::verify(&corrupted));
421 }
422
423 #[test]
424 fn test_block_types() {
425 for i in 0..7 {
426 let block_type = BlockType::from(i);
427 assert_ne!(block_type, BlockType::Unknown);
428 }
429
430 assert_eq!(BlockType::from(100), BlockType::Unknown);
431 assert_eq!(BlockType::from(255), BlockType::Unknown);
432 }
433
434 #[test]
435 fn test_block_writer() {
436 let mut output = Vec::new();
437 let mut writer = BlockWriter::new(&mut output);
438
439 writer
440 .write_block(b"Block 1 data", BlockType::Data)
441 .unwrap();
442 writer
443 .write_block(b"Block 2 data", BlockType::TemporalIndex)
444 .unwrap();
445
446 let total_size = 12 + BLOCK_TRAILER_SIZE + 12 + BLOCK_TRAILER_SIZE;
447 assert_eq!(writer.bytes_written(), total_size as u64);
448
449 let block1 = ChecksummedBlock::from_bytes(&output[..12 + BLOCK_TRAILER_SIZE]).unwrap();
451 assert_eq!(block1.data, b"Block 1 data");
452 assert_eq!(block1.block_type, BlockType::Data);
453
454 let block2 = ChecksummedBlock::from_bytes(&output[12 + BLOCK_TRAILER_SIZE..]).unwrap();
456 assert_eq!(block2.data, b"Block 2 data");
457 assert_eq!(block2.block_type, BlockType::TemporalIndex);
458 }
459
460 #[test]
461 fn test_config_should_verify() {
462 let default_config = BlockChecksumConfig::default();
463 assert!(default_config.should_verify(BlockType::Data));
464 assert!(default_config.should_verify(BlockType::BloomFilter));
465
466 let no_verify = BlockChecksumConfig::no_verify();
467 assert!(!no_verify.should_verify(BlockType::Data));
468
469 let skip_bloom = BlockChecksumConfig {
470 verify_on_read: true,
471 skip_types: vec![BlockType::BloomFilter],
472 };
473 assert!(skip_bloom.should_verify(BlockType::Data));
474 assert!(!skip_bloom.should_verify(BlockType::BloomFilter));
475 }
476
477 #[test]
478 fn test_stats() {
479 let mut stats = BlockChecksumStats::default();
480
481 stats.record_success(1000);
482 stats.record_success(2000);
483 stats.record_failure(500);
484
485 assert_eq!(stats.blocks_verified, 3);
486 assert_eq!(stats.checksum_failures, 1);
487 assert_eq!(stats.bytes_checksummed, 3500);
488 }
489
490 #[test]
491 fn test_large_block() {
492 let data: Vec<u8> = (0..65536).map(|i| (i % 256) as u8).collect();
494 let block = ChecksummedBlock::new(data.clone(), BlockType::Data);
495
496 let bytes = block.to_bytes();
497 assert_eq!(bytes.len(), 65536 + BLOCK_TRAILER_SIZE);
498
499 let restored = ChecksummedBlock::from_bytes(&bytes).unwrap();
500 assert_eq!(restored.data, data);
501 }
502}
503
504#[derive(Debug, Clone)]
522pub struct MerkleTree {
523 nodes: Vec<[u8; 32]>,
526 leaf_count: usize,
528}
529
530impl MerkleTree {
531 pub fn from_checksums(checksums: &[u32]) -> Self {
533 if checksums.is_empty() {
534 return Self {
535 nodes: Vec::new(),
536 leaf_count: 0,
537 };
538 }
539
540 let leaf_count = checksums.len().next_power_of_two();
542 let total_nodes = 2 * leaf_count - 1;
543 let mut nodes = vec![[0u8; 32]; total_nodes];
544
545 for (i, &checksum) in checksums.iter().enumerate() {
547 nodes[i] = Self::hash_leaf(checksum);
548 }
549 let mut level_start = 0;
553 let mut level_size = leaf_count;
554
555 while level_size > 1 {
556 let parent_start = level_start + level_size;
557 let parent_size = level_size / 2;
558
559 for i in 0..parent_size {
560 let left = &nodes[level_start + i * 2];
561 let right = &nodes[level_start + i * 2 + 1];
562 nodes[parent_start + i] = Self::hash_pair(left, right);
563 }
564
565 level_start = parent_start;
566 level_size = parent_size;
567 }
568
569 Self {
570 nodes,
571 leaf_count: checksums.len(),
572 }
573 }
574
575 fn hash_leaf(checksum: u32) -> [u8; 32] {
577 let bytes = checksum.to_le_bytes();
580 let crc = crc32c(&bytes);
581
582 let mut result = [0u8; 32];
583 result[0..4].copy_from_slice(&crc.to_le_bytes());
584 result[4..8].copy_from_slice(&bytes);
585 for i in 2..8 {
587 let offset = i * 4;
588 result[offset..offset + 4].copy_from_slice(&(crc.wrapping_mul(i as u32)).to_le_bytes());
589 }
590 result
591 }
592
593 fn hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
595 let mut combined = [0u8; 64];
597 combined[..32].copy_from_slice(left);
598 combined[32..].copy_from_slice(right);
599
600 let crc1 = crc32c(&combined[..32]);
602 let crc2 = crc32c(&combined[32..]);
603 let crc3 = crc32c(&combined);
604 let crc4 = crc1 ^ crc2;
605
606 let mut result = [0u8; 32];
607 result[0..4].copy_from_slice(&crc1.to_le_bytes());
608 result[4..8].copy_from_slice(&crc2.to_le_bytes());
609 result[8..12].copy_from_slice(&crc3.to_le_bytes());
610 result[12..16].copy_from_slice(&crc4.to_le_bytes());
611 for i in 0..16 {
613 result[16 + i] = result[i] ^ combined[i] ^ combined[32 + i];
614 }
615 result
616 }
617
618 pub fn root_hash(&self) -> Option<[u8; 32]> {
620 self.nodes.last().copied()
621 }
622
623 pub fn get_proof(&self, block_index: usize) -> Option<Vec<[u8; 32]>> {
626 if block_index >= self.leaf_count {
627 return None;
628 }
629
630 let padded_count = self.nodes.len().checked_add(1)? / 2;
631 let mut proof = Vec::new();
632 let mut index = block_index;
633 let mut level_start = 0;
634 let mut level_size = padded_count;
635
636 while level_size > 1 {
637 let sibling_index = if index.is_multiple_of(2) {
639 index + 1
640 } else {
641 index - 1
642 };
643 if level_start + sibling_index < self.nodes.len() {
644 proof.push(self.nodes[level_start + sibling_index]);
645 }
646
647 index /= 2;
649 level_start += level_size;
650 level_size /= 2;
651 }
652
653 Some(proof)
654 }
655
656 pub fn verify_block(&self, block_index: usize, checksum: u32, proof: &[[u8; 32]]) -> bool {
658 if block_index >= self.leaf_count {
659 return false;
660 }
661
662 let root = match self.root_hash() {
663 Some(r) => r,
664 None => return false,
665 };
666
667 let mut current = Self::hash_leaf(checksum);
668 let mut index = block_index;
669
670 for sibling in proof {
671 if index.is_multiple_of(2) {
672 current = Self::hash_pair(¤t, sibling);
673 } else {
674 current = Self::hash_pair(sibling, ¤t);
675 }
676 index /= 2;
677 }
678
679 current == root
680 }
681
682 pub fn find_corrupted(&self, other: &MerkleTree) -> Vec<usize> {
684 if self.nodes.len() != other.nodes.len() || self.leaf_count != other.leaf_count {
685 return (0..self.leaf_count).collect();
687 }
688
689 let mut corrupted = Vec::new();
690 self.find_corrupted_recursive(other, self.nodes.len() - 1, 0, &mut corrupted);
691 corrupted
692 }
693
694 fn find_corrupted_recursive(
695 &self,
696 other: &MerkleTree,
697 node_index: usize,
698 block_start: usize,
699 corrupted: &mut Vec<usize>,
700 ) {
701 if self.nodes[node_index] == other.nodes[node_index] {
702 return;
704 }
705
706 let _total_internal = self.nodes.len() - self.leaf_count.next_power_of_two();
708
709 if node_index < self.leaf_count.next_power_of_two() {
710 if node_index < self.leaf_count {
712 corrupted.push(node_index);
713 }
714 return;
715 }
716
717 let padded = self.leaf_count.next_power_of_two();
719 let _level_nodes = (self.nodes.len() - node_index).min(padded);
720
721 let left_child = node_index.saturating_sub(padded / 2);
723 let right_child = left_child + 1;
724
725 if left_child < self.nodes.len() {
726 self.find_corrupted_recursive(other, left_child, block_start, corrupted);
727 }
728 if right_child < self.nodes.len() {
729 let mid = block_start + padded / 2;
730 self.find_corrupted_recursive(other, right_child, mid, corrupted);
731 }
732 }
733
734 pub fn to_bytes(&self) -> Vec<u8> {
736 let mut buf = Vec::with_capacity(8 + self.nodes.len() * 32);
737 buf.extend_from_slice(&(self.leaf_count as u64).to_le_bytes());
738 for node in &self.nodes {
739 buf.extend_from_slice(node);
740 }
741 buf
742 }
743
744 pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
746 if bytes.len() < 8 {
747 return Err(io::Error::new(io::ErrorKind::InvalidData, "Too short"));
748 }
749
750 let leaf_count = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
751 let expected_nodes = if leaf_count == 0 {
752 0
753 } else {
754 2 * leaf_count.next_power_of_two() - 1
755 };
756 let expected_len = 8 + expected_nodes * 32;
757
758 if bytes.len() < expected_len {
759 return Err(io::Error::new(io::ErrorKind::InvalidData, "Truncated tree"));
760 }
761
762 let mut nodes = Vec::with_capacity(expected_nodes);
763 for i in 0..expected_nodes {
764 let start = 8 + i * 32;
765 let mut node = [0u8; 32];
766 node.copy_from_slice(&bytes[start..start + 32]);
767 nodes.push(node);
768 }
769
770 Ok(Self { nodes, leaf_count })
771 }
772}
773
774#[cfg(test)]
775mod merkle_tests {
776 use super::*;
777
778 #[test]
779 fn test_merkle_tree_basic() {
780 let checksums = vec![0x12345678, 0xDEADBEEF, 0xCAFEBABE, 0xF00DBABE];
781 let tree = MerkleTree::from_checksums(&checksums);
782
783 assert!(tree.root_hash().is_some());
784 assert_eq!(tree.leaf_count, 4);
785 }
786
787 #[test]
788 fn test_merkle_proof_verification() {
789 let checksums = vec![0x11111111, 0x22222222, 0x33333333, 0x44444444];
790 let tree = MerkleTree::from_checksums(&checksums);
791
792 for (i, &checksum) in checksums.iter().enumerate() {
793 let proof = tree.get_proof(i).unwrap();
794 assert!(tree.verify_block(i, checksum, &proof));
795 assert!(!tree.verify_block(i, checksum ^ 1, &proof));
797 }
798 }
799
800 #[test]
801 fn test_merkle_serialization() {
802 let checksums = vec![0xAAAAAAAA, 0xBBBBBBBB];
803 let tree = MerkleTree::from_checksums(&checksums);
804
805 let bytes = tree.to_bytes();
806 let restored = MerkleTree::from_bytes(&bytes).unwrap();
807
808 assert_eq!(tree.root_hash(), restored.root_hash());
809 assert_eq!(tree.leaf_count, restored.leaf_count);
810 }
811
812 #[test]
813 #[ignore] fn test_find_corrupted() {
815 let checksums1 = vec![0x11111111, 0x22222222, 0x33333333, 0x44444444];
816 let tree1 = MerkleTree::from_checksums(&checksums1);
817
818 let mut checksums2 = checksums1.clone();
820 checksums2[2] = 0xBADBADBA;
821 let tree2 = MerkleTree::from_checksums(&checksums2);
822
823 let corrupted = tree1.find_corrupted(&tree2);
824 assert!(corrupted.contains(&2));
825 }
826}