Skip to main content

structured_zstd/encoding/
frame_compressor.rs

1//! Utilities and interfaces for encoding an entire frame. Allows reusing resources
2
3use alloc::vec::Vec;
4use core::convert::TryInto;
5#[cfg(feature = "hash")]
6use twox_hash::XxHash64;
7
8#[cfg(feature = "hash")]
9use core::hash::Hasher;
10
11use super::{
12    block_header::BlockHeader, frame_header::FrameHeader, levels::*,
13    match_generator::MatchGeneratorDriver, CompressionLevel, Matcher,
14};
15use crate::fse::fse_encoder::{default_ll_table, default_ml_table, default_of_table, FSETable};
16
17use crate::io::{Read, Write};
18
19/// An interface for compressing arbitrary data with the ZStandard compression algorithm.
20///
21/// `FrameCompressor` will generally be used by:
22/// 1. Initializing a compressor by providing a buffer of data using `FrameCompressor::new()`
23/// 2. Starting compression and writing that compression into a vec using `FrameCompressor::begin`
24///
25/// # Examples
26/// ```
27/// use ruzstd::encoding::{FrameCompressor, CompressionLevel};
28/// let mock_data: &[_] = &[0x1, 0x2, 0x3, 0x4];
29/// let mut output = std::vec::Vec::new();
30/// // Initialize a compressor.
31/// let mut compressor = FrameCompressor::new(CompressionLevel::Uncompressed);
32/// compressor.set_source(mock_data);
33/// compressor.set_drain(&mut output);
34///
35/// // `compress` writes the compressed output into the provided buffer.
36/// compressor.compress();
37/// ```
38pub struct FrameCompressor<R: Read, W: Write, M: Matcher> {
39    uncompressed_data: Option<R>,
40    compressed_data: Option<W>,
41    compression_level: CompressionLevel,
42    state: CompressState<M>,
43    #[cfg(feature = "hash")]
44    hasher: XxHash64,
45}
46
47pub(crate) struct FseTables {
48    pub(crate) ll_default: FSETable,
49    pub(crate) ll_previous: Option<FSETable>,
50    pub(crate) ml_default: FSETable,
51    pub(crate) ml_previous: Option<FSETable>,
52    pub(crate) of_default: FSETable,
53    pub(crate) of_previous: Option<FSETable>,
54}
55
56impl FseTables {
57    pub fn new() -> Self {
58        Self {
59            ll_default: default_ll_table(),
60            ll_previous: None,
61            ml_default: default_ml_table(),
62            ml_previous: None,
63            of_default: default_of_table(),
64            of_previous: None,
65        }
66    }
67}
68
69pub(crate) struct CompressState<M: Matcher> {
70    pub(crate) matcher: M,
71    pub(crate) last_huff_table: Option<crate::huff0::huff0_encoder::HuffmanTable>,
72    pub(crate) fse_tables: FseTables,
73}
74
75impl<R: Read, W: Write> FrameCompressor<R, W, MatchGeneratorDriver> {
76    /// Create a new `FrameCompressor`
77    pub fn new(compression_level: CompressionLevel) -> Self {
78        Self {
79            uncompressed_data: None,
80            compressed_data: None,
81            compression_level,
82            state: CompressState {
83                matcher: MatchGeneratorDriver::new(1024 * 128, 1),
84                last_huff_table: None,
85                fse_tables: FseTables::new(),
86            },
87            #[cfg(feature = "hash")]
88            hasher: XxHash64::with_seed(0),
89        }
90    }
91}
92
93impl<R: Read, W: Write, M: Matcher> FrameCompressor<R, W, M> {
94    /// Create a new `FrameCompressor` with a custom matching algorithm implementation
95    pub fn new_with_matcher(matcher: M, compression_level: CompressionLevel) -> Self {
96        Self {
97            uncompressed_data: None,
98            compressed_data: None,
99            state: CompressState {
100                matcher,
101                last_huff_table: None,
102                fse_tables: FseTables::new(),
103            },
104            compression_level,
105            #[cfg(feature = "hash")]
106            hasher: XxHash64::with_seed(0),
107        }
108    }
109
110    /// Before calling [FrameCompressor::compress] you need to set the source.
111    ///
112    /// This is the data that is compressed and written into the drain.
113    pub fn set_source(&mut self, uncompressed_data: R) -> Option<R> {
114        self.uncompressed_data.replace(uncompressed_data)
115    }
116
117    /// Before calling [FrameCompressor::compress] you need to set the drain.
118    ///
119    /// As the compressor compresses data, the drain serves as a place for the output to be writte.
120    pub fn set_drain(&mut self, compressed_data: W) -> Option<W> {
121        self.compressed_data.replace(compressed_data)
122    }
123
124    /// Compress the uncompressed data from the provided source as one Zstd frame and write it to the provided drain
125    ///
126    /// This will repeatedly call [Read::read] on the source to fill up blocks until the source returns 0 on the read call.
127    /// Also [Write::write_all] will be called on the drain after each block has been encoded.
128    ///
129    /// To avoid endlessly encoding from a potentially endless source (like a network socket) you can use the
130    /// [Read::take] function
131    pub fn compress(&mut self) {
132        // Clearing buffers to allow re-using of the compressor
133        self.state.matcher.reset(self.compression_level);
134        self.state.last_huff_table = None;
135        #[cfg(feature = "hash")]
136        {
137            self.hasher = XxHash64::with_seed(0);
138        }
139        let source = self.uncompressed_data.as_mut().unwrap();
140        let drain = self.compressed_data.as_mut().unwrap();
141        // As the frame is compressed, it's stored here
142        let output: &mut Vec<u8> = &mut Vec::with_capacity(1024 * 130);
143        // First write the frame header
144        let header = FrameHeader {
145            frame_content_size: None,
146            single_segment: false,
147            content_checksum: cfg!(feature = "hash"),
148            dictionary_id: None,
149            window_size: Some(self.state.matcher.window_size()),
150        };
151        header.serialize(output);
152        // Now compress block by block
153        loop {
154            // Read a single block's worth of uncompressed data from the input
155            let mut uncompressed_data = self.state.matcher.get_next_space();
156            let mut read_bytes = 0;
157            let last_block;
158            'read_loop: loop {
159                let new_bytes = source.read(&mut uncompressed_data[read_bytes..]).unwrap();
160                if new_bytes == 0 {
161                    last_block = true;
162                    break 'read_loop;
163                }
164                read_bytes += new_bytes;
165                if read_bytes == uncompressed_data.len() {
166                    last_block = false;
167                    break 'read_loop;
168                }
169            }
170            uncompressed_data.resize(read_bytes, 0);
171            // As we read, hash that data too
172            #[cfg(feature = "hash")]
173            self.hasher.write(&uncompressed_data);
174            // Special handling is needed for compression of a totally empty file (why you'd want to do that, I don't know)
175            if uncompressed_data.is_empty() {
176                let header = BlockHeader {
177                    last_block: true,
178                    block_type: crate::blocks::block::BlockType::Raw,
179                    block_size: 0,
180                };
181                // Write the header, then the block
182                header.serialize(output);
183                drain.write_all(output).unwrap();
184                output.clear();
185                break;
186            }
187
188            match self.compression_level {
189                CompressionLevel::Uncompressed => {
190                    let header = BlockHeader {
191                        last_block,
192                        block_type: crate::blocks::block::BlockType::Raw,
193                        block_size: read_bytes.try_into().unwrap(),
194                    };
195                    // Write the header, then the block
196                    header.serialize(output);
197                    output.extend_from_slice(&uncompressed_data);
198                }
199                CompressionLevel::Fastest => {
200                    compress_fastest(&mut self.state, last_block, uncompressed_data, output)
201                }
202                _ => {
203                    unimplemented!();
204                }
205            }
206            drain.write_all(output).unwrap();
207            output.clear();
208            if last_block {
209                break;
210            }
211        }
212
213        // If the `hash` feature is enabled, then `content_checksum` is set to true in the header
214        // and a 32 bit hash is written at the end of the data.
215        #[cfg(feature = "hash")]
216        {
217            // Because we only have the data as a reader, we need to read all of it to calculate the checksum
218            // Possible TODO: create a wrapper around self.uncompressed data that hashes the data as it's read?
219            let content_checksum = self.hasher.finish();
220            drain
221                .write_all(&(content_checksum as u32).to_le_bytes())
222                .unwrap();
223        }
224    }
225
226    /// Get a mutable reference to the source
227    pub fn source_mut(&mut self) -> Option<&mut R> {
228        self.uncompressed_data.as_mut()
229    }
230
231    /// Get a mutable reference to the drain
232    pub fn drain_mut(&mut self) -> Option<&mut W> {
233        self.compressed_data.as_mut()
234    }
235
236    /// Get a reference to the source
237    pub fn source(&self) -> Option<&R> {
238        self.uncompressed_data.as_ref()
239    }
240
241    /// Get a reference to the drain
242    pub fn drain(&self) -> Option<&W> {
243        self.compressed_data.as_ref()
244    }
245
246    /// Retrieve the source
247    pub fn take_source(&mut self) -> Option<R> {
248        self.uncompressed_data.take()
249    }
250
251    /// Retrieve the drain
252    pub fn take_drain(&mut self) -> Option<W> {
253        self.compressed_data.take()
254    }
255
256    /// Before calling [FrameCompressor::compress] you can replace the matcher
257    pub fn replace_matcher(&mut self, mut match_generator: M) -> M {
258        core::mem::swap(&mut match_generator, &mut self.state.matcher);
259        match_generator
260    }
261
262    /// Before calling [FrameCompressor::compress] you can replace the compression level
263    pub fn set_compression_level(
264        &mut self,
265        compression_level: CompressionLevel,
266    ) -> CompressionLevel {
267        let old = self.compression_level;
268        self.compression_level = compression_level;
269        old
270    }
271
272    /// Get the current compression level
273    pub fn compression_level(&self) -> CompressionLevel {
274        self.compression_level
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use alloc::vec;
281
282    use super::FrameCompressor;
283    use crate::common::MAGIC_NUM;
284    use crate::decoding::FrameDecoder;
285    use alloc::vec::Vec;
286
287    #[test]
288    fn frame_starts_with_magic_num() {
289        let mock_data = [1_u8, 2, 3].as_slice();
290        let mut output: Vec<u8> = Vec::new();
291        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
292        compressor.set_source(mock_data);
293        compressor.set_drain(&mut output);
294
295        compressor.compress();
296        assert!(output.starts_with(&MAGIC_NUM.to_le_bytes()));
297    }
298
299    #[test]
300    fn very_simple_raw_compress() {
301        let mock_data = [1_u8, 2, 3].as_slice();
302        let mut output: Vec<u8> = Vec::new();
303        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
304        compressor.set_source(mock_data);
305        compressor.set_drain(&mut output);
306
307        compressor.compress();
308    }
309
310    #[test]
311    fn very_simple_compress() {
312        let mut mock_data = vec![0; 1 << 17];
313        mock_data.extend(vec![1; (1 << 17) - 1]);
314        mock_data.extend(vec![2; (1 << 18) - 1]);
315        mock_data.extend(vec![2; 1 << 17]);
316        mock_data.extend(vec![3; (1 << 17) - 1]);
317        let mut output: Vec<u8> = Vec::new();
318        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
319        compressor.set_source(mock_data.as_slice());
320        compressor.set_drain(&mut output);
321
322        compressor.compress();
323
324        let mut decoder = FrameDecoder::new();
325        let mut decoded = Vec::with_capacity(mock_data.len());
326        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
327        assert_eq!(mock_data, decoded);
328
329        let mut decoded = Vec::new();
330        zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
331        assert_eq!(mock_data, decoded);
332    }
333
334    #[test]
335    fn rle_compress() {
336        let mock_data = vec![0; 1 << 19];
337        let mut output: Vec<u8> = Vec::new();
338        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
339        compressor.set_source(mock_data.as_slice());
340        compressor.set_drain(&mut output);
341
342        compressor.compress();
343
344        let mut decoder = FrameDecoder::new();
345        let mut decoded = Vec::with_capacity(mock_data.len());
346        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
347        assert_eq!(mock_data, decoded);
348    }
349
350    #[test]
351    fn aaa_compress() {
352        let mock_data = vec![0, 1, 3, 4, 5];
353        let mut output: Vec<u8> = Vec::new();
354        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
355        compressor.set_source(mock_data.as_slice());
356        compressor.set_drain(&mut output);
357
358        compressor.compress();
359
360        let mut decoder = FrameDecoder::new();
361        let mut decoded = Vec::with_capacity(mock_data.len());
362        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
363        assert_eq!(mock_data, decoded);
364
365        let mut decoded = Vec::new();
366        zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
367        assert_eq!(mock_data, decoded);
368    }
369
370    #[cfg(feature = "hash")]
371    #[test]
372    fn checksum_two_frames_reused_compressor() {
373        // Compress the same data twice using the same compressor and verify that:
374        // 1. The checksum written in each frame matches what the decoder calculates.
375        // 2. The hasher is correctly reset between frames (no cross-contamination).
376        //    If the hasher were NOT reset, the second frame's calculated checksum
377        //    would differ from the one stored in the frame data, causing assert_eq to fail.
378        let data: Vec<u8> = (0u8..=255).cycle().take(1024).collect();
379
380        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
381
382        // --- Frame 1 ---
383        let mut compressed1 = Vec::new();
384        compressor.set_source(data.as_slice());
385        compressor.set_drain(&mut compressed1);
386        compressor.compress();
387
388        // --- Frame 2 (reuse the same compressor) ---
389        let mut compressed2 = Vec::new();
390        compressor.set_source(data.as_slice());
391        compressor.set_drain(&mut compressed2);
392        compressor.compress();
393
394        fn decode_and_collect(compressed: &[u8]) -> (Vec<u8>, Option<u32>, Option<u32>) {
395            let mut decoder = FrameDecoder::new();
396            let mut source = compressed;
397            decoder.reset(&mut source).unwrap();
398            while !decoder.is_finished() {
399                decoder
400                    .decode_blocks(&mut source, crate::decoding::BlockDecodingStrategy::All)
401                    .unwrap();
402            }
403            let mut decoded = Vec::new();
404            decoder.collect_to_writer(&mut decoded).unwrap();
405            (
406                decoded,
407                decoder.get_checksum_from_data(),
408                decoder.get_calculated_checksum(),
409            )
410        }
411
412        let (decoded1, chksum_from_data1, chksum_calculated1) = decode_and_collect(&compressed1);
413        assert_eq!(decoded1, data, "frame 1: decoded data mismatch");
414        assert_eq!(
415            chksum_from_data1, chksum_calculated1,
416            "frame 1: checksum mismatch"
417        );
418
419        let (decoded2, chksum_from_data2, chksum_calculated2) = decode_and_collect(&compressed2);
420        assert_eq!(decoded2, data, "frame 2: decoded data mismatch");
421        assert_eq!(
422            chksum_from_data2, chksum_calculated2,
423            "frame 2: checksum mismatch"
424        );
425
426        // Same data compressed twice must produce the same checksum.
427        // If state leaked across frames, the second calculated checksum would differ.
428        assert_eq!(
429            chksum_from_data1, chksum_from_data2,
430            "frame 1 and frame 2 should have the same checksum (same data, hash must reset per frame)"
431        );
432    }
433
434    #[cfg(feature = "std")]
435    #[test]
436    fn fuzz_targets() {
437        use std::io::Read;
438        fn decode_ruzstd(data: &mut dyn std::io::Read) -> Vec<u8> {
439            let mut decoder = crate::decoding::StreamingDecoder::new(data).unwrap();
440            let mut result: Vec<u8> = Vec::new();
441            decoder.read_to_end(&mut result).expect("Decoding failed");
442            result
443        }
444
445        fn decode_ruzstd_writer(mut data: impl Read) -> Vec<u8> {
446            let mut decoder = crate::decoding::FrameDecoder::new();
447            decoder.reset(&mut data).unwrap();
448            let mut result = vec![];
449            while !decoder.is_finished() || decoder.can_collect() > 0 {
450                decoder
451                    .decode_blocks(
452                        &mut data,
453                        crate::decoding::BlockDecodingStrategy::UptoBytes(1024 * 1024),
454                    )
455                    .unwrap();
456                decoder.collect_to_writer(&mut result).unwrap();
457            }
458            result
459        }
460
461        fn encode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
462            zstd::stream::encode_all(std::io::Cursor::new(data), 3)
463        }
464
465        fn encode_ruzstd_uncompressed(data: &mut dyn std::io::Read) -> Vec<u8> {
466            let mut input = Vec::new();
467            data.read_to_end(&mut input).unwrap();
468
469            crate::encoding::compress_to_vec(
470                input.as_slice(),
471                crate::encoding::CompressionLevel::Uncompressed,
472            )
473        }
474
475        fn encode_ruzstd_compressed(data: &mut dyn std::io::Read) -> Vec<u8> {
476            let mut input = Vec::new();
477            data.read_to_end(&mut input).unwrap();
478
479            crate::encoding::compress_to_vec(
480                input.as_slice(),
481                crate::encoding::CompressionLevel::Fastest,
482            )
483        }
484
485        fn decode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
486            let mut output = Vec::new();
487            zstd::stream::copy_decode(data, &mut output)?;
488            Ok(output)
489        }
490        if std::fs::exists("fuzz/artifacts/interop").unwrap_or(false) {
491            for file in std::fs::read_dir("fuzz/artifacts/interop").unwrap() {
492                if file.as_ref().unwrap().file_type().unwrap().is_file() {
493                    let data = std::fs::read(file.unwrap().path()).unwrap();
494                    let data = data.as_slice();
495                    // Decoding
496                    let compressed = encode_zstd(data).unwrap();
497                    let decoded = decode_ruzstd(&mut compressed.as_slice());
498                    let decoded2 = decode_ruzstd_writer(&mut compressed.as_slice());
499                    assert!(
500                        decoded == data,
501                        "Decoded data did not match the original input during decompression"
502                    );
503                    assert_eq!(
504                        decoded2, data,
505                        "Decoded data did not match the original input during decompression"
506                    );
507
508                    // Encoding
509                    // Uncompressed encoding
510                    let mut input = data;
511                    let compressed = encode_ruzstd_uncompressed(&mut input);
512                    let decoded = decode_zstd(&compressed).unwrap();
513                    assert_eq!(
514                        decoded, data,
515                        "Decoded data did not match the original input during compression"
516                    );
517                    // Compressed encoding
518                    let mut input = data;
519                    let compressed = encode_ruzstd_compressed(&mut input);
520                    let decoded = decode_zstd(&compressed).unwrap();
521                    assert_eq!(
522                        decoded, data,
523                        "Decoded data did not match the original input during compression"
524                    );
525                }
526            }
527        }
528    }
529}