tasm_lib/list/multiset_equality_digests.rs
1use triton_vm::prelude::*;
2
3use crate::hashing::algebraic_hasher::hash_varlen::HashVarlen;
4use crate::list::length::Length;
5use crate::prelude::*;
6
7/// Determine whether two lists are equal up to permutation.
8///
9/// The lists are given as lists of digests. This function uses hashing
10/// to compute a challenge indeterminate, and then computes a running
11/// products for both lists. In the future, the implementation of
12/// function may be replaced by one that uses Triton VM's native
13/// support for permutation checks instead of Fiat-Shamir and running
14/// products.
15#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
16pub struct MultisetEqualityDigests;
17
18impl BasicSnippet for MultisetEqualityDigests {
19 fn parameters(&self) -> Vec<(DataType, String)> {
20 vec![
21 (DataType::List(Box::new(DataType::Digest)), "a".to_owned()),
22 (DataType::List(Box::new(DataType::Digest)), "b".to_owned()),
23 ]
24 }
25
26 fn return_values(&self) -> Vec<(DataType, String)> {
27 vec![(DataType::Bool, "equal_multisets".to_owned())]
28 }
29
30 fn entrypoint(&self) -> String {
31 "tasmlib_list_multiset_equality_digests".to_owned()
32 }
33
34 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
35 let entrypoint = self.entrypoint();
36 let length_snippet = library.import(Box::new(Length));
37 let hash_varlen = library.import(Box::new(HashVarlen));
38
39 let early_abort_label = format!("{entrypoint}_early_abort");
40 let continue_label = format!("{entrypoint}_continue");
41 let running_product_label = format!("{entrypoint}_running_product");
42 let running_product_loop_label = format!("{entrypoint}_running_product_loop");
43
44 triton_asm!(
45 // BEFORE: _ *list_a *list_b
46 // AFTER: _ list_a==list_b (as multisets, or up to permutation)
47 {entrypoint}:
48
49 // read lengths of lists
50 dup 1 dup 1 // _ *list_a *list_b *list_a *list_b
51 call {length_snippet} // _ *list_a *list_b *list_a len_b
52 swap 1 // _ *list_a *list_b len_b *list_a
53 call {length_snippet} // _ *list_a *list_b len_b len_a
54
55 // equate lengths and return early if possible
56 dup 1 // _ *list_a *list_b len_b len_a len_b
57 eq // _ *list_a *list_b len_b (len_a==len_b)
58 push 0 eq // _ *list_a *list_b len_b (len_a!=len_b)
59
60 // early return if lengths mismatch
61 // otherwise continue
62 push 1 swap 1 // _ *list_a *list_b len_b 1 (len_a!=len_b)
63 skiz call {early_abort_label}
64 skiz call {continue_label}
65
66 // _ (list_a == list_b) (as multisets, or up to permutation)
67 return
68
69 {early_abort_label}:
70 // _ *list_a *list_b len_b 1
71 pop 4
72
73 // push return value (false)
74 push 0 // _ 0
75
76 // ensure `else` branch is not taken
77 push 0
78 return
79
80 {continue_label}:
81 // _ *list_a *list_b len
82
83 // hash list_a
84 dup 2 // _ *list_a *list_b len *list_a
85 push 1 add // _ *list_a *list_b len *list_a[0]
86 dup 1 // _ *list_a *list_b len *list_a[0] len
87 push {Digest::LEN} mul // _ *list_a *list_b len *list_a[0] (len*{Digest::LEN})
88 call {hash_varlen} // _ *list_a *list_b len da4 da3 da2 da1 da0
89
90 // hash list_b
91 dup 6 // _ *list_a *list_b len da4 da3 da2 da1 da0 *list_b
92 push 1 add // _ *list_a *list_b len *list_b[0]
93 dup 6 // _ *list_a *list_b len da4 da3 da2 da1 da0 *list_b[0] len
94 push {Digest::LEN} mul // _ *list_a *list_b len da4 da3 da2 da1 da0 *list_b[0] (len*{Digest::LEN})
95 call {hash_varlen} // _ *list_a *list_b len da4 da3 da2 da1 da0 db4 db3 db2 db1 db0
96
97 // hash together
98 hash
99 // _ *list_a *list_b len d4 d3 d2 d1 d0
100
101 // Get 2nd challenge
102 push 0
103 push 0
104 push 0
105 push 0
106 push 0
107 dup 9
108 dup 9
109 dup 9
110 dup 9
111 dup 9
112 // _ *list_a *list_b len d4 d3 d2 d1 d0 0 0 0 0 0 d4 d3 d2 d1 d0
113
114 hash
115 // _ *list_a *list_b len d4 d3 d2 d1 d0 e4 e3 e2 e1 e0
116
117 pop 4
118 hint _x0: XFieldElement = stack[3..6]
119 hint x1: XFieldElement = stack[0..3]
120 // _ *list_a *list_b len d4 d3 d2 d1 d0 e4
121 // _ *list_a *list_b len [-x0] [x1] <- rename
122
123 call {running_product_label} // _ *list_a *list_b len [-x0] [x1] [rpb]
124 dup 11 // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a
125 dup 10 // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len
126 dup 10 dup 10 dup 10 // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len [-x0]
127 dup 10 dup 10 dup 10 // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len [-x0] [x1]
128 call {running_product_label} // _ *list_a *list_b len [-x0] [x1] [rpb] *list_a len [-x0] [x1] [rpa]
129
130 // test equality
131 dup 11 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa1 rpa0 rpb0
132 eq // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa1 rpa0==rpb0
133 swap 1 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa0==rpb0 rpa1
134 dup 12 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa0==rpb0 rpa1 rpb1
135 eq mul // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa2 rpa0==rpb0&&rpa1==rpb1
136 swap 1 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa0==rpb0&&rpa1==rpb1 rpa2
137 dup 12 // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa0==rpb0&&rpa1==rpb1 rpa2 rpb2
138 eq mul // _ *list_a *list_b len [-x0] [x1] rpb2 rpb1 rpb0 *list_a len [-x0] [x1] rpa0==rpb0&&rpa1==rpb1&&rpa2==rpb2
139
140 // clean up and return
141 swap 14 // _ rpa0==rpb0&&rpa1==rpb1 rpa2 rpb2 *list_b len [-indeterminate] rpb2 rpb1 rpb0 *list_a len [-indeterminate] list_a
142 pop 5 pop 5 pop 4
143 // _ *list_a *list_b len [-x0] (rpa == rpb)
144
145 swap 6
146 pop 5
147 pop 1
148 // _ (rpa0==rpb0&&rpa1==rpb1)
149
150 return
151
152 // BEFORE: _ *list len [-x0] [x1]
153 // AFTER: _ *list len [-x0] [x1] rp2 rp1 rp0
154 {running_product_label}:
155 // initialize loop
156 dup 7 // _ *list len [-x0] [x1] *list
157 push 1 add // _ *list len [-x0] [x1] addr
158 dup 7 // _ *list len [-x0] [x1] addr itrs_left
159 push 0 push 0 push 1 // _ *list len [-x0] [x1] addr itrs_left [rp]
160
161 call {running_product_loop_label}
162 // _ *list len [-x0] [x1] addr* 0 [rp]
163
164 // clean up and return
165 swap 2
166 swap 4
167 pop 1
168 swap 2
169 pop 1
170 // _ *list len [-x0] [x1] [rp]
171
172 return
173
174 // INVARIANT: _ *list len [-x0] [x1] addr itrs_left [rp]
175 {running_product_loop_label}:
176 hint running_prod: XFieldElement = stack[0..3]
177 hint itrs_left = stack[3]
178
179 // test termination condition
180 dup 3 // _ *list len [-x0] [x1] addr itrs_left [rp] itrs_left
181 push 0 eq // _ *list len [-x0] [x1] addr itrs_left [rp] itrs_left==0
182 skiz return // _ *list len [-x0] [x1] addr itrs_left [rp]
183
184 // read two first words
185 dup 4 push {Digest::LEN - 1} add read_mem 2
186 // _ *list len [-x0] [x1] addr itrs_left [rp] m4 m3 (addr + 2)
187
188 swap 7
189 pop 1
190 // _ *list len [-x0] [x1] (addr + 2) itrs_left [rp] m4 m3
191
192 push 0
193 dup 10
194 dup 10
195 dup 10
196 // _ *list len [-x0] [x1] (addr + 2) itrs_left [rp] m4 m3 0 [x1]
197
198 xx_mul
199 // _ *list len [-x0] [x1] (addr + 2) itrs_left [rp] m4' m3' µ
200
201 // Read last three words
202 dup 7
203 read_mem 3
204 push {Digest::LEN + 1} add
205 swap 11
206 pop 1
207 // _ *list len [-x0] [x1] (addr + 5) itrs_left [rp] m4' m3' µ m2 m1 m0
208
209 xx_add
210 // _ *list len [-x0] [x1] (addr + 5) itrs_left [rp] [m']
211
212 // itrs_left -= 1
213 swap 6 push -1 add swap 6 // _ *list len [-x0] [x1] addr' itrs_left' [rp] [m']
214
215 // add x0
216 dup 13 dup 13 dup 13 // _ *list len [-x0] [x1] addr' itrs_left' [rp] [m'] [-x0]
217 xx_add // _ *list len [-x0] [x1] addr' itrs_left' [rp] [m' - x0]
218
219 // multiply into running product
220 xx_mul // _ *list len [-x0] [x1] addr' itrs_left' [rp']
221
222 recurse
223 )
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use num::One;
230 use twenty_first::math::other::random_elements;
231
232 use super::*;
233 use crate::empty_stack;
234 use crate::rust_shadowing_helper_functions;
235 use crate::rust_shadowing_helper_functions::list::load_list_with_copy_elements;
236 use crate::test_prelude::*;
237
238 impl Function for MultisetEqualityDigests {
239 fn rust_shadow(
240 &self,
241 stack: &mut Vec<BFieldElement>,
242 memory: &mut HashMap<BFieldElement, BFieldElement>,
243 ) -> Result<(), RustShadowError> {
244 let list_b_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
245 let list_a_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
246
247 let a = load_list_with_copy_elements::<{ Digest::LEN }>(list_a_pointer, memory)?;
248 let mut a = a.into_iter().map(Digest::new).collect_vec();
249 a.sort_unstable();
250 let b = load_list_with_copy_elements::<{ Digest::LEN }>(list_b_pointer, memory)?;
251 let mut b = b.into_iter().map(Digest::new).collect_vec();
252 b.sort_unstable();
253
254 // equate and push result to stack
255 let result = a == b;
256 stack.push(BFieldElement::new(result as u64));
257 Ok(())
258 }
259
260 fn pseudorandom_initial_state(
261 &self,
262 seed: [u8; 32],
263 bench_case: Option<BenchmarkCase>,
264 ) -> FunctionInitialState {
265 match bench_case {
266 Some(BenchmarkCase::CommonCase) => self.random_equal_lists(2),
267 Some(BenchmarkCase::WorstCase) => self.random_equal_lists(100),
268 None => {
269 let mut rng = StdRng::from_seed(seed);
270 let length = rng.random_range(1..50);
271 let index = rng.random_range(0..length);
272 let digest_word_index = rng.random_range(0..Digest::LEN);
273 let another_length = length + rng.random_range(1..10);
274 match rng.random_range(0..=3) {
275 0 => self.random_equal_lists(length),
276 1 => self.random_unequal_lists(length),
277 2 => self.random_unequal_length_lists(length, another_length),
278 3 => {
279 self.random_lists_one_element_flipped(length, index, digest_word_index)
280 }
281 _ => unreachable!(),
282 }
283 }
284 }
285 }
286
287 fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
288 let short_equal_multisets = (0..15).map(|i| self.random_equal_lists(i)).collect_vec();
289 let short_unequal_multisets =
290 (0..15).map(|i| self.random_unequal_lists(i)).collect_vec();
291 let mut short_lists_one_element_flipped = vec![];
292 for length in 1..7 {
293 for manipulated_element in 0..length {
294 for manipulated_word in 0..Digest::LEN {
295 short_lists_one_element_flipped.push(
296 self.random_lists_one_element_flipped(
297 length,
298 manipulated_element,
299 manipulated_word,
300 ),
301 );
302 }
303 }
304 }
305
306 let unequal_lengths = vec![
307 self.random_unequal_length_lists(0, 5),
308 self.random_unequal_length_lists(0, 1),
309 self.random_unequal_length_lists(1, 0),
310 self.random_unequal_length_lists(5, 0),
311 self.random_unequal_length_lists(1, 2),
312 self.random_unequal_length_lists(2, 1),
313 self.random_unequal_length_lists(10, 17),
314 self.random_unequal_length_lists(21, 0),
315 ];
316
317 [
318 short_equal_multisets,
319 short_unequal_multisets,
320 short_lists_one_element_flipped,
321 unequal_lengths,
322 ]
323 .concat()
324 }
325 }
326
327 impl MultisetEqualityDigests {
328 fn random_equal_lists(&self, length: usize) -> FunctionInitialState {
329 let list_a: Vec<Digest> = random_elements(length);
330 let mut list_b = list_a.clone();
331 list_b.sort();
332 let pointer_a: BFieldElement = rand::random();
333 let pointer_b: BFieldElement =
334 BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
335
336 let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
337
338 rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
339 rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
340
341 let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
342 FunctionInitialState { stack, memory }
343 }
344
345 fn random_unequal_lists(&self, length: usize) -> FunctionInitialState {
346 let list_a: Vec<Digest> = random_elements(length);
347 let list_b: Vec<Digest> = random_elements(length);
348 let pointer_a: BFieldElement = rand::random();
349 let pointer_b: BFieldElement =
350 BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
351 let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
352
353 rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
354 rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
355
356 let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
357 FunctionInitialState { stack, memory }
358 }
359
360 fn random_unequal_length_lists(
361 &self,
362 length_a: usize,
363 length_b: usize,
364 ) -> FunctionInitialState {
365 let list_a: Vec<Digest> = random_elements(length_a);
366 let list_b: Vec<Digest> = random_elements(length_b);
367 let pointer_a: BFieldElement = rand::random();
368 let pointer_b: BFieldElement =
369 BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
370 let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
371
372 rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
373 rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
374
375 let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
376 FunctionInitialState { stack, memory }
377 }
378
379 fn random_lists_one_element_flipped(
380 &self,
381 length: usize,
382 manipulated_index: usize,
383 manipulated_digest_word_index: usize,
384 ) -> FunctionInitialState {
385 assert!(manipulated_index < length);
386 assert!(manipulated_digest_word_index < Digest::LEN);
387 let list_a: Vec<Digest> = random_elements(length);
388 let mut list_b = list_a.clone();
389 list_b.sort();
390 list_b[manipulated_index].0[manipulated_digest_word_index] += BFieldElement::one();
391 let pointer_a: BFieldElement = rand::random();
392 let pointer_b: BFieldElement =
393 BFieldElement::new(pointer_a.value() + rand::random::<u32>() as u64);
394
395 let mut memory: HashMap<BFieldElement, BFieldElement> = HashMap::new();
396
397 rust_shadowing_helper_functions::list::list_insert(pointer_a, list_a, &mut memory);
398 rust_shadowing_helper_functions::list::list_insert(pointer_b, list_b, &mut memory);
399
400 let stack = [empty_stack(), vec![pointer_a, pointer_b]].concat();
401 FunctionInitialState { stack, memory }
402 }
403 }
404
405 #[macro_rules_attr::apply(test)]
406 fn rust_shadow() {
407 ShadowedFunction::new(MultisetEqualityDigests).test();
408 }
409}
410
411#[cfg(test)]
412mod benches {
413 use super::*;
414 use crate::test_prelude::*;
415
416 #[macro_rules_attr::apply(test)]
417 fn benchmark() {
418 ShadowedFunction::new(MultisetEqualityDigests).bench();
419 }
420}