tasm_lib/list/higher_order/
zip.rs1use std::collections::HashMap;
2
3use itertools::Itertools;
4use rand::prelude::*;
5use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
6use triton_vm::prelude::*;
7
8use crate::InitVmState;
9use crate::empty_stack;
10use crate::list::LIST_METADATA_SIZE;
11use crate::list::new::New;
12use crate::prelude::*;
13use crate::rust_shadowing_helper_functions::list::untyped_insert_random_list;
14use crate::snippet_bencher::BenchmarkCase;
15use crate::traits::function::*;
16
17#[derive(Debug, Clone, Eq, PartialEq, Hash)]
19pub struct Zip {
20 pub left_type: DataType,
21 pub right_type: DataType,
22}
23
24impl Zip {
25 pub fn new(left_type: DataType, right_type: DataType) -> Self {
26 Self {
27 left_type,
28 right_type,
29 }
30 }
31}
32
33impl BasicSnippet for Zip {
34 fn inputs(&self) -> Vec<(DataType, String)> {
35 let list = |data_type| DataType::List(Box::new(data_type));
36
37 let left_list = (list(self.left_type.clone()), "*left_list".to_string());
38 let right_list = (list(self.right_type.clone()), "*right_list".to_string());
39 vec![left_list, right_list]
40 }
41
42 fn outputs(&self) -> Vec<(DataType, String)> {
43 let list = |data_type| DataType::List(Box::new(data_type));
44
45 let tuple_type = DataType::Tuple(vec![self.left_type.clone(), self.right_type.clone()]);
46 let output_list = (list(tuple_type), "*output_list".to_string());
47 vec![output_list]
48 }
49
50 fn entrypoint(&self) -> String {
51 format!(
52 "tasmlib_list_higher_order_u32_zip_{}_with_{}",
53 self.left_type.label_friendly_name(),
54 self.right_type.label_friendly_name()
55 )
56 }
57
58 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
59 let output_type = DataType::Tuple(vec![self.left_type.clone(), self.right_type.clone()]);
60
61 let new_output_list = library.import(Box::new(New));
62
63 let entrypoint = self.entrypoint();
64 let main_loop_label = format!("{entrypoint}_loop");
65
66 let right_size = self.right_type.stack_size();
67 let left_size = self.left_type.stack_size();
68 let read_left_element = self.left_type.read_value_from_memory_leave_pointer();
69 let read_right_element = self.right_type.read_value_from_memory_leave_pointer();
70 let write_output_element = output_type.write_value_to_memory_leave_pointer();
71 let left_size_plus_one = left_size + 1;
72 let left_size_plus_three = left_size + 3;
73 let sum_of_size = left_size + right_size;
74 let sum_of_size_plus_two = sum_of_size + 2;
75 assert!(
76 sum_of_size_plus_two <= NUM_OP_STACK_REGISTERS,
77 "zip only works for an output element size less than or equal to the available \
78 op-stack words"
79 );
80 let minus_two_times_sum_of_size = -(2 * sum_of_size as i32);
81
82 let mul_with_size = |n| match n {
83 0 => triton_asm!(pop 1 push 0),
84 1 => triton_asm!(),
85 n => triton_asm!(
86 push {n}
87 mul
88 ),
89 };
90
91 let main_loop = triton_asm!(
92 {main_loop_label}:
94 dup 3
96 dup 3
97 eq
98
99 skiz return
100 dup 2
103 {&read_left_element}
104 swap {left_size_plus_three}
107 pop 1
108 dup {left_size_plus_one}
111 {&read_right_element}
114 swap {sum_of_size_plus_two}
117 pop 1
118 dup {sum_of_size}
121 {&write_output_element}
124 push {minus_two_times_sum_of_size}
127 add
128 swap 1
131 pop 1
132 recurse
135 );
136
137 triton_asm!(
138 {entrypoint}:
141 dup 1 read_mem 1 pop 1 dup 1 read_mem 1 pop 1 dup 1 eq assert call {new_output_list} dup 1
157 swap 1
158 write_mem 1
159 dup 1
163 push -1
164 add
165 {&mul_with_size(sum_of_size)}
168 add
169 swap 2
172 dup 1
175 {&mul_with_size(right_size)}
176 add
177 swap 1
180 {&mul_with_size(left_size)}
183 dup 3
186 add
189 swap 2
192 call {main_loop_label}
195 push {sum_of_size - 1}
199 add
200
201 swap 3
202
203 pop 3
204
205 return
206
207 {&main_loop}
208 )
209 }
210}
211
212impl Function for Zip {
213 fn rust_shadow(
214 &self,
215 stack: &mut Vec<BFieldElement>,
216 memory: &mut HashMap<BFieldElement, BFieldElement>,
217 ) {
218 use crate::rust_shadowing_helper_functions::dyn_malloc;
219 use crate::rust_shadowing_helper_functions::list;
220
221 let right_pointer = stack.pop().unwrap();
222 let left_pointer = stack.pop().unwrap();
223
224 let left_length = list::list_get_length(left_pointer, memory);
225 let right_length = list::list_get_length(right_pointer, memory);
226 assert_eq!(left_length, right_length);
227 let len = left_length;
228
229 let output_pointer = dyn_malloc::dynamic_allocator(memory);
230 list::list_new(output_pointer, memory);
231 list::list_set_length(output_pointer, len, memory);
232
233 for i in 0..len {
234 let left_item = list::list_get(left_pointer, i, memory, self.left_type.stack_size());
235 let right_item = list::list_get(right_pointer, i, memory, self.right_type.stack_size());
236
237 let pair = right_item.into_iter().chain(left_item).collect_vec();
238 list::list_set(output_pointer, i, pair, memory);
239 }
240
241 stack.push(output_pointer);
242 }
243
244 fn pseudorandom_initial_state(
245 &self,
246 seed: [u8; 32],
247 _bench_case: Option<BenchmarkCase>,
248 ) -> FunctionInitialState {
249 let mut rng = StdRng::from_seed(seed);
250 let list_len = rng.random_range(0..20);
251 let execution_state = self.generate_input_state(list_len, list_len);
252 FunctionInitialState {
253 stack: execution_state.stack,
254 memory: execution_state.nondeterminism.ram,
255 }
256 }
257}
258
259impl Zip {
260 fn generate_input_state(&self, left_length: usize, right_length: usize) -> InitVmState {
261 let fill_with_random_elements =
262 |data_type: &DataType, list_pointer, list_len, memory: &mut _| {
263 untyped_insert_random_list(list_pointer, list_len, memory, data_type.stack_size())
264 };
265
266 let left_pointer = BFieldElement::new(0);
267 let left_size = LIST_METADATA_SIZE + left_length * self.left_type.stack_size();
268 let right_pointer = left_pointer + BFieldElement::new(left_size as u64);
269
270 let mut memory = HashMap::default();
271 fill_with_random_elements(&self.left_type, left_pointer, left_length, &mut memory);
272 fill_with_random_elements(&self.right_type, right_pointer, right_length, &mut memory);
273
274 let stack = [empty_stack(), vec![left_pointer, right_pointer]].concat();
275
276 InitVmState::with_stack_and_memory(stack, memory)
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use proptest::collection::vec;
283
284 use super::*;
285 use crate::rust_shadowing_helper_functions::list;
286 use crate::test_prelude::*;
287
288 #[test]
289 fn prop_test_xfe_digest() {
290 ShadowedFunction::new(Zip::new(DataType::Xfe, DataType::Digest)).test();
291 }
292
293 #[test]
294 fn list_prop_test_more_types() {
295 ShadowedFunction::new(Zip::new(DataType::Bfe, DataType::Bfe)).test();
296 ShadowedFunction::new(Zip::new(DataType::U64, DataType::U32)).test();
297 ShadowedFunction::new(Zip::new(DataType::Bool, DataType::Digest)).test();
298 ShadowedFunction::new(Zip::new(DataType::U128, DataType::VoidPointer)).test();
299 ShadowedFunction::new(Zip::new(DataType::U128, DataType::Digest)).test();
300 ShadowedFunction::new(Zip::new(DataType::U128, DataType::U128)).test();
301 ShadowedFunction::new(Zip::new(DataType::Digest, DataType::Digest)).test();
302 }
303
304 #[proptest]
305 fn zipping_u32s_with_x_field_elements_correspond_to_bfieldcodec(
306 left_list: Vec<u32>,
307 #[strategy(vec(arb(), #left_list.len()))] right_list: Vec<XFieldElement>,
308 ) {
309 let left_pointer = bfe!(0);
310 let right_pointer = bfe!(1_u64 << 60); let mut ram = HashMap::default();
313 write_list_to_ram(&mut ram, left_pointer, &left_list);
314 write_list_to_ram(&mut ram, right_pointer, &right_list);
315
316 let mut stack = [empty_stack(), vec![left_pointer, right_pointer]].concat();
317
318 let zip = Zip::new(DataType::U32, DataType::Xfe);
319 zip.rust_shadow(&mut stack, &mut ram);
320 let output_list_pointer = stack.pop().unwrap();
321 let tasm_zipped = *Vec::decode_from_memory(&ram, output_list_pointer).unwrap();
322
323 let rust_zipped = left_list.into_iter().zip_eq(right_list).collect_vec();
324 prop_assert_eq!(rust_zipped, tasm_zipped);
325 }
326
327 fn write_list_to_ram<T: BFieldCodec + Copy>(
328 ram: &mut HashMap<BFieldElement, BFieldElement>,
329 list_pointer: BFieldElement,
330 list: &[T],
331 ) {
332 list::list_new(list_pointer, ram);
333 for &item in list {
334 list::list_push(list_pointer, item.encode(), ram);
335 }
336 }
337}
338
339#[cfg(test)]
340mod benches {
341 use super::*;
342 use crate::test_prelude::*;
343
344 #[test]
345 fn benchmark() {
346 ShadowedFunction::new(Zip::new(DataType::Xfe, DataType::Digest)).bench();
347 }
348}