tasm_lib/array/
inner_product_of_xfes.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::data_type::ArrayType;
6use crate::prelude::*;
7use crate::traits::basic_snippet::Reviewer;
8use crate::traits::basic_snippet::SignOffFingerprint;
9
10/// Compute the inner product of two lists of [`XFieldElement`]s.
11///
12/// ### Behavior
13///
14/// ```text
15/// BEFORE: _ *a *b
16/// AFTER:  _ [inner_product: XFieldElement]
17/// ```
18///
19/// ### Preconditions
20///
21/// None.
22///
23/// ### Postconditions
24///
25/// None.
26#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
27pub struct InnerProductOfXfes {
28    pub length: usize,
29}
30
31impl InnerProductOfXfes {
32    pub fn new(length: usize) -> Self {
33        Self { length }
34    }
35}
36
37impl BasicSnippet for InnerProductOfXfes {
38    fn inputs(&self) -> Vec<(DataType, String)> {
39        let argument_type = DataType::Array(Box::new(ArrayType {
40            element_type: DataType::Xfe,
41            length: self.length,
42        }));
43
44        vec![
45            (argument_type.clone(), "*a".to_owned()),
46            (argument_type, "*b".to_owned()),
47        ]
48    }
49
50    fn outputs(&self) -> Vec<(DataType, String)> {
51        vec![(DataType::Xfe, "inner_product".to_owned())]
52    }
53
54    fn entrypoint(&self) -> String {
55        format!("tasmlib_array_inner_product_of_{}_xfes", self.length)
56    }
57
58    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
59        triton_asm!(
60            // BEFORE: _ *a *b
61            // AFTER:  _ [inner_product: XFieldElement]
62            {self.entrypoint()}:
63
64                push 0
65                push 0
66                push 0
67                // _ *a *b [0: XFE]
68
69                pick 4
70                pick 4
71                // _ [0: XFE] *a *b
72
73                {&triton_asm![xx_dot_step; self.length]}
74                // _ [acc: XFE] *garbage0 *garbage1
75
76                pop 2
77                // _ [acc: XFE]
78
79                return
80        )
81    }
82
83    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
84        let mut sign_offs = HashMap::new();
85
86        if self.length == 4 {
87            sign_offs.insert(Reviewer("ferdinand"), 0x154bf4aa5a53bef7.into());
88        }
89
90        sign_offs
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use crate::rust_shadowing_helper_functions::array::array_from_memory;
98    use crate::rust_shadowing_helper_functions::array::insert_as_array;
99    use crate::rust_shadowing_helper_functions::array::insert_random_array;
100    use crate::test_prelude::*;
101
102    impl Accessor for InnerProductOfXfes {
103        fn rust_shadow(
104            &self,
105            stack: &mut Vec<BFieldElement>,
106            memory: &HashMap<BFieldElement, BFieldElement>,
107        ) {
108            let b = array_from_memory::<XFieldElement>(stack.pop().unwrap(), self.length, memory);
109            let a = array_from_memory::<XFieldElement>(stack.pop().unwrap(), self.length, memory);
110            let inner_product: XFieldElement = a.into_iter().zip(b).map(|(a, b)| a * b).sum();
111
112            push_encodable(stack, &inner_product);
113        }
114
115        fn pseudorandom_initial_state(
116            &self,
117            seed: [u8; 32],
118            _: Option<BenchmarkCase>,
119        ) -> AccessorInitialState {
120            let mut rng = StdRng::from_seed(seed);
121            let pointer_a = rng.random();
122            let pointer_b_offset = rng.random_range(self.length..usize::MAX - self.length);
123            let pointer_b = pointer_a + bfe!(pointer_b_offset);
124
125            let mut memory = HashMap::default();
126            insert_random_array(&DataType::Xfe, pointer_a, self.length, &mut memory);
127            insert_random_array(&DataType::Xfe, pointer_b, self.length, &mut memory);
128
129            let mut stack = self.init_stack_for_isolated_run();
130            stack.push(pointer_a);
131            stack.push(pointer_b);
132
133            AccessorInitialState { stack, memory }
134        }
135
136        fn corner_case_initial_states(&self) -> Vec<AccessorInitialState> {
137            let all_zeros = AccessorInitialState {
138                stack: [self.init_stack_for_isolated_run(), bfe_vec![0, 1_u64 << 40]].concat(),
139                memory: HashMap::default(),
140            };
141
142            vec![all_zeros]
143        }
144    }
145
146    #[test]
147    fn inner_product_of_xfes_pbt() {
148        for test_case in (0..20).chain(100..110).map(InnerProductOfXfes::new) {
149            ShadowedAccessor::new(test_case).test()
150        }
151    }
152
153    #[test]
154    fn inner_product_unit_test() {
155        let a = xfe_vec![[3, 0, 0], [5, 0, 0]];
156        let b = xfe_vec![[501, 0, 0], [1003, 0, 0]];
157        let inner_product = xfe!([3 * 501 + 5 * 1003, 0, 0]);
158
159        let rust_inner_product = a
160            .iter()
161            .zip(&b)
162            .map(|(&a, &b)| a * b)
163            .sum::<XFieldElement>();
164        debug_assert_eq!(inner_product, rust_inner_product);
165
166        let mut memory = HashMap::default();
167        let pointer_a = bfe!(1_u64 << 44);
168        let pointer_b = bfe!(1_u64 << 45);
169        insert_as_array(pointer_a, &mut memory, a);
170        insert_as_array(pointer_b, &mut memory, b);
171
172        let snippet = InnerProductOfXfes::new(2);
173        let mut initial_stack = snippet.init_stack_for_isolated_run();
174        initial_stack.push(pointer_a);
175        initial_stack.push(pointer_b);
176
177        let mut expected_final_stack = snippet.init_stack_for_isolated_run();
178        push_encodable(&mut expected_final_stack, &inner_product);
179
180        test_rust_equivalence_given_complete_state(
181            &ShadowedAccessor::new(snippet),
182            &initial_stack,
183            &[],
184            &NonDeterminism::default().with_ram(memory),
185            &None,
186            Some(&expected_final_stack),
187        );
188    }
189}
190
191#[cfg(test)]
192mod benches {
193    use triton_vm::table::master_table::MasterAuxTable;
194    use triton_vm::table::master_table::MasterMainTable;
195
196    use super::*;
197    use crate::test_prelude::*;
198
199    #[test]
200    fn benchmark() {
201        ShadowedAccessor::new(InnerProductOfXfes::new(100)).bench();
202        ShadowedAccessor::new(InnerProductOfXfes::new(200)).bench();
203
204        let num_columns = MasterMainTable::NUM_COLUMNS + MasterAuxTable::NUM_COLUMNS;
205        ShadowedAccessor::new(InnerProductOfXfes::new(num_columns)).bench();
206        ShadowedAccessor::new(InnerProductOfXfes::new(MasterAuxTable::NUM_CONSTRAINTS)).bench();
207    }
208}