1#![forbid(unsafe_code)]
2
3#[cfg(feature = "alloc")]
4use alloc::vec;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7
8use super::primitives;
9use crate::huffman::{MAX_BITS, MAX_SYMBOL_VALUE};
10
11pub struct HuffmanEncodeTable {
12 codes: [u16; MAX_SYMBOL_VALUE + 1],
13 num_bits: [u8; MAX_SYMBOL_VALUE + 1],
14 weights: Vec<u8>,
15 max_symbol: u8,
16 table_log: u8,
17}
18
19#[cfg(feature = "alloc")]
20impl HuffmanEncodeTable {
21 pub fn from_data(data: &[u8]) -> Option<Self> {
22 if data.is_empty() {
23 return None;
24 }
25
26 let mut freqs = [0u32; MAX_SYMBOL_VALUE + 1];
27 let mut max_sym = 0u8;
28 for &b in data {
29 freqs[b as usize] += 1;
30 if b > max_sym {
31 max_sym = b;
32 }
33 }
34
35 let num_symbols = max_sym as usize + 1;
36 let active_count = freqs[..num_symbols].iter().filter(|&&f| f > 0).count();
37 if active_count < 2 {
38 return None;
39 }
40
41 if max_sym as usize > 128 {
42 return None;
43 }
44
45 let (weights, table_log) = compute_huffman_weights(&freqs, num_symbols)?;
46 let (codes, num_bits) = build_encode_codes(&weights, table_log);
47
48 Some(Self {
49 codes,
50 num_bits,
51 weights,
52 max_symbol: max_sym,
53 table_log,
54 })
55 }
56
57 pub fn table_log(&self) -> u8 {
58 self.table_log
59 }
60
61 pub fn can_encode(&self, data: &[u8]) -> bool {
62 for &b in data {
63 if self.num_bits[b as usize] == 0 {
64 return false;
65 }
66 }
67 true
68 }
69
70 pub fn serialize_weights(&self) -> Vec<u8> {
71 let explicit = &self.weights[..self.max_symbol as usize];
72 let num_symbols = explicit.len();
73
74 let mut out = Vec::with_capacity(1 + num_symbols.div_ceil(2));
75 out.push((num_symbols + 127) as u8);
76 let num_bytes = num_symbols.div_ceil(2);
77 for i in 0..num_bytes {
78 let hi = explicit.get(i * 2).copied().unwrap_or(0);
79 let lo = explicit.get(i * 2 + 1).copied().unwrap_or(0);
80 out.push((hi << 4) | lo);
81 }
82 out
83 }
84
85 pub fn encode_single_stream(&self, data: &[u8]) -> Vec<u8> {
86 let mut buf = Vec::with_capacity(data.len() + 8);
87 self.encode_single_stream_into(data, &mut buf);
88 buf
89 }
90
91 pub fn encode_single_stream_into(&self, data: &[u8], buf: &mut Vec<u8>) {
92 buf.clear();
93 let tl = self.table_log as usize;
94 let unroll: usize = (32usize).checked_div(tl).unwrap_or(1).max(2);
95
96 buf.reserve(data.len() + 16);
97 let mut bits: u64 = 0;
98 let mut bits_used: u8 = 0;
99 let mut wpos: usize = 0;
100
101 macro_rules! flush_bits {
102 () => {
103 if wpos + 8 > buf.capacity() {
104 primitives::set_vec_len(buf, wpos);
105 buf.reserve(64);
106 }
107 primitives::bitstream_flush_vec(buf, wpos, bits);
108 let nb = (bits_used >> 3) as usize;
109 wpos += nb;
110 bits >>= nb << 3;
111 bits_used &= 7;
112 };
113 }
114
115 let mut pos = data.len();
116 while pos >= unroll {
117 pos -= unroll;
118 for j in 0..unroll {
119 let b = primitives::get_unchecked_byte(data, pos + (unroll - 1 - j));
120 let c = primitives::get_unchecked_u16(&self.codes, b as usize) as u64;
121 let n = primitives::get_unchecked_u8_arr(&self.num_bits, b as usize);
122 bits |= c << bits_used;
123 bits_used += n;
124 }
125 if bits_used >= 32 {
126 flush_bits!();
127 }
128 }
129 while pos > 0 {
130 pos -= 1;
131 let b = primitives::get_unchecked_byte(data, pos);
132 let c = primitives::get_unchecked_u16(&self.codes, b as usize) as u64;
133 let n = primitives::get_unchecked_u8_arr(&self.num_bits, b as usize);
134 bits |= c << bits_used;
135 bits_used += n;
136 if bits_used >= 32 {
137 flush_bits!();
138 }
139 }
140
141 primitives::set_vec_len(buf, wpos);
142 bits |= 1u64 << bits_used;
143 bits_used += 1;
144 while bits_used > 0 {
145 buf.push(bits as u8);
146 bits >>= 8;
147 bits_used = bits_used.saturating_sub(8);
148 }
149 }
150
151 pub fn encode_4_streams(&self, data: &[u8]) -> Vec<u8> {
152 let mut out = Vec::new();
153 self.encode_4_streams_into(data, &mut out, &mut Vec::new());
154 out
155 }
156
157 pub fn encode_4_streams_into(&self, data: &[u8], out: &mut Vec<u8>, stream_buf: &mut Vec<u8>) {
158 let seg = data.len().div_ceil(4);
159 let s1 = &data[..seg.min(data.len())];
160 let s2 = &data[seg.min(data.len())..(seg * 2).min(data.len())];
161 let s3 = &data[(seg * 2).min(data.len())..(seg * 3).min(data.len())];
162 let s4 = &data[(seg * 3).min(data.len())..];
163
164 out.clear();
165 out.extend_from_slice(&[0u8; 6]);
166
167 self.encode_single_stream_into(s1, stream_buf);
168 let e1_len = stream_buf.len();
169 out.extend_from_slice(stream_buf);
170
171 self.encode_single_stream_into(s2, stream_buf);
172 let e2_len = stream_buf.len();
173 out.extend_from_slice(stream_buf);
174
175 self.encode_single_stream_into(s3, stream_buf);
176 let e3_len = stream_buf.len();
177 out.extend_from_slice(stream_buf);
178
179 self.encode_single_stream_into(s4, stream_buf);
180 out.extend_from_slice(stream_buf);
181
182 out[0..2].copy_from_slice(&(e1_len as u16).to_le_bytes());
183 out[2..4].copy_from_slice(&(e2_len as u16).to_le_bytes());
184 out[4..6].copy_from_slice(&(e3_len as u16).to_le_bytes());
185 }
186
187 pub fn compressed_size_single(&self, data: &[u8]) -> usize {
188 let total_bits: usize = data
189 .iter()
190 .map(|&b| self.num_bits[b as usize] as usize)
191 .sum();
192 (total_bits + 8) / 8
193 }
194}
195
196fn compute_huffman_weights(freqs: &[u32], num_symbols: usize) -> Option<(Vec<u8>, u8)> {
197 use alloc::collections::BinaryHeap;
198 use core::cmp::Reverse;
199
200 let active: Vec<(u64, usize)> = freqs[..num_symbols]
201 .iter()
202 .enumerate()
203 .filter(|(_, f)| **f > 0)
204 .map(|(s, &f)| (f as u64, s))
205 .collect();
206
207 if active.len() < 2 {
208 return None;
209 }
210
211 let n = active.len();
212
213 let max_nodes = 2 * n;
214 let mut parent = vec![usize::MAX; max_nodes];
215
216 let mut heap: BinaryHeap<Reverse<(u64, usize)>> = BinaryHeap::with_capacity(n);
217 for (i, &(f, _)) in active.iter().enumerate() {
218 heap.push(Reverse((f, i)));
219 }
220
221 for next_id in n..n + (n - 1) {
222 let Reverse((f1, n1)) = heap.pop().unwrap();
223 let Reverse((f2, n2)) = heap.pop().unwrap();
224 parent[n1] = next_id;
225 parent[n2] = next_id;
226 heap.push(Reverse((f1 + f2, next_id)));
227 }
228
229 let mut bit_lengths = vec![0u8; num_symbols];
230 for (i, &(_, sym)) in active.iter().enumerate().take(n) {
231 let mut depth = 0u8;
232 let mut node = i;
233 while parent[node] != usize::MAX {
234 depth += 1;
235 node = parent[node];
236 }
237 bit_lengths[sym] = depth;
238 }
239
240 let max_bl = *bit_lengths.iter().max().unwrap();
241 if max_bl == 0 || max_bl > MAX_BITS {
242 return None;
243 }
244
245 let table_log = max_bl;
246 let mut weights = vec![0u8; num_symbols];
247 for (s, &bl) in bit_lengths.iter().enumerate() {
248 if bl > 0 {
249 weights[s] = table_log + 1 - bl;
250 }
251 }
252
253 Some((weights, table_log))
254}
255
256fn build_encode_codes(
257 weights: &[u8],
258 table_log: u8,
259) -> ([u16; MAX_SYMBOL_VALUE + 1], [u8; MAX_SYMBOL_VALUE + 1]) {
260 let mut codes = [0u16; MAX_SYMBOL_VALUE + 1];
261 let mut num_bits = [0u8; MAX_SYMBOL_VALUE + 1];
262
263 let max_w = table_log + 1;
264 let mut rank_count = [0u32; MAX_BITS as usize + 2];
265
266 for (s, &w) in weights.iter().enumerate() {
267 if w > 0 && w <= max_w {
268 num_bits[s] = table_log + 1 - w;
269 rank_count[w as usize] += 1;
270 }
271 }
272
273 let mut rank_start = [0u32; MAX_BITS as usize + 2];
274 let mut cumul = 0u32;
275 for w in 1..=max_w {
276 rank_start[w as usize] = cumul;
277 cumul += rank_count[w as usize] * (1u32 << (w - 1));
278 }
279
280 for (s, &w) in weights.iter().enumerate() {
281 if w == 0 {
282 continue;
283 }
284 let start = rank_start[w as usize];
285 codes[s] = (start >> (w - 1)) as u16;
286 rank_start[w as usize] += 1u32 << (w - 1);
287 }
288
289 (codes, num_bits)
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn roundtrip_simple() {
298 let data = b"hello world hello world hello world!";
299 let table = HuffmanEncodeTable::from_data(data).unwrap();
300 let weights_raw = table.serialize_weights();
301 let encoded = table.encode_single_stream(data);
302
303 let (parsed_weights, _) =
304 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
305 let (decode_table, decode_log) =
306 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
307 let decoded = crate::huffman::decode::decode_single_stream(
308 &decode_table,
309 decode_log,
310 &encoded,
311 data.len(),
312 )
313 .unwrap();
314 assert_eq!(decoded, data);
315 }
316
317 #[test]
318 fn roundtrip_4_streams() {
319 let data: Vec<u8> = b"ABCDEFGH".iter().cycle().take(1024).copied().collect();
320 let table = HuffmanEncodeTable::from_data(&data).unwrap();
321 let weights_raw = table.serialize_weights();
322 let encoded = table.encode_4_streams(&data);
323
324 let (parsed_weights, _) =
325 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
326 let (decode_table, decode_log) =
327 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
328 let decoded = crate::huffman::decode::decode_4_streams(
329 &decode_table,
330 decode_log,
331 &encoded,
332 data.len(),
333 )
334 .unwrap();
335 assert_eq!(decoded, data);
336 }
337
338 #[test]
339 fn roundtrip_all_bytes() {
340 let data: Vec<u8> = (0u8..=127).cycle().take(4096).collect();
341 let table = HuffmanEncodeTable::from_data(&data).unwrap();
342 let weights_raw = table.serialize_weights();
343 let encoded = table.encode_single_stream(&data);
344
345 let (parsed_weights, _) =
346 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
347 let (decode_table, decode_log) =
348 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
349 let decoded = crate::huffman::decode::decode_single_stream(
350 &decode_table,
351 decode_log,
352 &encoded,
353 data.len(),
354 )
355 .unwrap();
356 assert_eq!(decoded, data);
357 }
358
359 #[test]
360 fn skewed_distribution() {
361 let mut data = vec![0u8; 900];
362 data.extend(vec![1u8; 80]);
363 data.extend(vec![2u8; 15]);
364 data.extend(vec![3u8; 5]);
365 let table = HuffmanEncodeTable::from_data(&data).unwrap();
366 assert!(table.num_bits[0] < table.num_bits[3]);
367 let weights_raw = table.serialize_weights();
368 let encoded = table.encode_single_stream(&data);
369
370 let (parsed_weights, _) =
371 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
372 let (decode_table, decode_log) =
373 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
374 let decoded = crate::huffman::decode::decode_single_stream(
375 &decode_table,
376 decode_log,
377 &encoded,
378 data.len(),
379 )
380 .unwrap();
381 assert_eq!(decoded, data);
382 }
383
384 #[test]
385 fn two_symbols() {
386 let mut data = vec![0u8; 500];
387 data.extend(vec![1u8; 500]);
388 let table = HuffmanEncodeTable::from_data(&data).unwrap();
389 assert_eq!(table.num_bits[0], 1);
390 assert_eq!(table.num_bits[1], 1);
391 let encoded = table.encode_single_stream(&data);
392
393 let weights_raw = table.serialize_weights();
394 let (parsed_weights, _) =
395 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
396 let (decode_table, decode_log) =
397 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
398 let decoded = crate::huffman::decode::decode_single_stream(
399 &decode_table,
400 decode_log,
401 &encoded,
402 data.len(),
403 )
404 .unwrap();
405 assert_eq!(decoded, data);
406 }
407}