simplicity/bit_machine/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
// SPDX-License-Identifier: CC0-1.0

//! # Simplicity Execution
//!
//! Implementation of the Bit Machine, without TCO, as TCO precludes some
//! frame management optimizations which can be used to great benefit.
//!

mod frame;

use std::error;
use std::fmt;
use std::sync::Arc;

use crate::analysis;
use crate::jet::{Jet, JetFailed};
use crate::node::{self, RedeemNode};
use crate::types::Final;
use crate::{Cmr, FailEntropy, Value};
use frame::Frame;

/// An execution context for a Simplicity program
pub struct BitMachine {
    /// Space for bytes that read and write frames point to.
    /// (De)allocation happens LIFO from left to right
    data: Vec<u8>,
    /// Top of data stack; index of first unused bit
    next_frame_start: usize,
    /// Read frame stack
    read: Vec<Frame>,
    /// Write frame stack
    write: Vec<Frame>,
    /// Acceptable source type
    source_ty: Arc<Final>,
}

impl BitMachine {
    /// Construct a Bit Machine with enough space to execute the given program.
    pub fn for_program<J: Jet>(program: &RedeemNode<J>) -> Self {
        let io_width = program.arrow().source.bit_width() + program.arrow().target.bit_width();

        Self {
            data: vec![0; (io_width + program.bounds().extra_cells + 7) / 8],
            next_frame_start: 0,
            read: Vec::with_capacity(program.bounds().extra_frames + analysis::IO_EXTRA_FRAMES),
            write: Vec::with_capacity(program.bounds().extra_frames + analysis::IO_EXTRA_FRAMES),
            source_ty: program.arrow().source.clone(),
        }
    }

    #[cfg(test)]
    pub fn test_exec<J: Jet>(
        program: Arc<crate::node::ConstructNode<J>>,
        env: &J::Environment,
    ) -> Result<Value, ExecutionError> {
        use crate::node::SimpleFinalizer;

        let prog = program
            .finalize_types_non_program()
            .expect("finalizing types")
            .finalize(&mut SimpleFinalizer::new(None.into_iter()))
            .expect("finalizing");
        let mut mac = BitMachine::for_program(&prog);
        mac.exec(&prog, env)
    }

    /// Push a new frame of given size onto the write frame stack
    fn new_frame(&mut self, len: usize) {
        debug_assert!(
            self.next_frame_start + len <= self.data.len() * 8,
            "Data out of bounds: number of cells"
        );
        debug_assert!(
            self.write.len() + self.read.len() < self.read.capacity(),
            "Stacks out of bounds: number of frames"
        );

        self.write.push(Frame::new(self.next_frame_start, len));
        self.next_frame_start += len;
    }

    /// Move the active write frame to the read frame stack
    fn move_frame(&mut self) {
        let mut _active_write_frame = self.write.pop().unwrap();
        _active_write_frame.reset_cursor();
        self.read.push(_active_write_frame);
    }

    /// Drop the active read frame
    fn drop_frame(&mut self) {
        let active_read_frame = self.read.pop().unwrap();
        self.next_frame_start -= active_read_frame.bit_width();
        assert_eq!(self.next_frame_start, active_read_frame.start());
    }

    /// Write a single bit to the active write frame
    fn write_bit(&mut self, bit: bool) {
        self.write
            .last_mut()
            .expect("Empty write frame stack")
            .write_bit(bit, &mut self.data);
    }

    /// Move the cursor of the active write frame forward by
    /// a specified number of bits
    fn skip(&mut self, n: usize) {
        // short circuit n = 0
        if n == 0 {
            return;
        }
        let idx = self.write.len() - 1;
        self.write[idx].move_cursor_forward(n);
    }

    /// Copy the given number of bits from the active read frame
    /// to the active write frame
    fn copy(&mut self, n: usize) {
        // short circuit n = 0
        if n == 0 {
            return;
        }
        let widx = self.write.len() - 1;
        let ridx = self.read.len() - 1;
        self.write[widx].copy_from(&self.read[ridx], n, &mut self.data);
    }

    /// Move the cursor of the active read frame forward
    /// by the given number of bits
    fn fwd(&mut self, n: usize) {
        // short circuit n = 0
        if n == 0 {
            return;
        }
        let idx = self.read.len() - 1;
        self.read[idx].move_cursor_forward(n);
    }

