tasm_lib/verifier/master_table/
divide_out_zerofiers.rs

1use triton_vm::prelude::LabelledInstruction;
2use triton_vm::prelude::*;
3use triton_vm::table::ConstraintType;
4use triton_vm::table::master_table::MasterAuxTable;
5use twenty_first::math::x_field_element::EXTENSION_DEGREE;
6
7use crate::prelude::*;
8use crate::verifier::master_table::air_constraint_evaluation::AirConstraintEvaluation;
9use crate::verifier::master_table::zerofiers_inverse::ZerofiersInverse;
10
11/// Takes an AIR evaluation and divides out the zerofiers.
12#[derive(Debug, Clone)]
13pub struct DivideOutZerofiers;
14
15impl BasicSnippet for DivideOutZerofiers {
16    fn inputs(&self) -> Vec<(DataType, String)> {
17        vec![
18            (DataType::VoidPointer, "*air_evaluation_result".to_string()),
19            (DataType::Xfe, "out_of_domain_point_curr_row".to_owned()),
20            (DataType::U32, "padded_height".to_owned()),
21            (DataType::Bfe, "trace_domain_generator".to_owned()),
22        ]
23    }
24
25    fn outputs(&self) -> Vec<(DataType, String)> {
26        vec![(
27            AirConstraintEvaluation::output_type(),
28            "*quotient_summands".to_owned(),
29        )]
30    }
31
32    fn entrypoint(&self) -> String {
33        "tasmlib_verifier_master_table_divide_out_zerofiers".to_owned()
34    }
35
36    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
37        let entrypoint = self.entrypoint();
38
39        let zerofiers_inverse_alloc =
40            library.kmalloc(ZerofiersInverse::array_size().try_into().unwrap());
41        let zerofiers_inverse_snippet = ZerofiersInverse {
42            zerofiers_inverse_write_address: zerofiers_inverse_alloc.write_address(),
43        };
44        let zerofiers_inverse = library.import(Box::new(zerofiers_inverse_snippet));
45
46        let read_all_air_elements = vec![
47            triton_asm!(
48                // _ *air_elem
49                read_mem { EXTENSION_DEGREE } // _ [air_elem] *air_elem_prev
50            );
51            MasterAuxTable::NUM_CONSTRAINTS
52        ]
53        .concat();
54
55        let mul_and_write = |constraint_type: ConstraintType, num_constraints: usize| {
56            vec![
57                triton_asm!(
58                    // _ [[air_elem]] *air_elem
59
60                    swap 3
61                    swap 2
62                    swap 1
63                    // _ *air_elem [[air_elem]]
64
65                    push {zerofiers_inverse_snippet.zerofier_inv_read_address(constraint_type)}
66                    read_mem {EXTENSION_DEGREE}
67                    pop 1
68
69                    xx_mul
70                    // _ [[air_elem] *air_elem [[air_elem * z_inv]]]
71
72                    swap 1
73                    swap 2
74                    swap 3
75                    // _ [[air_elem] [[air_elem * z_inv]]] *air_elem
76
77                    write_mem {EXTENSION_DEGREE}
78                    // _ [[air_elem] *air_elem_next
79
80                );
81                num_constraints
82            ]
83            .concat()
84        };
85
86        let mul_and_write = [
87            mul_and_write(
88                ConstraintType::Initial,
89                MasterAuxTable::NUM_INITIAL_CONSTRAINTS,
90            ),
91            mul_and_write(
92                ConstraintType::Consistency,
93                MasterAuxTable::NUM_CONSISTENCY_CONSTRAINTS,
94            ),
95            mul_and_write(
96                ConstraintType::Transition,
97                MasterAuxTable::NUM_TRANSITION_CONSTRAINTS,
98            ),
99            mul_and_write(
100                ConstraintType::Terminal,
101                MasterAuxTable::NUM_TERMINAL_CONSTRAINTS,
102            ),
103        ]
104        .concat();
105
106        let jump_to_last_air_element = triton_asm!(
107            // _ *air_elem[0]
108
109            push {MasterAuxTable::NUM_CONSTRAINTS * EXTENSION_DEGREE - 1}
110            add
111            // _ *air_elem_last_word
112        );
113
114        let jump_to_first_air_element = triton_asm!(
115            // _ *air_elem[n]
116
117            push {-((MasterAuxTable::NUM_CONSTRAINTS * EXTENSION_DEGREE) as i32)}
118            add
119            // _ *air_elem_last_word
120        );
121
122        triton_asm!(
123            {entrypoint}:
124                // _ *air_evaluation_result [out_of_domain_point_curr_row] padded_height trace_domain_generator
125
126                call {zerofiers_inverse}
127                // _ *air_evaluation_result
128
129                {&jump_to_last_air_element}
130                // _ *air_elem_last_word
131
132                {&read_all_air_elements}
133                // _ [[air_elem]] (*air_constraints - 1)
134
135                push 1 add
136                // _ [[air_elem]] *air_constraints
137
138                {&mul_and_write}
139                // _ (*air_elem_last + 3)
140
141                {&jump_to_first_air_element}
142
143                return
144        )
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use std::collections::HashMap;
151
152    use itertools::Itertools;
153    use rand::prelude::*;
154    use twenty_first::math::traits::ModPowU32;
155    use twenty_first::math::traits::PrimitiveRootOfUnity;
156
157    use super::*;
158    use crate::empty_stack;
159    use crate::execute_test;
160
161    #[test]
162    fn divide_out_zerofiers_test() {
163        let snippet = DivideOutZerofiers;
164        let mut seed: [u8; 32] = [0u8; 32];
165        rand::rng().fill_bytes(&mut seed);
166        snippet.test_equivalence_with_host_machine(seed);
167    }
168
169    impl DivideOutZerofiers {
170        fn test_equivalence_with_host_machine(&self, seed: [u8; 32]) {
171            let mut rng = StdRng::from_seed(seed);
172            let (air_evaluation_result, ood_point_curr_row, padded_height, trace_domain_generator) =
173                Self::random_input_values(&mut rng);
174
175            let rust_result = Self::rust_result(
176                air_evaluation_result,
177                ood_point_curr_row,
178                padded_height,
179                trace_domain_generator,
180            );
181
182            let tasm_result = self.tasm_result(
183                air_evaluation_result,
184                ood_point_curr_row,
185                padded_height,
186                trace_domain_generator,
187            );
188
189            assert_eq!(tasm_result.len(), rust_result.len());
190            assert_eq!(
191                tasm_result.iter().copied().sum::<XFieldElement>(),
192                rust_result.iter().copied().sum::<XFieldElement>(),
193                "\ntasm: [{},...]\nrust: [{},...]",
194                tasm_result.iter().take(3).join(","),
195                rust_result.iter().take(3).join(",")
196            );
197            assert_eq!(tasm_result, rust_result);
198        }
199
200        pub(super) fn random_input_values(
201            rng: &mut StdRng,
202        ) -> (
203            [XFieldElement; MasterAuxTable::NUM_CONSTRAINTS],
204            XFieldElement,
205            u32,
206            BFieldElement,
207        ) {
208            let air_evaluation_result =
209                rng.random::<[XFieldElement; MasterAuxTable::NUM_CONSTRAINTS]>();
210            let ood_point_curr_row: XFieldElement = rng.random();
211            let padded_height = 2u32.pow(rng.random_range(8..32));
212            let trace_domain_generator =
213                BFieldElement::primitive_root_of_unity(padded_height as u64).unwrap();
214
215            (
216                air_evaluation_result,
217                ood_point_curr_row,
218                padded_height,
219                trace_domain_generator,
220            )
221        }
222
223        /// Return the evaluated array of quotient values, and its address in memory
224        fn tasm_result(
225            &self,
226            air_evaluation_result: [XFieldElement; MasterAuxTable::NUM_CONSTRAINTS],
227            out_of_domain_point_curr_row: XFieldElement,
228            padded_height: u32,
229            trace_domain_generator: BFieldElement,
230        ) -> Vec<XFieldElement> {
231            let free_page_pointer = BFieldElement::new(((1u64 << 32) - 3) * (1 << 32));
232            let mut memory = HashMap::<BFieldElement, BFieldElement>::new();
233            println!(
234                "air evaluation result encoded: [{}, ...]",
235                air_evaluation_result.encode().iter().take(4).join(",")
236            );
237            for (i, e) in air_evaluation_result.encode().into_iter().enumerate() {
238                memory.insert(free_page_pointer + BFieldElement::new(i as u64), e);
239            }
240
241            let stack = [
242                empty_stack(),
243                vec![free_page_pointer],
244                out_of_domain_point_curr_row
245                    .coefficients
246                    .into_iter()
247                    .rev()
248                    .collect_vec(),
249                vec![
250                    BFieldElement::new(padded_height as u64),
251                    trace_domain_generator,
252                ],
253            ]
254            .concat();
255            let final_state = execute_test(
256                &self.link_for_isolated_run(),
257                &mut stack.clone(),
258                self.stack_diff(),
259                vec![],
260                NonDeterminism::default().with_ram(memory),
261                None,
262            );
263
264            // read the array pointed to by the pointer living on top of the stack
265            AirConstraintEvaluation::read_result_from_memory(final_state).0
266        }
267
268        pub fn rust_result(
269            air_evaluation_result: [XFieldElement; MasterAuxTable::NUM_CONSTRAINTS],
270            out_of_domain_point_curr_row: XFieldElement,
271            padded_height: u32,
272            trace_domain_generator: BFieldElement,
273        ) -> Vec<XFieldElement> {
274            println!("trace domain generator: {trace_domain_generator}");
275            println!("padded height: {padded_height}");
276            println!("out-of-domain point current row: {out_of_domain_point_curr_row}");
277            let initial_zerofier_inv = (out_of_domain_point_curr_row - bfe!(1)).inverse();
278            let consistency_zerofier_inv =
279                (out_of_domain_point_curr_row.mod_pow_u32(padded_height) - bfe!(1)).inverse();
280            let except_last_row = out_of_domain_point_curr_row - trace_domain_generator.inverse();
281            let transition_zerofier_inv = except_last_row * consistency_zerofier_inv;
282            let terminal_zerofier_inv = except_last_row.inverse(); // i.e., only last row
283
284            println!("initial zerofier inverse: {initial_zerofier_inv}");
285            println!("consistency zerofier inverse: {consistency_zerofier_inv}");
286            println!("transition zerofier inverse: {transition_zerofier_inv}");
287            println!("terminal zerofier inverse: {terminal_zerofier_inv}");
288
289            let mut running_total_constraints = 0;
290            let initial_quotients = air_evaluation_result[running_total_constraints
291                ..(running_total_constraints + MasterAuxTable::NUM_INITIAL_CONSTRAINTS)]
292                .iter()
293                .map(|&x| x * initial_zerofier_inv)
294                .collect_vec();
295            running_total_constraints += MasterAuxTable::NUM_INITIAL_CONSTRAINTS;
296
297            let consistency_quotients = air_evaluation_result[running_total_constraints
298                ..(running_total_constraints + MasterAuxTable::NUM_CONSISTENCY_CONSTRAINTS)]
299                .iter()
300                .map(|&x| x * consistency_zerofier_inv)
301                .collect_vec();
302            running_total_constraints += MasterAuxTable::NUM_CONSISTENCY_CONSTRAINTS;
303
304            let transition_quotients = air_evaluation_result[running_total_constraints
305                ..(running_total_constraints + MasterAuxTable::NUM_TRANSITION_CONSTRAINTS)]
306                .iter()
307                .map(|&x| x * transition_zerofier_inv)
308                .collect_vec();
309            running_total_constraints += MasterAuxTable::NUM_TRANSITION_CONSTRAINTS;
310
311            let terminal_quotients = air_evaluation_result[running_total_constraints
312                ..(running_total_constraints + MasterAuxTable::NUM_TERMINAL_CONSTRAINTS)]
313                .iter()
314                .map(|&x| x * terminal_zerofier_inv)
315                .collect_vec();
316
317            [
318                initial_quotients,
319                consistency_quotients,
320                transition_quotients,
321                terminal_quotients,
322            ]
323            .concat()
324        }
325    }
326}
327
328#[cfg(test)]
329mod bench {
330    use std::collections::HashMap;
331
332    use itertools::Itertools;
333    use twenty_first::math::traits::PrimitiveRootOfUnity;
334
335    use super::*;
336    use crate::empty_stack;
337    use crate::test_prelude::*;
338
339    #[test]
340    fn bench_divide_out_zerofiers() {
341        ShadowedFunction::new(DivideOutZerofiers).bench();
342    }
343
344    impl Function for DivideOutZerofiers {
345        fn rust_shadow(
346            &self,
347            _: &mut Vec<BFieldElement>,
348            _: &mut HashMap<BFieldElement, BFieldElement>,
349        ) {
350            // Never called as we do a more manual test.
351            // The more manual test is done bc we don't want to
352            // have to simulate all the intermediate calculations
353            // that are stored to memory.
354            unimplemented!()
355        }
356
357        fn pseudorandom_initial_state(
358            &self,
359            seed: [u8; 32],
360            _: Option<BenchmarkCase>,
361        ) -> FunctionInitialState {
362            // Used for benchmarking
363            let mut rng = StdRng::from_seed(seed);
364            let air_evaluation_result =
365                rng.random::<[XFieldElement; MasterAuxTable::NUM_CONSTRAINTS]>();
366            let ood_point_current_row = rng.random::<XFieldElement>();
367            let padded_height = 1 << 20;
368            let trace_domain_generator =
369                BFieldElement::primitive_root_of_unity(padded_height).unwrap();
370
371            let free_page_pointer = BFieldElement::new(((1u64 << 32) - 3) * (1 << 32));
372            let mut memory = HashMap::<BFieldElement, BFieldElement>::new();
373            for (i, e) in air_evaluation_result
374                .encode()
375                .into_iter()
376                .skip(1)
377                .enumerate()
378            {
379                memory.insert(free_page_pointer + BFieldElement::new(i as u64), e);
380            }
381
382            let stack = [
383                empty_stack(),
384                vec![free_page_pointer],
385                ood_point_current_row
386                    .coefficients
387                    .into_iter()
388                    .rev()
389                    .collect_vec(),
390                vec![BFieldElement::new(padded_height), trace_domain_generator],
391            ]
392            .concat();
393
394            FunctionInitialState { stack, memory }
395        }
396    }
397}