Skip to main content

zrip_core/huffman/
encode.rs

1#[cfg(feature = "alloc")]
2use alloc::vec;
3#[cfg(feature = "alloc")]
4use alloc::vec::Vec;
5
6use crate::huffman::{MAX_BITS, MAX_SYMBOL_VALUE};
7
8pub struct HuffmanEncodeTable {
9    codes: [u16; MAX_SYMBOL_VALUE + 1],
10    num_bits: [u8; MAX_SYMBOL_VALUE + 1],
11    weights: Vec<u8>,
12    max_symbol: u8,
13    table_log: u8,
14}
15
16#[cfg(feature = "alloc")]
17impl HuffmanEncodeTable {
18    pub fn from_data(data: &[u8]) -> Option<Self> {
19        if data.is_empty() {
20            return None;
21        }
22
23        let mut freqs = [0u32; MAX_SYMBOL_VALUE + 1];
24        let mut max_sym = 0u8;
25        for &b in data {
26            freqs[b as usize] += 1;
27            if b > max_sym {
28                max_sym = b;
29            }
30        }
31
32        let num_symbols = max_sym as usize + 1;
33        let active_count = freqs[..num_symbols].iter().filter(|&&f| f > 0).count();
34        if active_count < 2 {
35            return None;
36        }
37
38        // Direct weight encoding supports at most 128 explicit symbols.
39        // Fall back to raw literals for blocks with more distinct values.
40        if max_sym as usize > 128 {
41            return None;
42        }
43
44        let (weights, table_log) = compute_huffman_weights(&freqs, num_symbols)?;
45        let (codes, num_bits) = build_encode_codes(&weights, table_log);
46
47        Some(Self {
48            codes,
49            num_bits,
50            weights,
51            max_symbol: max_sym,
52            table_log,
53        })
54    }
55
56    pub fn table_log(&self) -> u8 {
57        self.table_log
58    }
59
60    pub fn can_encode(&self, data: &[u8]) -> bool {
61        for &b in data {
62            if self.num_bits[b as usize] == 0 {
63                return false;
64            }
65        }
66        true
67    }
68
69    pub fn serialize_weights(&self) -> Vec<u8> {
70        let explicit = &self.weights[..self.max_symbol as usize];
71        let num_symbols = explicit.len();
72
73        let mut out = Vec::with_capacity(1 + (num_symbols + 1) / 2);
74        out.push((num_symbols + 127) as u8);
75        let num_bytes = (num_symbols + 1) / 2;
76        for i in 0..num_bytes {
77            let hi = explicit.get(i * 2).copied().unwrap_or(0);
78            let lo = explicit.get(i * 2 + 1).copied().unwrap_or(0);
79            out.push((hi << 4) | lo);
80        }
81        out
82    }
83
84    pub fn encode_single_stream(&self, data: &[u8]) -> Vec<u8> {
85        let mut buf = Vec::with_capacity(data.len() + 8);
86        self.encode_single_stream_into(data, &mut buf);
87        buf
88    }
89
90    pub fn encode_single_stream_into(&self, data: &[u8], buf: &mut Vec<u8>) {
91        buf.clear();
92        let tl = self.table_log as usize;
93        let unroll: usize = if tl > 0 { 32 / tl } else { 1 }.max(2);
94
95        buf.reserve(data.len() + 16);
96        let mut bits: u64 = 0;
97        let mut bits_used: u8 = 0;
98        let mut wpos: usize = 0;
99        let codes = self.codes.as_ptr();
100        let nbits = self.num_bits.as_ptr();
101
102        macro_rules! flush_bits {
103            () => {
104                if wpos + 8 > buf.capacity() {
105                    unsafe {
106                        buf.set_len(wpos);
107                    }
108                    buf.reserve(64);
109                }
110                unsafe {
111                    let ptr = buf.as_mut_ptr().add(wpos);
112                    (ptr as *mut u64).write_unaligned(bits.to_le());
113                    let nb = (bits_used >> 3) as usize;
114                    wpos += nb;
115                    bits >>= nb << 3;
116                    bits_used &= 7;
117                }
118            };
119        }
120
121        let mut pos = data.len();
122        while pos >= unroll {
123            pos -= unroll;
124            for j in 0..unroll {
125                let b = unsafe { *data.get_unchecked(pos + (unroll - 1 - j)) };
126                let c = unsafe { *codes.add(b as usize) } as u64;
127                let n = unsafe { *nbits.add(b as usize) };
128                bits |= c << bits_used;
129                bits_used += n;
130            }
131            if bits_used >= 32 {
132                flush_bits!();
133            }
134        }
135        while pos > 0 {
136            pos -= 1;
137            let b = unsafe { *data.get_unchecked(pos) };
138            let c = unsafe { *codes.add(b as usize) } as u64;
139            let n = unsafe { *nbits.add(b as usize) };
140            bits |= c << bits_used;
141            bits_used += n;
142            if bits_used >= 32 {
143                flush_bits!();
144            }
145        }
146
147        unsafe {
148            buf.set_len(wpos);
149        }
150        bits |= 1u64 << bits_used;
151        bits_used += 1;
152        while bits_used > 0 {
153            buf.push(bits as u8);
154            bits >>= 8;
155            bits_used = bits_used.saturating_sub(8);
156        }
157    }
158
159    pub fn encode_4_streams(&self, data: &[u8]) -> Vec<u8> {
160        let mut out = Vec::new();
161        self.encode_4_streams_into(data, &mut out, &mut Vec::new());
162        out
163    }
164
165    pub fn encode_4_streams_into(&self, data: &[u8], out: &mut Vec<u8>, stream_buf: &mut Vec<u8>) {
166        let seg = (data.len() + 3) / 4;
167        let s1 = &data[..seg.min(data.len())];
168        let s2 = &data[seg.min(data.len())..(seg * 2).min(data.len())];
169        let s3 = &data[(seg * 2).min(data.len())..(seg * 3).min(data.len())];
170        let s4 = &data[(seg * 3).min(data.len())..];
171
172        out.clear();
173        out.extend_from_slice(&[0u8; 6]);
174
175        self.encode_single_stream_into(s1, stream_buf);
176        let e1_len = stream_buf.len();
177        out.extend_from_slice(stream_buf);
178
179        self.encode_single_stream_into(s2, stream_buf);
180        let e2_len = stream_buf.len();
181        out.extend_from_slice(stream_buf);
182
183        self.encode_single_stream_into(s3, stream_buf);
184        let e3_len = stream_buf.len();
185        out.extend_from_slice(stream_buf);
186
187        self.encode_single_stream_into(s4, stream_buf);
188        out.extend_from_slice(stream_buf);
189
190        out[0..2].copy_from_slice(&(e1_len as u16).to_le_bytes());
191        out[2..4].copy_from_slice(&(e2_len as u16).to_le_bytes());
192        out[4..6].copy_from_slice(&(e3_len as u16).to_le_bytes());
193    }
194
195    pub fn compressed_size_single(&self, data: &[u8]) -> usize {
196        let total_bits: usize = data
197            .iter()
198            .map(|&b| self.num_bits[b as usize] as usize)
199            .sum();
200        (total_bits + 8) / 8
201    }
202}
203
204fn compute_huffman_weights(freqs: &[u32], num_symbols: usize) -> Option<(Vec<u8>, u8)> {
205    use alloc::collections::BinaryHeap;
206    use core::cmp::Reverse;
207
208    let active: Vec<(u64, usize)> = freqs[..num_symbols]
209        .iter()
210        .enumerate()
211        .filter(|(_, f)| **f > 0)
212        .map(|(s, &f)| (f as u64, s))
213        .collect();
214
215    if active.len() < 2 {
216        return None;
217    }
218
219    let n = active.len();
220
221    let max_nodes = 2 * n;
222    let mut parent = vec![usize::MAX; max_nodes];
223
224    let mut heap: BinaryHeap<Reverse<(u64, usize)>> = BinaryHeap::with_capacity(n);
225    for (i, &(f, _)) in active.iter().enumerate() {
226        heap.push(Reverse((f, i)));
227    }
228
229    let mut next_id = n;
230    for _ in 0..n - 1 {
231        let Reverse((f1, n1)) = heap.pop().unwrap();
232        let Reverse((f2, n2)) = heap.pop().unwrap();
233        parent[n1] = next_id;
234        parent[n2] = next_id;
235        heap.push(Reverse((f1 + f2, next_id)));
236        next_id += 1;
237    }
238
239    let mut bit_lengths = vec![0u8; num_symbols];
240    for i in 0..n {
241        let sym = active[i].1;
242        let mut depth = 0u8;
243        let mut node = i;
244        while parent[node] != usize::MAX {
245            depth += 1;
246            node = parent[node];
247        }
248        bit_lengths[sym] = depth;
249    }
250
251    let max_bl = *bit_lengths.iter().max().unwrap();
252    if max_bl == 0 || max_bl > MAX_BITS {
253        return None;
254    }
255
256    let table_log = max_bl;
257    let mut weights = vec![0u8; num_symbols];
258    for (s, &bl) in bit_lengths.iter().enumerate() {
259        if bl > 0 {
260            weights[s] = table_log + 1 - bl;
261        }
262    }
263
264    Some((weights, table_log))
265}
266
267fn build_encode_codes(
268    weights: &[u8],
269    table_log: u8,
270) -> ([u16; MAX_SYMBOL_VALUE + 1], [u8; MAX_SYMBOL_VALUE + 1]) {
271    let mut codes = [0u16; MAX_SYMBOL_VALUE + 1];
272    let mut num_bits = [0u8; MAX_SYMBOL_VALUE + 1];
273
274    let max_w = table_log + 1;
275    let mut rank_count = [0u32; MAX_BITS as usize + 2];
276
277    for (s, &w) in weights.iter().enumerate() {
278        if w > 0 && w <= max_w {
279            num_bits[s] = table_log + 1 - w;
280            rank_count[w as usize] += 1;
281        }
282    }
283
284    let mut rank_start = [0u32; MAX_BITS as usize + 2];
285    let mut cumul = 0u32;
286    for w in 1..=max_w {
287        rank_start[w as usize] = cumul;
288        cumul += rank_count[w as usize] * (1u32 << (w - 1));
289    }
290
291    for (s, &w) in weights.iter().enumerate() {
292        if w == 0 {
293            continue;
294        }
295        let start = rank_start[w as usize];
296        codes[s] = (start >> (w - 1)) as u16;
297        rank_start[w as usize] += 1u32 << (w - 1);
298    }
299
300    (codes, num_bits)
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn roundtrip_simple() {
309        let data = b"hello world hello world hello world!";
310        let table = HuffmanEncodeTable::from_data(data).unwrap();
311        let weights_raw = table.serialize_weights();
312        let encoded = table.encode_single_stream(data);
313
314        let (parsed_weights, _) =
315            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
316        let (decode_table, decode_log) =
317            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
318        let decoded = crate::huffman::decode::decode_single_stream(
319            &decode_table,
320            decode_log,
321            &encoded,
322            data.len(),
323        )
324        .unwrap();
325        assert_eq!(decoded, data);
326    }
327
328    #[test]
329    fn roundtrip_4_streams() {
330        let data: Vec<u8> = b"ABCDEFGH".iter().cycle().take(1024).copied().collect();
331        let table = HuffmanEncodeTable::from_data(&data).unwrap();
332        let weights_raw = table.serialize_weights();
333        let encoded = table.encode_4_streams(&data);
334
335        let (parsed_weights, _) =
336            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
337        let (decode_table, decode_log) =
338            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
339        let decoded = crate::huffman::decode::decode_4_streams(
340            &decode_table,
341            decode_log,
342            &encoded,
343            data.len(),
344        )
345        .unwrap();
346        assert_eq!(decoded, data);
347    }
348
349    #[test]
350    fn roundtrip_all_bytes() {
351        let data: Vec<u8> = (0u8..=127).cycle().take(4096).collect();
352        let table = HuffmanEncodeTable::from_data(&data).unwrap();
353        let weights_raw = table.serialize_weights();
354        let encoded = table.encode_single_stream(&data);
355
356        let (parsed_weights, _) =
357            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
358        let (decode_table, decode_log) =
359            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
360        let decoded = crate::huffman::decode::decode_single_stream(
361            &decode_table,
362            decode_log,
363            &encoded,
364            data.len(),
365        )
366        .unwrap();
367        assert_eq!(decoded, data);
368    }
369
370    #[test]
371    fn skewed_distribution() {
372        let mut data = vec![0u8; 900];
373        data.extend(vec![1u8; 80]);
374        data.extend(vec![2u8; 15]);
375        data.extend(vec![3u8; 5]);
376        let table = HuffmanEncodeTable::from_data(&data).unwrap();
377        assert!(table.num_bits[0] < table.num_bits[3]);
378        let weights_raw = table.serialize_weights();
379        let encoded = table.encode_single_stream(&data);
380
381        let (parsed_weights, _) =
382            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
383        let (decode_table, decode_log) =
384            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
385        let decoded = crate::huffman::decode::decode_single_stream(
386            &decode_table,
387            decode_log,
388            &encoded,
389            data.len(),
390        )
391        .unwrap();
392        assert_eq!(decoded, data);
393    }
394
395    #[test]
396    fn two_symbols() {
397        let mut data = vec![0u8; 500];
398        data.extend(vec![1u8; 500]);
399        let table = HuffmanEncodeTable::from_data(&data).unwrap();
400        assert_eq!(table.num_bits[0], 1);
401        assert_eq!(table.num_bits[1], 1);
402        let encoded = table.encode_single_stream(&data);
403
404        let weights_raw = table.serialize_weights();
405        let (parsed_weights, _) =
406            crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
407        let (decode_table, decode_log) =
408            crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
409        let decoded = crate::huffman::decode::decode_single_stream(
410            &decode_table,
411            decode_log,
412            &encoded,
413            data.len(),
414        )
415        .unwrap();
416        assert_eq!(decoded, data);
417    }
418}