    /// Move the cursor of the active read frame back
    /// by the given number of bits
    fn back(&mut self, n: usize) {
        // short circuit n = 0
        if n == 0 {
            return;
        }
        let idx = self.read.len() - 1;
        self.read[idx].move_cursor_backward(n);
    }

    /// Write a big-endian u8 value to the active write frame
    fn write_u8(&mut self, value: u8) {
        self.write
            .last_mut()
            .expect("Empty write frame stack")
            .write_u8(value, &mut self.data);
    }

    /// Read a bit from the active read frame
    fn read_bit(&mut self) -> bool {
        self.read
            .last_mut()
            .expect("Empty read frame stack")
            .read_bit(&self.data)
    }

    /// Write a bit string to the active write frame
    fn write_bytes(&mut self, bytes: &[u8]) {
        for bit in bytes {
            self.write_u8(*bit);
        }
    }

    /// Write a value to the current write frame
    fn write_value(&mut self, val: &Value) {
        for bit in val.iter_padded() {
            self.write_bit(bit);
        }
    }

    /// Return the bit width of the active read frame.
    fn active_read_bit_width(&self) -> usize {
        self.read.last().map(|frame| frame.bit_width()).unwrap_or(0)
    }

    /// Return the bit width of the active write frame.
    fn active_write_bit_width(&self) -> usize {
        self.write
            .last()
            .map(|frame| frame.bit_width())
            .unwrap_or(0)
    }

    /// Add a read frame with some given value in it, as input to the
    /// program
    pub fn input(&mut self, input: &Value) -> Result<(), ExecutionError> {
        if !input.is_of_type(&self.source_ty) {
            return Err(ExecutionError::InputWrongType(self.source_ty.clone()));
        }
        // Unit value doesn't need extra frame
        if !input.is_empty() {
            self.new_frame(input.padded_len());
            self.write_value(input);
            self.move_frame();
        }
        Ok(())
    }

    /// Execute the given program on the Bit Machine, using the given environment.
    ///
    /// Make sure the Bit Machine has enough space by constructing it via [`Self::for_program()`].
    pub fn exec<J: Jet + std::fmt::Debug>(
        &mut self,
        program: &RedeemNode<J>,
        env: &J::Environment,
    ) -> Result<Value, ExecutionError> {
        enum CallStack<'a, J: Jet> {
            Goto(&'a RedeemNode<J>),
            MoveFrame,
            DropFrame,
            CopyFwd(usize),
            Back(usize),
        }

        // Not used, but useful for debugging, so keep it around
        impl<'a, J: Jet> fmt::Debug for CallStack<'a, J> {
            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
                match self {
                    CallStack::Goto(ins) => write!(f, "goto {}", ins.inner()),
                    CallStack::MoveFrame => f.write_str("move frame"),
                    CallStack::DropFrame => f.write_str("drop frame"),
                    CallStack::CopyFwd(n) => write!(f, "copy/fwd {}", n),
                    CallStack::Back(n) => write!(f, "back {}", n),
                }
            }
        }

        if self.read.is_empty() != self.source_ty.is_empty() {
            return Err(ExecutionError::InputWrongType(self.source_ty.clone()));
        }

        let mut ip = program;
        let mut call_stack = vec![];

        let output_width = ip.arrow().target.bit_width();
        if output_width > 0 {
            self.new_frame(output_width);
        }

