1#![forbid(unsafe_code)]
2
3use std::io::{self, Read};
4
5use crate::BlockDecodeWorkspace;
6use crate::literals::decode_literals_ws;
7use crate::sequences::{SequenceDecodeTables, parse_sequence_count, parse_sequence_tables_ws};
8
9use crate::exec::decode_execute_sequences;
10use zrip_core::block::{BlockType, parse_block_header};
11use zrip_core::dict::Dictionary;
12use zrip_core::error::DecompressError;
13use zrip_core::frame::MAX_BLOCK_SIZE;
14use zrip_core::frame::header::parse_frame_header;
15use zrip_core::fse::{promote_ll_table, promote_ml_table, promote_of_table};
16use zrip_core::xxhash::Xxh64State;
17
18#[cfg(all(
19 any(target_arch = "x86_64", target_arch = "aarch64"),
20 not(feature = "paranoid")
21))]
22use zrip_core::simd::CpuTier;
23
24enum State {
25 FrameHeader,
26 BlockHeader,
27 BlockData {
28 block_type: BlockType,
29 block_size: usize,
30 last: bool,
31 },
32 Checksum,
33 Done,
34}
35
36pub struct FrameDecoder<R: Read> {
53 inner: R,
54 state: State,
55 read_buf: Vec<u8>,
56 output_buf: Vec<u8>,
57 output_pos: usize,
58 ws: Box<BlockDecodeWorkspace>,
59 seq_tables: SequenceDecodeTables,
60 rep_offsets: [u32; 3],
61 hasher: Option<Xxh64State>,
62 content_checksum: bool,
63 max_output: usize,
64 bytes_output: usize,
65 dict: Option<Dictionary>,
66}
67
68impl<R: Read> FrameDecoder<R> {
69 pub fn new(reader: R) -> Self {
71 Self::with_limit(reader, zrip_core::DEFAULT_DECOMPRESS_LIMIT)
72 }
73
74 pub fn with_limit(reader: R, max_output: usize) -> Self {
76 Self {
77 inner: reader,
78 state: State::FrameHeader,
79 read_buf: Vec::new(),
80 output_buf: Vec::new(),
81 output_pos: 0,
82 ws: Box::new(BlockDecodeWorkspace::new()),
83 seq_tables: SequenceDecodeTables::new_default(),
84 rep_offsets: [1, 4, 8],
85 hasher: None,
86 content_checksum: false,
87 max_output,
88 bytes_output: 0,
89 dict: None,
90 }
91 }
92
93 pub fn with_dict(reader: R, dict: Dictionary) -> Self {
95 Self::with_dict_and_limit(reader, dict, zrip_core::DEFAULT_DECOMPRESS_LIMIT)
96 }
97
98 pub fn with_dict_and_limit(reader: R, dict: Dictionary, max_output: usize) -> Self {
100 Self {
101 inner: reader,
102 state: State::FrameHeader,
103 read_buf: Vec::new(),
104 output_buf: Vec::new(),
105 output_pos: 0,
106 ws: Box::new(BlockDecodeWorkspace::new()),
107 seq_tables: SequenceDecodeTables::new_default(),
108 rep_offsets: [1, 4, 8],
109 hasher: None,
110 content_checksum: false,
111 max_output,
112 bytes_output: 0,
113 dict: Some(dict),
114 }
115 }
116
117 pub fn into_inner(self) -> R {
119 self.inner
120 }
121
122 pub fn reset(&mut self, new_reader: R) -> R {
125 let old = core::mem::replace(&mut self.inner, new_reader);
126 self.state = State::FrameHeader;
127 self.output_buf.clear();
128 self.output_pos = 0;
129 self.rep_offsets = [1, 4, 8];
130 self.seq_tables = SequenceDecodeTables::new_default();
131 self.ws.huf_valid = false;
132 self.hasher = None;
133 self.content_checksum = false;
134 self.bytes_output = 0;
135 old
136 }
137
138 fn fill_output(&mut self) -> io::Result<()> {
139 loop {
140 match self.state {
141 State::Done => return Ok(()),
142 State::FrameHeader => self.read_frame_header()?,
143 State::BlockHeader => self.read_block_header()?,
144 State::BlockData {
145 block_type,
146 block_size,
147 last,
148 } => {
149 self.read_block_data(block_type, block_size, last)?;
150 if self.output_pos < self.output_buf.len() {
151 return Ok(());
152 }
153 }
154 State::Checksum => self.read_checksum()?,
155 }
156 }
157 }
158
159 fn read_frame_header(&mut self) -> io::Result<()> {
160 self.read_buf.resize(18, 0);
161 self.inner.read_exact(&mut self.read_buf[..5])?;
162
163 let magic = u32::from_le_bytes([
164 self.read_buf[0],
165 self.read_buf[1],
166 self.read_buf[2],
167 self.read_buf[3],
168 ]);
169
170 if (magic & 0xFFFF_FFF0) == 0x184D_2A50 {
171 self.inner.read_exact(&mut self.read_buf[5..9])?;
172 let skip_size = u32::from_le_bytes([
173 self.read_buf[5],
174 self.read_buf[6],
175 self.read_buf[7],
176 self.read_buf[8],
177 ]) as usize;
178 io::copy(
179 &mut self.inner.by_ref().take(skip_size as u64),
180 &mut io::sink(),
181 )?;
182 return Ok(());
183 }
184
185 let descriptor = self.read_buf[4];
186 let single_segment = (descriptor & 0x20) != 0;
187 let dict_id_flag = descriptor & 0x03;
188 let fcs_flag = (descriptor >> 6) & 0x03;
189
190 let mut hdr_len = 5usize;
191 if !single_segment {
192 hdr_len += 1;
193 }
194 hdr_len += match dict_id_flag {
195 0 => 0,
196 1 => 1,
197 2 => 2,
198 3 => 4,
199 _ => unreachable!(),
200 };
201 hdr_len += match fcs_flag {
202 0 if single_segment => 1,
203 0 => 0,
204 1 => 2,
205 2 => 4,
206 3 => 8,
207 _ => unreachable!(),
208 };
209
210 if hdr_len > 5 {
211 self.inner.read_exact(&mut self.read_buf[5..hdr_len])?;
212 }
213
214 let header = parse_frame_header(&self.read_buf[..hdr_len])
215 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
216
217 if let Some(frame_dict_id) = header.dict_id {
218 match &self.dict {
219 Some(d) if d.id() == frame_dict_id => {}
220 Some(d) => {
221 return Err(io::Error::new(
222 io::ErrorKind::InvalidData,
223 DecompressError::DictMismatch {
224 expected: frame_dict_id,
225 got: d.id(),
226 },
227 ));
228 }
229 None => {
230 return Err(io::Error::new(
231 io::ErrorKind::InvalidData,
232 DecompressError::DictRequired,
233 ));
234 }
235 }
236 }
237
238 if let Some(fcs) = header.frame_content_size {
239 if fcs as usize > self.max_output {
240 return Err(io::Error::new(
241 io::ErrorKind::InvalidData,
242 DecompressError::OutputTooSmall,
243 ));
244 }
245 }
246
247 self.content_checksum = header.content_checksum;
248 self.hasher = if header.content_checksum {
249 Some(Xxh64State::new(0))
250 } else {
251 None
252 };
253
254 if let Some(ref d) = self.dict {
255 self.rep_offsets = *d.rep_offsets();
256 let mut st = SequenceDecodeTables::new_default();
257 if let Some((t, l)) = d.of_table() {
258 st.of_table = promote_of_table(t);
259 st.of_accuracy = l;
260 }
261 if let Some((t, l)) = d.ml_table() {
262 st.ml_table = promote_ml_table(t);
263 st.ml_accuracy = l;
264 }
265 if let Some((t, l)) = d.ll_table() {
266 st.ll_table = promote_ll_table(t);
267 st.ll_accuracy = l;
268 }
269 self.seq_tables = st;
270 self.ws.huf_valid = false;
271 if let Some((t, l)) = d.huf_table() {
272 self.ws.huf_table.clear();
273 self.ws.huf_table.extend_from_slice(t);
274 self.ws.huf_table_log = l;
275 self.ws.huf_valid = true;
276 }
277 } else {
278 self.rep_offsets = [1, 4, 8];
279 self.seq_tables = SequenceDecodeTables::new_default();
280 self.ws.huf_valid = false;
281 }
282
283 self.state = State::BlockHeader;
284 Ok(())
285 }
286
287 fn read_block_header(&mut self) -> io::Result<()> {
288 let mut hdr = [0u8; 3];
289 self.inner.read_exact(&mut hdr)?;
290 let block_header =
291 parse_block_header(&hdr).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
292
293 let block_size = block_header.block_size as usize;
294
295 match block_header.block_type {
296 BlockType::Raw | BlockType::Rle if block_size > MAX_BLOCK_SIZE => {
297 return Err(io::Error::new(
298 io::ErrorKind::InvalidData,
299 DecompressError::BlockTooLarge,
300 ));
301 }
302 _ => {}
303 }
304
305 self.state = State::BlockData {
306 block_type: block_header.block_type,
307 block_size,
308 last: block_header.last_block,
309 };
310 Ok(())
311 }
312
313 fn read_block_data(
314 &mut self,
315 block_type: BlockType,
316 block_size: usize,
317 last: bool,
318 ) -> io::Result<()> {
319 self.output_buf.clear();
320 self.output_pos = 0;
321
322 match block_type {
323 BlockType::Raw => {
324 self.output_buf.resize(block_size, 0);
325 self.inner.read_exact(&mut self.output_buf)?;
326 }
327 BlockType::Rle => {
328 let mut byte = [0u8; 1];
329 self.inner.read_exact(&mut byte)?;
330 self.output_buf.resize(block_size, byte[0]);
331 }
332 BlockType::Compressed => {
333 self.read_buf.resize(block_size, 0);
334 self.inner.read_exact(&mut self.read_buf[..block_size])?;
335 self.decode_compressed_block(block_size)?;
336 }
337 }
338
339 if let Some(ref mut hasher) = self.hasher {
340 hasher.update(&self.output_buf);
341 }
342 self.bytes_output += self.output_buf.len();
343 if self.bytes_output > self.max_output {
344 return Err(io::Error::new(
345 io::ErrorKind::InvalidData,
346 DecompressError::OutputTooSmall,
347 ));
348 }
349
350 self.state = if last {
351 if self.content_checksum {
352 State::Checksum
353 } else {
354 State::FrameHeader
355 }
356 } else {
357 State::BlockHeader
358 };
359
360 Ok(())
361 }
362
363 fn decode_compressed_block(&mut self, block_size: usize) -> io::Result<()> {
364 let dict_history: &[u8] = match &self.dict {
365 Some(d) => d.content(),
366 None => &[],
367 };
368 let block_data = &self.read_buf[..block_size];
369
370 let lit_consumed = decode_literals_ws(block_data, &mut self.ws)
371 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
372
373 let remaining = &block_data[lit_consumed..];
374
375 if remaining.is_empty() {
376 self.output_buf.extend_from_slice(&self.ws.literal_buf);
377 return Ok(());
378 }
379
380 let (num_sequences, seq_count_size) = parse_sequence_count(remaining)
381 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
382
383 if num_sequences == 0 {
384 self.output_buf.extend_from_slice(&self.ws.literal_buf);
385 return Ok(());
386 }
387
388 let table_data = &remaining[seq_count_size..];
389 let tables_consumed =
390 parse_sequence_tables_ws(table_data, &mut self.seq_tables, &mut self.ws)
391 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
392
393 let seq_data = &table_data[tables_consumed..];
394
395 #[cfg(all(target_arch = "x86_64", not(feature = "paranoid")))]
396 {
397 if zrip_core::simd::cpu_tier() >= CpuTier::Avx2 {
398 let before = self.output_buf.len();
399 crate::simd_decode::x86_64::decode::decode_execute_avx2_safe(
400 seq_data,
401 num_sequences,
402 &self.seq_tables,
403 &mut self.rep_offsets,
404 &self.ws.literal_buf,
405 &mut self.output_buf,
406 dict_history,
407 )
408 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
409 if self.output_buf.len() - before > MAX_BLOCK_SIZE {
410 return Err(io::Error::new(
411 io::ErrorKind::InvalidData,
412 DecompressError::BlockTooLarge,
413 ));
414 }
415 return Ok(());
416 }
417 }
418
419 #[cfg(all(target_arch = "aarch64", not(feature = "paranoid")))]
420 {
421 if zrip_core::simd::cpu_tier() >= CpuTier::Neon {
422 let before = self.output_buf.len();
423 crate::simd_decode::aarch64::decode::decode_execute_neon_safe(
424 seq_data,
425 num_sequences,
426 &self.seq_tables,
427 &mut self.rep_offsets,
428 &self.ws.literal_buf,
429 &mut self.output_buf,
430 dict_history,
431 )
432 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
433 if self.output_buf.len() - before > MAX_BLOCK_SIZE {
434 return Err(io::Error::new(
435 io::ErrorKind::InvalidData,
436 DecompressError::BlockTooLarge,
437 ));
438 }
439 return Ok(());
440 }
441 }
442
443 let before = self.output_buf.len();
444 decode_execute_sequences(
445 seq_data,
446 num_sequences,
447 &self.seq_tables,
448 &mut self.rep_offsets,
449 &self.ws.literal_buf,
450 &mut self.output_buf,
451 dict_history,
452 )
453 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
454 if self.output_buf.len() - before > MAX_BLOCK_SIZE {
455 return Err(io::Error::new(
456 io::ErrorKind::InvalidData,
457 DecompressError::BlockTooLarge,
458 ));
459 }
460 Ok(())
461 }
462
463 fn read_checksum(&mut self) -> io::Result<()> {
464 let mut buf = [0u8; 4];
465 self.inner.read_exact(&mut buf)?;
466 let stored = u32::from_le_bytes(buf);
467
468 if let Some(ref hasher) = self.hasher {
469 let hash = hasher.finish();
470 let expected = (hash & 0xFFFF_FFFF) as u32;
471 if expected != stored {
472 return Err(io::Error::new(
473 io::ErrorKind::InvalidData,
474 DecompressError::ChecksumMismatch {
475 expected: stored,
476 got: expected,
477 },
478 ));
479 }
480 }
481
482 self.state = State::FrameHeader;
483 Ok(())
484 }
485}
486
487impl<R: Read> Read for FrameDecoder<R> {
488 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
489 if self.output_pos >= self.output_buf.len() {
490 if let State::Done = &self.state {
491 return Ok(0);
492 }
493
494 self.output_buf.clear();
495 self.output_pos = 0;
496
497 match self.fill_output() {
498 Ok(()) => {}
499 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => match &self.state {
500 State::FrameHeader => {
501 self.state = State::Done;
502 return Ok(0);
503 }
504 _ => return Err(e),
505 },
506 Err(e) => return Err(e),
507 }
508 }
509
510 let available = &self.output_buf[self.output_pos..];
511 let n = buf.len().min(available.len());
512 buf[..n].copy_from_slice(&available[..n]);
513 self.output_pos += n;
514 Ok(n)
515 }
516}