1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// SPDX-License-Identifier: CC0-1.0

use std::io;

/// Bitwise writer formed by wrapping a bytewise [`io::Write`].
/// Bits are written in big-endian order.
/// Bytes are filled with zeroes for padding.
pub struct BitWriter<W: io::Write> {
    /// Byte writer
    w: W,
    /// Current byte that contains current bits, yet to be written out
    cache: u8,
    /// Number of current bits
    cache_len: usize,
    /// Total number of written bits
    total_written: usize,
}

impl<W: io::Write> From<W> for BitWriter<W> {
    fn from(w: W) -> Self {
        BitWriter {
            w,
            cache: 0,
            cache_len: 0,
            total_written: 0,
        }
    }
}

impl<W: io::Write> io::Write for BitWriter<W> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        for b in buf {
            for i in 0..8 {
                self.write_bit((b & (1 << (7 - i))) != 0)?;
            }
        }
        Ok(buf.len())
    }

    /// Does **not** write out cached bits
    /// (i.e. bits written since the last byte boundary).
    /// To do this you must call [`BitWriter::flush_all()`].
    fn flush(&mut self) -> io::Result<()> {
        self.w.flush()
    }
}

impl<W: io::Write> BitWriter<W> {
    /// Create a bitwise writer from a bytewise one.
    /// Equivalent to using [`From`].
    pub fn new(w: W) -> BitWriter<W> {
        BitWriter::from(w)
    }

    /// Write a single bit.
    pub fn write_bit(&mut self, b: bool) -> io::Result<()> {
        if self.cache_len < 8 {
            self.cache_len += 1;
            self.total_written += 1;
            if b {
                self.cache |= 1 << (8 - self.cache_len);
            }
            Ok(())
        } else {
            self.w.write_all(&[self.cache])?;
            self.cache_len = 0;
            self.cache = 0;
            self.write_bit(b)
        }
    }

    /// Write out all cached bits.
    /// This may write up to two bytes and flushes the underlying [`io::Write`].
    pub fn flush_all(&mut self) -> io::Result<()> {
        if self.cache_len > 0 {
            self.w.write_all(&[self.cache])?;
            self.cache_len = 0;
            self.cache = 0;
        }

        io::Write::flush(&mut self.w)
    }

    /// Return total number of written bits.
    pub fn n_total_written(&self) -> usize {
        self.total_written
    }

    /// Write up to 64 bits in big-endian order.
    /// The first `len` many _least significant_ bits from `n` are written.
    ///
    /// Returns the number of written bits.
    pub fn write_bits_be(&mut self, n: u64, len: usize) -> io::Result<usize> {
        for i in 0..len {
            self.write_bit(n & (1 << (len - i - 1)) != 0)?;
        }
        Ok(len)
    }
}

/// Write the result of a bit operation into a byte vector and return the vector.
///
/// I/O to a vector never fails.
pub fn write_to_vec<F>(f: F) -> Vec<u8>
where
    F: FnOnce(&mut BitWriter<&mut Vec<u8>>) -> io::Result<usize>,
{
    let mut bytes = Vec::new();
    let mut bits = BitWriter::new(&mut bytes);
    f(&mut bits).expect("I/O to vector never fails");
    bits.flush_all().expect("I/O to vector never fails");
    bytes
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::jet::Core;
    use crate::node::CoreConstructible;
    use crate::ConstructNode;
    use std::sync::Arc;

    #[test]
    fn vec() {
        let program = Arc::<ConstructNode<Core>>::unit();
        let _ = write_to_vec(|w| program.encode(w));
    }

    #[test]
    fn empty_vec() {
        let vec = write_to_vec(|_| Ok(0));
        assert!(vec.is_empty());
    }
}