tasm_lib/verifier/fri/
barycentric_evaluation.rs

1use triton_vm::prelude::*;
2use twenty_first::math::x_field_element::EXTENSION_DEGREE;
3
4use crate::arithmetic::bfe::primitive_root_of_unity::PrimitiveRootOfUnity;
5use crate::prelude::*;
6
7const MAX_CODEWORD_LENGTH: u32 = 1 << 15;
8
9/// Use the barycentric Lagrange evaluation formula to extrapolate the codeword
10/// to an out-of-domain location. Codeword must have a length that is a power
11/// of 2 and may not exceed `MAX_CODEWORD_LENGTH`.
12#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
13pub struct BarycentricEvaluation;
14
15impl BasicSnippet for BarycentricEvaluation {
16    fn inputs(&self) -> Vec<(DataType, String)> {
17        vec![
18            (
19                DataType::List(Box::new(DataType::Xfe)),
20                "codeword".to_owned(),
21            ),
22            (DataType::Xfe, "indeterminate".to_owned()),
23        ]
24    }
25
26    fn outputs(&self) -> Vec<(DataType, String)> {
27        vec![(DataType::Xfe, "evaluation_result".to_owned())]
28    }
29
30    fn entrypoint(&self) -> String {
31        "tasmlib_verifier_fri_barycentric_evaluation".to_owned()
32    }
33
34    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
35        let entrypoint = self.entrypoint();
36        let generator = library.import(Box::new(PrimitiveRootOfUnity));
37        let partial_terms_alloc = library.kmalloc(MAX_CODEWORD_LENGTH * EXTENSION_DEGREE as u32);
38
39        // Partial terms are of the form $ g^i / (g^i - indeterminate) $
40        let calculate_and_store_partial_terms_loop_label =
41            format!("{entrypoint}_partial_terms_loop");
42        let calculate_and_store_partial_terms_code = triton_asm!(
43            // BEGIN:     _ *partial_terms_end_cond *partial_terms[0] generator [-indeterminate] 1
44            // INVARIANT: _ *partial_terms_end_cond *partial_terms[n] generator [-indeterminate] generator_acc
45            {calculate_and_store_partial_terms_loop_label}:
46                dup 3
47                dup 3
48                dup 3
49                dup 3
50                add
51                // _ *partial_terms_end_cond *partial_terms generator [-indeterminate] generator_acc [generator_acc-indeterminate]
52
53                x_invert
54                // _ *partial_terms_end_cond *partial_terms generator [-indeterminate] generator_acc [1/(generator_acc-indeterminate)]
55
56                dup 3
57                xb_mul
58                // _ *partial_terms_end_cond *partial_terms generator [-indeterminate] generator_acc [generator_acc/(generator_acc-indeterminate)]
59
60                pick 8
61                write_mem {EXTENSION_DEGREE}
62                place 5
63                // _ *partial_terms_end_cond *partial_terms' generator [-indeterminate] generator_acc
64
65                dup 4
66                mul
67                // _ *partial_terms_end_cond *partial_terms' generator [-indeterminate] generator_acc'
68
69                recurse_or_return
70        );
71
72        // The numerator is the sum of codeword[n] * partial_terms[n], i.e. the
73        // inner product of the codeword and the partial terms.
74        let numerator_from_partial_sums_loop_label =
75            format!("{entrypoint}_numerator_from_partial_sums");
76        let numerator_from_partial_sums_loop_code = triton_asm!(
77            // BEGIN:     _ *ptec *partial_terms[0] 0 [0; 3]   *codeword[0]
78            // INVARIANT: _ *ptec *partial_terms[n] 0 [acc; 3] *codeword[n]
79            {numerator_from_partial_sums_loop_label}:
80                pick 5
81                xx_dot_step
82                place 5
83                recurse_or_return
84        );
85
86        // The denominator is the sum of all terms in `partial_terms`
87        let denominator_from_partial_sums_loop_label =
88            format!("{entrypoint}_denominator_from_partial_sums");
89        let denominator_from_partial_sums_loop_code = triton_asm!(
90            // START:     _ (*partial_terms-1) *partial_terms_last_word 0 0 [0]
91            // INVARIANT: _ (*partial_terms-1) *partial_terms[n]        0 0 [acc]
92            {denominator_from_partial_sums_loop_label}:
93                // _ (*partial_terms-1) *partial_terms[n] 0 0 [acc]
94
95                pick 5
96                read_mem {EXTENSION_DEGREE}
97                place 8
98                // _ (*partial_terms-1) *partial_terms[n-1] 0 0 [acc] [term]
99
100                xx_add
101                // _ (*partial_terms-1) *partial_terms[n-1] 0 0 [acc']
102
103                recurse_or_return
104        );
105
106        triton_asm!(
107            {entrypoint}:
108                // _ *codeword [indeterminate]
109
110                push -1
111                xb_mul
112                hint neg_indeterminate = stack[0..3]
113                // _ *codeword [-indeterminate]
114
115                /* Prepare stack for call to partial terms' loop */
116                dup 3
117                read_mem 1
118                pop 1
119                // _ *codeword [-indeterminate] codeword_len
120
121                /* assert `codeword_len <= MAX_CODEWORD_LENGTH` */
122                push {MAX_CODEWORD_LENGTH + 1}
123                dup 1
124                lt
125                assert
126                // _ *codeword [-indeterminate] codeword_len
127
128                push 0
129                dup 1
130                // _ *codeword [-indeterminate] codeword_len [codeword_len; as u64]
131
132                /* This call to get the generator will fail if codeword_len
133                   exceeds u32::MAX, or, if it is not a power of 2. Which is
134                   desired behavior. */
135                call {generator}
136                hint generator: BFieldElement = stack[0]
137                // _ *codeword [-indeterminate] codeword_len generator
138
139                pick 1
140                // _ *codeword [-indeterminate] generator codeword_len
141
142                push {EXTENSION_DEGREE}
143                mul
144                push {partial_terms_alloc.write_address()}
145                add
146                // _ *codeword [-indeterminate] generator (*partial_terms_last_word+1)
147
148                place 4
149                place 3
150                push {partial_terms_alloc.write_address()}
151                place 4
152                // _ *codeword (*partial_terms_last_word+1) *partial_terms[0] generator [-indeterminate]
153
154                push 1
155                // _ *codeword (*partial_terms_last_word+1) *partial_terms[0] generator [-indeterminate] generator_acc
156
157                call {calculate_and_store_partial_terms_loop_label}
158                // _ *codeword (*partial_terms_last_word+1) (*partial_terms_last_word+1) generator [-indeterminate] generator_acc
159
160                pop 5
161                pop 1
162                // _ *codeword (*partial_terms_last_word+1)
163
164                place 1
165                push {partial_terms_alloc.write_address()}
166                push 0
167                push 0 push 0 push 0
168                // _ (*partial_terms_last_word+1) *codeword *partial_terms[0] 0 [acc; 3]
169
170                pick 5
171                addi 1
172                // _ (*partial_terms_last_word+1) *partial_terms[0] 0 [acc; 3] *codeword[0]
173
174                call {numerator_from_partial_sums_loop_label}
175                hint numerator: XFieldElement = stack[1..4]
176                // _ (*partial_terms_last_word+1) (*partial_terms_last_word+1) 0 [numerator; 3] *codeword[last + 1]
177
178                pick 4
179                pick 5
180                pop 3
181                // _ (*partial_terms_last_word+1) [numerator; 3]
182
183                pick 3
184                addi -1
185                // _ [numerator; 3] *partial_terms_last_word
186
187                push 0
188                push 0
189                push 0
190                push 0
191                push 0
192                // _ [numerator; 3] *partial_terms_last_word 0 0 [0]
193
194                push {partial_terms_alloc.write_address() - bfe!(1)}
195                hint loop_end_condition = stack[0]
196                place 6
197                // _ [numerator; 3] (*partial_terms - 1) *partial_terms_last_word 0 0 [0]
198
199                call {denominator_from_partial_sums_loop_label}
200                hint denominator: XFieldElement = stack[0..2]
201                // _ [numerator; 3] (*partial_terms - 1) (*partial_terms - 1) 0 0 [denominator]
202
203                place 6
204                place 6
205                place 6
206                pop 4
207                // _ [numerator; 3] [denominator]
208
209                x_invert
210                xx_mul
211                // _ [numerator / denominator]
212                // _ [result]                  <-- rename
213
214                return
215
216            {&calculate_and_store_partial_terms_code}
217            {&numerator_from_partial_sums_loop_code}
218            {&denominator_from_partial_sums_loop_code}
219        )
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use twenty_first::math::other::random_elements;
226    use twenty_first::math::polynomial::barycentric_evaluate;
227    use twenty_first::math::traits::PrimitiveRootOfUnity;
228
229    use super::*;
230    use crate::library::STATIC_MEMORY_FIRST_ADDRESS;
231    use crate::rust_shadowing_helper_functions::list::list_insert;
232    use crate::rust_shadowing_helper_functions::list::load_list_with_copy_elements;
233    use crate::test_prelude::*;
234
235    #[test]
236    fn barycentric_evaluation_pbt() {
237        ShadowedFunction::new(BarycentricEvaluation).test()
238    }
239
240    impl BarycentricEvaluation {
241        fn prepare_state(
242            &self,
243            codeword: Vec<XFieldElement>,
244            codeword_pointer: BFieldElement,
245            indeterminate: XFieldElement,
246        ) -> FunctionInitialState {
247            let mut memory = HashMap::default();
248            list_insert(codeword_pointer, codeword, &mut memory);
249
250            let mut stack = self.init_stack_for_isolated_run();
251            stack.push(codeword_pointer);
252
253            for word in indeterminate.coefficients.into_iter().rev() {
254                stack.push(word);
255            }
256
257            FunctionInitialState { stack, memory }
258        }
259    }
260
261    impl Function for BarycentricEvaluation {
262        fn rust_shadow(
263            &self,
264            stack: &mut Vec<BFieldElement>,
265            memory: &mut HashMap<BFieldElement, BFieldElement>,
266        ) {
267            let indeterminate = XFieldElement::new([
268                stack.pop().unwrap(),
269                stack.pop().unwrap(),
270                stack.pop().unwrap(),
271            ]);
272            let codeword_pointer = stack.pop().unwrap();
273            let codeword =
274                load_list_with_copy_elements::<EXTENSION_DEGREE>(codeword_pointer, memory);
275            let codeword_length: u32 = codeword.len().try_into().unwrap();
276            assert!(codeword_length <= MAX_CODEWORD_LENGTH);
277
278            let codeword: Vec<XFieldElement> = codeword.into_iter().map(|x| x.into()).collect_vec();
279            let result = barycentric_evaluate(&codeword, indeterminate);
280
281            // Emulate effect on memory
282            let generator = BFieldElement::primitive_root_of_unity(codeword.len() as u64).unwrap();
283            let mut partial_terms_pointer = STATIC_MEMORY_FIRST_ADDRESS
284                - bfe!(MAX_CODEWORD_LENGTH * EXTENSION_DEGREE as u32 - 1);
285            let mut gen_acc = bfe!(1);
286            for _ in 0..codeword_length {
287                let n = gen_acc;
288                let d = gen_acc - indeterminate;
289                let term: XFieldElement = d.inverse() * n;
290                memory.insert(partial_terms_pointer, term.coefficients[0]);
291                partial_terms_pointer.increment();
292                memory.insert(partial_terms_pointer, term.coefficients[1]);
293                partial_terms_pointer.increment();
294                memory.insert(partial_terms_pointer, term.coefficients[2]);
295                partial_terms_pointer.increment();
296
297                gen_acc *= generator;
298            }
299
300            for word in result.coefficients.into_iter().rev() {
301                stack.push(word);
302            }
303        }
304
305        fn pseudorandom_initial_state(
306            &self,
307            seed: [u8; 32],
308            bench_case: Option<BenchmarkCase>,
309        ) -> FunctionInitialState {
310            let mut rng = StdRng::from_seed(seed);
311            let codeword_length = match bench_case {
312                Some(BenchmarkCase::CommonCase) => 256,
313                Some(BenchmarkCase::WorstCase) => 512,
314                None => 1 << rng.random_range(0..=14),
315            };
316
317            let codeword_pointer = rng.random_range(0..=(1u64 << 34));
318            let codeword_pointer = bfe!(codeword_pointer);
319            let indeterminate: XFieldElement = rng.random();
320            let codeword = random_elements(codeword_length);
321
322            self.prepare_state(codeword, codeword_pointer, indeterminate)
323        }
324
325        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
326            let some_indeterminate = XFieldElement::new([bfe!(1555), bfe!(1556), bfe!(1557)]);
327            let some_codeword_pointer = bfe!(19191919);
328            let codeword_of_length_one =
329                self.prepare_state(xfe_vec![155], some_codeword_pointer, some_indeterminate);
330            let const_codeword_of_length_two = self.prepare_state(
331                xfe_vec![155, 155],
332                some_codeword_pointer,
333                some_indeterminate,
334            );
335            let non_const_codeword_of_length_two = self.prepare_state(
336                xfe_vec![155, 1_919_191_919],
337                some_codeword_pointer,
338                some_indeterminate,
339            );
340            let const_codeword_of_length_8 =
341                self.prepare_state(xfe_vec![155; 8], some_codeword_pointer, some_indeterminate);
342
343            vec![
344                codeword_of_length_one,
345                const_codeword_of_length_two,
346                non_const_codeword_of_length_two,
347                const_codeword_of_length_8,
348            ]
349        }
350    }
351}
352
353#[cfg(test)]
354mod benches {
355    use super::*;
356    use crate::test_prelude::*;
357
358    #[test]
359    fn bench_barycentric_evaluation() {
360        ShadowedFunction::new(BarycentricEvaluation).bench();
361    }
362}