mod frame;
use crate::analysis;
use crate::dag::{DagLike, NoSharing};
use crate::jet::{Jet, JetFailed};
use crate::node::{self, RedeemNode};
use crate::{Cmr, FailEntropy, Value};
use frame::Frame;
use std::fmt;
use std::sync::Arc;
use std::{cmp, error};
pub struct BitMachine {
    data: Vec<u8>,
    next_frame_start: usize,
    read: Vec<Frame>,
    write: Vec<Frame>,
}
impl BitMachine {
    pub fn for_program<J: Jet>(program: &RedeemNode<J>) -> Self {
        let io_width = program.arrow().source.bit_width() + program.arrow().target.bit_width();
        Self {
            data: vec![0; (io_width + program.bounds().extra_cells + 7) / 8],
            next_frame_start: 0,
            read: Vec::with_capacity(program.bounds().extra_frames + analysis::IO_EXTRA_FRAMES),
            write: Vec::with_capacity(program.bounds().extra_frames + analysis::IO_EXTRA_FRAMES),
        }
    }
    #[cfg(test)]
    pub fn test_exec<J: Jet>(
        program: Arc<crate::node::ConstructNode<J>>,
        env: &J::Environment,
    ) -> Result<Arc<Value>, ExecutionError> {
        use crate::node::SimpleFinalizer;
        let prog = program
            .finalize_types_non_program()
            .expect("finalizing types")
            .finalize(&mut SimpleFinalizer::new(None.into_iter()))
            .expect("finalizing");
        let mut mac = BitMachine::for_program(&prog);
        mac.exec(&prog, env)
    }
    fn new_frame(&mut self, len: usize) {
        debug_assert!(
            self.next_frame_start + len <= self.data.len() * 8,
            "Data out of bounds: number of cells"
        );
        debug_assert!(
            self.write.len() + self.read.len() < self.read.capacity(),
            "Stacks out of bounds: number of frames"
        );
        self.write.push(Frame::new(self.next_frame_start, len));
        self.next_frame_start += len;
    }
    fn move_frame(&mut self) {
        let mut _active_write_frame = self.write.pop().unwrap();
        _active_write_frame.reset_cursor();
        self.read.push(_active_write_frame);
    }
    fn drop_frame(&mut self) {
        let active_read_frame = self.read.pop().unwrap();
        self.next_frame_start -= active_read_frame.len;
        assert_eq!(self.next_frame_start, active_read_frame.start);
    }
    fn write_bit(&mut self, bit: bool) {
        self.write
            .last_mut()
            .expect("Empty write frame stack")
            .write_bit(bit, &mut self.data);
    }
    fn skip(&mut self, n: usize) {
        if n == 0 {
            return;
        }
        let idx = self.write.len() - 1;
        self.write[idx].move_cursor_forward(n);
    }
    fn copy(&mut self, n: usize) {
        if n == 0 {
            return;
        }
        let widx = self.write.len() - 1;
        let ridx = self.read.len() - 1;
        self.write[widx].copy_from(&self.read[ridx], n, &mut self.data);
    }
    fn fwd(&mut self, n: usize) {
        if n == 0 {
            return;
        }
        let idx = self.read.len() - 1;
        self.read[idx].move_cursor_forward(n);
    }
    fn back(&mut self, n: usize) {
        if n == 0 {
            return;
        }
        let idx = self.read.len() - 1;
        self.read[idx].move_cursor_backward(n);
    }
    fn write_u8(&mut self, value: u8) {
        self.write
            .last_mut()
            .expect("Empty write frame stack")
            .write_u8(value, &mut self.data);
    }
    fn read_bit(&mut self) -> bool {
        self.read
            .last_mut()
            .expect("Empty read frame stack")
            .read_bit(&self.data)
    }
    fn write_bytes(&mut self, bytes: &[u8]) {
        for bit in bytes {
            self.write_u8(*bit);
        }
    }
    fn write_value(&mut self, val: &Value) {
        for val in val.pre_order_iter::<NoSharing>() {
            match val {
                Value::Unit => {}
                Value::SumL(..) => self.write_bit(false),
                Value::SumR(..) => self.write_bit(true),
                Value::Prod(..) => {}
            }
        }
    }
    pub fn input(&mut self, input: &Value) {
        self.new_frame(input.len());
        self.write_value(input);
        self.move_frame();
    }
    pub fn exec<J: Jet + std::fmt::Debug>(
        &mut self,
        program: &RedeemNode<J>,
        env: &J::Environment,
    ) -> Result<Arc<Value>, ExecutionError> {
        enum CallStack<'a, J: Jet> {
            Goto(&'a RedeemNode<J>),
            MoveFrame,
            DropFrame,
            CopyFwd(usize),
            Back(usize),
        }
        impl<'a, J: Jet> fmt::Debug for CallStack<'a, J> {
            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
                match self {
                    CallStack::Goto(ins) => write!(f, "goto {}", ins.inner()),
                    CallStack::MoveFrame => f.write_str("move frame"),
                    CallStack::DropFrame => f.write_str("drop frame"),
                    CallStack::CopyFwd(n) => write!(f, "copy/fwd {}", n),
                    CallStack::Back(n) => write!(f, "back {}", n),
                }
            }
        }
        let mut ip = program;
        let mut call_stack = vec![];
        let mut iterations = 0u64;
        let input_width = ip.arrow().source.bit_width();
        assert!(
            self.read.is_empty() || input_width > 0,
            "Program requires a non-empty input to execute",
        );
        let output_width = ip.arrow().target.bit_width();
        if output_width > 0 {
            self.new_frame(output_width);
        }
        'main_loop: loop {
            iterations += 1;
            if iterations % 1_000_000_000 == 0 {
                println!("({:5} M) exec {:?}", iterations / 1_000_000, ip);
            }
            match ip.inner() {
                node::Inner::Unit => {}
                node::Inner::Iden => {
                    let size_a = ip.arrow().source.bit_width();
                    self.copy(size_a);
                }
                node::Inner::InjL(left) => {
                    let (b, _c) = ip.arrow().target.split_sum().unwrap();
                    let padl_b_c = ip.arrow().target.bit_width() - b.bit_width() - 1;
                    self.write_bit(false);
                    self.skip(padl_b_c);
                    call_stack.push(CallStack::Goto(left));
                }
                node::Inner::InjR(left) => {
                    let (_b, c) = ip.arrow().target.split_sum().unwrap();
                    let padr_b_c = ip.arrow().target.bit_width() - c.bit_width() - 1;
                    self.write_bit(true);
                    self.skip(padr_b_c);
                    call_stack.push(CallStack::Goto(left));
                }
                node::Inner::Pair(left, right) => {
                    call_stack.push(CallStack::Goto(right));
                    call_stack.push(CallStack::Goto(left));
                }
                node::Inner::Comp(left, right) => {
                    let size_b = left.arrow().target.bit_width();
                    self.new_frame(size_b);
                    call_stack.push(CallStack::DropFrame);
                    call_stack.push(CallStack::Goto(right));
                    call_stack.push(CallStack::MoveFrame);
                    call_stack.push(CallStack::Goto(left));
                }
                node::Inner::Disconnect(left, right) => {
                    let size_prod_256_a = left.arrow().source.bit_width();
                    let size_a = size_prod_256_a - 256;
                    let size_prod_b_c = left.arrow().target.bit_width();
                    let size_b = size_prod_b_c - right.arrow().source.bit_width();
                    self.new_frame(size_prod_256_a);
                    self.write_bytes(right.cmr().as_ref());
                    self.copy(size_a);
                    self.move_frame();
                    self.new_frame(size_prod_b_c);
                    call_stack.push(CallStack::DropFrame);
                    call_stack.push(CallStack::DropFrame);
                    call_stack.push(CallStack::Goto(right));
                    call_stack.push(CallStack::CopyFwd(size_b));
                    call_stack.push(CallStack::MoveFrame);
                    call_stack.push(CallStack::Goto(left));
                }
                node::Inner::Take(left) => call_stack.push(CallStack::Goto(left)),
                node::Inner::Drop(left) => {
                    let size_a = ip.arrow().source.split_product().unwrap().0.bit_width();
                    self.fwd(size_a);
                    call_stack.push(CallStack::Back(size_a));
                    call_stack.push(CallStack::Goto(left));
                }
                node::Inner::Case(..) | node::Inner::AssertL(..) | node::Inner::AssertR(..) => {
                    let choice_bit = self.read[self.read.len() - 1].peek_bit(&self.data);
                    let (sum_a_b, _c) = ip.arrow().source.split_product().unwrap();
                    let (a, b) = sum_a_b.split_sum().unwrap();
                    let size_a = a.bit_width();
                    let size_b = b.bit_width();
                    match (ip.inner(), choice_bit) {
                        (node::Inner::Case(_, right), true)
                        | (node::Inner::AssertR(_, right), true) => {
                            let padr_a_b = cmp::max(size_a, size_b) - size_b;
                            self.fwd(1 + padr_a_b);
                            call_stack.push(CallStack::Back(1 + padr_a_b));
                            call_stack.push(CallStack::Goto(right));
                        }
                        (node::Inner::Case(left, _), false)
                        | (node::Inner::AssertL(left, _), false) => {
                            let padl_a_b = cmp::max(size_a, size_b) - size_a;
                            self.fwd(1 + padl_a_b);
                            call_stack.push(CallStack::Back(1 + padl_a_b));
                            call_stack.push(CallStack::Goto(left));
                        }
                        (node::Inner::AssertL(_, r_cmr), true) => {
                            return Err(ExecutionError::ReachedPrunedBranch(*r_cmr))
                        }
                        (node::Inner::AssertR(l_cmr, _), false) => {
                            return Err(ExecutionError::ReachedPrunedBranch(*l_cmr))
                        }
                        _ => unreachable!(),
                    }
                }
                node::Inner::Witness(value) => self.write_value(value),
                node::Inner::Jet(jet) => self.exec_jet(*jet, env)?,
                node::Inner::Word(value) => self.write_value(value),
                node::Inner::Fail(entropy) => {
                    return Err(ExecutionError::ReachedFailNode(*entropy))
                }
            }
            ip = loop {
                match call_stack.pop() {
                    Some(CallStack::Goto(next)) => break next,
                    Some(CallStack::MoveFrame) => self.move_frame(),
                    Some(CallStack::DropFrame) => self.drop_frame(),
                    Some(CallStack::CopyFwd(n)) => {
                        self.copy(n);
                        self.fwd(n);
                    }
                    Some(CallStack::Back(n)) => self.back(n),
                    None => break 'main_loop,
                };
            };
        }
        if output_width > 0 {
            let out_frame = self.write.last_mut().unwrap();
            out_frame.reset_cursor();
            let value = out_frame
                .as_bit_iter(&self.data)
                .read_value(&program.arrow().target)
                .expect("Decode value of output frame");
            Ok(value)
        } else {
            Ok(Value::unit())
        }
    }
    fn exec_jet<J: Jet>(&mut self, jet: J, env: &J::Environment) -> Result<(), JetFailed> {
        use simplicity_sys::c_jets::frame_ffi::{c_readBit, c_writeBit, CFrameItem};
        use simplicity_sys::c_jets::round_u_word;
        if !simplicity_sys::c_jets::sanity_checks() {
            return Err(JetFailed);
        }
        let src_ty_bit_width = jet.source_ty().to_bit_width();
        let target_ty_bit_width = jet.target_ty().to_bit_width();
        let a_frame_size = round_u_word(src_ty_bit_width);
        let b_frame_size = round_u_word(target_ty_bit_width);
        if a_frame_size == 0 && b_frame_size == 0 {
            return Ok(());
        }
        let mut src_buf = vec![0usize; a_frame_size + b_frame_size];
        let src_ptr_end = unsafe { src_buf.as_mut_ptr().add(a_frame_size) }; let src_ptr = src_buf.as_mut_ptr(); let dst_ptr_begin = unsafe { src_buf.as_mut_ptr().add(a_frame_size) }; let dst_ptr_end = unsafe { src_buf.as_mut_ptr().add(a_frame_size + b_frame_size) }; let mut a_frame = unsafe { CFrameItem::new_write(src_ty_bit_width, src_ptr_end) };
        for _ in 0..src_ty_bit_width {
            let bit = self.read_bit();
            unsafe {
                c_writeBit(&mut a_frame, bit);
            }
        }
        self.back(src_ty_bit_width);
        let src_frame = unsafe { CFrameItem::new_read(src_ty_bit_width, src_ptr) };
        let mut dst_frame = unsafe { CFrameItem::new_write(target_ty_bit_width, dst_ptr_end) };
        let jet_fn = jet.c_jet_ptr();
        let c_env = jet.c_jet_env(env);
        let res = jet_fn(&mut dst_frame, src_frame, c_env);
        if !res {
            return Err(JetFailed);
        }
        let mut b_frame = unsafe { CFrameItem::new_read(target_ty_bit_width, dst_ptr_begin) };
        for _ in 0..target_ty_bit_width {
            let bit = unsafe { c_readBit(&mut b_frame) };
            self.write_bit(bit);
        }
        Ok(())
    }
}
#[derive(Debug)]
pub enum ExecutionError {
    ReachedFailNode(FailEntropy),
    ReachedPrunedBranch(Cmr),
    JetFailed(JetFailed),
}
impl fmt::Display for ExecutionError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            ExecutionError::ReachedFailNode(entropy) => {
                write!(f, "Execution reached a fail node: {}", entropy)
            }
            ExecutionError::ReachedPrunedBranch(hash) => {
                write!(f, "Execution reached a pruned branch: {}", hash)
            }
            ExecutionError::JetFailed(jet_failed) => fmt::Display::fmt(jet_failed, f),
        }
    }
}
impl error::Error for ExecutionError {}
impl From<JetFailed> for ExecutionError {
    fn from(jet_failed: JetFailed) -> Self {
        ExecutionError::JetFailed(jet_failed)
    }
}
#[cfg(test)]
mod tests {
    #[cfg(feature = "elements")]
    use super::*;
    #[cfg(feature = "elements")]
    use crate::jet::{elements::ElementsEnv, Elements};
    #[cfg(feature = "elements")]
    use crate::{node::RedeemNode, BitIter};
    #[cfg(feature = "elements")]
    use hex::DisplayHex;
    #[cfg(feature = "elements")]
    fn run_program_elements(
        prog_bytes: &[u8],
        cmr_str: &str,
        amr_str: &str,
        imr_str: &str,
    ) -> Result<Arc<Value>, ExecutionError> {
        let prog_hex = prog_bytes.as_hex();
        let mut iter = BitIter::from(prog_bytes);
        let prog = match RedeemNode::<Elements>::decode(&mut iter) {
            Ok(prog) => prog,
            Err(e) => panic!("program {} failed: {}", prog_hex, e),
        };
        assert_eq!(
            prog.cmr().to_string(),
            cmr_str,
            "CMR mismatch (got {} expected {}) for program {}",
            prog.cmr(),
            cmr_str,
            prog_hex,
        );
        assert_eq!(
            prog.imr().to_string(),
            imr_str,
            "IMR mismatch (got {} expected {}) for program {}",
            prog.imr(),
            imr_str,
            prog_hex,
        );
        assert_eq!(
            prog.amr().to_string(),
            amr_str,
            "AMR mismatch (got {} expected {}) for program {}",
            prog.amr(),
            amr_str,
            prog_hex,
        );
        let env = ElementsEnv::dummy();
        BitMachine::for_program(&prog).exec(&prog, &env)
    }
    #[test]
    #[cfg(feature = "elements")]
    fn crash_regression1() {
        let res = run_program_elements(
            &[0xcf, 0xe1, 0x8f, 0xb4, 0x40, 0x28, 0x87, 0x04, 0x00],
            "ec48102095c13fcdc1d539de2848ae287acdea249e2cda6f0d8f34bccd292294",
            "abd217b5ea14d7da249a03e16dd047b136a2efec4b82c1b60675297d782b51ad",
            "dea130f31a0754ea2f82ad570f7a4882c9e465b6bdd6f8be4d6d68342a57dff3",
        );
        assert_eq!(res.unwrap(), Value::unit());
    }
}