tasm_lib/list/
multiset_equality_digests.rs

1use triton_vm::prelude::*;
2
3use crate::hashing::algebraic_hasher::hash_varlen::HashVarlen;
4use crate::list::length::Length;
5use crate::prelude::*;
6
7/// Determine whether two lists are equal up to permutation.
8///
9/// The lists are given as lists of digests. This function uses hashing
10/// to compute a challenge indeterminate, and then computes a running
11/// products for both lists. In the future, the implementation of
12/// function may be replaced by one that uses Triton VM's native
13/// support for permutation checks instead of Fiat-Shamir and running
14/// products.
15#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
16pub struct MultisetEqualityDigests;
17
18impl BasicSnippet for MultisetEqualityDigests {
19    fn inputs(&self) -> Vec<(DataType, String)> {
20        vec![
21            (DataType::List(Box::new(DataType::Digest)), "a".to_owned()),
22            (DataType::List(Box::new(DataType::Digest)), "b".to_owned()),
23        ]
24    }
25
26    fn outputs(&self) -> Vec<(DataType, String)> {
27        vec![(DataType::Bool, "equal_multisets".to_owned())]
28    }
29
30    fn entrypoint(&self) -> String {
31        "tasmlib_list_multiset_equality_digests".to_owned()
32    }
33
34    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
35        let entrypoint = self.entrypoint();
36        let length_snippet = library.import(Box::new(Length));
37        let hash_varlen = library.import(Box::new(HashVarlen));
38
39        let early_abort_label = format!("{entrypoint}_early_abort");
40        let continue_label = format!("{entrypoint}_continue");
41        let running_product_label = format!("{entrypoint}_running_product");
42        let running_product_loop_label = format!("{entrypoint}_running_product_loop");
43
44        triton_asm!(
45            // BEFORE: _ *list_a *list_b
46            // AFTER:  _ list_a==list_b (as multisets, or up to permutation)
47            {entrypoint}:
48
49                // read lengths of lists
50                dup 1 dup 1             // _ *list_a *list_b *list_a *list_b
51                call {length_snippet}   // _ *list_a *list_b *list_a len_b
52                swap 1                  // _ *list_a *list_b len_b *list_a
53                call {length_snippet}   // _ *list_a *list_b len_b len_a
54
55                // equate lengths and return early if possible
56                dup 1                   // _ *list_a *list_b len_b len_a len_b
57                eq                      // _ *list_a *list_b len_b (len_a==len_b)
58                push 0 eq               // _ *list_a *list_b len_b (len_a!=len_b)
59
60                // early return if lengths mismatch
61                // otherwise continue
62                push 1 swap 1           // _ *list_a *list_b len_b 1 (len_a!=len_b)
63                skiz call {early_abort_label}
64                skiz call {continue_label}
65
66                // _ (list_a == list_b) (as multisets, or up to permutation)
67                return
68
69            {early_abort_label}:
70                // _ *list_a *list_b len_b 1
71                pop 4
72
73                // push return value (false)
74                push 0 // _ 0
75
76                // ensure `else` branch is not taken
77                push 0
78                return
79
80            {continue_label}:
81                // _ *list_a *list_b len
82
83                // hash list_a
84                dup 2                    // _ *list_a *list_b len *list_a
85                push 1 add               // _ *list_a *list_b len *list_a[0]
86                dup 1                    // _ *list_a *list_b len *list_a[0] len
87                push {Digest::LEN} mul // _ *list_a *list_b len *list_a[0] (len*{Digest::LEN})
88                call {hash_varlen}       // _ *list_a *list_b len da4 da3 da2 da1 da0
89
90                // hash list_b
91                dup 6                            // _ *list_a *list_b len da4 da3 da2 da1 da0 *list_b
92                push 1 add                       // _ *list_a *list_b len *list_b[0]
93                dup 6                            // _ *list_a *list_b len da4 da3 da2 da1 da0 *list_b[0] len
94                push {Digest::LEN} mul         // _ *list_a *list_b len da4 da3 da2 da1 da0 *list_b[0] (len*{Digest::LEN})
95                call {hash_varlen}               // _ *list_a *list_b len da4 da3 da2 da1 da0 db4 db3 db2 db1 db0
96
97                // hash together
98                hash
99                // _ *list_a *list_b len d4 d3 d2 d1 d0
100
101                // Get 2nd challenge
102                push 0
103                push 0
104                push 0
105                push 0
106                push 0
107                dup 9
108                dup 9
109                dup 9
110                dup 9
111                dup 9
112                // _ *list_a *list_b len d4 d3 d2 d1 d0 0 0 0 0 0 d4 d3 d2 d1 d0
113
114                hash
115                // _ *list_a *list_b len d4 d3 d2 d1 d0 e4 e3 e2 e1 e0
116
117                pop 4
118                hint _x0: XFieldElement = stack[3..6]
119                hint x1: XFieldElement = stack[0..3]
120                // _ *list_a *list_b len d4 d3 d2 d1 d0 e4
121                // _ *list_a *list_b len [-x0] [x1] <- rename
122
123                call {running_product_label} // _ *list_a *list_b len [-x0] [x1] [rpb]
124                dup 11                       // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a
125                dup 10                       // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len
126                dup 10 dup 10 dup 10         // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len [-x0]
127                dup 10 dup 10 dup 10         // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len [-x0] [x1]
128                call {running_product_label} // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len [-x0] [x1] [rpa]
129
130                // test equality
131                dup 11 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa1 rpa0 rpb0
132                eq     // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa1 rpa0==rpb0
133                swap 1 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa0==rpb0 rpa1
134                dup 12 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa0==rpb0 rpa1 rpb1
135                eq mul // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa0==rpb0&&rpa1==rpb1
136                swap 1 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa0==rpb0&&rpa1==rpb1 rpa2
137                dup 12 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa0==rpb0&&rpa1==rpb1 rpa2 rpb2
138                eq mul // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa0==rpb0&&rpa1==rpb1&&rpa2==rpb2
139
140                // clean up and return
141                swap 14 // _ rpa0==rpb0&&rpa1==rpb1 rpa2 rpb2 *list_b len [-indeterminate] rpb2 rpb1 rpb0 *list_a len [-indeterminate] list_a
142                pop 5 pop 5 pop 4
143                // _ *list_a *list_b len [-x0] (rpa == rpb)
144
145                swap 6
146                pop 5
147                pop 1
148                // _ (rpa0==rpb0&&rpa1==rpb1)
149
150                return
151
152            // BEFORE: _ *list len [-x0] [x1]
153            // AFTER:  _ *list len [-x0] [x1] rp2 rp1 rp0
154            {running_product_label}:
155                // initialize loop
156                dup 7                // _ *list len [-x0] [x1] *list
157                push 1 add           // _ *list len [-x0] [x1] addr
158                dup 7                // _ *list len [-x0] [x1] addr itrs_left
159                push 0 push 0 push 1 // _ *list len [-x0] [x1] addr itrs_left [rp]
160
161                call {running_product_loop_label}
162                // _ *list len [-x0] [x1] addr* 0 [rp]
163
164                // clean up and return
165                swap 2
166                swap 4
167                pop 1
168                swap 2
169                pop 1
170                // _ *list len [-x0] [x1] [rp]
171
172                return
173
174            // INVARIANT: _ *list len [-x0] [x1] addr itrs_left [rp]
175            {running_product_loop_label}:
176                hint running_prod: XFieldElement = stack[0..3]
177                hint itrs_left = stack[3]
178
179                // test termination condition
180                dup 3       // _ *list len [-x0] [x1] addr itrs_left [rp] itrs_left
181                push 0 eq   // _ *list len [-x0] [x1] addr itrs_left [rp] itrs_left==0
182                skiz return // _ *list len [-x0] [x1] addr itrs_left [rp]
183
184                // read two first words
185                dup 4 push {Digest::LEN - 1} add read_mem 2
186                // _ *list len [-x0] [x1] addr itrs_left [rp] m4 m3 (addr + 2)
187
188                swap 7
189                pop 1
190                // _ *list len [-x0] [x1] (addr + 2) itrs_left [rp] m4 m3
191
192                push 0
193                dup 10
194                dup 10
195                dup 10
196                // _ *list len [-x0] [x1] (addr + 2) itrs_left [rp] m4 m3 0 [x1]
197
198                xx_mul
199                // _ *list len [-x0] [x1] (addr + 2) itrs_left [rp] m4' m3' µ
200
201                // Read last three words
202                dup 7
203                read_mem 3
204                push {Digest::LEN + 1} add
205                swap 11
206                pop 1
207                // _ *list len [-x0] [x1] (addr + 5) itrs_left [rp] m4' m3' µ m2 m1 m0
208
209                xx_add
210                // _ *list len [-x0] [x1] (addr + 5) itrs_left [rp] [m']
211
212                // itrs_left -= 1
213                swap 6 push -1 add swap 6         // _ *list len [-x0] [x1] addr' itrs_left' [rp] [m']
214
215                // add x0
216                dup 13 dup 13 dup 13  // _ *list len [-x0] [x1] addr' itrs_left' [rp] [m'] [-x0]
217                xx_add                // _ *list len [-x0] [x1] addr' itrs_left' [rp] [m' - x0]
218
219                // multiply into running product
220                xx_mul                // _ *list len [-x0] [x1] addr' itrs_left' [rp']
221
222                recurse
223        )
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use num::One;
230    use twenty_first::math::other::random_elements;
231
232    use super::*;
233    use crate::empty_stack;
234    use crate::rust_shadowing_helper_functions;
235    use crate::rust_shadowing_helper_functions::list::load_list_with_copy_elements;
236    use crate::test_prelude::*;
237
238    impl Function for MultisetEqualityDigests {
239        fn rust_shadow(
240            &self,
241            stack: &mut Vec<BFieldElement>,
242            memory: &mut HashMap<BFieldElement, BFieldElement>,
243        ) {
244            let list_b_pointer = stack.pop().unwrap();
245            let list_a_pointer = stack.pop().unwrap();
246
247            let a: Vec<[BFieldElement; Digest::LEN]> =
248                load_list_with_copy_elements(list_a_pointer, memory);
249            let mut a = a.into_iter().map(Digest::new).collect_vec();
250            a.sort_unstable();
251            let b: Vec<[BFieldElement; Digest::LEN]> =
252                load_list_with_copy_elements(list_b_pointer, memory);
253            let mut b = b.into_iter().map(Digest::new).collect_vec();
254            b.sort_unstable();
255
256            // equate and push result to stack
257            let result = a == b;
258            stack.push(BFieldElement::new(result as u64));
259        }
260
261        fn pseudorandom_initial_state(
262            &self,
263            seed: [u8; 32],
264            bench_case: Option<BenchmarkCase>,
265        ) -> FunctionInitialState {
266            match bench_case {
267                Some(BenchmarkCase::CommonCase) => self.random_equal_lists(2),
268                Some(BenchmarkCase::WorstCase) => self.random_equal_lists(100),
269                None => {
270                    let mut rng = StdRng::from_seed(seed);
271                    let length = rng.random_range(1..50);
272                    let index = rng.random_range(0..length);
273                    let digest_word_index = rng.random_range(0..Digest::LEN);
274                    let another_length = length + rng.random_range(1..10);
275                    match rng.random_range(0..=3) {
276                        0 => self.random_equal_lists(length),
277                        1 => self.random_unequal_lists(length),
278                        2 => self.random_unequal_length_lists(length, another_length),
279                        3 => {
280                            self.random_lists_one_element_flipped(length, index, digest_word_index)
281                        }
282                        _ => unreachable!(),
283                    }
284                }
285            }
286        }
287
288        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
289            let short_equal_multisets = (0..15).map(|i| self.random_equal_lists(i)).collect_vec();
290            let short_unequal_multisets =
291                (0..15).map(|i| self.random_unequal_lists(i)).collect_vec();
292            let mut short_lists_one_element_flipped = vec![];
293            for length in 1..7 {
294                for manipulated_element in 0..length {
295                    for manipulated_word in 0..Digest::LEN {
296                        short_lists_one_element_flipped.push(
297                            self.random_lists_one_element_flipped(
298                                length,
299                                manipulated_element,
300                                manipulated_word,
301                            ),
302                        );
303                    }
304                }
305            }
306
307            let unequal_lengths = vec![
308                self.random_unequal_length_lists(0, 5),
309                self.random_unequal_length_lists(0, 1),
310                self.random_unequal_length_lists(1, 0),
311                self.random_unequal_length_lists(5, 0),
312                self.random_unequal_length_lists(1, 2),
313                self.random_unequal_length_lists(2, 1),
314                self.random_unequal_length_lists(10, 17),
315                self.random_unequal_length_lists(21, 0),
316            ];
317
318            [
319                short_equal_multisets,
320                short_unequal_multisets,
321                short_lists_one_element_flipped,
322                unequal_lengths,
323            ]
324            .concat()
325        }
326    }
327
328    impl MultisetEqualityDigests {
329        fn random_equal_lists(&self, length: usize) -> FunctionInitialState {
330            let list_a: Vec<Digest> = random_elements(length);
331            let mut list_b = list_a.clone();
332            list_b.sort();
333            let pointer_a: BFieldElement = rand::random();
334            let pointer_b: BFieldElement =
335                BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
336
337            let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
338
339            rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
340            rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
341
342            let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
343            FunctionInitialState { stack, memory }
344        }
345
346        fn random_unequal_lists(&self, length: usize) -> FunctionInitialState {
347            let list_a: Vec<Digest> = random_elements(length);
348            let list_b: Vec<Digest> = random_elements(length);
349            let pointer_a: BFieldElement = rand::random();
350            let pointer_b: BFieldElement =
351                BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
352            let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
353
354            rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
355            rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
356
357            let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
358            FunctionInitialState { stack, memory }
359        }
360
361        fn random_unequal_length_lists(
362            &self,
363            length_a: usize,
364            length_b: usize,
365        ) -> FunctionInitialState {
366            let list_a: Vec<Digest> = random_elements(length_a);
367            let list_b: Vec<Digest> = random_elements(length_b);
368            let pointer_a: BFieldElement = rand::random();
369            let pointer_b: BFieldElement =
370                BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
371            let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
372
373            rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
374            rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
375
376            let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
377            FunctionInitialState { stack, memory }
378        }
379
380        fn random_lists_one_element_flipped(
381            &self,
382            length: usize,
383            manipulated_index: usize,
384            manipulated_digest_word_index: usize,
385        ) -> FunctionInitialState {
386            assert!(manipulated_index < length);
387            assert!(manipulated_digest_word_index < Digest::LEN);
388            let list_a: Vec<Digest> = random_elements(length);
389            let mut list_b = list_a.clone();
390            list_b.sort();
391            list_b[manipulated_index].0[manipulated_digest_word_index] += BFieldElement::one();
392            let pointer_a: BFieldElement = rand::random();
393            let pointer_b: BFieldElement =
394                BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
395
396            let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
397
398            rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
399            rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
400
401            let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
402            FunctionInitialState { stack, memory }
403        }
404    }
405
406    #[test]
407    fn rust_shadow() {
408        ShadowedFunction::new(MultisetEqualityDigests).test();
409    }
410}
411
412#[cfg(test)]
413mod benches {
414    use super::*;
415    use crate::test_prelude::*;
416
417    #[test]
418    fn benchmark() {
419        ShadowedFunction::new(MultisetEqualityDigests).bench();
420    }
421}