Skip to main content

structured_zstd/encoding/
frame_compressor.rs

1//! Utilities and interfaces for encoding an entire frame. Allows reusing resources
2
3use alloc::{boxed::Box, vec::Vec};
4use core::convert::TryInto;
5#[cfg(feature = "hash")]
6use twox_hash::XxHash64;
7
8#[cfg(feature = "hash")]
9use core::hash::Hasher;
10
11use super::{
12    CompressionLevel, Matcher, block_header::BlockHeader, frame_header::FrameHeader, levels::*,
13    match_generator::MatchGeneratorDriver,
14};
15use crate::fse::fse_encoder::{FSETable, default_ll_table, default_ml_table, default_of_table};
16
17use crate::io::{Read, Write};
18
19/// An interface for compressing arbitrary data with the ZStandard compression algorithm.
20///
21/// `FrameCompressor` will generally be used by:
22/// 1. Initializing a compressor by providing a buffer of data using `FrameCompressor::new()`
23/// 2. Starting compression and writing that compression into a vec using `FrameCompressor::begin`
24///
25/// # Examples
26/// ```
27/// use structured_zstd::encoding::{FrameCompressor, CompressionLevel};
28/// let mock_data: &[_] = &[0x1, 0x2, 0x3, 0x4];
29/// let mut output = std::vec::Vec::new();
30/// // Initialize a compressor.
31/// let mut compressor = FrameCompressor::new(CompressionLevel::Uncompressed);
32/// compressor.set_source(mock_data);
33/// compressor.set_drain(&mut output);
34///
35/// // `compress` writes the compressed output into the provided buffer.
36/// compressor.compress();
37/// ```
38pub struct FrameCompressor<R: Read, W: Write, M: Matcher> {
39    uncompressed_data: Option<R>,
40    compressed_data: Option<W>,
41    compression_level: CompressionLevel,
42    dictionary: Option<crate::decoding::Dictionary>,
43    dictionary_entropy_cache: Option<CachedDictionaryEntropy>,
44    source_size_hint: Option<u64>,
45    state: CompressState<M>,
46    #[cfg(feature = "hash")]
47    hasher: XxHash64,
48}
49
50#[derive(Clone, Default)]
51struct CachedDictionaryEntropy {
52    huff: Option<crate::huff0::huff0_encoder::HuffmanTable>,
53    ll_previous: Option<PreviousFseTable>,
54    ml_previous: Option<PreviousFseTable>,
55    of_previous: Option<PreviousFseTable>,
56}
57
58#[derive(Clone)]
59pub(crate) enum PreviousFseTable {
60    // Default tables are immutable and already stored alongside the state, so
61    // repeating them only needs a lightweight marker instead of cloning FSETable.
62    Default,
63    Custom(Box<FSETable>),
64}
65
66impl PreviousFseTable {
67    pub(crate) fn as_table<'a>(&'a self, default: &'a FSETable) -> &'a FSETable {
68        match self {
69            Self::Default => default,
70            Self::Custom(table) => table,
71        }
72    }
73}
74
75pub(crate) struct FseTables {
76    pub(crate) ll_default: FSETable,
77    pub(crate) ll_previous: Option<PreviousFseTable>,
78    pub(crate) ml_default: FSETable,
79    pub(crate) ml_previous: Option<PreviousFseTable>,
80    pub(crate) of_default: FSETable,
81    pub(crate) of_previous: Option<PreviousFseTable>,
82}
83
84impl FseTables {
85    pub fn new() -> Self {
86        Self {
87            ll_default: default_ll_table(),
88            ll_previous: None,
89            ml_default: default_ml_table(),
90            ml_previous: None,
91            of_default: default_of_table(),
92            of_previous: None,
93        }
94    }
95}
96
97pub(crate) struct CompressState<M: Matcher> {
98    pub(crate) matcher: M,
99    pub(crate) last_huff_table: Option<crate::huff0::huff0_encoder::HuffmanTable>,
100    pub(crate) fse_tables: FseTables,
101    /// Offset history for repeat offset encoding: [rep0, rep1, rep2].
102    /// Initialized to [1, 4, 8] per RFC 8878 ยง3.1.2.5.
103    pub(crate) offset_hist: [u32; 3],
104}
105
106impl<R: Read, W: Write> FrameCompressor<R, W, MatchGeneratorDriver> {
107    /// Create a new `FrameCompressor`
108    pub fn new(compression_level: CompressionLevel) -> Self {
109        Self {
110            uncompressed_data: None,
111            compressed_data: None,
112            compression_level,
113            dictionary: None,
114            dictionary_entropy_cache: None,
115            source_size_hint: None,
116            state: CompressState {
117                matcher: MatchGeneratorDriver::new(1024 * 128, 1),
118                last_huff_table: None,
119                fse_tables: FseTables::new(),
120                offset_hist: [1, 4, 8],
121            },
122            #[cfg(feature = "hash")]
123            hasher: XxHash64::with_seed(0),
124        }
125    }
126}
127
128impl<R: Read, W: Write, M: Matcher> FrameCompressor<R, W, M> {
129    /// Create a new `FrameCompressor` with a custom matching algorithm implementation
130    pub fn new_with_matcher(matcher: M, compression_level: CompressionLevel) -> Self {
131        Self {
132            uncompressed_data: None,
133            compressed_data: None,
134            dictionary: None,
135            dictionary_entropy_cache: None,
136            source_size_hint: None,
137            state: CompressState {
138                matcher,
139                last_huff_table: None,
140                fse_tables: FseTables::new(),
141                offset_hist: [1, 4, 8],
142            },
143            compression_level,
144            #[cfg(feature = "hash")]
145            hasher: XxHash64::with_seed(0),
146        }
147    }
148
149    /// Before calling [FrameCompressor::compress] you need to set the source.
150    ///
151    /// This is the data that is compressed and written into the drain.
152    pub fn set_source(&mut self, uncompressed_data: R) -> Option<R> {
153        self.uncompressed_data.replace(uncompressed_data)
154    }
155
156    /// Before calling [FrameCompressor::compress] you need to set the drain.
157    ///
158    /// As the compressor compresses data, the drain serves as a place for the output to be writte.
159    pub fn set_drain(&mut self, compressed_data: W) -> Option<W> {
160        self.compressed_data.replace(compressed_data)
161    }
162
163    /// Provide a hint about the total uncompressed size for the next frame.
164    ///
165    /// When set, the encoder selects smaller hash tables and windows for
166    /// small inputs, matching the C zstd source-size-class behavior.
167    ///
168    /// This hint applies only to frame payload bytes (`size`). Dictionary
169    /// history is primed separately and does not inflate the hinted size or
170    /// advertised frame window.
171    /// Must be called before [`compress`](Self::compress).
172    pub fn set_source_size_hint(&mut self, size: u64) {
173        self.source_size_hint = Some(size);
174    }
175
176    /// Compress the uncompressed data from the provided source as one Zstd frame and write it to the provided drain
177    ///
178    /// This will repeatedly call [Read::read] on the source to fill up blocks until the source returns 0 on the read call.
179    /// All compressed blocks are buffered in memory so that the frame header can include the
180    /// `Frame_Content_Size` field (which requires knowing the total uncompressed size). The
181    /// entire frame โ€” header, blocks, and optional checksum โ€” is then written to the drain
182    /// at the end. This means peak memory usage is O(compressed_size).
183    ///
184    /// To avoid endlessly encoding from a potentially endless source (like a network socket) you can use the
185    /// [Read::take] function
186    pub fn compress(&mut self) {
187        let use_dictionary_state =
188            !matches!(self.compression_level, CompressionLevel::Uncompressed)
189                && self.state.matcher.supports_dictionary_priming();
190        if let Some(size_hint) = self.source_size_hint.take() {
191            // Keep source-size hint scoped to payload bytes; dictionary priming
192            // is applied separately and should not force larger matcher sizing.
193            self.state.matcher.set_source_size_hint(size_hint);
194        }
195        // Clearing buffers to allow re-using of the compressor
196        self.state.matcher.reset(self.compression_level);
197        self.state.offset_hist = [1, 4, 8];
198        let cached_entropy = if use_dictionary_state {
199            self.dictionary_entropy_cache.as_ref()
200        } else {
201            None
202        };
203        if use_dictionary_state && let Some(dict) = self.dictionary.as_ref() {
204            // This state drives sequence encoding, while matcher priming below updates
205            // the match generator's internal repeat-offset history for match finding.
206            self.state.offset_hist = dict.offset_hist;
207            self.state
208                .matcher
209                .prime_with_dictionary(dict.dict_content.as_slice(), dict.offset_hist);
210        }
211        if let Some(cache) = cached_entropy {
212            self.state.last_huff_table.clone_from(&cache.huff);
213        } else {
214            self.state.last_huff_table = None;
215        }
216        // `clone_from` keeps frame-to-frame seeding cheap for reused compressors by
217        // reusing existing allocations where possible instead of reallocating every frame.
218        if let Some(cache) = cached_entropy {
219            self.state
220                .fse_tables
221                .ll_previous
222                .clone_from(&cache.ll_previous);
223            self.state
224                .fse_tables
225                .ml_previous
226                .clone_from(&cache.ml_previous);
227            self.state
228                .fse_tables
229                .of_previous
230                .clone_from(&cache.of_previous);
231        } else {
232            self.state.fse_tables.ll_previous = None;
233            self.state.fse_tables.ml_previous = None;
234            self.state.fse_tables.of_previous = None;
235        }
236        #[cfg(feature = "hash")]
237        {
238            self.hasher = XxHash64::with_seed(0);
239        }
240        let source = self.uncompressed_data.as_mut().unwrap();
241        let drain = self.compressed_data.as_mut().unwrap();
242        let window_size = self.state.matcher.window_size();
243        assert!(
244            window_size != 0,
245            "matcher reported window_size == 0, which is invalid"
246        );
247        // Accumulate all compressed blocks; the frame header is written after
248        // all input has been read so that Frame_Content_Size is known.
249        let mut all_blocks: Vec<u8> = Vec::with_capacity(1024 * 130);
250        let mut total_uncompressed: u64 = 0;
251        // Compress block by block
252        loop {
253            // Read a single block's worth of uncompressed data from the input
254            let mut uncompressed_data = self.state.matcher.get_next_space();
255            let mut read_bytes = 0;
256            let last_block;
257            'read_loop: loop {
258                let new_bytes = source.read(&mut uncompressed_data[read_bytes..]).unwrap();
259                if new_bytes == 0 {
260                    last_block = true;
261                    break 'read_loop;
262                }
263                read_bytes += new_bytes;
264                if read_bytes == uncompressed_data.len() {
265                    last_block = false;
266                    break 'read_loop;
267                }
268            }
269            uncompressed_data.resize(read_bytes, 0);
270            total_uncompressed += read_bytes as u64;
271            // As we read, hash that data too
272            #[cfg(feature = "hash")]
273            self.hasher.write(&uncompressed_data);
274            // Special handling is needed for compression of a totally empty file
275            if uncompressed_data.is_empty() {
276                let header = BlockHeader {
277                    last_block: true,
278                    block_type: crate::blocks::block::BlockType::Raw,
279                    block_size: 0,
280                };
281                header.serialize(&mut all_blocks);
282                break;
283            }
284
285            match self.compression_level {
286                CompressionLevel::Uncompressed => {
287                    let header = BlockHeader {
288                        last_block,
289                        block_type: crate::blocks::block::BlockType::Raw,
290                        block_size: read_bytes.try_into().unwrap(),
291                    };
292                    header.serialize(&mut all_blocks);
293                    all_blocks.extend_from_slice(&uncompressed_data);
294                }
295                CompressionLevel::Fastest
296                | CompressionLevel::Default
297                | CompressionLevel::Better
298                | CompressionLevel::Best
299                | CompressionLevel::Level(_) => compress_block_encoded(
300                    &mut self.state,
301                    last_block,
302                    uncompressed_data,
303                    &mut all_blocks,
304                ),
305            }
306            if last_block {
307                break;
308            }
309        }
310
311        // Now that total_uncompressed is known, write the frame header with FCS.
312        // We always include the window descriptor (single_segment = false) because
313        // compressed blocks are encoded against the matcher's window, not the content
314        // size. Setting single_segment would tell the decoder to use FCS as window,
315        // which can be smaller than the encoder's actual window and trip up decoders.
316        let header = FrameHeader {
317            frame_content_size: Some(total_uncompressed),
318            single_segment: false,
319            content_checksum: cfg!(feature = "hash"),
320            dictionary_id: if use_dictionary_state {
321                self.dictionary.as_ref().map(|dict| dict.id as u64)
322            } else {
323                None
324            },
325            window_size: Some(window_size),
326        };
327        // Write the frame header and compressed blocks separately to avoid
328        // shifting the entire `all_blocks` buffer to prepend the header.
329        let mut header_buf: Vec<u8> = Vec::with_capacity(14);
330        header.serialize(&mut header_buf);
331        drain.write_all(&header_buf).unwrap();
332        drain.write_all(&all_blocks).unwrap();
333
334        // If the `hash` feature is enabled, then `content_checksum` is set to true in the header
335        // and a 32 bit hash is written at the end of the data.
336        #[cfg(feature = "hash")]
337        {
338            // Because we only have the data as a reader, we need to read all of it to calculate the checksum
339            // Possible TODO: create a wrapper around self.uncompressed data that hashes the data as it's read?
340            let content_checksum = self.hasher.finish();
341            drain
342                .write_all(&(content_checksum as u32).to_le_bytes())
343                .unwrap();
344        }
345    }
346
347    /// Get a mutable reference to the source
348    pub fn source_mut(&mut self) -> Option<&mut R> {
349        self.uncompressed_data.as_mut()
350    }
351
352    /// Get a mutable reference to the drain
353    pub fn drain_mut(&mut self) -> Option<&mut W> {
354        self.compressed_data.as_mut()
355    }
356
357    /// Get a reference to the source
358    pub fn source(&self) -> Option<&R> {
359        self.uncompressed_data.as_ref()
360    }
361
362    /// Get a reference to the drain
363    pub fn drain(&self) -> Option<&W> {
364        self.compressed_data.as_ref()
365    }
366
367    /// Retrieve the source
368    pub fn take_source(&mut self) -> Option<R> {
369        self.uncompressed_data.take()
370    }
371
372    /// Retrieve the drain
373    pub fn take_drain(&mut self) -> Option<W> {
374        self.compressed_data.take()
375    }
376
377    /// Before calling [FrameCompressor::compress] you can replace the matcher
378    pub fn replace_matcher(&mut self, mut match_generator: M) -> M {
379        core::mem::swap(&mut match_generator, &mut self.state.matcher);
380        match_generator
381    }
382
383    /// Before calling [FrameCompressor::compress] you can replace the compression level
384    pub fn set_compression_level(
385        &mut self,
386        compression_level: CompressionLevel,
387    ) -> CompressionLevel {
388        let old = self.compression_level;
389        self.compression_level = compression_level;
390        old
391    }
392
393    /// Get the current compression level
394    pub fn compression_level(&self) -> CompressionLevel {
395        self.compression_level
396    }
397
398    /// Attach a pre-parsed dictionary to be used for subsequent compressions.
399    ///
400    /// In compressed modes, the dictionary id is written only when the active
401    /// matcher supports dictionary priming.
402    /// Uncompressed mode and non-priming matchers ignore the attached dictionary
403    /// at encode time.
404    pub fn set_dictionary(
405        &mut self,
406        dictionary: crate::decoding::Dictionary,
407    ) -> Result<Option<crate::decoding::Dictionary>, crate::decoding::errors::DictionaryDecodeError>
408    {
409        if dictionary.id == 0 {
410            return Err(crate::decoding::errors::DictionaryDecodeError::ZeroDictionaryId);
411        }
412        if let Some(index) = dictionary.offset_hist.iter().position(|&rep| rep == 0) {
413            return Err(
414                crate::decoding::errors::DictionaryDecodeError::ZeroRepeatOffsetInDictionary {
415                    index: index as u8,
416                },
417            );
418        }
419        self.dictionary_entropy_cache = Some(CachedDictionaryEntropy {
420            huff: dictionary.huf.table.to_encoder_table(),
421            ll_previous: dictionary
422                .fse
423                .literal_lengths
424                .to_encoder_table()
425                .map(|table| PreviousFseTable::Custom(Box::new(table))),
426            ml_previous: dictionary
427                .fse
428                .match_lengths
429                .to_encoder_table()
430                .map(|table| PreviousFseTable::Custom(Box::new(table))),
431            of_previous: dictionary
432                .fse
433                .offsets
434                .to_encoder_table()
435                .map(|table| PreviousFseTable::Custom(Box::new(table))),
436        });
437        Ok(self.dictionary.replace(dictionary))
438    }
439
440    /// Parse and attach a serialized dictionary blob.
441    pub fn set_dictionary_from_bytes(
442        &mut self,
443        raw_dictionary: &[u8],
444    ) -> Result<Option<crate::decoding::Dictionary>, crate::decoding::errors::DictionaryDecodeError>
445    {
446        let dictionary = crate::decoding::Dictionary::decode_dict(raw_dictionary)?;
447        self.set_dictionary(dictionary)
448    }
449
450    /// Remove the attached dictionary.
451    pub fn clear_dictionary(&mut self) -> Option<crate::decoding::Dictionary> {
452        self.dictionary_entropy_cache = None;
453        self.dictionary.take()
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    #[cfg(all(feature = "dict_builder", feature = "std"))]
460    use alloc::format;
461    use alloc::vec;
462
463    use super::FrameCompressor;
464    use crate::common::MAGIC_NUM;
465    use crate::decoding::FrameDecoder;
466    use crate::encoding::{Matcher, Sequence};
467    use alloc::vec::Vec;
468
469    /// Frame content size is written correctly and C zstd can decompress the output.
470    #[cfg(feature = "std")]
471    #[test]
472    fn fcs_header_written_and_c_zstd_compatible() {
473        let levels = [
474            crate::encoding::CompressionLevel::Uncompressed,
475            crate::encoding::CompressionLevel::Fastest,
476            crate::encoding::CompressionLevel::Default,
477            crate::encoding::CompressionLevel::Better,
478            crate::encoding::CompressionLevel::Best,
479        ];
480        let fcs_2byte = vec![0xCDu8; 300]; // 300 bytes โ†’ 2-byte FCS (256..=65791 range)
481        let large = vec![0xABu8; 100_000];
482        let inputs: [&[u8]; 5] = [
483            &[],
484            &[0x00],
485            b"abcdefghijklmnopqrstuvwxy\n",
486            &fcs_2byte,
487            &large,
488        ];
489        for level in levels {
490            for data in &inputs {
491                let compressed = crate::encoding::compress_to_vec(*data, level);
492                // Verify FCS is present and correct
493                let header = crate::decoding::frame::read_frame_header(compressed.as_slice())
494                    .unwrap()
495                    .0;
496                assert_eq!(
497                    header.frame_content_size(),
498                    data.len() as u64,
499                    "FCS mismatch for len={} level={:?}",
500                    data.len(),
501                    level,
502                );
503                // Confirm the FCS field is actually present in the header
504                // (not just the decoder returning 0 for absent FCS).
505                assert_ne!(
506                    header.descriptor.frame_content_size_bytes().unwrap(),
507                    0,
508                    "FCS field must be present for len={} level={:?}",
509                    data.len(),
510                    level,
511                );
512                // Verify C zstd can decompress
513                let mut decoded = Vec::new();
514                zstd::stream::copy_decode(compressed.as_slice(), &mut decoded).unwrap();
515                assert_eq!(
516                    decoded.as_slice(),
517                    *data,
518                    "C zstd roundtrip failed for len={}",
519                    data.len()
520                );
521            }
522        }
523    }
524
525    struct NoDictionaryMatcher {
526        last_space: Vec<u8>,
527        window_size: u64,
528    }
529
530    impl NoDictionaryMatcher {
531        fn new(window_size: u64) -> Self {
532            Self {
533                last_space: Vec::new(),
534                window_size,
535            }
536        }
537    }
538
539    impl Matcher for NoDictionaryMatcher {
540        fn get_next_space(&mut self) -> Vec<u8> {
541            vec![0; self.window_size as usize]
542        }
543
544        fn get_last_space(&mut self) -> &[u8] {
545            self.last_space.as_slice()
546        }
547
548        fn commit_space(&mut self, space: Vec<u8>) {
549            self.last_space = space;
550        }
551
552        fn skip_matching(&mut self) {}
553
554        fn start_matching(&mut self, mut handle_sequence: impl for<'a> FnMut(Sequence<'a>)) {
555            handle_sequence(Sequence::Literals {
556                literals: self.last_space.as_slice(),
557            });
558        }
559
560        fn reset(&mut self, _level: super::CompressionLevel) {
561            self.last_space.clear();
562        }
563
564        fn window_size(&self) -> u64 {
565            self.window_size
566        }
567    }
568
569    #[test]
570    fn frame_starts_with_magic_num() {
571        let mock_data = [1_u8, 2, 3].as_slice();
572        let mut output: Vec<u8> = Vec::new();
573        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
574        compressor.set_source(mock_data);
575        compressor.set_drain(&mut output);
576
577        compressor.compress();
578        assert!(output.starts_with(&MAGIC_NUM.to_le_bytes()));
579    }
580
581    #[test]
582    fn very_simple_raw_compress() {
583        let mock_data = [1_u8, 2, 3].as_slice();
584        let mut output: Vec<u8> = Vec::new();
585        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
586        compressor.set_source(mock_data);
587        compressor.set_drain(&mut output);
588
589        compressor.compress();
590    }
591
592    #[test]
593    fn very_simple_compress() {
594        let mut mock_data = vec![0; 1 << 17];
595        mock_data.extend(vec![1; (1 << 17) - 1]);
596        mock_data.extend(vec![2; (1 << 18) - 1]);
597        mock_data.extend(vec![2; 1 << 17]);
598        mock_data.extend(vec![3; (1 << 17) - 1]);
599        let mut output: Vec<u8> = Vec::new();
600        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
601        compressor.set_source(mock_data.as_slice());
602        compressor.set_drain(&mut output);
603
604        compressor.compress();
605
606        let mut decoder = FrameDecoder::new();
607        let mut decoded = Vec::with_capacity(mock_data.len());
608        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
609        assert_eq!(mock_data, decoded);
610
611        let mut decoded = Vec::new();
612        zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
613        assert_eq!(mock_data, decoded);
614    }
615
616    #[test]
617    fn rle_compress() {
618        let mock_data = vec![0; 1 << 19];
619        let mut output: Vec<u8> = Vec::new();
620        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
621        compressor.set_source(mock_data.as_slice());
622        compressor.set_drain(&mut output);
623
624        compressor.compress();
625
626        let mut decoder = FrameDecoder::new();
627        let mut decoded = Vec::with_capacity(mock_data.len());
628        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
629        assert_eq!(mock_data, decoded);
630    }
631
632    #[test]
633    fn aaa_compress() {
634        let mock_data = vec![0, 1, 3, 4, 5];
635        let mut output: Vec<u8> = Vec::new();
636        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
637        compressor.set_source(mock_data.as_slice());
638        compressor.set_drain(&mut output);
639
640        compressor.compress();
641
642        let mut decoder = FrameDecoder::new();
643        let mut decoded = Vec::with_capacity(mock_data.len());
644        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
645        assert_eq!(mock_data, decoded);
646
647        let mut decoded = Vec::new();
648        zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
649        assert_eq!(mock_data, decoded);
650    }
651
652    #[test]
653    fn dictionary_compression_sets_required_dict_id_and_roundtrips() {
654        let dict_raw = include_bytes!("../../dict_tests/dictionary");
655        let dict_for_encoder = crate::decoding::Dictionary::decode_dict(dict_raw).unwrap();
656        let dict_for_decoder = crate::decoding::Dictionary::decode_dict(dict_raw).unwrap();
657
658        let mut data = Vec::new();
659        for _ in 0..8 {
660            data.extend_from_slice(&dict_for_decoder.dict_content[..2048]);
661        }
662
663        let mut with_dict = Vec::new();
664        let mut compressor = FrameCompressor::new(super::CompressionLevel::Fastest);
665        let previous = compressor
666            .set_dictionary_from_bytes(dict_raw)
667            .expect("dictionary bytes should parse");
668        assert!(
669            previous.is_none(),
670            "first dictionary insert should return None"
671        );
672        assert_eq!(
673            compressor
674                .set_dictionary(dict_for_encoder)
675                .expect("valid dictionary should attach")
676                .expect("set_dictionary_from_bytes inserted previous dictionary")
677                .id,
678            dict_for_decoder.id
679        );
680        compressor.set_source(data.as_slice());
681        compressor.set_drain(&mut with_dict);
682        compressor.compress();
683
684        let (frame_header, _) = crate::decoding::frame::read_frame_header(with_dict.as_slice())
685            .expect("encoded stream should have a frame header");
686        assert_eq!(frame_header.dictionary_id(), Some(dict_for_decoder.id));
687
688        let mut decoder = FrameDecoder::new();
689        let mut missing_dict_target = Vec::with_capacity(data.len());
690        let err = decoder
691            .decode_all_to_vec(&with_dict, &mut missing_dict_target)
692            .unwrap_err();
693        assert!(
694            matches!(
695                &err,
696                crate::decoding::errors::FrameDecoderError::DictNotProvided { .. }
697            ),
698            "dict-compressed stream should require dictionary id, got: {err:?}"
699        );
700
701        let mut decoder = FrameDecoder::new();
702        decoder.add_dict(dict_for_decoder).unwrap();
703        let mut decoded = Vec::with_capacity(data.len());
704        decoder.decode_all_to_vec(&with_dict, &mut decoded).unwrap();
705        assert_eq!(decoded, data);
706
707        let mut ffi_decoder = zstd::bulk::Decompressor::with_dictionary(dict_raw).unwrap();
708        let mut ffi_decoded = Vec::with_capacity(data.len());
709        let ffi_written = ffi_decoder
710            .decompress_to_buffer(with_dict.as_slice(), &mut ffi_decoded)
711            .unwrap();
712        assert_eq!(ffi_written, data.len());
713        assert_eq!(ffi_decoded, data);
714    }
715
716    #[cfg(all(feature = "dict_builder", feature = "std"))]
717    #[test]
718    fn dictionary_compression_roundtrips_with_dict_builder_dictionary() {
719        use std::io::Cursor;
720
721        let mut training = Vec::new();
722        for idx in 0..256u32 {
723            training.extend_from_slice(
724                format!("tenant=demo table=orders key={idx} region=eu\n").as_bytes(),
725            );
726        }
727        let mut raw_dict = Vec::new();
728        crate::dictionary::create_raw_dict_from_source(
729            Cursor::new(training.as_slice()),
730            training.len(),
731            &mut raw_dict,
732            4096,
733        );
734        assert!(
735            !raw_dict.is_empty(),
736            "dict_builder produced an empty dictionary"
737        );
738
739        let dict_id = 0xD1C7_0008;
740        let encoder_dict =
741            crate::decoding::Dictionary::from_raw_content(dict_id, raw_dict.clone()).unwrap();
742        let decoder_dict =
743            crate::decoding::Dictionary::from_raw_content(dict_id, raw_dict.clone()).unwrap();
744
745        let mut payload = Vec::new();
746        for idx in 0..96u32 {
747            payload.extend_from_slice(
748                format!(
749                    "tenant=demo table=orders op=put key={idx} value=aaaaabbbbbcccccdddddeeeee\n"
750                )
751                .as_bytes(),
752            );
753        }
754
755        let mut without_dict = Vec::new();
756        let mut baseline = FrameCompressor::new(super::CompressionLevel::Fastest);
757        baseline.set_source(payload.as_slice());
758        baseline.set_drain(&mut without_dict);
759        baseline.compress();
760
761        let mut with_dict = Vec::new();
762        let mut compressor = FrameCompressor::new(super::CompressionLevel::Fastest);
763        compressor
764            .set_dictionary(encoder_dict)
765            .expect("valid dict_builder dictionary should attach");
766        compressor.set_source(payload.as_slice());
767        compressor.set_drain(&mut with_dict);
768        compressor.compress();
769
770        let (frame_header, _) = crate::decoding::frame::read_frame_header(with_dict.as_slice())
771            .expect("encoded stream should have a frame header");
772        assert_eq!(frame_header.dictionary_id(), Some(dict_id));
773        let mut decoder = FrameDecoder::new();
774        decoder.add_dict(decoder_dict).unwrap();
775        let mut decoded = Vec::with_capacity(payload.len());
776        decoder.decode_all_to_vec(&with_dict, &mut decoded).unwrap();
777        assert_eq!(decoded, payload);
778        assert!(
779            with_dict.len() < without_dict.len(),
780            "trained dictionary should improve compression for this small payload"
781        );
782    }
783
784    #[test]
785    fn set_dictionary_from_bytes_seeds_entropy_tables_for_first_block() {
786        let dict_raw = include_bytes!("../../dict_tests/dictionary");
787        let mut output = Vec::new();
788        let input = b"";
789
790        let mut compressor = FrameCompressor::new(super::CompressionLevel::Fastest);
791        let previous = compressor
792            .set_dictionary_from_bytes(dict_raw)
793            .expect("dictionary bytes should parse");
794        assert!(previous.is_none());
795
796        compressor.set_source(input.as_slice());
797        compressor.set_drain(&mut output);
798        compressor.compress();
799
800        assert!(
801            compressor.state.last_huff_table.is_some(),
802            "dictionary entropy should seed previous huffman table before first block"
803        );
804        assert!(
805            compressor.state.fse_tables.ll_previous.is_some(),
806            "dictionary entropy should seed previous ll table before first block"
807        );
808        assert!(
809            compressor.state.fse_tables.ml_previous.is_some(),
810            "dictionary entropy should seed previous ml table before first block"
811        );
812        assert!(
813            compressor.state.fse_tables.of_previous.is_some(),
814            "dictionary entropy should seed previous of table before first block"
815        );
816    }
817
818    #[test]
819    fn set_dictionary_rejects_zero_dictionary_id() {
820        let invalid = crate::decoding::Dictionary {
821            id: 0,
822            fse: crate::decoding::scratch::FSEScratch::new(),
823            huf: crate::decoding::scratch::HuffmanScratch::new(),
824            dict_content: vec![1, 2, 3],
825            offset_hist: [1, 4, 8],
826        };
827
828        let mut compressor: FrameCompressor<
829            &[u8],
830            Vec<u8>,
831            crate::encoding::match_generator::MatchGeneratorDriver,
832        > = FrameCompressor::new(super::CompressionLevel::Fastest);
833        let result = compressor.set_dictionary(invalid);
834        assert!(matches!(
835            result,
836            Err(crate::decoding::errors::DictionaryDecodeError::ZeroDictionaryId)
837        ));
838    }
839
840    #[test]
841    fn set_dictionary_rejects_zero_repeat_offsets() {
842        let invalid = crate::decoding::Dictionary {
843            id: 1,
844            fse: crate::decoding::scratch::FSEScratch::new(),
845            huf: crate::decoding::scratch::HuffmanScratch::new(),
846            dict_content: vec![1, 2, 3],
847            offset_hist: [0, 4, 8],
848        };
849
850        let mut compressor: FrameCompressor<
851            &[u8],
852            Vec<u8>,
853            crate::encoding::match_generator::MatchGeneratorDriver,
854        > = FrameCompressor::new(super::CompressionLevel::Fastest);
855        let result = compressor.set_dictionary(invalid);
856        assert!(matches!(
857            result,
858            Err(
859                crate::decoding::errors::DictionaryDecodeError::ZeroRepeatOffsetInDictionary {
860                    index: 0
861                }
862            )
863        ));
864    }
865
866    #[test]
867    fn uncompressed_mode_does_not_require_dictionary() {
868        let dict_id = 0xABCD_0001;
869        let dict =
870            crate::decoding::Dictionary::from_raw_content(dict_id, b"shared-history".to_vec())
871                .expect("raw dictionary should be valid");
872
873        let payload = b"plain-bytes-that-should-stay-raw";
874        let mut output = Vec::new();
875        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
876        compressor
877            .set_dictionary(dict)
878            .expect("dictionary should attach in uncompressed mode");
879        compressor.set_source(payload.as_slice());
880        compressor.set_drain(&mut output);
881        compressor.compress();
882
883        let (frame_header, _) = crate::decoding::frame::read_frame_header(output.as_slice())
884            .expect("encoded frame should have a header");
885        assert_eq!(
886            frame_header.dictionary_id(),
887            None,
888            "raw/uncompressed frames must not advertise dictionary dependency"
889        );
890
891        let mut decoder = FrameDecoder::new();
892        let mut decoded = Vec::with_capacity(payload.len());
893        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
894        assert_eq!(decoded, payload);
895    }
896
897    #[test]
898    fn dictionary_roundtrip_stays_valid_after_output_exceeds_window() {
899        use crate::encoding::match_generator::MatchGeneratorDriver;
900
901        let dict_id = 0xABCD_0002;
902        let dict = crate::decoding::Dictionary::from_raw_content(dict_id, b"abcdefgh".to_vec())
903            .expect("raw dictionary should be valid");
904        let dict_for_decoder =
905            crate::decoding::Dictionary::from_raw_content(dict_id, b"abcdefgh".to_vec())
906                .expect("raw dictionary should be valid");
907
908        // Payload must exceed the encoder's advertised window (128 KiB for
909        // Fastest) so the test actually exercises cross-window-boundary behavior.
910        let payload = b"abcdefgh".repeat(128 * 1024 / 8 + 64);
911        let matcher = MatchGeneratorDriver::new(1024, 1);
912
913        let mut no_dict_output = Vec::new();
914        let mut no_dict_compressor =
915            FrameCompressor::new_with_matcher(matcher, super::CompressionLevel::Fastest);
916        no_dict_compressor.set_source(payload.as_slice());
917        no_dict_compressor.set_drain(&mut no_dict_output);
918        no_dict_compressor.compress();
919        let (no_dict_frame_header, _) =
920            crate::decoding::frame::read_frame_header(no_dict_output.as_slice())
921                .expect("baseline frame should have a header");
922        let no_dict_window = no_dict_frame_header
923            .window_size()
924            .expect("window size should be present");
925
926        let mut output = Vec::new();
927        let matcher = MatchGeneratorDriver::new(1024, 1);
928        let mut compressor =
929            FrameCompressor::new_with_matcher(matcher, super::CompressionLevel::Fastest);
930        compressor
931            .set_dictionary(dict)
932            .expect("dictionary should attach");
933        compressor.set_source(payload.as_slice());
934        compressor.set_drain(&mut output);
935        compressor.compress();
936
937        let (frame_header, _) = crate::decoding::frame::read_frame_header(output.as_slice())
938            .expect("encoded frame should have a header");
939        let advertised_window = frame_header
940            .window_size()
941            .expect("window size should be present");
942        assert_eq!(
943            advertised_window, no_dict_window,
944            "dictionary priming must not inflate advertised window size"
945        );
946        assert!(
947            payload.len() > advertised_window as usize,
948            "test must cross the advertised window boundary"
949        );
950
951        let mut decoder = FrameDecoder::new();
952        decoder.add_dict(dict_for_decoder).unwrap();
953        let mut decoded = Vec::with_capacity(payload.len());
954        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
955        assert_eq!(decoded, payload);
956    }
957
958    #[test]
959    fn source_size_hint_with_dictionary_keeps_roundtrip_and_nonincreasing_window() {
960        let dict_id = 0xABCD_0004;
961        let dict_content = b"abcd".repeat(1024); // 4 KiB dictionary history
962        let dict = crate::decoding::Dictionary::from_raw_content(dict_id, dict_content).unwrap();
963        let dict_for_decoder =
964            crate::decoding::Dictionary::from_raw_content(dict_id, b"abcd".repeat(1024)).unwrap();
965        let payload = b"abcdabcdabcdabcd".repeat(128);
966
967        let mut hinted_output = Vec::new();
968        let mut hinted = FrameCompressor::new(super::CompressionLevel::Fastest);
969        hinted.set_dictionary(dict).unwrap();
970        hinted.set_source_size_hint(1);
971        hinted.set_source(payload.as_slice());
972        hinted.set_drain(&mut hinted_output);
973        hinted.compress();
974
975        let mut no_hint_output = Vec::new();
976        let mut no_hint = FrameCompressor::new(super::CompressionLevel::Fastest);
977        no_hint
978            .set_dictionary(
979                crate::decoding::Dictionary::from_raw_content(dict_id, b"abcd".repeat(1024))
980                    .unwrap(),
981            )
982            .unwrap();
983        no_hint.set_source(payload.as_slice());
984        no_hint.set_drain(&mut no_hint_output);
985        no_hint.compress();
986
987        let hinted_window = crate::decoding::frame::read_frame_header(hinted_output.as_slice())
988            .expect("encoded frame should have a header")
989            .0
990            .window_size()
991            .expect("window size should be present");
992        let no_hint_window = crate::decoding::frame::read_frame_header(no_hint_output.as_slice())
993            .expect("encoded frame should have a header")
994            .0
995            .window_size()
996            .expect("window size should be present");
997        assert!(
998            hinted_window <= no_hint_window,
999            "source-size hint should not increase advertised window with dictionary priming",
1000        );
1001
1002        let mut decoder = FrameDecoder::new();
1003        decoder.add_dict(dict_for_decoder).unwrap();
1004        let mut decoded = Vec::with_capacity(payload.len());
1005        decoder
1006            .decode_all_to_vec(&hinted_output, &mut decoded)
1007            .unwrap();
1008        assert_eq!(decoded, payload);
1009    }
1010
1011    #[test]
1012    fn source_size_hint_with_dictionary_keeps_roundtrip_for_larger_payload() {
1013        let dict_id = 0xABCD_0005;
1014        let dict_content = b"abcd".repeat(1024); // 4 KiB dictionary history
1015        let dict = crate::decoding::Dictionary::from_raw_content(dict_id, dict_content).unwrap();
1016        let dict_for_decoder =
1017            crate::decoding::Dictionary::from_raw_content(dict_id, b"abcd".repeat(1024)).unwrap();
1018        let payload = b"abcd".repeat(1024); // 4 KiB payload
1019        let payload_len = payload.len() as u64;
1020
1021        let mut hinted_output = Vec::new();
1022        let mut hinted = FrameCompressor::new(super::CompressionLevel::Fastest);
1023        hinted.set_dictionary(dict).unwrap();
1024        hinted.set_source_size_hint(payload_len);
1025        hinted.set_source(payload.as_slice());
1026        hinted.set_drain(&mut hinted_output);
1027        hinted.compress();
1028
1029        let mut no_hint_output = Vec::new();
1030        let mut no_hint = FrameCompressor::new(super::CompressionLevel::Fastest);
1031        no_hint
1032            .set_dictionary(
1033                crate::decoding::Dictionary::from_raw_content(dict_id, b"abcd".repeat(1024))
1034                    .unwrap(),
1035            )
1036            .unwrap();
1037        no_hint.set_source(payload.as_slice());
1038        no_hint.set_drain(&mut no_hint_output);
1039        no_hint.compress();
1040
1041        let hinted_window = crate::decoding::frame::read_frame_header(hinted_output.as_slice())
1042            .expect("encoded frame should have a header")
1043            .0
1044            .window_size()
1045            .expect("window size should be present");
1046        let no_hint_window = crate::decoding::frame::read_frame_header(no_hint_output.as_slice())
1047            .expect("encoded frame should have a header")
1048            .0
1049            .window_size()
1050            .expect("window size should be present");
1051        assert!(
1052            hinted_window <= no_hint_window,
1053            "source-size hint should not increase advertised window with dictionary priming",
1054        );
1055
1056        let mut decoder = FrameDecoder::new();
1057        decoder.add_dict(dict_for_decoder).unwrap();
1058        let mut decoded = Vec::with_capacity(payload.len());
1059        decoder
1060            .decode_all_to_vec(&hinted_output, &mut decoded)
1061            .unwrap();
1062        assert_eq!(decoded, payload);
1063    }
1064
1065    #[test]
1066    fn custom_matcher_without_dictionary_priming_does_not_advertise_dict_id() {
1067        let dict_id = 0xABCD_0003;
1068        let dict = crate::decoding::Dictionary::from_raw_content(dict_id, b"abcdefgh".to_vec())
1069            .expect("raw dictionary should be valid");
1070        let payload = b"abcdefghabcdefgh";
1071
1072        let mut output = Vec::new();
1073        let matcher = NoDictionaryMatcher::new(64);
1074        let mut compressor =
1075            FrameCompressor::new_with_matcher(matcher, super::CompressionLevel::Fastest);
1076        compressor
1077            .set_dictionary(dict)
1078            .expect("dictionary should attach");
1079        compressor.set_source(payload.as_slice());
1080        compressor.set_drain(&mut output);
1081        compressor.compress();
1082
1083        let (frame_header, _) = crate::decoding::frame::read_frame_header(output.as_slice())
1084            .expect("encoded frame should have a header");
1085        assert_eq!(
1086            frame_header.dictionary_id(),
1087            None,
1088            "matchers that do not support dictionary priming must not advertise dictionary dependency"
1089        );
1090
1091        let mut decoder = FrameDecoder::new();
1092        let mut decoded = Vec::with_capacity(payload.len());
1093        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
1094        assert_eq!(decoded, payload);
1095    }
1096
1097    #[cfg(feature = "hash")]
1098    #[test]
1099    fn checksum_two_frames_reused_compressor() {
1100        // Compress the same data twice using the same compressor and verify that:
1101        // 1. The checksum written in each frame matches what the decoder calculates.
1102        // 2. The hasher is correctly reset between frames (no cross-contamination).
1103        //    If the hasher were NOT reset, the second frame's calculated checksum
1104        //    would differ from the one stored in the frame data, causing assert_eq to fail.
1105        let data: Vec<u8> = (0u8..=255).cycle().take(1024).collect();
1106
1107        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
1108
1109        // --- Frame 1 ---
1110        let mut compressed1 = Vec::new();
1111        compressor.set_source(data.as_slice());
1112        compressor.set_drain(&mut compressed1);
1113        compressor.compress();
1114
1115        // --- Frame 2 (reuse the same compressor) ---
1116        let mut compressed2 = Vec::new();
1117        compressor.set_source(data.as_slice());
1118        compressor.set_drain(&mut compressed2);
1119        compressor.compress();
1120
1121        fn decode_and_collect(compressed: &[u8]) -> (Vec<u8>, Option<u32>, Option<u32>) {
1122            let mut decoder = FrameDecoder::new();
1123            let mut source = compressed;
1124            decoder.reset(&mut source).unwrap();
1125            while !decoder.is_finished() {
1126                decoder
1127                    .decode_blocks(&mut source, crate::decoding::BlockDecodingStrategy::All)
1128                    .unwrap();
1129            }
1130            let mut decoded = Vec::new();
1131            decoder.collect_to_writer(&mut decoded).unwrap();
1132            (
1133                decoded,
1134                decoder.get_checksum_from_data(),
1135                decoder.get_calculated_checksum(),
1136            )
1137        }
1138
1139        let (decoded1, chksum_from_data1, chksum_calculated1) = decode_and_collect(&compressed1);
1140        assert_eq!(decoded1, data, "frame 1: decoded data mismatch");
1141        assert_eq!(
1142            chksum_from_data1, chksum_calculated1,
1143            "frame 1: checksum mismatch"
1144        );
1145
1146        let (decoded2, chksum_from_data2, chksum_calculated2) = decode_and_collect(&compressed2);
1147        assert_eq!(decoded2, data, "frame 2: decoded data mismatch");
1148        assert_eq!(
1149            chksum_from_data2, chksum_calculated2,
1150            "frame 2: checksum mismatch"
1151        );
1152
1153        // Same data compressed twice must produce the same checksum.
1154        // If state leaked across frames, the second calculated checksum would differ.
1155        assert_eq!(
1156            chksum_from_data1, chksum_from_data2,
1157            "frame 1 and frame 2 should have the same checksum (same data, hash must reset per frame)"
1158        );
1159    }
1160
1161    #[cfg(feature = "std")]
1162    #[test]
1163    fn fuzz_targets() {
1164        use std::io::Read;
1165        fn decode_szstd(data: &mut dyn std::io::Read) -> Vec<u8> {
1166            let mut decoder = crate::decoding::StreamingDecoder::new(data).unwrap();
1167            let mut result: Vec<u8> = Vec::new();
1168            decoder.read_to_end(&mut result).expect("Decoding failed");
1169            result
1170        }
1171
1172        fn decode_szstd_writer(mut data: impl Read) -> Vec<u8> {
1173            let mut decoder = crate::decoding::FrameDecoder::new();
1174            decoder.reset(&mut data).unwrap();
1175            let mut result = vec![];
1176            while !decoder.is_finished() || decoder.can_collect() > 0 {
1177                decoder
1178                    .decode_blocks(
1179                        &mut data,
1180                        crate::decoding::BlockDecodingStrategy::UptoBytes(1024 * 1024),
1181                    )
1182                    .unwrap();
1183                decoder.collect_to_writer(&mut result).unwrap();
1184            }
1185            result
1186        }
1187
1188        fn encode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
1189            zstd::stream::encode_all(std::io::Cursor::new(data), 3)
1190        }
1191
1192        fn encode_szstd_uncompressed(data: &mut dyn std::io::Read) -> Vec<u8> {
1193            let mut input = Vec::new();
1194            data.read_to_end(&mut input).unwrap();
1195
1196            crate::encoding::compress_to_vec(
1197                input.as_slice(),
1198                crate::encoding::CompressionLevel::Uncompressed,
1199            )
1200        }
1201
1202        fn encode_szstd_compressed(data: &mut dyn std::io::Read) -> Vec<u8> {
1203            let mut input = Vec::new();
1204            data.read_to_end(&mut input).unwrap();
1205
1206            crate::encoding::compress_to_vec(
1207                input.as_slice(),
1208                crate::encoding::CompressionLevel::Fastest,
1209            )
1210        }
1211
1212        fn decode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
1213            let mut output = Vec::new();
1214            zstd::stream::copy_decode(data, &mut output)?;
1215            Ok(output)
1216        }
1217        if std::fs::exists("fuzz/artifacts/interop").unwrap_or(false) {
1218            for file in std::fs::read_dir("fuzz/artifacts/interop").unwrap() {
1219                if file.as_ref().unwrap().file_type().unwrap().is_file() {
1220                    let data = std::fs::read(file.unwrap().path()).unwrap();
1221                    let data = data.as_slice();
1222                    // Decoding
1223                    let compressed = encode_zstd(data).unwrap();
1224                    let decoded = decode_szstd(&mut compressed.as_slice());
1225                    let decoded2 = decode_szstd_writer(&mut compressed.as_slice());
1226                    assert!(
1227                        decoded == data,
1228                        "Decoded data did not match the original input during decompression"
1229                    );
1230                    assert_eq!(
1231                        decoded2, data,
1232                        "Decoded data did not match the original input during decompression"
1233                    );
1234
1235                    // Encoding
1236                    // Uncompressed encoding
1237                    let mut input = data;
1238                    let compressed = encode_szstd_uncompressed(&mut input);
1239                    let decoded = decode_zstd(&compressed).unwrap();
1240                    assert_eq!(
1241                        decoded, data,
1242                        "Decoded data did not match the original input during compression"
1243                    );
1244                    // Compressed encoding
1245                    let mut input = data;
1246                    let compressed = encode_szstd_compressed(&mut input);
1247                    let decoded = decode_zstd(&compressed).unwrap();
1248                    assert_eq!(
1249                        decoded, data,
1250                        "Decoded data did not match the original input during compression"
1251                    );
1252                }
1253            }
1254        }
1255    }
1256}