Skip to main content

tasm_lib/verifier/master_table/
divide_out_zerofiers.rs

1use triton_vm::prelude::LabelledInstruction;
2use triton_vm::prelude::*;
3use triton_vm::table::ConstraintType;
4use triton_vm::table::master_table::MasterAuxTable;
5use twenty_first::math::x_field_element::EXTENSION_DEGREE;
6
7use crate::prelude::*;
8use crate::verifier::master_table::air_constraint_evaluation::AirConstraintEvaluation;
9use crate::verifier::master_table::zerofiers_inverse::ZerofiersInverse;
10
11/// Takes an AIR evaluation and divides out the zerofiers.
12#[derive(Debug, Clone)]
13pub struct DivideOutZerofiers;
14
15impl BasicSnippet for DivideOutZerofiers {
16    fn parameters(&self) -> Vec<(DataType, String)> {
17        vec![
18            (DataType::VoidPointer, "*air_evaluation_result".to_string()),
19            (DataType::Xfe, "out_of_domain_point_curr_row".to_owned()),
20            (DataType::U32, "padded_height".to_owned()),
21            (DataType::Bfe, "trace_domain_generator".to_owned()),
22        ]
23    }
24
25    fn return_values(&self) -> Vec<(DataType, String)> {
26        vec![(
27            AirConstraintEvaluation::output_type(),
28            "*quotient_summands".to_owned(),
29        )]
30    }
31
32    fn entrypoint(&self) -> String {
33        "tasmlib_verifier_master_table_divide_out_zerofiers".to_owned()
34    }
35
36    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
37        let entrypoint = self.entrypoint();
38
39        let zerofiers_inverse_alloc =
40            library.kmalloc(ZerofiersInverse::array_size().try_into().unwrap());
41        let zerofiers_inverse_snippet = ZerofiersInverse {
42            zerofiers_inverse_write_address: zerofiers_inverse_alloc.write_address(),
43        };
44        let zerofiers_inverse = library.import(Box::new(zerofiers_inverse_snippet));
45
46        let read_all_air_elements = vec![
47            triton_asm!(
48                // _ *air_elem
49                read_mem { EXTENSION_DEGREE } // _ [air_elem] *air_elem_prev
50            );
51            MasterAuxTable::NUM_CONSTRAINTS
52        ]
53        .concat();
54
55        let mul_and_write = |constraint_type: ConstraintType, num_constraints: usize| {
56            vec![
57                triton_asm!(
58                    // _ [[air_elem]] *air_elem
59
60                    swap 3
61                    swap 2
62                    swap 1
63                    // _ *air_elem [[air_elem]]
64
65                    push {zerofiers_inverse_snippet.zerofier_inv_read_address(constraint_type)}
66                    read_mem {EXTENSION_DEGREE}
67                    pop 1
68
69                    xx_mul
70                    // _ [[air_elem] *air_elem [[air_elem * z_inv]]]
71
72                    swap 1
73                    swap 2
74                    swap 3
75                    // _ [[air_elem] [[air_elem * z_inv]]] *air_elem
76
77                    write_mem {EXTENSION_DEGREE}
78                    // _ [[air_elem] *air_elem_next
79
80                );
81                num_constraints
82            ]
83            .concat()
84        };
85
86        let mul_and_write = [
87            mul_and_write(
88                ConstraintType::Initial,
89                MasterAuxTable::NUM_INITIAL_CONSTRAINTS,
90            ),
91            mul_and_write(
92                ConstraintType::Consistency,
93                MasterAuxTable::NUM_CONSISTENCY_CONSTRAINTS,
94            ),
95            mul_and_write(
96                ConstraintType::Transition,
97                MasterAuxTable::NUM_TRANSITION_CONSTRAINTS,
98            ),
99            mul_and_write(
100                ConstraintType::Terminal,
101                MasterAuxTable::NUM_TERMINAL_CONSTRAINTS,
102            ),
103        ]
104        .concat();
105
106        let jump_to_last_air_element = triton_asm!(
107            // _ *air_elem[0]
108
109            push {MasterAuxTable::NUM_CONSTRAINTS * EXTENSION_DEGREE - 1}
110            add
111            // _ *air_elem_last_word
112        );
113
114        let jump_to_first_air_element = triton_asm!(
115            // _ *air_elem[n]
116
117            push {-((MasterAuxTable::NUM_CONSTRAINTS * EXTENSION_DEGREE) as i32)}
118            add
119            // _ *air_elem_last_word
120        );
121
122        triton_asm!(
123            {entrypoint}:
124                // _ *air_evaluation_result [out_of_domain_point_curr_row] padded_height trace_domain_generator
125
126                call {zerofiers_inverse}
127                // _ *air_evaluation_result
128
129                {&jump_to_last_air_element}
130                // _ *air_elem_last_word
131
132                {&read_all_air_elements}
133                // _ [[air_elem]] (*air_constraints - 1)
134
135                push 1 add
136                // _ [[air_elem]] *air_constraints
137
138                {&mul_and_write}
139                // _ (*air_elem_last + 3)
140
141                {&jump_to_first_air_element}
142
143                return
144        )
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use std::collections::HashMap;
151
152    use itertools::Itertools;
153    use rand::prelude::*;
154    use twenty_first::math::traits::ModPowU32;
155    use twenty_first::math::traits::PrimitiveRootOfUnity;
156
157    use super::*;
158    use crate::empty_stack;
159    use crate::execute_test;
160    use crate::test_prelude::*;
161
162    #[macro_rules_attr::apply(test)]
163    fn divide_out_zerofiers_test() {
164        let snippet = DivideOutZerofiers;
165        let mut seed: [u8; 32] = [0u8; 32];
166        rand::rng().fill_bytes(&mut seed);
167        snippet.test_equivalence_with_host_machine(seed);
168    }
169
170    impl DivideOutZerofiers {
171        fn test_equivalence_with_host_machine(&self, seed: [u8; 32]) {
172            let mut rng = StdRng::from_seed(seed);
173            let (air_evaluation_result, ood_point_curr_row, padded_height, trace_domain_generator) =
174                Self::random_input_values(&mut rng);
175
176            let rust_result = Self::rust_result(
177                air_evaluation_result,
178                ood_point_curr_row,
179                padded_height,
180                trace_domain_generator,
181            );
182
183            let tasm_result = self.tasm_result(
184                air_evaluation_result,
185                ood_point_curr_row,
186                padded_height,
187                trace_domain_generator,
188            );
189
190            assert_eq!(tasm_result.len(), rust_result.len());
191            assert_eq!(
192                tasm_result.iter().copied().sum::<XFieldElement>(),
193                rust_result.iter().copied().sum::<XFieldElement>(),
194                "\ntasm: [{},...]\nrust: [{},...]",
195                tasm_result.iter().take(3).join(","),
196                rust_result.iter().take(3).join(",")
197            );
198            assert_eq!(tasm_result, rust_result);
199        }
200
201        pub(super) fn random_input_values(
202            rng: &mut StdRng,
203        ) -> (
204            [XFieldElement; MasterAuxTable::NUM_CONSTRAINTS],
205            XFieldElement,
206            u32,
207            BFieldElement,
208        ) {
209            let air_evaluation_result =
210                rng.random::<[XFieldElement; MasterAuxTable::NUM_CONSTRAINTS]>();
211            let ood_point_curr_row: XFieldElement = rng.random();
212            let padded_height = 2u32.pow(rng.random_range(8..32));
213            let trace_domain_generator =
214                BFieldElement::primitive_root_of_unity(padded_height as u64).unwrap();
215
216            (
217                air_evaluation_result,
218                ood_point_curr_row,
219                padded_height,
220                trace_domain_generator,
221            )
222        }
223
224        /// Return the evaluated array of quotient values, and its address in memory
225        fn tasm_result(
226            &self,
227            air_evaluation_result: [XFieldElement; MasterAuxTable::NUM_CONSTRAINTS],
228            out_of_domain_point_curr_row: XFieldElement,
229            padded_height: u32,
230            trace_domain_generator: BFieldElement,
231        ) -> Vec<XFieldElement> {
232            let free_page_pointer = BFieldElement::new(((1u64 << 32) - 3) * (1 << 32));
233            let mut memory = HashMap::<BFieldElement, BFieldElement>::new();
234            println!(
235                "air evaluation result encoded: [{}, ...]",
236                air_evaluation_result.encode().iter().take(4).join(",")
237            );
238            for (i, e) in air_evaluation_result.encode().into_iter().enumerate() {
239                memory.insert(free_page_pointer + BFieldElement::new(i as u64), e);
240            }
241
242            let stack = [
243                empty_stack(),
244                vec![free_page_pointer],
245                out_of_domain_point_curr_row
246                    .coefficients
247                    .into_iter()
248                    .rev()
249                    .collect_vec(),
250                vec![
251                    BFieldElement::new(padded_height as u64),
252                    trace_domain_generator,
253                ],
254            ]
255            .concat();
256            let final_state = execute_test(
257                &self.link_for_isolated_run(),
258                &mut stack.clone(),
259                self.stack_diff(),
260                vec![],
261                NonDeterminism::default().with_ram(memory),
262                None,
263            )
264            .unwrap();
265
266            // read the array pointed to by the pointer living on top of the stack
267            AirConstraintEvaluation::read_result_from_memory(final_state).0
268        }
269
270        pub fn rust_result(
271            air_evaluation_result: [XFieldElement; MasterAuxTable::NUM_CONSTRAINTS],
272            out_of_domain_point_curr_row: XFieldElement,
273            padded_height: u32,
274            trace_domain_generator: BFieldElement,
275        ) -> Vec<XFieldElement> {
276            println!("trace domain generator: {trace_domain_generator}");
277            println!("padded height: {padded_height}");
278            println!("out-of-domain point current row: {out_of_domain_point_curr_row}");
279            let initial_zerofier_inv = (out_of_domain_point_curr_row - bfe!(1)).inverse();
280            let consistency_zerofier_inv =
281                (out_of_domain_point_curr_row.mod_pow_u32(padded_height) - bfe!(1)).inverse();
282            let except_last_row = out_of_domain_point_curr_row - trace_domain_generator.inverse();
283            let transition_zerofier_inv = except_last_row * consistency_zerofier_inv;
284            let terminal_zerofier_inv = except_last_row.inverse(); // i.e., only last row
285
286            println!("initial zerofier inverse: {initial_zerofier_inv}");
287            println!("consistency zerofier inverse: {consistency_zerofier_inv}");
288            println!("transition zerofier inverse: {transition_zerofier_inv}");
289            println!("terminal zerofier inverse: {terminal_zerofier_inv}");
290
291            let mut running_total_constraints = 0;
292            let initial_quotients = air_evaluation_result[running_total_constraints
293                ..(running_total_constraints + MasterAuxTable::NUM_INITIAL_CONSTRAINTS)]
294                .iter()
295                .map(|&x| x * initial_zerofier_inv)
296                .collect_vec();
297            running_total_constraints += MasterAuxTable::NUM_INITIAL_CONSTRAINTS;
298
299            let consistency_quotients = air_evaluation_result[running_total_constraints
300                ..(running_total_constraints + MasterAuxTable::NUM_CONSISTENCY_CONSTRAINTS)]
301                .iter()
302                .map(|&x| x * consistency_zerofier_inv)
303                .collect_vec();
304            running_total_constraints += MasterAuxTable::NUM_CONSISTENCY_CONSTRAINTS;
305
306            let transition_quotients = air_evaluation_result[running_total_constraints
307                ..(running_total_constraints + MasterAuxTable::NUM_TRANSITION_CONSTRAINTS)]
308                .iter()
309                .map(|&x| x * transition_zerofier_inv)
310                .collect_vec();
311            running_total_constraints += MasterAuxTable::NUM_TRANSITION_CONSTRAINTS;
312
313            let terminal_quotients = air_evaluation_result[running_total_constraints
314                ..(running_total_constraints + MasterAuxTable::NUM_TERMINAL_CONSTRAINTS)]
315                .iter()
316                .map(|&x| x * terminal_zerofier_inv)
317                .collect_vec();
318
319            [
320                initial_quotients,
321                consistency_quotients,
322                transition_quotients,
323                terminal_quotients,
324            ]
325            .concat()
326        }
327    }
328}
329
330#[cfg(test)]
331mod bench {
332    use std::collections::HashMap;
333
334    use itertools::Itertools;
335    use twenty_first::math::traits::PrimitiveRootOfUnity;
336
337    use super::*;
338    use crate::empty_stack;
339    use crate::test_prelude::*;
340
341    #[macro_rules_attr::apply(test)]
342    fn bench_divide_out_zerofiers() {
343        ShadowedFunction::new(DivideOutZerofiers).bench();
344    }
345
346    impl Function for DivideOutZerofiers {
347        fn rust_shadow(
348            &self,
349            _: &mut Vec<BFieldElement>,
350            _: &mut HashMap<BFieldElement, BFieldElement>,
351        ) -> Result<(), RustShadowError> {
352            // Never called as we do a more manual test.
353            // The more manual test is done bc we don't want to
354            // have to simulate all the intermediate calculations
355            // that are stored to memory.
356            unimplemented!()
357        }
358
359        fn pseudorandom_initial_state(
360            &self,
361            seed: [u8; 32],
362            _: Option<BenchmarkCase>,
363        ) -> FunctionInitialState {
364            // Used for benchmarking
365            let mut rng = StdRng::from_seed(seed);
366            let air_evaluation_result =
367                rng.random::<[XFieldElement; MasterAuxTable::NUM_CONSTRAINTS]>();
368            let ood_point_current_row = rng.random::<XFieldElement>();
369            let padded_height = 1 << 20;
370            let trace_domain_generator =
371                BFieldElement::primitive_root_of_unity(padded_height).unwrap();
372
373            let free_page_pointer = BFieldElement::new(((1u64 << 32) - 3) * (1 << 32));
374            let mut memory = HashMap::<BFieldElement, BFieldElement>::new();
375            for (i, e) in air_evaluation_result
376                .encode()
377                .into_iter()
378                .skip(1)
379                .enumerate()
380            {
381                memory.insert(free_page_pointer + BFieldElement::new(i as u64), e);
382            }
383
384            let stack = [
385                empty_stack(),
386                vec![free_page_pointer],
387                ood_point_current_row
388                    .coefficients
389                    .into_iter()
390                    .rev()
391                    .collect_vec(),
392                vec![BFieldElement::new(padded_height), trace_domain_generator],
393            ]
394            .concat();
395
396            FunctionInitialState { stack, memory }
397        }
398    }
399}