1#[cfg(feature = "alloc")]
2use alloc::vec;
3#[cfg(feature = "alloc")]
4use alloc::vec::Vec;
5
6use crate::huffman::{MAX_BITS, MAX_SYMBOL_VALUE};
7
8pub struct HuffmanEncodeTable {
9 codes: [u16; MAX_SYMBOL_VALUE + 1],
10 num_bits: [u8; MAX_SYMBOL_VALUE + 1],
11 weights: Vec<u8>,
12 max_symbol: u8,
13 table_log: u8,
14}
15
16#[cfg(feature = "alloc")]
17impl HuffmanEncodeTable {
18 pub fn from_data(data: &[u8]) -> Option<Self> {
19 if data.is_empty() {
20 return None;
21 }
22
23 let mut freqs = [0u32; MAX_SYMBOL_VALUE + 1];
24 let mut max_sym = 0u8;
25 for &b in data {
26 freqs[b as usize] += 1;
27 if b > max_sym {
28 max_sym = b;
29 }
30 }
31
32 let num_symbols = max_sym as usize + 1;
33 let active_count = freqs[..num_symbols].iter().filter(|&&f| f > 0).count();
34 if active_count < 2 {
35 return None;
36 }
37
38 if max_sym as usize > 128 {
41 return None;
42 }
43
44 let (weights, table_log) = compute_huffman_weights(&freqs, num_symbols)?;
45 let (codes, num_bits) = build_encode_codes(&weights, table_log);
46
47 Some(Self {
48 codes,
49 num_bits,
50 weights,
51 max_symbol: max_sym,
52 table_log,
53 })
54 }
55
56 pub fn table_log(&self) -> u8 {
57 self.table_log
58 }
59
60 pub fn can_encode(&self, data: &[u8]) -> bool {
61 for &b in data {
62 if self.num_bits[b as usize] == 0 {
63 return false;
64 }
65 }
66 true
67 }
68
69 pub fn serialize_weights(&self) -> Vec<u8> {
70 let explicit = &self.weights[..self.max_symbol as usize];
71 let num_symbols = explicit.len();
72
73 let mut out = Vec::with_capacity(1 + (num_symbols + 1) / 2);
74 out.push((num_symbols + 127) as u8);
75 let num_bytes = (num_symbols + 1) / 2;
76 for i in 0..num_bytes {
77 let hi = explicit.get(i * 2).copied().unwrap_or(0);
78 let lo = explicit.get(i * 2 + 1).copied().unwrap_or(0);
79 out.push((hi << 4) | lo);
80 }
81 out
82 }
83
84 pub fn encode_single_stream(&self, data: &[u8]) -> Vec<u8> {
85 let mut buf = Vec::with_capacity(data.len() + 8);
86 self.encode_single_stream_into(data, &mut buf);
87 buf
88 }
89
90 pub fn encode_single_stream_into(&self, data: &[u8], buf: &mut Vec<u8>) {
91 buf.clear();
92 let tl = self.table_log as usize;
93 let unroll: usize = if tl > 0 { 32 / tl } else { 1 }.max(2);
94
95 buf.reserve(data.len() + 16);
96 let mut bits: u64 = 0;
97 let mut bits_used: u8 = 0;
98 let mut wpos: usize = 0;
99 let codes = self.codes.as_ptr();
100 let nbits = self.num_bits.as_ptr();
101
102 macro_rules! flush_bits {
103 () => {
104 if wpos + 8 > buf.capacity() {
105 unsafe {
106 buf.set_len(wpos);
107 }
108 buf.reserve(64);
109 }
110 unsafe {
111 let ptr = buf.as_mut_ptr().add(wpos);
112 (ptr as *mut u64).write_unaligned(bits.to_le());
113 let nb = (bits_used >> 3) as usize;
114 wpos += nb;
115 bits >>= nb << 3;
116 bits_used &= 7;
117 }
118 };
119 }
120
121 let mut pos = data.len();
122 while pos >= unroll {
123 pos -= unroll;
124 for j in 0..unroll {
125 let b = unsafe { *data.get_unchecked(pos + (unroll - 1 - j)) };
126 let c = unsafe { *codes.add(b as usize) } as u64;
127 let n = unsafe { *nbits.add(b as usize) };
128 bits |= c << bits_used;
129 bits_used += n;
130 }
131 if bits_used >= 32 {
132 flush_bits!();
133 }
134 }
135 while pos > 0 {
136 pos -= 1;
137 let b = unsafe { *data.get_unchecked(pos) };
138 let c = unsafe { *codes.add(b as usize) } as u64;
139 let n = unsafe { *nbits.add(b as usize) };
140 bits |= c << bits_used;
141 bits_used += n;
142 if bits_used >= 32 {
143 flush_bits!();
144 }
145 }
146
147 unsafe {
148 buf.set_len(wpos);
149 }
150 bits |= 1u64 << bits_used;
151 bits_used += 1;
152 while bits_used > 0 {
153 buf.push(bits as u8);
154 bits >>= 8;
155 bits_used = bits_used.saturating_sub(8);
156 }
157 }
158
159 pub fn encode_4_streams(&self, data: &[u8]) -> Vec<u8> {
160 let mut out = Vec::new();
161 self.encode_4_streams_into(data, &mut out, &mut Vec::new());
162 out
163 }
164
165 pub fn encode_4_streams_into(&self, data: &[u8], out: &mut Vec<u8>, stream_buf: &mut Vec<u8>) {
166 let seg = (data.len() + 3) / 4;
167 let s1 = &data[..seg.min(data.len())];
168 let s2 = &data[seg.min(data.len())..(seg * 2).min(data.len())];
169 let s3 = &data[(seg * 2).min(data.len())..(seg * 3).min(data.len())];
170 let s4 = &data[(seg * 3).min(data.len())..];
171
172 out.clear();
173 out.extend_from_slice(&[0u8; 6]);
174
175 self.encode_single_stream_into(s1, stream_buf);
176 let e1_len = stream_buf.len();
177 out.extend_from_slice(stream_buf);
178
179 self.encode_single_stream_into(s2, stream_buf);
180 let e2_len = stream_buf.len();
181 out.extend_from_slice(stream_buf);
182
183 self.encode_single_stream_into(s3, stream_buf);
184 let e3_len = stream_buf.len();
185 out.extend_from_slice(stream_buf);
186
187 self.encode_single_stream_into(s4, stream_buf);
188 out.extend_from_slice(stream_buf);
189
190 out[0..2].copy_from_slice(&(e1_len as u16).to_le_bytes());
191 out[2..4].copy_from_slice(&(e2_len as u16).to_le_bytes());
192 out[4..6].copy_from_slice(&(e3_len as u16).to_le_bytes());
193 }
194
195 pub fn compressed_size_single(&self, data: &[u8]) -> usize {
196 let total_bits: usize = data
197 .iter()
198 .map(|&b| self.num_bits[b as usize] as usize)
199 .sum();
200 (total_bits + 8) / 8
201 }
202}
203
204fn compute_huffman_weights(freqs: &[u32], num_symbols: usize) -> Option<(Vec<u8>, u8)> {
205 use alloc::collections::BinaryHeap;
206 use core::cmp::Reverse;
207
208 let active: Vec<(u64, usize)> = freqs[..num_symbols]
209 .iter()
210 .enumerate()
211 .filter(|(_, f)| **f > 0)
212 .map(|(s, &f)| (f as u64, s))
213 .collect();
214
215 if active.len() < 2 {
216 return None;
217 }
218
219 let n = active.len();
220
221 let max_nodes = 2 * n;
222 let mut parent = vec![usize::MAX; max_nodes];
223
224 let mut heap: BinaryHeap<Reverse<(u64, usize)>> = BinaryHeap::with_capacity(n);
225 for (i, &(f, _)) in active.iter().enumerate() {
226 heap.push(Reverse((f, i)));
227 }
228
229 let mut next_id = n;
230 for _ in 0..n - 1 {
231 let Reverse((f1, n1)) = heap.pop().unwrap();
232 let Reverse((f2, n2)) = heap.pop().unwrap();
233 parent[n1] = next_id;
234 parent[n2] = next_id;
235 heap.push(Reverse((f1 + f2, next_id)));
236 next_id += 1;
237 }
238
239 let mut bit_lengths = vec![0u8; num_symbols];
240 for i in 0..n {
241 let sym = active[i].1;
242 let mut depth = 0u8;
243 let mut node = i;
244 while parent[node] != usize::MAX {
245 depth += 1;
246 node = parent[node];
247 }
248 bit_lengths[sym] = depth;
249 }
250
251 let max_bl = *bit_lengths.iter().max().unwrap();
252 if max_bl == 0 || max_bl > MAX_BITS {
253 return None;
254 }
255
256 let table_log = max_bl;
257 let mut weights = vec![0u8; num_symbols];
258 for (s, &bl) in bit_lengths.iter().enumerate() {
259 if bl > 0 {
260 weights[s] = table_log + 1 - bl;
261 }
262 }
263
264 Some((weights, table_log))
265}
266
267fn build_encode_codes(
268 weights: &[u8],
269 table_log: u8,
270) -> ([u16; MAX_SYMBOL_VALUE + 1], [u8; MAX_SYMBOL_VALUE + 1]) {
271 let mut codes = [0u16; MAX_SYMBOL_VALUE + 1];
272 let mut num_bits = [0u8; MAX_SYMBOL_VALUE + 1];
273
274 let max_w = table_log + 1;
275 let mut rank_count = [0u32; MAX_BITS as usize + 2];
276
277 for (s, &w) in weights.iter().enumerate() {
278 if w > 0 && w <= max_w {
279 num_bits[s] = table_log + 1 - w;
280 rank_count[w as usize] += 1;
281 }
282 }
283
284 let mut rank_start = [0u32; MAX_BITS as usize + 2];
285 let mut cumul = 0u32;
286 for w in 1..=max_w {
287 rank_start[w as usize] = cumul;
288 cumul += rank_count[w as usize] * (1u32 << (w - 1));
289 }
290
291 for (s, &w) in weights.iter().enumerate() {
292 if w == 0 {
293 continue;
294 }
295 let start = rank_start[w as usize];
296 codes[s] = (start >> (w - 1)) as u16;
297 rank_start[w as usize] += 1u32 << (w - 1);
298 }
299
300 (codes, num_bits)
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn roundtrip_simple() {
309 let data = b"hello world hello world hello world!";
310 let table = HuffmanEncodeTable::from_data(data).unwrap();
311 let weights_raw = table.serialize_weights();
312 let encoded = table.encode_single_stream(data);
313
314 let (parsed_weights, _) =
315 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
316 let (decode_table, decode_log) =
317 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
318 let decoded = crate::huffman::decode::decode_single_stream(
319 &decode_table,
320 decode_log,
321 &encoded,
322 data.len(),
323 )
324 .unwrap();
325 assert_eq!(decoded, data);
326 }
327
328 #[test]
329 fn roundtrip_4_streams() {
330 let data: Vec<u8> = b"ABCDEFGH".iter().cycle().take(1024).copied().collect();
331 let table = HuffmanEncodeTable::from_data(&data).unwrap();
332 let weights_raw = table.serialize_weights();
333 let encoded = table.encode_4_streams(&data);
334
335 let (parsed_weights, _) =
336 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
337 let (decode_table, decode_log) =
338 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
339 let decoded = crate::huffman::decode::decode_4_streams(
340 &decode_table,
341 decode_log,
342 &encoded,
343 data.len(),
344 )
345 .unwrap();
346 assert_eq!(decoded, data);
347 }
348
349 #[test]
350 fn roundtrip_all_bytes() {
351 let data: Vec<u8> = (0u8..=127).cycle().take(4096).collect();
352 let table = HuffmanEncodeTable::from_data(&data).unwrap();
353 let weights_raw = table.serialize_weights();
354 let encoded = table.encode_single_stream(&data);
355
356 let (parsed_weights, _) =
357 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
358 let (decode_table, decode_log) =
359 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
360 let decoded = crate::huffman::decode::decode_single_stream(
361 &decode_table,
362 decode_log,
363 &encoded,
364 data.len(),
365 )
366 .unwrap();
367 assert_eq!(decoded, data);
368 }
369
370 #[test]
371 fn skewed_distribution() {
372 let mut data = vec![0u8; 900];
373 data.extend(vec![1u8; 80]);
374 data.extend(vec![2u8; 15]);
375 data.extend(vec![3u8; 5]);
376 let table = HuffmanEncodeTable::from_data(&data).unwrap();
377 assert!(table.num_bits[0] < table.num_bits[3]);
378 let weights_raw = table.serialize_weights();
379 let encoded = table.encode_single_stream(&data);
380
381 let (parsed_weights, _) =
382 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
383 let (decode_table, decode_log) =
384 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
385 let decoded = crate::huffman::decode::decode_single_stream(
386 &decode_table,
387 decode_log,
388 &encoded,
389 data.len(),
390 )
391 .unwrap();
392 assert_eq!(decoded, data);
393 }
394
395 #[test]
396 fn two_symbols() {
397 let mut data = vec![0u8; 500];
398 data.extend(vec![1u8; 500]);
399 let table = HuffmanEncodeTable::from_data(&data).unwrap();
400 assert_eq!(table.num_bits[0], 1);
401 assert_eq!(table.num_bits[1], 1);
402 let encoded = table.encode_single_stream(&data);
403
404 let weights_raw = table.serialize_weights();
405 let (parsed_weights, _) =
406 crate::huffman::weights::parse_huffman_weights(&weights_raw).unwrap();
407 let (decode_table, decode_log) =
408 crate::huffman::weights::build_huffman_decode_table(&parsed_weights).unwrap();
409 let decoded = crate::huffman::decode::decode_single_stream(
410 &decode_table,
411 decode_log,
412 &encoded,
413 data.len(),
414 )
415 .unwrap();
416 assert_eq!(decoded, data);
417 }
418}