Skip to main content

sit_algos/
lzah.rs

1use std::io::{self, Seek};
2
3use bitstream_io::BitRead;
4
5use super::huffman_decoder::HuffmanDecoder;
6
7#[derive(Debug, thiserror::Error)]
8pub enum Error {
9    #[error(transparent)]
10    Io(#[from] io::Error),
11
12    #[error("Encountered an invalid tree path while reading next symbol")]
13    InvalidCode,
14}
15
16impl From<Error> for io::Error {
17    fn from(val: Error) -> Self {
18        match val {
19            Error::Io(error) => error,
20            err => io::Error::other(err),
21        }
22    }
23}
24
25pub(crate) enum LiteralOrOffset {
26    Literal(u8),
27    Offset { length: u16, offset: u32 },
28}
29
30#[derive(Debug, Copy, Clone)]
31pub struct TreeNode {
32    parent_ptr: Option<NodeId>,
33    left_ptr: Option<NodeId>,
34    right_ptr: Option<NodeId>,
35    index: TreeIdx,
36    value: u16,
37    frequency: usize,
38}
39
40#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
41struct NodeId(u16);
42
43impl NodeId {
44    #[inline]
45    fn raw(&self) -> usize {
46        self.0 as usize
47    }
48}
49
50#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord)]
51struct TreeIdx(u16);
52
53impl TreeIdx {
54    #[inline]
55    fn is_root(&self) -> bool {
56        *self == ROOT
57    }
58
59    #[inline]
60    fn previous(&self) -> Option<Self> {
61        if self.is_root() {
62            None
63        } else {
64            Some(Self(self.0 - 1))
65        }
66    }
67
68    #[inline]
69    fn index(&self) -> usize {
70        self.0 as usize
71    }
72}
73
74const ROOT: TreeIdx = TreeIdx(0);
75
76impl Default for TreeNode {
77    fn default() -> Self {
78        Self {
79            parent_ptr: None,
80            left_ptr: None,
81            right_ptr: None,
82            index: Default::default(),
83            value: Default::default(),
84            frequency: usize::MAX,
85        }
86    }
87}
88
89const WINDOW_SIZE: usize = 4096;
90const LEAF_COUNT: usize = 314;
91const NODE_COUNT: usize = LEAF_COUNT * 2 - 1;
92
93pub(crate) struct LzSlidingWindow<const WINDOW_SIZE: usize> {
94    pub(crate) window: [u8; WINDOW_SIZE],
95    pub(crate) window_mask: usize,
96    pub(crate) match_len: u16,
97    pub(crate) match_offset: i32,
98}
99
100const fn default_window() -> LzSlidingWindow<WINDOW_SIZE> {
101    LzSlidingWindow {
102        match_len: 0,
103        match_offset: 0,
104        window_mask: WINDOW_SIZE - 1,
105        window: default_window_contents(),
106    }
107}
108
109const fn default_window_contents() -> [u8; WINDOW_SIZE] {
110    // Use cfor macros to loop at constant time
111    use cfor::cfor;
112
113    let mut window = [0u8; WINDOW_SIZE];
114
115    let mut cur = 0;
116
117    // "Add" 18 leading zeros
118    cur += 18;
119
120    // Add 13 repetitions of each byte
121    cfor! {let mut i=0; i < 256; i += 1; {
122        cfor!{let mut j=0; j < 13; j+=1; {
123            window[cur + i * 13 + j] = i as u8;
124        }}
125    }}
126    cur += 13 * 256;
127
128    // Add sequence of increasing bytes
129    cfor! {let mut i=0; i < 256; i += 1; {
130        window[cur + i] = i as u8;
131    }}
132    cur += 256;
133
134    // Add sequence of decreasing bytes
135    cfor! {let mut i=0; i < 256; i+=1; {
136        window[cur + i] = 255 - i as u8;
137    }}
138    cur += 256;
139
140    // "Add" a run of 128 zeros
141    cur += 128;
142
143    // Fill rest of window (110 bytes) with ascii spaces
144    cfor! {let mut i=0; i < 110; i += 1; {
145        window[cur + i] = b' ';
146    }}
147    cur += 110;
148
149    if cur != WINDOW_SIZE {
150        panic!("Something went wrong during window initialization");
151    }
152
153    window
154}
155
156impl Default for LzSlidingWindow<WINDOW_SIZE> {
157    fn default() -> Self {
158        Self {
159            window: default_window_contents(),
160            window_mask: WINDOW_SIZE - 1,
161            match_len: 0,
162            match_offset: 0,
163        }
164    }
165}
166
167impl<const WINDOW_SIZE: usize> LzSlidingWindow<WINDOW_SIZE> {
168    #[inline]
169    pub(crate) fn is_empty(&self) -> bool {
170        self.match_len == 0
171    }
172
173    #[inline]
174    pub(crate) fn update(&mut self, pos: usize, val: LiteralOrOffset) -> u8 {
175        match val {
176            LiteralOrOffset::Literal(lit) => {
177                self.window[pos & self.window_mask] = lit;
178                lit
179            }
180            LiteralOrOffset::Offset { length, offset } => {
181                self.match_len = length;
182                self.match_offset = pos as i32 - offset as i32;
183
184                self.next(pos)
185            }
186        }
187    }
188
189    #[inline]
190    pub(crate) fn next(&mut self, pos: usize) -> u8 {
191        self.match_len -= 1;
192        let byte = self.window[self.match_offset as usize & self.window_mask];
193        self.match_offset += 1;
194        self.window[pos & self.window_mask] = byte;
195
196        byte
197    }
198}
199
200/// Reader for LZ77+Huffman with Adaptive Huffman coding streams
201pub struct LzahReader<R: io::Read + io::Seek> {
202    inner: bitstream_io::BitReader<R, bitstream_io::BigEndian>,
203    uncompressed_size: u64,
204    nodes: [TreeNode; NODE_COUNT],
205    tree: [NodeId; NODE_COUNT],
206
207    pos: usize,
208    decoder: HuffmanDecoder,
209    win: LzSlidingWindow<WINDOW_SIZE>,
210}
211
212impl<R: io::Read + io::Seek> LzahReader<R> {
213    pub fn new(inner: R, uncompressed_size: u64) -> Self {
214        let mut decoder = HuffmanDecoder::initialize(
215            &[
216                3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7,
217                7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8,
218                8, 8, 8, 8, 8, 8, 8, 8,
219            ],
220            8,
221            true,
222        )
223        .unwrap();
224        decoder.make_table(false);
225
226        let mut me = Self {
227            inner: bitstream_io::BitReader::<_, bitstream_io::BigEndian>::new(inner),
228            nodes: [Default::default(); NODE_COUNT],
229            tree: [NodeId(0); NODE_COUNT],
230            decoder,
231            pos: 0,
232            win: default_window(),
233            uncompressed_size,
234        };
235        me.reset();
236        me
237    }
238
239    fn reset(&mut self) {
240        self.nodes = [Default::default(); NODE_COUNT];
241        self.tree = [Default::default(); NODE_COUNT];
242
243        // Initialize tree
244        self.tree
245            .iter_mut()
246            .enumerate()
247            .for_each(|(idx, node)| *node = NodeId(idx as u16));
248
249        // Initialize leaves
250        for i in 0..LEAF_COUNT {
251            let node = NodeId(NODE_COUNT as u16 - 1 - i as u16);
252            self.node_mut(node).index = TreeIdx(node.0);
253            self.node_mut(node).frequency = 1;
254            self.node_mut(node).value = i as u16;
255        }
256
257        // Initialize intermediate nodes and hierarchy
258        for i in (0..(LEAF_COUNT - 1)).rev() {
259            let parent = NodeId(i as u16);
260            let left = NodeId(i as u16 * 2 + 1);
261            let right = NodeId(i as u16 * 2 + 2);
262
263            self.node_mut(parent).index = TreeIdx(i as u16);
264            self.node_mut(parent).left_ptr = Some(left);
265            self.node_mut(parent).right_ptr = Some(right);
266            self.node_mut(parent).frequency =
267                self.node(left).frequency + self.node(right).frequency;
268
269            self.node_mut(left).parent_ptr = Some(parent);
270            self.node_mut(right).parent_ptr = Some(parent);
271        }
272    }
273
274    pub fn into_inner(self) -> R {
275        self.inner.into_reader()
276    }
277
278    #[inline]
279    fn next(&mut self) -> Result<LiteralOrOffset, Error> {
280        let mut node = self.tree_lookup(ROOT);
281        while self.node(node).left_ptr.is_some() || self.node(node).right_ptr.is_some() {
282            node = match self.inner.read_bit() {
283                Ok(true) => self.node(node).left_ptr.unwrap(),
284                Ok(false) => self.node(node).right_ptr.unwrap(),
285                Err(e) => return Err(e)?,
286            }
287        }
288
289        if self.tree_node(ROOT).frequency == 0x8000 {
290            self.reconstruct_tree();
291        }
292
293        self.update_node(node);
294
295        let literal = self.node(node).value;
296        if literal < 0x100 {
297            return Ok(LiteralOrOffset::Literal(literal as u8));
298        }
299
300        let length = literal - 0x100 + 3;
301        let highbits = self.decoder.next_symbol(&mut self.inner)?;
302        let lowbits = self.inner.read::<6, u32>()?;
303        let offset = ((highbits as u32) << 6) + lowbits + 1;
304
305        Ok(LiteralOrOffset::Offset { length, offset })
306    }
307
308    #[inline]
309    fn update_node(&mut self, node: NodeId) {
310        let mut node = node;
311        loop {
312            self.node_mut(node).frequency += 1;
313
314            if self.node(node).parent_ptr.is_none() {
315                break;
316            }
317
318            self.rearrange_node(node);
319            // Re-fetch node's parent as it might have been changed during rearrangment
320            node = self.node(node).parent_ptr.unwrap();
321        }
322    }
323
324    fn rearrange_node(&mut self, node: NodeId) {
325        let TreeNode {
326            index: node_idx,
327            frequency,
328            ..
329        } = self.node(node);
330
331        // Find ancestor in tree with lower frequency than `node`
332        let mut ancestor = node_idx;
333        while let Some(prev) = ancestor.previous() {
334            if self.tree_node(prev).frequency < frequency {
335                ancestor = prev;
336            } else {
337                break;
338            }
339        }
340
341        // Move `node` closer to root if we've found a better spot
342        if ancestor < node_idx {
343            self.swap_nodes(ancestor, node_idx);
344        }
345    }
346
347    fn swap_nodes(&mut self, node_1_idx: TreeIdx, node_2_idx: TreeIdx) {
348        let node_1 = self.tree_lookup(node_1_idx);
349        let parent_1 = self.node(node_1).parent_ptr;
350
351        let node_2 = self.tree_lookup(node_2_idx);
352        let parent_2 = self.node(node_2).parent_ptr;
353
354        // Determine which branch each node is in before making any changes, in case they both belong to the same node
355        let node_1_is_right_child = parent_1
356            .map(|parent| self.node(parent).right_ptr == Some(node_1))
357            .unwrap_or_default();
358        let node_2_is_right_child = parent_2
359            .map(|parent| self.node(parent).right_ptr == Some(node_2))
360            .unwrap_or_default();
361
362        // Update child pointer of node1's parent to point to node2
363        if let Some(parent) = parent_1 {
364            if node_1_is_right_child {
365                self.node_mut(parent).right_ptr = Some(node_2);
366            } else {
367                self.node_mut(parent).left_ptr = Some(node_2);
368            }
369        }
370
371        // Update child pointer of node2's parent to point to node1
372        if let Some(parent) = parent_2 {
373            if node_2_is_right_child {
374                self.node_mut(parent).right_ptr = Some(node_1);
375            } else {
376                self.node_mut(parent).left_ptr = Some(node_1);
377            }
378        }
379
380        // Update parent pointers
381        self.node_mut(node_1).parent_ptr = parent_2;
382        self.node_mut(node_2).parent_ptr = parent_1;
383
384        // Update self-position in tree
385        self.node_mut(node_1).index = node_2_idx;
386        self.node_mut(node_2).index = node_1_idx;
387
388        // Update references in tree
389        self.tree[node_1_idx.index()] = node_2;
390        self.tree[node_2_idx.index()] = node_1;
391    }
392
393    fn reconstruct_tree(&mut self) {
394        let mut leaf_nodes = Vec::with_capacity(LEAF_COUNT);
395
396        // Collect all leaf nodes and half their frequency
397        for index in 0..NODE_COUNT {
398            let node = self.tree[index];
399            if self.is_leaf(node) {
400                self.node_mut(node).frequency = self.node(node).frequency.div_ceil(2);
401                leaf_nodes.push(node);
402            }
403        }
404        assert_eq!(leaf_nodes.len(), LEAF_COUNT);
405
406        let mut leaf_index = LEAF_COUNT as i32 - 1;
407        let mut branch_index = LEAF_COUNT as i32 - 2;
408        let mut node_index: i32 = NODE_COUNT as i32 - 1;
409        let mut pair_index: i32 = NODE_COUNT as i32 - 2;
410
411        while node_index >= 0 {
412            while node_index >= pair_index {
413                let leaf = leaf_nodes[leaf_index as usize];
414                self.tree[node_index as usize] = leaf;
415                self.node_mut(leaf).index = TreeIdx(node_index as u16);
416                node_index -= 1;
417                leaf_index -= 1;
418            }
419
420            let branch = NodeId(branch_index as u16);
421            let left_child = self.tree[pair_index as usize];
422            let right_child = self.tree[pair_index as usize + 1];
423            self.node_mut(branch).left_ptr = Some(left_child);
424            self.node_mut(branch).right_ptr = Some(right_child);
425            self.node_mut(left_child).parent_ptr = Some(branch);
426            self.node_mut(right_child).parent_ptr = Some(branch);
427            self.node_mut(branch).frequency =
428                self.node(left_child).frequency + self.node(right_child).frequency;
429            branch_index -= 1;
430
431            while leaf_index >= 0
432                && self.node(leaf_nodes[leaf_index as usize]).frequency
433                    <= self.node(branch).frequency
434            {
435                let leaf = leaf_nodes[leaf_index as usize];
436                self.tree[node_index as usize] = leaf;
437                self.node_mut(leaf).index = TreeIdx(node_index as u16);
438                node_index -= 1;
439                leaf_index -= 1;
440            }
441
442            self.tree[node_index as usize] = branch;
443            self.node_mut(branch).index = TreeIdx(node_index as u16);
444
445            node_index -= 1;
446            pair_index -= 2;
447        }
448        self.node_mut(self.tree_lookup(ROOT)).parent_ptr = None;
449    }
450
451    #[inline]
452    fn is_leaf(&self, node: NodeId) -> bool {
453        !self.has_left_child(node) && !self.has_right_child(node)
454    }
455
456    #[inline]
457    fn has_left_child(&self, node: NodeId) -> bool {
458        self.node(node).left_ptr.is_some()
459    }
460
461    #[inline]
462    fn has_right_child(&self, node: NodeId) -> bool {
463        self.node(node).right_ptr.is_some()
464    }
465
466    #[inline]
467    fn produce_next_byte(&mut self) -> Result<u8, Error> {
468        if self.win.is_empty() {
469            let token = self.next()?;
470            return Ok(self.win.update(self.pos, token));
471        }
472
473        Ok(self.win.next(self.pos))
474    }
475
476    #[inline]
477    fn node(&self, node: NodeId) -> TreeNode {
478        self.nodes[node.raw()]
479    }
480
481    #[inline]
482    fn node_mut(&mut self, node: NodeId) -> &mut TreeNode {
483        &mut self.nodes[node.raw()]
484    }
485
486    #[inline]
487    fn tree_node(&self, idx: TreeIdx) -> TreeNode {
488        self.node(self.tree[idx.index()])
489    }
490
491    #[inline]
492    fn tree_lookup(&self, idx: TreeIdx) -> NodeId {
493        self.tree[idx.index()]
494    }
495}
496
497impl<R: io::Read + io::Seek> io::Read for LzahReader<R> {
498    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
499        for (idx, b) in buf.iter_mut().enumerate() {
500            if self.stream_position()? >= self.stream_len()? {
501                return Ok(idx);
502            }
503
504            match self.produce_next_byte() {
505                Ok(byte) => {
506                    *b = byte;
507                    self.pos += 1;
508                }
509                Err(e) => return Err(e)?,
510            }
511        }
512
513        Ok(buf.len())
514    }
515}
516
517impl<R: io::Read + io::Seek> io::Seek for LzahReader<R> {
518    fn seek(&mut self, _: io::SeekFrom) -> io::Result<u64> {
519        todo!()
520    }
521
522    #[inline]
523    fn stream_position(&mut self) -> io::Result<u64> {
524        Ok(self.pos as u64)
525    }
526
527    #[inline]
528    fn stream_len(&mut self) -> io::Result<u64> {
529        Ok(self.uncompressed_size)
530    }
531}