1use std::fmt::Display;
2use std::fmt::Formatter;
3use std::str::FromStr;
4
5use itertools::Itertools;
6use rand::prelude::*;
7use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
8use triton_vm::prelude::*;
9
10use crate::io::InputSource;
11use crate::memory::load_words_from_memory_leave_pointer;
12use crate::memory::load_words_from_memory_pop_pointer;
13use crate::memory::write_words_to_memory_leave_pointer;
14use crate::memory::write_words_to_memory_pop_pointer;
15use crate::pop_encodable;
16
17#[derive(Debug, Clone, Eq, PartialEq, Hash)]
18pub enum DataType {
19 Bool,
20 U32,
21 U64,
22 U128,
23 I128,
24 Bfe,
25 Xfe,
26 Digest,
27 List(Box<DataType>),
28 Array(Box<ArrayType>),
29 Tuple(Vec<DataType>),
30 VoidPointer,
31 StructRef(StructType),
32}
33
34#[derive(Debug, Clone, Hash, PartialEq, Eq)]
35pub struct ArrayType {
36 pub element_type: DataType,
37 pub length: usize,
38}
39
40#[derive(Debug, Clone, Hash, PartialEq, Eq)]
41pub struct StructType {
42 pub name: String,
43 pub fields: Vec<(String, DataType)>,
44}
45
46impl Display for StructType {
47 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48 write!(f, "{}", self.name)
49 }
50}
51
52impl DataType {
53 pub(crate) fn static_length(&self) -> Option<usize> {
55 match self {
56 Self::Bool => bool::static_length(),
57 Self::U32 => u32::static_length(),
58 Self::U64 => u64::static_length(),
59 Self::U128 => u128::static_length(),
60 Self::I128 => i128::static_length(),
61 Self::Bfe => BFieldElement::static_length(),
62 Self::Xfe => XFieldElement::static_length(),
63 Self::Digest => Digest::static_length(),
64 Self::List(_) => None,
65 Self::Array(a) => Some(a.length * a.element_type.static_length()?),
66 Self::Tuple(t) => t.iter().map(|dt| dt.static_length()).sum(),
67 Self::VoidPointer => None,
68 Self::StructRef(s) => s.fields.iter().map(|(_, dt)| dt.static_length()).sum(),
69 }
70 }
71
72 pub fn label_friendly_name(&self) -> String {
74 match self {
75 Self::List(inner_type) => format!("list_L{}R", inner_type.label_friendly_name()),
76 Self::Tuple(inner_types) => {
77 format!(
78 "tuple_L{}R",
79 inner_types
80 .iter()
81 .map(|x| x.label_friendly_name())
82 .join("___")
83 )
84 }
85 Self::VoidPointer => "void_pointer".to_string(),
86 Self::Bool => "bool".to_string(),
87 Self::U32 => "u32".to_string(),
88 Self::U64 => "u64".to_string(),
89 Self::U128 => "u128".to_string(),
90 Self::I128 => "i128".to_string(),
91 Self::Bfe => "bfe".to_string(),
92 Self::Xfe => "xfe".to_string(),
93 Self::Digest => "digest".to_string(),
94 Self::Array(array_type) => format!(
95 "array{}___{}",
96 array_type.length,
97 array_type.element_type.label_friendly_name()
98 ),
99 Self::StructRef(struct_type) => format!("{struct_type}"),
100 }
101 }
102
103 pub fn stack_size(&self) -> usize {
105 match self {
106 Self::Bool
107 | Self::U32
108 | Self::U64
109 | Self::U128
110 | Self::I128
111 | Self::Bfe
112 | Self::Xfe
113 | Self::Digest => self.static_length().unwrap(),
114 Self::List(_) => 1,
115 Self::Array(_) => 1,
116 Self::Tuple(t) => t.iter().map(|dt| dt.stack_size()).sum(),
117 Self::VoidPointer => 1,
118 Self::StructRef(_) => 1,
119 }
120 }
121
122 pub fn read_value_from_memory_leave_pointer(&self) -> Vec<LabelledInstruction> {
130 load_words_from_memory_leave_pointer(self.stack_size())
131 }
132
133 pub fn read_value_from_memory_pop_pointer(&self) -> Vec<LabelledInstruction> {
141 load_words_from_memory_pop_pointer(self.stack_size())
142 }
143
144 pub fn write_value_to_memory_leave_pointer(&self) -> Vec<LabelledInstruction> {
151 write_words_to_memory_leave_pointer(self.stack_size())
152 }
153
154 pub fn write_value_to_memory_pop_pointer(&self) -> Vec<LabelledInstruction> {
161 write_words_to_memory_pop_pointer(self.stack_size())
162 }
163
164 pub fn read_value_from_input(&self, input_source: InputSource) -> Vec<LabelledInstruction> {
171 input_source.read_words(self.stack_size())
172 }
173
174 pub fn write_value_to_stdout(&self) -> Vec<LabelledInstruction> {
181 crate::io::write_words(self.stack_size())
182 }
183
184 pub fn compare_elem_of_stack_size(stack_size: usize) -> Vec<LabelledInstruction> {
191 if stack_size == 0 {
192 return triton_asm!(push 1);
193 } else if stack_size == 1 {
194 return triton_asm!(eq);
195 }
196
197 assert!(stack_size + 1 < NUM_OP_STACK_REGISTERS);
198 let first_cmps = vec![triton_asm!(swap {stack_size + 1} eq); stack_size - 1].concat();
199 let last_cmp = triton_asm!(swap 2 eq);
200 let boolean_ands = triton_asm![mul; stack_size - 1];
201
202 [first_cmps, last_cmp, boolean_ands].concat()
203 }
204
205 pub fn compare(&self) -> Vec<LabelledInstruction> {
212 Self::compare_elem_of_stack_size(self.stack_size())
213 }
214
215 pub fn variant_name(&self) -> String {
217 match self {
219 Self::Bool => "DataType::Bool".to_owned(),
220 Self::U32 => "DataType::U32".to_owned(),
221 Self::U64 => "DataType::U64".to_owned(),
222 Self::U128 => "DataType::U128".to_owned(),
223 Self::I128 => "DataType::I128".to_owned(),
224 Self::Bfe => "DataType::BFE".to_owned(),
225 Self::Xfe => "DataType::XFE".to_owned(),
226 Self::Digest => "DataType::Digest".to_owned(),
227 Self::List(elem_type) => {
228 format!("DataType::List(Box::new({}))", elem_type.variant_name())
229 }
230 Self::VoidPointer => "DataType::VoidPointer".to_owned(),
231 Self::Tuple(elements) => {
232 let elements_as_variant_names =
233 elements.iter().map(|x| x.variant_name()).collect_vec();
234 format!(
235 "DataType::Tuple(vec![{}])",
236 elements_as_variant_names.join(", ")
237 )
238 }
239 Self::Array(array_type) => format!(
240 "[{}; {}]",
241 array_type.element_type.variant_name(),
242 array_type.length
243 ),
244 Self::StructRef(struct_type) => format!("Box<{struct_type}>"),
245 }
246 }
247
248 #[cfg(test)]
250 pub fn big_random_generatable_type_collection() -> Vec<Self> {
251 vec![
252 Self::Bool,
253 Self::U32,
254 Self::U64,
255 Self::U128,
256 Self::Bfe,
257 Self::Xfe,
258 Self::Digest,
259 Self::VoidPointer,
260 Self::Tuple(vec![Self::Bool]),
261 Self::Tuple(vec![Self::Xfe, Self::Bool]),
262 Self::Tuple(vec![Self::Xfe, Self::Digest]),
263 Self::Tuple(vec![Self::Bool, Self::Bool]),
264 Self::Tuple(vec![Self::Digest, Self::Xfe]),
265 Self::Tuple(vec![Self::Bfe, Self::Xfe, Self::Digest]),
266 Self::Tuple(vec![Self::Xfe, Self::Bfe, Self::Digest]),
267 Self::Tuple(vec![Self::U64, Self::Digest, Self::Digest, Self::Digest]),
268 Self::Tuple(vec![Self::Digest, Self::Digest, Self::Digest, Self::U64]),
269 Self::Tuple(vec![Self::Digest, Self::Xfe, Self::U128, Self::Bool]),
270 ]
271 }
272
273 pub fn random_elements(&self, count: usize) -> Vec<Vec<BFieldElement>> {
274 (0..count)
275 .map(|_| self.seeded_random_element(&mut rand::rng()))
276 .collect()
277 }
278
279 pub fn seeded_random_element(&self, rng: &mut impl Rng) -> Vec<BFieldElement> {
280 match self {
281 Self::Bool => rng.random::<bool>().encode(),
282 Self::U32 => rng.random::<u32>().encode(),
283 Self::U64 => rng.random::<u64>().encode(),
284 Self::U128 => rng.random::<u128>().encode(),
285 Self::I128 => rng.random::<[u32; 4]>().encode(),
286 Self::Bfe => rng.random::<BFieldElement>().encode(),
287 Self::Xfe => rng.random::<XFieldElement>().encode(),
288 Self::Digest => rng.random::<Digest>().encode(),
289 Self::List(e) => {
290 let len = rng.random_range(0..20);
291 e.random_list(rng, len)
292 }
293 Self::Array(a) => Self::random_array(rng, a),
294 Self::Tuple(tys) => tys
295 .iter()
296 .flat_map(|ty| ty.seeded_random_element(rng))
297 .collect(),
298 Self::VoidPointer => vec![rng.random()],
299 Self::StructRef(_) => panic!("Random generation of structs is not supported"),
300 }
301 }
302
303 pub(crate) fn random_list(&self, rng: &mut impl Rng, len: usize) -> Vec<BFieldElement> {
305 let maybe_prepend_elem_len = |elem: Vec<_>| {
306 if self.static_length().is_some() {
307 elem
308 } else {
309 [bfe_vec![elem.len() as u64], elem].concat()
310 }
311 };
312
313 let elements = (0..len)
314 .map(|_| self.seeded_random_element(rng))
315 .flat_map(maybe_prepend_elem_len)
316 .collect();
317
318 [bfe_vec![len as u64], elements].concat()
319 }
320
321 pub(crate) fn random_array(rng: &mut impl Rng, array_ty: &ArrayType) -> Vec<BFieldElement> {
322 (0..array_ty.length)
323 .flat_map(|_| array_ty.element_type.seeded_random_element(rng))
324 .collect()
325 }
326}
327
328impl FromStr for DataType {
329 type Err = anyhow::Error;
330
331 fn from_str(s: &str) -> Result<Self, Self::Err> {
333 let res = if s.starts_with("list_L") && s.ends_with('R') {
334 let inner = &s[6..s.len() - 1];
335 let inner = FromStr::from_str(inner)?;
336 Self::List(Box::new(inner))
337 } else if s.starts_with("tuple_L") && s.ends_with('R') {
338 let inner = &s[7..s.len() - 1];
339 let inners = inner.split("___");
340 let mut inners_resolved: Vec<Self> = vec![];
341 for inner_elem in inners {
342 inners_resolved.push(FromStr::from_str(inner_elem)?);
343 }
344
345 Self::Tuple(inners_resolved)
346 } else {
347 match s {
348 "void_pointer" => Self::VoidPointer,
349 "bool" => Self::Bool,
350 "u32" => Self::U32,
351 "u64" => Self::U64,
352 "u128" => Self::U128,
353 "i128" => Self::I128,
354 "bfe" => Self::Bfe,
355 "xfe" => Self::Xfe,
356 "digest" => Self::Digest,
357 _ => anyhow::bail!("Could not parse {s} as a data type"),
358 }
359 };
360
361 Ok(res)
362 }
363}
364
365#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, arbitrary::Arbitrary)]
366pub enum Literal {
367 Bool(bool),
368 U32(u32),
369 U64(u64),
370 U128(u128),
371 I128(i128),
372 Bfe(BFieldElement),
373 Xfe(XFieldElement),
374 Digest(Digest),
375}
376
377impl Literal {
378 pub fn data_type(&self) -> DataType {
379 match self {
380 Self::Bool(_) => DataType::Bool,
381 Self::U32(_) => DataType::U32,
382 Self::U64(_) => DataType::U64,
383 Self::U128(_) => DataType::U128,
384 Self::I128(_) => DataType::I128,
385 Self::Bfe(_) => DataType::Bfe,
386 Self::Xfe(_) => DataType::Xfe,
387 Self::Digest(_) => DataType::Digest,
388 }
389 }
390
391 pub fn as_xfe(&self) -> XFieldElement {
395 match self {
396 Self::Xfe(xfe) => *xfe,
397 _ => panic!("Expected XFE, got {self:?}"),
398 }
399 }
400
401 pub fn push_to_stack_code(&self) -> Vec<LabelledInstruction> {
403 let encoding = match self {
404 Literal::Bool(x) => x.encode(),
405 Literal::U32(x) => x.encode(),
406 Literal::U64(x) => x.encode(),
407 Literal::U128(x) => x.encode(),
408 Literal::I128(x) => x.encode(),
409 Literal::Bfe(x) => x.encode(),
410 Literal::Xfe(x) => x.encode(),
411 Literal::Digest(x) => x.encode(),
412 };
413
414 encoding
415 .into_iter()
416 .rev()
417 .flat_map(|b| triton_asm!(push { b }))
418 .collect()
419 }
420
421 pub fn pop_from_stack(data_type: DataType, stack: &mut Vec<BFieldElement>) -> Self {
427 match data_type {
428 DataType::Bool => Self::Bool(pop_encodable(stack)),
429 DataType::U32 => Self::U32(pop_encodable(stack)),
430 DataType::U64 => Self::U64(pop_encodable(stack)),
431 DataType::U128 => Self::U128(pop_encodable(stack)),
432 DataType::I128 => Self::I128(pop_encodable(stack)),
433 DataType::Bfe => Self::Bfe(pop_encodable(stack)),
434 DataType::Xfe => Self::Xfe(pop_encodable(stack)),
435 DataType::Digest => Self::Digest(pop_encodable(stack)),
436 DataType::List(_)
437 | DataType::Array(_)
438 | DataType::Tuple(_)
439 | DataType::VoidPointer
440 | DataType::StructRef(_) => unimplemented!(),
441 }
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448 use crate::test_prelude::*;
449
450 impl Literal {
451 fn as_i128(&self) -> i128 {
452 let Self::I128(val) = self else {
453 panic!("Expected i128, got: {self:?}");
454 };
455
456 *val
457 }
458 }
459
460 #[proptest]
461 fn push_to_stack_leaves_value_on_top_of_stack(#[strategy(arb())] literal: Literal) {
462 let code = triton_asm!(
463 {&literal.push_to_stack_code()}
464 halt
465 );
466
467 let mut vm_state = VMState::new(
468 Program::new(&code),
469 PublicInput::default(),
470 NonDeterminism::default(),
471 );
472 vm_state.run().unwrap();
473 let read = Literal::pop_from_stack(literal.data_type(), &mut vm_state.op_stack.stack);
474 assert_eq!(literal, read);
475 }
476
477 #[test]
478 fn static_lengths_match_up_for_vectors() {
479 assert_eq!(
480 <Vec<XFieldElement>>::static_length(),
481 DataType::List(Box::new(DataType::Xfe)).static_length()
482 );
483 }
484
485 #[test]
486 fn static_lengths_match_up_for_array_with_static_length_data_type() {
487 assert_eq!(
488 <[BFieldElement; 42]>::static_length(),
489 DataType::Array(Box::new(ArrayType {
490 element_type: DataType::Bfe,
491 length: 42
492 }))
493 .static_length()
494 );
495 }
496
497 #[test]
498 fn static_lengths_match_up_for_array_with_dynamic_length_data_type() {
499 assert_eq!(
500 <[Vec<BFieldElement>; 42]>::static_length(),
501 DataType::Array(Box::new(ArrayType {
502 element_type: DataType::List(Box::new(DataType::Bfe)),
503 length: 42
504 }))
505 .static_length()
506 );
507 }
508
509 #[test]
510 fn static_lengths_match_up_for_tuple_with_only_static_length_types() {
511 assert_eq!(
512 <(XFieldElement, BFieldElement)>::static_length(),
513 DataType::Tuple(vec![DataType::Xfe, DataType::Bfe]).static_length()
514 );
515 }
516
517 #[test]
518 fn static_lengths_match_up_for_tuple_with_dynamic_length_types() {
519 assert_eq!(
520 <(XFieldElement, Vec<BFieldElement>)>::static_length(),
521 DataType::Tuple(vec![DataType::Xfe, DataType::List(Box::new(DataType::Bfe))])
522 .static_length()
523 );
524 }
525
526 #[test]
527 fn static_length_of_void_pointer_is_unknown() {
528 assert!(DataType::VoidPointer.static_length().is_none());
529 }
530
531 #[test]
532 fn static_lengths_match_up_for_struct_with_only_static_length_types() {
533 #[derive(Debug, Clone, BFieldCodec)]
534 struct StructTyStatic {
535 u32: u32,
536 u64: u64,
537 }
538
539 let struct_ty_static = StructType {
540 name: "struct".to_owned(),
541 fields: vec![
542 ("u32".to_owned(), DataType::U32),
543 ("u64".to_owned(), DataType::U64),
544 ],
545 };
546 assert_eq!(
547 StructTyStatic::static_length(),
548 DataType::StructRef(struct_ty_static).static_length()
549 );
550 }
551
552 #[test]
553 fn static_lengths_match_up_for_struct_with_dynamic_length_types() {
554 #[derive(Debug, Clone, BFieldCodec)]
555 struct StructTyDyn {
556 digest: Digest,
557 list: Vec<BFieldElement>,
558 }
559
560 let struct_ty_dyn = StructType {
561 name: "struct".to_owned(),
562 fields: vec![
563 ("digest".to_owned(), DataType::Digest),
564 ("list".to_owned(), DataType::List(Box::new(DataType::Bfe))),
565 ],
566 };
567 assert_eq!(
568 StructTyDyn::static_length(),
569 DataType::StructRef(struct_ty_dyn).static_length()
570 );
571 }
572
573 #[test]
574 fn random_list_of_lists_can_be_generated() {
575 let mut rng = StdRng::seed_from_u64(5950175350772851878);
576 let element_type = DataType::List(Box::new(DataType::Digest));
577 let _list = element_type.random_list(&mut rng, 10);
578 }
579
580 #[test]
581 fn i128_sizes() {
582 assert_eq!(4, DataType::I128.stack_size());
583 assert_eq!(Some(4), DataType::I128.static_length());
584 }
585
586 #[proptest]
587 fn non_negative_i128s_encode_like_u128s_prop(
588 #[strategy(arb())]
589 #[filter(#as_i128 >= 0i128)]
590 as_i128: i128,
591 ) {
592 let as_u128: u128 = as_i128.try_into().unwrap();
593 assert_eq!(
594 Literal::U128(as_u128).push_to_stack_code(),
595 Literal::I128(as_i128).push_to_stack_code()
596 );
597 }
598
599 #[proptest]
600 fn i128_literals_prop(val: i128) {
601 let program = Literal::I128(val).push_to_stack_code();
602 let program = triton_program!(
603 {&program}
604 halt
605 );
606
607 let mut vm_state = VMState::new(program, [].into(), [].into());
608 vm_state.run().unwrap();
609 let mut final_stack = vm_state.op_stack.stack;
610 let popped = Literal::pop_from_stack(DataType::I128, &mut final_stack).as_i128();
611 assert_eq!(val, popped);
612 }
613
614 #[proptest]
615 fn random_list_conforms_to_bfield_codec(#[strategy(..255_usize)] len: usize, seed: u64) {
616 let mut rng = StdRng::seed_from_u64(seed);
617 let element_type = DataType::Digest;
618 let list = element_type.random_list(&mut rng, len);
619 prop_assert!(<Vec<Digest>>::decode(&list).is_ok());
620 }
621
622 #[proptest]
623 fn random_list_of_lists_conforms_to_bfield_codec(
624 #[strategy(..255_usize)] len: usize,
625 seed: u64,
626 ) {
627 let mut rng = StdRng::seed_from_u64(seed);
628 let element_type = DataType::List(Box::new(DataType::Digest));
629 let list = element_type.random_list(&mut rng, len);
630 prop_assert!(<Vec<Vec<Digest>>>::decode(&list).is_ok());
631 }
632
633 #[proptest]
634 fn random_array_conforms_to_bfield_codec(seed: u64) {
635 const LEN: usize = 42;
636
637 let mut rng = StdRng::seed_from_u64(seed);
638 let array_type = ArrayType {
639 element_type: DataType::Digest,
640 length: LEN,
641 };
642 let array = DataType::random_array(&mut rng, &array_type);
643 prop_assert!(<[Digest; LEN]>::decode(&array).is_ok());
644 }
645
646 #[proptest]
647 fn random_array_of_arrays_conforms_to_bfield_codec(seed: u64) {
648 const INNER_LEN: usize = 42;
649 const OUTER_LEN: usize = 13;
650
651 let mut rng = StdRng::seed_from_u64(seed);
652 let inner_type = ArrayType {
653 element_type: DataType::Digest,
654 length: INNER_LEN,
655 };
656 let outer_type = ArrayType {
657 element_type: DataType::Array(Box::new(inner_type)),
658 length: OUTER_LEN,
659 };
660 let array = DataType::random_array(&mut rng, &outer_type);
661 prop_assert!(<[[Digest; INNER_LEN]; OUTER_LEN]>::decode(&array).is_ok());
662 }
663}
664
665#[cfg(test)]
668mod compare_literals {
669 use super::*;
670 use crate::prelude::*;
671 use crate::test_prelude::*;
672
673 macro_rules! comparison_snippet {
674 ($name:ident for tasm_ty $tasm_ty:ident and rust_ty $rust_ty:ident) => {
675 #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
676 struct $name;
677
678 impl BasicSnippet for $name {
679 fn inputs(&self) -> Vec<(DataType, String)> {
680 ["left", "right"]
681 .map(|s| (DataType::$tasm_ty, s.to_string()))
682 .to_vec()
683 }
684
685 fn outputs(&self) -> Vec<(DataType, String)> {
686 vec![(DataType::Bool, "are_eq".to_string())]
687 }
688
689 fn entrypoint(&self) -> String {
690 let ty = stringify!($tasm_ty);
691 format!("tasmlib_test_compare_{ty}")
692 }
693
694 fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
695 triton_asm!({self.entrypoint()}: {&DataType::$tasm_ty.compare()} return)
696 }
697 }
698
699 impl Closure for $name {
700 type Args = ($rust_ty, $rust_ty);
701
702 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
703 let (right, left) = pop_encodable::<Self::Args>(stack);
704 push_encodable(stack, &(left == right));
705 }
706
707 fn pseudorandom_args(
708 &self,
709 seed: [u8; 32],
710 _: Option<BenchmarkCase>
711 ) -> Self::Args {
712 StdRng::from_seed(seed).random()
714 }
715
716 fn corner_case_args(&self) -> Vec<Self::Args> {
717 vec![Self::Args::default()]
719 }
720 }
721 };
722 }
723
724 comparison_snippet!(CompareBfes for tasm_ty Bfe and rust_ty BFieldElement);
726
727 comparison_snippet!(CompareDigests for tasm_ty Digest and rust_ty Digest);
729
730 #[test]
731 fn test() {
732 ShadowedClosure::new(CompareBfes).test();
733 ShadowedClosure::new(CompareDigests).test();
734 }
735
736 #[test]
737 fn bench() {
738 ShadowedClosure::new(CompareBfes).bench();
739 ShadowedClosure::new(CompareDigests).bench();
740 }
741}