1use std::io::{self, Read, Write};
46
47const MAGIC: &[u8; 4] = b"SQRY";
49
50const FORMAT_VERSION: u32 = 1;
52
53pub const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
55
56pub const DEFAULT_MAX_UNCOMPRESSED_SIZE: u64 = 500 * 1024 * 1024;
61
62const MIN_MAX_UNCOMPRESSED_SIZE: u64 = 1024 * 1024;
64
65const MAX_MAX_UNCOMPRESSED_SIZE: u64 = 2 * 1024 * 1024 * 1024;
67
68#[must_use]
80pub fn max_uncompressed_size() -> u64 {
81 let size = std::env::var("SQRY_MAX_INDEX_SIZE")
82 .ok()
83 .and_then(|s| s.parse().ok())
84 .unwrap_or(DEFAULT_MAX_UNCOMPRESSED_SIZE);
85 size.clamp(MIN_MAX_UNCOMPRESSED_SIZE, MAX_MAX_UNCOMPRESSED_SIZE)
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90#[repr(u8)]
91pub enum CompressionFormat {
92 None = 0,
94 Zstd = 1,
96}
97
98impl CompressionFormat {
99 fn from_u8(value: u8) -> Result<Self, CompressionError> {
101 match value {
102 0 => Ok(Self::None),
103 1 => Ok(Self::Zstd),
104 _ => Err(CompressionError::UnsupportedCompression(value)),
105 }
106 }
107}
108
109#[derive(Debug, thiserror::Error)]
111pub enum CompressionError {
112 #[error("I/O error: {0}")]
114 Io(#[from] io::Error),
115
116 #[error("Unsupported compression format: {0}")]
118 UnsupportedCompression(u8),
119
120 #[error("Invalid magic bytes, expected SQRY")]
122 InvalidMagic,
123
124 #[error("Index version {index_version} is too new for sqry {sqry_version}, please upgrade")]
126 IndexVersionTooNew {
127 index_version: u32,
129 sqry_version: &'static str,
131 },
132
133 #[error("Invalid index version: {0}")]
135 InvalidIndexVersion(u32),
136
137 #[error("Invalid header size: expected at least 21 bytes, got {0}")]
139 InvalidHeaderSize(usize),
140
141 #[error("Decompressed size mismatch: expected {expected}, got {actual}")]
143 SizeMismatch {
144 expected: u64,
146 actual: u64,
148 },
149
150 #[error("Decompression bomb detected: uncompressed size {size} exceeds maximum {max}")]
152 DecompressionBomb {
153 size: u64,
155 max: u64,
157 },
158}
159
160#[derive(Debug, Clone)]
162pub struct CompressedIndex {
163 version: u32,
165 compression: CompressionFormat,
167 level: i32,
169 uncompressed_size: u64,
171 data: Vec<u8>,
173}
174
175impl CompressedIndex {
176 pub fn compress(data: &[u8], level: i32) -> Result<Self, CompressionError> {
198 let mut encoder = zstd::Encoder::new(Vec::new(), level)?;
199 encoder.write_all(data)?;
200 let compressed = encoder.finish()?;
201
202 Ok(Self {
203 version: FORMAT_VERSION,
204 compression: CompressionFormat::Zstd,
205 level,
206 uncompressed_size: data.len() as u64,
207 data: compressed,
208 })
209 }
210
211 #[must_use]
222 pub fn uncompressed(data: &[u8]) -> Self {
223 Self {
224 version: FORMAT_VERSION,
225 compression: CompressionFormat::None,
226 level: 0,
227 uncompressed_size: data.len() as u64,
228 data: data.to_vec(),
229 }
230 }
231
232 pub fn decompress(&self) -> Result<Vec<u8>, CompressionError> {
251 let max_size = max_uncompressed_size();
253 if self.uncompressed_size > max_size {
254 return Err(CompressionError::DecompressionBomb {
255 size: self.uncompressed_size,
256 max: max_size,
257 });
258 }
259
260 match self.compression {
261 CompressionFormat::None => {
262 if self.data.len() as u64 > max_size {
264 return Err(CompressionError::DecompressionBomb {
265 size: self.data.len() as u64,
266 max: max_size,
267 });
268 }
269 Ok(self.data.clone())
270 }
271 CompressionFormat::Zstd => {
272 let decoder = zstd::Decoder::new(&self.data[..])?;
274
275 let mut limited = decoder.take(max_size + 1);
281 let mut decompressed = Vec::new();
282 limited.read_to_end(&mut decompressed)?;
283
284 let actual_size = decompressed.len() as u64;
286 if actual_size != self.uncompressed_size {
287 return Err(CompressionError::SizeMismatch {
288 expected: self.uncompressed_size,
289 actual: actual_size,
290 });
291 }
292
293 if actual_size > max_size {
296 return Err(CompressionError::DecompressionBomb {
297 size: actual_size,
298 max: max_size,
299 });
300 }
301
302 Ok(decompressed)
303 }
304 }
305 }
306
307 #[must_use]
320 pub fn serialize(&self) -> Vec<u8> {
321 let mut buffer = Vec::with_capacity(21 + self.data.len());
322
323 buffer.extend_from_slice(MAGIC);
325
326 buffer.extend_from_slice(&self.version.to_le_bytes());
328
329 buffer.push(self.compression as u8);
331
332 buffer.extend_from_slice(&self.level.to_le_bytes());
334
335 buffer.extend_from_slice(&self.uncompressed_size.to_le_bytes());
337
338 buffer.extend_from_slice(&self.data);
340
341 buffer
342 }
343
344 pub fn deserialize(data: &[u8]) -> Result<Self, CompressionError> {
354 if data.len() < 21 {
356 return Err(CompressionError::InvalidHeaderSize(data.len()));
357 }
358
359 if &data[0..4] != MAGIC {
361 return Err(CompressionError::InvalidMagic);
362 }
363
364 let version = u32::from_le_bytes(
366 data[4..8]
367 .try_into()
368 .map_err(|_| CompressionError::InvalidHeaderSize(data.len()))?,
369 );
370
371 match version {
373 0 => return Err(CompressionError::InvalidIndexVersion(0)),
374 FORMAT_VERSION => {
375 }
377 v if v > FORMAT_VERSION => {
378 return Err(CompressionError::IndexVersionTooNew {
379 index_version: v,
380 sqry_version: env!("CARGO_PKG_VERSION"),
381 });
382 }
383 _ => {
384 return Err(CompressionError::InvalidIndexVersion(version));
386 }
387 }
388
389 let compression = CompressionFormat::from_u8(data[8])?;
391
392 let level = i32::from_le_bytes(
394 data[9..13]
395 .try_into()
396 .map_err(|_| CompressionError::InvalidHeaderSize(data.len()))?,
397 );
398
399 let uncompressed_size = u64::from_le_bytes(
401 data[13..21]
402 .try_into()
403 .map_err(|_| CompressionError::InvalidHeaderSize(data.len()))?,
404 );
405
406 let index_data = data[21..].to_vec();
408
409 Ok(Self {
410 version,
411 compression,
412 level,
413 uncompressed_size,
414 data: index_data,
415 })
416 }
417
418 #[must_use]
420 pub fn compression(&self) -> CompressionFormat {
421 self.compression
422 }
423
424 #[must_use]
426 pub fn uncompressed_size(&self) -> u64 {
427 self.uncompressed_size
428 }
429
430 #[must_use]
432 pub fn compressed_size(&self) -> usize {
433 self.data.len()
434 }
435
436 #[must_use]
440 pub fn compression_ratio(&self) -> f64 {
441 if self.data.is_empty() {
442 return 1.0;
443 }
444 Self::to_f64_lossy_u64(self.uncompressed_size) / Self::to_f64_lossy_usize(self.data.len())
445 }
446
447 #[inline]
448 #[allow(clippy::cast_precision_loss)] fn to_f64_lossy_u64(value: u64) -> f64 {
450 value as f64
451 }
452
453 #[inline]
454 #[allow(clippy::cast_precision_loss)] fn to_f64_lossy_usize(value: usize) -> f64 {
456 value as f64
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 #[test]
465 fn test_compress_decompress_roundtrip() {
466 let original = b"test data for compression";
467 let compressed = CompressedIndex::compress(original, DEFAULT_COMPRESSION_LEVEL).unwrap();
468 let decompressed = compressed.decompress().unwrap();
469
470 assert_eq!(original, &decompressed[..]);
471 }
472
473 #[test]
474 fn test_serialize_deserialize_roundtrip() {
475 let original = b"test data for serialization";
476 let compressed = CompressedIndex::compress(original, 3).unwrap();
477 let serialized = compressed.serialize();
478 let deserialized = CompressedIndex::deserialize(&serialized).unwrap();
479 let decompressed = deserialized.decompress().unwrap();
480
481 assert_eq!(original, &decompressed[..]);
482 }
483
484 #[test]
485 fn test_compression_reduces_size() {
486 let original = vec![b'a'; 10000];
488 let compressed = CompressedIndex::compress(&original, 3).unwrap();
489
490 assert!(
491 compressed.compressed_size() < original.len(),
492 "Compressed size {} should be less than original size {}",
493 compressed.compressed_size(),
494 original.len()
495 );
496 }
497
498 #[test]
499 fn test_compression_ratio() {
500 let original = vec![b'x'; 1000];
501 let compressed = CompressedIndex::compress(&original, 3).unwrap();
502
503 let ratio = compressed.compression_ratio();
504 assert!(
505 ratio > 1.0,
506 "Compression ratio should be > 1.0 for compressible data"
507 );
508 }
509
510 #[test]
511 fn test_uncompressed_roundtrip() {
512 let original = b"uncompressed test data";
513 let uncompressed = CompressedIndex::uncompressed(original);
514 let decompressed = uncompressed.decompress().unwrap();
515
516 assert_eq!(original, &decompressed[..]);
517 assert_eq!(uncompressed.compression(), CompressionFormat::None);
518 }
519
520 #[test]
521 fn test_magic_bytes_in_header() {
522 let original = b"test";
523 let compressed = CompressedIndex::compress(original, 3).unwrap();
524 let serialized = compressed.serialize();
525
526 assert_eq!(&serialized[0..4], b"SQRY");
527 }
528
529 #[test]
530 fn test_invalid_magic_bytes() {
531 let mut invalid_data = vec![0u8; 21];
533 invalid_data[0..4].copy_from_slice(b"XXXX"); let result = CompressedIndex::deserialize(&invalid_data);
535
536 assert!(matches!(result, Err(CompressionError::InvalidMagic)));
537 }
538
539 #[test]
540 fn test_header_too_small() {
541 let too_small = b"SQRY123"; let result = CompressedIndex::deserialize(too_small);
543
544 assert!(matches!(
545 result,
546 Err(CompressionError::InvalidHeaderSize(7))
547 ));
548 }
549
550 #[test]
551 fn test_unsupported_compression_format() {
552 let mut data = vec![0u8; 21];
553 data[0..4].copy_from_slice(b"SQRY");
554 data[4..8].copy_from_slice(&1u32.to_le_bytes()); data[8] = 99; let result = CompressedIndex::deserialize(&data);
558
559 assert!(matches!(
560 result,
561 Err(CompressionError::UnsupportedCompression(99))
562 ));
563 }
564
565 #[test]
566 fn test_future_version_error() {
567 let mut data = vec![0u8; 21];
568 data[0..4].copy_from_slice(b"SQRY");
569 data[4..8].copy_from_slice(&999u32.to_le_bytes()); let result = CompressedIndex::deserialize(&data);
572
573 assert!(matches!(
574 result,
575 Err(CompressionError::IndexVersionTooNew { .. })
576 ));
577 }
578
579 #[test]
580 fn test_zero_version_error() {
581 let mut data = vec![0u8; 21];
582 data[0..4].copy_from_slice(b"SQRY");
583 data[4..8].copy_from_slice(&0u32.to_le_bytes()); let result = CompressedIndex::deserialize(&data);
586
587 assert!(matches!(
588 result,
589 Err(CompressionError::InvalidIndexVersion(0))
590 ));
591 }
592
593 #[test]
594 fn test_compression_metadata() {
595 let original = vec![b'y'; 5000];
596 let compressed = CompressedIndex::compress(&original, 5).unwrap();
597
598 assert_eq!(compressed.uncompressed_size(), 5000);
599 assert_eq!(compressed.compression(), CompressionFormat::Zstd);
600 assert!(compressed.compressed_size() < 5000);
601 }
602
603 #[test]
604 fn test_empty_data_compression() {
605 let original = b"";
606 let compressed = CompressedIndex::compress(original, 3).unwrap();
607 let decompressed = compressed.decompress().unwrap();
608
609 assert_eq!(original, &decompressed[..]);
610 assert_eq!(compressed.uncompressed_size(), 0);
611 }
612
613 #[test]
614 fn test_large_data_compression() {
615 let original = vec![b'z'; 1_000_000];
617 let compressed = CompressedIndex::compress(&original, 3).unwrap();
618 let decompressed = compressed.decompress().unwrap();
619
620 assert_eq!(original, decompressed);
621 assert!(
623 compressed.compressed_size() < 100_000,
624 "Expected < 100KB compressed, got {}",
625 compressed.compressed_size()
626 );
627 }
628
629 #[test]
634 fn test_decompression_bomb_protection_blocks_oversized() {
635 let original = vec![b'a'; 1_000_000]; let compressed = CompressedIndex::compress(&original, 3).unwrap();
638
639 let mut serialized = compressed.serialize();
641 let fake_size = 600u64 * 1024 * 1024; serialized[13..21].copy_from_slice(&fake_size.to_le_bytes());
643
644 let corrupted = CompressedIndex::deserialize(&serialized).unwrap();
645 let result = corrupted.decompress();
646
647 assert!(
649 matches!(result, Err(CompressionError::DecompressionBomb { .. })),
650 "Should reject oversized decompression claim"
651 );
652 }
653
654 #[test]
655 fn test_decompression_bomb_protection_allows_at_limit() {
656 let original = vec![b'b'; 100_000]; let mut compressed = CompressedIndex::compress(&original, 3).unwrap();
660
661 let exact_limit = 500u64 * 1024 * 1024;
663 compressed.uncompressed_size = exact_limit;
664
665 let serialized = compressed.serialize();
666 let deserialized = CompressedIndex::deserialize(&serialized).unwrap();
667
668 let result = deserialized.decompress();
672
673 assert!(
675 !matches!(result, Err(CompressionError::DecompressionBomb { .. })),
676 "Should not reject data exactly at limit as decompression bomb"
677 );
678 }
679
680 #[test]
681 fn test_decompression_bomb_protection_blocks_one_over_limit() {
682 let original = vec![b'c'; 100_000]; let compressed = CompressedIndex::compress(&original, 3).unwrap();
685
686 let mut serialized = compressed.serialize();
688 let over_limit = (500u64 * 1024 * 1024) + 1; serialized[13..21].copy_from_slice(&over_limit.to_le_bytes());
690
691 let corrupted = CompressedIndex::deserialize(&serialized).unwrap();
692 let result = corrupted.decompress();
693
694 assert!(
696 matches!(result, Err(CompressionError::DecompressionBomb { .. })),
697 "Should reject data exceeding limit by even 1 byte"
698 );
699 }
700
701 #[test]
702 fn test_decompression_enforces_streaming_limit() {
703 let original = vec![b'd'; 200_000]; let compressed = CompressedIndex::compress(&original, 3).unwrap();
709
710 let result = compressed.decompress();
712 assert!(result.is_ok(), "Decompression within limit should succeed");
713
714 }
719
720 #[test]
721 fn test_max_uncompressed_size_clamping_enforces_minimum() {
722 const MIN_MAX_UNCOMPRESSED_SIZE: u64 = 1024 * 1024; const MAX_MAX_UNCOMPRESSED_SIZE: u64 = 2 * 1024 * 1024 * 1024; assert_eq!(MIN_MAX_UNCOMPRESSED_SIZE, 1_048_576, "MIN should be 1MB");
731 assert_eq!(
732 MAX_MAX_UNCOMPRESSED_SIZE, 2_147_483_648,
733 "MAX should be 2GB"
734 );
735
736 let default_size = max_uncompressed_size();
738 assert!(
739 default_size >= MIN_MAX_UNCOMPRESSED_SIZE,
740 "Default {default_size} should be >= MIN {MIN_MAX_UNCOMPRESSED_SIZE}"
741 );
742 assert!(
743 default_size <= MAX_MAX_UNCOMPRESSED_SIZE,
744 "Default {default_size} should be <= MAX {MAX_MAX_UNCOMPRESSED_SIZE}"
745 );
746 }
747
748 #[test]
749 fn test_max_uncompressed_size_default_is_500mb() {
750 let default = max_uncompressed_size();
752
753 assert!(
756 default >= 500 * 1024 * 1024 || std::env::var("SQRY_MAX_INDEX_SIZE").is_ok(),
757 "Default should be 500MB or env var should be set"
758 );
759 }
760
761 #[test]
762 fn test_decompression_bomb_error_includes_sizes() {
763 let original = vec![b'e'; 100_000];
765 let compressed = CompressedIndex::compress(&original, 3).unwrap();
766
767 let mut serialized = compressed.serialize();
769 let oversized = 600u64 * 1024 * 1024; serialized[13..21].copy_from_slice(&oversized.to_le_bytes());
771
772 let corrupted = CompressedIndex::deserialize(&serialized).unwrap();
773
774 match corrupted.decompress() {
775 Err(CompressionError::DecompressionBomb { size, max }) => {
776 assert_eq!(size, oversized, "Error should report actual claimed size");
777 assert!(max > 0, "Error should report max limit");
778 assert!(size > max, "Error should show size exceeds max");
779 }
780 other => panic!("Expected DecompressionBomb error, got {other:?}"),
781 }
782 }
783
784 #[test]
785 fn test_compression_format_from_u8() {
786 assert!(matches!(
788 CompressionFormat::from_u8(0),
789 Ok(CompressionFormat::None)
790 ));
791 assert!(matches!(
792 CompressionFormat::from_u8(1),
793 Ok(CompressionFormat::Zstd)
794 ));
795 assert!(matches!(
796 CompressionFormat::from_u8(99),
797 Err(CompressionError::UnsupportedCompression(99))
798 ));
799 }
800}