1use crate::dag::{Dag, DagLike, PostOrderIterItem, SharingTracker};
11use crate::node::{self, Disconnectable};
12use crate::{BitWriter, Cmr, Value};
13
14use std::collections::{hash_map::Entry, HashMap};
15use std::sync::Arc;
16use std::{hash, io, mem};
17
18#[derive(Copy, Clone)]
19enum EncodeNode<'n, N: node::Marker> {
20 Node(&'n node::Node<N>),
21 Hidden(Cmr),
22}
23
24impl<'n, N: node::Marker> Disconnectable<EncodeNode<'n, N>> for EncodeNode<'n, N> {
25 fn disconnect_dag_arc(self, other: Arc<EncodeNode<'n, N>>) -> Dag<Arc<EncodeNode<'n, N>>> {
26 Dag::Binary(other, Arc::new(self))
27 }
28
29 fn disconnect_dag_ref<'s>(
30 &'s self,
31 other: &'s EncodeNode<'n, N>,
32 ) -> Dag<&'s EncodeNode<'n, N>> {
33 Dag::Binary(other, self)
34 }
35}
36
37impl<N: node::Marker> DagLike for EncodeNode<'_, N> {
38 type Node = Self;
39 fn data(&self) -> &Self {
40 self
41 }
42
43 fn as_dag_node(&self) -> Dag<Self> {
44 let node = match *self {
45 EncodeNode::Node(node) => node,
46 EncodeNode::Hidden(..) => return Dag::Nullary,
47 };
48 match node.inner() {
49 node::Inner::Unit
50 | node::Inner::Iden
51 | node::Inner::Fail(..)
52 | node::Inner::Jet(..)
53 | node::Inner::Word(..) => Dag::Nullary,
54 node::Inner::InjL(sub)
55 | node::Inner::InjR(sub)
56 | node::Inner::Take(sub)
57 | node::Inner::Drop(sub) => Dag::Unary(EncodeNode::Node(sub)),
58 node::Inner::Comp(left, right)
59 | node::Inner::Case(left, right)
60 | node::Inner::Pair(left, right) => {
61 Dag::Binary(EncodeNode::Node(left), EncodeNode::Node(right))
62 }
63 node::Inner::Disconnect(left, right) => {
64 right.disconnect_dag_ref(left).map(EncodeNode::Node)
65 }
66 node::Inner::AssertL(left, rcmr) => {
67 Dag::Binary(EncodeNode::Node(left), EncodeNode::Hidden(*rcmr))
68 }
69 node::Inner::AssertR(lcmr, right) => {
70 Dag::Binary(EncodeNode::Hidden(*lcmr), EncodeNode::Node(right))
71 }
72 node::Inner::Witness(..) => Dag::Nullary,
73 }
74 }
75}
76
77#[derive(Clone)]
78enum EncodeId<N: node::Marker> {
79 Node(N::SharingId),
80 Hidden(Cmr),
81}
82
83impl<N: node::Marker> PartialEq for EncodeId<N> {
85 fn eq(&self, other: &Self) -> bool {
86 match (self, other) {
87 (EncodeId::Node(left), EncodeId::Node(right)) => left == right,
88 (EncodeId::Hidden(left), EncodeId::Hidden(right)) => left == right,
89 _ => false,
90 }
91 }
92}
93
94impl<N: node::Marker> Eq for EncodeId<N> {}
95
96impl<N: node::Marker> hash::Hash for EncodeId<N> {
97 fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
98 match self {
99 EncodeId::Node(id) => {
100 hash::Hash::hash(&false, hasher);
101 hash::Hash::hash(id, hasher);
102 }
103 EncodeId::Hidden(cmr) => {
104 hash::Hash::hash(&true, hasher);
105 hash::Hash::hash(cmr, hasher);
106 }
107 }
108 }
109}
110
111#[derive(Clone)]
114pub struct EncodeSharing<N: node::Marker> {
115 map: HashMap<EncodeId<N>, usize>,
116}
117
118impl<N: node::Marker> Default for EncodeSharing<N> {
120 fn default() -> Self {
121 EncodeSharing {
122 map: HashMap::default(),
123 }
124 }
125}
126
127impl<N: node::Marker> SharingTracker<EncodeNode<'_, N>> for EncodeSharing<N> {
128 fn record(&mut self, d: &EncodeNode<N>, index: usize) -> Option<usize> {
129 let id = match d {
130 EncodeNode::Node(n) => EncodeId::Node(n.sharing_id()?),
131 EncodeNode::Hidden(cmr) => EncodeId::Hidden(*cmr),
132 };
133
134 match self.map.entry(id) {
135 Entry::Occupied(occ) => Some(*occ.get()),
136 Entry::Vacant(vac) => {
137 vac.insert(index);
138 None
139 }
140 }
141 }
142
143 fn seen_before(&self, d: &EncodeNode<N>) -> Option<usize> {
144 let id = match d {
145 EncodeNode::Node(n) => EncodeId::Node(n.sharing_id()?),
146 EncodeNode::Hidden(cmr) => EncodeId::Hidden(*cmr),
147 };
148
149 self.map.get(&id).copied()
150 }
151}
152
153pub fn encode_program<N: node::Marker>(
157 program: &node::Node<N>,
158 w: &mut BitWriter<&mut dyn io::Write>,
159) -> io::Result<usize> {
160 let iter = EncodeNode::Node(program).post_order_iter::<EncodeSharing<N>>();
161
162 let len = iter.clone().count();
163 let n_start = w.n_total_written();
164 encode_natural(len, w)?;
165
166 for node in iter {
167 encode_node(node, w)?;
168 }
169
170 Ok(w.n_total_written() - n_start)
171}
172
173fn encode_node<N: node::Marker>(
175 data: PostOrderIterItem<EncodeNode<N>>,
176 w: &mut BitWriter<&mut dyn io::Write>,
177) -> io::Result<()> {
178 let node = match data.node {
180 EncodeNode::Node(node) => node,
181 EncodeNode::Hidden(cmr) => {
182 w.write_bits_be(0b0110, 4)?;
183 encode_hash(cmr.as_ref(), w)?;
184 return Ok(());
185 }
186 };
187
188 if let Some(i_abs) = data.left_index {
189 debug_assert!(i_abs < data.index);
190 let i = data.index - i_abs;
191
192 if let Some(j_abs) = data.right_index {
193 debug_assert!(j_abs < data.index);
194 let j = data.index - j_abs;
195
196 match node.inner() {
197 node::Inner::Comp(_, _) => {
198 w.write_bits_be(0x00000, 5)?;
199 }
200 node::Inner::Case(_, _)
201 | node::Inner::AssertL(_, _)
202 | node::Inner::AssertR(_, _) => {
203 w.write_bits_be(0b00001, 5)?;
204 }
205 node::Inner::Pair(_, _) => {
206 w.write_bits_be(0b00010, 5)?;
207 }
208 node::Inner::Disconnect(_, _) => {
209 w.write_bits_be(0b00011, 5)?;
210 }
211 _ => unreachable!(),
212 }
213
214 encode_natural(i, w)?;
215 encode_natural(j, w)?;
216 } else {
217 match node.inner() {
218 node::Inner::InjL(_) => {
219 w.write_bits_be(0b00100, 5)?;
220 }
221 node::Inner::InjR(_) => {
222 w.write_bits_be(0b00101, 5)?;
223 }
224 node::Inner::Take(_) => {
225 w.write_bits_be(0b00110, 5)?;
226 }
227 node::Inner::Drop(_) => {
228 w.write_bits_be(0b00111, 5)?;
229 }
230 node::Inner::Disconnect(_, _) => {
231 w.write_bits_be(0b01011, 5)?;
232 }
233 _ => unreachable!(),
234 };
235
236 encode_natural(i, w)?;
237 }
238 } else {
239 match node.inner() {
240 node::Inner::Iden => {
241 w.write_bits_be(0b01000, 5)?;
242 }
243 node::Inner::Unit => {
244 w.write_bits_be(0b01001, 5)?;
245 }
246 node::Inner::Fail(entropy) => {
247 w.write_bits_be(0b01010, 5)?;
248 encode_hash(entropy.as_ref(), w)?;
249 }
250 node::Inner::Witness(_) => {
251 w.write_bits_be(0b0111, 4)?;
252 }
253 node::Inner::Jet(jet) => {
254 w.write_bit(true)?; w.write_bit(true)?; jet.encode(w)?;
257 }
258 node::Inner::Word(word) => {
259 w.write_bit(true)?; w.write_bit(false)?; encode_natural(1 + word.n() as usize, w)?;
262 encode_value(word.as_value(), w)?;
263 }
264 _ => unreachable!(),
265 }
266 }
267
268 Ok(())
269}
270
271pub fn encode_witness<'a, W: io::Write, I>(witness: I, w: &mut BitWriter<W>) -> io::Result<usize>
275where
276 I: Iterator<Item = &'a Value> + Clone,
277{
278 let mut len = 0;
279 for value in witness {
280 len += encode_value(value, w)?;
281 }
282 Ok(len)
283}
284
285pub fn encode_value<W: io::Write>(value: &Value, w: &mut BitWriter<W>) -> io::Result<usize> {
287 let n_start = w.n_total_written();
288 for bit in value.iter_compact() {
289 w.write_bit(bit)?;
290 }
291 Ok(w.n_total_written() - n_start)
292}
293
294pub fn encode_hash<W: io::Write>(h: &[u8], w: &mut BitWriter<W>) -> io::Result<usize> {
296 for byte in h {
297 w.write_bits_be(u64::from(*byte), 8)?;
298 }
299
300 Ok(h.len() * 8)
301}
302
303pub fn encode_natural<W: io::Write>(mut n: usize, w: &mut BitWriter<W>) -> io::Result<usize> {
305 assert!(n > 0, "Zero cannot be encoded");
306 let n_start = w.n_total_written();
307
308 fn truncated_bit_len(n: usize) -> usize {
310 8 * mem::size_of::<usize>() - n.leading_zeros() as usize - 1
311 }
312
313 let mut suffix = Vec::new();
314
315 loop {
316 debug_assert!(n > 0);
317 let len = truncated_bit_len(n);
318 if len == 0 {
319 w.write_bit(false)?;
320 break;
321 } else {
322 w.write_bit(true)?;
323 suffix.push((n, len));
324 n = len;
325 }
326 }
327
328 while let Some((bits, len)) = suffix.pop() {
329 let bits = bits as u64; w.write_bits_be(bits, len)?;
331 }
332
333 Ok(w.n_total_written() - n_start)
334}
335
336#[cfg(test)]
337mod test {
338 use super::*;
339
340 use crate::BitIter;
341
342 #[test]
343 fn encode_decode_natural() {
344 for n in 1..1000 {
345 let mut sink = Vec::<u8>::new();
346 let mut w = BitWriter::from(&mut sink);
347 encode_natural(n, &mut w).expect("encoding to vector");
348 w.flush_all().expect("flushing");
349 let m: usize = BitIter::from(sink.into_iter())
350 .read_natural(None)
351 .expect("decoding from vector");
352 assert_eq!(n, m);
353 }
354 }
355}