1use crate::{Cmr, FailEntropy};
13use std::{error, fmt};
14
15#[non_exhaustive]
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub struct EarlyEndOfStreamError;
19
20impl fmt::Display for EarlyEndOfStreamError {
21 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
22 f.write_str("bitstream ended early")
23 }
24}
25
26impl error::Error for EarlyEndOfStreamError {}
27
28#[non_exhaustive]
30#[derive(Clone, Debug, PartialEq, Eq)]
31pub enum DecodeNaturalError {
32 BadIndex {
34 got: usize,
36 max: usize,
38 },
39 EndOfStream(EarlyEndOfStreamError),
41 Overflow,
43}
44
45impl fmt::Display for DecodeNaturalError {
46 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
47 match *self {
48 Self::BadIndex { got, max } => {
49 write!(
50 f,
51 "backreference {} exceeds current program length {}",
52 got, max
53 )
54 }
55 Self::EndOfStream(ref e) => e.fmt(f),
56 Self::Overflow => f.write_str("encoded number exceeded 31 bits"),
57 }
58 }
59}
60
61impl error::Error for DecodeNaturalError {
62 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
63 match *self {
64 Self::BadIndex { .. } => None,
65 Self::EndOfStream(ref e) => Some(e),
66 Self::Overflow => None,
67 }
68 }
69}
70
71#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
73pub enum CloseError {
74 TrailingBytes {
77 first_byte: u8,
79 },
80 IllegalPadding {
81 masked_padding: u8,
82 n_bits: usize,
83 },
84}
85
86impl fmt::Display for CloseError {
87 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
88 match self {
89 CloseError::TrailingBytes { first_byte } => {
90 write!(f, "bitstream had trailing bytes 0x{:02x}...", first_byte)
91 }
92 CloseError::IllegalPadding {
93 masked_padding,
94 n_bits,
95 } => write!(
96 f,
97 "bitstream had {n_bits} bits in its last byte 0x{:02x}, not all zero",
98 masked_padding
99 ),
100 }
101 }
102}
103
104impl error::Error for CloseError {}
105
106#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
111#[allow(non_camel_case_types)]
112pub enum u2 {
113 _0,
114 _1,
115 _2,
116 _3,
117}
118
119impl From<u2> for u8 {
120 fn from(s: u2) -> u8 {
121 match s {
122 u2::_0 => 0,
123 u2::_1 => 1,
124 u2::_2 => 2,
125 u2::_3 => 3,
126 }
127 }
128}
129
130#[derive(Debug, Clone)]
133pub struct BitIter<I: Iterator<Item = u8>> {
134 iter: I,
136 cached_byte: u8,
138 read_bits: usize,
140 total_read: usize,
142}
143
144impl From<Vec<u8>> for BitIter<std::vec::IntoIter<u8>> {
145 fn from(v: Vec<u8>) -> Self {
146 BitIter {
147 iter: v.into_iter(),
148 cached_byte: 0,
149 read_bits: 8,
152 total_read: 0,
153 }
154 }
155}
156
157impl<'a> From<&'a [u8]> for BitIter<std::iter::Copied<std::slice::Iter<'a, u8>>> {
158 fn from(sl: &'a [u8]) -> Self {
159 BitIter {
160 iter: sl.iter().copied(),
161 cached_byte: 0,
162 read_bits: 8,
165 total_read: 0,
166 }
167 }
168}
169
170impl<I: Iterator<Item = u8>> From<I> for BitIter<I> {
171 fn from(iter: I) -> Self {
172 BitIter {
173 iter,
174 cached_byte: 0,
175 read_bits: 8,
178 total_read: 0,
179 }
180 }
181}
182
183impl<I: Iterator<Item = u8>> Iterator for BitIter<I> {
184 type Item = bool;
185
186 fn next(&mut self) -> Option<bool> {
187 if self.read_bits < 8 {
188 self.read_bits += 1;
189 self.total_read += 1;
190 Some(self.cached_byte & (1 << (8 - self.read_bits as u8)) != 0)
191 } else {
192 self.cached_byte = self.iter.next()?;
193 self.read_bits = 0;
194 self.next()
195 }
196 }
197
198 fn size_hint(&self) -> (usize, Option<usize>) {
199 let (lo, hi) = self.iter.size_hint();
200 let adj = |n| 8 - self.read_bits + 8 * n;
201 (adj(lo), hi.map(adj))
202 }
203}
204
205impl<I> core::iter::FusedIterator for BitIter<I> where
206 I: Iterator<Item = u8> + core::iter::FusedIterator
207{
208}
209
210impl<I> core::iter::ExactSizeIterator for BitIter<I> where
211 I: Iterator<Item = u8> + core::iter::ExactSizeIterator
212{
213}
214
215impl<'a> BitIter<std::iter::Copied<std::slice::Iter<'a, u8>>> {
216 pub fn byte_slice_window(sl: &'a [u8], start: usize, end: usize) -> Self {
221 assert!(start <= end);
222 assert!(end <= sl.len() * 8);
223
224 let actual_sl = &sl[start / 8..end.div_ceil(8)];
225 let mut iter = actual_sl.iter().copied();
226
227 let read_bits = start % 8;
228 if read_bits == 0 {
229 BitIter {
230 iter,
231 cached_byte: 0,
232 read_bits: 8,
233 total_read: 0,
234 }
235 } else {
236 BitIter {
237 cached_byte: iter.by_ref().next().unwrap(),
238 iter,
239 read_bits,
240 total_read: 0,
241 }
242 }
243 }
244}
245
246impl<I: Iterator<Item = u8>> BitIter<I> {
247 pub fn new(iter: I) -> Self {
250 Self::from(iter)
251 }
252
253 pub fn read_bit(&mut self) -> Result<bool, EarlyEndOfStreamError> {
255 self.next().ok_or(EarlyEndOfStreamError)
256 }
257
258 pub fn read_u2(&mut self) -> Result<u2, EarlyEndOfStreamError> {
260 match (self.next(), self.next()) {
261 (Some(false), Some(false)) => Ok(u2::_0),
262 (Some(false), Some(true)) => Ok(u2::_1),
263 (Some(true), Some(false)) => Ok(u2::_2),
264 (Some(true), Some(true)) => Ok(u2::_3),
265 _ => Err(EarlyEndOfStreamError),
266 }
267 }
268
269 pub fn read_u8(&mut self) -> Result<u8, EarlyEndOfStreamError> {
271 debug_assert!(self.read_bits > 0);
272 let cached = self.cached_byte;
273 self.cached_byte = self.iter.next().ok_or(EarlyEndOfStreamError)?;
274 self.total_read += 8;
275
276 Ok(cached.checked_shl(self.read_bits as u32).unwrap_or(0)
277 + (self.cached_byte >> (8 - self.read_bits)))
278 }
279
280 pub fn read_cmr(&mut self) -> Result<Cmr, EarlyEndOfStreamError> {
282 let mut ret = [0; 32];
283 for byte in &mut ret {
284 *byte = self.read_u8()?;
285 }
286 Ok(Cmr::from_byte_array(ret))
287 }
288
289 pub fn read_fail_entropy(&mut self) -> Result<FailEntropy, EarlyEndOfStreamError> {
291 let mut ret = [0; 64];
292 for byte in &mut ret {
293 *byte = self.read_u8()?;
294 }
295 Ok(FailEntropy::from_byte_array(ret))
296 }
297
298 pub fn read_natural<N>(&mut self, bound: Option<N>) -> Result<N, DecodeNaturalError>
303 where
304 N: TryFrom<u32> + PartialOrd,
305 u32: TryFrom<N>,
306 usize: TryFrom<N>,
307 {
308 let mut recurse_depth = 0;
309 loop {
310 match self.read_bit() {
311 Ok(true) => recurse_depth += 1,
312 Ok(false) => break,
313 Err(e) => return Err(DecodeNaturalError::EndOfStream(e)),
314 }
315 }
316
317 let mut len = 0;
318 loop {
319 let mut n = 1u32;
320 for _ in 0..len {
321 let bit = u32::from(self.read_bit().map_err(DecodeNaturalError::EndOfStream)?);
322 n = 2 * n + bit;
323 }
324
325 if recurse_depth == 0 {
326 let ret = N::try_from(n).map_err(|_| DecodeNaturalError::Overflow)?;
327
328 if let Some(bound) = bound {
329 if ret > bound {
330 let got = n.try_into().unwrap_or(usize::MAX);
341 let max = bound.try_into().unwrap_or(usize::MAX);
342 return Err(DecodeNaturalError::BadIndex { got, max });
343 }
344 }
345
346 return Ok(ret);
347 } else {
348 len = n.try_into().map_err(|_| DecodeNaturalError::Overflow)?;
350 if len > 31 {
351 return Err(DecodeNaturalError::Overflow);
352 }
353 recurse_depth -= 1;
354 }
355 }
356 }
357
358 pub fn n_total_read(&self) -> usize {
361 self.total_read
362 }
363
364 pub fn close(mut self) -> Result<(), CloseError> {
367 if let Some(first_byte) = self.iter.next() {
368 return Err(CloseError::TrailingBytes { first_byte });
369 }
370
371 debug_assert!(self.read_bits >= 1);
372 debug_assert!(self.read_bits <= 8);
373 let n_bits = 8 - self.read_bits;
374 let masked_padding = self.cached_byte & ((1u8 << n_bits) - 1);
375 if masked_padding != 0 {
376 Err(CloseError::IllegalPadding {
377 masked_padding,
378 n_bits,
379 })
380 } else {
381 Ok(())
382 }
383 }
384}
385
386pub trait BitCollector: Sized {
388 fn collect_bits(self) -> (Vec<u8>, usize);
390
391 fn try_collect_bytes(self) -> Result<Vec<u8>, &'static str> {
395 let (bytes, bit_length) = self.collect_bits();
396 if bit_length % 8 == 0 {
397 Ok(bytes)
398 } else {
399 Err("Number of collected bits is not divisible by 8")
400 }
401 }
402}
403
404impl<I: Iterator<Item = bool>> BitCollector for I {
405 fn collect_bits(self) -> (Vec<u8>, usize) {
406 let mut bytes = vec![];
407 let mut unfinished_byte = Vec::with_capacity(8);
408
409 for bit in self {
410 unfinished_byte.push(bit);
411
412 if unfinished_byte.len() == 8 {
413 bytes.push(
414 unfinished_byte
415 .iter()
416 .fold(0, |acc, &b| acc * 2 + u8::from(b)),
417 );
418 unfinished_byte.clear();
419 }
420 }
421
422 let bit_length = bytes.len() * 8 + unfinished_byte.len();
423
424 if !unfinished_byte.is_empty() {
425 unfinished_byte.resize(8, false);
426 bytes.push(
427 unfinished_byte
428 .iter()
429 .fold(0, |acc, &b| acc * 2 + u8::from(b)),
430 );
431 }
432
433 (bytes, bit_length)
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440
441 #[test]
442 fn empty_iter() {
443 let mut iter = BitIter::from([].iter().cloned());
444 assert_eq!(iter.len(), 0);
445 assert!(iter.next().is_none());
446 assert!(iter.next().is_none());
447 assert_eq!(iter.read_bit(), Err(EarlyEndOfStreamError));
448 assert_eq!(iter.read_u2(), Err(EarlyEndOfStreamError));
449 assert_eq!(iter.read_u8(), Err(EarlyEndOfStreamError));
450 assert_eq!(iter.read_cmr(), Err(EarlyEndOfStreamError));
451 assert_eq!(iter.n_total_read(), 0);
452 }
453
454 #[test]
455 fn one_bit_iter() {
456 let mut iter = BitIter::from([0x80].iter().cloned());
457 assert_eq!(iter.len(), 8);
458 assert_eq!(iter.read_bit(), Ok(true));
459 assert_eq!(iter.len(), 7);
460 assert_eq!(iter.read_bit(), Ok(false));
461 assert_eq!(iter.len(), 6);
462 assert_eq!(iter.read_u8(), Err(EarlyEndOfStreamError));
463 assert_eq!(iter.n_total_read(), 2);
464 }
465
466 #[test]
467 fn bit_by_bit() {
468 let mut iter = BitIter::from([0x0f, 0xaa].iter().cloned());
469 assert_eq!(iter.len(), 16);
470 for _ in 0..4 {
471 assert_eq!(iter.next(), Some(false));
472 }
473 assert_eq!(iter.len(), 12);
474 for _ in 0..4 {
475 assert_eq!(iter.next(), Some(true));
476 }
477 assert_eq!(iter.len(), 8);
478 for _ in 0..4 {
479 assert_eq!(iter.next(), Some(true));
480 assert_eq!(iter.next(), Some(false));
481 }
482 assert_eq!(iter.len(), 0);
483 assert_eq!(iter.next(), None);
484 assert_eq!(iter.len(), 0);
485 }
486
487 #[test]
488 fn byte_by_byte() {
489 let mut iter = BitIter::from([0x0f, 0xaa].iter().cloned());
490 assert_eq!(iter.read_u8(), Ok(0x0f));
491 assert_eq!(iter.read_u8(), Ok(0xaa));
492 assert_eq!(iter.next(), None);
493 }
494
495 #[test]
496 fn regression_1() {
497 let mut iter = BitIter::from([0x34, 0x90].iter().cloned());
498 assert_eq!(iter.read_u2(), Ok(u2::_0)); assert_eq!(iter.read_u2(), Ok(u2::_3)); assert_eq!(iter.next(), Some(false)); assert_eq!(iter.read_u2(), Ok(u2::_2)); assert_eq!(iter.read_u2(), Ok(u2::_1)); assert_eq!(iter.n_total_read(), 9);
504 }
505
506 #[test]
507 fn byte_slice_window() {
508 let data = [0x12, 0x23, 0x34];
509
510 let mut full = BitIter::byte_slice_window(&data, 0, 24);
511 assert_eq!(full.len(), 24);
512 assert_eq!(full.read_u8(), Ok(0x12));
513 assert_eq!(full.len(), 16);
514 assert_eq!(full.n_total_read(), 8);
515 assert_eq!(full.read_u8(), Ok(0x23));
516 assert_eq!(full.n_total_read(), 16);
517 assert_eq!(full.read_u8(), Ok(0x34));
518 assert_eq!(full.n_total_read(), 24);
519 assert_eq!(full.read_u8(), Err(EarlyEndOfStreamError));
520
521 let mut mid = BitIter::byte_slice_window(&data, 8, 16);
522 assert_eq!(mid.len(), 8);
523 assert_eq!(mid.read_u8(), Ok(0x23));
524 assert_eq!(mid.len(), 0);
525 assert_eq!(mid.read_u8(), Err(EarlyEndOfStreamError));
526
527 let mut offs = BitIter::byte_slice_window(&data, 4, 20);
528 assert_eq!(offs.read_u8(), Ok(0x22));
529 assert_eq!(offs.read_u8(), Ok(0x33));
530 assert_eq!(offs.read_u8(), Err(EarlyEndOfStreamError));
531
532 let mut shift1 = BitIter::byte_slice_window(&data, 1, 24);
533 assert_eq!(shift1.len(), 23);
534 assert_eq!(shift1.read_u8(), Ok(0x24));
535 assert_eq!(shift1.len(), 15);
536 assert_eq!(shift1.read_u8(), Ok(0x46));
537 assert_eq!(shift1.len(), 7);
538 assert_eq!(shift1.read_u8(), Err(EarlyEndOfStreamError));
539 assert_eq!(shift1.len(), 7);
540
541 let mut shift7 = BitIter::byte_slice_window(&data, 7, 24);
542 assert_eq!(shift7.len(), 17);
543 assert_eq!(shift7.read_u8(), Ok(0x11));
544 assert_eq!(shift7.len(), 9);
545 assert_eq!(shift7.read_u8(), Ok(0x9a));
546 assert_eq!(shift7.len(), 1);
547 assert_eq!(shift7.read_u8(), Err(EarlyEndOfStreamError));
548 assert_eq!(shift7.len(), 1);
549 }
550}