tlbits/
integer.rs

1//! Collection of **de**/**ser**ialization helpers for integers
2use core::mem;
3
4use bitvec::{
5    mem::bits_of,
6    order::Msb0,
7    view::{AsBits, AsMutBits},
8};
9
10use crate::{
11    Error,
12    r#as::{AsBytes, NBits},
13    de::{BitReader, BitReaderExt, BitUnpack, r#as::BitUnpackAs},
14    ser::{BitPack, BitWriter, BitWriterExt, r#as::BitPackAs},
15};
16
17/// Constant version of `bool`
18///
19/// ## Deserialization
20///
21/// Reads `bool` and returns an error if it didn't match the
22/// type parameter.
23///
24/// ```rust
25/// # use tlbits::{
26/// #   bitvec::{bits, order::Msb0},
27/// #   de::{BitReaderExt},
28/// #   Error,
29/// #   integer::ConstBit,
30/// #   StringError,
31/// # };
32/// # fn main() -> Result<(), StringError> {
33/// # let mut reader = bits![u8, Msb0; 1, 1];
34/// reader.unpack::<ConstBit<true>>()?;
35/// // is equivalent of:
36/// if !reader.unpack::<bool>()? {
37///     return Err(Error::custom("expected 1, got 0"));
38/// }
39/// # Ok(())
40/// # }
41/// ```
42///
43/// ## Serialization
44///
45/// Writes `bool` specified in type parameter.
46///
47/// ```rust
48/// # use tlbits::{
49/// #   bitvec::{bits, vec::BitVec, order::Msb0},
50/// #   integer::ConstBit,
51/// #   ser::BitWriterExt,
52/// #   StringError,
53/// # };
54/// # fn main() -> Result<(), StringError> {
55/// # let mut writer = BitVec::<u8, Msb0>::new();
56/// writer.pack(ConstBit::<true>)?;
57/// // is equivalent of:
58/// writer.pack(true)?;
59/// # assert_eq!(writer, bits![u8, Msb0; 1, 1]);
60/// # Ok(())
61/// # }
62/// ```
63#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
64pub struct ConstBit<const VALUE: bool>;
65
66impl<const VALUE: bool> BitPack for ConstBit<VALUE> {
67    #[inline]
68    fn pack<W>(&self, writer: W) -> Result<(), W::Error>
69    where
70        W: BitWriter,
71    {
72        VALUE.pack(writer)
73    }
74}
75
76impl<'de, const VALUE: bool> BitUnpack<'de> for ConstBit<VALUE> {
77    #[inline]
78    fn unpack<R>(mut reader: R) -> Result<Self, R::Error>
79    where
80        R: BitReader<'de>,
81    {
82        if VALUE != reader.unpack::<bool>()? {
83            Err(Error::custom(format!(
84                "expected {:#b}, got {:#b}",
85                VALUE as u8, !VALUE as u8
86            )))
87        } else {
88            Ok(Self)
89        }
90    }
91}
92
93macro_rules! impl_bit_serde_for_integers {
94    ($($t:tt)+) => {$(
95        impl BitPack for $t {
96            #[inline]
97            fn pack<W>(&self, mut writer: W) -> Result<(), W::Error>
98            where
99                W: BitWriter,
100            {
101                writer.pack_as::<_, AsBytes>(self.to_be_bytes())?;
102                Ok(())
103            }
104        }
105
106        impl<'de> BitUnpack<'de> for $t {
107            #[inline]
108            fn unpack<R>(mut reader: R) -> Result<Self, R::Error>
109            where
110                R: BitReader<'de>,
111            {
112                reader.read_bytes_array().map(Self::from_be_bytes)
113            }
114        }
115
116        impl<const BITS: usize> BitPackAs<$t> for NBits<BITS> {
117            #[inline]
118            fn pack_as<W>(source: &$t, mut writer: W) -> Result<(), W::Error>
119            where
120                W: BitWriter,
121            {
122                const BITS_SIZE: usize = bits_of::<$t>();
123                assert!(BITS <= BITS_SIZE, "excessive bits for type");
124                if BITS < BITS_SIZE - source.leading_zeros() as usize {
125                    return Err(Error::custom(
126                        format!("{source:#b} cannot be packed into {BITS} bits"),
127                    ));
128                }
129                let bytes = source.to_be_bytes();
130                let bits = bytes.as_bits::<Msb0>();
131                writer.write_bitslice(&bits[bits.len() - BITS..])?;
132                Ok(())
133            }
134        }
135
136        impl<'de, const BITS: usize> BitUnpackAs<'de, $t> for NBits<BITS> {
137            #[inline]
138            fn unpack_as<R>(mut reader: R) -> Result<$t, R::Error>
139            where
140                R: BitReader<'de>,
141            {
142                const BITS_SIZE: usize = bits_of::<$t>();
143                assert!(BITS <= BITS_SIZE, "excessive bits for type");
144                let mut arr = [0u8; mem::size_of::<$t>()];
145                let arr_bits = &mut arr.as_mut_bits()[BITS_SIZE - BITS..];
146                if reader.read_bits_into(arr_bits)? != arr_bits.len() {
147                    return Err(Error::custom("EOF"));
148                }
149                Ok($t::from_be_bytes(arr))
150            }
151        }
152    )+};
153}
154impl_bit_serde_for_integers! {
155    u8 u16 u32 u64 u128 usize
156    i8 i16 i32 i64 i128 isize
157}
158
159macro_rules! const_uint {
160    ($($vis:vis $name:ident<$typ:tt>)+) => {$(
161        #[doc = concat!("Constant version of `", stringify!($typ), "`")]
162        /// ## Deserialization
163        #[doc = concat!(
164            "Reads `", stringify!($typ), "` and returns an error
165            if it didn't match the type parameter.",
166        )]
167        ///
168        /// ```rust
169        /// # use tlbits::{
170        /// #   bitvec::{vec::BitVec, order::Msb0},
171        /// #   de::BitReaderExt,
172        /// #   Error,
173        #[doc = concat!("# integer::", stringify!($name), ",")]
174        /// #   ser::BitWriterExt,
175        /// #   StringError,
176        /// # };
177        /// # fn main() -> Result<(), StringError> {
178        /// # let mut buff = BitVec::<u8, Msb0>::new();
179        #[doc = concat!("# buff.pack::<[", stringify!($typ), "; 2]>([123; 2])?;")]
180        /// # let mut reader = buff.as_bitslice();
181        #[doc = concat!("reader.unpack::<", stringify!($name), "<123>>()?;")]
182        /// // is equivalent of:
183        #[doc = concat!("let got: ", stringify!($typ), " = reader.unpack()?;")]
184        /// if got != 123 {
185        ///     return Err(Error::custom(format!("expected 123, got {got}")));
186        /// }
187        /// # Ok(())
188        /// # }
189        /// ```
190        ///
191        /// ## Serialization
192        ///
193        #[doc = concat!(
194            "Writes `", stringify!($typ), "` as specified in type parameter."
195        )]
196        ///
197        /// ```rust
198        /// # use tlbits::{
199        /// #   bitvec::{bits, vec::BitVec, order::Msb0},
200        /// #   de::BitReaderExt,
201        #[doc = concat!("# integer::", stringify!($name), ",")]
202        /// #   ser::BitWriterExt,
203        /// #   StringError,
204        /// # };
205        /// # fn main() -> Result<(), StringError> {
206        /// # let mut writer = BitVec::<u8, Msb0>::new();
207        #[doc = concat!("writer.pack(", stringify!($name), "::<123>)?;")]
208        /// // is equivalent of:
209        #[doc = concat!("writer.pack::<", stringify!($typ), ">(123)?;")]
210        /// # let mut reader = writer.as_bitslice();
211        #[doc = concat!(
212            "# assert_eq!(reader.unpack::<[", stringify!($typ), "; 2]>()?, [123; 2]);"
213        )]
214        /// # Ok(())
215        /// # }
216        /// ```
217        #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
218        $vis struct $name<const VALUE: $typ, const BITS: usize = { bits_of::<$typ>() }>;
219
220        impl<const VALUE: $typ, const BITS: usize> BitPack for $name<VALUE, BITS> {
221            #[inline]
222            fn pack<W>(&self, mut writer: W) -> Result<(), W::Error>
223            where
224                W: BitWriter,
225            {
226                writer.pack_as::<_, NBits<BITS>>(VALUE)?;
227                Ok(())
228            }
229        }
230
231        impl<'de, const VALUE: $typ, const BITS: usize> BitUnpack<'de> for $name<VALUE, BITS> {
232            #[inline]
233            fn unpack<R>(mut reader: R) -> Result<Self, R::Error>
234            where
235                R: BitReader<'de>,
236            {
237                let v = reader.unpack_as::<$typ, NBits<BITS>>()?;
238                if v != VALUE {
239                    return Err(Error::custom(format!(
240                        "expected {VALUE:#b}, got: {v:#b}"
241                    )));
242                }
243                Ok(Self)
244            }
245        }
246    )+};
247}
248
249const_uint! {
250    pub ConstU8  <u8>
251    pub ConstI8  <i8>
252    pub ConstU16 <u16>
253    pub ConstI16 <i16>
254    pub ConstU32 <u32>
255    pub ConstI32 <i32>
256    pub ConstU64 <u64>
257    pub ConstI64 <i64>
258    pub ConstU128<u128>
259    pub ConstI128<i128>
260}
261
262#[cfg(test)]
263mod tests {
264    use bitvec::{bits, order::Msb0};
265    use num_bigint::BigUint;
266
267    use crate::{
268        ser::{r#as::pack_as, pack},
269        tests::{assert_pack_unpack_as_eq, assert_pack_unpack_eq},
270    };
271
272    use super::*;
273
274    #[test]
275    fn store_uint() {
276        assert_eq!(
277            pack(0xFD_FE_u16).unwrap(),
278            bits![u8, Msb0; 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0],
279        )
280    }
281
282    #[test]
283    fn serde_uint() {
284        assert_pack_unpack_eq(12345_u32);
285    }
286
287    #[test]
288    fn store_nbits_uint() {
289        assert_eq!(
290            pack_as::<_, NBits<7>>(0x7E).unwrap(),
291            bits![u8, Msb0; 1, 1, 1, 1, 1, 1, 0],
292        )
293    }
294
295    #[test]
296    fn nbits_one_bit() {
297        assert_eq!(pack_as::<_, NBits<1>>(0b1).unwrap(), pack(true).unwrap())
298    }
299
300    #[test]
301    fn store_nbits_same_uint() {
302        const N: u8 = 231;
303        assert_eq!(pack(N).unwrap(), pack_as::<_, NBits<8>>(N).unwrap())
304    }
305
306    #[test]
307    fn serde_nbits_uint() {
308        assert_pack_unpack_as_eq::<u8, NBits<7>>(0x7E);
309    }
310
311    #[test]
312    fn serde_big_nbits() {
313        assert_pack_unpack_as_eq::<BigUint, NBits<100>>(12345_u64.into());
314    }
315}