triton_isa/
program.rs

1use std::collections::hash_map::Entry;
2use std::collections::HashMap;
3use std::collections::HashSet;
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::fmt::Result as FmtResult;
7use std::hash::Hash;
8use std::io::Cursor;
9
10use arbitrary::Arbitrary;
11use get_size2::GetSize;
12use itertools::Itertools;
13use serde::Deserialize;
14use serde::Serialize;
15use thiserror::Error;
16use twenty_first::prelude::*;
17
18use crate::instruction::AnInstruction;
19use crate::instruction::AssertionContext;
20use crate::instruction::Instruction;
21use crate::instruction::InstructionError;
22use crate::instruction::LabelledInstruction;
23use crate::instruction::TypeHint;
24use crate::parser;
25use crate::parser::ParseError;
26
27/// A program for Triton VM. Triton VM can run and profile such programs,
28/// and trace its execution in order to generate a proof of correct execution.
29/// See there for details.
30///
31/// A program may contain debug information, such as label names and breakpoints.
32/// Access this information through methods [`label_for_address()`][label_for_address] and
33/// [`is_breakpoint()`][is_breakpoint]. Some operations, most notably
34/// [BField-encoding](BFieldCodec::encode), discard this debug information.
35///
36/// [program attestation]: https://triton-vm.org/spec/program-attestation.html
37/// [label_for_address]: Program::label_for_address
38/// [is_breakpoint]: Program::is_breakpoint
39#[derive(Debug, Clone, Eq, Serialize, Deserialize, GetSize)]
40pub struct Program {
41    pub instructions: Vec<Instruction>,
42    address_to_label: HashMap<u64, String>,
43    debug_information: DebugInformation,
44}
45
46impl Display for Program {
47    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
48        for instruction in self.labelled_instructions() {
49            writeln!(f, "{instruction}")?;
50        }
51        Ok(())
52    }
53}
54
55impl PartialEq for Program {
56    fn eq(&self, other: &Program) -> bool {
57        self.instructions.eq(&other.instructions)
58    }
59}
60
61impl BFieldCodec for Program {
62    type Error = ProgramDecodingError;
63
64    fn decode(sequence: &[BFieldElement]) -> Result<Box<Self>, Self::Error> {
65        if sequence.is_empty() {
66            return Err(Self::Error::EmptySequence);
67        }
68        let program_length = sequence[0].value() as usize;
69        let sequence = &sequence[1..];
70        if sequence.len() < program_length {
71            return Err(Self::Error::SequenceTooShort);
72        }
73        if sequence.len() > program_length {
74            return Err(Self::Error::SequenceTooLong);
75        }
76
77        // instantiating with claimed capacity is a potential DOS vector
78        let mut instructions = vec![];
79        let mut read_idx = 0;
80        while read_idx < program_length {
81            let opcode = sequence[read_idx];
82            let mut instruction = Instruction::try_from(opcode)
83                .map_err(|err| Self::Error::InvalidInstruction(read_idx, err))?;
84            let instruction_has_arg = instruction.arg().is_some();
85            if instruction_has_arg && instructions.len() + instruction.size() > program_length {
86                return Err(Self::Error::MissingArgument(read_idx, instruction));
87            }
88            if instruction_has_arg {
89                let arg = sequence[read_idx + 1];
90                instruction = instruction
91                    .change_arg(arg)
92                    .map_err(|err| Self::Error::InvalidInstruction(read_idx, err))?;
93            }
94
95            instructions.extend(vec![instruction; instruction.size()]);
96            read_idx += instruction.size();
97        }
98
99        if read_idx != program_length {
100            return Err(Self::Error::LengthMismatch);
101        }
102        if instructions.len() != program_length {
103            return Err(Self::Error::LengthMismatch);
104        }
105
106        Ok(Box::new(Program {
107            instructions,
108            address_to_label: HashMap::default(),
109            debug_information: DebugInformation::default(),
110        }))
111    }
112
113    fn encode(&self) -> Vec<BFieldElement> {
114        let mut sequence = Vec::with_capacity(self.len_bwords() + 1);
115        sequence.push(bfe!(self.len_bwords() as u64));
116        sequence.extend(self.to_bwords());
117        sequence
118    }
119
120    fn static_length() -> Option<usize> {
121        None
122    }
123}
124
125impl<'a> Arbitrary<'a> for Program {
126    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
127        let contains_label = |labelled_instructions: &[_], maybe_label: &_| {
128            let LabelledInstruction::Label(label) = maybe_label else {
129                return false;
130            };
131            labelled_instructions
132                .iter()
133                .any(|labelled_instruction| match labelled_instruction {
134                    LabelledInstruction::Label(l) => l == label,
135                    _ => false,
136                })
137        };
138        let is_assertion = |maybe_instruction: &_| {
139            matches!(
140                maybe_instruction,
141                LabelledInstruction::Instruction(
142                    AnInstruction::Assert | AnInstruction::AssertVector
143                )
144            )
145        };
146
147        let mut labelled_instructions = vec![];
148        for _ in 0..u.arbitrary_len::<LabelledInstruction>()? {
149            let labelled_instruction = u.arbitrary()?;
150            if contains_label(&labelled_instructions, &labelled_instruction) {
151                continue;
152            }
153            if let LabelledInstruction::AssertionContext(_) = labelled_instruction {
154                // assertion context must come after an assertion
155                continue;
156            }
157
158            let is_assertion = is_assertion(&labelled_instruction);
159            labelled_instructions.push(labelled_instruction);
160
161            if is_assertion && u.arbitrary()? {
162                let assertion_context = LabelledInstruction::AssertionContext(u.arbitrary()?);
163                labelled_instructions.push(assertion_context);
164            }
165        }
166
167        let all_call_targets = labelled_instructions
168            .iter()
169            .filter_map(|instruction| match instruction {
170                LabelledInstruction::Instruction(AnInstruction::Call(target)) => Some(target),
171                _ => None,
172            })
173            .unique();
174        let labels_that_are_called_but_not_declared = all_call_targets
175            .map(|target| LabelledInstruction::Label(target.clone()))
176            .filter(|label| !contains_label(&labelled_instructions, label))
177            .collect_vec();
178
179        for label in labels_that_are_called_but_not_declared {
180            let insertion_index = u.choose_index(labelled_instructions.len() + 1)?;
181            labelled_instructions.insert(insertion_index, label);
182        }
183
184        Ok(Program::new(&labelled_instructions))
185    }
186}
187
188/// An `InstructionIter` loops the instructions of a `Program` by skipping duplicate placeholders.
189#[derive(Debug, Default, Clone, Eq, PartialEq)]
190pub struct InstructionIter {
191    cursor: Cursor<Vec<Instruction>>,
192}
193
194impl Iterator for InstructionIter {
195    type Item = Instruction;
196
197    fn next(&mut self) -> Option<Self::Item> {
198        let pos = self.cursor.position() as usize;
199        let instructions = self.cursor.get_ref();
200        let instruction = *instructions.get(pos)?;
201        self.cursor.set_position((pos + instruction.size()) as u64);
202
203        Some(instruction)
204    }
205}
206
207impl IntoIterator for Program {
208    type Item = Instruction;
209
210    type IntoIter = InstructionIter;
211
212    fn into_iter(self) -> Self::IntoIter {
213        let cursor = Cursor::new(self.instructions);
214        InstructionIter { cursor }
215    }
216}
217
218#[derive(Debug, Default, Clone, Eq, PartialEq, Serialize, Deserialize, Arbitrary, GetSize)]
219struct DebugInformation {
220    breakpoints: Vec<bool>,
221    type_hints: HashMap<u64, Vec<TypeHint>>,
222    assertion_context: HashMap<u64, AssertionContext>,
223}
224
225impl Program {
226    pub fn new(labelled_instructions: &[LabelledInstruction]) -> Self {
227        let label_to_address = parser::build_label_to_address_map(labelled_instructions);
228        let instructions =
229            parser::turn_labels_into_addresses(labelled_instructions, &label_to_address);
230        let address_to_label = Self::flip_map(label_to_address);
231        let debug_information = Self::extract_debug_information(labelled_instructions);
232
233        debug_assert_eq!(instructions.len(), debug_information.breakpoints.len());
234        Program {
235            instructions,
236            address_to_label,
237            debug_information,
238        }
239    }
240
241    fn flip_map<Key, Value: Eq + Hash>(map: HashMap<Key, Value>) -> HashMap<Value, Key> {
242        map.into_iter().map(|(key, value)| (value, key)).collect()
243    }
244
245    fn extract_debug_information(
246        labelled_instructions: &[LabelledInstruction],
247    ) -> DebugInformation {
248        let mut address = 0;
249        let mut break_before_next_instruction = false;
250        let mut debug_info = DebugInformation::default();
251        for instruction in labelled_instructions {
252            match instruction {
253                LabelledInstruction::Instruction(instruction) => {
254                    let new_breakpoints = vec![break_before_next_instruction; instruction.size()];
255                    debug_info.breakpoints.extend(new_breakpoints);
256                    break_before_next_instruction = false;
257                    address += instruction.size() as u64;
258                }
259                LabelledInstruction::Label(_) => (),
260                LabelledInstruction::Breakpoint => break_before_next_instruction = true,
261                LabelledInstruction::TypeHint(hint) => match debug_info.type_hints.entry(address) {
262                    Entry::Occupied(mut entry) => entry.get_mut().push(hint.clone()),
263                    Entry::Vacant(entry) => entry.insert(vec![]).push(hint.clone()),
264                },
265                LabelledInstruction::AssertionContext(ctx) => {
266                    let address_of_associated_assertion = address.saturating_sub(1);
267                    debug_info
268                        .assertion_context
269                        .insert(address_of_associated_assertion, ctx.clone());
270                }
271            }
272        }
273
274        debug_info
275    }
276
277    /// Create a `Program` by parsing source code.
278    pub fn from_code(code: &str) -> Result<Self, ParseError> {
279        parser::parse(code)
280            .map(|tokens| parser::to_labelled_instructions(&tokens))
281            .map(|instructions| Program::new(&instructions))
282    }
283
284    pub fn labelled_instructions(&self) -> Vec<LabelledInstruction> {
285        let call_targets = self.call_targets();
286        let instructions_with_labels = self.instructions.iter().map(|instruction| {
287            instruction.map_call_address(|&address| self.label_for_address(address.value()))
288        });
289
290        let mut labelled_instructions = vec![];
291        let mut address = 0;
292        let mut instruction_stream = instructions_with_labels.into_iter();
293        while let Some(instruction) = instruction_stream.next() {
294            let instruction_size = instruction.size() as u64;
295            if call_targets.contains(&address) {
296                let label = self.label_for_address(address);
297                let label = LabelledInstruction::Label(label);
298                labelled_instructions.push(label);
299            }
300            for type_hint in self.type_hints_at(address) {
301                labelled_instructions.push(LabelledInstruction::TypeHint(type_hint));
302            }
303            if self.is_breakpoint(address) {
304                labelled_instructions.push(LabelledInstruction::Breakpoint);
305            }
306            labelled_instructions.push(LabelledInstruction::Instruction(instruction));
307            if let Some(context) = self.assertion_context_at(address) {
308                labelled_instructions.push(LabelledInstruction::AssertionContext(context));
309            }
310
311            for _ in 1..instruction_size {
312                instruction_stream.next();
313            }
314            address += instruction_size;
315        }
316
317        let leftover_labels = self
318            .address_to_label
319            .iter()
320            .filter(|(&labels_address, _)| labels_address >= address)
321            .sorted();
322        for (_, label) in leftover_labels {
323            labelled_instructions.push(LabelledInstruction::Label(label.clone()));
324        }
325
326        labelled_instructions
327    }
328
329    fn call_targets(&self) -> HashSet<u64> {
330        self.instructions
331            .iter()
332            .filter_map(|instruction| match instruction {
333                Instruction::Call(address) => Some(address.value()),
334                _ => None,
335            })
336            .collect()
337    }
338
339    pub fn is_breakpoint(&self, address: u64) -> bool {
340        let address: usize = address.try_into().unwrap();
341        self.debug_information
342            .breakpoints
343            .get(address)
344            .copied()
345            .unwrap_or_default()
346    }
347
348    pub fn type_hints_at(&self, address: u64) -> Vec<TypeHint> {
349        self.debug_information
350            .type_hints
351            .get(&address)
352            .cloned()
353            .unwrap_or_default()
354    }
355
356    pub fn assertion_context_at(&self, address: u64) -> Option<AssertionContext> {
357        self.debug_information
358            .assertion_context
359            .get(&address)
360            .cloned()
361    }
362
363    /// Turn the program into a sequence of `BFieldElement`s. Each instruction is encoded as its
364    /// opcode, followed by its argument (if any).
365    ///
366    /// **Note**: This is _almost_ (but not quite!) equivalent to [encoding](BFieldCodec::encode)
367    /// the program. For that, use [`encode()`](Self::encode()) instead.
368    pub fn to_bwords(&self) -> Vec<BFieldElement> {
369        self.clone()
370            .into_iter()
371            .flat_map(|instruction| {
372                let opcode = instruction.opcode_b();
373                if let Some(arg) = instruction.arg() {
374                    vec![opcode, arg]
375                } else {
376                    vec![opcode]
377                }
378            })
379            .collect()
380    }
381
382    /// The total length of the program as `BFieldElement`s. Double-word instructions contribute
383    /// two `BFieldElement`s.
384    pub fn len_bwords(&self) -> usize {
385        self.instructions.len()
386    }
387
388    pub fn is_empty(&self) -> bool {
389        self.instructions.is_empty()
390    }
391
392    /// Produces the program's canonical hash digest. Uses [`Tip5`], the
393    /// canonical hash function for Triton VM.
394    pub fn hash(&self) -> Digest {
395        // not encoded using `BFieldCodec` because that would prepend the length
396        Tip5::hash_varlen(&self.to_bwords())
397    }
398
399    /// The label for the given address, or a deterministic, unique substitute if no label is found.
400    pub fn label_for_address(&self, address: u64) -> String {
401        // Uniqueness of the label is relevant for printing and subsequent parsing:
402        // Parsing fails on duplicate labels.
403        self.address_to_label
404            .get(&address)
405            .cloned()
406            .unwrap_or_else(|| format!("address_{address}"))
407    }
408}
409
410#[non_exhaustive]
411#[derive(Debug, Clone, Eq, PartialEq, Error)]
412pub enum ProgramDecodingError {
413    #[error("sequence to decode is empty")]
414    EmptySequence,
415
416    #[error("sequence to decode is too short")]
417    SequenceTooShort,
418
419    #[error("sequence to decode is too long")]
420    SequenceTooLong,
421
422    #[error("length of decoded program is unexpected")]
423    LengthMismatch,
424
425    #[error("sequence to decode contains invalid instruction at index {0}: {1}")]
426    InvalidInstruction(usize, InstructionError),
427
428    #[error("missing argument for instruction {1} at index {0}")]
429    MissingArgument(usize, Instruction),
430}
431
432#[cfg(test)]
433mod tests {
434    use assert2::assert;
435    use assert2::let_assert;
436    use proptest::prelude::*;
437    use proptest_arbitrary_interop::arb;
438    use rand::Rng;
439    use test_strategy::proptest;
440
441    use crate::triton_program;
442
443    use super::*;
444
445    #[proptest]
446    fn random_program_encode_decode_equivalence(#[strategy(arb())] program: Program) {
447        let encoding = program.encode();
448        let decoding = *Program::decode(&encoding).unwrap();
449        prop_assert_eq!(program, decoding);
450    }
451
452    #[test]
453    fn decode_program_with_missing_argument_as_last_instruction() {
454        let program = triton_program!(push 3 push 3 eq assert push 3);
455        let program_length = program.len_bwords() as u64;
456        let encoded = program.encode();
457
458        let mut encoded = encoded[0..encoded.len() - 1].to_vec();
459        encoded[0] = bfe!(program_length - 1);
460
461        let_assert!(Err(err) = Program::decode(&encoded));
462        let_assert!(ProgramDecodingError::MissingArgument(6, _) = err);
463    }
464
465    #[test]
466    fn decode_program_with_shorter_than_indicated_sequence() {
467        let program = triton_program!(nop nop hash push 0 skiz end: halt call end);
468        let mut encoded = program.encode();
469        encoded[0] += bfe!(1);
470        let_assert!(Err(err) = Program::decode(&encoded));
471        let_assert!(ProgramDecodingError::SequenceTooShort = err);
472    }
473
474    #[test]
475    fn decode_program_with_longer_than_indicated_sequence() {
476        let program = triton_program!(nop nop hash push 0 skiz end: halt call end);
477        let mut encoded = program.encode();
478        encoded[0] -= bfe!(1);
479        let_assert!(Err(err) = Program::decode(&encoded));
480        let_assert!(ProgramDecodingError::SequenceTooLong = err);
481    }
482
483    #[test]
484    fn decode_program_from_empty_sequence() {
485        let encoded = vec![];
486        let_assert!(Err(err) = Program::decode(&encoded));
487        let_assert!(ProgramDecodingError::EmptySequence = err);
488    }
489
490    #[test]
491    fn hash_simple_program() {
492        let program = triton_program!(halt);
493        let digest = program.hash();
494
495        let expected_digest = bfe_array![
496            0x4338_de79_520b_3949_u64,
497            0xe6a2_129b_2885_0dc9_u64,
498            0xfd3c_d098_6a86_0450_u64,
499            0x69fd_ba91_0ceb_a7bc_u64,
500            0x7e5b_118c_9594_c062_u64,
501        ];
502        let expected_digest = Digest::new(expected_digest);
503
504        assert!(expected_digest == digest);
505    }
506
507    #[test]
508    fn empty_program_is_empty() {
509        let program = triton_program!();
510        assert!(program.is_empty());
511    }
512
513    #[test]
514    fn create_program_from_code() {
515        let element_3 = rand::rng().random_range(0..BFieldElement::P);
516        let element_2 = 1337_usize;
517        let element_1 = "17";
518        let element_0 = bfe!(0);
519        let instruction_push = Instruction::Push(bfe!(42));
520        let dup_arg = 1;
521        let label = "my_label".to_string();
522
523        let source_code = format!(
524            "push {element_3} push {element_2} push {element_1} push {element_0}
525             call {label} halt
526             {label}:
527                {instruction_push}
528                dup {dup_arg}
529                skiz
530                recurse
531                return"
532        );
533        let program_from_code = Program::from_code(&source_code).unwrap();
534        let program_from_macro = triton_program!({ source_code });
535        assert!(program_from_code == program_from_macro);
536    }
537
538    #[test]
539    fn parser_macro_with_interpolated_label_as_first_argument() {
540        let label = "my_label";
541        let _program = triton_program!(
542            {label}: push 1 assert halt
543        );
544    }
545
546    #[test]
547    fn breakpoints_propagate_to_debug_information_as_expected() {
548        let program = triton_program! {
549            break push 1 push 2
550            break break break break
551            pop 2 hash halt
552            break // no effect
553        };
554
555        assert!(program.is_breakpoint(0));
556        assert!(program.is_breakpoint(1));
557        assert!(!program.is_breakpoint(2));
558        assert!(!program.is_breakpoint(3));
559        assert!(program.is_breakpoint(4));
560        assert!(program.is_breakpoint(5));
561        assert!(!program.is_breakpoint(6));
562        assert!(!program.is_breakpoint(7));
563
564        // going beyond the length of the program must not break things
565        assert!(!program.is_breakpoint(8));
566        assert!(!program.is_breakpoint(9));
567    }
568
569    #[test]
570    fn print_program_without_any_debug_information() {
571        let program = triton_program! {
572            call foo
573            call bar
574            call baz
575            halt
576            foo: nop nop return
577            bar: call baz return
578            baz: push 1 return
579        };
580        let encoding = program.encode();
581        let program = Program::decode(&encoding).unwrap();
582        println!("{program}");
583    }
584
585    #[proptest]
586    fn printed_program_can_be_parsed_again(#[strategy(arb())] program: Program) {
587        parser::parse(&program.to_string())?;
588    }
589
590    struct TypeHintTestCase {
591        expected: TypeHint,
592        input: &'static str,
593    }
594
595    impl TypeHintTestCase {
596        fn run(&self) {
597            let program = Program::from_code(self.input).unwrap();
598            let [ref type_hint] = program.type_hints_at(0)[..] else {
599                panic!("Expected a single type hint at address 0");
600            };
601            assert!(&self.expected == type_hint);
602        }
603    }
604
605    #[test]
606    fn parse_simple_type_hint() {
607        let expected = TypeHint {
608            starting_index: 0,
609            length: 1,
610            type_name: Some("Type".to_string()),
611            variable_name: "foo".to_string(),
612        };
613
614        TypeHintTestCase {
615            expected,
616            input: "hint foo: Type = stack[0]",
617        }
618        .run();
619    }
620
621    #[test]
622    fn parse_type_hint_with_range() {
623        let expected = TypeHint {
624            starting_index: 0,
625            length: 5,
626            type_name: Some("Digest".to_string()),
627            variable_name: "foo".to_string(),
628        };
629
630        TypeHintTestCase {
631            expected,
632            input: "hint foo: Digest = stack[0..5]",
633        }
634        .run();
635    }
636
637    #[test]
638    fn parse_type_hint_with_range_and_offset() {
639        let expected = TypeHint {
640            starting_index: 7,
641            length: 3,
642            type_name: Some("XFieldElement".to_string()),
643            variable_name: "bar".to_string(),
644        };
645
646        TypeHintTestCase {
647            expected,
648            input: "hint bar: XFieldElement = stack[7..10]",
649        }
650        .run();
651    }
652
653    #[test]
654    fn parse_type_hint_with_range_and_offset_and_weird_whitespace() {
655        let expected = TypeHint {
656            starting_index: 2,
657            length: 12,
658            type_name: Some("BigType".to_string()),
659            variable_name: "bar".to_string(),
660        };
661
662        TypeHintTestCase {
663            expected,
664            input: " hint \t \t bar  :BigType=stack[ 2\t.. 14 ]\t \n",
665        }
666        .run();
667    }
668
669    #[test]
670    fn parse_type_hint_with_no_type_only_variable_name() {
671        let expected = TypeHint {
672            starting_index: 0,
673            length: 1,
674            type_name: None,
675            variable_name: "foo".to_string(),
676        };
677
678        TypeHintTestCase {
679            expected,
680            input: "hint foo = stack[0]",
681        }
682        .run();
683    }
684
685    #[test]
686    fn parse_type_hint_with_no_type_only_variable_name_with_range() {
687        let expected = TypeHint {
688            starting_index: 2,
689            length: 5,
690            type_name: None,
691            variable_name: "foo".to_string(),
692        };
693
694        TypeHintTestCase {
695            expected,
696            input: "hint foo = stack[2..7]",
697        }
698        .run();
699    }
700
701    #[test]
702    fn assertion_context_is_propagated_into_debug_info() {
703        let program = triton_program! {push 1000 assert error_id 17 halt};
704        //                              ↑0   ↑1   ↑2
705
706        let assertion_contexts = program.debug_information.assertion_context;
707        assert!(1 == assertion_contexts.len());
708        let_assert!(AssertionContext::ID(error_id) = &assertion_contexts[&2]);
709        assert!(17 == *error_id);
710    }
711
712    #[test]
713    fn printing_program_includes_debug_information() {
714        let source_code = "\
715            call foo\n\
716            break\n\
717            call bar\n\
718            halt\n\
719            foo:\n\
720            break\n\
721            call baz\n\
722            push 1\n\
723            nop\n\
724            return\n\
725            baz:\n\
726            hash\n\
727            hint my_digest: Digest = stack[0..5]\n\
728            hint random_stuff = stack[17]\n\
729            return\n\
730            nop\n\
731            pop 1\n\
732            bar:\n\
733            divine 1\n\
734            hint got_insight: Magic = stack[0]\n\
735            skiz\n\
736            split\n\
737            break\n\
738            assert\n\
739            error_id 1337\n\
740            return\n\
741        ";
742        let program = Program::from_code(source_code).unwrap();
743        let printed_program = format!("{program}");
744        assert_eq!(source_code, &printed_program);
745    }
746}