Skip to main content

xpress_huffman/
lib.rs

1//! Pure-Rust decompressor for **Microsoft Xpress-Huffman**
2//! (`LZXPRESS_HUFFMAN`, `COMPRESSION_FORMAT_XPRESS_HUFF` = 4), specified in
3//! [MS-XCA] §2.2.4.
4//!
5//! Xpress-Huffman is the LZ77+Huffman codec that modern Windows uses pervasively:
6//! Win8.1+ **prefetch** (`MAM` wrapper), **`hiberfil.sys`**, **SMB3** compression,
7//! **registry-hive** compression, and Windows Update payloads. Unlike *plain*
8//! `LZXpress` (`COMPRESSION_FORMAT_XPRESS` = 3, the `LZNT`-style format the existing
9//! `rust-lzxpress` / `xpress_rs` crates implement), this is the Huffman-coded
10//! variant.
11//!
12//! This crate is **cross-platform** — it does not call the Windows
13//! `RtlDecompressBufferEx` API, so it decompresses Windows artifacts on Linux and
14//! macOS just as well as on Windows. It is **panic-free**: every length/offset
15//! field read from the (untrusted) input is bounds-checked, never indexed blindly.
16//!
17//! ```
18//! # fn main() -> Result<(), xpress_huffman::Error> {
19//! # let compressed: &[u8] = &[];
20//! # let known_output_size = 0;
21//! # if !compressed.is_empty() {
22//! let plain = xpress_huffman::decompress(compressed, known_output_size)?;
23//! # let _ = plain;
24//! # }
25//! # Ok(())
26//! # }
27//! ```
28//!
29//! Reimplemented clean-room from the [MS-XCA] algorithm (structure cross-checked
30//! against Fox-IT's `dissect.util`; no code copied).
31//! Spec: <https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-xca/>.
32
33#![forbid(unsafe_code)]
34#![cfg_attr(not(any(feature = "std", test)), no_std)]
35
36extern crate alloc;
37use alloc::vec;
38use alloc::vec::Vec;
39
40/// Errors that can arise while decompressing an Xpress-Huffman stream.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum Error {
43    /// A 64 KiB block's 256-byte Huffman code-length table was truncated.
44    TruncatedTable,
45    /// A back-reference pointed before the start of the output (corrupt stream).
46    BadMatchOffset,
47}
48
49impl core::fmt::Display for Error {
50    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
51        match self {
52            Error::TruncatedTable => f.write_str("xpress-huffman: truncated Huffman table"),
53            Error::BadMatchOffset => {
54                f.write_str("xpress-huffman: match offset before output start")
55            }
56        }
57    }
58}
59
60#[cfg(feature = "std")]
61impl std::error::Error for Error {}
62
63/// Per-block Huffman tree size: 512 symbols × 4-bit code lengths = 256 bytes.
64const TABLE_LEN: usize = 256;
65/// Each block decodes at most 64 KiB before a fresh table is read.
66const BLOCK_SIZE: usize = 1 << 16;
67
68/// Decompress an Xpress-Huffman stream into `decompressed_size` bytes.
69///
70/// `decompressed_size` is the output length recorded by the container that wraps
71/// the stream (e.g. the 4-byte size after a prefetch `MAM\x04` signature, or the
72/// uncompressed size in an SMB3 / hibernation header). Decoding stops once that
73/// many bytes have been produced or the input is exhausted, whichever comes
74/// first; the returned `Vec` is exactly the bytes decoded.
75pub fn decompress(compressed: &[u8], decompressed_size: usize) -> Result<Vec<u8>, Error> {
76    let mut dst: Vec<u8> = Vec::with_capacity(decompressed_size);
77    let mut bs = BitStream::new(compressed);
78
79    while bs.pos < compressed.len() && dst.len() < decompressed_size {
80        let table = compressed
81            .get(bs.pos..bs.pos + TABLE_LEN)
82            .ok_or(Error::TruncatedTable)?;
83        let tree = build_tree(table);
84        bs.pos += TABLE_LEN;
85        bs.init();
86
87        let mut produced_in_block = 0usize;
88        while produced_in_block < BLOCK_SIZE
89            && bs.pos < compressed.len()
90            && dst.len() < decompressed_size
91        {
92            let symbol = bs.decode(&tree);
93            if symbol < 256 {
94                dst.push(symbol as u8);
95                produced_in_block += 1;
96                continue;
97            }
98            let symbol = symbol - 256;
99            let offset_bits = u32::from(symbol >> 4);
100            let offset = (1usize << offset_bits) + bs.lookup(offset_bits) as usize;
101            let length = bs.match_length(u32::from(symbol & 0x0F));
102            bs.skip(offset_bits as i32);
103
104            if offset == 0 || offset > dst.len() {
105                return Err(Error::BadMatchOffset);
106            }
107            let mut remaining = length as usize;
108            while remaining > 0 {
109                let from = dst.len() - offset;
110                let n = remaining.min(offset);
111                for k in 0..n {
112                    let b = dst[from + k];
113                    dst.push(b);
114                }
115                remaining -= n;
116            }
117            produced_in_block += length as usize;
118        }
119    }
120    Ok(dst)
121}
122
123/// One Huffman tree node (index-based; no per-child allocation).
124#[derive(Clone, Copy)]
125struct Node {
126    children: [usize; 2],
127    is_leaf: bool,
128    symbol: u16,
129}
130const NONE: usize = usize::MAX;
131
132/// Build the per-block Huffman decode tree from its 256-byte code-length table
133/// (512 symbols, 4 bits each: byte `k` holds symbol `2k` in the low nibble,
134/// `2k+1` in the high nibble). Canonical-code assignment per [MS-XCA].
135fn build_tree(buf: &[u8]) -> Vec<Node> {
136    let mut nodes = vec![
137        Node {
138            children: [NONE, NONE],
139            is_leaf: false,
140            symbol: 0,
141        };
142        1024
143    ];
144
145    let mut symbols: Vec<(u8, u16)> = Vec::with_capacity(512);
146    for (i, &c) in buf.iter().enumerate().take(TABLE_LEN) {
147        symbols.push((c & 0x0F, (i * 2) as u16));
148        symbols.push((c >> 4, (i * 2 + 1) as u16));
149    }
150    symbols.sort_unstable();
151
152    let start = symbols.iter().take_while(|(len, _)| *len == 0).count();
153
154    let mut mask: u32 = 0;
155    let mut bits: u32 = 1;
156    let mut tree_index = 1usize;
157
158    for &(length, symbol) in symbols.iter().take(512).skip(start) {
159        let length = u32::from(length);
160        {
161            let node = &mut nodes[tree_index];
162            node.symbol = symbol;
163            node.is_leaf = true;
164        }
165        mask = mask.wrapping_shl(length.wrapping_sub(bits));
166        bits = length;
167        tree_index = add_leaf(&mut nodes, tree_index, mask, bits);
168        mask = mask.wrapping_add(1);
169    }
170    nodes
171}
172
173/// Splice leaf node `idx` into the tree along the path described by `mask`/`bits`,
174/// creating internal nodes as needed. Returns the next free node index.
175fn add_leaf(nodes: &mut [Node], idx: usize, mask: u32, bits: u32) -> usize {
176    let mut cur = 0usize;
177    let mut i = idx + 1;
178    let mut bits = bits;
179    while bits > 1 {
180        bits -= 1;
181        let childidx = ((mask >> bits) & 1) as usize;
182        if nodes[cur].children[childidx] == NONE {
183            nodes[cur].children[childidx] = i;
184            nodes[i].is_leaf = false;
185            i += 1;
186        }
187        cur = nodes[cur].children[childidx];
188    }
189    nodes[cur].children[(mask & 1) as usize] = idx;
190    i
191}
192
193/// Bit reader over the compressed stream: a 32-bit window refilled 16 bits at a
194/// time from little-endian source words, with a byte cursor SHARED with the
195/// extended-length / extra-offset byte reads (they interleave, per [MS-XCA]).
196struct BitStream<'a> {
197    data: &'a [u8],
198    pos: usize,
199    mask: u32,
200    bits: i32,
201}
202
203impl<'a> BitStream<'a> {
204    fn new(data: &'a [u8]) -> Self {
205        Self {
206            data,
207            pos: 0,
208            mask: 0,
209            bits: 0,
210        }
211    }
212
213    /// Read one 16-bit little-endian word; at EOF a lone trailing byte becomes
214    /// the high byte. Advances the cursor by the bytes actually available (≤ 2).
215    fn read16(&mut self) -> u32 {
216        let avail = self.data.len().saturating_sub(self.pos);
217        let v = match avail {
218            0 => 0,
219            1 => u32::from(self.data[self.pos]) << 8,
220            _ => u32::from(u16::from_le_bytes([
221                self.data[self.pos],
222                self.data[self.pos + 1],
223            ])),
224        };
225        self.pos += avail.min(2);
226        v
227    }
228
229    fn init(&mut self) {
230        self.mask = (self.read16() << 16).wrapping_add(self.read16());
231        self.bits = 32;
232    }
233
234    fn lookup(&self, n: u32) -> u32 {
235        if n == 0 {
236            0
237        } else {
238            self.mask >> (32 - n)
239        }
240    }
241
242    fn skip(&mut self, n: i32) {
243        self.mask = self.mask.wrapping_shl(n as u32);
244        self.bits -= n;
245        if self.bits < 16 {
246            self.mask = self.mask.wrapping_add(self.read16() << (16 - self.bits));
247            self.bits += 16;
248        }
249    }
250
251    fn read_byte(&mut self) -> u8 {
252        let b = self.data.get(self.pos).copied().unwrap_or(0);
253        self.pos += 1;
254        b
255    }
256
257    /// Decode an LZ77 match length from its 4-bit symbol nibble ([MS-XCA]): a
258    /// nibble of 15 escalates to a trailing length byte, and a byte of 255 (i.e.
259    /// length 270) escalates again to a trailing 16-bit length word. The returned
260    /// value already includes the +3 minimum-match bias.
261    fn match_length(&mut self, nibble: u32) -> u32 {
262        let mut length = nibble;
263        if length == 15 {
264            length = u32::from(self.read_byte()) + 15;
265            if length == 270 {
266                length = self.read16();
267            }
268        }
269        length + 3
270    }
271
272    /// Walk the Huffman tree one bit at a time to a leaf symbol.
273    fn decode(&mut self, nodes: &[Node]) -> u16 {
274        let mut node = 0usize;
275        while !nodes[node].is_leaf {
276            let bit = self.lookup(1) as usize;
277            self.skip(1);
278            let next = nodes[node].children[bit];
279            if next == NONE {
280                return 0; // cov:unreachable: a valid tree always reaches a leaf
281            }
282            node = next;
283        }
284        nodes[node].symbol
285    }
286}
287
288#[cfg(test)]
289#[allow(clippy::unwrap_used, clippy::expect_used)]
290mod tests {
291    use super::*;
292
293    // A real Xpress-Huffman test vector: the compressed stream from a Win10
294    // prefetch file (Stolen Szechuan Sauce, AM_DELTA.EXE) with its 8-byte MAM
295    // wrapper stripped, plus the expected decompressed payload. The expected
296    // bytes were independently confirmed byte-for-byte by Fox-IT's dissect.util
297    // decompressor (see docs/validation.md). Provenance: tests/data/README.md.
298    const COMPRESSED: &[u8] = include_bytes!("../tests/data/am_delta.xhuff");
299    const EXPECTED: &[u8] = include_bytes!("../tests/data/am_delta.expected");
300    // A larger real vector (35954 bytes) — exercises the extended match-length
301    // path (length == 270 → a trailing 16-bit length word).
302    const AUDIODG: &[u8] = include_bytes!("../tests/data/audiodg.xhuff");
303    const AUDIODG_EXPECTED: &[u8] = include_bytes!("../tests/data/audiodg.expected");
304
305    #[test]
306    fn decompresses_real_xpress_huffman_vector() {
307        let out = decompress(COMPRESSED, EXPECTED.len()).unwrap();
308        assert_eq!(out.len(), EXPECTED.len());
309        assert_eq!(out, EXPECTED);
310    }
311
312    #[test]
313    fn decompresses_larger_real_vector() {
314        let out = decompress(AUDIODG, AUDIODG_EXPECTED.len()).unwrap();
315        assert_eq!(out, AUDIODG_EXPECTED);
316    }
317
318    /// Build a one-symbol Huffman table whose first decoded symbol is the match
319    /// symbol 256 (code length 1, code `0`). Decoding it before any literal has
320    /// been emitted yields a match with offset 1 into an empty output.
321    fn table_first_symbol_is_match() -> [u8; TABLE_LEN] {
322        let mut t = [0u8; TABLE_LEN];
323        // symbol 256 lives in byte 128's low nibble; length 1.
324        t[128] = 0x01;
325        t
326    }
327
328    #[test]
329    fn match_before_any_output_errors() {
330        // table + 4 init bytes + 2 padding bytes (so the decode loop runs): the
331        // first decoded symbol is 256 → a match with offset 1, but the output is
332        // empty → BadMatchOffset (also exercises the lookup(0) zero-bits path).
333        let mut input = table_first_symbol_is_match().to_vec();
334        input.extend_from_slice(&[0, 0, 0, 0, 0, 0]);
335        assert_eq!(decompress(&input, 100), Err(Error::BadMatchOffset));
336    }
337
338    #[test]
339    fn handles_init_at_exact_eof() {
340        // Only 2 bytes after the table: init's second 16-bit read sees 0 bytes
341        // left (EOF → 0). The stream is then exhausted, so the result is empty.
342        let mut input = table_first_symbol_is_match().to_vec();
343        input.extend_from_slice(&[0, 0]);
344        assert_eq!(decompress(&input, 100), Ok(Vec::new()));
345    }
346
347    #[test]
348    fn handles_init_with_one_trailing_byte() {
349        // 3 bytes after the table: init's second 16-bit read sees a single
350        // trailing byte (EOF padding → it becomes the high byte).
351        let mut input = table_first_symbol_is_match().to_vec();
352        input.extend_from_slice(&[0, 0, 0]);
353        assert_eq!(decompress(&input, 100), Ok(Vec::new()));
354    }
355
356    #[test]
357    fn stops_at_requested_size() {
358        // Asking for fewer bytes than the stream holds returns exactly that many.
359        let out = decompress(COMPRESSED, 16).unwrap();
360        assert_eq!(out.len(), 16);
361        assert_eq!(out, &EXPECTED[..16]);
362    }
363
364    #[test]
365    fn empty_input_yields_empty() {
366        assert_eq!(decompress(&[], 0).unwrap(), Vec::<u8>::new());
367        assert_eq!(decompress(&[], 100).unwrap(), Vec::<u8>::new());
368    }
369
370    #[test]
371    fn truncated_table_errors() {
372        // Fewer than 256 bytes of table → TruncatedTable, no panic.
373        assert_eq!(decompress(&[0u8; 10], 100), Err(Error::TruncatedTable));
374    }
375
376    #[test]
377    fn error_is_display() {
378        assert!(!format!("{}", Error::TruncatedTable).is_empty());
379        assert!(!format!("{}", Error::BadMatchOffset).is_empty());
380    }
381
382    // The match-length escalation ladder ([MS-XCA]), unit-tested directly: a full
383    // 273+ byte run (nibble 15 → byte 255 → trailing 16-bit word) is too rare to
384    // appear in the small real vectors, so exercise the decision in isolation.
385    #[test]
386    fn match_length_short() {
387        // nibble < 15: no escalation, just the +3 bias.
388        let mut bs = BitStream::new(&[]);
389        assert_eq!(bs.match_length(7), 10);
390    }
391
392    #[test]
393    fn match_length_one_byte_extension() {
394        // nibble 15, extension byte 10 (!= 255): length = 10 + 15 + 3.
395        let mut bs = BitStream::new(&[10]);
396        assert_eq!(bs.match_length(15), 28);
397    }
398
399    #[test]
400    fn match_length_max_run_reads_16bit() {
401        // nibble 15, extension byte 255 → 270 → trailing u16 (LE 0x1234) is the
402        // real length. This is the line a 273+ byte run reaches.
403        let mut bs = BitStream::new(&[255, 0x34, 0x12]);
404        assert_eq!(bs.match_length(15), 0x1234 + 3);
405    }
406}