tasm_lib/list/
sum_bfes.rs

1use triton_vm::prelude::*;
2
3use crate::prelude::*;
4
5/// Calculate the sum of the `BFieldElement`s in a list
6#[allow(dead_code)]
7#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
8struct SumOfBfes;
9
10impl BasicSnippet for SumOfBfes {
11    fn inputs(&self) -> Vec<(DataType, String)> {
12        vec![(
13            // For naming the input argument, I just follow what `Rust` calls this argument
14            DataType::List(Box::new(DataType::Bfe)),
15            "self".to_owned(),
16        )]
17    }
18
19    fn outputs(&self) -> Vec<(DataType, String)> {
20        vec![(DataType::Bfe, "sum".to_owned())]
21    }
22
23    fn entrypoint(&self) -> String {
24        format!("tasmlib_list_sum_{}", DataType::Bfe.label_friendly_name())
25    }
26
27    fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> {
28        let entrypoint = self.entrypoint();
29        let accumulate_five_elements_loop_label = format!("{entrypoint}_acc_5_elements");
30
31        let accumulate_five_elements_loop = triton_asm!(
32            // Invariant: _ *end_loop *element acc
33            {accumulate_five_elements_loop_label}:
34
35                dup 2
36                dup 2
37                eq
38                skiz
39                    return
40                // _ *end_loop *element acc
41
42                dup 1
43                read_mem 5
44                // _ *end_loop *element acc [elements] (*element - 5)
45
46                swap 7
47                pop 1
48                // _ *end_loop (*element - 5) acc [elements]
49                // _ *end_loop *element' acc [elements]
50
51                add
52                add
53                add
54                add
55                add
56                // _ *end_loop *element' acc'
57
58                recurse
59        );
60
61        let accumulate_one_element_loop_label = format!("{entrypoint}_acc_1_element");
62        let accumulate_one_element_loop = triton_asm!(
63            // Invariant: _ *end_loop *element acc
64            {accumulate_one_element_loop_label}:
65                dup 2
66                dup 2
67                eq
68                skiz
69                    return
70                // _ *end_loop *element acc
71
72                dup 1
73                read_mem 1
74                swap 3
75                pop 1
76                // _ *end_loop (*element - 1) acc element
77
78                add
79                // _ *end_loop *element' acc'
80
81                recurse
82        );
83
84        triton_asm!(
85            {entrypoint}:
86                // _ *list
87
88                // Get pointer to last element
89                dup 0
90                read_mem 1
91                // _ *list length (*list - 1)
92
93                pop 1
94                // _ *list length
95
96                dup 1
97                dup 1
98                add
99                // _ *list length *last_element
100
101                // Get pointer to *end_loop that is the loop termination condition
102
103                push 5
104                dup 2
105                // _ *list length *last_element 5 length
106
107                div_mod
108                // _ *list length *last_element (length / 5) (length % 5)
109
110                swap 1
111                pop 1
112                // _ *list length *last_element (length % 5)
113
114                dup 3
115                add
116                // _ *list length *last_element *element[length % 5]
117                // _ *list length *last_element *end_loop
118
119                swap 1
120                push 0
121                // _ *list length *end_loop *last_element 0
122
123                call {accumulate_five_elements_loop_label}
124                // _ *list length *end_loop *next_element sum
125
126                swap 1
127                // _ *list length *end_loop sum *next_element
128
129                swap 3
130                // _ *list *next_element *end_loop sum length
131
132                pop 1
133                // _ *list *next_element *end_loop sum
134
135                swap 1
136                pop 1
137                // _ *list *next_element sum
138
139                call {accumulate_one_element_loop_label}
140                // _ *list *list sum
141
142                swap 2
143                pop 2
144                // _ sum
145
146                return
147
148                {&accumulate_five_elements_loop}
149                {&accumulate_one_element_loop}
150        )
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use num_traits::Zero;
157
158    use super::*;
159    use crate::rust_shadowing_helper_functions::list::insert_random_list;
160    use crate::rust_shadowing_helper_functions::list::load_list_with_copy_elements;
161    use crate::test_prelude::*;
162
163    impl Function for SumOfBfes {
164        fn rust_shadow(
165            &self,
166            stack: &mut Vec<BFieldElement>,
167            memory: &mut HashMap<BFieldElement, BFieldElement>,
168        ) {
169            const BFIELDELEMENT_SIZE: usize = 1;
170            let list_pointer = stack.pop().unwrap();
171            let list = load_list_with_copy_elements::<BFIELDELEMENT_SIZE>(list_pointer, memory);
172
173            let sum: BFieldElement = list.into_iter().map(|x| x[0]).sum();
174            stack.push(sum);
175        }
176
177        fn pseudorandom_initial_state(
178            &self,
179            seed: [u8; 32],
180            bench_case: Option<BenchmarkCase>,
181        ) -> FunctionInitialState {
182            let mut rng = StdRng::from_seed(seed);
183            let list_pointer = BFieldElement::new(rng.random());
184            let list_length = match bench_case {
185                Some(BenchmarkCase::CommonCase) => 104,
186                Some(BenchmarkCase::WorstCase) => 1004,
187                None => rng.random_range(0..200),
188            };
189            self.prepare_state(list_pointer, list_length)
190        }
191
192        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
193            (0..13)
194                .map(|len| self.prepare_state(BFieldElement::zero(), len))
195                .collect_vec()
196        }
197    }
198
199    impl SumOfBfes {
200        fn prepare_state(
201            &self,
202            list_pointer: BFieldElement,
203            list_length: usize,
204        ) -> FunctionInitialState {
205            let mut memory = HashMap::default();
206            insert_random_list(&DataType::Bfe, list_pointer, list_length, &mut memory);
207
208            let mut init_stack = self.init_stack_for_isolated_run();
209            init_stack.push(list_pointer);
210            FunctionInitialState {
211                stack: init_stack,
212                memory,
213            }
214        }
215    }
216
217    #[test]
218    fn sum_bfes_pbt() {
219        ShadowedFunction::new(SumOfBfes).test()
220    }
221}
222
223#[cfg(test)]
224mod benches {
225    use super::*;
226    use crate::test_prelude::*;
227
228    #[test]
229    fn benchmark() {
230        ShadowedFunction::new(SumOfBfes).bench();
231    }
232}