1use triton_vm::prelude::*;
2use twenty_first::math::x_field_element::EXTENSION_DEGREE;
3
4use crate::hashing::algebraic_hasher::hash_varlen::HashVarlen;
5use crate::prelude::*;
6
7#[derive(Debug, Clone, Copy)]
8pub struct MultisetEqualityU64s;
9
10const U64_STACK_SIZE: usize = 2;
11
12impl BasicSnippet for MultisetEqualityU64s {
13 fn inputs(&self) -> Vec<(DataType, String)> {
14 vec![
15 (DataType::List(Box::new(DataType::U64)), "list_a".to_owned()),
16 (DataType::List(Box::new(DataType::U64)), "list_b".to_owned()),
17 ]
18 }
19
20 fn outputs(&self) -> Vec<(DataType, String)> {
21 vec![(DataType::Bool, "multisets_are_equal".to_owned())]
22 }
23
24 fn entrypoint(&self) -> String {
25 "tasmlib_list_multiset_equality_u64s".to_owned()
26 }
27
28 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
29 let entrypoint = self.entrypoint();
30 assert_eq!(U64_STACK_SIZE, DataType::U64.stack_size());
31
32 let hash_varlen = library.import(Box::new(HashVarlen));
33 let compare_xfes = DataType::Xfe.compare();
34
35 let running_product_result_alloc = library.kmalloc(EXTENSION_DEGREE.try_into().unwrap());
36
37 let compare_lengths = triton_asm!(
38 dup 1
41 dup 1
42 read_mem 1
45 pop 1
46 swap 1
49 read_mem 1
50 pop 1
51 dup 1
54 eq
55 );
57
58 let not_equal_length_label = format!("{entrypoint}_not_equal_length");
59 let not_equal_length_code = triton_asm!(
60 {not_equal_length_label}:
61 pop 4
63
64 push 0 push 0
69
70 return
71 );
72
73 let find_challenge_indeterminate = triton_asm!(
74 dup 2
78 dup 1
79 push 1 add
82 call {hash_varlen}
86 dup 6
90 dup 6
91 push 1 add
94 call {hash_varlen}
97 hash
101 pop 2
102 );
104
105 let calculate_running_product_loop_label = format!("{entrypoint}_loop");
106 let calculate_running_product_loop_code = triton_asm!(
107 {calculate_running_product_loop_label}:
109
110 push 0
111 dup 6
112 read_mem {U64_STACK_SIZE}
113 swap 9
114 pop 1
115 dup 12
119 dup 12
120 dup 12
121 xx_add
122 xx_mul
123 recurse_or_return
127 );
128
129 let equal_length_label = format!("{entrypoint}_equal_length");
130 let equal_length_code = triton_asm!(
131 {equal_length_label}:
132 push {U64_STACK_SIZE}
135 mul
136 {&find_challenge_indeterminate}
143 dup 5
146 dup 6
147 dup 5
148 add
149 push 0
152 push 0
153 push 0
156 push 0
157 push 1
158 dup 6
162 dup 6
163 eq
164 push 0
165 eq
166 skiz call {calculate_running_product_loop_label}
167 push {running_product_result_alloc.write_address()}
171 write_mem {running_product_result_alloc.num_words()}
172 pop 5
173 dup 4
177 dup 5
178 dup 5
179 add
182 push 0
185 push 0
186 push 0
187 push 0
188 push 1
189 dup 6
192 dup 6
193 eq
194 push 0
195 eq
196 skiz call {calculate_running_product_loop_label}
199 swap 10
202 pop 1
203 swap 10
204 pop 1
205 swap 10
206 pop 5
209 pop 3
210 push {running_product_result_alloc.read_address()}
213 read_mem {running_product_result_alloc.num_words()}
214 pop 1
215 {&compare_xfes}
218 return
221 );
222
223 triton_asm!(
224 {entrypoint}:
227 {&compare_lengths}
228 push 1
231 swap 1
232 push 0
233 eq
234 skiz call {not_equal_length_label}
237 skiz call {equal_length_label}
238 return
241
242 {¬_equal_length_code}
243 {&equal_length_code}
244 {&calculate_running_product_loop_code}
245 )
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use num::One;
252 use num::Zero;
253
254 use super::*;
255 use crate::library::STATIC_MEMORY_FIRST_ADDRESS;
256 use crate::list::LIST_METADATA_SIZE;
257 use crate::memory::encode_to_memory;
258 use crate::rust_shadowing_helper_functions;
259 use crate::rust_shadowing_helper_functions::list::list_get;
260 use crate::rust_shadowing_helper_functions::list::load_list_with_copy_elements;
261 use crate::test_helpers::test_rust_equivalence_given_complete_state;
262 use crate::test_prelude::*;
263
264 #[test]
265 fn returns_true_on_multiset_equality() {
266 let snippet = MultisetEqualityU64s;
267 let return_value_is_true = [
268 snippet.init_stack_for_isolated_run(),
269 vec![BFieldElement::one()],
270 ]
271 .concat();
272
273 let mut rng = rand::rng();
274 let mut seed = [0u8; 32];
275 rng.fill_bytes(&mut seed);
276 let mut rng = StdRng::from_seed(seed);
277
278 for length in (0..10).chain(1000..1001) {
279 let init_state = snippet.random_equal_multisets(length, &mut rng);
280 let nd = NonDeterminism::default().with_ram(init_state.memory);
281 test_rust_equivalence_given_complete_state(
282 &ShadowedFunction::new(snippet),
283 &init_state.stack,
284 &[],
285 &nd,
286 &None,
287 Some(&return_value_is_true),
288 );
289 }
290 }
291
292 #[test]
293 fn returns_false_on_multiset_inequality() {
294 let snippet = MultisetEqualityU64s;
295 let return_value_is_false = [
296 snippet.init_stack_for_isolated_run(),
297 vec![BFieldElement::zero()],
298 ]
299 .concat();
300
301 let mut rng = rand::rng();
302 let mut seed = [0u8; 32];
303 rng.fill_bytes(&mut seed);
304 let mut rng = StdRng::from_seed(seed);
305
306 for length in (1..10).chain(1000..1001) {
307 let init_state = snippet.random_same_length_mutated_elements(length, 1, 1, &mut rng);
308 let nd = NonDeterminism::default().with_ram(init_state.memory);
309 test_rust_equivalence_given_complete_state(
310 &ShadowedFunction::new(snippet),
311 &init_state.stack,
312 &[],
313 &nd,
314 &None,
315 Some(&return_value_is_false),
316 );
317 }
318 }
319
320 #[test]
321 fn multiset_equality_u64s_pbt() {
322 ShadowedFunction::new(MultisetEqualityU64s).test()
323 }
324
325 impl Function for MultisetEqualityU64s {
326 fn rust_shadow(
327 &self,
328 stack: &mut Vec<BFieldElement>,
329 memory: &mut HashMap<BFieldElement, BFieldElement>,
330 ) {
331 let list_b_pointer = stack.pop().unwrap();
332 let list_a_pointer = stack.pop().unwrap();
333
334 let a: Vec<[BFieldElement; 2]> = load_list_with_copy_elements(list_a_pointer, memory);
335 let b: Vec<[BFieldElement; 2]> = load_list_with_copy_elements(list_b_pointer, memory);
336
337 if a.len() != b.len() {
338 stack.push(BFieldElement::zero());
339 return;
340 }
341
342 let len = a.len();
343
344 let a_digest = Tip5::hash(&a);
346 let b_digest = Tip5::hash(&b);
347 let indeterminate = Tip5::hash_pair(b_digest, a_digest);
348 let indeterminate =
349 -XFieldElement::new(indeterminate.values()[2..Digest::LEN].try_into().unwrap());
350
351 let mut running_product_a = XFieldElement::one();
353 for i in 0..len as u64 {
354 let u64_elem = list_get(list_a_pointer, i as usize, memory, U64_STACK_SIZE);
355 let m = XFieldElement::new([u64_elem[0], u64_elem[1], BFieldElement::zero()]);
356 let factor = m - indeterminate;
357 running_product_a *= factor;
358 }
359 let mut running_product_b = XFieldElement::one();
360 for i in 0..len as u64 {
361 let u64_elem = list_get(list_b_pointer, i as usize, memory, U64_STACK_SIZE);
362 let m = XFieldElement::new([u64_elem[0], u64_elem[1], BFieldElement::zero()]);
363 let factor = m - indeterminate;
364 running_product_b *= factor;
365 }
366
367 encode_to_memory(
369 memory,
370 STATIC_MEMORY_FIRST_ADDRESS - bfe!(EXTENSION_DEGREE as u64 - 1),
371 &running_product_a,
372 );
373
374 stack.push(bfe!((running_product_a == running_product_b) as u64))
375 }
376
377 fn pseudorandom_initial_state(
378 &self,
379 seed: [u8; 32],
380 bench_case: Option<BenchmarkCase>,
381 ) -> FunctionInitialState {
382 let mut rng = StdRng::from_seed(seed);
383
384 match bench_case {
385 Some(BenchmarkCase::CommonCase) => self.random_equal_multisets(90, &mut rng),
388 Some(BenchmarkCase::WorstCase) => self.random_equal_multisets(360, &mut rng),
389 None => {
390 let length = rng.random_range(0..50);
391 let num_mutations = rng.random_range(0..=length);
392 let mutation_translation: u64 = rng.random();
393 let another_length = length + rng.random_range(1..10);
394 match rng.random_range(0..=5) {
395 0 => self.random_equal_multisets(length, &mut rng),
396 1 => self.random_equal_lists(length, &mut rng),
397 2 => self.random_equal_multisets_flipped_pointers(length, &mut rng),
398 3 => self.random_same_length_mutated_elements(
399 length,
400 num_mutations,
401 mutation_translation,
402 &mut rng,
403 ),
404 4 => self.random_unequal_length_lists(length, another_length, &mut rng),
405 5 => self.random_unequal_length_lists_trailing_zeros(
406 length,
407 another_length,
408 &mut rng,
409 ),
410 _ => unreachable!(),
411 }
412 }
413 }
414 }
415
416 fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
417 let seed = [111u8; 32];
418 let mut rng = StdRng::from_seed(seed);
419
420 let length_0_length_1 = self.random_unequal_length_lists(0, 1, &mut rng);
421 let length_1_length_0 = self.random_unequal_length_lists(1, 0, &mut rng);
422 let two_empty_lists = self.random_equal_multisets(0, &mut rng);
423 let two_equal_singletons = self.random_equal_multisets(1, &mut rng);
424 let two_equal_lists_length_2 = self.random_equal_lists(2, &mut rng);
425 let two_equal_lists_flipped_order =
426 self.random_equal_multisets_flipped_pointers(4, &mut rng);
427
428 let unqual_lists_length_1_add_1 =
429 self.random_same_length_mutated_elements(1, 1, 1, &mut rng);
430 let unqual_lists_length_1_add_2pow32 =
431 self.random_same_length_mutated_elements(1, 1, 1u64 << 32, &mut rng);
432
433 let unqual_lists_length_2_add_1 =
434 self.random_same_length_mutated_elements(2, 1, 1, &mut rng);
435 let unqual_lists_length_2_add_2pow32 =
436 self.random_same_length_mutated_elements(2, 1, 1u64 << 32, &mut rng);
437
438 let equal_multisets_length_2s = (0..10)
439 .map(|_| self.random_equal_multisets(2, &mut rng))
440 .collect_vec();
441 let equal_multisets_length_3s = (0..10)
442 .map(|_| self.random_equal_multisets(3, &mut rng))
443 .collect_vec();
444 let equal_multisets_length_4s = (0..10)
445 .map(|_| self.random_equal_multisets(4, &mut rng))
446 .collect_vec();
447
448 let different_lengths_same_initial_elements_1_2 =
449 self.random_unequal_length_lists(1, 2, &mut rng);
450 let different_lengths_same_initial_elements_2_1 =
451 self.random_unequal_length_lists(2, 1, &mut rng);
452 let different_lengths_trailing_zeros_1_2 =
453 self.random_unequal_length_lists_trailing_zeros(1, 2, &mut rng);
454
455 [
456 vec![
457 length_0_length_1,
458 length_1_length_0,
459 two_empty_lists,
460 two_equal_singletons,
461 two_equal_lists_length_2,
462 two_equal_lists_flipped_order,
463 unqual_lists_length_1_add_1,
464 unqual_lists_length_1_add_2pow32,
465 unqual_lists_length_2_add_1,
466 unqual_lists_length_2_add_2pow32,
467 different_lengths_same_initial_elements_1_2,
468 different_lengths_same_initial_elements_2_1,
469 different_lengths_trailing_zeros_1_2,
470 ],
471 equal_multisets_length_2s,
472 equal_multisets_length_3s,
473 equal_multisets_length_4s,
474 ]
475 .concat()
476 }
477 }
478
479 impl MultisetEqualityU64s {
480 fn list_a_and_both_pointers(
481 &self,
482 length: usize,
483 rng: &mut StdRng,
484 ) -> (Vec<u64>, BFieldElement, BFieldElement) {
485 let mut list_a: Vec<u64> = vec![0u64; length];
486 for elem in list_a.iter_mut() {
487 *elem = rng.random();
488 }
489
490 let pointer_a: BFieldElement = rng.random();
491
492 let list_size = length * U64_STACK_SIZE + LIST_METADATA_SIZE;
494 let pointer_b_offset: u32 = rng.random_range(list_size as u32..u32::MAX);
495 let pointer_b: BFieldElement =
496 BFieldElement::new(pointer_a.value() + pointer_b_offset as u64);
497
498 (list_a, pointer_a, pointer_b)
499 }
500
501 fn init_state(
502 &self,
503 pointer_a: BFieldElement,
504 pointer_b: BFieldElement,
505 a: Vec<u64>,
506 b: Vec<u64>,
507 ) -> FunctionInitialState {
508 let mut memory = HashMap::default();
509 rust_shadowing_helper_functions::list::list_insert(pointer_a, a, &mut memory);
510 rust_shadowing_helper_functions::list::list_insert(pointer_b, b, &mut memory);
511
512 let stack = [
513 self.init_stack_for_isolated_run(),
514 vec![pointer_a, pointer_b],
515 ]
516 .concat();
517 FunctionInitialState { stack, memory }
518 }
519
520 fn random_equal_multisets(&self, length: usize, rng: &mut StdRng) -> FunctionInitialState {
521 let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length, rng);
522 let mut b = a.clone();
523 b.sort();
524
525 self.init_state(pointer_a, pointer_b, a, b)
526 }
527
528 fn random_equal_lists(&self, length: usize, rng: &mut StdRng) -> FunctionInitialState {
529 let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length, rng);
530 let b = a.clone();
531
532 self.init_state(pointer_a, pointer_b, a, b)
533 }
534
535 fn random_equal_multisets_flipped_pointers(
536 &self,
537 length: usize,
538 rng: &mut StdRng,
539 ) -> FunctionInitialState {
540 let (b, pointer_b, pointer_a) = self.list_a_and_both_pointers(length, rng);
541 let mut a = b.clone();
542 a.sort();
543
544 self.init_state(pointer_a, pointer_b, a, b)
546 }
547
548 fn random_same_length_mutated_elements(
549 &self,
550 length: usize,
551 num_mutations: usize,
552 mutation_translation: u64,
553 rng: &mut StdRng,
554 ) -> FunctionInitialState {
555 let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length, rng);
556 let mut b = a.clone();
557 b.sort();
558
559 for _ in 0..num_mutations {
560 let elem_mut_ref = b.choose_mut(rng).unwrap();
561 *elem_mut_ref = elem_mut_ref.wrapping_add(mutation_translation);
562 }
563
564 self.init_state(pointer_a, pointer_b, a, b)
565 }
566
567 fn random_unequal_length_lists(
568 &self,
569 length_a: usize,
570 length_b: usize,
571 rng: &mut StdRng,
572 ) -> FunctionInitialState {
573 assert_ne!(length_a, length_b, "Don't do this");
574
575 let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length_a, rng);
576 let mut b = a.clone();
577 b.resize_with(length_b, || rng.random());
578
579 self.init_state(pointer_a, pointer_b, a, b)
580 }
581
582 fn random_unequal_length_lists_trailing_zeros(
583 &self,
584 length_a: usize,
585 length_b: usize,
586 rng: &mut StdRng,
587 ) -> FunctionInitialState {
588 assert!(length_b > length_a);
589
590 let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length_a, rng);
591 let mut b = a.clone();
592 b.resize_with(length_b, || 0);
593
594 self.init_state(pointer_a, pointer_b, a, b)
595 }
596 }
597}
598
599#[cfg(test)]
600mod benches {
601 use super::*;
602 use crate::test_prelude::*;
603
604 #[test]
605 fn benchmark() {
606 ShadowedFunction::new(MultisetEqualityU64s).bench()
607 }
608}