Skip to main content

simplicity/bit_encoding/
encode.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! # Encoding
4//!
5//! Functionality to encode Simplicity programs.
6//! These programs are encoded bitwise rather than bytewise,
7//! so given a hex dump of a program it is not generally possible
8//! to read it visually the way you can with Bitcoin Script.
9
10use crate::dag::{Dag, DagLike, PostOrderIterItem, SharingTracker};
11use crate::node::{self, Disconnectable};
12use crate::{BitWriter, Cmr, Value};
13
14use std::collections::{hash_map::Entry, HashMap};
15use std::sync::Arc;
16use std::{hash, io, mem};
17
18#[derive(Copy, Clone)]
19enum EncodeNode<'n, N: node::Marker> {
20    Node(&'n node::Node<N>),
21    Hidden(Cmr),
22}
23
24impl<'n, N: node::Marker> Disconnectable<EncodeNode<'n, N>> for EncodeNode<'n, N> {
25    fn disconnect_dag_arc(self, other: Arc<EncodeNode<'n, N>>) -> Dag<Arc<EncodeNode<'n, N>>> {
26        Dag::Binary(other, Arc::new(self))
27    }
28
29    fn disconnect_dag_ref<'s>(
30        &'s self,
31        other: &'s EncodeNode<'n, N>,
32    ) -> Dag<&'s EncodeNode<'n, N>> {
33        Dag::Binary(other, self)
34    }
35}
36
37impl<N: node::Marker> DagLike for EncodeNode<'_, N> {
38    type Node = Self;
39    fn data(&self) -> &Self {
40        self
41    }
42
43    fn as_dag_node(&self) -> Dag<Self> {
44        let node = match *self {
45            EncodeNode::Node(node) => node,
46            EncodeNode::Hidden(..) => return Dag::Nullary,
47        };
48        match node.inner() {
49            node::Inner::Unit
50            | node::Inner::Iden
51            | node::Inner::Fail(..)
52            | node::Inner::Jet(..)
53            | node::Inner::Word(..) => Dag::Nullary,
54            node::Inner::InjL(sub)
55            | node::Inner::InjR(sub)
56            | node::Inner::Take(sub)
57            | node::Inner::Drop(sub) => Dag::Unary(EncodeNode::Node(sub)),
58            node::Inner::Comp(left, right)
59            | node::Inner::Case(left, right)
60            | node::Inner::Pair(left, right) => {
61                Dag::Binary(EncodeNode::Node(left), EncodeNode::Node(right))
62            }
63            node::Inner::Disconnect(left, right) => {
64                right.disconnect_dag_ref(left).map(EncodeNode::Node)
65            }
66            node::Inner::AssertL(left, rcmr) => {
67                Dag::Binary(EncodeNode::Node(left), EncodeNode::Hidden(*rcmr))
68            }
69            node::Inner::AssertR(lcmr, right) => {
70                Dag::Binary(EncodeNode::Hidden(*lcmr), EncodeNode::Node(right))
71            }
72            node::Inner::Witness(..) => Dag::Nullary,
73        }
74    }
75}
76
77#[derive(Clone)]
78enum EncodeId<N: node::Marker> {
79    Node(N::SharingId),
80    Hidden(Cmr),
81}
82
83// Have to implement these manually because Rust sucks.
84impl<N: node::Marker> PartialEq for EncodeId<N> {
85    fn eq(&self, other: &Self) -> bool {
86        match (self, other) {
87            (EncodeId::Node(left), EncodeId::Node(right)) => left == right,
88            (EncodeId::Hidden(left), EncodeId::Hidden(right)) => left == right,
89            _ => false,
90        }
91    }
92}
93
94impl<N: node::Marker> Eq for EncodeId<N> {}
95
96impl<N: node::Marker> hash::Hash for EncodeId<N> {
97    fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
98        match self {
99            EncodeId::Node(id) => {
100                hash::Hash::hash(&false, hasher);
101                hash::Hash::hash(id, hasher);
102            }
103            EncodeId::Hidden(cmr) => {
104                hash::Hash::hash(&true, hasher);
105                hash::Hash::hash(cmr, hasher);
106            }
107        }
108    }
109}
110
111/// Shares nodes based on IHR, *except* for Hidden nodes, which are identified
112/// solely by the hash they contain
113#[derive(Clone)]
114pub struct EncodeSharing<N: node::Marker> {
115    map: HashMap<EncodeId<N>, usize>,
116}
117
118// Annoyingly we have to implement Default by hand
119impl<N: node::Marker> Default for EncodeSharing<N> {
120    fn default() -> Self {
121        EncodeSharing {
122            map: HashMap::default(),
123        }
124    }
125}
126
127impl<N: node::Marker> SharingTracker<EncodeNode<'_, N>> for EncodeSharing<N> {
128    fn record(&mut self, d: &EncodeNode<N>, index: usize) -> Option<usize> {
129        let id = match d {
130            EncodeNode::Node(n) => EncodeId::Node(n.sharing_id()?),
131            EncodeNode::Hidden(cmr) => EncodeId::Hidden(*cmr),
132        };
133
134        match self.map.entry(id) {
135            Entry::Occupied(occ) => Some(*occ.get()),
136            Entry::Vacant(vac) => {
137                vac.insert(index);
138                None
139            }
140        }
141    }
142
143    fn seen_before(&self, d: &EncodeNode<N>) -> Option<usize> {
144        let id = match d {
145            EncodeNode::Node(n) => EncodeId::Node(n.sharing_id()?),
146            EncodeNode::Hidden(cmr) => EncodeId::Hidden(*cmr),
147        };
148
149        self.map.get(&id).copied()
150    }
151}
152
153/// Encode a Simplicity program to bits, without witness data.
154///
155/// Returns the number of written bits.
156pub fn encode_program<N: node::Marker>(
157    program: &node::Node<N>,
158    w: &mut BitWriter<&mut dyn io::Write>,
159) -> io::Result<usize> {
160    let iter = EncodeNode::Node(program).post_order_iter::<EncodeSharing<N>>();
161
162    let len = iter.clone().count();
163    let n_start = w.n_total_written();
164    encode_natural(len, w)?;
165
166    for node in iter {
167        encode_node(node, w)?;
168    }
169
170    Ok(w.n_total_written() - n_start)
171}
172
173/// Encode a node to bits.
174fn encode_node<N: node::Marker>(
175    data: PostOrderIterItem<EncodeNode<N>>,
176    w: &mut BitWriter<&mut dyn io::Write>,
177) -> io::Result<()> {
178    // Handle Hidden nodes specially
179    let node = match data.node {
180        EncodeNode::Node(node) => node,
181        EncodeNode::Hidden(cmr) => {
182            w.write_bits_be(0b0110, 4)?;
183            encode_hash(cmr.as_ref(), w)?;
184            return Ok(());
185        }
186    };
187
188    if let Some(i_abs) = data.left_index {
189        debug_assert!(i_abs < data.index);
190        let i = data.index - i_abs;
191
192        if let Some(j_abs) = data.right_index {
193            debug_assert!(j_abs < data.index);
194            let j = data.index - j_abs;
195
196            match node.inner() {
197                node::Inner::Comp(_, _) => {
198                    w.write_bits_be(0x00000, 5)?;
199                }
200                node::Inner::Case(_, _)
201                | node::Inner::AssertL(_, _)
202                | node::Inner::AssertR(_, _) => {
203                    w.write_bits_be(0b00001, 5)?;
204                }
205                node::Inner::Pair(_, _) => {
206                    w.write_bits_be(0b00010, 5)?;
207                }
208                node::Inner::Disconnect(_, _) => {
209                    w.write_bits_be(0b00011, 5)?;
210                }
211                _ => unreachable!(),
212            }
213
214            encode_natural(i, w)?;
215            encode_natural(j, w)?;
216        } else {
217            match node.inner() {
218                node::Inner::InjL(_) => {
219                    w.write_bits_be(0b00100, 5)?;
220                }
221                node::Inner::InjR(_) => {
222                    w.write_bits_be(0b00101, 5)?;
223                }
224                node::Inner::Take(_) => {
225                    w.write_bits_be(0b00110, 5)?;
226                }
227                node::Inner::Drop(_) => {
228                    w.write_bits_be(0b00111, 5)?;
229                }
230                node::Inner::Disconnect(_, _) => {
231                    w.write_bits_be(0b01011, 5)?;
232                }
233                _ => unreachable!(),
234            };
235
236            encode_natural(i, w)?;
237        }
238    } else {
239        match node.inner() {
240            node::Inner::Iden => {
241                w.write_bits_be(0b01000, 5)?;
242            }
243            node::Inner::Unit => {
244                w.write_bits_be(0b01001, 5)?;
245            }
246            node::Inner::Fail(entropy) => {
247                w.write_bits_be(0b01010, 5)?;
248                encode_hash(entropy.as_ref(), w)?;
249            }
250            node::Inner::Witness(_) => {
251                w.write_bits_be(0b0111, 4)?;
252            }
253            node::Inner::Jet(jet) => {
254                w.write_bit(true)?; // jet or word
255                w.write_bit(true)?; // jet
256                jet.encode(w)?;
257            }
258            node::Inner::Word(word) => {
259                w.write_bit(true)?; // jet or word
260                w.write_bit(false)?; // word
261                encode_natural(1 + word.n() as usize, w)?;
262                encode_value(word.as_value(), w)?;
263            }
264            _ => unreachable!(),
265        }
266    }
267
268    Ok(())
269}
270
271/// Encode witness data to bits.
272///
273/// Returns the number of written bits.
274pub fn encode_witness<'a, W: io::Write, I>(witness: I, w: &mut BitWriter<W>) -> io::Result<usize>
275where
276    I: Iterator<Item = &'a Value> + Clone,
277{
278    let mut len = 0;
279    for value in witness {
280        len += encode_value(value, w)?;
281    }
282    Ok(len)
283}
284
285/// Encode a value to bits.
286pub fn encode_value<W: io::Write>(value: &Value, w: &mut BitWriter<W>) -> io::Result<usize> {
287    let n_start = w.n_total_written();
288    for bit in value.iter_compact() {
289        w.write_bit(bit)?;
290    }
291    Ok(w.n_total_written() - n_start)
292}
293
294/// Encode a hash to bits.
295pub fn encode_hash<W: io::Write>(h: &[u8], w: &mut BitWriter<W>) -> io::Result<usize> {
296    for byte in h {
297        w.write_bits_be(u64::from(*byte), 8)?;
298    }
299
300    Ok(h.len() * 8)
301}
302
303/// Encode a positive integer to bits.
304pub fn encode_natural<W: io::Write>(mut n: usize, w: &mut BitWriter<W>) -> io::Result<usize> {
305    assert!(n > 0, "Zero cannot be encoded");
306    let n_start = w.n_total_written();
307
308    /// Minimum number of bits to represent `n` minus the most-significant bit
309    fn truncated_bit_len(n: usize) -> usize {
310        8 * mem::size_of::<usize>() - n.leading_zeros() as usize - 1
311    }
312
313    let mut suffix = Vec::new();
314
315    loop {
316        debug_assert!(n > 0);
317        let len = truncated_bit_len(n);
318        if len == 0 {
319            w.write_bit(false)?;
320            break;
321        } else {
322            w.write_bit(true)?;
323            suffix.push((n, len));
324            n = len;
325        }
326    }
327
328    while let Some((bits, len)) = suffix.pop() {
329        let bits = bits as u64; // Case safety: assuming 64-bit machine or lower
330        w.write_bits_be(bits, len)?;
331    }
332
333    Ok(w.n_total_written() - n_start)
334}
335
336#[cfg(test)]
337mod test {
338    use super::*;
339
340    use crate::BitIter;
341
342    #[test]
343    fn encode_decode_natural() {
344        for n in 1..1000 {
345            let mut sink = Vec::<u8>::new();
346            let mut w = BitWriter::from(&mut sink);
347            encode_natural(n, &mut w).expect("encoding to vector");
348            w.flush_all().expect("flushing");
349            let m: usize = BitIter::from(sink.into_iter())
350                .read_natural(None)
351                .expect("decoding from vector");
352            assert_eq!(n, m);
353        }
354    }
355}