Skip to main content

zrip_decode/
streaming.rs

1#![forbid(unsafe_code)]
2
3use std::io::{self, Read};
4
5use crate::BlockDecodeWorkspace;
6use crate::literals::decode_literals_ws;
7use crate::sequences::{SequenceDecodeTables, parse_sequence_count, parse_sequence_tables_ws};
8
9use crate::exec::decode_execute_sequences;
10use zrip_core::block::{BlockType, parse_block_header};
11use zrip_core::dict::Dictionary;
12use zrip_core::error::DecompressError;
13use zrip_core::frame::MAX_BLOCK_SIZE;
14use zrip_core::frame::header::parse_frame_header;
15use zrip_core::fse::{promote_ll_table, promote_ml_table, promote_of_table};
16use zrip_core::xxhash::Xxh64State;
17
18#[cfg(all(
19    any(target_arch = "x86_64", target_arch = "aarch64"),
20    not(feature = "paranoid")
21))]
22use zrip_core::simd::CpuTier;
23
24enum State {
25    FrameHeader,
26    BlockHeader,
27    BlockData {
28        block_type: BlockType,
29        block_size: usize,
30        last: bool,
31    },
32    Checksum,
33    Done,
34}
35
36/// Streaming zstd decompressor implementing [`Read`].
37///
38/// Wraps a reader of compressed data and yields decompressed bytes.
39/// Supports multi-frame streams and skippable frames.
40///
41/// ```
42/// use std::io::Read;
43///
44/// let data = b"hello, streaming world!".repeat(100);
45/// let compressed = zrip::compress(&data, 1).unwrap();
46///
47/// let mut decoder = zrip::FrameDecoder::new(&compressed[..]);
48/// let mut output = Vec::new();
49/// decoder.read_to_end(&mut output).unwrap();
50/// assert_eq!(output, data);
51/// ```
52pub struct FrameDecoder<R: Read> {
53    inner: R,
54    state: State,
55    read_buf: Vec<u8>,
56    output_buf: Vec<u8>,
57    output_pos: usize,
58    ws: Box<BlockDecodeWorkspace>,
59    seq_tables: SequenceDecodeTables,
60    rep_offsets: [u32; 3],
61    hasher: Option<Xxh64State>,
62    content_checksum: bool,
63    max_output: usize,
64    bytes_output: usize,
65    frame_content_size: Option<u64>,
66    frame_bytes: usize,
67    dict: Option<Dictionary>,
68}
69
70impl<R: Read> FrameDecoder<R> {
71    /// Creates a decoder with [`DEFAULT_DECOMPRESS_LIMIT`](zrip_core::DEFAULT_DECOMPRESS_LIMIT).
72    pub fn new(reader: R) -> Self {
73        Self::with_limit(reader, zrip_core::DEFAULT_DECOMPRESS_LIMIT)
74    }
75
76    /// Creates a decoder with an explicit output size limit.
77    pub fn with_limit(reader: R, max_output: usize) -> Self {
78        Self {
79            inner: reader,
80            state: State::FrameHeader,
81            read_buf: Vec::new(),
82            output_buf: Vec::new(),
83            output_pos: 0,
84            ws: Box::new(BlockDecodeWorkspace::new()),
85            seq_tables: SequenceDecodeTables::new_default(),
86            rep_offsets: [1, 4, 8],
87            hasher: None,
88            content_checksum: false,
89            max_output,
90            bytes_output: 0,
91            frame_content_size: None,
92            frame_bytes: 0,
93            dict: None,
94        }
95    }
96
97    /// Creates a decoder with a dictionary and default output limit.
98    pub fn with_dict(reader: R, dict: Dictionary) -> Self {
99        Self::with_dict_and_limit(reader, dict, zrip_core::DEFAULT_DECOMPRESS_LIMIT)
100    }
101
102    /// Creates a decoder with a dictionary and explicit output limit.
103    pub fn with_dict_and_limit(reader: R, dict: Dictionary, max_output: usize) -> Self {
104        Self {
105            inner: reader,
106            state: State::FrameHeader,
107            read_buf: Vec::new(),
108            output_buf: Vec::new(),
109            output_pos: 0,
110            ws: Box::new(BlockDecodeWorkspace::new()),
111            seq_tables: SequenceDecodeTables::new_default(),
112            rep_offsets: [1, 4, 8],
113            hasher: None,
114            content_checksum: false,
115            max_output,
116            bytes_output: 0,
117            frame_content_size: None,
118            frame_bytes: 0,
119            dict: Some(dict),
120        }
121    }
122
123    /// Consumes the decoder and returns the underlying reader.
124    pub fn into_inner(self) -> R {
125        self.inner
126    }
127
128    /// Installs a new reader for the next frame, keeping all internal
129    /// buffers allocated. Returns the previous reader.
130    pub fn reset(&mut self, new_reader: R) -> R {
131        let old = core::mem::replace(&mut self.inner, new_reader);
132        self.state = State::FrameHeader;
133        self.output_buf.clear();
134        self.output_pos = 0;
135        self.rep_offsets = [1, 4, 8];
136        self.seq_tables = SequenceDecodeTables::new_default();
137        self.ws.huf_valid = false;
138        self.hasher = None;
139        self.content_checksum = false;
140        self.bytes_output = 0;
141        self.frame_content_size = None;
142        self.frame_bytes = 0;
143        old
144    }
145
146    fn fill_output(&mut self) -> io::Result<()> {
147        loop {
148            match self.state {
149                State::Done => return Ok(()),
150                State::FrameHeader => self.read_frame_header()?,
151                State::BlockHeader => self.read_block_header()?,
152                State::BlockData {
153                    block_type,
154                    block_size,
155                    last,
156                } => {
157                    self.read_block_data(block_type, block_size, last)?;
158                    if self.output_pos < self.output_buf.len() {
159                        return Ok(());
160                    }
161                }
162                State::Checksum => self.read_checksum()?,
163            }
164        }
165    }
166
167    fn read_frame_header(&mut self) -> io::Result<()> {
168        self.read_buf.resize(18, 0);
169        self.inner.read_exact(&mut self.read_buf[..5])?;
170
171        let magic = u32::from_le_bytes([
172            self.read_buf[0],
173            self.read_buf[1],
174            self.read_buf[2],
175            self.read_buf[3],
176        ]);
177
178        if (magic & 0xFFFF_FFF0) == 0x184D_2A50 {
179            self.inner.read_exact(&mut self.read_buf[5..9])?;
180            let skip_size = u32::from_le_bytes([
181                self.read_buf[5],
182                self.read_buf[6],
183                self.read_buf[7],
184                self.read_buf[8],
185            ]) as usize;
186            io::copy(
187                &mut self.inner.by_ref().take(skip_size as u64),
188                &mut io::sink(),
189            )?;
190            return Ok(());
191        }
192
193        let descriptor = self.read_buf[4];
194        let single_segment = (descriptor & 0x20) != 0;
195        let dict_id_flag = descriptor & 0x03;
196        let fcs_flag = (descriptor >> 6) & 0x03;
197
198        let mut hdr_len = 5usize;
199        if !single_segment {
200            hdr_len += 1;
201        }
202        hdr_len += match dict_id_flag {
203            0 => 0,
204            1 => 1,
205            2 => 2,
206            3 => 4,
207            _ => unreachable!(),
208        };
209        hdr_len += match fcs_flag {
210            0 if single_segment => 1,
211            0 => 0,
212            1 => 2,
213            2 => 4,
214            3 => 8,
215            _ => unreachable!(),
216        };
217
218        if hdr_len > 5 {
219            self.inner.read_exact(&mut self.read_buf[5..hdr_len])?;
220        }
221
222        let header = parse_frame_header(&self.read_buf[..hdr_len])
223            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
224
225        if let Some(frame_dict_id) = header.dict_id {
226            match &self.dict {
227                Some(d) if d.id() == frame_dict_id => {}
228                Some(d) => {
229                    return Err(io::Error::new(
230                        io::ErrorKind::InvalidData,
231                        DecompressError::DictMismatch {
232                            expected: frame_dict_id,
233                            got: d.id(),
234                        },
235                    ));
236                }
237                None => {
238                    return Err(io::Error::new(
239                        io::ErrorKind::InvalidData,
240                        DecompressError::DictRequired,
241                    ));
242                }
243            }
244        }
245
246        if let Some(fcs) = header.frame_content_size {
247            if fcs as usize > self.max_output {
248                return Err(io::Error::new(
249                    io::ErrorKind::InvalidData,
250                    DecompressError::OutputTooSmall,
251                ));
252            }
253        }
254
255        self.frame_content_size = header.frame_content_size;
256        self.frame_bytes = 0;
257        self.content_checksum = header.content_checksum;
258        self.hasher = if header.content_checksum {
259            Some(Xxh64State::new(0))
260        } else {
261            None
262        };
263
264        if let Some(ref d) = self.dict {
265            self.rep_offsets = *d.rep_offsets();
266            let mut st = SequenceDecodeTables::new_default();
267            if let Some((t, l)) = d.of_table() {
268                st.of_table = promote_of_table(t);
269                st.of_accuracy = l;
270                st.of_set = true;
271            }
272            if let Some((t, l)) = d.ml_table() {
273                st.ml_table = promote_ml_table(t);
274                st.ml_accuracy = l;
275                st.ml_set = true;
276            }
277            if let Some((t, l)) = d.ll_table() {
278                st.ll_table = promote_ll_table(t);
279                st.ll_accuracy = l;
280                st.ll_set = true;
281            }
282            self.seq_tables = st;
283            self.ws.huf_valid = false;
284            if let Some((t, l)) = d.huf_table() {
285                self.ws.huf_table.clear();
286                self.ws.huf_table.extend_from_slice(t);
287                self.ws.huf_table_log = l;
288                self.ws.huf_valid = true;
289            }
290        } else {
291            self.rep_offsets = [1, 4, 8];
292            self.seq_tables = SequenceDecodeTables::new_default();
293            self.ws.huf_valid = false;
294        }
295
296        self.state = State::BlockHeader;
297        Ok(())
298    }
299
300    fn read_block_header(&mut self) -> io::Result<()> {
301        let mut hdr = [0u8; 3];
302        self.inner.read_exact(&mut hdr)?;
303        let block_header =
304            parse_block_header(&hdr).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
305
306        let block_size = block_header.block_size as usize;
307
308        match block_header.block_type {
309            BlockType::Raw | BlockType::Rle if block_size > MAX_BLOCK_SIZE => {
310                return Err(io::Error::new(
311                    io::ErrorKind::InvalidData,
312                    DecompressError::BlockTooLarge,
313                ));
314            }
315            _ => {}
316        }
317
318        self.state = State::BlockData {
319            block_type: block_header.block_type,
320            block_size,
321            last: block_header.last_block,
322        };
323        Ok(())
324    }
325
326    fn read_block_data(
327        &mut self,
328        block_type: BlockType,
329        block_size: usize,
330        last: bool,
331    ) -> io::Result<()> {
332        self.output_buf.clear();
333        self.output_pos = 0;
334
335        match block_type {
336            BlockType::Raw => {
337                self.output_buf.resize(block_size, 0);
338                self.inner.read_exact(&mut self.output_buf)?;
339            }
340            BlockType::Rle => {
341                let mut byte = [0u8; 1];
342                self.inner.read_exact(&mut byte)?;
343                self.output_buf.resize(block_size, byte[0]);
344            }
345            BlockType::Compressed => {
346                self.read_buf.resize(block_size, 0);
347                self.inner.read_exact(&mut self.read_buf[..block_size])?;
348                self.decode_compressed_block(block_size)?;
349            }
350        }
351
352        if let Some(ref mut hasher) = self.hasher {
353            hasher.update(&self.output_buf);
354        }
355        self.bytes_output += self.output_buf.len();
356        self.frame_bytes += self.output_buf.len();
357        if self.bytes_output > self.max_output {
358            return Err(io::Error::new(
359                io::ErrorKind::InvalidData,
360                DecompressError::OutputTooSmall,
361            ));
362        }
363
364        self.state = if last {
365            if let Some(fcs) = self.frame_content_size {
366                if self.frame_bytes as u64 != fcs {
367                    return Err(io::Error::new(
368                        io::ErrorKind::InvalidData,
369                        DecompressError::FrameSizeMismatch,
370                    ));
371                }
372            }
373            if self.content_checksum {
374                State::Checksum
375            } else {
376                State::FrameHeader
377            }
378        } else {
379            State::BlockHeader
380        };
381
382        Ok(())
383    }
384
385    fn decode_compressed_block(&mut self, block_size: usize) -> io::Result<()> {
386        let dict_history: &[u8] = match &self.dict {
387            Some(d) => d.content(),
388            None => &[],
389        };
390        let block_data = &self.read_buf[..block_size];
391
392        let lit_consumed = decode_literals_ws(block_data, &mut self.ws)
393            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
394
395        let remaining = &block_data[lit_consumed..];
396
397        if remaining.is_empty() {
398            self.output_buf.extend_from_slice(&self.ws.literal_buf);
399            return Ok(());
400        }
401
402        let (num_sequences, seq_count_size) = parse_sequence_count(remaining)
403            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
404
405        if num_sequences == 0 {
406            self.output_buf.extend_from_slice(&self.ws.literal_buf);
407            return Ok(());
408        }
409
410        let table_data = &remaining[seq_count_size..];
411        let tables_consumed =
412            parse_sequence_tables_ws(table_data, &mut self.seq_tables, &mut self.ws)
413                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
414
415        let seq_data = &table_data[tables_consumed..];
416
417        #[cfg(all(target_arch = "x86_64", not(feature = "paranoid")))]
418        {
419            if zrip_core::simd::cpu_tier() >= CpuTier::Avx2 {
420                let before = self.output_buf.len();
421                crate::simd_decode::x86_64::decode::decode_execute_avx2_safe(
422                    seq_data,
423                    num_sequences,
424                    &self.seq_tables,
425                    &mut self.rep_offsets,
426                    &self.ws.literal_buf,
427                    &mut self.output_buf,
428                    dict_history,
429                )
430                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
431                if self.output_buf.len() - before > MAX_BLOCK_SIZE {
432                    return Err(io::Error::new(
433                        io::ErrorKind::InvalidData,
434                        DecompressError::BlockTooLarge,
435                    ));
436                }
437                return Ok(());
438            }
439        }
440
441        #[cfg(all(target_arch = "aarch64", not(feature = "paranoid")))]
442        {
443            if zrip_core::simd::cpu_tier() >= CpuTier::Neon {
444                let before = self.output_buf.len();
445                crate::simd_decode::aarch64::decode::decode_execute_neon_safe(
446                    seq_data,
447                    num_sequences,
448                    &self.seq_tables,
449                    &mut self.rep_offsets,
450                    &self.ws.literal_buf,
451                    &mut self.output_buf,
452                    dict_history,
453                )
454                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
455                if self.output_buf.len() - before > MAX_BLOCK_SIZE {
456                    return Err(io::Error::new(
457                        io::ErrorKind::InvalidData,
458                        DecompressError::BlockTooLarge,
459                    ));
460                }
461                return Ok(());
462            }
463        }
464
465        let before = self.output_buf.len();
466        decode_execute_sequences(
467            seq_data,
468            num_sequences,
469            &self.seq_tables,
470            &mut self.rep_offsets,
471            &self.ws.literal_buf,
472            &mut self.output_buf,
473            dict_history,
474        )
475        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
476        if self.output_buf.len() - before > MAX_BLOCK_SIZE {
477            return Err(io::Error::new(
478                io::ErrorKind::InvalidData,
479                DecompressError::BlockTooLarge,
480            ));
481        }
482        Ok(())
483    }
484
485    fn read_checksum(&mut self) -> io::Result<()> {
486        let mut buf = [0u8; 4];
487        self.inner.read_exact(&mut buf)?;
488        let stored = u32::from_le_bytes(buf);
489
490        if let Some(ref hasher) = self.hasher {
491            let hash = hasher.finish();
492            let expected = (hash & 0xFFFF_FFFF) as u32;
493            if expected != stored {
494                return Err(io::Error::new(
495                    io::ErrorKind::InvalidData,
496                    DecompressError::ChecksumMismatch {
497                        expected: stored,
498                        got: expected,
499                    },
500                ));
501            }
502        }
503
504        self.state = State::FrameHeader;
505        Ok(())
506    }
507}
508
509impl<R: Read> Read for FrameDecoder<R> {
510    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
511        if self.output_pos >= self.output_buf.len() {
512            if let State::Done = &self.state {
513                return Ok(0);
514            }
515
516            self.output_buf.clear();
517            self.output_pos = 0;
518
519            match self.fill_output() {
520                Ok(()) => {}
521                Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => match &self.state {
522                    State::FrameHeader => {
523                        self.state = State::Done;
524                        return Ok(0);
525                    }
526                    _ => return Err(e),
527                },
528                Err(e) => return Err(e),
529            }
530        }
531
532        let available = &self.output_buf[self.output_pos..];
533        let n = buf.len().min(available.len());
534        buf[..n].copy_from_slice(&available[..n]);
535        self.output_pos += n;
536        Ok(n)
537    }
538}