tasm_lib/list/
sum_xfes.rs

1use triton_vm::prelude::*;
2use twenty_first::math::x_field_element::EXTENSION_DEGREE;
3
4use crate::prelude::*;
5
6/// Calculate the sum of the `XFieldElement`s in a list
7#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
8#[allow(dead_code)]
9struct SumOfXfes;
10
11impl BasicSnippet for SumOfXfes {
12    fn inputs(&self) -> Vec<(DataType, String)> {
13        vec![(
14            // For naming the input argument, I just follow what `Rust` calls this argument in
15            // its `sum` method.
16            DataType::List(Box::new(DataType::Xfe)),
17            "self".to_owned(),
18        )]
19    }
20
21    fn outputs(&self) -> Vec<(DataType, String)> {
22        vec![(DataType::Xfe, "sum".to_owned())]
23    }
24
25    fn entrypoint(&self) -> String {
26        format!("tasmlib_list_sum_{}", DataType::Xfe.label_friendly_name())
27    }
28
29    fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> {
30        assert_eq!(
31            3, EXTENSION_DEGREE,
32            "Code only works for extension degree = 3, got: {EXTENSION_DEGREE}"
33        );
34        let entrypoint = self.entrypoint();
35        let accumulate_5_elements_loop_label = format!("{entrypoint}_acc_5_elements_loop");
36        let accumulate_5_elements_loop = triton_asm!(
37            // Invariant: _ *end_loop *element_last_word [acc; 3]
38            {accumulate_5_elements_loop_label}:
39                // _ *end_loop *element_last_word [acc; 3]
40                dup 4
41                dup 4
42                eq
43                skiz
44                    return
45                // _ *end_loop *element_last_word [acc]
46
47                dup 3
48                read_mem 5
49                read_mem 5
50                read_mem 5
51                // _ *end_loop *element_last_word [acc] [elem_4] [elem_3] [elem_2] [elem_1] [elem_0] (*element_last_word - 15)
52
53                pop 1
54                // _ *end_loop *element_last_word [acc] [elem_4] [elem_3] [elem_2] [elem_1] [elem_0]
55
56                xx_add
57                xx_add
58                xx_add
59                xx_add
60                xx_add
61                // _ *end_loop *element_last_word [acc']
62
63                swap 3
64                push -15
65                add
66                swap 3
67                // _ *end_loop *element_last_word' [acc']
68
69                recurse
70        );
71
72        let accumulate_one_element_loop_label = format!("{entrypoint}_acc_1_element_loop");
73        let accumulate_one_element_loop = triton_asm!(
74            // Invariant: _ *end_loop *element_last_word [acc; 3]
75            {accumulate_one_element_loop_label}:
76                // _ *end_loop *element_last_word [acc; 3]
77                dup 4
78                dup 4
79                eq
80                skiz
81                    return
82                // _ *end_loop *element_last_word [acc]
83
84                dup 3
85                read_mem 3
86                // _ *end_loop *element_last_word [acc; 3] [element_last_words; 3] (*element_last_word - 3)
87
88                swap 7
89                pop 1
90                // _ *end_loop (*element_last_word - 3) [acc; 3] [element_last_words; 3]
91
92                xx_add
93                // _ *end_loop *element_last_word' [acc']
94
95                recurse
96        );
97
98        let offset_for_last_word = triton_asm!(
99            // _ len
100            push 3
101            mul // _ offset_last_word
102        );
103
104        triton_asm!(
105            {entrypoint}:
106                // _ *list
107
108                // Calculate pointer to last element
109                dup 0
110                read_mem 1
111                // _ *list len (*list - 1)
112
113                pop 1
114                // _ *list len
115
116                {&offset_for_last_word}
117                // _ *list offset_last_word
118
119                dup 1
120                add
121                // _ *list *last_word
122
123                // Get pointer to *end_loop that is the loop termination condition
124                push 5
125                dup 2
126                // _ *list *last_word 5 *list
127
128                read_mem 1
129                pop 1
130                // _ *list *last_word 5 len
131
132                div_mod
133                // _ *list *last_word (len / 5) (len % 5)
134
135                swap 1
136                pop 1
137                // _ *list *last_word (len % 5)
138
139                push {EXTENSION_DEGREE}
140                mul
141                // _ *list *last_word ((len % 5) * 3)
142
143                dup 2
144                add
145                // _ *list *last_word ((len % 5) * 3 + *list)
146                // _ *list *last_word *end_5_loop
147
148                swap 1
149                push 0
150                push 0
151                push 0
152                // _ *list *end_5_loop *last_word [acc]
153
154                call {accumulate_5_elements_loop_label}
155                // _ *list *end_5_loop *end_5_loop [acc]
156
157                swap 1
158                swap 2
159                swap 3
160                pop 1
161                // _ *list *end_5_loop [acc]
162                // _ *end_condition_1_loop *end_5_loop [acc]
163
164                call {accumulate_one_element_loop_label}
165                // _ *end_condition_1_loop *end_5_loop [acc]
166
167                swap 2
168                swap 4
169                pop 1
170                swap 2
171                pop 1
172                // _ [acc]
173
174                return
175
176                {&accumulate_one_element_loop}
177                {&accumulate_5_elements_loop}
178        )
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use twenty_first::math::x_field_element::EXTENSION_DEGREE;
185
186    use super::*;
187    use crate::rust_shadowing_helper_functions::list::insert_random_list;
188    use crate::rust_shadowing_helper_functions::list::load_list_with_copy_elements;
189    use crate::test_helpers::test_rust_equivalence_given_complete_state;
190    use crate::test_prelude::*;
191
192    impl Function for SumOfXfes {
193        fn rust_shadow(
194            &self,
195            stack: &mut Vec<BFieldElement>,
196            memory: &mut HashMap<BFieldElement, BFieldElement>,
197        ) {
198            let list_pointer = stack.pop().unwrap();
199            let list = load_list_with_copy_elements::<EXTENSION_DEGREE>(list_pointer, memory);
200
201            let sum: XFieldElement = list
202                .into_iter()
203                .map(|x| XFieldElement::new([x[0], x[1], x[2]]))
204                .sum();
205            for elem in sum.coefficients.into_iter().rev() {
206                stack.push(elem);
207            }
208        }
209
210        fn pseudorandom_initial_state(
211            &self,
212            seed: [u8; 32],
213            bench_case: Option<BenchmarkCase>,
214        ) -> FunctionInitialState {
215            let mut rng = StdRng::from_seed(seed);
216            let list_pointer = rng.random();
217            let list_length = match bench_case {
218                Some(BenchmarkCase::CommonCase) => 104,
219                Some(BenchmarkCase::WorstCase) => 1004,
220                None => rng.random_range(0..200),
221            };
222            self.prepare_state(list_pointer, list_length)
223        }
224
225        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
226            (0..13)
227                .map(|len| self.prepare_state(BFieldElement::new((1u64 << 44) + 1313), len))
228                .collect_vec()
229        }
230    }
231
232    impl SumOfXfes {
233        fn prepare_state(
234            &self,
235            list_pointer: BFieldElement,
236            list_length: usize,
237        ) -> FunctionInitialState {
238            let mut memory = HashMap::default();
239            insert_random_list(&DataType::Xfe, list_pointer, list_length, &mut memory);
240
241            let mut init_stack = self.init_stack_for_isolated_run();
242            init_stack.push(list_pointer);
243            FunctionInitialState {
244                stack: init_stack,
245                memory,
246            }
247        }
248    }
249
250    #[test]
251    fn sum_xfes_pbt() {
252        ShadowedFunction::new(SumOfXfes).test()
253    }
254
255    #[test]
256    fn sum_xfes_unit_test() {
257        let snippet = SumOfXfes;
258        let input_list_2_long: Vec<XFieldElement> = vec![rand::random(), rand::random()];
259        let expected_sum: XFieldElement = input_list_2_long.clone().into_iter().sum();
260
261        let mut memory = HashMap::default();
262        let list_pointer = BFieldElement::new(1u64 << 33);
263        insert_xfe_list_into_memory(list_pointer, input_list_2_long, &mut memory);
264        let init_stack = [snippet.init_stack_for_isolated_run(), vec![list_pointer]].concat();
265        let expected_final_stack = [
266            snippet.init_stack_for_isolated_run(),
267            vec![
268                expected_sum.coefficients[2],
269                expected_sum.coefficients[1],
270                expected_sum.coefficients[0],
271            ],
272        ]
273        .concat();
274
275        test_rust_equivalence_given_complete_state(
276            &ShadowedFunction::new(snippet),
277            &init_stack,
278            &[],
279            &NonDeterminism::default().with_ram(memory),
280            &None,
281            Some(&expected_final_stack),
282        );
283    }
284
285    fn insert_xfe_list_into_memory(
286        list_pointer: BFieldElement,
287        list: Vec<XFieldElement>,
288        memory: &mut HashMap<BFieldElement, BFieldElement>,
289    ) {
290        let mut pointer = list_pointer;
291        memory.insert(pointer, BFieldElement::new(list.len() as u64));
292        pointer.increment();
293        for xfe in list.iter() {
294            for bfe in xfe.coefficients.iter() {
295                memory.insert(pointer, *bfe);
296                pointer.increment();
297            }
298        }
299    }
300}
301
302#[cfg(test)]
303mod benches {
304    use super::*;
305    use crate::test_prelude::*;
306
307    #[test]
308    fn benchmark() {
309        ShadowedFunction::new(SumOfXfes).bench();
310    }
311}