Skip to main content

tasm_lib/list/
contains.rs

1use triton_vm::prelude::*;
2
3use crate::list::get::Get;
4use crate::prelude::*;
5
6/// Returns `true` if the list contains an element with the given value.
7///
8/// This operation is *O*(*n*).
9///
10/// Mirrors the `contains` method from Rust `core` as closely as possible.
11///
12/// Only supports lists with [statically sized](BFieldCodec::static_length)
13/// elements. The element's static size must be in range `1..=14`.
14///
15/// ### Behavior
16///
17/// ```text
18/// BEFORE: _ *list [needle: ElementType]
19/// AFTER:  _ [needle ∈ list: bool]
20/// ```
21///
22/// ### Preconditions
23///
24/// - the argument `*list` points to a properly [`BFieldCodec`]-encoded list
25/// - all input arguments are properly [`BFieldCodec`] encoded
26///
27/// ### Postconditions
28///
29/// - the output is properly [`BFieldCodec`] encoded
30#[derive(Debug, Clone, Eq, PartialEq, Hash)]
31pub struct Contains {
32    element_type: DataType,
33}
34
35impl Contains {
36    /// # Panics
37    ///
38    /// Panics
39    /// - if the element has [dynamic length][BFieldCodec::static_length], or
40    /// - if the static length is 0, or
41    /// - if the static length is larger than or equal to 15.
42    // Requirement “static length < 15” is needed for comparing elements.
43    pub fn new(element_type: DataType) -> Self {
44        Get::assert_element_type_is_supported(&element_type);
45
46        Self { element_type }
47    }
48}
49
50impl BasicSnippet for Contains {
51    fn parameters(&self) -> Vec<(DataType, String)> {
52        let element_type = self.element_type.clone();
53        let list_type = DataType::List(Box::new(element_type.clone()));
54
55        vec![
56            (list_type, "self".to_owned()),
57            (element_type, "needle".to_owned()),
58        ]
59    }
60
61    fn return_values(&self) -> Vec<(DataType, String)> {
62        vec![(DataType::Bool, "match_found".to_owned())]
63    }
64
65    fn entrypoint(&self) -> String {
66        let element_type = self.element_type.label_friendly_name();
67        format!("tasmlib_list_contains___{element_type}")
68    }
69
70    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
71        // unwrap is fine: Self::new checks range of stack size
72        let element_size = self.element_type.stack_size().try_into().unwrap();
73        let needle_alloc = library.kmalloc(element_size);
74
75        let entrypoint = self.entrypoint();
76        let loop_label = format!("{entrypoint}_loop");
77        let mul_with_element_size = match element_size {
78            1 => triton_asm!(), // no-op
79            n => triton_asm!(push {n} mul),
80        };
81
82        triton_asm!(
83            // BEFORE: _ *list [value]
84            // AFTER:  _ match_found
85            {entrypoint}:
86                push {needle_alloc.write_address()}
87                {&self.element_type.write_value_to_memory_leave_pointer()}
88                pop 1           // _ *list
89
90                push 0          hint match_found: bool = stack[0]
91                pick 1          // _ 0 *list
92
93                dup 0
94                read_mem 1      // _ 0 *list list_len (*list - 1)
95                addi 1          // _ 0 *list list_len *list
96                pick 1          // _ 0 *list *list list_len
97                {&mul_with_element_size}
98                                // _ 0 *list *list (list_len * elem_size)
99                add             // _ 0 *list *list_last_word
100
101                call {loop_label}
102                                // _ match_found *list *list_last_word
103                pop 2           // _ match_found
104
105                return
106
107            // INVARIANT: _ match_found *list *list[i]
108            {loop_label}:
109                /* loop header – all elements checked, or match found? */
110                dup 1
111                dup 1
112                eq              // _ match_found *list *list[i] (*list == *list[i])
113                dup 3
114                add             // _ match_found *list *list[i] ((*list == *list[i]) || match_found)
115                skiz return     // _ 0           *list *list[i]
116
117
118                /* Loop body */
119                {&self.element_type.read_value_from_memory_leave_pointer()}
120                                // _ 0 *list [haystack_element] *list[i-1]
121                place {self.element_type.stack_size()}
122                                // _ 0 *list *list[i-1] [haystack_element]
123
124                push {needle_alloc.read_address()}
125                {&self.element_type.read_value_from_memory_pop_pointer()}
126                                // _ 0 *list *list[i-1] [haystack_element] [needle]
127                {&self.element_type.compare()}
128                                // _ 0 *list *list[i-1] (haystack_element == needle)
129
130                swap 3
131                pop 1           // _ (haystack_element == needle) *list *list[i-1]
132                recurse
133        )
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::library::STATIC_MEMORY_FIRST_ADDRESS;
141    use crate::rust_shadowing_helper_functions::list::load_list_unstructured;
142    use crate::test_helpers::test_rust_equivalence_given_complete_state;
143    use crate::test_prelude::*;
144
145    impl Contains {
146        fn static_pointer_isolated_run(&self) -> BFieldElement {
147            STATIC_MEMORY_FIRST_ADDRESS - bfe!(self.element_type.stack_size()) + bfe!(1)
148        }
149
150        fn prepare_state(
151            &self,
152            list_pointer: BFieldElement,
153            mut needle: Vec<BFieldElement>,
154            haystack_elements: Vec<Vec<BFieldElement>>,
155        ) -> FunctionInitialState {
156            let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::default();
157            let list_length = haystack_elements.len();
158            memory.insert(list_pointer, bfe!(list_length));
159            let mut word_pointer = list_pointer;
160            word_pointer.increment();
161            for rand_elem in haystack_elements.iter() {
162                for word in rand_elem {
163                    memory.insert(word_pointer, *word);
164                    word_pointer.increment();
165                }
166            }
167
168            needle.reverse();
169            let init_stack = [
170                self.init_stack_for_isolated_run(),
171                vec![list_pointer],
172                needle,
173            ]
174            .concat();
175            FunctionInitialState {
176                stack: init_stack,
177                memory,
178            }
179        }
180    }
181
182    impl Function for Contains {
183        fn rust_shadow(
184            &self,
185            stack: &mut Vec<BFieldElement>,
186            memory: &mut HashMap<BFieldElement, BFieldElement>,
187        ) -> Result<(), RustShadowError> {
188            let needle = (0..self.element_type.stack_size())
189                .map(|_| stack.pop().ok_or(RustShadowError::StackUnderflow))
190                .try_collect()?;
191
192            let haystack_list_ptr = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
193            let haystack_elems =
194                load_list_unstructured(self.element_type.stack_size(), haystack_list_ptr, memory)?;
195
196            stack.push(bfe!(haystack_elems.contains(&needle) as u32));
197
198            // Write needle value to static memory
199            let mut static_pointer = self.static_pointer_isolated_run();
200            for word in needle {
201                memory.insert(static_pointer, word);
202                static_pointer.increment();
203            }
204
205            Ok(())
206        }
207
208        fn pseudorandom_initial_state(
209            &self,
210            seed: [u8; 32],
211            bench_case: Option<BenchmarkCase>,
212        ) -> FunctionInitialState {
213            let mut rng: StdRng = StdRng::from_seed(seed);
214            let list_length = match bench_case {
215                Some(BenchmarkCase::CommonCase) => 100,
216                Some(BenchmarkCase::WorstCase) => 400,
217                None => rng.random_range(1..400),
218            };
219            let haystack_elements = (0..list_length)
220                .map(|_| self.element_type.seeded_random_element(&mut rng))
221                .collect_vec();
222
223            let list_pointer: BFieldElement = rng.random();
224
225            let needle = match bench_case {
226                Some(BenchmarkCase::CommonCase) => haystack_elements[list_length / 2].clone(),
227                Some(BenchmarkCase::WorstCase) => haystack_elements[list_length / 2].clone(),
228                None => {
229                    // An element is guaranteed to exist, as the initial length is never 0
230                    if rng.random() {
231                        haystack_elements
232                            .choose(&mut rng)
233                            .as_ref()
234                            .unwrap()
235                            .to_owned()
236                            .to_owned()
237                    } else {
238                        // Will create a false positive with rate
239                        // $ list_length / element-type-value-space $. But
240                        // since the rust-shadowing agrees with the TASM code,
241                        // the test will not fail.
242                        self.element_type.seeded_random_element(&mut rng)
243                    }
244                }
245            };
246
247            self.prepare_state(list_pointer, needle, haystack_elements)
248        }
249
250        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
251            let empty_list =
252                self.prepare_state(bfe!(1), bfe_vec![1; self.element_type.stack_size()], vec![]);
253
254            let an_element = bfe_vec![42; self.element_type.stack_size()];
255            let another_element = bfe_vec![420; self.element_type.stack_size()];
256            let a_pointer = bfe!(42);
257            let one_element_match =
258                self.prepare_state(a_pointer, an_element.clone(), vec![an_element.clone()]);
259            let one_element_no_match =
260                self.prepare_state(a_pointer, an_element.clone(), vec![another_element.clone()]);
261            let two_elements_match_first = self.prepare_state(
262                a_pointer,
263                an_element.clone(),
264                vec![an_element.clone(), another_element.clone()],
265            );
266            let two_elements_match_last = self.prepare_state(
267                a_pointer,
268                an_element.clone(),
269                vec![another_element.clone(), an_element.clone()],
270            );
271            let two_elements_no_match = self.prepare_state(
272                a_pointer,
273                an_element.clone(),
274                vec![another_element.clone(), another_element.clone()],
275            );
276            let two_elements_both_match = self.prepare_state(
277                a_pointer,
278                an_element.clone(),
279                vec![an_element.clone(), an_element.clone()],
280            );
281
282            let non_symmetric_value = (0..self.element_type.stack_size())
283                .map(|i| bfe!(i + 200))
284                .collect_vec();
285            let mut mirrored_non_symmetric_value = non_symmetric_value.clone();
286            mirrored_non_symmetric_value.reverse();
287            let no_match_on_inverted_value_unless_size_1 = self.prepare_state(
288                a_pointer,
289                non_symmetric_value,
290                vec![mirrored_non_symmetric_value],
291            );
292
293            vec![
294                empty_list,
295                one_element_match,
296                one_element_no_match,
297                two_elements_match_first,
298                two_elements_match_last,
299                two_elements_no_match,
300                two_elements_both_match,
301                no_match_on_inverted_value_unless_size_1,
302            ]
303        }
304    }
305
306    #[macro_rules_attr::apply(test)]
307    fn rust_shadow() {
308        for element_type in [
309            DataType::Bfe,
310            DataType::U32,
311            DataType::U64,
312            DataType::Xfe,
313            DataType::U128,
314            DataType::Digest,
315            DataType::Tuple(vec![DataType::Digest, DataType::Digest]),
316        ] {
317            ShadowedFunction::new(Contains::new(element_type)).test()
318        }
319    }
320
321    #[macro_rules_attr::apply(test)]
322    fn contains_returns_true_on_contained_value() {
323        let snippet = Contains::new(DataType::U64);
324        let a_u64_element = bfe_vec![2, 3];
325        let u64_list = vec![a_u64_element.clone()];
326        let init_state = snippet.prepare_state(bfe!(0), a_u64_element, u64_list);
327        let nd = NonDeterminism::default().with_ram(init_state.memory);
328
329        let expected_final_stack = [snippet.init_stack_for_isolated_run(), bfe_vec![1]].concat();
330
331        test_rust_equivalence_given_complete_state(
332            &ShadowedFunction::new(snippet),
333            &init_state.stack,
334            &[],
335            &nd,
336            &None,
337            Some(&expected_final_stack),
338        );
339    }
340
341    #[macro_rules_attr::apply(test)]
342    fn contains_returns_false_on_mirrored_value() {
343        let snippet = Contains::new(DataType::U64);
344        let a_u64_element = bfe_vec![2, 3];
345        let mirrored_u64_element = bfe_vec![3, 2];
346        let init_state = snippet.prepare_state(bfe!(0), a_u64_element, vec![mirrored_u64_element]);
347        let nd = NonDeterminism::default().with_ram(init_state.memory);
348
349        let expected_final_stack = [snippet.init_stack_for_isolated_run(), bfe_vec![0]].concat();
350
351        test_rust_equivalence_given_complete_state(
352            &ShadowedFunction::new(Contains::new(DataType::U64)),
353            &init_state.stack,
354            &[],
355            &nd,
356            &None,
357            Some(&expected_final_stack),
358        );
359    }
360}
361
362#[cfg(test)]
363mod benches {
364    use super::*;
365    use crate::test_prelude::*;
366
367    #[macro_rules_attr::apply(test)]
368    fn benchmark() {
369        for element_type in [DataType::U64, DataType::Digest] {
370            ShadowedFunction::new(Contains::new(element_type)).bench();
371        }
372    }
373}