Skip to main content

sit_algos/
huffman_decoder.rs

1use bitstream_io::{BitRead, BitReader, Endianness};
2use std::io::{self};
3
4#[derive(thiserror::Error, Debug)]
5pub enum InitializationError {
6    #[error("Prefix already exists")]
7    DuplicatedPrefix,
8    #[error("Invalid repeat position")]
9    InvalidRepeatPosition,
10    #[error("Invalid repeating code")]
11    InvaildRepeatingCode,
12}
13
14impl From<InitializationError> for io::Error {
15    fn from(value: InitializationError) -> Self {
16        io::Error::other(value)
17    }
18}
19
20#[derive(thiserror::Error, Debug)]
21pub enum DecompressionError {
22    #[error("Invalid prefix code when getting next symbol [length]")]
23    InvalidPrefixCodeLength,
24    #[error("Invalid prefix code when getting next symbol [code]")]
25    InvalidPrefixCodeCode,
26}
27
28impl From<DecompressionError> for io::Error {
29    fn from(value: DecompressionError) -> Self {
30        io::Error::other(value)
31    }
32}
33
34#[derive(Clone, Debug)]
35struct HuffmanTreeNode {
36    left: i32,
37    right: i32,
38}
39
40#[derive(Default, Clone)]
41struct HuffmanTableRow {
42    length: i32,
43    value: i32,
44}
45
46#[derive(Clone)]
47pub struct HuffmanDecoder {
48    min_length: usize,
49    max_length: usize,
50
51    tree: Vec<HuffmanTreeNode>,
52
53    table: Option<Vec<HuffmanTableRow>>,
54    table_size: usize,
55}
56
57impl Default for HuffmanDecoder {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl HuffmanDecoder {
64    pub fn new() -> Self {
65        let mut me = Self {
66            min_length: usize::MAX,
67            max_length: usize::MIN,
68            tree: Vec::new(),
69            table: None,
70            table_size: 0,
71        };
72        me.new_node();
73        me
74    }
75
76    pub fn initialize(
77        lengths: &[isize],
78        max_code_length: isize,
79        zeros: bool,
80    ) -> Result<Self, InitializationError> {
81        let mut me = Self::new();
82        let mut code = 0;
83        let mut unhandled_symbols = lengths.len();
84
85        for length in 1..=max_code_length {
86            for (i, cur_len) in lengths.iter().enumerate() {
87                if *cur_len != length {
88                    continue;
89                }
90
91                me.add_value(
92                    i as i32,
93                    if zeros { code } else { !code },
94                    length as usize,
95                    length as usize,
96                )?;
97
98                code += 1;
99
100                unhandled_symbols -= 1;
101                if unhandled_symbols == 0 {
102                    break;
103                }
104            }
105
106            code <<= 1;
107        }
108
109        Ok(me)
110    }
111
112    pub fn add_value(
113        &mut self,
114        value: i32,
115        code: u32,
116        length: usize,
117        repeat_pos: usize,
118    ) -> Result<(), InitializationError> {
119        self.max_length = self.max_length.max(length);
120        self.min_length = self.min_length.min(length);
121
122        let repeat_pos = length as isize - 1 - repeat_pos as isize;
123        let mut last_node = 0;
124
125        let codest = (((repeat_pos - 1) as u32) >> 1) & 3;
126        if repeat_pos == 0 || (repeat_pos >= 0 && (codest == 0 || codest == 3)) {
127            return Err(InitializationError::InvalidRepeatPosition);
128        }
129
130        let mut bitpos = length as isize - 1;
131        loop {
132            if bitpos < 0 {
133                break;
134            }
135
136            let bit = ((code >> bitpos) & 1) != 0;
137            if self.is_leaf_node(last_node) {
138                return Err(InitializationError::DuplicatedPrefix);
139            };
140
141            if bitpos == repeat_pos {
142                if !self.is_open_branch(last_node, bit) {
143                    return Err(InitializationError::InvaildRepeatingCode);
144                };
145
146                let repeat_node = self.new_node();
147                let next_node = self.new_node();
148
149                self.set_branch(last_node, bit, repeat_node);
150                self.set_branch(repeat_node, bit, repeat_node);
151                self.set_branch(repeat_node, !bit, next_node);
152
153                last_node = next_node;
154
155                bitpos += 1;
156            } else {
157                if self.is_open_branch(last_node, bit) {
158                    let new_node = self.new_node();
159                    self.set_branch(last_node, bit, new_node);
160                }
161                last_node = self.branch(last_node, bit);
162            }
163
164            bitpos -= 1;
165        }
166
167        if !self.is_empty_node(last_node) {
168            return Err(InitializationError::DuplicatedPrefix);
169        }
170
171        self.set_leaf_value(last_node, value);
172
173        Ok(())
174    }
175
176    pub(crate) fn new_node(&mut self) -> i32 {
177        self.tree.push(HuffmanTreeNode {
178            left: -1,
179            right: -2,
180        });
181        self.tree.len() as i32 - 1
182    }
183
184    pub(crate) fn set_leaf_value(&mut self, node: i32, value: i32) {
185        self.set_branch(node, false, value);
186        self.set_branch(node, true, value);
187    }
188
189    fn leaf_value(&self, node: i32) -> i32 {
190        assert!(self.branch(node, false) == self.branch(node, true));
191        self.branch(node, false)
192    }
193
194    fn is_empty_node(&self, node: i32) -> bool {
195        self.branch(node, false) == -1 && self.branch(node, true) == -2
196    }
197
198    fn is_leaf_node(&self, node: i32) -> bool {
199        self.branch(node, false) == self.branch(node, true)
200    }
201
202    fn is_open_branch(&self, node: i32, right_branch: bool) -> bool {
203        if right_branch {
204            self.tree[node as usize].right < 0
205        } else {
206            self.tree[node as usize].left < 0
207        }
208    }
209
210    pub(crate) fn set_branch(&mut self, node: i32, right_branch: bool, value: i32) {
211        if right_branch {
212            self.tree[node as usize].right = value;
213        } else {
214            self.tree[node as usize].left = value;
215        }
216    }
217
218    fn branch(&self, node: i32, right_branch: bool) -> i32 {
219        if right_branch {
220            self.tree[node as usize].right
221        } else {
222            self.tree[node as usize].left
223        }
224    }
225
226    fn is_invalid_node(&self, node: i32) -> bool {
227        node < 0
228    }
229
230    pub fn add_value_lf(
231        &mut self,
232        value: i32,
233        code: u32,
234        length: usize,
235        repeat_pos: usize,
236    ) -> Result<(), InitializationError> {
237        self.add_value(value, reverse_n(code, length), length, repeat_pos)
238    }
239
240    #[inline]
241    pub fn next_symbol<R: io::Read + io::Seek, E: Endianness>(
242        &mut self,
243        input: &mut bitstream_io::BitReader<R, E>,
244    ) -> io::Result<i32> {
245        let Some(table) = self.table.as_ref() else {
246            panic!("Huffman: search table not built");
247        };
248
249        let bits: u32 = input.peek_word(self.table_size as u32)?;
250        let HuffmanTableRow { length, value } = table[bits as usize];
251        if length <= 0 {
252            return Err(DecompressionError::InvalidPrefixCodeLength)?;
253        }
254
255        if length <= self.table_size as i32 {
256            input.skip(length as u32)?;
257            return Ok(value);
258        }
259
260        input.skip(self.table_size as u32)?;
261
262        let mut node = value;
263        loop {
264            if self.is_leaf_node(node) {
265                break;
266            }
267            let bit = input.read_bit()?;
268            if self.is_open_branch(node, bit) {
269                return Err(DecompressionError::InvalidPrefixCodeCode)?;
270            }
271            node = self.branch(node, bit);
272        }
273
274        Ok(self.leaf_value(node))
275    }
276
277    fn make_table_recursive_le(&mut self, node: i32, table: &mut [HuffmanTableRow], depth: i32) {
278        let curr_table_size = 1 << (self.table_size as i32 - depth);
279        let curr_stride = 1 << depth;
280
281        if self.is_invalid_node(node) {
282            for i in 0..curr_table_size {
283                table[i * curr_stride].length = -1;
284            }
285        } else if self.is_leaf_node(node) {
286            for i in 0..curr_table_size {
287                table[i * curr_stride].length = depth;
288                table[i * curr_stride].value = self.leaf_value(node);
289            }
290        } else if depth == self.table_size as i32 {
291            table[0].length = self.table_size as i32 + 1;
292            table[0].value = node;
293        } else {
294            self.make_table_recursive_le(self.branch(node, false), table, depth + 1);
295            let size = table.len();
296            self.make_table_recursive_le(
297                self.branch(node, true),
298                &mut table[curr_stride..size],
299                depth + 1,
300            );
301        }
302    }
303
304    fn make_table_recursive_be(&mut self, node: i32, table: &mut [HuffmanTableRow], depth: i32) {
305        let curr_table_size = 1 << (self.table_size as i32 - depth);
306
307        if self.is_invalid_node(node) {
308            table
309                .iter_mut()
310                .take(curr_table_size)
311                .for_each(|i| i.length = -1);
312        } else if self.is_leaf_node(node) {
313            table.iter_mut().take(curr_table_size).for_each(|i| {
314                i.length = depth;
315                i.value = self.leaf_value(node);
316            });
317        } else if depth == self.table_size as i32 {
318            table[0].length = self.table_size as i32 + 1;
319            table[0].value = node;
320        } else {
321            self.make_table_recursive_be(self.branch(node, false), table, depth + 1);
322            let size = table.len();
323            self.make_table_recursive_be(
324                self.branch(node, true),
325                &mut table[(curr_table_size / 2)..size],
326                depth + 1,
327            );
328        }
329    }
330
331    pub fn make_table(&mut self, little_endian: bool) {
332        self.table_size = if self.max_length < self.min_length || self.max_length >= 10 {
333            10
334        } else {
335            self.max_length
336        };
337
338        let mut table = vec![Default::default(); 1 << self.table_size];
339        if little_endian {
340            self.make_table_recursive_le(0, &mut table, 0)
341        } else {
342            self.make_table_recursive_be(0, &mut table, 0)
343        }
344        self.table = Some(table);
345    }
346}
347
348#[inline]
349fn reverse(value: u32) -> u32 {
350    let mut val = value;
351    val = ((val >> 1) & 0x55555555) | ((val & 0x55555555) << 1);
352    val = ((val >> 2) & 0x33333333) | ((val & 0x33333333) << 2);
353    val = ((val >> 4) & 0x0F0F0F0F) | ((val & 0x0F0F0F0F) << 4);
354    val = ((val >> 8) & 0x00FF00FF) | ((val & 0x00FF00FF) << 8);
355    val.rotate_left(16)
356}
357
358#[inline]
359fn reverse_n(value: u32, length: usize) -> u32 {
360    reverse(value) >> (32 - length)
361}
362
363pub trait ReadWord {
364    fn peek_word(&mut self, bits: u32) -> io::Result<u32>;
365    fn read_word(&mut self, bits: u32) -> io::Result<u32>;
366}
367
368impl<R: io::Read + io::Seek, E: Endianness> ReadWord for BitReader<R, E> {
369    // TODO: Peaking with a backwards seek like this is verify expensive if the underlying
370    // stream is itself not easily seekable like binhex's sixbit rle reader for example.
371    // This bin reader should probably be wrapped in way that caches the peeked bits without doing
372    // a backwards seek!
373    fn peek_word(&mut self, bits: u32) -> io::Result<u32> {
374        let start = self.position_in_bits().unwrap();
375        match self.read_var::<u32>(bits) {
376            Ok(result) => {
377                self.seek_bits(io::SeekFrom::Current(-(bits as i64)))?;
378                Ok(result)
379            }
380            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
381                let end = self.seek_bits(io::SeekFrom::End(0)).unwrap();
382                self.seek_bits(io::SeekFrom::Start(start)).unwrap();
383
384                if start < end {
385                    log::debug!("Adjusting to {} bit peek", (end - start));
386                    let available_bits = end - start;
387                    let result = self.read_var::<u32>(available_bits as u32)?
388                        << (bits as u64 - available_bits);
389                    self.seek_bits(io::SeekFrom::Current(-(available_bits as i64)))?;
390                    return Ok(result);
391                }
392                Err(e)
393            }
394            Err(e) => Err(e),
395        }
396    }
397
398    #[inline]
399    fn read_word(&mut self, bits: u32) -> io::Result<u32> {
400        let start = self.position_in_bits().unwrap();
401
402        match self.read_var::<u32>(bits) {
403            Ok(result) => Ok(result),
404            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
405                let end = self.seek_bits(io::SeekFrom::End(0)).unwrap();
406                if start < end {
407                    log::debug!("Adjusting to {} bit read", (end - start));
408                    let available_bits = end - start;
409                    self.seek_bits(io::SeekFrom::Start(start)).unwrap();
410                    return Ok(self.read_var::<u32>(available_bits as u32)?
411                        << (bits as u64 - available_bits));
412                }
413
414                Err(e)
415            }
416            Err(e) => Err(e),
417        }
418    }
419}
420
421#[cfg(test)]
422mod test {
423    use super::*;
424
425    #[test]
426    fn successful_initialization() {
427        assert!(
428            HuffmanDecoder::initialize(
429                &[
430                    3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7,
431                    7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8,
432                    8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
433                ],
434                8,
435                true,
436            )
437            .is_ok()
438        )
439    }
440}