tasm_lib/list/
multiset_equality_u64s.rs

1use triton_vm::prelude::*;
2use twenty_first::math::x_field_element::EXTENSION_DEGREE;
3
4use crate::hashing::algebraic_hasher::hash_varlen::HashVarlen;
5use crate::prelude::*;
6
7#[derive(Debug, Clone, Copy)]
8pub struct MultisetEqualityU64s;
9
10const U64_STACK_SIZE: usize = 2;
11
12impl BasicSnippet for MultisetEqualityU64s {
13    fn inputs(&self) -> Vec<(DataType, String)> {
14        vec![
15            (DataType::List(Box::new(DataType::U64)), "list_a".to_owned()),
16            (DataType::List(Box::new(DataType::U64)), "list_b".to_owned()),
17        ]
18    }
19
20    fn outputs(&self) -> Vec<(DataType, String)> {
21        vec![(DataType::Bool, "multisets_are_equal".to_owned())]
22    }
23
24    fn entrypoint(&self) -> String {
25        "tasmlib_list_multiset_equality_u64s".to_owned()
26    }
27
28    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
29        let entrypoint = self.entrypoint();
30        assert_eq!(U64_STACK_SIZE, DataType::U64.stack_size());
31
32        let hash_varlen = library.import(Box::new(HashVarlen));
33        let compare_xfes = DataType::Xfe.compare();
34
35        let running_product_result_alloc = library.kmalloc(EXTENSION_DEGREE.try_into().unwrap());
36
37        let compare_lengths = triton_asm!(
38            // _ *a *b
39
40            dup 1
41            dup 1
42            // _ *a *b *a *b
43
44            read_mem 1
45            pop 1
46            // _ *a *b *a b_len
47
48            swap 1
49            read_mem 1
50            pop 1
51            // _ *a *b b_len a_len
52
53            dup 1
54            eq
55            // _ *a *b b_len (b_len == a_len)
56        );
57
58        let not_equal_length_label = format!("{entrypoint}_not_equal_length");
59        let not_equal_length_code = triton_asm!(
60            {not_equal_length_label}:
61            // _ *list_a *list_b len_b 1
62            pop 4
63
64            /* push snippet return value (false) */
65            push 0 // _ 0
66
67            /* ensure `else` branch is not taken */
68            push 0
69
70            return
71        );
72
73        let find_challenge_indeterminate = triton_asm!(
74            // _ *a *b size
75
76            /* Hash list `a` */
77            dup 2
78            dup 1
79            // _ *a *b size *a size
80
81            push 1 add
82            // _ *a *b size *a (size + 1)
83            // _ *a *b size *a full_size ; `full_size` includes length indicator
84
85            call {hash_varlen}
86            // _ *a *b size [a_digest]
87
88            /* Hash list `b` */
89            dup 6
90            dup 6
91            // _ *a *b size [a_digest] *b size
92
93            push 1 add
94            // _ *a *b size [a_digest] *b (size + 1)
95
96            call {hash_varlen}
97            // _ *a *b size [a_digest] [b_digest]
98
99            /* Get challenge indeterminate */
100            hash
101            pop 2
102            // _ *a *b size [-indeterminate]
103        );
104
105        let calculate_running_product_loop_label = format!("{entrypoint}_loop");
106        let calculate_running_product_loop_code = triton_asm!(
107            // INVARIANT: _ [-indeterminate] *list *list[i]_lw [garbage; 2] [running_product]
108            {calculate_running_product_loop_label}:
109
110                push 0
111                dup 6
112                read_mem {U64_STACK_SIZE}
113                swap 9
114                pop 1
115                // _ [-indeterminate] *list *list[i-1]_lw [garbage; 2] [running_product] 0 u64_hi u64_lo
116                // _ [-indeterminate] *list *list[i-1]_lw [garbage; 2] [running_product] [elem_as_xfe]
117
118                dup 12
119                dup 12
120                dup 12
121                xx_add
122                xx_mul
123                // _ [-indeterminate] *list *list[i-1]_lw [garbage; 2] [running_product * (elem_as_xfe -indeterminate)]
124                // _ [-indeterminate] *list *list[i-1]_lw [garbage; 2] [running_product']
125
126                recurse_or_return
127        );
128
129        let equal_length_label = format!("{entrypoint}_equal_length");
130        let equal_length_code = triton_asm!(
131            {equal_length_label}:
132                // _ *a *b len
133
134                push {U64_STACK_SIZE}
135                mul
136                // _ *a *b size
137
138                // `size` is size of elements excluding length indicator
139                // Notice that we also absorb the length indicator into the
140                // sponge state.
141
142                {&find_challenge_indeterminate}
143                // _ *a *b size [-indeterminate]
144
145                dup 5
146                dup 6
147                dup 5
148                add
149                // _ *a *b size [-indeterminate] *a (*a + size)
150
151                push 0
152                push 0
153                // _ *a *b size [-indeterminate] *a (*a + size) [garbage; 2]
154
155                push 0
156                push 0
157                push 1
158                // _ *a *b size [-indeterminate] *a (*a + size) [garbage; 2] [1]
159                // _ *a *b size [-indeterminate] *a (*a + size) [garbage; 2] [running_product]
160
161                dup 6
162                dup 6
163                eq
164                push 0
165                eq
166                skiz call {calculate_running_product_loop_label}
167                // _ *a *b size [-indeterminate] *a *a [garbage; 2] [a_rp]
168
169                /* store result in static memory and cleanup stack */
170                push {running_product_result_alloc.write_address()}
171                write_mem {running_product_result_alloc.num_words()}
172                pop 5
173                // _ *a *b size [-indeterminate]
174
175                /* Prepare stack for loop */
176                dup 4
177                dup 5
178                dup 5
179                // _ *a *b size [-indeterminate] *b *b size
180
181                add
182                // _ *a *b size [-indeterminate] *b *b_lw
183
184                push 0
185                push 0
186                push 0
187                push 0
188                push 1
189                // _ *a *b size [-indeterminate] *b *b_lw [garbage; 2] [running_product]
190
191                dup 6
192                dup 6
193                eq
194                push 0
195                eq
196                // _ *a *b size [-indeterminate] *b *b_lw [garbage; 2] [running_product] (*b != *b_lw)
197
198                skiz call {calculate_running_product_loop_label}
199                // _ *a *b size [-indeterminate] *b *b_lw [garbage; 2] [b_rp]
200
201                swap 10
202                pop 1
203                swap 10
204                pop 1
205                swap 10
206                // _ [b_rp] [-indeterminate] *b *b_lw [garbage; 2] *a
207
208                pop 5
209                pop 3
210                // _ [b_rp]
211
212                push {running_product_result_alloc.read_address()}
213                read_mem {running_product_result_alloc.num_words()}
214                pop 1
215                // _ [b_rp] [a_rp]
216
217                {&compare_xfes}
218                // _ (b_rp == a_rp)
219
220                return
221        );
222
223        triton_asm!(
224            // BEFORE: _ *a *b
225            // AFTER: a == b (as multisets, or up to permutation)
226            {entrypoint}:
227                {&compare_lengths}
228                // _ *a *b b_len (a_len == b_len)
229
230                push 1
231                swap 1
232                push 0
233                eq
234                // _ *a *b b_len 1 (a_len != b_len)
235
236                skiz call {not_equal_length_label}
237                skiz call {equal_length_label}
238                // _ multisets_are_equal
239
240                return
241
242            {&not_equal_length_code}
243            {&equal_length_code}
244            {&calculate_running_product_loop_code}
245        )
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use num::One;
252    use num::Zero;
253
254    use super::*;
255    use crate::library::STATIC_MEMORY_FIRST_ADDRESS;
256    use crate::list::LIST_METADATA_SIZE;
257    use crate::memory::encode_to_memory;
258    use crate::rust_shadowing_helper_functions;
259    use crate::rust_shadowing_helper_functions::list::list_get;
260    use crate::rust_shadowing_helper_functions::list::load_list_with_copy_elements;
261    use crate::test_helpers::test_rust_equivalence_given_complete_state;
262    use crate::test_prelude::*;
263
264    #[test]
265    fn returns_true_on_multiset_equality() {
266        let snippet = MultisetEqualityU64s;
267        let return_value_is_true = [
268            snippet.init_stack_for_isolated_run(),
269            vec![BFieldElement::one()],
270        ]
271        .concat();
272
273        let mut rng = rand::rng();
274        let mut seed = [0u8; 32];
275        rng.fill_bytes(&mut seed);
276        let mut rng = StdRng::from_seed(seed);
277
278        for length in (0..10).chain(1000..1001) {
279            let init_state = snippet.random_equal_multisets(length, &mut rng);
280            let nd = NonDeterminism::default().with_ram(init_state.memory);
281            test_rust_equivalence_given_complete_state(
282                &ShadowedFunction::new(snippet),
283                &init_state.stack,
284                &[],
285                &nd,
286                &None,
287                Some(&return_value_is_true),
288            );
289        }
290    }
291
292    #[test]
293    fn returns_false_on_multiset_inequality() {
294        let snippet = MultisetEqualityU64s;
295        let return_value_is_false = [
296            snippet.init_stack_for_isolated_run(),
297            vec![BFieldElement::zero()],
298        ]
299        .concat();
300
301        let mut rng = rand::rng();
302        let mut seed = [0u8; 32];
303        rng.fill_bytes(&mut seed);
304        let mut rng = StdRng::from_seed(seed);
305
306        for length in (1..10).chain(1000..1001) {
307            let init_state = snippet.random_same_length_mutated_elements(length, 1, 1, &mut rng);
308            let nd = NonDeterminism::default().with_ram(init_state.memory);
309            test_rust_equivalence_given_complete_state(
310                &ShadowedFunction::new(snippet),
311                &init_state.stack,
312                &[],
313                &nd,
314                &None,
315                Some(&return_value_is_false),
316            );
317        }
318    }
319
320    #[test]
321    fn multiset_equality_u64s_pbt() {
322        ShadowedFunction::new(MultisetEqualityU64s).test()
323    }
324
325    impl Function for MultisetEqualityU64s {
326        fn rust_shadow(
327            &self,
328            stack: &mut Vec<BFieldElement>,
329            memory: &mut HashMap<BFieldElement, BFieldElement>,
330        ) {
331            let list_b_pointer = stack.pop().unwrap();
332            let list_a_pointer = stack.pop().unwrap();
333
334            let a: Vec<[BFieldElement; 2]> = load_list_with_copy_elements(list_a_pointer, memory);
335            let b: Vec<[BFieldElement; 2]> = load_list_with_copy_elements(list_b_pointer, memory);
336
337            if a.len() != b.len() {
338                stack.push(BFieldElement::zero());
339                return;
340            }
341
342            let len = a.len();
343
344            // hash to get Fiat-Shamir challenge
345            let a_digest = Tip5::hash(&a);
346            let b_digest = Tip5::hash(&b);
347            let indeterminate = Tip5::hash_pair(b_digest, a_digest);
348            let indeterminate =
349                -XFieldElement::new(indeterminate.values()[2..Digest::LEN].try_into().unwrap());
350
351            // compute running products
352            let mut running_product_a = XFieldElement::one();
353            for i in 0..len as u64 {
354                let u64_elem = list_get(list_a_pointer, i as usize, memory, U64_STACK_SIZE);
355                let m = XFieldElement::new([u64_elem[0], u64_elem[1], BFieldElement::zero()]);
356                let factor = m - indeterminate;
357                running_product_a *= factor;
358            }
359            let mut running_product_b = XFieldElement::one();
360            for i in 0..len as u64 {
361                let u64_elem = list_get(list_b_pointer, i as usize, memory, U64_STACK_SIZE);
362                let m = XFieldElement::new([u64_elem[0], u64_elem[1], BFieldElement::zero()]);
363                let factor = m - indeterminate;
364                running_product_b *= factor;
365            }
366
367            // Write to static memory, since that's what the TASM code does
368            encode_to_memory(
369                memory,
370                STATIC_MEMORY_FIRST_ADDRESS - bfe!(EXTENSION_DEGREE as u64 - 1),
371                &running_product_a,
372            );
373
374            stack.push(bfe!((running_product_a == running_product_b) as u64))
375        }
376
377        fn pseudorandom_initial_state(
378            &self,
379            seed: [u8; 32],
380            bench_case: Option<BenchmarkCase>,
381        ) -> FunctionInitialState {
382            let mut rng = StdRng::from_seed(seed);
383
384            match bench_case {
385                // Common case: 2 * 45 ~ 2 inputs
386                // Common case: 8 inputs
387                Some(BenchmarkCase::CommonCase) => self.random_equal_multisets(90, &mut rng),
388                Some(BenchmarkCase::WorstCase) => self.random_equal_multisets(360, &mut rng),
389                None => {
390                    let length = rng.random_range(0..50);
391                    let num_mutations = rng.random_range(0..=length);
392                    let mutation_translation: u64 = rng.random();
393                    let another_length = length + rng.random_range(1..10);
394                    match rng.random_range(0..=5) {
395                        0 => self.random_equal_multisets(length, &mut rng),
396                        1 => self.random_equal_lists(length, &mut rng),
397                        2 => self.random_equal_multisets_flipped_pointers(length, &mut rng),
398                        3 => self.random_same_length_mutated_elements(
399                            length,
400                            num_mutations,
401                            mutation_translation,
402                            &mut rng,
403                        ),
404                        4 => self.random_unequal_length_lists(length, another_length, &mut rng),
405                        5 => self.random_unequal_length_lists_trailing_zeros(
406                            length,
407                            another_length,
408                            &mut rng,
409                        ),
410                        _ => unreachable!(),
411                    }
412                }
413            }
414        }
415
416        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
417            let seed = [111u8; 32];
418            let mut rng = StdRng::from_seed(seed);
419
420            let length_0_length_1 = self.random_unequal_length_lists(0, 1, &mut rng);
421            let length_1_length_0 = self.random_unequal_length_lists(1, 0, &mut rng);
422            let two_empty_lists = self.random_equal_multisets(0, &mut rng);
423            let two_equal_singletons = self.random_equal_multisets(1, &mut rng);
424            let two_equal_lists_length_2 = self.random_equal_lists(2, &mut rng);
425            let two_equal_lists_flipped_order =
426                self.random_equal_multisets_flipped_pointers(4, &mut rng);
427
428            let unqual_lists_length_1_add_1 =
429                self.random_same_length_mutated_elements(1, 1, 1, &mut rng);
430            let unqual_lists_length_1_add_2pow32 =
431                self.random_same_length_mutated_elements(1, 1, 1u64 << 32, &mut rng);
432
433            let unqual_lists_length_2_add_1 =
434                self.random_same_length_mutated_elements(2, 1, 1, &mut rng);
435            let unqual_lists_length_2_add_2pow32 =
436                self.random_same_length_mutated_elements(2, 1, 1u64 << 32, &mut rng);
437
438            let equal_multisets_length_2s = (0..10)
439                .map(|_| self.random_equal_multisets(2, &mut rng))
440                .collect_vec();
441            let equal_multisets_length_3s = (0..10)
442                .map(|_| self.random_equal_multisets(3, &mut rng))
443                .collect_vec();
444            let equal_multisets_length_4s = (0..10)
445                .map(|_| self.random_equal_multisets(4, &mut rng))
446                .collect_vec();
447
448            let different_lengths_same_initial_elements_1_2 =
449                self.random_unequal_length_lists(1, 2, &mut rng);
450            let different_lengths_same_initial_elements_2_1 =
451                self.random_unequal_length_lists(2, 1, &mut rng);
452            let different_lengths_trailing_zeros_1_2 =
453                self.random_unequal_length_lists_trailing_zeros(1, 2, &mut rng);
454
455            [
456                vec![
457                    length_0_length_1,
458                    length_1_length_0,
459                    two_empty_lists,
460                    two_equal_singletons,
461                    two_equal_lists_length_2,
462                    two_equal_lists_flipped_order,
463                    unqual_lists_length_1_add_1,
464                    unqual_lists_length_1_add_2pow32,
465                    unqual_lists_length_2_add_1,
466                    unqual_lists_length_2_add_2pow32,
467                    different_lengths_same_initial_elements_1_2,
468                    different_lengths_same_initial_elements_2_1,
469                    different_lengths_trailing_zeros_1_2,
470                ],
471                equal_multisets_length_2s,
472                equal_multisets_length_3s,
473                equal_multisets_length_4s,
474            ]
475            .concat()
476        }
477    }
478
479    impl MultisetEqualityU64s {
480        fn list_a_and_both_pointers(
481            &self,
482            length: usize,
483            rng: &mut StdRng,
484        ) -> (Vec<u64>, BFieldElement, BFieldElement) {
485            let mut list_a: Vec<u64> = vec![0u64; length];
486            for elem in list_a.iter_mut() {
487                *elem = rng.random();
488            }
489
490            let pointer_a: BFieldElement = rng.random();
491
492            // Avoid lists from overlapping in memory
493            let list_size = length * U64_STACK_SIZE + LIST_METADATA_SIZE;
494            let pointer_b_offset: u32 = rng.random_range(list_size as u32..u32::MAX);
495            let pointer_b: BFieldElement =
496                BFieldElement::new(pointer_a.value() + pointer_b_offset as u64);
497
498            (list_a, pointer_a, pointer_b)
499        }
500
501        fn init_state(
502            &self,
503            pointer_a: BFieldElement,
504            pointer_b: BFieldElement,
505            a: Vec<u64>,
506            b: Vec<u64>,
507        ) -> FunctionInitialState {
508            let mut memory = HashMap::default();
509            rust_shadowing_helper_functions::list::list_insert(pointer_a, a, &mut memory);
510            rust_shadowing_helper_functions::list::list_insert(pointer_b, b, &mut memory);
511
512            let stack = [
513                self.init_stack_for_isolated_run(),
514                vec![pointer_a, pointer_b],
515            ]
516            .concat();
517            FunctionInitialState { stack, memory }
518        }
519
520        fn random_equal_multisets(&self, length: usize, rng: &mut StdRng) -> FunctionInitialState {
521            let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length, rng);
522            let mut b = a.clone();
523            b.sort();
524
525            self.init_state(pointer_a, pointer_b, a, b)
526        }
527
528        fn random_equal_lists(&self, length: usize, rng: &mut StdRng) -> FunctionInitialState {
529            let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length, rng);
530            let b = a.clone();
531
532            self.init_state(pointer_a, pointer_b, a, b)
533        }
534
535        fn random_equal_multisets_flipped_pointers(
536            &self,
537            length: usize,
538            rng: &mut StdRng,
539        ) -> FunctionInitialState {
540            let (b, pointer_b, pointer_a) = self.list_a_and_both_pointers(length, rng);
541            let mut a = b.clone();
542            a.sort();
543
544            // Generate testcase where `(*a)`.value() < `(*b).value`
545            self.init_state(pointer_a, pointer_b, a, b)
546        }
547
548        fn random_same_length_mutated_elements(
549            &self,
550            length: usize,
551            num_mutations: usize,
552            mutation_translation: u64,
553            rng: &mut StdRng,
554        ) -> FunctionInitialState {
555            let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length, rng);
556            let mut b = a.clone();
557            b.sort();
558
559            for _ in 0..num_mutations {
560                let elem_mut_ref = b.choose_mut(rng).unwrap();
561                *elem_mut_ref = elem_mut_ref.wrapping_add(mutation_translation);
562            }
563
564            self.init_state(pointer_a, pointer_b, a, b)
565        }
566
567        fn random_unequal_length_lists(
568            &self,
569            length_a: usize,
570            length_b: usize,
571            rng: &mut StdRng,
572        ) -> FunctionInitialState {
573            assert_ne!(length_a, length_b, "Don't do this");
574
575            let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length_a, rng);
576            let mut b = a.clone();
577            b.resize_with(length_b, || rng.random());
578
579            self.init_state(pointer_a, pointer_b, a, b)
580        }
581
582        fn random_unequal_length_lists_trailing_zeros(
583            &self,
584            length_a: usize,
585            length_b: usize,
586            rng: &mut StdRng,
587        ) -> FunctionInitialState {
588            assert!(length_b > length_a);
589
590            let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length_a, rng);
591            let mut b = a.clone();
592            b.resize_with(length_b, || 0);
593
594            self.init_state(pointer_a, pointer_b, a, b)
595        }
596    }
597}
598
599#[cfg(test)]
600mod benches {
601    use super::*;
602    use crate::test_prelude::*;
603
604    #[test]
605    fn benchmark() {
606        ShadowedFunction::new(MultisetEqualityU64s).bench()
607    }
608}