tasm_lib/verifier/
out_of_domain_points.rs

1use triton_vm::prelude::*;
2use twenty_first::math::x_field_element::EXTENSION_DEGREE;
3
4use crate::data_type::ArrayType;
5use crate::prelude::*;
6
7/// Calculate the three needed values related to out-of-domain points and store them in a statically
8/// allocated array. Return the pointer to this array.
9#[derive(Debug, Clone, Copy)]
10pub struct OutOfDomainPoints;
11
12pub const NUM_OF_OUT_OF_DOMAIN_POINTS: usize = 3;
13
14#[derive(Debug, Clone, Copy)]
15pub enum OodPoint {
16    CurrentRow,
17    NextRow,
18    CurrentRowPowNumSegments,
19}
20
21impl OutOfDomainPoints {
22    /// Push the requested OOD point to the stack, pop the pointer.
23    pub fn read_ood_point(ood_point_type: OodPoint) -> Vec<LabelledInstruction> {
24        let address_offset = (ood_point_type as usize) * EXTENSION_DEGREE + (EXTENSION_DEGREE - 1);
25        triton_asm!(
26            // _ *ood_points // of type same as the output value of this snippet
27
28            push {address_offset}
29            add
30            // _ (*ood_points[n] + 2)
31
32            read_mem {EXTENSION_DEGREE}
33            // _ [ood_point] (*ood_points[n] - 1)
34
35            pop 1
36        )
37    }
38}
39
40impl BasicSnippet for OutOfDomainPoints {
41    fn inputs(&self) -> Vec<(DataType, String)> {
42        vec![
43            (DataType::Bfe, "trace_domain_generator".to_owned()),
44            (DataType::Xfe, "out_of_domain_curr_row".to_owned()),
45        ]
46    }
47
48    fn outputs(&self) -> Vec<(DataType, String)> {
49        vec![(
50            DataType::Array(Box::new(ArrayType {
51                element_type: DataType::Xfe,
52                length: NUM_OF_OUT_OF_DOMAIN_POINTS,
53            })),
54            "out_of_domain_points".to_owned(),
55        )]
56    }
57
58    fn entrypoint(&self) -> String {
59        "tasmlib_verifier_out_of_domain_points".to_owned()
60    }
61
62    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
63        let entrypoint = self.entrypoint();
64
65        // Snippet for sampling *one* scalar, and holding the values:
66        // - `out_of_domain_point_curr_row`
67        // - `out_of_domain_point_next_row`
68        // - `out_of_domain_point_curr_row_pow_num_segments`
69        let num_words_for_out_of_domain_points = (NUM_OF_OUT_OF_DOMAIN_POINTS * EXTENSION_DEGREE)
70            .try_into()
71            .unwrap();
72        let ood_points_alloc = library.kmalloc(num_words_for_out_of_domain_points);
73
74        triton_asm!(
75            {entrypoint}:
76                // _ trace_domain_generator [ood_curr_row]
77
78                dup 2
79                dup 2
80                dup 2
81                dup 2
82                dup 2
83                dup 2
84                push {ood_points_alloc.write_address()}
85                write_mem {EXTENSION_DEGREE}
86                // _ trace_domain_generator [ood_curr_row] [ood_curr_row] *ood_points[1]
87
88                swap 7
89                // _ *ood_points[1] [ood_curr_row] [ood_curr_row] trace_domain_generator
90
91                xb_mul
92                // _ *ood_points[1] [ood_curr_row] [ood_next_row]
93
94                dup 6
95                write_mem {EXTENSION_DEGREE}
96                // _ *ood_points[1] [ood_curr_row] *ood_points[2]
97
98                swap 4
99                pop 1
100                // _ *ood_points[2] [ood_curr_row]
101
102                dup 2 dup 2 dup 2
103                xx_mul
104                dup 2 dup 2 dup 2
105                xx_mul
106                // _ *ood_points[2] [ood_curr_row**4]
107
108                swap 1
109                swap 2
110                swap 3
111                // _ [ood_curr_row**4] *ood_points[2]
112
113                write_mem {EXTENSION_DEGREE}
114                // _ *ood_points[3]
115
116                push {-(3 * EXTENSION_DEGREE as i32)}
117                add
118                // _ *ood_points
119
120                return
121        )
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use triton_vm::table::NUM_QUOTIENT_SEGMENTS;
128    use twenty_first::math::traits::ModPowU32;
129    use twenty_first::math::traits::PrimitiveRootOfUnity;
130
131    use super::*;
132    use crate::rust_shadowing_helper_functions::array::insert_as_array;
133    use crate::test_prelude::*;
134
135    #[test]
136    fn ood_points_pbt() {
137        ShadowedFunction::new(OutOfDomainPoints).test();
138    }
139
140    impl Function for OutOfDomainPoints {
141        fn rust_shadow(
142            &self,
143            stack: &mut Vec<BFieldElement>,
144            memory: &mut HashMap<BFieldElement, BFieldElement>,
145        ) {
146            let ood_curr_row = XFieldElement::new([
147                stack.pop().unwrap(),
148                stack.pop().unwrap(),
149                stack.pop().unwrap(),
150            ]);
151            let domain_generator = stack.pop().unwrap();
152            let ood_next_row = ood_curr_row * domain_generator;
153            let ood_curr_row_pow_num_segments =
154                ood_curr_row.mod_pow_u32(NUM_QUOTIENT_SEGMENTS.try_into().unwrap());
155            let static_malloc_size: i32 = (EXTENSION_DEGREE * NUM_OF_OUT_OF_DOMAIN_POINTS)
156                .try_into()
157                .unwrap();
158            let ood_points_pointer = bfe!(-static_malloc_size - 1);
159            insert_as_array(
160                ood_points_pointer,
161                memory,
162                vec![ood_curr_row, ood_next_row, ood_curr_row_pow_num_segments],
163            );
164
165            stack.push(ood_points_pointer)
166        }
167
168        fn pseudorandom_initial_state(
169            &self,
170            seed: [u8; 32],
171            bench_case: Option<BenchmarkCase>,
172        ) -> FunctionInitialState {
173            let domain_length = match bench_case {
174                Some(BenchmarkCase::CommonCase) => 1u64 << 20,
175                Some(BenchmarkCase::WorstCase) => 1u64 << 24,
176                None => {
177                    let mut rng = StdRng::from_seed(seed);
178                    1u64 << rng.random_range(8..=32)
179                }
180            };
181            println!("domain_length: {domain_length}");
182
183            let domain_generator = BFieldElement::primitive_root_of_unity(domain_length).unwrap();
184            let ood_curr_row: XFieldElement = rand::random();
185
186            FunctionInitialState {
187                stack: [
188                    self.init_stack_for_isolated_run(),
189                    vec![
190                        domain_generator,
191                        ood_curr_row.coefficients[2],
192                        ood_curr_row.coefficients[1],
193                        ood_curr_row.coefficients[0],
194                    ],
195                ]
196                .concat(),
197                memory: HashMap::default(),
198            }
199        }
200    }
201}
202
203#[cfg(test)]
204mod benches {
205    use super::*;
206    use crate::test_prelude::*;
207
208    #[test]
209    fn benchmark() {
210        ShadowedFunction::new(OutOfDomainPoints).bench();
211    }
212}