mod frame;
use std::error;
use std::fmt;
use std::sync::Arc;
use crate::analysis;
use crate::jet::{Jet, JetFailed};
use crate::node::{self, RedeemNode};
use crate::types::Final;
use crate::{Cmr, FailEntropy, Value};
use frame::Frame;
pub struct BitMachine {
data: Vec<u8>,
next_frame_start: usize,
read: Vec<Frame>,
write: Vec<Frame>,
source_ty: Arc<Final>,
}
impl BitMachine {
pub fn for_program<J: Jet>(program: &RedeemNode<J>) -> Self {
let io_width = program.arrow().source.bit_width() + program.arrow().target.bit_width();
Self {
data: vec![0; (io_width + program.bounds().extra_cells + 7) / 8],
next_frame_start: 0,
read: Vec::with_capacity(program.bounds().extra_frames + analysis::IO_EXTRA_FRAMES),
write: Vec::with_capacity(program.bounds().extra_frames + analysis::IO_EXTRA_FRAMES),
source_ty: program.arrow().source.clone(),
}
}
#[cfg(test)]
pub fn test_exec<J: Jet>(
program: Arc<crate::node::ConstructNode<J>>,
env: &J::Environment,
) -> Result<Value, ExecutionError> {
use crate::node::SimpleFinalizer;
let prog = program
.finalize_types_non_program()
.expect("finalizing types")
.finalize(&mut SimpleFinalizer::new(None.into_iter()))
.expect("finalizing");
let mut mac = BitMachine::for_program(&prog);
mac.exec(&prog, env)
}
fn new_frame(&mut self, len: usize) {
debug_assert!(
self.next_frame_start + len <= self.data.len() * 8,
"Data out of bounds: number of cells"
);
debug_assert!(
self.write.len() + self.read.len() < self.read.capacity(),
"Stacks out of bounds: number of frames"
);
self.write.push(Frame::new(self.next_frame_start, len));
self.next_frame_start += len;
}
fn move_frame(&mut self) {
let mut _active_write_frame = self.write.pop().unwrap();
_active_write_frame.reset_cursor();
self.read.push(_active_write_frame);
}
fn drop_frame(&mut self) {
let active_read_frame = self.read.pop().unwrap();
self.next_frame_start -= active_read_frame.bit_width();
assert_eq!(self.next_frame_start, active_read_frame.start());
}
fn write_bit(&mut self, bit: bool) {
self.write
.last_mut()
.expect("Empty write frame stack")
.write_bit(bit, &mut self.data);
}
fn skip(&mut self, n: usize) {
if n == 0 {
return;
}
let idx = self.write.len() - 1;
self.write[idx].move_cursor_forward(n);
}
fn copy(&mut self, n: usize) {
if n == 0 {
return;
}
let widx = self.write.len() - 1;
let ridx = self.read.len() - 1;
self.write[widx].copy_from(&self.read[ridx], n, &mut self.data);
}
fn fwd(&mut self, n: usize) {
if n == 0 {
return;
}
let idx = self.read.len() - 1;
self.read[idx].move_cursor_forward(n);
}
fn back(&mut self, n: usize) {
if n == 0 {
return;
}
let idx = self.read.len() - 1;
self.read[idx].move_cursor_backward(n);
}
fn write_u8(&mut self, value: u8) {
self.write
.last_mut()
.expect("Empty write frame stack")
.write_u8(value, &mut self.data);
}
fn read_bit(&mut self) -> bool {
self.read
.last_mut()
.expect("Empty read frame stack")
.read_bit(&self.data)
}
fn write_bytes(&mut self, bytes: &[u8]) {
for bit in bytes {
self.write_u8(*bit);
}
}
fn write_value(&mut self, val: &Value) {
for bit in val.iter_padded() {
self.write_bit(bit);
}
}
fn active_read_bit_width(&self) -> usize {
self.read.last().map(|frame| frame.bit_width()).unwrap_or(0)
}
fn active_write_bit_width(&self) -> usize {
self.write
.last()
.map(|frame| frame.bit_width())
.unwrap_or(0)
}
pub fn input(&mut self, input: &Value) -> Result<(), ExecutionError> {
if !input.is_of_type(&self.source_ty) {
return Err(ExecutionError::InputWrongType(self.source_ty.clone()));
}
if !input.is_empty() {
self.new_frame(input.padded_len());
self.write_value(input);
self.move_frame();
}
Ok(())
}
pub fn exec<J: Jet + std::fmt::Debug>(
&mut self,
program: &RedeemNode<J>,
env: &J::Environment,
) -> Result<Value, ExecutionError> {
enum CallStack<'a, J: Jet> {
Goto(&'a RedeemNode<J>),
MoveFrame,
DropFrame,
CopyFwd(usize),
Back(usize),
}
impl<'a, J: Jet> fmt::Debug for CallStack<'a, J> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
CallStack::Goto(ins) => write!(f, "goto {}", ins.inner()),
CallStack::MoveFrame => f.write_str("move frame"),
CallStack::DropFrame => f.write_str("drop frame"),
CallStack::CopyFwd(n) => write!(f, "copy/fwd {}", n),
CallStack::Back(n) => write!(f, "back {}", n),
}
}
}
if self.read.is_empty() != self.source_ty.is_empty() {
return Err(ExecutionError::InputWrongType(self.source_ty.clone()));
}
let mut ip = program;
let mut call_stack = vec![];
let output_width = ip.arrow().target.bit_width();
if output_width > 0 {
self.new_frame(output_width);
}
'main_loop: loop {
match ip.inner() {
node::Inner::Unit => {}
node::Inner::Iden => {
let size_a = ip.arrow().source.bit_width();
self.copy(size_a);
}
node::Inner::InjL(left) => {
let (b, c) = ip.arrow().target.as_sum().unwrap();
self.write_bit(false);
self.skip(b.pad_left(c));
call_stack.push(CallStack::Goto(left));
}
node::Inner::InjR(left) => {
let (b, c) = ip.arrow().target.as_sum().unwrap();
self.write_bit(true);
self.skip(b.pad_right(c));
call_stack.push(CallStack::Goto(left));
}
node::Inner::Pair(left, right) => {
call_stack.push(CallStack::Goto(right));
call_stack.push(CallStack::Goto(left));
}
node::Inner::Comp(left, right) => {
let size_b = left.arrow().target.bit_width();
self.new_frame(size_b);
call_stack.push(CallStack::DropFrame);
call_stack.push(CallStack::Goto(right));
call_stack.push(CallStack::MoveFrame);
call_stack.push(CallStack::Goto(left));
}
node::Inner::Disconnect(left, right) => {
let size_prod_256_a = left.arrow().source.bit_width();
let size_a = size_prod_256_a - 256;
let size_prod_b_c = left.arrow().target.bit_width();
let size_b = size_prod_b_c - right.arrow().source.bit_width();
self.new_frame(size_prod_256_a);
self.write_bytes(right.cmr().as_ref());
self.copy(size_a);
self.move_frame();
self.new_frame(size_prod_b_c);
call_stack.push(CallStack::DropFrame);
call_stack.push(CallStack::DropFrame);
call_stack.push(CallStack::Goto(right));
call_stack.push(CallStack::CopyFwd(size_b));
call_stack.push(CallStack::MoveFrame);
call_stack.push(CallStack::Goto(left));
}
node::Inner::Take(left) => call_stack.push(CallStack::Goto(left)),
node::Inner::Drop(left) => {
let size_a = ip.arrow().source.as_product().unwrap().0.bit_width();
self.fwd(size_a);
call_stack.push(CallStack::Back(size_a));
call_stack.push(CallStack::Goto(left));
}
node::Inner::Case(..) | node::Inner::AssertL(..) | node::Inner::AssertR(..) => {
let choice_bit = self.read[self.read.len() - 1].peek_bit(&self.data);
let (sum_a_b, _c) = ip.arrow().source.as_product().unwrap();
let (a, b) = sum_a_b.as_sum().unwrap();
match (ip.inner(), choice_bit) {
(node::Inner::Case(_, right), true)
| (node::Inner::AssertR(_, right), true) => {
self.fwd(1 + a.pad_right(b));
call_stack.push(CallStack::Back(1 + a.pad_right(b)));
call_stack.push(CallStack::Goto(right));
}
(node::Inner::Case(left, _), false)
| (node::Inner::AssertL(left, _), false) => {
self.fwd(1 + a.pad_left(b));
call_stack.push(CallStack::Back(1 + a.pad_left(b)));
call_stack.push(CallStack::Goto(left));
}
(node::Inner::AssertL(_, r_cmr), true) => {
return Err(ExecutionError::ReachedPrunedBranch(*r_cmr))
}
(node::Inner::AssertR(l_cmr, _), false) => {
return Err(ExecutionError::ReachedPrunedBranch(*l_cmr))
}
_ => unreachable!(),
}
}
node::Inner::Witness(value) => self.write_value(value),
node::Inner::Jet(jet) => self.exec_jet(*jet, env)?,
node::Inner::Word(value) => self.write_value(value),
node::Inner::Fail(entropy) => {
return Err(ExecutionError::ReachedFailNode(*entropy))
}
}
ip = loop {
match call_stack.pop() {
Some(CallStack::Goto(next)) => break next,
Some(CallStack::MoveFrame) => self.move_frame(),
Some(CallStack::DropFrame) => self.drop_frame(),
Some(CallStack::CopyFwd(n)) => {
self.copy(n);
self.fwd(n);
}
Some(CallStack::Back(n)) => self.back(n),
None => break 'main_loop,
};
};
}
if output_width > 0 {
let out_frame = self.write.last_mut().unwrap();
out_frame.reset_cursor();
let value = Value::from_padded_bits(
&mut out_frame.as_bit_iter(&self.data),
&program.arrow().target,
)
.expect("Decode value of output frame");
Ok(value)
} else {
Ok(Value::unit())
}
}
fn exec_jet<J: Jet>(&mut self, jet: J, env: &J::Environment) -> Result<(), JetFailed> {
use crate::ffi::c_jets::frame_ffi::{c_readBit, c_writeBit, CFrameItem};
use crate::ffi::c_jets::uword_width;
use crate::ffi::ffi::UWORD;
unsafe fn get_input_frame(
mac: &mut BitMachine,
bit_width: usize,
) -> (CFrameItem, Vec<UWORD>) {
assert!(bit_width <= mac.active_read_bit_width());
let uword_width = uword_width(bit_width);
let mut buffer = vec![0; uword_width];
let buffer_end = buffer.as_mut_ptr().add(uword_width);
let mut write_frame = CFrameItem::new_write(bit_width, buffer_end);
for _ in 0..bit_width {
let bit = mac.read_bit();
c_writeBit(&mut write_frame, bit);
}
mac.back(bit_width);
let buffer_ptr = buffer.as_mut_ptr();
let read_frame = CFrameItem::new_read(bit_width, buffer_ptr);
(read_frame, buffer)
}
unsafe fn get_output_frame(bit_width: usize) -> (CFrameItem, Vec<UWORD>) {
let uword_width = uword_width(bit_width);
let mut buffer = vec![0; uword_width];
let buffer_end = buffer.as_mut_ptr().add(uword_width);
let write_frame = CFrameItem::new_write(bit_width, buffer_end);
(write_frame, buffer)
}
fn update_active_write_frame(mac: &mut BitMachine, bit_width: usize, buffer: &[UWORD]) {
assert!(bit_width <= mac.active_write_bit_width());
assert!(uword_width(bit_width) <= buffer.len());
let buffer_ptr = buffer.as_ptr();
let mut read_frame = unsafe { CFrameItem::new_read(bit_width, buffer_ptr) };
for _ in 0..bit_width {
let bit = unsafe { c_readBit(&mut read_frame) };
mac.write_bit(bit);
}
}
if !simplicity_sys::c_jets::sanity_checks() {
return Err(JetFailed);
}
let input_width = jet.source_ty().to_bit_width();
let output_width = jet.target_ty().to_bit_width();
let (input_read_frame, _input_buffer) = unsafe { get_input_frame(self, input_width) };
let (mut output_write_frame, output_buffer) = unsafe { get_output_frame(output_width) };
let jet_fn = jet.c_jet_ptr();
let c_env = J::c_jet_env(env);
let success = jet_fn(&mut output_write_frame, input_read_frame, c_env);
if !success {
Err(JetFailed)
} else {
update_active_write_frame(self, output_width, &output_buffer);
Ok(())
}
}
}
#[derive(Debug)]
pub enum ExecutionError {
InputWrongType(Arc<Final>),
ReachedFailNode(FailEntropy),
ReachedPrunedBranch(Cmr),
JetFailed(JetFailed),
}
impl fmt::Display for ExecutionError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ExecutionError::InputWrongType(expected_ty) => {
write!(f, "Expected input of type: {expected_ty}")
}
ExecutionError::ReachedFailNode(entropy) => {
write!(f, "Execution reached a fail node: {}", entropy)
}
ExecutionError::ReachedPrunedBranch(hash) => {
write!(f, "Execution reached a pruned branch: {}", hash)
}
ExecutionError::JetFailed(jet_failed) => fmt::Display::fmt(jet_failed, f),
}
}
}
impl error::Error for ExecutionError {}
impl From<JetFailed> for ExecutionError {
fn from(jet_failed: JetFailed) -> Self {
ExecutionError::JetFailed(jet_failed)
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "elements")]
use super::*;
#[cfg(feature = "elements")]
use crate::jet::{elements::ElementsEnv, Elements};
#[cfg(feature = "elements")]
use crate::{node::RedeemNode, BitIter};
#[cfg(feature = "elements")]
use hex::DisplayHex;
#[cfg(feature = "elements")]
fn run_program_elements(
prog_bytes: &[u8],
witness_bytes: &[u8],
cmr_str: &str,
amr_str: &str,
imr_str: &str,
) -> Result<Value, ExecutionError> {
let prog_hex = prog_bytes.as_hex();
let prog = BitIter::from(prog_bytes);
let witness = BitIter::from(witness_bytes);
let prog = match RedeemNode::<Elements>::decode(prog, witness) {
Ok(prog) => prog,
Err(e) => panic!("program {} failed: {}", prog_hex, e),
};
assert_eq!(
prog.cmr().to_string(),
cmr_str,
"CMR mismatch (got {} expected {}) for program {}",
prog.cmr(),
cmr_str,
prog_hex,
);
assert_eq!(
prog.imr().to_string(),
imr_str,
"IMR mismatch (got {} expected {}) for program {}",
prog.imr(),
imr_str,
prog_hex,
);
assert_eq!(
prog.amr().to_string(),
amr_str,
"AMR mismatch (got {} expected {}) for program {}",
prog.amr(),
amr_str,
prog_hex,
);
let env = ElementsEnv::dummy();
BitMachine::for_program(&prog).exec(&prog, &env)
}
#[test]
#[cfg(feature = "elements")]
fn crash_regression1() {
let res = run_program_elements(
&[0xcf, 0xe1, 0x8f, 0xb4, 0x40, 0x28, 0x87, 0x04, 0x00],
&[],
"615034594b26f261f89485f71b705ebf2e5b27233130d9c41c49c214dcbf0a7f",
"3e2c6ae87f6578e52d51510b476fd2e1dd400ce4f4f6e8a9174574434dc93d7d",
"ffc4aa8b46fd3c25f765f7ad1f44474bd936f9edeb4a90e8b198215c3b743f17",
);
assert_eq!(res.unwrap(), Value::unit());
}
}