Skip to main content

sp1_gpu_commit/
commit.rs

1use std::{iter::once, sync::Arc};
2
3use slop_algebra::AbstractField;
4use slop_alloc::HasBackend;
5use slop_challenger::IopCtx;
6use slop_jagged::JaggedProverData;
7use slop_symmetric::{CryptographicHasher, PseudoCompressionFunction as _};
8use slop_tensor::Tensor;
9use sp1_gpu_basefold::{CudaStackedPcsProverData, FriCudaProver};
10use sp1_gpu_cudart::TaskScope;
11use sp1_gpu_merkle_tree::{CudaTcsProver, SingleLayerMerkleTreeProverError};
12use sp1_gpu_utils::{traces::JaggedTraceMle, Ext, Felt};
13
14/// TODO: document
15#[allow(clippy::type_complexity)]
16pub fn commit_multilinears<GC: IopCtx<F = Felt, EF = Ext>, P: CudaTcsProver<GC>>(
17    jagged_trace_mle: &JaggedTraceMle<Felt, TaskScope>,
18    max_log_row_count: u32,
19    use_preprocessed: bool,
20    drop_main_traces: bool,
21    basefold_prover: &FriCudaProver<GC, P, Felt>,
22) -> Result<
23    (GC::Digest, JaggedProverData<GC, CudaStackedPcsProverData<GC>>),
24    SingleLayerMerkleTreeProverError,
25> {
26    let (index, padding, dst) = if use_preprocessed {
27        (
28            &jagged_trace_mle.dense().preprocessed_table_index,
29            jagged_trace_mle.dense().preprocessed_padding,
30            Tensor::<Felt, TaskScope>::with_sizes_in(
31                [
32                    jagged_trace_mle.dense().preprocessed_offset >> basefold_prover.log_height,
33                    1 << (basefold_prover.log_height as usize
34                        + basefold_prover.config.log_blowup()),
35                ],
36                jagged_trace_mle.dense().dense.backend().clone(),
37            ),
38        )
39    } else {
40        (
41            &jagged_trace_mle.dense().main_table_index,
42            jagged_trace_mle.dense().main_padding,
43            Tensor::<Felt, TaskScope>::with_sizes_in(
44                [
45                    jagged_trace_mle.dense().main_size() >> basefold_prover.log_height,
46                    1 << (basefold_prover.log_height as usize
47                        + basefold_prover.config.log_blowup()),
48                ],
49                jagged_trace_mle.dense().dense.backend().clone(),
50            ),
51        )
52    };
53    let (mut row_counts, mut column_counts) = (
54        index.values().map(|x| x.poly_size).collect::<Vec<_>>(),
55        index.values().map(|x| x.num_polys).collect::<Vec<_>>(),
56    );
57
58    let drop_traces = drop_main_traces && !use_preprocessed;
59
60    let (commitment, data) =
61        basefold_prover.encode_and_commit(use_preprocessed, drop_traces, jagged_trace_mle, dst)?;
62
63    let num_added_cols = padding.div_ceil(1 << max_log_row_count).max(1);
64
65    row_counts.push(1 << max_log_row_count);
66    row_counts.push(padding - (num_added_cols - 1) * (1 << max_log_row_count));
67    column_counts.push(num_added_cols - 1);
68    column_counts.push(1);
69
70    let (hasher, compressor) = GC::default_hasher_and_compressor();
71
72    let hash = hasher.hash_iter(
73        once(Felt::from_canonical_u32(row_counts.len() as u32))
74            .chain(row_counts.clone().into_iter().map(|x| Felt::from_canonical_u32(x as u32)))
75            .chain(column_counts.clone().into_iter().map(|x| Felt::from_canonical_u32(x as u32))),
76    );
77
78    let final_commitment = compressor.compress([commitment, hash]);
79
80    let jagged_prover_data = JaggedProverData {
81        pcs_prover_data: data,
82        row_counts: Arc::new(row_counts),
83        column_counts: Arc::new(column_counts),
84        padding_column_count: num_added_cols,
85        original_commitment: commitment,
86    };
87
88    Ok((final_commitment, jagged_prover_data))
89}
90
91#[cfg(test)]
92mod tests {
93    use std::sync::Arc;
94
95    use serial_test::serial;
96    use slop_alloc::{CpuBackend, ToHost};
97    use slop_challenger::IopCtx;
98    use slop_futures::queue::WorkerQueue;
99    use slop_jagged::{JaggedPcsVerifier, JaggedProver};
100    use slop_merkle_tree::Poseidon2KoalaBear16Prover;
101    use slop_stacked::StackedPcsProver;
102    use sp1_gpu_basefold::FriCudaProver;
103    use sp1_gpu_cudart::{run_in_place, PinnedBuffer};
104    use sp1_gpu_jagged_tracegen::test_utils::tracegen_setup::{
105        self, CORE_MAX_LOG_ROW_COUNT, LOG_STACKING_HEIGHT,
106    };
107    use sp1_gpu_jagged_tracegen::{full_tracegen, CORE_MAX_TRACE_SIZE};
108    use sp1_gpu_merkle_tree::{CudaTcsProver, Poseidon2SP1Field16CudaProver};
109    use sp1_gpu_utils::{Felt, TestGC};
110    use sp1_hypercube::prover::{DefaultTraceGenerator, ProverSemaphore, TraceGenerator};
111    use sp1_hypercube::{SP1InnerPcs, SP1PcsProofInner};
112    use sp1_primitives::fri_params::core_fri_config;
113
114    use crate::commit::commit_multilinears;
115    #[serial]
116    #[tokio::test]
117    async fn test_commit_matches() {
118        let (machine, record, program) = tracegen_setup::setup().await;
119
120        type JC = SP1InnerPcs;
121        type Prover = JaggedProver<
122            TestGC,
123            SP1PcsProofInner,
124            StackedPcsProver<Poseidon2KoalaBear16Prover, TestGC>,
125        >;
126
127        run_in_place(|scope| async move {
128            let semaphore = ProverSemaphore::new(1);
129            // Generate traces using the host tracegen.
130            let trace_generator = DefaultTraceGenerator::new_in(machine.clone(), CpuBackend);
131            let old_traces = trace_generator
132                .generate_traces(
133                    program.clone(),
134                    record.clone(),
135                    CORE_MAX_LOG_ROW_COUNT as usize,
136                    semaphore.clone(),
137                )
138                .await;
139
140            tracing::info!(
141                "warmup traces generated: {:?}",
142                old_traces.main_trace_data.shard_chips.len()
143            );
144
145            let num_rounds = 2;
146
147            let jagged_verifier = JaggedPcsVerifier::<_, JC>::new_from_basefold_params(
148                core_fri_config(),
149                LOG_STACKING_HEIGHT,
150                CORE_MAX_LOG_ROW_COUNT as usize,
151                num_rounds,
152            );
153
154            // Commit to preprocessed and main using the old prover.
155            let jagged_prover = Prover::from_verifier(&jagged_verifier);
156
157            let mut preprocessed_host_values = Vec::new();
158            for mle in old_traces.preprocessed_traces.values() {
159                let mle_host = mle.to_host().unwrap();
160                preprocessed_host_values.push(mle_host);
161            }
162
163            let mut main_host_values = Vec::new();
164            for mle in old_traces.main_trace_data.traces.values() {
165                let mle_host = mle.to_host().unwrap();
166                main_host_values.push(mle_host);
167            }
168
169            let preprocessed_message = preprocessed_host_values.into_iter().collect();
170            let main_message = main_host_values.into_iter().collect();
171
172            let (old_preprocessed_commitment, old_preprocessed_data) =
173                jagged_prover.commit_multilinears(preprocessed_message).ok().unwrap();
174            let (old_main_commitment, old_main_data) =
175                jagged_prover.commit_multilinears(main_message).ok().unwrap();
176
177            // Commit to preprocessed and main using the new prover.
178            // Do tracegen with the new setup.
179            let record = Arc::new(record);
180            let capacity = CORE_MAX_TRACE_SIZE as usize;
181            let buffer = PinnedBuffer::<Felt>::with_capacity(capacity);
182            let queue = Arc::new(WorkerQueue::new(vec![buffer]));
183            let buffer = queue.pop().await.unwrap();
184            let (_public_values, jagged_trace_data, _chip_set, _permit) = full_tracegen(
185                &machine,
186                program.clone(),
187                record.clone(),
188                &buffer,
189                CORE_MAX_TRACE_SIZE as usize,
190                LOG_STACKING_HEIGHT,
191                CORE_MAX_LOG_ROW_COUNT,
192                &scope,
193                ProverSemaphore::new(1),
194                false,
195            )
196            .await;
197
198            let tcs_prover = Poseidon2SP1Field16CudaProver::new(&scope);
199
200            let basefold_prover = FriCudaProver::<TestGC, _, <TestGC as IopCtx>::F>::new(
201                tcs_prover,
202                jagged_verifier.pcs_verifier.basefold_verifier.fri_config,
203                LOG_STACKING_HEIGHT,
204            );
205
206            let (new_preprocessed_commitment, new_preprocessed_data) =
207                commit_multilinears::<TestGC, _>(
208                    &jagged_trace_data,
209                    CORE_MAX_LOG_ROW_COUNT,
210                    true,
211                    false,
212                    &basefold_prover,
213                )
214                .unwrap();
215
216            let (new_main_commitment, new_main_data) = commit_multilinears::<TestGC, _>(
217                &jagged_trace_data,
218                CORE_MAX_LOG_ROW_COUNT,
219                false,
220                false,
221                &basefold_prover,
222            )
223            .unwrap();
224
225            assert_eq!(old_preprocessed_data.row_counts, new_preprocessed_data.row_counts);
226            assert_eq!(old_preprocessed_data.column_counts, new_preprocessed_data.column_counts);
227            assert_eq!(
228                old_preprocessed_data.padding_column_count,
229                new_preprocessed_data.padding_column_count
230            );
231            assert_eq!(old_main_data.row_counts, new_main_data.row_counts);
232            assert_eq!(old_main_data.column_counts, new_main_data.column_counts);
233            assert_eq!(old_main_data.padding_column_count, new_main_data.padding_column_count);
234            assert_eq!(
235                old_preprocessed_data.original_commitment,
236                new_preprocessed_data.original_commitment
237            );
238            assert_eq!(old_main_data.original_commitment, new_main_data.original_commitment);
239            assert_eq!(old_preprocessed_commitment, new_preprocessed_commitment);
240            assert_eq!(old_main_commitment, new_main_commitment);
241        })
242        .await;
243    }
244}