use crate::analysis::NodeBounds;
use crate::dag::{DagLike, InternalSharing, MaxSharing, PostOrderIterItem};
use crate::jet::Jet;
use crate::types::{self, arrow::FinalArrow};
use crate::{encode, write_to_vec, WitnessNode};
use crate::{Amr, BitIter, BitWriter, Cmr, Error, FirstPassImr, Imr, Value};
use super::{
Commit, CommitData, CommitNode, Construct, ConstructNode, Constructible, Converter, Inner,
Marker, NoDisconnect, NoWitness, Node, Witness, WitnessData,
};
use std::collections::HashSet;
use std::io;
use std::marker::PhantomData;
use std::sync::Arc;
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
pub struct Redeem<J> {
never: std::convert::Infallible,
phantom: std::marker::PhantomData<J>,
}
impl<J: Jet> Marker for Redeem<J> {
type CachedData = Arc<RedeemData<J>>;
type Witness = Arc<Value>;
type Disconnect = Arc<RedeemNode<J>>;
type SharingId = Imr;
type Jet = J;
fn compute_sharing_id(_: Cmr, cached_data: &Arc<RedeemData<J>>) -> Option<Imr> {
Some(cached_data.imr)
}
}
pub type RedeemNode<J> = Node<Redeem<J>>;
#[derive(Clone, Debug)]
pub struct RedeemData<J> {
amr: Amr,
first_pass_imr: FirstPassImr,
imr: Imr,
arrow: FinalArrow,
bounds: NodeBounds,
phantom: PhantomData<J>,
}
impl<J> PartialEq for RedeemData<J> {
fn eq(&self, other: &Self) -> bool {
self.imr == other.imr
}
}
impl<J> Eq for RedeemData<J> {}
impl<J> std::hash::Hash for RedeemData<J> {
fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
self.imr.hash(hasher)
}
}
impl<J: Jet> RedeemData<J> {
pub fn new(arrow: FinalArrow, inner: Inner<&Arc<Self>, J, &Arc<Self>, Arc<Value>>) -> Self {
let (amr, first_pass_imr, bounds) = match inner {
Inner::Iden => (
Amr::iden(&arrow),
FirstPassImr::iden(),
NodeBounds::iden(arrow.source.bit_width()),
),
Inner::Unit => (Amr::unit(&arrow), FirstPassImr::unit(), NodeBounds::unit()),
Inner::InjL(child) => (
Amr::injl(&arrow, child.amr),
FirstPassImr::injl(child.first_pass_imr),
NodeBounds::injl(child.bounds),
),
Inner::InjR(child) => (
Amr::injr(&arrow, child.amr),
FirstPassImr::injr(child.first_pass_imr),
NodeBounds::injr(child.bounds),
),
Inner::Take(child) => (
Amr::take(&arrow, child.amr),
FirstPassImr::take(child.first_pass_imr),
NodeBounds::take(child.bounds),
),
Inner::Drop(child) => (
Amr::drop(&arrow, child.amr),
FirstPassImr::drop(child.first_pass_imr),
NodeBounds::drop(child.bounds),
),
Inner::Comp(left, right) => (
Amr::comp(&arrow, &left.arrow, left.amr, right.amr),
FirstPassImr::comp(left.first_pass_imr, right.first_pass_imr),
NodeBounds::comp(left.bounds, right.bounds, left.arrow.target.bit_width()),
),
Inner::Case(left, right) => (
Amr::case(&arrow, left.amr, right.amr),
FirstPassImr::case(left.first_pass_imr, right.first_pass_imr),
NodeBounds::case(left.bounds, right.bounds),
),
Inner::AssertL(left, r_cmr) => (
Amr::assertl(&arrow, left.amr, r_cmr.into()),
FirstPassImr::case(left.first_pass_imr, r_cmr.into()),
NodeBounds::assertl(left.bounds),
),
Inner::AssertR(l_cmr, right) => (
Amr::assertr(&arrow, l_cmr.into(), right.amr),
FirstPassImr::case(l_cmr.into(), right.first_pass_imr),
NodeBounds::assertr(right.bounds),
),
Inner::Pair(left, right) => (
Amr::pair(&arrow, &left.arrow, &right.arrow, left.amr, right.amr),
FirstPassImr::pair(left.first_pass_imr, right.first_pass_imr),
NodeBounds::pair(left.bounds, right.bounds),
),
Inner::Disconnect(left, right) => (
Amr::disconnect(&arrow, &right.arrow, left.amr, right.amr),
FirstPassImr::disconnect(left.first_pass_imr, right.first_pass_imr),
NodeBounds::disconnect(
left.bounds,
right.bounds,
left.arrow.target.bit_width() - right.arrow.source.bit_width(),
left.arrow.source.bit_width(),
left.arrow.target.bit_width(),
),
),
Inner::Witness(ref value) => (
Amr::witness(&arrow, value),
FirstPassImr::witness(&arrow, value),
NodeBounds::witness(arrow.target.bit_width()),
),
Inner::Fail(entropy) => (
Amr::fail(entropy),
FirstPassImr::fail(entropy),
NodeBounds::fail(),
),
Inner::Jet(jet) => (Amr::jet(jet), FirstPassImr::jet(jet), NodeBounds::jet(jet)),
Inner::Word(ref val) => (
Amr::const_word(val),
FirstPassImr::const_word(val),
NodeBounds::const_word(val),
),
};
RedeemData {
amr,
first_pass_imr,
imr: Imr::compute_pass2(first_pass_imr, &arrow),
arrow,
bounds,
phantom: PhantomData,
}
}
}
impl<J: Jet> RedeemNode<J> {
pub fn amr(&self) -> Amr {
self.data.amr
}
pub fn imr(&self) -> Imr {
self.data.imr
}
pub fn arrow(&self) -> &FinalArrow {
&self.data.arrow
}
pub fn bounds(&self) -> NodeBounds {
self.data.bounds
}
pub fn unfinalize(&self) -> Result<Arc<CommitNode<J>>, types::Error> {
struct Unfinalizer<J>(PhantomData<J>);
impl<J: Jet> Converter<Redeem<J>, Commit<J>> for Unfinalizer<J> {
type Error = types::Error;
fn convert_witness(
&mut self,
_: &PostOrderIterItem<&RedeemNode<J>>,
_: &Arc<Value>,
) -> Result<NoWitness, Self::Error> {
Ok(NoWitness)
}
fn convert_disconnect(
&mut self,
_: &PostOrderIterItem<&RedeemNode<J>>,
_: Option<&Arc<CommitNode<J>>>,
_: &Arc<RedeemNode<J>>,
) -> Result<NoDisconnect, Self::Error> {
Ok(NoDisconnect)
}
fn convert_data(
&mut self,
data: &PostOrderIterItem<&RedeemNode<J>>,
inner: Inner<&Arc<CommitNode<J>>, J, &NoDisconnect, &NoWitness>,
) -> Result<Arc<CommitData<J>>, Self::Error> {
let converted_data = inner.map(|node| node.cached_data());
Ok(Arc::new(CommitData::from_final(
data.node.data.arrow.shallow_clone(),
converted_data,
)))
}
}
self.convert::<MaxSharing<Redeem<J>>, _, _>(&mut Unfinalizer(PhantomData))
}
pub fn to_witness_node(&self) -> Arc<WitnessNode<J>> {
struct ToWitness<J>(PhantomData<J>);
impl<J: Jet> Converter<Redeem<J>, Witness<J>> for ToWitness<J> {
type Error = ();
fn convert_witness(
&mut self,
_: &PostOrderIterItem<&Node<Redeem<J>>>,
witness: &Arc<Value>,
) -> Result<Option<Arc<Value>>, Self::Error> {
Ok(Some(witness.clone()))
}
fn convert_disconnect(
&mut self,
_: &PostOrderIterItem<&Node<Redeem<J>>>,
right: Option<&Arc<Node<Witness<J>>>>,
_: &Arc<RedeemNode<J>>,
) -> Result<Option<Arc<Node<Witness<J>>>>, Self::Error> {
Ok(right.cloned())
}
fn convert_data(
&mut self,
_: &PostOrderIterItem<&Node<Redeem<J>>>,
inner: Inner<
&Arc<Node<Witness<J>>>,
J,
&Option<Arc<WitnessNode<J>>>,
&Option<Arc<Value>>,
>,
) -> Result<WitnessData<J>, Self::Error> {
let inner = inner
.map(|node| node.cached_data())
.map_witness(|maybe_value| maybe_value.clone());
Ok(WitnessData::from_inner(inner).expect("types are already finalized"))
}
}
self.convert::<InternalSharing, _, _>(&mut ToWitness(PhantomData))
.unwrap()
}
pub fn decode<I: Iterator<Item = u8>>(bits: &mut BitIter<I>) -> Result<Arc<Self>, Error> {
struct DecodeFinalizer<'bits, J: Jet, I: Iterator<Item = u8>> {
bits: &'bits mut BitIter<I>,
phantom: PhantomData<J>,
}
impl<'bits, J: Jet, I: Iterator<Item = u8>> Converter<Construct<J>, Redeem<J>>
for DecodeFinalizer<'bits, J, I>
{
type Error = Error;
fn convert_witness(
&mut self,
data: &PostOrderIterItem<&ConstructNode<J>>,
_: &NoWitness,
) -> Result<Arc<Value>, Self::Error> {
let target_ty = data.node.data.arrow().target.finalize()?;
self.bits.read_value(&target_ty).map_err(Error::from)
}
fn convert_disconnect(
&mut self,
_: &PostOrderIterItem<&ConstructNode<J>>,
right: Option<&Arc<RedeemNode<J>>>,
_: &Option<Arc<ConstructNode<J>>>,
) -> Result<Arc<RedeemNode<J>>, Self::Error> {
if let Some(child) = right {
Ok(Arc::clone(child))
} else {
Err(Error::DisconnectRedeemTime)
}
}
fn convert_data(
&mut self,
data: &PostOrderIterItem<&ConstructNode<J>>,
inner: Inner<&Arc<RedeemNode<J>>, J, &Arc<RedeemNode<J>>, &Arc<Value>>,
) -> Result<Arc<RedeemData<J>>, Self::Error> {
let arrow = data.node.data.arrow().finalize()?;
let converted_data = inner
.map(|node| node.cached_data())
.map_disconnect(|node| node.cached_data())
.map_witness(Arc::clone);
Ok(Arc::new(RedeemData::new(arrow, converted_data)))
}
}
let construct = crate::decode::decode_expression(bits)?;
construct.set_arrow_to_program()?;
let witness_len = if bits.read_bit()? {
bits.read_natural(None)?
} else {
0
};
let witness_start = bits.n_total_read();
let program: Arc<Self> =
construct.convert::<InternalSharing, _, _>(&mut DecodeFinalizer {
bits,
phantom: PhantomData,
})?;
if bits.n_total_read() != witness_start + witness_len {
return Err(Error::InconsistentWitnessLength);
}
let mut imrs: HashSet<Imr> = HashSet::new();
for data in program.as_ref().post_order_iter::<InternalSharing>() {
if !imrs.insert(data.node.imr()) {
return Err(Error::Decode(crate::decode::Error::SharingNotMaximal));
}
}
Ok(program)
}
pub fn encode<W: io::Write>(&self, w: &mut BitWriter<W>) -> io::Result<usize> {
let sharing_iter = self.post_order_iter::<MaxSharing<Redeem<J>>>();
let program_bits = encode::encode_program(self, w)?;
let witness_bits =
encode::encode_witness(sharing_iter.into_witnesses().map(Arc::as_ref), w)?;
w.flush_all()?;
Ok(program_bits + witness_bits)
}
pub fn encode_to_vec(&self) -> Vec<u8> {
write_to_vec(|w| self.encode(w))
}
}
#[cfg(test)]
mod tests {
use super::*;
use hex::DisplayHex;
use std::fmt;
use crate::jet::Core;
use crate::node::SimpleFinalizer;
fn assert_program_deserializable<J: Jet>(
prog_bytes: &[u8],
cmr_str: &str,
amr_str: &str,
imr_str: &str,
) -> Arc<RedeemNode<J>> {
let prog_hex = prog_bytes.as_hex();
let mut iter = BitIter::from(prog_bytes);
let prog = match RedeemNode::<J>::decode(&mut iter) {
Ok(prog) => prog,
Err(e) => panic!("program {} failed: {}", prog_hex, e),
};
assert_eq!(
prog.cmr().to_string(),
cmr_str,
"CMR mismatch (got {} expected {}) for program {}",
prog.cmr(),
cmr_str,
prog_hex,
);
assert_eq!(
prog.amr().to_string(),
amr_str,
"AMR mismatch (got {} expected {}) for program {}",
prog.amr(),
amr_str,
prog_hex,
);
assert_eq!(
prog.imr().to_string(),
imr_str,
"IMR mismatch (got {} expected {}) for program {}",
prog.imr(),
imr_str,
prog_hex,
);
let reser_sink = prog.encode_to_vec();
assert_eq!(
prog_bytes,
&reser_sink[..],
"program {} reserialized as {}",
prog_hex,
reser_sink.as_hex(),
);
prog
}
fn assert_program_not_deserializable<J: Jet>(prog: &[u8], err: &dyn fmt::Display) {
let prog_hex = prog.as_hex();
let err_str = err.to_string();
let mut iter = BitIter::from(prog);
match RedeemNode::<J>::decode(&mut iter) {
Ok(prog) => panic!(
"Program {} succeded (expected error {}). Program parsed as:\n{}",
prog_hex, err, prog
),
Err(e) if e.to_string() == err_str => {} Err(e) => panic!(
"Program {} failed with error {} (expected error {})",
prog_hex, e, err
),
};
}
#[test]
fn encode_shared_witnesses() {
let eqwits = [0xcd, 0xdc, 0x51, 0xb6, 0xe2, 0x08, 0xc0, 0x40];
let mut iter = BitIter::from(&eqwits[..]);
let eqwits_prog = CommitNode::<Core>::decode(&mut iter).unwrap();
let eqwits_final = eqwits_prog
.finalize(&mut SimpleFinalizer::new(std::iter::repeat(Value::u32(
0xDEADBEEF,
))))
.unwrap();
let output = eqwits_final.encode_to_vec();
assert_eq!(
output,
[0xc9, 0xc4, 0x6d, 0xb8, 0x82, 0x30, 0x11, 0xe2, 0x0d, 0xea, 0xdb, 0xee, 0xf0],
);
}
#[test]
fn decode_shared_witnesses() {
assert_program_deserializable::<Core>(
&[
0xc9, 0xc4, 0x6d, 0xb8, 0x82, 0x30, 0x11, 0xe2, 0x0d, 0xea, 0xdb, 0xee, 0xf0,
],
"2d170e731b6d6856e69f3c6ee04b368302f7f71b2270a26276d98ea494bbebd7",
"9bdb88f9a9ef64d5ec507af96e5b88ae3a8b09c042cb3c3563f982cafc572bae",
"71cdfd26a3f4dd865e2e92b526fc2083260c964c52dd9773aa52771f253b73e1",
);
}
#[test]
fn unshared_child() {
assert_program_not_deserializable::<Core>(
&[0xc1, 0x08, 0x04, 0x00, 0x00, 0x74, 0x74, 0x74],
&Error::Decode(crate::decode::Error::SharingNotMaximal),
);
}
#[test]
fn witness_consumed() {
let badwit = [0x27, 0x00];
let mut iter = BitIter::from(&badwit[..]);
if let Err(Error::InconsistentWitnessLength) =
RedeemNode::<crate::jet::Core>::decode(&mut iter)
{
} else {
panic!("accepted program with bad witness length")
}
}
#[test]
fn shared_grandchild() {
assert_program_deserializable::<Core>(
&[0xc1, 0x00, 0x00, 0x01, 0x00],
"c2c86be0081a9c75af49098f359c7efdfa7ccbd0459adb11bcf676b80c8644b1",
"e053520f0c3219d1cabd705b4523ccd05c8d703a70f6f3994a20774a42b5ccfc",
"7b0ad0514279280d5c2ac1da729222936b8768d9f465c6c6ade3b0ed7dc97263",
);
}
#[test]
#[rustfmt::skip]
fn assert_lr() {
assert_program_deserializable::<Core>(
&[
0xcd, 0x24, 0x08, 0x4b, 0x6f, 0x56, 0xdf, 0x77,
0xef, 0x56, 0xdf, 0x77, 0xef, 0x56, 0xdf, 0x77,
0xef, 0x56, 0xdf, 0x77, 0xef, 0x56, 0xdf, 0x77,
0xef, 0x56, 0xdf, 0x77, 0xef, 0x56, 0xdf, 0x77,
0xef, 0x56, 0xdf, 0x77, 0x86, 0x01, 0x80,
],
"c7194362a5480900dd44f9f647a49b8adcb92a25fb293c920e6bbcf6977cf63d",
"eaf95c23d967563132b65e43578fe08dae2a29ac66775ddd37af3ac7de28678b",
"d2927a9a54ddea8359ee00aa27e0aa1e354cc6924b090c759e2ed686712700a0",
);
assert_program_deserializable::<Core>(
&[
0xcd, 0x25, 0x08, 0x6d, 0xea, 0xdb, 0xee, 0xfd,
0xea, 0xdb, 0xee, 0xfd, 0xea, 0xdb, 0xee, 0xfd,
0xea, 0xdb, 0xee, 0xfd, 0xea, 0xdb, 0xee, 0xfd,
0xea, 0xdb, 0xee, 0xfd, 0xea, 0xdb, 0xee, 0xfd,
0xea, 0xdb, 0xee, 0xf4, 0x86, 0x01, 0x80,
],
"8e471ac519e0b16a2b7dda7e8d68165f260cae4823861ddc494b7c73a615b212",
"ea1ee417816a57b80739520c7319c33a39a5f4ce7b59856e69f768d5d8f174a6",
"f262f83f1c9341390e015e4c5126f3954e17a1f275af73da2948eaf4797fda48",
);
}
#[test]
#[rustfmt::skip]
fn disconnect() {
assert_program_deserializable::<Core>(
&[0xc5, 0x02, 0x06, 0x24, 0x10],
"a453360c0825cc2d3c4c907d67b174273b0e0386c7e5ecdb28394a8f37fd68b9",
"d5b05a5da87ee490312279496e12e17bc987c98219d8961bc3a7c3ec95a7ce1e",
"3579ae2a05bbe689f16bd3ff29d840ae8aa8bbad70f6de27b7473746637abeb6",
);
}
#[test]
#[rustfmt::skip]
#[cfg(feature = "elements")]
fn disconnect2() {
assert_program_deserializable::<crate::jet::Elements>(
&[
0xd3, 0x69, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x3b, 0x78, 0xce,
0x56, 0x3f, 0x89, 0xa0, 0xed, 0x94, 0x14, 0xf5,
0xaa, 0x28, 0xad, 0x0d, 0x96, 0xd6, 0x79, 0x5f,
0x9c, 0x63, 0x47, 0x07, 0x02, 0xc0, 0xe2, 0x8d,
0x88, 0x11, 0xe9, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x1d,
0xbc, 0x67, 0x2b, 0x1f, 0xc4, 0xd0, 0x76, 0xca,
0x0a, 0x7a, 0xd5, 0x14, 0x56, 0x86, 0xcb, 0x6b,
0x3c, 0xaf, 0xce, 0x31, 0xed, 0xc3, 0x46, 0xa2,
0xd0, 0x5e, 0x0e, 0x8c, 0x80, 0x98, 0x15, 0xe4,
0x3d, 0x43, 0x8e, 0x78, 0xac, 0x71, 0x5e, 0xf1,
0x67, 0xd3, 0x22, 0xd4, 0x4a, 0xe0, 0xda, 0x2e,
0xb4, 0x75, 0x12, 0x60, 0x00,
],
"3c77e90bcf5ff2bf45f6f30ecb093da96ff22509b5e981af0c21dddb84eec184",
"8e1ea76972cfe1684295784a59cb3c7229c9ab64bcdbc159278a7092b625d67c",
"dfb28b5859be539546f4fe9ce8c89083f021c76895be684d337087ffcfb4a7af",
);
}
#[test]
#[rustfmt::skip]
#[cfg(feature = "elements")]
fn disconnect3() {
assert_program_deserializable::<crate::jet::Elements>(
&[0xc9, 0x09, 0x20, 0x74, 0x90, 0x40],
"a8c9cc7a83518d0886afe1078d88eabca8353509e8c2e3b5c72cf559c713c9f5",
"97f77a7e7d7f3b2b1ac790bf54b39d47d6db8dcab7ed3c0a48df12f2c940af58",
"ed8152948589d65e0dea6d84f90eb752f63df818041f46bdc8f959f33299cbd3",
);
}
#[test]
#[cfg(feature = "elements")]
fn decode_schnorr() {
#[rustfmt::skip]
let schnorr0 = vec![
0xc6, 0xd5, 0xf2, 0x61, 0x14, 0x03, 0x24, 0xb1, 0x86, 0x20, 0x92, 0x68, 0x9f, 0x0b, 0xf1, 0x3a,
0xa4, 0x53, 0x6a, 0x63, 0x90, 0x8b, 0x06, 0xdf, 0x33, 0x61, 0x0c, 0x03, 0xe2, 0x27, 0x79, 0xc0,
0x6d, 0xf2, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0xe2, 0x8d, 0x8c, 0x04, 0x7a, 0x40, 0x1d, 0x20, 0xf0, 0x63, 0xf0, 0x10, 0x91, 0xa2,
0x0d, 0x34, 0xa6, 0xe3, 0x68, 0x04, 0x82, 0x06, 0xc9, 0x7b, 0xe3, 0x8b, 0xf0, 0x60, 0xf6, 0x01,
0x09, 0x8a, 0xbe, 0x39, 0xc5, 0xb9, 0x50, 0x42, 0xa4, 0xbe, 0xcd, 0x49, 0x50, 0xbd, 0x51, 0x6e,
0x3c, 0x90, 0x54, 0xe9, 0xe7, 0x05, 0xa5, 0x9c, 0xbd, 0x7d, 0xdd, 0x1f, 0xb6, 0x42, 0xe5, 0xe8,
0xef, 0xbe, 0x92, 0x01, 0xa6, 0x20, 0xa6, 0xd8, 0x00
];
assert_program_deserializable::<crate::jet::Elements>(
&schnorr0,
"dacbdfcf64122edf8efda2b34fe353cac4424dd455a9204fc92af258b465bbc4",
"097f231c68c5cd55fc23c70c6101463d3547046e62b90c43ed65c4c1c2aeea91",
"190bfc6677d227f1301ab6694f4de230b02277a8d2936517bddf9ebd16dc8250",
);
}
}