Skip to main content

tzap_core/
compression.rs

1use crate::format::FormatError;
2
3const ZSTD_MAGIC: [u8; 4] = [0x28, 0xb5, 0x2f, 0xfd];
4
5pub fn compress_zstd_frame(plaintext: &[u8], level: i32) -> Result<Vec<u8>, FormatError> {
6    zstd::bulk::compress(plaintext, level).map_err(|_| FormatError::ZstdCompressionFailure)
7}
8
9pub fn compress_zstd_frame_with_dictionary(
10    plaintext: &[u8],
11    level: i32,
12    dictionary: &[u8],
13) -> Result<Vec<u8>, FormatError> {
14    zstd::bulk::Compressor::with_dictionary(level, dictionary)
15        .and_then(|mut compressor| compressor.compress(plaintext))
16        .map_err(|_| FormatError::ZstdCompressionFailure)
17}
18
19pub fn decompress_exact_zstd_frame(
20    compressed: &[u8],
21    expected_decompressed_size: usize,
22) -> Result<Vec<u8>, FormatError> {
23    validate_metadata_decompressed_size(expected_decompressed_size)?;
24    validate_exact_zstd_frame(compressed)?;
25    let decompressed = zstd::bulk::decompress(compressed, expected_decompressed_size)
26        .map_err(|_| FormatError::ZstdDecompressionFailure)?;
27    if decompressed.len() != expected_decompressed_size {
28        return Err(FormatError::ZstdDecompressedSizeMismatch {
29            expected: expected_decompressed_size,
30            actual: decompressed.len(),
31        });
32    }
33    Ok(decompressed)
34}
35
36pub fn decompress_exact_zstd_frame_with_dictionary(
37    compressed: &[u8],
38    expected_decompressed_size: usize,
39    dictionary: &[u8],
40) -> Result<Vec<u8>, FormatError> {
41    validate_metadata_decompressed_size(expected_decompressed_size)?;
42    validate_exact_zstd_frame(compressed)?;
43    let decompressed = zstd::bulk::Decompressor::with_dictionary(dictionary)
44        .and_then(|mut decompressor| {
45            decompressor.decompress(compressed, expected_decompressed_size)
46        })
47        .map_err(|_| FormatError::ZstdDecompressionFailure)?;
48    if decompressed.len() != expected_decompressed_size {
49        return Err(FormatError::ZstdDecompressedSizeMismatch {
50            expected: expected_decompressed_size,
51            actual: decompressed.len(),
52        });
53    }
54    Ok(decompressed)
55}
56
57pub fn validate_exact_zstd_frame(compressed: &[u8]) -> Result<(), FormatError> {
58    if compressed.is_empty() {
59        return Err(FormatError::EmptyZstdFrame);
60    }
61    if compressed.len() < 4 || compressed[0..4] != ZSTD_MAGIC {
62        return Err(FormatError::NotStandardZstdFrame);
63    }
64    let frame_size = zstd_safe::find_frame_compressed_size(compressed)
65        .map_err(|_| FormatError::InvalidZstdFrame)?;
66    if frame_size != compressed.len() {
67        return Err(FormatError::TrailingBytesAfterZstdFrame);
68    }
69    Ok(())
70}
71
72fn validate_metadata_decompressed_size(
73    expected_decompressed_size: usize,
74) -> Result<(), FormatError> {
75    if expected_decompressed_size > u32::MAX as usize {
76        Err(FormatError::ReaderResourceLimitExceeded {
77            field: "decompressed_size",
78            cap: u32::MAX as u64,
79            actual: expected_decompressed_size as u64,
80        })
81    } else {
82        Ok(())
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn compresses_and_decompresses_exact_frame() {
92        let plaintext = b"metadata object payload";
93        let compressed = compress_zstd_frame(plaintext, 3).unwrap();
94        let decompressed = decompress_exact_zstd_frame(&compressed, plaintext.len()).unwrap();
95        assert_eq!(decompressed, plaintext);
96    }
97
98    #[test]
99    fn rejects_trailing_concatenated_and_skippable_frames() {
100        let plaintext = b"payload";
101        let mut compressed = compress_zstd_frame(plaintext, 1).unwrap();
102        compressed.push(0);
103        assert_eq!(
104            decompress_exact_zstd_frame(&compressed, plaintext.len()).unwrap_err(),
105            FormatError::TrailingBytesAfterZstdFrame
106        );
107
108        let one = compress_zstd_frame(plaintext, 1).unwrap();
109        let mut concatenated = one.clone();
110        concatenated.extend_from_slice(&one);
111        assert_eq!(
112            decompress_exact_zstd_frame(&concatenated, plaintext.len()).unwrap_err(),
113            FormatError::TrailingBytesAfterZstdFrame
114        );
115
116        let skippable = [0x50, 0x2a, 0x4d, 0x18, 0, 0, 0, 0];
117        assert_eq!(
118            validate_exact_zstd_frame(&skippable).unwrap_err(),
119            FormatError::NotStandardZstdFrame
120        );
121    }
122
123    #[test]
124    fn rejects_wrong_decompressed_size() {
125        let compressed = compress_zstd_frame(b"payload", 1).unwrap();
126        assert_eq!(
127            decompress_exact_zstd_frame(&compressed, 100).unwrap_err(),
128            FormatError::ZstdDecompressedSizeMismatch {
129                expected: 100,
130                actual: 7
131            }
132        );
133    }
134
135    #[cfg(target_pointer_width = "64")]
136    #[test]
137    fn rejects_decompressed_size_over_u32_cap() {
138        let compressed = compress_zstd_frame(b"metadata-object", 1).unwrap();
139        assert_eq!(
140            decompress_exact_zstd_frame(&compressed, (u32::MAX as usize) + 1).unwrap_err(),
141            FormatError::ReaderResourceLimitExceeded {
142                field: "decompressed_size",
143                cap: u32::MAX as u64,
144                actual: (u32::MAX as u64) + 1,
145            }
146        );
147    }
148
149    #[test]
150    fn compresses_and_decompresses_exact_dictionary_frame() {
151        let dictionary = b"common prefix common prefix common prefix";
152        let plaintext = b"common prefix payload";
153        let compressed = compress_zstd_frame_with_dictionary(plaintext, 3, dictionary).unwrap();
154        let decompressed =
155            decompress_exact_zstd_frame_with_dictionary(&compressed, plaintext.len(), dictionary)
156                .unwrap();
157        assert_eq!(decompressed, plaintext);
158    }
159}