1use triton_vm::prelude::*;
2use twenty_first::math::x_field_element::EXTENSION_DEGREE;
3
4use crate::hashing::algebraic_hasher::hash_varlen::HashVarlen;
5use crate::prelude::*;
6
7#[derive(Debug, Clone, Copy)]
8pub struct MultisetEqualityU64s;
9
10const U64_STACK_SIZE: usize = 2;
11
12impl BasicSnippet for MultisetEqualityU64s {
13 fn parameters(&self) -> Vec<(DataType, String)> {
14 vec![
15 (DataType::List(Box::new(DataType::U64)), "list_a".to_owned()),
16 (DataType::List(Box::new(DataType::U64)), "list_b".to_owned()),
17 ]
18 }
19
20 fn return_values(&self) -> Vec<(DataType, String)> {
21 vec![(DataType::Bool, "multisets_are_equal".to_owned())]
22 }
23
24 fn entrypoint(&self) -> String {
25 "tasmlib_list_multiset_equality_u64s".to_owned()
26 }
27
28 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
29 let entrypoint = self.entrypoint();
30 assert_eq!(U64_STACK_SIZE, DataType::U64.stack_size());
31
32 let hash_varlen = library.import(Box::new(HashVarlen));
33 let compare_xfes = DataType::Xfe.compare();
34
35 let running_product_result_alloc = library.kmalloc(EXTENSION_DEGREE.try_into().unwrap());
36
37 let compare_lengths = triton_asm!(
38 dup 1
41 dup 1
42 read_mem 1
45 pop 1
46 swap 1
49 read_mem 1
50 pop 1
51 dup 1
54 eq
55 );
57
58 let not_equal_length_label = format!("{entrypoint}_not_equal_length");
59 let not_equal_length_code = triton_asm!(
60 {not_equal_length_label}:
61 pop 4
63
64 push 0 push 0
69
70 return
71 );
72
73 let find_challenge_indeterminate = triton_asm!(
74 dup 2
78 dup 1
79 push 1 add
82 call {hash_varlen}
86 dup 6
90 dup 6
91 push 1 add
94 call {hash_varlen}
97 hash
101 pop 2
102 );
104
105 let calculate_running_product_loop_label = format!("{entrypoint}_loop");
106 let calculate_running_product_loop_code = triton_asm!(
107 {calculate_running_product_loop_label}:
109
110 push 0
111 dup 6
112 read_mem {U64_STACK_SIZE}
113 swap 9
114 pop 1
115 dup 12
119 dup 12
120 dup 12
121 xx_add
122 xx_mul
123 recurse_or_return
127 );
128
129 let equal_length_label = format!("{entrypoint}_equal_length");
130 let equal_length_code = triton_asm!(
131 {equal_length_label}:
132 push {U64_STACK_SIZE}
135 mul
136 {&find_challenge_indeterminate}
143 dup 5
146 dup 6
147 dup 5
148 add
149 push 0
152 push 0
153 push 0
156 push 0
157 push 1
158 dup 6
162 dup 6
163 eq
164 push 0
165 eq
166 skiz call {calculate_running_product_loop_label}
167 push {running_product_result_alloc.write_address()}
171 write_mem {running_product_result_alloc.num_words()}
172 pop 5
173 dup 4
177 dup 5
178 dup 5
179 add
182 push 0
185 push 0
186 push 0
187 push 0
188 push 1
189 dup 6
192 dup 6
193 eq
194 push 0
195 eq
196 skiz call {calculate_running_product_loop_label}
199 swap 10
202 pop 1
203 swap 10
204 pop 1
205 swap 10
206 pop 5
209 pop 3
210 push {running_product_result_alloc.read_address()}
213 read_mem {running_product_result_alloc.num_words()}
214 pop 1
215 {&compare_xfes}
218 return
221 );
222
223 triton_asm!(
224 {entrypoint}:
227 {&compare_lengths}
228 push 1
231 swap 1
232 push 0
233 eq
234 skiz call {not_equal_length_label}
237 skiz call {equal_length_label}
238 return
241
242 {¬_equal_length_code}
243 {&equal_length_code}
244 {&calculate_running_product_loop_code}
245 )
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use num::One;
252 use num::Zero;
253
254 use super::*;
255 use crate::library::STATIC_MEMORY_FIRST_ADDRESS;
256 use crate::list::LIST_METADATA_SIZE;
257 use crate::memory::encode_to_memory;
258 use crate::rust_shadowing_helper_functions;
259 use crate::rust_shadowing_helper_functions::list::list_get;
260 use crate::rust_shadowing_helper_functions::list::load_list_with_copy_elements;
261 use crate::test_helpers::test_rust_equivalence_given_complete_state;
262 use crate::test_prelude::*;
263
264 #[macro_rules_attr::apply(test)]
265 fn returns_true_on_multiset_equality() {
266 let snippet = MultisetEqualityU64s;
267 let return_value_is_true = [
268 snippet.init_stack_for_isolated_run(),
269 vec![BFieldElement::one()],
270 ]
271 .concat();
272
273 let mut rng = rand::rng();
274 let mut seed = [0u8; 32];
275 rng.fill_bytes(&mut seed);
276 let mut rng = StdRng::from_seed(seed);
277
278 for length in (0..10).chain(1000..1001) {
279 let init_state = snippet.random_equal_multisets(length, &mut rng);
280 let nd = NonDeterminism::default().with_ram(init_state.memory);
281 test_rust_equivalence_given_complete_state(
282 &ShadowedFunction::new(snippet),
283 &init_state.stack,
284 &[],
285 &nd,
286 &None,
287 Some(&return_value_is_true),
288 );
289 }
290 }
291
292 #[macro_rules_attr::apply(test)]
293 fn returns_false_on_multiset_inequality() {
294 let snippet = MultisetEqualityU64s;
295 let return_value_is_false = [
296 snippet.init_stack_for_isolated_run(),
297 vec![BFieldElement::zero()],
298 ]
299 .concat();
300
301 let mut rng = rand::rng();
302 let mut seed = [0u8; 32];
303 rng.fill_bytes(&mut seed);
304 let mut rng = StdRng::from_seed(seed);
305
306 for length in (1..10).chain(1000..1001) {
307 let init_state = snippet.random_same_length_mutated_elements(length, 1, 1, &mut rng);
308 let nd = NonDeterminism::default().with_ram(init_state.memory);
309 test_rust_equivalence_given_complete_state(
310 &ShadowedFunction::new(snippet),
311 &init_state.stack,
312 &[],
313 &nd,
314 &None,
315 Some(&return_value_is_false),
316 );
317 }
318 }
319
320 #[macro_rules_attr::apply(test)]
321 fn multiset_equality_u64s_pbt() {
322 ShadowedFunction::new(MultisetEqualityU64s).test()
323 }
324
325 impl Function for MultisetEqualityU64s {
326 fn rust_shadow(
327 &self,
328 stack: &mut Vec<BFieldElement>,
329 memory: &mut HashMap<BFieldElement, BFieldElement>,
330 ) -> Result<(), RustShadowError> {
331 let list_b_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
332 let list_a_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
333
334 let a = load_list_with_copy_elements::<2>(list_a_pointer, memory)?;
335 let b = load_list_with_copy_elements::<2>(list_b_pointer, memory)?;
336
337 if a.len() != b.len() {
338 stack.push(BFieldElement::zero());
339 return Ok(());
340 }
341
342 let len = a.len();
343
344 let a_digest = Tip5::hash(&a);
346 let b_digest = Tip5::hash(&b);
347 let indeterminate = Tip5::hash_pair(b_digest, a_digest);
348 let indeterminate =
349 -XFieldElement::new(indeterminate.values()[2..Digest::LEN].try_into().unwrap());
350
351 let mut running_product_a = XFieldElement::one();
353 for i in 0..len {
354 let u64_elem = list_get(list_a_pointer, i, memory, U64_STACK_SIZE)?;
355 let m = XFieldElement::new([u64_elem[0], u64_elem[1], BFieldElement::zero()]);
356 let factor = m - indeterminate;
357 running_product_a *= factor;
358 }
359 let mut running_product_b = XFieldElement::one();
360 for i in 0..len {
361 let u64_elem = list_get(list_b_pointer, i, memory, U64_STACK_SIZE)?;
362 let m = XFieldElement::new([u64_elem[0], u64_elem[1], BFieldElement::zero()]);
363 let factor = m - indeterminate;
364 running_product_b *= factor;
365 }
366
367 encode_to_memory(
369 memory,
370 STATIC_MEMORY_FIRST_ADDRESS - bfe!(EXTENSION_DEGREE as u64 - 1),
371 &running_product_a,
372 );
373 stack.push(bfe!((running_product_a == running_product_b) as u64));
374
375 Ok(())
376 }
377
378 fn pseudorandom_initial_state(
379 &self,
380 seed: [u8; 32],
381 bench_case: Option<BenchmarkCase>,
382 ) -> FunctionInitialState {
383 let mut rng = StdRng::from_seed(seed);
384
385 match bench_case {
386 Some(BenchmarkCase::CommonCase) => self.random_equal_multisets(90, &mut rng),
389 Some(BenchmarkCase::WorstCase) => self.random_equal_multisets(360, &mut rng),
390 None => {
391 let length = rng.random_range(0..50);
392 let num_mutations = rng.random_range(0..=length);
393 let mutation_translation: u64 = rng.random();
394 let another_length = length + rng.random_range(1..10);
395 match rng.random_range(0..=5) {
396 0 => self.random_equal_multisets(length, &mut rng),
397 1 => self.random_equal_lists(length, &mut rng),
398 2 => self.random_equal_multisets_flipped_pointers(length, &mut rng),
399 3 => self.random_same_length_mutated_elements(
400 length,
401 num_mutations,
402 mutation_translation,
403 &mut rng,
404 ),
405 4 => self.random_unequal_length_lists(length, another_length, &mut rng),
406 5 => self.random_unequal_length_lists_trailing_zeros(
407 length,
408 another_length,
409 &mut rng,
410 ),
411 _ => unreachable!(),
412 }
413 }
414 }
415 }
416
417 fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
418 let seed = [111u8; 32];
419 let mut rng = StdRng::from_seed(seed);
420
421 let length_0_length_1 = self.random_unequal_length_lists(0, 1, &mut rng);
422 let length_1_length_0 = self.random_unequal_length_lists(1, 0, &mut rng);
423 let two_empty_lists = self.random_equal_multisets(0, &mut rng);
424 let two_equal_singletons = self.random_equal_multisets(1, &mut rng);
425 let two_equal_lists_length_2 = self.random_equal_lists(2, &mut rng);
426 let two_equal_lists_flipped_order =
427 self.random_equal_multisets_flipped_pointers(4, &mut rng);
428
429 let unqual_lists_length_1_add_1 =
430 self.random_same_length_mutated_elements(1, 1, 1, &mut rng);
431 let unqual_lists_length_1_add_2pow32 =
432 self.random_same_length_mutated_elements(1, 1, 1u64 << 32, &mut rng);
433
434 let unqual_lists_length_2_add_1 =
435 self.random_same_length_mutated_elements(2, 1, 1, &mut rng);
436 let unqual_lists_length_2_add_2pow32 =
437 self.random_same_length_mutated_elements(2, 1, 1u64 << 32, &mut rng);
438
439 let equal_multisets_length_2s = (0..10)
440 .map(|_| self.random_equal_multisets(2, &mut rng))
441 .collect_vec();
442 let equal_multisets_length_3s = (0..10)
443 .map(|_| self.random_equal_multisets(3, &mut rng))
444 .collect_vec();
445 let equal_multisets_length_4s = (0..10)
446 .map(|_| self.random_equal_multisets(4, &mut rng))
447 .collect_vec();
448
449 let different_lengths_same_initial_elements_1_2 =
450 self.random_unequal_length_lists(1, 2, &mut rng);
451 let different_lengths_same_initial_elements_2_1 =
452 self.random_unequal_length_lists(2, 1, &mut rng);
453 let different_lengths_trailing_zeros_1_2 =
454 self.random_unequal_length_lists_trailing_zeros(1, 2, &mut rng);
455
456 [
457 vec![
458 length_0_length_1,
459 length_1_length_0,
460 two_empty_lists,
461 two_equal_singletons,
462 two_equal_lists_length_2,
463 two_equal_lists_flipped_order,
464 unqual_lists_length_1_add_1,
465 unqual_lists_length_1_add_2pow32,
466 unqual_lists_length_2_add_1,
467 unqual_lists_length_2_add_2pow32,
468 different_lengths_same_initial_elements_1_2,
469 different_lengths_same_initial_elements_2_1,
470 different_lengths_trailing_zeros_1_2,
471 ],
472 equal_multisets_length_2s,
473 equal_multisets_length_3s,
474 equal_multisets_length_4s,
475 ]
476 .concat()
477 }
478 }
479
480 impl MultisetEqualityU64s {
481 fn list_a_and_both_pointers(
482 &self,
483 length: usize,
484 rng: &mut StdRng,
485 ) -> (Vec<u64>, BFieldElement, BFieldElement) {
486 let mut list_a: Vec<u64> = vec![0u64; length];
487 for elem in list_a.iter_mut() {
488 *elem = rng.random();
489 }
490
491 let pointer_a: BFieldElement = rng.random();
492
493 let list_size = length * U64_STACK_SIZE + LIST_METADATA_SIZE;
495 let pointer_b_offset: u32 = rng.random_range(list_size as u32..u32::MAX);
496 let pointer_b: BFieldElement =
497 BFieldElement::new(pointer_a.value() + pointer_b_offset as u64);
498
499 (list_a, pointer_a, pointer_b)
500 }
501
502 fn init_state(
503 &self,
504 pointer_a: BFieldElement,
505 pointer_b: BFieldElement,
506 a: Vec<u64>,
507 b: Vec<u64>,
508 ) -> FunctionInitialState {
509 let mut memory = HashMap::default();
510 rust_shadowing_helper_functions::list::list_insert(pointer_a, a, &mut memory);
511 rust_shadowing_helper_functions::list::list_insert(pointer_b, b, &mut memory);
512
513 let stack = [
514 self.init_stack_for_isolated_run(),
515 vec![pointer_a, pointer_b],
516 ]
517 .concat();
518 FunctionInitialState { stack, memory }
519 }
520
521 fn random_equal_multisets(&self, length: usize, rng: &mut StdRng) -> FunctionInitialState {
522 let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length, rng);
523 let mut b = a.clone();
524 b.sort();
525
526 self.init_state(pointer_a, pointer_b, a, b)
527 }
528
529 fn random_equal_lists(&self, length: usize, rng: &mut StdRng) -> FunctionInitialState {
530 let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length, rng);
531 let b = a.clone();
532
533 self.init_state(pointer_a, pointer_b, a, b)
534 }
535
536 fn random_equal_multisets_flipped_pointers(
537 &self,
538 length: usize,
539 rng: &mut StdRng,
540 ) -> FunctionInitialState {
541 let (b, pointer_b, pointer_a) = self.list_a_and_both_pointers(length, rng);
542 let mut a = b.clone();
543 a.sort();
544
545 self.init_state(pointer_a, pointer_b, a, b)
547 }
548
549 fn random_same_length_mutated_elements(
550 &self,
551 length: usize,
552 num_mutations: usize,
553 mutation_translation: u64,
554 rng: &mut StdRng,
555 ) -> FunctionInitialState {
556 let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length, rng);
557 let mut b = a.clone();
558 b.sort();
559
560 for _ in 0..num_mutations {
561 let elem_mut_ref = b.choose_mut(rng).unwrap();
562 *elem_mut_ref = elem_mut_ref.wrapping_add(mutation_translation);
563 }
564
565 self.init_state(pointer_a, pointer_b, a, b)
566 }
567
568 fn random_unequal_length_lists(
569 &self,
570 length_a: usize,
571 length_b: usize,
572 rng: &mut StdRng,
573 ) -> FunctionInitialState {
574 assert_ne!(length_a, length_b, "Don't do this");
575
576 let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length_a, rng);
577 let mut b = a.clone();
578 b.resize_with(length_b, || rng.random());
579
580 self.init_state(pointer_a, pointer_b, a, b)
581 }
582
583 fn random_unequal_length_lists_trailing_zeros(
584 &self,
585 length_a: usize,
586 length_b: usize,
587 rng: &mut StdRng,
588 ) -> FunctionInitialState {
589 assert!(length_b > length_a);
590
591 let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length_a, rng);
592 let mut b = a.clone();
593 b.resize_with(length_b, || 0);
594
595 self.init_state(pointer_a, pointer_b, a, b)
596 }
597 }
598}
599
600#[cfg(test)]
601mod benches {
602 use super::*;
603 use crate::test_prelude::*;
604
605 #[macro_rules_attr::apply(test)]
606 fn benchmark() {
607 ShadowedFunction::new(MultisetEqualityU64s).bench()
608 }
609}