1use crate::error::{BitError, BitResult};
4
5#[derive(Debug)]
10pub struct BitWriter<'a> {
11 buf: &'a mut [u8],
12 bit_pos: usize,
13}
14
15impl<'a> BitWriter<'a> {
16 #[must_use]
18 pub const fn new(buf: &'a mut [u8]) -> Self {
19 Self { buf, bit_pos: 0 }
20 }
21
22 #[must_use]
24 pub const fn bits_written(&self) -> usize {
25 self.bit_pos
26 }
27
28 #[must_use]
30 pub const fn bits_remaining(&self) -> usize {
31 self.buf
32 .len()
33 .saturating_mul(8)
34 .saturating_sub(self.bit_pos)
35 }
36
37 pub fn write_bit(&mut self, value: bool) -> BitResult<()> {
39 self.ensure_bits(1)?;
40 self.write_bit_unchecked(value);
41 Ok(())
42 }
43
44 pub fn write_bits(&mut self, value: u64, bits: u8) -> BitResult<()> {
51 if bits > 64 {
52 return Err(BitError::InvalidBitCount { bits, max_bits: 64 });
53 }
54 if bits == 0 {
55 return Ok(());
56 }
57 if bits < 64 && value >= (1u64 << bits) {
58 return Err(BitError::ValueOutOfRange { value, bits });
59 }
60 self.ensure_bits(bits as usize)?;
61 for i in (0..bits).rev() {
62 let bit = (value >> i) & 1 == 1;
63 self.write_bit_unchecked(bit);
64 }
65 Ok(())
66 }
67
68 pub fn align_to_byte(&mut self) -> BitResult<()> {
70 let rem = self.bit_pos % 8;
71 if rem == 0 {
72 return Ok(());
73 }
74 let padding = 8 - rem;
75 self.ensure_bits(padding)?;
76 for _ in 0..padding {
77 self.write_bit_unchecked(false);
78 }
79 Ok(())
80 }
81
82 pub fn write_u8_aligned(&mut self, value: u8) -> BitResult<()> {
84 self.ensure_aligned()?;
85 self.ensure_bits(8)?;
86 let idx = self.bit_pos / 8;
87 self.buf[idx] = value;
88 self.bit_pos += 8;
89 Ok(())
90 }
91
92 pub fn write_u16_aligned(&mut self, value: u16) -> BitResult<()> {
94 self.write_bytes_aligned(&value.to_le_bytes())
95 }
96
97 pub fn write_u32_aligned(&mut self, value: u32) -> BitResult<()> {
99 self.write_bytes_aligned(&value.to_le_bytes())
100 }
101
102 pub fn write_u64_aligned(&mut self, value: u64) -> BitResult<()> {
104 self.write_bytes_aligned(&value.to_le_bytes())
105 }
106
107 pub fn write_varu32(&mut self, mut value: u32) -> BitResult<()> {
109 self.ensure_aligned()?;
110 loop {
111 let mut byte = (value & 0x7F) as u8;
112 value >>= 7;
113 if value != 0 {
114 byte |= 0x80;
115 }
116 self.write_u8_aligned(byte)?;
117 if value == 0 {
118 break;
119 }
120 }
121 Ok(())
122 }
123
124 pub fn write_vars32(&mut self, value: i32) -> BitResult<()> {
126 let zigzag = ((value << 1) ^ (value >> 31)) as u32;
127 self.write_varu32(zigzag)
128 }
129
130 #[must_use]
132 pub fn finish(self) -> usize {
133 self.bit_pos.div_ceil(8)
134 }
135
136 fn ensure_bits(&self, bits: usize) -> BitResult<()> {
137 let available = self.bits_remaining();
138 if bits > available {
139 return Err(BitError::WriteOverflow {
140 attempted: bits,
141 available,
142 });
143 }
144 Ok(())
145 }
146
147 fn ensure_aligned(&self) -> BitResult<()> {
148 if self.bit_pos % 8 != 0 {
149 return Err(BitError::MisalignedAccess {
150 bit_position: self.bit_pos,
151 });
152 }
153 Ok(())
154 }
155
156 fn write_bytes_aligned(&mut self, bytes: &[u8]) -> BitResult<()> {
157 self.ensure_aligned()?;
158 self.ensure_bits(bytes.len() * 8)?;
159 let idx = self.bit_pos / 8;
160 self.buf[idx..idx + bytes.len()].copy_from_slice(bytes);
161 self.bit_pos += bytes.len() * 8;
162 Ok(())
163 }
164
165 fn write_bit_unchecked(&mut self, value: bool) {
166 let byte_idx = self.bit_pos / 8;
167 let bit_idx = self.bit_pos % 8;
168 let mask = 1u8 << (7 - bit_idx);
169 if value {
170 self.buf[byte_idx] |= mask;
171 } else {
172 self.buf[byte_idx] &= !mask;
173 }
174 self.bit_pos += 1;
175 }
176}
177
178#[derive(Debug, Default)]
180pub struct BitVecWriter {
181 buf: Vec<u8>,
182 bit_pos: usize,
183}
184
185impl BitVecWriter {
186 #[must_use]
188 pub fn new() -> Self {
189 Self::default()
190 }
191
192 #[must_use]
194 pub fn with_capacity(bytes: usize) -> Self {
195 Self {
196 buf: Vec::with_capacity(bytes),
197 bit_pos: 0,
198 }
199 }
200
201 #[must_use]
203 pub const fn bits_written(&self) -> usize {
204 self.bit_pos
205 }
206
207 pub fn write_bit(&mut self, value: bool) {
209 self.ensure_capacity_bits(1);
210 self.write_bit_unchecked(value);
211 }
212
213 pub fn write_bits(&mut self, value: u64, bits: u8) -> BitResult<()> {
215 if bits > 64 {
216 return Err(BitError::InvalidBitCount { bits, max_bits: 64 });
217 }
218 if bits == 0 {
219 return Ok(());
220 }
221 if bits < 64 && value >= (1u64 << bits) {
222 return Err(BitError::ValueOutOfRange { value, bits });
223 }
224 self.ensure_capacity_bits(bits as usize);
225 for i in (0..bits).rev() {
226 let bit = (value >> i) & 1 == 1;
227 self.write_bit_unchecked(bit);
228 }
229 Ok(())
230 }
231
232 pub fn align_to_byte(&mut self) {
234 let rem = self.bit_pos % 8;
235 if rem == 0 {
236 return;
237 }
238 let padding = 8 - rem;
239 self.ensure_capacity_bits(padding);
240 for _ in 0..padding {
241 self.write_bit_unchecked(false);
242 }
243 }
244
245 pub fn write_u8_aligned(&mut self, value: u8) -> BitResult<()> {
247 self.ensure_aligned()?;
248 self.ensure_capacity_bits(8);
249 let idx = self.bit_pos / 8;
250 self.buf[idx] = value;
251 self.bit_pos += 8;
252 Ok(())
253 }
254
255 pub fn write_u16_aligned(&mut self, value: u16) -> BitResult<()> {
257 self.write_bytes_aligned(&value.to_le_bytes())
258 }
259
260 pub fn write_u32_aligned(&mut self, value: u32) -> BitResult<()> {
262 self.write_bytes_aligned(&value.to_le_bytes())
263 }
264
265 pub fn write_u64_aligned(&mut self, value: u64) -> BitResult<()> {
267 self.write_bytes_aligned(&value.to_le_bytes())
268 }
269
270 pub fn write_varu32(&mut self, mut value: u32) -> BitResult<()> {
272 self.ensure_aligned()?;
273 loop {
274 let mut byte = (value & 0x7F) as u8;
275 value >>= 7;
276 if value != 0 {
277 byte |= 0x80;
278 }
279 self.write_u8_aligned(byte)?;
280 if value == 0 {
281 break;
282 }
283 }
284 Ok(())
285 }
286
287 pub fn write_vars32(&mut self, value: i32) -> BitResult<()> {
289 let zigzag = ((value << 1) ^ (value >> 31)) as u32;
290 self.write_varu32(zigzag)
291 }
292
293 #[must_use]
295 pub fn finish(mut self) -> Vec<u8> {
296 let bytes = self.bit_pos.div_ceil(8);
297 self.buf.truncate(bytes);
298 self.buf
299 }
300
301 fn ensure_capacity_bits(&mut self, bits: usize) {
302 let required_bits = self.bit_pos + bits;
303 let required_bytes = required_bits.div_ceil(8);
304 if required_bytes > self.buf.len() {
305 self.buf.resize(required_bytes, 0);
306 }
307 }
308
309 fn ensure_aligned(&self) -> BitResult<()> {
310 if self.bit_pos % 8 != 0 {
311 return Err(BitError::MisalignedAccess {
312 bit_position: self.bit_pos,
313 });
314 }
315 Ok(())
316 }
317
318 fn write_bytes_aligned(&mut self, bytes: &[u8]) -> BitResult<()> {
319 self.ensure_aligned()?;
320 self.ensure_capacity_bits(bytes.len() * 8);
321 let idx = self.bit_pos / 8;
322 self.buf[idx..idx + bytes.len()].copy_from_slice(bytes);
323 self.bit_pos += bytes.len() * 8;
324 Ok(())
325 }
326
327 fn write_bit_unchecked(&mut self, value: bool) {
328 let byte_idx = self.bit_pos / 8;
329 let bit_idx = self.bit_pos % 8;
330 let mask = 1u8 << (7 - bit_idx);
331 if value {
332 self.buf[byte_idx] |= mask;
333 } else {
334 self.buf[byte_idx] &= !mask;
335 }
336 self.bit_pos += 1;
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn bounded_overflow() {
346 let mut buf = [0u8; 1];
347 let mut writer = BitWriter::new(&mut buf);
348 writer.write_bits(0xFF, 8).unwrap();
349 let err = writer.write_bit(true).unwrap_err();
350 assert!(matches!(err, BitError::WriteOverflow { .. }));
351 }
352
353 #[test]
354 fn bounded_write_and_finish() {
355 let mut buf = [0u8; 2];
356 let mut writer = BitWriter::new(&mut buf);
357 writer.write_bits(0b1010, 4).unwrap();
358 writer.align_to_byte().unwrap();
359 writer.write_u8_aligned(0xAB).unwrap();
360 let bytes = writer.finish();
361 assert_eq!(bytes, 2);
362 assert_eq!(&buf[..2], &[0b1010_0000, 0xAB]);
363 }
364
365 #[test]
366 fn vec_writer_roundtrip_bits() {
367 let mut writer = BitVecWriter::new();
368 writer.write_bits(0b1010, 4).unwrap();
369 writer.write_bits(0xAB, 8).unwrap();
370 let bytes = writer.finish();
371 assert_eq!(bytes, vec![0b1010_1010, 0b1011_0000]);
372 }
373
374 #[test]
375 fn vec_writer_align() {
376 let mut writer = BitVecWriter::new();
377 writer.write_bits(0b1010, 4).unwrap();
378 writer.align_to_byte();
379 writer.write_u8_aligned(0xFF).unwrap();
380 let bytes = writer.finish();
381 assert_eq!(bytes, vec![0b1010_0000, 0xFF]);
382 }
383
384 #[test]
385 fn vec_writer_varint() {
386 let mut writer = BitVecWriter::new();
387 writer.align_to_byte();
388 writer.write_varu32(300).unwrap();
389 let bytes = writer.finish();
390 assert_eq!(bytes, vec![0xAC, 0x02]);
391 }
392
393 #[test]
394 fn vec_writer_zigzag() {
395 let mut writer = BitVecWriter::new();
396 writer.align_to_byte();
397 writer.write_vars32(-1).unwrap();
398 let bytes = writer.finish();
399 assert_eq!(bytes, vec![0x01]);
400 }
401}