Skip to main content

sp1_gpu_basefold/
grinding_challenger.rs

1use serde::{Deserialize, Serialize};
2use slop_algebra::{Field, PrimeField, PrimeField31, PrimeField32, PrimeField64};
3use slop_challenger::GrindingChallenger;
4use slop_symmetric::CryptographicPermutation;
5use sp1_gpu_challenger::{
6    grind_koala_bear_challenger_on_device, grind_multi_field32_challenger_on_device,
7    KoalaBearDuplexChallenger,
8};
9use sp1_gpu_cudart::TaskScope;
10
11/// A [`GrindingChallenger`] that can also grind on device.
12pub trait DeviceGrindingChallenger: GrindingChallenger {
13    /// Grinds on device.
14    fn grind_device(&mut self, bits: usize, scope: &TaskScope) -> Self::Witness;
15}
16
17// Concrete implementation for KoalaBear DuplexChallenger - uses GPU grinding
18impl DeviceGrindingChallenger for KoalaBearDuplexChallenger {
19    fn grind_device(&mut self, bits: usize, scope: &TaskScope) -> Self::Witness {
20        grind_koala_bear_challenger_on_device(self, bits, scope)
21    }
22}
23
24// GPU implementation for MultiField32Challenger
25impl<F, PF, P, const WIDTH: usize, const RATE: usize> DeviceGrindingChallenger
26    for slop_challenger::MultiField32Challenger<F, PF, P, WIDTH, RATE>
27where
28    F: PrimeField64 + PrimeField31 + PrimeField32 + Send + Sync,
29    PF: PrimeField + Field + Send + Sync,
30    P: CryptographicPermutation<[PF; WIDTH]> + Send + Sync,
31{
32    fn grind_device(&mut self, bits: usize, scope: &TaskScope) -> Self::Witness {
33        grind_multi_field32_challenger_on_device(self, bits, scope)
34    }
35}
36
37#[derive(
38    Debug, Clone, Copy, Default, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize,
39)]
40pub struct GrindingPowCudaProver;
41
42impl GrindingPowCudaProver {
43    pub fn grind<C: DeviceGrindingChallenger + Send + Sync>(
44        challenger: &mut C,
45        bits: usize,
46        scope: &TaskScope,
47    ) -> C::Witness {
48        challenger.grind_device(bits, scope)
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use crate::grinding_challenger::DeviceGrindingChallenger;
55    use slop_algebra::AbstractField;
56    use slop_challenger::{CanObserve, CanSample, GrindingChallenger};
57    use slop_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral};
58    use sp1_hypercube::inner_perm;
59    use sp1_primitives::{SP1DiffusionMatrix, SP1Field};
60
61    pub type Perm = Poseidon2<SP1Field, Poseidon2ExternalMatrixGeneral, SP1DiffusionMatrix, 16, 3>;
62
63    #[test]
64    fn test_grinding() {
65        sp1_gpu_cudart::run_sync_in_place(|t| {
66            for bits in 1..20 {
67                let default_perm = inner_perm();
68                let mut challenger =
69                    slop_challenger::DuplexChallenger::<SP1Field, Perm, 16, 8>::new(default_perm);
70
71                // Observe 7 elements to make the input buffer almost full and trigger duplexing on
72                challenger.observe(SP1Field::from_canonical_u32(0));
73                challenger.observe(SP1Field::from_canonical_u32(1));
74                challenger.observe(SP1Field::from_canonical_u32(2));
75                challenger.observe(SP1Field::from_canonical_u32(3));
76                challenger.observe(SP1Field::from_canonical_u32(4));
77                challenger.observe(SP1Field::from_canonical_u32(5));
78                challenger.observe(SP1Field::from_canonical_u32(6));
79                challenger.observe(SP1Field::from_canonical_u32(7));
80
81                // Make another challenger that also samples before grinding (this empties the input buffer).
82                let mut challenger_2 = challenger.clone();
83                let _: SP1Field = challenger.sample();
84
85                let mut original_challenger = challenger.clone();
86                let result = challenger.grind_device(bits, &t);
87
88                assert!(original_challenger.check_witness(bits, result));
89
90                let mut original_challenger_2 = challenger_2.clone();
91                let result_2 = challenger_2.grind_device(bits, &t);
92
93                assert!(original_challenger_2.check_witness(bits, result_2));
94
95                // Checks to make sure the pow witness was properly observed in `grind_on_device`.
96                assert!(original_challenger_2.sponge_state == challenger_2.sponge_state);
97                assert!(original_challenger_2.input_buffer == challenger_2.input_buffer);
98                assert!(original_challenger_2.output_buffer == challenger_2.output_buffer);
99            }
100        })
101        .unwrap()
102    }
103
104    #[test]
105    fn test_grinding_multi_field32() {
106        use slop_bn254::{
107            outer_perm, Bn254Fr, OuterPerm, OUTER_CHALLENGER_RATE, OUTER_CHALLENGER_STATE_WIDTH,
108        };
109
110        sp1_gpu_cudart::run_sync_in_place(|t| {
111            for bits in 1..10 {
112                let perm = outer_perm();
113                let mut challenger = slop_challenger::MultiField32Challenger::<
114                    SP1Field,
115                    Bn254Fr,
116                    OuterPerm,
117                    OUTER_CHALLENGER_STATE_WIDTH,
118                    OUTER_CHALLENGER_RATE,
119                >::new(perm)
120                .unwrap();
121
122                // Observe some elements
123                challenger.observe(SP1Field::from_canonical_u32(0));
124                challenger.observe(SP1Field::from_canonical_u32(1));
125                challenger.observe(SP1Field::from_canonical_u32(2));
126                challenger.observe(SP1Field::from_canonical_u32(3));
127
128                // Make another challenger that also samples before grinding.
129                let mut challenger_2 = challenger.clone();
130                let _: SP1Field = challenger.sample();
131
132                let mut original_challenger = challenger.clone();
133                let result = challenger.grind_device(bits, &t);
134
135                assert!(original_challenger.check_witness(bits, result));
136
137                let mut original_challenger_2 = challenger_2.clone();
138                let result_2 = challenger_2.grind_device(bits, &t);
139
140                assert!(original_challenger_2.check_witness(bits, result_2));
141
142                // Checks to make sure the pow witness was properly observed in `grind_on_device`.
143                assert!(original_challenger_2.sponge_state == challenger_2.sponge_state);
144                assert!(original_challenger_2.input_buffer == challenger_2.input_buffer);
145                assert!(original_challenger_2.output_buffer == challenger_2.output_buffer);
146            }
147        })
148        .unwrap()
149    }
150}