Skip to main content

tasm_lib/list/
multiset_equality_digests.rs

1use triton_vm::prelude::*;
2
3use crate::hashing::algebraic_hasher::hash_varlen::HashVarlen;
4use crate::list::length::Length;
5use crate::prelude::*;
6
7/// Determine whether two lists are equal up to permutation.
8///
9/// The lists are given as lists of digests. This function uses hashing
10/// to compute a challenge indeterminate, and then computes a running
11/// products for both lists. In the future, the implementation of
12/// function may be replaced by one that uses Triton VM's native
13/// support for permutation checks instead of Fiat-Shamir and running
14/// products.
15#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
16pub struct MultisetEqualityDigests;
17
18impl BasicSnippet for MultisetEqualityDigests {
19    fn parameters(&self) -> Vec<(DataType, String)> {
20        vec![
21            (DataType::List(Box::new(DataType::Digest)), "a".to_owned()),
22            (DataType::List(Box::new(DataType::Digest)), "b".to_owned()),
23        ]
24    }
25
26    fn return_values(&self) -> Vec<(DataType, String)> {
27        vec![(DataType::Bool, "equal_multisets".to_owned())]
28    }
29
30    fn entrypoint(&self) -> String {
31        "tasmlib_list_multiset_equality_digests".to_owned()
32    }
33
34    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
35        let entrypoint = self.entrypoint();
36        let length_snippet = library.import(Box::new(Length));
37        let hash_varlen = library.import(Box::new(HashVarlen));
38
39        let early_abort_label = format!("{entrypoint}_early_abort");
40        let continue_label = format!("{entrypoint}_continue");
41        let running_product_label = format!("{entrypoint}_running_product");
42        let running_product_loop_label = format!("{entrypoint}_running_product_loop");
43
44        triton_asm!(
45            // BEFORE: _ *list_a *list_b
46            // AFTER:  _ list_a==list_b (as multisets, or up to permutation)
47            {entrypoint}:
48
49                // read lengths of lists
50                dup 1 dup 1             // _ *list_a *list_b *list_a *list_b
51                call {length_snippet}   // _ *list_a *list_b *list_a len_b
52                swap 1                  // _ *list_a *list_b len_b *list_a
53                call {length_snippet}   // _ *list_a *list_b len_b len_a
54
55                // equate lengths and return early if possible
56                dup 1                   // _ *list_a *list_b len_b len_a len_b
57                eq                      // _ *list_a *list_b len_b (len_a==len_b)
58                push 0 eq               // _ *list_a *list_b len_b (len_a!=len_b)
59
60                // early return if lengths mismatch
61                // otherwise continue
62                push 1 swap 1           // _ *list_a *list_b len_b 1 (len_a!=len_b)
63                skiz call {early_abort_label}
64                skiz call {continue_label}
65
66                // _ (list_a == list_b) (as multisets, or up to permutation)
67                return
68
69            {early_abort_label}:
70                // _ *list_a *list_b len_b 1
71                pop 4
72
73                // push return value (false)
74                push 0 // _ 0
75
76                // ensure `else` branch is not taken
77                push 0
78                return
79
80            {continue_label}:
81                // _ *list_a *list_b len
82
83                // hash list_a
84                dup 2                    // _ *list_a *list_b len *list_a
85                push 1 add               // _ *list_a *list_b len *list_a[0]
86                dup 1                    // _ *list_a *list_b len *list_a[0] len
87                push {Digest::LEN} mul // _ *list_a *list_b len *list_a[0] (len*{Digest::LEN})
88                call {hash_varlen}       // _ *list_a *list_b len da4 da3 da2 da1 da0
89
90                // hash list_b
91                dup 6                            // _ *list_a *list_b len da4 da3 da2 da1 da0 *list_b
92                push 1 add                       // _ *list_a *list_b len *list_b[0]
93                dup 6                            // _ *list_a *list_b len da4 da3 da2 da1 da0 *list_b[0] len
94                push {Digest::LEN} mul         // _ *list_a *list_b len da4 da3 da2 da1 da0 *list_b[0] (len*{Digest::LEN})
95                call {hash_varlen}               // _ *list_a *list_b len da4 da3 da2 da1 da0 db4 db3 db2 db1 db0
96
97                // hash together
98                hash
99                // _ *list_a *list_b len d4 d3 d2 d1 d0
100
101                // Get 2nd challenge
102                push 0
103                push 0
104                push 0
105                push 0
106                push 0
107                dup 9
108                dup 9
109                dup 9
110                dup 9
111                dup 9
112                // _ *list_a *list_b len d4 d3 d2 d1 d0 0 0 0 0 0 d4 d3 d2 d1 d0
113
114                hash
115                // _ *list_a *list_b len d4 d3 d2 d1 d0 e4 e3 e2 e1 e0
116
117                pop 4
118                hint _x0: XFieldElement = stack[3..6]
119                hint x1: XFieldElement = stack[0..3]
120                // _ *list_a *list_b len d4 d3 d2 d1 d0 e4
121                // _ *list_a *list_b len [-x0] [x1] <- rename
122
123                call {running_product_label} // _ *list_a *list_b len [-x0] [x1] [rpb]
124                dup 11                       // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a
125                dup 10                       // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len
126                dup 10 dup 10 dup 10         // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len [-x0]
127                dup 10 dup 10 dup 10         // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len [-x0] [x1]
128                call {running_product_label} // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len [-x0] [x1] [rpa]
129
130                // test equality
131                dup 11 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa1 rpa0 rpb0
132                eq     // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa1 rpa0==rpb0
133                swap 1 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa0==rpb0 rpa1
134                dup 12 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa0==rpb0 rpa1 rpb1
135                eq mul // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa0==rpb0&&rpa1==rpb1
136                swap 1 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa0==rpb0&&rpa1==rpb1 rpa2
137                dup 12 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa0==rpb0&&rpa1==rpb1 rpa2 rpb2
138                eq mul // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa0==rpb0&&rpa1==rpb1&&rpa2==rpb2
139
140                // clean up and return
141                swap 14 // _ rpa0==rpb0&&rpa1==rpb1 rpa2 rpb2 *list_b len [-indeterminate] rpb2 rpb1 rpb0 *list_a len [-indeterminate] list_a
142                pop 5 pop 5 pop 4
143                // _ *list_a *list_b len [-x0] (rpa == rpb)
144
145                swap 6
146                pop 5
147                pop 1
148                // _ (rpa0==rpb0&&rpa1==rpb1)
149
150                return
151
152            // BEFORE: _ *list len [-x0] [x1]
153            // AFTER:  _ *list len [-x0] [x1] rp2 rp1 rp0
154            {running_product_label}:
155                // initialize loop
156                dup 7                // _ *list len [-x0] [x1] *list
157                push 1 add           // _ *list len [-x0] [x1] addr
158                dup 7                // _ *list len [-x0] [x1] addr itrs_left
159                push 0 push 0 push 1 // _ *list len [-x0] [x1] addr itrs_left [rp]
160
161                call {running_product_loop_label}
162                // _ *list len [-x0] [x1] addr* 0 [rp]
163
164                // clean up and return
165                swap 2
166                swap 4
167                pop 1
168                swap 2
169                pop 1
170                // _ *list len [-x0] [x1] [rp]
171
172                return
173
174            // INVARIANT: _ *list len [-x0] [x1] addr itrs_left [rp]
175            {running_product_loop_label}:
176                hint running_prod: XFieldElement = stack[0..3]
177                hint itrs_left = stack[3]
178
179                // test termination condition
180                dup 3       // _ *list len [-x0] [x1] addr itrs_left [rp] itrs_left
181                push 0 eq   // _ *list len [-x0] [x1] addr itrs_left [rp] itrs_left==0
182                skiz return // _ *list len [-x0] [x1] addr itrs_left [rp]
183
184                // read two first words
185                dup 4 push {Digest::LEN - 1} add read_mem 2
186                // _ *list len [-x0] [x1] addr itrs_left [rp] m4 m3 (addr + 2)
187
188                swap 7
189                pop 1
190                // _ *list len [-x0] [x1] (addr + 2) itrs_left [rp] m4 m3
191
192                push 0
193                dup 10
194                dup 10
195                dup 10
196                // _ *list len [-x0] [x1] (addr + 2) itrs_left [rp] m4 m3 0 [x1]
197
198                xx_mul
199                // _ *list len [-x0] [x1] (addr + 2) itrs_left [rp] m4' m3' µ
200
201                // Read last three words
202                dup 7
203                read_mem 3
204                push {Digest::LEN + 1} add
205                swap 11
206                pop 1
207                // _ *list len [-x0] [x1] (addr + 5) itrs_left [rp] m4' m3' µ m2 m1 m0
208
209                xx_add
210                // _ *list len [-x0] [x1] (addr + 5) itrs_left [rp] [m']
211
212                // itrs_left -= 1
213                swap 6 push -1 add swap 6         // _ *list len [-x0] [x1] addr' itrs_left' [rp] [m']
214
215                // add x0
216                dup 13 dup 13 dup 13  // _ *list len [-x0] [x1] addr' itrs_left' [rp] [m'] [-x0]
217                xx_add                // _ *list len [-x0] [x1] addr' itrs_left' [rp] [m' - x0]
218
219                // multiply into running product
220                xx_mul                // _ *list len [-x0] [x1] addr' itrs_left' [rp']
221
222                recurse
223        )
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use num::One;
230    use twenty_first::math::other::random_elements;
231
232    use super::*;
233    use crate::empty_stack;
234    use crate::rust_shadowing_helper_functions;
235    use crate::rust_shadowing_helper_functions::list::load_list_with_copy_elements;
236    use crate::test_prelude::*;
237
238    impl Function for MultisetEqualityDigests {
239        fn rust_shadow(
240            &self,
241            stack: &mut Vec<BFieldElement>,
242            memory: &mut HashMap<BFieldElement, BFieldElement>,
243        ) -> Result<(), RustShadowError> {
244            let list_b_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
245            let list_a_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
246
247            let a = load_list_with_copy_elements::<{ Digest::LEN }>(list_a_pointer, memory)?;
248            let mut a = a.into_iter().map(Digest::new).collect_vec();
249            a.sort_unstable();
250            let b = load_list_with_copy_elements::<{ Digest::LEN }>(list_b_pointer, memory)?;
251            let mut b = b.into_iter().map(Digest::new).collect_vec();
252            b.sort_unstable();
253
254            // equate and push result to stack
255            let result = a == b;
256            stack.push(BFieldElement::new(result as u64));
257            Ok(())
258        }
259
260        fn pseudorandom_initial_state(
261            &self,
262            seed: [u8; 32],
263            bench_case: Option<BenchmarkCase>,
264        ) -> FunctionInitialState {
265            match bench_case {
266                Some(BenchmarkCase::CommonCase) => self.random_equal_lists(2),
267                Some(BenchmarkCase::WorstCase) => self.random_equal_lists(100),
268                None => {
269                    let mut rng = StdRng::from_seed(seed);
270                    let length = rng.random_range(1..50);
271                    let index = rng.random_range(0..length);
272                    let digest_word_index = rng.random_range(0..Digest::LEN);
273                    let another_length = length + rng.random_range(1..10);
274                    match rng.random_range(0..=3) {
275                        0 => self.random_equal_lists(length),
276                        1 => self.random_unequal_lists(length),
277                        2 => self.random_unequal_length_lists(length, another_length),
278                        3 => {
279                            self.random_lists_one_element_flipped(length, index, digest_word_index)
280                        }
281                        _ => unreachable!(),
282                    }
283                }
284            }
285        }
286
287        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
288            let short_equal_multisets = (0..15).map(|i| self.random_equal_lists(i)).collect_vec();
289            let short_unequal_multisets =
290                (0..15).map(|i| self.random_unequal_lists(i)).collect_vec();
291            let mut short_lists_one_element_flipped = vec![];
292            for length in 1..7 {
293                for manipulated_element in 0..length {
294                    for manipulated_word in 0..Digest::LEN {
295                        short_lists_one_element_flipped.push(
296                            self.random_lists_one_element_flipped(
297                                length,
298                                manipulated_element,
299                                manipulated_word,
300                            ),
301                        );
302                    }
303                }
304            }
305
306            let unequal_lengths = vec![
307                self.random_unequal_length_lists(0, 5),
308                self.random_unequal_length_lists(0, 1),
309                self.random_unequal_length_lists(1, 0),
310                self.random_unequal_length_lists(5, 0),
311                self.random_unequal_length_lists(1, 2),
312                self.random_unequal_length_lists(2, 1),
313                self.random_unequal_length_lists(10, 17),
314                self.random_unequal_length_lists(21, 0),
315            ];
316
317            [
318                short_equal_multisets,
319                short_unequal_multisets,
320                short_lists_one_element_flipped,
321                unequal_lengths,
322            ]
323            .concat()
324        }
325    }
326
327    impl MultisetEqualityDigests {
328        fn random_equal_lists(&self, length: usize) -> FunctionInitialState {
329            let list_a: Vec<Digest> = random_elements(length);
330            let mut list_b = list_a.clone();
331            list_b.sort();
332            let pointer_a: BFieldElement = rand::random();
333            let pointer_b: BFieldElement =
334                BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
335
336            let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
337
338            rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
339            rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
340
341            let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
342            FunctionInitialState { stack, memory }
343        }
344
345        fn random_unequal_lists(&self, length: usize) -> FunctionInitialState {
346            let list_a: Vec<Digest> = random_elements(length);
347            let list_b: Vec<Digest> = random_elements(length);
348            let pointer_a: BFieldElement = rand::random();
349            let pointer_b: BFieldElement =
350                BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
351            let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
352
353            rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
354            rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
355
356            let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
357            FunctionInitialState { stack, memory }
358        }
359
360        fn random_unequal_length_lists(
361            &self,
362            length_a: usize,
363            length_b: usize,
364        ) -> FunctionInitialState {
365            let list_a: Vec<Digest> = random_elements(length_a);
366            let list_b: Vec<Digest> = random_elements(length_b);
367            let pointer_a: BFieldElement = rand::random();
368            let pointer_b: BFieldElement =
369                BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
370            let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
371
372            rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
373            rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
374
375            let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
376            FunctionInitialState { stack, memory }
377        }
378
379        fn random_lists_one_element_flipped(
380            &self,
381            length: usize,
382            manipulated_index: usize,
383            manipulated_digest_word_index: usize,
384        ) -> FunctionInitialState {
385            assert!(manipulated_index < length);
386            assert!(manipulated_digest_word_index < Digest::LEN);
387            let list_a: Vec<Digest> = random_elements(length);
388            let mut list_b = list_a.clone();
389            list_b.sort();
390            list_b[manipulated_index].0[manipulated_digest_word_index] += BFieldElement::one();
391            let pointer_a: BFieldElement = rand::random();
392            let pointer_b: BFieldElement =
393                BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
394
395            let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
396
397            rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
398            rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
399
400            let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
401            FunctionInitialState { stack, memory }
402        }
403    }
404
405    #[macro_rules_attr::apply(test)]
406    fn rust_shadow() {
407        ShadowedFunction::new(MultisetEqualityDigests).test();
408    }
409}
410
411#[cfg(test)]
412mod benches {
413    use super::*;
414    use crate::test_prelude::*;
415
416    #[macro_rules_attr::apply(test)]
417    fn benchmark() {
418        ShadowedFunction::new(MultisetEqualityDigests).bench();
419    }
420}