1use bytes::Bytes;
8
9use crate::{
10 ChunkManifest, Codec, CodecError, CodecKind, DECOMPRESS_BOOTSTRAP_CAPACITY,
11 validate_decompress_manifest,
12};
13
14#[derive(Debug, Clone)]
18pub struct CpuZstd {
19 level: i32,
20}
21
22impl CpuZstd {
23 pub const DEFAULT_LEVEL: i32 = 3;
24
25 pub fn new(level: i32) -> Self {
26 Self {
27 level: level.clamp(1, 22),
28 }
29 }
30}
31
32impl Default for CpuZstd {
33 fn default() -> Self {
34 Self::new(Self::DEFAULT_LEVEL)
35 }
36}
37
38pub fn decompress_blocking(input: &[u8], manifest: &ChunkManifest) -> Result<Vec<u8>, CodecError> {
46 if manifest.codec != CodecKind::CpuZstd {
47 return Err(CodecError::CodecMismatch {
48 expected: CodecKind::CpuZstd,
49 got: manifest.codec,
50 });
51 }
52 let allocated_orig_size = validate_decompress_manifest(manifest, input.len())?;
56 use std::io::Read;
57 let limit = manifest.original_size.saturating_add(1024);
58 let mut decoder = zstd::stream::Decoder::new(input).map_err(CodecError::Io)?;
59 let mut buf = Vec::with_capacity(allocated_orig_size.min(DECOMPRESS_BOOTSTRAP_CAPACITY));
65 {
66 let mut limited = (&mut decoder).take(limit);
67 limited.read_to_end(&mut buf).map_err(CodecError::Io)?;
68 }
69 if buf.len() as u64 > manifest.original_size {
79 let mut peek = [0u8; 1];
80 let more_available = decoder.read(&mut peek).map(|n| n > 0).unwrap_or(false);
81 return Err(CodecError::Io(std::io::Error::other(format!(
82 "zstd decompression bomb detected: produced at least {} bytes \
83 (truncated at cap = manifest.original_size + 1024 = {}); \
84 manifest claimed {}{}",
85 buf.len(),
86 limit,
87 manifest.original_size,
88 if more_available {
89 "; decoder had more bytes available beyond the cap"
90 } else {
91 ""
92 },
93 ))));
94 }
95 if buf.len() as u64 != manifest.original_size {
96 return Err(CodecError::SizeMismatch {
97 expected: manifest.original_size,
98 got: buf.len() as u64,
99 });
100 }
101 let actual_crc = crc32c::crc32c(&buf);
102 if actual_crc != manifest.crc32c {
103 return Err(CodecError::CrcMismatch {
104 expected: manifest.crc32c,
105 got: actual_crc,
106 });
107 }
108 Ok(buf)
109}
110
111pub fn compress_blocking(input: &[u8], level: i32) -> Result<(Vec<u8>, ChunkManifest), CodecError> {
115 let level = level.clamp(1, 22);
116 let original_size = input.len() as u64;
117 let original_crc = crc32c::crc32c(input);
118 let compressed = zstd::stream::encode_all(input, level).map_err(CodecError::Io)?;
119 Ok((
120 compressed.clone(),
121 ChunkManifest {
122 codec: CodecKind::CpuZstd,
123 original_size,
124 compressed_size: compressed.len() as u64,
125 crc32c: original_crc,
126 },
127 ))
128}
129
130#[async_trait::async_trait]
131impl Codec for CpuZstd {
132 fn kind(&self) -> CodecKind {
133 CodecKind::CpuZstd
134 }
135
136 async fn compress(&self, input: Bytes) -> Result<(Bytes, ChunkManifest), CodecError> {
137 let level = self.level;
138 let original_size = input.len() as u64;
139 let original_crc = crc32c::crc32c(&input);
140
141 let compressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
142 zstd::stream::encode_all(input.as_ref(), level)
143 })
144 .await??;
145
146 let compressed_size = compressed.len() as u64;
147 let manifest = ChunkManifest {
148 codec: CodecKind::CpuZstd,
149 original_size,
150 compressed_size,
151 crc32c: original_crc,
152 };
153 Ok((Bytes::from(compressed), manifest))
154 }
155
156 async fn decompress(
157 &self,
158 input: Bytes,
159 manifest: &ChunkManifest,
160 ) -> Result<Bytes, CodecError> {
161 if manifest.codec != CodecKind::CpuZstd {
162 return Err(CodecError::CodecMismatch {
163 expected: CodecKind::CpuZstd,
164 got: manifest.codec,
165 });
166 }
167 let allocated_orig_size = validate_decompress_manifest(manifest, input.len())?;
174
175 let expected_crc = manifest.crc32c;
176 let expected_orig_size = manifest.original_size;
177
178 let decompressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
183 use std::io::Read;
184 let limit = expected_orig_size.saturating_add(1024);
188 let mut decoder = zstd::stream::Decoder::new(input.as_ref())?;
189 let mut buf =
192 Vec::with_capacity(allocated_orig_size.min(DECOMPRESS_BOOTSTRAP_CAPACITY));
193 {
194 let mut limited = (&mut decoder).take(limit);
195 limited.read_to_end(&mut buf)?;
196 }
197 if (buf.len() as u64) > expected_orig_size {
208 let mut peek = [0u8; 1];
209 let more_available = decoder.read(&mut peek).map(|n| n > 0).unwrap_or(false);
210 return Err(std::io::Error::other(format!(
211 "zstd decompression bomb detected: produced at least {} bytes \
212 (truncated at cap = manifest.original_size + 1024 = {}); \
213 manifest claimed {}{}",
214 buf.len(),
215 limit,
216 expected_orig_size,
217 if more_available {
218 "; decoder had more bytes available beyond the cap"
219 } else {
220 ""
221 },
222 )));
223 }
224 Ok(buf)
225 })
226 .await??;
227
228 if decompressed.len() as u64 != expected_orig_size {
229 return Err(CodecError::SizeMismatch {
230 expected: expected_orig_size,
231 got: decompressed.len() as u64,
232 });
233 }
234 let actual_crc = crc32c::crc32c(&decompressed);
235 if actual_crc != expected_crc {
236 return Err(CodecError::CrcMismatch {
237 expected: expected_crc,
238 got: actual_crc,
239 });
240 }
241 Ok(Bytes::from(decompressed))
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[tokio::test]
250 async fn roundtrip_small() {
251 let codec = CpuZstd::default();
252 let input = Bytes::from_static(b"hello squished s3 hello squished s3 hello squished s3");
253 let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
254 assert_eq!(manifest.codec, CodecKind::CpuZstd);
256 assert_eq!(manifest.original_size, input.len() as u64);
257 let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
258 assert_eq!(decompressed, input);
259 }
260
261 #[tokio::test]
262 async fn roundtrip_compressible() {
263 let codec = CpuZstd::default();
264 let input = Bytes::from(vec![b'x'; 1024 * 1024]);
266 let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
267 assert!(
268 compressed.len() < input.len() / 100,
269 "expected zstd to compress 1 MiB of x bytes very well, got {} bytes",
270 compressed.len()
271 );
272 let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
273 assert_eq!(decompressed, input);
274 }
275
276 #[tokio::test]
277 async fn detects_corrupted_compressed_payload() {
278 let codec = CpuZstd::default();
279 let input = Bytes::from(vec![b'x'; 1024]);
280 let (mut compressed, manifest) = codec.compress(input).await.unwrap();
281 let mut buf = compressed.to_vec();
283 if buf.len() > 8 {
284 buf[5] ^= 0xff;
285 }
286 compressed = Bytes::from(buf);
287 let err = codec.decompress(compressed, &manifest).await.unwrap_err();
288 assert!(matches!(
290 err,
291 CodecError::Io(_) | CodecError::CrcMismatch { .. } | CodecError::SizeMismatch { .. }
292 ));
293 }
294
295 #[tokio::test]
296 async fn rejects_codec_mismatch() {
297 let codec = CpuZstd::default();
298 let manifest = ChunkManifest {
299 codec: CodecKind::Passthrough,
300 original_size: 10,
301 compressed_size: 10,
302 crc32c: 0,
303 };
304 let err = codec
305 .decompress(Bytes::from_static(b"0123456789"), &manifest)
306 .await
307 .unwrap_err();
308 assert!(matches!(err, CodecError::CodecMismatch { .. }));
309 }
310
311 #[tokio::test]
315 async fn issue_89_rejects_manifest_over_5gib() {
316 let codec = CpuZstd::default();
317 let body = Bytes::from_static(&[0x00, 0xd1, 0xd1, 0xd1, 0xd1, 0xd1]);
318 let manifest = ChunkManifest {
319 codec: CodecKind::CpuZstd,
320 original_size: crate::MAX_DECOMPRESSED_BYTES + 1,
321 compressed_size: body.len() as u64,
322 crc32c: 0,
323 };
324 let err = codec.decompress(body, &manifest).await.unwrap_err();
325 match err {
326 CodecError::ManifestSizeExceedsLimit { requested, limit } => {
327 assert_eq!(requested, crate::MAX_DECOMPRESSED_BYTES + 1);
328 assert_eq!(limit, crate::MAX_DECOMPRESSED_BYTES);
329 }
330 other => panic!("expected ManifestSizeExceedsLimit, got {other:?}"),
331 }
332 }
333
334 #[tokio::test]
345 async fn issue_89_bootstrap_cap_keeps_4gib_claim_alloc_safe() {
346 let codec = CpuZstd::default();
347 let body = Bytes::from_static(&[0x00, 0xd1, 0xd1, 0xd1, 0xd1, 0xd1]);
348 let manifest = ChunkManifest {
349 codec: CodecKind::CpuZstd,
350 original_size: u32::MAX as u64,
354 compressed_size: body.len() as u64,
355 crc32c: 0,
356 };
357 let err = codec.decompress(body, &manifest).await.unwrap_err();
358 assert!(
362 matches!(err, CodecError::Io(_) | CodecError::SizeMismatch { .. }),
363 "expected Io or SizeMismatch, got {err:?}"
364 );
365 }
366
367 #[test]
370 fn blocking_roundtrip() {
371 let input = b"hello squished s3 hello squished s3 hello squished s3";
372 let (compressed, manifest) = compress_blocking(input, CpuZstd::DEFAULT_LEVEL).unwrap();
373 assert_eq!(manifest.codec, CodecKind::CpuZstd);
374 let decompressed = decompress_blocking(&compressed, &manifest).unwrap();
375 assert_eq!(decompressed, input);
376 }
377
378 #[test]
386 fn issue_89_blocking_rejects_manifest_over_5gib() {
387 let body = &[0x00, 0xd1, 0xd1, 0xd1, 0xd1, 0xd1];
388 let manifest = ChunkManifest {
389 codec: CodecKind::CpuZstd,
390 original_size: crate::MAX_DECOMPRESSED_BYTES + 1,
391 compressed_size: body.len() as u64,
392 crc32c: 0,
393 };
394 let err = decompress_blocking(body, &manifest).unwrap_err();
395 match err {
396 CodecError::ManifestSizeExceedsLimit { requested, limit } => {
397 assert_eq!(requested, crate::MAX_DECOMPRESSED_BYTES + 1);
398 assert_eq!(limit, crate::MAX_DECOMPRESSED_BYTES);
399 }
400 other => panic!("expected ManifestSizeExceedsLimit, got {other:?}"),
401 }
402 }
403
404 #[test]
405 fn issue_89_blocking_bootstrap_cap_keeps_4gib_claim_alloc_safe() {
406 let body = &[0x00, 0xd1, 0xd1, 0xd1, 0xd1, 0xd1];
407 let manifest = ChunkManifest {
408 codec: CodecKind::CpuZstd,
409 original_size: u32::MAX as u64,
410 compressed_size: body.len() as u64,
411 crc32c: 0,
412 };
413 let err = decompress_blocking(body, &manifest).unwrap_err();
414 assert!(
415 matches!(err, CodecError::Io(_) | CodecError::SizeMismatch { .. }),
416 "expected Io or SizeMismatch, got {err:?}"
417 );
418 }
419}