Skip to main content

sit_algos/
lzw.rs

1use std::{
2    cmp::Ordering,
3    io::{self, Read},
4};
5
6use bitstream_io::BitRead2;
7
8#[derive(Debug, thiserror::Error)]
9pub enum Error {
10    #[error(transparent)]
11    Io(#[from] io::Error),
12    #[error("Invalid LZW code encountered")]
13    InvalidCode,
14    #[error("Too many LZW codes encountered")]
15    TooManyCodes,
16}
17
18impl From<Error> for io::Error {
19    fn from(val: Error) -> Self {
20        match val {
21            Error::Io(error) => error,
22            Error::InvalidCode => io::Error::other("Invalid LZW code encountered"),
23            Error::TooManyCodes => io::Error::other("Too many LZW codes encountered"),
24        }
25    }
26}
27
28pub struct LzwTree<const N: usize> {
29    symbol_count: usize,
30    previous_symbol: usize,
31    symbol_size: u32,
32
33    parents: [usize; N],
34    values: [u8; N],
35    buffer: Vec<u8>,
36}
37
38impl<const N: usize> Default for LzwTree<N> {
39    fn default() -> Self {
40        if N < 256 {
41            panic!("Invalid LZWTree configuration");
42        }
43
44        let mut symbols = [0u8; N];
45        symbols
46            .iter_mut()
47            .take(256)
48            .enumerate()
49            .for_each(|(i, symbol)| *symbol = i as u8);
50
51        Self {
52            symbol_count: 256 + 1,
53            symbol_size: 9,
54            previous_symbol: usize::MAX,
55            parents: [usize::MAX; N],
56            buffer: Vec::with_capacity(1024),
57            values: symbols,
58        }
59    }
60}
61
62impl<const N: usize> LzwTree<N> {
63    fn reset(&mut self) {
64        self.symbol_count = 256 + 1;
65        self.previous_symbol = usize::MAX;
66        self.symbol_size = 9;
67    }
68
69    fn advance(&mut self, symbol: usize) -> Result<&[u8], Error> {
70        if self.previous_symbol == usize::MAX {
71            if symbol >= self.symbol_count {
72                return Err(Error::InvalidCode);
73            }
74
75            self.previous_symbol = symbol;
76        } else {
77            let value = match symbol.cmp(&self.symbol_count) {
78                Ordering::Less => self.find_first_byte(symbol),
79                Ordering::Equal => self.find_first_byte(self.previous_symbol),
80                Ordering::Greater => return Err(Error::InvalidCode),
81            };
82
83            let parent = self.previous_symbol;
84            self.previous_symbol = symbol;
85
86            if !self.full() {
87                self.parents[self.symbol_count] = parent;
88                self.values[self.symbol_count] = value;
89                self.symbol_count += 1;
90
91                if !self.full() && (self.symbol_count & (self.symbol_count - 1)) == 0 {
92                    self.symbol_size += 1;
93                }
94            } else {
95                log::warn!("Ignore overflowing code table, hopefully the block ends soon…");
96            }
97        }
98
99        let n = self.output_len();
100        if n > self.buffer.len() {
101            self.buffer = vec![0u8; n];
102        }
103
104        let mut i = n;
105        let mut symbol = self.previous_symbol;
106        loop {
107            match symbol {
108                usize::MAX => return Ok(&self.buffer[0..n]),
109                _ => {
110                    self.buffer[i - 1] = self.values[symbol];
111                    symbol = self.parents[symbol];
112                    i -= 1;
113                }
114            }
115        }
116    }
117
118    fn find_first_byte(&mut self, mut symbol: usize) -> u8 {
119        assert_ne!(symbol, usize::MAX);
120        loop {
121            match self.parents[symbol] {
122                usize::MAX => return self.values[symbol],
123                _ => symbol = self.parents[symbol],
124            }
125        }
126    }
127
128    fn full(&self) -> bool {
129        self.symbol_count == N
130    }
131
132    fn output_len(&self) -> usize {
133        let mut n = 0;
134        let mut symbol = self.previous_symbol;
135        loop {
136            match symbol {
137                usize::MAX => return n,
138                _ => {
139                    n += 1;
140                    symbol = self.parents[symbol]
141                }
142            }
143        }
144    }
145}
146
147pub struct LzwReader<R: io::Read> {
148    initialized: bool,
149    inner: bitstream_io::BitReader<R, bitstream_io::LittleEndian>,
150    tree: LzwTree<0x4000>,
151    symbol_counter: u32,
152    buffer: Vec<u8>,
153    buffer_pos: usize,
154
155    position: u64,
156    uncompressed_size: u64,
157}
158
159impl<R: io::Read> LzwReader<R> {
160    pub fn new(inner: R, uncompressed_size: u64) -> Self {
161        Self {
162            initialized: false,
163            inner: bitstream_io::BitReader::<_, bitstream_io::LittleEndian>::new(inner),
164            tree: Default::default(),
165            symbol_counter: 0,
166            buffer: Vec::new(),
167            buffer_pos: 0,
168
169            position: 0,
170            uncompressed_size,
171        }
172    }
173
174    pub fn into_inner(self) -> R {
175        self.inner.into_reader()
176    }
177
178    fn decode_chunk(&mut self) -> Result<&[u8], Error> {
179        loop {
180            self.symbol_counter += 1;
181            match self.inner.read(self.tree.symbol_size)? {
182                256u16 => {
183                    log::info!("End of block found");
184                    if !self.symbol_counter.is_multiple_of(8) {
185                        self.inner
186                            .skip(self.tree.symbol_size * (8 - (self.symbol_counter % 8)))?;
187                    }
188                    self.tree.reset();
189                    self.symbol_counter = 0;
190                }
191                symbol => return self.tree.advance(symbol as usize),
192            }
193        }
194    }
195}
196
197impl<R: io::Read> io::Read for LzwReader<R> {
198    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
199        if self.position >= self.uncompressed_size {
200            return Ok(0);
201        }
202
203        if !self.initialized {
204            self.initialized = true;
205            self.buffer = self.decode_chunk()?.to_vec();
206            self.buffer_pos = 0;
207        }
208
209        for (idx, byte) in buf.iter_mut().enumerate() {
210            // Copy already decoded data
211            if self.buffer_pos < self.buffer.len() {
212                *byte = self.buffer[self.buffer_pos];
213                self.buffer_pos += 1;
214                continue;
215            }
216
217            // Decode another chunk
218            self.buffer = match self.decode_chunk() {
219                Ok(buf) => buf.to_vec(),
220                Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(idx),
221                Err(e) => return Err(e.into()),
222            };
223            self.buffer_pos = 0;
224
225            // Copy more data if we have it, otherwise report number of copied bytes thus far
226            if self.buffer_pos < self.buffer.len() {
227                *byte = self.buffer[self.buffer_pos];
228                self.buffer_pos += 1;
229                continue;
230            }
231
232            self.position += idx as u64;
233            return Ok(idx);
234        }
235
236        Ok(buf.len())
237    }
238}
239
240impl<R: io::Read> io::Seek for LzwReader<R> {
241    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
242        Ok(match pos {
243            io::SeekFrom::Current(0) => todo!(),
244            io::SeekFrom::Current(n) if n < 0 => todo!(),
245            io::SeekFrom::Current(x) => {
246                let mut buf = vec![0u8; x as usize];
247                self.read(&mut buf)? as u64
248            }
249            io::SeekFrom::End(_) => todo!(),
250            io::SeekFrom::Start(n) if n > self.position => {
251                self.seek(io::SeekFrom::Current(n as i64 - self.position as i64))?
252            }
253            _ => todo!(),
254        })
255    }
256
257    #[inline]
258    fn stream_position(&mut self) -> io::Result<u64> {
259        Ok(self.position)
260    }
261
262    #[inline]
263    fn stream_len(&mut self) -> io::Result<u64> {
264        Ok(self.uncompressed_size)
265    }
266}