1#![forbid(unsafe_code)]
2
3#[cfg(feature = "alloc")]
4use alloc::vec;
5#[cfg(feature = "alloc")]
6use alloc::vec::Vec;
7
8use super::primitives;
9use crate::huffman::{MAX_BITS, MAX_SYMBOL_VALUE};
10
11#[derive(Clone)]
12pub struct HuffmanEncodeTable {
13 codes: [u16; MAX_SYMBOL_VALUE + 1],
14 num_bits: [u8; MAX_SYMBOL_VALUE + 1],
15 weights: Vec<u8>,
16 max_symbol: u8,
17 table_log: u8,
18}
19
20#[cfg(feature = "alloc")]
21impl HuffmanEncodeTable {
22 pub fn from_data(data: &[u8]) -> Option<Self> {
23 if data.is_empty() {
24 return None;
25 }
26
27 let mut freqs = [0u32; MAX_SYMBOL_VALUE + 1];
28 let mut max_sym = 0u8;
29 for &b in data {
30 freqs[b as usize] += 1;
31 if b > max_sym {
32 max_sym = b;
33 }
34 }
35
36 let num_symbols = max_sym as usize + 1;
37 let active_count = freqs[..num_symbols].iter().filter(|&&f| f > 0).count();
38 if active_count < 2 {
39 return None;
40 }
41
42 if max_sym as usize > 128 {
43 return None;
44 }
45
46 let (weights, table_log) = compute_huffman_weights(&freqs, num_symbols)?;
47 let (codes, num_bits) = build_encode_codes(&weights, table_log);
48
49 Some(Self {
50 codes,
51 num_bits,
52 weights,
53 max_symbol: max_sym,
54 table_log,
55 })
56 }
57
58 pub fn from_decode_table(
59 decode_table: &[super::HuffmanDecodeEntry],
60 table_log: u8,
61 ) -> Option<Self> {
62 let table_size = 1usize << table_log;
63 if decode_table.len() < table_size {
64 return None;
65 }
66
67 let mut num_bits_per_sym = [0u8; MAX_SYMBOL_VALUE + 1];
68 let mut max_sym = 0u8;
69 let mut seen = [false; MAX_SYMBOL_VALUE + 1];
70 for entry in &decode_table[..table_size] {
71 let s = entry.symbol;
72 if !seen[s as usize] {
73 num_bits_per_sym[s as usize] = entry.num_bits;
74 seen[s as usize] = true;
75 if s > max_sym {
76 max_sym = s;
77 }
78 }
79 }
80
81 if max_sym as usize > 128 {
82 return None;
83 }
84
85 let num_symbols = max_sym as usize + 1;
86 let active_count = seen[..num_symbols].iter().filter(|&&s| s).count();
87 if active_count < 2 {
88 return None;
89 }
90
91 let mut weights = vec![0u8; num_symbols];
92 for s in 0..num_symbols {
93 if seen[s] {
94 weights[s] = table_log + 1 - num_bits_per_sym[s];
95 }
96 }
97
98 let (codes, num_bits) = build_encode_codes(&weights, table_log);
99
100 Some(Self {
101 codes,
102 num_bits,
103 weights,
104 max_symbol: max_sym,
105 table_log,
106 })
107 }
108
109 pub fn table_log(&self) -> u8 {
110 self.table_log
111 }
112
113 pub fn can_encode(&self, data: &[u8]) -> bool {
114 for &b in data {
115 if self.num_bits[b as usize] == 0 {
116 return false;
117 }
118 }
119 true
120 }
121
122 pub fn serialize_weights(&self) -> Vec<u8> {
123 let explicit = &self.weights[..self.max_symbol as usize];
124 let num_symbols = explicit.len();
125
126 let mut out = Vec::with_capacity(1 + num_symbols.div_ceil(2));
127 out.push((num_symbols + 127) as u8);
128 let num_bytes = num_symbols.div_ceil(2);
129 for i in 0..num_bytes {
130 let hi = explicit.get(i * 2).copied().unwrap_or(0);
131 let lo = explicit.get(i * 2 + 1).copied().unwrap_or(0);
132 out.push((hi << 4) | lo);
133 }
134 out
135 }
136
137 pub fn encode_single_stream(&self, data: &[u8]) -> Vec<u8> {
138 let mut buf = Vec::with_capacity(data.len() + 8);
139 self.encode_single_stream_into(data, &mut buf);
140 buf
141 }
142
143 pub fn encode_single_stream_into(&self, data: &[u8], buf: &mut Vec<u8>) {
144 buf.clear();
145 let tl = self.table_log as usize;
146 let unroll: usize = (32usize).checked_div(tl).unwrap_or(1).max(2);
147
148 buf.reserve(data.len() + 16);
149 let mut bits: u64 = 0;
150 let mut bits_used: u8 = 0;
151 let mut wpos: usize = 0;
152
153 macro_rules! flush_bits {
154 () => {
155 if wpos + 8 > buf.capacity() {
156 primitives::set_vec_len(buf, wpos);
157 buf.reserve(64);
158 }
159 primitives::bitstream_flush_vec(buf, wpos, bits);
160 let nb = (bits_used >> 3) as usize;
161 wpos += nb;
162 bits >>= nb << 3;
163 bits_used &= 7;
164 };
165 }
166
167 let mut pos = data.len();
168 while pos >= unroll {
169 pos -= unroll;
170 for j in 0..unroll {
171 let b = primitives::get_unchecked_byte(data, pos + (unroll - 1 - j));
172 let c = primitives::get_unchecked_u16(&self.codes, b as usize) as u64;
173 let n = primitives::get_unchecked_u8_arr(&self.num_bits, b as usize);
174 bits |= c << bits_used;
175 bits_used += n;
176 }
177 if bits_used >= 32 {
178 flush_bits!();
179 }
180 }
181 while pos > 0 {
182 pos -= 1;
183 let b = primitives::get_unchecked_byte(data, pos);
184 let c = primitives::get_unchecked_u16(&self.codes, b as usize) as u64;
185 let n = primitives::get_unchecked_u8_arr(&self.num_bits, b as usize);
186 bits |= c << bits_used;
187 bits_used += n;
188 if bits_used >= 32 {
189 flush_bits!();
190 }
191 }
192
193 primitives::set_vec_len(buf, wpos);
194 bits |= 1u64 << bits_used;
195 bits_used += 1;
196 while bits_used > 0 {
197 buf.push(bits as u8);
198 bits >>= 8;
199 bits_used = bits_used.saturating_sub(8);
200 }
201 }
202
203 pub fn encode_4_streams(&self, data: &[u8]) -> Vec<u8> {
204 let mut out = Vec::new();
205 self.encode_4_streams_into(data, &mut out, &mut Vec::new());
206 out
207 }
208
209 pub fn encode_4_streams_into(&self, data: &[u8], out: &mut Vec<u8>, stream_buf: &mut Vec<u8>) {
210 let seg = data.len().div_ceil(4);
211 let s1 = &data[..seg.min(data.len())];
212 let s2 = &data[seg.min(data.len())..(seg * 2).min(data.len())];
213 let s3 = &data[(seg * 2).min(data.len())..(seg * 3).min(data.len())];
214 let s4 = &data[(seg * 3).min(data.len())..];
215
216 out.clear();
217 out.extend_from_slice(&[0u8; 6]);
218
219 self.encode_single_stream_into(s1, stream_buf);
220 let e1_len = stream_buf.len();
221 out.extend_from_slice(stream_buf);
222
223 self.encode_single_stream_into(s2, stream_buf);
224 let e2_len = stream_buf.len();
225 out.extend_from_slice(stream_buf);
226
227 self.encode_single_stream_into(s3, stream_buf);
228 let e3_len = stream_buf.len();
229 out.extend_from_slice(stream_buf);
230
231 self.encode_single_stream_into(s4, stream_buf);
232 out.extend_from_slice(stream_buf);
233
234 out[0..2].copy_from_slice(&(e1_len as u16).to_le_bytes());
235 out[2..4].copy_from_slice(&(e2_len as u16).to_le_bytes());
236 out[4..6].copy_from_slice(&(e3_len as u16).to_le_bytes());
237 }
238
239 pub fn compressed_size_single(&self, data: &[u8]) -> usize {
240 let total_bits: usize = data
241 .iter()
242 .map(|&b| self.num_bits[b as usize] as usize)
243 .sum();
244 (total_bits + 8) / 8
245 }
246}
247
248fn compute_huffman_weights(freqs: &[u32], num_symbols: usize) -> Option<(Vec<u8>, u8)> {
249 use alloc::collections::BinaryHeap;
250 use core::cmp::Reverse;
251
252 let active: Vec<(u64, usize)> = freqs[..num_symbols]
253 .iter()
254 .enumerate()
255 .filter(|(_, f)| **f > 0)
256 .map(|(s, &f)| (f as u64, s))
257 .collect();
258
259 if active.len() < 2 {
260 return None;
261 }
262
263 let n = active.len();
264
265 let max_nodes = 2 * n;
266 let mut parent = vec![usize::MAX; max_nodes];
267
268 let mut heap: BinaryHeap<Reverse<(u64, usize)>> = BinaryHeap::with_capacity(n);
269 for (i, &(f, _)) in active.iter().enumerate() {
270 heap.push(Reverse((f, i)));
271 }
272
273 for next_id in n..n + (n - 1) {
274 let Reverse((f1, n1)) = heap.pop().unwrap();
275 let Reverse((f2, n2)) = heap.pop().unwrap();
276 parent[n1] = next_id;
277 parent[n2] = next_id;
278 heap.push(Reverse((f1 + f2, next_id)));
279 }
280
281 let mut bit_lengths = vec![0u8; num_symbols];
282 for (i, &(_, sym)) in active.iter().enumerate().take(n) {
283 let mut depth = 0u8;
284 let mut node = i;
285 while parent[node] != usize::MAX {
286 depth += 1;
287 node = parent[node];
288 }
289 bit_lengths[sym] = depth;
290 }
291
292 let max_bl = *bit_lengths.iter().max().unwrap();
293 if max_bl == 0 || max_bl > MAX_BITS {
294 return None;
295 }
296
297 let table_log = max_bl;
298 let mut weights = vec![0u8; num_symbols];
299 for (s, &bl) in bit_lengths.iter().enumerate() {
300 if bl > 0 {
301 weights[s] = table_log + 1 - bl;
302 }
303 }
304
305 Some((weights, table_log))
306}
307
308fn build_encode_codes(
309 weights: &[u8],
310 table_log: u8,
311) -> ([u16; MAX_SYMBOL_VALUE + 1], [u8; MAX_SYMBOL_VALUE + 1]) {
312 let mut codes = [0u16; MAX_SYMBOL_VALUE + 1];
313 let mut num_bits = [0u8; MAX_SYMBOL_VALUE + 1];
314
315 let max_w = table_log + 1;
316 let mut rank_count = [0u32; MAX_BITS as usize + 2];
317
318 for (s, &w) in weights.iter().enumerate() {
319 if w > 0 && w <= max_w {
320 num_bits[s] = table_log + 1 - w;
321 rank_count[w as usize] += 1;
322 }
323 }
324
325 let mut rank_start = [0u32; MAX_BITS as usize + 2];
326 let mut cumul = 0u32;
327 for w in 1..=max_w {
328 rank_start[w as usize] = cumul;
329 cumul += rank_count[w as usize] * (1u32 << (w - 1));
330 }
331
332 for (s, &w) in weights.iter().enumerate() {
333 if w == 0 {
334 continue;
335 }
336 let start = rank_start[w as usize];
337 codes[s] = (start >> (w - 1)) as u16;
338 rank_start[w as usize] += 1u32 << (w - 1);
339 }
340
341 (codes, num_bits)
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn from_decode_table_roundtrip() {
350 let data = b"hello world hello world hello world!";
351 let original = HuffmanEncodeTable::from_data(data).unwrap();
352 let weights_raw = original.serialize_weights();
353
354 let (parsed_weights, _) =
355 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
356 let (decode_table, decode_log) =
357 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
358
359 let rebuilt = HuffmanEncodeTable::from_decode_table(&decode_table, decode_log).unwrap();
360 assert_eq!(original.table_log, rebuilt.table_log);
361 assert_eq!(original.max_symbol, rebuilt.max_symbol);
362 assert_eq!(original.weights, rebuilt.weights);
363 assert_eq!(original.num_bits, rebuilt.num_bits);
364 assert_eq!(original.codes, rebuilt.codes);
365
366 let encoded = rebuilt.encode_single_stream(data);
367 let decoded = crate::huffman::decode::decode_single_stream(
368 &decode_table,
369 decode_log,
370 &encoded,
371 data.len(),
372 )
373 .unwrap();
374 assert_eq!(decoded, data);
375 }
376
377 #[test]
378 fn from_decode_table_skewed() {
379 let mut data = vec![0u8; 900];
380 data.extend(vec![1u8; 80]);
381 data.extend(vec![2u8; 15]);
382 data.extend(vec![3u8; 5]);
383 let original = HuffmanEncodeTable::from_data(&data).unwrap();
384 let weights_raw = original.serialize_weights();
385
386 let (parsed_weights, _) =
387 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
388 let (decode_table, decode_log) =
389 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
390
391 let rebuilt = HuffmanEncodeTable::from_decode_table(&decode_table, decode_log).unwrap();
392 assert_eq!(original.codes, rebuilt.codes);
393 assert_eq!(original.num_bits, rebuilt.num_bits);
394
395 let encoded = rebuilt.encode_single_stream(&data);
396 let decoded = crate::huffman::decode::decode_single_stream(
397 &decode_table,
398 decode_log,
399 &encoded,
400 data.len(),
401 )
402 .unwrap();
403 assert_eq!(decoded, data);
404 }
405
406 #[test]
407 fn roundtrip_simple() {
408 let data = b"hello world hello world hello world!";
409 let table = HuffmanEncodeTable::from_data(data).unwrap();
410 let weights_raw = table.serialize_weights();
411 let encoded = table.encode_single_stream(data);
412
413 let (parsed_weights, _) =
414 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
415 let (decode_table, decode_log) =
416 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
417 let decoded = crate::huffman::decode::decode_single_stream(
418 &decode_table,
419 decode_log,
420 &encoded,
421 data.len(),
422 )
423 .unwrap();
424 assert_eq!(decoded, data);
425 }
426
427 #[test]
428 fn roundtrip_4_streams() {
429 let data: Vec<u8> = b"ABCDEFGH".iter().cycle().take(1024).copied().collect();
430 let table = HuffmanEncodeTable::from_data(&data).unwrap();
431 let weights_raw = table.serialize_weights();
432 let encoded = table.encode_4_streams(&data);
433
434 let (parsed_weights, _) =
435 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
436 let (decode_table, decode_log) =
437 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
438 let decoded = crate::huffman::decode::decode_4_streams(
439 &decode_table,
440 decode_log,
441 &encoded,
442 data.len(),
443 )
444 .unwrap();
445 assert_eq!(decoded, data);
446 }
447
448 #[test]
449 fn roundtrip_all_bytes() {
450 let data: Vec<u8> = (0u8..=127).cycle().take(4096).collect();
451 let table = HuffmanEncodeTable::from_data(&data).unwrap();
452 let weights_raw = table.serialize_weights();
453 let encoded = table.encode_single_stream(&data);
454
455 let (parsed_weights, _) =
456 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
457 let (decode_table, decode_log) =
458 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
459 let decoded = crate::huffman::decode::decode_single_stream(
460 &decode_table,
461 decode_log,
462 &encoded,
463 data.len(),
464 )
465 .unwrap();
466 assert_eq!(decoded, data);
467 }
468
469 #[test]
470 fn skewed_distribution() {
471 let mut data = vec![0u8; 900];
472 data.extend(vec![1u8; 80]);
473 data.extend(vec![2u8; 15]);
474 data.extend(vec![3u8; 5]);
475 let table = HuffmanEncodeTable::from_data(&data).unwrap();
476 assert!(table.num_bits[0] < table.num_bits[3]);
477 let weights_raw = table.serialize_weights();
478 let encoded = table.encode_single_stream(&data);
479
480 let (parsed_weights, _) =
481 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
482 let (decode_table, decode_log) =
483 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
484 let decoded = crate::huffman::decode::decode_single_stream(
485 &decode_table,
486 decode_log,
487 &encoded,
488 data.len(),
489 )
490 .unwrap();
491 assert_eq!(decoded, data);
492 }
493
494 #[test]
495 fn two_symbols() {
496 let mut data = vec![0u8; 500];
497 data.extend(vec![1u8; 500]);
498 let table = HuffmanEncodeTable::from_data(&data).unwrap();
499 assert_eq!(table.num_bits[0], 1);
500 assert_eq!(table.num_bits[1], 1);
501 let encoded = table.encode_single_stream(&data);
502
503 let weights_raw = table.serialize_weights();
504 let (parsed_weights, _) =
505 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
506 let (decode_table, decode_log) =
507 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
508 let decoded = crate::huffman::decode::decode_single_stream(
509 &decode_table,
510 decode_log,
511 &encoded,
512 data.len(),
513 )
514 .unwrap();
515 assert_eq!(decoded, data);
516 }
517}