1use bytes::Bytes;
8
9use crate::{ChunkManifest, Codec, CodecError, CodecKind};
10
11#[derive(Debug, Clone)]
15pub struct CpuZstd {
16 level: i32,
17}
18
19impl CpuZstd {
20 pub const DEFAULT_LEVEL: i32 = 3;
21
22 pub fn new(level: i32) -> Self {
23 Self {
24 level: level.clamp(1, 22),
25 }
26 }
27}
28
29impl Default for CpuZstd {
30 fn default() -> Self {
31 Self::new(Self::DEFAULT_LEVEL)
32 }
33}
34
35pub fn decompress_blocking(input: &[u8], manifest: &ChunkManifest) -> Result<Vec<u8>, CodecError> {
43 if manifest.codec != CodecKind::CpuZstd {
44 return Err(CodecError::CodecMismatch {
45 expected: CodecKind::CpuZstd,
46 got: manifest.codec,
47 });
48 }
49 if input.len() as u64 != manifest.compressed_size {
50 return Err(CodecError::SizeMismatch {
51 expected: manifest.compressed_size,
52 got: input.len() as u64,
53 });
54 }
55 use std::io::Read;
56 let limit = manifest.original_size.saturating_add(1024);
57 let mut decoder = zstd::stream::Decoder::new(input).map_err(CodecError::Io)?;
58 let mut buf = Vec::with_capacity(manifest.original_size as usize);
59 (&mut decoder)
60 .take(limit)
61 .read_to_end(&mut buf)
62 .map_err(CodecError::Io)?;
63 if (buf.len() as u64) > manifest.original_size {
64 return Err(CodecError::Io(std::io::Error::other(format!(
65 "zstd decompression bomb detected: produced {} bytes, manifest claimed {}",
66 buf.len(),
67 manifest.original_size
68 ))));
69 }
70 if buf.len() as u64 != manifest.original_size {
71 return Err(CodecError::SizeMismatch {
72 expected: manifest.original_size,
73 got: buf.len() as u64,
74 });
75 }
76 let actual_crc = crc32c::crc32c(&buf);
77 if actual_crc != manifest.crc32c {
78 return Err(CodecError::CrcMismatch {
79 expected: manifest.crc32c,
80 got: actual_crc,
81 });
82 }
83 Ok(buf)
84}
85
86pub fn compress_blocking(input: &[u8], level: i32) -> Result<(Vec<u8>, ChunkManifest), CodecError> {
90 let level = level.clamp(1, 22);
91 let original_size = input.len() as u64;
92 let original_crc = crc32c::crc32c(input);
93 let compressed = zstd::stream::encode_all(input, level).map_err(CodecError::Io)?;
94 Ok((
95 compressed.clone(),
96 ChunkManifest {
97 codec: CodecKind::CpuZstd,
98 original_size,
99 compressed_size: compressed.len() as u64,
100 crc32c: original_crc,
101 },
102 ))
103}
104
105#[async_trait::async_trait]
106impl Codec for CpuZstd {
107 fn kind(&self) -> CodecKind {
108 CodecKind::CpuZstd
109 }
110
111 async fn compress(&self, input: Bytes) -> Result<(Bytes, ChunkManifest), CodecError> {
112 let level = self.level;
113 let original_size = input.len() as u64;
114 let original_crc = crc32c::crc32c(&input);
115
116 let compressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
117 zstd::stream::encode_all(input.as_ref(), level)
118 })
119 .await??;
120
121 let compressed_size = compressed.len() as u64;
122 let manifest = ChunkManifest {
123 codec: CodecKind::CpuZstd,
124 original_size,
125 compressed_size,
126 crc32c: original_crc,
127 };
128 Ok((Bytes::from(compressed), manifest))
129 }
130
131 async fn decompress(
132 &self,
133 input: Bytes,
134 manifest: &ChunkManifest,
135 ) -> Result<Bytes, CodecError> {
136 if manifest.codec != CodecKind::CpuZstd {
137 return Err(CodecError::CodecMismatch {
138 expected: CodecKind::CpuZstd,
139 got: manifest.codec,
140 });
141 }
142 if input.len() as u64 != manifest.compressed_size {
143 return Err(CodecError::SizeMismatch {
144 expected: manifest.compressed_size,
145 got: input.len() as u64,
146 });
147 }
148
149 let expected_crc = manifest.crc32c;
150 let expected_orig_size = manifest.original_size;
151
152 let decompressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
157 use std::io::Read;
158 let limit = expected_orig_size.saturating_add(1024);
162 let mut decoder = zstd::stream::Decoder::new(input.as_ref())?;
163 let mut buf = Vec::with_capacity(expected_orig_size as usize);
164 (&mut decoder).take(limit).read_to_end(&mut buf)?;
165 if (buf.len() as u64) > expected_orig_size {
167 return Err(std::io::Error::other(format!(
168 "zstd decompression bomb detected: produced {} bytes, manifest claimed {}",
169 buf.len(),
170 expected_orig_size
171 )));
172 }
173 Ok(buf)
174 })
175 .await??;
176
177 if decompressed.len() as u64 != expected_orig_size {
178 return Err(CodecError::SizeMismatch {
179 expected: expected_orig_size,
180 got: decompressed.len() as u64,
181 });
182 }
183 let actual_crc = crc32c::crc32c(&decompressed);
184 if actual_crc != expected_crc {
185 return Err(CodecError::CrcMismatch {
186 expected: expected_crc,
187 got: actual_crc,
188 });
189 }
190 Ok(Bytes::from(decompressed))
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[tokio::test]
199 async fn roundtrip_small() {
200 let codec = CpuZstd::default();
201 let input = Bytes::from_static(b"hello squished s3 hello squished s3 hello squished s3");
202 let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
203 assert_eq!(manifest.codec, CodecKind::CpuZstd);
205 assert_eq!(manifest.original_size, input.len() as u64);
206 let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
207 assert_eq!(decompressed, input);
208 }
209
210 #[tokio::test]
211 async fn roundtrip_compressible() {
212 let codec = CpuZstd::default();
213 let input = Bytes::from(vec![b'x'; 1024 * 1024]);
215 let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
216 assert!(
217 compressed.len() < input.len() / 100,
218 "expected zstd to compress 1 MiB of x bytes very well, got {} bytes",
219 compressed.len()
220 );
221 let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
222 assert_eq!(decompressed, input);
223 }
224
225 #[tokio::test]
226 async fn detects_corrupted_compressed_payload() {
227 let codec = CpuZstd::default();
228 let input = Bytes::from(vec![b'x'; 1024]);
229 let (mut compressed, manifest) = codec.compress(input).await.unwrap();
230 let mut buf = compressed.to_vec();
232 if buf.len() > 8 {
233 buf[5] ^= 0xff;
234 }
235 compressed = Bytes::from(buf);
236 let err = codec.decompress(compressed, &manifest).await.unwrap_err();
237 assert!(matches!(
239 err,
240 CodecError::Io(_) | CodecError::CrcMismatch { .. } | CodecError::SizeMismatch { .. }
241 ));
242 }
243
244 #[tokio::test]
245 async fn rejects_codec_mismatch() {
246 let codec = CpuZstd::default();
247 let manifest = ChunkManifest {
248 codec: CodecKind::Passthrough,
249 original_size: 10,
250 compressed_size: 10,
251 crc32c: 0,
252 };
253 let err = codec
254 .decompress(Bytes::from_static(b"0123456789"), &manifest)
255 .await
256 .unwrap_err();
257 assert!(matches!(err, CodecError::CodecMismatch { .. }));
258 }
259
260 #[test]
263 fn blocking_roundtrip() {
264 let input = b"hello squished s3 hello squished s3 hello squished s3";
265 let (compressed, manifest) = compress_blocking(input, CpuZstd::DEFAULT_LEVEL).unwrap();
266 assert_eq!(manifest.codec, CodecKind::CpuZstd);
267 let decompressed = decompress_blocking(&compressed, &manifest).unwrap();
268 assert_eq!(decompressed, input);
269 }
270}