Skip to main content

rust_zstd/
decode.rs

1#![allow(clippy::manual_is_multiple_of, clippy::identity_op)]
2//! Self-contained Zstandard decompressor.
3//!
4//! Ported from ruzstd 0.8.2 by Moritz Borcherding, used under the MIT license.
5//!
6//! ```text
7//! MIT License
8//!
9//! Copyright (c) ruzstd contributors
10//!
11//! Permission is hereby granted, free of charge, to any person obtaining a copy
12//! of this software and associated documentation files (the "Software"), to deal
13//! in the Software without restriction, including without limitation the rights
14//! to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15//! copies of the Software, and to permit persons to whom the Software is
16//! furnished to do so, subject to the following conditions:
17//!
18//! The above copyright notice and this permission notice shall be included in all
19//! copies or substantial portions of the Software.
20//!
21//! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22//! IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23//! FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24//! AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25//! LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26//! OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27//! SOFTWARE.
28//! ```
29//!
30//! Public API: `decompress(data: &[u8]) -> Result<Vec<u8>, String>`
31//!
32//! Supports raw blocks, RLE blocks, and compressed blocks with Huffman
33//! literals and FSE sequences. No dictionary support.
34
35#![allow(
36    clippy::needless_range_loop,
37    clippy::len_without_is_empty,
38    clippy::upper_case_acronyms,
39    clippy::manual_range_contains,
40    dead_code
41)]
42
43// ============================================================
44// Constants
45// ============================================================
46
47const ZSTD_MAGIC: u32 = 0xFD2F_B528;
48const MIN_WINDOW_SIZE: u64 = 1024;
49const MAX_WINDOW_SIZE: u64 = (1 << 41) + 7 * (1 << 38);
50const MAX_BLOCK_SIZE: u32 = 128 * 1024;
51const MAXIMUM_ALLOWED_WINDOW_SIZE: u64 = 1024 * 1024 * 100;
52const MAX_MAX_NUM_BITS: u8 = 11;
53const ACC_LOG_OFFSET: u8 = 5;
54
55const MAX_LITERAL_LENGTH_CODE: u8 = 35;
56const MAX_MATCH_LENGTH_CODE: u8 = 52;
57const MAX_OFFSET_CODE: u8 = 31;
58
59const LL_MAX_LOG: u8 = 9;
60const ML_MAX_LOG: u8 = 9;
61const OF_MAX_LOG: u8 = 8;
62
63const LL_DEFAULT_ACC_LOG: u8 = 6;
64const ML_DEFAULT_ACC_LOG: u8 = 6;
65const OF_DEFAULT_ACC_LOG: u8 = 5;
66
67const LITERALS_LENGTH_DEFAULT_DISTRIBUTION: [i32; 36] = [
68    4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
69    -1, -1, -1, -1,
70];
71
72const MATCH_LENGTH_DEFAULT_DISTRIBUTION: [i32; 53] = [
73    1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
74    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1,
75];
76
77const OFFSET_DEFAULT_DISTRIBUTION: [i32; 29] = [
78    1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
79];
80
81// ============================================================
82// Public API
83// ============================================================
84
85/// Decode Huffman weights from FSE-compressed data (for encoder verification).
86pub fn decode_huf_weights_from_fse(source: &[u8], header: u8) -> Result<Vec<u8>, String> {
87    let mut ht = HuffmanTable::new();
88    let mut full = vec![header];
89    full.extend_from_slice(source);
90    let _ = ht.read_weights(&full)?;
91    Ok(ht.weights.clone())
92}
93
94/// Decode a Huffman tree description and reconstruct canonical codes.
95/// Returns codes array matching compress.rs format: [(code, nbits); 256].
96pub fn parse_fse_header(source: &[u8], max_log: u8) -> Result<(u8, Vec<i32>, usize), String> {
97    let mut table = FSETable::new(255);
98    let bytes = table.read_probabilities(source, max_log)?;
99    Ok((
100        table.accuracy_log,
101        table.symbol_probabilities.clone(),
102        bytes,
103    ))
104}
105
106/// Decompress a zstd-compressed byte slice, returning the uncompressed data.
107///
108/// Supports one or more concatenated zstd frames. Skippable frames are skipped.
109/// Dictionary frames are not supported.
110pub fn decompress(data: &[u8]) -> Result<Vec<u8>, String> {
111    let mut cursor = std::io::Cursor::new(data);
112    let mut output = Vec::new();
113    let mut decoder = FrameDecoder::new();
114
115    loop {
116        // Check if we have consumed all the data
117        if cursor.position() as usize >= data.len() {
118            break;
119        }
120
121        match decoder.reset(&mut cursor) {
122            Ok(()) => {}
123            Err(e) => {
124                if let Some(skip_len) = e.skip_frame_length() {
125                    let new_pos = cursor.position() + skip_len as u64;
126                    if new_pos as usize > data.len() {
127                        return Err("Skippable frame extends past end of input".to_string());
128                    }
129                    cursor.set_position(new_pos);
130                    continue;
131                }
132                // If we already have output and hit an error, it might just be trailing data
133                if !output.is_empty() {
134                    break;
135                }
136                return Err(format!("Frame header error: {}", e));
137            }
138        }
139
140        // Decode all blocks in this frame
141        decoder.decode_all_blocks(&mut cursor)?;
142
143        // Collect the output
144        if let Some(mut collected) = decoder.collect() {
145            output.append(&mut collected);
146        }
147    }
148
149    Ok(output)
150}
151
152// ============================================================
153// BitReader (forward)
154// ============================================================
155
156struct BitReader<'s> {
157    idx: usize,
158    source: &'s [u8],
159}
160
161impl<'s> BitReader<'s> {
162    fn new(source: &'s [u8]) -> BitReader<'s> {
163        BitReader { idx: 0, source }
164    }
165
166    fn bits_left(&self) -> usize {
167        self.source.len() * 8 - self.idx
168    }
169
170    fn bits_read(&self) -> usize {
171        self.idx
172    }
173
174    fn return_bits(&mut self, n: usize) {
175        if n > self.idx {
176            panic!("Cannot return more bits than have been read");
177        }
178        self.idx -= n;
179    }
180
181    fn get_bits(&mut self, n: usize) -> Result<u64, String> {
182        if n > 64 {
183            return Err(format!("Cannot read {} bits, maximum is 64", n));
184        }
185        if self.bits_left() < n {
186            return Err(format!(
187                "Cannot read {} bits, only {} remaining",
188                n,
189                self.bits_left()
190            ));
191        }
192
193        let old_idx = self.idx;
194        let bits_left_in_current_byte = 8 - (self.idx % 8);
195        let bits_not_needed_in_current_byte = 8 - bits_left_in_current_byte;
196
197        let mut value = u64::from(self.source[self.idx / 8] >> bits_not_needed_in_current_byte);
198
199        if bits_left_in_current_byte >= n {
200            value &= (1 << n) - 1;
201            self.idx += n;
202        } else {
203            self.idx += bits_left_in_current_byte;
204            let full_bytes_needed = (n - bits_left_in_current_byte) / 8;
205            let bits_in_last_byte_needed = n - bits_left_in_current_byte - full_bytes_needed * 8;
206
207            let mut bit_shift = bits_left_in_current_byte;
208
209            for _ in 0..full_bytes_needed {
210                value |= u64::from(self.source[self.idx / 8]) << bit_shift;
211                self.idx += 8;
212                bit_shift += 8;
213            }
214
215            if bits_in_last_byte_needed > 0 {
216                let val_last_byte =
217                    u64::from(self.source[self.idx / 8]) & ((1 << bits_in_last_byte_needed) - 1);
218                value |= val_last_byte << bit_shift;
219                self.idx += bits_in_last_byte_needed;
220            }
221        }
222
223        debug_assert!(self.idx == old_idx + n);
224        Ok(value)
225    }
226}
227
228// ============================================================
229// BitReaderReversed
230// ============================================================
231
232struct BitReaderReversed<'s> {
233    index: usize,
234    bits_consumed: u8,
235    extra_bits: usize,
236    source: &'s [u8],
237    bit_container: u64,
238}
239
240impl<'s> BitReaderReversed<'s> {
241    fn bits_remaining(&self) -> isize {
242        self.index as isize * 8 + (64 - self.bits_consumed as isize) - self.extra_bits as isize
243    }
244
245    fn new(source: &'s [u8]) -> BitReaderReversed<'s> {
246        BitReaderReversed {
247            index: source.len(),
248            bits_consumed: 64,
249            source,
250            bit_container: 0,
251            extra_bits: 0,
252        }
253    }
254
255    #[cold]
256    fn refill(&mut self) {
257        let bytes_consumed = self.bits_consumed as usize / 8;
258        if bytes_consumed == 0 {
259            return;
260        }
261
262        if self.index >= bytes_consumed {
263            self.index -= bytes_consumed;
264            self.bits_consumed &= 7;
265            let remaining = self.source.len() - self.index;
266            if remaining >= 8 {
267                self.bit_container =
268                    u64::from_le_bytes(self.source[self.index..][..8].try_into().unwrap());
269            } else {
270                let mut value = [0u8; 8];
271                value[..remaining].copy_from_slice(&self.source[self.index..]);
272                self.bit_container = u64::from_le_bytes(value);
273            }
274        } else if self.index > 0 {
275            if self.source.len() >= 8 {
276                self.bit_container = u64::from_le_bytes(self.source[..8].try_into().unwrap());
277            } else {
278                let mut value = [0; 8];
279                value[..self.source.len()].copy_from_slice(self.source);
280                self.bit_container = u64::from_le_bytes(value);
281            }
282
283            self.bits_consumed -= 8 * self.index as u8;
284            self.index = 0;
285
286            self.bit_container <<= self.bits_consumed;
287            self.extra_bits += self.bits_consumed as usize;
288            self.bits_consumed = 0;
289        } else if self.bits_consumed < 64 {
290            self.bit_container <<= self.bits_consumed;
291            self.extra_bits += self.bits_consumed as usize;
292            self.bits_consumed = 0;
293        } else {
294            self.extra_bits += self.bits_consumed as usize;
295            self.bits_consumed = 0;
296            self.bit_container = 0;
297        }
298
299        debug_assert!(self.bits_consumed < 8);
300    }
301
302    #[inline(always)]
303    fn get_bits(&mut self, n: u8) -> u64 {
304        if self.bits_consumed + n > 64 {
305            self.refill();
306        }
307        let value = self.peek_bits(n);
308        self.consume(n);
309        value
310    }
311
312    #[inline(always)]
313    fn peek_bits(&mut self, n: u8) -> u64 {
314        if n == 0 {
315            return 0;
316        }
317        let mask = (1u64 << n) - 1u64;
318        let shift_by = 64 - self.bits_consumed - n;
319        (self.bit_container >> shift_by) & mask
320    }
321
322    #[inline(always)]
323    fn peek_bits_triple(&mut self, sum: u8, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
324        if sum == 0 {
325            return (0, 0, 0);
326        }
327        let all_three = self.bit_container >> (64 - self.bits_consumed - sum);
328
329        let mask1 = (1u64 << n1) - 1u64;
330        let val1 = (all_three >> (n3 + n2)) & mask1;
331
332        let mask2 = (1u64 << n2) - 1u64;
333        let val2 = (all_three >> n3) & mask2;
334
335        let mask3 = (1u64 << n3) - 1u64;
336        let val3 = all_three & mask3;
337
338        (val1, val2, val3)
339    }
340
341    #[inline(always)]
342    fn consume(&mut self, n: u8) {
343        self.bits_consumed += n;
344        debug_assert!(self.bits_consumed <= 64);
345    }
346
347    #[inline(always)]
348    fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
349        let sum = n1 + n2 + n3;
350        if sum <= 56 {
351            self.refill();
352            let triple = self.peek_bits_triple(sum, n1, n2, n3);
353            self.consume(sum);
354            return triple;
355        }
356        (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3))
357    }
358}
359
360// ============================================================
361// FSE Table and Decoder
362// ============================================================
363
364#[derive(Copy, Clone, Debug)]
365struct FSEEntry {
366    base_line: u32,
367    num_bits: u8,
368    symbol: u8,
369}
370
371#[derive(Debug, Clone)]
372struct FSETable {
373    max_symbol: u8,
374    decode: Vec<FSEEntry>,
375    accuracy_log: u8,
376    symbol_probabilities: Vec<i32>,
377    symbol_counter: Vec<u32>,
378}
379
380impl FSETable {
381    fn new(max_symbol: u8) -> FSETable {
382        FSETable {
383            max_symbol,
384            symbol_probabilities: Vec::with_capacity(256),
385            symbol_counter: Vec::with_capacity(256),
386            decode: Vec::new(),
387            accuracy_log: 0,
388        }
389    }
390
391    fn reinit_from(&mut self, other: &Self) {
392        self.reset();
393        self.symbol_counter.extend_from_slice(&other.symbol_counter);
394        self.symbol_probabilities
395            .extend_from_slice(&other.symbol_probabilities);
396        self.decode.extend_from_slice(&other.decode);
397        self.accuracy_log = other.accuracy_log;
398    }
399
400    fn reset(&mut self) {
401        self.symbol_counter.clear();
402        self.symbol_probabilities.clear();
403        self.decode.clear();
404        self.accuracy_log = 0;
405    }
406
407    fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result<usize, String> {
408        self.accuracy_log = 0;
409        let bytes_read = self.read_probabilities(source, max_log)?;
410        self.build_decoding_table()?;
411        Ok(bytes_read)
412    }
413
414    fn build_from_probabilities(&mut self, acc_log: u8, probs: &[i32]) -> Result<(), String> {
415        if acc_log == 0 {
416            return Err("Accuracy log is zero".to_string());
417        }
418        self.symbol_probabilities = probs.to_vec();
419        self.accuracy_log = acc_log;
420        self.build_decoding_table()
421    }
422
423    fn build_decoding_table(&mut self) -> Result<(), String> {
424        if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
425            return Err(format!(
426                "Too many symbols: {}, max: {}",
427                self.symbol_probabilities.len(),
428                self.max_symbol + 1
429            ));
430        }
431
432        self.decode.clear();
433
434        let table_size = 1 << self.accuracy_log;
435        self.decode.resize(
436            table_size,
437            FSEEntry {
438                base_line: 0,
439                num_bits: 0,
440                symbol: 0,
441            },
442        );
443
444        let mut negative_idx = table_size;
445
446        for symbol in 0..self.symbol_probabilities.len() {
447            if self.symbol_probabilities[symbol] == -1 {
448                negative_idx -= 1;
449                let entry = &mut self.decode[negative_idx];
450                entry.symbol = symbol as u8;
451                entry.base_line = 0;
452                entry.num_bits = self.accuracy_log;
453            }
454        }
455
456        let mut position = 0;
457        for idx in 0..self.symbol_probabilities.len() {
458            let symbol = idx as u8;
459            if self.symbol_probabilities[idx] <= 0 {
460                continue;
461            }
462            let prob = self.symbol_probabilities[idx];
463            for _ in 0..prob {
464                let entry = &mut self.decode[position];
465                entry.symbol = symbol;
466                position = fse_next_position(position, table_size);
467                while position >= negative_idx {
468                    position = fse_next_position(position, table_size);
469                }
470            }
471        }
472
473        self.symbol_counter.clear();
474        self.symbol_counter
475            .resize(self.symbol_probabilities.len(), 0);
476        for idx in 0..negative_idx {
477            let symbol = self.decode[idx].symbol;
478            let prob = self.symbol_probabilities[symbol as usize];
479            let symbol_count = self.symbol_counter[symbol as usize];
480            let (bl, nb) =
481                fse_calc_baseline_and_numbits(table_size as u32, prob as u32, symbol_count);
482
483            assert!(nb <= self.accuracy_log);
484            self.symbol_counter[symbol as usize] += 1;
485
486            self.decode[idx].base_line = bl;
487            self.decode[idx].num_bits = nb;
488        }
489        Ok(())
490    }
491
492    fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, String> {
493        self.symbol_probabilities.clear();
494
495        let mut br = BitReader::new(source);
496        self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8);
497        if self.accuracy_log > max_log {
498            return Err(format!(
499                "Accuracy log {} exceeds max {}",
500                self.accuracy_log, max_log
501            ));
502        }
503        if self.accuracy_log == 0 {
504            return Err("Accuracy log is zero".to_string());
505        }
506
507        let probability_sum = 1u32 << self.accuracy_log;
508        let mut probability_counter = 0u32;
509
510        while probability_counter < probability_sum {
511            let max_remaining_value = probability_sum - probability_counter + 1;
512            let bits_to_read = highest_bit_set(max_remaining_value);
513
514            let unchecked_value = br.get_bits(bits_to_read as usize)? as u32;
515
516            let low_threshold = ((1 << bits_to_read) - 1) - max_remaining_value;
517            let mask = (1 << (bits_to_read - 1)) - 1;
518            let small_value = unchecked_value & mask;
519
520            let value = if small_value < low_threshold {
521                br.return_bits(1);
522                small_value
523            } else if unchecked_value > mask {
524                unchecked_value - low_threshold
525            } else {
526                unchecked_value
527            };
528
529            let prob = (value as i32) - 1;
530            self.symbol_probabilities.push(prob);
531
532            if prob != 0 {
533                if prob > 0 {
534                    probability_counter += prob as u32;
535                } else {
536                    // probability -1 counts as 1
537                    probability_counter += 1;
538                }
539            } else {
540                loop {
541                    let skip_amount = br.get_bits(2)? as usize;
542                    self.symbol_probabilities
543                        .resize(self.symbol_probabilities.len() + skip_amount, 0);
544                    if skip_amount != 3 {
545                        break;
546                    }
547                }
548            }
549        }
550
551        if probability_counter != probability_sum {
552            return Err(format!(
553                "Probability counter {} does not match expected sum {}",
554                probability_counter, probability_sum
555            ));
556        }
557        if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
558            return Err(format!(
559                "Too many symbols: {}",
560                self.symbol_probabilities.len()
561            ));
562        }
563
564        let bytes_read = if br.bits_read() % 8 == 0 {
565            br.bits_read() / 8
566        } else {
567            (br.bits_read() / 8) + 1
568        };
569
570        Ok(bytes_read)
571    }
572}
573
574fn fse_next_position(mut p: usize, table_size: usize) -> usize {
575    p += (table_size >> 1) + (table_size >> 3) + 3;
576    p &= table_size - 1;
577    p
578}
579
580/// Compute baseline and numbits for FSE decode table entry.
581/// Port of C zstd educational decoder / reference FSE_buildDTable:
582///   state_desc = num_states_symbol + state_number
583///   num_bits = accuracy_log - floor(log2(state_desc))
584///   baseline = (state_desc << num_bits) - table_size
585pub(crate) fn fse_calc_baseline_and_numbits(
586    num_states_total: u32,
587    num_states_symbol: u32,
588    state_number: u32,
589) -> (u32, u8) {
590    if num_states_symbol == 0 {
591        return (0, 0);
592    }
593    let accuracy_log = highest_bit_set(num_states_total) - 1;
594    let state_desc = num_states_symbol + state_number;
595    let hsb = highest_bit_set(state_desc) - 1; // floor(log2(state_desc))
596    let num_bits = accuracy_log - hsb;
597    let baseline = (state_desc << num_bits) - num_states_total;
598    (baseline, num_bits as u8)
599}
600
601pub(crate) fn highest_bit_set(x: u32) -> u32 {
602    assert!(x > 0);
603    u32::BITS - x.leading_zeros()
604}
605
606struct FSEDecoder<'table> {
607    state: FSEEntry,
608    table: &'table FSETable,
609}
610
611impl<'t> FSEDecoder<'t> {
612    fn new(table: &'t FSETable) -> FSEDecoder<'t> {
613        FSEDecoder {
614            state: table.decode.first().copied().unwrap_or(FSEEntry {
615                base_line: 0,
616                num_bits: 0,
617                symbol: 0,
618            }),
619            table,
620        }
621    }
622
623    fn decode_symbol(&self) -> u8 {
624        self.state.symbol
625    }
626
627    fn init_state(&mut self, bits: &mut BitReaderReversed<'_>) -> Result<(), String> {
628        if self.table.accuracy_log == 0 {
629            return Err("FSE table is uninitialized".to_string());
630        }
631        let new_state = bits.get_bits(self.table.accuracy_log);
632        self.state = self.table.decode[new_state as usize];
633        Ok(())
634    }
635
636    fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) {
637        let num_bits = self.state.num_bits;
638        let add = bits.get_bits(num_bits);
639        let base_line = self.state.base_line;
640        let new_state = base_line + add as u32;
641        self.state = self.table.decode[new_state as usize];
642    }
643}
644
645// ============================================================
646// Huffman Table and Decoder
647// ============================================================
648
649#[derive(Copy, Clone, Debug)]
650struct HuffmanEntry {
651    symbol: u8,
652    num_bits: u8,
653}
654
655struct HuffmanTable {
656    decode: Vec<HuffmanEntry>,
657    weights: Vec<u8>,
658    max_num_bits: u8,
659    bits: Vec<u8>,
660    bit_ranks: Vec<u32>,
661    rank_indexes: Vec<usize>,
662    fse_table: FSETable,
663}
664
665impl HuffmanTable {
666    fn new() -> HuffmanTable {
667        HuffmanTable {
668            decode: Vec::new(),
669            weights: Vec::with_capacity(256),
670            max_num_bits: 0,
671            bits: Vec::with_capacity(256),
672            bit_ranks: Vec::with_capacity(11),
673            rank_indexes: Vec::with_capacity(11),
674            fse_table: FSETable::new(255),
675        }
676    }
677
678    fn reinit_from(&mut self, other: &Self) {
679        self.reset();
680        self.decode.extend_from_slice(&other.decode);
681        self.weights.extend_from_slice(&other.weights);
682        self.max_num_bits = other.max_num_bits;
683        self.bits.extend_from_slice(&other.bits);
684        self.rank_indexes.extend_from_slice(&other.rank_indexes);
685        self.fse_table.reinit_from(&other.fse_table);
686    }
687
688    fn reset(&mut self) {
689        self.decode.clear();
690        self.weights.clear();
691        self.max_num_bits = 0;
692        self.bits.clear();
693        self.bit_ranks.clear();
694        self.rank_indexes.clear();
695        self.fse_table.reset();
696    }
697
698    fn build_decoder(&mut self, source: &[u8]) -> Result<u32, String> {
699        self.decode.clear();
700        let bytes_used = self.read_weights(source)?;
701        self.build_table_from_weights()?;
702        Ok(bytes_used)
703    }
704
705    fn read_weights(&mut self, source: &[u8]) -> Result<u32, String> {
706        if source.is_empty() {
707            return Err("Huffman source is empty".to_string());
708        }
709        let header = source[0];
710        let mut bits_read = 8;
711
712        match header {
713            0..=127 => {
714                let fse_stream = &source[1..];
715                if (header as usize) > fse_stream.len() {
716                    return Err(format!(
717                        "Not enough bytes for weights: have {}, need {}",
718                        fse_stream.len(),
719                        header
720                    ));
721                }
722                let bytes_used_by_fse_header = self.fse_table.build_decoder(fse_stream, 6)?;
723
724                if bytes_used_by_fse_header > header as usize {
725                    return Err(format!(
726                        "FSE table used {} bytes but only {} available",
727                        bytes_used_by_fse_header, header
728                    ));
729                }
730
731                let mut dec1 = FSEDecoder::new(&self.fse_table);
732                let mut dec2 = FSEDecoder::new(&self.fse_table);
733
734                let compressed_start = bytes_used_by_fse_header;
735                let compressed_length = header as usize - bytes_used_by_fse_header;
736
737                let compressed_weights = &fse_stream[compressed_start..];
738                if compressed_weights.len() < compressed_length {
739                    return Err(format!(
740                        "Not enough bytes to decompress weights: have {}, need {}",
741                        compressed_weights.len(),
742                        compressed_length
743                    ));
744                }
745                let compressed_weights = &compressed_weights[..compressed_length];
746                let mut br = BitReaderReversed::new(compressed_weights);
747
748                bits_read += (bytes_used_by_fse_header + compressed_length) * 8;
749
750                let mut skipped_bits = 0;
751                loop {
752                    let val = br.get_bits(1);
753                    skipped_bits += 1;
754                    if val == 1 || skipped_bits > 8 {
755                        break;
756                    }
757                }
758                if skipped_bits > 8 {
759                    return Err(format!("Extra padding: {} bits skipped", skipped_bits));
760                }
761
762                dec1.init_state(&mut br)?;
763                dec2.init_state(&mut br)?;
764
765                self.weights.clear();
766
767                loop {
768                    let w = dec1.decode_symbol();
769                    self.weights.push(w);
770                    dec1.update_state(&mut br);
771
772                    if br.bits_remaining() <= -1 {
773                        self.weights.push(dec2.decode_symbol());
774                        break;
775                    }
776
777                    let w = dec2.decode_symbol();
778                    self.weights.push(w);
779                    dec2.update_state(&mut br);
780
781                    if br.bits_remaining() <= -1 {
782                        self.weights.push(dec1.decode_symbol());
783                        break;
784                    }
785                    if self.weights.len() > 255 {
786                        return Err(format!("Too many weights: {}", self.weights.len()));
787                    }
788                }
789            }
790            _ => {
791                let weights_raw = &source[1..];
792                let num_weights = header - 127;
793                self.weights.resize(num_weights as usize, 0);
794
795                let bytes_needed = if num_weights % 2 == 0 {
796                    num_weights as usize / 2
797                } else {
798                    (num_weights as usize / 2) + 1
799                };
800
801                if weights_raw.len() < bytes_needed {
802                    return Err(format!(
803                        "Not enough bytes in source: have {}, need {}",
804                        weights_raw.len(),
805                        bytes_needed
806                    ));
807                }
808
809                for idx in 0..num_weights {
810                    if idx % 2 == 0 {
811                        self.weights[idx as usize] = weights_raw[idx as usize / 2] >> 4;
812                    } else {
813                        self.weights[idx as usize] = weights_raw[idx as usize / 2] & 0xF;
814                    }
815                    bits_read += 4;
816                }
817            }
818        }
819
820        let bytes_read = if bits_read % 8 == 0 {
821            bits_read / 8
822        } else {
823            (bits_read / 8) + 1
824        };
825        Ok(bytes_read as u32)
826    }
827
828    fn build_table_from_weights(&mut self) -> Result<(), String> {
829        self.bits.clear();
830        self.bits.resize(self.weights.len() + 1, 0);
831
832        let mut weight_sum: u32 = 0;
833        for w in &self.weights {
834            if *w > MAX_MAX_NUM_BITS {
835                return Err(format!("Weight {} exceeds max {}", w, MAX_MAX_NUM_BITS));
836            }
837            weight_sum += if *w > 0 { 1_u32 << (*w - 1) } else { 0 };
838        }
839
840        if weight_sum == 0 {
841            return Err("Missing weights".to_string());
842        }
843
844        let max_bits = highest_bit_set(weight_sum) as u8;
845        let left_over = (1u32 << max_bits) - weight_sum;
846
847        if !left_over.is_power_of_two() {
848            return Err(format!("Leftover {} is not a power of 2", left_over));
849        }
850
851        let last_weight = highest_bit_set(left_over) as u8;
852
853        for symbol in 0..self.weights.len() {
854            let bits = if self.weights[symbol] > 0 {
855                max_bits + 1 - self.weights[symbol]
856            } else {
857                0
858            };
859            self.bits[symbol] = bits;
860        }
861
862        self.bits[self.weights.len()] = max_bits + 1 - last_weight;
863        self.max_num_bits = max_bits;
864
865        if max_bits > MAX_MAX_NUM_BITS {
866            return Err(format!("Max bits {} too high", max_bits));
867        }
868
869        self.bit_ranks.clear();
870        self.bit_ranks.resize((max_bits + 1) as usize, 0);
871        for num_bits in &self.bits {
872            self.bit_ranks[(*num_bits) as usize] += 1;
873        }
874
875        self.decode.resize(
876            1 << self.max_num_bits,
877            HuffmanEntry {
878                symbol: 0,
879                num_bits: 0,
880            },
881        );
882
883        self.rank_indexes.clear();
884        self.rank_indexes.resize((max_bits + 1) as usize, 0);
885
886        self.rank_indexes[max_bits as usize] = 0;
887        for bits in (1..self.rank_indexes.len() as u8).rev() {
888            self.rank_indexes[bits as usize - 1] = self.rank_indexes[bits as usize]
889                + self.bit_ranks[bits as usize] as usize * (1 << (max_bits - bits));
890        }
891
892        assert!(
893            self.rank_indexes[0] == self.decode.len(),
894            "rank_idx[0]: {} should be: {}",
895            self.rank_indexes[0],
896            self.decode.len()
897        );
898
899        for symbol in 0..self.bits.len() {
900            let bits_for_symbol = self.bits[symbol];
901            if bits_for_symbol != 0 {
902                let base_idx = self.rank_indexes[bits_for_symbol as usize];
903                let len = 1 << (max_bits - bits_for_symbol);
904                self.rank_indexes[bits_for_symbol as usize] += len;
905                for idx in 0..len {
906                    self.decode[base_idx + idx].symbol = symbol as u8;
907                    self.decode[base_idx + idx].num_bits = bits_for_symbol;
908                }
909            }
910        }
911
912        Ok(())
913    }
914}
915
916struct HuffmanDecoder<'table> {
917    table: &'table HuffmanTable,
918    state: u64,
919}
920
921impl<'t> HuffmanDecoder<'t> {
922    fn new(table: &'t HuffmanTable) -> HuffmanDecoder<'t> {
923        HuffmanDecoder { table, state: 0 }
924    }
925
926    fn decode_symbol(&mut self) -> u8 {
927        self.table.decode[self.state as usize].symbol
928    }
929
930    fn init_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 {
931        let num_bits = self.table.max_num_bits;
932        let new_bits = br.get_bits(num_bits);
933        self.state = new_bits;
934        num_bits
935    }
936
937    fn next_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 {
938        let num_bits = self.table.decode[self.state as usize].num_bits;
939        let new_bits = br.get_bits(num_bits);
940        self.state <<= num_bits;
941        self.state &= self.table.decode.len() as u64 - 1;
942        self.state |= new_bits;
943        num_bits
944    }
945}
946
947// ============================================================
948// Block types and headers
949// ============================================================
950
951#[derive(Debug, Clone, Copy, PartialEq, Eq)]
952enum BlockType {
953    Raw,
954    RLE,
955    Compressed,
956    Reserved,
957}
958
959struct BlockHeader {
960    last_block: bool,
961    block_type: BlockType,
962    decompressed_size: u32,
963    content_size: u32,
964}
965
966// ============================================================
967// Literals Section
968// ============================================================
969
970enum LiteralsSectionType {
971    Raw,
972    RLE,
973    Compressed,
974    Treeless,
975}
976
977struct LiteralsSection {
978    regenerated_size: u32,
979    compressed_size: Option<u32>,
980    num_streams: Option<u8>,
981    ls_type: LiteralsSectionType,
982}
983
984impl LiteralsSection {
985    fn new() -> LiteralsSection {
986        LiteralsSection {
987            regenerated_size: 0,
988            compressed_size: None,
989            num_streams: None,
990            ls_type: LiteralsSectionType::Raw,
991        }
992    }
993
994    fn section_type(raw: u8) -> Result<LiteralsSectionType, String> {
995        let t = raw & 0x3;
996        match t {
997            0 => Ok(LiteralsSectionType::Raw),
998            1 => Ok(LiteralsSectionType::RLE),
999            2 => Ok(LiteralsSectionType::Compressed),
1000            3 => Ok(LiteralsSectionType::Treeless),
1001            other => Err(format!("Illegal literal section type: {}", other)),
1002        }
1003    }
1004
1005    fn header_bytes_needed(&self, first_byte: u8) -> Result<u8, String> {
1006        let ls_type = Self::section_type(first_byte)?;
1007        let size_format = (first_byte >> 2) & 0x3;
1008        match ls_type {
1009            LiteralsSectionType::RLE | LiteralsSectionType::Raw => match size_format {
1010                0 | 2 => Ok(1),
1011                1 => Ok(2),
1012                3 => Ok(3),
1013                _ => unreachable!(),
1014            },
1015            LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => match size_format {
1016                0 | 1 => Ok(3),
1017                2 => Ok(4),
1018                3 => Ok(5),
1019                _ => unreachable!(),
1020            },
1021        }
1022    }
1023
1024    fn parse_from_header(&mut self, raw: &[u8]) -> Result<u8, String> {
1025        let mut br = BitReader::new(raw);
1026        let block_type = br.get_bits(2)? as u8;
1027        self.ls_type = Self::section_type(block_type)?;
1028        let size_format = br.get_bits(2)? as u8;
1029
1030        let byte_needed = self.header_bytes_needed(raw[0])?;
1031        if raw.len() < byte_needed as usize {
1032            return Err(format!(
1033                "Not enough bytes for literals header: have {}, need {}",
1034                raw.len(),
1035                byte_needed
1036            ));
1037        }
1038
1039        match self.ls_type {
1040            LiteralsSectionType::RLE | LiteralsSectionType::Raw => {
1041                self.compressed_size = None;
1042                match size_format {
1043                    0 | 2 => {
1044                        self.regenerated_size = u32::from(raw[0]) >> 3;
1045                        Ok(1)
1046                    }
1047                    1 => {
1048                        self.regenerated_size = (u32::from(raw[0]) >> 4) + (u32::from(raw[1]) << 4);
1049                        Ok(2)
1050                    }
1051                    3 => {
1052                        self.regenerated_size = (u32::from(raw[0]) >> 4)
1053                            + (u32::from(raw[1]) << 4)
1054                            + (u32::from(raw[2]) << 12);
1055                        Ok(3)
1056                    }
1057                    _ => unreachable!(),
1058                }
1059            }
1060            LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => {
1061                match size_format {
1062                    0 => {
1063                        self.num_streams = Some(1);
1064                    }
1065                    1..=3 => {
1066                        self.num_streams = Some(4);
1067                    }
1068                    _ => unreachable!(),
1069                };
1070
1071                match size_format {
1072                    0 | 1 => {
1073                        self.regenerated_size =
1074                            (u32::from(raw[0]) >> 4) + ((u32::from(raw[1]) & 0x3f) << 4);
1075                        self.compressed_size =
1076                            Some(u32::from(raw[1] >> 6) + (u32::from(raw[2]) << 2));
1077                        Ok(3)
1078                    }
1079                    2 => {
1080                        self.regenerated_size = (u32::from(raw[0]) >> 4)
1081                            + (u32::from(raw[1]) << 4)
1082                            + ((u32::from(raw[2]) & 0x3) << 12);
1083                        self.compressed_size =
1084                            Some((u32::from(raw[2]) >> 2) + (u32::from(raw[3]) << 6));
1085                        Ok(4)
1086                    }
1087                    3 => {
1088                        self.regenerated_size = (u32::from(raw[0]) >> 4)
1089                            + (u32::from(raw[1]) << 4)
1090                            + ((u32::from(raw[2]) & 0x3F) << 12);
1091                        self.compressed_size = Some(
1092                            (u32::from(raw[2]) >> 6)
1093                                + (u32::from(raw[3]) << 2)
1094                                + (u32::from(raw[4]) << 10),
1095                        );
1096                        Ok(5)
1097                    }
1098                    _ => unreachable!(),
1099                }
1100            }
1101        }
1102    }
1103}
1104
1105// ============================================================
1106// Sequences Section
1107// ============================================================
1108
1109#[derive(Clone, Copy)]
1110struct Sequence {
1111    ll: u32,
1112    ml: u32,
1113    of: u32,
1114}
1115
1116#[derive(Copy, Clone)]
1117struct CompressionModes(u8);
1118
1119enum ModeType {
1120    Predefined,
1121    RLE,
1122    FSECompressed,
1123    Repeat,
1124}
1125
1126impl CompressionModes {
1127    fn decode_mode(m: u8) -> ModeType {
1128        match m {
1129            0 => ModeType::Predefined,
1130            1 => ModeType::RLE,
1131            2 => ModeType::FSECompressed,
1132            3 => ModeType::Repeat,
1133            _ => panic!("Invalid mode value"),
1134        }
1135    }
1136    fn ll_mode(self) -> ModeType {
1137        Self::decode_mode(self.0 >> 6)
1138    }
1139    fn of_mode(self) -> ModeType {
1140        Self::decode_mode((self.0 >> 4) & 0x3)
1141    }
1142    fn ml_mode(self) -> ModeType {
1143        Self::decode_mode((self.0 >> 2) & 0x3)
1144    }
1145}
1146
1147struct SequencesHeader {
1148    num_sequences: u32,
1149    modes: Option<CompressionModes>,
1150}
1151
1152impl SequencesHeader {
1153    fn new() -> SequencesHeader {
1154        SequencesHeader {
1155            num_sequences: 0,
1156            modes: None,
1157        }
1158    }
1159
1160    fn parse_from_header(&mut self, source: &[u8]) -> Result<u8, String> {
1161        let mut bytes_read = 0;
1162        if source.is_empty() {
1163            return Err("Sequences header source is empty".to_string());
1164        }
1165
1166        match source[0] {
1167            0 => {
1168                self.num_sequences = 0;
1169                bytes_read += 1;
1170            }
1171            1..=127 => {
1172                if source.len() < 2 {
1173                    return Err(format!(
1174                        "Not enough bytes for sequences header: have {}, need 2",
1175                        source.len()
1176                    ));
1177                }
1178                self.num_sequences = u32::from(source[0]);
1179                self.modes = Some(CompressionModes(source[1]));
1180                bytes_read += 2;
1181            }
1182            128..=254 => {
1183                if source.len() < 2 {
1184                    return Err(format!(
1185                        "Not enough bytes for sequences header: have {}, need 2",
1186                        source.len()
1187                    ));
1188                }
1189                self.num_sequences = ((u32::from(source[0]) - 128) << 8) + u32::from(source[1]);
1190                bytes_read += 2;
1191                if self.num_sequences != 0 {
1192                    if source.len() < 3 {
1193                        return Err(format!(
1194                            "Not enough bytes for sequences header: have {}, need 3",
1195                            source.len()
1196                        ));
1197                    }
1198                    self.modes = Some(CompressionModes(source[2]));
1199                    bytes_read += 1;
1200                }
1201            }
1202            255 => {
1203                if source.len() < 4 {
1204                    return Err(format!(
1205                        "Not enough bytes for sequences header: have {}, need 4",
1206                        source.len()
1207                    ));
1208                }
1209                self.num_sequences = u32::from(source[1]) + (u32::from(source[2]) << 8) + 0x7F00;
1210                self.modes = Some(CompressionModes(source[3]));
1211                bytes_read += 4;
1212            }
1213        }
1214
1215        Ok(bytes_read)
1216    }
1217}
1218
1219// ============================================================
1220// Decode Buffer (Vec-based, no ringbuffer)
1221// ============================================================
1222
1223struct DecodeBuffer {
1224    buffer: Vec<u8>,
1225    window_size: usize,
1226}
1227
1228impl DecodeBuffer {
1229    fn new(window_size: usize) -> DecodeBuffer {
1230        DecodeBuffer {
1231            buffer: Vec::new(),
1232            window_size,
1233        }
1234    }
1235
1236    fn reset(&mut self, window_size: usize) {
1237        self.window_size = window_size;
1238        self.buffer.clear();
1239    }
1240
1241    fn len(&self) -> usize {
1242        self.buffer.len()
1243    }
1244
1245    fn push(&mut self, data: &[u8]) {
1246        self.buffer.extend_from_slice(data);
1247    }
1248
1249    fn repeat(&mut self, offset: usize, match_length: usize) -> Result<(), String> {
1250        if offset > self.buffer.len() {
1251            return Err(format!(
1252                "Offset {} exceeds buffer length {}",
1253                offset,
1254                self.buffer.len()
1255            ));
1256        }
1257        if offset == 0 {
1258            return Err("Zero offset in repeat".to_string());
1259        }
1260
1261        let start_idx = self.buffer.len() - offset;
1262        self.buffer.reserve(match_length);
1263
1264        for i in 0..match_length {
1265            let byte = self.buffer[start_idx + (i % offset)];
1266            self.buffer.push(byte);
1267        }
1268
1269        Ok(())
1270    }
1271
1272    fn drain(&mut self) -> Vec<u8> {
1273        std::mem::take(&mut self.buffer)
1274    }
1275}
1276
1277// ============================================================
1278// Scratch space
1279// ============================================================
1280
1281struct HuffmanScratch {
1282    table: HuffmanTable,
1283}
1284
1285struct FSEScratch {
1286    offsets: FSETable,
1287    of_rle: Option<u8>,
1288    literal_lengths: FSETable,
1289    ll_rle: Option<u8>,
1290    match_lengths: FSETable,
1291    ml_rle: Option<u8>,
1292}
1293
1294struct DecoderScratch {
1295    huf: HuffmanScratch,
1296    fse: FSEScratch,
1297    buffer: DecodeBuffer,
1298    offset_hist: [u32; 3],
1299    literals_buffer: Vec<u8>,
1300    sequences: Vec<Sequence>,
1301    block_content_buffer: Vec<u8>,
1302}
1303
1304impl DecoderScratch {
1305    fn new(window_size: usize) -> DecoderScratch {
1306        DecoderScratch {
1307            huf: HuffmanScratch {
1308                table: HuffmanTable::new(),
1309            },
1310            fse: FSEScratch {
1311                offsets: FSETable::new(MAX_OFFSET_CODE),
1312                of_rle: None,
1313                literal_lengths: FSETable::new(MAX_LITERAL_LENGTH_CODE),
1314                ll_rle: None,
1315                match_lengths: FSETable::new(MAX_MATCH_LENGTH_CODE),
1316                ml_rle: None,
1317            },
1318            buffer: DecodeBuffer::new(window_size),
1319            offset_hist: [1, 4, 8],
1320            block_content_buffer: Vec::new(),
1321            literals_buffer: Vec::new(),
1322            sequences: Vec::new(),
1323        }
1324    }
1325
1326    fn reset(&mut self, window_size: usize) {
1327        self.offset_hist = [1, 4, 8];
1328        self.literals_buffer.clear();
1329        self.sequences.clear();
1330        self.block_content_buffer.clear();
1331        self.buffer.reset(window_size);
1332        self.fse.literal_lengths.reset();
1333        self.fse.match_lengths.reset();
1334        self.fse.offsets.reset();
1335        self.fse.ll_rle = None;
1336        self.fse.ml_rle = None;
1337        self.fse.of_rle = None;
1338        self.huf.table.reset();
1339    }
1340}
1341
1342// ============================================================
1343// Frame header
1344// ============================================================
1345
1346struct FrameDescriptor(u8);
1347
1348impl FrameDescriptor {
1349    fn frame_content_size_flag(&self) -> u8 {
1350        self.0 >> 6
1351    }
1352
1353    fn single_segment_flag(&self) -> bool {
1354        ((self.0 >> 5) & 0x1) == 1
1355    }
1356
1357    fn content_checksum_flag(&self) -> bool {
1358        ((self.0 >> 2) & 0x1) == 1
1359    }
1360
1361    fn dict_id_flag(&self) -> u8 {
1362        self.0 & 0x3
1363    }
1364
1365    fn frame_content_size_bytes(&self) -> Result<u8, String> {
1366        match self.frame_content_size_flag() {
1367            0 => {
1368                if self.single_segment_flag() {
1369                    Ok(1)
1370                } else {
1371                    Ok(0)
1372                }
1373            }
1374            1 => Ok(2),
1375            2 => Ok(4),
1376            3 => Ok(8),
1377            other => Err(format!("Invalid frame content size flag: {}", other)),
1378        }
1379    }
1380
1381    fn dictionary_id_bytes(&self) -> Result<u8, String> {
1382        match self.dict_id_flag() {
1383            0 => Ok(0),
1384            1 => Ok(1),
1385            2 => Ok(2),
1386            3 => Ok(4),
1387            other => Err(format!("Invalid dict id flag: {}", other)),
1388        }
1389    }
1390}
1391
1392struct FrameHeader {
1393    descriptor: FrameDescriptor,
1394    window_descriptor: u8,
1395    frame_content_size: u64,
1396}
1397
1398impl FrameHeader {
1399    fn window_size(&self) -> Result<u64, String> {
1400        if self.descriptor.single_segment_flag() {
1401            Ok(self.frame_content_size)
1402        } else {
1403            let exp = self.window_descriptor >> 3;
1404            let mantissa = self.window_descriptor & 0x7;
1405
1406            let window_log = 10 + u64::from(exp);
1407            let window_base = 1u64 << window_log;
1408            let window_add = (window_base / 8) * u64::from(mantissa);
1409
1410            let window_size = window_base + window_add;
1411
1412            if window_size < MIN_WINDOW_SIZE {
1413                Err(format!("Window size {} too small", window_size))
1414            } else if window_size >= MAX_WINDOW_SIZE {
1415                Err(format!("Window size {} too big", window_size))
1416            } else {
1417                Ok(window_size)
1418            }
1419        }
1420    }
1421
1422    fn frame_content_size(&self) -> u64 {
1423        self.frame_content_size
1424    }
1425}
1426
1427// ============================================================
1428// Error wrapper for skip frames
1429// ============================================================
1430
1431struct FrameDecoderError {
1432    msg: String,
1433    skip_length: Option<u32>,
1434}
1435
1436impl FrameDecoderError {
1437    fn new(msg: String) -> Self {
1438        Self {
1439            msg,
1440            skip_length: None,
1441        }
1442    }
1443
1444    fn skip(length: u32) -> Self {
1445        Self {
1446            msg: format!("Skippable frame with length {}", length),
1447            skip_length: Some(length),
1448        }
1449    }
1450
1451    fn skip_frame_length(&self) -> Option<u32> {
1452        self.skip_length
1453    }
1454}
1455
1456impl std::fmt::Display for FrameDecoderError {
1457    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1458        write!(f, "{}", self.msg)
1459    }
1460}
1461
1462// ============================================================
1463// Frame header reading
1464// ============================================================
1465
1466fn read_frame_header(r: &mut dyn std::io::Read) -> Result<(FrameHeader, u8), FrameDecoderError> {
1467    let mut buf = [0u8; 4];
1468
1469    r.read_exact(&mut buf)
1470        .map_err(|e| FrameDecoderError::new(format!("Error reading magic number: {}", e)))?;
1471    let mut bytes_read: usize = 4;
1472    let magic_num = u32::from_le_bytes(buf);
1473
1474    // Skippable frames
1475    if (0x184D2A50..=0x184D2A5F).contains(&magic_num) {
1476        r.read_exact(&mut buf)
1477            .map_err(|e| FrameDecoderError::new(format!("Error reading skip frame size: {}", e)))?;
1478        let skip_size = u32::from_le_bytes(buf);
1479        return Err(FrameDecoderError::skip(skip_size));
1480    }
1481
1482    if magic_num != ZSTD_MAGIC {
1483        return Err(FrameDecoderError::new(format!(
1484            "Bad magic number: 0x{:X}",
1485            magic_num
1486        )));
1487    }
1488
1489    r.read_exact(&mut buf[0..1])
1490        .map_err(|e| FrameDecoderError::new(format!("Error reading frame descriptor: {}", e)))?;
1491    let desc = FrameDescriptor(buf[0]);
1492    bytes_read += 1;
1493
1494    let mut frame_header = FrameHeader {
1495        descriptor: FrameDescriptor(desc.0),
1496        frame_content_size: 0,
1497        window_descriptor: 0,
1498    };
1499
1500    if !desc.single_segment_flag() {
1501        r.read_exact(&mut buf[0..1]).map_err(|e| {
1502            FrameDecoderError::new(format!("Error reading window descriptor: {}", e))
1503        })?;
1504        frame_header.window_descriptor = buf[0];
1505        bytes_read += 1;
1506    }
1507
1508    let dict_id_len = desc.dictionary_id_bytes().map_err(FrameDecoderError::new)? as usize;
1509    if dict_id_len != 0 {
1510        let buf = &mut buf[..dict_id_len];
1511        r.read_exact(buf)
1512            .map_err(|e| FrameDecoderError::new(format!("Error reading dictionary id: {}", e)))?;
1513        bytes_read += dict_id_len;
1514        // We don't support dictionaries, but we still need to skip these bytes
1515    }
1516
1517    let fcs_len = desc
1518        .frame_content_size_bytes()
1519        .map_err(FrameDecoderError::new)? as usize;
1520    if fcs_len != 0 {
1521        let mut fcs_buf = [0u8; 8];
1522        let fcs_buf = &mut fcs_buf[..fcs_len];
1523        r.read_exact(fcs_buf).map_err(|e| {
1524            FrameDecoderError::new(format!("Error reading frame content size: {}", e))
1525        })?;
1526        bytes_read += fcs_len;
1527        let mut fcs = 0u64;
1528        for i in 0..fcs_len {
1529            fcs += (fcs_buf[i] as u64) << (8 * i);
1530        }
1531        if fcs_len == 2 {
1532            fcs += 256;
1533        }
1534        frame_header.frame_content_size = fcs;
1535    }
1536
1537    Ok((frame_header, bytes_read as u8))
1538}
1539
1540// ============================================================
1541// Block header reading
1542// ============================================================
1543
1544fn read_block_header(r: &mut dyn std::io::Read) -> Result<(BlockHeader, u8), String> {
1545    let mut buf = [0u8; 3];
1546    r.read_exact(&mut buf)
1547        .map_err(|e| format!("Error reading block header: {}", e))?;
1548
1549    let last_block = buf[0] & 0x1 == 1;
1550    let block_type_raw = (buf[0] >> 1) & 0x3;
1551    let block_type = match block_type_raw {
1552        0 => BlockType::Raw,
1553        1 => BlockType::RLE,
1554        2 => BlockType::Compressed,
1555        3 => BlockType::Reserved,
1556        _ => unreachable!(),
1557    };
1558
1559    if block_type == BlockType::Reserved {
1560        return Err("Found reserved block type".to_string());
1561    }
1562
1563    let block_size = u32::from(buf[0] >> 3) | (u32::from(buf[1]) << 5) | (u32::from(buf[2]) << 13);
1564
1565    if block_size > MAX_BLOCK_SIZE {
1566        return Err(format!(
1567            "Block size {} exceeds max {}",
1568            block_size, MAX_BLOCK_SIZE
1569        ));
1570    }
1571
1572    let decompressed_size = match block_type {
1573        BlockType::Raw | BlockType::RLE => block_size,
1574        BlockType::Compressed | BlockType::Reserved => 0,
1575    };
1576    let content_size = match block_type {
1577        BlockType::Raw | BlockType::Compressed => block_size,
1578        BlockType::RLE => 1,
1579        BlockType::Reserved => 0,
1580    };
1581
1582    Ok((
1583        BlockHeader {
1584            last_block,
1585            block_type,
1586            decompressed_size,
1587            content_size,
1588        },
1589        3,
1590    ))
1591}
1592
1593// ============================================================
1594// Literals section decoder
1595// ============================================================
1596
1597fn decode_literals(
1598    section: &LiteralsSection,
1599    scratch: &mut HuffmanScratch,
1600    source: &[u8],
1601    target: &mut Vec<u8>,
1602) -> Result<u32, String> {
1603    match section.ls_type {
1604        LiteralsSectionType::Raw => {
1605            target.extend(&source[0..section.regenerated_size as usize]);
1606            Ok(section.regenerated_size)
1607        }
1608        LiteralsSectionType::RLE => {
1609            target.resize(target.len() + section.regenerated_size as usize, source[0]);
1610            Ok(1)
1611        }
1612        LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => {
1613            decompress_literals(section, scratch, source, target)
1614        }
1615    }
1616}
1617
1618fn decompress_literals(
1619    section: &LiteralsSection,
1620    scratch: &mut HuffmanScratch,
1621    source: &[u8],
1622    target: &mut Vec<u8>,
1623) -> Result<u32, String> {
1624    let compressed_size = section
1625        .compressed_size
1626        .ok_or_else(|| "Missing compressed size".to_string())? as usize;
1627    let num_streams = section
1628        .num_streams
1629        .ok_or_else(|| "Missing num_streams".to_string())?;
1630
1631    target.reserve(section.regenerated_size as usize);
1632    let source = &source[0..compressed_size];
1633    let mut bytes_read = 0u32;
1634
1635    match section.ls_type {
1636        LiteralsSectionType::Compressed => {
1637            bytes_read += scratch.table.build_decoder(source)?;
1638        }
1639        LiteralsSectionType::Treeless => {
1640            if scratch.table.max_num_bits == 0 {
1641                return Err("Uninitialized Huffman table for treeless literals".to_string());
1642            }
1643        }
1644        _ => {}
1645    }
1646
1647    let source = &source[bytes_read as usize..];
1648
1649    if num_streams == 4 {
1650        if source.len() < 6 {
1651            return Err(format!(
1652                "Missing bytes for jump header: have {}",
1653                source.len()
1654            ));
1655        }
1656        let jump1 = source[0] as usize + ((source[1] as usize) << 8);
1657        let jump2 = jump1 + source[2] as usize + ((source[3] as usize) << 8);
1658        let jump3 = jump2 + source[4] as usize + ((source[5] as usize) << 8);
1659        bytes_read += 6;
1660        let source = &source[6..];
1661
1662        if source.len() < jump3 {
1663            return Err(format!(
1664                "Missing bytes for literals: have {}, need {}",
1665                source.len(),
1666                jump3
1667            ));
1668        }
1669
1670        let stream1 = &source[..jump1];
1671        let stream2 = &source[jump1..jump2];
1672        let stream3 = &source[jump2..jump3];
1673        let stream4 = &source[jump3..];
1674
1675        for stream in &[stream1, stream2, stream3, stream4] {
1676            let mut decoder = HuffmanDecoder::new(&scratch.table);
1677            let mut br = BitReaderReversed::new(stream);
1678            let mut skipped_bits = 0;
1679            loop {
1680                let val = br.get_bits(1);
1681                skipped_bits += 1;
1682                if val == 1 || skipped_bits > 8 {
1683                    break;
1684                }
1685            }
1686            if skipped_bits > 8 {
1687                return Err(format!("Extra padding: {} bits skipped", skipped_bits));
1688            }
1689            decoder.init_state(&mut br);
1690
1691            while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
1692                target.push(decoder.decode_symbol());
1693                decoder.next_state(&mut br);
1694            }
1695            if br.bits_remaining() != -(scratch.table.max_num_bits as isize) {
1696                return Err(format!(
1697                    "Bitstream read mismatch: {} vs expected {}",
1698                    br.bits_remaining(),
1699                    -(scratch.table.max_num_bits as isize)
1700                ));
1701            }
1702        }
1703
1704        bytes_read += source.len() as u32;
1705    } else {
1706        assert!(num_streams == 1);
1707        let mut decoder = HuffmanDecoder::new(&scratch.table);
1708        let mut br = BitReaderReversed::new(source);
1709        let mut skipped_bits = 0;
1710        loop {
1711            let val = br.get_bits(1);
1712            skipped_bits += 1;
1713            if val == 1 || skipped_bits > 8 {
1714                break;
1715            }
1716        }
1717        if skipped_bits > 8 {
1718            return Err(format!("Extra padding: {} bits skipped", skipped_bits));
1719        }
1720        decoder.init_state(&mut br);
1721        while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
1722            target.push(decoder.decode_symbol());
1723            decoder.next_state(&mut br);
1724        }
1725        bytes_read += source.len() as u32;
1726    }
1727
1728    if target.len() != section.regenerated_size as usize {
1729        return Err(format!(
1730            "Decoded literal count mismatch: {} vs expected {}",
1731            target.len(),
1732            section.regenerated_size
1733        ));
1734    }
1735
1736    Ok(bytes_read)
1737}
1738
1739// ============================================================
1740// Sequence section decoder
1741// ============================================================
1742
1743fn decode_sequences(
1744    section: &SequencesHeader,
1745    source: &[u8],
1746    scratch: &mut FSEScratch,
1747    target: &mut Vec<Sequence>,
1748) -> Result<(), String> {
1749    let bytes_read = maybe_update_fse_tables(section, source, scratch)?;
1750    let bit_stream = &source[bytes_read..];
1751
1752    let mut br = BitReaderReversed::new(bit_stream);
1753
1754    let mut skipped_bits = 0;
1755    loop {
1756        let val = br.get_bits(1);
1757        skipped_bits += 1;
1758        if val == 1 || skipped_bits > 8 {
1759            break;
1760        }
1761    }
1762    if skipped_bits > 8 {
1763        return Err(format!("Extra padding: {} bits skipped", skipped_bits));
1764    }
1765
1766    if scratch.ll_rle.is_some() || scratch.ml_rle.is_some() || scratch.of_rle.is_some() {
1767        decode_sequences_with_rle(section, &mut br, scratch, target)
1768    } else {
1769        decode_sequences_without_rle(section, &mut br, scratch, target)
1770    }
1771}
1772
1773fn decode_sequences_with_rle(
1774    section: &SequencesHeader,
1775    br: &mut BitReaderReversed<'_>,
1776    scratch: &FSEScratch,
1777    target: &mut Vec<Sequence>,
1778) -> Result<(), String> {
1779    let mut ll_dec = FSEDecoder::new(&scratch.literal_lengths);
1780    let mut ml_dec = FSEDecoder::new(&scratch.match_lengths);
1781    let mut of_dec = FSEDecoder::new(&scratch.offsets);
1782
1783    if scratch.ll_rle.is_none() {
1784        ll_dec.init_state(br)?;
1785    }
1786    if scratch.of_rle.is_none() {
1787        of_dec.init_state(br)?;
1788    }
1789    if scratch.ml_rle.is_none() {
1790        ml_dec.init_state(br)?;
1791    }
1792
1793    target.clear();
1794    target.reserve(section.num_sequences as usize);
1795
1796    for _seq_idx in 0..section.num_sequences {
1797        let ll_code = scratch.ll_rle.unwrap_or_else(|| ll_dec.decode_symbol());
1798        let ml_code = scratch.ml_rle.unwrap_or_else(|| ml_dec.decode_symbol());
1799        let of_code = scratch.of_rle.unwrap_or_else(|| of_dec.decode_symbol());
1800
1801        let (ll_value, ll_num_bits) = lookup_ll_code(ll_code)?;
1802        let (ml_value, ml_num_bits) = lookup_ml_code(ml_code)?;
1803
1804        if of_code > MAX_OFFSET_CODE {
1805            return Err(format!("Unsupported offset code: {}", of_code));
1806        }
1807
1808        let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
1809        let offset = obits as u32 + (1u32 << of_code);
1810
1811        if offset == 0 {
1812            return Err("Zero offset".to_string());
1813        }
1814
1815        target.push(Sequence {
1816            ll: ll_value + ll_add as u32,
1817            ml: ml_value + ml_add as u32,
1818            of: offset,
1819        });
1820
1821        if target.len() < section.num_sequences as usize {
1822            if scratch.ll_rle.is_none() {
1823                ll_dec.update_state(br);
1824            }
1825            if scratch.ml_rle.is_none() {
1826                ml_dec.update_state(br);
1827            }
1828            if scratch.of_rle.is_none() {
1829                of_dec.update_state(br);
1830            }
1831        }
1832
1833        if br.bits_remaining() < 0 {
1834            return Err("Not enough bytes for number of sequences".to_string());
1835        }
1836    }
1837
1838    if br.bits_remaining() > 0 {
1839        Err(format!("Extra bits remaining: {}", br.bits_remaining()))
1840    } else {
1841        Ok(())
1842    }
1843}
1844
1845fn decode_sequences_without_rle(
1846    section: &SequencesHeader,
1847    br: &mut BitReaderReversed<'_>,
1848    scratch: &FSEScratch,
1849    target: &mut Vec<Sequence>,
1850) -> Result<(), String> {
1851    let mut ll_dec = FSEDecoder::new(&scratch.literal_lengths);
1852    let mut ml_dec = FSEDecoder::new(&scratch.match_lengths);
1853    let mut of_dec = FSEDecoder::new(&scratch.offsets);
1854
1855    ll_dec.init_state(br)?;
1856    of_dec.init_state(br)?;
1857    ml_dec.init_state(br)?;
1858
1859    target.clear();
1860    target.reserve(section.num_sequences as usize);
1861
1862    for _seq_idx in 0..section.num_sequences {
1863        let ll_code = ll_dec.decode_symbol();
1864        let ml_code = ml_dec.decode_symbol();
1865        let of_code = of_dec.decode_symbol();
1866
1867        let (ll_value, ll_num_bits) = lookup_ll_code(ll_code)?;
1868        let (ml_value, ml_num_bits) = lookup_ml_code(ml_code)?;
1869
1870        if of_code > MAX_OFFSET_CODE {
1871            return Err(format!("Unsupported offset code: {}", of_code));
1872        }
1873
1874        let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
1875        let offset = obits as u32 + (1u32 << of_code);
1876
1877        if offset == 0 {
1878            return Err("Zero offset".to_string());
1879        }
1880
1881        target.push(Sequence {
1882            ll: ll_value + ll_add as u32,
1883            ml: ml_value + ml_add as u32,
1884            of: offset,
1885        });
1886
1887        if target.len() < section.num_sequences as usize {
1888            ll_dec.update_state(br);
1889            ml_dec.update_state(br);
1890            of_dec.update_state(br);
1891        }
1892
1893        if br.bits_remaining() < 0 {
1894            return Err("Not enough bytes for number of sequences".to_string());
1895        }
1896    }
1897
1898    if br.bits_remaining() > 0 {
1899        Err(format!("Extra bits remaining: {}", br.bits_remaining()))
1900    } else {
1901        Ok(())
1902    }
1903}
1904
1905fn lookup_ll_code(code: u8) -> Result<(u32, u8), String> {
1906    let result = match code {
1907        0..=15 => (u32::from(code), 0),
1908        16 => (16, 1),
1909        17 => (18, 1),
1910        18 => (20, 1),
1911        19 => (22, 1),
1912        20 => (24, 2),
1913        21 => (28, 2),
1914        22 => (32, 3),
1915        23 => (40, 3),
1916        24 => (48, 4),
1917        25 => (64, 6),
1918        26 => (128, 7),
1919        27 => (256, 8),
1920        28 => (512, 9),
1921        29 => (1024, 10),
1922        30 => (2048, 11),
1923        31 => (4096, 12),
1924        32 => (8192, 13),
1925        33 => (16384, 14),
1926        34 => (32768, 15),
1927        35 => (65536, 16),
1928        _ => return Err(format!("Illegal literal length code: {}", code)),
1929    };
1930    Ok(result)
1931}
1932
1933fn lookup_ml_code(code: u8) -> Result<(u32, u8), String> {
1934    let result = match code {
1935        0..=31 => (u32::from(code) + 3, 0),
1936        32 => (35, 1),
1937        33 => (37, 1),
1938        34 => (39, 1),
1939        35 => (41, 1),
1940        36 => (43, 2),
1941        37 => (47, 2),
1942        38 => (51, 3),
1943        39 => (59, 3),
1944        40 => (67, 4),
1945        41 => (83, 4),
1946        42 => (99, 5),
1947        43 => (131, 7),
1948        44 => (259, 8),
1949        45 => (515, 9),
1950        46 => (1027, 10),
1951        47 => (2051, 11),
1952        48 => (4099, 12),
1953        49 => (8195, 13),
1954        50 => (16387, 14),
1955        51 => (32771, 15),
1956        52 => (65539, 16),
1957        _ => return Err(format!("Illegal match length code: {}", code)),
1958    };
1959    Ok(result)
1960}
1961
1962fn maybe_update_fse_tables(
1963    section: &SequencesHeader,
1964    source: &[u8],
1965    scratch: &mut FSEScratch,
1966) -> Result<usize, String> {
1967    let modes = section
1968        .modes
1969        .ok_or_else(|| "Missing compression mode".to_string())?;
1970
1971    let mut bytes_read = 0;
1972
1973    match modes.ll_mode() {
1974        ModeType::FSECompressed => {
1975            let bytes = scratch.literal_lengths.build_decoder(source, LL_MAX_LOG)?;
1976            bytes_read += bytes;
1977            scratch.ll_rle = None;
1978        }
1979        ModeType::RLE => {
1980            if source.is_empty() {
1981                return Err("Missing byte for RLE LL table".to_string());
1982            }
1983            bytes_read += 1;
1984            if source[0] > MAX_LITERAL_LENGTH_CODE {
1985                return Err(format!("RLE LL code {} exceeds max", source[0]));
1986            }
1987            scratch.ll_rle = Some(source[0]);
1988        }
1989        ModeType::Predefined => {
1990            scratch.literal_lengths.build_from_probabilities(
1991                LL_DEFAULT_ACC_LOG,
1992                &LITERALS_LENGTH_DEFAULT_DISTRIBUTION,
1993            )?;
1994            scratch.ll_rle = None;
1995        }
1996        ModeType::Repeat => { /* Nothing to do */ }
1997    };
1998
1999    let of_source = &source[bytes_read..];
2000
2001    match modes.of_mode() {
2002        ModeType::FSECompressed => {
2003            let bytes = scratch.offsets.build_decoder(of_source, OF_MAX_LOG)?;
2004            bytes_read += bytes;
2005            scratch.of_rle = None;
2006        }
2007        ModeType::RLE => {
2008            if of_source.is_empty() {
2009                return Err("Missing byte for RLE OF table".to_string());
2010            }
2011            bytes_read += 1;
2012            if of_source[0] > MAX_OFFSET_CODE {
2013                return Err(format!("RLE OF code {} exceeds max", of_source[0]));
2014            }
2015            scratch.of_rle = Some(of_source[0]);
2016        }
2017        ModeType::Predefined => {
2018            scratch
2019                .offsets
2020                .build_from_probabilities(OF_DEFAULT_ACC_LOG, &OFFSET_DEFAULT_DISTRIBUTION)?;
2021            scratch.of_rle = None;
2022        }
2023        ModeType::Repeat => { /* Nothing to do */ }
2024    };
2025
2026    let ml_source = &source[bytes_read..];
2027
2028    match modes.ml_mode() {
2029        ModeType::FSECompressed => {
2030            let bytes = scratch.match_lengths.build_decoder(ml_source, ML_MAX_LOG)?;
2031            bytes_read += bytes;
2032            scratch.ml_rle = None;
2033        }
2034        ModeType::RLE => {
2035            if ml_source.is_empty() {
2036                return Err("Missing byte for RLE ML table".to_string());
2037            }
2038            bytes_read += 1;
2039            if ml_source[0] > MAX_MATCH_LENGTH_CODE {
2040                return Err(format!("RLE ML code {} exceeds max", ml_source[0]));
2041            }
2042            scratch.ml_rle = Some(ml_source[0]);
2043        }
2044        ModeType::Predefined => {
2045            scratch
2046                .match_lengths
2047                .build_from_probabilities(ML_DEFAULT_ACC_LOG, &MATCH_LENGTH_DEFAULT_DISTRIBUTION)?;
2048            scratch.ml_rle = None;
2049        }
2050        ModeType::Repeat => { /* Nothing to do */ }
2051    };
2052
2053    Ok(bytes_read)
2054}
2055
2056// ============================================================
2057// Sequence execution
2058// ============================================================
2059
2060fn execute_sequences(scratch: &mut DecoderScratch) -> Result<(), String> {
2061    let mut literals_copy_counter = 0;
2062    let old_buffer_size = scratch.buffer.len();
2063    let mut seq_sum = 0u32;
2064
2065    for idx in 0..scratch.sequences.len() {
2066        let seq = scratch.sequences[idx];
2067
2068        if seq.ll > 0 {
2069            let high = literals_copy_counter + seq.ll as usize;
2070            if high > scratch.literals_buffer.len() {
2071                return Err(format!(
2072                    "Not enough bytes for sequence: wanted {}, have {}",
2073                    high,
2074                    scratch.literals_buffer.len()
2075                ));
2076            }
2077            let literals = &scratch.literals_buffer[literals_copy_counter..high];
2078            literals_copy_counter += seq.ll as usize;
2079            scratch.buffer.push(literals);
2080        }
2081
2082        let actual_offset = do_offset_history(seq.of, seq.ll, &mut scratch.offset_hist);
2083        if actual_offset == 0 {
2084            return Err("Zero offset in sequence execution".to_string());
2085        }
2086        if seq.ml > 0 {
2087            scratch
2088                .buffer
2089                .repeat(actual_offset as usize, seq.ml as usize)?;
2090        }
2091
2092        seq_sum += seq.ml;
2093        seq_sum += seq.ll;
2094    }
2095
2096    if literals_copy_counter < scratch.literals_buffer.len() {
2097        let rest_literals = &scratch.literals_buffer[literals_copy_counter..];
2098        scratch.buffer.push(rest_literals);
2099        seq_sum += rest_literals.len() as u32;
2100    }
2101
2102    let diff = scratch.buffer.len() - old_buffer_size;
2103    assert!(
2104        seq_sum as usize == diff,
2105        "Seq_sum: {} is different from the difference in buffersize: {}",
2106        seq_sum,
2107        diff
2108    );
2109    Ok(())
2110}
2111
2112fn do_offset_history(offset_value: u32, lit_len: u32, scratch: &mut [u32; 3]) -> u32 {
2113    let actual_offset = if lit_len > 0 {
2114        match offset_value {
2115            1..=3 => scratch[offset_value as usize - 1],
2116            _ => offset_value - 3,
2117        }
2118    } else {
2119        match offset_value {
2120            1..=2 => scratch[offset_value as usize],
2121            3 => scratch[0].wrapping_sub(1),
2122            _ => offset_value - 3,
2123        }
2124    };
2125
2126    if lit_len > 0 {
2127        match offset_value {
2128            1 => { /* nothing */ }
2129            2 => {
2130                scratch[1] = scratch[0];
2131                scratch[0] = actual_offset;
2132            }
2133            _ => {
2134                scratch[2] = scratch[1];
2135                scratch[1] = scratch[0];
2136                scratch[0] = actual_offset;
2137            }
2138        }
2139    } else {
2140        match offset_value {
2141            1 => {
2142                scratch[1] = scratch[0];
2143                scratch[0] = actual_offset;
2144            }
2145            _ => {
2146                scratch[2] = scratch[1];
2147                scratch[1] = scratch[0];
2148                scratch[0] = actual_offset;
2149            }
2150        }
2151    }
2152
2153    actual_offset
2154}
2155
2156// ============================================================
2157// Block decoder
2158// ============================================================
2159
2160fn decode_block_content(
2161    header: &BlockHeader,
2162    workspace: &mut DecoderScratch,
2163    source: &mut dyn std::io::Read,
2164) -> Result<u64, String> {
2165    match header.block_type {
2166        BlockType::RLE => {
2167            const BATCH_SIZE: usize = 512;
2168            let mut buf = [0u8; BATCH_SIZE];
2169            let full_reads = header.decompressed_size / BATCH_SIZE as u32;
2170            let single_read_size = header.decompressed_size % BATCH_SIZE as u32;
2171
2172            source
2173                .read_exact(&mut buf[0..1])
2174                .map_err(|e| format!("Error reading RLE byte: {}", e))?;
2175
2176            for i in 1..BATCH_SIZE {
2177                buf[i] = buf[0];
2178            }
2179
2180            for _ in 0..full_reads {
2181                workspace.buffer.push(&buf[..]);
2182            }
2183            let smaller = &buf[..single_read_size as usize];
2184            workspace.buffer.push(smaller);
2185
2186            Ok(1)
2187        }
2188        BlockType::Raw => {
2189            const BATCH_SIZE: usize = 128 * 1024;
2190            let mut buf = [0u8; BATCH_SIZE];
2191            let full_reads = header.decompressed_size / BATCH_SIZE as u32;
2192            let single_read_size = header.decompressed_size % BATCH_SIZE as u32;
2193
2194            for _ in 0..full_reads {
2195                source
2196                    .read_exact(&mut buf[..])
2197                    .map_err(|e| format!("Error reading raw block: {}", e))?;
2198                workspace.buffer.push(&buf[..]);
2199            }
2200
2201            let smaller = &mut buf[..single_read_size as usize];
2202            source
2203                .read_exact(smaller)
2204                .map_err(|e| format!("Error reading raw block: {}", e))?;
2205            workspace.buffer.push(smaller);
2206
2207            Ok(u64::from(header.decompressed_size))
2208        }
2209        BlockType::Reserved => Err("Reserved block type encountered".to_string()),
2210        BlockType::Compressed => {
2211            decompress_block(header, workspace, source)?;
2212            Ok(u64::from(header.content_size))
2213        }
2214    }
2215}
2216
2217fn decompress_block(
2218    header: &BlockHeader,
2219    workspace: &mut DecoderScratch,
2220    source: &mut dyn std::io::Read,
2221) -> Result<(), String> {
2222    workspace
2223        .block_content_buffer
2224        .resize(header.content_size as usize, 0);
2225
2226    source
2227        .read_exact(workspace.block_content_buffer.as_mut_slice())
2228        .map_err(|e| format!("Error reading compressed block: {}", e))?;
2229    let raw = workspace.block_content_buffer.as_slice();
2230
2231    let mut section = LiteralsSection::new();
2232    let bytes_in_literals_header = section.parse_from_header(raw)?;
2233    let raw = &raw[bytes_in_literals_header as usize..];
2234
2235    let upper_limit_for_literals = match section.compressed_size {
2236        Some(x) => x as usize,
2237        None => match section.ls_type {
2238            LiteralsSectionType::RLE => 1,
2239            LiteralsSectionType::Raw => section.regenerated_size as usize,
2240            _ => return Err("Bug: unexpected literals section type".to_string()),
2241        },
2242    };
2243
2244    if raw.len() < upper_limit_for_literals {
2245        return Err(format!(
2246            "Malformed section header: expected {} bytes, have {}",
2247            upper_limit_for_literals,
2248            raw.len()
2249        ));
2250    }
2251
2252    let raw_literals = &raw[..upper_limit_for_literals];
2253
2254    workspace.literals_buffer.clear();
2255    let bytes_used_in_literals_section = decode_literals(
2256        &section,
2257        &mut workspace.huf,
2258        raw_literals,
2259        &mut workspace.literals_buffer,
2260    )?;
2261    assert!(
2262        section.regenerated_size == workspace.literals_buffer.len() as u32,
2263        "Wrong number of literals: {}, Should have been: {}",
2264        workspace.literals_buffer.len(),
2265        section.regenerated_size
2266    );
2267    assert!(bytes_used_in_literals_section == upper_limit_for_literals as u32);
2268
2269    let raw = &raw[upper_limit_for_literals..];
2270
2271    let mut seq_section = SequencesHeader::new();
2272    let bytes_in_sequence_header = seq_section.parse_from_header(raw)?;
2273    let raw = &raw[bytes_in_sequence_header as usize..];
2274
2275    assert!(
2276        u32::from(bytes_in_literals_header)
2277            + bytes_used_in_literals_section
2278            + u32::from(bytes_in_sequence_header)
2279            + raw.len() as u32
2280            == header.content_size
2281    );
2282
2283    if seq_section.num_sequences != 0 {
2284        decode_sequences(
2285            &seq_section,
2286            raw,
2287            &mut workspace.fse,
2288            &mut workspace.sequences,
2289        )?;
2290        execute_sequences(workspace)?;
2291    } else {
2292        if !raw.is_empty() {
2293            return Err(format!(
2294                "Extra bits remaining: {} bits",
2295                raw.len() as isize * 8
2296            ));
2297        }
2298        workspace.buffer.push(&workspace.literals_buffer);
2299        workspace.sequences.clear();
2300    }
2301
2302    Ok(())
2303}
2304
2305// ============================================================
2306// Frame Decoder (top-level)
2307// ============================================================
2308
2309struct FrameDecoder {
2310    scratch: Option<DecoderScratch>,
2311    frame_header: Option<FrameHeader>,
2312    frame_finished: bool,
2313}
2314
2315impl FrameDecoder {
2316    fn new() -> FrameDecoder {
2317        FrameDecoder {
2318            scratch: None,
2319            frame_header: None,
2320            frame_finished: false,
2321        }
2322    }
2323
2324    fn reset(&mut self, source: &mut dyn std::io::Read) -> Result<(), FrameDecoderError> {
2325        let (frame_header, _header_size) = read_frame_header(source)?;
2326        let window_size = frame_header.window_size().map_err(FrameDecoderError::new)?;
2327
2328        if window_size > MAXIMUM_ALLOWED_WINDOW_SIZE {
2329            return Err(FrameDecoderError::new(format!(
2330                "Window size {} exceeds maximum allowed {}",
2331                window_size, MAXIMUM_ALLOWED_WINDOW_SIZE
2332            )));
2333        }
2334
2335        match &mut self.scratch {
2336            Some(s) => s.reset(window_size as usize),
2337            None => {
2338                self.scratch = Some(DecoderScratch::new(window_size as usize));
2339            }
2340        }
2341
2342        self.frame_header = Some(frame_header);
2343        self.frame_finished = false;
2344        Ok(())
2345    }
2346
2347    fn decode_all_blocks(&mut self, source: &mut dyn std::io::Read) -> Result<(), String> {
2348        let scratch = self
2349            .scratch
2350            .as_mut()
2351            .ok_or_else(|| "Decoder not initialized".to_string())?;
2352
2353        loop {
2354            let (block_header, _block_header_size) = read_block_header(source)?;
2355
2356            decode_block_content(&block_header, scratch, source)?;
2357
2358            if block_header.last_block {
2359                self.frame_finished = true;
2360
2361                // Read and discard checksum if present
2362                if let Some(ref fh) = self.frame_header {
2363                    if fh.descriptor.content_checksum_flag() {
2364                        let mut chksum = [0u8; 4];
2365                        source
2366                            .read_exact(&mut chksum)
2367                            .map_err(|e| format!("Error reading checksum: {}", e))?;
2368                        // We skip checksum verification in this simplified decoder
2369                    }
2370                }
2371                break;
2372            }
2373        }
2374
2375        Ok(())
2376    }
2377
2378    fn collect(&mut self) -> Option<Vec<u8>> {
2379        self.scratch.as_mut().map(|s| s.buffer.drain())
2380    }
2381}
2382
2383#[cfg(test)]
2384mod tests {
2385    use super::*;
2386
2387    #[test]
2388    fn test_empty_input() {
2389        let result = decompress(&[]);
2390        assert!(result.is_ok());
2391        assert!(result.unwrap().is_empty());
2392    }
2393
2394    #[test]
2395    fn test_bad_magic() {
2396        let result = decompress(&[0, 0, 0, 0, 0]);
2397        assert!(result.is_err());
2398    }
2399
2400    #[test]
2401    fn test_roundtrip_raw() {
2402        // A minimal zstd frame: magic + frame header + single raw block
2403        // This test builds a valid frame with a raw block containing "hello"
2404        let data = b"hello";
2405        let mut frame = Vec::new();
2406        // Magic number
2407        frame.extend_from_slice(&ZSTD_MAGIC.to_le_bytes());
2408        // Frame descriptor: single_segment=1, no checksum, no dict, fcs_flag=0
2409        // So FCS field = 1 byte
2410        frame.push(0x20); // single_segment_flag set
2411                          // FCS = 5 (length of "hello")
2412        frame.push(5);
2413        // Block header: last_block=1, type=raw(0), size=5
2414        // Encoding: bit0=last(1), bit1-2=type(0), bit3-20=size(5)
2415        let bh = 1u32 | (0u32 << 1) | (5u32 << 3);
2416        frame.push((bh & 0xFF) as u8);
2417        frame.push(((bh >> 8) & 0xFF) as u8);
2418        frame.push(((bh >> 16) & 0xFF) as u8);
2419        // Block content
2420        frame.extend_from_slice(data);
2421
2422        let result = decompress(&frame).unwrap();
2423        assert_eq!(result, data);
2424    }
2425
2426    #[test]
2427    fn test_roundtrip_rle() {
2428        // Frame with an RLE block: 10 copies of byte 0x42
2429        let mut frame = Vec::new();
2430        frame.extend_from_slice(&ZSTD_MAGIC.to_le_bytes());
2431        frame.push(0x20); // single_segment_flag set
2432        frame.push(10); // FCS = 10
2433                        // Block header: last_block=1, type=RLE(1), size=10
2434        let bh = 1u32 | (1u32 << 1) | (10u32 << 3);
2435        frame.push((bh & 0xFF) as u8);
2436        frame.push(((bh >> 8) & 0xFF) as u8);
2437        frame.push(((bh >> 16) & 0xFF) as u8);
2438        // Single RLE byte
2439        frame.push(0x42);
2440
2441        let result = decompress(&frame).unwrap();
2442        assert_eq!(result, vec![0x42; 10]);
2443    }
2444
2445    #[test]
2446    fn test_roundtrip_with_compressor() {
2447        // Use the crate's own compressor to produce a valid zstd frame,
2448        // then decompress with our decoder.
2449        let data = b"Hello, world! This is a test of the zstd compression and decompression round-trip. \
2450                      The quick brown fox jumps over the lazy dog. \
2451                      AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA \
2452                      BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB \
2453                      Hello, world! This is a test of the zstd compression and decompression round-trip.";
2454        let compressed = crate::compress::compress_to_vec(data);
2455        let decompressed = decompress(&compressed).unwrap();
2456        assert_eq!(decompressed, data);
2457    }
2458
2459    #[test]
2460    fn test_roundtrip_larger() {
2461        // Test with larger data that triggers compressed blocks.
2462        let data = Vec::with_capacity(16384);
2463        let compressed = crate::compress::compress_to_vec(&data);
2464        let decompressed = decompress(&compressed).unwrap();
2465        assert_eq!(decompressed, data);
2466    }
2467}