1use triton_vm::prelude::*;
2
3use crate::list::get::Get;
4use crate::prelude::*;
5
6#[derive(Debug, Clone, Eq, PartialEq, Hash)]
31pub struct Contains {
32 element_type: DataType,
33}
34
35impl Contains {
36 pub fn new(element_type: DataType) -> Self {
44 Get::assert_element_type_is_supported(&element_type);
45
46 Self { element_type }
47 }
48}
49
50impl BasicSnippet for Contains {
51 fn parameters(&self) -> Vec<(DataType, String)> {
52 let element_type = self.element_type.clone();
53 let list_type = DataType::List(Box::new(element_type.clone()));
54
55 vec![
56 (list_type, "self".to_owned()),
57 (element_type, "needle".to_owned()),
58 ]
59 }
60
61 fn return_values(&self) -> Vec<(DataType, String)> {
62 vec![(DataType::Bool, "match_found".to_owned())]
63 }
64
65 fn entrypoint(&self) -> String {
66 let element_type = self.element_type.label_friendly_name();
67 format!("tasmlib_list_contains___{element_type}")
68 }
69
70 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
71 let element_size = self.element_type.stack_size().try_into().unwrap();
73 let needle_alloc = library.kmalloc(element_size);
74
75 let entrypoint = self.entrypoint();
76 let loop_label = format!("{entrypoint}_loop");
77 let mul_with_element_size = match element_size {
78 1 => triton_asm!(), n => triton_asm!(push {n} mul),
80 };
81
82 triton_asm!(
83 {entrypoint}:
86 push {needle_alloc.write_address()}
87 {&self.element_type.write_value_to_memory_leave_pointer()}
88 pop 1 push 0 hint match_found: bool = stack[0]
91 pick 1 dup 0
94 read_mem 1 addi 1 pick 1 {&mul_with_element_size}
98 add call {loop_label}
102 pop 2 return
106
107 {loop_label}:
109 dup 1
111 dup 1
112 eq dup 3
114 add skiz return {&self.element_type.read_value_from_memory_leave_pointer()}
120 place {self.element_type.stack_size()}
122 push {needle_alloc.read_address()}
125 {&self.element_type.read_value_from_memory_pop_pointer()}
126 {&self.element_type.compare()}
128 swap 3
131 pop 1 recurse
133 )
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::library::STATIC_MEMORY_FIRST_ADDRESS;
141 use crate::rust_shadowing_helper_functions::list::load_list_unstructured;
142 use crate::test_helpers::test_rust_equivalence_given_complete_state;
143 use crate::test_prelude::*;
144
145 impl Contains {
146 fn static_pointer_isolated_run(&self) -> BFieldElement {
147 STATIC_MEMORY_FIRST_ADDRESS - bfe!(self.element_type.stack_size()) + bfe!(1)
148 }
149
150 fn prepare_state(
151 &self,
152 list_pointer: BFieldElement,
153 mut needle: Vec<BFieldElement>,
154 haystack_elements: Vec<Vec<BFieldElement>>,
155 ) -> FunctionInitialState {
156 let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::default();
157 let list_length = haystack_elements.len();
158 memory.insert(list_pointer, bfe!(list_length));
159 let mut word_pointer = list_pointer;
160 word_pointer.increment();
161 for rand_elem in haystack_elements.iter() {
162 for word in rand_elem {
163 memory.insert(word_pointer, *word);
164 word_pointer.increment();
165 }
166 }
167
168 needle.reverse();
169 let init_stack = [
170 self.init_stack_for_isolated_run(),
171 vec![list_pointer],
172 needle,
173 ]
174 .concat();
175 FunctionInitialState {
176 stack: init_stack,
177 memory,
178 }
179 }
180 }
181
182 impl Function for Contains {
183 fn rust_shadow(
184 &self,
185 stack: &mut Vec<BFieldElement>,
186 memory: &mut HashMap<BFieldElement, BFieldElement>,
187 ) -> Result<(), RustShadowError> {
188 let needle = (0..self.element_type.stack_size())
189 .map(|_| stack.pop().ok_or(RustShadowError::StackUnderflow))
190 .try_collect()?;
191
192 let haystack_list_ptr = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
193 let haystack_elems =
194 load_list_unstructured(self.element_type.stack_size(), haystack_list_ptr, memory)?;
195
196 stack.push(bfe!(haystack_elems.contains(&needle) as u32));
197
198 let mut static_pointer = self.static_pointer_isolated_run();
200 for word in needle {
201 memory.insert(static_pointer, word);
202 static_pointer.increment();
203 }
204
205 Ok(())
206 }
207
208 fn pseudorandom_initial_state(
209 &self,
210 seed: [u8; 32],
211 bench_case: Option<BenchmarkCase>,
212 ) -> FunctionInitialState {
213 let mut rng: StdRng = StdRng::from_seed(seed);
214 let list_length = match bench_case {
215 Some(BenchmarkCase::CommonCase) => 100,
216 Some(BenchmarkCase::WorstCase) => 400,
217 None => rng.random_range(1..400),
218 };
219 let haystack_elements = (0..list_length)
220 .map(|_| self.element_type.seeded_random_element(&mut rng))
221 .collect_vec();
222
223 let list_pointer: BFieldElement = rng.random();
224
225 let needle = match bench_case {
226 Some(BenchmarkCase::CommonCase) => haystack_elements[list_length / 2].clone(),
227 Some(BenchmarkCase::WorstCase) => haystack_elements[list_length / 2].clone(),
228 None => {
229 if rng.random() {
231 haystack_elements
232 .choose(&mut rng)
233 .as_ref()
234 .unwrap()
235 .to_owned()
236 .to_owned()
237 } else {
238 self.element_type.seeded_random_element(&mut rng)
243 }
244 }
245 };
246
247 self.prepare_state(list_pointer, needle, haystack_elements)
248 }
249
250 fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
251 let empty_list =
252 self.prepare_state(bfe!(1), bfe_vec![1; self.element_type.stack_size()], vec![]);
253
254 let an_element = bfe_vec![42; self.element_type.stack_size()];
255 let another_element = bfe_vec![420; self.element_type.stack_size()];
256 let a_pointer = bfe!(42);
257 let one_element_match =
258 self.prepare_state(a_pointer, an_element.clone(), vec![an_element.clone()]);
259 let one_element_no_match =
260 self.prepare_state(a_pointer, an_element.clone(), vec![another_element.clone()]);
261 let two_elements_match_first = self.prepare_state(
262 a_pointer,
263 an_element.clone(),
264 vec![an_element.clone(), another_element.clone()],
265 );
266 let two_elements_match_last = self.prepare_state(
267 a_pointer,
268 an_element.clone(),
269 vec![another_element.clone(), an_element.clone()],
270 );
271 let two_elements_no_match = self.prepare_state(
272 a_pointer,
273 an_element.clone(),
274 vec![another_element.clone(), another_element.clone()],
275 );
276 let two_elements_both_match = self.prepare_state(
277 a_pointer,
278 an_element.clone(),
279 vec![an_element.clone(), an_element.clone()],
280 );
281
282 let non_symmetric_value = (0..self.element_type.stack_size())
283 .map(|i| bfe!(i + 200))
284 .collect_vec();
285 let mut mirrored_non_symmetric_value = non_symmetric_value.clone();
286 mirrored_non_symmetric_value.reverse();
287 let no_match_on_inverted_value_unless_size_1 = self.prepare_state(
288 a_pointer,
289 non_symmetric_value,
290 vec![mirrored_non_symmetric_value],
291 );
292
293 vec![
294 empty_list,
295 one_element_match,
296 one_element_no_match,
297 two_elements_match_first,
298 two_elements_match_last,
299 two_elements_no_match,
300 two_elements_both_match,
301 no_match_on_inverted_value_unless_size_1,
302 ]
303 }
304 }
305
306 #[macro_rules_attr::apply(test)]
307 fn rust_shadow() {
308 for element_type in [
309 DataType::Bfe,
310 DataType::U32,
311 DataType::U64,
312 DataType::Xfe,
313 DataType::U128,
314 DataType::Digest,
315 DataType::Tuple(vec![DataType::Digest, DataType::Digest]),
316 ] {
317 ShadowedFunction::new(Contains::new(element_type)).test()
318 }
319 }
320
321 #[macro_rules_attr::apply(test)]
322 fn contains_returns_true_on_contained_value() {
323 let snippet = Contains::new(DataType::U64);
324 let a_u64_element = bfe_vec![2, 3];
325 let u64_list = vec![a_u64_element.clone()];
326 let init_state = snippet.prepare_state(bfe!(0), a_u64_element, u64_list);
327 let nd = NonDeterminism::default().with_ram(init_state.memory);
328
329 let expected_final_stack = [snippet.init_stack_for_isolated_run(), bfe_vec![1]].concat();
330
331 test_rust_equivalence_given_complete_state(
332 &ShadowedFunction::new(snippet),
333 &init_state.stack,
334 &[],
335 &nd,
336 &None,
337 Some(&expected_final_stack),
338 );
339 }
340
341 #[macro_rules_attr::apply(test)]
342 fn contains_returns_false_on_mirrored_value() {
343 let snippet = Contains::new(DataType::U64);
344 let a_u64_element = bfe_vec![2, 3];
345 let mirrored_u64_element = bfe_vec![3, 2];
346 let init_state = snippet.prepare_state(bfe!(0), a_u64_element, vec![mirrored_u64_element]);
347 let nd = NonDeterminism::default().with_ram(init_state.memory);
348
349 let expected_final_stack = [snippet.init_stack_for_isolated_run(), bfe_vec![0]].concat();
350
351 test_rust_equivalence_given_complete_state(
352 &ShadowedFunction::new(Contains::new(DataType::U64)),
353 &init_state.stack,
354 &[],
355 &nd,
356 &None,
357 Some(&expected_final_stack),
358 );
359 }
360}
361
362#[cfg(test)]
363mod benches {
364 use super::*;
365 use crate::test_prelude::*;
366
367 #[macro_rules_attr::apply(test)]
368 fn benchmark() {
369 for element_type in [DataType::U64, DataType::Digest] {
370 ShadowedFunction::new(Contains::new(element_type)).bench();
371 }
372 }
373}