Skip to main content

zrip_core/huffman/
encode.rs

1#![forbid(unsafe_code)]
2
3#[cfg(feature = "alloc")]
4use alloc::vec;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7
8use super::primitives;
9use crate::huffman::{MAX_BITS, MAX_SYMBOL_VALUE};
10
11pub struct HuffmanEncodeTable {
12    codes: [u16; MAX_SYMBOL_VALUE + 1],
13    num_bits: [u8; MAX_SYMBOL_VALUE + 1],
14    weights: Vec<u8>,
15    max_symbol: u8,
16    table_log: u8,
17}
18
19#[cfg(feature = "alloc")]
20impl HuffmanEncodeTable {
21    pub fn from_data(data: &[u8]) -> Option<Self> {
22        if data.is_empty() {
23            return None;
24        }
25
26        let mut freqs = [0u32; MAX_SYMBOL_VALUE + 1];
27        let mut max_sym = 0u8;
28        for &b in data {
29            freqs[b as usize] += 1;
30            if b > max_sym {
31                max_sym = b;
32            }
33        }
34
35        let num_symbols = max_sym as usize + 1;
36        let active_count = freqs[..num_symbols].iter().filter(|&&f| f > 0).count();
37        if active_count < 2 {
38            return None;
39        }
40
41        if max_sym as usize > 128 {
42            return None;
43        }
44
45        let (weights, table_log) = compute_huffman_weights(&freqs, num_symbols)?;
46        let (codes, num_bits) = build_encode_codes(&weights, table_log);
47
48        Some(Self {
49            codes,
50            num_bits,
51            weights,
52            max_symbol: max_sym,
53            table_log,
54        })
55    }
56
57    pub fn table_log(&self) -> u8 {
58        self.table_log
59    }
60
61    pub fn can_encode(&self, data: &[u8]) -> bool {
62        for &b in data {
63            if self.num_bits[b as usize] == 0 {
64                return false;
65            }
66        }
67        true
68    }
69
70    pub fn serialize_weights(&self) -> Vec<u8> {
71        let explicit = &self.weights[..self.max_symbol as usize];
72        let num_symbols = explicit.len();
73
74        let mut out = Vec::with_capacity(1 + num_symbols.div_ceil(2));
75        out.push((num_symbols + 127) as u8);
76        let num_bytes = num_symbols.div_ceil(2);
77        for i in 0..num_bytes {
78            let hi = explicit.get(i * 2).copied().unwrap_or(0);
79            let lo = explicit.get(i * 2 + 1).copied().unwrap_or(0);
80            out.push((hi << 4) | lo);
81        }
82        out
83    }
84
85    pub fn encode_single_stream(&self, data: &[u8]) -> Vec<u8> {
86        let mut buf = Vec::with_capacity(data.len() + 8);
87        self.encode_single_stream_into(data, &mut buf);
88        buf
89    }
90
91    pub fn encode_single_stream_into(&self, data: &[u8], buf: &mut Vec<u8>) {
92        buf.clear();
93        let tl = self.table_log as usize;
94        let unroll: usize = (32usize).checked_div(tl).unwrap_or(1).max(2);
95
96        buf.reserve(data.len() + 16);
97        let mut bits: u64 = 0;
98        let mut bits_used: u8 = 0;
99        let mut wpos: usize = 0;
100
101        macro_rules! flush_bits {
102            () => {
103                if wpos + 8 > buf.capacity() {
104                    primitives::set_vec_len(buf, wpos);
105                    buf.reserve(64);
106                }
107                primitives::bitstream_flush_vec(buf, wpos, bits);
108                let nb = (bits_used >> 3) as usize;
109                wpos += nb;
110                bits >>= nb << 3;
111                bits_used &= 7;
112            };
113        }
114
115        let mut pos = data.len();
116        while pos >= unroll {
117            pos -= unroll;
118            for j in 0..unroll {
119                let b = primitives::get_unchecked_byte(data, pos + (unroll - 1 - j));
120                let c = primitives::get_unchecked_u16(&self.codes, b as usize) as u64;
121                let n = primitives::get_unchecked_u8_arr(&self.num_bits, b as usize);
122                bits |= c << bits_used;
123                bits_used += n;
124            }
125            if bits_used >= 32 {
126                flush_bits!();
127            }
128        }
129        while pos > 0 {
130            pos -= 1;
131            let b = primitives::get_unchecked_byte(data, pos);
132            let c = primitives::get_unchecked_u16(&self.codes, b as usize) as u64;
133            let n = primitives::get_unchecked_u8_arr(&self.num_bits, b as usize);
134            bits |= c << bits_used;
135            bits_used += n;
136            if bits_used >= 32 {
137                flush_bits!();
138            }
139        }
140
141        primitives::set_vec_len(buf, wpos);
142        bits |= 1u64 << bits_used;
143        bits_used += 1;
144        while bits_used > 0 {
145            buf.push(bits as u8);
146            bits >>= 8;
147            bits_used = bits_used.saturating_sub(8);
148        }
149    }
150
151    pub fn encode_4_streams(&self, data: &[u8]) -> Vec<u8> {
152        let mut out = Vec::new();
153        self.encode_4_streams_into(data, &mut out, &mut Vec::new());
154        out
155    }
156
157    pub fn encode_4_streams_into(&self, data: &[u8], out: &mut Vec<u8>, stream_buf: &mut Vec<u8>) {
158        let seg = data.len().div_ceil(4);
159        let s1 = &data[..seg.min(data.len())];
160        let s2 = &data[seg.min(data.len())..(seg * 2).min(data.len())];
161        let s3 = &data[(seg * 2).min(data.len())..(seg * 3).min(data.len())];
162        let s4 = &data[(seg * 3).min(data.len())..];
163
164        out.clear();
165        out.extend_from_slice(&[0u8; 6]);
166
167        self.encode_single_stream_into(s1, stream_buf);
168        let e1_len = stream_buf.len();
169        out.extend_from_slice(stream_buf);
170
171        self.encode_single_stream_into(s2, stream_buf);
172        let e2_len = stream_buf.len();
173        out.extend_from_slice(stream_buf);
174
175        self.encode_single_stream_into(s3, stream_buf);
176        let e3_len = stream_buf.len();
177        out.extend_from_slice(stream_buf);
178
179        self.encode_single_stream_into(s4, stream_buf);
180        out.extend_from_slice(stream_buf);
181
182        out[0..2].copy_from_slice(&(e1_len as u16).to_le_bytes());
183        out[2..4].copy_from_slice(&(e2_len as u16).to_le_bytes());
184        out[4..6].copy_from_slice(&(e3_len as u16).to_le_bytes());
185    }
186
187    pub fn compressed_size_single(&self, data: &[u8]) -> usize {
188        let total_bits: usize = data
189            .iter()
190            .map(|&b| self.num_bits[b as usize] as usize)
191            .sum();
192        (total_bits + 8) / 8
193    }
194}
195
196fn compute_huffman_weights(freqs: &[u32], num_symbols: usize) -> Option<(Vec<u8>, u8)> {
197    use alloc::collections::BinaryHeap;
198    use core::cmp::Reverse;
199
200    let active: Vec<(u64, usize)> = freqs[..num_symbols]
201        .iter()
202        .enumerate()
203        .filter(|(_, f)| **f > 0)
204        .map(|(s, &f)| (f as u64, s))
205        .collect();
206
207    if active.len() < 2 {
208        return None;
209    }
210
211    let n = active.len();
212
213    let max_nodes = 2 * n;
214    let mut parent = vec![usize::MAX; max_nodes];
215
216    let mut heap: BinaryHeap<Reverse<(u64, usize)>> = BinaryHeap::with_capacity(n);
217    for (i, &(f, _)) in active.iter().enumerate() {
218        heap.push(Reverse((f, i)));
219    }
220
221    for next_id in n..n + (n - 1) {
222        let Reverse((f1, n1)) = heap.pop().unwrap();
223        let Reverse((f2, n2)) = heap.pop().unwrap();
224        parent[n1] = next_id;
225        parent[n2] = next_id;
226        heap.push(Reverse((f1 + f2, next_id)));
227    }
228
229    let mut bit_lengths = vec![0u8; num_symbols];
230    for (i, &(_, sym)) in active.iter().enumerate().take(n) {
231        let mut depth = 0u8;
232        let mut node = i;
233        while parent[node] != usize::MAX {
234            depth += 1;
235            node = parent[node];
236        }
237        bit_lengths[sym] = depth;
238    }
239
240    let max_bl = *bit_lengths.iter().max().unwrap();
241    if max_bl == 0 || max_bl > MAX_BITS {
242        return None;
243    }
244
245    let table_log = max_bl;
246    let mut weights = vec![0u8; num_symbols];
247    for (s, &bl) in bit_lengths.iter().enumerate() {
248        if bl > 0 {
249            weights[s] = table_log + 1 - bl;
250        }
251    }
252
253    Some((weights, table_log))
254}
255
256fn build_encode_codes(
257    weights: &[u8],
258    table_log: u8,
259) -> ([u16; MAX_SYMBOL_VALUE + 1], [u8; MAX_SYMBOL_VALUE + 1]) {
260    let mut codes = [0u16; MAX_SYMBOL_VALUE + 1];
261    let mut num_bits = [0u8; MAX_SYMBOL_VALUE + 1];
262
263    let max_w = table_log + 1;
264    let mut rank_count = [0u32; MAX_BITS as usize + 2];
265
266    for (s, &w) in weights.iter().enumerate() {
267        if w > 0 && w <= max_w {
268            num_bits[s] = table_log + 1 - w;
269            rank_count[w as usize] += 1;
270        }
271    }
272
273    let mut rank_start = [0u32; MAX_BITS as usize + 2];
274    let mut cumul = 0u32;
275    for w in 1..=max_w {
276        rank_start[w as usize] = cumul;
277        cumul += rank_count[w as usize] * (1u32 << (w - 1));
278    }
279
280    for (s, &w) in weights.iter().enumerate() {
281        if w == 0 {
282            continue;
283        }
284        let start = rank_start[w as usize];
285        codes[s] = (start >> (w - 1)) as u16;
286        rank_start[w as usize] += 1u32 << (w - 1);
287    }
288
289    (codes, num_bits)
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn roundtrip_simple() {
298        let data = b"hello world hello world hello world!";
299        let table = HuffmanEncodeTable::from_data(data).unwrap();
300        let weights_raw = table.serialize_weights();
301        let encoded = table.encode_single_stream(data);
302
303        let (parsed_weights, _) =
304            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
305        let (decode_table, decode_log) =
306            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
307        let decoded = crate::huffman::decode::decode_single_stream(
308            &decode_table,
309            decode_log,
310            &encoded,
311            data.len(),
312        )
313        .unwrap();
314        assert_eq!(decoded, data);
315    }
316
317    #[test]
318    fn roundtrip_4_streams() {
319        let data: Vec<u8> = b"ABCDEFGH".iter().cycle().take(1024).copied().collect();
320        let table = HuffmanEncodeTable::from_data(&data).unwrap();
321        let weights_raw = table.serialize_weights();
322        let encoded = table.encode_4_streams(&data);
323
324        let (parsed_weights, _) =
325            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
326        let (decode_table, decode_log) =
327            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
328        let decoded = crate::huffman::decode::decode_4_streams(
329            &decode_table,
330            decode_log,
331            &encoded,
332            data.len(),
333        )
334        .unwrap();
335        assert_eq!(decoded, data);
336    }
337
338    #[test]
339    fn roundtrip_all_bytes() {
340        let data: Vec<u8> = (0u8..=127).cycle().take(4096).collect();
341        let table = HuffmanEncodeTable::from_data(&data).unwrap();
342        let weights_raw = table.serialize_weights();
343        let encoded = table.encode_single_stream(&data);
344
345        let (parsed_weights, _) =
346            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
347        let (decode_table, decode_log) =
348            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
349        let decoded = crate::huffman::decode::decode_single_stream(
350            &decode_table,
351            decode_log,
352            &encoded,
353            data.len(),
354        )
355        .unwrap();
356        assert_eq!(decoded, data);
357    }
358
359    #[test]
360    fn skewed_distribution() {
361        let mut data = vec![0u8; 900];
362        data.extend(vec![1u8; 80]);
363        data.extend(vec![2u8; 15]);
364        data.extend(vec![3u8; 5]);
365        let table = HuffmanEncodeTable::from_data(&data).unwrap();
366        assert!(table.num_bits[0] < table.num_bits[3]);
367        let weights_raw = table.serialize_weights();
368        let encoded = table.encode_single_stream(&data);
369
370        let (parsed_weights, _) =
371            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
372        let (decode_table, decode_log) =
373            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
374        let decoded = crate::huffman::decode::decode_single_stream(
375            &decode_table,
376            decode_log,
377            &encoded,
378            data.len(),
379        )
380        .unwrap();
381        assert_eq!(decoded, data);
382    }
383
384    #[test]
385    fn two_symbols() {
386        let mut data = vec![0u8; 500];
387        data.extend(vec![1u8; 500]);
388        let table = HuffmanEncodeTable::from_data(&data).unwrap();
389        assert_eq!(table.num_bits[0], 1);
390        assert_eq!(table.num_bits[1], 1);
391        let encoded = table.encode_single_stream(&data);
392
393        let weights_raw = table.serialize_weights();
394        let (parsed_weights, _) =
395            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
396        let (decode_table, decode_log) =
397            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
398        let decoded = crate::huffman::decode::decode_single_stream(
399            &decode_table,
400            decode_log,
401            &encoded,
402            data.len(),
403        )
404        .unwrap();
405        assert_eq!(decoded, data);
406    }
407}