Skip to main content

structured_zstd/encoding/
frame_compressor.rs

1//! Utilities and interfaces for encoding an entire frame. Allows reusing resources
2
3use alloc::{boxed::Box, vec::Vec};
4use core::convert::TryInto;
5#[cfg(feature = "hash")]
6use twox_hash::XxHash64;
7
8#[cfg(feature = "hash")]
9use core::hash::Hasher;
10
11use super::{
12    CompressionLevel, Matcher, block_header::BlockHeader, frame_header::FrameHeader, levels::*,
13    match_generator::MatchGeneratorDriver,
14};
15use crate::fse::fse_encoder::{FSETable, default_ll_table, default_ml_table, default_of_table};
16
17use crate::io::{Read, Write};
18
19/// An interface for compressing arbitrary data with the ZStandard compression algorithm.
20///
21/// `FrameCompressor` will generally be used by:
22/// 1. Initializing a compressor by providing a buffer of data using `FrameCompressor::new()`
23/// 2. Starting compression and writing that compression into a vec using `FrameCompressor::begin`
24///
25/// # Examples
26/// ```
27/// use structured_zstd::encoding::{FrameCompressor, CompressionLevel};
28/// let mock_data: &[_] = &[0x1, 0x2, 0x3, 0x4];
29/// let mut output = std::vec::Vec::new();
30/// // Initialize a compressor.
31/// let mut compressor = FrameCompressor::new(CompressionLevel::Uncompressed);
32/// compressor.set_source(mock_data);
33/// compressor.set_drain(&mut output);
34///
35/// // `compress` writes the compressed output into the provided buffer.
36/// compressor.compress();
37/// ```
38pub struct FrameCompressor<R: Read, W: Write, M: Matcher> {
39    uncompressed_data: Option<R>,
40    compressed_data: Option<W>,
41    compression_level: CompressionLevel,
42    dictionary: Option<crate::decoding::Dictionary>,
43    dictionary_entropy_cache: Option<CachedDictionaryEntropy>,
44    state: CompressState<M>,
45    #[cfg(feature = "hash")]
46    hasher: XxHash64,
47}
48
49#[derive(Clone, Default)]
50struct CachedDictionaryEntropy {
51    huff: Option<crate::huff0::huff0_encoder::HuffmanTable>,
52    ll_previous: Option<PreviousFseTable>,
53    ml_previous: Option<PreviousFseTable>,
54    of_previous: Option<PreviousFseTable>,
55}
56
57#[derive(Clone)]
58pub(crate) enum PreviousFseTable {
59    // Default tables are immutable and already stored alongside the state, so
60    // repeating them only needs a lightweight marker instead of cloning FSETable.
61    Default,
62    Custom(Box<FSETable>),
63}
64
65impl PreviousFseTable {
66    pub(crate) fn as_table<'a>(&'a self, default: &'a FSETable) -> &'a FSETable {
67        match self {
68            Self::Default => default,
69            Self::Custom(table) => table,
70        }
71    }
72}
73
74pub(crate) struct FseTables {
75    pub(crate) ll_default: FSETable,
76    pub(crate) ll_previous: Option<PreviousFseTable>,
77    pub(crate) ml_default: FSETable,
78    pub(crate) ml_previous: Option<PreviousFseTable>,
79    pub(crate) of_default: FSETable,
80    pub(crate) of_previous: Option<PreviousFseTable>,
81}
82
83impl FseTables {
84    pub fn new() -> Self {
85        Self {
86            ll_default: default_ll_table(),
87            ll_previous: None,
88            ml_default: default_ml_table(),
89            ml_previous: None,
90            of_default: default_of_table(),
91            of_previous: None,
92        }
93    }
94}
95
96pub(crate) struct CompressState<M: Matcher> {
97    pub(crate) matcher: M,
98    pub(crate) last_huff_table: Option<crate::huff0::huff0_encoder::HuffmanTable>,
99    pub(crate) fse_tables: FseTables,
100    /// Offset history for repeat offset encoding: [rep0, rep1, rep2].
101    /// Initialized to [1, 4, 8] per RFC 8878 §3.1.2.5.
102    pub(crate) offset_hist: [u32; 3],
103}
104
105impl<R: Read, W: Write> FrameCompressor<R, W, MatchGeneratorDriver> {
106    /// Create a new `FrameCompressor`
107    pub fn new(compression_level: CompressionLevel) -> Self {
108        Self {
109            uncompressed_data: None,
110            compressed_data: None,
111            compression_level,
112            dictionary: None,
113            dictionary_entropy_cache: None,
114            state: CompressState {
115                matcher: MatchGeneratorDriver::new(1024 * 128, 1),
116                last_huff_table: None,
117                fse_tables: FseTables::new(),
118                offset_hist: [1, 4, 8],
119            },
120            #[cfg(feature = "hash")]
121            hasher: XxHash64::with_seed(0),
122        }
123    }
124}
125
126impl<R: Read, W: Write, M: Matcher> FrameCompressor<R, W, M> {
127    /// Create a new `FrameCompressor` with a custom matching algorithm implementation
128    pub fn new_with_matcher(matcher: M, compression_level: CompressionLevel) -> Self {
129        Self {
130            uncompressed_data: None,
131            compressed_data: None,
132            dictionary: None,
133            dictionary_entropy_cache: None,
134            state: CompressState {
135                matcher,
136                last_huff_table: None,
137                fse_tables: FseTables::new(),
138                offset_hist: [1, 4, 8],
139            },
140            compression_level,
141            #[cfg(feature = "hash")]
142            hasher: XxHash64::with_seed(0),
143        }
144    }
145
146    /// Before calling [FrameCompressor::compress] you need to set the source.
147    ///
148    /// This is the data that is compressed and written into the drain.
149    pub fn set_source(&mut self, uncompressed_data: R) -> Option<R> {
150        self.uncompressed_data.replace(uncompressed_data)
151    }
152
153    /// Before calling [FrameCompressor::compress] you need to set the drain.
154    ///
155    /// As the compressor compresses data, the drain serves as a place for the output to be writte.
156    pub fn set_drain(&mut self, compressed_data: W) -> Option<W> {
157        self.compressed_data.replace(compressed_data)
158    }
159
160    /// Compress the uncompressed data from the provided source as one Zstd frame and write it to the provided drain
161    ///
162    /// This will repeatedly call [Read::read] on the source to fill up blocks until the source returns 0 on the read call.
163    /// Also [Write::write_all] will be called on the drain after each block has been encoded.
164    ///
165    /// To avoid endlessly encoding from a potentially endless source (like a network socket) you can use the
166    /// [Read::take] function
167    pub fn compress(&mut self) {
168        // Clearing buffers to allow re-using of the compressor
169        self.state.matcher.reset(self.compression_level);
170        self.state.offset_hist = [1, 4, 8];
171        let use_dictionary_state =
172            !matches!(self.compression_level, CompressionLevel::Uncompressed)
173                && self.state.matcher.supports_dictionary_priming();
174        let cached_entropy = if use_dictionary_state {
175            self.dictionary_entropy_cache.as_ref()
176        } else {
177            None
178        };
179        if use_dictionary_state && let Some(dict) = self.dictionary.as_ref() {
180            // This state drives sequence encoding, while matcher priming below updates
181            // the match generator's internal repeat-offset history for match finding.
182            self.state.offset_hist = dict.offset_hist;
183            self.state
184                .matcher
185                .prime_with_dictionary(dict.dict_content.as_slice(), dict.offset_hist);
186        }
187        if let Some(cache) = cached_entropy {
188            self.state.last_huff_table.clone_from(&cache.huff);
189        } else {
190            self.state.last_huff_table = None;
191        }
192        // `clone_from` keeps frame-to-frame seeding cheap for reused compressors by
193        // reusing existing allocations where possible instead of reallocating every frame.
194        if let Some(cache) = cached_entropy {
195            self.state
196                .fse_tables
197                .ll_previous
198                .clone_from(&cache.ll_previous);
199            self.state
200                .fse_tables
201                .ml_previous
202                .clone_from(&cache.ml_previous);
203            self.state
204                .fse_tables
205                .of_previous
206                .clone_from(&cache.of_previous);
207        } else {
208            self.state.fse_tables.ll_previous = None;
209            self.state.fse_tables.ml_previous = None;
210            self.state.fse_tables.of_previous = None;
211        }
212        #[cfg(feature = "hash")]
213        {
214            self.hasher = XxHash64::with_seed(0);
215        }
216        let source = self.uncompressed_data.as_mut().unwrap();
217        let drain = self.compressed_data.as_mut().unwrap();
218        // As the frame is compressed, it's stored here
219        let output: &mut Vec<u8> = &mut Vec::with_capacity(1024 * 130);
220        // First write the frame header
221        let header = FrameHeader {
222            frame_content_size: None,
223            single_segment: false,
224            content_checksum: cfg!(feature = "hash"),
225            dictionary_id: if use_dictionary_state {
226                self.dictionary.as_ref().map(|dict| dict.id as u64)
227            } else {
228                None
229            },
230            window_size: Some(self.state.matcher.window_size()),
231        };
232        header.serialize(output);
233        // Now compress block by block
234        loop {
235            // Read a single block's worth of uncompressed data from the input
236            let mut uncompressed_data = self.state.matcher.get_next_space();
237            let mut read_bytes = 0;
238            let last_block;
239            'read_loop: loop {
240                let new_bytes = source.read(&mut uncompressed_data[read_bytes..]).unwrap();
241                if new_bytes == 0 {
242                    last_block = true;
243                    break 'read_loop;
244                }
245                read_bytes += new_bytes;
246                if read_bytes == uncompressed_data.len() {
247                    last_block = false;
248                    break 'read_loop;
249                }
250            }
251            uncompressed_data.resize(read_bytes, 0);
252            // As we read, hash that data too
253            #[cfg(feature = "hash")]
254            self.hasher.write(&uncompressed_data);
255            // Special handling is needed for compression of a totally empty file (why you'd want to do that, I don't know)
256            if uncompressed_data.is_empty() {
257                let header = BlockHeader {
258                    last_block: true,
259                    block_type: crate::blocks::block::BlockType::Raw,
260                    block_size: 0,
261                };
262                // Write the header, then the block
263                header.serialize(output);
264                drain.write_all(output).unwrap();
265                output.clear();
266                break;
267            }
268
269            match self.compression_level {
270                CompressionLevel::Uncompressed => {
271                    let header = BlockHeader {
272                        last_block,
273                        block_type: crate::blocks::block::BlockType::Raw,
274                        block_size: read_bytes.try_into().unwrap(),
275                    };
276                    // Write the header, then the block
277                    header.serialize(output);
278                    output.extend_from_slice(&uncompressed_data);
279                }
280                CompressionLevel::Fastest
281                | CompressionLevel::Default
282                | CompressionLevel::Better
283                | CompressionLevel::Best => {
284                    // All compressed levels share this block-encoding pipeline;
285                    // they differ only in the matcher backend and its parameters.
286                    compress_block_encoded(&mut self.state, last_block, uncompressed_data, output)
287                }
288            }
289            drain.write_all(output).unwrap();
290            output.clear();
291            if last_block {
292                break;
293            }
294        }
295
296        // If the `hash` feature is enabled, then `content_checksum` is set to true in the header
297        // and a 32 bit hash is written at the end of the data.
298        #[cfg(feature = "hash")]
299        {
300            // Because we only have the data as a reader, we need to read all of it to calculate the checksum
301            // Possible TODO: create a wrapper around self.uncompressed data that hashes the data as it's read?
302            let content_checksum = self.hasher.finish();
303            drain
304                .write_all(&(content_checksum as u32).to_le_bytes())
305                .unwrap();
306        }
307    }
308
309    /// Get a mutable reference to the source
310    pub fn source_mut(&mut self) -> Option<&mut R> {
311        self.uncompressed_data.as_mut()
312    }
313
314    /// Get a mutable reference to the drain
315    pub fn drain_mut(&mut self) -> Option<&mut W> {
316        self.compressed_data.as_mut()
317    }
318
319    /// Get a reference to the source
320    pub fn source(&self) -> Option<&R> {
321        self.uncompressed_data.as_ref()
322    }
323
324    /// Get a reference to the drain
325    pub fn drain(&self) -> Option<&W> {
326        self.compressed_data.as_ref()
327    }
328
329    /// Retrieve the source
330    pub fn take_source(&mut self) -> Option<R> {
331        self.uncompressed_data.take()
332    }
333
334    /// Retrieve the drain
335    pub fn take_drain(&mut self) -> Option<W> {
336        self.compressed_data.take()
337    }
338
339    /// Before calling [FrameCompressor::compress] you can replace the matcher
340    pub fn replace_matcher(&mut self, mut match_generator: M) -> M {
341        core::mem::swap(&mut match_generator, &mut self.state.matcher);
342        match_generator
343    }
344
345    /// Before calling [FrameCompressor::compress] you can replace the compression level
346    pub fn set_compression_level(
347        &mut self,
348        compression_level: CompressionLevel,
349    ) -> CompressionLevel {
350        let old = self.compression_level;
351        self.compression_level = compression_level;
352        old
353    }
354
355    /// Get the current compression level
356    pub fn compression_level(&self) -> CompressionLevel {
357        self.compression_level
358    }
359
360    /// Attach a pre-parsed dictionary to be used for subsequent compressions.
361    ///
362    /// In compressed modes, the dictionary id is written only when the active
363    /// matcher supports dictionary priming.
364    /// Uncompressed mode and non-priming matchers ignore the attached dictionary
365    /// at encode time.
366    pub fn set_dictionary(
367        &mut self,
368        dictionary: crate::decoding::Dictionary,
369    ) -> Result<Option<crate::decoding::Dictionary>, crate::decoding::errors::DictionaryDecodeError>
370    {
371        if dictionary.id == 0 {
372            return Err(crate::decoding::errors::DictionaryDecodeError::ZeroDictionaryId);
373        }
374        if let Some(index) = dictionary.offset_hist.iter().position(|&rep| rep == 0) {
375            return Err(
376                crate::decoding::errors::DictionaryDecodeError::ZeroRepeatOffsetInDictionary {
377                    index: index as u8,
378                },
379            );
380        }
381        self.dictionary_entropy_cache = Some(CachedDictionaryEntropy {
382            huff: dictionary.huf.table.to_encoder_table(),
383            ll_previous: dictionary
384                .fse
385                .literal_lengths
386                .to_encoder_table()
387                .map(|table| PreviousFseTable::Custom(Box::new(table))),
388            ml_previous: dictionary
389                .fse
390                .match_lengths
391                .to_encoder_table()
392                .map(|table| PreviousFseTable::Custom(Box::new(table))),
393            of_previous: dictionary
394                .fse
395                .offsets
396                .to_encoder_table()
397                .map(|table| PreviousFseTable::Custom(Box::new(table))),
398        });
399        Ok(self.dictionary.replace(dictionary))
400    }
401
402    /// Parse and attach a serialized dictionary blob.
403    pub fn set_dictionary_from_bytes(
404        &mut self,
405        raw_dictionary: &[u8],
406    ) -> Result<Option<crate::decoding::Dictionary>, crate::decoding::errors::DictionaryDecodeError>
407    {
408        let dictionary = crate::decoding::Dictionary::decode_dict(raw_dictionary)?;
409        self.set_dictionary(dictionary)
410    }
411
412    /// Remove the attached dictionary.
413    pub fn clear_dictionary(&mut self) -> Option<crate::decoding::Dictionary> {
414        self.dictionary_entropy_cache = None;
415        self.dictionary.take()
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    #[cfg(all(feature = "dict_builder", feature = "std"))]
422    use alloc::format;
423    use alloc::vec;
424
425    use super::FrameCompressor;
426    use crate::common::MAGIC_NUM;
427    use crate::decoding::FrameDecoder;
428    use crate::encoding::{Matcher, Sequence};
429    use alloc::vec::Vec;
430
431    struct NoDictionaryMatcher {
432        last_space: Vec<u8>,
433        window_size: u64,
434    }
435
436    impl NoDictionaryMatcher {
437        fn new(window_size: u64) -> Self {
438            Self {
439                last_space: Vec::new(),
440                window_size,
441            }
442        }
443    }
444
445    impl Matcher for NoDictionaryMatcher {
446        fn get_next_space(&mut self) -> Vec<u8> {
447            vec![0; self.window_size as usize]
448        }
449
450        fn get_last_space(&mut self) -> &[u8] {
451            self.last_space.as_slice()
452        }
453
454        fn commit_space(&mut self, space: Vec<u8>) {
455            self.last_space = space;
456        }
457
458        fn skip_matching(&mut self) {}
459
460        fn start_matching(&mut self, mut handle_sequence: impl for<'a> FnMut(Sequence<'a>)) {
461            handle_sequence(Sequence::Literals {
462                literals: self.last_space.as_slice(),
463            });
464        }
465
466        fn reset(&mut self, _level: super::CompressionLevel) {
467            self.last_space.clear();
468        }
469
470        fn window_size(&self) -> u64 {
471            self.window_size
472        }
473    }
474
475    #[test]
476    fn frame_starts_with_magic_num() {
477        let mock_data = [1_u8, 2, 3].as_slice();
478        let mut output: Vec<u8> = Vec::new();
479        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
480        compressor.set_source(mock_data);
481        compressor.set_drain(&mut output);
482
483        compressor.compress();
484        assert!(output.starts_with(&MAGIC_NUM.to_le_bytes()));
485    }
486
487    #[test]
488    fn very_simple_raw_compress() {
489        let mock_data = [1_u8, 2, 3].as_slice();
490        let mut output: Vec<u8> = Vec::new();
491        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
492        compressor.set_source(mock_data);
493        compressor.set_drain(&mut output);
494
495        compressor.compress();
496    }
497
498    #[test]
499    fn very_simple_compress() {
500        let mut mock_data = vec![0; 1 << 17];
501        mock_data.extend(vec![1; (1 << 17) - 1]);
502        mock_data.extend(vec![2; (1 << 18) - 1]);
503        mock_data.extend(vec![2; 1 << 17]);
504        mock_data.extend(vec![3; (1 << 17) - 1]);
505        let mut output: Vec<u8> = Vec::new();
506        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
507        compressor.set_source(mock_data.as_slice());
508        compressor.set_drain(&mut output);
509
510        compressor.compress();
511
512        let mut decoder = FrameDecoder::new();
513        let mut decoded = Vec::with_capacity(mock_data.len());
514        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
515        assert_eq!(mock_data, decoded);
516
517        let mut decoded = Vec::new();
518        zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
519        assert_eq!(mock_data, decoded);
520    }
521
522    #[test]
523    fn rle_compress() {
524        let mock_data = vec![0; 1 << 19];
525        let mut output: Vec<u8> = Vec::new();
526        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
527        compressor.set_source(mock_data.as_slice());
528        compressor.set_drain(&mut output);
529
530        compressor.compress();
531
532        let mut decoder = FrameDecoder::new();
533        let mut decoded = Vec::with_capacity(mock_data.len());
534        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
535        assert_eq!(mock_data, decoded);
536    }
537
538    #[test]
539    fn aaa_compress() {
540        let mock_data = vec![0, 1, 3, 4, 5];
541        let mut output: Vec<u8> = Vec::new();
542        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
543        compressor.set_source(mock_data.as_slice());
544        compressor.set_drain(&mut output);
545
546        compressor.compress();
547
548        let mut decoder = FrameDecoder::new();
549        let mut decoded = Vec::with_capacity(mock_data.len());
550        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
551        assert_eq!(mock_data, decoded);
552
553        let mut decoded = Vec::new();
554        zstd::stream::copy_decode(output.as_slice(), &mut decoded).unwrap();
555        assert_eq!(mock_data, decoded);
556    }
557
558    #[test]
559    fn dictionary_compression_sets_required_dict_id_and_roundtrips() {
560        let dict_raw = include_bytes!("../../dict_tests/dictionary");
561        let dict_for_encoder = crate::decoding::Dictionary::decode_dict(dict_raw).unwrap();
562        let dict_for_decoder = crate::decoding::Dictionary::decode_dict(dict_raw).unwrap();
563
564        let mut data = Vec::new();
565        for _ in 0..8 {
566            data.extend_from_slice(&dict_for_decoder.dict_content[..2048]);
567        }
568
569        let mut with_dict = Vec::new();
570        let mut compressor = FrameCompressor::new(super::CompressionLevel::Fastest);
571        let previous = compressor
572            .set_dictionary_from_bytes(dict_raw)
573            .expect("dictionary bytes should parse");
574        assert!(
575            previous.is_none(),
576            "first dictionary insert should return None"
577        );
578        assert_eq!(
579            compressor
580                .set_dictionary(dict_for_encoder)
581                .expect("valid dictionary should attach")
582                .expect("set_dictionary_from_bytes inserted previous dictionary")
583                .id,
584            dict_for_decoder.id
585        );
586        compressor.set_source(data.as_slice());
587        compressor.set_drain(&mut with_dict);
588        compressor.compress();
589
590        let (frame_header, _) = crate::decoding::frame::read_frame_header(with_dict.as_slice())
591            .expect("encoded stream should have a frame header");
592        assert_eq!(frame_header.dictionary_id(), Some(dict_for_decoder.id));
593
594        let mut decoder = FrameDecoder::new();
595        let mut missing_dict_target = Vec::with_capacity(data.len());
596        let err = decoder
597            .decode_all_to_vec(&with_dict, &mut missing_dict_target)
598            .unwrap_err();
599        assert!(
600            matches!(
601                &err,
602                crate::decoding::errors::FrameDecoderError::DictNotProvided { .. }
603            ),
604            "dict-compressed stream should require dictionary id, got: {err:?}"
605        );
606
607        let mut decoder = FrameDecoder::new();
608        decoder.add_dict(dict_for_decoder).unwrap();
609        let mut decoded = Vec::with_capacity(data.len());
610        decoder.decode_all_to_vec(&with_dict, &mut decoded).unwrap();
611        assert_eq!(decoded, data);
612
613        let mut ffi_decoder = zstd::bulk::Decompressor::with_dictionary(dict_raw).unwrap();
614        let mut ffi_decoded = Vec::with_capacity(data.len());
615        let ffi_written = ffi_decoder
616            .decompress_to_buffer(with_dict.as_slice(), &mut ffi_decoded)
617            .unwrap();
618        assert_eq!(ffi_written, data.len());
619        assert_eq!(ffi_decoded, data);
620    }
621
622    #[cfg(all(feature = "dict_builder", feature = "std"))]
623    #[test]
624    fn dictionary_compression_roundtrips_with_dict_builder_dictionary() {
625        use std::io::Cursor;
626
627        let mut training = Vec::new();
628        for idx in 0..256u32 {
629            training.extend_from_slice(
630                format!("tenant=demo table=orders key={idx} region=eu\n").as_bytes(),
631            );
632        }
633        let mut raw_dict = Vec::new();
634        crate::dictionary::create_raw_dict_from_source(
635            Cursor::new(training.as_slice()),
636            training.len(),
637            &mut raw_dict,
638            4096,
639        );
640        assert!(
641            !raw_dict.is_empty(),
642            "dict_builder produced an empty dictionary"
643        );
644
645        let dict_id = 0xD1C7_0008;
646        let encoder_dict =
647            crate::decoding::Dictionary::from_raw_content(dict_id, raw_dict.clone()).unwrap();
648        let decoder_dict =
649            crate::decoding::Dictionary::from_raw_content(dict_id, raw_dict.clone()).unwrap();
650
651        let mut payload = Vec::new();
652        for idx in 0..96u32 {
653            payload.extend_from_slice(
654                format!(
655                    "tenant=demo table=orders op=put key={idx} value=aaaaabbbbbcccccdddddeeeee\n"
656                )
657                .as_bytes(),
658            );
659        }
660
661        let mut without_dict = Vec::new();
662        let mut baseline = FrameCompressor::new(super::CompressionLevel::Fastest);
663        baseline.set_source(payload.as_slice());
664        baseline.set_drain(&mut without_dict);
665        baseline.compress();
666
667        let mut with_dict = Vec::new();
668        let mut compressor = FrameCompressor::new(super::CompressionLevel::Fastest);
669        compressor
670            .set_dictionary(encoder_dict)
671            .expect("valid dict_builder dictionary should attach");
672        compressor.set_source(payload.as_slice());
673        compressor.set_drain(&mut with_dict);
674        compressor.compress();
675
676        let (frame_header, _) = crate::decoding::frame::read_frame_header(with_dict.as_slice())
677            .expect("encoded stream should have a frame header");
678        assert_eq!(frame_header.dictionary_id(), Some(dict_id));
679        let mut decoder = FrameDecoder::new();
680        decoder.add_dict(decoder_dict).unwrap();
681        let mut decoded = Vec::with_capacity(payload.len());
682        decoder.decode_all_to_vec(&with_dict, &mut decoded).unwrap();
683        assert_eq!(decoded, payload);
684        assert!(
685            with_dict.len() < without_dict.len(),
686            "trained dictionary should improve compression for this small payload"
687        );
688    }
689
690    #[test]
691    fn set_dictionary_from_bytes_seeds_entropy_tables_for_first_block() {
692        let dict_raw = include_bytes!("../../dict_tests/dictionary");
693        let mut output = Vec::new();
694        let input = b"";
695
696        let mut compressor = FrameCompressor::new(super::CompressionLevel::Fastest);
697        let previous = compressor
698            .set_dictionary_from_bytes(dict_raw)
699            .expect("dictionary bytes should parse");
700        assert!(previous.is_none());
701
702        compressor.set_source(input.as_slice());
703        compressor.set_drain(&mut output);
704        compressor.compress();
705
706        assert!(
707            compressor.state.last_huff_table.is_some(),
708            "dictionary entropy should seed previous huffman table before first block"
709        );
710        assert!(
711            compressor.state.fse_tables.ll_previous.is_some(),
712            "dictionary entropy should seed previous ll table before first block"
713        );
714        assert!(
715            compressor.state.fse_tables.ml_previous.is_some(),
716            "dictionary entropy should seed previous ml table before first block"
717        );
718        assert!(
719            compressor.state.fse_tables.of_previous.is_some(),
720            "dictionary entropy should seed previous of table before first block"
721        );
722    }
723
724    #[test]
725    fn set_dictionary_rejects_zero_dictionary_id() {
726        let invalid = crate::decoding::Dictionary {
727            id: 0,
728            fse: crate::decoding::scratch::FSEScratch::new(),
729            huf: crate::decoding::scratch::HuffmanScratch::new(),
730            dict_content: vec![1, 2, 3],
731            offset_hist: [1, 4, 8],
732        };
733
734        let mut compressor: FrameCompressor<
735            &[u8],
736            Vec<u8>,
737            crate::encoding::match_generator::MatchGeneratorDriver,
738        > = FrameCompressor::new(super::CompressionLevel::Fastest);
739        let result = compressor.set_dictionary(invalid);
740        assert!(matches!(
741            result,
742            Err(crate::decoding::errors::DictionaryDecodeError::ZeroDictionaryId)
743        ));
744    }
745
746    #[test]
747    fn set_dictionary_rejects_zero_repeat_offsets() {
748        let invalid = crate::decoding::Dictionary {
749            id: 1,
750            fse: crate::decoding::scratch::FSEScratch::new(),
751            huf: crate::decoding::scratch::HuffmanScratch::new(),
752            dict_content: vec![1, 2, 3],
753            offset_hist: [0, 4, 8],
754        };
755
756        let mut compressor: FrameCompressor<
757            &[u8],
758            Vec<u8>,
759            crate::encoding::match_generator::MatchGeneratorDriver,
760        > = FrameCompressor::new(super::CompressionLevel::Fastest);
761        let result = compressor.set_dictionary(invalid);
762        assert!(matches!(
763            result,
764            Err(
765                crate::decoding::errors::DictionaryDecodeError::ZeroRepeatOffsetInDictionary {
766                    index: 0
767                }
768            )
769        ));
770    }
771
772    #[test]
773    fn uncompressed_mode_does_not_require_dictionary() {
774        let dict_id = 0xABCD_0001;
775        let dict =
776            crate::decoding::Dictionary::from_raw_content(dict_id, b"shared-history".to_vec())
777                .expect("raw dictionary should be valid");
778
779        let payload = b"plain-bytes-that-should-stay-raw";
780        let mut output = Vec::new();
781        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
782        compressor
783            .set_dictionary(dict)
784            .expect("dictionary should attach in uncompressed mode");
785        compressor.set_source(payload.as_slice());
786        compressor.set_drain(&mut output);
787        compressor.compress();
788
789        let (frame_header, _) = crate::decoding::frame::read_frame_header(output.as_slice())
790            .expect("encoded frame should have a header");
791        assert_eq!(
792            frame_header.dictionary_id(),
793            None,
794            "raw/uncompressed frames must not advertise dictionary dependency"
795        );
796
797        let mut decoder = FrameDecoder::new();
798        let mut decoded = Vec::with_capacity(payload.len());
799        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
800        assert_eq!(decoded, payload);
801    }
802
803    #[test]
804    fn dictionary_roundtrip_stays_valid_after_output_exceeds_window() {
805        use crate::encoding::match_generator::MatchGeneratorDriver;
806
807        let dict_id = 0xABCD_0002;
808        let dict = crate::decoding::Dictionary::from_raw_content(dict_id, b"abcdefgh".to_vec())
809            .expect("raw dictionary should be valid");
810        let dict_for_decoder =
811            crate::decoding::Dictionary::from_raw_content(dict_id, b"abcdefgh".to_vec())
812                .expect("raw dictionary should be valid");
813
814        let payload = b"abcdefgh".repeat(512);
815        let matcher = MatchGeneratorDriver::new(8, 1);
816
817        let mut no_dict_output = Vec::new();
818        let mut no_dict_compressor =
819            FrameCompressor::new_with_matcher(matcher, super::CompressionLevel::Fastest);
820        no_dict_compressor.set_source(payload.as_slice());
821        no_dict_compressor.set_drain(&mut no_dict_output);
822        no_dict_compressor.compress();
823        let (no_dict_frame_header, _) =
824            crate::decoding::frame::read_frame_header(no_dict_output.as_slice())
825                .expect("baseline frame should have a header");
826        let no_dict_window = no_dict_frame_header
827            .window_size()
828            .expect("window size should be present");
829
830        let mut output = Vec::new();
831        let matcher = MatchGeneratorDriver::new(8, 1);
832        let mut compressor =
833            FrameCompressor::new_with_matcher(matcher, super::CompressionLevel::Fastest);
834        compressor
835            .set_dictionary(dict)
836            .expect("dictionary should attach");
837        compressor.set_source(payload.as_slice());
838        compressor.set_drain(&mut output);
839        compressor.compress();
840
841        let (frame_header, _) = crate::decoding::frame::read_frame_header(output.as_slice())
842            .expect("encoded frame should have a header");
843        let advertised_window = frame_header
844            .window_size()
845            .expect("window size should be present");
846        assert_eq!(
847            advertised_window, no_dict_window,
848            "dictionary priming must not inflate advertised window size"
849        );
850        assert!(
851            payload.len() > advertised_window as usize,
852            "test must cross the advertised window boundary"
853        );
854
855        let mut decoder = FrameDecoder::new();
856        decoder.add_dict(dict_for_decoder).unwrap();
857        let mut decoded = Vec::with_capacity(payload.len());
858        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
859        assert_eq!(decoded, payload);
860    }
861
862    #[test]
863    fn custom_matcher_without_dictionary_priming_does_not_advertise_dict_id() {
864        let dict_id = 0xABCD_0003;
865        let dict = crate::decoding::Dictionary::from_raw_content(dict_id, b"abcdefgh".to_vec())
866            .expect("raw dictionary should be valid");
867        let payload = b"abcdefghabcdefgh";
868
869        let mut output = Vec::new();
870        let matcher = NoDictionaryMatcher::new(64);
871        let mut compressor =
872            FrameCompressor::new_with_matcher(matcher, super::CompressionLevel::Fastest);
873        compressor
874            .set_dictionary(dict)
875            .expect("dictionary should attach");
876        compressor.set_source(payload.as_slice());
877        compressor.set_drain(&mut output);
878        compressor.compress();
879
880        let (frame_header, _) = crate::decoding::frame::read_frame_header(output.as_slice())
881            .expect("encoded frame should have a header");
882        assert_eq!(
883            frame_header.dictionary_id(),
884            None,
885            "matchers that do not support dictionary priming must not advertise dictionary dependency"
886        );
887
888        let mut decoder = FrameDecoder::new();
889        let mut decoded = Vec::with_capacity(payload.len());
890        decoder.decode_all_to_vec(&output, &mut decoded).unwrap();
891        assert_eq!(decoded, payload);
892    }
893
894    #[cfg(feature = "hash")]
895    #[test]
896    fn checksum_two_frames_reused_compressor() {
897        // Compress the same data twice using the same compressor and verify that:
898        // 1. The checksum written in each frame matches what the decoder calculates.
899        // 2. The hasher is correctly reset between frames (no cross-contamination).
900        //    If the hasher were NOT reset, the second frame's calculated checksum
901        //    would differ from the one stored in the frame data, causing assert_eq to fail.
902        let data: Vec<u8> = (0u8..=255).cycle().take(1024).collect();
903
904        let mut compressor = FrameCompressor::new(super::CompressionLevel::Uncompressed);
905
906        // --- Frame 1 ---
907        let mut compressed1 = Vec::new();
908        compressor.set_source(data.as_slice());
909        compressor.set_drain(&mut compressed1);
910        compressor.compress();
911
912        // --- Frame 2 (reuse the same compressor) ---
913        let mut compressed2 = Vec::new();
914        compressor.set_source(data.as_slice());
915        compressor.set_drain(&mut compressed2);
916        compressor.compress();
917
918        fn decode_and_collect(compressed: &[u8]) -> (Vec<u8>, Option<u32>, Option<u32>) {
919            let mut decoder = FrameDecoder::new();
920            let mut source = compressed;
921            decoder.reset(&mut source).unwrap();
922            while !decoder.is_finished() {
923                decoder
924                    .decode_blocks(&mut source, crate::decoding::BlockDecodingStrategy::All)
925                    .unwrap();
926            }
927            let mut decoded = Vec::new();
928            decoder.collect_to_writer(&mut decoded).unwrap();
929            (
930                decoded,
931                decoder.get_checksum_from_data(),
932                decoder.get_calculated_checksum(),
933            )
934        }
935
936        let (decoded1, chksum_from_data1, chksum_calculated1) = decode_and_collect(&compressed1);
937        assert_eq!(decoded1, data, "frame 1: decoded data mismatch");
938        assert_eq!(
939            chksum_from_data1, chksum_calculated1,
940            "frame 1: checksum mismatch"
941        );
942
943        let (decoded2, chksum_from_data2, chksum_calculated2) = decode_and_collect(&compressed2);
944        assert_eq!(decoded2, data, "frame 2: decoded data mismatch");
945        assert_eq!(
946            chksum_from_data2, chksum_calculated2,
947            "frame 2: checksum mismatch"
948        );
949
950        // Same data compressed twice must produce the same checksum.
951        // If state leaked across frames, the second calculated checksum would differ.
952        assert_eq!(
953            chksum_from_data1, chksum_from_data2,
954            "frame 1 and frame 2 should have the same checksum (same data, hash must reset per frame)"
955        );
956    }
957
958    #[cfg(feature = "std")]
959    #[test]
960    fn fuzz_targets() {
961        use std::io::Read;
962        fn decode_szstd(data: &mut dyn std::io::Read) -> Vec<u8> {
963            let mut decoder = crate::decoding::StreamingDecoder::new(data).unwrap();
964            let mut result: Vec<u8> = Vec::new();
965            decoder.read_to_end(&mut result).expect("Decoding failed");
966            result
967        }
968
969        fn decode_szstd_writer(mut data: impl Read) -> Vec<u8> {
970            let mut decoder = crate::decoding::FrameDecoder::new();
971            decoder.reset(&mut data).unwrap();
972            let mut result = vec![];
973            while !decoder.is_finished() || decoder.can_collect() > 0 {
974                decoder
975                    .decode_blocks(
976                        &mut data,
977                        crate::decoding::BlockDecodingStrategy::UptoBytes(1024 * 1024),
978                    )
979                    .unwrap();
980                decoder.collect_to_writer(&mut result).unwrap();
981            }
982            result
983        }
984
985        fn encode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
986            zstd::stream::encode_all(std::io::Cursor::new(data), 3)
987        }
988
989        fn encode_szstd_uncompressed(data: &mut dyn std::io::Read) -> Vec<u8> {
990            let mut input = Vec::new();
991            data.read_to_end(&mut input).unwrap();
992
993            crate::encoding::compress_to_vec(
994                input.as_slice(),
995                crate::encoding::CompressionLevel::Uncompressed,
996            )
997        }
998
999        fn encode_szstd_compressed(data: &mut dyn std::io::Read) -> Vec<u8> {
1000            let mut input = Vec::new();
1001            data.read_to_end(&mut input).unwrap();
1002
1003            crate::encoding::compress_to_vec(
1004                input.as_slice(),
1005                crate::encoding::CompressionLevel::Fastest,
1006            )
1007        }
1008
1009        fn decode_zstd(data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
1010            let mut output = Vec::new();
1011            zstd::stream::copy_decode(data, &mut output)?;
1012            Ok(output)
1013        }
1014        if std::fs::exists("fuzz/artifacts/interop").unwrap_or(false) {
1015            for file in std::fs::read_dir("fuzz/artifacts/interop").unwrap() {
1016                if file.as_ref().unwrap().file_type().unwrap().is_file() {
1017                    let data = std::fs::read(file.unwrap().path()).unwrap();
1018                    let data = data.as_slice();
1019                    // Decoding
1020                    let compressed = encode_zstd(data).unwrap();
1021                    let decoded = decode_szstd(&mut compressed.as_slice());
1022                    let decoded2 = decode_szstd_writer(&mut compressed.as_slice());
1023                    assert!(
1024                        decoded == data,
1025                        "Decoded data did not match the original input during decompression"
1026                    );
1027                    assert_eq!(
1028                        decoded2, data,
1029                        "Decoded data did not match the original input during decompression"
1030                    );
1031
1032                    // Encoding
1033                    // Uncompressed encoding
1034                    let mut input = data;
1035                    let compressed = encode_szstd_uncompressed(&mut input);
1036                    let decoded = decode_zstd(&compressed).unwrap();
1037                    assert_eq!(
1038                        decoded, data,
1039                        "Decoded data did not match the original input during compression"
1040                    );
1041                    // Compressed encoding
1042                    let mut input = data;
1043                    let compressed = encode_szstd_compressed(&mut input);
1044                    let decoded = decode_zstd(&compressed).unwrap();
1045                    assert_eq!(
1046                        decoded, data,
1047                        "Decoded data did not match the original input during compression"
1048                    );
1049                }
1050            }
1051        }
1052    }
1053}