Skip to main content

zrip_core/bitstream/
writer.rs

1#![forbid(unsafe_code)]
2
3#[cfg(feature = "alloc")]
4use alloc::vec::Vec;
5
6use super::primitives;
7
8pub struct BitWriter {
9    #[cfg(feature = "alloc")]
10    buf: Vec<u8>,
11    bits: u64,
12    bits_used: u8,
13    pos: usize,
14}
15
16#[cfg(feature = "alloc")]
17impl Default for BitWriter {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23#[cfg(feature = "alloc")]
24impl BitWriter {
25    pub fn new() -> Self {
26        Self {
27            buf: Vec::with_capacity(64),
28            bits: 0,
29            bits_used: 0,
30            pos: 0,
31        }
32    }
33
34    pub fn with_capacity(cap: usize) -> Self {
35        Self {
36            buf: Vec::with_capacity(cap + 8),
37            bits: 0,
38            bits_used: 0,
39            pos: 0,
40        }
41    }
42
43    pub fn from_vec(mut buf: Vec<u8>) -> Self {
44        buf.clear();
45        buf.reserve(8);
46        Self {
47            buf,
48            bits: 0,
49            bits_used: 0,
50            pos: 0,
51        }
52    }
53
54    pub fn into_vec(mut self) -> Vec<u8> {
55        self.flush_remaining();
56        self.buf
57    }
58
59    #[inline(always)]
60    fn ensure_capacity(&mut self) {
61        if self.pos + 8 > self.buf.capacity() {
62            primitives::set_vec_len(&mut self.buf, self.pos);
63            self.buf.reserve(64);
64        }
65    }
66
67    #[inline(always)]
68    pub fn write_bits(&mut self, value: u32, n: u8) {
69        debug_assert!(n <= 25);
70        if n == 0 {
71            return;
72        }
73        debug_assert!(value < (1u32 << n));
74        self.bits |= (value as u64) << self.bits_used;
75        self.bits_used += n;
76        if self.bits_used >= 32 {
77            self.ensure_capacity();
78            primitives::write_u64_le_unaligned(&mut self.buf, self.pos, self.bits);
79            let nb = (self.bits_used >> 3) as usize;
80            self.pos += nb;
81            self.bits >>= nb << 3;
82            self.bits_used &= 7;
83        }
84    }
85
86    pub fn flush_remaining(&mut self) {
87        primitives::set_vec_len(&mut self.buf, self.pos);
88        while self.bits_used > 0 {
89            self.buf.push(self.bits as u8);
90            self.bits >>= 8;
91            self.bits_used = self.bits_used.saturating_sub(8);
92        }
93        self.pos = self.buf.len();
94    }
95
96    pub fn close_reverse_stream(&mut self) {
97        self.write_bits(1, 1);
98        self.flush_remaining();
99    }
100
101    pub fn bits_written(&self) -> usize {
102        self.pos * 8 + self.bits_used as usize
103    }
104
105    pub fn into_bytes(mut self) -> Vec<u8> {
106        self.flush_remaining();
107        self.buf
108    }
109
110    pub fn as_bytes(&mut self) -> &[u8] {
111        primitives::set_vec_len(&mut self.buf, self.pos);
112        &self.buf
113    }
114
115    pub fn write_byte(&mut self, b: u8) {
116        self.write_bits(b as u32, 8);
117    }
118
119    pub fn write_bytes(&mut self, bytes: &[u8]) {
120        for &b in bytes {
121            self.write_byte(b);
122        }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn write_and_read_back() {
132        let mut w = BitWriter::new();
133        w.write_bits(0b101, 3);
134        w.write_bits(0b11, 2);
135        w.write_bits(0b000, 3);
136        let bytes = w.into_bytes();
137        assert_eq!(bytes, [0b0001_1101]);
138    }
139
140    #[test]
141    fn write_cross_byte() {
142        let mut w = BitWriter::new();
143        w.write_bits(0xFF, 8);
144        w.write_bits(0x01, 8);
145        let bytes = w.into_bytes();
146        assert_eq!(bytes, [0xFF, 0x01]);
147    }
148
149    #[test]
150    fn write_partial() {
151        let mut w = BitWriter::new();
152        w.write_bits(0b1, 1);
153        let bytes = w.into_bytes();
154        assert_eq!(bytes, [0b0000_0001]);
155    }
156
157    #[test]
158    fn reverse_stream_roundtrip() {
159        use crate::bitstream::reader_reverse::ReverseBitReader;
160
161        let mut w = BitWriter::new();
162        w.write_bits(0b1010, 4);
163        w.write_bits(0b0101, 4);
164        w.close_reverse_stream();
165        let bytes = w.into_bytes();
166
167        let mut r = ReverseBitReader::new(&bytes).unwrap();
168        assert_eq!(r.read_bits(4).unwrap(), 0b0101);
169        assert_eq!(r.read_bits(4).unwrap(), 0b1010);
170        assert_eq!(r.bits_remaining(), 0);
171    }
172}