1use std::{iter::once, sync::Arc};
2
3use slop_algebra::AbstractField;
4use slop_alloc::HasBackend;
5use slop_challenger::IopCtx;
6use slop_jagged::JaggedProverData;
7use slop_symmetric::{CryptographicHasher, PseudoCompressionFunction as _};
8use slop_tensor::Tensor;
9use sp1_gpu_basefold::{CudaStackedPcsProverData, FriCudaProver};
10use sp1_gpu_cudart::TaskScope;
11use sp1_gpu_merkle_tree::{CudaTcsProver, SingleLayerMerkleTreeProverError};
12use sp1_gpu_utils::{traces::JaggedTraceMle, Ext, Felt};
13
14#[allow(clippy::type_complexity)]
16pub fn commit_multilinears<GC: IopCtx<F = Felt, EF = Ext>, P: CudaTcsProver<GC>>(
17 jagged_trace_mle: &JaggedTraceMle<Felt, TaskScope>,
18 max_log_row_count: u32,
19 use_preprocessed: bool,
20 drop_main_traces: bool,
21 basefold_prover: &FriCudaProver<GC, P, Felt>,
22) -> Result<
23 (GC::Digest, JaggedProverData<GC, CudaStackedPcsProverData<GC>>),
24 SingleLayerMerkleTreeProverError,
25> {
26 let (index, padding, dst) = if use_preprocessed {
27 (
28 &jagged_trace_mle.dense().preprocessed_table_index,
29 jagged_trace_mle.dense().preprocessed_padding,
30 Tensor::<Felt, TaskScope>::with_sizes_in(
31 [
32 jagged_trace_mle.dense().preprocessed_offset >> basefold_prover.log_height,
33 1 << (basefold_prover.log_height as usize
34 + basefold_prover.config.log_blowup()),
35 ],
36 jagged_trace_mle.dense().dense.backend().clone(),
37 ),
38 )
39 } else {
40 (
41 &jagged_trace_mle.dense().main_table_index,
42 jagged_trace_mle.dense().main_padding,
43 Tensor::<Felt, TaskScope>::with_sizes_in(
44 [
45 jagged_trace_mle.dense().main_size() >> basefold_prover.log_height,
46 1 << (basefold_prover.log_height as usize
47 + basefold_prover.config.log_blowup()),
48 ],
49 jagged_trace_mle.dense().dense.backend().clone(),
50 ),
51 )
52 };
53 let (mut row_counts, mut column_counts) = (
54 index.values().map(|x| x.poly_size).collect::<Vec<_>>(),
55 index.values().map(|x| x.num_polys).collect::<Vec<_>>(),
56 );
57
58 let drop_traces = drop_main_traces && !use_preprocessed;
59
60 let (commitment, data) =
61 basefold_prover.encode_and_commit(use_preprocessed, drop_traces, jagged_trace_mle, dst)?;
62
63 let num_added_cols = padding.div_ceil(1 << max_log_row_count).max(1);
64
65 row_counts.push(1 << max_log_row_count);
66 row_counts.push(padding - (num_added_cols - 1) * (1 << max_log_row_count));
67 column_counts.push(num_added_cols - 1);
68 column_counts.push(1);
69
70 let (hasher, compressor) = GC::default_hasher_and_compressor();
71
72 let hash = hasher.hash_iter(
73 once(Felt::from_canonical_u32(row_counts.len() as u32))
74 .chain(row_counts.clone().into_iter().map(|x| Felt::from_canonical_u32(x as u32)))
75 .chain(column_counts.clone().into_iter().map(|x| Felt::from_canonical_u32(x as u32))),
76 );
77
78 let final_commitment = compressor.compress([commitment, hash]);
79
80 let jagged_prover_data = JaggedProverData {
81 pcs_prover_data: data,
82 row_counts: Arc::new(row_counts),
83 column_counts: Arc::new(column_counts),
84 padding_column_count: num_added_cols,
85 original_commitment: commitment,
86 };
87
88 Ok((final_commitment, jagged_prover_data))
89}
90
91#[cfg(test)]
92mod tests {
93 use std::sync::Arc;
94
95 use serial_test::serial;
96 use slop_alloc::{CpuBackend, ToHost};
97 use slop_challenger::IopCtx;
98 use slop_futures::queue::WorkerQueue;
99 use slop_jagged::{JaggedPcsVerifier, JaggedProver};
100 use slop_merkle_tree::Poseidon2KoalaBear16Prover;
101 use slop_stacked::StackedPcsProver;
102 use sp1_gpu_basefold::FriCudaProver;
103 use sp1_gpu_cudart::{run_in_place, PinnedBuffer};
104 use sp1_gpu_jagged_tracegen::test_utils::tracegen_setup::{
105 self, CORE_MAX_LOG_ROW_COUNT, LOG_STACKING_HEIGHT,
106 };
107 use sp1_gpu_jagged_tracegen::{full_tracegen, CORE_MAX_TRACE_SIZE};
108 use sp1_gpu_merkle_tree::{CudaTcsProver, Poseidon2SP1Field16CudaProver};
109 use sp1_gpu_utils::{Felt, TestGC};
110 use sp1_hypercube::prover::{DefaultTraceGenerator, ProverSemaphore, TraceGenerator};
111 use sp1_hypercube::{SP1InnerPcs, SP1PcsProofInner};
112 use sp1_primitives::fri_params::core_fri_config;
113
114 use crate::commit::commit_multilinears;
115 #[serial]
116 #[tokio::test]
117 async fn test_commit_matches() {
118 let (machine, record, program) = tracegen_setup::setup().await;
119
120 type JC = SP1InnerPcs;
121 type Prover = JaggedProver<
122 TestGC,
123 SP1PcsProofInner,
124 StackedPcsProver<Poseidon2KoalaBear16Prover, TestGC>,
125 >;
126
127 run_in_place(|scope| async move {
128 let semaphore = ProverSemaphore::new(1);
129 let trace_generator = DefaultTraceGenerator::new_in(machine.clone(), CpuBackend);
131 let old_traces = trace_generator
132 .generate_traces(
133 program.clone(),
134 record.clone(),
135 CORE_MAX_LOG_ROW_COUNT as usize,
136 semaphore.clone(),
137 )
138 .await;
139
140 tracing::info!(
141 "warmup traces generated: {:?}",
142 old_traces.main_trace_data.shard_chips.len()
143 );
144
145 let num_rounds = 2;
146
147 let jagged_verifier = JaggedPcsVerifier::<_, JC>::new_from_basefold_params(
148 core_fri_config(),
149 LOG_STACKING_HEIGHT,
150 CORE_MAX_LOG_ROW_COUNT as usize,
151 num_rounds,
152 );
153
154 let jagged_prover = Prover::from_verifier(&jagged_verifier);
156
157 let mut preprocessed_host_values = Vec::new();
158 for mle in old_traces.preprocessed_traces.values() {
159 let mle_host = mle.to_host().unwrap();
160 preprocessed_host_values.push(mle_host);
161 }
162
163 let mut main_host_values = Vec::new();
164 for mle in old_traces.main_trace_data.traces.values() {
165 let mle_host = mle.to_host().unwrap();
166 main_host_values.push(mle_host);
167 }
168
169 let preprocessed_message = preprocessed_host_values.into_iter().collect();
170 let main_message = main_host_values.into_iter().collect();
171
172 let (old_preprocessed_commitment, old_preprocessed_data) =
173 jagged_prover.commit_multilinears(preprocessed_message).ok().unwrap();
174 let (old_main_commitment, old_main_data) =
175 jagged_prover.commit_multilinears(main_message).ok().unwrap();
176
177 let record = Arc::new(record);
180 let capacity = CORE_MAX_TRACE_SIZE as usize;
181 let buffer = PinnedBuffer::<Felt>::with_capacity(capacity);
182 let queue = Arc::new(WorkerQueue::new(vec![buffer]));
183 let buffer = queue.pop().await.unwrap();
184 let (_public_values, jagged_trace_data, _chip_set, _permit) = full_tracegen(
185 &machine,
186 program.clone(),
187 record.clone(),
188 &buffer,
189 CORE_MAX_TRACE_SIZE as usize,
190 LOG_STACKING_HEIGHT,
191 CORE_MAX_LOG_ROW_COUNT,
192 &scope,
193 ProverSemaphore::new(1),
194 false,
195 )
196 .await;
197
198 let tcs_prover = Poseidon2SP1Field16CudaProver::new(&scope);
199
200 let basefold_prover = FriCudaProver::<TestGC, _, <TestGC as IopCtx>::F>::new(
201 tcs_prover,
202 jagged_verifier.pcs_verifier.basefold_verifier.fri_config,
203 LOG_STACKING_HEIGHT,
204 );
205
206 let (new_preprocessed_commitment, new_preprocessed_data) =
207 commit_multilinears::<TestGC, _>(
208 &jagged_trace_data,
209 CORE_MAX_LOG_ROW_COUNT,
210 true,
211 false,
212 &basefold_prover,
213 )
214 .unwrap();
215
216 let (new_main_commitment, new_main_data) = commit_multilinears::<TestGC, _>(
217 &jagged_trace_data,
218 CORE_MAX_LOG_ROW_COUNT,
219 false,
220 false,
221 &basefold_prover,
222 )
223 .unwrap();
224
225 assert_eq!(old_preprocessed_data.row_counts, new_preprocessed_data.row_counts);
226 assert_eq!(old_preprocessed_data.column_counts, new_preprocessed_data.column_counts);
227 assert_eq!(
228 old_preprocessed_data.padding_column_count,
229 new_preprocessed_data.padding_column_count
230 );
231 assert_eq!(old_main_data.row_counts, new_main_data.row_counts);
232 assert_eq!(old_main_data.column_counts, new_main_data.column_counts);
233 assert_eq!(old_main_data.padding_column_count, new_main_data.padding_column_count);
234 assert_eq!(
235 old_preprocessed_data.original_commitment,
236 new_preprocessed_data.original_commitment
237 );
238 assert_eq!(old_main_data.original_commitment, new_main_data.original_commitment);
239 assert_eq!(old_preprocessed_commitment, new_preprocessed_commitment);
240 assert_eq!(old_main_commitment, new_main_commitment);
241 })
242 .await;
243 }
244}