simplicity/bit_encoding/
encode.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! # Encoding
4//!
5//! Functionality to encode Simplicity programs.
6//! These programs are encoded bitwise rather than bytewise,
7//! so given a hex dump of a program it is not generally possible
8//! to read it visually the way you can with Bitcoin Script.
9
10use crate::dag::{Dag, DagLike, PostOrderIterItem, SharingTracker};
11use crate::jet::Jet;
12use crate::node::{self, Disconnectable};
13use crate::{BitWriter, Cmr, Value};
14
15use std::collections::{hash_map::Entry, HashMap};
16use std::sync::Arc;
17use std::{hash, io, mem};
18
19#[derive(Copy, Clone)]
20enum EncodeNode<'n, N: node::Marker> {
21    Node(&'n node::Node<N>),
22    Hidden(Cmr),
23}
24
25impl<'n, N: node::Marker> Disconnectable<EncodeNode<'n, N>> for EncodeNode<'n, N> {
26    fn disconnect_dag_arc(self, other: Arc<EncodeNode<'n, N>>) -> Dag<Arc<EncodeNode<'n, N>>> {
27        Dag::Binary(other, Arc::new(self))
28    }
29
30    fn disconnect_dag_ref<'s>(
31        &'s self,
32        other: &'s EncodeNode<'n, N>,
33    ) -> Dag<&'s EncodeNode<'n, N>> {
34        Dag::Binary(other, self)
35    }
36}
37
38impl<N: node::Marker> DagLike for EncodeNode<'_, N> {
39    type Node = Self;
40    fn data(&self) -> &Self {
41        self
42    }
43
44    fn as_dag_node(&self) -> Dag<Self> {
45        let node = match *self {
46            EncodeNode::Node(node) => node,
47            EncodeNode::Hidden(..) => return Dag::Nullary,
48        };
49        match node.inner() {
50            node::Inner::Unit
51            | node::Inner::Iden
52            | node::Inner::Fail(..)
53            | node::Inner::Jet(..)
54            | node::Inner::Word(..) => Dag::Nullary,
55            node::Inner::InjL(sub)
56            | node::Inner::InjR(sub)
57            | node::Inner::Take(sub)
58            | node::Inner::Drop(sub) => Dag::Unary(EncodeNode::Node(sub)),
59            node::Inner::Comp(left, right)
60            | node::Inner::Case(left, right)
61            | node::Inner::Pair(left, right) => {
62                Dag::Binary(EncodeNode::Node(left), EncodeNode::Node(right))
63            }
64            node::Inner::Disconnect(left, right) => {
65                right.disconnect_dag_ref(left).map(EncodeNode::Node)
66            }
67            node::Inner::AssertL(left, rcmr) => {
68                Dag::Binary(EncodeNode::Node(left), EncodeNode::Hidden(*rcmr))
69            }
70            node::Inner::AssertR(lcmr, right) => {
71                Dag::Binary(EncodeNode::Hidden(*lcmr), EncodeNode::Node(right))
72            }
73            node::Inner::Witness(..) => Dag::Nullary,
74        }
75    }
76}
77
78#[derive(Clone)]
79enum EncodeId<N: node::Marker> {
80    Node(N::SharingId),
81    Hidden(Cmr),
82}
83
84// Have to implement these manually because Rust sucks.
85impl<N: node::Marker> PartialEq for EncodeId<N> {
86    fn eq(&self, other: &Self) -> bool {
87        match (self, other) {
88            (EncodeId::Node(left), EncodeId::Node(right)) => left == right,
89            (EncodeId::Hidden(left), EncodeId::Hidden(right)) => left == right,
90            _ => false,
91        }
92    }
93}
94
95impl<N: node::Marker> Eq for EncodeId<N> {}
96
97impl<N: node::Marker> hash::Hash for EncodeId<N> {
98    fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
99        match self {
100            EncodeId::Node(id) => {
101                hash::Hash::hash(&false, hasher);
102                hash::Hash::hash(id, hasher);
103            }
104            EncodeId::Hidden(cmr) => {
105                hash::Hash::hash(&true, hasher);
106                hash::Hash::hash(cmr, hasher);
107            }
108        }
109    }
110}
111
112/// Shares nodes based on IHR, *except* for Hidden nodes, which are identified
113/// solely by the hash they contain
114#[derive(Clone)]
115pub struct EncodeSharing<N: node::Marker> {
116    map: HashMap<EncodeId<N>, usize>,
117}
118
119// Annoyingly we have to implement Default by hand
120impl<N: node::Marker> Default for EncodeSharing<N> {
121    fn default() -> Self {
122        EncodeSharing {
123            map: HashMap::default(),
124        }
125    }
126}
127
128impl<N: node::Marker> SharingTracker<EncodeNode<'_, N>> for EncodeSharing<N> {
129    fn record(&mut self, d: &EncodeNode<N>, index: usize) -> Option<usize> {
130        let id = match d {
131            EncodeNode::Node(n) => EncodeId::Node(n.sharing_id()?),
132            EncodeNode::Hidden(cmr) => EncodeId::Hidden(*cmr),
133        };
134
135        match self.map.entry(id) {
136            Entry::Occupied(occ) => Some(*occ.get()),
137            Entry::Vacant(vac) => {
138                vac.insert(index);
139                None
140            }
141        }
142    }
143
144    fn seen_before(&self, d: &EncodeNode<N>) -> Option<usize> {
145        let id = match d {
146            EncodeNode::Node(n) => EncodeId::Node(n.sharing_id()?),
147            EncodeNode::Hidden(cmr) => EncodeId::Hidden(*cmr),
148        };
149
150        self.map.get(&id).copied()
151    }
152}
153
154/// Encode a Simplicity program to bits, without witness data.
155///
156/// Returns the number of written bits.
157pub fn encode_program<W: io::Write, N: node::Marker>(
158    program: &node::Node<N>,
159    w: &mut BitWriter<W>,
160) -> io::Result<usize> {
161    let iter = EncodeNode::Node(program).post_order_iter::<EncodeSharing<N>>();
162
163    let len = iter.clone().count();
164    let n_start = w.n_total_written();
165    encode_natural(len, w)?;
166
167    for node in iter {
168        encode_node(node, w)?;
169    }
170
171    Ok(w.n_total_written() - n_start)
172}
173
174/// Encode a node to bits.
175fn encode_node<W: io::Write, N: node::Marker>(
176    data: PostOrderIterItem<EncodeNode<N>>,
177    w: &mut BitWriter<W>,
178) -> io::Result<()> {
179    // Handle Hidden nodes specially
180    let node = match data.node {
181        EncodeNode::Node(node) => node,
182        EncodeNode::Hidden(cmr) => {
183            w.write_bits_be(0b0110, 4)?;
184            encode_hash(cmr.as_ref(), w)?;
185            return Ok(());
186        }
187    };
188
189    if let Some(i_abs) = data.left_index {
190        debug_assert!(i_abs < data.index);
191        let i = data.index - i_abs;
192
193        if let Some(j_abs) = data.right_index {
194            debug_assert!(j_abs < data.index);
195            let j = data.index - j_abs;
196
197            match node.inner() {
198                node::Inner::Comp(_, _) => {
199                    w.write_bits_be(0x00000, 5)?;
200                }
201                node::Inner::Case(_, _)
202                | node::Inner::AssertL(_, _)
203                | node::Inner::AssertR(_, _) => {
204                    w.write_bits_be(0b00001, 5)?;
205                }
206                node::Inner::Pair(_, _) => {
207                    w.write_bits_be(0b00010, 5)?;
208                }
209                node::Inner::Disconnect(_, _) => {
210                    w.write_bits_be(0b00011, 5)?;
211                }
212                _ => unreachable!(),
213            }
214
215            encode_natural(i, w)?;
216            encode_natural(j, w)?;
217        } else {
218            match node.inner() {
219                node::Inner::InjL(_) => {
220                    w.write_bits_be(0b00100, 5)?;
221                }
222                node::Inner::InjR(_) => {
223                    w.write_bits_be(0b00101, 5)?;
224                }
225                node::Inner::Take(_) => {
226                    w.write_bits_be(0b00110, 5)?;
227                }
228                node::Inner::Drop(_) => {
229                    w.write_bits_be(0b00111, 5)?;
230                }
231                node::Inner::Disconnect(_, _) => {
232                    w.write_bits_be(0b01011, 5)?;
233                }
234                _ => unreachable!(),
235            };
236
237            encode_natural(i, w)?;
238        }
239    } else {
240        match node.inner() {
241            node::Inner::Iden => {
242                w.write_bits_be(0b01000, 5)?;
243            }
244            node::Inner::Unit => {
245                w.write_bits_be(0b01001, 5)?;
246            }
247            node::Inner::Fail(entropy) => {
248                w.write_bits_be(0b01010, 5)?;
249                encode_hash(entropy.as_ref(), w)?;
250            }
251            node::Inner::Witness(_) => {
252                w.write_bits_be(0b0111, 4)?;
253            }
254            node::Inner::Jet(jet) => {
255                w.write_bit(true)?; // jet or word
256                w.write_bit(true)?; // jet
257                jet.encode(w)?;
258            }
259            node::Inner::Word(word) => {
260                w.write_bit(true)?; // jet or word
261                w.write_bit(false)?; // word
262                encode_natural(1 + word.n() as usize, w)?;
263                encode_value(word.as_value(), w)?;
264            }
265            _ => unreachable!(),
266        }
267    }
268
269    Ok(())
270}
271
272/// Encode witness data to bits.
273///
274/// Returns the number of written bits.
275pub fn encode_witness<'a, W: io::Write, I>(witness: I, w: &mut BitWriter<W>) -> io::Result<usize>
276where
277    I: Iterator<Item = &'a Value> + Clone,
278{
279    let mut len = 0;
280    for value in witness {
281        len += encode_value(value, w)?;
282    }
283    Ok(len)
284}
285
286/// Encode a value to bits.
287pub fn encode_value<W: io::Write>(value: &Value, w: &mut BitWriter<W>) -> io::Result<usize> {
288    let n_start = w.n_total_written();
289    for bit in value.iter_compact() {
290        w.write_bit(bit)?;
291    }
292    Ok(w.n_total_written() - n_start)
293}
294
295/// Encode a hash to bits.
296pub fn encode_hash<W: io::Write>(h: &[u8], w: &mut BitWriter<W>) -> io::Result<usize> {
297    for byte in h {
298        w.write_bits_be(u64::from(*byte), 8)?;
299    }
300
301    Ok(h.len() * 8)
302}
303
304/// Encode a positive integer to bits.
305pub fn encode_natural<W: io::Write>(mut n: usize, w: &mut BitWriter<W>) -> io::Result<usize> {
306    assert!(n > 0, "Zero cannot be encoded");
307    let n_start = w.n_total_written();
308
309    /// Minimum number of bits to represent `n` minus the most-significant bit
310    fn truncated_bit_len(n: usize) -> usize {
311        8 * mem::size_of::<usize>() - n.leading_zeros() as usize - 1
312    }
313
314    let mut suffix = Vec::new();
315
316    loop {
317        debug_assert!(n > 0);
318        let len = truncated_bit_len(n);
319        if len == 0 {
320            w.write_bit(false)?;
321            break;
322        } else {
323            w.write_bit(true)?;
324            suffix.push((n, len));
325            n = len;
326        }
327    }
328
329    while let Some((bits, len)) = suffix.pop() {
330        let bits = bits as u64; // Case safety: assuming 64-bit machine or lower
331        w.write_bits_be(bits, len)?;
332    }
333
334    Ok(w.n_total_written() - n_start)
335}
336
337#[cfg(test)]
338mod test {
339    use super::*;
340
341    use crate::BitIter;
342
343    #[test]
344    fn encode_decode_natural() {
345        for n in 1..1000 {
346            let mut sink = Vec::<u8>::new();
347            let mut w = BitWriter::from(&mut sink);
348            encode_natural(n, &mut w).expect("encoding to vector");
349            w.flush_all().expect("flushing");
350            let m = BitIter::from(sink.into_iter())
351                .read_natural(None)
352                .expect("decoding from vector");
353            assert_eq!(n, m);
354        }
355    }
356}