tasm_lib/list/
contains.rs

1use triton_vm::prelude::*;
2
3use crate::list::get::Get;
4use crate::prelude::*;
5
6/// Returns `true` if the list contains an element with the given value.
7///
8/// This operation is *O*(*n*).
9///
10/// Mirrors the `contains` method from Rust `core` as closely as possible.
11///
12/// Only supports lists with [statically sized](BFieldCodec::static_length)
13/// elements. The element's static size must be in range `1..=14`.
14///
15/// ### Behavior
16///
17/// ```text
18/// BEFORE: _ *list [needle: ElementType]
19/// AFTER:  _ [needle ∈ list: bool]
20/// ```
21///
22/// ### Preconditions
23///
24/// - the argument `*list` points to a properly [`BFieldCodec`]-encoded list
25/// - all input arguments are properly [`BFieldCodec`] encoded
26///
27/// ### Postconditions
28///
29/// - the output is properly [`BFieldCodec`] encoded
30#[derive(Debug, Clone, Eq, PartialEq, Hash)]
31pub struct Contains {
32    element_type: DataType,
33}
34
35impl Contains {
36    /// # Panics
37    ///
38    /// Panics
39    /// - if the element has [dynamic length][BFieldCodec::static_length], or
40    /// - if the static length is 0, or
41    /// - if the static length is larger than or equal to 15.
42    // Requirement “static length < 15” is needed for comparing elements.
43    pub fn new(element_type: DataType) -> Self {
44        Get::assert_element_type_is_supported(&element_type);
45
46        Self { element_type }
47    }
48}
49
50impl BasicSnippet for Contains {
51    fn inputs(&self) -> Vec<(DataType, String)> {
52        let element_type = self.element_type.clone();
53        let list_type = DataType::List(Box::new(element_type.clone()));
54
55        vec![
56            (list_type, "self".to_owned()),
57            (element_type, "needle".to_owned()),
58        ]
59    }
60
61    fn outputs(&self) -> Vec<(DataType, String)> {
62        vec![(DataType::Bool, "match_found".to_owned())]
63    }
64
65    fn entrypoint(&self) -> String {
66        let element_type = self.element_type.label_friendly_name();
67        format!("tasmlib_list_contains___{element_type}")
68    }
69
70    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
71        // unwrap is fine: Self::new checks range of stack size
72        let element_size = self.element_type.stack_size().try_into().unwrap();
73        let needle_alloc = library.kmalloc(element_size);
74
75        let entrypoint = self.entrypoint();
76        let loop_label = format!("{entrypoint}_loop");
77        let mul_with_element_size = match element_size {
78            1 => triton_asm!(), // no-op
79            n => triton_asm!(push {n} mul),
80        };
81
82        triton_asm!(
83            // BEFORE: _ *list [value]
84            // AFTER:  _ match_found
85            {entrypoint}:
86                push {needle_alloc.write_address()}
87                {&self.element_type.write_value_to_memory_leave_pointer()}
88                pop 1           // _ *list
89
90                push 0          hint match_found: bool = stack[0]
91                pick 1          // _ 0 *list
92
93                dup 0
94                read_mem 1      // _ 0 *list list_len (*list - 1)
95                addi 1          // _ 0 *list list_len *list
96                pick 1          // _ 0 *list *list list_len
97                {&mul_with_element_size}
98                                // _ 0 *list *list (list_len * elem_size)
99                add             // _ 0 *list *list_last_word
100
101                call {loop_label}
102                                // _ match_found *list *list_last_word
103                pop 2           // _ match_found
104
105                return
106
107            // INVARIANT: _ match_found *list *list[i]
108            {loop_label}:
109                /* loop header – all elements checked, or match found? */
110                dup 1
111                dup 1
112                eq              // _ match_found *list *list[i] (*list == *list[i])
113                dup 3
114                add             // _ match_found *list *list[i] ((*list == *list[i]) || match_found)
115                skiz return     // _ 0           *list *list[i]
116
117
118                /* Loop body */
119                {&self.element_type.read_value_from_memory_leave_pointer()}
120                                // _ 0 *list [haystack_element] *list[i-1]
121                place {self.element_type.stack_size()}
122                                // _ 0 *list *list[i-1] [haystack_element]
123
124                push {needle_alloc.read_address()}
125                {&self.element_type.read_value_from_memory_pop_pointer()}
126                                // _ 0 *list *list[i-1] [haystack_element] [needle]
127                {&self.element_type.compare()}
128                                // _ 0 *list *list[i-1] (haystack_element == needle)
129
130                swap 3
131                pop 1           // _ (haystack_element == needle) *list *list[i-1]
132                recurse
133        )
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::library::STATIC_MEMORY_FIRST_ADDRESS;
141    use crate::rust_shadowing_helper_functions::list::load_list_unstructured;
142    use crate::test_helpers::test_rust_equivalence_given_complete_state;
143    use crate::test_prelude::*;
144
145    impl Contains {
146        fn static_pointer_isolated_run(&self) -> BFieldElement {
147            STATIC_MEMORY_FIRST_ADDRESS - bfe!(self.element_type.stack_size()) + bfe!(1)
148        }
149
150        fn prepare_state(
151            &self,
152            list_pointer: BFieldElement,
153            mut needle: Vec<BFieldElement>,
154            haystack_elements: Vec<Vec<BFieldElement>>,
155        ) -> FunctionInitialState {
156            let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::default();
157            let list_length = haystack_elements.len();
158            memory.insert(list_pointer, bfe!(list_length));
159            let mut word_pointer = list_pointer;
160            word_pointer.increment();
161            for rand_elem in haystack_elements.iter() {
162                for word in rand_elem {
163                    memory.insert(word_pointer, *word);
164                    word_pointer.increment();
165                }
166            }
167
168            needle.reverse();
169            let init_stack = [
170                self.init_stack_for_isolated_run(),
171                vec![list_pointer],
172                needle,
173            ]
174            .concat();
175            FunctionInitialState {
176                stack: init_stack,
177                memory,
178            }
179        }
180    }
181
182    impl Function for Contains {
183        fn rust_shadow(
184            &self,
185            stack: &mut Vec<BFieldElement>,
186            memory: &mut HashMap<BFieldElement, BFieldElement>,
187        ) {
188            let needle = (0..self.element_type.stack_size())
189                .map(|_| stack.pop().unwrap())
190                .collect_vec();
191
192            let haystack_list_ptr = stack.pop().unwrap();
193            let haystack_elems =
194                load_list_unstructured(self.element_type.stack_size(), haystack_list_ptr, memory);
195
196            stack.push(bfe!(haystack_elems.contains(&needle) as u32));
197
198            // Write needle value to static memory
199            let mut static_pointer = self.static_pointer_isolated_run();
200            for word in needle {
201                memory.insert(static_pointer, word);
202                static_pointer.increment();
203            }
204        }
205
206        fn pseudorandom_initial_state(
207            &self,
208            seed: [u8; 32],
209            bench_case: Option<BenchmarkCase>,
210        ) -> FunctionInitialState {
211            let mut rng: StdRng = StdRng::from_seed(seed);
212            let list_length = match bench_case {
213                Some(BenchmarkCase::CommonCase) => 100,
214                Some(BenchmarkCase::WorstCase) => 400,
215                None => rng.random_range(1..400),
216            };
217            let haystack_elements = (0..list_length)
218                .map(|_| self.element_type.seeded_random_element(&mut rng))
219                .collect_vec();
220
221            let list_pointer: BFieldElement = rng.random();
222
223            let needle = match bench_case {
224                Some(BenchmarkCase::CommonCase) => haystack_elements[list_length / 2].clone(),
225                Some(BenchmarkCase::WorstCase) => haystack_elements[list_length / 2].clone(),
226                None => {
227                    // An element is guaranteed to exist, as the initial length is never 0
228                    if rng.random() {
229                        haystack_elements
230                            .choose(&mut rng)
231                            .as_ref()
232                            .unwrap()
233                            .to_owned()
234                            .to_owned()
235                    } else {
236                        // Will create a false positive with rate
237                        // $ list_length / element-type-value-space $. But
238                        // since the rust-shadowing agrees with the TASM code,
239                        // the test will not fail.
240                        self.element_type.seeded_random_element(&mut rng)
241                    }
242                }
243            };
244
245            self.prepare_state(list_pointer, needle, haystack_elements)
246        }
247
248        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
249            let empty_list =
250                self.prepare_state(bfe!(1), bfe_vec![1; self.element_type.stack_size()], vec![]);
251
252            let an_element = bfe_vec![42; self.element_type.stack_size()];
253            let another_element = bfe_vec![420; self.element_type.stack_size()];
254            let a_pointer = bfe!(42);
255            let one_element_match =
256                self.prepare_state(a_pointer, an_element.clone(), vec![an_element.clone()]);
257            let one_element_no_match =
258                self.prepare_state(a_pointer, an_element.clone(), vec![another_element.clone()]);
259            let two_elements_match_first = self.prepare_state(
260                a_pointer,
261                an_element.clone(),
262                vec![an_element.clone(), another_element.clone()],
263            );
264            let two_elements_match_last = self.prepare_state(
265                a_pointer,
266                an_element.clone(),
267                vec![another_element.clone(), an_element.clone()],
268            );
269            let two_elements_no_match = self.prepare_state(
270                a_pointer,
271                an_element.clone(),
272                vec![another_element.clone(), another_element.clone()],
273            );
274            let two_elements_both_match = self.prepare_state(
275                a_pointer,
276                an_element.clone(),
277                vec![an_element.clone(), an_element.clone()],
278            );
279
280            let non_symmetric_value = (0..self.element_type.stack_size())
281                .map(|i| bfe!(i + 200))
282                .collect_vec();
283            let mut mirrored_non_symmetric_value = non_symmetric_value.clone();
284            mirrored_non_symmetric_value.reverse();
285            let no_match_on_inverted_value_unless_size_1 = self.prepare_state(
286                a_pointer,
287                non_symmetric_value,
288                vec![mirrored_non_symmetric_value],
289            );
290
291            vec![
292                empty_list,
293                one_element_match,
294                one_element_no_match,
295                two_elements_match_first,
296                two_elements_match_last,
297                two_elements_no_match,
298                two_elements_both_match,
299                no_match_on_inverted_value_unless_size_1,
300            ]
301        }
302    }
303
304    #[test]
305    fn rust_shadow() {
306        for element_type in [
307            DataType::Bfe,
308            DataType::U32,
309            DataType::U64,
310            DataType::Xfe,
311            DataType::U128,
312            DataType::Digest,
313            DataType::Tuple(vec![DataType::Digest, DataType::Digest]),
314        ] {
315            ShadowedFunction::new(Contains::new(element_type)).test()
316        }
317    }
318
319    #[test]
320    fn contains_returns_true_on_contained_value() {
321        let snippet = Contains::new(DataType::U64);
322        let a_u64_element = bfe_vec![2, 3];
323        let u64_list = vec![a_u64_element.clone()];
324        let init_state = snippet.prepare_state(bfe!(0), a_u64_element, u64_list);
325        let nd = NonDeterminism::default().with_ram(init_state.memory);
326
327        let expected_final_stack = [snippet.init_stack_for_isolated_run(), bfe_vec![1]].concat();
328
329        test_rust_equivalence_given_complete_state(
330            &ShadowedFunction::new(snippet),
331            &init_state.stack,
332            &[],
333            &nd,
334            &None,
335            Some(&expected_final_stack),
336        );
337    }
338
339    #[test]
340    fn contains_returns_false_on_mirrored_value() {
341        let snippet = Contains::new(DataType::U64);
342        let a_u64_element = bfe_vec![2, 3];
343        let mirrored_u64_element = bfe_vec![3, 2];
344        let init_state = snippet.prepare_state(bfe!(0), a_u64_element, vec![mirrored_u64_element]);
345        let nd = NonDeterminism::default().with_ram(init_state.memory);
346
347        let expected_final_stack = [snippet.init_stack_for_isolated_run(), bfe_vec![0]].concat();
348
349        test_rust_equivalence_given_complete_state(
350            &ShadowedFunction::new(Contains::new(DataType::U64)),
351            &init_state.stack,
352            &[],
353            &nd,
354            &None,
355            Some(&expected_final_stack),
356        );
357    }
358}
359
360#[cfg(test)]
361mod benches {
362    use super::*;
363    use crate::test_prelude::*;
364
365    #[test]
366    fn benchmark() {
367        for element_type in [DataType::U64, DataType::Digest] {
368            ShadowedFunction::new(Contains::new(element_type)).bench();
369        }
370    }
371}