Skip to main content

tasm_lib/list/higher_order/
zip.rs

1use std::collections::HashMap;
2
3use itertools::Itertools;
4use rand::prelude::*;
5use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
6use triton_vm::prelude::*;
7
8use crate::InitVmState;
9use crate::empty_stack;
10use crate::list::LIST_METADATA_SIZE;
11use crate::list::new::New;
12use crate::prelude::*;
13use crate::rust_shadowing_helper_functions::list::untyped_insert_random_list;
14use crate::snippet_bencher::BenchmarkCase;
15use crate::traits::function::*;
16use crate::traits::rust_shadow::RustShadowError;
17
18/// Zips two lists of equal length, returning a new list of pairs of elements.
19#[derive(Debug, Clone, Eq, PartialEq, Hash)]
20pub struct Zip {
21    pub left_type: DataType,
22    pub right_type: DataType,
23}
24
25impl Zip {
26    pub fn new(left_type: DataType, right_type: DataType) -> Self {
27        Self {
28            left_type,
29            right_type,
30        }
31    }
32}
33
34impl BasicSnippet for Zip {
35    fn parameters(&self) -> Vec<(DataType, String)> {
36        let list = |data_type| DataType::List(Box::new(data_type));
37
38        let left_list = (list(self.left_type.clone()), "*left_list".to_string());
39        let right_list = (list(self.right_type.clone()), "*right_list".to_string());
40        vec![left_list, right_list]
41    }
42
43    fn return_values(&self) -> Vec<(DataType, String)> {
44        let list = |data_type| DataType::List(Box::new(data_type));
45
46        let tuple_type = DataType::Tuple(vec![self.left_type.clone(), self.right_type.clone()]);
47        let output_list = (list(tuple_type), "*output_list".to_string());
48        vec![output_list]
49    }
50
51    fn entrypoint(&self) -> String {
52        format!(
53            "tasmlib_list_higher_order_u32_zip_{}_with_{}",
54            self.left_type.label_friendly_name(),
55            self.right_type.label_friendly_name()
56        )
57    }
58
59    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
60        let output_type = DataType::Tuple(vec![self.left_type.clone(), self.right_type.clone()]);
61
62        let new_output_list = library.import(Box::new(New));
63
64        let entrypoint = self.entrypoint();
65        let main_loop_label = format!("{entrypoint}_loop");
66
67        let right_size = self.right_type.stack_size();
68        let left_size = self.left_type.stack_size();
69        let read_left_element = self.left_type.read_value_from_memory_leave_pointer();
70        let read_right_element = self.right_type.read_value_from_memory_leave_pointer();
71        let write_output_element = output_type.write_value_to_memory_leave_pointer();
72        let left_size_plus_one = left_size + 1;
73        let left_size_plus_three = left_size + 3;
74        let sum_of_size = left_size + right_size;
75        let sum_of_size_plus_two = sum_of_size + 2;
76        assert!(
77            sum_of_size_plus_two <= NUM_OP_STACK_REGISTERS,
78            "zip only works for an output element size less than or equal to the available \
79                op-stack words"
80        );
81        let minus_two_times_sum_of_size = -(2 * sum_of_size as i32);
82
83        let mul_with_size = |n| match n {
84            0 => triton_asm!(pop 1 push 0),
85            1 => triton_asm!(),
86            n => triton_asm!(
87                push {n}
88                mul
89            ),
90        };
91
92        let main_loop = triton_asm!(
93            // INVARIANT: _ *l *l_elem_last_word *r_elem_last_word *pair_elem_first_word
94            {main_loop_label}:
95                // test return condition: *l == *l_elem_last_word
96                dup 3
97                dup 3
98                eq
99
100                skiz return
101                // _*l *l_elem_last_word *r_elem_last_word *pair_elem_first_word
102
103                dup 2
104                {&read_left_element}
105                // _ *l *l_elem_last_word *r_elem_last_word *pair_elem_first_word [left_element] *l_elem_last_word_prev
106
107                swap {left_size_plus_three}
108                pop 1
109                // _ *l *l_elem_last_word_prev *r_elem_last_word *pair_elem_first_word [left_element]
110
111                dup {left_size_plus_one}
112                // _ *l *l_elem_last_word *r_elem_last_word *pair_elem_first_word [left_element] *r_elem_last_word
113
114                {&read_right_element}
115                // _ *l *l_elem_last_word *r_elem_last_word *pair_elem_first_word [left_element] [right_element] *r_elem_last_word_prev
116
117                swap {sum_of_size_plus_two}
118                pop 1
119                // _ *l *l_elem_last_word_prev *r_elem_last_word_prev *pair_elem_first_word [left_element] [right_element]
120
121                dup {sum_of_size}
122                // _ *l *l_elem_last_word_prev *r_elem_last_word_prev *pair_elem_first_word [right_element] [left_element] *pair_elem_first_word
123
124                {&write_output_element}
125                // _ *l *l_elem_last_word_prev *r_elem_last_word_prev *pair_elem_first_word *pair_elem_first_word_next
126
127                push {minus_two_times_sum_of_size}
128                add
129                // _ *l *l_elem_last_word_prev *r_elem_last_word_prev *pair_elem_first_word *pair_elem_first_word_prev
130
131                swap 1
132                pop 1
133                // _ *l *l_elem_last_word_prev *r_elem_last_word_prev *pair_elem_first_word_prev
134
135                recurse
136        );
137
138        triton_asm!(
139            // BEFORE: _ *left_list *right_list
140            // AFTER:  _ *pair_list
141            {entrypoint}:
142            // get lengths
143            dup 1                   // _ *left_list *right_list *left_list
144            read_mem 1 pop 1        // _ *left_list *right_list left_len
145
146            dup 1                   // _ *left_list *right_list left_len *right_list
147            read_mem 1 pop 1        // _ *left_list *right_list left_len right_len
148
149            // assert equal lengths
150            dup 1                   // _ *left_list *right_list left_len right_len left_len
151            eq assert               // _ *left_list *right_list len
152
153            // create object for pair list and set length
154            call {new_output_list}  // _ *left_list *right_list len *pair_list
155
156            // Write length of *pair_list
157            dup 1
158            swap 1
159            write_mem 1
160            // _ *left_list *right_list len *pair_list_first_word
161
162            // Change all pointers to point to end of lists, in preparation for loop
163            dup 1
164            push -1
165            add
166            // _ *left_list *right_list len *pair_list_first_word (len - 1)
167
168            {&mul_with_size(sum_of_size)}
169            add
170            // _ *left_list *right_list len *pair_list_last_element_first_word
171
172            swap 2
173            // _ *left_list *pair_list_last_element_first_word len *right_list
174
175            dup 1
176            {&mul_with_size(right_size)}
177            add
178            // _ *left_list *pair_list_last_element_first_word len *r_list_last_elem_last_word
179
180            swap 1
181            // _ *left_list *pair_list_last_element_first_word *r_list_last_elem_last_word len
182
183            {&mul_with_size(left_size)}
184            // _ *left_list *pair_list_last_element_first_word *r_list_last_elem_last_word left_offset
185
186            dup 3
187            // _ *left_list *pair_list_last_element_first_word *r_list_last_elem_last_word left_offset *left_list
188
189            add
190            // _ *left_list *pair_list_last_element_first_word *r_list_last_elem_last_word *l_list_last_elem_last_word
191
192            swap 2
193            // _ *l *l_elem_last_word *r_elem_last_word *pair_elem_first_word
194
195            call {main_loop_label}
196            // _ *l *l_elem_last_word *r_elem_last_word *pair_elem_first_word
197
198            // Adjust *pair to point to list instead of element in list
199            push {sum_of_size - 1}
200            add
201
202            swap 3
203
204            pop 3
205
206            return
207
208            {&main_loop}
209        )
210    }
211}
212
213impl Function for Zip {
214    fn rust_shadow(
215        &self,
216        stack: &mut Vec<BFieldElement>,
217        memory: &mut HashMap<BFieldElement, BFieldElement>,
218    ) -> Result<(), RustShadowError> {
219        use crate::rust_shadowing_helper_functions::dyn_malloc;
220        use crate::rust_shadowing_helper_functions::list;
221
222        let right_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
223        let left_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
224
225        let left_length = list::list_get_length(left_pointer, memory)?;
226        let right_length = list::list_get_length(right_pointer, memory)?;
227        if left_length != right_length {
228            return Err(RustShadowError::Other);
229        }
230        let len = left_length;
231
232        let output_pointer = dyn_malloc::dynamic_allocator(memory);
233        list::list_new(output_pointer, memory);
234        list::list_set_length(output_pointer, len, memory);
235
236        for i in 0..len {
237            let left_item = list::list_get(left_pointer, i, memory, self.left_type.stack_size())?;
238            let right_item =
239                list::list_get(right_pointer, i, memory, self.right_type.stack_size())?;
240
241            let pair = right_item.into_iter().chain(left_item).collect_vec();
242            list::list_set(output_pointer, i, pair, memory)?;
243        }
244
245        stack.push(output_pointer);
246        Ok(())
247    }
248
249    fn pseudorandom_initial_state(
250        &self,
251        seed: [u8; 32],
252        _bench_case: Option<BenchmarkCase>,
253    ) -> FunctionInitialState {
254        let mut rng = StdRng::from_seed(seed);
255        let list_len = rng.random_range(0..20);
256        let execution_state = self.generate_input_state(list_len, list_len);
257        FunctionInitialState {
258            stack: execution_state.stack,
259            memory: execution_state.nondeterminism.ram,
260        }
261    }
262}
263
264impl Zip {
265    fn generate_input_state(&self, left_length: usize, right_length: usize) -> InitVmState {
266        let fill_with_random_elements =
267            |data_type: &DataType, list_pointer, list_len, memory: &mut _| {
268                untyped_insert_random_list(list_pointer, list_len, memory, data_type.stack_size())
269            };
270
271        let left_pointer = BFieldElement::new(0);
272        let left_size = LIST_METADATA_SIZE + left_length * self.left_type.stack_size();
273        let right_pointer = left_pointer + BFieldElement::new(left_size as u64);
274
275        let mut memory = HashMap::default();
276        fill_with_random_elements(&self.left_type, left_pointer, left_length, &mut memory);
277        fill_with_random_elements(&self.right_type, right_pointer, right_length, &mut memory);
278
279        let stack = [empty_stack(), vec![left_pointer, right_pointer]].concat();
280
281        InitVmState::with_stack_and_memory(stack, memory)
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use proptest::collection::vec;
288
289    use super::*;
290    use crate::rust_shadowing_helper_functions::list;
291    use crate::test_prelude::*;
292
293    #[macro_rules_attr::apply(test)]
294    fn prop_test_xfe_digest() {
295        ShadowedFunction::new(Zip::new(DataType::Xfe, DataType::Digest)).test();
296    }
297
298    #[macro_rules_attr::apply(test)]
299    fn list_prop_test_more_types() {
300        ShadowedFunction::new(Zip::new(DataType::Bfe, DataType::Bfe)).test();
301        ShadowedFunction::new(Zip::new(DataType::U64, DataType::U32)).test();
302        ShadowedFunction::new(Zip::new(DataType::Bool, DataType::Digest)).test();
303        ShadowedFunction::new(Zip::new(DataType::U128, DataType::VoidPointer)).test();
304        ShadowedFunction::new(Zip::new(DataType::U128, DataType::Digest)).test();
305        ShadowedFunction::new(Zip::new(DataType::U128, DataType::U128)).test();
306        ShadowedFunction::new(Zip::new(DataType::Digest, DataType::Digest)).test();
307    }
308
309    #[macro_rules_attr::apply(proptest)]
310    fn zipping_u32s_with_x_field_elements_correspond_to_bfieldcodec(
311        left_list: Vec<u32>,
312        #[strategy(vec(arb(), #left_list.len()))] right_list: Vec<XFieldElement>,
313    ) {
314        let left_pointer = bfe!(0);
315        let right_pointer = bfe!(1_u64 << 60); // far enough
316
317        let mut ram = HashMap::default();
318        write_list_to_ram(&mut ram, left_pointer, &left_list)?;
319        write_list_to_ram(&mut ram, right_pointer, &right_list)?;
320
321        let mut stack = [empty_stack(), vec![left_pointer, right_pointer]].concat();
322
323        let zip = Zip::new(DataType::U32, DataType::Xfe);
324        zip.rust_shadow(&mut stack, &mut ram).unwrap();
325        let output_list_pointer = stack.pop().unwrap();
326        let tasm_zipped = *Vec::decode_from_memory(&ram, output_list_pointer).unwrap();
327
328        let rust_zipped = left_list.into_iter().zip_eq(right_list).collect_vec();
329        prop_assert_eq!(rust_zipped, tasm_zipped);
330    }
331
332    fn write_list_to_ram<T: BFieldCodec + Copy>(
333        ram: &mut HashMap<BFieldElement, BFieldElement>,
334        list_pointer: BFieldElement,
335        list: &[T],
336    ) -> Result<(), RustShadowError> {
337        list::list_new(list_pointer, ram);
338        for &item in list {
339            list::list_push(list_pointer, item.encode(), ram)?;
340        }
341
342        Ok(())
343    }
344}
345
346#[cfg(test)]
347mod benches {
348    use super::*;
349    use crate::test_prelude::*;
350
351    #[macro_rules_attr::apply(test)]
352    fn benchmark() {
353        ShadowedFunction::new(Zip::new(DataType::Xfe, DataType::Digest)).bench();
354    }
355}