Skip to main content

tasm_lib/array/
sum_of_xfes.rs

1use num::Zero;
2use triton_vm::prelude::*;
3use twenty_first::math::x_field_element::EXTENSION_DEGREE;
4
5use crate::data_type::ArrayType;
6use crate::memory::load_words_from_memory_pop_pointer;
7use crate::prelude::*;
8
9#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
10pub struct SumOfXfes {
11    length: usize,
12}
13
14impl SumOfXfes {
15    fn input_type(&self) -> DataType {
16        DataType::Array(Box::new(ArrayType {
17            element_type: DataType::Xfe,
18            length: self.length,
19        }))
20    }
21}
22
23impl BasicSnippet for SumOfXfes {
24    fn parameters(&self) -> Vec<(DataType, String)> {
25        vec![(self.input_type(), "*array".to_owned())]
26    }
27
28    fn return_values(&self) -> Vec<(DataType, String)> {
29        vec![(DataType::Xfe, "sum".to_owned())]
30    }
31
32    fn entrypoint(&self) -> String {
33        format!("tasmlib_array_sum_of_{}_xfes", self.length)
34    }
35
36    fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> {
37        let entrypoint = self.entrypoint();
38
39        let move_pointer_to_last_word = match self.length {
40            0 => triton_asm!(),
41            n => {
42                let word_length_minus_one = EXTENSION_DEGREE * n - 1;
43                triton_asm!(
44                    push {word_length_minus_one}
45                    add
46                )
47            }
48        };
49
50        let load_all_elements_to_stack =
51            load_words_from_memory_pop_pointer(EXTENSION_DEGREE * self.length);
52
53        let sum = if self.length.is_zero() {
54            triton_asm!(
55                push 0
56                push 0
57                push 0
58            )
59        } else {
60            vec![triton_asm!(xx_add); self.length - 1].concat()
61        };
62
63        triton_asm!(
64            {entrypoint}:
65                // _ *array
66
67                {&move_pointer_to_last_word}
68                // _ *last_word
69
70                {&load_all_elements_to_stack}
71                // _ [elements]
72
73                {&sum}
74                // _ sum
75
76                return
77        )
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use num::One;
84    use num::Zero;
85    use num_traits::ConstZero;
86
87    use super::*;
88    use crate::rust_shadowing_helper_functions::array::insert_as_array;
89    use crate::rust_shadowing_helper_functions::array::insert_random_array;
90    use crate::test_prelude::*;
91
92    impl Function for SumOfXfes {
93        fn rust_shadow(
94            &self,
95            stack: &mut Vec<BFieldElement>,
96            memory: &mut HashMap<BFieldElement, BFieldElement>,
97        ) -> Result<(), RustShadowError> {
98            let mut array_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
99            let mut array_quote_unquote = vec![XFieldElement::zero(); self.length];
100            for array_elem in array_quote_unquote.iter_mut() {
101                for word in array_elem.coefficients.iter_mut() {
102                    memory
103                        .get(&array_pointer)
104                        .unwrap_or(&BFieldElement::ZERO)
105                        .clone_into(word);
106                    array_pointer.increment();
107                }
108            }
109
110            let sum: XFieldElement = array_quote_unquote.into_iter().sum();
111
112            for word in sum.coefficients.into_iter().rev() {
113                stack.push(word);
114            }
115            Ok(())
116        }
117
118        fn pseudorandom_initial_state(
119            &self,
120            seed: [u8; 32],
121            _bench_case: Option<BenchmarkCase>,
122        ) -> crate::traits::function::FunctionInitialState {
123            let mut rng = StdRng::from_seed(seed);
124            let list_pointer = BFieldElement::new(rng.random());
125            self.prepare_state(list_pointer)
126        }
127
128        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
129            let all_zeros = {
130                let mut init_stack = self.init_stack_for_isolated_run();
131                init_stack.push(BFieldElement::new(500));
132                FunctionInitialState {
133                    stack: init_stack,
134                    memory: HashMap::default(),
135                }
136            };
137            let all_ones = {
138                let mut init_stack = self.init_stack_for_isolated_run();
139                let array_pointer = BFieldElement::new(500);
140                init_stack.push(array_pointer);
141                let all_ones = vec![XFieldElement::one(); self.length];
142                let mut memory = HashMap::default();
143                insert_as_array(array_pointer, &mut memory, all_ones);
144
145                FunctionInitialState {
146                    stack: init_stack,
147                    memory,
148                }
149            };
150
151            vec![all_zeros, all_ones]
152        }
153    }
154
155    impl SumOfXfes {
156        fn prepare_state(&self, array_pointer: BFieldElement) -> FunctionInitialState {
157            let mut memory = HashMap::default();
158            insert_random_array(&DataType::Xfe, array_pointer, self.length, &mut memory);
159
160            let mut init_stack = self.init_stack_for_isolated_run();
161            init_stack.push(array_pointer);
162            FunctionInitialState {
163                stack: init_stack,
164                memory,
165            }
166        }
167    }
168
169    #[macro_rules_attr::apply(test)]
170    fn sum_xfes_pbt() {
171        let snippets = (0..20).chain(100..110).map(|x| SumOfXfes { length: x });
172        for test_case in snippets {
173            ShadowedFunction::new(test_case).test()
174        }
175    }
176
177    #[macro_rules_attr::apply(test)]
178    fn xfe_array_sum_unit_test() {
179        let xfes = vec![
180            XFieldElement::new([
181                BFieldElement::new(100),
182                BFieldElement::zero(),
183                BFieldElement::new(10),
184            ]),
185            XFieldElement::new([
186                BFieldElement::new(200),
187                BFieldElement::zero(),
188                BFieldElement::new(4),
189            ]),
190        ];
191
192        let expected_sum = XFieldElement::new([
193            BFieldElement::new(300),
194            BFieldElement::zero(),
195            BFieldElement::new(14),
196        ]);
197        assert_eq!(xfes.iter().cloned().sum::<XFieldElement>(), expected_sum);
198
199        let mut memory = HashMap::default();
200        let array_pointer = BFieldElement::new(1u64 << 44);
201        insert_as_array(array_pointer, &mut memory, xfes);
202
203        let snippet = SumOfXfes { length: 2 };
204        let expected_final_stack = [
205            snippet.init_stack_for_isolated_run(),
206            expected_sum.coefficients.into_iter().rev().collect_vec(),
207        ]
208        .concat();
209        let init_stack = [snippet.init_stack_for_isolated_run(), vec![array_pointer]].concat();
210        test_rust_equivalence_given_complete_state(
211            &ShadowedFunction::new(snippet),
212            &init_stack,
213            &[],
214            &NonDeterminism::default().with_ram(memory),
215            &None,
216            Some(&expected_final_stack),
217        );
218    }
219}
220
221#[cfg(test)]
222mod benches {
223    use super::*;
224    use crate::test_prelude::*;
225
226    #[macro_rules_attr::apply(test)]
227    fn sum_xfes_bench_100() {
228        ShadowedFunction::new(SumOfXfes { length: 100 }).bench();
229    }
230
231    #[macro_rules_attr::apply(test)]
232    fn sum_xfes_bench_200() {
233        ShadowedFunction::new(SumOfXfes { length: 200 }).bench();
234    }
235}