Skip to main content

wire_codec/
bitfield.rs

1//! Bit-level read and write cursors.
2//!
3//! [`BitReader`] consumes bits from a borrowed byte slice in most-significant-bit
4//! first order, suitable for packed binary protocols where multiple fields share
5//! a single byte. [`BitWriter`] performs the reverse operation against a caller-
6//! supplied output buffer.
7//!
8//! Both cursors support arbitrary widths from 1 to 64 bits per call.
9
10use crate::error::{Error, Result};
11
12/// Maximum bit width per `read_bits` / `write_bits` call.
13pub const MAX_BIT_WIDTH: u32 = 64;
14
15/// Reads packed bits from a borrowed byte slice, most-significant-bit first.
16///
17/// # Example
18///
19/// ```
20/// use wire_codec::BitReader;
21///
22/// // 0b1010_1100 = 0xAC. Read 3 bits, then 5 bits.
23/// let mut r = BitReader::new(&[0xAC]);
24/// assert_eq!(r.read_bits(3).unwrap(), 0b101);
25/// assert_eq!(r.read_bits(5).unwrap(), 0b01100);
26/// ```
27#[derive(Debug, Clone)]
28pub struct BitReader<'a> {
29    bytes: &'a [u8],
30    bit_pos: usize,
31}
32
33impl<'a> BitReader<'a> {
34    /// Wrap `bytes` in a new bit cursor positioned at the first bit.
35    #[inline]
36    pub const fn new(bytes: &'a [u8]) -> Self {
37        Self { bytes, bit_pos: 0 }
38    }
39
40    /// Number of bits already consumed.
41    #[inline]
42    pub const fn bits_consumed(&self) -> usize {
43        self.bit_pos
44    }
45
46    /// Number of bits still available to read.
47    #[inline]
48    pub const fn bits_remaining(&self) -> usize {
49        (self.bytes.len() * 8) - self.bit_pos
50    }
51
52    /// Advance the cursor to the next byte boundary.
53    #[inline]
54    pub fn align_to_byte(&mut self) {
55        self.bit_pos = (self.bit_pos + 7) & !7;
56    }
57
58    /// Read the next `n` bits as the low-order bits of a `u64`.
59    ///
60    /// # Errors
61    ///
62    /// - [`Error::BitOverflow`] if `n` is `0` or greater than [`MAX_BIT_WIDTH`].
63    /// - [`Error::UnexpectedEof`] if fewer than `n` bits remain in the input.
64    pub fn read_bits(&mut self, n: u32) -> Result<u64> {
65        if n == 0 || n > MAX_BIT_WIDTH {
66            return Err(Error::BitOverflow);
67        }
68        if (n as usize) > self.bits_remaining() {
69            return Err(Error::UnexpectedEof);
70        }
71        let mut value: u64 = 0;
72        let mut bits_left = n;
73        while bits_left > 0 {
74            let byte_idx = self.bit_pos / 8;
75            let bit_off = (self.bit_pos % 8) as u32;
76            let avail = 8 - bit_off;
77            let take = if bits_left < avail { bits_left } else { avail };
78            let shift = avail - take;
79            let mask: u32 = (1u32 << take) - 1;
80            let chunk = (u32::from(self.bytes[byte_idx]) >> shift) & mask;
81            value = (value << take) | u64::from(chunk);
82            self.bit_pos += take as usize;
83            bits_left -= take;
84        }
85        Ok(value)
86    }
87}
88
89/// Writes packed bits to a caller-supplied byte slice, most-significant-bit first.
90///
91/// The cursor zero-fills bytes lazily as it advances, so callers can pass an
92/// uninitialized-looking buffer and read out only the prefix returned by
93/// [`BitWriter::finish`].
94///
95/// # Example
96///
97/// ```
98/// use wire_codec::BitWriter;
99///
100/// let mut storage = [0u8; 1];
101/// let mut w = BitWriter::new(&mut storage);
102/// w.write_bits(0b101, 3).unwrap();
103/// w.write_bits(0b01100, 5).unwrap();
104/// assert_eq!(w.finish(), 1);
105/// assert_eq!(storage[0], 0xAC);
106/// ```
107#[derive(Debug)]
108pub struct BitWriter<'a> {
109    bytes: &'a mut [u8],
110    bit_pos: usize,
111}
112
113impl<'a> BitWriter<'a> {
114    /// Wrap `bytes` in a new bit cursor. Existing contents of `bytes` are
115    /// overwritten as bits are written.
116    #[inline]
117    pub fn new(bytes: &'a mut [u8]) -> Self {
118        Self { bytes, bit_pos: 0 }
119    }
120
121    /// Number of bits written so far.
122    #[inline]
123    pub fn bits_written(&self) -> usize {
124        self.bit_pos
125    }
126
127    /// Pad the cursor with zero bits up to the next byte boundary.
128    ///
129    /// # Errors
130    ///
131    /// Returns [`Error::BufferFull`] if rounding up would exceed the backing
132    /// slice.
133    pub fn align_to_byte(&mut self) -> Result<()> {
134        let aligned = (self.bit_pos + 7) & !7;
135        if aligned > self.bytes.len() * 8 {
136            return Err(Error::BufferFull);
137        }
138        self.bit_pos = aligned;
139        Ok(())
140    }
141
142    /// Write the low `n` bits of `value` into the buffer.
143    ///
144    /// # Errors
145    ///
146    /// - [`Error::BitOverflow`] if `n` is `0`, greater than [`MAX_BIT_WIDTH`],
147    ///   or `value` has bits set above position `n`.
148    /// - [`Error::BufferFull`] if fewer than `n` bits remain in the buffer.
149    pub fn write_bits(&mut self, value: u64, n: u32) -> Result<()> {
150        if n == 0 || n > MAX_BIT_WIDTH {
151            return Err(Error::BitOverflow);
152        }
153        if n < 64 && (value >> n) != 0 {
154            return Err(Error::BitOverflow);
155        }
156        if (n as usize) > self.bytes.len() * 8 - self.bit_pos {
157            return Err(Error::BufferFull);
158        }
159        let mut bits_left = n;
160        while bits_left > 0 {
161            let byte_idx = self.bit_pos / 8;
162            let bit_off = (self.bit_pos % 8) as u32;
163            let avail = 8 - bit_off;
164            let take = if bits_left < avail { bits_left } else { avail };
165            let shift = avail - take;
166            let chunk = ((value >> (bits_left - take)) & ((1u64 << take) - 1)) as u8;
167            // Clear the destination bits, then OR in the new ones. This lets
168            // callers pass a freshly-allocated zeroed buffer or a reused one
169            // and get well-defined output either way.
170            let mask = (((1u32 << take) - 1) as u8) << shift;
171            self.bytes[byte_idx] = (self.bytes[byte_idx] & !mask) | (chunk << shift);
172            self.bit_pos += take as usize;
173            bits_left -= take;
174        }
175        Ok(())
176    }
177
178    /// Finalize writes and return the number of whole bytes consumed.
179    ///
180    /// A partially-filled trailing byte counts as one full byte, with the
181    /// remaining low-order bits left zero-padded.
182    #[inline]
183    pub fn finish(self) -> usize {
184        self.bit_pos.div_ceil(8)
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn pack_unpack_round_trip() {
194        let mut storage = [0u8; 4];
195        let mut w = BitWriter::new(&mut storage);
196        w.write_bits(0b1011, 4).unwrap();
197        w.write_bits(0b1100_1010, 8).unwrap();
198        w.write_bits(0b1111_0000, 8).unwrap();
199        let n = w.finish();
200        assert_eq!(n, 3);
201
202        let mut r = BitReader::new(&storage[..n]);
203        assert_eq!(r.read_bits(4).unwrap(), 0b1011);
204        assert_eq!(r.read_bits(8).unwrap(), 0b1100_1010);
205        assert_eq!(r.read_bits(8).unwrap(), 0b1111_0000);
206    }
207
208    #[test]
209    fn write_rejects_value_overflow() {
210        let mut storage = [0u8; 1];
211        let mut w = BitWriter::new(&mut storage);
212        // 0b1_0000 needs 5 bits but only 4 were requested.
213        assert_eq!(w.write_bits(0b1_0000, 4), Err(Error::BitOverflow));
214    }
215
216    #[test]
217    fn align_rounds_up_to_byte() {
218        let mut storage = [0u8; 2];
219        let mut w = BitWriter::new(&mut storage);
220        w.write_bits(0b101, 3).unwrap();
221        w.align_to_byte().unwrap();
222        w.write_bits(0xFF, 8).unwrap();
223        assert_eq!(w.finish(), 2);
224        assert_eq!(storage, [0b1010_0000, 0xFF]);
225    }
226}