1use std::collections::HashMap;
2use std::collections::HashSet;
3use std::collections::hash_map::Entry;
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::fmt::Result as FmtResult;
7use std::hash::Hash;
8use std::io::Cursor;
9
10use arbitrary::Arbitrary;
11use get_size2::GetSize;
12use itertools::Itertools;
13use serde::Deserialize;
14use serde::Serialize;
15use thiserror::Error;
16use twenty_first::prelude::*;
17
18use crate::instruction::AnInstruction;
19use crate::instruction::AssertionContext;
20use crate::instruction::Instruction;
21use crate::instruction::InstructionError;
22use crate::instruction::LabelledInstruction;
23use crate::instruction::TypeHint;
24use crate::parser;
25use crate::parser::ParseError;
26
27#[derive(Debug, Clone, Eq, Serialize, Deserialize, GetSize)]
41pub struct Program {
42 pub instructions: Vec<Instruction>,
43 address_to_label: HashMap<u64, String>,
44 debug_information: DebugInformation,
45}
46
47impl Display for Program {
48 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
49 for instruction in self.labelled_instructions() {
50 writeln!(f, "{instruction}")?;
51 }
52 Ok(())
53 }
54}
55
56impl PartialEq for Program {
57 fn eq(&self, other: &Program) -> bool {
58 self.instructions.eq(&other.instructions)
59 }
60}
61
62impl BFieldCodec for Program {
63 type Error = ProgramDecodingError;
64
65 fn decode(sequence: &[BFieldElement]) -> Result<Box<Self>, Self::Error> {
66 if sequence.is_empty() {
67 return Err(Self::Error::EmptySequence);
68 }
69 let program_length = sequence[0].value() as usize;
70 let sequence = &sequence[1..];
71 if sequence.len() < program_length {
72 return Err(Self::Error::SequenceTooShort);
73 }
74 if sequence.len() > program_length {
75 return Err(Self::Error::SequenceTooLong);
76 }
77
78 let mut instructions = vec![];
80 let mut read_idx = 0;
81 while read_idx < program_length {
82 let opcode = sequence[read_idx];
83 let mut instruction = Instruction::try_from(opcode)
84 .map_err(|err| Self::Error::InvalidInstruction(read_idx, err))?;
85 let instruction_has_arg = instruction.arg().is_some();
86 if instruction_has_arg && instructions.len() + instruction.size() > program_length {
87 return Err(Self::Error::MissingArgument(read_idx, instruction));
88 }
89 if instruction_has_arg {
90 let arg = sequence[read_idx + 1];
91 instruction = instruction
92 .change_arg(arg)
93 .map_err(|err| Self::Error::InvalidInstruction(read_idx, err))?;
94 }
95
96 instructions.extend(vec![instruction; instruction.size()]);
97 read_idx += instruction.size();
98 }
99
100 if read_idx != program_length {
101 return Err(Self::Error::LengthMismatch);
102 }
103 if instructions.len() != program_length {
104 return Err(Self::Error::LengthMismatch);
105 }
106
107 Ok(Box::new(Program {
108 instructions,
109 address_to_label: HashMap::default(),
110 debug_information: DebugInformation::default(),
111 }))
112 }
113
114 fn encode(&self) -> Vec<BFieldElement> {
115 let mut sequence = Vec::with_capacity(self.len_bwords() + 1);
116 sequence.push(bfe!(self.len_bwords() as u64));
117 sequence.extend(self.to_bwords());
118 sequence
119 }
120
121 fn static_length() -> Option<usize> {
122 None
123 }
124}
125
126impl<'a> Arbitrary<'a> for Program {
127 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
128 let contains_label = |labelled_instructions: &[_], maybe_label: &_| {
129 let LabelledInstruction::Label(label) = maybe_label else {
130 return false;
131 };
132 labelled_instructions
133 .iter()
134 .any(|labelled_instruction| match labelled_instruction {
135 LabelledInstruction::Label(l) => l == label,
136 _ => false,
137 })
138 };
139 let is_assertion = |maybe_instruction: &_| {
140 matches!(
141 maybe_instruction,
142 LabelledInstruction::Instruction(
143 AnInstruction::Assert | AnInstruction::AssertVector
144 )
145 )
146 };
147
148 let mut labelled_instructions = vec![];
149 for _ in 0..u.arbitrary_len::<LabelledInstruction>()? {
150 let labelled_instruction = u.arbitrary()?;
151 if contains_label(&labelled_instructions, &labelled_instruction) {
152 continue;
153 }
154 if let LabelledInstruction::AssertionContext(_) = labelled_instruction {
155 continue;
157 }
158
159 let is_assertion = is_assertion(&labelled_instruction);
160 labelled_instructions.push(labelled_instruction);
161
162 if is_assertion && u.arbitrary()? {
163 let assertion_context = LabelledInstruction::AssertionContext(u.arbitrary()?);
164 labelled_instructions.push(assertion_context);
165 }
166 }
167
168 let all_call_targets = labelled_instructions
169 .iter()
170 .filter_map(|instruction| match instruction {
171 LabelledInstruction::Instruction(AnInstruction::Call(target)) => Some(target),
172 _ => None,
173 })
174 .unique();
175 let labels_that_are_called_but_not_declared = all_call_targets
176 .map(|target| LabelledInstruction::Label(target.clone()))
177 .filter(|label| !contains_label(&labelled_instructions, label))
178 .collect_vec();
179
180 for label in labels_that_are_called_but_not_declared {
181 let insertion_index = u.choose_index(labelled_instructions.len() + 1)?;
182 labelled_instructions.insert(insertion_index, label);
183 }
184
185 Ok(Program::new(&labelled_instructions))
186 }
187}
188
189#[derive(Debug, Default, Clone, Eq, PartialEq)]
192pub struct InstructionIter {
193 cursor: Cursor<Vec<Instruction>>,
194}
195
196impl Iterator for InstructionIter {
197 type Item = Instruction;
198
199 fn next(&mut self) -> Option<Self::Item> {
200 let pos = self.cursor.position() as usize;
201 let instructions = self.cursor.get_ref();
202 let instruction = *instructions.get(pos)?;
203 self.cursor.set_position((pos + instruction.size()) as u64);
204
205 Some(instruction)
206 }
207}
208
209impl IntoIterator for Program {
210 type Item = Instruction;
211
212 type IntoIter = InstructionIter;
213
214 fn into_iter(self) -> Self::IntoIter {
215 let cursor = Cursor::new(self.instructions);
216 InstructionIter { cursor }
217 }
218}
219
220#[derive(Debug, Default, Clone, Eq, PartialEq, Serialize, Deserialize, Arbitrary, GetSize)]
221struct DebugInformation {
222 breakpoints: Vec<bool>,
223 type_hints: HashMap<u64, Vec<TypeHint>>,
224 assertion_context: HashMap<u64, AssertionContext>,
225}
226
227impl Program {
228 pub fn new(labelled_instructions: &[LabelledInstruction]) -> Self {
229 let label_to_address = parser::build_label_to_address_map(labelled_instructions);
230 let instructions =
231 parser::turn_labels_into_addresses(labelled_instructions, &label_to_address);
232 let address_to_label = Self::flip_map(label_to_address);
233 let debug_information = Self::extract_debug_information(labelled_instructions);
234
235 debug_assert_eq!(instructions.len(), debug_information.breakpoints.len());
236 Program {
237 instructions,
238 address_to_label,
239 debug_information,
240 }
241 }
242
243 fn flip_map<Key, Value: Eq + Hash>(map: HashMap<Key, Value>) -> HashMap<Value, Key> {
244 map.into_iter().map(|(key, value)| (value, key)).collect()
245 }
246
247 fn extract_debug_information(
248 labelled_instructions: &[LabelledInstruction],
249 ) -> DebugInformation {
250 let mut address = 0;
251 let mut break_before_next_instruction = false;
252 let mut debug_info = DebugInformation::default();
253 for instruction in labelled_instructions {
254 match instruction {
255 LabelledInstruction::Instruction(instruction) => {
256 let new_breakpoints = vec![break_before_next_instruction; instruction.size()];
257 debug_info.breakpoints.extend(new_breakpoints);
258 break_before_next_instruction = false;
259 address += instruction.size() as u64;
260 }
261 LabelledInstruction::Label(_) => (),
262 LabelledInstruction::Breakpoint => break_before_next_instruction = true,
263 LabelledInstruction::TypeHint(hint) => match debug_info.type_hints.entry(address) {
264 Entry::Occupied(mut entry) => entry.get_mut().push(hint.clone()),
265 Entry::Vacant(entry) => entry.insert(vec![]).push(hint.clone()),
266 },
267 LabelledInstruction::AssertionContext(ctx) => {
268 let address_of_associated_assertion = address.saturating_sub(1);
269 debug_info
270 .assertion_context
271 .insert(address_of_associated_assertion, ctx.clone());
272 }
273 }
274 }
275
276 debug_info
277 }
278
279 pub fn from_code(code: &str) -> Result<Self, ParseError<'_>> {
281 parser::parse(code)
282 .map(|tokens| parser::to_labelled_instructions(&tokens))
283 .map(|instructions| Program::new(&instructions))
284 }
285
286 pub fn labelled_instructions(&self) -> Vec<LabelledInstruction> {
287 let call_targets = self.call_targets();
288 let instructions_with_labels = self.instructions.iter().map(|instruction| {
289 instruction.map_call_address(|&address| self.label_for_address(address.value()))
290 });
291
292 let mut labelled_instructions = vec![];
293 let mut address = 0;
294 let mut instruction_stream = instructions_with_labels.into_iter();
295 while let Some(instruction) = instruction_stream.next() {
296 let instruction_size = instruction.size() as u64;
297 if call_targets.contains(&address) {
298 let label = self.label_for_address(address);
299 let label = LabelledInstruction::Label(label);
300 labelled_instructions.push(label);
301 }
302 for type_hint in self.type_hints_at(address) {
303 labelled_instructions.push(LabelledInstruction::TypeHint(type_hint));
304 }
305 if self.is_breakpoint(address) {
306 labelled_instructions.push(LabelledInstruction::Breakpoint);
307 }
308 labelled_instructions.push(LabelledInstruction::Instruction(instruction));
309 if let Some(context) = self.assertion_context_at(address) {
310 labelled_instructions.push(LabelledInstruction::AssertionContext(context));
311 }
312
313 for _ in 1..instruction_size {
314 instruction_stream.next();
315 }
316 address += instruction_size;
317 }
318
319 let leftover_labels = self
320 .address_to_label
321 .iter()
322 .filter(|&(&labels_address, _)| labels_address >= address)
323 .sorted();
324 for (_, label) in leftover_labels {
325 labelled_instructions.push(LabelledInstruction::Label(label.clone()));
326 }
327
328 labelled_instructions
329 }
330
331 fn call_targets(&self) -> HashSet<u64> {
332 self.instructions
333 .iter()
334 .filter_map(|instruction| match instruction {
335 Instruction::Call(address) => Some(address.value()),
336 _ => None,
337 })
338 .collect()
339 }
340
341 pub fn is_breakpoint(&self, address: u64) -> bool {
342 let address: usize = address.try_into().unwrap();
343 self.debug_information
344 .breakpoints
345 .get(address)
346 .copied()
347 .unwrap_or_default()
348 }
349
350 pub fn type_hints_at(&self, address: u64) -> Vec<TypeHint> {
351 self.debug_information
352 .type_hints
353 .get(&address)
354 .cloned()
355 .unwrap_or_default()
356 }
357
358 pub fn assertion_context_at(&self, address: u64) -> Option<AssertionContext> {
359 self.debug_information
360 .assertion_context
361 .get(&address)
362 .cloned()
363 }
364
365 pub fn to_bwords(&self) -> Vec<BFieldElement> {
372 self.clone()
373 .into_iter()
374 .flat_map(|instruction| {
375 let opcode = instruction.opcode_b();
376 if let Some(arg) = instruction.arg() {
377 vec![opcode, arg]
378 } else {
379 vec![opcode]
380 }
381 })
382 .collect()
383 }
384
385 pub fn len_bwords(&self) -> usize {
388 self.instructions.len()
389 }
390
391 pub fn is_empty(&self) -> bool {
392 self.instructions.is_empty()
393 }
394
395 pub fn hash(&self) -> Digest {
398 Tip5::hash_varlen(&self.to_bwords())
400 }
401
402 pub fn label_for_address(&self, address: u64) -> String {
405 self.address_to_label
408 .get(&address)
409 .cloned()
410 .unwrap_or_else(|| format!("address_{address}"))
411 }
412}
413
414#[non_exhaustive]
415#[derive(Debug, Clone, Eq, PartialEq, Error)]
416pub enum ProgramDecodingError {
417 #[error("sequence to decode is empty")]
418 EmptySequence,
419
420 #[error("sequence to decode is too short")]
421 SequenceTooShort,
422
423 #[error("sequence to decode is too long")]
424 SequenceTooLong,
425
426 #[error("length of decoded program is unexpected")]
427 LengthMismatch,
428
429 #[error("sequence to decode contains invalid instruction at index {0}: {1}")]
430 InvalidInstruction(usize, InstructionError),
431
432 #[error("missing argument for instruction {1} at index {0}")]
433 MissingArgument(usize, Instruction),
434}
435
436#[cfg(test)]
437#[cfg_attr(coverage_nightly, coverage(off))]
438mod tests {
439 use assert2::assert;
440 use assert2::let_assert;
441 use proptest::prelude::*;
442 use proptest_arbitrary_interop::arb;
443 use rand::Rng;
444 use test_strategy::proptest;
445
446 use crate::triton_program;
447
448 use super::*;
449
450 #[proptest]
451 fn random_program_encode_decode_equivalence(#[strategy(arb())] program: Program) {
452 let encoding = program.encode();
453 let decoding = *Program::decode(&encoding).unwrap();
454 prop_assert_eq!(program, decoding);
455 }
456
457 #[test]
458 fn decode_program_with_missing_argument_as_last_instruction() {
459 let program = triton_program!(push 3 push 3 eq assert push 3);
460 let program_length = program.len_bwords() as u64;
461 let encoded = program.encode();
462
463 let mut encoded = encoded[0..encoded.len() - 1].to_vec();
464 encoded[0] = bfe!(program_length - 1);
465
466 let_assert!(Err(err) = Program::decode(&encoded));
467 let_assert!(ProgramDecodingError::MissingArgument(6, _) = err);
468 }
469
470 #[test]
471 fn decode_program_with_shorter_than_indicated_sequence() {
472 let program = triton_program!(nop nop hash push 0 skiz end: halt call end);
473 let mut encoded = program.encode();
474 encoded[0] += bfe!(1);
475 let_assert!(Err(err) = Program::decode(&encoded));
476 let_assert!(ProgramDecodingError::SequenceTooShort = err);
477 }
478
479 #[test]
480 fn decode_program_with_longer_than_indicated_sequence() {
481 let program = triton_program!(nop nop hash push 0 skiz end: halt call end);
482 let mut encoded = program.encode();
483 encoded[0] -= bfe!(1);
484 let_assert!(Err(err) = Program::decode(&encoded));
485 let_assert!(ProgramDecodingError::SequenceTooLong = err);
486 }
487
488 #[test]
489 fn decode_program_from_empty_sequence() {
490 let encoded = vec![];
491 let_assert!(Err(err) = Program::decode(&encoded));
492 let_assert!(ProgramDecodingError::EmptySequence = err);
493 }
494
495 #[test]
496 fn hash_simple_program() {
497 let program = triton_program!(halt);
498 let digest = program.hash();
499
500 let expected_digest = bfe_array![
501 0x4338_de79_520b_3949_u64,
502 0xe6a2_129b_2885_0dc9_u64,
503 0xfd3c_d098_6a86_0450_u64,
504 0x69fd_ba91_0ceb_a7bc_u64,
505 0x7e5b_118c_9594_c062_u64,
506 ];
507 let expected_digest = Digest::new(expected_digest);
508
509 assert!(expected_digest == digest);
510 }
511
512 #[test]
513 fn empty_program_is_empty() {
514 let program = triton_program!();
515 assert!(program.is_empty());
516 }
517
518 #[test]
519 fn create_program_from_code() {
520 let element_3 = rand::rng().random_range(0..BFieldElement::P);
521 let element_2 = 1337_usize;
522 let element_1 = "17";
523 let element_0 = bfe!(0);
524 let instruction_push = Instruction::Push(bfe!(42));
525 let dup_arg = 1;
526 let label = "my_label".to_string();
527
528 let source_code = format!(
529 "push {element_3} push {element_2} push {element_1} push {element_0}
530 call {label} halt
531 {label}:
532 {instruction_push}
533 dup {dup_arg}
534 skiz
535 recurse
536 return"
537 );
538 let program_from_code = Program::from_code(&source_code).unwrap();
539 let program_from_macro = triton_program!({ source_code });
540 assert!(program_from_code == program_from_macro);
541 }
542
543 #[test]
544 fn parser_macro_with_interpolated_label_as_first_argument() {
545 let label = "my_label";
546 let _program = triton_program!(
547 {label}: push 1 assert halt
548 );
549 }
550
551 #[test]
552 fn breakpoints_propagate_to_debug_information_as_expected() {
553 let program = triton_program! {
554 break push 1 push 2
555 break break break break
556 pop 2 hash halt
557 break };
559
560 assert!(program.is_breakpoint(0));
561 assert!(program.is_breakpoint(1));
562 assert!(!program.is_breakpoint(2));
563 assert!(!program.is_breakpoint(3));
564 assert!(program.is_breakpoint(4));
565 assert!(program.is_breakpoint(5));
566 assert!(!program.is_breakpoint(6));
567 assert!(!program.is_breakpoint(7));
568
569 assert!(!program.is_breakpoint(8));
571 assert!(!program.is_breakpoint(9));
572 }
573
574 #[test]
575 fn print_program_without_any_debug_information() {
576 let program = triton_program! {
577 call foo
578 call bar
579 call baz
580 halt
581 foo: nop nop return
582 bar: call baz return
583 baz: push 1 return
584 };
585 let encoding = program.encode();
586 let program = Program::decode(&encoding).unwrap();
587 println!("{program}");
588 }
589
590 #[proptest]
591 fn printed_program_can_be_parsed_again(#[strategy(arb())] program: Program) {
592 parser::parse(&program.to_string())?;
593 }
594
595 struct TypeHintTestCase {
596 expected: TypeHint,
597 input: &'static str,
598 }
599
600 impl TypeHintTestCase {
601 fn run(&self) {
602 let program = Program::from_code(self.input).unwrap();
603 let [ref type_hint] = program.type_hints_at(0)[..] else {
604 panic!("Expected a single type hint at address 0");
605 };
606 assert!(&self.expected == type_hint);
607 }
608 }
609
610 #[test]
611 fn parse_simple_type_hint() {
612 let expected = TypeHint {
613 starting_index: 0,
614 length: 1,
615 type_name: Some("Type".to_string()),
616 variable_name: "foo".to_string(),
617 };
618
619 TypeHintTestCase {
620 expected,
621 input: "hint foo: Type = stack[0]",
622 }
623 .run();
624 }
625
626 #[test]
627 fn parse_type_hint_with_range() {
628 let expected = TypeHint {
629 starting_index: 0,
630 length: 5,
631 type_name: Some("Digest".to_string()),
632 variable_name: "foo".to_string(),
633 };
634
635 TypeHintTestCase {
636 expected,
637 input: "hint foo: Digest = stack[0..5]",
638 }
639 .run();
640 }
641
642 #[test]
643 fn parse_type_hint_with_range_and_offset() {
644 let expected = TypeHint {
645 starting_index: 7,
646 length: 3,
647 type_name: Some("XFieldElement".to_string()),
648 variable_name: "bar".to_string(),
649 };
650
651 TypeHintTestCase {
652 expected,
653 input: "hint bar: XFieldElement = stack[7..10]",
654 }
655 .run();
656 }
657
658 #[test]
659 fn parse_type_hint_with_range_and_offset_and_weird_whitespace() {
660 let expected = TypeHint {
661 starting_index: 2,
662 length: 12,
663 type_name: Some("BigType".to_string()),
664 variable_name: "bar".to_string(),
665 };
666
667 TypeHintTestCase {
668 expected,
669 input: " hint \t \t bar :BigType=stack[ 2\t.. 14 ]\t \n",
670 }
671 .run();
672 }
673
674 #[test]
675 fn parse_type_hint_with_no_type_only_variable_name() {
676 let expected = TypeHint {
677 starting_index: 0,
678 length: 1,
679 type_name: None,
680 variable_name: "foo".to_string(),
681 };
682
683 TypeHintTestCase {
684 expected,
685 input: "hint foo = stack[0]",
686 }
687 .run();
688 }
689
690 #[test]
691 fn parse_type_hint_with_no_type_only_variable_name_with_range() {
692 let expected = TypeHint {
693 starting_index: 2,
694 length: 5,
695 type_name: None,
696 variable_name: "foo".to_string(),
697 };
698
699 TypeHintTestCase {
700 expected,
701 input: "hint foo = stack[2..7]",
702 }
703 .run();
704 }
705
706 #[test]
707 fn assertion_context_is_propagated_into_debug_info() {
708 let program = triton_program! {push 1000 assert error_id 17 halt};
709 let assertion_contexts = program.debug_information.assertion_context;
712 assert!(1 == assertion_contexts.len());
713 let_assert!(AssertionContext::ID(error_id) = &assertion_contexts[&2]);
714 assert!(17 == *error_id);
715 }
716
717 #[test]
718 fn printing_program_includes_debug_information() {
719 let source_code = "\
720 call foo\n\
721 break\n\
722 call bar\n\
723 halt\n\
724 foo:\n\
725 break\n\
726 call baz\n\
727 push 1\n\
728 nop\n\
729 return\n\
730 baz:\n\
731 hash\n\
732 hint my_digest: Digest = stack[0..5]\n\
733 hint random_stuff = stack[17]\n\
734 return\n\
735 nop\n\
736 pop 1\n\
737 bar:\n\
738 divine 1\n\
739 hint got_insight: Magic = stack[0]\n\
740 skiz\n\
741 split\n\
742 break\n\
743 assert\n\
744 error_id 1337\n\
745 return\n\
746 ";
747 let program = Program::from_code(source_code).unwrap();
748 let printed_program = format!("{program}");
749 assert_eq!(source_code, &printed_program);
750 }
751}