1use std::fmt::{Debug, Display, Formatter};
2use std::io::Write;
3use std::mem::ManuallyDrop;
4
5use zstd_safe::{CCtx, CParameter, DCtx, InBuffer, OutBuffer, ResetDirective, get_error_name};
6
7type Result<T> = std::result::Result<T, ZstdError>;
8
9#[derive(Clone, Copy)]
10pub struct ZstdDecompress<'a> {
11 input: &'a [u8],
12 decompressed_size: Option<u64>,
13}
14
15impl<'a> ZstdDecompress<'a> {
16 pub fn estimate_size(input: &'a [u8]) -> Result<Option<u64>> {
17 const ZSTD_CONTENTSIZE_UNKNOWN: u64 = u64::MAX;
18 const ZSTD_CONTENTSIZE_ERROR: u64 = u64::MAX - 1;
19
20 if input.is_empty() {
21 return Ok(Some(0));
22 }
23
24 let decompressed_size =
26 unsafe { zstd_sys::ZSTD_getFrameContentSize(input.as_ptr().cast(), input.len() as _) };
27
28 match decompressed_size {
29 ZSTD_CONTENTSIZE_UNKNOWN => Ok(None),
31 ZSTD_CONTENTSIZE_ERROR => Err(ZstdError::InvalidDecompressedSize {
32 decompressed_size,
33 input_size: input.len(),
34 }),
35 _ if decompressed_size > input.len().saturating_mul(10) as u64 => {
36 Err(ZstdError::SuspiciousCompressionRatio {
37 compressed_size: input.len(),
38 decompressed_size,
39 })
40 }
41 _ => Ok(Some(decompressed_size)),
42 }
43 }
44
45 pub fn begin(input: &'a [u8]) -> Result<Self> {
46 let decompressed_size = Self::estimate_size(input)?;
47 Ok(Self {
48 input,
49 decompressed_size,
50 })
51 }
52
53 pub fn with_known_size(input: &'a [u8], decompressed_size: Option<u64>) -> Self {
54 Self {
55 input,
56 decompressed_size,
57 }
58 }
59
60 pub fn decompressed_size(&self) -> Option<u64> {
61 self.decompressed_size
62 }
63
64 pub fn decompress(self, output: &mut Vec<u8>) -> Result<()> {
65 const MAX_SAFE_RESERVE: usize = 1 << 30; output.clear();
68 if self.input.is_empty() {
69 return Ok(());
70 }
71
72 if let Some(decompressed_size) = self.decompressed_size {
73 output.reserve(std::cmp::min(decompressed_size as usize, MAX_SAFE_RESERVE));
74 zstd_safe::decompress(output, self.input).map_err(ZstdError::from_raw)?;
75 Ok(())
76 } else {
77 ZstdDecompressStream::new(self.input.len())?.write(self.input, output)
78 }
79 }
80}
81
82#[cfg(any(test, feature = "test"))]
84pub fn zstd_decompress_simple(input: &[u8]) -> Result<Vec<u8>> {
85 let mut output = Vec::new();
86 ZstdDecompress::begin(input)?.decompress(&mut output)?;
87 Ok(output)
88}
89
90pub fn zstd_compress(input: &[u8], output: &mut Vec<u8>, compression_level: i32) {
93 output.clear();
94
95 let max_compressed_size = zstd_safe::compress_bound(input.len());
97
98 output.reserve_exact(max_compressed_size);
100
101 zstd_safe::compress(output, input, compression_level).expect("buffer size is set correctly");
103}
104
105#[cfg(any(test, feature = "test"))]
107pub fn zstd_compress_simple(data: &[u8]) -> Vec<u8> {
108 let mut compressed = Vec::new();
109 zstd_compress(data, &mut compressed, 3);
110 compressed
111}
112
113pub struct ZstdCompressedFile<W: Write> {
114 writer: W,
115 compressor: ZstdCompressStream<'static>,
116 buffer: Vec<u8>,
117}
118
119impl<W: Write> ZstdCompressedFile<W> {
120 pub fn new(writer: W, compression_level: i32, buffer_capacity: usize) -> Result<Self> {
121 Ok(Self {
122 writer,
123 buffer: Vec::with_capacity(buffer_capacity),
124 compressor: ZstdCompressStream::new(compression_level, buffer_capacity)?,
125 })
126 }
127
128 pub fn finish(mut self) -> std::io::Result<W> {
130 self.finish_impl()?;
131
132 let mut this = ManuallyDrop::new(self);
133 let _buffer = std::mem::take(&mut this.buffer);
134
135 let writer = unsafe { std::ptr::read(&this.writer) };
137
138 let _compressor = unsafe { std::ptr::read(&this.compressor) };
140
141 Ok(writer)
142 }
143
144 fn finish_impl(&mut self) -> std::io::Result<()> {
145 self.compressor.finish(&mut self.buffer)?;
146 if !self.buffer.is_empty() {
147 self.writer.write_all(&self.buffer)?;
148 self.buffer.clear();
149 }
150 Ok(())
151 }
152
153 fn flush_buf(&mut self) -> std::io::Result<()> {
154 if !self.buffer.is_empty() {
155 if self.compressor.finished {
156 return Err(std::io::Error::other("compressor already terminated"));
157 }
158
159 self.writer.write_all(&self.buffer)?;
160 self.buffer.clear();
161 }
162 Ok(())
163 }
164}
165
166impl<W: Write> Write for ZstdCompressedFile<W> {
167 fn write(&mut self, data: &[u8]) -> std::io::Result<usize> {
168 self.write_all(data).map(|_| data.len())
169 }
170
171 fn write_all(&mut self, data: &[u8]) -> std::io::Result<()> {
172 self.compressor.write(data, &mut self.buffer)?;
173 self.flush_buf()
174 }
175
176 fn flush(&mut self) -> std::io::Result<()> {
177 self.flush_buf()?;
178 self.writer.flush()
179 }
180}
181
182impl<W: Write> Drop for ZstdCompressedFile<W> {
183 fn drop(&mut self) {
184 if !self.compressor.finished {
185 let _ = self.finish_impl();
186 }
187 }
188}
189
190pub struct ZstdCompressStream<'s> {
191 cctx: CCtx<'s>,
192 finished: bool,
193 resize_by: usize,
194}
195
196impl ZstdCompressStream<'_> {
197 pub fn new(compression_level: i32, resize_by: usize) -> Result<Self> {
201 let mut cctx = CCtx::create();
202 cctx.set_parameter(CParameter::CompressionLevel(compression_level))
203 .map_err(ZstdError::from_raw)?;
204
205 Ok(Self {
206 cctx,
207 finished: false,
208 resize_by,
209 })
210 }
211
212 pub fn multithreaded(&mut self, workers: u8) -> Result<()> {
218 self.cctx
219 .set_parameter(CParameter::NbWorkers(workers as _))
220 .map_err(ZstdError::from_raw)?;
221
222 Ok(())
223 }
224
225 pub fn write(&mut self, uncompressed: &[u8], compress_buffer: &mut Vec<u8>) -> Result<()> {
226 const MODE: zstd_sys::ZSTD_EndDirective = zstd_sys::ZSTD_EndDirective::ZSTD_e_continue;
227 if self.finished {
228 return Err(ZstdError::StreamAlreadyFinished);
229 }
230
231 if uncompressed.is_empty() {
232 return Ok(());
233 }
234
235 let mut input = InBuffer::around(uncompressed);
236
237 loop {
242 let mut output = self.out_buffer(compress_buffer);
243
244 self.cctx
245 .compress_stream2(&mut output, &mut input, MODE)
246 .map_err(ZstdError::from_raw)?;
247
248 if input.pos() >= input.src.len() {
263 break Ok(());
264 }
265 }
266 }
267
268 fn out_buffer<'b>(&self, compress_buffer: &'b mut Vec<u8>) -> OutBuffer<'b, Vec<u8>> {
269 let start = compress_buffer.len();
271 if compress_buffer.spare_capacity_mut().len() < self.resize_by {
273 compress_buffer.reserve(self.resize_by);
274 }
275
276 OutBuffer::around_pos(compress_buffer, start)
277 }
278
279 pub fn finish(&mut self, compress_buffer: &mut Vec<u8>) -> Result<()> {
280 if self.finished {
281 return Ok(());
282 }
283
284 loop {
285 let mut output = self.out_buffer(compress_buffer);
286
287 let remaining = self
288 .cctx
289 .end_stream(&mut output)
290 .map_err(ZstdError::from_raw)?;
291
292 if remaining == 0 {
293 self.finished = true;
294 return Ok(());
295 }
296 }
297 }
298
299 pub fn reset(&mut self) -> Result<()> {
302 self.cctx
303 .reset(ResetDirective::SessionOnly)
304 .map_err(ZstdError::from_raw)?;
305 self.finished = false;
306
307 Ok(())
308 }
309}
310
311pub struct ZstdDecompressStream<'s> {
312 dctx: DCtx<'s>,
313 resize_by: usize,
314 finished: bool,
315}
316
317impl ZstdDecompressStream<'_> {
318 pub fn new(resize_by: usize) -> Result<Self> {
319 let mut dctx = DCtx::create();
320 dctx.init().map_err(ZstdError::from_raw)?;
321
322 Ok(Self {
323 dctx,
324 resize_by,
325 finished: false,
326 })
327 }
328
329 pub fn write(&mut self, compressed: &[u8], decompress_buffer: &mut Vec<u8>) -> Result<()> {
330 if self.finished {
331 return Err(ZstdError::StreamAlreadyFinished);
332 }
333 if compressed.is_empty() {
334 return Ok(());
335 }
336
337 let mut input = InBuffer::around(compressed);
338
339 loop {
340 let start = decompress_buffer.len();
341 if decompress_buffer.spare_capacity_mut().len() < self.resize_by {
342 decompress_buffer.reserve(self.resize_by);
343 }
344
345 if input.pos() == input.src.len() {
347 break Ok(());
348 }
349
350 let mut output = OutBuffer::around_pos(decompress_buffer, start);
351 let read = self
352 .dctx
353 .decompress_stream(&mut output, &mut input)
354 .map_err(ZstdError::from_raw)?;
355
356 if read == 0 {
358 self.finished = true;
359 break Ok(());
360 }
361 }
362 }
363
364 pub fn reset(&mut self) -> Result<()> {
367 self.dctx
368 .reset(ResetDirective::SessionOnly)
369 .map_err(ZstdError::from_raw)?;
370 self.finished = false;
371
372 Ok(())
373 }
374}
375
376#[derive(thiserror::Error, Debug)]
377pub enum ZstdError {
378 #[error("Zstd error: {0}")]
379 Raw(#[from] RawCompressorError),
380
381 #[error(
382 "Suspicious compression ratio detected: compressed size: {compressed_size}, decompressed size: {decompressed_size}"
383 )]
384 SuspiciousCompressionRatio {
385 compressed_size: usize,
386 decompressed_size: u64,
387 },
388
389 #[error("Invalid decompressed size: {decompressed_size}, input size: {input_size}")]
390 InvalidDecompressedSize {
391 decompressed_size: u64,
392 input_size: usize,
393 },
394
395 #[error("Stream already finished")]
396 StreamAlreadyFinished,
397}
398
399impl From<ZstdError> for std::io::Error {
400 fn from(value: ZstdError) -> Self {
401 std::io::Error::other(value)
402 }
403}
404
405impl ZstdError {
406 fn from_raw(code: usize) -> Self {
407 ZstdError::Raw(RawCompressorError { code })
408 }
409}
410
411pub struct RawCompressorError {
412 code: usize,
413}
414
415impl Debug for RawCompressorError {
416 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
417 f.write_str(get_error_name(self.code))
418 }
419}
420
421impl Display for RawCompressorError {
422 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
423 f.write_str(get_error_name(self.code))
424 }
425}
426
427impl std::error::Error for RawCompressorError {}
428
429#[cfg(test)]
430mod tests {
431 use std::io::{Read, Seek};
432
433 use rand::prelude::StdRng;
434 use rand::{RngCore, SeedableRng};
435
436 use super::*;
437
438 #[test]
439 fn test_zstd_compress_decompress() {
440 let seed = 42; let mut rng = StdRng::seed_from_u64(seed);
442
443 for size in [10, 1024, 1024 * 1024, 10 * 1024 * 1024] {
444 let mut input = vec![0; size];
445 rng.fill_bytes(input.as_mut_slice());
447
448 let mut compressed = Vec::new();
449 zstd_compress(&input, &mut compressed, 3);
450
451 let decompressed = zstd_decompress_simple(&compressed).unwrap();
452 assert_eq!(input, decompressed.as_slice());
453 }
454
455 let input = b"Hello, world!";
456 let mut compressed = Vec::new();
457 zstd_compress(input, &mut compressed, 3);
458 let decompressed = zstd_decompress_simple(&compressed).unwrap();
459 assert_eq!(input, decompressed.as_slice());
460
461 let mut input = b"bad".to_vec();
462 input.extend_from_slice(&compressed);
463 zstd_decompress_simple(&input).unwrap_err();
464 }
465
466 #[test]
467 fn test_streaming() {
468 for size in [10usize, 1021, 1024, 1024 * 1024, 10 * 1024 * 1024] {
469 let input = vec![0; size];
470 check_compression(input, false);
471
472 }
475
476 let pseudo_random = (0..1024)
477 .map(|i: u32| i.overflowing_mul(13).0 as u8)
478 .collect::<Vec<_>>();
479 check_compression(pseudo_random, false);
480
481 let hello_world = Vec::from_iter(b"Hello, world!".repeat(1023));
482 check_compression(hello_world, false);
483 }
484
485 #[test]
487 fn test_steaming_mt() {
488 for size in [10usize, 1021, 1024, 1024 * 1024, 10 * 1024 * 1024] {
489 let input = vec![0; size];
490 check_compression(input, true);
491
492 }
495
496 let pseudo_random = (0..1024)
497 .map(|i: u32| i.overflowing_mul(13).0 as u8)
498 .collect::<Vec<_>>();
499 check_compression(pseudo_random, true);
500
501 let hello_world = Vec::from_iter(b"Hello, world!".repeat(1023));
502 check_compression(hello_world, true);
503 }
504
505 fn check_compression(input: Vec<u8>, multithreaded: bool) {
506 let mut compressor = ZstdCompressStream::new(3, 128).unwrap();
507 if multithreaded {
508 compressor.multithreaded(4).unwrap();
509 }
510
511 let mut compress_buffer = Vec::new();
512 let mut result_buf = Vec::new();
513
514 for chunk in input.chunks(1024) {
515 compressor.write(chunk, &mut compress_buffer).unwrap();
516 if compress_buffer.len() > 1024 {
517 result_buf.extend_from_slice(&compress_buffer);
518 compress_buffer.clear();
519 }
520 }
521 compressor.finish(&mut compress_buffer).unwrap();
522 result_buf.extend_from_slice(&compress_buffer);
523
524 let decompressed = zstd_decompress_simple(&result_buf).unwrap();
525 assert_eq!(input, decompressed);
526
527 let decompressed = {
528 let mut streaming_decoder = ZstdDecompressStream::new(128).unwrap();
529 let mut decompressed = Vec::new();
530 streaming_decoder
531 .write(&result_buf, &mut decompressed)
532 .unwrap();
533 decompressed
534 };
535 assert_eq!(input, decompressed);
536 }
537
538 #[test]
539 fn test_dos() {
540 for malicious in malicious_files() {
541 if zstd_decompress_simple(&malicious).is_ok() {
542 panic!("Malicious file was decompressed successfully");
543 }
544 }
545 }
546
547 fn malicious_files() -> Vec<Vec<u8>> {
548 let mut files = Vec::new();
549
550 files.push(create_malicious_zstd(1_000_000_000, b"Small content"));
552
553 files.push(create_malicious_zstd(
555 10,
556 b"This content is actually longer than claimed",
557 ));
558
559 let large_content = vec![b'A'; 1_000_000];
561 files.push(create_malicious_zstd(
562 large_content.len() as u64,
563 &large_content,
564 ));
565
566 files.push(vec![
568 0x28, 0xB5, 0x2F, 0xFD, 0x40, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
569 ]);
570
571 let truncated_content = b"This file will be truncated";
573 let mut truncated_compressed = encode_all(truncated_content.as_slice(), 3).unwrap();
574 truncated_compressed.truncate(truncated_compressed.len() / 2);
575 files.push(truncated_compressed);
576
577 files
578 }
579
580 fn encode_all(input: &[u8], level: i32) -> Result<Vec<u8>> {
581 let mut compressed = Vec::new();
582 zstd_compress(input, &mut compressed, level);
583 Ok(compressed)
584 }
585
586 fn create_malicious_zstd(content_size: u64, actual_content: &[u8]) -> Vec<u8> {
587 let mut compressed = encode_all(actual_content, 3).unwrap();
588
589 compressed[4] = (compressed[4] & 0b11000000) | 0b00000011;
591
592 compressed.splice(5..9, content_size.to_le_bytes());
594
595 compressed
596 }
597
598 #[test]
599 fn test_decode_chunked() {
600 let mut rng = StdRng::seed_from_u64(42);
601 let mut data = Vec::with_capacity(10 * 1024 * 1024);
602 let mut pseudo_rand_patern = vec![0; 1024 * 1024];
603 rng.fill_bytes(&mut pseudo_rand_patern);
604
605 for _ in 0..10 {
606 data.extend_from_slice(&pseudo_rand_patern);
607 }
608
609 let compressed = encode_all(&data, 3).unwrap();
610 let mut decompressed = Vec::new();
611
612 let mut decompressor = ZstdDecompressStream::new(128).unwrap();
613 for chunk in compressed.chunks(1024) {
614 decompressor.write(chunk, &mut decompressed).unwrap();
615 }
616
617 assert_eq!(data, decompressed);
618 }
619
620 #[test]
621 fn buffered_compress_decompress() {
622 const BUFFER_LEN: usize = 64 << 20; let mut rng = StdRng::seed_from_u64(42);
626 let mut original = vec![0; 4 << 20];
627 rng.fill_bytes(&mut original);
628
629 for prealloc in [1024, 4194409, BUFFER_LEN] {
631 let mut compressed = Vec::new();
633 {
634 let file = tempfile::tempfile().unwrap();
635 file.set_len(prealloc as _).unwrap();
636 let file = ZstdCompressedFile::new(file, 9, BUFFER_LEN).unwrap();
637
638 let mut buffer = std::io::BufWriter::with_capacity(BUFFER_LEN, file);
639 for chunk in original.chunks(2048) {
640 buffer.write_all(chunk).unwrap();
641 }
642
643 let file = buffer.into_inner().map_err(|e| e.into_error()).unwrap();
644 let mut file = file.finish().unwrap();
645 file.flush().unwrap();
646
647 let file_size = file.stream_position().unwrap();
648 file.set_len(file_size).unwrap(); file.seek(std::io::SeekFrom::Start(0)).unwrap();
651
652 #[allow(clippy::verbose_file_reads)]
653 file.read_to_end(&mut compressed).unwrap();
654 }
655
656 {
658 let mut stream = ZstdDecompressStream::new(1 << 20).unwrap();
659
660 let mut decompressed = Vec::new();
661 let mut decompressed_chunk = Vec::new();
662 for chunk in compressed.chunks(1 << 20) {
663 decompressed_chunk.clear();
664 stream.write(chunk, &mut decompressed_chunk).unwrap();
665
666 decompressed.extend_from_slice(&decompressed_chunk);
667 }
668
669 assert_eq!(decompressed, original);
670 }
671 }
672 }
673}