simplicity/bit_encoding/
bititer.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Bit Iterator functionality
4//!
5//! Simplicity programs are encoded bitwise rather than bytewise. This
6//! module provides some helper functionality to make efficient parsing
7//! easier. In particular, the `BitIter` type takes a byte iterator and
8//! wraps it with some additional functionality (including implementing
9//! `Iterator<Item=bool>`.
10//!
11
12use crate::{Cmr, FailEntropy};
13use std::{error, fmt};
14
15/// Attempted to read from a bit iterator, but there was no more data
16#[non_exhaustive]
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub struct EarlyEndOfStreamError;
19
20impl fmt::Display for EarlyEndOfStreamError {
21    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
22        f.write_str("bitstream ended early")
23    }
24}
25
26impl error::Error for EarlyEndOfStreamError {}
27
28/// Failed to decode a natural number from a bitstream.
29#[non_exhaustive]
30#[derive(Clone, Debug, PartialEq, Eq)]
31pub enum DecodeNaturalError {
32    /// Natural was a backreference, and pointed past the beginning of the program.
33    BadIndex {
34        /// The number we read.
35        got: usize,
36        /// The maximum value.
37        max: usize,
38    },
39    /// Ran out of bits to read.
40    EndOfStream(EarlyEndOfStreamError),
41    /// Read a natural that exceeded the maximum value (2^31).
42    Overflow,
43}
44
45impl fmt::Display for DecodeNaturalError {
46    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
47        match *self {
48            Self::BadIndex { got, max } => {
49                write!(
50                    f,
51                    "backreference {} exceeds current program length {}",
52                    got, max
53                )
54            }
55            Self::EndOfStream(ref e) => e.fmt(f),
56            Self::Overflow => f.write_str("encoded number exceeded 31 bits"),
57        }
58    }
59}
60
61impl error::Error for DecodeNaturalError {
62    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
63        match *self {
64            Self::BadIndex { .. } => None,
65            Self::EndOfStream(ref e) => Some(e),
66            Self::Overflow => None,
67        }
68    }
69}
70
71/// Closed out a bit iterator and there was remaining data.
72#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
73pub enum CloseError {
74    /// The iterator was closed but the underlying byte iterator was
75    /// still yielding data.
76    TrailingBytes {
77        /// The first unused byte from the iterator.
78        first_byte: u8,
79    },
80    IllegalPadding {
81        masked_padding: u8,
82        n_bits: usize,
83    },
84}
85
86impl fmt::Display for CloseError {
87    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
88        match self {
89            CloseError::TrailingBytes { first_byte } => {
90                write!(f, "bitstream had trailing bytes 0x{:02x}...", first_byte)
91            }
92            CloseError::IllegalPadding {
93                masked_padding,
94                n_bits,
95            } => write!(
96                f,
97                "bitstream had {n_bits} bits in its last byte 0x{:02x}, not all zero",
98                masked_padding
99            ),
100        }
101    }
102}
103
104impl error::Error for CloseError {}
105
106/// Two-bit type used during decoding
107///
108/// Use of an enum rather than a numeric primitive type makes it easier to
109/// match on.
110#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
111#[allow(non_camel_case_types)]
112pub enum u2 {
113    _0,
114    _1,
115    _2,
116    _3,
117}
118
119impl From<u2> for u8 {
120    fn from(s: u2) -> u8 {
121        match s {
122            u2::_0 => 0,
123            u2::_1 => 1,
124            u2::_2 => 2,
125            u2::_3 => 3,
126        }
127    }
128}
129
130/// Bitwise iterator formed from a wrapped bytewise iterator. Bytes are
131/// interpreted big-endian, i.e. MSB is returned first
132#[derive(Debug, Clone)]
133pub struct BitIter<I: Iterator<Item = u8>> {
134    /// Byte iterator
135    iter: I,
136    /// Current byte that contains next bit
137    cached_byte: u8,
138    /// Number of read bits in current byte
139    read_bits: usize,
140    /// Total number of read bits
141    total_read: usize,
142}
143
144impl From<Vec<u8>> for BitIter<std::vec::IntoIter<u8>> {
145    fn from(v: Vec<u8>) -> Self {
146        BitIter {
147            iter: v.into_iter(),
148            cached_byte: 0,
149            // Set to 8 to force next `Iterator::next` to read a new byte
150            // from the underlying iterator
151            read_bits: 8,
152            total_read: 0,
153        }
154    }
155}
156
157impl<'a> From<&'a [u8]> for BitIter<std::iter::Copied<std::slice::Iter<'a, u8>>> {
158    fn from(sl: &'a [u8]) -> Self {
159        BitIter {
160            iter: sl.iter().copied(),
161            cached_byte: 0,
162            // Set to 8 to force next `Iterator::next` to read a new byte
163            // from the underlying iterator
164            read_bits: 8,
165            total_read: 0,
166        }
167    }
168}
169
170impl<I: Iterator<Item = u8>> From<I> for BitIter<I> {
171    fn from(iter: I) -> Self {
172        BitIter {
173            iter,
174            cached_byte: 0,
175            // Set to 8 to force next `Iterator::next` to read a new byte
176            // from the underlying iterator
177            read_bits: 8,
178            total_read: 0,
179        }
180    }
181}
182
183impl<I: Iterator<Item = u8>> Iterator for BitIter<I> {
184    type Item = bool;
185
186    fn next(&mut self) -> Option<bool> {
187        if self.read_bits < 8 {
188            self.read_bits += 1;
189            self.total_read += 1;
190            Some(self.cached_byte & (1 << (8 - self.read_bits as u8)) != 0)
191        } else {
192            self.cached_byte = self.iter.next()?;
193            self.read_bits = 0;
194            self.next()
195        }
196    }
197
198    fn size_hint(&self) -> (usize, Option<usize>) {
199        let (lo, hi) = self.iter.size_hint();
200        let adj = |n| 8 - self.read_bits + 8 * n;
201        (adj(lo), hi.map(adj))
202    }
203}
204
205impl<I> core::iter::FusedIterator for BitIter<I> where
206    I: Iterator<Item = u8> + core::iter::FusedIterator
207{
208}
209
210impl<I> core::iter::ExactSizeIterator for BitIter<I> where
211    I: Iterator<Item = u8> + core::iter::ExactSizeIterator
212{
213}
214
215impl<'a> BitIter<std::iter::Copied<std::slice::Iter<'a, u8>>> {
216    /// Creates a new bitwise iterator from a bytewise one.
217    ///
218    /// Takes start and end indices *in bits*. If you want to use the entire slice,
219    /// `BitIter::from` is equivalent and easier to call.
220    pub fn byte_slice_window(sl: &'a [u8], start: usize, end: usize) -> Self {
221        assert!(start <= end);
222        assert!(end <= sl.len() * 8);
223
224        let actual_sl = &sl[start / 8..end.div_ceil(8)];
225        let mut iter = actual_sl.iter().copied();
226
227        let read_bits = start % 8;
228        if read_bits == 0 {
229            BitIter {
230                iter,
231                cached_byte: 0,
232                read_bits: 8,
233                total_read: 0,
234            }
235        } else {
236            BitIter {
237                cached_byte: iter.by_ref().next().unwrap(),
238                iter,
239                read_bits,
240                total_read: 0,
241            }
242        }
243    }
244}
245
246impl<I: Iterator<Item = u8>> BitIter<I> {
247    /// Creates a new bitwise iterator from a bytewise one. Equivalent
248    /// to using `From`
249    pub fn new(iter: I) -> Self {
250        Self::from(iter)
251    }
252
253    /// Reads a single bit from the iterator.
254    pub fn read_bit(&mut self) -> Result<bool, EarlyEndOfStreamError> {
255        self.next().ok_or(EarlyEndOfStreamError)
256    }
257
258    /// Reads two bits from the iterator.
259    pub fn read_u2(&mut self) -> Result<u2, EarlyEndOfStreamError> {
260        match (self.next(), self.next()) {
261            (Some(false), Some(false)) => Ok(u2::_0),
262            (Some(false), Some(true)) => Ok(u2::_1),
263            (Some(true), Some(false)) => Ok(u2::_2),
264            (Some(true), Some(true)) => Ok(u2::_3),
265            _ => Err(EarlyEndOfStreamError),
266        }
267    }
268
269    /// Reads a byte from the iterator.
270    pub fn read_u8(&mut self) -> Result<u8, EarlyEndOfStreamError> {
271        debug_assert!(self.read_bits > 0);
272        let cached = self.cached_byte;
273        self.cached_byte = self.iter.next().ok_or(EarlyEndOfStreamError)?;
274        self.total_read += 8;
275
276        Ok(cached.checked_shl(self.read_bits as u32).unwrap_or(0)
277            + (self.cached_byte >> (8 - self.read_bits)))
278    }
279
280    /// Reads a 256-bit CMR from the iterator.
281    pub fn read_cmr(&mut self) -> Result<Cmr, EarlyEndOfStreamError> {
282        let mut ret = [0; 32];
283        for byte in &mut ret {
284            *byte = self.read_u8()?;
285        }
286        Ok(Cmr::from_byte_array(ret))
287    }
288
289    /// Reads a 512-bit fail-combinator entropy from the iterator.
290    pub fn read_fail_entropy(&mut self) -> Result<FailEntropy, EarlyEndOfStreamError> {
291        let mut ret = [0; 64];
292        for byte in &mut ret {
293            *byte = self.read_u8()?;
294        }
295        Ok(FailEntropy::from_byte_array(ret))
296    }
297
298    /// Decode a natural number from bits.
299    ///
300    /// If a bound is specified, then the decoding terminates before trying to
301    /// decode a larger number.
302    pub fn read_natural<N>(&mut self, bound: Option<N>) -> Result<N, DecodeNaturalError>
303    where
304        N: TryFrom<u32> + PartialOrd,
305        u32: TryFrom<N>,
306        usize: TryFrom<N>,
307    {
308        let mut recurse_depth = 0;
309        loop {
310            match self.read_bit() {
311                Ok(true) => recurse_depth += 1,
312                Ok(false) => break,
313                Err(e) => return Err(DecodeNaturalError::EndOfStream(e)),
314            }
315        }
316
317        let mut len = 0;
318        loop {
319            let mut n = 1u32;
320            for _ in 0..len {
321                let bit = u32::from(self.read_bit().map_err(DecodeNaturalError::EndOfStream)?);
322                n = 2 * n + bit;
323            }
324
325            if recurse_depth == 0 {
326                let ret = N::try_from(n).map_err(|_| DecodeNaturalError::Overflow)?;
327
328                if let Some(bound) = bound {
329                    if ret > bound {
330                        // We are doing our arithmetic in 32 bits and will return early if we try
331                        // to exceed this. But maybe usize will be smaller than 32 bits. To handle
332                        // this without casts or panics we just saturate here. It's just error
333                        // reporting.
334                        //
335                        // Also, we can't write usize::try_from(n) here even though usize implements
336                        // TryFrom<u32> because of some bug in the Rust compiler. I couldn't find an
337                        // issue in a five minute search and I refuse to file any more issues on
338                        // rustc because it's a waste of time. So we have to write try_into and the
339                        // reader can do type inference themselves.
340                        let got = n.try_into().unwrap_or(usize::MAX);
341                        let max = bound.try_into().unwrap_or(usize::MAX);
342                        return Err(DecodeNaturalError::BadIndex { got, max });
343                    }
344                }
345
346                return Ok(ret);
347            } else {
348                // This is an attempted conversion to usize. See above comment.
349                len = n.try_into().map_err(|_| DecodeNaturalError::Overflow)?;
350                if len > 31 {
351                    return Err(DecodeNaturalError::Overflow);
352                }
353                recurse_depth -= 1;
354            }
355        }
356    }
357
358    /// Accessor for the number of bits which have been read,
359    /// in total, from this iterator
360    pub fn n_total_read(&self) -> usize {
361        self.total_read
362    }
363
364    /// Consumes the bit iterator, checking that there are no remaining
365    /// bytes and that any unread bits are zero.
366    pub fn close(mut self) -> Result<(), CloseError> {
367        if let Some(first_byte) = self.iter.next() {
368            return Err(CloseError::TrailingBytes { first_byte });
369        }
370
371        debug_assert!(self.read_bits >= 1);
372        debug_assert!(self.read_bits <= 8);
373        let n_bits = 8 - self.read_bits;
374        let masked_padding = self.cached_byte & ((1u8 << n_bits) - 1);
375        if masked_padding != 0 {
376            Err(CloseError::IllegalPadding {
377                masked_padding,
378                n_bits,
379            })
380        } else {
381            Ok(())
382        }
383    }
384}
385
386/// Functionality for Boolean iterators to collect their bits or bytes.
387pub trait BitCollector: Sized {
388    /// Collect the bits of the iterator into a byte vector and a bit length.
389    fn collect_bits(self) -> (Vec<u8>, usize);
390
391    /// Try to collect the bits of the iterator into a clean byte vector.
392    ///
393    /// This fails if the number of bits is not divisible by 8.
394    fn try_collect_bytes(self) -> Result<Vec<u8>, &'static str> {
395        let (bytes, bit_length) = self.collect_bits();
396        if bit_length % 8 == 0 {
397            Ok(bytes)
398        } else {
399            Err("Number of collected bits is not divisible by 8")
400        }
401    }
402}
403
404impl<I: Iterator<Item = bool>> BitCollector for I {
405    fn collect_bits(self) -> (Vec<u8>, usize) {
406        let mut bytes = vec![];
407        let mut unfinished_byte = Vec::with_capacity(8);
408
409        for bit in self {
410            unfinished_byte.push(bit);
411
412            if unfinished_byte.len() == 8 {
413                bytes.push(
414                    unfinished_byte
415                        .iter()
416                        .fold(0, |acc, &b| acc * 2 + u8::from(b)),
417                );
418                unfinished_byte.clear();
419            }
420        }
421
422        let bit_length = bytes.len() * 8 + unfinished_byte.len();
423
424        if !unfinished_byte.is_empty() {
425            unfinished_byte.resize(8, false);
426            bytes.push(
427                unfinished_byte
428                    .iter()
429                    .fold(0, |acc, &b| acc * 2 + u8::from(b)),
430            );
431        }
432
433        (bytes, bit_length)
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440
441    #[test]
442    fn empty_iter() {
443        let mut iter = BitIter::from([].iter().cloned());
444        assert_eq!(iter.len(), 0);
445        assert!(iter.next().is_none());
446        assert!(iter.next().is_none());
447        assert_eq!(iter.read_bit(), Err(EarlyEndOfStreamError));
448        assert_eq!(iter.read_u2(), Err(EarlyEndOfStreamError));
449        assert_eq!(iter.read_u8(), Err(EarlyEndOfStreamError));
450        assert_eq!(iter.read_cmr(), Err(EarlyEndOfStreamError));
451        assert_eq!(iter.n_total_read(), 0);
452    }
453
454    #[test]
455    fn one_bit_iter() {
456        let mut iter = BitIter::from([0x80].iter().cloned());
457        assert_eq!(iter.len(), 8);
458        assert_eq!(iter.read_bit(), Ok(true));
459        assert_eq!(iter.len(), 7);
460        assert_eq!(iter.read_bit(), Ok(false));
461        assert_eq!(iter.len(), 6);
462        assert_eq!(iter.read_u8(), Err(EarlyEndOfStreamError));
463        assert_eq!(iter.n_total_read(), 2);
464    }
465
466    #[test]
467    fn bit_by_bit() {
468        let mut iter = BitIter::from([0x0f, 0xaa].iter().cloned());
469        assert_eq!(iter.len(), 16);
470        for _ in 0..4 {
471            assert_eq!(iter.next(), Some(false));
472        }
473        assert_eq!(iter.len(), 12);
474        for _ in 0..4 {
475            assert_eq!(iter.next(), Some(true));
476        }
477        assert_eq!(iter.len(), 8);
478        for _ in 0..4 {
479            assert_eq!(iter.next(), Some(true));
480            assert_eq!(iter.next(), Some(false));
481        }
482        assert_eq!(iter.len(), 0);
483        assert_eq!(iter.next(), None);
484        assert_eq!(iter.len(), 0);
485    }
486
487    #[test]
488    fn byte_by_byte() {
489        let mut iter = BitIter::from([0x0f, 0xaa].iter().cloned());
490        assert_eq!(iter.read_u8(), Ok(0x0f));
491        assert_eq!(iter.read_u8(), Ok(0xaa));
492        assert_eq!(iter.next(), None);
493    }
494
495    #[test]
496    fn regression_1() {
497        let mut iter = BitIter::from([0x34, 0x90].iter().cloned());
498        assert_eq!(iter.read_u2(), Ok(u2::_0)); // 0011
499        assert_eq!(iter.read_u2(), Ok(u2::_3)); // 0011
500        assert_eq!(iter.next(), Some(false)); // 0
501        assert_eq!(iter.read_u2(), Ok(u2::_2)); // 10
502        assert_eq!(iter.read_u2(), Ok(u2::_1)); // 01
503        assert_eq!(iter.n_total_read(), 9);
504    }
505
506    #[test]
507    fn byte_slice_window() {
508        let data = [0x12, 0x23, 0x34];
509
510        let mut full = BitIter::byte_slice_window(&data, 0, 24);
511        assert_eq!(full.len(), 24);
512        assert_eq!(full.read_u8(), Ok(0x12));
513        assert_eq!(full.len(), 16);
514        assert_eq!(full.n_total_read(), 8);
515        assert_eq!(full.read_u8(), Ok(0x23));
516        assert_eq!(full.n_total_read(), 16);
517        assert_eq!(full.read_u8(), Ok(0x34));
518        assert_eq!(full.n_total_read(), 24);
519        assert_eq!(full.read_u8(), Err(EarlyEndOfStreamError));
520
521        let mut mid = BitIter::byte_slice_window(&data, 8, 16);
522        assert_eq!(mid.len(), 8);
523        assert_eq!(mid.read_u8(), Ok(0x23));
524        assert_eq!(mid.len(), 0);
525        assert_eq!(mid.read_u8(), Err(EarlyEndOfStreamError));
526
527        let mut offs = BitIter::byte_slice_window(&data, 4, 20);
528        assert_eq!(offs.read_u8(), Ok(0x22));
529        assert_eq!(offs.read_u8(), Ok(0x33));
530        assert_eq!(offs.read_u8(), Err(EarlyEndOfStreamError));
531
532        let mut shift1 = BitIter::byte_slice_window(&data, 1, 24);
533        assert_eq!(shift1.len(), 23);
534        assert_eq!(shift1.read_u8(), Ok(0x24));
535        assert_eq!(shift1.len(), 15);
536        assert_eq!(shift1.read_u8(), Ok(0x46));
537        assert_eq!(shift1.len(), 7);
538        assert_eq!(shift1.read_u8(), Err(EarlyEndOfStreamError));
539        assert_eq!(shift1.len(), 7);
540
541        let mut shift7 = BitIter::byte_slice_window(&data, 7, 24);
542        assert_eq!(shift7.len(), 17);
543        assert_eq!(shift7.read_u8(), Ok(0x11));
544        assert_eq!(shift7.len(), 9);
545        assert_eq!(shift7.read_u8(), Ok(0x9a));
546        assert_eq!(shift7.len(), 1);
547        assert_eq!(shift7.read_u8(), Err(EarlyEndOfStreamError));
548        assert_eq!(shift7.len(), 1);
549    }
550}