Skip to main content

structured_zstd/encoding/
frame_compressor.rs

1//! Utilities and interfaces for encoding an entire frame. Allows reusing resources
2
3use alloc::{boxed::Box, vec::Vec};
4use core::convert::TryInto;
5#[cfg(feature = "hash")]
6use twox_hash::XxHash64;
7
8#[cfg(feature = "hash")]
9use core::hash::Hasher;
10
11use super::{
12    CompressionLevel, Matcher, block_header::BlockHeader, frame_header::FrameHeader, levels::*,
13    match_generator::MatchGeneratorDriver,
14};
15use crate::fse::fse_encoder::{FSETable, default_ll_table, default_ml_table, default_of_table};
16
17use crate::io::{Read, Write};
18
19/// An interface for compressing arbitrary data with the ZStandard compression algorithm.
20///
21/// `FrameCompressor` will generally be used by:
22/// 1. Initializing a compressor by providing a buffer of data using `FrameCompressor::new()`
23/// 2. Starting compression and writing that compression into a vec using `FrameCompressor::begin`
24///
25/// # Examples
26/// ```
27/// use structured_zstd::encoding::{FrameCompressor, CompressionLevel};
28/// let mock_data: &[_] = &[0x1, 0x2, 0x3, 0x4];
29/// let mut output = std::vec::Vec::new();
30/// // Initialize a compressor.
31/// let mut compressor = FrameCompressor::new(CompressionLevel::Uncompressed);
32/// compressor.set_source(mock_data);
33/// compressor.set_drain(&mut output);
34///
35/// // `compress` writes the compressed output into the provided buffer.
36/// compressor.compress();
37/// ```
38pub struct FrameCompressor<R: Read, W: Write, M: Matcher> {
39    uncompressed_data: Option<R>,
40    compressed_data: Option<W>,
41    compression_level: CompressionLevel,
42    state: CompressState<M>,
43    #[cfg(feature = "hash")]
44    hasher: XxHash64,
45}
46
47#[derive(Clone)]
48pub(crate) enum PreviousFseTable {
49    // Default tables are immutable and already stored alongside the state, so
50    // repeating them only needs a lightweight marker instead of cloning FSETable.
51    Default,
52    Custom(Box<FSETable>),
53}
54
55impl PreviousFseTable {
56    pub(crate) fn as_table<'a>(&'a self, default: &'a FSETable) -> &'a FSETable {
57        match self {
58            Self::Default => default,
59            Self::Custom(table) => table,
60        }
61    }
62}
63
64pub(crate) struct FseTables {
65    pub(crate) ll_default: FSETable,
66    pub(crate) ll_previous: Option<PreviousFseTable>,
67    pub(crate) ml_default: FSETable,
68    pub(crate) ml_previous: Option<PreviousFseTable>,
69    pub(crate) of_default: FSETable,
70    pub(crate) of_previous: Option<PreviousFseTable>,
71}
72
73impl FseTables {
74    pub fn new() -> Self {
75        Self {
76            ll_default: default_ll_table(),
77            ll_previous: None,
78            ml_default: default_ml_table(),
79            ml_previous: None,
80            of_default: default_of_table(),
81            of_previous: None,
82        }
83    }
84}
85
86pub(crate) struct CompressState<M: Matcher> {
87    pub(crate) matcher: M,
88    pub(crate) last_huff_table: Option<crate::huff0::huff0_encoder::HuffmanTable>,
89    pub(crate) fse_tables: FseTables,
90    /// Offset history for repeat offset encoding: [rep0, rep1, rep2].
91    /// Initialized to [1, 4, 8] per RFC 8878 §3.1.2.5.
92    pub(crate) offset_hist: [u32; 3],
93}
94
95impl<R: Read, W: Write> FrameCompressor<R, W, MatchGeneratorDriver> {
96    /// Create a new `FrameCompressor`
97    pub fn new(compression_level: CompressionLevel) -> Self {
98        Self {
99            uncompressed_data: None,
100            compressed_data: None,
101            compression_level,
102            state: CompressState {
103                matcher: MatchGeneratorDriver::new(1024 * 128, 1),
104                last_huff_table: None,
105                fse_tables: FseTables::new(),
106                offset_hist: [1, 4, 8],
107            },
108            #[cfg(feature = "hash")]
109            hasher: XxHash64::with_seed(0),
110        }
111    }
112}
113
114impl<R: Read, W: Write, M: Matcher> FrameCompressor<R, W, M> {
115    /// Create a new `FrameCompressor` with a custom matching algorithm implementation
116    pub fn new_with_matcher(matcher: M, compression_level: CompressionLevel) -> Self {
117        Self {
118            uncompressed_data: None,
119            compressed_data: None,
120            state: CompressState {
121                matcher,
122                last_huff_table: None,
123                fse_tables: FseTables::new(),
124                offset_hist: [1, 4, 8],
125            },
126            compression_level,
127            #[cfg(feature = "hash")]
128            hasher: XxHash64::with_seed(0),
129        }
130    }
131
132    /// Before calling [FrameCompressor::compress] you need to set the source.
133    ///
134    /// This is the data that is compressed and written into the drain.
135    pub fn set_source(&mut self, uncompressed_data: R) -> Option<R> {
136        self.uncompressed_data.replace(uncompressed_data)
137    }
138
139    /// Before calling [FrameCompressor::compress] you need to set the drain.
140    ///
141    /// As the compressor compresses data, the drain serves as a place for the output to be writte.
142    pub fn set_drain(&mut self, compressed_data: W) -> Option<W> {
143        self.compressed_data.replace(compressed_data)
144    }
145
146    /// Compress the uncompressed data from the provided source as one Zstd frame and write it to the provided drain
147    ///
148    /// This will repeatedly call [Read::read] on the source to fill up blocks until the source returns 0 on the read call.
149    /// Also [Write::write_all] will be called on the drain after each block has been encoded.
150    ///
151    /// To avoid endlessly encoding from a potentially endless source (like a network socket) you can use the
152    /// [Read::take] function
153    pub fn compress(&mut self) {
154        // Clearing buffers to allow re-using of the compressor
155        self.state.matcher.reset(self.compression_level);
156        self.state.last_huff_table = None;
157        self.state.fse_tables.ll_previous = None;
158        self.state.fse_tables.ml_previous = None;
159        self.state.fse_tables.of_previous = None;
160        self.state.offset_hist = [1, 4, 8];
161        #[cfg(feature = "hash")]
162        {
163            self.hasher = XxHash64::with_seed(0);
164        }
165        let source = self.uncompressed_data.as_mut().unwrap();
166        let drain = self.compressed_data.as_mut().unwrap();
167        // As the frame is compressed, it's stored here
168        let output: &mut Vec<u8> = &mut Vec::with_capacity(1024 * 130);
169        // First write the frame header
170        let header = FrameHeader {
171            frame_content_size: None,
172            single_segment: false,
173            content_checksum: cfg!(feature = "hash"),
174            dictionary_id: None,
175            window_size: Some(self.state.matcher.window_size()),
176        };
177        header.serialize(output);
178        // Now compress block by block
179        loop {
180            // Read a single block's worth of uncompressed data from the input
181            let mut uncompressed_data = self.state.matcher.get_next_space();
182            let mut read_bytes = 0;
183            let last_block;
184            'read_loop: loop {
185                let new_bytes = source.read(&mut uncompressed_data[read_bytes..]).unwrap();
186                if new_bytes == 0 {
187                    last_block = true;
188                    break 'read_loop;
189                }
190                read_bytes += new_bytes;
191                if read_bytes == uncompressed_data.len() {
192                    last_block = false;
193                    break 'read_loop;
194                }
195            }
196            uncompressed_data.resize(read_bytes, 0);
197            // As we read, hash that data too
198            #[cfg(feature = "hash")]
199            self.hasher.write(&uncompressed_data);
200            // Special handling is needed for compression of a totally empty file (why you'd want to do that, I don't know)
201            if uncompressed_data.is_empty() {
202                let header = BlockHeader {
203                    last_block: true,
204                    block_type: crate::blocks::block::BlockType::Raw,
205                    block_size: 0,
206                };
207                // Write the header, then the block
208                header.serialize(output);
209                drain.write_all(output).unwrap();
210                output.clear();
211                break;
212            }
213
214            match self.compression_level {
215                CompressionLevel::Uncompressed => {
216                    let header = BlockHeader {
217                        last_block,
218                        block_type: crate::blocks::block::BlockType::Raw,
219                        block_size: read_bytes.try_into().unwrap(),
220                    };
221                    // Write the header, then the block
222                    header.serialize(output);
223                    output.extend_from_slice(&uncompressed_data);
224                }
225                CompressionLevel::Fastest | CompressionLevel::Default => {
226                    // Default shares this fast block-encoding pipeline, but it
227                    // remains a distinct level via the matcher's dfast backend.
228                    compress_fastest(&mut self.state, last_block, uncompressed_data, output)
229                }
230                _ => {
231                    unimplemented!();
232                }
233            }
234            drain.write_all(output).unwrap();
235            output.clear();
236            if last_block {
237                break;
238            }
239        }
240
241        // If the `hash` feature is enabled, then `content_checksum` is set to true in the header
242        // and a 32 bit hash is written at the end of the data.
243        #[cfg(feature = "hash")]
244        {
245            // Because we only have the data as a reader, we need to read all of it to calculate the checksum
246            // Possible TODO: create a wrapper around self.uncompressed data that hashes the data as it's read?
247            let content_checksum = self.hasher.finish();
248            drain
249                .write_all(&(content_checksum as u32).to_le_bytes())
250                .unwrap();
251        }
252    }
253
254    /// Get a mutable reference to the source
255    pub fn source_mut(&mut self) -> Option<&mut R> {
256        self.uncompressed_data.as_mut()
257    }
258
259    /// Get a mutable reference to the drain
260    pub fn drain_mut(&mut self) -> Option<&mut W> {
261        self.compressed_data.as_mut()
262    }
263
264    /// Get a reference to the source
265    pub fn source(&self) -> Option<&R> {
266        self.uncompressed_data.as_ref()
267    }
268
269    /// Get a reference to the drain
270    pub fn drain(&self) -> Option<&W> {
271        self.compressed_data.as_ref()
272    }
273
274    /// Retrieve the source
275    pub fn take_source(&mut self) -> Option<R> {
276        self.uncompressed_data.take()
277    }
278
279    /// Retrieve the drain
280    pub fn take_drain(&mut self) -> Option<W> {
281        self.compressed_data.take()
282    }
283
284    /// Before calling [FrameCompressor::compress] you can replace the matcher
285    pub fn replace_matcher(&mut self, mut match_generator: M) -> M {
286        core::mem::swap(&mut match_generator, &mut self.state.matcher);
287        match_generator
288    }
289
290    /// Before calling [FrameCompressor::compress] you can replace the compression level
291    pub fn set_compression_level(
292        &mut self,
293        compression_level: CompressionLevel,
294    ) -> CompressionLevel {
295        let old = self.compression_level;
296        self.compression_level = compression_level;
297        old
298    }
299
300    /// Get the current compression level
301    pub fn compression_level(&self) -> CompressionLevel {
302        self.compression_level
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use alloc::vec;
309
310    use super::FrameCompressor;
311    use crate::common::MAGIC_NUM;
312    use crate::decoding::FrameDecoder;
313    use alloc::vec::Vec;
314
315    #[test]
316    fn frame_starts_with_magic_num() {
317        let mock_data = [1_u8, 2, 3].as_slice();
318        let mut output: Vec<u8> = Vec::new();
319        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
320        compressor.set_source(mock_data);
321        compressor.set_drain(&mut output);
322
323        compressor.compress();
324        assert!(output.starts_with(&MAGIC_NUM.to_le_bytes()));
325    }
326
327    #[test]
328    fn very_simple_raw_compress() {
329        let mock_data = [1_u8, 2, 3].as_slice();
330        let mut output: Vec<u8> = Vec::new();
331        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
332        compressor.set_source(mock_data);
333        compressor.set_drain(&mut output);
334
335        compressor.compress();
336    }
337
338    #[test]
339    fn very_simple_compress() {
340        let mut mock_data = vec![0; 1 << 17];
341        mock_data.extend(vec![1; (1 << 17) - 1]);
342        mock_data.extend(vec![2; (1 << 18) - 1]);
343        mock_data.extend(vec![2; 1 << 17]);
344        mock_data.extend(vec![3; (1 << 17) - 1]);
345        let mut output: Vec<u8> = Vec::new();
346        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
347        compressor.set_source(mock_data.as_slice());
348        compressor.set_drain(&mut output);
349
350        compressor.compress();
351
352        let mut decoder = FrameDecoder::new();
353        let mut decoded = Vec::with_capacity(mock_data.len());
354        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
355        assert_eq!(mock_data, decoded);
356
357        let mut decoded = Vec::new();
358        zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
359        assert_eq!(mock_data, decoded);
360    }
361
362    #[test]
363    fn rle_compress() {
364        let mock_data = vec![0; 1 << 19];
365        let mut output: Vec<u8> = Vec::new();
366        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
367        compressor.set_source(mock_data.as_slice());
368        compressor.set_drain(&mut output);
369
370        compressor.compress();
371
372        let mut decoder = FrameDecoder::new();
373        let mut decoded = Vec::with_capacity(mock_data.len());
374        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
375        assert_eq!(mock_data, decoded);
376    }
377
378    #[test]
379    fn aaa_compress() {
380        let mock_data = vec![0, 1, 3, 4, 5];
381        let mut output: Vec<u8> = Vec::new();
382        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
383        compressor.set_source(mock_data.as_slice());
384        compressor.set_drain(&mut output);
385
386        compressor.compress();
387
388        let mut decoder = FrameDecoder::new();
389        let mut decoded = Vec::with_capacity(mock_data.len());
390        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
391        assert_eq!(mock_data, decoded);
392
393        let mut decoded = Vec::new();
394        zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
395        assert_eq!(mock_data, decoded);
396    }
397
398    #[cfg(feature = "hash")]
399    #[test]
400    fn checksum_two_frames_reused_compressor() {
401        // Compress the same data twice using the same compressor and verify that:
402        // 1. The checksum written in each frame matches what the decoder calculates.
403        // 2. The hasher is correctly reset between frames (no cross-contamination).
404        //    If the hasher were NOT reset, the second frame's calculated checksum
405        //    would differ from the one stored in the frame data, causing assert_eq to fail.
406        let data: Vec<u8> = (0u8..=255).cycle().take(1024).collect();
407
408        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
409
410        // --- Frame 1 ---
411        let mut compressed1 = Vec::new();
412        compressor.set_source(data.as_slice());
413        compressor.set_drain(&mut compressed1);
414        compressor.compress();
415
416        // --- Frame 2 (reuse the same compressor) ---
417        let mut compressed2 = Vec::new();
418        compressor.set_source(data.as_slice());
419        compressor.set_drain(&mut compressed2);
420        compressor.compress();
421
422        fn decode_and_collect(compressed: &[u8]) -> (Vec<u8>, Option<u32>, Option<u32>) {
423            let mut decoder = FrameDecoder::new();
424            let mut source = compressed;
425            decoder.reset(&mut source).unwrap();
426            while !decoder.is_finished() {
427                decoder
428                    .decode_blocks(&mut source, crate::decoding::BlockDecodingStrategy::All)
429                    .unwrap();
430            }
431            let mut decoded = Vec::new();
432            decoder.collect_to_writer(&mut decoded).unwrap();
433            (
434                decoded,
435                decoder.get_checksum_from_data(),
436                decoder.get_calculated_checksum(),
437            )
438        }
439
440        let (decoded1, chksum_from_data1, chksum_calculated1) = decode_and_collect(&compressed1);
441        assert_eq!(decoded1, data, "frame 1: decoded data mismatch");
442        assert_eq!(
443            chksum_from_data1, chksum_calculated1,
444            "frame 1: checksum mismatch"
445        );
446
447        let (decoded2, chksum_from_data2, chksum_calculated2) = decode_and_collect(&compressed2);
448        assert_eq!(decoded2, data, "frame 2: decoded data mismatch");
449        assert_eq!(
450            chksum_from_data2, chksum_calculated2,
451            "frame 2: checksum mismatch"
452        );
453
454        // Same data compressed twice must produce the same checksum.
455        // If state leaked across frames, the second calculated checksum would differ.
456        assert_eq!(
457            chksum_from_data1, chksum_from_data2,
458            "frame 1 and frame 2 should have the same checksum (same data, hash must reset per frame)"
459        );
460    }
461
462    #[cfg(feature = "std")]
463    #[test]
464    fn fuzz_targets() {
465        use std::io::Read;
466        fn decode_szstd(data: &mut dyn std::io::Read) -> Vec<u8> {
467            let mut decoder = crate::decoding::StreamingDecoder::new(data).unwrap();
468            let mut result: Vec<u8> = Vec::new();
469            decoder.read_to_end(&mut result).expect("Decoding failed");
470            result
471        }
472
473        fn decode_szstd_writer(mut data: impl Read) -> Vec<u8> {
474            let mut decoder = crate::decoding::FrameDecoder::new();
475            decoder.reset(&mut data).unwrap();
476            let mut result = vec![];
477            while !decoder.is_finished() || decoder.can_collect() > 0 {
478                decoder
479                    .decode_blocks(
480                        &mut data,
481                        crate::decoding::BlockDecodingStrategy::UptoBytes(1024 * 1024),
482                    )
483                    .unwrap();
484                decoder.collect_to_writer(&mut result).unwrap();
485            }
486            result
487        }
488
489        fn encode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
490            zstd::stream::encode_all(std::io::Cursor::new(data), 3)
491        }
492
493        fn encode_szstd_uncompressed(data: &mut dyn std::io::Read) -> Vec<u8> {
494            let mut input = Vec::new();
495            data.read_to_end(&mut input).unwrap();
496
497            crate::encoding::compress_to_vec(
498                input.as_slice(),
499                crate::encoding::CompressionLevel::Uncompressed,
500            )
501        }
502
503        fn encode_szstd_compressed(data: &mut dyn std::io::Read) -> Vec<u8> {
504            let mut input = Vec::new();
505            data.read_to_end(&mut input).unwrap();
506
507            crate::encoding::compress_to_vec(
508                input.as_slice(),
509                crate::encoding::CompressionLevel::Fastest,
510            )
511        }
512
513        fn decode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
514            let mut output = Vec::new();
515            zstd::stream::copy_decode(data, &mut output)?;
516            Ok(output)
517        }
518        if std::fs::exists("fuzz/artifacts/interop").unwrap_or(false) {
519            for file in std::fs::read_dir("fuzz/artifacts/interop").unwrap() {
520                if file.as_ref().unwrap().file_type().unwrap().is_file() {
521                    let data = std::fs::read(file.unwrap().path()).unwrap();
522                    let data = data.as_slice();
523                    // Decoding
524                    let compressed = encode_zstd(data).unwrap();
525                    let decoded = decode_szstd(&mut compressed.as_slice());
526                    let decoded2 = decode_szstd_writer(&mut compressed.as_slice());
527                    assert!(
528                        decoded == data,
529                        "Decoded data did not match the original input during decompression"
530                    );
531                    assert_eq!(
532                        decoded2, data,
533                        "Decoded data did not match the original input during decompression"
534                    );
535
536                    // Encoding
537                    // Uncompressed encoding
538                    let mut input = data;
539                    let compressed = encode_szstd_uncompressed(&mut input);
540                    let decoded = decode_zstd(&compressed).unwrap();
541                    assert_eq!(
542                        decoded, data,
543                        "Decoded data did not match the original input during compression"
544                    );
545                    // Compressed encoding
546                    let mut input = data;
547                    let compressed = encode_szstd_compressed(&mut input);
548                    let decoded = decode_zstd(&compressed).unwrap();
549                    assert_eq!(
550                        decoded, data,
551                        "Decoded data did not match the original input during compression"
552                    );
553                }
554            }
555        }
556    }
557}