Skip to main content

zrip_decode/
streaming.rs

1#![forbid(unsafe_code)]
2
3use std::io::{self, Read};
4
5use crate::BlockDecodeWorkspace;
6use crate::literals::decode_literals_ws;
7use crate::sequences::{SequenceDecodeTables, parse_sequence_count, parse_sequence_tables_ws};
8
9use crate::exec::decode_execute_sequences;
10use zrip_core::block::{BlockType, parse_block_header};
11use zrip_core::dict::Dictionary;
12use zrip_core::error::DecompressError;
13use zrip_core::frame::MAX_BLOCK_SIZE;
14use zrip_core::frame::header::parse_frame_header;
15use zrip_core::fse::{promote_ll_table, promote_ml_table, promote_of_table};
16use zrip_core::xxhash::Xxh64State;
17
18#[cfg(all(
19    any(target_arch = "x86_64", target_arch = "aarch64"),
20    not(feature = "paranoid")
21))]
22use zrip_core::simd::CpuTier;
23
24enum State {
25    FrameHeader,
26    BlockHeader,
27    BlockData {
28        block_type: BlockType,
29        block_size: usize,
30        last: bool,
31    },
32    Checksum,
33    Done,
34}
35
36/// Streaming zstd decompressor implementing [`Read`].
37///
38/// Wraps a reader of compressed data and yields decompressed bytes.
39/// Supports multi-frame streams and skippable frames.
40///
41/// ```
42/// use std::io::Read;
43///
44/// let data = b"hello, streaming world!".repeat(100);
45/// let compressed = zrip::compress(&data, 1).unwrap();
46///
47/// let mut decoder = zrip::FrameDecoder::new(&compressed[..]);
48/// let mut output = Vec::new();
49/// decoder.read_to_end(&mut output).unwrap();
50/// assert_eq!(output, data);
51/// ```
52pub struct FrameDecoder<R: Read> {
53    inner: R,
54    state: State,
55    read_buf: Vec<u8>,
56    output_buf: Vec<u8>,
57    output_pos: usize,
58    ws: Box<BlockDecodeWorkspace>,
59    seq_tables: SequenceDecodeTables,
60    rep_offsets: [u32; 3],
61    hasher: Option<Xxh64State>,
62    content_checksum: bool,
63    max_output: usize,
64    bytes_output: usize,
65    dict: Option<Dictionary>,
66}
67
68impl<R: Read> FrameDecoder<R> {
69    /// Creates a decoder with [`DEFAULT_DECOMPRESS_LIMIT`](zrip_core::DEFAULT_DECOMPRESS_LIMIT).
70    pub fn new(reader: R) -> Self {
71        Self::with_limit(reader, zrip_core::DEFAULT_DECOMPRESS_LIMIT)
72    }
73
74    /// Creates a decoder with an explicit output size limit.
75    pub fn with_limit(reader: R, max_output: usize) -> Self {
76        Self {
77            inner: reader,
78            state: State::FrameHeader,
79            read_buf: Vec::new(),
80            output_buf: Vec::new(),
81            output_pos: 0,
82            ws: Box::new(BlockDecodeWorkspace::new()),
83            seq_tables: SequenceDecodeTables::new_default(),
84            rep_offsets: [1, 4, 8],
85            hasher: None,
86            content_checksum: false,
87            max_output,
88            bytes_output: 0,
89            dict: None,
90        }
91    }
92
93    /// Creates a decoder with a dictionary and default output limit.
94    pub fn with_dict(reader: R, dict: Dictionary) -> Self {
95        Self::with_dict_and_limit(reader, dict, zrip_core::DEFAULT_DECOMPRESS_LIMIT)
96    }
97
98    /// Creates a decoder with a dictionary and explicit output limit.
99    pub fn with_dict_and_limit(reader: R, dict: Dictionary, max_output: usize) -> Self {
100        Self {
101            inner: reader,
102            state: State::FrameHeader,
103            read_buf: Vec::new(),
104            output_buf: Vec::new(),
105            output_pos: 0,
106            ws: Box::new(BlockDecodeWorkspace::new()),
107            seq_tables: SequenceDecodeTables::new_default(),
108            rep_offsets: [1, 4, 8],
109            hasher: None,
110            content_checksum: false,
111            max_output,
112            bytes_output: 0,
113            dict: Some(dict),
114        }
115    }
116
117    /// Consumes the decoder and returns the underlying reader.
118    pub fn into_inner(self) -> R {
119        self.inner
120    }
121
122    /// Installs a new reader for the next frame, keeping all internal
123    /// buffers allocated. Returns the previous reader.
124    pub fn reset(&mut self, new_reader: R) -> R {
125        let old = core::mem::replace(&mut self.inner, new_reader);
126        self.state = State::FrameHeader;
127        self.output_buf.clear();
128        self.output_pos = 0;
129        self.rep_offsets = [1, 4, 8];
130        self.seq_tables = SequenceDecodeTables::new_default();
131        self.ws.huf_valid = false;
132        self.hasher = None;
133        self.content_checksum = false;
134        self.bytes_output = 0;
135        old
136    }
137
138    fn fill_output(&mut self) -> io::Result<()> {
139        loop {
140            match self.state {
141                State::Done => return Ok(()),
142                State::FrameHeader => self.read_frame_header()?,
143                State::BlockHeader => self.read_block_header()?,
144                State::BlockData {
145                    block_type,
146                    block_size,
147                    last,
148                } => {
149                    self.read_block_data(block_type, block_size, last)?;
150                    if self.output_pos < self.output_buf.len() {
151                        return Ok(());
152                    }
153                }
154                State::Checksum => self.read_checksum()?,
155            }
156        }
157    }
158
159    fn read_frame_header(&mut self) -> io::Result<()> {
160        self.read_buf.resize(18, 0);
161        self.inner.read_exact(&mut self.read_buf[..5])?;
162
163        let magic = u32::from_le_bytes([
164            self.read_buf[0],
165            self.read_buf[1],
166            self.read_buf[2],
167            self.read_buf[3],
168        ]);
169
170        if (magic & 0xFFFF_FFF0) == 0x184D_2A50 {
171            self.inner.read_exact(&mut self.read_buf[5..9])?;
172            let skip_size = u32::from_le_bytes([
173                self.read_buf[5],
174                self.read_buf[6],
175                self.read_buf[7],
176                self.read_buf[8],
177            ]) as usize;
178            io::copy(
179                &mut self.inner.by_ref().take(skip_size as u64),
180                &mut io::sink(),
181            )?;
182            return Ok(());
183        }
184
185        let descriptor = self.read_buf[4];
186        let single_segment = (descriptor & 0x20) != 0;
187        let dict_id_flag = descriptor & 0x03;
188        let fcs_flag = (descriptor >> 6) & 0x03;
189
190        let mut hdr_len = 5usize;
191        if !single_segment {
192            hdr_len += 1;
193        }
194        hdr_len += match dict_id_flag {
195            0 => 0,
196            1 => 1,
197            2 => 2,
198            3 => 4,
199            _ => unreachable!(),
200        };
201        hdr_len += match fcs_flag {
202            0 if single_segment => 1,
203            0 => 0,
204            1 => 2,
205            2 => 4,
206            3 => 8,
207            _ => unreachable!(),
208        };
209
210        if hdr_len > 5 {
211            self.inner.read_exact(&mut self.read_buf[5..hdr_len])?;
212        }
213
214        let header = parse_frame_header(&self.read_buf[..hdr_len])
215            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
216
217        if let Some(frame_dict_id) = header.dict_id {
218            match &self.dict {
219                Some(d) if d.id() == frame_dict_id => {}
220                Some(d) => {
221                    return Err(io::Error::new(
222                        io::ErrorKind::InvalidData,
223                        DecompressError::DictMismatch {
224                            expected: frame_dict_id,
225                            got: d.id(),
226                        },
227                    ));
228                }
229                None => {
230                    return Err(io::Error::new(
231                        io::ErrorKind::InvalidData,
232                        DecompressError::DictRequired,
233                    ));
234                }
235            }
236        }
237
238        if let Some(fcs) = header.frame_content_size {
239            if fcs as usize > self.max_output {
240                return Err(io::Error::new(
241                    io::ErrorKind::InvalidData,
242                    DecompressError::OutputTooSmall,
243                ));
244            }
245        }
246
247        self.content_checksum = header.content_checksum;
248        self.hasher = if header.content_checksum {
249            Some(Xxh64State::new(0))
250        } else {
251            None
252        };
253
254        if let Some(ref d) = self.dict {
255            self.rep_offsets = *d.rep_offsets();
256            let mut st = SequenceDecodeTables::new_default();
257            if let Some((t, l)) = d.of_table() {
258                st.of_table = promote_of_table(t);
259                st.of_accuracy = l;
260            }
261            if let Some((t, l)) = d.ml_table() {
262                st.ml_table = promote_ml_table(t);
263                st.ml_accuracy = l;
264            }
265            if let Some((t, l)) = d.ll_table() {
266                st.ll_table = promote_ll_table(t);
267                st.ll_accuracy = l;
268            }
269            self.seq_tables = st;
270            self.ws.huf_valid = false;
271            if let Some((t, l)) = d.huf_table() {
272                self.ws.huf_table.clear();
273                self.ws.huf_table.extend_from_slice(t);
274                self.ws.huf_table_log = l;
275                self.ws.huf_valid = true;
276            }
277        } else {
278            self.rep_offsets = [1, 4, 8];
279            self.seq_tables = SequenceDecodeTables::new_default();
280            self.ws.huf_valid = false;
281        }
282
283        self.state = State::BlockHeader;
284        Ok(())
285    }
286
287    fn read_block_header(&mut self) -> io::Result<()> {
288        let mut hdr = [0u8; 3];
289        self.inner.read_exact(&mut hdr)?;
290        let block_header =
291            parse_block_header(&hdr).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
292
293        let block_size = block_header.block_size as usize;
294
295        match block_header.block_type {
296            BlockType::Raw | BlockType::Rle if block_size > MAX_BLOCK_SIZE => {
297                return Err(io::Error::new(
298                    io::ErrorKind::InvalidData,
299                    DecompressError::BlockTooLarge,
300                ));
301            }
302            _ => {}
303        }
304
305        self.state = State::BlockData {
306            block_type: block_header.block_type,
307            block_size,
308            last: block_header.last_block,
309        };
310        Ok(())
311    }
312
313    fn read_block_data(
314        &mut self,
315        block_type: BlockType,
316        block_size: usize,
317        last: bool,
318    ) -> io::Result<()> {
319        self.output_buf.clear();
320        self.output_pos = 0;
321
322        match block_type {
323            BlockType::Raw => {
324                self.output_buf.resize(block_size, 0);
325                self.inner.read_exact(&mut self.output_buf)?;
326            }
327            BlockType::Rle => {
328                let mut byte = [0u8; 1];
329                self.inner.read_exact(&mut byte)?;
330                self.output_buf.resize(block_size, byte[0]);
331            }
332            BlockType::Compressed => {
333                self.read_buf.resize(block_size, 0);
334                self.inner.read_exact(&mut self.read_buf[..block_size])?;
335                self.decode_compressed_block(block_size)?;
336            }
337        }
338
339        if let Some(ref mut hasher) = self.hasher {
340            hasher.update(&self.output_buf);
341        }
342        self.bytes_output += self.output_buf.len();
343        if self.bytes_output > self.max_output {
344            return Err(io::Error::new(
345                io::ErrorKind::InvalidData,
346                DecompressError::OutputTooSmall,
347            ));
348        }
349
350        self.state = if last {
351            if self.content_checksum {
352                State::Checksum
353            } else {
354                State::FrameHeader
355            }
356        } else {
357            State::BlockHeader
358        };
359
360        Ok(())
361    }
362
363    fn decode_compressed_block(&mut self, block_size: usize) -> io::Result<()> {
364        let dict_history: &[u8] = match &self.dict {
365            Some(d) => d.content(),
366            None => &[],
367        };
368        let block_data = &self.read_buf[..block_size];
369
370        let lit_consumed = decode_literals_ws(block_data, &mut self.ws)
371            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
372
373        let remaining = &block_data[lit_consumed..];
374
375        if remaining.is_empty() {
376            self.output_buf.extend_from_slice(&self.ws.literal_buf);
377            return Ok(());
378        }
379
380        let (num_sequences, seq_count_size) = parse_sequence_count(remaining)
381            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
382
383        if num_sequences == 0 {
384            self.output_buf.extend_from_slice(&self.ws.literal_buf);
385            return Ok(());
386        }
387
388        let table_data = &remaining[seq_count_size..];
389        let tables_consumed =
390            parse_sequence_tables_ws(table_data, &mut self.seq_tables, &mut self.ws)
391                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
392
393        let seq_data = &table_data[tables_consumed..];
394
395        #[cfg(all(target_arch = "x86_64", not(feature = "paranoid")))]
396        {
397            if zrip_core::simd::cpu_tier() >= CpuTier::Avx2 {
398                let before = self.output_buf.len();
399                crate::simd_decode::x86_64::decode::decode_execute_avx2_safe(
400                    seq_data,
401                    num_sequences,
402                    &self.seq_tables,
403                    &mut self.rep_offsets,
404                    &self.ws.literal_buf,
405                    &mut self.output_buf,
406                    dict_history,
407                )
408                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
409                if self.output_buf.len() - before > MAX_BLOCK_SIZE {
410                    return Err(io::Error::new(
411                        io::ErrorKind::InvalidData,
412                        DecompressError::BlockTooLarge,
413                    ));
414                }
415                return Ok(());
416            }
417        }
418
419        #[cfg(all(target_arch = "aarch64", not(feature = "paranoid")))]
420        {
421            if zrip_core::simd::cpu_tier() >= CpuTier::Neon {
422                let before = self.output_buf.len();
423                crate::simd_decode::aarch64::decode::decode_execute_neon_safe(
424                    seq_data,
425                    num_sequences,
426                    &self.seq_tables,
427                    &mut self.rep_offsets,
428                    &self.ws.literal_buf,
429                    &mut self.output_buf,
430                    dict_history,
431                )
432                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
433                if self.output_buf.len() - before > MAX_BLOCK_SIZE {
434                    return Err(io::Error::new(
435                        io::ErrorKind::InvalidData,
436                        DecompressError::BlockTooLarge,
437                    ));
438                }
439                return Ok(());
440            }
441        }
442
443        let before = self.output_buf.len();
444        decode_execute_sequences(
445            seq_data,
446            num_sequences,
447            &self.seq_tables,
448            &mut self.rep_offsets,
449            &self.ws.literal_buf,
450            &mut self.output_buf,
451            dict_history,
452        )
453        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
454        if self.output_buf.len() - before > MAX_BLOCK_SIZE {
455            return Err(io::Error::new(
456                io::ErrorKind::InvalidData,
457                DecompressError::BlockTooLarge,
458            ));
459        }
460        Ok(())
461    }
462
463    fn read_checksum(&mut self) -> io::Result<()> {
464        let mut buf = [0u8; 4];
465        self.inner.read_exact(&mut buf)?;
466        let stored = u32::from_le_bytes(buf);
467
468        if let Some(ref hasher) = self.hasher {
469            let hash = hasher.finish();
470            let expected = (hash & 0xFFFF_FFFF) as u32;
471            if expected != stored {
472                return Err(io::Error::new(
473                    io::ErrorKind::InvalidData,
474                    DecompressError::ChecksumMismatch {
475                        expected: stored,
476                        got: expected,
477                    },
478                ));
479            }
480        }
481
482        self.state = State::FrameHeader;
483        Ok(())
484    }
485}
486
487impl<R: Read> Read for FrameDecoder<R> {
488    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
489        if self.output_pos >= self.output_buf.len() {
490            if let State::Done = &self.state {
491                return Ok(0);
492            }
493
494            self.output_buf.clear();
495            self.output_pos = 0;
496
497            match self.fill_output() {
498                Ok(()) => {}
499                Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => match &self.state {
500                    State::FrameHeader => {
501                        self.state = State::Done;
502                        return Ok(0);
503                    }
504                    _ => return Err(e),
505                },
506                Err(e) => return Err(e),
507            }
508        }
509
510        let available = &self.output_buf[self.output_pos..];
511        let n = buf.len().min(available.len());
512        buf[..n].copy_from_slice(&available[..n]);
513        self.output_pos += n;
514        Ok(n)
515    }
516}