1use std::fmt::Display;
2use std::fmt::Formatter;
3use std::str::FromStr;
4
5use itertools::Itertools;
6use rand::prelude::*;
7use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
8use triton_vm::prelude::*;
9
10use crate::io::InputSource;
11use crate::memory::load_words_from_memory_leave_pointer;
12use crate::memory::load_words_from_memory_pop_pointer;
13use crate::memory::write_words_to_memory_leave_pointer;
14use crate::memory::write_words_to_memory_pop_pointer;
15use crate::pop_encodable;
16
17#[derive(Debug, Clone, Eq, PartialEq, Hash)]
21pub enum DataType {
22 Bool,
23 U32,
24 U64,
25 U128,
26 U160,
27 U192,
28 I128,
29 Bfe,
30 Xfe,
31 Digest,
32 List(Box<DataType>),
33 Array(Box<ArrayType>),
34 Tuple(Vec<DataType>),
35 VoidPointer,
36 StructRef(StructType),
37}
38
39#[derive(Debug, Clone, Hash, PartialEq, Eq)]
40pub struct ArrayType {
41 pub element_type: DataType,
42 pub length: usize,
43}
44
45#[derive(Debug, Clone, Hash, PartialEq, Eq)]
46pub struct StructType {
47 pub name: String,
48 pub fields: Vec<(String, DataType)>,
49}
50
51impl Display for StructType {
52 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53 write!(f, "{}", self.name)
54 }
55}
56
57impl DataType {
58 pub(crate) fn static_length(&self) -> Option<usize> {
60 match self {
61 Self::Bool => bool::static_length(),
62 Self::U32 => u32::static_length(),
63 Self::U64 => u64::static_length(),
64 Self::U128 => u128::static_length(),
65 Self::U160 => Some(5),
66 Self::U192 => Some(6),
67 Self::I128 => i128::static_length(),
68 Self::Bfe => BFieldElement::static_length(),
69 Self::Xfe => XFieldElement::static_length(),
70 Self::Digest => Digest::static_length(),
71 Self::List(_) => None,
72 Self::Array(a) => Some(a.length * a.element_type.static_length()?),
73 Self::Tuple(t) => t.iter().map(|dt| dt.static_length()).sum(),
74 Self::VoidPointer => None,
75 Self::StructRef(s) => s.fields.iter().map(|(_, dt)| dt.static_length()).sum(),
76 }
77 }
78
79 pub fn label_friendly_name(&self) -> String {
81 match self {
82 Self::List(inner_type) => format!("list_L{}R", inner_type.label_friendly_name()),
83 Self::Tuple(inner_types) => {
84 format!(
85 "tuple_L{}R",
86 inner_types
87 .iter()
88 .map(|x| x.label_friendly_name())
89 .join("___")
90 )
91 }
92 Self::VoidPointer => "void_pointer".to_string(),
93 Self::Bool => "bool".to_string(),
94 Self::U32 => "u32".to_string(),
95 Self::U64 => "u64".to_string(),
96 Self::U128 => "u128".to_string(),
97 Self::U160 => "u160".to_string(),
98 Self::U192 => "u192".to_string(),
99 Self::I128 => "i128".to_string(),
100 Self::Bfe => "bfe".to_string(),
101 Self::Xfe => "xfe".to_string(),
102 Self::Digest => "digest".to_string(),
103 Self::Array(array_type) => format!(
104 "array{}___{}",
105 array_type.length,
106 array_type.element_type.label_friendly_name()
107 ),
108 Self::StructRef(struct_type) => format!("{struct_type}"),
109 }
110 }
111
112 pub fn stack_size(&self) -> usize {
114 match self {
115 Self::Bool
116 | Self::U32
117 | Self::U64
118 | Self::U128
119 | Self::U160
120 | Self::U192
121 | Self::I128
122 | Self::Bfe
123 | Self::Xfe
124 | Self::Digest => self.static_length().unwrap(),
125 Self::List(_) => 1,
126 Self::Array(_) => 1,
127 Self::Tuple(t) => t.iter().map(|dt| dt.stack_size()).sum(),
128 Self::VoidPointer => 1,
129 Self::StructRef(_) => 1,
130 }
131 }
132
133 pub fn read_value_from_memory_leave_pointer(&self) -> Vec<LabelledInstruction> {
141 load_words_from_memory_leave_pointer(self.stack_size())
142 }
143
144 pub fn read_value_from_memory_pop_pointer(&self) -> Vec<LabelledInstruction> {
152 load_words_from_memory_pop_pointer(self.stack_size())
153 }
154
155 pub fn write_value_to_memory_leave_pointer(&self) -> Vec<LabelledInstruction> {
162 write_words_to_memory_leave_pointer(self.stack_size())
163 }
164
165 pub fn write_value_to_memory_pop_pointer(&self) -> Vec<LabelledInstruction> {
172 write_words_to_memory_pop_pointer(self.stack_size())
173 }
174
175 pub fn read_value_from_input(&self, input_source: InputSource) -> Vec<LabelledInstruction> {
182 input_source.read_words(self.stack_size())
183 }
184
185 pub fn write_value_to_stdout(&self) -> Vec<LabelledInstruction> {
192 crate::io::write_words(self.stack_size())
193 }
194
195 pub fn compare_elem_of_stack_size(stack_size: usize) -> Vec<LabelledInstruction> {
202 if stack_size == 0 {
203 return triton_asm!(push 1);
204 } else if stack_size == 1 {
205 return triton_asm!(eq);
206 }
207
208 assert!(stack_size + 1 < NUM_OP_STACK_REGISTERS);
209 let first_cmps = vec![triton_asm!(swap {stack_size + 1} eq); stack_size - 1].concat();
210 let last_cmp = triton_asm!(swap 2 eq);
211 let boolean_ands = triton_asm![mul; stack_size - 1];
212
213 [first_cmps, last_cmp, boolean_ands].concat()
214 }
215
216 pub fn compare(&self) -> Vec<LabelledInstruction> {
223 Self::compare_elem_of_stack_size(self.stack_size())
224 }
225
226 pub fn variant_name(&self) -> String {
228 match self {
230 Self::Bool => "DataType::Bool".to_owned(),
231 Self::U32 => "DataType::U32".to_owned(),
232 Self::U64 => "DataType::U64".to_owned(),
233 Self::U128 => "DataType::U128".to_owned(),
234 Self::U160 => "DataType::U160".to_owned(),
235 Self::U192 => "DataType::U192".to_owned(),
236 Self::I128 => "DataType::I128".to_owned(),
237 Self::Bfe => "DataType::BFE".to_owned(),
238 Self::Xfe => "DataType::XFE".to_owned(),
239 Self::Digest => "DataType::Digest".to_owned(),
240 Self::List(elem_type) => {
241 format!("DataType::List(Box::new({}))", elem_type.variant_name())
242 }
243 Self::VoidPointer => "DataType::VoidPointer".to_owned(),
244 Self::Tuple(elements) => {
245 let elements_as_variant_names =
246 elements.iter().map(|x| x.variant_name()).collect_vec();
247 format!(
248 "DataType::Tuple(vec![{}])",
249 elements_as_variant_names.join(", ")
250 )
251 }
252 Self::Array(array_type) => format!(
253 "[{}; {}]",
254 array_type.element_type.variant_name(),
255 array_type.length
256 ),
257 Self::StructRef(struct_type) => format!("Box<{struct_type}>"),
258 }
259 }
260
261 #[cfg(test)]
263 pub fn big_random_generatable_type_collection() -> Vec<Self> {
264 vec![
265 Self::Bool,
266 Self::U32,
267 Self::U64,
268 Self::U128,
269 Self::U160,
270 Self::U192,
271 Self::Bfe,
272 Self::Xfe,
273 Self::Digest,
274 Self::VoidPointer,
275 Self::Tuple(vec![Self::Bool]),
276 Self::Tuple(vec![Self::Xfe, Self::Bool]),
277 Self::Tuple(vec![Self::Xfe, Self::Digest]),
278 Self::Tuple(vec![Self::Bool, Self::Bool]),
279 Self::Tuple(vec![Self::Digest, Self::Xfe]),
280 Self::Tuple(vec![Self::Bfe, Self::Xfe, Self::Digest]),
281 Self::Tuple(vec![Self::Xfe, Self::Bfe, Self::Digest]),
282 Self::Tuple(vec![Self::U64, Self::Digest, Self::Digest, Self::Digest]),
283 Self::Tuple(vec![Self::Digest, Self::Digest, Self::Digest, Self::U64]),
284 Self::Tuple(vec![Self::Digest, Self::Xfe, Self::U128, Self::Bool]),
285 ]
286 }
287
288 pub fn random_elements(&self, count: usize) -> Vec<Vec<BFieldElement>> {
289 (0..count)
290 .map(|_| self.seeded_random_element(&mut rand::rng()))
291 .collect()
292 }
293
294 pub fn seeded_random_element(&self, rng: &mut impl Rng) -> Vec<BFieldElement> {
295 match self {
296 Self::Bool => rng.random::<bool>().encode(),
297 Self::U32 => rng.random::<u32>().encode(),
298 Self::U64 => rng.random::<u64>().encode(),
299 Self::U128 => rng.random::<u128>().encode(),
300 Self::U160 => rng.random::<[u32; 5]>().encode(),
301 Self::U192 => rng.random::<[u32; 6]>().encode(),
302 Self::I128 => rng.random::<[u32; 4]>().encode(),
303 Self::Bfe => rng.random::<BFieldElement>().encode(),
304 Self::Xfe => rng.random::<XFieldElement>().encode(),
305 Self::Digest => rng.random::<Digest>().encode(),
306 Self::List(e) => {
307 let len = rng.random_range(0..20);
308 e.random_list(rng, len)
309 }
310 Self::Array(a) => Self::random_array(rng, a),
311 Self::Tuple(tys) => tys
312 .iter()
313 .flat_map(|ty| ty.seeded_random_element(rng))
314 .collect(),
315 Self::VoidPointer => vec![rng.random()],
316 Self::StructRef(_) => panic!("Random generation of structs is not supported"),
317 }
318 }
319
320 pub(crate) fn random_list(&self, rng: &mut impl Rng, len: usize) -> Vec<BFieldElement> {
322 let maybe_prepend_elem_len = |elem: Vec<_>| {
323 if self.static_length().is_some() {
324 elem
325 } else {
326 [bfe_vec![elem.len() as u64], elem].concat()
327 }
328 };
329
330 let elements = (0..len)
331 .map(|_| self.seeded_random_element(rng))
332 .flat_map(maybe_prepend_elem_len)
333 .collect();
334
335 [bfe_vec![len as u64], elements].concat()
336 }
337
338 pub(crate) fn random_array(rng: &mut impl Rng, array_ty: &ArrayType) -> Vec<BFieldElement> {
339 (0..array_ty.length)
340 .flat_map(|_| array_ty.element_type.seeded_random_element(rng))
341 .collect()
342 }
343}
344
345impl FromStr for DataType {
346 type Err = anyhow::Error;
347
348 fn from_str(s: &str) -> Result<Self, Self::Err> {
350 let res = if s.starts_with("list_L") && s.ends_with('R') {
351 let inner = &s[6..s.len() - 1];
352 let inner = FromStr::from_str(inner)?;
353 Self::List(Box::new(inner))
354 } else if s.starts_with("tuple_L") && s.ends_with('R') {
355 let inner = &s[7..s.len() - 1];
356 let inners = inner.split("___");
357 let mut inners_resolved: Vec<Self> = vec![];
358 for inner_elem in inners {
359 inners_resolved.push(FromStr::from_str(inner_elem)?);
360 }
361
362 Self::Tuple(inners_resolved)
363 } else {
364 match s {
365 "void_pointer" => Self::VoidPointer,
366 "bool" => Self::Bool,
367 "u32" => Self::U32,
368 "u64" => Self::U64,
369 "u128" => Self::U128,
370 "u160" => Self::U160,
371 "u192" => Self::U192,
372 "i128" => Self::I128,
373 "bfe" => Self::Bfe,
374 "xfe" => Self::Xfe,
375 "digest" => Self::Digest,
376 _ => anyhow::bail!("Could not parse {s} as a data type"),
377 }
378 };
379
380 Ok(res)
381 }
382}
383
384#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, arbitrary::Arbitrary)]
385pub enum Literal {
386 Bool(bool),
387 U32(u32),
388 U64(u64),
389 U128(u128),
390 I128(i128),
391 Bfe(BFieldElement),
392 Xfe(XFieldElement),
393 Digest(Digest),
394}
395
396impl Literal {
397 pub fn data_type(&self) -> DataType {
398 match self {
399 Self::Bool(_) => DataType::Bool,
400 Self::U32(_) => DataType::U32,
401 Self::U64(_) => DataType::U64,
402 Self::U128(_) => DataType::U128,
403 Self::I128(_) => DataType::I128,
404 Self::Bfe(_) => DataType::Bfe,
405 Self::Xfe(_) => DataType::Xfe,
406 Self::Digest(_) => DataType::Digest,
407 }
408 }
409
410 pub fn as_xfe(&self) -> XFieldElement {
414 match self {
415 Self::Xfe(xfe) => *xfe,
416 _ => panic!("Expected XFE, got {self:?}"),
417 }
418 }
419
420 pub fn push_to_stack_code(&self) -> Vec<LabelledInstruction> {
422 let encoding = match self {
423 Literal::Bool(x) => x.encode(),
424 Literal::U32(x) => x.encode(),
425 Literal::U64(x) => x.encode(),
426 Literal::U128(x) => x.encode(),
427 Literal::I128(x) => x.encode(),
428 Literal::Bfe(x) => x.encode(),
429 Literal::Xfe(x) => x.encode(),
430 Literal::Digest(x) => x.encode(),
431 };
432
433 encoding
434 .into_iter()
435 .rev()
436 .flat_map(|b| triton_asm!(push { b }))
437 .collect()
438 }
439
440 pub fn pop_from_stack(data_type: DataType, stack: &mut Vec<BFieldElement>) -> Self {
446 match data_type {
447 DataType::Bool => Self::Bool(pop_encodable(stack)),
448 DataType::U32 => Self::U32(pop_encodable(stack)),
449 DataType::U64 => Self::U64(pop_encodable(stack)),
450 DataType::U128 => Self::U128(pop_encodable(stack)),
451 DataType::I128 => Self::I128(pop_encodable(stack)),
452 DataType::Bfe => Self::Bfe(pop_encodable(stack)),
453 DataType::Xfe => Self::Xfe(pop_encodable(stack)),
454 DataType::Digest => Self::Digest(pop_encodable(stack)),
455 DataType::List(_)
456 | DataType::Array(_)
457 | DataType::Tuple(_)
458 | DataType::VoidPointer
459 | DataType::StructRef(_)
460 | DataType::U160
461 | DataType::U192 => unimplemented!(),
462 }
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469 use crate::test_prelude::*;
470
471 impl Literal {
472 fn as_i128(&self) -> i128 {
473 let Self::I128(val) = self else {
474 panic!("Expected i128, got: {self:?}");
475 };
476
477 *val
478 }
479 }
480
481 #[proptest]
482 fn push_to_stack_leaves_value_on_top_of_stack(#[strategy(arb())] literal: Literal) {
483 let code = triton_asm!(
484 {&literal.push_to_stack_code()}
485 halt
486 );
487
488 let mut vm_state = VMState::new(
489 Program::new(&code),
490 PublicInput::default(),
491 NonDeterminism::default(),
492 );
493 vm_state.run().unwrap();
494 let read = Literal::pop_from_stack(literal.data_type(), &mut vm_state.op_stack.stack);
495 assert_eq!(literal, read);
496 }
497
498 #[test]
499 fn static_lengths_match_up_for_vectors() {
500 assert_eq!(
501 <Vec<XFieldElement>>::static_length(),
502 DataType::List(Box::new(DataType::Xfe)).static_length()
503 );
504 }
505
506 #[test]
507 fn static_lengths_match_up_for_array_with_static_length_data_type() {
508 assert_eq!(
509 <[BFieldElement; 42]>::static_length(),
510 DataType::Array(Box::new(ArrayType {
511 element_type: DataType::Bfe,
512 length: 42
513 }))
514 .static_length()
515 );
516 }
517
518 #[test]
519 fn static_lengths_match_up_for_array_with_dynamic_length_data_type() {
520 assert_eq!(
521 <[Vec<BFieldElement>; 42]>::static_length(),
522 DataType::Array(Box::new(ArrayType {
523 element_type: DataType::List(Box::new(DataType::Bfe)),
524 length: 42
525 }))
526 .static_length()
527 );
528 }
529
530 #[test]
531 fn static_lengths_match_up_for_tuple_with_only_static_length_types() {
532 assert_eq!(
533 <(XFieldElement, BFieldElement)>::static_length(),
534 DataType::Tuple(vec![DataType::Xfe, DataType::Bfe]).static_length()
535 );
536 }
537
538 #[test]
539 fn static_lengths_match_up_for_tuple_with_dynamic_length_types() {
540 assert_eq!(
541 <(XFieldElement, Vec<BFieldElement>)>::static_length(),
542 DataType::Tuple(vec![DataType::Xfe, DataType::List(Box::new(DataType::Bfe))])
543 .static_length()
544 );
545 }
546
547 #[test]
548 fn static_length_of_void_pointer_is_unknown() {
549 assert!(DataType::VoidPointer.static_length().is_none());
550 }
551
552 #[test]
553 fn static_lengths_match_up_for_struct_with_only_static_length_types() {
554 #[derive(Debug, Clone, BFieldCodec)]
555 struct StructTyStatic {
556 u32: u32,
557 u64: u64,
558 }
559
560 let struct_ty_static = StructType {
561 name: "struct".to_owned(),
562 fields: vec![
563 ("u32".to_owned(), DataType::U32),
564 ("u64".to_owned(), DataType::U64),
565 ],
566 };
567 assert_eq!(
568 StructTyStatic::static_length(),
569 DataType::StructRef(struct_ty_static).static_length()
570 );
571 }
572
573 #[test]
574 fn static_lengths_match_up_for_struct_with_dynamic_length_types() {
575 #[derive(Debug, Clone, BFieldCodec)]
576 struct StructTyDyn {
577 digest: Digest,
578 list: Vec<BFieldElement>,
579 }
580
581 let struct_ty_dyn = StructType {
582 name: "struct".to_owned(),
583 fields: vec![
584 ("digest".to_owned(), DataType::Digest),
585 ("list".to_owned(), DataType::List(Box::new(DataType::Bfe))),
586 ],
587 };
588 assert_eq!(
589 StructTyDyn::static_length(),
590 DataType::StructRef(struct_ty_dyn).static_length()
591 );
592 }
593
594 #[test]
595 fn random_list_of_lists_can_be_generated() {
596 let mut rng = StdRng::seed_from_u64(5950175350772851878);
597 let element_type = DataType::List(Box::new(DataType::Digest));
598 let _list = element_type.random_list(&mut rng, 10);
599 }
600
601 #[test]
602 fn i128_sizes() {
603 assert_eq!(4, DataType::I128.stack_size());
604 assert_eq!(Some(4), DataType::I128.static_length());
605 }
606
607 #[proptest]
608 fn non_negative_i128s_encode_like_u128s_prop(
609 #[strategy(arb())]
610 #[filter(#as_i128 >= 0i128)]
611 as_i128: i128,
612 ) {
613 let as_u128: u128 = as_i128.try_into().unwrap();
614 assert_eq!(
615 Literal::U128(as_u128).push_to_stack_code(),
616 Literal::I128(as_i128).push_to_stack_code()
617 );
618 }
619
620 #[proptest]
621 fn i128_literals_prop(val: i128) {
622 let program = Literal::I128(val).push_to_stack_code();
623 let program = triton_program!(
624 {&program}
625 halt
626 );
627
628 let mut vm_state = VMState::new(program, [].into(), [].into());
629 vm_state.run().unwrap();
630 let mut final_stack = vm_state.op_stack.stack;
631 let popped = Literal::pop_from_stack(DataType::I128, &mut final_stack).as_i128();
632 assert_eq!(val, popped);
633 }
634
635 #[proptest]
636 fn random_list_conforms_to_bfield_codec(#[strategy(..255_usize)] len: usize, seed: u64) {
637 let mut rng = StdRng::seed_from_u64(seed);
638 let element_type = DataType::Digest;
639 let list = element_type.random_list(&mut rng, len);
640 prop_assert!(<Vec<Digest>>::decode(&list).is_ok());
641 }
642
643 #[proptest]
644 fn random_list_of_lists_conforms_to_bfield_codec(
645 #[strategy(..255_usize)] len: usize,
646 seed: u64,
647 ) {
648 let mut rng = StdRng::seed_from_u64(seed);
649 let element_type = DataType::List(Box::new(DataType::Digest));
650 let list = element_type.random_list(&mut rng, len);
651 prop_assert!(<Vec<Vec<Digest>>>::decode(&list).is_ok());
652 }
653
654 #[proptest]
655 fn random_array_conforms_to_bfield_codec(seed: u64) {
656 const LEN: usize = 42;
657
658 let mut rng = StdRng::seed_from_u64(seed);
659 let array_type = ArrayType {
660 element_type: DataType::Digest,
661 length: LEN,
662 };
663 let array = DataType::random_array(&mut rng, &array_type);
664 prop_assert!(<[Digest; LEN]>::decode(&array).is_ok());
665 }
666
667 #[proptest]
668 fn random_array_of_arrays_conforms_to_bfield_codec(seed: u64) {
669 const INNER_LEN: usize = 42;
670 const OUTER_LEN: usize = 13;
671
672 let mut rng = StdRng::seed_from_u64(seed);
673 let inner_type = ArrayType {
674 element_type: DataType::Digest,
675 length: INNER_LEN,
676 };
677 let outer_type = ArrayType {
678 element_type: DataType::Array(Box::new(inner_type)),
679 length: OUTER_LEN,
680 };
681 let array = DataType::random_array(&mut rng, &outer_type);
682 prop_assert!(<[[Digest; INNER_LEN]; OUTER_LEN]>::decode(&array).is_ok());
683 }
684}
685
686#[cfg(test)]
689mod compare_literals {
690 use super::*;
691 use crate::prelude::*;
692 use crate::test_prelude::*;
693
694 macro_rules! comparison_snippet {
695 ($name:ident for tasm_ty $tasm_ty:ident and rust_ty $rust_ty:ident) => {
696 #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
697 struct $name;
698
699 impl BasicSnippet for $name {
700 fn parameters(&self) -> Vec<(DataType, String)> {
701 ["left", "right"]
702 .map(|s| (DataType::$tasm_ty, s.to_string()))
703 .to_vec()
704 }
705
706 fn return_values(&self) -> Vec<(DataType, String)> {
707 vec![(DataType::Bool, "are_eq".to_string())]
708 }
709
710 fn entrypoint(&self) -> String {
711 let ty = stringify!($tasm_ty);
712 format!("tasmlib_test_compare_{ty}")
713 }
714
715 fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
716 triton_asm!({self.entrypoint()}: {&DataType::$tasm_ty.compare()} return)
717 }
718 }
719
720 impl Closure for $name {
721 type Args = ($rust_ty, $rust_ty);
722
723 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
724 let (right, left) = pop_encodable::<Self::Args>(stack);
725 push_encodable(stack, &(left == right));
726 }
727
728 fn pseudorandom_args(
729 &self,
730 seed: [u8; 32],
731 _: Option<BenchmarkCase>
732 ) -> Self::Args {
733 StdRng::from_seed(seed).random()
735 }
736
737 fn corner_case_args(&self) -> Vec<Self::Args> {
738 vec![Self::Args::default()]
740 }
741 }
742 };
743 }
744
745 comparison_snippet!(CompareBfes for tasm_ty Bfe and rust_ty BFieldElement);
747
748 comparison_snippet!(CompareDigests for tasm_ty Digest and rust_ty Digest);
750
751 #[test]
752 fn test() {
753 ShadowedClosure::new(CompareBfes).test();
754 ShadowedClosure::new(CompareDigests).test();
755 }
756
757 #[test]
758 fn bench() {
759 ShadowedClosure::new(CompareBfes).bench();
760 ShadowedClosure::new(CompareDigests).bench();
761 }
762}