tycho_util/
compression.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::io::Write;
3use std::mem::ManuallyDrop;
4
5use zstd_safe::{CCtx, CParameter, DCtx, InBuffer, OutBuffer, ResetDirective, get_error_name};
6
7type Result<T> = std::result::Result<T, ZstdError>;
8
9#[derive(Clone, Copy)]
10pub struct ZstdDecompress<'a> {
11    input: &'a [u8],
12    decompressed_size: Option<u64>,
13}
14
15impl<'a> ZstdDecompress<'a> {
16    pub fn estimate_size(input: &'a [u8]) -> Result<Option<u64>> {
17        const ZSTD_CONTENTSIZE_UNKNOWN: u64 = u64::MAX;
18        const ZSTD_CONTENTSIZE_ERROR: u64 = u64::MAX - 1;
19
20        if input.is_empty() {
21            return Ok(Some(0));
22        }
23
24        // Try to decompress with known size from header
25        let decompressed_size =
26            unsafe { zstd_sys::ZSTD_getFrameContentSize(input.as_ptr().cast(), input.len() as _) };
27
28        match decompressed_size {
29            // TODO: Maybe forbid decompress when unknown?
30            ZSTD_CONTENTSIZE_UNKNOWN => Ok(None),
31            ZSTD_CONTENTSIZE_ERROR => Err(ZstdError::InvalidDecompressedSize {
32                decompressed_size,
33                input_size: input.len(),
34            }),
35            _ if decompressed_size > input.len().saturating_mul(10) as u64 => {
36                Err(ZstdError::SuspiciousCompressionRatio {
37                    compressed_size: input.len(),
38                    decompressed_size,
39                })
40            }
41            _ => Ok(Some(decompressed_size)),
42        }
43    }
44
45    pub fn begin(input: &'a [u8]) -> Result<Self> {
46        let decompressed_size = Self::estimate_size(input)?;
47        Ok(Self {
48            input,
49            decompressed_size,
50        })
51    }
52
53    pub fn with_known_size(input: &'a [u8], decompressed_size: Option<u64>) -> Self {
54        Self {
55            input,
56            decompressed_size,
57        }
58    }
59
60    pub fn decompressed_size(&self) -> Option<u64> {
61        self.decompressed_size
62    }
63
64    pub fn decompress(self, output: &mut Vec<u8>) -> Result<()> {
65        const MAX_SAFE_RESERVE: usize = 1 << 30; // 1 GB
66
67        output.clear();
68        if self.input.is_empty() {
69            return Ok(());
70        }
71
72        if let Some(decompressed_size) = self.decompressed_size {
73            output.reserve(std::cmp::min(decompressed_size as usize, MAX_SAFE_RESERVE));
74            zstd_safe::decompress(output, self.input).map_err(ZstdError::from_raw)?;
75            Ok(())
76        } else {
77            ZstdDecompressStream::new(self.input.len())?.write(self.input, output)
78        }
79    }
80}
81
82/// tries to decompress data with known size from header, if it fails, fallbacks to streaming decompression
83#[cfg(any(test, feature = "test"))]
84pub fn zstd_decompress_simple(input: &[u8]) -> Result<Vec<u8>> {
85    let mut output = Vec::new();
86    ZstdDecompress::begin(input)?.decompress(&mut output)?;
87    Ok(output)
88}
89
90/// Compresses the input data using zstd with the specified compression level.
91/// Writes decompressed size into the output buffer.
92pub fn zstd_compress(input: &[u8], output: &mut Vec<u8>, compression_level: i32) {
93    output.clear();
94
95    // Calculate the maximum compressed size
96    let max_compressed_size = zstd_safe::compress_bound(input.len());
97
98    // Resize the output vector to accommodate the maximum possible compressed size
99    output.reserve_exact(max_compressed_size);
100
101    // Perform the compression
102    zstd_safe::compress(output, input, compression_level).expect("buffer size is set correctly");
103}
104
105/// Test utility for compression operations
106#[cfg(any(test, feature = "test"))]
107pub fn zstd_compress_simple(data: &[u8]) -> Vec<u8> {
108    let mut compressed = Vec::new();
109    zstd_compress(data, &mut compressed, 3);
110    compressed
111}
112
113pub struct ZstdCompressedFile<W: Write> {
114    writer: W,
115    compressor: ZstdCompressStream<'static>,
116    buffer: Vec<u8>,
117}
118
119impl<W: Write> ZstdCompressedFile<W> {
120    pub fn new(writer: W, compression_level: i32, buffer_capacity: usize) -> Result<Self> {
121        Ok(Self {
122            writer,
123            buffer: Vec::with_capacity(buffer_capacity),
124            compressor: ZstdCompressStream::new(compression_level, buffer_capacity)?,
125        })
126    }
127
128    /// Terminates the compression stream. All subsequent writes will fail.
129    pub fn finish(mut self) -> std::io::Result<W> {
130        self.finish_impl()?;
131
132        let mut this = ManuallyDrop::new(self);
133        let _buffer = std::mem::take(&mut this.buffer);
134
135        // SAFETY: double-drops are prevented by putting `this` in a ManuallyDrop that is never dropped
136        let writer = unsafe { std::ptr::read(&this.writer) };
137
138        // SAFETY: double-drops are prevented by putting `this` in a ManuallyDrop that is never dropped
139        let _compressor = unsafe { std::ptr::read(&this.compressor) };
140
141        Ok(writer)
142    }
143
144    fn finish_impl(&mut self) -> std::io::Result<()> {
145        self.compressor.finish(&mut self.buffer)?;
146        if !self.buffer.is_empty() {
147            self.writer.write_all(&self.buffer)?;
148            self.buffer.clear();
149        }
150        Ok(())
151    }
152
153    fn flush_buf(&mut self) -> std::io::Result<()> {
154        if !self.buffer.is_empty() {
155            if self.compressor.finished {
156                return Err(std::io::Error::other("compressor already terminated"));
157            }
158
159            self.writer.write_all(&self.buffer)?;
160            self.buffer.clear();
161        }
162        Ok(())
163    }
164}
165
166impl<W: Write> Write for ZstdCompressedFile<W> {
167    fn write(&mut self, data: &[u8]) -> std::io::Result<usize> {
168        self.write_all(data).map(|_| data.len())
169    }
170
171    fn write_all(&mut self, data: &[u8]) -> std::io::Result<()> {
172        self.compressor.write(data, &mut self.buffer)?;
173        self.flush_buf()
174    }
175
176    fn flush(&mut self) -> std::io::Result<()> {
177        self.flush_buf()?;
178        self.writer.flush()
179    }
180}
181
182impl<W: Write> Drop for ZstdCompressedFile<W> {
183    fn drop(&mut self) {
184        if !self.compressor.finished {
185            let _ = self.finish_impl();
186        }
187    }
188}
189
190pub struct ZstdCompressStream<'s> {
191    cctx: CCtx<'s>,
192    finished: bool,
193    resize_by: usize,
194}
195
196impl ZstdCompressStream<'_> {
197    /// # Arguments
198    /// * `compression_level` - The compression level to use.
199    /// * `resize_by` - The amount to resize the buffer by when it runs out of space.
200    pub fn new(compression_level: i32, resize_by: usize) -> Result<Self> {
201        let mut cctx = CCtx::create();
202        cctx.set_parameter(CParameter::CompressionLevel(compression_level))
203            .map_err(ZstdError::from_raw)?;
204
205        Ok(Self {
206            cctx,
207            finished: false,
208            resize_by,
209        })
210    }
211
212    /// Sets the number of worker threads to use for compression.
213    /// Can be called at any time.
214    /// Setting `workers` to `>= 1` will make compression asynchronous.
215    /// All compression will be done in background threads.
216    /// So it's important to call `finish` before dropping the stream.
217    pub fn multithreaded(&mut self, workers: u8) -> Result<()> {
218        self.cctx
219            .set_parameter(CParameter::NbWorkers(workers as _))
220            .map_err(ZstdError::from_raw)?;
221
222        Ok(())
223    }
224
225    pub fn write(&mut self, uncompressed: &[u8], compress_buffer: &mut Vec<u8>) -> Result<()> {
226        const MODE: zstd_sys::ZSTD_EndDirective = zstd_sys::ZSTD_EndDirective::ZSTD_e_continue;
227        if self.finished {
228            return Err(ZstdError::StreamAlreadyFinished);
229        }
230
231        if uncompressed.is_empty() {
232            return Ok(());
233        }
234
235        let mut input = InBuffer::around(uncompressed);
236
237        // we check that there is spare space in the buffer, if it's true we fill spare space with zeroes
238        // and then we compress the data
239        // in the end of loop we resize the buffer to the actual size
240
241        loop {
242            let mut output = self.out_buffer(compress_buffer);
243
244            self.cctx
245                .compress_stream2(&mut output, &mut input, MODE)
246                .map_err(ZstdError::from_raw)?;
247
248            // from the https://facebook.github.io/zstd/zstd_manual.html
249            //
250            //   Select how many threads will be spawned to compress in parallel.
251            //   When nbWorkers >= 1, triggers asynchronous mode when invoking ZSTD_compressStream*() :
252            //   ZSTD_compressStream*() consumes input and flush output if possible, but immediately gives back control to caller,
253            //   while compression is performed in parallel, within worker thread(s).
254            //   (note : a strong exception to this rule is when first invocation of ZSTD_compressStream2() sets ZSTD_e_end :
255            //    in which case, ZSTD_compressStream2() delegates to ZSTD_compress2(), which is always a blocking call).
256            //   More workers improve speed, but also increase memory usage.
257            //   Default value is `0`, aka "single-threaded mode" : no worker is spawned,
258            //   compression is performed inside Caller's thread, and all invocations are blocking
259
260            // For multithreaded compression, we should continue if there's more input to process
261
262            if input.pos() >= input.src.len() {
263                break Ok(());
264            }
265        }
266    }
267
268    fn out_buffer<'b>(&self, compress_buffer: &'b mut Vec<u8>) -> OutBuffer<'b, Vec<u8>> {
269        // Ensure there's enough space in the output buffer
270        let start = compress_buffer.len();
271        // check if there is enough unused space in the buffer
272        if compress_buffer.spare_capacity_mut().len() < self.resize_by {
273            compress_buffer.reserve(self.resize_by);
274        }
275
276        OutBuffer::around_pos(compress_buffer, start)
277    }
278
279    pub fn finish(&mut self, compress_buffer: &mut Vec<u8>) -> Result<()> {
280        if self.finished {
281            return Ok(());
282        }
283
284        loop {
285            let mut output = self.out_buffer(compress_buffer);
286
287            let remaining = self
288                .cctx
289                .end_stream(&mut output)
290                .map_err(ZstdError::from_raw)?;
291
292            if remaining == 0 {
293                self.finished = true;
294                return Ok(());
295            }
296        }
297    }
298
299    /// Resets the compression context.
300    /// You can again write data to the stream after calling this method.
301    pub fn reset(&mut self) -> Result<()> {
302        self.cctx
303            .reset(ResetDirective::SessionOnly)
304            .map_err(ZstdError::from_raw)?;
305        self.finished = false;
306
307        Ok(())
308    }
309}
310
311pub struct ZstdDecompressStream<'s> {
312    dctx: DCtx<'s>,
313    resize_by: usize,
314    finished: bool,
315}
316
317impl ZstdDecompressStream<'_> {
318    pub fn new(resize_by: usize) -> Result<Self> {
319        let mut dctx = DCtx::create();
320        dctx.init().map_err(ZstdError::from_raw)?;
321
322        Ok(Self {
323            dctx,
324            resize_by,
325            finished: false,
326        })
327    }
328
329    pub fn write(&mut self, compressed: &[u8], decompress_buffer: &mut Vec<u8>) -> Result<()> {
330        if self.finished {
331            return Err(ZstdError::StreamAlreadyFinished);
332        }
333        if compressed.is_empty() {
334            return Ok(());
335        }
336
337        let mut input = InBuffer::around(compressed);
338
339        loop {
340            let start = decompress_buffer.len();
341            if decompress_buffer.spare_capacity_mut().len() < self.resize_by {
342                decompress_buffer.reserve(self.resize_by);
343            }
344
345            // all input was read, chunky boy wants more
346            if input.pos() == input.src.len() {
347                break Ok(());
348            }
349
350            let mut output = OutBuffer::around_pos(decompress_buffer, start);
351            let read = self
352                .dctx
353                .decompress_stream(&mut output, &mut input)
354                .map_err(ZstdError::from_raw)?;
355
356            // when a frame is completely decoded and fully flushed,
357            if read == 0 {
358                self.finished = true;
359                break Ok(());
360            }
361        }
362    }
363
364    /// Resets the decompression context.
365    /// You can again write data to the stream after calling this method.
366    pub fn reset(&mut self) -> Result<()> {
367        self.dctx
368            .reset(ResetDirective::SessionOnly)
369            .map_err(ZstdError::from_raw)?;
370        self.finished = false;
371
372        Ok(())
373    }
374}
375
376#[derive(thiserror::Error, Debug)]
377pub enum ZstdError {
378    #[error("Zstd error: {0}")]
379    Raw(#[from] RawCompressorError),
380
381    #[error(
382        "Suspicious compression ratio detected: compressed size: {compressed_size}, decompressed size: {decompressed_size}"
383    )]
384    SuspiciousCompressionRatio {
385        compressed_size: usize,
386        decompressed_size: u64,
387    },
388
389    #[error("Invalid decompressed size: {decompressed_size}, input size: {input_size}")]
390    InvalidDecompressedSize {
391        decompressed_size: u64,
392        input_size: usize,
393    },
394
395    #[error("Stream already finished")]
396    StreamAlreadyFinished,
397}
398
399impl From<ZstdError> for std::io::Error {
400    fn from(value: ZstdError) -> Self {
401        std::io::Error::other(value)
402    }
403}
404
405impl ZstdError {
406    fn from_raw(code: usize) -> Self {
407        ZstdError::Raw(RawCompressorError { code })
408    }
409}
410
411pub struct RawCompressorError {
412    code: usize,
413}
414
415impl Debug for RawCompressorError {
416    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
417        f.write_str(get_error_name(self.code))
418    }
419}
420
421impl Display for RawCompressorError {
422    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
423        f.write_str(get_error_name(self.code))
424    }
425}
426
427impl std::error::Error for RawCompressorError {}
428
429#[cfg(test)]
430mod tests {
431    use std::io::{Read, Seek};
432
433    use rand::prelude::StdRng;
434    use rand::{RngCore, SeedableRng};
435
436    use super::*;
437
438    #[test]
439    fn test_zstd_compress_decompress() {
440        let seed = 42; // I've asked the universe
441        let mut rng = StdRng::seed_from_u64(seed);
442
443        for size in [10, 1024, 1024 * 1024, 10 * 1024 * 1024] {
444            let mut input = vec![0; size];
445            // without rng it will trigger check for too high compression ratio
446            rng.fill_bytes(input.as_mut_slice());
447
448            let mut compressed = Vec::new();
449            zstd_compress(&input, &mut compressed, 3);
450
451            let decompressed = zstd_decompress_simple(&compressed).unwrap();
452            assert_eq!(input, decompressed.as_slice());
453        }
454
455        let input = b"Hello, world!";
456        let mut compressed = Vec::new();
457        zstd_compress(input, &mut compressed, 3);
458        let decompressed = zstd_decompress_simple(&compressed).unwrap();
459        assert_eq!(input, decompressed.as_slice());
460
461        let mut input = b"bad".to_vec();
462        input.extend_from_slice(&compressed);
463        zstd_decompress_simple(&input).unwrap_err();
464    }
465
466    #[test]
467    fn test_streaming() {
468        for size in [10usize, 1021, 1024, 1024 * 1024, 10 * 1024 * 1024] {
469            let input = vec![0; size];
470            check_compression(input, false);
471
472            // NOTE: streaming compression will give slightly different results with one shot compression,
473            // so we can't compare the compressed data directly, only that decompression works
474        }
475
476        let pseudo_random = (0..1024)
477            .map(|i: u32| i.overflowing_mul(13).0 as u8)
478            .collect::<Vec<_>>();
479        check_compression(pseudo_random, false);
480
481        let hello_world = Vec::from_iter(b"Hello, world!".repeat(1023));
482        check_compression(hello_world, false);
483    }
484
485    // split on 2 tests because it's too long for a single test
486    #[test]
487    fn test_steaming_mt() {
488        for size in [10usize, 1021, 1024, 1024 * 1024, 10 * 1024 * 1024] {
489            let input = vec![0; size];
490            check_compression(input, true);
491
492            // NOTE: streaming compression will give slightly different results with one shot compression,
493            // so we can't compare the compressed data directly, only that decompression works
494        }
495
496        let pseudo_random = (0..1024)
497            .map(|i: u32| i.overflowing_mul(13).0 as u8)
498            .collect::<Vec<_>>();
499        check_compression(pseudo_random, true);
500
501        let hello_world = Vec::from_iter(b"Hello, world!".repeat(1023));
502        check_compression(hello_world, true);
503    }
504
505    fn check_compression(input: Vec<u8>, multithreaded: bool) {
506        let mut compressor = ZstdCompressStream::new(3, 128).unwrap();
507        if multithreaded {
508            compressor.multithreaded(4).unwrap();
509        }
510
511        let mut compress_buffer = Vec::new();
512        let mut result_buf = Vec::new();
513
514        for chunk in input.chunks(1024) {
515            compressor.write(chunk, &mut compress_buffer).unwrap();
516            if compress_buffer.len() > 1024 {
517                result_buf.extend_from_slice(&compress_buffer);
518                compress_buffer.clear();
519            }
520        }
521        compressor.finish(&mut compress_buffer).unwrap();
522        result_buf.extend_from_slice(&compress_buffer);
523
524        let decompressed = zstd_decompress_simple(&result_buf).unwrap();
525        assert_eq!(input, decompressed);
526
527        let decompressed = {
528            let mut streaming_decoder = ZstdDecompressStream::new(128).unwrap();
529            let mut decompressed = Vec::new();
530            streaming_decoder
531                .write(&result_buf, &mut decompressed)
532                .unwrap();
533            decompressed
534        };
535        assert_eq!(input, decompressed);
536    }
537
538    #[test]
539    fn test_dos() {
540        for malicious in malicious_files() {
541            if zstd_decompress_simple(&malicious).is_ok() {
542                panic!("Malicious file was decompressed successfully");
543            }
544        }
545    }
546
547    fn malicious_files() -> Vec<Vec<u8>> {
548        let mut files = Vec::new();
549
550        // 1. Lie about content size (much larger)
551        files.push(create_malicious_zstd(1_000_000_000, b"Small content"));
552
553        // 2. Lie about content size (much smaller)
554        files.push(create_malicious_zstd(
555            10,
556            b"This content is actually longer than claimed",
557        ));
558
559        // 3. Extremely high compression ratio
560        let large_content = vec![b'A'; 1_000_000];
561        files.push(create_malicious_zstd(
562            large_content.len() as u64,
563            &large_content,
564        ));
565
566        // 4. Invalid content size
567        files.push(vec![
568            0x28, 0xB5, 0x2F, 0xFD, 0x40, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
569        ]);
570
571        // 5. Truncated file
572        let truncated_content = b"This file will be truncated";
573        let mut truncated_compressed = encode_all(truncated_content.as_slice(), 3).unwrap();
574        truncated_compressed.truncate(truncated_compressed.len() / 2);
575        files.push(truncated_compressed);
576
577        files
578    }
579
580    fn encode_all(input: &[u8], level: i32) -> Result<Vec<u8>> {
581        let mut compressed = Vec::new();
582        zstd_compress(input, &mut compressed, level);
583        Ok(compressed)
584    }
585
586    fn create_malicious_zstd(content_size: u64, actual_content: &[u8]) -> Vec<u8> {
587        let mut compressed = encode_all(actual_content, 3).unwrap();
588
589        // Modify the Frame_Header_Descriptor to use an 8-byte content size
590        compressed[4] = (compressed[4] & 0b11000000) | 0b00000011;
591
592        // Insert the fake content size (8 bytes, little-endian)
593        compressed.splice(5..9, content_size.to_le_bytes());
594
595        compressed
596    }
597
598    #[test]
599    fn test_decode_chunked() {
600        let mut rng = StdRng::seed_from_u64(42);
601        let mut data = Vec::with_capacity(10 * 1024 * 1024);
602        let mut pseudo_rand_patern = vec![0; 1024 * 1024];
603        rng.fill_bytes(&mut pseudo_rand_patern);
604
605        for _ in 0..10 {
606            data.extend_from_slice(&pseudo_rand_patern);
607        }
608
609        let compressed = encode_all(&data, 3).unwrap();
610        let mut decompressed = Vec::new();
611
612        let mut decompressor = ZstdDecompressStream::new(128).unwrap();
613        for chunk in compressed.chunks(1024) {
614            decompressor.write(chunk, &mut decompressed).unwrap();
615        }
616
617        assert_eq!(data, decompressed);
618    }
619
620    #[test]
621    fn buffered_compress_decompress() {
622        const BUFFER_LEN: usize = 64 << 20; // 64 MB
623
624        // Prepare
625        let mut rng = StdRng::seed_from_u64(42);
626        let mut original = vec![0; 4 << 20];
627        rng.fill_bytes(&mut original);
628
629        // Try each kind of prealloc: small, exact, huge
630        for prealloc in [1024, 4194409, BUFFER_LEN] {
631            // Compress
632            let mut compressed = Vec::new();
633            {
634                let file = tempfile::tempfile().unwrap();
635                file.set_len(prealloc as _).unwrap();
636                let file = ZstdCompressedFile::new(file, 9, BUFFER_LEN).unwrap();
637
638                let mut buffer = std::io::BufWriter::with_capacity(BUFFER_LEN, file);
639                for chunk in original.chunks(2048) {
640                    buffer.write_all(chunk).unwrap();
641                }
642
643                let file = buffer.into_inner().map_err(|e| e.into_error()).unwrap();
644                let mut file = file.finish().unwrap();
645                file.flush().unwrap();
646
647                let file_size = file.stream_position().unwrap();
648                file.set_len(file_size).unwrap(); // <- Truncate after prealloc
649
650                file.seek(std::io::SeekFrom::Start(0)).unwrap();
651
652                #[allow(clippy::verbose_file_reads)]
653                file.read_to_end(&mut compressed).unwrap();
654            }
655
656            // Decompress
657            {
658                let mut stream = ZstdDecompressStream::new(1 << 20).unwrap();
659
660                let mut decompressed = Vec::new();
661                let mut decompressed_chunk = Vec::new();
662                for chunk in compressed.chunks(1 << 20) {
663                    decompressed_chunk.clear();
664                    stream.write(chunk, &mut decompressed_chunk).unwrap();
665
666                    decompressed.extend_from_slice(&decompressed_chunk);
667                }
668
669                assert_eq!(decompressed, original);
670            }
671        }
672    }
673}