willow_encoding/
compact_width.rs

1use crate::error::DecodeError;
2
3/// A minimum width of bytes needed to represent a unsigned integer.
4///
5/// [Definition](https://willowprotocol.org/specs/encodings/index.html#compact_width)
6#[derive(PartialEq, Eq, Debug)]
7pub enum CompactWidth {
8    /// The byte-width required to represent numbers up to 256 (i.e. a 8-bit number).
9    One,
10    /// The byte-width required to represent numbers up to 256^2 (i.e. a 16-bit number).
11    Two,
12    /// The byte-width required to represent numbers up to 256^4 (i.e. a 32-bit number).
13    Four,
14    /// The byte-width required to represent numbers up to 256^8 (i.e. a 64-bit number).
15    Eight,
16}
17
18#[derive(Debug)]
19pub(crate) struct NotACompactWidthError();
20
21impl core::fmt::Display for NotACompactWidthError {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        write!(
24            f,
25            "Tried to construct a CompactWidth from a u8 that was not 1, 2, 4, or 8."
26        )
27    }
28}
29
30impl std::error::Error for NotACompactWidthError {}
31
32impl CompactWidth {
33    /// Return a new [`CompactWidth`].
34    pub(crate) fn new(n: u8) -> Result<CompactWidth, NotACompactWidthError> {
35        match n {
36            1 => Ok(CompactWidth::One),
37            2 => Ok(CompactWidth::Two),
38            4 => Ok(CompactWidth::Four),
39            8 => Ok(CompactWidth::Eight),
40            _ => Err(NotACompactWidthError()),
41        }
42    }
43
44    /// Return the most compact width in bytes (1, 2, 4, or 8) needed to represent a given `u64` as a corresponding 8-bit, 16-bit, 32-bit, or 64-bit number.
45    ///
46    /// [Definition](https://willowprotocol.org/specs/encodings/index.html#compact_width).
47    pub fn from_u64(value: u64) -> Self {
48        if value <= u8::MAX as u64 {
49            CompactWidth::One
50        } else if value <= u16::MAX as u64 {
51            CompactWidth::Two
52        } else if value <= u32::MAX as u64 {
53            CompactWidth::Four
54        } else {
55            CompactWidth::Eight
56        }
57    }
58
59    /// Return the most compact width in bytes (1, 2, 4) needed to represent a given `u32` as a corresponding 8-bit, 16-bit, or 32-bit number.
60    ///
61    /// [Definition](https://willowprotocol.org/specs/encodings/index.html#compact_width).
62    pub fn from_u32(value: u32) -> Self {
63        if value <= u8::MAX as u32 {
64            CompactWidth::One
65        } else if value <= u16::MAX as u32 {
66            CompactWidth::Two
67        } else {
68            CompactWidth::Four
69        }
70    }
71
72    /// Return the most compact width in bytes (1 or 2) needed to represent a given `u16` as a corresponding 8-bit or 16-bit number.
73    ///
74    /// [Definition](https://willowprotocol.org/specs/encodings/index.html#compact_width).
75    pub fn from_u16(value: u16) -> Self {
76        if value <= u8::MAX as u16 {
77            CompactWidth::One
78        } else {
79            CompactWidth::Two
80        }
81    }
82
83    /// Return [`CompactWidth::One`], the only [`CompactWidth`] needed to represent a given `u8`.
84    ///
85    /// [Definition](https://willowprotocol.org/specs/encodings/index.html#compact_width).
86    pub fn from_u8(_: u8) -> Self {
87        CompactWidth::One
88    }
89
90    /// Return the width in bytes of this [`CompactWidth`].
91    pub fn width(&self) -> usize {
92        match self {
93            CompactWidth::One => 1,
94            CompactWidth::Two => 2,
95            CompactWidth::Four => 4,
96            CompactWidth::Eight => 8,
97        }
98    }
99
100    /// Encode a [`CompactWidth`] as a 2-bit integer `n` such that 2^n gives the bytewidth of the [`CompactWidth`], and then place that 2-bit number into a `u8` at the bit-index of `position`.
101    pub fn bitmask(&self, position: u8) -> u8 {
102        let og = match self {
103            CompactWidth::One => 0b0000_0000,
104            CompactWidth::Two => 0b0100_0000,
105            CompactWidth::Four => 0b1000_0000,
106            CompactWidth::Eight => 0b1100_0000,
107        };
108
109        og >> position
110    }
111
112    pub fn decode_fixed_width_bitmask(mask: u8, offset: u8) -> Self {
113        let twobit_mask = 0b0000_0011;
114        let two_bit_int = mask >> (6 - offset) & twobit_mask;
115
116        // Because we sanitise the input down to a 2-bit integer, we can safely unwrap this.
117        CompactWidth::new(2u8.pow(two_bit_int as u32)).unwrap()
118    }
119}
120
121use syncify::syncify;
122use syncify::syncify_replace;
123
124#[syncify(encoding_sync)]
125pub mod encoding {
126    use super::*;
127
128    #[syncify_replace(use ufotofu::sync::{BulkConsumer, BulkProducer};)]
129    use ufotofu::local_nb::{BulkConsumer, BulkProducer};
130
131    use crate::unsigned_int::{U16BE, U32BE, U64BE, U8BE};
132
133    #[syncify_replace(use crate::sync::{Decodable};)]
134    use crate::Decodable;
135
136    /// Encode a `u64` integer as a `compact_width(value)`-byte big-endian integer, and consume that with a [`BulkConsumer`].
137    pub async fn encode_compact_width_be<Consumer: BulkConsumer<Item = u8>>(
138        value: u64,
139        consumer: &mut Consumer,
140    ) -> Result<(), Consumer::Error> {
141        let width = CompactWidth::from_u64(value).width();
142
143        consumer
144            .bulk_consume_full_slice(&value.to_be_bytes()[8 - width..])
145            .await
146            .map_err(|f| f.reason)?;
147
148        Ok(())
149    }
150
151    /// Decode the bytes representing a [`CompactWidth`]-bytes integer into a `usize`.
152    pub async fn decode_compact_width_be<Producer: BulkProducer<Item = u8>>(
153        compact_width: CompactWidth,
154        producer: &mut Producer,
155    ) -> Result<u64, DecodeError<Producer::Error>> {
156        let decoded = match compact_width {
157            CompactWidth::One => U8BE::decode(producer).await.map(u64::from),
158            CompactWidth::Two => U16BE::decode(producer).await.map(u64::from),
159            CompactWidth::Four => U32BE::decode(producer).await.map(u64::from),
160            CompactWidth::Eight => U64BE::decode(producer).await.map(u64::from),
161        }?;
162
163        let real_width = CompactWidth::from_u64(decoded);
164
165        if real_width != compact_width {
166            return Err(DecodeError::InvalidInput);
167        }
168
169        Ok(decoded)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use encoding_sync::{decode_compact_width_be, encode_compact_width_be};
176    use ufotofu::local_nb::consumer::IntoVec;
177    use ufotofu::local_nb::producer::FromBoxedSlice;
178
179    use super::*;
180
181    #[test]
182    fn compact_width_works() {
183        // u64
184        assert_eq!(CompactWidth::from_u64(0_u64), CompactWidth::One);
185        assert_eq!(CompactWidth::from_u64(u8::MAX as u64), CompactWidth::One);
186
187        assert_eq!(
188            CompactWidth::from_u64(u8::MAX as u64 + 1),
189            CompactWidth::Two
190        );
191        assert_eq!(CompactWidth::from_u64(u16::MAX as u64), CompactWidth::Two);
192
193        assert_eq!(
194            CompactWidth::from_u64(u16::MAX as u64 + 1),
195            CompactWidth::Four
196        );
197        assert_eq!(CompactWidth::from_u64(u32::MAX as u64), CompactWidth::Four);
198
199        assert_eq!(
200            CompactWidth::from_u64(u32::MAX as u64 + 1),
201            CompactWidth::Eight
202        );
203        assert_eq!(CompactWidth::from_u64(u64::MAX), CompactWidth::Eight);
204
205        // u32
206        assert_eq!(CompactWidth::from_u32(0_u32), CompactWidth::One);
207        assert_eq!(CompactWidth::from_u32(u8::MAX as u32), CompactWidth::One);
208
209        assert_eq!(
210            CompactWidth::from_u32(u8::MAX as u32 + 1),
211            CompactWidth::Two
212        );
213        assert_eq!(CompactWidth::from_u32(u16::MAX as u32), CompactWidth::Two);
214
215        assert_eq!(
216            CompactWidth::from_u32(u16::MAX as u32 + 1),
217            CompactWidth::Four
218        );
219        assert_eq!(CompactWidth::from_u32(u32::MAX), CompactWidth::Four);
220
221        // u16
222        assert_eq!(CompactWidth::from_u16(0_u16), CompactWidth::One);
223        assert_eq!(CompactWidth::from_u16(u8::MAX as u16), CompactWidth::One);
224
225        assert_eq!(
226            CompactWidth::from_u16(u8::MAX as u16 + 1),
227            CompactWidth::Two
228        );
229        assert_eq!(CompactWidth::from_u16(u16::MAX), CompactWidth::Two);
230
231        // u8
232        assert_eq!(CompactWidth::from_u8(0_u8), CompactWidth::One);
233        assert_eq!(CompactWidth::from_u8(u8::MAX), CompactWidth::One);
234    }
235
236    #[test]
237    fn encoding() {
238        let values = [
239            (CompactWidth::One, 0),
240            (CompactWidth::One, u8::MAX as u64),
241            (CompactWidth::Two, u8::MAX as u64 + 1),
242            (CompactWidth::Two, u16::MAX as u64),
243            (CompactWidth::Four, u16::MAX as u64 + 1),
244            (CompactWidth::Four, u32::MAX as u64),
245            (CompactWidth::Eight, u32::MAX as u64 + 1),
246            (CompactWidth::Eight, u64::MAX),
247        ];
248
249        for (compact_width, value) in values {
250            let mut consumer = IntoVec::<u8>::new();
251
252            encode_compact_width_be(value, &mut consumer).unwrap();
253
254            let encode_result = consumer.into_vec();
255
256            let decoded_compact_width = CompactWidth::new(encode_result.len() as u8).unwrap();
257
258            assert_eq!(decoded_compact_width, compact_width);
259
260            let mut producer = FromBoxedSlice::from_vec(encode_result);
261
262            let decode_result =
263                decode_compact_width_be(decoded_compact_width, &mut producer).unwrap();
264
265            assert_eq!(decode_result, value);
266        }
267    }
268}