Skip to main content

rust_zstd/
bitstream.rs

1//! Bit-level stream writers for zstd encoding.
2//!
3//! `BitWriter` — forward bitstream (Huffman literals).
4//! `BackwardBitWriter` — backward bitstream (FSE sequences, FSE weights).
5//!
6//! The backward bitstream matches C zstd's `BIT_CStream_t` exactly:
7//! - Bits accumulate LSB-first in a 64-bit register
8//! - `flush_bits()` writes full bytes to output (forward/LE)
9//! - `finish()` adds sentinel 1-bit, flushes, returns bytes
10//! - Decoder reads this from the END toward the BEGINNING
11
12/// Forward bitstream writer (Huffman literal streams).
13pub struct BitWriter {
14    buf: Vec<u8>,
15    bit_pos: u32,
16    current: u8,
17}
18
19impl Default for BitWriter {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl BitWriter {
26    pub fn new() -> Self {
27        Self {
28            buf: Vec::with_capacity(256),
29            bit_pos: 0,
30            current: 0,
31        }
32    }
33
34    pub fn write_bits(&mut self, value: u64, nbits: u32) {
35        let mut val = value;
36        let mut bits = nbits;
37        while bits > 0 {
38            let space = 8 - self.bit_pos;
39            let take = std::cmp::min(space, bits);
40            let mask = (1u64 << take) - 1;
41            self.current |= ((val & mask) as u8) << self.bit_pos;
42            val >>= take;
43            bits -= take;
44            self.bit_pos += take;
45            if self.bit_pos == 8 {
46                self.buf.push(self.current);
47                self.current = 0;
48                self.bit_pos = 0;
49            }
50        }
51    }
52
53    pub fn finish(mut self) -> Vec<u8> {
54        if self.bit_pos > 0 {
55            self.buf.push(self.current);
56        }
57        self.buf
58    }
59
60    pub fn len(&self) -> usize {
61        self.buf.len() * 8 + self.bit_pos as usize
62    }
63}
64
65/// Backward bitstream writer matching C zstd's BIT_CStream_t.
66///
67/// Bits accumulate LSB-first in a 64-bit container. `flush_bits()` writes
68/// complete bytes to the output buffer in LE order. The decoder reads
69/// from the END of this buffer (BitReaderReversed).
70///
71/// Key: bytes are written FORWARD. No reverse needed. The decoder
72/// naturally reads backward from the last byte.
73pub struct BackwardBitWriter {
74    container: u64,
75    bit_pos: u32,
76    buf: Vec<u8>,
77}
78
79impl Default for BackwardBitWriter {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl BackwardBitWriter {
86    pub fn new() -> Self {
87        Self {
88            container: 0,
89            bit_pos: 0,
90            buf: Vec::with_capacity(256),
91        }
92    }
93
94    /// Add `nbits` from the low bits of `value` to the container.
95    /// Matches: `BIT_addBits(bitC, value, nbBits)`
96    #[inline]
97    pub fn add_bits(&mut self, value: u64, nbits: u32) {
98        if nbits == 0 {
99            return;
100        }
101        debug_assert!(nbits <= 57);
102        debug_assert!(self.bit_pos + nbits <= 64);
103        let mask = if nbits >= 64 {
104            u64::MAX
105        } else {
106            (1u64 << nbits) - 1
107        };
108        self.container |= (value & mask) << self.bit_pos;
109        self.bit_pos += nbits;
110    }
111
112    /// Flush complete bytes from the container to the output.
113    /// Matches: `BIT_flushBits(bitC)`
114    #[inline]
115    pub fn flush_bits(&mut self) {
116        let nb_bytes = (self.bit_pos / 8) as usize;
117        for i in 0..nb_bytes {
118            self.buf.push((self.container >> (i * 8)) as u8);
119        }
120        self.container >>= nb_bytes * 8;
121        self.bit_pos &= 7;
122    }
123
124    /// Finalize: add sentinel 1-bit, flush remaining.
125    /// Matches: `BIT_closeCStream(bitC)`
126    pub fn finish(mut self) -> Vec<u8> {
127        self.add_bits(1, 1);
128        self.flush_bits();
129        if self.bit_pos > 0 {
130            self.buf.push(self.container as u8);
131        }
132        self.buf
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn bit_writer_basic() {
142        let mut w = BitWriter::new();
143        w.write_bits(0b101, 3);
144        w.write_bits(0b1100, 4);
145        w.write_bits(0b1, 1);
146        let bytes = w.finish();
147        assert_eq!(bytes, vec![0xE5]);
148    }
149
150    #[test]
151    fn backward_writer_sentinel_only() {
152        let w = BackwardBitWriter::new();
153        let result = w.finish();
154        // Sentinel 1-bit at position 0 → byte 0x01
155        assert_eq!(result, vec![0x01]);
156    }
157
158    #[test]
159    fn backward_writer_c_layout() {
160        let mut w = BackwardBitWriter::new();
161        w.add_bits(0xFF, 8);
162        w.flush_bits();
163        w.add_bits(0xAB, 8);
164        let result = w.finish();
165        // flush: [0xFF], then add 0xAB+sentinel → container=0x1AB, bitPos=9
166        // flush 1 byte: [0xAB], remaining 0x01
167        // result: [0xFF, 0xAB, 0x01]
168        assert_eq!(result, vec![0xFF, 0xAB, 0x01]);
169    }
170}