Skip to main content

tasm_lib/array/
sum_of_bfes.rs

1use num::Zero;
2use triton_vm::prelude::*;
3
4use crate::data_type::ArrayType;
5use crate::memory::load_words_from_memory_pop_pointer;
6use crate::prelude::*;
7
8#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
9pub struct SumOfBfes {
10    length: usize,
11}
12
13impl BasicSnippet for SumOfBfes {
14    fn parameters(&self) -> Vec<(DataType, String)> {
15        vec![(
16            DataType::Array(Box::new(ArrayType {
17                element_type: DataType::Bfe,
18                length: self.length,
19            })),
20            "*array".to_owned(),
21        )]
22    }
23
24    fn return_values(&self) -> Vec<(DataType, String)> {
25        vec![(DataType::Bfe, "sum".to_owned())]
26    }
27
28    fn entrypoint(&self) -> String {
29        format!("tasmlib_array_sum_of_{}_bfes", self.length)
30    }
31
32    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
33        let move_pointer_to_last_word = match self.length {
34            0 | 1 => triton_asm!(),
35            n => triton_asm!( addi {n - 1} ),
36        };
37
38        let load_all_elements_to_stack = load_words_from_memory_pop_pointer(self.length);
39
40        let sum = if self.length.is_zero() {
41            triton_asm!(push 0)
42        } else {
43            triton_asm![add; self.length - 1]
44        };
45
46        triton_asm!(
47            {self.entrypoint()}:
48                // _ *array
49
50                {&move_pointer_to_last_word}
51                // _ *last_word
52
53                {&load_all_elements_to_stack}
54                // _ [elements]
55
56                {&sum}
57                // _ sum
58
59                return
60        )
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use num_traits::ConstZero;
67
68    use super::*;
69    use crate::rust_shadowing_helper_functions::array::insert_random_array;
70    use crate::test_prelude::*;
71
72    impl Function for SumOfBfes {
73        fn rust_shadow(
74            &self,
75            stack: &mut Vec<BFieldElement>,
76            memory: &mut HashMap<BFieldElement, BFieldElement>,
77        ) -> Result<(), RustShadowError> {
78            let mut array_pointer = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
79            let mut array_quote_unquote = bfe_vec![0; self.length];
80            for array_elem in array_quote_unquote.iter_mut() {
81                memory
82                    .get(&array_pointer)
83                    .unwrap_or(&BFieldElement::ZERO)
84                    .clone_into(array_elem);
85                array_pointer.increment();
86            }
87
88            let sum = array_quote_unquote.into_iter().sum();
89
90            stack.push(sum);
91            Ok(())
92        }
93
94        fn pseudorandom_initial_state(
95            &self,
96            seed: [u8; 32],
97            _: Option<BenchmarkCase>,
98        ) -> FunctionInitialState {
99            self.prepare_state(StdRng::from_seed(seed).random())
100        }
101
102        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
103            let all_zeros = {
104                let mut init_stack = self.init_stack_for_isolated_run();
105                init_stack.push(BFieldElement::new(500));
106                FunctionInitialState {
107                    stack: init_stack,
108                    memory: HashMap::default(),
109                }
110            };
111
112            vec![all_zeros]
113        }
114    }
115
116    impl SumOfBfes {
117        fn prepare_state(&self, array_pointer: BFieldElement) -> FunctionInitialState {
118            let mut memory = HashMap::default();
119            insert_random_array(&DataType::Bfe, array_pointer, self.length, &mut memory);
120
121            let mut stack = self.init_stack_for_isolated_run();
122            stack.push(array_pointer);
123
124            FunctionInitialState { stack, memory }
125        }
126    }
127
128    #[macro_rules_attr::apply(test)]
129    fn sum_bfes_pbt() {
130        let snippets = (0..20).chain(100..110).map(|x| SumOfBfes { length: x });
131        for test_case in snippets {
132            ShadowedFunction::new(test_case).test()
133        }
134    }
135}
136
137#[cfg(test)]
138mod benches {
139    use super::*;
140    use crate::test_prelude::*;
141
142    #[macro_rules_attr::apply(test)]
143    fn sum_bfes_bench_100() {
144        ShadowedFunction::new(SumOfBfes { length: 100 }).bench();
145    }
146
147    #[macro_rules_attr::apply(test)]
148    fn sum_bfes_bench_200() {
149        ShadowedFunction::new(SumOfBfes { length: 200 }).bench();
150    }
151}