1use crate::format::FormatError;
2
3const ZSTD_MAGIC: [u8; 4] = [0x28, 0xb5, 0x2f, 0xfd];
4
5pub fn compress_zstd_frame(plaintext: &[u8], level: i32) -> Result<Vec<u8>, FormatError> {
6 zstd::bulk::compress(plaintext, level).map_err(|_| FormatError::ZstdCompressionFailure)
7}
8
9pub fn compress_zstd_frame_with_dictionary(
10 plaintext: &[u8],
11 level: i32,
12 dictionary: &[u8],
13) -> Result<Vec<u8>, FormatError> {
14 zstd::bulk::Compressor::with_dictionary(level, dictionary)
15 .and_then(|mut compressor| compressor.compress(plaintext))
16 .map_err(|_| FormatError::ZstdCompressionFailure)
17}
18
19pub fn decompress_exact_zstd_frame(
20 compressed: &[u8],
21 expected_decompressed_size: usize,
22) -> Result<Vec<u8>, FormatError> {
23 validate_metadata_decompressed_size(expected_decompressed_size)?;
24 validate_exact_zstd_frame(compressed)?;
25 let decompressed = zstd::bulk::decompress(compressed, expected_decompressed_size)
26 .map_err(|_| FormatError::ZstdDecompressionFailure)?;
27 if decompressed.len() != expected_decompressed_size {
28 return Err(FormatError::ZstdDecompressedSizeMismatch {
29 expected: expected_decompressed_size,
30 actual: decompressed.len(),
31 });
32 }
33 Ok(decompressed)
34}
35
36pub fn decompress_exact_zstd_frame_with_dictionary(
37 compressed: &[u8],
38 expected_decompressed_size: usize,
39 dictionary: &[u8],
40) -> Result<Vec<u8>, FormatError> {
41 validate_metadata_decompressed_size(expected_decompressed_size)?;
42 validate_exact_zstd_frame(compressed)?;
43 let decompressed = zstd::bulk::Decompressor::with_dictionary(dictionary)
44 .and_then(|mut decompressor| {
45 decompressor.decompress(compressed, expected_decompressed_size)
46 })
47 .map_err(|_| FormatError::ZstdDecompressionFailure)?;
48 if decompressed.len() != expected_decompressed_size {
49 return Err(FormatError::ZstdDecompressedSizeMismatch {
50 expected: expected_decompressed_size,
51 actual: decompressed.len(),
52 });
53 }
54 Ok(decompressed)
55}
56
57pub fn validate_exact_zstd_frame(compressed: &[u8]) -> Result<(), FormatError> {
58 if compressed.is_empty() {
59 return Err(FormatError::EmptyZstdFrame);
60 }
61 if compressed.len() < 4 || compressed[0..4] != ZSTD_MAGIC {
62 return Err(FormatError::NotStandardZstdFrame);
63 }
64 let frame_size = zstd_safe::find_frame_compressed_size(compressed)
65 .map_err(|_| FormatError::InvalidZstdFrame)?;
66 if frame_size != compressed.len() {
67 return Err(FormatError::TrailingBytesAfterZstdFrame);
68 }
69 Ok(())
70}
71
72fn validate_metadata_decompressed_size(
73 expected_decompressed_size: usize,
74) -> Result<(), FormatError> {
75 if expected_decompressed_size > u32::MAX as usize {
76 Err(FormatError::ReaderResourceLimitExceeded {
77 field: "decompressed_size",
78 cap: u32::MAX as u64,
79 actual: expected_decompressed_size as u64,
80 })
81 } else {
82 Ok(())
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 #[test]
91 fn compresses_and_decompresses_exact_frame() {
92 let plaintext = b"metadata object payload";
93 let compressed = compress_zstd_frame(plaintext, 3).unwrap();
94 let decompressed = decompress_exact_zstd_frame(&compressed, plaintext.len()).unwrap();
95 assert_eq!(decompressed, plaintext);
96 }
97
98 #[test]
99 fn rejects_trailing_concatenated_and_skippable_frames() {
100 let plaintext = b"payload";
101 let mut compressed = compress_zstd_frame(plaintext, 1).unwrap();
102 compressed.push(0);
103 assert_eq!(
104 decompress_exact_zstd_frame(&compressed, plaintext.len()).unwrap_err(),
105 FormatError::TrailingBytesAfterZstdFrame
106 );
107
108 let one = compress_zstd_frame(plaintext, 1).unwrap();
109 let mut concatenated = one.clone();
110 concatenated.extend_from_slice(&one);
111 assert_eq!(
112 decompress_exact_zstd_frame(&concatenated, plaintext.len()).unwrap_err(),
113 FormatError::TrailingBytesAfterZstdFrame
114 );
115
116 let skippable = [0x50, 0x2a, 0x4d, 0x18, 0, 0, 0, 0];
117 assert_eq!(
118 validate_exact_zstd_frame(&skippable).unwrap_err(),
119 FormatError::NotStandardZstdFrame
120 );
121 }
122
123 #[test]
124 fn rejects_wrong_decompressed_size() {
125 let compressed = compress_zstd_frame(b"payload", 1).unwrap();
126 assert_eq!(
127 decompress_exact_zstd_frame(&compressed, 100).unwrap_err(),
128 FormatError::ZstdDecompressedSizeMismatch {
129 expected: 100,
130 actual: 7
131 }
132 );
133 }
134
135 #[cfg(target_pointer_width = "64")]
136 #[test]
137 fn rejects_decompressed_size_over_u32_cap() {
138 let compressed = compress_zstd_frame(b"metadata-object", 1).unwrap();
139 assert_eq!(
140 decompress_exact_zstd_frame(&compressed, (u32::MAX as usize) + 1).unwrap_err(),
141 FormatError::ReaderResourceLimitExceeded {
142 field: "decompressed_size",
143 cap: u32::MAX as u64,
144 actual: (u32::MAX as u64) + 1,
145 }
146 );
147 }
148
149 #[test]
150 fn compresses_and_decompresses_exact_dictionary_frame() {
151 let dictionary = b"common prefix common prefix common prefix";
152 let plaintext = b"common prefix payload";
153 let compressed = compress_zstd_frame_with_dictionary(plaintext, 3, dictionary).unwrap();
154 let decompressed =
155 decompress_exact_zstd_frame_with_dictionary(&compressed, plaintext.len(), dictionary)
156 .unwrap();
157 assert_eq!(decompressed, plaintext);
158 }
159}