1use std::collections::hash_map::Entry;
2use std::collections::HashMap;
3use std::collections::HashSet;
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::fmt::Result as FmtResult;
7use std::hash::Hash;
8use std::io::Cursor;
9
10use arbitrary::Arbitrary;
11use get_size2::GetSize;
12use itertools::Itertools;
13use serde::Deserialize;
14use serde::Serialize;
15use thiserror::Error;
16use twenty_first::prelude::*;
17
18use crate::instruction::AnInstruction;
19use crate::instruction::AssertionContext;
20use crate::instruction::Instruction;
21use crate::instruction::InstructionError;
22use crate::instruction::LabelledInstruction;
23use crate::instruction::TypeHint;
24use crate::parser;
25use crate::parser::ParseError;
26
27#[derive(Debug, Clone, Eq, Serialize, Deserialize, GetSize)]
40pub struct Program {
41 pub instructions: Vec<Instruction>,
42 address_to_label: HashMap<u64, String>,
43 debug_information: DebugInformation,
44}
45
46impl Display for Program {
47 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
48 for instruction in self.labelled_instructions() {
49 writeln!(f, "{instruction}")?;
50 }
51 Ok(())
52 }
53}
54
55impl PartialEq for Program {
56 fn eq(&self, other: &Program) -> bool {
57 self.instructions.eq(&other.instructions)
58 }
59}
60
61impl BFieldCodec for Program {
62 type Error = ProgramDecodingError;
63
64 fn decode(sequence: &[BFieldElement]) -> Result<Box<Self>, Self::Error> {
65 if sequence.is_empty() {
66 return Err(Self::Error::EmptySequence);
67 }
68 let program_length = sequence[0].value() as usize;
69 let sequence = &sequence[1..];
70 if sequence.len() < program_length {
71 return Err(Self::Error::SequenceTooShort);
72 }
73 if sequence.len() > program_length {
74 return Err(Self::Error::SequenceTooLong);
75 }
76
77 let mut instructions = vec![];
79 let mut read_idx = 0;
80 while read_idx < program_length {
81 let opcode = sequence[read_idx];
82 let mut instruction = Instruction::try_from(opcode)
83 .map_err(|err| Self::Error::InvalidInstruction(read_idx, err))?;
84 let instruction_has_arg = instruction.arg().is_some();
85 if instruction_has_arg && instructions.len() + instruction.size() > program_length {
86 return Err(Self::Error::MissingArgument(read_idx, instruction));
87 }
88 if instruction_has_arg {
89 let arg = sequence[read_idx + 1];
90 instruction = instruction
91 .change_arg(arg)
92 .map_err(|err| Self::Error::InvalidInstruction(read_idx, err))?;
93 }
94
95 instructions.extend(vec![instruction; instruction.size()]);
96 read_idx += instruction.size();
97 }
98
99 if read_idx != program_length {
100 return Err(Self::Error::LengthMismatch);
101 }
102 if instructions.len() != program_length {
103 return Err(Self::Error::LengthMismatch);
104 }
105
106 Ok(Box::new(Program {
107 instructions,
108 address_to_label: HashMap::default(),
109 debug_information: DebugInformation::default(),
110 }))
111 }
112
113 fn encode(&self) -> Vec<BFieldElement> {
114 let mut sequence = Vec::with_capacity(self.len_bwords() + 1);
115 sequence.push(bfe!(self.len_bwords() as u64));
116 sequence.extend(self.to_bwords());
117 sequence
118 }
119
120 fn static_length() -> Option<usize> {
121 None
122 }
123}
124
125impl<'a> Arbitrary<'a> for Program {
126 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
127 let contains_label = |labelled_instructions: &[_], maybe_label: &_| {
128 let LabelledInstruction::Label(label) = maybe_label else {
129 return false;
130 };
131 labelled_instructions
132 .iter()
133 .any(|labelled_instruction| match labelled_instruction {
134 LabelledInstruction::Label(l) => l == label,
135 _ => false,
136 })
137 };
138 let is_assertion = |maybe_instruction: &_| {
139 matches!(
140 maybe_instruction,
141 LabelledInstruction::Instruction(
142 AnInstruction::Assert | AnInstruction::AssertVector
143 )
144 )
145 };
146
147 let mut labelled_instructions = vec![];
148 for _ in 0..u.arbitrary_len::<LabelledInstruction>()? {
149 let labelled_instruction = u.arbitrary()?;
150 if contains_label(&labelled_instructions, &labelled_instruction) {
151 continue;
152 }
153 if let LabelledInstruction::AssertionContext(_) = labelled_instruction {
154 continue;
156 }
157
158 let is_assertion = is_assertion(&labelled_instruction);
159 labelled_instructions.push(labelled_instruction);
160
161 if is_assertion && u.arbitrary()? {
162 let assertion_context = LabelledInstruction::AssertionContext(u.arbitrary()?);
163 labelled_instructions.push(assertion_context);
164 }
165 }
166
167 let all_call_targets = labelled_instructions
168 .iter()
169 .filter_map(|instruction| match instruction {
170 LabelledInstruction::Instruction(AnInstruction::Call(target)) => Some(target),
171 _ => None,
172 })
173 .unique();
174 let labels_that_are_called_but_not_declared = all_call_targets
175 .map(|target| LabelledInstruction::Label(target.clone()))
176 .filter(|label| !contains_label(&labelled_instructions, label))
177 .collect_vec();
178
179 for label in labels_that_are_called_but_not_declared {
180 let insertion_index = u.choose_index(labelled_instructions.len() + 1)?;
181 labelled_instructions.insert(insertion_index, label);
182 }
183
184 Ok(Program::new(&labelled_instructions))
185 }
186}
187
188#[derive(Debug, Default, Clone, Eq, PartialEq)]
190pub struct InstructionIter {
191 cursor: Cursor<Vec<Instruction>>,
192}
193
194impl Iterator for InstructionIter {
195 type Item = Instruction;
196
197 fn next(&mut self) -> Option<Self::Item> {
198 let pos = self.cursor.position() as usize;
199 let instructions = self.cursor.get_ref();
200 let instruction = *instructions.get(pos)?;
201 self.cursor.set_position((pos + instruction.size()) as u64);
202
203 Some(instruction)
204 }
205}
206
207impl IntoIterator for Program {
208 type Item = Instruction;
209
210 type IntoIter = InstructionIter;
211
212 fn into_iter(self) -> Self::IntoIter {
213 let cursor = Cursor::new(self.instructions);
214 InstructionIter { cursor }
215 }
216}
217
218#[derive(Debug, Default, Clone, Eq, PartialEq, Serialize, Deserialize, Arbitrary, GetSize)]
219struct DebugInformation {
220 breakpoints: Vec<bool>,
221 type_hints: HashMap<u64, Vec<TypeHint>>,
222 assertion_context: HashMap<u64, AssertionContext>,
223}
224
225impl Program {
226 pub fn new(labelled_instructions: &[LabelledInstruction]) -> Self {
227 let label_to_address = parser::build_label_to_address_map(labelled_instructions);
228 let instructions =
229 parser::turn_labels_into_addresses(labelled_instructions, &label_to_address);
230 let address_to_label = Self::flip_map(label_to_address);
231 let debug_information = Self::extract_debug_information(labelled_instructions);
232
233 debug_assert_eq!(instructions.len(), debug_information.breakpoints.len());
234 Program {
235 instructions,
236 address_to_label,
237 debug_information,
238 }
239 }
240
241 fn flip_map<Key, Value: Eq + Hash>(map: HashMap<Key, Value>) -> HashMap<Value, Key> {
242 map.into_iter().map(|(key, value)| (value, key)).collect()
243 }
244
245 fn extract_debug_information(
246 labelled_instructions: &[LabelledInstruction],
247 ) -> DebugInformation {
248 let mut address = 0;
249 let mut break_before_next_instruction = false;
250 let mut debug_info = DebugInformation::default();
251 for instruction in labelled_instructions {
252 match instruction {
253 LabelledInstruction::Instruction(instruction) => {
254 let new_breakpoints = vec![break_before_next_instruction; instruction.size()];
255 debug_info.breakpoints.extend(new_breakpoints);
256 break_before_next_instruction = false;
257 address += instruction.size() as u64;
258 }
259 LabelledInstruction::Label(_) => (),
260 LabelledInstruction::Breakpoint => break_before_next_instruction = true,
261 LabelledInstruction::TypeHint(hint) => match debug_info.type_hints.entry(address) {
262 Entry::Occupied(mut entry) => entry.get_mut().push(hint.clone()),
263 Entry::Vacant(entry) => entry.insert(vec![]).push(hint.clone()),
264 },
265 LabelledInstruction::AssertionContext(ctx) => {
266 let address_of_associated_assertion = address.saturating_sub(1);
267 debug_info
268 .assertion_context
269 .insert(address_of_associated_assertion, ctx.clone());
270 }
271 }
272 }
273
274 debug_info
275 }
276
277 pub fn from_code(code: &str) -> Result<Self, ParseError> {
279 parser::parse(code)
280 .map(|tokens| parser::to_labelled_instructions(&tokens))
281 .map(|instructions| Program::new(&instructions))
282 }
283
284 pub fn labelled_instructions(&self) -> Vec<LabelledInstruction> {
285 let call_targets = self.call_targets();
286 let instructions_with_labels = self.instructions.iter().map(|instruction| {
287 instruction.map_call_address(|&address| self.label_for_address(address.value()))
288 });
289
290 let mut labelled_instructions = vec![];
291 let mut address = 0;
292 let mut instruction_stream = instructions_with_labels.into_iter();
293 while let Some(instruction) = instruction_stream.next() {
294 let instruction_size = instruction.size() as u64;
295 if call_targets.contains(&address) {
296 let label = self.label_for_address(address);
297 let label = LabelledInstruction::Label(label);
298 labelled_instructions.push(label);
299 }
300 for type_hint in self.type_hints_at(address) {
301 labelled_instructions.push(LabelledInstruction::TypeHint(type_hint));
302 }
303 if self.is_breakpoint(address) {
304 labelled_instructions.push(LabelledInstruction::Breakpoint);
305 }
306 labelled_instructions.push(LabelledInstruction::Instruction(instruction));
307 if let Some(context) = self.assertion_context_at(address) {
308 labelled_instructions.push(LabelledInstruction::AssertionContext(context));
309 }
310
311 for _ in 1..instruction_size {
312 instruction_stream.next();
313 }
314 address += instruction_size;
315 }
316
317 let leftover_labels = self
318 .address_to_label
319 .iter()
320 .filter(|(&labels_address, _)| labels_address >= address)
321 .sorted();
322 for (_, label) in leftover_labels {
323 labelled_instructions.push(LabelledInstruction::Label(label.clone()));
324 }
325
326 labelled_instructions
327 }
328
329 fn call_targets(&self) -> HashSet<u64> {
330 self.instructions
331 .iter()
332 .filter_map(|instruction| match instruction {
333 Instruction::Call(address) => Some(address.value()),
334 _ => None,
335 })
336 .collect()
337 }
338
339 pub fn is_breakpoint(&self, address: u64) -> bool {
340 let address: usize = address.try_into().unwrap();
341 self.debug_information
342 .breakpoints
343 .get(address)
344 .copied()
345 .unwrap_or_default()
346 }
347
348 pub fn type_hints_at(&self, address: u64) -> Vec<TypeHint> {
349 self.debug_information
350 .type_hints
351 .get(&address)
352 .cloned()
353 .unwrap_or_default()
354 }
355
356 pub fn assertion_context_at(&self, address: u64) -> Option<AssertionContext> {
357 self.debug_information
358 .assertion_context
359 .get(&address)
360 .cloned()
361 }
362
363 pub fn to_bwords(&self) -> Vec<BFieldElement> {
369 self.clone()
370 .into_iter()
371 .flat_map(|instruction| {
372 let opcode = instruction.opcode_b();
373 if let Some(arg) = instruction.arg() {
374 vec![opcode, arg]
375 } else {
376 vec![opcode]
377 }
378 })
379 .collect()
380 }
381
382 pub fn len_bwords(&self) -> usize {
385 self.instructions.len()
386 }
387
388 pub fn is_empty(&self) -> bool {
389 self.instructions.is_empty()
390 }
391
392 pub fn hash(&self) -> Digest {
395 Tip5::hash_varlen(&self.to_bwords())
397 }
398
399 pub fn label_for_address(&self, address: u64) -> String {
401 self.address_to_label
404 .get(&address)
405 .cloned()
406 .unwrap_or_else(|| format!("address_{address}"))
407 }
408}
409
410#[non_exhaustive]
411#[derive(Debug, Clone, Eq, PartialEq, Error)]
412pub enum ProgramDecodingError {
413 #[error("sequence to decode is empty")]
414 EmptySequence,
415
416 #[error("sequence to decode is too short")]
417 SequenceTooShort,
418
419 #[error("sequence to decode is too long")]
420 SequenceTooLong,
421
422 #[error("length of decoded program is unexpected")]
423 LengthMismatch,
424
425 #[error("sequence to decode contains invalid instruction at index {0}: {1}")]
426 InvalidInstruction(usize, InstructionError),
427
428 #[error("missing argument for instruction {1} at index {0}")]
429 MissingArgument(usize, Instruction),
430}
431
432#[cfg(test)]
433mod tests {
434 use assert2::assert;
435 use assert2::let_assert;
436 use proptest::prelude::*;
437 use proptest_arbitrary_interop::arb;
438 use rand::Rng;
439 use test_strategy::proptest;
440
441 use crate::triton_program;
442
443 use super::*;
444
445 #[proptest]
446 fn random_program_encode_decode_equivalence(#[strategy(arb())] program: Program) {
447 let encoding = program.encode();
448 let decoding = *Program::decode(&encoding).unwrap();
449 prop_assert_eq!(program, decoding);
450 }
451
452 #[test]
453 fn decode_program_with_missing_argument_as_last_instruction() {
454 let program = triton_program!(push 3 push 3 eq assert push 3);
455 let program_length = program.len_bwords() as u64;
456 let encoded = program.encode();
457
458 let mut encoded = encoded[0..encoded.len() - 1].to_vec();
459 encoded[0] = bfe!(program_length - 1);
460
461 let_assert!(Err(err) = Program::decode(&encoded));
462 let_assert!(ProgramDecodingError::MissingArgument(6, _) = err);
463 }
464
465 #[test]
466 fn decode_program_with_shorter_than_indicated_sequence() {
467 let program = triton_program!(nop nop hash push 0 skiz end: halt call end);
468 let mut encoded = program.encode();
469 encoded[0] += bfe!(1);
470 let_assert!(Err(err) = Program::decode(&encoded));
471 let_assert!(ProgramDecodingError::SequenceTooShort = err);
472 }
473
474 #[test]
475 fn decode_program_with_longer_than_indicated_sequence() {
476 let program = triton_program!(nop nop hash push 0 skiz end: halt call end);
477 let mut encoded = program.encode();
478 encoded[0] -= bfe!(1);
479 let_assert!(Err(err) = Program::decode(&encoded));
480 let_assert!(ProgramDecodingError::SequenceTooLong = err);
481 }
482
483 #[test]
484 fn decode_program_from_empty_sequence() {
485 let encoded = vec![];
486 let_assert!(Err(err) = Program::decode(&encoded));
487 let_assert!(ProgramDecodingError::EmptySequence = err);
488 }
489
490 #[test]
491 fn hash_simple_program() {
492 let program = triton_program!(halt);
493 let digest = program.hash();
494
495 let expected_digest = bfe_array![
496 0x4338_de79_520b_3949_u64,
497 0xe6a2_129b_2885_0dc9_u64,
498 0xfd3c_d098_6a86_0450_u64,
499 0x69fd_ba91_0ceb_a7bc_u64,
500 0x7e5b_118c_9594_c062_u64,
501 ];
502 let expected_digest = Digest::new(expected_digest);
503
504 assert!(expected_digest == digest);
505 }
506
507 #[test]
508 fn empty_program_is_empty() {
509 let program = triton_program!();
510 assert!(program.is_empty());
511 }
512
513 #[test]
514 fn create_program_from_code() {
515 let element_3 = rand::rng().random_range(0..BFieldElement::P);
516 let element_2 = 1337_usize;
517 let element_1 = "17";
518 let element_0 = bfe!(0);
519 let instruction_push = Instruction::Push(bfe!(42));
520 let dup_arg = 1;
521 let label = "my_label".to_string();
522
523 let source_code = format!(
524 "push {element_3} push {element_2} push {element_1} push {element_0}
525 call {label} halt
526 {label}:
527 {instruction_push}
528 dup {dup_arg}
529 skiz
530 recurse
531 return"
532 );
533 let program_from_code = Program::from_code(&source_code).unwrap();
534 let program_from_macro = triton_program!({ source_code });
535 assert!(program_from_code == program_from_macro);
536 }
537
538 #[test]
539 fn parser_macro_with_interpolated_label_as_first_argument() {
540 let label = "my_label";
541 let _program = triton_program!(
542 {label}: push 1 assert halt
543 );
544 }
545
546 #[test]
547 fn breakpoints_propagate_to_debug_information_as_expected() {
548 let program = triton_program! {
549 break push 1 push 2
550 break break break break
551 pop 2 hash halt
552 break };
554
555 assert!(program.is_breakpoint(0));
556 assert!(program.is_breakpoint(1));
557 assert!(!program.is_breakpoint(2));
558 assert!(!program.is_breakpoint(3));
559 assert!(program.is_breakpoint(4));
560 assert!(program.is_breakpoint(5));
561 assert!(!program.is_breakpoint(6));
562 assert!(!program.is_breakpoint(7));
563
564 assert!(!program.is_breakpoint(8));
566 assert!(!program.is_breakpoint(9));
567 }
568
569 #[test]
570 fn print_program_without_any_debug_information() {
571 let program = triton_program! {
572 call foo
573 call bar
574 call baz
575 halt
576 foo: nop nop return
577 bar: call baz return
578 baz: push 1 return
579 };
580 let encoding = program.encode();
581 let program = Program::decode(&encoding).unwrap();
582 println!("{program}");
583 }
584
585 #[proptest]
586 fn printed_program_can_be_parsed_again(#[strategy(arb())] program: Program) {
587 parser::parse(&program.to_string())?;
588 }
589
590 struct TypeHintTestCase {
591 expected: TypeHint,
592 input: &'static str,
593 }
594
595 impl TypeHintTestCase {
596 fn run(&self) {
597 let program = Program::from_code(self.input).unwrap();
598 let [ref type_hint] = program.type_hints_at(0)[..] else {
599 panic!("Expected a single type hint at address 0");
600 };
601 assert!(&self.expected == type_hint);
602 }
603 }
604
605 #[test]
606 fn parse_simple_type_hint() {
607 let expected = TypeHint {
608 starting_index: 0,
609 length: 1,
610 type_name: Some("Type".to_string()),
611 variable_name: "foo".to_string(),
612 };
613
614 TypeHintTestCase {
615 expected,
616 input: "hint foo: Type = stack[0]",
617 }
618 .run();
619 }
620
621 #[test]
622 fn parse_type_hint_with_range() {
623 let expected = TypeHint {
624 starting_index: 0,
625 length: 5,
626 type_name: Some("Digest".to_string()),
627 variable_name: "foo".to_string(),
628 };
629
630 TypeHintTestCase {
631 expected,
632 input: "hint foo: Digest = stack[0..5]",
633 }
634 .run();
635 }
636
637 #[test]
638 fn parse_type_hint_with_range_and_offset() {
639 let expected = TypeHint {
640 starting_index: 7,
641 length: 3,
642 type_name: Some("XFieldElement".to_string()),
643 variable_name: "bar".to_string(),
644 };
645
646 TypeHintTestCase {
647 expected,
648 input: "hint bar: XFieldElement = stack[7..10]",
649 }
650 .run();
651 }
652
653 #[test]
654 fn parse_type_hint_with_range_and_offset_and_weird_whitespace() {
655 let expected = TypeHint {
656 starting_index: 2,
657 length: 12,
658 type_name: Some("BigType".to_string()),
659 variable_name: "bar".to_string(),
660 };
661
662 TypeHintTestCase {
663 expected,
664 input: " hint \t \t bar :BigType=stack[ 2\t.. 14 ]\t \n",
665 }
666 .run();
667 }
668
669 #[test]
670 fn parse_type_hint_with_no_type_only_variable_name() {
671 let expected = TypeHint {
672 starting_index: 0,
673 length: 1,
674 type_name: None,
675 variable_name: "foo".to_string(),
676 };
677
678 TypeHintTestCase {
679 expected,
680 input: "hint foo = stack[0]",
681 }
682 .run();
683 }
684
685 #[test]
686 fn parse_type_hint_with_no_type_only_variable_name_with_range() {
687 let expected = TypeHint {
688 starting_index: 2,
689 length: 5,
690 type_name: None,
691 variable_name: "foo".to_string(),
692 };
693
694 TypeHintTestCase {
695 expected,
696 input: "hint foo = stack[2..7]",
697 }
698 .run();
699 }
700
701 #[test]
702 fn assertion_context_is_propagated_into_debug_info() {
703 let program = triton_program! {push 1000 assert error_id 17 halt};
704 let assertion_contexts = program.debug_information.assertion_context;
707 assert!(1 == assertion_contexts.len());
708 let_assert!(AssertionContext::ID(error_id) = &assertion_contexts[&2]);
709 assert!(17 == *error_id);
710 }
711
712 #[test]
713 fn printing_program_includes_debug_information() {
714 let source_code = "\
715 call foo\n\
716 break\n\
717 call bar\n\
718 halt\n\
719 foo:\n\
720 break\n\
721 call baz\n\
722 push 1\n\
723 nop\n\
724 return\n\
725 baz:\n\
726 hash\n\
727 hint my_digest: Digest = stack[0..5]\n\
728 hint random_stuff = stack[17]\n\
729 return\n\
730 nop\n\
731 pop 1\n\
732 bar:\n\
733 divine 1\n\
734 hint got_insight: Magic = stack[0]\n\
735 skiz\n\
736 split\n\
737 break\n\
738 assert\n\
739 error_id 1337\n\
740 return\n\
741 ";
742 let program = Program::from_code(source_code).unwrap();
743 let printed_program = format!("{program}");
744 assert_eq!(source_code, &printed_program);
745 }
746}