sp1_recursion_program/
challenger.rs

1use p3_field::AbstractField;
2use sp1_recursion_compiler::prelude::{
3    Array, Builder, Config, DslVariable, Ext, Felt, MemIndex, MemVariable, Ptr, Usize, Var,
4    Variable,
5};
6use sp1_recursion_core::runtime::{DIGEST_SIZE, HASH_RATE, PERMUTATION_WIDTH};
7
8use crate::{fri::types::DigestVariable, types::VerifyingKeyVariable};
9
10/// Reference: [p3_challenger::CanObserve].
11pub trait CanObserveVariable<C: Config, V> {
12    fn observe(&mut self, builder: &mut Builder<C>, value: V);
13
14    fn observe_slice(&mut self, builder: &mut Builder<C>, values: Array<C, V>);
15}
16
17pub trait CanSampleVariable<C: Config, V> {
18    fn sample(&mut self, builder: &mut Builder<C>) -> V;
19}
20
21/// Reference: [p3_challenger::FieldChallenger].
22pub trait FeltChallenger<C: Config>:
23    CanObserveVariable<C, Felt<C::F>> + CanSampleVariable<C, Felt<C::F>> + CanSampleBitsVariable<C>
24{
25    fn sample_ext(&mut self, builder: &mut Builder<C>) -> Ext<C::F, C::EF>;
26}
27
28pub trait CanSampleBitsVariable<C: Config> {
29    fn sample_bits(
30        &mut self,
31        builder: &mut Builder<C>,
32        nb_bits: Usize<C::N>,
33    ) -> Array<C, Var<C::N>>;
34}
35
36/// Reference: [p3_challenger::DuplexChallenger]
37#[derive(Clone, DslVariable)]
38pub struct DuplexChallengerVariable<C: Config> {
39    pub sponge_state: Array<C, Felt<C::F>>,
40    pub nb_inputs: Var<C::N>,
41    pub input_buffer: Array<C, Felt<C::F>>,
42    pub nb_outputs: Var<C::N>,
43    pub output_buffer: Array<C, Felt<C::F>>,
44}
45
46impl<C: Config> DuplexChallengerVariable<C> {
47    /// Creates a new duplex challenger with the default state.
48    pub fn new(builder: &mut Builder<C>) -> Self {
49        let mut result = DuplexChallengerVariable::<C> {
50            sponge_state: builder.dyn_array(PERMUTATION_WIDTH),
51            nb_inputs: builder.eval(C::N::zero()),
52            input_buffer: builder.dyn_array(PERMUTATION_WIDTH),
53            nb_outputs: builder.eval(C::N::zero()),
54            output_buffer: builder.dyn_array(PERMUTATION_WIDTH),
55        };
56
57        // Constrain the state of the challenger to contain all zeroes.
58        builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
59            builder.set(&mut result.sponge_state, i, C::F::zero());
60            builder.set(&mut result.input_buffer, i, C::F::zero());
61            builder.set(&mut result.output_buffer, i, C::F::zero());
62        });
63        result
64    }
65
66    /// Creates a new challenger with the same state as an existing challenger.
67    pub fn copy(&self, builder: &mut Builder<C>) -> Self {
68        let mut sponge_state = builder.dyn_array(PERMUTATION_WIDTH);
69        builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
70            let element = builder.get(&self.sponge_state, i);
71            builder.set(&mut sponge_state, i, element);
72        });
73        let nb_inputs = builder.eval(self.nb_inputs);
74        let mut input_buffer = builder.dyn_array(PERMUTATION_WIDTH);
75        builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
76            let element = builder.get(&self.input_buffer, i);
77            builder.set(&mut input_buffer, i, element);
78        });
79        let nb_outputs = builder.eval(self.nb_outputs);
80        let mut output_buffer = builder.dyn_array(PERMUTATION_WIDTH);
81        builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
82            let element = builder.get(&self.output_buffer, i);
83            builder.set(&mut output_buffer, i, element);
84        });
85        DuplexChallengerVariable::<C> {
86            sponge_state,
87            nb_inputs,
88            input_buffer,
89            nb_outputs,
90            output_buffer,
91        }
92    }
93
94    /// Asserts that the state of this challenger is equal to the state of another challenger.
95    pub fn assert_eq(&self, builder: &mut Builder<C>, other: &Self) {
96        builder.assert_var_eq(self.nb_inputs, other.nb_inputs);
97        builder.assert_var_eq(self.nb_outputs, other.nb_outputs);
98        builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
99            let element = builder.get(&self.sponge_state, i);
100            let other_element = builder.get(&other.sponge_state, i);
101            builder.assert_felt_eq(element, other_element);
102        });
103        builder.range(0, self.nb_inputs).for_each(|i, builder| {
104            let element = builder.get(&self.input_buffer, i);
105            let other_element = builder.get(&other.input_buffer, i);
106            builder.assert_felt_eq(element, other_element);
107        });
108        builder.range(0, self.nb_outputs).for_each(|i, builder| {
109            let element = builder.get(&self.output_buffer, i);
110            let other_element = builder.get(&other.output_buffer, i);
111            builder.assert_felt_eq(element, other_element);
112        });
113    }
114
115    pub fn reset(&mut self, builder: &mut Builder<C>) {
116        let zero: Var<_> = builder.eval(C::N::zero());
117        let zero_felt: Felt<_> = builder.eval(C::F::zero());
118        builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
119            builder.set(&mut self.sponge_state, i, zero_felt);
120        });
121        builder.assign(self.nb_inputs, zero);
122        builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
123            builder.set(&mut self.input_buffer, i, zero_felt);
124        });
125        builder.assign(self.nb_outputs, zero);
126        builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
127            builder.set(&mut self.output_buffer, i, zero_felt);
128        });
129    }
130
131    pub fn duplexing(&mut self, builder: &mut Builder<C>) {
132        builder.range(0, self.nb_inputs).for_each(|i, builder| {
133            let element = builder.get(&self.input_buffer, i);
134            builder.set(&mut self.sponge_state, i, element);
135        });
136        builder.assign(self.nb_inputs, C::N::zero());
137
138        builder.poseidon2_permute_mut(&self.sponge_state);
139
140        builder.assign(self.nb_outputs, C::N::zero());
141
142        for i in 0..PERMUTATION_WIDTH {
143            let element = builder.get(&self.sponge_state, i);
144            builder.set(&mut self.output_buffer, i, element);
145            builder.assign(self.nb_outputs, self.nb_outputs + C::N::one());
146        }
147    }
148
149    fn observe(&mut self, builder: &mut Builder<C>, value: Felt<C::F>) {
150        builder.assign(self.nb_outputs, C::N::zero());
151
152        builder.set(&mut self.input_buffer, self.nb_inputs, value);
153        builder.assign(self.nb_inputs, self.nb_inputs + C::N::one());
154
155        builder.if_eq(self.nb_inputs, C::N::from_canonical_usize(HASH_RATE)).then(|builder| {
156            self.duplexing(builder);
157        })
158    }
159
160    fn observe_commitment(&mut self, builder: &mut Builder<C>, commitment: DigestVariable<C>) {
161        for i in 0..DIGEST_SIZE {
162            let element = builder.get(&commitment, i);
163            self.observe(builder, element);
164        }
165    }
166
167    fn sample(&mut self, builder: &mut Builder<C>) -> Felt<C::F> {
168        let zero: Var<_> = builder.eval(C::N::zero());
169        builder.if_ne(self.nb_inputs, zero).then_or_else(
170            |builder| {
171                self.clone().duplexing(builder);
172            },
173            |builder| {
174                builder.if_eq(self.nb_outputs, zero).then(|builder| {
175                    self.clone().duplexing(builder);
176                });
177            },
178        );
179        let idx: Var<_> = builder.eval(self.nb_outputs - C::N::one());
180        let output = builder.get(&self.output_buffer, idx);
181        builder.assign(self.nb_outputs, self.nb_outputs - C::N::one());
182        output
183    }
184
185    fn sample_ext(&mut self, builder: &mut Builder<C>) -> Ext<C::F, C::EF> {
186        let a = self.sample(builder);
187        let b = self.sample(builder);
188        let c = self.sample(builder);
189        let d = self.sample(builder);
190        builder.ext_from_base_slice(&[a, b, c, d])
191    }
192
193    fn sample_bits(
194        &mut self,
195        builder: &mut Builder<C>,
196        nb_bits: Usize<C::N>,
197    ) -> Array<C, Var<C::N>> {
198        let rand_f = self.sample(builder);
199        let mut bits = builder.num2bits_f(rand_f);
200
201        builder.range(nb_bits, bits.len()).for_each(|i, builder| {
202            builder.set(&mut bits, i, C::N::zero());
203        });
204
205        bits
206    }
207
208    pub fn check_witness(
209        &mut self,
210        builder: &mut Builder<C>,
211        nb_bits: Var<C::N>,
212        witness: Felt<C::F>,
213    ) {
214        self.observe(builder, witness);
215        let element_bits = self.sample_bits(builder, nb_bits.into());
216        builder.range(0, nb_bits).for_each(|i, builder| {
217            let element = builder.get(&element_bits, i);
218            builder.assert_var_eq(element, C::N::zero());
219        });
220    }
221}
222
223impl<C: Config> CanObserveVariable<C, Felt<C::F>> for DuplexChallengerVariable<C> {
224    fn observe(&mut self, builder: &mut Builder<C>, value: Felt<C::F>) {
225        DuplexChallengerVariable::observe(self, builder, value);
226    }
227
228    fn observe_slice(&mut self, builder: &mut Builder<C>, values: Array<C, Felt<C::F>>) {
229        match values {
230            Array::Dyn(_, len) => {
231                builder.range(0, len).for_each(|i, builder| {
232                    let element = builder.get(&values, i);
233                    self.observe(builder, element);
234                });
235            }
236            Array::Fixed(values) => {
237                values.iter().for_each(|value| {
238                    self.observe(builder, *value);
239                });
240            }
241        }
242    }
243}
244
245impl<C: Config> CanSampleVariable<C, Felt<C::F>> for DuplexChallengerVariable<C> {
246    fn sample(&mut self, builder: &mut Builder<C>) -> Felt<C::F> {
247        DuplexChallengerVariable::sample(self, builder)
248    }
249}
250
251impl<C: Config> CanSampleBitsVariable<C> for DuplexChallengerVariable<C> {
252    fn sample_bits(
253        &mut self,
254        builder: &mut Builder<C>,
255        nb_bits: Usize<C::N>,
256    ) -> Array<C, Var<C::N>> {
257        DuplexChallengerVariable::sample_bits(self, builder, nb_bits)
258    }
259}
260
261impl<C: Config> CanObserveVariable<C, DigestVariable<C>> for DuplexChallengerVariable<C> {
262    fn observe(&mut self, builder: &mut Builder<C>, commitment: DigestVariable<C>) {
263        DuplexChallengerVariable::observe_commitment(self, builder, commitment);
264    }
265
266    fn observe_slice(&mut self, _builder: &mut Builder<C>, _values: Array<C, DigestVariable<C>>) {
267        todo!()
268    }
269}
270
271impl<C: Config> CanObserveVariable<C, VerifyingKeyVariable<C>> for DuplexChallengerVariable<C> {
272    fn observe(&mut self, builder: &mut Builder<C>, value: VerifyingKeyVariable<C>) {
273        self.observe_commitment(builder, value.commitment);
274        self.observe(builder, value.pc_start)
275    }
276
277    fn observe_slice(
278        &mut self,
279        _builder: &mut Builder<C>,
280        _values: Array<C, VerifyingKeyVariable<C>>,
281    ) {
282        todo!()
283    }
284}
285
286impl<C: Config> FeltChallenger<C> for DuplexChallengerVariable<C> {
287    fn sample_ext(&mut self, builder: &mut Builder<C>) -> Ext<C::F, C::EF> {
288        DuplexChallengerVariable::sample_ext(self, builder)
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use p3_challenger::{CanObserve, CanSample};
295    use p3_field::AbstractField;
296
297    use sp1_recursion_compiler::{
298        asm::{AsmBuilder, AsmConfig},
299        ir::{Felt, Usize, Var},
300    };
301
302    use sp1_recursion_core::{
303        runtime::PERMUTATION_WIDTH,
304        stark::utils::{run_test_recursion, TestConfig},
305    };
306    use sp1_stark::{baby_bear_poseidon2::BabyBearPoseidon2, StarkGenericConfig};
307
308    use crate::challenger::DuplexChallengerVariable;
309
310    #[test]
311    fn test_compiler_challenger() {
312        type SC = BabyBearPoseidon2;
313        type F = <SC as StarkGenericConfig>::Val;
314        type EF = <SC as StarkGenericConfig>::Challenge;
315
316        let config = SC::default();
317        let mut challenger = config.challenger();
318        challenger.observe(F::one());
319        challenger.observe(F::two());
320        challenger.observe(F::two());
321        challenger.observe(F::two());
322        let result: F = challenger.sample();
323        println!("expected result: {}", result);
324
325        let mut builder = AsmBuilder::<F, EF>::default();
326
327        let width: Var<_> = builder.eval(F::from_canonical_usize(PERMUTATION_WIDTH));
328        let mut challenger = DuplexChallengerVariable::<AsmConfig<F, EF>> {
329            sponge_state: builder.array(Usize::Var(width)),
330            nb_inputs: builder.eval(F::zero()),
331            input_buffer: builder.array(Usize::Var(width)),
332            nb_outputs: builder.eval(F::zero()),
333            output_buffer: builder.array(Usize::Var(width)),
334        };
335        let one: Felt<_> = builder.eval(F::one());
336        let two: Felt<_> = builder.eval(F::two());
337        builder.halt();
338        challenger.observe(&mut builder, one);
339        challenger.observe(&mut builder, two);
340        challenger.observe(&mut builder, two);
341        challenger.observe(&mut builder, two);
342        let element = challenger.sample(&mut builder);
343
344        let expected_result: Felt<_> = builder.eval(result);
345        builder.assert_felt_eq(expected_result, element);
346
347        let program = builder.compile_program();
348        run_test_recursion(program, None, TestConfig::All);
349    }
350}