Skip to main content

structured_zstd/encoding/
frame_compressor.rs

1//! Utilities and interfaces for encoding an entire frame. Allows reusing resources
2
3use alloc::{boxed::Box, vec::Vec};
4use core::convert::TryInto;
5#[cfg(feature = "hash")]
6use twox_hash::XxHash64;
7
8#[cfg(feature = "hash")]
9use core::hash::Hasher;
10
11use super::{
12    CompressionLevel, Matcher, block_header::BlockHeader, frame_header::FrameHeader, levels::*,
13    match_generator::MatchGeneratorDriver,
14};
15use crate::fse::fse_encoder::{FSETable, default_ll_table, default_ml_table, default_of_table};
16
17use crate::io::{Read, Write};
18
19/// An interface for compressing arbitrary data with the ZStandard compression algorithm.
20///
21/// `FrameCompressor` will generally be used by:
22/// 1. Initializing a compressor by providing a buffer of data using `FrameCompressor::new()`
23/// 2. Starting compression and writing that compression into a vec using `FrameCompressor::begin`
24///
25/// # Examples
26/// ```
27/// use structured_zstd::encoding::{FrameCompressor, CompressionLevel};
28/// let mock_data: &[_] = &[0x1, 0x2, 0x3, 0x4];
29/// let mut output = std::vec::Vec::new();
30/// // Initialize a compressor.
31/// let mut compressor = FrameCompressor::new(CompressionLevel::Uncompressed);
32/// compressor.set_source(mock_data);
33/// compressor.set_drain(&mut output);
34///
35/// // `compress` writes the compressed output into the provided buffer.
36/// compressor.compress();
37/// ```
38pub struct FrameCompressor<R: Read, W: Write, M: Matcher> {
39    uncompressed_data: Option<R>,
40    compressed_data: Option<W>,
41    compression_level: CompressionLevel,
42    dictionary: Option<crate::decoding::Dictionary>,
43    dictionary_entropy_cache: Option<CachedDictionaryEntropy>,
44    state: CompressState<M>,
45    #[cfg(feature = "hash")]
46    hasher: XxHash64,
47}
48
49#[derive(Clone, Default)]
50struct CachedDictionaryEntropy {
51    huff: Option<crate::huff0::huff0_encoder::HuffmanTable>,
52    ll_previous: Option<PreviousFseTable>,
53    ml_previous: Option<PreviousFseTable>,
54    of_previous: Option<PreviousFseTable>,
55}
56
57#[derive(Clone)]
58pub(crate) enum PreviousFseTable {
59    // Default tables are immutable and already stored alongside the state, so
60    // repeating them only needs a lightweight marker instead of cloning FSETable.
61    Default,
62    Custom(Box<FSETable>),
63}
64
65impl PreviousFseTable {
66    pub(crate) fn as_table<'a>(&'a self, default: &'a FSETable) -> &'a FSETable {
67        match self {
68            Self::Default => default,
69            Self::Custom(table) => table,
70        }
71    }
72}
73
74pub(crate) struct FseTables {
75    pub(crate) ll_default: FSETable,
76    pub(crate) ll_previous: Option<PreviousFseTable>,
77    pub(crate) ml_default: FSETable,
78    pub(crate) ml_previous: Option<PreviousFseTable>,
79    pub(crate) of_default: FSETable,
80    pub(crate) of_previous: Option<PreviousFseTable>,
81}
82
83impl FseTables {
84    pub fn new() -> Self {
85        Self {
86            ll_default: default_ll_table(),
87            ll_previous: None,
88            ml_default: default_ml_table(),
89            ml_previous: None,
90            of_default: default_of_table(),
91            of_previous: None,
92        }
93    }
94}
95
96pub(crate) struct CompressState<M: Matcher> {
97    pub(crate) matcher: M,
98    pub(crate) last_huff_table: Option<crate::huff0::huff0_encoder::HuffmanTable>,
99    pub(crate) fse_tables: FseTables,
100    /// Offset history for repeat offset encoding: [rep0, rep1, rep2].
101    /// Initialized to [1, 4, 8] per RFC 8878 ยง3.1.2.5.
102    pub(crate) offset_hist: [u32; 3],
103}
104
105impl<R: Read, W: Write> FrameCompressor<R, W, MatchGeneratorDriver> {
106    /// Create a new `FrameCompressor`
107    pub fn new(compression_level: CompressionLevel) -> Self {
108        Self {
109            uncompressed_data: None,
110            compressed_data: None,
111            compression_level,
112            dictionary: None,
113            dictionary_entropy_cache: None,
114            state: CompressState {
115                matcher: MatchGeneratorDriver::new(1024 * 128, 1),
116                last_huff_table: None,
117                fse_tables: FseTables::new(),
118                offset_hist: [1, 4, 8],
119            },
120            #[cfg(feature = "hash")]
121            hasher: XxHash64::with_seed(0),
122        }
123    }
124}
125
126impl<R: Read, W: Write, M: Matcher> FrameCompressor<R, W, M> {
127    /// Create a new `FrameCompressor` with a custom matching algorithm implementation
128    pub fn new_with_matcher(matcher: M, compression_level: CompressionLevel) -> Self {
129        Self {
130            uncompressed_data: None,
131            compressed_data: None,
132            dictionary: None,
133            dictionary_entropy_cache: None,
134            state: CompressState {
135                matcher,
136                last_huff_table: None,
137                fse_tables: FseTables::new(),
138                offset_hist: [1, 4, 8],
139            },
140            compression_level,
141            #[cfg(feature = "hash")]
142            hasher: XxHash64::with_seed(0),
143        }
144    }
145
146    /// Before calling [FrameCompressor::compress] you need to set the source.
147    ///
148    /// This is the data that is compressed and written into the drain.
149    pub fn set_source(&mut self, uncompressed_data: R) -> Option<R> {
150        self.uncompressed_data.replace(uncompressed_data)
151    }
152
153    /// Before calling [FrameCompressor::compress] you need to set the drain.
154    ///
155    /// As the compressor compresses data, the drain serves as a place for the output to be writte.
156    pub fn set_drain(&mut self, compressed_data: W) -> Option<W> {
157        self.compressed_data.replace(compressed_data)
158    }
159
160    /// Compress the uncompressed data from the provided source as one Zstd frame and write it to the provided drain
161    ///
162    /// This will repeatedly call [Read::read] on the source to fill up blocks until the source returns 0 on the read call.
163    /// All compressed blocks are buffered in memory so that the frame header can include the
164    /// `Frame_Content_Size` field (which requires knowing the total uncompressed size). The
165    /// entire frame โ€” header, blocks, and optional checksum โ€” is then written to the drain
166    /// at the end. This means peak memory usage is O(compressed_size).
167    ///
168    /// To avoid endlessly encoding from a potentially endless source (like a network socket) you can use the
169    /// [Read::take] function
170    pub fn compress(&mut self) {
171        // Clearing buffers to allow re-using of the compressor
172        self.state.matcher.reset(self.compression_level);
173        self.state.offset_hist = [1, 4, 8];
174        let use_dictionary_state =
175            !matches!(self.compression_level, CompressionLevel::Uncompressed)
176                && self.state.matcher.supports_dictionary_priming();
177        let cached_entropy = if use_dictionary_state {
178            self.dictionary_entropy_cache.as_ref()
179        } else {
180            None
181        };
182        if use_dictionary_state && let Some(dict) = self.dictionary.as_ref() {
183            // This state drives sequence encoding, while matcher priming below updates
184            // the match generator's internal repeat-offset history for match finding.
185            self.state.offset_hist = dict.offset_hist;
186            self.state
187                .matcher
188                .prime_with_dictionary(dict.dict_content.as_slice(), dict.offset_hist);
189        }
190        if let Some(cache) = cached_entropy {
191            self.state.last_huff_table.clone_from(&cache.huff);
192        } else {
193            self.state.last_huff_table = None;
194        }
195        // `clone_from` keeps frame-to-frame seeding cheap for reused compressors by
196        // reusing existing allocations where possible instead of reallocating every frame.
197        if let Some(cache) = cached_entropy {
198            self.state
199                .fse_tables
200                .ll_previous
201                .clone_from(&cache.ll_previous);
202            self.state
203                .fse_tables
204                .ml_previous
205                .clone_from(&cache.ml_previous);
206            self.state
207                .fse_tables
208                .of_previous
209                .clone_from(&cache.of_previous);
210        } else {
211            self.state.fse_tables.ll_previous = None;
212            self.state.fse_tables.ml_previous = None;
213            self.state.fse_tables.of_previous = None;
214        }
215        #[cfg(feature = "hash")]
216        {
217            self.hasher = XxHash64::with_seed(0);
218        }
219        let source = self.uncompressed_data.as_mut().unwrap();
220        let drain = self.compressed_data.as_mut().unwrap();
221        let window_size = self.state.matcher.window_size();
222        assert!(
223            window_size != 0,
224            "matcher reported window_size == 0, which is invalid"
225        );
226        // Accumulate all compressed blocks; the frame header is written after
227        // all input has been read so that Frame_Content_Size is known.
228        let mut all_blocks: Vec<u8> = Vec::with_capacity(1024 * 130);
229        let mut total_uncompressed: u64 = 0;
230        // Compress block by block
231        loop {
232            // Read a single block's worth of uncompressed data from the input
233            let mut uncompressed_data = self.state.matcher.get_next_space();
234            let mut read_bytes = 0;
235            let last_block;
236            'read_loop: loop {
237                let new_bytes = source.read(&mut uncompressed_data[read_bytes..]).unwrap();
238                if new_bytes == 0 {
239                    last_block = true;
240                    break 'read_loop;
241                }
242                read_bytes += new_bytes;
243                if read_bytes == uncompressed_data.len() {
244                    last_block = false;
245                    break 'read_loop;
246                }
247            }
248            uncompressed_data.resize(read_bytes, 0);
249            total_uncompressed += read_bytes as u64;
250            // As we read, hash that data too
251            #[cfg(feature = "hash")]
252            self.hasher.write(&uncompressed_data);
253            // Special handling is needed for compression of a totally empty file
254            if uncompressed_data.is_empty() {
255                let header = BlockHeader {
256                    last_block: true,
257                    block_type: crate::blocks::block::BlockType::Raw,
258                    block_size: 0,
259                };
260                header.serialize(&mut all_blocks);
261                break;
262            }
263
264            match self.compression_level {
265                CompressionLevel::Uncompressed => {
266                    let header = BlockHeader {
267                        last_block,
268                        block_type: crate::blocks::block::BlockType::Raw,
269                        block_size: read_bytes.try_into().unwrap(),
270                    };
271                    header.serialize(&mut all_blocks);
272                    all_blocks.extend_from_slice(&uncompressed_data);
273                }
274                CompressionLevel::Fastest
275                | CompressionLevel::Default
276                | CompressionLevel::Better
277                | CompressionLevel::Best => compress_block_encoded(
278                    &mut self.state,
279                    last_block,
280                    uncompressed_data,
281                    &mut all_blocks,
282                ),
283            }
284            if last_block {
285                break;
286            }
287        }
288
289        // Now that total_uncompressed is known, write the frame header with FCS.
290        // We always include the window descriptor (single_segment = false) because
291        // compressed blocks are encoded against the matcher's window, not the content
292        // size. Setting single_segment would tell the decoder to use FCS as window,
293        // which can be smaller than the encoder's actual window and trip up decoders.
294        let header = FrameHeader {
295            frame_content_size: Some(total_uncompressed),
296            single_segment: false,
297            content_checksum: cfg!(feature = "hash"),
298            dictionary_id: if use_dictionary_state {
299                self.dictionary.as_ref().map(|dict| dict.id as u64)
300            } else {
301                None
302            },
303            window_size: Some(window_size),
304        };
305        // Write the frame header and compressed blocks separately to avoid
306        // shifting the entire `all_blocks` buffer to prepend the header.
307        let mut header_buf: Vec<u8> = Vec::with_capacity(14);
308        header.serialize(&mut header_buf);
309        drain.write_all(&header_buf).unwrap();
310        drain.write_all(&all_blocks).unwrap();
311
312        // If the `hash` feature is enabled, then `content_checksum` is set to true in the header
313        // and a 32 bit hash is written at the end of the data.
314        #[cfg(feature = "hash")]
315        {
316            // Because we only have the data as a reader, we need to read all of it to calculate the checksum
317            // Possible TODO: create a wrapper around self.uncompressed data that hashes the data as it's read?
318            let content_checksum = self.hasher.finish();
319            drain
320                .write_all(&(content_checksum as u32).to_le_bytes())
321                .unwrap();
322        }
323    }
324
325    /// Get a mutable reference to the source
326    pub fn source_mut(&mut self) -> Option<&mut R> {
327        self.uncompressed_data.as_mut()
328    }
329
330    /// Get a mutable reference to the drain
331    pub fn drain_mut(&mut self) -> Option<&mut W> {
332        self.compressed_data.as_mut()
333    }
334
335    /// Get a reference to the source
336    pub fn source(&self) -> Option<&R> {
337        self.uncompressed_data.as_ref()
338    }
339
340    /// Get a reference to the drain
341    pub fn drain(&self) -> Option<&W> {
342        self.compressed_data.as_ref()
343    }
344
345    /// Retrieve the source
346    pub fn take_source(&mut self) -> Option<R> {
347        self.uncompressed_data.take()
348    }
349
350    /// Retrieve the drain
351    pub fn take_drain(&mut self) -> Option<W> {
352        self.compressed_data.take()
353    }
354
355    /// Before calling [FrameCompressor::compress] you can replace the matcher
356    pub fn replace_matcher(&mut self, mut match_generator: M) -> M {
357        core::mem::swap(&mut match_generator, &mut self.state.matcher);
358        match_generator
359    }
360
361    /// Before calling [FrameCompressor::compress] you can replace the compression level
362    pub fn set_compression_level(
363        &mut self,
364        compression_level: CompressionLevel,
365    ) -> CompressionLevel {
366        let old = self.compression_level;
367        self.compression_level = compression_level;
368        old
369    }
370
371    /// Get the current compression level
372    pub fn compression_level(&self) -> CompressionLevel {
373        self.compression_level
374    }
375
376    /// Attach a pre-parsed dictionary to be used for subsequent compressions.
377    ///
378    /// In compressed modes, the dictionary id is written only when the active
379    /// matcher supports dictionary priming.
380    /// Uncompressed mode and non-priming matchers ignore the attached dictionary
381    /// at encode time.
382    pub fn set_dictionary(
383        &mut self,
384        dictionary: crate::decoding::Dictionary,
385    ) -> Result<Option<crate::decoding::Dictionary>, crate::decoding::errors::DictionaryDecodeError>
386    {
387        if dictionary.id == 0 {
388            return Err(crate::decoding::errors::DictionaryDecodeError::ZeroDictionaryId);
389        }
390        if let Some(index) = dictionary.offset_hist.iter().position(|&rep| rep == 0) {
391            return Err(
392                crate::decoding::errors::DictionaryDecodeError::ZeroRepeatOffsetInDictionary {
393                    index: index as u8,
394                },
395            );
396        }
397        self.dictionary_entropy_cache = Some(CachedDictionaryEntropy {
398            huff: dictionary.huf.table.to_encoder_table(),
399            ll_previous: dictionary
400                .fse
401                .literal_lengths
402                .to_encoder_table()
403                .map(|table| PreviousFseTable::Custom(Box::new(table))),
404            ml_previous: dictionary
405                .fse
406                .match_lengths
407                .to_encoder_table()
408                .map(|table| PreviousFseTable::Custom(Box::new(table))),
409            of_previous: dictionary
410                .fse
411                .offsets
412                .to_encoder_table()
413                .map(|table| PreviousFseTable::Custom(Box::new(table))),
414        });
415        Ok(self.dictionary.replace(dictionary))
416    }
417
418    /// Parse and attach a serialized dictionary blob.
419    pub fn set_dictionary_from_bytes(
420        &mut self,
421        raw_dictionary: &[u8],
422    ) -> Result<Option<crate::decoding::Dictionary>, crate::decoding::errors::DictionaryDecodeError>
423    {
424        let dictionary = crate::decoding::Dictionary::decode_dict(raw_dictionary)?;
425        self.set_dictionary(dictionary)
426    }
427
428    /// Remove the attached dictionary.
429    pub fn clear_dictionary(&mut self) -> Option<crate::decoding::Dictionary> {
430        self.dictionary_entropy_cache = None;
431        self.dictionary.take()
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    #[cfg(all(feature = "dict_builder", feature = "std"))]
438    use alloc::format;
439    use alloc::vec;
440
441    use super::FrameCompressor;
442    use crate::common::MAGIC_NUM;
443    use crate::decoding::FrameDecoder;
444    use crate::encoding::{Matcher, Sequence};
445    use alloc::vec::Vec;
446
447    /// Frame content size is written correctly and C zstd can decompress the output.
448    #[cfg(feature = "std")]
449    #[test]
450    fn fcs_header_written_and_c_zstd_compatible() {
451        let levels = [
452            crate::encoding::CompressionLevel::Uncompressed,
453            crate::encoding::CompressionLevel::Fastest,
454            crate::encoding::CompressionLevel::Default,
455            crate::encoding::CompressionLevel::Better,
456            crate::encoding::CompressionLevel::Best,
457        ];
458        let fcs_2byte = vec![0xCDu8; 300]; // 300 bytes โ†’ 2-byte FCS (256..=65791 range)
459        let large = vec![0xABu8; 100_000];
460        let inputs: [&[u8]; 5] = [
461            &[],
462            &[0x00],
463            b"abcdefghijklmnopqrstuvwxy\n",
464            &fcs_2byte,
465            &large,
466        ];
467        for level in levels {
468            for data in &inputs {
469                let compressed = crate::encoding::compress_to_vec(*data, level);
470                // Verify FCS is present and correct
471                let header = crate::decoding::frame::read_frame_header(compressed.as_slice())
472                    .unwrap()
473                    .0;
474                assert_eq!(
475                    header.frame_content_size(),
476                    data.len() as u64,
477                    "FCS mismatch for len={} level={:?}",
478                    data.len(),
479                    level as u8,
480                );
481                // Confirm the FCS field is actually present in the header
482                // (not just the decoder returning 0 for absent FCS).
483                assert_ne!(
484                    header.descriptor.frame_content_size_bytes().unwrap(),
485                    0,
486                    "FCS field must be present for len={} level={:?}",
487                    data.len(),
488                    level as u8,
489                );
490                // Verify C zstd can decompress
491                let mut decoded = Vec::new();
492                zstd::stream::copy_decode(compressed.as_slice(), &mut decoded).unwrap();
493                assert_eq!(
494                    decoded.as_slice(),
495                    *data,
496                    "C zstd roundtrip failed for len={}",
497                    data.len()
498                );
499            }
500        }
501    }
502
503    struct NoDictionaryMatcher {
504        last_space: Vec<u8>,
505        window_size: u64,
506    }
507
508    impl NoDictionaryMatcher {
509        fn new(window_size: u64) -> Self {
510            Self {
511                last_space: Vec::new(),
512                window_size,
513            }
514        }
515    }
516
517    impl Matcher for NoDictionaryMatcher {
518        fn get_next_space(&mut self) -> Vec<u8> {
519            vec![0; self.window_size as usize]
520        }
521
522        fn get_last_space(&mut self) -> &[u8] {
523            self.last_space.as_slice()
524        }
525
526        fn commit_space(&mut self, space: Vec<u8>) {
527            self.last_space = space;
528        }
529
530        fn skip_matching(&mut self) {}
531
532        fn start_matching(&mut self, mut handle_sequence: impl for<'a> FnMut(Sequence<'a>)) {
533            handle_sequence(Sequence::Literals {
534                literals: self.last_space.as_slice(),
535            });
536        }
537
538        fn reset(&mut self, _level: super::CompressionLevel) {
539            self.last_space.clear();
540        }
541
542        fn window_size(&self) -> u64 {
543            self.window_size
544        }
545    }
546
547    #[test]
548    fn frame_starts_with_magic_num() {
549        let mock_data = [1_u8, 2, 3].as_slice();
550        let mut output: Vec<u8> = Vec::new();
551        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
552        compressor.set_source(mock_data);
553        compressor.set_drain(&mut output);
554
555        compressor.compress();
556        assert!(output.starts_with(&MAGIC_NUM.to_le_bytes()));
557    }
558
559    #[test]
560    fn very_simple_raw_compress() {
561        let mock_data = [1_u8, 2, 3].as_slice();
562        let mut output: Vec<u8> = Vec::new();
563        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
564        compressor.set_source(mock_data);
565        compressor.set_drain(&mut output);
566
567        compressor.compress();
568    }
569
570    #[test]
571    fn very_simple_compress() {
572        let mut mock_data = vec![0; 1 << 17];
573        mock_data.extend(vec![1; (1 << 17) - 1]);
574        mock_data.extend(vec![2; (1 << 18) - 1]);
575        mock_data.extend(vec![2; 1 << 17]);
576        mock_data.extend(vec![3; (1 << 17) - 1]);
577        let mut output: Vec<u8> = Vec::new();
578        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
579        compressor.set_source(mock_data.as_slice());
580        compressor.set_drain(&mut output);
581
582        compressor.compress();
583
584        let mut decoder = FrameDecoder::new();
585        let mut decoded = Vec::with_capacity(mock_data.len());
586        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
587        assert_eq!(mock_data, decoded);
588
589        let mut decoded = Vec::new();
590        zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
591        assert_eq!(mock_data, decoded);
592    }
593
594    #[test]
595    fn rle_compress() {
596        let mock_data = vec![0; 1 << 19];
597        let mut output: Vec<u8> = Vec::new();
598        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
599        compressor.set_source(mock_data.as_slice());
600        compressor.set_drain(&mut output);
601
602        compressor.compress();
603
604        let mut decoder = FrameDecoder::new();
605        let mut decoded = Vec::with_capacity(mock_data.len());
606        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
607        assert_eq!(mock_data, decoded);
608    }
609
610    #[test]
611    fn aaa_compress() {
612        let mock_data = vec![0, 1, 3, 4, 5];
613        let mut output: Vec<u8> = Vec::new();
614        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
615        compressor.set_source(mock_data.as_slice());
616        compressor.set_drain(&mut output);
617
618        compressor.compress();
619
620        let mut decoder = FrameDecoder::new();
621        let mut decoded = Vec::with_capacity(mock_data.len());
622        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
623        assert_eq!(mock_data, decoded);
624
625        let mut decoded = Vec::new();
626        zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
627        assert_eq!(mock_data, decoded);
628    }
629
630    #[test]
631    fn dictionary_compression_sets_required_dict_id_and_roundtrips() {
632        let dict_raw = include_bytes!("../../dict_tests/dictionary");
633        let dict_for_encoder = crate::decoding::Dictionary::decode_dict(dict_raw).unwrap();
634        let dict_for_decoder = crate::decoding::Dictionary::decode_dict(dict_raw).unwrap();
635
636        let mut data = Vec::new();
637        for _ in 0..8 {
638            data.extend_from_slice(&dict_for_decoder.dict_content[..2048]);
639        }
640
641        let mut with_dict = Vec::new();
642        let mut compressor = FrameCompressor::new(super::CompressionLevel::Fastest);
643        let previous = compressor
644            .set_dictionary_from_bytes(dict_raw)
645            .expect("dictionary bytes should parse");
646        assert!(
647            previous.is_none(),
648            "first dictionary insert should return None"
649        );
650        assert_eq!(
651            compressor
652                .set_dictionary(dict_for_encoder)
653                .expect("valid dictionary should attach")
654                .expect("set_dictionary_from_bytes inserted previous dictionary")
655                .id,
656            dict_for_decoder.id
657        );
658        compressor.set_source(data.as_slice());
659        compressor.set_drain(&mut with_dict);
660        compressor.compress();
661
662        let (frame_header, _) = crate::decoding::frame::read_frame_header(with_dict.as_slice())
663            .expect("encoded stream should have a frame header");
664        assert_eq!(frame_header.dictionary_id(), Some(dict_for_decoder.id));
665
666        let mut decoder = FrameDecoder::new();
667        let mut missing_dict_target = Vec::with_capacity(data.len());
668        let err = decoder
669            .decode_all_to_vec(&with_dict, &mut missing_dict_target)
670            .unwrap_err();
671        assert!(
672            matches!(
673                &err,
674                crate::decoding::errors::FrameDecoderError::DictNotProvided { .. }
675            ),
676            "dict-compressed stream should require dictionary id, got: {err:?}"
677        );
678
679        let mut decoder = FrameDecoder::new();
680        decoder.add_dict(dict_for_decoder).unwrap();
681        let mut decoded = Vec::with_capacity(data.len());
682        decoder.decode_all_to_vec(&with_dict, &mut decoded).unwrap();
683        assert_eq!(decoded, data);
684
685        let mut ffi_decoder = zstd::bulk::Decompressor::with_dictionary(dict_raw).unwrap();
686        let mut ffi_decoded = Vec::with_capacity(data.len());
687        let ffi_written = ffi_decoder
688            .decompress_to_buffer(with_dict.as_slice(), &mut ffi_decoded)
689            .unwrap();
690        assert_eq!(ffi_written, data.len());
691        assert_eq!(ffi_decoded, data);
692    }
693
694    #[cfg(all(feature = "dict_builder", feature = "std"))]
695    #[test]
696    fn dictionary_compression_roundtrips_with_dict_builder_dictionary() {
697        use std::io::Cursor;
698
699        let mut training = Vec::new();
700        for idx in 0..256u32 {
701            training.extend_from_slice(
702                format!("tenant=demo table=orders key={idx} region=eu\n").as_bytes(),
703            );
704        }
705        let mut raw_dict = Vec::new();
706        crate::dictionary::create_raw_dict_from_source(
707            Cursor::new(training.as_slice()),
708            training.len(),
709            &mut raw_dict,
710            4096,
711        );
712        assert!(
713            !raw_dict.is_empty(),
714            "dict_builder produced an empty dictionary"
715        );
716
717        let dict_id = 0xD1C7_0008;
718        let encoder_dict =
719            crate::decoding::Dictionary::from_raw_content(dict_id, raw_dict.clone()).unwrap();
720        let decoder_dict =
721            crate::decoding::Dictionary::from_raw_content(dict_id, raw_dict.clone()).unwrap();
722
723        let mut payload = Vec::new();
724        for idx in 0..96u32 {
725            payload.extend_from_slice(
726                format!(
727                    "tenant=demo table=orders op=put key={idx} value=aaaaabbbbbcccccdddddeeeee\n"
728                )
729                .as_bytes(),
730            );
731        }
732
733        let mut without_dict = Vec::new();
734        let mut baseline = FrameCompressor::new(super::CompressionLevel::Fastest);
735        baseline.set_source(payload.as_slice());
736        baseline.set_drain(&mut without_dict);
737        baseline.compress();
738
739        let mut with_dict = Vec::new();
740        let mut compressor = FrameCompressor::new(super::CompressionLevel::Fastest);
741        compressor
742            .set_dictionary(encoder_dict)
743            .expect("valid dict_builder dictionary should attach");
744        compressor.set_source(payload.as_slice());
745        compressor.set_drain(&mut with_dict);
746        compressor.compress();
747
748        let (frame_header, _) = crate::decoding::frame::read_frame_header(with_dict.as_slice())
749            .expect("encoded stream should have a frame header");
750        assert_eq!(frame_header.dictionary_id(), Some(dict_id));
751        let mut decoder = FrameDecoder::new();
752        decoder.add_dict(decoder_dict).unwrap();
753        let mut decoded = Vec::with_capacity(payload.len());
754        decoder.decode_all_to_vec(&with_dict, &mut decoded).unwrap();
755        assert_eq!(decoded, payload);
756        assert!(
757            with_dict.len() < without_dict.len(),
758            "trained dictionary should improve compression for this small payload"
759        );
760    }
761
762    #[test]
763    fn set_dictionary_from_bytes_seeds_entropy_tables_for_first_block() {
764        let dict_raw = include_bytes!("../../dict_tests/dictionary");
765        let mut output = Vec::new();
766        let input = b"";
767
768        let mut compressor = FrameCompressor::new(super::CompressionLevel::Fastest);
769        let previous = compressor
770            .set_dictionary_from_bytes(dict_raw)
771            .expect("dictionary bytes should parse");
772        assert!(previous.is_none());
773
774        compressor.set_source(input.as_slice());
775        compressor.set_drain(&mut output);
776        compressor.compress();
777
778        assert!(
779            compressor.state.last_huff_table.is_some(),
780            "dictionary entropy should seed previous huffman table before first block"
781        );
782        assert!(
783            compressor.state.fse_tables.ll_previous.is_some(),
784            "dictionary entropy should seed previous ll table before first block"
785        );
786        assert!(
787            compressor.state.fse_tables.ml_previous.is_some(),
788            "dictionary entropy should seed previous ml table before first block"
789        );
790        assert!(
791            compressor.state.fse_tables.of_previous.is_some(),
792            "dictionary entropy should seed previous of table before first block"
793        );
794    }
795
796    #[test]
797    fn set_dictionary_rejects_zero_dictionary_id() {
798        let invalid = crate::decoding::Dictionary {
799            id: 0,
800            fse: crate::decoding::scratch::FSEScratch::new(),
801            huf: crate::decoding::scratch::HuffmanScratch::new(),
802            dict_content: vec![1, 2, 3],
803            offset_hist: [1, 4, 8],
804        };
805
806        let mut compressor: FrameCompressor<
807            &[u8],
808            Vec<u8>,
809            crate::encoding::match_generator::MatchGeneratorDriver,
810        > = FrameCompressor::new(super::CompressionLevel::Fastest);
811        let result = compressor.set_dictionary(invalid);
812        assert!(matches!(
813            result,
814            Err(crate::decoding::errors::DictionaryDecodeError::ZeroDictionaryId)
815        ));
816    }
817
818    #[test]
819    fn set_dictionary_rejects_zero_repeat_offsets() {
820        let invalid = crate::decoding::Dictionary {
821            id: 1,
822            fse: crate::decoding::scratch::FSEScratch::new(),
823            huf: crate::decoding::scratch::HuffmanScratch::new(),
824            dict_content: vec![1, 2, 3],
825            offset_hist: [0, 4, 8],
826        };
827
828        let mut compressor: FrameCompressor<
829            &[u8],
830            Vec<u8>,
831            crate::encoding::match_generator::MatchGeneratorDriver,
832        > = FrameCompressor::new(super::CompressionLevel::Fastest);
833        let result = compressor.set_dictionary(invalid);
834        assert!(matches!(
835            result,
836            Err(
837                crate::decoding::errors::DictionaryDecodeError::ZeroRepeatOffsetInDictionary {
838                    index: 0
839                }
840            )
841        ));
842    }
843
844    #[test]
845    fn uncompressed_mode_does_not_require_dictionary() {
846        let dict_id = 0xABCD_0001;
847        let dict =
848            crate::decoding::Dictionary::from_raw_content(dict_id, b"shared-history".to_vec())
849                .expect("raw dictionary should be valid");
850
851        let payload = b"plain-bytes-that-should-stay-raw";
852        let mut output = Vec::new();
853        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
854        compressor
855            .set_dictionary(dict)
856            .expect("dictionary should attach in uncompressed mode");
857        compressor.set_source(payload.as_slice());
858        compressor.set_drain(&mut output);
859        compressor.compress();
860
861        let (frame_header, _) = crate::decoding::frame::read_frame_header(output.as_slice())
862            .expect("encoded frame should have a header");
863        assert_eq!(
864            frame_header.dictionary_id(),
865            None,
866            "raw/uncompressed frames must not advertise dictionary dependency"
867        );
868
869        let mut decoder = FrameDecoder::new();
870        let mut decoded = Vec::with_capacity(payload.len());
871        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
872        assert_eq!(decoded, payload);
873    }
874
875    #[test]
876    fn dictionary_roundtrip_stays_valid_after_output_exceeds_window() {
877        use crate::encoding::match_generator::MatchGeneratorDriver;
878
879        let dict_id = 0xABCD_0002;
880        let dict = crate::decoding::Dictionary::from_raw_content(dict_id, b"abcdefgh".to_vec())
881            .expect("raw dictionary should be valid");
882        let dict_for_decoder =
883            crate::decoding::Dictionary::from_raw_content(dict_id, b"abcdefgh".to_vec())
884                .expect("raw dictionary should be valid");
885
886        let payload = b"abcdefgh".repeat(512);
887        let matcher = MatchGeneratorDriver::new(8, 1);
888
889        let mut no_dict_output = Vec::new();
890        let mut no_dict_compressor =
891            FrameCompressor::new_with_matcher(matcher, super::CompressionLevel::Fastest);
892        no_dict_compressor.set_source(payload.as_slice());
893        no_dict_compressor.set_drain(&mut no_dict_output);
894        no_dict_compressor.compress();
895        let (no_dict_frame_header, _) =
896            crate::decoding::frame::read_frame_header(no_dict_output.as_slice())
897                .expect("baseline frame should have a header");
898        let no_dict_window = no_dict_frame_header
899            .window_size()
900            .expect("window size should be present");
901
902        let mut output = Vec::new();
903        let matcher = MatchGeneratorDriver::new(8, 1);
904        let mut compressor =
905            FrameCompressor::new_with_matcher(matcher, super::CompressionLevel::Fastest);
906        compressor
907            .set_dictionary(dict)
908            .expect("dictionary should attach");
909        compressor.set_source(payload.as_slice());
910        compressor.set_drain(&mut output);
911        compressor.compress();
912
913        let (frame_header, _) = crate::decoding::frame::read_frame_header(output.as_slice())
914            .expect("encoded frame should have a header");
915        let advertised_window = frame_header
916            .window_size()
917            .expect("window size should be present");
918        assert_eq!(
919            advertised_window, no_dict_window,
920            "dictionary priming must not inflate advertised window size"
921        );
922        assert!(
923            payload.len() > advertised_window as usize,
924            "test must cross the advertised window boundary"
925        );
926
927        let mut decoder = FrameDecoder::new();
928        decoder.add_dict(dict_for_decoder).unwrap();
929        let mut decoded = Vec::with_capacity(payload.len());
930        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
931        assert_eq!(decoded, payload);
932    }
933
934    #[test]
935    fn custom_matcher_without_dictionary_priming_does_not_advertise_dict_id() {
936        let dict_id = 0xABCD_0003;
937        let dict = crate::decoding::Dictionary::from_raw_content(dict_id, b"abcdefgh".to_vec())
938            .expect("raw dictionary should be valid");
939        let payload = b"abcdefghabcdefgh";
940
941        let mut output = Vec::new();
942        let matcher = NoDictionaryMatcher::new(64);
943        let mut compressor =
944            FrameCompressor::new_with_matcher(matcher, super::CompressionLevel::Fastest);
945        compressor
946            .set_dictionary(dict)
947            .expect("dictionary should attach");
948        compressor.set_source(payload.as_slice());
949        compressor.set_drain(&mut output);
950        compressor.compress();
951
952        let (frame_header, _) = crate::decoding::frame::read_frame_header(output.as_slice())
953            .expect("encoded frame should have a header");
954        assert_eq!(
955            frame_header.dictionary_id(),
956            None,
957            "matchers that do not support dictionary priming must not advertise dictionary dependency"
958        );
959
960        let mut decoder = FrameDecoder::new();
961        let mut decoded = Vec::with_capacity(payload.len());
962        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
963        assert_eq!(decoded, payload);
964    }
965
966    #[cfg(feature = "hash")]
967    #[test]
968    fn checksum_two_frames_reused_compressor() {
969        // Compress the same data twice using the same compressor and verify that:
970        // 1. The checksum written in each frame matches what the decoder calculates.
971        // 2. The hasher is correctly reset between frames (no cross-contamination).
972        //    If the hasher were NOT reset, the second frame's calculated checksum
973        //    would differ from the one stored in the frame data, causing assert_eq to fail.
974        let data: Vec<u8> = (0u8..=255).cycle().take(1024).collect();
975
976        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
977
978        // --- Frame 1 ---
979        let mut compressed1 = Vec::new();
980        compressor.set_source(data.as_slice());
981        compressor.set_drain(&mut compressed1);
982        compressor.compress();
983
984        // --- Frame 2 (reuse the same compressor) ---
985        let mut compressed2 = Vec::new();
986        compressor.set_source(data.as_slice());
987        compressor.set_drain(&mut compressed2);
988        compressor.compress();
989
990        fn decode_and_collect(compressed: &[u8]) -> (Vec<u8>, Option<u32>, Option<u32>) {
991            let mut decoder = FrameDecoder::new();
992            let mut source = compressed;
993            decoder.reset(&mut source).unwrap();
994            while !decoder.is_finished() {
995                decoder
996                    .decode_blocks(&mut source, crate::decoding::BlockDecodingStrategy::All)
997                    .unwrap();
998            }
999            let mut decoded = Vec::new();
1000            decoder.collect_to_writer(&mut decoded).unwrap();
1001            (
1002                decoded,
1003                decoder.get_checksum_from_data(),
1004                decoder.get_calculated_checksum(),
1005            )
1006        }
1007
1008        let (decoded1, chksum_from_data1, chksum_calculated1) = decode_and_collect(&compressed1);
1009        assert_eq!(decoded1, data, "frame 1: decoded data mismatch");
1010        assert_eq!(
1011            chksum_from_data1, chksum_calculated1,
1012            "frame 1: checksum mismatch"
1013        );
1014
1015        let (decoded2, chksum_from_data2, chksum_calculated2) = decode_and_collect(&compressed2);
1016        assert_eq!(decoded2, data, "frame 2: decoded data mismatch");
1017        assert_eq!(
1018            chksum_from_data2, chksum_calculated2,
1019            "frame 2: checksum mismatch"
1020        );
1021
1022        // Same data compressed twice must produce the same checksum.
1023        // If state leaked across frames, the second calculated checksum would differ.
1024        assert_eq!(
1025            chksum_from_data1, chksum_from_data2,
1026            "frame 1 and frame 2 should have the same checksum (same data, hash must reset per frame)"
1027        );
1028    }
1029
1030    #[cfg(feature = "std")]
1031    #[test]
1032    fn fuzz_targets() {
1033        use std::io::Read;
1034        fn decode_szstd(data: &mut dyn std::io::Read) -> Vec<u8> {
1035            let mut decoder = crate::decoding::StreamingDecoder::new(data).unwrap();
1036            let mut result: Vec<u8> = Vec::new();
1037            decoder.read_to_end(&mut result).expect("Decoding failed");
1038            result
1039        }
1040
1041        fn decode_szstd_writer(mut data: impl Read) -> Vec<u8> {
1042            let mut decoder = crate::decoding::FrameDecoder::new();
1043            decoder.reset(&mut data).unwrap();
1044            let mut result = vec![];
1045            while !decoder.is_finished() || decoder.can_collect() > 0 {
1046                decoder
1047                    .decode_blocks(
1048                        &mut data,
1049                        crate::decoding::BlockDecodingStrategy::UptoBytes(1024 * 1024),
1050                    )
1051                    .unwrap();
1052                decoder.collect_to_writer(&mut result).unwrap();
1053            }
1054            result
1055        }
1056
1057        fn encode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
1058            zstd::stream::encode_all(std::io::Cursor::new(data), 3)
1059        }
1060
1061        fn encode_szstd_uncompressed(data: &mut dyn std::io::Read) -> Vec<u8> {
1062            let mut input = Vec::new();
1063            data.read_to_end(&mut input).unwrap();
1064
1065            crate::encoding::compress_to_vec(
1066                input.as_slice(),
1067                crate::encoding::CompressionLevel::Uncompressed,
1068            )
1069        }
1070
1071        fn encode_szstd_compressed(data: &mut dyn std::io::Read) -> Vec<u8> {
1072            let mut input = Vec::new();
1073            data.read_to_end(&mut input).unwrap();
1074
1075            crate::encoding::compress_to_vec(
1076                input.as_slice(),
1077                crate::encoding::CompressionLevel::Fastest,
1078            )
1079        }
1080
1081        fn decode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
1082            let mut output = Vec::new();
1083            zstd::stream::copy_decode(data, &mut output)?;
1084            Ok(output)
1085        }
1086        if std::fs::exists("fuzz/artifacts/interop").unwrap_or(false) {
1087            for file in std::fs::read_dir("fuzz/artifacts/interop").unwrap() {
1088                if file.as_ref().unwrap().file_type().unwrap().is_file() {
1089                    let data = std::fs::read(file.unwrap().path()).unwrap();
1090                    let data = data.as_slice();
1091                    // Decoding
1092                    let compressed = encode_zstd(data).unwrap();
1093                    let decoded = decode_szstd(&mut compressed.as_slice());
1094                    let decoded2 = decode_szstd_writer(&mut compressed.as_slice());
1095                    assert!(
1096                        decoded == data,
1097                        "Decoded data did not match the original input during decompression"
1098                    );
1099                    assert_eq!(
1100                        decoded2, data,
1101                        "Decoded data did not match the original input during decompression"
1102                    );
1103
1104                    // Encoding
1105                    // Uncompressed encoding
1106                    let mut input = data;
1107                    let compressed = encode_szstd_uncompressed(&mut input);
1108                    let decoded = decode_zstd(&compressed).unwrap();
1109                    assert_eq!(
1110                        decoded, data,
1111                        "Decoded data did not match the original input during compression"
1112                    );
1113                    // Compressed encoding
1114                    let mut input = data;
1115                    let compressed = encode_szstd_compressed(&mut input);
1116                    let decoded = decode_zstd(&compressed).unwrap();
1117                    assert_eq!(
1118                        decoded, data,
1119                        "Decoded data did not match the original input during compression"
1120                    );
1121                }
1122            }
1123        }
1124    }
1125}