        'main_loop: loop {
            match ip.inner() {
                node::Inner::Unit => {}
                node::Inner::Iden => {
                    let size_a = ip.arrow().source.bit_width();
                    self.copy(size_a);
                }
                node::Inner::InjL(left) => {
                    let (b, c) = ip.arrow().target.as_sum().unwrap();
                    self.write_bit(false);
                    self.skip(b.pad_left(c));
                    call_stack.push(CallStack::Goto(left));
                }
                node::Inner::InjR(left) => {
                    let (b, c) = ip.arrow().target.as_sum().unwrap();
                    self.write_bit(true);
                    self.skip(b.pad_right(c));
                    call_stack.push(CallStack::Goto(left));
                }
                node::Inner::Pair(left, right) => {
                    call_stack.push(CallStack::Goto(right));
                    call_stack.push(CallStack::Goto(left));
                }
                node::Inner::Comp(left, right) => {
                    let size_b = left.arrow().target.bit_width();

                    self.new_frame(size_b);
                    call_stack.push(CallStack::DropFrame);
                    call_stack.push(CallStack::Goto(right));
                    call_stack.push(CallStack::MoveFrame);
                    call_stack.push(CallStack::Goto(left));
                }
                node::Inner::Disconnect(left, right) => {
                    let size_prod_256_a = left.arrow().source.bit_width();
                    let size_a = size_prod_256_a - 256;
                    let size_prod_b_c = left.arrow().target.bit_width();
                    let size_b = size_prod_b_c - right.arrow().source.bit_width();

                    self.new_frame(size_prod_256_a);
                    self.write_bytes(right.cmr().as_ref());
                    self.copy(size_a);
                    self.move_frame();
                    self.new_frame(size_prod_b_c);

                    // Remember that call stack pushes are executed in reverse order
                    call_stack.push(CallStack::DropFrame);
                    call_stack.push(CallStack::DropFrame);
                    call_stack.push(CallStack::Goto(right));
                    call_stack.push(CallStack::CopyFwd(size_b));
                    call_stack.push(CallStack::MoveFrame);
                    call_stack.push(CallStack::Goto(left));
                }
                node::Inner::Take(left) => call_stack.push(CallStack::Goto(left)),
                node::Inner::Drop(left) => {
                    let size_a = ip.arrow().source.as_product().unwrap().0.bit_width();
                    self.fwd(size_a);
                    call_stack.push(CallStack::Back(size_a));
                    call_stack.push(CallStack::Goto(left));
                }
                node::Inner::Case(..) | node::Inner::AssertL(..) | node::Inner::AssertR(..) => {
                    let choice_bit = self.read[self.read.len() - 1].peek_bit(&self.data);

                    let (sum_a_b, _c) = ip.arrow().source.as_product().unwrap();
                    let (a, b) = sum_a_b.as_sum().unwrap();

                    match (ip.inner(), choice_bit) {
                        (node::Inner::Case(_, right), true)
                        | (node::Inner::AssertR(_, right), true) => {
                            self.fwd(1 + a.pad_right(b));
                            call_stack.push(CallStack::Back(1 + a.pad_right(b)));
                            call_stack.push(CallStack::Goto(right));
                        }
                        (node::Inner::Case(left, _), false)
                        | (node::Inner::AssertL(left, _), false) => {
                            self.fwd(1 + a.pad_left(b));
                            call_stack.push(CallStack::Back(1 + a.pad_left(b)));
                            call_stack.push(CallStack::Goto(left));
                        }
                        (node::Inner::AssertL(_, r_cmr), true) => {
                            return Err(ExecutionError::ReachedPrunedBranch(*r_cmr))
                        }
                        (node::Inner::AssertR(l_cmr, _), false) => {
                            return Err(ExecutionError::ReachedPrunedBranch(*l_cmr))
                        }
                        _ => unreachable!(),
                    }
                }
                node::Inner::Witness(value) => self.write_value(value),
                node::Inner::Jet(jet) => self.exec_jet(*jet, env)?,
                node::Inner::Word(value) => self.write_value(value),
                node::Inner::Fail(entropy) => {
                    return Err(ExecutionError::ReachedFailNode(*entropy))
                }
            }

