Skip to main content

zrip_core/huffman/
encode.rs

1#![forbid(unsafe_code)]
2
3#[cfg(feature = "alloc")]
4use alloc::vec;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7
8use super::primitives;
9use crate::huffman::{MAX_BITS, MAX_SYMBOL_VALUE};
10
11#[derive(Clone)]
12pub struct HuffmanEncodeTable {
13    codes: [u16; MAX_SYMBOL_VALUE + 1],
14    num_bits: [u8; MAX_SYMBOL_VALUE + 1],
15    weights: Vec<u8>,
16    max_symbol: u8,
17    table_log: u8,
18}
19
20#[cfg(feature = "alloc")]
21impl HuffmanEncodeTable {
22    pub fn from_data(data: &[u8]) -> Option<Self> {
23        if data.is_empty() {
24            return None;
25        }
26
27        let mut freqs = [0u32; MAX_SYMBOL_VALUE + 1];
28        let mut max_sym = 0u8;
29        for &b in data {
30            freqs[b as usize] += 1;
31            if b > max_sym {
32                max_sym = b;
33            }
34        }
35
36        let num_symbols = max_sym as usize + 1;
37        let active_count = freqs[..num_symbols].iter().filter(|&&f| f > 0).count();
38        if active_count < 2 {
39            return None;
40        }
41
42        if max_sym as usize > 128 {
43            return None;
44        }
45
46        let (weights, table_log) = compute_huffman_weights(&freqs, num_symbols)?;
47        let (codes, num_bits) = build_encode_codes(&weights, table_log);
48
49        Some(Self {
50            codes,
51            num_bits,
52            weights,
53            max_symbol: max_sym,
54            table_log,
55        })
56    }
57
58    pub fn from_decode_table(
59        decode_table: &[super::HuffmanDecodeEntry],
60        table_log: u8,
61    ) -> Option<Self> {
62        let table_size = 1usize << table_log;
63        if decode_table.len() < table_size {
64            return None;
65        }
66
67        let mut num_bits_per_sym = [0u8; MAX_SYMBOL_VALUE + 1];
68        let mut max_sym = 0u8;
69        let mut seen = [false; MAX_SYMBOL_VALUE + 1];
70        for entry in &decode_table[..table_size] {
71            let s = entry.symbol;
72            if !seen[s as usize] {
73                num_bits_per_sym[s as usize] = entry.num_bits;
74                seen[s as usize] = true;
75                if s > max_sym {
76                    max_sym = s;
77                }
78            }
79        }
80
81        if max_sym as usize > 128 {
82            return None;
83        }
84
85        let num_symbols = max_sym as usize + 1;
86        let active_count = seen[..num_symbols].iter().filter(|&&s| s).count();
87        if active_count < 2 {
88            return None;
89        }
90
91        let mut weights = vec![0u8; num_symbols];
92        for s in 0..num_symbols {
93            if seen[s] {
94                weights[s] = table_log + 1 - num_bits_per_sym[s];
95            }
96        }
97
98        let (codes, num_bits) = build_encode_codes(&weights, table_log);
99
100        Some(Self {
101            codes,
102            num_bits,
103            weights,
104            max_symbol: max_sym,
105            table_log,
106        })
107    }
108
109    pub fn table_log(&self) -> u8 {
110        self.table_log
111    }
112
113    pub fn can_encode(&self, data: &[u8]) -> bool {
114        for &b in data {
115            if self.num_bits[b as usize] == 0 {
116                return false;
117            }
118        }
119        true
120    }
121
122    pub fn serialize_weights(&self) -> Vec<u8> {
123        let explicit = &self.weights[..self.max_symbol as usize];
124        let num_symbols = explicit.len();
125
126        let mut out = Vec::with_capacity(1 + num_symbols.div_ceil(2));
127        out.push((num_symbols + 127) as u8);
128        let num_bytes = num_symbols.div_ceil(2);
129        for i in 0..num_bytes {
130            let hi = explicit.get(i * 2).copied().unwrap_or(0);
131            let lo = explicit.get(i * 2 + 1).copied().unwrap_or(0);
132            out.push((hi << 4) | lo);
133        }
134        out
135    }
136
137    pub fn encode_single_stream(&self, data: &[u8]) -> Vec<u8> {
138        let mut buf = Vec::with_capacity(data.len() + 8);
139        self.encode_single_stream_into(data, &mut buf);
140        buf
141    }
142
143    pub fn encode_single_stream_into(&self, data: &[u8], buf: &mut Vec<u8>) {
144        buf.clear();
145        let tl = self.table_log as usize;
146        let unroll: usize = (32usize).checked_div(tl).unwrap_or(1).max(2);
147
148        buf.reserve(data.len() + 16);
149        let mut bits: u64 = 0;
150        let mut bits_used: u8 = 0;
151        let mut wpos: usize = 0;
152
153        macro_rules! flush_bits {
154            () => {
155                if wpos + 8 > buf.capacity() {
156                    primitives::set_vec_len(buf, wpos);
157                    buf.reserve(64);
158                }
159                primitives::bitstream_flush_vec(buf, wpos, bits);
160                let nb = (bits_used >> 3) as usize;
161                wpos += nb;
162                bits >>= nb << 3;
163                bits_used &= 7;
164            };
165        }
166
167        let mut pos = data.len();
168        while pos >= unroll {
169            pos -= unroll;
170            for j in 0..unroll {
171                let b = primitives::get_unchecked_byte(data, pos + (unroll - 1 - j));
172                let c = primitives::get_unchecked_u16(&self.codes, b as usize) as u64;
173                let n = primitives::get_unchecked_u8_arr(&self.num_bits, b as usize);
174                bits |= c << bits_used;
175                bits_used += n;
176            }
177            if bits_used >= 32 {
178                flush_bits!();
179            }
180        }
181        while pos > 0 {
182            pos -= 1;
183            let b = primitives::get_unchecked_byte(data, pos);
184            let c = primitives::get_unchecked_u16(&self.codes, b as usize) as u64;
185            let n = primitives::get_unchecked_u8_arr(&self.num_bits, b as usize);
186            bits |= c << bits_used;
187            bits_used += n;
188            if bits_used >= 32 {
189                flush_bits!();
190            }
191        }
192
193        primitives::set_vec_len(buf, wpos);
194        bits |= 1u64 << bits_used;
195        bits_used += 1;
196        while bits_used > 0 {
197            buf.push(bits as u8);
198            bits >>= 8;
199            bits_used = bits_used.saturating_sub(8);
200        }
201    }
202
203    pub fn encode_4_streams(&self, data: &[u8]) -> Vec<u8> {
204        let mut out = Vec::new();
205        self.encode_4_streams_into(data, &mut out, &mut Vec::new());
206        out
207    }
208
209    pub fn encode_4_streams_into(&self, data: &[u8], out: &mut Vec<u8>, stream_buf: &mut Vec<u8>) {
210        let seg = data.len().div_ceil(4);
211        let s1 = &data[..seg.min(data.len())];
212        let s2 = &data[seg.min(data.len())..(seg * 2).min(data.len())];
213        let s3 = &data[(seg * 2).min(data.len())..(seg * 3).min(data.len())];
214        let s4 = &data[(seg * 3).min(data.len())..];
215
216        out.clear();
217        out.extend_from_slice(&[0u8; 6]);
218
219        self.encode_single_stream_into(s1, stream_buf);
220        let e1_len = stream_buf.len();
221        out.extend_from_slice(stream_buf);
222
223        self.encode_single_stream_into(s2, stream_buf);
224        let e2_len = stream_buf.len();
225        out.extend_from_slice(stream_buf);
226
227        self.encode_single_stream_into(s3, stream_buf);
228        let e3_len = stream_buf.len();
229        out.extend_from_slice(stream_buf);
230
231        self.encode_single_stream_into(s4, stream_buf);
232        out.extend_from_slice(stream_buf);
233
234        out[0..2].copy_from_slice(&(e1_len as u16).to_le_bytes());
235        out[2..4].copy_from_slice(&(e2_len as u16).to_le_bytes());
236        out[4..6].copy_from_slice(&(e3_len as u16).to_le_bytes());
237    }
238
239    pub fn compressed_size_single(&self, data: &[u8]) -> usize {
240        let total_bits: usize = data
241            .iter()
242            .map(|&b| self.num_bits[b as usize] as usize)
243            .sum();
244        (total_bits + 8) / 8
245    }
246}
247
248fn compute_huffman_weights(freqs: &[u32], num_symbols: usize) -> Option<(Vec<u8>, u8)> {
249    use alloc::collections::BinaryHeap;
250    use core::cmp::Reverse;
251
252    let active: Vec<(u64, usize)> = freqs[..num_symbols]
253        .iter()
254        .enumerate()
255        .filter(|(_, f)| **f > 0)
256        .map(|(s, &f)| (f as u64, s))
257        .collect();
258
259    if active.len() < 2 {
260        return None;
261    }
262
263    let n = active.len();
264
265    let max_nodes = 2 * n;
266    let mut parent = vec![usize::MAX; max_nodes];
267
268    let mut heap: BinaryHeap<Reverse<(u64, usize)>> = BinaryHeap::with_capacity(n);
269    for (i, &(f, _)) in active.iter().enumerate() {
270        heap.push(Reverse((f, i)));
271    }
272
273    for next_id in n..n + (n - 1) {
274        let Reverse((f1, n1)) = heap.pop().unwrap();
275        let Reverse((f2, n2)) = heap.pop().unwrap();
276        parent[n1] = next_id;
277        parent[n2] = next_id;
278        heap.push(Reverse((f1 + f2, next_id)));
279    }
280
281    let mut bit_lengths = vec![0u8; num_symbols];
282    for (i, &(_, sym)) in active.iter().enumerate().take(n) {
283        let mut depth = 0u8;
284        let mut node = i;
285        while parent[node] != usize::MAX {
286            depth += 1;
287            node = parent[node];
288        }
289        bit_lengths[sym] = depth;
290    }
291
292    let max_bl = *bit_lengths.iter().max().unwrap();
293    if max_bl == 0 || max_bl > MAX_BITS {
294        return None;
295    }
296
297    let table_log = max_bl;
298    let mut weights = vec![0u8; num_symbols];
299    for (s, &bl) in bit_lengths.iter().enumerate() {
300        if bl > 0 {
301            weights[s] = table_log + 1 - bl;
302        }
303    }
304
305    Some((weights, table_log))
306}
307
308fn build_encode_codes(
309    weights: &[u8],
310    table_log: u8,
311) -> ([u16; MAX_SYMBOL_VALUE + 1], [u8; MAX_SYMBOL_VALUE + 1]) {
312    let mut codes = [0u16; MAX_SYMBOL_VALUE + 1];
313    let mut num_bits = [0u8; MAX_SYMBOL_VALUE + 1];
314
315    let max_w = table_log + 1;
316    let mut rank_count = [0u32; MAX_BITS as usize + 2];
317
318    for (s, &w) in weights.iter().enumerate() {
319        if w > 0 && w <= max_w {
320            num_bits[s] = table_log + 1 - w;
321            rank_count[w as usize] += 1;
322        }
323    }
324
325    let mut rank_start = [0u32; MAX_BITS as usize + 2];
326    let mut cumul = 0u32;
327    for w in 1..=max_w {
328        rank_start[w as usize] = cumul;
329        cumul += rank_count[w as usize] * (1u32 << (w - 1));
330    }
331
332    for (s, &w) in weights.iter().enumerate() {
333        if w == 0 {
334            continue;
335        }
336        let start = rank_start[w as usize];
337        codes[s] = (start >> (w - 1)) as u16;
338        rank_start[w as usize] += 1u32 << (w - 1);
339    }
340
341    (codes, num_bits)
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn from_decode_table_roundtrip() {
350        let data = b"hello world hello world hello world!";
351        let original = HuffmanEncodeTable::from_data(data).unwrap();
352        let weights_raw = original.serialize_weights();
353
354        let (parsed_weights, _) =
355            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
356        let (decode_table, decode_log) =
357            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
358
359        let rebuilt = HuffmanEncodeTable::from_decode_table(&decode_table, decode_log).unwrap();
360        assert_eq!(original.table_log, rebuilt.table_log);
361        assert_eq!(original.max_symbol, rebuilt.max_symbol);
362        assert_eq!(original.weights, rebuilt.weights);
363        assert_eq!(original.num_bits, rebuilt.num_bits);
364        assert_eq!(original.codes, rebuilt.codes);
365
366        let encoded = rebuilt.encode_single_stream(data);
367        let decoded = crate::huffman::decode::decode_single_stream(
368            &decode_table,
369            decode_log,
370            &encoded,
371            data.len(),
372        )
373        .unwrap();
374        assert_eq!(decoded, data);
375    }
376
377    #[test]
378    fn from_decode_table_skewed() {
379        let mut data = vec![0u8; 900];
380        data.extend(vec![1u8; 80]);
381        data.extend(vec![2u8; 15]);
382        data.extend(vec![3u8; 5]);
383        let original = HuffmanEncodeTable::from_data(&data).unwrap();
384        let weights_raw = original.serialize_weights();
385
386        let (parsed_weights, _) =
387            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
388        let (decode_table, decode_log) =
389            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
390
391        let rebuilt = HuffmanEncodeTable::from_decode_table(&decode_table, decode_log).unwrap();
392        assert_eq!(original.codes, rebuilt.codes);
393        assert_eq!(original.num_bits, rebuilt.num_bits);
394
395        let encoded = rebuilt.encode_single_stream(&data);
396        let decoded = crate::huffman::decode::decode_single_stream(
397            &decode_table,
398            decode_log,
399            &encoded,
400            data.len(),
401        )
402        .unwrap();
403        assert_eq!(decoded, data);
404    }
405
406    #[test]
407    fn roundtrip_simple() {
408        let data = b"hello world hello world hello world!";
409        let table = HuffmanEncodeTable::from_data(data).unwrap();
410        let weights_raw = table.serialize_weights();
411        let encoded = table.encode_single_stream(data);
412
413        let (parsed_weights, _) =
414            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
415        let (decode_table, decode_log) =
416            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
417        let decoded = crate::huffman::decode::decode_single_stream(
418            &decode_table,
419            decode_log,
420            &encoded,
421            data.len(),
422        )
423        .unwrap();
424        assert_eq!(decoded, data);
425    }
426
427    #[test]
428    fn roundtrip_4_streams() {
429        let data: Vec<u8> = b"ABCDEFGH".iter().cycle().take(1024).copied().collect();
430        let table = HuffmanEncodeTable::from_data(&data).unwrap();
431        let weights_raw = table.serialize_weights();
432        let encoded = table.encode_4_streams(&data);
433
434        let (parsed_weights, _) =
435            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
436        let (decode_table, decode_log) =
437            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
438        let decoded = crate::huffman::decode::decode_4_streams(
439            &decode_table,
440            decode_log,
441            &encoded,
442            data.len(),
443        )
444        .unwrap();
445        assert_eq!(decoded, data);
446    }
447
448    #[test]
449    fn roundtrip_all_bytes() {
450        let data: Vec<u8> = (0u8..=127).cycle().take(4096).collect();
451        let table = HuffmanEncodeTable::from_data(&data).unwrap();
452        let weights_raw = table.serialize_weights();
453        let encoded = table.encode_single_stream(&data);
454
455        let (parsed_weights, _) =
456            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
457        let (decode_table, decode_log) =
458            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
459        let decoded = crate::huffman::decode::decode_single_stream(
460            &decode_table,
461            decode_log,
462            &encoded,
463            data.len(),
464        )
465        .unwrap();
466        assert_eq!(decoded, data);
467    }
468
469    #[test]
470    fn skewed_distribution() {
471        let mut data = vec![0u8; 900];
472        data.extend(vec![1u8; 80]);
473        data.extend(vec![2u8; 15]);
474        data.extend(vec![3u8; 5]);
475        let table = HuffmanEncodeTable::from_data(&data).unwrap();
476        assert!(table.num_bits[0] < table.num_bits[3]);
477        let weights_raw = table.serialize_weights();
478        let encoded = table.encode_single_stream(&data);
479
480        let (parsed_weights, _) =
481            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
482        let (decode_table, decode_log) =
483            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
484        let decoded = crate::huffman::decode::decode_single_stream(
485            &decode_table,
486            decode_log,
487            &encoded,
488            data.len(),
489        )
490        .unwrap();
491        assert_eq!(decoded, data);
492    }
493
494    #[test]
495    fn two_symbols() {
496        let mut data = vec![0u8; 500];
497        data.extend(vec![1u8; 500]);
498        let table = HuffmanEncodeTable::from_data(&data).unwrap();
499        assert_eq!(table.num_bits[0], 1);
500        assert_eq!(table.num_bits[1], 1);
501        let encoded = table.encode_single_stream(&data);
502
503        let weights_raw = table.serialize_weights();
504        let (parsed_weights, _) =
505            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
506        let (decode_table, decode_log) =
507            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
508        let decoded = crate::huffman::decode::decode_single_stream(
509            &decode_table,
510            decode_log,
511            &encoded,
512            data.len(),
513        )
514        .unwrap();
515        assert_eq!(decoded, data);
516    }
517}