1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
44use std::io::{self, Cursor, Write};
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48#[repr(u8)]
49pub enum BlockType {
50 Data = 0,
52 TemporalIndex = 1,
54 EdgeIndex = 2,
56 BloomFilter = 3,
58 FencePointers = 4,
60 BlockIndex = 5,
62 Footer = 6,
64 Unknown = 255,
66}
67
68impl From<u8> for BlockType {
69 fn from(value: u8) -> Self {
70 match value {
71 0 => BlockType::Data,
72 1 => BlockType::TemporalIndex,
73 2 => BlockType::EdgeIndex,
74 3 => BlockType::BloomFilter,
75 4 => BlockType::FencePointers,
76 5 => BlockType::BlockIndex,
77 6 => BlockType::Footer,
78 _ => BlockType::Unknown,
79 }
80 }
81}
82
83pub const BLOCK_TRAILER_SIZE: usize = 5; pub fn crc32c(data: &[u8]) -> u32 {
92 const CRC32C_POLY: u32 = 0x82F63B78;
94
95 const fn generate_table() -> [u32; 256] {
97 let mut table = [0u32; 256];
98 let mut i = 0;
99 while i < 256 {
100 let mut crc = i as u32;
101 let mut j = 0;
102 while j < 8 {
103 crc = if crc & 1 != 0 {
104 (crc >> 1) ^ CRC32C_POLY
105 } else {
106 crc >> 1
107 };
108 j += 1;
109 }
110 table[i] = crc;
111 i += 1;
112 }
113 table
114 }
115
116 static TABLE: [u32; 256] = generate_table();
117
118 let mut crc = !0u32;
119 for &byte in data {
120 let index = ((crc ^ byte as u32) & 0xFF) as usize;
121 crc = TABLE[index] ^ (crc >> 8);
122 }
123 !crc
124}
125
126pub fn mask_crc(crc: u32) -> u32 {
131 const MASK_DELTA: u32 = 0xa282ead8;
133 crc.rotate_right(15).wrapping_add(MASK_DELTA)
134}
135
136pub fn unmask_crc(masked: u32) -> u32 {
138 const MASK_DELTA: u32 = 0xa282ead8;
139 let rot = masked.wrapping_sub(MASK_DELTA);
140 rot.rotate_left(15)
141}
142
143#[derive(Debug, Clone)]
145pub struct ChecksummedBlock {
146 pub data: Vec<u8>,
148 pub block_type: BlockType,
150 pub checksum: u32,
152}
153
154impl ChecksummedBlock {
155 pub fn new(data: Vec<u8>, block_type: BlockType) -> Self {
157 let checksum = crc32c(&data);
158 Self {
159 data,
160 block_type,
161 checksum,
162 }
163 }
164
165 pub fn to_bytes(&self) -> Vec<u8> {
167 let mut buf = Vec::with_capacity(self.data.len() + BLOCK_TRAILER_SIZE);
168 buf.extend_from_slice(&self.data);
169 buf.write_u32::<LittleEndian>(mask_crc(self.checksum))
170 .unwrap();
171 buf.push(self.block_type as u8);
172 buf
173 }
174
175 pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
179 if bytes.len() < BLOCK_TRAILER_SIZE {
180 return Err(io::Error::new(
181 io::ErrorKind::InvalidData,
182 "Block too small for trailer",
183 ));
184 }
185
186 let data_len = bytes.len() - BLOCK_TRAILER_SIZE;
187 let data = bytes[..data_len].to_vec();
188 let trailer = &bytes[data_len..];
189
190 let mut cursor = Cursor::new(trailer);
191 let masked_crc = cursor.read_u32::<LittleEndian>()?;
192 let stored_crc = unmask_crc(masked_crc);
193 let block_type = BlockType::from(trailer[4]);
194
195 let computed_crc = crc32c(&data);
196
197 if stored_crc != computed_crc {
198 return Err(io::Error::new(
199 io::ErrorKind::InvalidData,
200 format!(
201 "Block checksum mismatch: stored 0x{:08x}, computed 0x{:08x}",
202 stored_crc, computed_crc
203 ),
204 ));
205 }
206
207 Ok(Self {
208 data,
209 block_type,
210 checksum: computed_crc,
211 })
212 }
213
214 pub fn verify(bytes: &[u8]) -> bool {
218 if bytes.len() < BLOCK_TRAILER_SIZE {
219 return false;
220 }
221
222 let data_len = bytes.len() - BLOCK_TRAILER_SIZE;
223 let data = &bytes[..data_len];
224 let trailer = &bytes[data_len..];
225
226 let masked_crc = u32::from_le_bytes([trailer[0], trailer[1], trailer[2], trailer[3]]);
227 let stored_crc = unmask_crc(masked_crc);
228 let computed_crc = crc32c(data);
229
230 stored_crc == computed_crc
231 }
232
233 pub fn total_size(&self) -> usize {
235 self.data.len() + BLOCK_TRAILER_SIZE
236 }
237}
238
239#[derive(Debug, Clone)]
241pub struct BlockChecksumConfig {
242 pub verify_on_read: bool,
244 pub skip_types: Vec<BlockType>,
246}
247
248impl Default for BlockChecksumConfig {
249 fn default() -> Self {
250 Self {
251 verify_on_read: true,
252 skip_types: Vec::new(),
253 }
254 }
255}
256
257impl BlockChecksumConfig {
258 pub fn no_verify() -> Self {
260 Self {
261 verify_on_read: false,
262 skip_types: Vec::new(),
263 }
264 }
265
266 pub fn should_verify(&self, block_type: BlockType) -> bool {
268 self.verify_on_read && !self.skip_types.contains(&block_type)
269 }
270}
271
272pub struct BlockWriter<W: Write> {
274 writer: W,
275 bytes_written: u64,
276}
277
278impl<W: Write> BlockWriter<W> {
279 pub fn new(writer: W) -> Self {
281 Self {
282 writer,
283 bytes_written: 0,
284 }
285 }
286
287 pub fn write_block(&mut self, data: &[u8], block_type: BlockType) -> io::Result<u64> {
289 let offset = self.bytes_written;
290 let checksum = crc32c(data);
291
292 self.writer.write_all(data)?;
293 self.writer.write_u32::<LittleEndian>(mask_crc(checksum))?;
294 self.writer.write_all(&[block_type as u8])?;
295
296 self.bytes_written += (data.len() + BLOCK_TRAILER_SIZE) as u64;
297 Ok(offset)
298 }
299
300 pub fn bytes_written(&self) -> u64 {
302 self.bytes_written
303 }
304
305 pub fn flush(&mut self) -> io::Result<()> {
307 self.writer.flush()
308 }
309
310 pub fn into_inner(self) -> W {
312 self.writer
313 }
314}
315
316#[derive(Debug, Default, Clone)]
318pub struct BlockChecksumStats {
319 pub blocks_verified: u64,
321 pub checksum_failures: u64,
323 pub bytes_checksummed: u64,
325}
326
327impl BlockChecksumStats {
328 pub fn record_success(&mut self, bytes: usize) {
330 self.blocks_verified += 1;
331 self.bytes_checksummed += bytes as u64;
332 }
333
334 pub fn record_failure(&mut self, bytes: usize) {
336 self.blocks_verified += 1;
337 self.checksum_failures += 1;
338 self.bytes_checksummed += bytes as u64;
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_crc32c_known_values() {
348 assert_eq!(crc32c(b""), 0x00000000);
350
351 let result = crc32c(b"123456789");
353 assert_eq!(result, 0xe3069283);
355 }
356
357 #[test]
358 fn test_crc32c_incremental() {
359 let data = b"Hello, World!";
360 let crc1 = crc32c(data);
361 let crc2 = crc32c(data);
362 assert_eq!(crc1, crc2, "CRC should be deterministic");
363 }
364
365 #[test]
366 fn test_mask_unmask() {
367 let original: u32 = 0xDEADBEEF;
368 let masked = mask_crc(original);
369 let unmasked = unmask_crc(masked);
370 assert_eq!(original, unmasked);
371
372 assert_ne!(original, masked);
374 }
375
376 #[test]
377 fn test_checksummed_block_roundtrip() {
378 let data = b"Test block data with some content".to_vec();
379 let block = ChecksummedBlock::new(data.clone(), BlockType::Data);
380
381 let bytes = block.to_bytes();
382 let restored = ChecksummedBlock::from_bytes(&bytes).unwrap();
383
384 assert_eq!(restored.data, data);
385 assert_eq!(restored.block_type, BlockType::Data);
386 assert_eq!(restored.checksum, block.checksum);
387 }
388
389 #[test]
390 fn test_checksummed_block_corruption() {
391 let data = b"Test block data".to_vec();
392 let block = ChecksummedBlock::new(data, BlockType::Data);
393
394 let mut bytes = block.to_bytes();
395
396 if !bytes.is_empty() {
398 bytes[0] ^= 0xFF;
399 }
400
401 let result = ChecksummedBlock::from_bytes(&bytes);
403 assert!(result.is_err());
404 }
405
406 #[test]
407 fn test_block_verify() {
408 let data = b"Quick verify test".to_vec();
409 let block = ChecksummedBlock::new(data, BlockType::TemporalIndex);
410 let bytes = block.to_bytes();
411
412 assert!(ChecksummedBlock::verify(&bytes));
413
414 let mut corrupted = bytes.clone();
416 corrupted[5] ^= 0x01;
417 assert!(!ChecksummedBlock::verify(&corrupted));
418 }
419
420 #[test]
421 fn test_block_types() {
422 for i in 0..7 {
423 let block_type = BlockType::from(i);
424 assert_ne!(block_type, BlockType::Unknown);
425 }
426
427 assert_eq!(BlockType::from(100), BlockType::Unknown);
428 assert_eq!(BlockType::from(255), BlockType::Unknown);
429 }
430
431 #[test]
432 fn test_block_writer() {
433 let mut output = Vec::new();
434 let mut writer = BlockWriter::new(&mut output);
435
436 writer
437 .write_block(b"Block 1 data", BlockType::Data)
438 .unwrap();
439 writer
440 .write_block(b"Block 2 data", BlockType::TemporalIndex)
441 .unwrap();
442
443 let total_size = 12 + BLOCK_TRAILER_SIZE + 12 + BLOCK_TRAILER_SIZE;
444 assert_eq!(writer.bytes_written(), total_size as u64);
445
446 let block1 = ChecksummedBlock::from_bytes(&output[..12 + BLOCK_TRAILER_SIZE]).unwrap();
448 assert_eq!(block1.data, b"Block 1 data");
449 assert_eq!(block1.block_type, BlockType::Data);
450
451 let block2 = ChecksummedBlock::from_bytes(&output[12 + BLOCK_TRAILER_SIZE..]).unwrap();
453 assert_eq!(block2.data, b"Block 2 data");
454 assert_eq!(block2.block_type, BlockType::TemporalIndex);
455 }
456
457 #[test]
458 fn test_config_should_verify() {
459 let default_config = BlockChecksumConfig::default();
460 assert!(default_config.should_verify(BlockType::Data));
461 assert!(default_config.should_verify(BlockType::BloomFilter));
462
463 let no_verify = BlockChecksumConfig::no_verify();
464 assert!(!no_verify.should_verify(BlockType::Data));
465
466 let skip_bloom = BlockChecksumConfig {
467 verify_on_read: true,
468 skip_types: vec![BlockType::BloomFilter],
469 };
470 assert!(skip_bloom.should_verify(BlockType::Data));
471 assert!(!skip_bloom.should_verify(BlockType::BloomFilter));
472 }
473
474 #[test]
475 fn test_stats() {
476 let mut stats = BlockChecksumStats::default();
477
478 stats.record_success(1000);
479 stats.record_success(2000);
480 stats.record_failure(500);
481
482 assert_eq!(stats.blocks_verified, 3);
483 assert_eq!(stats.checksum_failures, 1);
484 assert_eq!(stats.bytes_checksummed, 3500);
485 }
486
487 #[test]
488 fn test_large_block() {
489 let data: Vec<u8> = (0..65536).map(|i| (i % 256) as u8).collect();
491 let block = ChecksummedBlock::new(data.clone(), BlockType::Data);
492
493 let bytes = block.to_bytes();
494 assert_eq!(bytes.len(), 65536 + BLOCK_TRAILER_SIZE);
495
496 let restored = ChecksummedBlock::from_bytes(&bytes).unwrap();
497 assert_eq!(restored.data, data);
498 }
499}
500
501#[derive(Debug, Clone)]
519pub struct MerkleTree {
520 nodes: Vec<[u8; 32]>,
523 leaf_count: usize,
525}
526
527impl MerkleTree {
528 pub fn from_checksums(checksums: &[u32]) -> Self {
530 if checksums.is_empty() {
531 return Self {
532 nodes: Vec::new(),
533 leaf_count: 0,
534 };
535 }
536
537 let leaf_count = checksums.len().next_power_of_two();
539 let total_nodes = 2 * leaf_count - 1;
540 let mut nodes = vec![[0u8; 32]; total_nodes];
541
542 for (i, &checksum) in checksums.iter().enumerate() {
544 nodes[i] = Self::hash_leaf(checksum);
545 }
546 let mut level_start = 0;
550 let mut level_size = leaf_count;
551
552 while level_size > 1 {
553 let parent_start = level_start + level_size;
554 let parent_size = level_size / 2;
555
556 for i in 0..parent_size {
557 let left = &nodes[level_start + i * 2];
558 let right = &nodes[level_start + i * 2 + 1];
559 nodes[parent_start + i] = Self::hash_pair(left, right);
560 }
561
562 level_start = parent_start;
563 level_size = parent_size;
564 }
565
566 Self {
567 nodes,
568 leaf_count: checksums.len(),
569 }
570 }
571
572 fn hash_leaf(checksum: u32) -> [u8; 32] {
574 let bytes = checksum.to_le_bytes();
577 let crc = crc32c(&bytes);
578
579 let mut result = [0u8; 32];
580 result[0..4].copy_from_slice(&crc.to_le_bytes());
581 result[4..8].copy_from_slice(&bytes);
582 for i in 2..8 {
584 let offset = i * 4;
585 result[offset..offset + 4].copy_from_slice(&(crc.wrapping_mul(i as u32)).to_le_bytes());
586 }
587 result
588 }
589
590 fn hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
592 let mut combined = [0u8; 64];
594 combined[..32].copy_from_slice(left);
595 combined[32..].copy_from_slice(right);
596
597 let crc1 = crc32c(&combined[..32]);
599 let crc2 = crc32c(&combined[32..]);
600 let crc3 = crc32c(&combined);
601 let crc4 = crc1 ^ crc2;
602
603 let mut result = [0u8; 32];
604 result[0..4].copy_from_slice(&crc1.to_le_bytes());
605 result[4..8].copy_from_slice(&crc2.to_le_bytes());
606 result[8..12].copy_from_slice(&crc3.to_le_bytes());
607 result[12..16].copy_from_slice(&crc4.to_le_bytes());
608 for i in 0..16 {
610 result[16 + i] = result[i] ^ combined[i] ^ combined[32 + i];
611 }
612 result
613 }
614
615 pub fn root_hash(&self) -> Option<[u8; 32]> {
617 self.nodes.last().copied()
618 }
619
620 pub fn get_proof(&self, block_index: usize) -> Option<Vec<[u8; 32]>> {
623 if block_index >= self.leaf_count {
624 return None;
625 }
626
627 let padded_count = self.nodes.len().checked_add(1)? / 2;
628 let mut proof = Vec::new();
629 let mut index = block_index;
630 let mut level_start = 0;
631 let mut level_size = padded_count;
632
633 while level_size > 1 {
634 let sibling_index = if index.is_multiple_of(2) {
636 index + 1
637 } else {
638 index - 1
639 };
640 if level_start + sibling_index < self.nodes.len() {
641 proof.push(self.nodes[level_start + sibling_index]);
642 }
643
644 index /= 2;
646 level_start += level_size;
647 level_size /= 2;
648 }
649
650 Some(proof)
651 }
652
653 pub fn verify_block(&self, block_index: usize, checksum: u32, proof: &[[u8; 32]]) -> bool {
655 if block_index >= self.leaf_count {
656 return false;
657 }
658
659 let root = match self.root_hash() {
660 Some(r) => r,
661 None => return false,
662 };
663
664 let mut current = Self::hash_leaf(checksum);
665 let mut index = block_index;
666
667 for sibling in proof {
668 if index.is_multiple_of(2) {
669 current = Self::hash_pair(¤t, sibling);
670 } else {
671 current = Self::hash_pair(sibling, ¤t);
672 }
673 index /= 2;
674 }
675
676 current == root
677 }
678
679 pub fn find_corrupted(&self, other: &MerkleTree) -> Vec<usize> {
681 if self.nodes.len() != other.nodes.len() || self.leaf_count != other.leaf_count {
682 return (0..self.leaf_count).collect();
684 }
685
686 let mut corrupted = Vec::new();
687 self.find_corrupted_recursive(other, self.nodes.len() - 1, 0, &mut corrupted);
688 corrupted
689 }
690
691 fn find_corrupted_recursive(
692 &self,
693 other: &MerkleTree,
694 node_index: usize,
695 block_start: usize,
696 corrupted: &mut Vec<usize>,
697 ) {
698 if self.nodes[node_index] == other.nodes[node_index] {
699 return;
701 }
702
703 let _total_internal = self.nodes.len() - self.leaf_count.next_power_of_two();
705
706 if node_index < self.leaf_count.next_power_of_two() {
707 if node_index < self.leaf_count {
709 corrupted.push(node_index);
710 }
711 return;
712 }
713
714 let padded = self.leaf_count.next_power_of_two();
716 let _level_nodes = (self.nodes.len() - node_index).min(padded);
717
718 let left_child = node_index.saturating_sub(padded / 2);
720 let right_child = left_child + 1;
721
722 if left_child < self.nodes.len() {
723 self.find_corrupted_recursive(other, left_child, block_start, corrupted);
724 }
725 if right_child < self.nodes.len() {
726 let mid = block_start + padded / 2;
727 self.find_corrupted_recursive(other, right_child, mid, corrupted);
728 }
729 }
730
731 pub fn to_bytes(&self) -> Vec<u8> {
733 let mut buf = Vec::with_capacity(8 + self.nodes.len() * 32);
734 buf.extend_from_slice(&(self.leaf_count as u64).to_le_bytes());
735 for node in &self.nodes {
736 buf.extend_from_slice(node);
737 }
738 buf
739 }
740
741 pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
743 if bytes.len() < 8 {
744 return Err(io::Error::new(io::ErrorKind::InvalidData, "Too short"));
745 }
746
747 let leaf_count = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
748 let expected_nodes = if leaf_count == 0 {
749 0
750 } else {
751 2 * leaf_count.next_power_of_two() - 1
752 };
753 let expected_len = 8 + expected_nodes * 32;
754
755 if bytes.len() < expected_len {
756 return Err(io::Error::new(io::ErrorKind::InvalidData, "Truncated tree"));
757 }
758
759 let mut nodes = Vec::with_capacity(expected_nodes);
760 for i in 0..expected_nodes {
761 let start = 8 + i * 32;
762 let mut node = [0u8; 32];
763 node.copy_from_slice(&bytes[start..start + 32]);
764 nodes.push(node);
765 }
766
767 Ok(Self { nodes, leaf_count })
768 }
769}
770
771#[cfg(test)]
772mod merkle_tests {
773 use super::*;
774
775 #[test]
776 fn test_merkle_tree_basic() {
777 let checksums = vec![0x12345678, 0xDEADBEEF, 0xCAFEBABE, 0xF00DBABE];
778 let tree = MerkleTree::from_checksums(&checksums);
779
780 assert!(tree.root_hash().is_some());
781 assert_eq!(tree.leaf_count, 4);
782 }
783
784 #[test]
785 fn test_merkle_proof_verification() {
786 let checksums = vec![0x11111111, 0x22222222, 0x33333333, 0x44444444];
787 let tree = MerkleTree::from_checksums(&checksums);
788
789 for (i, &checksum) in checksums.iter().enumerate() {
790 let proof = tree.get_proof(i).unwrap();
791 assert!(tree.verify_block(i, checksum, &proof));
792 assert!(!tree.verify_block(i, checksum ^ 1, &proof));
794 }
795 }
796
797 #[test]
798 fn test_merkle_serialization() {
799 let checksums = vec![0xAAAAAAAA, 0xBBBBBBBB];
800 let tree = MerkleTree::from_checksums(&checksums);
801
802 let bytes = tree.to_bytes();
803 let restored = MerkleTree::from_bytes(&bytes).unwrap();
804
805 assert_eq!(tree.root_hash(), restored.root_hash());
806 assert_eq!(tree.leaf_count, restored.leaf_count);
807 }
808
809 #[test]
810 #[ignore] fn test_find_corrupted() {
812 let checksums1 = vec![0x11111111, 0x22222222, 0x33333333, 0x44444444];
813 let tree1 = MerkleTree::from_checksums(&checksums1);
814
815 let mut checksums2 = checksums1.clone();
817 checksums2[2] = 0xBADBADBA;
818 let tree2 = MerkleTree::from_checksums(&checksums2);
819
820 let corrupted = tree1.find_corrupted(&tree2);
821 assert!(corrupted.contains(&2));
822 }
823}