simplicity/bit_encoding/
bitwriter.rs

1// SPDX-License-Identifier: CC0-1.0
2
3use std::io;
4
5/// Bitwise writer formed by wrapping a bytewise [`io::Write`].
6/// Bits are written in big-endian order.
7/// Bytes are filled with zeroes for padding.
8pub struct BitWriter<W: io::Write> {
9    /// Byte writer
10    w: W,
11    /// Current byte that contains current bits, yet to be written out
12    cache: u8,
13    /// Number of current bits
14    cache_len: usize,
15    /// Total number of written bits
16    total_written: usize,
17}
18
19impl<W: io::Write> From<W> for BitWriter<W> {
20    fn from(w: W) -> Self {
21        BitWriter {
22            w,
23            cache: 0,
24            cache_len: 0,
25            total_written: 0,
26        }
27    }
28}
29
30impl<W: io::Write> io::Write for BitWriter<W> {
31    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
32        for b in buf {
33            for i in 0..8 {
34                self.write_bit((b & (1 << (7 - i))) != 0)?;
35            }
36        }
37        Ok(buf.len())
38    }
39
40    /// Does **not** write out cached bits
41    /// (i.e. bits written since the last byte boundary).
42    /// To do this you must call [`BitWriter::flush_all()`].
43    fn flush(&mut self) -> io::Result<()> {
44        self.w.flush()
45    }
46}
47
48impl<W: io::Write> BitWriter<W> {
49    /// Create a bitwise writer from a bytewise one.
50    /// Equivalent to using [`From`].
51    pub fn new(w: W) -> BitWriter<W> {
52        BitWriter::from(w)
53    }
54
55    /// Write a single bit.
56    pub fn write_bit(&mut self, b: bool) -> io::Result<()> {
57        if self.cache_len < 8 {
58            self.cache_len += 1;
59            self.total_written += 1;
60            if b {
61                self.cache |= 1 << (8 - self.cache_len);
62            }
63            Ok(())
64        } else {
65            self.w.write_all(&[self.cache])?;
66            self.cache_len = 0;
67            self.cache = 0;
68            self.write_bit(b)
69        }
70    }
71
72    /// Write out all cached bits.
73    /// This may write up to two bytes and flushes the underlying [`io::Write`].
74    pub fn flush_all(&mut self) -> io::Result<()> {
75        if self.cache_len > 0 {
76            self.w.write_all(&[self.cache])?;
77            self.cache_len = 0;
78            self.cache = 0;
79        }
80
81        io::Write::flush(&mut self.w)
82    }
83
84    /// Return total number of written bits.
85    pub fn n_total_written(&self) -> usize {
86        self.total_written
87    }
88
89    /// Write up to 64 bits in big-endian order.
90    /// The first `len` many _least significant_ bits from `n` are written.
91    ///
92    /// Returns the number of written bits.
93    pub fn write_bits_be(&mut self, n: u64, len: usize) -> io::Result<usize> {
94        for i in 0..len {
95            self.write_bit(n & (1 << (len - i - 1)) != 0)?;
96        }
97        Ok(len)
98    }
99}
100
101/// Write the result of a bit operation into a byte vector and return the vector.
102///
103/// I/O to a vector never fails.
104pub fn write_to_vec<F>(f: F) -> Vec<u8>
105where
106    F: FnOnce(&mut BitWriter<&mut Vec<u8>>) -> io::Result<usize>,
107{
108    let mut bytes = Vec::new();
109    let mut bits = BitWriter::new(&mut bytes);
110    f(&mut bits).expect("I/O to vector never fails");
111    bits.flush_all().expect("I/O to vector never fails");
112    bytes
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::jet::Core;
119    use crate::node::CoreConstructible;
120    use crate::types;
121    use crate::ConstructNode;
122    use std::sync::Arc;
123
124    #[test]
125    fn vec() {
126        types::Context::with_context(|ctx| {
127            let program = Arc::<ConstructNode<Core>>::unit(&ctx);
128            let _ = write_to_vec(|w| program.encode_without_witness(w));
129        })
130    }
131
132    #[test]
133    fn empty_vec() {
134        let vec = write_to_vec(|_| Ok(0));
135        assert!(vec.is_empty());
136    }
137}