tasm_lib/list/multiset_equality_digests.rs
1use triton_vm::prelude::*;
2
3use crate::hashing::algebraic_hasher::hash_varlen::HashVarlen;
4use crate::list::length::Length;
5use crate::prelude::*;
6
7/// Determine whether two lists are equal up to permutation.
8///
9/// The lists are given as lists of digests. This function uses hashing
10/// to compute a challenge indeterminate, and then computes a running
11/// products for both lists. In the future, the implementation of
12/// function may be replaced by one that uses Triton VM's native
13/// support for permutation checks instead of Fiat-Shamir and running
14/// products.
15#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
16pub struct MultisetEqualityDigests;
17
18impl BasicSnippet for MultisetEqualityDigests {
19 fn inputs(&self) -> Vec<(DataType, String)> {
20 vec![
21 (DataType::List(Box::new(DataType::Digest)), "a".to_owned()),
22 (DataType::List(Box::new(DataType::Digest)), "b".to_owned()),
23 ]
24 }
25
26 fn outputs(&self) -> Vec<(DataType, String)> {
27 vec![(DataType::Bool, "equal_multisets".to_owned())]
28 }
29
30 fn entrypoint(&self) -> String {
31 "tasmlib_list_multiset_equality_digests".to_owned()
32 }
33
34 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
35 let entrypoint = self.entrypoint();
36 let length_snippet = library.import(Box::new(Length));
37 let hash_varlen = library.import(Box::new(HashVarlen));
38
39 let early_abort_label = format!("{entrypoint}_early_abort");
40 let continue_label = format!("{entrypoint}_continue");
41 let running_product_label = format!("{entrypoint}_running_product");
42 let running_product_loop_label = format!("{entrypoint}_running_product_loop");
43
44 triton_asm!(
45 // BEFORE: _ *list_a *list_b
46 // AFTER: _ list_a==list_b (as multisets, or up to permutation)
47 {entrypoint}:
48
49 // read lengths of lists
50 dup 1 dup 1 // _ *list_a *list_b *list_a *list_b
51 call {length_snippet} // _ *list_a *list_b *list_a len_b
52 swap 1 // _ *list_a *list_b len_b *list_a
53 call {length_snippet} // _ *list_a *list_b len_b len_a
54
55 // equate lengths and return early if possible
56 dup 1 // _ *list_a *list_b len_b len_a len_b
57 eq // _ *list_a *list_b len_b (len_a==len_b)
58 push 0 eq // _ *list_a *list_b len_b (len_a!=len_b)
59
60 // early return if lengths mismatch
61 // otherwise continue
62 push 1 swap 1 // _ *list_a *list_b len_b 1 (len_a!=len_b)
63 skiz call {early_abort_label}
64 skiz call {continue_label}
65
66 // _ (list_a == list_b) (as multisets, or up to permutation)
67 return
68
69 {early_abort_label}:
70 // _ *list_a *list_b len_b 1
71 pop 4
72
73 // push return value (false)
74 push 0 // _ 0
75
76 // ensure `else` branch is not taken
77 push 0
78 return
79
80 {continue_label}:
81 // _ *list_a *list_b len
82
83 // hash list_a
84 dup 2 // _ *list_a *list_b len *list_a
85 push 1 add // _ *list_a *list_b len *list_a[0]
86 dup 1 // _ *list_a *list_b len *list_a[0] len
87 push {Digest::LEN} mul // _ *list_a *list_b len *list_a[0] (len*{Digest::LEN})
88 call {hash_varlen} // _ *list_a *list_b len da4 da3 da2 da1 da0
89
90 // hash list_b
91 dup 6 // _ *list_a *list_b len da4 da3 da2 da1 da0 *list_b
92 push 1 add // _ *list_a *list_b len *list_b[0]
93 dup 6 // _ *list_a *list_b len da4 da3 da2 da1 da0 *list_b[0] len
94 push {Digest::LEN} mul // _ *list_a *list_b len da4 da3 da2 da1 da0 *list_b[0] (len*{Digest::LEN})
95 call {hash_varlen} // _ *list_a *list_b len da4 da3 da2 da1 da0 db4 db3 db2 db1 db0
96
97 // hash together
98 hash
99 // _ *list_a *list_b len d4 d3 d2 d1 d0
100
101 // Get 2nd challenge
102 push 0
103 push 0
104 push 0
105 push 0
106 push 0
107 dup 9
108 dup 9
109 dup 9
110 dup 9
111 dup 9
112 // _ *list_a *list_b len d4 d3 d2 d1 d0 0 0 0 0 0 d4 d3 d2 d1 d0
113
114 hash
115 // _ *list_a *list_b len d4 d3 d2 d1 d0 e4 e3 e2 e1 e0
116
117 pop 4
118 hint _x0: XFieldElement = stack[3..6]
119 hint x1: XFieldElement = stack[0..3]
120 // _ *list_a *list_b len d4 d3 d2 d1 d0 e4
121 // _ *list_a *list_b len [-x0] [x1] <- rename
122
123 call {running_product_label} // _ *list_a *list_b len [-x0] [x1] [rpb]
124 dup 11 // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a
125 dup 10 // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len
126 dup 10 dup 10 dup 10 // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len [-x0]
127 dup 10 dup 10 dup 10 // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len [-x0] [x1]
128 call {running_product_label} // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len [-x0] [x1] [rpa]
129
130 // test equality
131 dup 11 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa1 rpa0 rpb0
132 eq // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa1 rpa0==rpb0
133 swap 1 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa0==rpb0 rpa1
134 dup 12 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa0==rpb0 rpa1 rpb1
135 eq mul // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa0==rpb0&&rpa1==rpb1
136 swap 1 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa0==rpb0&&rpa1==rpb1 rpa2
137 dup 12 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa0==rpb0&&rpa1==rpb1 rpa2 rpb2
138 eq mul // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa0==rpb0&&rpa1==rpb1&&rpa2==rpb2
139
140 // clean up and return
141 swap 14 // _ rpa0==rpb0&&rpa1==rpb1 rpa2 rpb2 *list_b len [-indeterminate] rpb2 rpb1 rpb0 *list_a len [-indeterminate] list_a
142 pop 5 pop 5 pop 4
143 // _ *list_a *list_b len [-x0] (rpa == rpb)
144
145 swap 6
146 pop 5
147 pop 1
148 // _ (rpa0==rpb0&&rpa1==rpb1)
149
150 return
151
152 // BEFORE: _ *list len [-x0] [x1]
153 // AFTER: _ *list len [-x0] [x1] rp2 rp1 rp0
154 {running_product_label}:
155 // initialize loop
156 dup 7 // _ *list len [-x0] [x1] *list
157 push 1 add // _ *list len [-x0] [x1] addr
158 dup 7 // _ *list len [-x0] [x1] addr itrs_left
159 push 0 push 0 push 1 // _ *list len [-x0] [x1] addr itrs_left [rp]
160
161 call {running_product_loop_label}
162 // _ *list len [-x0] [x1] addr* 0 [rp]
163
164 // clean up and return
165 swap 2
166 swap 4
167 pop 1
168 swap 2
169 pop 1
170 // _ *list len [-x0] [x1] [rp]
171
172 return
173
174 // INVARIANT: _ *list len [-x0] [x1] addr itrs_left [rp]
175 {running_product_loop_label}:
176 hint running_prod: XFieldElement = stack[0..3]
177 hint itrs_left = stack[3]
178
179 // test termination condition
180 dup 3 // _ *list len [-x0] [x1] addr itrs_left [rp] itrs_left
181 push 0 eq // _ *list len [-x0] [x1] addr itrs_left [rp] itrs_left==0
182 skiz return // _ *list len [-x0] [x1] addr itrs_left [rp]
183
184 // read two first words
185 dup 4 push {Digest::LEN - 1} add read_mem 2
186 // _ *list len [-x0] [x1] addr itrs_left [rp] m4 m3 (addr + 2)
187
188 swap 7
189 pop 1
190 // _ *list len [-x0] [x1] (addr + 2) itrs_left [rp] m4 m3
191
192 push 0
193 dup 10
194 dup 10
195 dup 10
196 // _ *list len [-x0] [x1] (addr + 2) itrs_left [rp] m4 m3 0 [x1]
197
198 xx_mul
199 // _ *list len [-x0] [x1] (addr + 2) itrs_left [rp] m4' m3' µ
200
201 // Read last three words
202 dup 7
203 read_mem 3
204 push {Digest::LEN + 1} add
205 swap 11
206 pop 1
207 // _ *list len [-x0] [x1] (addr + 5) itrs_left [rp] m4' m3' µ m2 m1 m0
208
209 xx_add
210 // _ *list len [-x0] [x1] (addr + 5) itrs_left [rp] [m']
211
212 // itrs_left -= 1
213 swap 6 push -1 add swap 6 // _ *list len [-x0] [x1] addr' itrs_left' [rp] [m']
214
215 // add x0
216 dup 13 dup 13 dup 13 // _ *list len [-x0] [x1] addr' itrs_left' [rp] [m'] [-x0]
217 xx_add // _ *list len [-x0] [x1] addr' itrs_left' [rp] [m' - x0]
218
219 // multiply into running product
220 xx_mul // _ *list len [-x0] [x1] addr' itrs_left' [rp']
221
222 recurse
223 )
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use num::One;
230 use twenty_first::math::other::random_elements;
231
232 use super::*;
233 use crate::empty_stack;
234 use crate::rust_shadowing_helper_functions;
235 use crate::rust_shadowing_helper_functions::list::load_list_with_copy_elements;
236 use crate::test_prelude::*;
237
238 impl Function for MultisetEqualityDigests {
239 fn rust_shadow(
240 &self,
241 stack: &mut Vec<BFieldElement>,
242 memory: &mut HashMap<BFieldElement, BFieldElement>,
243 ) {
244 let list_b_pointer = stack.pop().unwrap();
245 let list_a_pointer = stack.pop().unwrap();
246
247 let a: Vec<[BFieldElement; Digest::LEN]> =
248 load_list_with_copy_elements(list_a_pointer, memory);
249 let mut a = a.into_iter().map(Digest::new).collect_vec();
250 a.sort_unstable();
251 let b: Vec<[BFieldElement; Digest::LEN]> =
252 load_list_with_copy_elements(list_b_pointer, memory);
253 let mut b = b.into_iter().map(Digest::new).collect_vec();
254 b.sort_unstable();
255
256 // equate and push result to stack
257 let result = a == b;
258 stack.push(BFieldElement::new(result as u64));
259 }
260
261 fn pseudorandom_initial_state(
262 &self,
263 seed: [u8; 32],
264 bench_case: Option<BenchmarkCase>,
265 ) -> FunctionInitialState {
266 match bench_case {
267 Some(BenchmarkCase::CommonCase) => self.random_equal_lists(2),
268 Some(BenchmarkCase::WorstCase) => self.random_equal_lists(100),
269 None => {
270 let mut rng = StdRng::from_seed(seed);
271 let length = rng.random_range(1..50);
272 let index = rng.random_range(0..length);
273 let digest_word_index = rng.random_range(0..Digest::LEN);
274 let another_length = length + rng.random_range(1..10);
275 match rng.random_range(0..=3) {
276 0 => self.random_equal_lists(length),
277 1 => self.random_unequal_lists(length),
278 2 => self.random_unequal_length_lists(length, another_length),
279 3 => {
280 self.random_lists_one_element_flipped(length, index, digest_word_index)
281 }
282 _ => unreachable!(),
283 }
284 }
285 }
286 }
287
288 fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
289 let short_equal_multisets = (0..15).map(|i| self.random_equal_lists(i)).collect_vec();
290 let short_unequal_multisets =
291 (0..15).map(|i| self.random_unequal_lists(i)).collect_vec();
292 let mut short_lists_one_element_flipped = vec![];
293 for length in 1..7 {
294 for manipulated_element in 0..length {
295 for manipulated_word in 0..Digest::LEN {
296 short_lists_one_element_flipped.push(
297 self.random_lists_one_element_flipped(
298 length,
299 manipulated_element,
300 manipulated_word,
301 ),
302 );
303 }
304 }
305 }
306
307 let unequal_lengths = vec![
308 self.random_unequal_length_lists(0, 5),
309 self.random_unequal_length_lists(0, 1),
310 self.random_unequal_length_lists(1, 0),
311 self.random_unequal_length_lists(5, 0),
312 self.random_unequal_length_lists(1, 2),
313 self.random_unequal_length_lists(2, 1),
314 self.random_unequal_length_lists(10, 17),
315 self.random_unequal_length_lists(21, 0),
316 ];
317
318 [
319 short_equal_multisets,
320 short_unequal_multisets,
321 short_lists_one_element_flipped,
322 unequal_lengths,
323 ]
324 .concat()
325 }
326 }
327
328 impl MultisetEqualityDigests {
329 fn random_equal_lists(&self, length: usize) -> FunctionInitialState {
330 let list_a: Vec<Digest> = random_elements(length);
331 let mut list_b = list_a.clone();
332 list_b.sort();
333 let pointer_a: BFieldElement = rand::random();
334 let pointer_b: BFieldElement =
335 BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
336
337 let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
338
339 rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
340 rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
341
342 let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
343 FunctionInitialState { stack, memory }
344 }
345
346 fn random_unequal_lists(&self, length: usize) -> FunctionInitialState {
347 let list_a: Vec<Digest> = random_elements(length);
348 let list_b: Vec<Digest> = random_elements(length);
349 let pointer_a: BFieldElement = rand::random();
350 let pointer_b: BFieldElement =
351 BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
352 let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
353
354 rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
355 rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
356
357 let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
358 FunctionInitialState { stack, memory }
359 }
360
361 fn random_unequal_length_lists(
362 &self,
363 length_a: usize,
364 length_b: usize,
365 ) -> FunctionInitialState {
366 let list_a: Vec<Digest> = random_elements(length_a);
367 let list_b: Vec<Digest> = random_elements(length_b);
368 let pointer_a: BFieldElement = rand::random();
369 let pointer_b: BFieldElement =
370 BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
371 let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
372
373 rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
374 rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
375
376 let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
377 FunctionInitialState { stack, memory }
378 }
379
380 fn random_lists_one_element_flipped(
381 &self,
382 length: usize,
383 manipulated_index: usize,
384 manipulated_digest_word_index: usize,
385 ) -> FunctionInitialState {
386 assert!(manipulated_index < length);
387 assert!(manipulated_digest_word_index < Digest::LEN);
388 let list_a: Vec<Digest> = random_elements(length);
389 let mut list_b = list_a.clone();
390 list_b.sort();
391 list_b[manipulated_index].0[manipulated_digest_word_index] += BFieldElement::one();
392 let pointer_a: BFieldElement = rand::random();
393 let pointer_b: BFieldElement =
394 BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
395
396 let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
397
398 rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
399 rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
400
401 let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
402 FunctionInitialState { stack, memory }
403 }
404 }
405
406 #[test]
407 fn rust_shadow() {
408 ShadowedFunction::new(MultisetEqualityDigests).test();
409 }
410}
411
412#[cfg(test)]
413mod benches {
414 use super::*;
415 use crate::test_prelude::*;
416
417 #[test]
418 fn benchmark() {
419 ShadowedFunction::new(MultisetEqualityDigests).bench();
420 }
421}