sp1_recursion_program/fri/
domain.rs

1use p3_commit::{LagrangeSelectors, TwoAdicMultiplicativeCoset};
2use p3_field::{AbstractField, TwoAdicField};
3use sp1_recursion_compiler::prelude::*;
4
5use super::types::FriConfigVariable;
6use crate::commit::PolynomialSpaceVariable;
7
8/// Reference: [p3_commit::TwoAdicMultiplicativeCoset]
9#[derive(DslVariable, Clone, Copy)]
10pub struct TwoAdicMultiplicativeCosetVariable<C: Config> {
11    pub log_n: Var<C::N>,
12    pub size: Var<C::N>,
13    pub shift: Felt<C::F>,
14    pub g: Felt<C::F>,
15}
16
17impl<C: Config> TwoAdicMultiplicativeCosetVariable<C> {
18    pub const fn size(&self) -> Var<C::N> {
19        self.size
20    }
21
22    pub const fn first_point(&self) -> Felt<C::F> {
23        self.shift
24    }
25
26    pub const fn gen(&self) -> Felt<C::F> {
27        self.g
28    }
29}
30
31impl<C: Config> FromConstant<C> for TwoAdicMultiplicativeCosetVariable<C>
32where
33    C::F: TwoAdicField,
34{
35    type Constant = TwoAdicMultiplicativeCoset<C::F>;
36
37    fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
38        let log_d_val = value.log_n as u32;
39        let g_val = C::F::two_adic_generator(value.log_n);
40        TwoAdicMultiplicativeCosetVariable::<C> {
41            log_n: builder.eval::<Var<_>, _>(C::N::from_canonical_u32(log_d_val)),
42            size: builder.eval::<Var<_>, _>(C::N::from_canonical_u32(1 << (log_d_val))),
43            shift: builder.eval(value.shift),
44            g: builder.eval(g_val),
45        }
46    }
47}
48
49impl<C: Config> PolynomialSpaceVariable<C> for TwoAdicMultiplicativeCosetVariable<C>
50where
51    C::F: TwoAdicField,
52{
53    type Constant = p3_commit::TwoAdicMultiplicativeCoset<C::F>;
54
55    fn next_point(
56        &self,
57        builder: &mut Builder<C>,
58        point: Ext<<C as Config>::F, <C as Config>::EF>,
59    ) -> Ext<<C as Config>::F, <C as Config>::EF> {
60        builder.eval(point * self.gen())
61    }
62
63    fn selectors_at_point(
64        &self,
65        builder: &mut Builder<C>,
66        point: Ext<<C as Config>::F, <C as Config>::EF>,
67    ) -> LagrangeSelectors<Ext<<C as Config>::F, <C as Config>::EF>> {
68        let unshifted_point: Ext<_, _> = builder.eval(point * self.shift.inverse());
69        let z_h_expr = builder
70            .exp_power_of_2_v::<Ext<_, _>>(unshifted_point, Usize::Var(self.log_n))
71            - C::EF::one();
72        let z_h: Ext<_, _> = builder.eval(z_h_expr);
73
74        LagrangeSelectors {
75            is_first_row: builder.eval(z_h / (unshifted_point - C::EF::one())),
76            is_last_row: builder.eval(z_h / (unshifted_point - self.gen().inverse())),
77            is_transition: builder.eval(unshifted_point - self.gen().inverse()),
78            inv_zeroifier: builder.eval(z_h.inverse()),
79        }
80    }
81
82    fn zp_at_point(
83        &self,
84        builder: &mut Builder<C>,
85        point: Ext<<C as Config>::F, <C as Config>::EF>,
86    ) -> Ext<<C as Config>::F, <C as Config>::EF> {
87        let unshifted_power = builder
88            .exp_power_of_2_v::<Ext<_, _>>(point * self.shift.inverse(), Usize::Var(self.log_n));
89        builder.eval(unshifted_power - C::EF::one())
90    }
91
92    fn split_domains(
93        &self,
94        builder: &mut Builder<C>,
95        log_num_chunks: impl Into<Usize<C::N>>,
96        num_chunks: impl Into<Usize<C::N>>,
97    ) -> Array<C, Self> {
98        let log_num_chunks = log_num_chunks.into();
99        let num_chunks = num_chunks.into();
100        let log_n: Var<_> = builder.eval(self.log_n - log_num_chunks);
101        let size = builder.sll(C::N::one(), Usize::Var(log_n));
102
103        let g_dom = self.gen();
104        let g = builder.exp_power_of_2_v::<Felt<C::F>>(g_dom, log_num_chunks);
105
106        let domain_power: Felt<_> = builder.eval(C::F::one());
107
108        let mut domains = builder.dyn_array(num_chunks);
109
110        builder.range(0, num_chunks).for_each(|i, builder| {
111            let domain = TwoAdicMultiplicativeCosetVariable {
112                log_n,
113                size,
114                shift: builder.eval(self.shift * domain_power),
115                g,
116            };
117            builder.set(&mut domains, i, domain);
118            builder.assign(domain_power, domain_power * g_dom);
119        });
120
121        domains
122    }
123
124    fn split_domains_const(&self, builder: &mut Builder<C>, log_num_chunks: usize) -> Vec<Self> {
125        let num_chunks = 1 << log_num_chunks;
126        let log_n: Var<_> = builder.eval(self.log_n - C::N::from_canonical_usize(log_num_chunks));
127        let size = builder.sll(C::N::one(), Usize::Var(log_n));
128
129        let g_dom = self.gen();
130        let g = builder.exp_power_of_2_v::<Felt<C::F>>(g_dom, log_num_chunks);
131
132        let domain_power: Felt<_> = builder.eval(C::F::one());
133        let mut domains = vec![];
134
135        for _ in 0..num_chunks {
136            domains.push(TwoAdicMultiplicativeCosetVariable {
137                log_n,
138                size,
139                shift: builder.eval(self.shift * domain_power),
140                g,
141            });
142            builder.assign(domain_power, domain_power * g_dom);
143        }
144        domains
145    }
146
147    fn create_disjoint_domain(
148        &self,
149        builder: &mut Builder<C>,
150        log_degree: Usize<<C as Config>::N>,
151        config: Option<FriConfigVariable<C>>,
152    ) -> Self {
153        let domain = config.unwrap().get_subgroup(builder, log_degree);
154        builder.assign(domain.shift, self.shift * C::F::generator());
155        domain
156    }
157}
158
159#[cfg(test)]
160pub(crate) mod tests {
161
162    use sp1_recursion_compiler::asm::AsmBuilder;
163    use sp1_recursion_core::stark::utils::{run_test_recursion, TestConfig};
164    use sp1_stark::{
165        baby_bear_poseidon2::BabyBearPoseidon2, inner_fri_config, Dom, StarkGenericConfig,
166    };
167
168    use crate::utils::const_fri_config;
169
170    use super::*;
171    use p3_commit::{Pcs, PolynomialSpace};
172    use rand::{thread_rng, Rng};
173
174    pub(crate) fn domain_assertions<F: TwoAdicField, C: Config<N = F, F = F>>(
175        builder: &mut Builder<C>,
176        domain: &TwoAdicMultiplicativeCosetVariable<C>,
177        domain_val: &TwoAdicMultiplicativeCoset<F>,
178        zeta_val: C::EF,
179    ) {
180        // Assert the domain parameters are the same.
181        builder.assert_var_eq(domain.log_n, F::from_canonical_usize(domain_val.log_n));
182        builder.assert_var_eq(domain.size, F::from_canonical_usize(1 << domain_val.log_n));
183        builder.assert_felt_eq(domain.shift, domain_val.shift);
184
185        // Get a random point.
186        let zeta: Ext<_, _> = builder.eval(zeta_val.cons());
187
188        // Compare the selector values of the reference and the builder.
189        let sels_expected = domain_val.selectors_at_point(zeta_val);
190        let sels = domain.selectors_at_point(builder, zeta);
191        builder.assert_ext_eq(sels.is_first_row, sels_expected.is_first_row.cons());
192        builder.assert_ext_eq(sels.is_last_row, sels_expected.is_last_row.cons());
193        builder.assert_ext_eq(sels.is_transition, sels_expected.is_transition.cons());
194
195        let zp_val = domain_val.zp_at_point(zeta_val);
196        let zp = domain.zp_at_point(builder, zeta);
197        builder.assert_ext_eq(zp, zp_val.cons());
198    }
199
200    #[test]
201    fn test_domain() {
202        type SC = BabyBearPoseidon2;
203        type F = <SC as StarkGenericConfig>::Val;
204        type EF = <SC as StarkGenericConfig>::Challenge;
205        type Challenger = <SC as StarkGenericConfig>::Challenger;
206        type ScPcs = <SC as StarkGenericConfig>::Pcs;
207
208        let mut rng = thread_rng();
209        let config = SC::default();
210        let pcs = config.pcs();
211        let natural_domain_for_degree = |degree: usize| -> Dom<SC> {
212            <ScPcs as Pcs<EF, Challenger>>::natural_domain_for_degree(pcs, degree)
213        };
214
215        // Initialize a builder.
216        let mut builder = AsmBuilder::<F, EF>::default();
217
218        let config_var = const_fri_config(&mut builder, &inner_fri_config());
219        for i in 0..5 {
220            let log_d_val = 10 + i;
221
222            let log_quotient_degree = 2;
223
224            // Initialize a reference doamin.
225            let domain_val = natural_domain_for_degree(1 << log_d_val);
226            let domain = builder.constant(domain_val);
227
228            // builder.assert_felt_eq(domain.shift, domain_val.shift);
229            let zeta_val = rng.gen::<EF>();
230            domain_assertions(&mut builder, &domain, &domain_val, zeta_val);
231
232            // Try a shifted domain.
233            let disjoint_domain_val =
234                domain_val.create_disjoint_domain(1 << (log_d_val + log_quotient_degree));
235            let disjoint_domain = builder.constant(disjoint_domain_val);
236            domain_assertions(&mut builder, &disjoint_domain, &disjoint_domain_val, zeta_val);
237
238            let log_degree: Usize<_> = builder.eval(Usize::Const(log_d_val) + log_quotient_degree);
239            let disjoint_domain_gen =
240                domain.create_disjoint_domain(&mut builder, log_degree, Some(config_var.clone()));
241            domain_assertions(&mut builder, &disjoint_domain_gen, &disjoint_domain_val, zeta_val);
242
243            // Now try splited domains
244            let qc_domains_val = disjoint_domain_val.split_domains(1 << log_quotient_degree);
245            for dom_val in qc_domains_val.iter() {
246                let dom = builder.constant(*dom_val);
247                domain_assertions(&mut builder, &dom, dom_val, zeta_val);
248            }
249
250            // Test the splitting of domains by the builder.
251            let quotient_size: Usize<_> = builder.eval(1 << log_quotient_degree);
252            let log_quotient_degree: Usize<_> = builder.eval(log_quotient_degree);
253            let qc_domains =
254                disjoint_domain.split_domains(&mut builder, log_quotient_degree, quotient_size);
255            for (i, dom_val) in qc_domains_val.iter().enumerate() {
256                let dom = builder.get(&qc_domains, i);
257                domain_assertions(&mut builder, &dom, dom_val, zeta_val);
258            }
259        }
260        builder.halt();
261
262        let program = builder.compile_program();
263        run_test_recursion(program, None, TestConfig::All);
264    }
265}