use crate::dag::{InternalSharing, PostOrderIterItem};
use crate::encode;
use crate::jet::Jet;
use crate::types::{self, arrow::Arrow};
use crate::{BitIter, BitWriter, Cmr, FailEntropy, Value};
use std::io;
use std::marker::PhantomData;
use std::sync::Arc;
use super::{
    Commit, CommitData, CommitNode, Converter, Inner, Marker, NoDisconnect, NoWitness, Node,
};
use super::{CoreConstructible, DisconnectConstructible, JetConstructible, WitnessConstructible};
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
pub enum ConstructId {}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
pub struct Construct<J> {
    never: std::convert::Infallible,
    phantom: std::marker::PhantomData<J>,
}
impl<J: Jet> Marker for Construct<J> {
    type CachedData = ConstructData<J>;
    type Witness = NoWitness;
    type Disconnect = Option<Arc<ConstructNode<J>>>;
    type SharingId = ConstructId;
    type Jet = J;
    fn compute_sharing_id(_: Cmr, _: &ConstructData<J>) -> Option<ConstructId> {
        None
    }
}
pub type ConstructNode<J> = Node<Construct<J>>;
impl<J: Jet> ConstructNode<J> {
    pub fn arrow(&self) -> &Arrow {
        self.data.arrow()
    }
    pub fn set_arrow_to_program(&self) -> Result<(), types::Error> {
        let unit_ty = types::Type::unit();
        self.arrow()
            .source
            .unify(&unit_ty, "setting root source to unit")?;
        self.arrow()
            .target
            .unify(&unit_ty, "setting root target to unit")?;
        Ok(())
    }
    pub fn finalize_types(&self) -> Result<Arc<CommitNode<J>>, crate::Error> {
        self.set_arrow_to_program()?;
        self.finalize_types_non_program()
    }
    pub fn finalize_types_non_program(&self) -> Result<Arc<CommitNode<J>>, crate::Error> {
        struct FinalizeTypes<J: Jet>(PhantomData<J>);
        impl<J: Jet> Converter<Construct<J>, Commit<J>> for FinalizeTypes<J> {
            type Error = crate::Error;
            fn convert_witness(
                &mut self,
                _: &PostOrderIterItem<&ConstructNode<J>>,
                _: &NoWitness,
            ) -> Result<NoWitness, Self::Error> {
                Ok(NoWitness)
            }
            fn convert_disconnect(
                &mut self,
                _: &PostOrderIterItem<&ConstructNode<J>>,
                maybe_converted: Option<&Arc<CommitNode<J>>>,
                _: &Option<Arc<ConstructNode<J>>>,
            ) -> Result<NoDisconnect, Self::Error> {
                if maybe_converted.is_some() {
                    Err(crate::Error::DisconnectCommitTime)
                } else {
                    Ok(NoDisconnect)
                }
            }
            fn convert_data(
                &mut self,
                data: &PostOrderIterItem<&ConstructNode<J>>,
                inner: Inner<&Arc<CommitNode<J>>, J, &NoDisconnect, &NoWitness>,
            ) -> Result<Arc<CommitData<J>>, Self::Error> {
                let converted_data = inner.map(|node| node.cached_data());
                CommitData::new(&data.node.data.arrow, converted_data)
                    .map(Arc::new)
                    .map_err(crate::Error::from)
            }
        }
        self.convert::<InternalSharing, _, _>(&mut FinalizeTypes(PhantomData))
    }
    pub fn decode<I: Iterator<Item = u8>>(
        bits: &mut BitIter<I>,
    ) -> Result<Arc<Self>, crate::decode::Error> {
        crate::decode::decode_expression(bits)
    }
    pub fn encode<W: io::Write>(&self, w: &mut BitWriter<W>) -> io::Result<usize> {
        let program_bits = encode::encode_program(self, w)?;
        w.flush_all()?;
        Ok(program_bits)
    }
}
#[derive(Clone, Debug)]
pub struct ConstructData<J> {
    arrow: Arrow,
    phantom: PhantomData<J>,
}
impl<J: Jet> ConstructData<J> {
    pub fn new(arrow: Arrow) -> Self {
        ConstructData {
            arrow,
            phantom: PhantomData,
        }
    }
    pub fn arrow(&self) -> &Arrow {
        &self.arrow
    }
}
impl<J> CoreConstructible for ConstructData<J> {
    fn iden() -> Self {
        ConstructData {
            arrow: Arrow::iden(),
            phantom: PhantomData,
        }
    }
    fn unit() -> Self {
        ConstructData {
            arrow: Arrow::unit(),
            phantom: PhantomData,
        }
    }
    fn injl(child: &Self) -> Self {
        ConstructData {
            arrow: Arrow::injl(&child.arrow),
            phantom: PhantomData,
        }
    }
    fn injr(child: &Self) -> Self {
        ConstructData {
            arrow: Arrow::injr(&child.arrow),
            phantom: PhantomData,
        }
    }
    fn take(child: &Self) -> Self {
        ConstructData {
            arrow: Arrow::take(&child.arrow),
            phantom: PhantomData,
        }
    }
    fn drop_(child: &Self) -> Self {
        ConstructData {
            arrow: Arrow::drop_(&child.arrow),
            phantom: PhantomData,
        }
    }
    fn comp(left: &Self, right: &Self) -> Result<Self, types::Error> {
        Ok(ConstructData {
            arrow: Arrow::comp(&left.arrow, &right.arrow)?,
            phantom: PhantomData,
        })
    }
    fn case(left: &Self, right: &Self) -> Result<Self, types::Error> {
        Ok(ConstructData {
            arrow: Arrow::case(&left.arrow, &right.arrow)?,
            phantom: PhantomData,
        })
    }
    fn assertl(left: &Self, right: Cmr) -> Result<Self, types::Error> {
        Ok(ConstructData {
            arrow: Arrow::assertl(&left.arrow, right)?,
            phantom: PhantomData,
        })
    }
    fn assertr(left: Cmr, right: &Self) -> Result<Self, types::Error> {
        Ok(ConstructData {
            arrow: Arrow::assertr(left, &right.arrow)?,
            phantom: PhantomData,
        })
    }
    fn pair(left: &Self, right: &Self) -> Result<Self, types::Error> {
        Ok(ConstructData {
            arrow: Arrow::pair(&left.arrow, &right.arrow)?,
            phantom: PhantomData,
        })
    }
    fn fail(entropy: FailEntropy) -> Self {
        ConstructData {
            arrow: Arrow::fail(entropy),
            phantom: PhantomData,
        }
    }
    fn const_word(word: Arc<Value>) -> Self {
        ConstructData {
            arrow: Arrow::const_word(word),
            phantom: PhantomData,
        }
    }
}
impl<J: Jet> DisconnectConstructible<Option<Arc<ConstructNode<J>>>> for ConstructData<J> {
    fn disconnect(
        left: &Self,
        right: &Option<Arc<ConstructNode<J>>>,
    ) -> Result<Self, types::Error> {
        let right = right.as_ref();
        Ok(ConstructData {
            arrow: Arrow::disconnect(&left.arrow, &right.map(|n| n.arrow()))?,
            phantom: PhantomData,
        })
    }
}
impl<J> WitnessConstructible<NoWitness> for ConstructData<J> {
    fn witness(witness: NoWitness) -> Self {
        ConstructData {
            arrow: Arrow::witness(witness),
            phantom: PhantomData,
        }
    }
}
impl<J: Jet> JetConstructible<J> for ConstructData<J> {
    fn jet(jet: J) -> Self {
        ConstructData {
            arrow: Arrow::jet(jet),
            phantom: PhantomData,
        }
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use crate::jet::Core;
    #[test]
    fn occurs_check_error() {
        let iden = Arc::<ConstructNode<Core>>::iden();
        let node = Arc::<ConstructNode<Core>>::disconnect(&iden, &Some(Arc::clone(&iden))).unwrap();
        assert!(matches!(
            node.finalize_types_non_program(),
            Err(crate::Error::Type(types::Error::OccursCheck)),
        ));
    }
    #[test]
    fn occurs_check_2() {
        let iden = Arc::<ConstructNode<Core>>::iden();
        let injr = Arc::<ConstructNode<Core>>::injr(&iden);
        let pair = Arc::<ConstructNode<Core>>::pair(&injr, &iden).unwrap();
        let drop = Arc::<ConstructNode<Core>>::drop_(&pair);
        let case1 = Arc::<ConstructNode<Core>>::case(&drop, &drop).unwrap();
        let case2 = Arc::<ConstructNode<Core>>::case(&case1, &case1).unwrap();
        let comp1 = Arc::<ConstructNode<Core>>::comp(&case2, &case2).unwrap();
        let comp2 = Arc::<ConstructNode<Core>>::comp(&comp1, &case1).unwrap();
        assert!(matches!(
            comp2.finalize_types_non_program(),
            Err(crate::Error::Type(types::Error::OccursCheck)),
        ));
    }
    #[test]
    fn occurs_check_3() {
        let wit = Arc::<ConstructNode<Core>>::witness(NoWitness);
        let drop = Arc::<ConstructNode<Core>>::drop_(&wit);
        let comp1 = Arc::<ConstructNode<Core>>::comp(&drop, &drop).unwrap();
        let comp2 = Arc::<ConstructNode<Core>>::comp(&comp1, &comp1).unwrap();
        let comp3 = Arc::<ConstructNode<Core>>::comp(&comp2, &comp2).unwrap();
        let comp4 = Arc::<ConstructNode<Core>>::comp(&comp3, &comp3).unwrap();
        let comp5 = Arc::<ConstructNode<Core>>::comp(&comp4, &comp4).unwrap();
        let case = Arc::<ConstructNode<Core>>::case(&comp5, &comp4).unwrap();
        let drop2 = Arc::<ConstructNode<Core>>::drop_(&case);
        let case2 = Arc::<ConstructNode<Core>>::case(&drop2, &case).unwrap();
        let comp6 = Arc::<ConstructNode<Core>>::comp(&case2, &case2).unwrap();
        let case3 = Arc::<ConstructNode<Core>>::case(&comp6, &comp6).unwrap();
        let comp7 = Arc::<ConstructNode<Core>>::comp(&case3, &case3).unwrap();
        let comp8 = Arc::<ConstructNode<Core>>::comp(&comp7, &comp7).unwrap();
        assert!(matches!(
            comp8.finalize_types_non_program(),
            Err(crate::Error::Type(types::Error::OccursCheck)),
        ));
    }
    #[test]
    fn type_check_error() {
        let unit = Arc::<ConstructNode<Core>>::unit();
        let case = Arc::<ConstructNode<Core>>::case(&unit, &unit).unwrap();
        assert!(matches!(
            Arc::<ConstructNode<Core>>::disconnect(&case, &Some(unit)),
            Err(types::Error::Bind { .. }),
        ));
    }
    #[test]
    fn scribe() {
        let unit = Arc::<ConstructNode<Core>>::unit();
        let bit0 = Arc::<ConstructNode<Core>>::injl(&unit);
        let bit1 = Arc::<ConstructNode<Core>>::injr(&unit);
        let bits01 = Arc::<ConstructNode<Core>>::pair(&bit0, &bit1).unwrap();
        assert_eq!(
            unit.cmr(),
            Arc::<ConstructNode<Core>>::scribe(&Value::Unit).cmr()
        );
        assert_eq!(
            bit0.cmr(),
            Arc::<ConstructNode<Core>>::scribe(&Value::u1(0)).cmr()
        );
        assert_eq!(
            bit1.cmr(),
            Arc::<ConstructNode<Core>>::scribe(&Value::u1(1)).cmr()
        );
        assert_eq!(
            bits01.cmr(),
            Arc::<ConstructNode<Core>>::scribe(&Value::u2(1)).cmr()
        );
    }
}