1use crate::dag::{Dag, DagLike, PostOrderIterItem, SharingTracker};
11use crate::jet::Jet;
12use crate::node::{self, Disconnectable};
13use crate::{BitWriter, Cmr, Value};
14
15use std::collections::{hash_map::Entry, HashMap};
16use std::sync::Arc;
17use std::{hash, io, mem};
18
19#[derive(Copy, Clone)]
20enum EncodeNode<'n, N: node::Marker> {
21 Node(&'n node::Node<N>),
22 Hidden(Cmr),
23}
24
25impl<'n, N: node::Marker> Disconnectable<EncodeNode<'n, N>> for EncodeNode<'n, N> {
26 fn disconnect_dag_arc(self, other: Arc<EncodeNode<'n, N>>) -> Dag<Arc<EncodeNode<'n, N>>> {
27 Dag::Binary(other, Arc::new(self))
28 }
29
30 fn disconnect_dag_ref<'s>(
31 &'s self,
32 other: &'s EncodeNode<'n, N>,
33 ) -> Dag<&'s EncodeNode<'n, N>> {
34 Dag::Binary(other, self)
35 }
36}
37
38impl<N: node::Marker> DagLike for EncodeNode<'_, N> {
39 type Node = Self;
40 fn data(&self) -> &Self {
41 self
42 }
43
44 fn as_dag_node(&self) -> Dag<Self> {
45 let node = match *self {
46 EncodeNode::Node(node) => node,
47 EncodeNode::Hidden(..) => return Dag::Nullary,
48 };
49 match node.inner() {
50 node::Inner::Unit
51 | node::Inner::Iden
52 | node::Inner::Fail(..)
53 | node::Inner::Jet(..)
54 | node::Inner::Word(..) => Dag::Nullary,
55 node::Inner::InjL(sub)
56 | node::Inner::InjR(sub)
57 | node::Inner::Take(sub)
58 | node::Inner::Drop(sub) => Dag::Unary(EncodeNode::Node(sub)),
59 node::Inner::Comp(left, right)
60 | node::Inner::Case(left, right)
61 | node::Inner::Pair(left, right) => {
62 Dag::Binary(EncodeNode::Node(left), EncodeNode::Node(right))
63 }
64 node::Inner::Disconnect(left, right) => {
65 right.disconnect_dag_ref(left).map(EncodeNode::Node)
66 }
67 node::Inner::AssertL(left, rcmr) => {
68 Dag::Binary(EncodeNode::Node(left), EncodeNode::Hidden(*rcmr))
69 }
70 node::Inner::AssertR(lcmr, right) => {
71 Dag::Binary(EncodeNode::Hidden(*lcmr), EncodeNode::Node(right))
72 }
73 node::Inner::Witness(..) => Dag::Nullary,
74 }
75 }
76}
77
78#[derive(Clone)]
79enum EncodeId<N: node::Marker> {
80 Node(N::SharingId),
81 Hidden(Cmr),
82}
83
84impl<N: node::Marker> PartialEq for EncodeId<N> {
86 fn eq(&self, other: &Self) -> bool {
87 match (self, other) {
88 (EncodeId::Node(left), EncodeId::Node(right)) => left == right,
89 (EncodeId::Hidden(left), EncodeId::Hidden(right)) => left == right,
90 _ => false,
91 }
92 }
93}
94
95impl<N: node::Marker> Eq for EncodeId<N> {}
96
97impl<N: node::Marker> hash::Hash for EncodeId<N> {
98 fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
99 match self {
100 EncodeId::Node(id) => {
101 hash::Hash::hash(&false, hasher);
102 hash::Hash::hash(id, hasher);
103 }
104 EncodeId::Hidden(cmr) => {
105 hash::Hash::hash(&true, hasher);
106 hash::Hash::hash(cmr, hasher);
107 }
108 }
109 }
110}
111
112#[derive(Clone)]
115pub struct EncodeSharing<N: node::Marker> {
116 map: HashMap<EncodeId<N>, usize>,
117}
118
119impl<N: node::Marker> Default for EncodeSharing<N> {
121 fn default() -> Self {
122 EncodeSharing {
123 map: HashMap::default(),
124 }
125 }
126}
127
128impl<N: node::Marker> SharingTracker<EncodeNode<'_, N>> for EncodeSharing<N> {
129 fn record(&mut self, d: &EncodeNode<N>, index: usize) -> Option<usize> {
130 let id = match d {
131 EncodeNode::Node(n) => EncodeId::Node(n.sharing_id()?),
132 EncodeNode::Hidden(cmr) => EncodeId::Hidden(*cmr),
133 };
134
135 match self.map.entry(id) {
136 Entry::Occupied(occ) => Some(*occ.get()),
137 Entry::Vacant(vac) => {
138 vac.insert(index);
139 None
140 }
141 }
142 }
143
144 fn seen_before(&self, d: &EncodeNode<N>) -> Option<usize> {
145 let id = match d {
146 EncodeNode::Node(n) => EncodeId::Node(n.sharing_id()?),
147 EncodeNode::Hidden(cmr) => EncodeId::Hidden(*cmr),
148 };
149
150 self.map.get(&id).copied()
151 }
152}
153
154pub fn encode_program<W: io::Write, N: node::Marker>(
158 program: &node::Node<N>,
159 w: &mut BitWriter<W>,
160) -> io::Result<usize> {
161 let iter = EncodeNode::Node(program).post_order_iter::<EncodeSharing<N>>();
162
163 let len = iter.clone().count();
164 let n_start = w.n_total_written();
165 encode_natural(len, w)?;
166
167 for node in iter {
168 encode_node(node, w)?;
169 }
170
171 Ok(w.n_total_written() - n_start)
172}
173
174fn encode_node<W: io::Write, N: node::Marker>(
176 data: PostOrderIterItem<EncodeNode<N>>,
177 w: &mut BitWriter<W>,
178) -> io::Result<()> {
179 let node = match data.node {
181 EncodeNode::Node(node) => node,
182 EncodeNode::Hidden(cmr) => {
183 w.write_bits_be(0b0110, 4)?;
184 encode_hash(cmr.as_ref(), w)?;
185 return Ok(());
186 }
187 };
188
189 if let Some(i_abs) = data.left_index {
190 debug_assert!(i_abs < data.index);
191 let i = data.index - i_abs;
192
193 if let Some(j_abs) = data.right_index {
194 debug_assert!(j_abs < data.index);
195 let j = data.index - j_abs;
196
197 match node.inner() {
198 node::Inner::Comp(_, _) => {
199 w.write_bits_be(0x00000, 5)?;
200 }
201 node::Inner::Case(_, _)
202 | node::Inner::AssertL(_, _)
203 | node::Inner::AssertR(_, _) => {
204 w.write_bits_be(0b00001, 5)?;
205 }
206 node::Inner::Pair(_, _) => {
207 w.write_bits_be(0b00010, 5)?;
208 }
209 node::Inner::Disconnect(_, _) => {
210 w.write_bits_be(0b00011, 5)?;
211 }
212 _ => unreachable!(),
213 }
214
215 encode_natural(i, w)?;
216 encode_natural(j, w)?;
217 } else {
218 match node.inner() {
219 node::Inner::InjL(_) => {
220 w.write_bits_be(0b00100, 5)?;
221 }
222 node::Inner::InjR(_) => {
223 w.write_bits_be(0b00101, 5)?;
224 }
225 node::Inner::Take(_) => {
226 w.write_bits_be(0b00110, 5)?;
227 }
228 node::Inner::Drop(_) => {
229 w.write_bits_be(0b00111, 5)?;
230 }
231 node::Inner::Disconnect(_, _) => {
232 w.write_bits_be(0b01011, 5)?;
233 }
234 _ => unreachable!(),
235 };
236
237 encode_natural(i, w)?;
238 }
239 } else {
240 match node.inner() {
241 node::Inner::Iden => {
242 w.write_bits_be(0b01000, 5)?;
243 }
244 node::Inner::Unit => {
245 w.write_bits_be(0b01001, 5)?;
246 }
247 node::Inner::Fail(entropy) => {
248 w.write_bits_be(0b01010, 5)?;
249 encode_hash(entropy.as_ref(), w)?;
250 }
251 node::Inner::Witness(_) => {
252 w.write_bits_be(0b0111, 4)?;
253 }
254 node::Inner::Jet(jet) => {
255 w.write_bit(true)?; w.write_bit(true)?; jet.encode(w)?;
258 }
259 node::Inner::Word(word) => {
260 w.write_bit(true)?; w.write_bit(false)?; encode_natural(1 + word.n() as usize, w)?;
263 encode_value(word.as_value(), w)?;
264 }
265 _ => unreachable!(),
266 }
267 }
268
269 Ok(())
270}
271
272pub fn encode_witness<'a, W: io::Write, I>(witness: I, w: &mut BitWriter<W>) -> io::Result<usize>
276where
277 I: Iterator<Item = &'a Value> + Clone,
278{
279 let mut len = 0;
280 for value in witness {
281 len += encode_value(value, w)?;
282 }
283 Ok(len)
284}
285
286pub fn encode_value<W: io::Write>(value: &Value, w: &mut BitWriter<W>) -> io::Result<usize> {
288 let n_start = w.n_total_written();
289 for bit in value.iter_compact() {
290 w.write_bit(bit)?;
291 }
292 Ok(w.n_total_written() - n_start)
293}
294
295pub fn encode_hash<W: io::Write>(h: &[u8], w: &mut BitWriter<W>) -> io::Result<usize> {
297 for byte in h {
298 w.write_bits_be(u64::from(*byte), 8)?;
299 }
300
301 Ok(h.len() * 8)
302}
303
304pub fn encode_natural<W: io::Write>(mut n: usize, w: &mut BitWriter<W>) -> io::Result<usize> {
306 assert!(n > 0, "Zero cannot be encoded");
307 let n_start = w.n_total_written();
308
309 fn truncated_bit_len(n: usize) -> usize {
311 8 * mem::size_of::<usize>() - n.leading_zeros() as usize - 1
312 }
313
314 let mut suffix = Vec::new();
315
316 loop {
317 debug_assert!(n > 0);
318 let len = truncated_bit_len(n);
319 if len == 0 {
320 w.write_bit(false)?;
321 break;
322 } else {
323 w.write_bit(true)?;
324 suffix.push((n, len));
325 n = len;
326 }
327 }
328
329 while let Some((bits, len)) = suffix.pop() {
330 let bits = bits as u64; w.write_bits_be(bits, len)?;
332 }
333
334 Ok(w.n_total_written() - n_start)
335}
336
337#[cfg(test)]
338mod test {
339 use super::*;
340
341 use crate::BitIter;
342
343 #[test]
344 fn encode_decode_natural() {
345 for n in 1..1000 {
346 let mut sink = Vec::<u8>::new();
347 let mut w = BitWriter::from(&mut sink);
348 encode_natural(n, &mut w).expect("encoding to vector");
349 w.flush_all().expect("flushing");
350 let m = BitIter::from(sink.into_iter())
351 .read_natural(None)
352 .expect("decoding from vector");
353 assert_eq!(n, m);
354 }
355 }
356}