            ip = loop {
                match call_stack.pop() {
                    Some(CallStack::Goto(next)) => break next,
                    Some(CallStack::MoveFrame) => self.move_frame(),
                    Some(CallStack::DropFrame) => self.drop_frame(),
                    Some(CallStack::CopyFwd(n)) => {
                        self.copy(n);
                        self.fwd(n);
                    }
                    Some(CallStack::Back(n)) => self.back(n),
                    None => break 'main_loop,
                };
            };
        }

        if output_width > 0 {
            let out_frame = self.write.last_mut().unwrap();
            out_frame.reset_cursor();
            let value = Value::from_padded_bits(
                &mut out_frame.as_bit_iter(&self.data),
                &program.arrow().target,
            )
            .expect("Decode value of output frame");

            Ok(value)
        } else {
            Ok(Value::unit())
        }
    }

    fn exec_jet<J: Jet>(&mut self, jet: J, env: &J::Environment) -> Result<(), JetFailed> {
        use crate::ffi::c_jets::frame_ffi::{c_readBit, c_writeBit, CFrameItem};
        use crate::ffi::c_jets::uword_width;
        use crate::ffi::ffi::UWORD;

        /// Create new C read frame that contains `bit_width` many bits from active read frame.
        ///
        /// Return C read frame together with underlying buffer.
        ///
        /// ## Safety
        ///
        /// The returned frame must outlive its buffer or there is a dangling pointer.
        ///
        /// ## Panics
        ///
        /// Active read frame has fewer bits than `bit_width`.
        unsafe fn get_input_frame(
            mac: &mut BitMachine,
            bit_width: usize,
        ) -> (CFrameItem, Vec<UWORD>) {
            assert!(bit_width <= mac.active_read_bit_width());
            let uword_width = uword_width(bit_width);
            let mut buffer = vec![0; uword_width];

            // Copy bits from active read frame into input frame
            let buffer_end = buffer.as_mut_ptr().add(uword_width);
            let mut write_frame = CFrameItem::new_write(bit_width, buffer_end);
            for _ in 0..bit_width {
                let bit = mac.read_bit();
                c_writeBit(&mut write_frame, bit);
            }
            mac.back(bit_width);

            // Convert input frame into read frame
            let buffer_ptr = buffer.as_mut_ptr();
            let read_frame = CFrameItem::new_read(bit_width, buffer_ptr);

            (read_frame, buffer)
        }

        /// Create C write frame that is as wide as `bit_width`.
        ///
        /// Return C write frame together with underlying buffer.
        ///
        /// ## Safety
        ///
        /// The returned frame must outlive its buffer or there is a dangling pointer.
        unsafe fn get_output_frame(bit_width: usize) -> (CFrameItem, Vec<UWORD>) {
            let uword_width = uword_width(bit_width);
            let mut buffer = vec![0; uword_width];

            // Return output frame as write frame
            let buffer_end = buffer.as_mut_ptr().add(uword_width);
            let write_frame = CFrameItem::new_write(bit_width, buffer_end);

            (write_frame, buffer)
        }

        /// Write `bit_width` many bits from `buffer` into active write frame.
        ///
        /// ## Panics
        ///
        /// Active write frame has fewer bits than `bit_width`.
        ///
        /// Buffer has fewer than bits than `bit_width` (converted to UWORDs).
        fn update_active_write_frame(mac: &mut BitMachine, bit_width: usize, buffer: &[UWORD]) {
            assert!(bit_width <= mac.active_write_bit_width());
            assert!(uword_width(bit_width) <= buffer.len());
            let buffer_ptr = buffer.as_ptr();
            let mut read_frame = unsafe { CFrameItem::new_read(bit_width, buffer_ptr) };

            for _ in 0..bit_width {
                let bit = unsafe { c_readBit(&mut read_frame) };
                mac.write_bit(bit);
            }
        }

        // Sanity Check: This should never really fail, but still good to do
        if !simplicity_sys::c_jets::sanity_checks() {
            return Err(JetFailed);
        }

        let input_width = jet.source_ty().to_bit_width();
        let output_width = jet.target_ty().to_bit_width();
        // Input buffer is implicitly referenced by input read frame!
        // Same goes for output buffer
        let (input_read_frame, _input_buffer) = unsafe { get_input_frame(self, input_width) };
        let (mut output_write_frame, output_buffer) = unsafe { get_output_frame(output_width) };

        let jet_fn = jet.c_jet_ptr();
        let c_env = J::c_jet_env(env);
        let success = jet_fn(&mut output_write_frame, input_read_frame, c_env);

        if !success {
            Err(JetFailed)
        } else {
            update_active_write_frame(self, output_width, &output_buffer);
            Ok(())
        }
    }
}

/// Errors related to simplicity Execution
#[derive(Debug)]
pub enum ExecutionError {
    /// Provided input is of wrong type
    InputWrongType(Arc<Final>),
    /// Reached a fail node
    ReachedFailNode(FailEntropy),
    /// Reached a pruned branch
    ReachedPrunedBranch(Cmr),
    /// Jet failed during execution
    JetFailed(JetFailed),
}

