1use p3_field::AbstractField;
2use sp1_recursion_compiler::prelude::{
3 Array, Builder, Config, DslVariable, Ext, Felt, MemIndex, MemVariable, Ptr, Usize, Var,
4 Variable,
5};
6use sp1_recursion_core::runtime::{DIGEST_SIZE, HASH_RATE, PERMUTATION_WIDTH};
7
8use crate::{fri::types::DigestVariable, types::VerifyingKeyVariable};
9
10pub trait CanObserveVariable<C: Config, V> {
12 fn observe(&mut self, builder: &mut Builder<C>, value: V);
13
14 fn observe_slice(&mut self, builder: &mut Builder<C>, values: Array<C, V>);
15}
16
17pub trait CanSampleVariable<C: Config, V> {
18 fn sample(&mut self, builder: &mut Builder<C>) -> V;
19}
20
21pub trait FeltChallenger<C: Config>:
23 CanObserveVariable<C, Felt<C::F>> + CanSampleVariable<C, Felt<C::F>> + CanSampleBitsVariable<C>
24{
25 fn sample_ext(&mut self, builder: &mut Builder<C>) -> Ext<C::F, C::EF>;
26}
27
28pub trait CanSampleBitsVariable<C: Config> {
29 fn sample_bits(
30 &mut self,
31 builder: &mut Builder<C>,
32 nb_bits: Usize<C::N>,
33 ) -> Array<C, Var<C::N>>;
34}
35
36#[derive(Clone, DslVariable)]
38pub struct DuplexChallengerVariable<C: Config> {
39 pub sponge_state: Array<C, Felt<C::F>>,
40 pub nb_inputs: Var<C::N>,
41 pub input_buffer: Array<C, Felt<C::F>>,
42 pub nb_outputs: Var<C::N>,
43 pub output_buffer: Array<C, Felt<C::F>>,
44}
45
46impl<C: Config> DuplexChallengerVariable<C> {
47 pub fn new(builder: &mut Builder<C>) -> Self {
49 let mut result = DuplexChallengerVariable::<C> {
50 sponge_state: builder.dyn_array(PERMUTATION_WIDTH),
51 nb_inputs: builder.eval(C::N::zero()),
52 input_buffer: builder.dyn_array(PERMUTATION_WIDTH),
53 nb_outputs: builder.eval(C::N::zero()),
54 output_buffer: builder.dyn_array(PERMUTATION_WIDTH),
55 };
56
57 builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
59 builder.set(&mut result.sponge_state, i, C::F::zero());
60 builder.set(&mut result.input_buffer, i, C::F::zero());
61 builder.set(&mut result.output_buffer, i, C::F::zero());
62 });
63 result
64 }
65
66 pub fn copy(&self, builder: &mut Builder<C>) -> Self {
68 let mut sponge_state = builder.dyn_array(PERMUTATION_WIDTH);
69 builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
70 let element = builder.get(&self.sponge_state, i);
71 builder.set(&mut sponge_state, i, element);
72 });
73 let nb_inputs = builder.eval(self.nb_inputs);
74 let mut input_buffer = builder.dyn_array(PERMUTATION_WIDTH);
75 builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
76 let element = builder.get(&self.input_buffer, i);
77 builder.set(&mut input_buffer, i, element);
78 });
79 let nb_outputs = builder.eval(self.nb_outputs);
80 let mut output_buffer = builder.dyn_array(PERMUTATION_WIDTH);
81 builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
82 let element = builder.get(&self.output_buffer, i);
83 builder.set(&mut output_buffer, i, element);
84 });
85 DuplexChallengerVariable::<C> {
86 sponge_state,
87 nb_inputs,
88 input_buffer,
89 nb_outputs,
90 output_buffer,
91 }
92 }
93
94 pub fn assert_eq(&self, builder: &mut Builder<C>, other: &Self) {
96 builder.assert_var_eq(self.nb_inputs, other.nb_inputs);
97 builder.assert_var_eq(self.nb_outputs, other.nb_outputs);
98 builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
99 let element = builder.get(&self.sponge_state, i);
100 let other_element = builder.get(&other.sponge_state, i);
101 builder.assert_felt_eq(element, other_element);
102 });
103 builder.range(0, self.nb_inputs).for_each(|i, builder| {
104 let element = builder.get(&self.input_buffer, i);
105 let other_element = builder.get(&other.input_buffer, i);
106 builder.assert_felt_eq(element, other_element);
107 });
108 builder.range(0, self.nb_outputs).for_each(|i, builder| {
109 let element = builder.get(&self.output_buffer, i);
110 let other_element = builder.get(&other.output_buffer, i);
111 builder.assert_felt_eq(element, other_element);
112 });
113 }
114
115 pub fn reset(&mut self, builder: &mut Builder<C>) {
116 let zero: Var<_> = builder.eval(C::N::zero());
117 let zero_felt: Felt<_> = builder.eval(C::F::zero());
118 builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
119 builder.set(&mut self.sponge_state, i, zero_felt);
120 });
121 builder.assign(self.nb_inputs, zero);
122 builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
123 builder.set(&mut self.input_buffer, i, zero_felt);
124 });
125 builder.assign(self.nb_outputs, zero);
126 builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
127 builder.set(&mut self.output_buffer, i, zero_felt);
128 });
129 }
130
131 pub fn duplexing(&mut self, builder: &mut Builder<C>) {
132 builder.range(0, self.nb_inputs).for_each(|i, builder| {
133 let element = builder.get(&self.input_buffer, i);
134 builder.set(&mut self.sponge_state, i, element);
135 });
136 builder.assign(self.nb_inputs, C::N::zero());
137
138 builder.poseidon2_permute_mut(&self.sponge_state);
139
140 builder.assign(self.nb_outputs, C::N::zero());
141
142 for i in 0..PERMUTATION_WIDTH {
143 let element = builder.get(&self.sponge_state, i);
144 builder.set(&mut self.output_buffer, i, element);
145 builder.assign(self.nb_outputs, self.nb_outputs + C::N::one());
146 }
147 }
148
149 fn observe(&mut self, builder: &mut Builder<C>, value: Felt<C::F>) {
150 builder.assign(self.nb_outputs, C::N::zero());
151
152 builder.set(&mut self.input_buffer, self.nb_inputs, value);
153 builder.assign(self.nb_inputs, self.nb_inputs + C::N::one());
154
155 builder.if_eq(self.nb_inputs, C::N::from_canonical_usize(HASH_RATE)).then(|builder| {
156 self.duplexing(builder);
157 })
158 }
159
160 fn observe_commitment(&mut self, builder: &mut Builder<C>, commitment: DigestVariable<C>) {
161 for i in 0..DIGEST_SIZE {
162 let element = builder.get(&commitment, i);
163 self.observe(builder, element);
164 }
165 }
166
167 fn sample(&mut self, builder: &mut Builder<C>) -> Felt<C::F> {
168 let zero: Var<_> = builder.eval(C::N::zero());
169 builder.if_ne(self.nb_inputs, zero).then_or_else(
170 |builder| {
171 self.clone().duplexing(builder);
172 },
173 |builder| {
174 builder.if_eq(self.nb_outputs, zero).then(|builder| {
175 self.clone().duplexing(builder);
176 });
177 },
178 );
179 let idx: Var<_> = builder.eval(self.nb_outputs - C::N::one());
180 let output = builder.get(&self.output_buffer, idx);
181 builder.assign(self.nb_outputs, self.nb_outputs - C::N::one());
182 output
183 }
184
185 fn sample_ext(&mut self, builder: &mut Builder<C>) -> Ext<C::F, C::EF> {
186 let a = self.sample(builder);
187 let b = self.sample(builder);
188 let c = self.sample(builder);
189 let d = self.sample(builder);
190 builder.ext_from_base_slice(&[a, b, c, d])
191 }
192
193 fn sample_bits(
194 &mut self,
195 builder: &mut Builder<C>,
196 nb_bits: Usize<C::N>,
197 ) -> Array<C, Var<C::N>> {
198 let rand_f = self.sample(builder);
199 let mut bits = builder.num2bits_f(rand_f);
200
201 builder.range(nb_bits, bits.len()).for_each(|i, builder| {
202 builder.set(&mut bits, i, C::N::zero());
203 });
204
205 bits
206 }
207
208 pub fn check_witness(
209 &mut self,
210 builder: &mut Builder<C>,
211 nb_bits: Var<C::N>,
212 witness: Felt<C::F>,
213 ) {
214 self.observe(builder, witness);
215 let element_bits = self.sample_bits(builder, nb_bits.into());
216 builder.range(0, nb_bits).for_each(|i, builder| {
217 let element = builder.get(&element_bits, i);
218 builder.assert_var_eq(element, C::N::zero());
219 });
220 }
221}
222
223impl<C: Config> CanObserveVariable<C, Felt<C::F>> for DuplexChallengerVariable<C> {
224 fn observe(&mut self, builder: &mut Builder<C>, value: Felt<C::F>) {
225 DuplexChallengerVariable::observe(self, builder, value);
226 }
227
228 fn observe_slice(&mut self, builder: &mut Builder<C>, values: Array<C, Felt<C::F>>) {
229 match values {
230 Array::Dyn(_, len) => {
231 builder.range(0, len).for_each(|i, builder| {
232 let element = builder.get(&values, i);
233 self.observe(builder, element);
234 });
235 }
236 Array::Fixed(values) => {
237 values.iter().for_each(|value| {
238 self.observe(builder, *value);
239 });
240 }
241 }
242 }
243}
244
245impl<C: Config> CanSampleVariable<C, Felt<C::F>> for DuplexChallengerVariable<C> {
246 fn sample(&mut self, builder: &mut Builder<C>) -> Felt<C::F> {
247 DuplexChallengerVariable::sample(self, builder)
248 }
249}
250
251impl<C: Config> CanSampleBitsVariable<C> for DuplexChallengerVariable<C> {
252 fn sample_bits(
253 &mut self,
254 builder: &mut Builder<C>,
255 nb_bits: Usize<C::N>,
256 ) -> Array<C, Var<C::N>> {
257 DuplexChallengerVariable::sample_bits(self, builder, nb_bits)
258 }
259}
260
261impl<C: Config> CanObserveVariable<C, DigestVariable<C>> for DuplexChallengerVariable<C> {
262 fn observe(&mut self, builder: &mut Builder<C>, commitment: DigestVariable<C>) {
263 DuplexChallengerVariable::observe_commitment(self, builder, commitment);
264 }
265
266 fn observe_slice(&mut self, _builder: &mut Builder<C>, _values: Array<C, DigestVariable<C>>) {
267 todo!()
268 }
269}
270
271impl<C: Config> CanObserveVariable<C, VerifyingKeyVariable<C>> for DuplexChallengerVariable<C> {
272 fn observe(&mut self, builder: &mut Builder<C>, value: VerifyingKeyVariable<C>) {
273 self.observe_commitment(builder, value.commitment);
274 self.observe(builder, value.pc_start)
275 }
276
277 fn observe_slice(
278 &mut self,
279 _builder: &mut Builder<C>,
280 _values: Array<C, VerifyingKeyVariable<C>>,
281 ) {
282 todo!()
283 }
284}
285
286impl<C: Config> FeltChallenger<C> for DuplexChallengerVariable<C> {
287 fn sample_ext(&mut self, builder: &mut Builder<C>) -> Ext<C::F, C::EF> {
288 DuplexChallengerVariable::sample_ext(self, builder)
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use p3_challenger::{CanObserve, CanSample};
295 use p3_field::AbstractField;
296
297 use sp1_recursion_compiler::{
298 asm::{AsmBuilder, AsmConfig},
299 ir::{Felt, Usize, Var},
300 };
301
302 use sp1_recursion_core::{
303 runtime::PERMUTATION_WIDTH,
304 stark::utils::{run_test_recursion, TestConfig},
305 };
306 use sp1_stark::{baby_bear_poseidon2::BabyBearPoseidon2, StarkGenericConfig};
307
308 use crate::challenger::DuplexChallengerVariable;
309
310 #[test]
311 fn test_compiler_challenger() {
312 type SC = BabyBearPoseidon2;
313 type F = <SC as StarkGenericConfig>::Val;
314 type EF = <SC as StarkGenericConfig>::Challenge;
315
316 let config = SC::default();
317 let mut challenger = config.challenger();
318 challenger.observe(F::one());
319 challenger.observe(F::two());
320 challenger.observe(F::two());
321 challenger.observe(F::two());
322 let result: F = challenger.sample();
323 println!("expected result: {}", result);
324
325 let mut builder = AsmBuilder::<F, EF>::default();
326
327 let width: Var<_> = builder.eval(F::from_canonical_usize(PERMUTATION_WIDTH));
328 let mut challenger = DuplexChallengerVariable::<AsmConfig<F, EF>> {
329 sponge_state: builder.array(Usize::Var(width)),
330 nb_inputs: builder.eval(F::zero()),
331 input_buffer: builder.array(Usize::Var(width)),
332 nb_outputs: builder.eval(F::zero()),
333 output_buffer: builder.array(Usize::Var(width)),
334 };
335 let one: Felt<_> = builder.eval(F::one());
336 let two: Felt<_> = builder.eval(F::two());
337 builder.halt();
338 challenger.observe(&mut builder, one);
339 challenger.observe(&mut builder, two);
340 challenger.observe(&mut builder, two);
341 challenger.observe(&mut builder, two);
342 let element = challenger.sample(&mut builder);
343
344 let expected_result: Felt<_> = builder.eval(result);
345 builder.assert_felt_eq(expected_result, element);
346
347 let program = builder.compile_program();
348 run_test_recursion(program, None, TestConfig::All);
349 }
350}