Skip to main content

tasm_lib/array/
inner_product_of_xfes.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::data_type::ArrayType;
6use crate::prelude::*;
7use crate::traits::basic_snippet::Reviewer;
8use crate::traits::basic_snippet::SignOffFingerprint;
9
10/// Compute the inner product of two lists of [`XFieldElement`]s.
11///
12/// ### Behavior
13///
14/// ```text
15/// BEFORE: _ *a *b
16/// AFTER:  _ [inner_product: XFieldElement]
17/// ```
18///
19/// ### Preconditions
20///
21/// None.
22///
23/// ### Postconditions
24///
25/// None.
26#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
27pub struct InnerProductOfXfes {
28    pub length: usize,
29}
30
31impl InnerProductOfXfes {
32    pub fn new(length: usize) -> Self {
33        Self { length }
34    }
35}
36
37impl BasicSnippet for InnerProductOfXfes {
38    fn parameters(&self) -> Vec<(DataType, String)> {
39        let argument_type = DataType::Array(Box::new(ArrayType {
40            element_type: DataType::Xfe,
41            length: self.length,
42        }));
43
44        vec![
45            (argument_type.clone(), "*a".to_owned()),
46            (argument_type, "*b".to_owned()),
47        ]
48    }
49
50    fn return_values(&self) -> Vec<(DataType, String)> {
51        vec![(DataType::Xfe, "inner_product".to_owned())]
52    }
53
54    fn entrypoint(&self) -> String {
55        format!("tasmlib_array_inner_product_of_{}_xfes", self.length)
56    }
57
58    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
59        triton_asm!(
60            // BEFORE: _ *a *b
61            // AFTER:  _ [inner_product: XFieldElement]
62            {self.entrypoint()}:
63
64                push 0
65                push 0
66                push 0
67                // _ *a *b [0: XFE]
68
69                pick 4
70                pick 4
71                // _ [0: XFE] *a *b
72
73                {&triton_asm![xx_dot_step; self.length]}
74                // _ [acc: XFE] *garbage0 *garbage1
75
76                pop 2
77                // _ [acc: XFE]
78
79                return
80        )
81    }
82
83    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
84        let mut sign_offs = HashMap::new();
85
86        if self.length == 4 {
87            sign_offs.insert(Reviewer("ferdinand"), 0x6c3e24944691423f.into());
88        }
89
90        sign_offs
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use crate::rust_shadowing_helper_functions::array::array_from_memory;
98    use crate::rust_shadowing_helper_functions::array::insert_as_array;
99    use crate::rust_shadowing_helper_functions::array::insert_random_array;
100    use crate::test_prelude::*;
101
102    impl Accessor for InnerProductOfXfes {
103        fn rust_shadow(
104            &self,
105            stack: &mut Vec<BFieldElement>,
106            memory: &HashMap<BFieldElement, BFieldElement>,
107        ) -> Result<(), RustShadowError> {
108            let b = array_from_memory::<XFieldElement>(
109                stack.pop().ok_or(RustShadowError::StackUnderflow)?,
110                self.length,
111                memory,
112            );
113            let a = array_from_memory::<XFieldElement>(
114                stack.pop().ok_or(RustShadowError::StackUnderflow)?,
115                self.length,
116                memory,
117            );
118            let inner_product: XFieldElement = a.into_iter().zip(b).map(|(a, b)| a * b).sum();
119
120            push_encodable(stack, &inner_product);
121            Ok(())
122        }
123
124        fn pseudorandom_initial_state(
125            &self,
126            seed: [u8; 32],
127            _: Option<BenchmarkCase>,
128        ) -> AccessorInitialState {
129            let mut rng = StdRng::from_seed(seed);
130            let pointer_a = rng.random();
131            let pointer_b_offset = rng.random_range(self.length..usize::MAX - self.length);
132            let pointer_b = pointer_a + bfe!(pointer_b_offset);
133
134            let mut memory = HashMap::default();
135            insert_random_array(&DataType::Xfe, pointer_a, self.length, &mut memory);
136            insert_random_array(&DataType::Xfe, pointer_b, self.length, &mut memory);
137
138            let mut stack = self.init_stack_for_isolated_run();
139            stack.push(pointer_a);
140            stack.push(pointer_b);
141
142            AccessorInitialState { stack, memory }
143        }
144
145        fn corner_case_initial_states(&self) -> Vec<AccessorInitialState> {
146            let all_zeros = AccessorInitialState {
147                stack: [self.init_stack_for_isolated_run(), bfe_vec![0, 1_u64 << 40]].concat(),
148                memory: HashMap::default(),
149            };
150
151            vec![all_zeros]
152        }
153    }
154
155    #[macro_rules_attr::apply(test)]
156    fn inner_product_of_xfes_pbt() {
157        for test_case in (0..20).chain(100..110).map(InnerProductOfXfes::new) {
158            ShadowedAccessor::new(test_case).test()
159        }
160    }
161
162    #[macro_rules_attr::apply(test)]
163    fn inner_product_unit_test() {
164        let a = xfe_vec![[3, 0, 0], [5, 0, 0]];
165        let b = xfe_vec![[501, 0, 0], [1003, 0, 0]];
166        let inner_product = xfe!([3 * 501 + 5 * 1003, 0, 0]);
167
168        let rust_inner_product = a
169            .iter()
170            .zip(&b)
171            .map(|(&a, &b)| a * b)
172            .sum::<XFieldElement>();
173        debug_assert_eq!(inner_product, rust_inner_product);
174
175        let mut memory = HashMap::default();
176        let pointer_a = bfe!(1_u64 << 44);
177        let pointer_b = bfe!(1_u64 << 45);
178        insert_as_array(pointer_a, &mut memory, a);
179        insert_as_array(pointer_b, &mut memory, b);
180
181        let snippet = InnerProductOfXfes::new(2);
182        let mut initial_stack = snippet.init_stack_for_isolated_run();
183        initial_stack.push(pointer_a);
184        initial_stack.push(pointer_b);
185
186        let mut expected_final_stack = snippet.init_stack_for_isolated_run();
187        push_encodable(&mut expected_final_stack, &inner_product);
188
189        test_rust_equivalence_given_complete_state(
190            &ShadowedAccessor::new(snippet),
191            &initial_stack,
192            &[],
193            &NonDeterminism::default().with_ram(memory),
194            &None,
195            Some(&expected_final_stack),
196        );
197    }
198}
199
200#[cfg(test)]
201mod benches {
202    use triton_vm::table::master_table::MasterAuxTable;
203    use triton_vm::table::master_table::MasterMainTable;
204
205    use super::*;
206    use crate::test_prelude::*;
207
208    #[macro_rules_attr::apply(test)]
209    fn benchmark() {
210        ShadowedAccessor::new(InnerProductOfXfes::new(100)).bench();
211        ShadowedAccessor::new(InnerProductOfXfes::new(200)).bench();
212
213        let num_columns = MasterMainTable::NUM_COLUMNS + MasterAuxTable::NUM_COLUMNS;
214        ShadowedAccessor::new(InnerProductOfXfes::new(num_columns)).bench();
215        ShadowedAccessor::new(InnerProductOfXfes::new(MasterAuxTable::NUM_CONSTRAINTS)).bench();
216    }
217}