Skip to main content

tasm_lib/list/
multiset_equality_u64s.rs

1use triton_vm::prelude::*;
2use twenty_first::math::x_field_element::EXTENSION_DEGREE;
3
4use crate::hashing::algebraic_hasher::hash_varlen::HashVarlen;
5use crate::prelude::*;
6
7#[derive(Debug, Clone, Copy)]
8pub struct MultisetEqualityU64s;
9
10const U64_STACK_SIZE: usize = 2;
11
12impl BasicSnippet for MultisetEqualityU64s {
13    fn parameters(&self) -> Vec<(DataType, String)> {
14        vec![
15            (DataType::List(Box::new(DataType::U64)), "list_a".to_owned()),
16            (DataType::List(Box::new(DataType::U64)), "list_b".to_owned()),
17        ]
18    }
19
20    fn return_values(&self) -> Vec<(DataType, String)> {
21        vec![(DataType::Bool, "multisets_are_equal".to_owned())]
22    }
23
24    fn entrypoint(&self) -> String {
25        "tasmlib_list_multiset_equality_u64s".to_owned()
26    }
27
28    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
29        let entrypoint = self.entrypoint();
30        assert_eq!(U64_STACK_SIZE, DataType::U64.stack_size());
31
32        let hash_varlen = library.import(Box::new(HashVarlen));
33        let compare_xfes = DataType::Xfe.compare();
34
35        let running_product_result_alloc = library.kmalloc(EXTENSION_DEGREE.try_into().unwrap());
36
37        let compare_lengths = triton_asm!(
38            // _ *a *b
39
40            dup 1
41            dup 1
42            // _ *a *b *a *b
43
44            read_mem 1
45            pop 1
46            // _ *a *b *a b_len
47
48            swap 1
49            read_mem 1
50            pop 1
51            // _ *a *b b_len a_len
52
53            dup 1
54            eq
55            // _ *a *b b_len (b_len == a_len)
56        );
57
58        let not_equal_length_label = format!("{entrypoint}_not_equal_length");
59        let not_equal_length_code = triton_asm!(
60            {not_equal_length_label}:
61            // _ *list_a *list_b len_b 1
62            pop 4
63
64            /* push snippet return value (false) */
65            push 0 // _ 0
66
67            /* ensure `else` branch is not taken */
68            push 0
69
70            return
71        );
72
73        let find_challenge_indeterminate = triton_asm!(
74            // _ *a *b size
75
76            /* Hash list `a` */
77            dup 2
78            dup 1
79            // _ *a *b size *a size
80
81            push 1 add
82            // _ *a *b size *a (size + 1)
83            // _ *a *b size *a full_size ; `full_size` includes length indicator
84
85            call {hash_varlen}
86            // _ *a *b size [a_digest]
87
88            /* Hash list `b` */
89            dup 6
90            dup 6
91            // _ *a *b size [a_digest] *b size
92
93            push 1 add
94            // _ *a *b size [a_digest] *b (size + 1)
95
96            call {hash_varlen}
97            // _ *a *b size [a_digest] [b_digest]
98
99            /* Get challenge indeterminate */
100            hash
101            pop 2
102            // _ *a *b size [-indeterminate]
103        );
104
105        let calculate_running_product_loop_label = format!("{entrypoint}_loop");
106        let calculate_running_product_loop_code = triton_asm!(
107            // INVARIANT: _ [-indeterminate] *list *list[i]_lw [garbage; 2] [running_product]
108            {calculate_running_product_loop_label}:
109
110                push 0
111                dup 6
112                read_mem {U64_STACK_SIZE}
113                swap 9
114                pop 1
115                // _ [-indeterminate] *list *list[i-1]_lw [garbage; 2] [running_product] 0 u64_hi u64_lo
116                // _ [-indeterminate] *list *list[i-1]_lw [garbage; 2] [running_product] [elem_as_xfe]
117
118                dup 12
119                dup 12
120                dup 12
121                xx_add
122                xx_mul
123                // _ [-indeterminate] *list *list[i-1]_lw [garbage; 2] [running_product * (elem_as_xfe -indeterminate)]
124                // _ [-indeterminate] *list *list[i-1]_lw [garbage; 2] [running_product']
125
126                recurse_or_return
127        );
128
129        let equal_length_label = format!("{entrypoint}_equal_length");
130        let equal_length_code = triton_asm!(
131            {equal_length_label}:
132                // _ *a *b len
133
134                push {U64_STACK_SIZE}
135                mul
136                // _ *a *b size
137
138                // `size` is size of elements excluding length indicator
139                // Notice that we also absorb the length indicator into the
140                // sponge state.
141
142                {&find_challenge_indeterminate}
143                // _ *a *b size [-indeterminate]
144
145                dup 5
146                dup 6
147                dup 5
148                add
149                // _ *a *b size [-indeterminate] *a (*a + size)
150
151                push 0
152                push 0
153                // _ *a *b size [-indeterminate] *a (*a + size) [garbage; 2]
154
155                push 0
156                push 0
157                push 1
158                // _ *a *b size [-indeterminate] *a (*a + size) [garbage; 2] [1]
159                // _ *a *b size [-indeterminate] *a (*a + size) [garbage; 2] [running_product]
160
161                dup 6
162                dup 6
163                eq
164                push 0
165                eq
166                skiz call {calculate_running_product_loop_label}
167                // _ *a *b size [-indeterminate] *a *a [garbage; 2] [a_rp]
168
169                /* store result in static memory and cleanup stack */
170                push {running_product_result_alloc.write_address()}
171                write_mem {running_product_result_alloc.num_words()}
172                pop 5
173                // _ *a *b size [-indeterminate]
174
175                /* Prepare stack for loop */
176                dup 4
177                dup 5
178                dup 5
179                // _ *a *b size [-indeterminate] *b *b size
180
181                add
182                // _ *a *b size [-indeterminate] *b *b_lw
183
184                push 0
185                push 0
186                push 0
187                push 0
188                push 1
189                // _ *a *b size [-indeterminate] *b *b_lw [garbage; 2] [running_product]
190
191                dup 6
192                dup 6
193                eq
194                push 0
195                eq
196                // _ *a *b size [-indeterminate] *b *b_lw [garbage; 2] [running_product] (*b != *b_lw)
197
198                skiz call {calculate_running_product_loop_label}
199                // _ *a *b size [-indeterminate] *b *b_lw [garbage; 2] [b_rp]
200
201                swap 10
202                pop 1
203                swap 10
204                pop 1
205                swap 10
206                // _ [b_rp] [-indeterminate] *b *b_lw [garbage; 2] *a
207
208                pop 5
209                pop 3
210                // _ [b_rp]
211
212                push {running_product_result_alloc.read_address()}
213                read_mem {running_product_result_alloc.num_words()}
214                pop 1
215                // _ [b_rp] [a_rp]
216
217                {&compare_xfes}
218                // _ (b_rp == a_rp)
219
220                return
221        );
222
223        triton_asm!(
224            // BEFORE: _ *a *b
225            // AFTER: a == b (as multisets, or up to permutation)
226            {entrypoint}:
227                {&compare_lengths}
228                // _ *a *b b_len (a_len == b_len)
229
230                push 1
231                swap 1
232                push 0
233                eq
234                // _ *a *b b_len 1 (a_len != b_len)
235
236                skiz call {not_equal_length_label}
237                skiz call {equal_length_label}
238                // _ multisets_are_equal
239
240                return
241
242            {&not_equal_length_code}
243            {&equal_length_code}
244            {&calculate_running_product_loop_code}
245        )
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use num::One;
252    use num::Zero;
253
254    use super::*;
255    use crate::library::STATIC_MEMORY_FIRST_ADDRESS;
256    use crate::list::LIST_METADATA_SIZE;
257    use crate::memory::encode_to_memory;
258    use crate::rust_shadowing_helper_functions;
259    use crate::rust_shadowing_helper_functions::list::list_get;
260    use crate::rust_shadowing_helper_functions::list::load_list_with_copy_elements;
261    use crate::test_helpers::test_rust_equivalence_given_complete_state;
262    use crate::test_prelude::*;
263
264    #[macro_rules_attr::apply(test)]
265    fn returns_true_on_multiset_equality() {
266        let snippet = MultisetEqualityU64s;
267        let return_value_is_true = [
268            snippet.init_stack_for_isolated_run(),
269            vec![BFieldElement::one()],
270        ]
271        .concat();
272
273        let mut rng = rand::rng();
274        let mut seed = [0u8; 32];
275        rng.fill_bytes(&mut seed);
276        let mut rng = StdRng::from_seed(seed);
277
278        for length in (0..10).chain(1000..1001) {
279            let init_state = snippet.random_equal_multisets(length, &mut rng);
280            let nd = NonDeterminism::default().with_ram(init_state.memory);
281            test_rust_equivalence_given_complete_state(
282                &ShadowedFunction::new(snippet),
283                &init_state.stack,
284                &[],
285                &nd,
286                &None,
287                Some(&return_value_is_true),
288            );
289        }
290    }
291
292    #[macro_rules_attr::apply(test)]
293    fn returns_false_on_multiset_inequality() {
294        let snippet = MultisetEqualityU64s;
295        let return_value_is_false = [
296            snippet.init_stack_for_isolated_run(),
297            vec![BFieldElement::zero()],
298        ]
299        .concat();
300
301        let mut rng = rand::rng();
302        let mut seed = [0u8; 32];
303        rng.fill_bytes(&mut seed);
304        let mut rng = StdRng::from_seed(seed);
305
306        for length in (1..10).chain(1000..1001) {
307            let init_state = snippet.random_same_length_mutated_elements(length, 1, 1, &mut rng);
308            let nd = NonDeterminism::default().with_ram(init_state.memory);
309            test_rust_equivalence_given_complete_state(
310                &ShadowedFunction::new(snippet),
311                &init_state.stack,
312                &[],
313                &nd,
314                &None,
315                Some(&return_value_is_false),
316            );
317        }
318    }
319
320    #[macro_rules_attr::apply(test)]
321    fn multiset_equality_u64s_pbt() {
322        ShadowedFunction::new(MultisetEqualityU64s).test()
323    }
324
325    impl Function for MultisetEqualityU64s {
326        fn rust_shadow(
327            &self,
328            stack: &mut Vec<BFieldElement>,
329            memory: &mut HashMap<BFieldElement, BFieldElement>,
330        ) -> Result<(), RustShadowError> {
331            let list_b_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
332            let list_a_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
333
334            let a = load_list_with_copy_elements::<2>(list_a_pointer, memory)?;
335            let b = load_list_with_copy_elements::<2>(list_b_pointer, memory)?;
336
337            if a.len() != b.len() {
338                stack.push(BFieldElement::zero());
339                return Ok(());
340            }
341
342            let len = a.len();
343
344            // hash to get Fiat-Shamir challenge
345            let a_digest = Tip5::hash(&a);
346            let b_digest = Tip5::hash(&b);
347            let indeterminate = Tip5::hash_pair(b_digest, a_digest);
348            let indeterminate =
349                -XFieldElement::new(indeterminate.values()[2..Digest::LEN].try_into().unwrap());
350
351            // compute running products
352            let mut running_product_a = XFieldElement::one();
353            for i in 0..len {
354                let u64_elem = list_get(list_a_pointer, i, memory, U64_STACK_SIZE)?;
355                let m = XFieldElement::new([u64_elem[0], u64_elem[1], BFieldElement::zero()]);
356                let factor = m - indeterminate;
357                running_product_a *= factor;
358            }
359            let mut running_product_b = XFieldElement::one();
360            for i in 0..len {
361                let u64_elem = list_get(list_b_pointer, i, memory, U64_STACK_SIZE)?;
362                let m = XFieldElement::new([u64_elem[0], u64_elem[1], BFieldElement::zero()]);
363                let factor = m - indeterminate;
364                running_product_b *= factor;
365            }
366
367            // Write to static memory, since that's what the TASM code does
368            encode_to_memory(
369                memory,
370                STATIC_MEMORY_FIRST_ADDRESS - bfe!(EXTENSION_DEGREE as u64 - 1),
371                &running_product_a,
372            );
373            stack.push(bfe!((running_product_a == running_product_b) as u64));
374
375            Ok(())
376        }
377
378        fn pseudorandom_initial_state(
379            &self,
380            seed: [u8; 32],
381            bench_case: Option<BenchmarkCase>,
382        ) -> FunctionInitialState {
383            let mut rng = StdRng::from_seed(seed);
384
385            match bench_case {
386                // Common case: 2 * 45 ~ 2 inputs
387                // Common case: 8 inputs
388                Some(BenchmarkCase::CommonCase) => self.random_equal_multisets(90, &mut rng),
389                Some(BenchmarkCase::WorstCase) => self.random_equal_multisets(360, &mut rng),
390                None => {
391                    let length = rng.random_range(0..50);
392                    let num_mutations = rng.random_range(0..=length);
393                    let mutation_translation: u64 = rng.random();
394                    let another_length = length + rng.random_range(1..10);
395                    match rng.random_range(0..=5) {
396                        0 => self.random_equal_multisets(length, &mut rng),
397                        1 => self.random_equal_lists(length, &mut rng),
398                        2 => self.random_equal_multisets_flipped_pointers(length, &mut rng),
399                        3 => self.random_same_length_mutated_elements(
400                            length,
401                            num_mutations,
402                            mutation_translation,
403                            &mut rng,
404                        ),
405                        4 => self.random_unequal_length_lists(length, another_length, &mut rng),
406                        5 => self.random_unequal_length_lists_trailing_zeros(
407                            length,
408                            another_length,
409                            &mut rng,
410                        ),
411                        _ => unreachable!(),
412                    }
413                }
414            }
415        }
416
417        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
418            let seed = [111u8; 32];
419            let mut rng = StdRng::from_seed(seed);
420
421            let length_0_length_1 = self.random_unequal_length_lists(0, 1, &mut rng);
422            let length_1_length_0 = self.random_unequal_length_lists(1, 0, &mut rng);
423            let two_empty_lists = self.random_equal_multisets(0, &mut rng);
424            let two_equal_singletons = self.random_equal_multisets(1, &mut rng);
425            let two_equal_lists_length_2 = self.random_equal_lists(2, &mut rng);
426            let two_equal_lists_flipped_order =
427                self.random_equal_multisets_flipped_pointers(4, &mut rng);
428
429            let unqual_lists_length_1_add_1 =
430                self.random_same_length_mutated_elements(1, 1, 1, &mut rng);
431            let unqual_lists_length_1_add_2pow32 =
432                self.random_same_length_mutated_elements(1, 1, 1u64 << 32, &mut rng);
433
434            let unqual_lists_length_2_add_1 =
435                self.random_same_length_mutated_elements(2, 1, 1, &mut rng);
436            let unqual_lists_length_2_add_2pow32 =
437                self.random_same_length_mutated_elements(2, 1, 1u64 << 32, &mut rng);
438
439            let equal_multisets_length_2s = (0..10)
440                .map(|_| self.random_equal_multisets(2, &mut rng))
441                .collect_vec();
442            let equal_multisets_length_3s = (0..10)
443                .map(|_| self.random_equal_multisets(3, &mut rng))
444                .collect_vec();
445            let equal_multisets_length_4s = (0..10)
446                .map(|_| self.random_equal_multisets(4, &mut rng))
447                .collect_vec();
448
449            let different_lengths_same_initial_elements_1_2 =
450                self.random_unequal_length_lists(1, 2, &mut rng);
451            let different_lengths_same_initial_elements_2_1 =
452                self.random_unequal_length_lists(2, 1, &mut rng);
453            let different_lengths_trailing_zeros_1_2 =
454                self.random_unequal_length_lists_trailing_zeros(1, 2, &mut rng);
455
456            [
457                vec![
458                    length_0_length_1,
459                    length_1_length_0,
460                    two_empty_lists,
461                    two_equal_singletons,
462                    two_equal_lists_length_2,
463                    two_equal_lists_flipped_order,
464                    unqual_lists_length_1_add_1,
465                    unqual_lists_length_1_add_2pow32,
466                    unqual_lists_length_2_add_1,
467                    unqual_lists_length_2_add_2pow32,
468                    different_lengths_same_initial_elements_1_2,
469                    different_lengths_same_initial_elements_2_1,
470                    different_lengths_trailing_zeros_1_2,
471                ],
472                equal_multisets_length_2s,
473                equal_multisets_length_3s,
474                equal_multisets_length_4s,
475            ]
476            .concat()
477        }
478    }
479
480    impl MultisetEqualityU64s {
481        fn list_a_and_both_pointers(
482            &self,
483            length: usize,
484            rng: &mut StdRng,
485        ) -> (Vec<u64>, BFieldElement, BFieldElement) {
486            let mut list_a: Vec<u64> = vec![0u64; length];
487            for elem in list_a.iter_mut() {
488                *elem = rng.random();
489            }
490
491            let pointer_a: BFieldElement = rng.random();
492
493            // Avoid lists from overlapping in memory
494            let list_size = length * U64_STACK_SIZE + LIST_METADATA_SIZE;
495            let pointer_b_offset: u32 = rng.random_range(list_size as u32..u32::MAX);
496            let pointer_b: BFieldElement =
497                BFieldElement::new(pointer_a.value() + pointer_b_offset as u64);
498
499            (list_a, pointer_a, pointer_b)
500        }
501
502        fn init_state(
503            &self,
504            pointer_a: BFieldElement,
505            pointer_b: BFieldElement,
506            a: Vec<u64>,
507            b: Vec<u64>,
508        ) -> FunctionInitialState {
509            let mut memory = HashMap::default();
510            rust_shadowing_helper_functions::list::list_insert(pointer_a, a, &mut memory);
511            rust_shadowing_helper_functions::list::list_insert(pointer_b, b, &mut memory);
512
513            let stack = [
514                self.init_stack_for_isolated_run(),
515                vec![pointer_a, pointer_b],
516            ]
517            .concat();
518            FunctionInitialState { stack, memory }
519        }
520
521        fn random_equal_multisets(&self, length: usize, rng: &mut StdRng) -> FunctionInitialState {
522            let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length, rng);
523            let mut b = a.clone();
524            b.sort();
525
526            self.init_state(pointer_a, pointer_b, a, b)
527        }
528
529        fn random_equal_lists(&self, length: usize, rng: &mut StdRng) -> FunctionInitialState {
530            let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length, rng);
531            let b = a.clone();
532
533            self.init_state(pointer_a, pointer_b, a, b)
534        }
535
536        fn random_equal_multisets_flipped_pointers(
537            &self,
538            length: usize,
539            rng: &mut StdRng,
540        ) -> FunctionInitialState {
541            let (b, pointer_b, pointer_a) = self.list_a_and_both_pointers(length, rng);
542            let mut a = b.clone();
543            a.sort();
544
545            // Generate testcase where `(*a)`.value() < `(*b).value`
546            self.init_state(pointer_a, pointer_b, a, b)
547        }
548
549        fn random_same_length_mutated_elements(
550            &self,
551            length: usize,
552            num_mutations: usize,
553            mutation_translation: u64,
554            rng: &mut StdRng,
555        ) -> FunctionInitialState {
556            let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length, rng);
557            let mut b = a.clone();
558            b.sort();
559
560            for _ in 0..num_mutations {
561                let elem_mut_ref = b.choose_mut(rng).unwrap();
562                *elem_mut_ref = elem_mut_ref.wrapping_add(mutation_translation);
563            }
564
565            self.init_state(pointer_a, pointer_b, a, b)
566        }
567
568        fn random_unequal_length_lists(
569            &self,
570            length_a: usize,
571            length_b: usize,
572            rng: &mut StdRng,
573        ) -> FunctionInitialState {
574            assert_ne!(length_a, length_b, "Don't do this");
575
576            let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length_a, rng);
577            let mut b = a.clone();
578            b.resize_with(length_b, || rng.random());
579
580            self.init_state(pointer_a, pointer_b, a, b)
581        }
582
583        fn random_unequal_length_lists_trailing_zeros(
584            &self,
585            length_a: usize,
586            length_b: usize,
587            rng: &mut StdRng,
588        ) -> FunctionInitialState {
589            assert!(length_b > length_a);
590
591            let (a, pointer_a, pointer_b) = self.list_a_and_both_pointers(length_a, rng);
592            let mut b = a.clone();
593            b.resize_with(length_b, || 0);
594
595            self.init_state(pointer_a, pointer_b, a, b)
596        }
597    }
598}
599
600#[cfg(test)]
601mod benches {
602    use super::*;
603    use crate::test_prelude::*;
604
605    #[macro_rules_attr::apply(test)]
606    fn benchmark() {
607        ShadowedFunction::new(MultisetEqualityU64s).bench()
608    }
609}