tasm_lib/
data_type.rs

1use std::fmt::Display;
2use std::fmt::Formatter;
3use std::str::FromStr;
4
5use itertools::Itertools;
6use rand::prelude::*;
7use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
8use triton_vm::prelude::*;
9
10use crate::io::InputSource;
11use crate::memory::load_words_from_memory_leave_pointer;
12use crate::memory::load_words_from_memory_pop_pointer;
13use crate::memory::write_words_to_memory_leave_pointer;
14use crate::memory::write_words_to_memory_pop_pointer;
15use crate::pop_encodable;
16
17/// A type hint for developers of Triton Assembly.
18///
19/// Note that _no_ type checking is performed.
20#[derive(Debug, Clone, Eq, PartialEq, Hash)]
21pub enum DataType {
22    Bool,
23    U32,
24    U64,
25    U128,
26    U160,
27    U192,
28    I128,
29    Bfe,
30    Xfe,
31    Digest,
32    List(Box<DataType>),
33    Array(Box<ArrayType>),
34    Tuple(Vec<DataType>),
35    VoidPointer,
36    StructRef(StructType),
37}
38
39#[derive(Debug, Clone, Hash, PartialEq, Eq)]
40pub struct ArrayType {
41    pub element_type: DataType,
42    pub length: usize,
43}
44
45#[derive(Debug, Clone, Hash, PartialEq, Eq)]
46pub struct StructType {
47    pub name: String,
48    pub fields: Vec<(String, DataType)>,
49}
50
51impl Display for StructType {
52    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53        write!(f, "{}", self.name)
54    }
55}
56
57impl DataType {
58    /// See [`BFieldCodec::static_length`].
59    pub(crate) fn static_length(&self) -> Option<usize> {
60        match self {
61            Self::Bool => bool::static_length(),
62            Self::U32 => u32::static_length(),
63            Self::U64 => u64::static_length(),
64            Self::U128 => u128::static_length(),
65            Self::U160 => Some(5),
66            Self::U192 => Some(6),
67            Self::I128 => i128::static_length(),
68            Self::Bfe => BFieldElement::static_length(),
69            Self::Xfe => XFieldElement::static_length(),
70            Self::Digest => Digest::static_length(),
71            Self::List(_) => None,
72            Self::Array(a) => Some(a.length * a.element_type.static_length()?),
73            Self::Tuple(t) => t.iter().map(|dt| dt.static_length()).sum(),
74            Self::VoidPointer => None,
75            Self::StructRef(s) => s.fields.iter().map(|(_, dt)| dt.static_length()).sum(),
76        }
77    }
78
79    /// A string which can be used as part of function labels in Triton-VM.
80    pub fn label_friendly_name(&self) -> String {
81        match self {
82            Self::List(inner_type) => format!("list_L{}R", inner_type.label_friendly_name()),
83            Self::Tuple(inner_types) => {
84                format!(
85                    "tuple_L{}R",
86                    inner_types
87                        .iter()
88                        .map(|x| x.label_friendly_name())
89                        .join("___")
90                )
91            }
92            Self::VoidPointer => "void_pointer".to_string(),
93            Self::Bool => "bool".to_string(),
94            Self::U32 => "u32".to_string(),
95            Self::U64 => "u64".to_string(),
96            Self::U128 => "u128".to_string(),
97            Self::U160 => "u160".to_string(),
98            Self::U192 => "u192".to_string(),
99            Self::I128 => "i128".to_string(),
100            Self::Bfe => "bfe".to_string(),
101            Self::Xfe => "xfe".to_string(),
102            Self::Digest => "digest".to_string(),
103            Self::Array(array_type) => format!(
104                "array{}___{}",
105                array_type.length,
106                array_type.element_type.label_friendly_name()
107            ),
108            Self::StructRef(struct_type) => format!("{struct_type}"),
109        }
110    }
111
112    /// The size that the data type takes up on stack
113    pub fn stack_size(&self) -> usize {
114        match self {
115            Self::Bool
116            | Self::U32
117            | Self::U64
118            | Self::U128
119            | Self::U160
120            | Self::U192
121            | Self::I128
122            | Self::Bfe
123            | Self::Xfe
124            | Self::Digest => self.static_length().unwrap(),
125            Self::List(_) => 1,
126            Self::Array(_) => 1,
127            Self::Tuple(t) => t.iter().map(|dt| dt.stack_size()).sum(),
128            Self::VoidPointer => 1,
129            Self::StructRef(_) => 1,
130        }
131    }
132
133    /// The code to read a value of this type from memory.
134    /// Leaves mutated point on top of stack.
135    ///
136    /// ```text
137    /// BEFORE: _ (*address + self.stack_size() - 1)
138    /// AFTER:  _ [value] (*address - 1)
139    /// ```
140    pub fn read_value_from_memory_leave_pointer(&self) -> Vec<LabelledInstruction> {
141        load_words_from_memory_leave_pointer(self.stack_size())
142    }
143
144    /// The code to read a value of this type from memory.
145    /// Pops pointer from stack.
146    ///
147    /// ```text
148    /// BEFORE: _ (*address + self.stack_size() - 1)
149    /// AFTER:  _ [value]
150    /// ```
151    pub fn read_value_from_memory_pop_pointer(&self) -> Vec<LabelledInstruction> {
152        load_words_from_memory_pop_pointer(self.stack_size())
153    }
154
155    /// The code to write a value of this type to memory
156    ///
157    /// ```text
158    /// BEFORE: _ [value] *address
159    /// AFTER:  _ (*address + self.stack_size())
160    /// ```
161    pub fn write_value_to_memory_leave_pointer(&self) -> Vec<LabelledInstruction> {
162        write_words_to_memory_leave_pointer(self.stack_size())
163    }
164
165    /// The code to write a value of this type to memory
166    ///
167    /// ```text
168    /// BEFORE: _ [value] *address
169    /// AFTER:  _
170    /// ```
171    pub fn write_value_to_memory_pop_pointer(&self) -> Vec<LabelledInstruction> {
172        write_words_to_memory_pop_pointer(self.stack_size())
173    }
174
175    /// The code to read a value of this type from the specified input source
176    ///
177    /// ```text
178    /// BEFORE: _
179    /// AFTER:  _ [value]
180    /// ```
181    pub fn read_value_from_input(&self, input_source: InputSource) -> Vec<LabelledInstruction> {
182        input_source.read_words(self.stack_size())
183    }
184
185    /// The code to write a value of this type to standard output
186    ///
187    /// ```text
188    /// BEFORE: _ [value]
189    /// AFTER:  _
190    /// ```
191    pub fn write_value_to_stdout(&self) -> Vec<LabelledInstruction> {
192        crate::io::write_words(self.stack_size())
193    }
194
195    /// The code that compares two elements of this stack-size.
196    ///
197    /// ```text
198    /// BEFORE: _ [self] [other]
199    /// AFTER:  _ (self == other)
200    /// ```
201    pub fn compare_elem_of_stack_size(stack_size: usize) -> Vec<LabelledInstruction> {
202        if stack_size == 0 {
203            return triton_asm!(push 1);
204        } else if stack_size == 1 {
205            return triton_asm!(eq);
206        }
207
208        assert!(stack_size + 1 < NUM_OP_STACK_REGISTERS);
209        let first_cmps = vec![triton_asm!(swap {stack_size + 1} eq); stack_size - 1].concat();
210        let last_cmp = triton_asm!(swap 2 eq);
211        let boolean_ands = triton_asm![mul; stack_size - 1];
212
213        [first_cmps, last_cmp, boolean_ands].concat()
214    }
215
216    /// The code that compares two elements of this type.
217    ///
218    /// ```text
219    /// BEFORE: _ [self] [other]
220    /// AFTER:  _ (self == other)
221    /// ```
222    pub fn compare(&self) -> Vec<LabelledInstruction> {
223        Self::compare_elem_of_stack_size(self.stack_size())
224    }
225
226    /// A string matching how the variant looks in source code.
227    pub fn variant_name(&self) -> String {
228        // This function is used to autogenerate snippets in the tasm-lang compiler
229        match self {
230            Self::Bool => "DataType::Bool".to_owned(),
231            Self::U32 => "DataType::U32".to_owned(),
232            Self::U64 => "DataType::U64".to_owned(),
233            Self::U128 => "DataType::U128".to_owned(),
234            Self::U160 => "DataType::U160".to_owned(),
235            Self::U192 => "DataType::U192".to_owned(),
236            Self::I128 => "DataType::I128".to_owned(),
237            Self::Bfe => "DataType::BFE".to_owned(),
238            Self::Xfe => "DataType::XFE".to_owned(),
239            Self::Digest => "DataType::Digest".to_owned(),
240            Self::List(elem_type) => {
241                format!("DataType::List(Box::new({}))", elem_type.variant_name())
242            }
243            Self::VoidPointer => "DataType::VoidPointer".to_owned(),
244            Self::Tuple(elements) => {
245                let elements_as_variant_names =
246                    elements.iter().map(|x| x.variant_name()).collect_vec();
247                format!(
248                    "DataType::Tuple(vec![{}])",
249                    elements_as_variant_names.join(", ")
250                )
251            }
252            Self::Array(array_type) => format!(
253                "[{}; {}]",
254                array_type.element_type.variant_name(),
255                array_type.length
256            ),
257            Self::StructRef(struct_type) => format!("Box<{struct_type}>"),
258        }
259    }
260
261    /// A collection of different data types, used for testing.
262    #[cfg(test)]
263    pub fn big_random_generatable_type_collection() -> Vec<Self> {
264        vec![
265            Self::Bool,
266            Self::U32,
267            Self::U64,
268            Self::U128,
269            Self::U160,
270            Self::U192,
271            Self::Bfe,
272            Self::Xfe,
273            Self::Digest,
274            Self::VoidPointer,
275            Self::Tuple(vec![Self::Bool]),
276            Self::Tuple(vec![Self::Xfe, Self::Bool]),
277            Self::Tuple(vec![Self::Xfe, Self::Digest]),
278            Self::Tuple(vec![Self::Bool, Self::Bool]),
279            Self::Tuple(vec![Self::Digest, Self::Xfe]),
280            Self::Tuple(vec![Self::Bfe, Self::Xfe, Self::Digest]),
281            Self::Tuple(vec![Self::Xfe, Self::Bfe, Self::Digest]),
282            Self::Tuple(vec![Self::U64, Self::Digest, Self::Digest, Self::Digest]),
283            Self::Tuple(vec![Self::Digest, Self::Digest, Self::Digest, Self::U64]),
284            Self::Tuple(vec![Self::Digest, Self::Xfe, Self::U128, Self::Bool]),
285        ]
286    }
287
288    pub fn random_elements(&self, count: usize) -> Vec<Vec<BFieldElement>> {
289        (0..count)
290            .map(|_| self.seeded_random_element(&mut rand::rng()))
291            .collect()
292    }
293
294    pub fn seeded_random_element(&self, rng: &mut impl Rng) -> Vec<BFieldElement> {
295        match self {
296            Self::Bool => rng.random::<bool>().encode(),
297            Self::U32 => rng.random::<u32>().encode(),
298            Self::U64 => rng.random::<u64>().encode(),
299            Self::U128 => rng.random::<u128>().encode(),
300            Self::U160 => rng.random::<[u32; 5]>().encode(),
301            Self::U192 => rng.random::<[u32; 6]>().encode(),
302            Self::I128 => rng.random::<[u32; 4]>().encode(),
303            Self::Bfe => rng.random::<BFieldElement>().encode(),
304            Self::Xfe => rng.random::<XFieldElement>().encode(),
305            Self::Digest => rng.random::<Digest>().encode(),
306            Self::List(e) => {
307                let len = rng.random_range(0..20);
308                e.random_list(rng, len)
309            }
310            Self::Array(a) => Self::random_array(rng, a),
311            Self::Tuple(tys) => tys
312                .iter()
313                .flat_map(|ty| ty.seeded_random_element(rng))
314                .collect(),
315            Self::VoidPointer => vec![rng.random()],
316            Self::StructRef(_) => panic!("Random generation of structs is not supported"),
317        }
318    }
319
320    /// A list of given length with random elements of type `self`.
321    pub(crate) fn random_list(&self, rng: &mut impl Rng, len: usize) -> Vec<BFieldElement> {
322        let maybe_prepend_elem_len = |elem: Vec<_>| {
323            if self.static_length().is_some() {
324                elem
325            } else {
326                [bfe_vec![elem.len() as u64], elem].concat()
327            }
328        };
329
330        let elements = (0..len)
331            .map(|_| self.seeded_random_element(rng))
332            .flat_map(maybe_prepend_elem_len)
333            .collect();
334
335        [bfe_vec![len as u64], elements].concat()
336    }
337
338    pub(crate) fn random_array(rng: &mut impl Rng, array_ty: &ArrayType) -> Vec<BFieldElement> {
339        (0..array_ty.length)
340            .flat_map(|_| array_ty.element_type.seeded_random_element(rng))
341            .collect()
342    }
343}
344
345impl FromStr for DataType {
346    type Err = anyhow::Error;
347
348    // This implementation must be the inverse of `label_friendly_name`
349    fn from_str(s: &str) -> Result<Self, Self::Err> {
350        let res = if s.starts_with("list_L") && s.ends_with('R') {
351            let inner = &s[6..s.len() - 1];
352            let inner = FromStr::from_str(inner)?;
353            Self::List(Box::new(inner))
354        } else if s.starts_with("tuple_L") && s.ends_with('R') {
355            let inner = &s[7..s.len() - 1];
356            let inners = inner.split("___");
357            let mut inners_resolved: Vec<Self> = vec![];
358            for inner_elem in inners {
359                inners_resolved.push(FromStr::from_str(inner_elem)?);
360            }
361
362            Self::Tuple(inners_resolved)
363        } else {
364            match s {
365                "void_pointer" => Self::VoidPointer,
366                "bool" => Self::Bool,
367                "u32" => Self::U32,
368                "u64" => Self::U64,
369                "u128" => Self::U128,
370                "u160" => Self::U160,
371                "u192" => Self::U192,
372                "i128" => Self::I128,
373                "bfe" => Self::Bfe,
374                "xfe" => Self::Xfe,
375                "digest" => Self::Digest,
376                _ => anyhow::bail!("Could not parse {s} as a data type"),
377            }
378        };
379
380        Ok(res)
381    }
382}
383
384#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, arbitrary::Arbitrary)]
385pub enum Literal {
386    Bool(bool),
387    U32(u32),
388    U64(u64),
389    U128(u128),
390    I128(i128),
391    Bfe(BFieldElement),
392    Xfe(XFieldElement),
393    Digest(Digest),
394}
395
396impl Literal {
397    pub fn data_type(&self) -> DataType {
398        match self {
399            Self::Bool(_) => DataType::Bool,
400            Self::U32(_) => DataType::U32,
401            Self::U64(_) => DataType::U64,
402            Self::U128(_) => DataType::U128,
403            Self::I128(_) => DataType::I128,
404            Self::Bfe(_) => DataType::Bfe,
405            Self::Xfe(_) => DataType::Xfe,
406            Self::Digest(_) => DataType::Digest,
407        }
408    }
409
410    /// # Panics
411    ///
412    /// Panics if `self` is anything but an [extension field element](Self::Xfe).
413    pub fn as_xfe(&self) -> XFieldElement {
414        match self {
415            Self::Xfe(xfe) => *xfe,
416            _ => panic!("Expected XFE, got {self:?}"),
417        }
418    }
419
420    /// The code to push the literal to the stack.
421    pub fn push_to_stack_code(&self) -> Vec<LabelledInstruction> {
422        let encoding = match self {
423            Literal::Bool(x) => x.encode(),
424            Literal::U32(x) => x.encode(),
425            Literal::U64(x) => x.encode(),
426            Literal::U128(x) => x.encode(),
427            Literal::I128(x) => x.encode(),
428            Literal::Bfe(x) => x.encode(),
429            Literal::Xfe(x) => x.encode(),
430            Literal::Digest(x) => x.encode(),
431        };
432
433        encoding
434            .into_iter()
435            .rev()
436            .flat_map(|b| triton_asm!(push { b }))
437            .collect()
438    }
439
440    /// # Panics
441    ///
442    /// - if the stack is too shallow
443    /// - if the top of the stack does not contain an element of the requested type
444    /// - if the element is incorrectly [`BFieldCodec`] encoded
445    pub fn pop_from_stack(data_type: DataType, stack: &mut Vec<BFieldElement>) -> Self {
446        match data_type {
447            DataType::Bool => Self::Bool(pop_encodable(stack)),
448            DataType::U32 => Self::U32(pop_encodable(stack)),
449            DataType::U64 => Self::U64(pop_encodable(stack)),
450            DataType::U128 => Self::U128(pop_encodable(stack)),
451            DataType::I128 => Self::I128(pop_encodable(stack)),
452            DataType::Bfe => Self::Bfe(pop_encodable(stack)),
453            DataType::Xfe => Self::Xfe(pop_encodable(stack)),
454            DataType::Digest => Self::Digest(pop_encodable(stack)),
455            DataType::List(_)
456            | DataType::Array(_)
457            | DataType::Tuple(_)
458            | DataType::VoidPointer
459            | DataType::StructRef(_)
460            | DataType::U160
461            | DataType::U192 => unimplemented!(),
462        }
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469    use crate::test_prelude::*;
470
471    impl Literal {
472        fn as_i128(&self) -> i128 {
473            let Self::I128(val) = self else {
474                panic!("Expected i128, got: {self:?}");
475            };
476
477            *val
478        }
479    }
480
481    #[proptest]
482    fn push_to_stack_leaves_value_on_top_of_stack(#[strategy(arb())] literal: Literal) {
483        let code = triton_asm!(
484            {&literal.push_to_stack_code()}
485            halt
486        );
487
488        let mut vm_state = VMState::new(
489            Program::new(&code),
490            PublicInput::default(),
491            NonDeterminism::default(),
492        );
493        vm_state.run().unwrap();
494        let read = Literal::pop_from_stack(literal.data_type(), &mut vm_state.op_stack.stack);
495        assert_eq!(literal, read);
496    }
497
498    #[test]
499    fn static_lengths_match_up_for_vectors() {
500        assert_eq!(
501            <Vec<XFieldElement>>::static_length(),
502            DataType::List(Box::new(DataType::Xfe)).static_length()
503        );
504    }
505
506    #[test]
507    fn static_lengths_match_up_for_array_with_static_length_data_type() {
508        assert_eq!(
509            <[BFieldElement; 42]>::static_length(),
510            DataType::Array(Box::new(ArrayType {
511                element_type: DataType::Bfe,
512                length: 42
513            }))
514            .static_length()
515        );
516    }
517
518    #[test]
519    fn static_lengths_match_up_for_array_with_dynamic_length_data_type() {
520        assert_eq!(
521            <[Vec<BFieldElement>; 42]>::static_length(),
522            DataType::Array(Box::new(ArrayType {
523                element_type: DataType::List(Box::new(DataType::Bfe)),
524                length: 42
525            }))
526            .static_length()
527        );
528    }
529
530    #[test]
531    fn static_lengths_match_up_for_tuple_with_only_static_length_types() {
532        assert_eq!(
533            <(XFieldElement, BFieldElement)>::static_length(),
534            DataType::Tuple(vec![DataType::Xfe, DataType::Bfe]).static_length()
535        );
536    }
537
538    #[test]
539    fn static_lengths_match_up_for_tuple_with_dynamic_length_types() {
540        assert_eq!(
541            <(XFieldElement, Vec<BFieldElement>)>::static_length(),
542            DataType::Tuple(vec![DataType::Xfe, DataType::List(Box::new(DataType::Bfe))])
543                .static_length()
544        );
545    }
546
547    #[test]
548    fn static_length_of_void_pointer_is_unknown() {
549        assert!(DataType::VoidPointer.static_length().is_none());
550    }
551
552    #[test]
553    fn static_lengths_match_up_for_struct_with_only_static_length_types() {
554        #[derive(Debug, Clone, BFieldCodec)]
555        struct StructTyStatic {
556            u32: u32,
557            u64: u64,
558        }
559
560        let struct_ty_static = StructType {
561            name: "struct".to_owned(),
562            fields: vec![
563                ("u32".to_owned(), DataType::U32),
564                ("u64".to_owned(), DataType::U64),
565            ],
566        };
567        assert_eq!(
568            StructTyStatic::static_length(),
569            DataType::StructRef(struct_ty_static).static_length()
570        );
571    }
572
573    #[test]
574    fn static_lengths_match_up_for_struct_with_dynamic_length_types() {
575        #[derive(Debug, Clone, BFieldCodec)]
576        struct StructTyDyn {
577            digest: Digest,
578            list: Vec<BFieldElement>,
579        }
580
581        let struct_ty_dyn = StructType {
582            name: "struct".to_owned(),
583            fields: vec![
584                ("digest".to_owned(), DataType::Digest),
585                ("list".to_owned(), DataType::List(Box::new(DataType::Bfe))),
586            ],
587        };
588        assert_eq!(
589            StructTyDyn::static_length(),
590            DataType::StructRef(struct_ty_dyn).static_length()
591        );
592    }
593
594    #[test]
595    fn random_list_of_lists_can_be_generated() {
596        let mut rng = StdRng::seed_from_u64(5950175350772851878);
597        let element_type = DataType::List(Box::new(DataType::Digest));
598        let _list = element_type.random_list(&mut rng, 10);
599    }
600
601    #[test]
602    fn i128_sizes() {
603        assert_eq!(4, DataType::I128.stack_size());
604        assert_eq!(Some(4), DataType::I128.static_length());
605    }
606
607    #[proptest]
608    fn non_negative_i128s_encode_like_u128s_prop(
609        #[strategy(arb())]
610        #[filter(#as_i128 >= 0i128)]
611        as_i128: i128,
612    ) {
613        let as_u128: u128 = as_i128.try_into().unwrap();
614        assert_eq!(
615            Literal::U128(as_u128).push_to_stack_code(),
616            Literal::I128(as_i128).push_to_stack_code()
617        );
618    }
619
620    #[proptest]
621    fn i128_literals_prop(val: i128) {
622        let program = Literal::I128(val).push_to_stack_code();
623        let program = triton_program!(
624            {&program}
625            halt
626        );
627
628        let mut vm_state = VMState::new(program, [].into(), [].into());
629        vm_state.run().unwrap();
630        let mut final_stack = vm_state.op_stack.stack;
631        let popped = Literal::pop_from_stack(DataType::I128, &mut final_stack).as_i128();
632        assert_eq!(val, popped);
633    }
634
635    #[proptest]
636    fn random_list_conforms_to_bfield_codec(#[strategy(..255_usize)] len: usize, seed: u64) {
637        let mut rng = StdRng::seed_from_u64(seed);
638        let element_type = DataType::Digest;
639        let list = element_type.random_list(&mut rng, len);
640        prop_assert!(<Vec<Digest>>::decode(&list).is_ok());
641    }
642
643    #[proptest]
644    fn random_list_of_lists_conforms_to_bfield_codec(
645        #[strategy(..255_usize)] len: usize,
646        seed: u64,
647    ) {
648        let mut rng = StdRng::seed_from_u64(seed);
649        let element_type = DataType::List(Box::new(DataType::Digest));
650        let list = element_type.random_list(&mut rng, len);
651        prop_assert!(<Vec<Vec<Digest>>>::decode(&list).is_ok());
652    }
653
654    #[proptest]
655    fn random_array_conforms_to_bfield_codec(seed: u64) {
656        const LEN: usize = 42;
657
658        let mut rng = StdRng::seed_from_u64(seed);
659        let array_type = ArrayType {
660            element_type: DataType::Digest,
661            length: LEN,
662        };
663        let array = DataType::random_array(&mut rng, &array_type);
664        prop_assert!(<[Digest; LEN]>::decode(&array).is_ok());
665    }
666
667    #[proptest]
668    fn random_array_of_arrays_conforms_to_bfield_codec(seed: u64) {
669        const INNER_LEN: usize = 42;
670        const OUTER_LEN: usize = 13;
671
672        let mut rng = StdRng::seed_from_u64(seed);
673        let inner_type = ArrayType {
674            element_type: DataType::Digest,
675            length: INNER_LEN,
676        };
677        let outer_type = ArrayType {
678            element_type: DataType::Array(Box::new(inner_type)),
679            length: OUTER_LEN,
680        };
681        let array = DataType::random_array(&mut rng, &outer_type);
682        prop_assert!(<[[Digest; INNER_LEN]; OUTER_LEN]>::decode(&array).is_ok());
683    }
684}
685
686/// Test [`DataType::compare`] by wrapping it in [`BasicSnippet`] and
687/// implementing [`RustShadow`] for it.
688#[cfg(test)]
689mod compare_literals {
690    use super::*;
691    use crate::prelude::*;
692    use crate::test_prelude::*;
693
694    macro_rules! comparison_snippet {
695        ($name:ident for tasm_ty $tasm_ty:ident and rust_ty $rust_ty:ident) => {
696            #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
697            struct $name;
698
699            impl BasicSnippet for $name {
700                fn parameters(&self) -> Vec<(DataType, String)> {
701                    ["left", "right"]
702                        .map(|s| (DataType::$tasm_ty, s.to_string()))
703                        .to_vec()
704                }
705
706                fn return_values(&self) -> Vec<(DataType, String)> {
707                    vec![(DataType::Bool, "are_eq".to_string())]
708                }
709
710                fn entrypoint(&self) -> String {
711                    let ty = stringify!($tasm_ty);
712                    format!("tasmlib_test_compare_{ty}")
713                }
714
715                fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
716                    triton_asm!({self.entrypoint()}: {&DataType::$tasm_ty.compare()} return)
717                }
718            }
719
720            impl Closure for $name {
721                type Args = ($rust_ty, $rust_ty);
722
723                fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
724                    let (right, left) = pop_encodable::<Self::Args>(stack);
725                    push_encodable(stack, &(left == right));
726                }
727
728                fn pseudorandom_args(
729                    &self,
730                    seed: [u8; 32],
731                    _: Option<BenchmarkCase>
732                ) -> Self::Args {
733                    // almost certainly different arguments, comparison gives `false`
734                    StdRng::from_seed(seed).random()
735                }
736
737                fn corner_case_args(&self) -> Vec<Self::Args> {
738                    // identical arguments, comparison gives `true`
739                    vec![Self::Args::default()]
740                }
741            }
742        };
743    }
744
745    // stack size == 1
746    comparison_snippet!(CompareBfes for tasm_ty Bfe and rust_ty BFieldElement);
747
748    // stack size > 1
749    comparison_snippet!(CompareDigests for tasm_ty Digest and rust_ty Digest);
750
751    #[test]
752    fn test() {
753        ShadowedClosure::new(CompareBfes).test();
754        ShadowedClosure::new(CompareDigests).test();
755    }
756
757    #[test]
758    fn bench() {
759        ShadowedClosure::new(CompareBfes).bench();
760        ShadowedClosure::new(CompareDigests).bench();
761    }
762}