1#![forbid(unsafe_code)]
2
3use std::io::{self, Read};
4
5use crate::BlockDecodeWorkspace;
6use crate::literals::decode_literals_ws;
7use crate::sequences::{SequenceDecodeTables, parse_sequence_count, parse_sequence_tables_ws};
8use zrip_core::block::{BlockType, parse_block_header};
9use zrip_core::error::DecompressError;
10use zrip_core::frame::MAX_BLOCK_SIZE;
11use zrip_core::frame::header::parse_frame_header;
12use zrip_core::xxhash::Xxh64State;
13
14#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
15use zrip_core::simd::CpuTier;
16
17enum State {
18 FrameHeader,
19 BlockHeader,
20 BlockData {
21 block_type: BlockType,
22 block_size: usize,
23 last: bool,
24 },
25 Checksum,
26 Done,
27}
28
29pub struct FrameDecoder<R: Read> {
46 inner: R,
47 state: State,
48 read_buf: Vec<u8>,
49 output_buf: Vec<u8>,
50 output_pos: usize,
51 ws: Box<BlockDecodeWorkspace>,
52 seq_tables: SequenceDecodeTables,
53 rep_offsets: [u32; 3],
54 hasher: Option<Xxh64State>,
55 content_checksum: bool,
56 max_output: usize,
57 bytes_output: usize,
58}
59
60impl<R: Read> FrameDecoder<R> {
61 pub fn new(reader: R) -> Self {
63 Self::with_limit(reader, zrip_core::DEFAULT_DECOMPRESS_LIMIT)
64 }
65
66 pub fn with_limit(reader: R, max_output: usize) -> Self {
68 Self {
69 inner: reader,
70 state: State::FrameHeader,
71 read_buf: Vec::new(),
72 output_buf: Vec::new(),
73 output_pos: 0,
74 ws: Box::new(BlockDecodeWorkspace::new()),
75 seq_tables: SequenceDecodeTables::new_default(),
76 rep_offsets: [1, 4, 8],
77 hasher: None,
78 content_checksum: false,
79 max_output,
80 bytes_output: 0,
81 }
82 }
83
84 pub fn into_inner(self) -> R {
86 self.inner
87 }
88
89 fn fill_output(&mut self) -> io::Result<()> {
90 loop {
91 match self.state {
92 State::Done => return Ok(()),
93 State::FrameHeader => self.read_frame_header()?,
94 State::BlockHeader => self.read_block_header()?,
95 State::BlockData {
96 block_type,
97 block_size,
98 last,
99 } => {
100 self.read_block_data(block_type, block_size, last)?;
101 if self.output_pos < self.output_buf.len() {
102 return Ok(());
103 }
104 }
105 State::Checksum => self.read_checksum()?,
106 }
107 }
108 }
109
110 fn read_frame_header(&mut self) -> io::Result<()> {
111 self.read_buf.resize(18, 0);
112 self.inner.read_exact(&mut self.read_buf[..5])?;
113
114 let magic = u32::from_le_bytes([
115 self.read_buf[0],
116 self.read_buf[1],
117 self.read_buf[2],
118 self.read_buf[3],
119 ]);
120
121 if (magic & 0xFFFFFFF0) == 0x184D2A50 {
122 self.inner.read_exact(&mut self.read_buf[5..9])?;
123 let skip_size = u32::from_le_bytes([
124 self.read_buf[5],
125 self.read_buf[6],
126 self.read_buf[7],
127 self.read_buf[8],
128 ]) as usize;
129 io::copy(
130 &mut self.inner.by_ref().take(skip_size as u64),
131 &mut io::sink(),
132 )?;
133 return Ok(());
134 }
135
136 let descriptor = self.read_buf[4];
137 let single_segment = (descriptor & 0x20) != 0;
138 let dict_id_flag = descriptor & 0x03;
139 let fcs_flag = (descriptor >> 6) & 0x03;
140
141 let mut hdr_len = 5usize;
142 if !single_segment {
143 hdr_len += 1;
144 }
145 hdr_len += match dict_id_flag {
146 0 => 0,
147 1 => 1,
148 2 => 2,
149 3 => 4,
150 _ => unreachable!(),
151 };
152 hdr_len += match fcs_flag {
153 0 if single_segment => 1,
154 0 => 0,
155 1 => 2,
156 2 => 4,
157 3 => 8,
158 _ => unreachable!(),
159 };
160
161 if hdr_len > 5 {
162 self.inner.read_exact(&mut self.read_buf[5..hdr_len])?;
163 }
164
165 let header = parse_frame_header(&self.read_buf[..hdr_len])
166 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
167
168 if let Some(fcs) = header.frame_content_size {
169 if fcs as usize > self.max_output {
170 return Err(io::Error::new(
171 io::ErrorKind::InvalidData,
172 DecompressError::OutputTooSmall,
173 ));
174 }
175 }
176
177 self.content_checksum = header.content_checksum;
178 self.hasher = if header.content_checksum {
179 Some(Xxh64State::new(0))
180 } else {
181 None
182 };
183 self.rep_offsets = [1, 4, 8];
184 self.ws.huf_valid = false;
185 self.state = State::BlockHeader;
186 Ok(())
187 }
188
189 fn read_block_header(&mut self) -> io::Result<()> {
190 let mut hdr = [0u8; 3];
191 self.inner.read_exact(&mut hdr)?;
192 let block_header =
193 parse_block_header(&hdr).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
194
195 let block_size = block_header.block_size as usize;
196
197 match block_header.block_type {
198 BlockType::Raw | BlockType::Rle if block_size > MAX_BLOCK_SIZE => {
199 return Err(io::Error::new(
200 io::ErrorKind::InvalidData,
201 DecompressError::CorruptSequences,
202 ));
203 }
204 _ => {}
205 }
206
207 self.state = State::BlockData {
208 block_type: block_header.block_type,
209 block_size,
210 last: block_header.last_block,
211 };
212 Ok(())
213 }
214
215 fn read_block_data(
216 &mut self,
217 block_type: BlockType,
218 block_size: usize,
219 last: bool,
220 ) -> io::Result<()> {
221 self.output_buf.clear();
222 self.output_pos = 0;
223
224 match block_type {
225 BlockType::Raw => {
226 self.output_buf.resize(block_size, 0);
227 self.inner.read_exact(&mut self.output_buf)?;
228 }
229 BlockType::Rle => {
230 let mut byte = [0u8; 1];
231 self.inner.read_exact(&mut byte)?;
232 self.output_buf.resize(block_size, byte[0]);
233 }
234 BlockType::Compressed => {
235 self.read_buf.resize(block_size, 0);
236 self.inner.read_exact(&mut self.read_buf[..block_size])?;
237 self.decode_compressed_block(block_size)?;
238 }
239 }
240
241 if let Some(ref mut hasher) = self.hasher {
242 hasher.update(&self.output_buf);
243 }
244 self.bytes_output += self.output_buf.len();
245 if self.bytes_output > self.max_output {
246 return Err(io::Error::new(
247 io::ErrorKind::InvalidData,
248 DecompressError::OutputTooSmall,
249 ));
250 }
251
252 self.state = if last {
253 if self.content_checksum {
254 State::Checksum
255 } else {
256 State::FrameHeader
257 }
258 } else {
259 State::BlockHeader
260 };
261
262 Ok(())
263 }
264
265 fn decode_compressed_block(&mut self, block_size: usize) -> io::Result<()> {
266 let block_data = &self.read_buf[..block_size];
267
268 let lit_consumed = decode_literals_ws(block_data, &mut self.ws)
269 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
270
271 let remaining = &block_data[lit_consumed..];
272
273 if remaining.is_empty() {
274 self.output_buf.extend_from_slice(&self.ws.literal_buf);
275 return Ok(());
276 }
277
278 let (num_sequences, seq_count_size) = parse_sequence_count(remaining)
279 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
280
281 if num_sequences == 0 {
282 self.output_buf.extend_from_slice(&self.ws.literal_buf);
283 return Ok(());
284 }
285
286 let table_data = &remaining[seq_count_size..];
287 let tables_consumed =
288 parse_sequence_tables_ws(table_data, &mut self.seq_tables, &mut self.ws)
289 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
290
291 let seq_data = &table_data[tables_consumed..];
292
293 #[cfg(target_arch = "x86_64")]
294 {
295 if zrip_core::simd::cpu_tier() >= CpuTier::Avx2 {
296 let before = self.output_buf.len();
297 crate::simd_decode::x86_64::decode::decode_execute_avx2_safe(
298 seq_data,
299 num_sequences,
300 &self.seq_tables,
301 &mut self.rep_offsets,
302 &self.ws.literal_buf,
303 &mut self.output_buf,
304 &[],
305 )
306 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
307 if self.output_buf.len() - before > MAX_BLOCK_SIZE {
308 return Err(io::Error::new(
309 io::ErrorKind::InvalidData,
310 DecompressError::CorruptSequences,
311 ));
312 }
313 return Ok(());
314 }
315 }
316
317 #[cfg(target_arch = "aarch64")]
318 {
319 if zrip_core::simd::cpu_tier() >= CpuTier::Neon {
320 let before = self.output_buf.len();
321 crate::simd_decode::aarch64::decode::decode_execute_neon_safe(
322 seq_data,
323 num_sequences,
324 &self.seq_tables,
325 &mut self.rep_offsets,
326 &self.ws.literal_buf,
327 &mut self.output_buf,
328 &[],
329 )
330 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
331 if self.output_buf.len() - before > MAX_BLOCK_SIZE {
332 return Err(io::Error::new(
333 io::ErrorKind::InvalidData,
334 DecompressError::CorruptSequences,
335 ));
336 }
337 return Ok(());
338 }
339 }
340
341 let before = self.output_buf.len();
342 crate::exec::decode_execute_sequences(
343 seq_data,
344 num_sequences,
345 &self.seq_tables,
346 &mut self.rep_offsets,
347 &self.ws.literal_buf,
348 &mut self.output_buf,
349 &[],
350 )
351 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
352 if self.output_buf.len() - before > MAX_BLOCK_SIZE {
353 return Err(io::Error::new(
354 io::ErrorKind::InvalidData,
355 DecompressError::CorruptSequences,
356 ));
357 }
358 Ok(())
359 }
360
361 fn read_checksum(&mut self) -> io::Result<()> {
362 let mut buf = [0u8; 4];
363 self.inner.read_exact(&mut buf)?;
364 let stored = u32::from_le_bytes(buf);
365
366 if let Some(ref hasher) = self.hasher {
367 let hash = hasher.finish();
368 let expected = (hash & 0xFFFFFFFF) as u32;
369 if expected != stored {
370 return Err(io::Error::new(
371 io::ErrorKind::InvalidData,
372 DecompressError::ChecksumMismatch {
373 expected: stored,
374 got: expected,
375 },
376 ));
377 }
378 }
379
380 self.state = State::FrameHeader;
381 Ok(())
382 }
383}
384
385impl<R: Read> Read for FrameDecoder<R> {
386 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
387 if self.output_pos >= self.output_buf.len() {
388 match &self.state {
389 State::Done => return Ok(0),
390 _ => {}
391 }
392
393 self.output_buf.clear();
394 self.output_pos = 0;
395
396 match self.fill_output() {
397 Ok(()) => {}
398 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => match &self.state {
399 State::FrameHeader => {
400 self.state = State::Done;
401 return Ok(0);
402 }
403 _ => return Err(e),
404 },
405 Err(e) => return Err(e),
406 }
407 }
408
409 let available = &self.output_buf[self.output_pos..];
410 let n = buf.len().min(available.len());
411 buf[..n].copy_from_slice(&available[..n]);
412 self.output_pos += n;
413 Ok(n)
414 }
415}