Skip to main content

structured_zstd/decoding/
frame_decoder.rs

1//! Framedecoder is the main low-level struct users interact with to decode zstd frames
2//!
3//! Zstandard compressed data is made of one or more frames. Each frame is independent and can be
4//! decompressed independently of other frames. This module contains structures
5//! and utilities that can be used to decode a frame.
6
7use super::frame;
8use crate::decoding;
9use crate::decoding::block_decoder::BlockDecoder;
10use crate::decoding::decode_buffer::DecodeBuffer;
11use crate::decoding::dictionary::{Dictionary, DictionaryHandle};
12use crate::decoding::errors::{DecodeBlockContentError, FrameDecoderError};
13use crate::decoding::flat_buf::FlatBuf;
14use crate::decoding::ringbuffer::RingBuffer;
15use crate::decoding::scratch::DecoderScratch;
16use crate::io::{Error, Read, Write};
17use alloc::collections::BTreeMap;
18use alloc::vec::Vec;
19use core::convert::TryInto;
20
21use crate::common::MAXIMUM_ALLOWED_WINDOW_SIZE;
22
23/// Low level Zstandard decoder that can be used to decompress frames with fine control over when and how many bytes are decoded.
24///
25/// This decoder is able to decode frames only partially and gives control
26/// over how many bytes/blocks will be decoded at a time (so you don't have to decode a 10GB file into memory all at once).
27/// It reads bytes as needed from a provided source and can be read from to collect partial results.
28///
29/// If you want to just read the whole frame with an `io::Read` without having to deal with manually calling [FrameDecoder::decode_blocks]
30/// you can use the provided [crate::decoding::StreamingDecoder] wich wraps this FrameDecoder.
31///
32/// Workflow is as follows:
33/// ```
34/// use structured_zstd::decoding::BlockDecodingStrategy;
35///
36/// # #[cfg(feature = "std")]
37/// use std::io::{Read, Write};
38///
39/// // no_std environments can use the crate's own Read traits
40/// # #[cfg(not(feature = "std"))]
41/// use structured_zstd::io::{Read, Write};
42///
43/// fn decode_this(mut file: impl Read) {
44///     //Create a new decoder
45///     let mut frame_dec = structured_zstd::decoding::FrameDecoder::new();
46///     let mut result = Vec::new();
47///
48///     // Use reset or init to make the decoder ready to decode the frame from the io::Read
49///     frame_dec.reset(&mut file).unwrap();
50///
51///     // Loop until the frame has been decoded completely
52///     while !frame_dec.is_finished() {
53///         // decode (roughly) batch_size many bytes
54///         frame_dec.decode_blocks(&mut file, BlockDecodingStrategy::UptoBytes(1024)).unwrap();
55///
56///         // read from the decoder to collect bytes from the internal buffer
57///         let bytes_read = frame_dec.read(result.as_mut_slice()).unwrap();
58///
59///         // then do something with it
60///         do_something(&result[0..bytes_read]);
61///     }
62///
63///     // handle the last chunk of data
64///     while frame_dec.can_collect() > 0 {
65///         let x = frame_dec.read(result.as_mut_slice()).unwrap();
66///
67///         do_something(&result[0..x]);
68///     }
69/// }
70///
71/// fn do_something(data: &[u8]) {
72/// # #[cfg(feature = "std")]
73///     std::io::stdout().write_all(data).unwrap();
74/// }
75/// ```
76pub struct FrameDecoder {
77    state: Option<FrameDecoderState>,
78    owned_dicts: BTreeMap<u32, Dictionary>,
79    #[cfg(target_has_atomic = "ptr")]
80    shared_dicts: BTreeMap<u32, DictionaryHandle>,
81    #[cfg(not(target_has_atomic = "ptr"))]
82    shared_dicts: (),
83}
84
85/// Backend-tagged decode scratch — chosen at frame-reset time based
86/// on the parsed `FrameHeader.descriptor.single_segment_flag()` and
87/// kept stable through the lifetime of the frame. The match in each
88/// helper below dispatches **once per call** (e.g. once per block in
89/// `decode_block_content`, once per drain in `drain_to_writer`) —
90/// never inside the hot push/repeat loop, which is fully
91/// monomorphised through the `DecoderScratch<B>` generic.
92enum DecoderScratchKind {
93    Ring(DecoderScratch<RingBuffer>),
94    Flat(DecoderScratch<FlatBuf>),
95}
96
97impl DecoderScratchKind {
98    fn new_ring(window_size: usize) -> Self {
99        let mut s = DecoderScratch::<RingBuffer>::new(window_size);
100        s.buffer.reserve(window_size);
101        Self::Ring(s)
102    }
103
104    /// Construct a flat-backed scratch sized for a single-segment
105    /// frame. `frame_content_size` is the upcoming output size in
106    /// bytes (== `window_size` when the flag is set).
107    fn new_flat(frame_content_size: usize) -> Self {
108        let flat = FlatBuf::with_capacity(frame_content_size);
109        // DecoderScratch's default ctor would discard the pre-sized
110        // FlatBuf — go through from_backend so the buffer carries the
111        // capacity the constructor wants.
112        let mut s = DecoderScratch::<FlatBuf>::new(frame_content_size);
113        s.buffer = DecodeBuffer::from_backend(flat, frame_content_size);
114        Self::Flat(s)
115    }
116
117    /// Reset (or transition between) backends for a new frame.
118    /// Reuses the existing `DecoderScratch` allocations (FSE / HUF
119    /// tables, sequence vec, etc.) when the backend kind is unchanged
120    /// — only the underlying buffer is re-sized for the new frame.
121    /// Building a fresh `DecoderScratch` on every frame would
122    /// re-allocate everything and was measured at +255 % vs ring on
123    /// small frames; reusing it keeps the small-frame cost flat.
124    fn reset(&mut self, frame: &frame::FrameHeader, window_size: usize) {
125        if frame.descriptor.single_segment_flag() {
126            match self {
127                Self::Flat(s) => {
128                    s.reset(window_size);
129                    // DecodeBuffer::reset clears + reserves
130                    // window_size; FlatBuf's reserve grows the
131                    // backing Vec if the new FCS is larger than
132                    // what's already allocated. No alloc when the
133                    // previous flat frame had >= this capacity.
134                }
135                Self::Ring(_) => *self = Self::new_flat(window_size),
136            }
137        } else {
138            match self {
139                Self::Ring(s) => s.reset(window_size),
140                Self::Flat(_) => *self = Self::new_ring(window_size),
141            }
142        }
143    }
144
145    fn init_from_dict(&mut self, dict: &Dictionary) {
146        match self {
147            Self::Ring(s) => s.init_from_dict(dict),
148            Self::Flat(s) => s.init_from_dict(dict),
149        }
150    }
151
152    #[inline]
153    fn buffer_len(&self) -> usize {
154        match self {
155            Self::Ring(s) => s.buffer.len(),
156            Self::Flat(s) => s.buffer.len(),
157        }
158    }
159
160    fn buffer_drain(&mut self) -> Vec<u8> {
161        match self {
162            Self::Ring(s) => s.buffer.drain(),
163            Self::Flat(s) => s.buffer.drain(),
164        }
165    }
166
167    fn buffer_drain_to_window_size(&mut self) -> Option<Vec<u8>> {
168        match self {
169            Self::Ring(s) => s.buffer.drain_to_window_size(),
170            Self::Flat(s) => s.buffer.drain_to_window_size(),
171        }
172    }
173
174    fn buffer_drain_to_writer(&mut self, sink: impl Write) -> Result<usize, Error> {
175        match self {
176            Self::Ring(s) => s.buffer.drain_to_writer(sink),
177            Self::Flat(s) => s.buffer.drain_to_writer(sink),
178        }
179    }
180
181    fn buffer_drain_to_window_size_writer(&mut self, sink: impl Write) -> Result<usize, Error> {
182        match self {
183            Self::Ring(s) => s.buffer.drain_to_window_size_writer(sink),
184            Self::Flat(s) => s.buffer.drain_to_window_size_writer(sink),
185        }
186    }
187
188    fn buffer_can_drain(&self) -> usize {
189        match self {
190            Self::Ring(s) => s.buffer.can_drain(),
191            Self::Flat(s) => s.buffer.can_drain(),
192        }
193    }
194
195    fn buffer_can_drain_to_window_size(&self) -> Option<usize> {
196        match self {
197            Self::Ring(s) => s.buffer.can_drain_to_window_size(),
198            Self::Flat(s) => s.buffer.can_drain_to_window_size(),
199        }
200    }
201
202    fn buffer_read(&mut self, target: &mut [u8]) -> Result<usize, Error> {
203        match self {
204            Self::Ring(s) => s.buffer.read(target),
205            Self::Flat(s) => s.buffer.read(target),
206        }
207    }
208
209    fn buffer_read_all(&mut self, target: &mut [u8]) -> Result<usize, Error> {
210        match self {
211            Self::Ring(s) => s.buffer.read_all(target),
212            Self::Flat(s) => s.buffer.read_all(target),
213        }
214    }
215
216    fn decode_block_content<R: Read>(
217        &mut self,
218        decoder: &mut BlockDecoder,
219        header: &crate::blocks::block::BlockHeader,
220        source: R,
221    ) -> Result<u64, DecodeBlockContentError> {
222        match self {
223            Self::Ring(s) => decoder.decode_block_content(header, s, source),
224            Self::Flat(s) => decoder.decode_block_content(header, s, source),
225        }
226    }
227
228    #[cfg(feature = "hash")]
229    fn hash_finish(&self) -> u64 {
230        use core::hash::Hasher;
231        match self {
232            Self::Ring(s) => s.buffer.hash.finish(),
233            Self::Flat(s) => s.buffer.hash.finish(),
234        }
235    }
236}
237
238struct FrameDecoderState {
239    pub frame_header: frame::FrameHeader,
240    decoder_scratch: DecoderScratchKind,
241    frame_finished: bool,
242    block_counter: usize,
243    bytes_read_counter: u64,
244    check_sum: Option<u32>,
245    using_dict: Option<u32>,
246}
247
248pub enum BlockDecodingStrategy {
249    All,
250    UptoBlocks(usize),
251    UptoBytes(usize),
252}
253
254impl FrameDecoderState {
255    /// Read the frame header from `source` and create a new decoder state.
256    ///
257    /// Pre-allocates the decode buffer to `window_size` so the first block
258    /// does not trigger incremental growth from zero capacity.
259    pub fn new(source: impl Read) -> Result<FrameDecoderState, FrameDecoderError> {
260        let (frame, header_size) = frame::read_frame_header(source)?;
261        let window_size = frame.window_size()?;
262
263        if window_size > MAXIMUM_ALLOWED_WINDOW_SIZE {
264            return Err(FrameDecoderError::WindowSizeTooBig {
265                requested: window_size,
266            });
267        }
268
269        let decoder_scratch = if frame.descriptor.single_segment_flag() {
270            DecoderScratchKind::new_flat(window_size as usize)
271        } else {
272            DecoderScratchKind::new_ring(window_size as usize)
273        };
274        Ok(FrameDecoderState {
275            frame_header: frame,
276            frame_finished: false,
277            block_counter: 0,
278            decoder_scratch,
279            bytes_read_counter: u64::from(header_size),
280            check_sum: None,
281            using_dict: None,
282        })
283    }
284
285    /// Reset this state for a new frame read from `source`, reusing existing allocations.
286    ///
287    /// `DecodeBuffer::reset` reserves `window_size` internally, so no
288    /// additional frame-level reservation is needed here. Further buffer
289    /// growth during decoding is performed on demand by the active block path.
290    pub fn reset(&mut self, source: impl Read) -> Result<(), FrameDecoderError> {
291        let (frame_header, header_size) = frame::read_frame_header(source)?;
292        let window_size = frame_header.window_size()?;
293
294        if window_size > MAXIMUM_ALLOWED_WINDOW_SIZE {
295            return Err(FrameDecoderError::WindowSizeTooBig {
296                requested: window_size,
297            });
298        }
299
300        self.decoder_scratch
301            .reset(&frame_header, window_size as usize);
302        self.frame_header = frame_header;
303        self.frame_finished = false;
304        self.block_counter = 0;
305        self.bytes_read_counter = u64::from(header_size);
306        self.check_sum = None;
307        self.using_dict = None;
308        Ok(())
309    }
310}
311
312impl Default for FrameDecoder {
313    fn default() -> Self {
314        Self::new()
315    }
316}
317
318impl FrameDecoder {
319    /// This will create a new decoder without allocating anything yet.
320    /// init()/reset() will allocate all needed buffers if it is the first time this decoder is used
321    /// else they just reset these buffers with not further allocations
322    pub fn new() -> FrameDecoder {
323        FrameDecoder {
324            state: None,
325            owned_dicts: BTreeMap::new(),
326            #[cfg(target_has_atomic = "ptr")]
327            shared_dicts: BTreeMap::new(),
328            #[cfg(not(target_has_atomic = "ptr"))]
329            shared_dicts: (),
330        }
331    }
332
333    #[cfg(target_has_atomic = "ptr")]
334    fn shared_dict_exists(&self, dict_id: u32) -> bool {
335        self.shared_dicts.contains_key(&dict_id)
336    }
337
338    #[cfg(not(target_has_atomic = "ptr"))]
339    fn shared_dict_exists(&self, _dict_id: u32) -> bool {
340        false
341    }
342
343    fn validate_registered_dictionary(dict: &Dictionary) -> Result<(), FrameDecoderError> {
344        use crate::decoding::errors::DictionaryDecodeError as dict_err;
345
346        if dict.id == 0 {
347            return Err(FrameDecoderError::from(dict_err::ZeroDictionaryId));
348        }
349        if let Some(index) = dict.offset_hist.iter().position(|&rep| rep == 0) {
350            return Err(FrameDecoderError::from(
351                dict_err::ZeroRepeatOffsetInDictionary { index: index as u8 },
352            ));
353        }
354        Ok(())
355    }
356
357    /// init() will allocate all needed buffers if it is the first time this decoder is used
358    /// else they just reset these buffers with not further allocations
359    ///
360    /// Note that all bytes currently in the decodebuffer from any previous frame will be lost. Collect them with collect()/collect_to_writer()
361    ///
362    /// equivalent to reset()
363    pub fn init(&mut self, source: impl Read) -> Result<(), FrameDecoderError> {
364        self.reset(source)
365    }
366
367    /// Initialize the decoder for a new frame using a pre-parsed dictionary handle.
368    ///
369    /// If the frame header has a dictionary ID, this validates it against
370    /// `dict.id()` and returns [`FrameDecoderError::DictIdMismatch`] on mismatch.
371    ///
372    /// If the header omits the optional dictionary ID, this still applies the
373    /// provided dictionary handle.
374    ///
375    /// # Warning
376    ///
377    /// This method always applies `dict` unless the frame header contains a
378    /// non-matching dictionary ID. Callers must only use this API when they
379    /// already know the frame was encoded with the provided dictionary, even if
380    /// the frame header omits the dictionary ID or encodes an explicit
381    /// dictionary ID of `0`.
382    ///
383    /// Passing a dictionary for a frame that was not encoded with it can
384    /// silently corrupt the decoded output.
385    pub fn init_with_dict_handle(
386        &mut self,
387        source: impl Read,
388        dict: &DictionaryHandle,
389    ) -> Result<(), FrameDecoderError> {
390        self.reset_with_dict_handle(source, dict)
391    }
392
393    /// reset() will allocate all needed buffers if it is the first time this decoder is used
394    /// else they just reset these buffers with not further allocations
395    ///
396    /// Note that all bytes currently in the decodebuffer from any previous frame will be lost. Collect them with collect()/collect_to_writer()
397    ///
398    /// equivalent to init()
399    pub fn reset(&mut self, source: impl Read) -> Result<(), FrameDecoderError> {
400        use FrameDecoderError as err;
401        let dict_id = match &mut self.state {
402            Some(s) => {
403                s.reset(source)?;
404                s.frame_header.dictionary_id()
405            }
406            None => {
407                self.state = Some(FrameDecoderState::new(source)?);
408                self.state
409                    .as_ref()
410                    .and_then(|state| state.frame_header.dictionary_id())
411            }
412        };
413        if let Some(dict_id) = dict_id {
414            let state = self.state.as_mut().expect("state initialized");
415            let owned_dicts = &self.owned_dicts;
416            #[cfg(target_has_atomic = "ptr")]
417            let shared_dicts = &self.shared_dicts;
418            let dict = owned_dicts
419                .get(&dict_id)
420                .or_else(|| {
421                    #[cfg(target_has_atomic = "ptr")]
422                    {
423                        shared_dicts.get(&dict_id).map(DictionaryHandle::as_dict)
424                    }
425                    #[cfg(not(target_has_atomic = "ptr"))]
426                    {
427                        None
428                    }
429                })
430                .ok_or(err::DictNotProvided { dict_id })?;
431            state.decoder_scratch.init_from_dict(dict);
432            state.using_dict = Some(dict_id);
433        }
434        Ok(())
435    }
436
437    /// Reset this decoder for a new frame using a pre-parsed dictionary handle.
438    ///
439    /// If the frame header has a dictionary ID, this validates it against
440    /// `dict.id()` and returns [`FrameDecoderError::DictIdMismatch`] on mismatch.
441    ///
442    /// If the header omits the optional dictionary ID, this still applies the
443    /// provided dictionary handle.
444    ///
445    /// # Warning
446    ///
447    /// This method always applies `dict` unless the frame header contains a
448    /// non-matching dictionary ID. Callers must only use this API when they
449    /// already know the frame was encoded with the provided dictionary, even if
450    /// the frame header omits the dictionary ID or encodes an explicit
451    /// dictionary ID of `0`.
452    ///
453    /// Passing a dictionary for a frame that was not encoded with it can
454    /// silently corrupt the decoded output.
455    pub fn reset_with_dict_handle(
456        &mut self,
457        source: impl Read,
458        dict: &DictionaryHandle,
459    ) -> Result<(), FrameDecoderError> {
460        use FrameDecoderError as err;
461        Self::validate_registered_dictionary(dict.as_dict())?;
462        let state = match &mut self.state {
463            Some(s) => {
464                s.reset(source)?;
465                s
466            }
467            None => {
468                self.state = Some(FrameDecoderState::new(source)?);
469                self.state.as_mut().unwrap()
470            }
471        };
472        if let Some(dict_id) = state.frame_header.dictionary_id()
473            && dict_id != dict.id()
474        {
475            return Err(err::DictIdMismatch {
476                expected: dict_id,
477                provided: dict.id(),
478            });
479        }
480        state.decoder_scratch.init_from_dict(dict.as_dict());
481        state.using_dict = Some(dict.id());
482        Ok(())
483    }
484
485    /// Add a dictionary that can be selected dynamically by frame dictionary ID.
486    ///
487    /// Returns [`FrameDecoderError::DictAlreadyRegistered`] if the ID is already
488    /// registered (either as owned or shared).
489    pub fn add_dict(&mut self, dict: Dictionary) -> Result<(), FrameDecoderError> {
490        Self::validate_registered_dictionary(&dict)?;
491        let dict_id = dict.id;
492        if self.owned_dicts.contains_key(&dict_id) || self.shared_dict_exists(dict_id) {
493            return Err(FrameDecoderError::DictAlreadyRegistered { dict_id });
494        }
495        self.owned_dicts.insert(dict_id, dict);
496        Ok(())
497    }
498
499    /// Parse and add a serialized dictionary blob.
500    pub fn add_dict_from_bytes(&mut self, raw_dictionary: &[u8]) -> Result<(), FrameDecoderError> {
501        let dict = Dictionary::decode_dict(raw_dictionary)?;
502        self.add_dict(dict)
503    }
504
505    /// Add a pre-parsed dictionary handle for reuse across decoders.
506    ///
507    /// This API is available on targets with pointer-width atomics
508    /// (`target_has_atomic = "ptr"`).
509    ///
510    /// Returns [`FrameDecoderError::DictAlreadyRegistered`] if the ID is already
511    /// registered (either as owned or shared).
512    #[cfg(target_has_atomic = "ptr")]
513    pub fn add_dict_handle(&mut self, dict: DictionaryHandle) -> Result<(), FrameDecoderError> {
514        Self::validate_registered_dictionary(dict.as_dict())?;
515        let dict_id = dict.id();
516        if self.owned_dicts.contains_key(&dict_id) || self.shared_dicts.contains_key(&dict_id) {
517            return Err(FrameDecoderError::DictAlreadyRegistered { dict_id });
518        }
519        self.shared_dicts.insert(dict_id, dict);
520        Ok(())
521    }
522
523    pub fn force_dict(&mut self, dict_id: u32) -> Result<(), FrameDecoderError> {
524        use FrameDecoderError as err;
525        let state = self.state.as_mut().ok_or(err::NotYetInitialized)?;
526        let owned_dicts = &self.owned_dicts;
527        #[cfg(target_has_atomic = "ptr")]
528        let shared_dicts = &self.shared_dicts;
529
530        let dict = owned_dicts
531            .get(&dict_id)
532            .or_else(|| {
533                #[cfg(target_has_atomic = "ptr")]
534                {
535                    shared_dicts.get(&dict_id).map(DictionaryHandle::as_dict)
536                }
537                #[cfg(not(target_has_atomic = "ptr"))]
538                {
539                    None
540                }
541            })
542            .ok_or(err::DictNotProvided { dict_id })?;
543        state.decoder_scratch.init_from_dict(dict);
544        state.using_dict = Some(dict_id);
545
546        Ok(())
547    }
548
549    /// Returns how many bytes the frame contains after decompression
550    pub fn content_size(&self) -> u64 {
551        match &self.state {
552            None => 0,
553            Some(s) => s.frame_header.frame_content_size(),
554        }
555    }
556
557    /// Returns the checksum that was read from the data. Only available after all bytes have been read. It is the last 4 bytes of a zstd-frame
558    pub fn get_checksum_from_data(&self) -> Option<u32> {
559        let state = self.state.as_ref()?;
560
561        state.check_sum
562    }
563
564    /// Returns the checksum that was calculated while decoding.
565    /// Only a sensible value after all decoded bytes have been collected/read from the FrameDecoder
566    #[cfg(feature = "hash")]
567    pub fn get_calculated_checksum(&self) -> Option<u32> {
568        let state = self.state.as_ref()?;
569        let cksum_64bit = state.decoder_scratch.hash_finish();
570        //truncate to lower 32bit because reasons...
571        Some(cksum_64bit as u32)
572    }
573
574    /// Counter for how many bytes have been consumed while decoding the frame
575    pub fn bytes_read_from_source(&self) -> u64 {
576        let state = match &self.state {
577            None => return 0,
578            Some(s) => s,
579        };
580        state.bytes_read_counter
581    }
582
583    /// Whether the current frames last block has been decoded yet
584    /// If this returns true you can call the drain* functions to get all content
585    /// (the read() function will drain automatically if this returns true)
586    pub fn is_finished(&self) -> bool {
587        let state = match &self.state {
588            None => return true,
589            Some(s) => s,
590        };
591        if state.frame_header.descriptor.content_checksum_flag() {
592            state.frame_finished && state.check_sum.is_some()
593        } else {
594            state.frame_finished
595        }
596    }
597
598    /// Counter for how many blocks have already been decoded
599    pub fn blocks_decoded(&self) -> usize {
600        let state = match &self.state {
601            None => return 0,
602            Some(s) => s,
603        };
604        state.block_counter
605    }
606
607    /// Decodes blocks from a reader. It requires that the framedecoder has been initialized first.
608    /// The Strategy influences how many blocks will be decoded before the function returns
609    /// This is important if you want to manage memory consumption carefully. If you don't care
610    /// about that you can just choose the strategy "All" and have all blocks of the frame decoded into the buffer
611    pub fn decode_blocks(
612        &mut self,
613        mut source: impl Read,
614        strat: BlockDecodingStrategy,
615    ) -> Result<bool, FrameDecoderError> {
616        use FrameDecoderError as err;
617        let state = self.state.as_mut().ok_or(err::NotYetInitialized)?;
618
619        let mut block_dec = decoding::block_decoder::new();
620
621        let buffer_size_before = state.decoder_scratch.buffer_len();
622        let block_counter_before = state.block_counter;
623        loop {
624            vprintln!("################");
625            vprintln!("Next Block: {}", state.block_counter);
626            vprintln!("################");
627            let (block_header, block_header_size) = block_dec
628                .read_block_header(&mut source)
629                .map_err(err::FailedToReadBlockHeader)?;
630            state.bytes_read_counter += u64::from(block_header_size);
631
632            vprintln!();
633            vprintln!(
634                "Found {} block with size: {}, which will be of size: {}",
635                block_header.block_type,
636                block_header.content_size,
637                block_header.decompressed_size
638            );
639
640            let bytes_read_in_block_body = state
641                .decoder_scratch
642                .decode_block_content(&mut block_dec, &block_header, &mut source)
643                .map_err(err::FailedToReadBlockBody)?;
644            state.bytes_read_counter += bytes_read_in_block_body;
645
646            state.block_counter += 1;
647
648            vprintln!("Output: {}", state.decoder_scratch.buffer_len());
649
650            if block_header.last_block {
651                state.frame_finished = true;
652                if state.frame_header.descriptor.content_checksum_flag() {
653                    let mut chksum = [0u8; 4];
654                    source
655                        .read_exact(&mut chksum)
656                        .map_err(err::FailedToReadChecksum)?;
657                    state.bytes_read_counter += 4;
658                    let chksum = u32::from_le_bytes(chksum);
659                    state.check_sum = Some(chksum);
660                }
661                break;
662            }
663
664            match strat {
665                BlockDecodingStrategy::All => { /* keep going */ }
666                BlockDecodingStrategy::UptoBlocks(n) => {
667                    if state.block_counter - block_counter_before >= n {
668                        break;
669                    }
670                }
671                BlockDecodingStrategy::UptoBytes(n) => {
672                    if state.decoder_scratch.buffer_len() - buffer_size_before >= n {
673                        break;
674                    }
675                }
676            }
677        }
678
679        Ok(state.frame_finished)
680    }
681
682    /// Collect bytes and retain window_size bytes while decoding is still going on.
683    /// After decoding of the frame (is_finished() == true) has finished it will collect all remaining bytes
684    pub fn collect(&mut self) -> Option<Vec<u8>> {
685        let finished = self.is_finished();
686        let state = self.state.as_mut()?;
687        if finished {
688            Some(state.decoder_scratch.buffer_drain())
689        } else {
690            state.decoder_scratch.buffer_drain_to_window_size()
691        }
692    }
693
694    /// Collect bytes and retain window_size bytes while decoding is still going on.
695    /// After decoding of the frame (is_finished() == true) has finished it will collect all remaining bytes
696    pub fn collect_to_writer(&mut self, w: impl Write) -> Result<usize, Error> {
697        let finished = self.is_finished();
698        let state = match &mut self.state {
699            None => return Ok(0),
700            Some(s) => s,
701        };
702        if finished {
703            state.decoder_scratch.buffer_drain_to_writer(w)
704        } else {
705            state.decoder_scratch.buffer_drain_to_window_size_writer(w)
706        }
707    }
708
709    /// How many bytes can currently be collected from the decodebuffer, while decoding is going on this will be lower than the actual decodbuffer size
710    /// because window_size bytes need to be retained for decoding.
711    /// After decoding of the frame (is_finished() == true) has finished it will report all remaining bytes
712    pub fn can_collect(&self) -> usize {
713        let finished = self.is_finished();
714        let state = match &self.state {
715            None => return 0,
716            Some(s) => s,
717        };
718        if finished {
719            state.decoder_scratch.buffer_can_drain()
720        } else {
721            state
722                .decoder_scratch
723                .buffer_can_drain_to_window_size()
724                .unwrap_or(0)
725        }
726    }
727
728    /// Decodes as many blocks as possible from the source slice and reads from the decodebuffer into the target slice
729    /// The source slice may contain only parts of a frame but must contain at least one full block to make progress
730    ///
731    /// By all means use decode_blocks if you have a io.Reader available. This is just for compatibility with other decompressors
732    /// which try to serve an old-style c api
733    ///
734    /// Returns (read, written), if read == 0 then the source did not contain a full block and further calls with the same
735    /// input will not make any progress!
736    ///
737    /// Note that no kind of block can be bigger than 128kb.
738    /// So to be safe use at least 128*1024 (max block content size) + 3 (block_header size) + 18 (max frame_header size) bytes as your source buffer
739    ///
740    /// You may call this function with an empty source after all bytes have been decoded. This is equivalent to just call decoder.read(&mut target)
741    pub fn decode_from_to(
742        &mut self,
743        source: &[u8],
744        target: &mut [u8],
745    ) -> Result<(usize, usize), FrameDecoderError> {
746        use FrameDecoderError as err;
747        let bytes_read_at_start = match &self.state {
748            Some(s) => s.bytes_read_counter,
749            None => 0,
750        };
751
752        if !self.is_finished() || self.state.is_none() {
753            let mut mt_source = source;
754
755            if self.state.is_none() {
756                self.init(&mut mt_source)?;
757            }
758
759            //pseudo block to scope "state" so we can borrow self again after the block
760            {
761                let state = match &mut self.state {
762                    Some(s) => s,
763                    None => panic!("Bug in library"),
764                };
765                let mut block_dec = decoding::block_decoder::new();
766
767                if state.frame_header.descriptor.content_checksum_flag()
768                    && state.frame_finished
769                    && state.check_sum.is_none()
770                {
771                    //this block is needed if the checksum were the only 4 bytes that were not included in the last decode_from_to call for a frame
772                    if mt_source.len() >= 4 {
773                        let chksum = mt_source[..4].try_into().expect("optimized away");
774                        state.bytes_read_counter += 4;
775                        let chksum = u32::from_le_bytes(chksum);
776                        state.check_sum = Some(chksum);
777                    }
778                    return Ok((4, 0));
779                }
780
781                loop {
782                    //check if there are enough bytes for the next header
783                    if mt_source.len() < 3 {
784                        break;
785                    }
786                    let (block_header, block_header_size) = block_dec
787                        .read_block_header(&mut mt_source)
788                        .map_err(err::FailedToReadBlockHeader)?;
789
790                    // check the needed size for the block before updating counters.
791                    // If not enough bytes are in the source, the header will have to be read again, so act like we never read it in the first place
792                    if mt_source.len() < block_header.content_size as usize {
793                        break;
794                    }
795                    state.bytes_read_counter += u64::from(block_header_size);
796
797                    let bytes_read_in_block_body = state
798                        .decoder_scratch
799                        .decode_block_content(&mut block_dec, &block_header, &mut mt_source)
800                        .map_err(err::FailedToReadBlockBody)?;
801                    state.bytes_read_counter += bytes_read_in_block_body;
802                    state.block_counter += 1;
803
804                    if block_header.last_block {
805                        state.frame_finished = true;
806                        if state.frame_header.descriptor.content_checksum_flag() {
807                            //if there are enough bytes handle this here. Else the block at the start of this function will handle it at the next call
808                            if mt_source.len() >= 4 {
809                                let chksum = mt_source[..4].try_into().expect("optimized away");
810                                state.bytes_read_counter += 4;
811                                let chksum = u32::from_le_bytes(chksum);
812                                state.check_sum = Some(chksum);
813                            }
814                        }
815                        break;
816                    }
817                }
818            }
819        }
820
821        let result_len = self.read(target).map_err(err::FailedToDrainDecodebuffer)?;
822        let bytes_read_at_end = match &mut self.state {
823            Some(s) => s.bytes_read_counter,
824            None => panic!("Bug in library"),
825        };
826        let read_len = bytes_read_at_end - bytes_read_at_start;
827        Ok((read_len as usize, result_len))
828    }
829
830    /// Decode multiple frames into the output slice.
831    ///
832    /// `input` must contain an exact number of frames. Skippable frames are allowed and will be
833    /// skipped during decode.
834    ///
835    /// `output` must be large enough to hold the decompressed data. If you don't know
836    /// how large the output will be, use [`FrameDecoder::decode_blocks`] instead.
837    ///
838    /// This calls [`FrameDecoder::init`], and all bytes currently in the decoder will be lost.
839    ///
840    /// Returns the number of bytes written to `output`.
841    pub fn decode_all(
842        &mut self,
843        input: &[u8],
844        output: &mut [u8],
845    ) -> Result<usize, FrameDecoderError> {
846        self.decode_all_impl(input, output, |this, src| this.init(src))
847    }
848
849    /// Decode multiple frames into the output slice using a pre-parsed dictionary handle.
850    ///
851    /// `input` must contain an exact number of frames. Skippable frames are allowed and will be
852    /// skipped during decode.
853    ///
854    /// `output` must be large enough to hold the decompressed data. If you don't know
855    /// how large the output will be, use [`FrameDecoder::decode_blocks`] instead.
856    ///
857    /// This calls [`FrameDecoder::init_with_dict_handle`], and all bytes currently in the
858    /// decoder will be lost.
859    ///
860    /// # Warning
861    ///
862    /// Each decoded frame is initialized with `dict`, even when a frame header
863    /// omits the optional dictionary ID. Callers must only use this API when
864    /// they already know the input frames were encoded with the provided
865    /// dictionary; otherwise decoded output can be silently corrupted.
866    pub fn decode_all_with_dict_handle(
867        &mut self,
868        input: &[u8],
869        output: &mut [u8],
870        dict: &DictionaryHandle,
871    ) -> Result<usize, FrameDecoderError> {
872        self.decode_all_impl(input, output, |this, src| {
873            this.init_with_dict_handle(src, dict)
874        })
875    }
876
877    fn decode_all_impl(
878        &mut self,
879        mut input: &[u8],
880        mut output: &mut [u8],
881        mut init_frame: impl FnMut(&mut Self, &mut &[u8]) -> Result<(), FrameDecoderError>,
882    ) -> Result<usize, FrameDecoderError> {
883        let mut total_bytes_written = 0;
884        while !input.is_empty() {
885            match init_frame(self, &mut input) {
886                Ok(_) => {}
887                Err(FrameDecoderError::ReadFrameHeaderError(
888                    crate::decoding::errors::ReadFrameHeaderError::SkipFrame { length, .. },
889                )) => {
890                    input = input
891                        .get(length as usize..)
892                        .ok_or(FrameDecoderError::FailedToSkipFrame)?;
893                    continue;
894                }
895                Err(e) => return Err(e),
896            };
897            loop {
898                self.decode_blocks(&mut input, BlockDecodingStrategy::UptoBytes(1024 * 1024))?;
899                let bytes_written = self
900                    .read(output)
901                    .map_err(FrameDecoderError::FailedToDrainDecodebuffer)?;
902                output = &mut output[bytes_written..];
903                total_bytes_written += bytes_written;
904                if self.can_collect() != 0 {
905                    return Err(FrameDecoderError::TargetTooSmall);
906                }
907                if self.is_finished() {
908                    break;
909                }
910            }
911        }
912
913        Ok(total_bytes_written)
914    }
915
916    /// Decode multiple frames into the output slice using a serialized dictionary.
917    ///
918    /// # Warning
919    ///
920    /// Each decoded frame is initialized with the parsed dictionary, even when a
921    /// frame header omits the optional dictionary ID. Callers must only use this
922    /// API when they already know the input frames were encoded with that
923    /// dictionary; otherwise decoded output can be silently corrupted.
924    pub fn decode_all_with_dict_bytes(
925        &mut self,
926        input: &[u8],
927        output: &mut [u8],
928        raw_dictionary: &[u8],
929    ) -> Result<usize, FrameDecoderError> {
930        let dict = DictionaryHandle::decode_dict(raw_dictionary)?;
931        self.decode_all_with_dict_handle(input, output, &dict)
932    }
933
934    /// Decode multiple frames into the extra capacity of the output vector.
935    ///
936    /// `input` must contain an exact number of frames.
937    ///
938    /// `output` must have enough extra capacity to hold the decompressed data.
939    /// This function will not reallocate or grow the vector. If you don't know
940    /// how large the output will be, use [`FrameDecoder::decode_blocks`] instead.
941    ///
942    /// This calls [`FrameDecoder::init`], and all bytes currently in the decoder will be lost.
943    ///
944    /// The length of the output vector is updated to include the decompressed data.
945    /// The length is not changed if an error occurs.
946    pub fn decode_all_to_vec(
947        &mut self,
948        input: &[u8],
949        output: &mut Vec<u8>,
950    ) -> Result<(), FrameDecoderError> {
951        let len = output.len();
952        let cap = output.capacity();
953        output.resize(cap, 0);
954        match self.decode_all(input, &mut output[len..]) {
955            Ok(bytes_written) => {
956                let new_len = core::cmp::min(len + bytes_written, cap); // Sanitizes `bytes_written`.
957                output.resize(new_len, 0);
958                Ok(())
959            }
960            Err(e) => {
961                output.resize(len, 0);
962                Err(e)
963            }
964        }
965    }
966}
967
968/// Read bytes from the decode_buffer that are no longer needed. While the frame is not yet finished
969/// this will retain window_size bytes, else it will drain it completely
970impl Read for FrameDecoder {
971    fn read(&mut self, target: &mut [u8]) -> Result<usize, Error> {
972        let state = match &mut self.state {
973            None => return Ok(0),
974            Some(s) => s,
975        };
976        if state.frame_finished {
977            state.decoder_scratch.buffer_read_all(target)
978        } else {
979            state.decoder_scratch.buffer_read(target)
980        }
981    }
982}
983
984#[cfg(test)]
985mod tests {
986    extern crate std;
987
988    use super::{DictionaryHandle, FrameDecoder};
989    use crate::encoding::{CompressionLevel, FrameCompressor};
990    use alloc::vec::Vec;
991
992    #[test]
993    fn reset_with_dict_handle_applies_dict_when_no_dict_id() {
994        let payload = b"reset-without-dict-id";
995        let mut compressor = FrameCompressor::new(CompressionLevel::Default);
996        compressor.set_source(payload.as_slice());
997        let mut compressed = Vec::new();
998        compressor.set_drain(&mut compressed);
999        compressor.compress();
1000
1001        let dict_raw = include_bytes!("../../dict_tests/dictionary");
1002        let handle = DictionaryHandle::decode_dict(dict_raw).expect("dictionary should parse");
1003
1004        let mut decoder = FrameDecoder::new();
1005        decoder
1006            .reset_with_dict_handle(compressed.as_slice(), &handle)
1007            .expect("reset should succeed");
1008        let state = decoder.state.as_ref().expect("state should be initialized");
1009        assert!(state.frame_header.dictionary_id().is_none());
1010        assert_eq!(state.using_dict, Some(handle.id()));
1011    }
1012}