Skip to main content

tasm_lib/verifier/
out_of_domain_points.rs

1use triton_vm::prelude::*;
2use twenty_first::math::x_field_element::EXTENSION_DEGREE;
3
4use crate::data_type::ArrayType;
5use crate::prelude::*;
6
7/// Calculate the three needed values related to out-of-domain points and store them in a statically
8/// allocated array. Return the pointer to this array.
9#[derive(Debug, Clone, Copy)]
10pub struct OutOfDomainPoints;
11
12pub const NUM_OF_OUT_OF_DOMAIN_POINTS: usize = 3;
13
14#[derive(Debug, Clone, Copy)]
15pub enum OodPoint {
16    CurrentRow,
17    NextRow,
18    CurrentRowPowNumSegments,
19}
20
21impl OutOfDomainPoints {
22    /// Push the requested OOD point to the stack, pop the pointer.
23    pub fn read_ood_point(ood_point_type: OodPoint) -> Vec<LabelledInstruction> {
24        let address_offset = (ood_point_type as usize) * EXTENSION_DEGREE + (EXTENSION_DEGREE - 1);
25        triton_asm!(
26            // _ *ood_points // of type same as the output value of this snippet
27
28            push {address_offset}
29            add
30            // _ (*ood_points[n] + 2)
31
32            read_mem {EXTENSION_DEGREE}
33            // _ [ood_point] (*ood_points[n] - 1)
34
35            pop 1
36        )
37    }
38}
39
40impl BasicSnippet for OutOfDomainPoints {
41    fn parameters(&self) -> Vec<(DataType, String)> {
42        vec![
43            (DataType::Bfe, "trace_domain_generator".to_owned()),
44            (DataType::Xfe, "out_of_domain_curr_row".to_owned()),
45        ]
46    }
47
48    fn return_values(&self) -> Vec<(DataType, String)> {
49        vec![(
50            DataType::Array(Box::new(ArrayType {
51                element_type: DataType::Xfe,
52                length: NUM_OF_OUT_OF_DOMAIN_POINTS,
53            })),
54            "out_of_domain_points".to_owned(),
55        )]
56    }
57
58    fn entrypoint(&self) -> String {
59        "tasmlib_verifier_out_of_domain_points".to_owned()
60    }
61
62    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
63        let entrypoint = self.entrypoint();
64
65        // Snippet for sampling *one* scalar, and holding the values:
66        // - `out_of_domain_point_curr_row`
67        // - `out_of_domain_point_next_row`
68        // - `out_of_domain_point_curr_row_pow_num_segments`
69        let num_words_for_out_of_domain_points = (NUM_OF_OUT_OF_DOMAIN_POINTS * EXTENSION_DEGREE)
70            .try_into()
71            .unwrap();
72        let ood_points_alloc = library.kmalloc(num_words_for_out_of_domain_points);
73
74        triton_asm!(
75            {entrypoint}:
76                // _ trace_domain_generator [ood_curr_row]
77
78                dup 2
79                dup 2
80                dup 2
81                dup 2
82                dup 2
83                dup 2
84                push {ood_points_alloc.write_address()}
85                write_mem {EXTENSION_DEGREE}
86                // _ trace_domain_generator [ood_curr_row] [ood_curr_row] *ood_points[1]
87
88                swap 7
89                // _ *ood_points[1] [ood_curr_row] [ood_curr_row] trace_domain_generator
90
91                xb_mul
92                // _ *ood_points[1] [ood_curr_row] [ood_next_row]
93
94                dup 6
95                write_mem {EXTENSION_DEGREE}
96                // _ *ood_points[1] [ood_curr_row] *ood_points[2]
97
98                swap 4
99                pop 1
100                // _ *ood_points[2] [ood_curr_row]
101
102                dup 2 dup 2 dup 2
103                xx_mul
104                dup 2 dup 2 dup 2
105                xx_mul
106                // _ *ood_points[2] [ood_curr_row**4]
107
108                swap 1
109                swap 2
110                swap 3
111                // _ [ood_curr_row**4] *ood_points[2]
112
113                write_mem {EXTENSION_DEGREE}
114                // _ *ood_points[3]
115
116                push {-(3 * EXTENSION_DEGREE as i32)}
117                add
118                // _ *ood_points
119
120                return
121        )
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use triton_vm::table::NUM_QUOTIENT_SEGMENTS;
128    use twenty_first::math::traits::ModPowU32;
129    use twenty_first::math::traits::PrimitiveRootOfUnity;
130
131    use super::*;
132    use crate::rust_shadowing_helper_functions::array::insert_as_array;
133    use crate::test_prelude::*;
134
135    #[macro_rules_attr::apply(test)]
136    fn ood_points_pbt() {
137        ShadowedFunction::new(OutOfDomainPoints).test();
138    }
139
140    impl Function for OutOfDomainPoints {
141        fn rust_shadow(
142            &self,
143            stack: &mut Vec<BFieldElement>,
144            memory: &mut HashMap<BFieldElement, BFieldElement>,
145        ) -> Result<(), RustShadowError> {
146            let ood_curr_row = XFieldElement::new([
147                stack.pop().ok_or(RustShadowError::StackUnderflow)?,
148                stack.pop().ok_or(RustShadowError::StackUnderflow)?,
149                stack.pop().ok_or(RustShadowError::StackUnderflow)?,
150            ]);
151            let domain_generator = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
152            let ood_next_row = ood_curr_row * domain_generator;
153            let num_quotient_segments: u32 = NUM_QUOTIENT_SEGMENTS
154                .try_into()
155                .map_err(|_| RustShadowError::UsizeToU32Error)?;
156            let ood_curr_row_pow_num_segments = ood_curr_row.mod_pow_u32(num_quotient_segments);
157            let static_malloc_size: i32 = (EXTENSION_DEGREE * NUM_OF_OUT_OF_DOMAIN_POINTS)
158                .try_into()
159                .map_err(|_| RustShadowError::Other)?;
160            let ood_points_pointer = bfe!(-static_malloc_size - 1);
161            insert_as_array(
162                ood_points_pointer,
163                memory,
164                vec![ood_curr_row, ood_next_row, ood_curr_row_pow_num_segments],
165            );
166
167            stack.push(ood_points_pointer);
168
169            Ok(())
170        }
171
172        fn pseudorandom_initial_state(
173            &self,
174            seed: [u8; 32],
175            bench_case: Option<BenchmarkCase>,
176        ) -> FunctionInitialState {
177            let domain_length = match bench_case {
178                Some(BenchmarkCase::CommonCase) => 1u64 << 20,
179                Some(BenchmarkCase::WorstCase) => 1u64 << 24,
180                None => {
181                    let mut rng = StdRng::from_seed(seed);
182                    1u64 << rng.random_range(8..=32)
183                }
184            };
185            println!("domain_length: {domain_length}");
186
187            let domain_generator = BFieldElement::primitive_root_of_unity(domain_length).unwrap();
188            let ood_curr_row: XFieldElement = rand::random();
189
190            FunctionInitialState {
191                stack: [
192                    self.init_stack_for_isolated_run(),
193                    vec![
194                        domain_generator,
195                        ood_curr_row.coefficients[2],
196                        ood_curr_row.coefficients[1],
197                        ood_curr_row.coefficients[0],
198                    ],
199                ]
200                .concat(),
201                memory: HashMap::default(),
202            }
203        }
204    }
205}
206
207#[cfg(test)]
208mod benches {
209    use super::*;
210    use crate::test_prelude::*;
211
212    #[macro_rules_attr::apply(test)]
213    fn benchmark() {
214        ShadowedFunction::new(OutOfDomainPoints).bench();
215    }
216}