1use bytes::Bytes;
8
9use crate::{
10 ChunkManifest, Codec, CodecError, CodecKind, DECOMPRESS_BOOTSTRAP_CAPACITY,
11 validate_decompress_manifest,
12};
13
14#[derive(Debug, Clone)]
18pub struct CpuZstd {
19 level: i32,
20}
21
22impl CpuZstd {
23 pub const DEFAULT_LEVEL: i32 = 3;
24
25 pub fn new(level: i32) -> Self {
26 Self {
27 level: level.clamp(1, 22),
28 }
29 }
30}
31
32impl Default for CpuZstd {
33 fn default() -> Self {
34 Self::new(Self::DEFAULT_LEVEL)
35 }
36}
37
38pub fn decompress_blocking(input: &[u8], manifest: &ChunkManifest) -> Result<Vec<u8>, CodecError> {
46 if manifest.codec != CodecKind::CpuZstd {
47 return Err(CodecError::CodecMismatch {
48 expected: CodecKind::CpuZstd,
49 got: manifest.codec,
50 });
51 }
52 let allocated_orig_size = validate_decompress_manifest(manifest, input.len())?;
56 use std::io::Read;
57 let limit = manifest.original_size.saturating_add(1024);
58 let mut decoder = zstd::stream::Decoder::new(input).map_err(CodecError::Io)?;
59 let mut buf = Vec::with_capacity(allocated_orig_size.min(DECOMPRESS_BOOTSTRAP_CAPACITY));
65 (&mut decoder)
66 .take(limit)
67 .read_to_end(&mut buf)
68 .map_err(CodecError::Io)?;
69 if (buf.len() as u64) > manifest.original_size {
70 return Err(CodecError::Io(std::io::Error::other(format!(
71 "zstd decompression bomb detected: produced {} bytes, manifest claimed {}",
72 buf.len(),
73 manifest.original_size
74 ))));
75 }
76 if buf.len() as u64 != manifest.original_size {
77 return Err(CodecError::SizeMismatch {
78 expected: manifest.original_size,
79 got: buf.len() as u64,
80 });
81 }
82 let actual_crc = crc32c::crc32c(&buf);
83 if actual_crc != manifest.crc32c {
84 return Err(CodecError::CrcMismatch {
85 expected: manifest.crc32c,
86 got: actual_crc,
87 });
88 }
89 Ok(buf)
90}
91
92pub fn compress_blocking(input: &[u8], level: i32) -> Result<(Vec<u8>, ChunkManifest), CodecError> {
96 let level = level.clamp(1, 22);
97 let original_size = input.len() as u64;
98 let original_crc = crc32c::crc32c(input);
99 let compressed = zstd::stream::encode_all(input, level).map_err(CodecError::Io)?;
100 Ok((
101 compressed.clone(),
102 ChunkManifest {
103 codec: CodecKind::CpuZstd,
104 original_size,
105 compressed_size: compressed.len() as u64,
106 crc32c: original_crc,
107 },
108 ))
109}
110
111#[async_trait::async_trait]
112impl Codec for CpuZstd {
113 fn kind(&self) -> CodecKind {
114 CodecKind::CpuZstd
115 }
116
117 async fn compress(&self, input: Bytes) -> Result<(Bytes, ChunkManifest), CodecError> {
118 let level = self.level;
119 let original_size = input.len() as u64;
120 let original_crc = crc32c::crc32c(&input);
121
122 let compressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
123 zstd::stream::encode_all(input.as_ref(), level)
124 })
125 .await??;
126
127 let compressed_size = compressed.len() as u64;
128 let manifest = ChunkManifest {
129 codec: CodecKind::CpuZstd,
130 original_size,
131 compressed_size,
132 crc32c: original_crc,
133 };
134 Ok((Bytes::from(compressed), manifest))
135 }
136
137 async fn decompress(
138 &self,
139 input: Bytes,
140 manifest: &ChunkManifest,
141 ) -> Result<Bytes, CodecError> {
142 if manifest.codec != CodecKind::CpuZstd {
143 return Err(CodecError::CodecMismatch {
144 expected: CodecKind::CpuZstd,
145 got: manifest.codec,
146 });
147 }
148 let allocated_orig_size = validate_decompress_manifest(manifest, input.len())?;
155
156 let expected_crc = manifest.crc32c;
157 let expected_orig_size = manifest.original_size;
158
159 let decompressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
164 use std::io::Read;
165 let limit = expected_orig_size.saturating_add(1024);
169 let mut decoder = zstd::stream::Decoder::new(input.as_ref())?;
170 let mut buf =
173 Vec::with_capacity(allocated_orig_size.min(DECOMPRESS_BOOTSTRAP_CAPACITY));
174 (&mut decoder).take(limit).read_to_end(&mut buf)?;
175 if (buf.len() as u64) > expected_orig_size {
177 return Err(std::io::Error::other(format!(
178 "zstd decompression bomb detected: produced {} bytes, manifest claimed {}",
179 buf.len(),
180 expected_orig_size
181 )));
182 }
183 Ok(buf)
184 })
185 .await??;
186
187 if decompressed.len() as u64 != expected_orig_size {
188 return Err(CodecError::SizeMismatch {
189 expected: expected_orig_size,
190 got: decompressed.len() as u64,
191 });
192 }
193 let actual_crc = crc32c::crc32c(&decompressed);
194 if actual_crc != expected_crc {
195 return Err(CodecError::CrcMismatch {
196 expected: expected_crc,
197 got: actual_crc,
198 });
199 }
200 Ok(Bytes::from(decompressed))
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[tokio::test]
209 async fn roundtrip_small() {
210 let codec = CpuZstd::default();
211 let input = Bytes::from_static(b"hello squished s3 hello squished s3 hello squished s3");
212 let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
213 assert_eq!(manifest.codec, CodecKind::CpuZstd);
215 assert_eq!(manifest.original_size, input.len() as u64);
216 let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
217 assert_eq!(decompressed, input);
218 }
219
220 #[tokio::test]
221 async fn roundtrip_compressible() {
222 let codec = CpuZstd::default();
223 let input = Bytes::from(vec![b'x'; 1024 * 1024]);
225 let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
226 assert!(
227 compressed.len() < input.len() / 100,
228 "expected zstd to compress 1 MiB of x bytes very well, got {} bytes",
229 compressed.len()
230 );
231 let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
232 assert_eq!(decompressed, input);
233 }
234
235 #[tokio::test]
236 async fn detects_corrupted_compressed_payload() {
237 let codec = CpuZstd::default();
238 let input = Bytes::from(vec![b'x'; 1024]);
239 let (mut compressed, manifest) = codec.compress(input).await.unwrap();
240 let mut buf = compressed.to_vec();
242 if buf.len() > 8 {
243 buf[5] ^= 0xff;
244 }
245 compressed = Bytes::from(buf);
246 let err = codec.decompress(compressed, &manifest).await.unwrap_err();
247 assert!(matches!(
249 err,
250 CodecError::Io(_) | CodecError::CrcMismatch { .. } | CodecError::SizeMismatch { .. }
251 ));
252 }
253
254 #[tokio::test]
255 async fn rejects_codec_mismatch() {
256 let codec = CpuZstd::default();
257 let manifest = ChunkManifest {
258 codec: CodecKind::Passthrough,
259 original_size: 10,
260 compressed_size: 10,
261 crc32c: 0,
262 };
263 let err = codec
264 .decompress(Bytes::from_static(b"0123456789"), &manifest)
265 .await
266 .unwrap_err();
267 assert!(matches!(err, CodecError::CodecMismatch { .. }));
268 }
269
270 #[tokio::test]
274 async fn issue_89_rejects_manifest_over_5gib() {
275 let codec = CpuZstd::default();
276 let body = Bytes::from_static(&[0x00, 0xd1, 0xd1, 0xd1, 0xd1, 0xd1]);
277 let manifest = ChunkManifest {
278 codec: CodecKind::CpuZstd,
279 original_size: crate::MAX_DECOMPRESSED_BYTES + 1,
280 compressed_size: body.len() as u64,
281 crc32c: 0,
282 };
283 let err = codec.decompress(body, &manifest).await.unwrap_err();
284 match err {
285 CodecError::ManifestSizeExceedsLimit { requested, limit } => {
286 assert_eq!(requested, crate::MAX_DECOMPRESSED_BYTES + 1);
287 assert_eq!(limit, crate::MAX_DECOMPRESSED_BYTES);
288 }
289 other => panic!("expected ManifestSizeExceedsLimit, got {other:?}"),
290 }
291 }
292
293 #[tokio::test]
304 async fn issue_89_bootstrap_cap_keeps_4gib_claim_alloc_safe() {
305 let codec = CpuZstd::default();
306 let body = Bytes::from_static(&[0x00, 0xd1, 0xd1, 0xd1, 0xd1, 0xd1]);
307 let manifest = ChunkManifest {
308 codec: CodecKind::CpuZstd,
309 original_size: u32::MAX as u64,
313 compressed_size: body.len() as u64,
314 crc32c: 0,
315 };
316 let err = codec.decompress(body, &manifest).await.unwrap_err();
317 assert!(
321 matches!(err, CodecError::Io(_) | CodecError::SizeMismatch { .. }),
322 "expected Io or SizeMismatch, got {err:?}"
323 );
324 }
325
326 #[test]
329 fn blocking_roundtrip() {
330 let input = b"hello squished s3 hello squished s3 hello squished s3";
331 let (compressed, manifest) = compress_blocking(input, CpuZstd::DEFAULT_LEVEL).unwrap();
332 assert_eq!(manifest.codec, CodecKind::CpuZstd);
333 let decompressed = decompress_blocking(&compressed, &manifest).unwrap();
334 assert_eq!(decompressed, input);
335 }
336}