1use bytes::Bytes;
8
9use crate::{ChunkManifest, Codec, CodecError, CodecKind};
10
11#[derive(Debug, Clone)]
15pub struct CpuZstd {
16 level: i32,
17}
18
19impl CpuZstd {
20 pub const DEFAULT_LEVEL: i32 = 3;
21
22 pub fn new(level: i32) -> Self {
23 Self {
24 level: level.clamp(1, 22),
25 }
26 }
27}
28
29impl Default for CpuZstd {
30 fn default() -> Self {
31 Self::new(Self::DEFAULT_LEVEL)
32 }
33}
34
35#[async_trait::async_trait]
36impl Codec for CpuZstd {
37 fn kind(&self) -> CodecKind {
38 CodecKind::CpuZstd
39 }
40
41 async fn compress(&self, input: Bytes) -> Result<(Bytes, ChunkManifest), CodecError> {
42 let level = self.level;
43 let original_size = input.len() as u64;
44 let original_crc = crc32c::crc32c(&input);
45
46 let compressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
47 zstd::stream::encode_all(input.as_ref(), level)
48 })
49 .await??;
50
51 let compressed_size = compressed.len() as u64;
52 let manifest = ChunkManifest {
53 codec: CodecKind::CpuZstd,
54 original_size,
55 compressed_size,
56 crc32c: original_crc,
57 };
58 Ok((Bytes::from(compressed), manifest))
59 }
60
61 async fn decompress(
62 &self,
63 input: Bytes,
64 manifest: &ChunkManifest,
65 ) -> Result<Bytes, CodecError> {
66 if manifest.codec != CodecKind::CpuZstd {
67 return Err(CodecError::CodecMismatch {
68 expected: CodecKind::CpuZstd,
69 got: manifest.codec,
70 });
71 }
72 if input.len() as u64 != manifest.compressed_size {
73 return Err(CodecError::SizeMismatch {
74 expected: manifest.compressed_size,
75 got: input.len() as u64,
76 });
77 }
78
79 let expected_crc = manifest.crc32c;
80 let expected_orig_size = manifest.original_size;
81
82 let decompressed = tokio::task::spawn_blocking(move || -> std::io::Result<Vec<u8>> {
87 use std::io::Read;
88 let limit = expected_orig_size.saturating_add(1024);
92 let mut decoder = zstd::stream::Decoder::new(input.as_ref())?;
93 let mut buf = Vec::with_capacity(expected_orig_size as usize);
94 (&mut decoder).take(limit).read_to_end(&mut buf)?;
95 if (buf.len() as u64) > expected_orig_size {
97 return Err(std::io::Error::other(format!(
98 "zstd decompression bomb detected: produced {} bytes, manifest claimed {}",
99 buf.len(),
100 expected_orig_size
101 )));
102 }
103 Ok(buf)
104 })
105 .await??;
106
107 if decompressed.len() as u64 != expected_orig_size {
108 return Err(CodecError::SizeMismatch {
109 expected: expected_orig_size,
110 got: decompressed.len() as u64,
111 });
112 }
113 let actual_crc = crc32c::crc32c(&decompressed);
114 if actual_crc != expected_crc {
115 return Err(CodecError::CrcMismatch {
116 expected: expected_crc,
117 got: actual_crc,
118 });
119 }
120 Ok(Bytes::from(decompressed))
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[tokio::test]
129 async fn roundtrip_small() {
130 let codec = CpuZstd::default();
131 let input = Bytes::from_static(b"hello squished s3 hello squished s3 hello squished s3");
132 let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
133 assert_eq!(manifest.codec, CodecKind::CpuZstd);
135 assert_eq!(manifest.original_size, input.len() as u64);
136 let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
137 assert_eq!(decompressed, input);
138 }
139
140 #[tokio::test]
141 async fn roundtrip_compressible() {
142 let codec = CpuZstd::default();
143 let input = Bytes::from(vec![b'x'; 1024 * 1024]);
145 let (compressed, manifest) = codec.compress(input.clone()).await.unwrap();
146 assert!(
147 compressed.len() < input.len() / 100,
148 "expected zstd to compress 1 MiB of x bytes very well, got {} bytes",
149 compressed.len()
150 );
151 let decompressed = codec.decompress(compressed, &manifest).await.unwrap();
152 assert_eq!(decompressed, input);
153 }
154
155 #[tokio::test]
156 async fn detects_corrupted_compressed_payload() {
157 let codec = CpuZstd::default();
158 let input = Bytes::from(vec![b'x'; 1024]);
159 let (mut compressed, manifest) = codec.compress(input).await.unwrap();
160 let mut buf = compressed.to_vec();
162 if buf.len() > 8 {
163 buf[5] ^= 0xff;
164 }
165 compressed = Bytes::from(buf);
166 let err = codec.decompress(compressed, &manifest).await.unwrap_err();
167 assert!(matches!(
169 err,
170 CodecError::Io(_) | CodecError::CrcMismatch { .. } | CodecError::SizeMismatch { .. }
171 ));
172 }
173
174 #[tokio::test]
175 async fn rejects_codec_mismatch() {
176 let codec = CpuZstd::default();
177 let manifest = ChunkManifest {
178 codec: CodecKind::Passthrough,
179 original_size: 10,
180 compressed_size: 10,
181 crc32c: 0,
182 };
183 let err = codec
184 .decompress(Bytes::from_static(b"0123456789"), &manifest)
185 .await
186 .unwrap_err();
187 assert!(matches!(err, CodecError::CodecMismatch { .. }));
188 }
189}