sp1_recursion_program/fri/
domain.rs1use p3_commit::{LagrangeSelectors, TwoAdicMultiplicativeCoset};
2use p3_field::{AbstractField, TwoAdicField};
3use sp1_recursion_compiler::prelude::*;
4
5use super::types::FriConfigVariable;
6use crate::commit::PolynomialSpaceVariable;
7
8#[derive(DslVariable, Clone, Copy)]
10pub struct TwoAdicMultiplicativeCosetVariable<C: Config> {
11 pub log_n: Var<C::N>,
12 pub size: Var<C::N>,
13 pub shift: Felt<C::F>,
14 pub g: Felt<C::F>,
15}
16
17impl<C: Config> TwoAdicMultiplicativeCosetVariable<C> {
18 pub const fn size(&self) -> Var<C::N> {
19 self.size
20 }
21
22 pub const fn first_point(&self) -> Felt<C::F> {
23 self.shift
24 }
25
26 pub const fn gen(&self) -> Felt<C::F> {
27 self.g
28 }
29}
30
31impl<C: Config> FromConstant<C> for TwoAdicMultiplicativeCosetVariable<C>
32where
33 C::F: TwoAdicField,
34{
35 type Constant = TwoAdicMultiplicativeCoset<C::F>;
36
37 fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
38 let log_d_val = value.log_n as u32;
39 let g_val = C::F::two_adic_generator(value.log_n);
40 TwoAdicMultiplicativeCosetVariable::<C> {
41 log_n: builder.eval::<Var<_>, _>(C::N::from_canonical_u32(log_d_val)),
42 size: builder.eval::<Var<_>, _>(C::N::from_canonical_u32(1 << (log_d_val))),
43 shift: builder.eval(value.shift),
44 g: builder.eval(g_val),
45 }
46 }
47}
48
49impl<C: Config> PolynomialSpaceVariable<C> for TwoAdicMultiplicativeCosetVariable<C>
50where
51 C::F: TwoAdicField,
52{
53 type Constant = p3_commit::TwoAdicMultiplicativeCoset<C::F>;
54
55 fn next_point(
56 &self,
57 builder: &mut Builder<C>,
58 point: Ext<<C as Config>::F, <C as Config>::EF>,
59 ) -> Ext<<C as Config>::F, <C as Config>::EF> {
60 builder.eval(point * self.gen())
61 }
62
63 fn selectors_at_point(
64 &self,
65 builder: &mut Builder<C>,
66 point: Ext<<C as Config>::F, <C as Config>::EF>,
67 ) -> LagrangeSelectors<Ext<<C as Config>::F, <C as Config>::EF>> {
68 let unshifted_point: Ext<_, _> = builder.eval(point * self.shift.inverse());
69 let z_h_expr = builder
70 .exp_power_of_2_v::<Ext<_, _>>(unshifted_point, Usize::Var(self.log_n))
71 - C::EF::one();
72 let z_h: Ext<_, _> = builder.eval(z_h_expr);
73
74 LagrangeSelectors {
75 is_first_row: builder.eval(z_h / (unshifted_point - C::EF::one())),
76 is_last_row: builder.eval(z_h / (unshifted_point - self.gen().inverse())),
77 is_transition: builder.eval(unshifted_point - self.gen().inverse()),
78 inv_zeroifier: builder.eval(z_h.inverse()),
79 }
80 }
81
82 fn zp_at_point(
83 &self,
84 builder: &mut Builder<C>,
85 point: Ext<<C as Config>::F, <C as Config>::EF>,
86 ) -> Ext<<C as Config>::F, <C as Config>::EF> {
87 let unshifted_power = builder
88 .exp_power_of_2_v::<Ext<_, _>>(point * self.shift.inverse(), Usize::Var(self.log_n));
89 builder.eval(unshifted_power - C::EF::one())
90 }
91
92 fn split_domains(
93 &self,
94 builder: &mut Builder<C>,
95 log_num_chunks: impl Into<Usize<C::N>>,
96 num_chunks: impl Into<Usize<C::N>>,
97 ) -> Array<C, Self> {
98 let log_num_chunks = log_num_chunks.into();
99 let num_chunks = num_chunks.into();
100 let log_n: Var<_> = builder.eval(self.log_n - log_num_chunks);
101 let size = builder.sll(C::N::one(), Usize::Var(log_n));
102
103 let g_dom = self.gen();
104 let g = builder.exp_power_of_2_v::<Felt<C::F>>(g_dom, log_num_chunks);
105
106 let domain_power: Felt<_> = builder.eval(C::F::one());
107
108 let mut domains = builder.dyn_array(num_chunks);
109
110 builder.range(0, num_chunks).for_each(|i, builder| {
111 let domain = TwoAdicMultiplicativeCosetVariable {
112 log_n,
113 size,
114 shift: builder.eval(self.shift * domain_power),
115 g,
116 };
117 builder.set(&mut domains, i, domain);
118 builder.assign(domain_power, domain_power * g_dom);
119 });
120
121 domains
122 }
123
124 fn split_domains_const(&self, builder: &mut Builder<C>, log_num_chunks: usize) -> Vec<Self> {
125 let num_chunks = 1 << log_num_chunks;
126 let log_n: Var<_> = builder.eval(self.log_n - C::N::from_canonical_usize(log_num_chunks));
127 let size = builder.sll(C::N::one(), Usize::Var(log_n));
128
129 let g_dom = self.gen();
130 let g = builder.exp_power_of_2_v::<Felt<C::F>>(g_dom, log_num_chunks);
131
132 let domain_power: Felt<_> = builder.eval(C::F::one());
133 let mut domains = vec![];
134
135 for _ in 0..num_chunks {
136 domains.push(TwoAdicMultiplicativeCosetVariable {
137 log_n,
138 size,
139 shift: builder.eval(self.shift * domain_power),
140 g,
141 });
142 builder.assign(domain_power, domain_power * g_dom);
143 }
144 domains
145 }
146
147 fn create_disjoint_domain(
148 &self,
149 builder: &mut Builder<C>,
150 log_degree: Usize<<C as Config>::N>,
151 config: Option<FriConfigVariable<C>>,
152 ) -> Self {
153 let domain = config.unwrap().get_subgroup(builder, log_degree);
154 builder.assign(domain.shift, self.shift * C::F::generator());
155 domain
156 }
157}
158
159#[cfg(test)]
160pub(crate) mod tests {
161
162 use sp1_recursion_compiler::asm::AsmBuilder;
163 use sp1_recursion_core::stark::utils::{run_test_recursion, TestConfig};
164 use sp1_stark::{
165 baby_bear_poseidon2::BabyBearPoseidon2, inner_fri_config, Dom, StarkGenericConfig,
166 };
167
168 use crate::utils::const_fri_config;
169
170 use super::*;
171 use p3_commit::{Pcs, PolynomialSpace};
172 use rand::{thread_rng, Rng};
173
174 pub(crate) fn domain_assertions<F: TwoAdicField, C: Config<N = F, F = F>>(
175 builder: &mut Builder<C>,
176 domain: &TwoAdicMultiplicativeCosetVariable<C>,
177 domain_val: &TwoAdicMultiplicativeCoset<F>,
178 zeta_val: C::EF,
179 ) {
180 builder.assert_var_eq(domain.log_n, F::from_canonical_usize(domain_val.log_n));
182 builder.assert_var_eq(domain.size, F::from_canonical_usize(1 << domain_val.log_n));
183 builder.assert_felt_eq(domain.shift, domain_val.shift);
184
185 let zeta: Ext<_, _> = builder.eval(zeta_val.cons());
187
188 let sels_expected = domain_val.selectors_at_point(zeta_val);
190 let sels = domain.selectors_at_point(builder, zeta);
191 builder.assert_ext_eq(sels.is_first_row, sels_expected.is_first_row.cons());
192 builder.assert_ext_eq(sels.is_last_row, sels_expected.is_last_row.cons());
193 builder.assert_ext_eq(sels.is_transition, sels_expected.is_transition.cons());
194
195 let zp_val = domain_val.zp_at_point(zeta_val);
196 let zp = domain.zp_at_point(builder, zeta);
197 builder.assert_ext_eq(zp, zp_val.cons());
198 }
199
200 #[test]
201 fn test_domain() {
202 type SC = BabyBearPoseidon2;
203 type F = <SC as StarkGenericConfig>::Val;
204 type EF = <SC as StarkGenericConfig>::Challenge;
205 type Challenger = <SC as StarkGenericConfig>::Challenger;
206 type ScPcs = <SC as StarkGenericConfig>::Pcs;
207
208 let mut rng = thread_rng();
209 let config = SC::default();
210 let pcs = config.pcs();
211 let natural_domain_for_degree = |degree: usize| -> Dom<SC> {
212 <ScPcs as Pcs<EF, Challenger>>::natural_domain_for_degree(pcs, degree)
213 };
214
215 let mut builder = AsmBuilder::<F, EF>::default();
217
218 let config_var = const_fri_config(&mut builder, &inner_fri_config());
219 for i in 0..5 {
220 let log_d_val = 10 + i;
221
222 let log_quotient_degree = 2;
223
224 let domain_val = natural_domain_for_degree(1 << log_d_val);
226 let domain = builder.constant(domain_val);
227
228 let zeta_val = rng.gen::<EF>();
230 domain_assertions(&mut builder, &domain, &domain_val, zeta_val);
231
232 let disjoint_domain_val =
234 domain_val.create_disjoint_domain(1 << (log_d_val + log_quotient_degree));
235 let disjoint_domain = builder.constant(disjoint_domain_val);
236 domain_assertions(&mut builder, &disjoint_domain, &disjoint_domain_val, zeta_val);
237
238 let log_degree: Usize<_> = builder.eval(Usize::Const(log_d_val) + log_quotient_degree);
239 let disjoint_domain_gen =
240 domain.create_disjoint_domain(&mut builder, log_degree, Some(config_var.clone()));
241 domain_assertions(&mut builder, &disjoint_domain_gen, &disjoint_domain_val, zeta_val);
242
243 let qc_domains_val = disjoint_domain_val.split_domains(1 << log_quotient_degree);
245 for dom_val in qc_domains_val.iter() {
246 let dom = builder.constant(*dom_val);
247 domain_assertions(&mut builder, &dom, dom_val, zeta_val);
248 }
249
250 let quotient_size: Usize<_> = builder.eval(1 << log_quotient_degree);
252 let log_quotient_degree: Usize<_> = builder.eval(log_quotient_degree);
253 let qc_domains =
254 disjoint_domain.split_domains(&mut builder, log_quotient_degree, quotient_size);
255 for (i, dom_val) in qc_domains_val.iter().enumerate() {
256 let dom = builder.get(&qc_domains, i);
257 domain_assertions(&mut builder, &dom, dom_val, zeta_val);
258 }
259 }
260 builder.halt();
261
262 let program = builder.compile_program();
263 run_test_recursion(program, None, TestConfig::All);
264 }
265}