1use triton_vm::prelude::*;
2
3use crate::list::get::Get;
4use crate::prelude::*;
5
6#[derive(Debug, Clone, Eq, PartialEq, Hash)]
31pub struct Contains {
32 element_type: DataType,
33}
34
35impl Contains {
36 pub fn new(element_type: DataType) -> Self {
44 Get::assert_element_type_is_supported(&element_type);
45
46 Self { element_type }
47 }
48}
49
50impl BasicSnippet for Contains {
51 fn inputs(&self) -> Vec<(DataType, String)> {
52 let element_type = self.element_type.clone();
53 let list_type = DataType::List(Box::new(element_type.clone()));
54
55 vec![
56 (list_type, "self".to_owned()),
57 (element_type, "needle".to_owned()),
58 ]
59 }
60
61 fn outputs(&self) -> Vec<(DataType, String)> {
62 vec![(DataType::Bool, "match_found".to_owned())]
63 }
64
65 fn entrypoint(&self) -> String {
66 let element_type = self.element_type.label_friendly_name();
67 format!("tasmlib_list_contains___{element_type}")
68 }
69
70 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
71 let element_size = self.element_type.stack_size().try_into().unwrap();
73 let needle_alloc = library.kmalloc(element_size);
74
75 let entrypoint = self.entrypoint();
76 let loop_label = format!("{entrypoint}_loop");
77 let mul_with_element_size = match element_size {
78 1 => triton_asm!(), n => triton_asm!(push {n} mul),
80 };
81
82 triton_asm!(
83 {entrypoint}:
86 push {needle_alloc.write_address()}
87 {&self.element_type.write_value_to_memory_leave_pointer()}
88 pop 1 push 0 hint match_found: bool = stack[0]
91 pick 1 dup 0
94 read_mem 1 addi 1 pick 1 {&mul_with_element_size}
98 add call {loop_label}
102 pop 2 return
106
107 {loop_label}:
109 dup 1
111 dup 1
112 eq dup 3
114 add skiz return {&self.element_type.read_value_from_memory_leave_pointer()}
120 place {self.element_type.stack_size()}
122 push {needle_alloc.read_address()}
125 {&self.element_type.read_value_from_memory_pop_pointer()}
126 {&self.element_type.compare()}
128 swap 3
131 pop 1 recurse
133 )
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::library::STATIC_MEMORY_FIRST_ADDRESS;
141 use crate::rust_shadowing_helper_functions::list::load_list_unstructured;
142 use crate::test_helpers::test_rust_equivalence_given_complete_state;
143 use crate::test_prelude::*;
144
145 impl Contains {
146 fn static_pointer_isolated_run(&self) -> BFieldElement {
147 STATIC_MEMORY_FIRST_ADDRESS - bfe!(self.element_type.stack_size()) + bfe!(1)
148 }
149
150 fn prepare_state(
151 &self,
152 list_pointer: BFieldElement,
153 mut needle: Vec<BFieldElement>,
154 haystack_elements: Vec<Vec<BFieldElement>>,
155 ) -> FunctionInitialState {
156 let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::default();
157 let list_length = haystack_elements.len();
158 memory.insert(list_pointer, bfe!(list_length));
159 let mut word_pointer = list_pointer;
160 word_pointer.increment();
161 for rand_elem in haystack_elements.iter() {
162 for word in rand_elem {
163 memory.insert(word_pointer, *word);
164 word_pointer.increment();
165 }
166 }
167
168 needle.reverse();
169 let init_stack = [
170 self.init_stack_for_isolated_run(),
171 vec![list_pointer],
172 needle,
173 ]
174 .concat();
175 FunctionInitialState {
176 stack: init_stack,
177 memory,
178 }
179 }
180 }
181
182 impl Function for Contains {
183 fn rust_shadow(
184 &self,
185 stack: &mut Vec<BFieldElement>,
186 memory: &mut HashMap<BFieldElement, BFieldElement>,
187 ) {
188 let needle = (0..self.element_type.stack_size())
189 .map(|_| stack.pop().unwrap())
190 .collect_vec();
191
192 let haystack_list_ptr = stack.pop().unwrap();
193 let haystack_elems =
194 load_list_unstructured(self.element_type.stack_size(), haystack_list_ptr, memory);
195
196 stack.push(bfe!(haystack_elems.contains(&needle) as u32));
197
198 let mut static_pointer = self.static_pointer_isolated_run();
200 for word in needle {
201 memory.insert(static_pointer, word);
202 static_pointer.increment();
203 }
204 }
205
206 fn pseudorandom_initial_state(
207 &self,
208 seed: [u8; 32],
209 bench_case: Option<BenchmarkCase>,
210 ) -> FunctionInitialState {
211 let mut rng: StdRng = StdRng::from_seed(seed);
212 let list_length = match bench_case {
213 Some(BenchmarkCase::CommonCase) => 100,
214 Some(BenchmarkCase::WorstCase) => 400,
215 None => rng.random_range(1..400),
216 };
217 let haystack_elements = (0..list_length)
218 .map(|_| self.element_type.seeded_random_element(&mut rng))
219 .collect_vec();
220
221 let list_pointer: BFieldElement = rng.random();
222
223 let needle = match bench_case {
224 Some(BenchmarkCase::CommonCase) => haystack_elements[list_length / 2].clone(),
225 Some(BenchmarkCase::WorstCase) => haystack_elements[list_length / 2].clone(),
226 None => {
227 if rng.random() {
229 haystack_elements
230 .choose(&mut rng)
231 .as_ref()
232 .unwrap()
233 .to_owned()
234 .to_owned()
235 } else {
236 self.element_type.seeded_random_element(&mut rng)
241 }
242 }
243 };
244
245 self.prepare_state(list_pointer, needle, haystack_elements)
246 }
247
248 fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
249 let empty_list =
250 self.prepare_state(bfe!(1), bfe_vec![1; self.element_type.stack_size()], vec![]);
251
252 let an_element = bfe_vec![42; self.element_type.stack_size()];
253 let another_element = bfe_vec![420; self.element_type.stack_size()];
254 let a_pointer = bfe!(42);
255 let one_element_match =
256 self.prepare_state(a_pointer, an_element.clone(), vec![an_element.clone()]);
257 let one_element_no_match =
258 self.prepare_state(a_pointer, an_element.clone(), vec![another_element.clone()]);
259 let two_elements_match_first = self.prepare_state(
260 a_pointer,
261 an_element.clone(),
262 vec![an_element.clone(), another_element.clone()],
263 );
264 let two_elements_match_last = self.prepare_state(
265 a_pointer,
266 an_element.clone(),
267 vec![another_element.clone(), an_element.clone()],
268 );
269 let two_elements_no_match = self.prepare_state(
270 a_pointer,
271 an_element.clone(),
272 vec![another_element.clone(), another_element.clone()],
273 );
274 let two_elements_both_match = self.prepare_state(
275 a_pointer,
276 an_element.clone(),
277 vec![an_element.clone(), an_element.clone()],
278 );
279
280 let non_symmetric_value = (0..self.element_type.stack_size())
281 .map(|i| bfe!(i + 200))
282 .collect_vec();
283 let mut mirrored_non_symmetric_value = non_symmetric_value.clone();
284 mirrored_non_symmetric_value.reverse();
285 let no_match_on_inverted_value_unless_size_1 = self.prepare_state(
286 a_pointer,
287 non_symmetric_value,
288 vec![mirrored_non_symmetric_value],
289 );
290
291 vec![
292 empty_list,
293 one_element_match,
294 one_element_no_match,
295 two_elements_match_first,
296 two_elements_match_last,
297 two_elements_no_match,
298 two_elements_both_match,
299 no_match_on_inverted_value_unless_size_1,
300 ]
301 }
302 }
303
304 #[test]
305 fn rust_shadow() {
306 for element_type in [
307 DataType::Bfe,
308 DataType::U32,
309 DataType::U64,
310 DataType::Xfe,
311 DataType::U128,
312 DataType::Digest,
313 DataType::Tuple(vec![DataType::Digest, DataType::Digest]),
314 ] {
315 ShadowedFunction::new(Contains::new(element_type)).test()
316 }
317 }
318
319 #[test]
320 fn contains_returns_true_on_contained_value() {
321 let snippet = Contains::new(DataType::U64);
322 let a_u64_element = bfe_vec![2, 3];
323 let u64_list = vec![a_u64_element.clone()];
324 let init_state = snippet.prepare_state(bfe!(0), a_u64_element, u64_list);
325 let nd = NonDeterminism::default().with_ram(init_state.memory);
326
327 let expected_final_stack = [snippet.init_stack_for_isolated_run(), bfe_vec![1]].concat();
328
329 test_rust_equivalence_given_complete_state(
330 &ShadowedFunction::new(snippet),
331 &init_state.stack,
332 &[],
333 &nd,
334 &None,
335 Some(&expected_final_stack),
336 );
337 }
338
339 #[test]
340 fn contains_returns_false_on_mirrored_value() {
341 let snippet = Contains::new(DataType::U64);
342 let a_u64_element = bfe_vec![2, 3];
343 let mirrored_u64_element = bfe_vec![3, 2];
344 let init_state = snippet.prepare_state(bfe!(0), a_u64_element, vec![mirrored_u64_element]);
345 let nd = NonDeterminism::default().with_ram(init_state.memory);
346
347 let expected_final_stack = [snippet.init_stack_for_isolated_run(), bfe_vec![0]].concat();
348
349 test_rust_equivalence_given_complete_state(
350 &ShadowedFunction::new(Contains::new(DataType::U64)),
351 &init_state.stack,
352 &[],
353 &nd,
354 &None,
355 Some(&expected_final_stack),
356 );
357 }
358}
359
360#[cfg(test)]
361mod benches {
362 use super::*;
363 use crate::test_prelude::*;
364
365 #[test]
366 fn benchmark() {
367 for element_type in [DataType::U64, DataType::Digest] {
368 ShadowedFunction::new(Contains::new(element_type)).bench();
369 }
370 }
371}