impl fmt::Display for ExecutionError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            ExecutionError::InputWrongType(expected_ty) => {
                write!(f, "Expected input of type: {expected_ty}")
            }
            ExecutionError::ReachedFailNode(entropy) => {
                write!(f, "Execution reached a fail node: {}", entropy)
            }
            ExecutionError::ReachedPrunedBranch(hash) => {
                write!(f, "Execution reached a pruned branch: {}", hash)
            }
            ExecutionError::JetFailed(jet_failed) => fmt::Display::fmt(jet_failed, f),
        }
    }
}

impl error::Error for ExecutionError {}

impl From<JetFailed> for ExecutionError {
    fn from(jet_failed: JetFailed) -> Self {
        ExecutionError::JetFailed(jet_failed)
    }
}

#[cfg(test)]
mod tests {
    #[cfg(feature = "elements")]
    use super::*;

    #[cfg(feature = "elements")]
    use crate::jet::{elements::ElementsEnv, Elements};
    #[cfg(feature = "elements")]
    use crate::{node::RedeemNode, BitIter};
    #[cfg(feature = "elements")]
    use hex::DisplayHex;

    #[cfg(feature = "elements")]
    fn run_program_elements(
        prog_bytes: &[u8],
        witness_bytes: &[u8],
        cmr_str: &str,
        amr_str: &str,
        imr_str: &str,
    ) -> Result<Value, ExecutionError> {
        let prog_hex = prog_bytes.as_hex();

        let prog = BitIter::from(prog_bytes);
        let witness = BitIter::from(witness_bytes);
        let prog = match RedeemNode::<Elements>::decode(prog, witness) {
            Ok(prog) => prog,
            Err(e) => panic!("program {} failed: {}", prog_hex, e),
        };

        // Check Merkle roots; check AMR last because the AMR is a more complicated
        // calculation and historically has been much more likely to be wrong.
        assert_eq!(
            prog.cmr().to_string(),
            cmr_str,
            "CMR mismatch (got {} expected {}) for program {}",
            prog.cmr(),
            cmr_str,
            prog_hex,
        );
        assert_eq!(
            prog.imr().to_string(),
            imr_str,
            "IMR mismatch (got {} expected {}) for program {}",
            prog.imr(),
            imr_str,
            prog_hex,
        );
        assert_eq!(
            prog.amr().to_string(),
            amr_str,
            "AMR mismatch (got {} expected {}) for program {}",
            prog.amr(),
            amr_str,
            prog_hex,
        );

        // Try to run it on the bit machine and return the result
        let env = ElementsEnv::dummy();
        BitMachine::for_program(&prog).exec(&prog, &env)
    }

    #[test]
    #[cfg(feature = "elements")]
    fn crash_regression1() {
        // cfe18fb44028870400
        // This program caused an array OOB.
        // # Witnesses
        // wit1 = witness :: (((1 + (1 + 2^256)) * (1 + (1 + 2^256))) + 1) -> 1
        //
        // # Program code
        // jt1 = jet_num_outputs    :: 1 -> 2^32                                          # cmr 447165a3...
        // jt2 = jet_issuance_token :: 2^32 -> (1 + (1 + 2^256))                          # cmr 85e9591c...
        // pr3 = pair jt2 jt2       :: 2^32 -> ((1 + (1 + 2^256)) * (1 + (1 + 2^256)))    # cmr 45d40848...
        // cp4 = comp jt1 pr3       :: 1 -> ((1 + (1 + 2^256)) * (1 + (1 + 2^256)))       # cmr 7bb1824f...
        // jl5 = injl cp4           :: 1 -> (((1 + (1 + 2^256)) * (1 + (1 + 2^256))) + 1) # cmr 277ee32c...
        //
        // main = comp jl5 wit1     :: 1 -> 1                                             # cmr 7050c4a6...

        let res = run_program_elements(
            &[0xcf, 0xe1, 0x8f, 0xb4, 0x40, 0x28, 0x87, 0x04, 0x00],
            &[],
            "615034594b26f261f89485f71b705ebf2e5b27233130d9c41c49c214dcbf0a7f",
            "3e2c6ae87f6578e52d51510b476fd2e1dd400ce4f4f6e8a9174574434dc93d7d",
            "ffc4aa8b46fd3c25f765f7ad1f44474bd936f9edeb4a90e8b198215c3b743f17",
        );
        assert_eq!(res.unwrap(), Value::unit());
    }
}