tasm_lib/list/higher_order/
zip.rs1use std::collections::HashMap;
2
3use itertools::Itertools;
4use rand::prelude::*;
5use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
6use triton_vm::prelude::*;
7
8use crate::InitVmState;
9use crate::empty_stack;
10use crate::list::LIST_METADATA_SIZE;
11use crate::list::new::New;
12use crate::prelude::*;
13use crate::rust_shadowing_helper_functions::list::untyped_insert_random_list;
14use crate::snippet_bencher::BenchmarkCase;
15use crate::traits::function::*;
16use crate::traits::rust_shadow::RustShadowError;
17
18#[derive(Debug, Clone, Eq, PartialEq, Hash)]
20pub struct Zip {
21 pub left_type: DataType,
22 pub right_type: DataType,
23}
24
25impl Zip {
26 pub fn new(left_type: DataType, right_type: DataType) -> Self {
27 Self {
28 left_type,
29 right_type,
30 }
31 }
32}
33
34impl BasicSnippet for Zip {
35 fn parameters(&self) -> Vec<(DataType, String)> {
36 let list = |data_type| DataType::List(Box::new(data_type));
37
38 let left_list = (list(self.left_type.clone()), "*left_list".to_string());
39 let right_list = (list(self.right_type.clone()), "*right_list".to_string());
40 vec![left_list, right_list]
41 }
42
43 fn return_values(&self) -> Vec<(DataType, String)> {
44 let list = |data_type| DataType::List(Box::new(data_type));
45
46 let tuple_type = DataType::Tuple(vec![self.left_type.clone(), self.right_type.clone()]);
47 let output_list = (list(tuple_type), "*output_list".to_string());
48 vec![output_list]
49 }
50
51 fn entrypoint(&self) -> String {
52 format!(
53 "tasmlib_list_higher_order_u32_zip_{}_with_{}",
54 self.left_type.label_friendly_name(),
55 self.right_type.label_friendly_name()
56 )
57 }
58
59 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
60 let output_type = DataType::Tuple(vec![self.left_type.clone(), self.right_type.clone()]);
61
62 let new_output_list = library.import(Box::new(New));
63
64 let entrypoint = self.entrypoint();
65 let main_loop_label = format!("{entrypoint}_loop");
66
67 let right_size = self.right_type.stack_size();
68 let left_size = self.left_type.stack_size();
69 let read_left_element = self.left_type.read_value_from_memory_leave_pointer();
70 let read_right_element = self.right_type.read_value_from_memory_leave_pointer();
71 let write_output_element = output_type.write_value_to_memory_leave_pointer();
72 let left_size_plus_one = left_size + 1;
73 let left_size_plus_three = left_size + 3;
74 let sum_of_size = left_size + right_size;
75 let sum_of_size_plus_two = sum_of_size + 2;
76 assert!(
77 sum_of_size_plus_two <= NUM_OP_STACK_REGISTERS,
78 "zip only works for an output element size less than or equal to the available \
79 op-stack words"
80 );
81 let minus_two_times_sum_of_size = -(2 * sum_of_size as i32);
82
83 let mul_with_size = |n| match n {
84 0 => triton_asm!(pop 1 push 0),
85 1 => triton_asm!(),
86 n => triton_asm!(
87 push {n}
88 mul
89 ),
90 };
91
92 let main_loop = triton_asm!(
93 {main_loop_label}:
95 dup 3
97 dup 3
98 eq
99
100 skiz return
101 dup 2
104 {&read_left_element}
105 swap {left_size_plus_three}
108 pop 1
109 dup {left_size_plus_one}
112 {&read_right_element}
115 swap {sum_of_size_plus_two}
118 pop 1
119 dup {sum_of_size}
122 {&write_output_element}
125 push {minus_two_times_sum_of_size}
128 add
129 swap 1
132 pop 1
133 recurse
136 );
137
138 triton_asm!(
139 {entrypoint}:
142 dup 1 read_mem 1 pop 1 dup 1 read_mem 1 pop 1 dup 1 eq assert call {new_output_list} dup 1
158 swap 1
159 write_mem 1
160 dup 1
164 push -1
165 add
166 {&mul_with_size(sum_of_size)}
169 add
170 swap 2
173 dup 1
176 {&mul_with_size(right_size)}
177 add
178 swap 1
181 {&mul_with_size(left_size)}
184 dup 3
187 add
190 swap 2
193 call {main_loop_label}
196 push {sum_of_size - 1}
200 add
201
202 swap 3
203
204 pop 3
205
206 return
207
208 {&main_loop}
209 )
210 }
211}
212
213impl Function for Zip {
214 fn rust_shadow(
215 &self,
216 stack: &mut Vec<BFieldElement>,
217 memory: &mut HashMap<BFieldElement, BFieldElement>,
218 ) -> Result<(), RustShadowError> {
219 use crate::rust_shadowing_helper_functions::dyn_malloc;
220 use crate::rust_shadowing_helper_functions::list;
221
222 let right_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
223 let left_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
224
225 let left_length = list::list_get_length(left_pointer, memory)?;
226 let right_length = list::list_get_length(right_pointer, memory)?;
227 if left_length != right_length {
228 return Err(RustShadowError::Other);
229 }
230 let len = left_length;
231
232 let output_pointer = dyn_malloc::dynamic_allocator(memory);
233 list::list_new(output_pointer, memory);
234 list::list_set_length(output_pointer, len, memory);
235
236 for i in 0..len {
237 let left_item = list::list_get(left_pointer, i, memory, self.left_type.stack_size())?;
238 let right_item =
239 list::list_get(right_pointer, i, memory, self.right_type.stack_size())?;
240
241 let pair = right_item.into_iter().chain(left_item).collect_vec();
242 list::list_set(output_pointer, i, pair, memory)?;
243 }
244
245 stack.push(output_pointer);
246 Ok(())
247 }
248
249 fn pseudorandom_initial_state(
250 &self,
251 seed: [u8; 32],
252 _bench_case: Option<BenchmarkCase>,
253 ) -> FunctionInitialState {
254 let mut rng = StdRng::from_seed(seed);
255 let list_len = rng.random_range(0..20);
256 let execution_state = self.generate_input_state(list_len, list_len);
257 FunctionInitialState {
258 stack: execution_state.stack,
259 memory: execution_state.nondeterminism.ram,
260 }
261 }
262}
263
264impl Zip {
265 fn generate_input_state(&self, left_length: usize, right_length: usize) -> InitVmState {
266 let fill_with_random_elements =
267 |data_type: &DataType, list_pointer, list_len, memory: &mut _| {
268 untyped_insert_random_list(list_pointer, list_len, memory, data_type.stack_size())
269 };
270
271 let left_pointer = BFieldElement::new(0);
272 let left_size = LIST_METADATA_SIZE + left_length * self.left_type.stack_size();
273 let right_pointer = left_pointer + BFieldElement::new(left_size as u64);
274
275 let mut memory = HashMap::default();
276 fill_with_random_elements(&self.left_type, left_pointer, left_length, &mut memory);
277 fill_with_random_elements(&self.right_type, right_pointer, right_length, &mut memory);
278
279 let stack = [empty_stack(), vec![left_pointer, right_pointer]].concat();
280
281 InitVmState::with_stack_and_memory(stack, memory)
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use proptest::collection::vec;
288
289 use super::*;
290 use crate::rust_shadowing_helper_functions::list;
291 use crate::test_prelude::*;
292
293 #[macro_rules_attr::apply(test)]
294 fn prop_test_xfe_digest() {
295 ShadowedFunction::new(Zip::new(DataType::Xfe, DataType::Digest)).test();
296 }
297
298 #[macro_rules_attr::apply(test)]
299 fn list_prop_test_more_types() {
300 ShadowedFunction::new(Zip::new(DataType::Bfe, DataType::Bfe)).test();
301 ShadowedFunction::new(Zip::new(DataType::U64, DataType::U32)).test();
302 ShadowedFunction::new(Zip::new(DataType::Bool, DataType::Digest)).test();
303 ShadowedFunction::new(Zip::new(DataType::U128, DataType::VoidPointer)).test();
304 ShadowedFunction::new(Zip::new(DataType::U128, DataType::Digest)).test();
305 ShadowedFunction::new(Zip::new(DataType::U128, DataType::U128)).test();
306 ShadowedFunction::new(Zip::new(DataType::Digest, DataType::Digest)).test();
307 }
308
309 #[macro_rules_attr::apply(proptest)]
310 fn zipping_u32s_with_x_field_elements_correspond_to_bfieldcodec(
311 left_list: Vec<u32>,
312 #[strategy(vec(arb(), #left_list.len()))] right_list: Vec<XFieldElement>,
313 ) {
314 let left_pointer = bfe!(0);
315 let right_pointer = bfe!(1_u64 << 60); let mut ram = HashMap::default();
318 write_list_to_ram(&mut ram, left_pointer, &left_list)?;
319 write_list_to_ram(&mut ram, right_pointer, &right_list)?;
320
321 let mut stack = [empty_stack(), vec![left_pointer, right_pointer]].concat();
322
323 let zip = Zip::new(DataType::U32, DataType::Xfe);
324 zip.rust_shadow(&mut stack, &mut ram).unwrap();
325 let output_list_pointer = stack.pop().unwrap();
326 let tasm_zipped = *Vec::decode_from_memory(&ram, output_list_pointer).unwrap();
327
328 let rust_zipped = left_list.into_iter().zip_eq(right_list).collect_vec();
329 prop_assert_eq!(rust_zipped, tasm_zipped);
330 }
331
332 fn write_list_to_ram<T: BFieldCodec + Copy>(
333 ram: &mut HashMap<BFieldElement, BFieldElement>,
334 list_pointer: BFieldElement,
335 list: &[T],
336 ) -> Result<(), RustShadowError> {
337 list::list_new(list_pointer, ram);
338 for &item in list {
339 list::list_push(list_pointer, item.encode(), ram)?;
340 }
341
342 Ok(())
343 }
344}
345
346#[cfg(test)]
347mod benches {
348 use super::*;
349 use crate::test_prelude::*;
350
351 #[macro_rules_attr::apply(test)]
352 fn benchmark() {
353 ShadowedFunction::new(Zip::new(DataType::Xfe, DataType::Digest)).bench();
354 }
355}