tasm_lib/list/higher_order/
zip.rs

1use std::collections::HashMap;
2
3use itertools::Itertools;
4use rand::prelude::*;
5use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
6use triton_vm::prelude::*;
7
8use crate::InitVmState;
9use crate::empty_stack;
10use crate::list::LIST_METADATA_SIZE;
11use crate::list::new::New;
12use crate::prelude::*;
13use crate::rust_shadowing_helper_functions::list::untyped_insert_random_list;
14use crate::snippet_bencher::BenchmarkCase;
15use crate::traits::function::*;
16
17/// Zips two lists of equal length, returning a new list of pairs of elements.
18#[derive(Debug, Clone, Eq, PartialEq, Hash)]
19pub struct Zip {
20    pub left_type: DataType,
21    pub right_type: DataType,
22}
23
24impl Zip {
25    pub fn new(left_type: DataType, right_type: DataType) -> Self {
26        Self {
27            left_type,
28            right_type,
29        }
30    }
31}
32
33impl BasicSnippet for Zip {
34    fn inputs(&self) -> Vec<(DataType, String)> {
35        let list = |data_type| DataType::List(Box::new(data_type));
36
37        let left_list = (list(self.left_type.clone()), "*left_list".to_string());
38        let right_list = (list(self.right_type.clone()), "*right_list".to_string());
39        vec![left_list, right_list]
40    }
41
42    fn outputs(&self) -> Vec<(DataType, String)> {
43        let list = |data_type| DataType::List(Box::new(data_type));
44
45        let tuple_type = DataType::Tuple(vec![self.left_type.clone(), self.right_type.clone()]);
46        let output_list = (list(tuple_type), "*output_list".to_string());
47        vec![output_list]
48    }
49
50    fn entrypoint(&self) -> String {
51        format!(
52            "tasmlib_list_higher_order_u32_zip_{}_with_{}",
53            self.left_type.label_friendly_name(),
54            self.right_type.label_friendly_name()
55        )
56    }
57
58    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
59        let output_type = DataType::Tuple(vec![self.left_type.clone(), self.right_type.clone()]);
60
61        let new_output_list = library.import(Box::new(New));
62
63        let entrypoint = self.entrypoint();
64        let main_loop_label = format!("{entrypoint}_loop");
65
66        let right_size = self.right_type.stack_size();
67        let left_size = self.left_type.stack_size();
68        let read_left_element = self.left_type.read_value_from_memory_leave_pointer();
69        let read_right_element = self.right_type.read_value_from_memory_leave_pointer();
70        let write_output_element = output_type.write_value_to_memory_leave_pointer();
71        let left_size_plus_one = left_size + 1;
72        let left_size_plus_three = left_size + 3;
73        let sum_of_size = left_size + right_size;
74        let sum_of_size_plus_two = sum_of_size + 2;
75        assert!(
76            sum_of_size_plus_two <= NUM_OP_STACK_REGISTERS,
77            "zip only works for an output element size less than or equal to the available \
78                op-stack words"
79        );
80        let minus_two_times_sum_of_size = -(2 * sum_of_size as i32);
81
82        let mul_with_size = |n| match n {
83            0 => triton_asm!(pop 1 push 0),
84            1 => triton_asm!(),
85            n => triton_asm!(
86                push {n}
87                mul
88            ),
89        };
90
91        let main_loop = triton_asm!(
92            // INVARIANT: _ *l *l_elem_last_word *r_elem_last_word *pair_elem_first_word
93            {main_loop_label}:
94                // test return condition: *l == *l_elem_last_word
95                dup 3
96                dup 3
97                eq
98
99                skiz return
100                // _*l *l_elem_last_word *r_elem_last_word *pair_elem_first_word
101
102                dup 2
103                {&read_left_element}
104                // _ *l *l_elem_last_word *r_elem_last_word *pair_elem_first_word [left_element] *l_elem_last_word_prev
105
106                swap {left_size_plus_three}
107                pop 1
108                // _ *l *l_elem_last_word_prev *r_elem_last_word *pair_elem_first_word [left_element]
109
110                dup {left_size_plus_one}
111                // _ *l *l_elem_last_word *r_elem_last_word *pair_elem_first_word [left_element] *r_elem_last_word
112
113                {&read_right_element}
114                // _ *l *l_elem_last_word *r_elem_last_word *pair_elem_first_word [left_element] [right_element] *r_elem_last_word_prev
115
116                swap {sum_of_size_plus_two}
117                pop 1
118                // _ *l *l_elem_last_word_prev *r_elem_last_word_prev *pair_elem_first_word [left_element] [right_element]
119
120                dup {sum_of_size}
121                // _ *l *l_elem_last_word_prev *r_elem_last_word_prev *pair_elem_first_word [right_element] [left_element] *pair_elem_first_word
122
123                {&write_output_element}
124                // _ *l *l_elem_last_word_prev *r_elem_last_word_prev *pair_elem_first_word *pair_elem_first_word_next
125
126                push {minus_two_times_sum_of_size}
127                add
128                // _ *l *l_elem_last_word_prev *r_elem_last_word_prev *pair_elem_first_word *pair_elem_first_word_prev
129
130                swap 1
131                pop 1
132                // _ *l *l_elem_last_word_prev *r_elem_last_word_prev *pair_elem_first_word_prev
133
134                recurse
135        );
136
137        triton_asm!(
138            // BEFORE: _ *left_list *right_list
139            // AFTER:  _ *pair_list
140            {entrypoint}:
141            // get lengths
142            dup 1                   // _ *left_list *right_list *left_list
143            read_mem 1 pop 1        // _ *left_list *right_list left_len
144
145            dup 1                   // _ *left_list *right_list left_len *right_list
146            read_mem 1 pop 1        // _ *left_list *right_list left_len right_len
147
148            // assert equal lengths
149            dup 1                   // _ *left_list *right_list left_len right_len left_len
150            eq assert               // _ *left_list *right_list len
151
152            // create object for pair list and set length
153            call {new_output_list}  // _ *left_list *right_list len *pair_list
154
155            // Write length of *pair_list
156            dup 1
157            swap 1
158            write_mem 1
159            // _ *left_list *right_list len *pair_list_first_word
160
161            // Change all pointers to point to end of lists, in preparation for loop
162            dup 1
163            push -1
164            add
165            // _ *left_list *right_list len *pair_list_first_word (len - 1)
166
167            {&mul_with_size(sum_of_size)}
168            add
169            // _ *left_list *right_list len *pair_list_last_element_first_word
170
171            swap 2
172            // _ *left_list *pair_list_last_element_first_word len *right_list
173
174            dup 1
175            {&mul_with_size(right_size)}
176            add
177            // _ *left_list *pair_list_last_element_first_word len *r_list_last_elem_last_word
178
179            swap 1
180            // _ *left_list *pair_list_last_element_first_word *r_list_last_elem_last_word len
181
182            {&mul_with_size(left_size)}
183            // _ *left_list *pair_list_last_element_first_word *r_list_last_elem_last_word left_offset
184
185            dup 3
186            // _ *left_list *pair_list_last_element_first_word *r_list_last_elem_last_word left_offset *left_list
187
188            add
189            // _ *left_list *pair_list_last_element_first_word *r_list_last_elem_last_word *l_list_last_elem_last_word
190
191            swap 2
192            // _ *l *l_elem_last_word *r_elem_last_word *pair_elem_first_word
193
194            call {main_loop_label}
195            // _ *l *l_elem_last_word *r_elem_last_word *pair_elem_first_word
196
197            // Adjust *pair to point to list instead of element in list
198            push {sum_of_size - 1}
199            add
200
201            swap 3
202
203            pop 3
204
205            return
206
207            {&main_loop}
208        )
209    }
210}
211
212impl Function for Zip {
213    fn rust_shadow(
214        &self,
215        stack: &mut Vec<BFieldElement>,
216        memory: &mut HashMap<BFieldElement, BFieldElement>,
217    ) {
218        use crate::rust_shadowing_helper_functions::dyn_malloc;
219        use crate::rust_shadowing_helper_functions::list;
220
221        let right_pointer = stack.pop().unwrap();
222        let left_pointer = stack.pop().unwrap();
223
224        let left_length = list::list_get_length(left_pointer, memory);
225        let right_length = list::list_get_length(right_pointer, memory);
226        assert_eq!(left_length, right_length);
227        let len = left_length;
228
229        let output_pointer = dyn_malloc::dynamic_allocator(memory);
230        list::list_new(output_pointer, memory);
231        list::list_set_length(output_pointer, len, memory);
232
233        for i in 0..len {
234            let left_item = list::list_get(left_pointer, i, memory, self.left_type.stack_size());
235            let right_item = list::list_get(right_pointer, i, memory, self.right_type.stack_size());
236
237            let pair = right_item.into_iter().chain(left_item).collect_vec();
238            list::list_set(output_pointer, i, pair, memory);
239        }
240
241        stack.push(output_pointer);
242    }
243
244    fn pseudorandom_initial_state(
245        &self,
246        seed: [u8; 32],
247        _bench_case: Option<BenchmarkCase>,
248    ) -> FunctionInitialState {
249        let mut rng = StdRng::from_seed(seed);
250        let list_len = rng.random_range(0..20);
251        let execution_state = self.generate_input_state(list_len, list_len);
252        FunctionInitialState {
253            stack: execution_state.stack,
254            memory: execution_state.nondeterminism.ram,
255        }
256    }
257}
258
259impl Zip {
260    fn generate_input_state(&self, left_length: usize, right_length: usize) -> InitVmState {
261        let fill_with_random_elements =
262            |data_type: &DataType, list_pointer, list_len, memory: &mut _| {
263                untyped_insert_random_list(list_pointer, list_len, memory, data_type.stack_size())
264            };
265
266        let left_pointer = BFieldElement::new(0);
267        let left_size = LIST_METADATA_SIZE + left_length * self.left_type.stack_size();
268        let right_pointer = left_pointer + BFieldElement::new(left_size as u64);
269
270        let mut memory = HashMap::default();
271        fill_with_random_elements(&self.left_type, left_pointer, left_length, &mut memory);
272        fill_with_random_elements(&self.right_type, right_pointer, right_length, &mut memory);
273
274        let stack = [empty_stack(), vec![left_pointer, right_pointer]].concat();
275
276        InitVmState::with_stack_and_memory(stack, memory)
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use proptest::collection::vec;
283
284    use super::*;
285    use crate::rust_shadowing_helper_functions::list;
286    use crate::test_prelude::*;
287
288    #[test]
289    fn prop_test_xfe_digest() {
290        ShadowedFunction::new(Zip::new(DataType::Xfe, DataType::Digest)).test();
291    }
292
293    #[test]
294    fn list_prop_test_more_types() {
295        ShadowedFunction::new(Zip::new(DataType::Bfe, DataType::Bfe)).test();
296        ShadowedFunction::new(Zip::new(DataType::U64, DataType::U32)).test();
297        ShadowedFunction::new(Zip::new(DataType::Bool, DataType::Digest)).test();
298        ShadowedFunction::new(Zip::new(DataType::U128, DataType::VoidPointer)).test();
299        ShadowedFunction::new(Zip::new(DataType::U128, DataType::Digest)).test();
300        ShadowedFunction::new(Zip::new(DataType::U128, DataType::U128)).test();
301        ShadowedFunction::new(Zip::new(DataType::Digest, DataType::Digest)).test();
302    }
303
304    #[proptest]
305    fn zipping_u32s_with_x_field_elements_correspond_to_bfieldcodec(
306        left_list: Vec<u32>,
307        #[strategy(vec(arb(), #left_list.len()))] right_list: Vec<XFieldElement>,
308    ) {
309        let left_pointer = bfe!(0);
310        let right_pointer = bfe!(1_u64 << 60); // far enough
311
312        let mut ram = HashMap::default();
313        write_list_to_ram(&mut ram, left_pointer, &left_list);
314        write_list_to_ram(&mut ram, right_pointer, &right_list);
315
316        let mut stack = [empty_stack(), vec![left_pointer, right_pointer]].concat();
317
318        let zip = Zip::new(DataType::U32, DataType::Xfe);
319        zip.rust_shadow(&mut stack, &mut ram);
320        let output_list_pointer = stack.pop().unwrap();
321        let tasm_zipped = *Vec::decode_from_memory(&ram, output_list_pointer).unwrap();
322
323        let rust_zipped = left_list.into_iter().zip_eq(right_list).collect_vec();
324        prop_assert_eq!(rust_zipped, tasm_zipped);
325    }
326
327    fn write_list_to_ram<T: BFieldCodec + Copy>(
328        ram: &mut HashMap<BFieldElement, BFieldElement>,
329        list_pointer: BFieldElement,
330        list: &[T],
331    ) {
332        list::list_new(list_pointer, ram);
333        for &item in list {
334            list::list_push(list_pointer, item.encode(), ram);
335        }
336    }
337}
338
339#[cfg(test)]
340mod benches {
341    use super::*;
342    use crate::test_prelude::*;
343
344    #[test]
345    fn benchmark() {
346        ShadowedFunction::new(Zip::new(DataType::Xfe, DataType::Digest)).bench();
347    }
348}