1use itertools::Itertools;
2use std::{marker::PhantomData, sync::Arc};
3
4use slop_algebra::{AbstractExtensionField, AbstractField, ExtensionField, TwoAdicField};
5use slop_alloc::{Buffer, HasBackend};
6use slop_basefold::{BasefoldProof, FriConfig, BATCH_GRINDING_BITS};
7use slop_basefold_prover::{host_fold_even_odd, BasefoldProverError};
8use slop_challenger::{CanObserve, CanSampleBits, FieldChallenger, IopCtx};
9use slop_commit::{Message, Rounds};
10use slop_merkle_tree::MerkleTreeOpeningAndProof;
11use slop_multilinear::{partial_lagrange_blocking, Mle, MultilinearPcsChallenger, Point};
12use slop_tensor::Tensor;
13use sp1_primitives::{SP1ExtensionField, SP1Field};
14
15use sp1_gpu_cudart::{
16 args,
17 sys::{
18 basefold::{
19 batch_koala_bear_base_ext_kernel, batch_koala_bear_base_ext_kernel_flattened,
20 flatten_to_base_koala_bear_base_ext_kernel,
21 transpose_even_odd_koala_bear_base_ext_kernel,
22 },
23 runtime::KernelPtr,
24 },
25 DeviceBuffer, DeviceMle, DeviceTensor, TaskScope,
26};
27use sp1_gpu_merkle_tree::{CudaTcsProver, MerkleTreeProverData, SingleLayerMerkleTreeProverError};
28use sp1_gpu_utils::{Ext, Felt, JaggedTraceMle, TraceDenseData};
29
30use crate::{
31 encode_batch, CudaStackedPcsProverData, DeviceGrindingChallenger, GrindingPowCudaProver,
32 SpparkDftKoalaBear,
33};
34
35pub unsafe trait MleBatchKernel<F: TwoAdicField, EF: ExtensionField<F>> {
38 fn batch_mle_kernel() -> KernelPtr;
39}
40
41pub unsafe trait RsCodeWordBatchKernel<F: TwoAdicField, EF: ExtensionField<F>> {
44 fn batch_rs_codeword_kernel() -> KernelPtr;
45}
46
47pub unsafe trait RsCodeWordTransposeKernel<F: TwoAdicField, EF: ExtensionField<F>> {
49 fn transpose_even_odd_kernel() -> KernelPtr;
50}
51
52pub unsafe trait MleFlattenKernel<F: TwoAdicField, EF: ExtensionField<F>> {
54 fn flatten_to_base_kernel() -> KernelPtr;
55}
56
57pub struct FriCudaProver<GC, P, F> {
58 pub tcs_prover: P,
59 pub config: FriConfig<F>,
60 pub log_height: u32,
61 _marker: PhantomData<GC>,
62}
63
64impl<GC: IopCtx<F = Felt, EF = Ext>, P> FriCudaProver<GC, P, GC::F>
65where
66 GC::F: TwoAdicField,
67 GC::EF: ExtensionField<GC::F> + TwoAdicField,
68 P: CudaTcsProver<GC>,
69
70 TaskScope: MleBatchKernel<GC::F, GC::EF>
71 + RsCodeWordBatchKernel<GC::F, GC::EF>
72 + RsCodeWordTransposeKernel<GC::F, GC::EF>
73 + MleFlattenKernel<GC::F, GC::EF>,
74{
75 pub fn new(tcs_prover: P, config: FriConfig<GC::F>, log_height: u32) -> Self {
76 Self { tcs_prover, config, log_height, _marker: PhantomData }
77 }
78 pub fn encode_and_commit(
79 &self,
80 use_preprocessed: bool,
81 drop_traces: bool,
82 jagged_trace_mle: &JaggedTraceMle<Felt, TaskScope>,
83 mut dst: Tensor<Felt, TaskScope>,
84 ) -> Result<
85 (<GC as IopCtx>::Digest, CudaStackedPcsProverData<GC>),
86 SingleLayerMerkleTreeProverError,
87 > {
88 let encoder = SpparkDftKoalaBear::default();
89
90 unsafe {
91 dst.assume_init();
92 }
93
94 let virtual_tensor = if use_preprocessed {
95 jagged_trace_mle.preprocessed_virtual_tensor(self.log_height)
96 } else {
97 jagged_trace_mle.main_virtual_tensor(self.log_height)
98 };
99
100 encode_batch(encoder, self.config.log_blowup as u32, virtual_tensor, &mut dst).unwrap();
101
102 let (commitment, tcs_data) = self.tcs_prover.commit_tensors(&dst)?;
105
106 let codeword_mle = if drop_traces { None } else { Some(Arc::new(dst)) };
107 let prover_data = CudaStackedPcsProverData { merkle_tree_tcs_data: tcs_data, codeword_mle };
108
109 Ok((commitment, prover_data))
110 }
111
112 #[allow(clippy::type_complexity)]
113 pub fn batch(
114 &self,
115 batching_coefficients: &Tensor<GC::EF>,
116 mles: &TraceDenseData<GC::F, TaskScope>,
117 codewords: Message<Tensor<Felt, TaskScope>>,
118 evaluation_claims: Vec<GC::EF>,
119 ) -> (Mle<GC::EF, TaskScope>, Tensor<GC::F, TaskScope>, GC::EF) {
120 let log_stacking_height = self.log_height;
121 let total_num_polynomials = codewords.iter().map(|c| c.sizes()[0]).sum::<usize>();
123
124 let num_variables = log_stacking_height;
126 let codeword_size = (codewords.first().unwrap()).sizes()[1];
127 let scope: TaskScope = mles.backend().clone();
128 let mut batch_mle =
129 Mle::new(Tensor::<GC::EF, TaskScope>::zeros_in([1, 1 << num_variables], scope.clone()));
130 let mut batch_codeword = Tensor::<GC::F, TaskScope>::zeros_in(
131 [<GC::EF as AbstractExtensionField<GC::F>>::D, codeword_size],
132 scope.clone(),
133 );
134
135 unsafe {
136 let block_dim = 256;
137 let grid_dim = (1usize << num_variables).div_ceil(block_dim);
138 let batch_size = total_num_polynomials;
139 let powers_device = DeviceBuffer::from_host(batching_coefficients.as_buffer(), &scope)
140 .unwrap()
141 .into_inner();
142 let mle_args = args!(
143 mles.dense.as_ptr(),
144 batch_mle.guts_mut().as_mut_ptr(),
145 powers_device.as_ptr(),
146 (1 << num_variables) as usize,
147 batch_size
148 );
149 scope
150 .launch_kernel(TaskScope::batch_mle_kernel(), grid_dim, block_dim, &mle_args, 0)
151 .unwrap();
152 }
153
154 let mut batch_coefficients = batching_coefficients.as_buffer().to_vec();
155 for codeword in codewords.iter() {
156 let batch_size = codeword.sizes()[0];
157 let mut powers = batch_coefficients;
158 batch_coefficients = powers.split_off(batch_size);
159 let powers_device = DeviceBuffer::from_host(&Buffer::from(powers.clone()), &scope)
160 .unwrap()
161 .into_inner();
162
163 let block_dim = 256;
164 let grid_dim = codeword_size.div_ceil(block_dim);
165 let codeword_args = args!(
166 codeword.as_ptr(),
167 batch_codeword.as_mut_ptr(),
168 powers_device.as_ptr(),
169 codeword_size,
170 batch_size
171 );
172 unsafe {
173 scope
174 .launch_kernel(
175 TaskScope::batch_rs_codeword_kernel(),
176 grid_dim,
177 block_dim,
178 &codeword_args,
179 0,
180 )
181 .unwrap();
182 }
183 }
184
185 let batch_eval_claim = evaluation_claims
187 .into_iter()
188 .zip(batching_coefficients.as_slice())
189 .map(|(eval, coeff)| eval * *coeff)
190 .sum::<GC::EF>();
191
192 (batch_mle, batch_codeword, batch_eval_claim)
193 }
194
195 #[allow(clippy::type_complexity)]
196 fn commit_phase_round(
197 &self,
198 current_mle: Mle<GC::EF, TaskScope>,
199 current_codeword: Tensor<GC::F, TaskScope>,
200 challenger: &mut GC::Challenger,
201 ) -> Result<
202 (
203 GC::EF,
204 Mle<GC::EF, TaskScope>,
205 Tensor<GC::F, TaskScope>,
206 GC::Digest,
207 Tensor<GC::F, TaskScope>,
208 MerkleTreeProverData<GC::Digest>,
209 ),
210 SingleLayerMerkleTreeProverError,
211 > {
212 let codeword_size = current_codeword.sizes()[1];
218 let batch_size = current_codeword.sizes()[0];
219 let scope = current_codeword.backend().clone();
220
221 let mut leaves = Tensor::with_sizes_in([batch_size * 2, codeword_size / 2], scope.clone());
222 let output_codeword_size = codeword_size / 2;
223 let block_dim = 256;
224 let grid_dim = output_codeword_size.div_ceil(block_dim);
225 unsafe {
226 let args = args!(current_codeword.as_ptr(), leaves.as_mut_ptr(), output_codeword_size);
227 leaves.assume_init();
228 scope
229 .launch_kernel(
230 TaskScope::transpose_even_odd_kernel(),
231 grid_dim,
232 block_dim,
233 &args,
234 0,
235 )
236 .unwrap();
237 }
238
239 let (commit, prover_data) = self.tcs_prover.commit_tensors(&leaves)?;
240 challenger.observe(commit);
242
243 let beta: GC::EF = challenger.sample_ext_element();
244
245 let folded_mle: Mle<_, TaskScope> = {
247 let device_mle = DeviceMle::from(current_mle);
248 device_mle.fold(beta).into()
249 };
250 let folded_num_variables = folded_mle.num_variables();
251
252 if folded_num_variables < 4 {
253 let current_codeword_transposed =
254 DeviceTensor::from_raw(current_codeword.clone()).transpose();
255 let current_codeword_vec = current_codeword_transposed.to_host().unwrap();
256 let current_codeword_vec =
257 current_codeword_vec.into_buffer().into_extension::<GC::EF>().into_vec();
258 let folded_codeword_vec = host_fold_even_odd(current_codeword_vec, beta);
259 let folded_codeword_storage =
260 Buffer::from(folded_codeword_vec).flatten_to_base::<GC::F>();
261 let mut new_size = current_codeword.sizes().to_vec();
262 new_size[1] /= 2;
263 let folded_codeword =
264 DeviceBuffer::from_host(&folded_codeword_storage, folded_mle.backend())
265 .unwrap()
266 .into_inner();
267 let folded_codeword = Tensor::from(folded_codeword).reshape([new_size[1], new_size[0]]);
268 let folded_codeword = DeviceTensor::from_raw(folded_codeword).transpose().into_inner();
269 return Ok((beta, folded_mle, folded_codeword, commit, leaves, prover_data));
270 }
271
272 let folded_height = 1 << folded_num_variables;
273 let mut folded_mle_flattened = Tensor::<GC::F, TaskScope>::with_sizes_in(
274 [<GC::EF as AbstractExtensionField<GC::F>>::D, folded_height],
275 scope.clone(),
276 );
277
278 let mut folded_codeword = Tensor::<GC::F, TaskScope>::zeros_in(
279 [<GC::EF as AbstractExtensionField<GC::F>>::D, folded_height << self.config.log_blowup],
280 scope.clone(),
281 );
282
283 let block_dim = 256;
284 let grid_dim = folded_height.div_ceil(block_dim);
285 unsafe {
286 let args =
287 args!(folded_mle.guts().as_ptr(), folded_mle_flattened.as_mut_ptr(), folded_height);
288 folded_mle_flattened.assume_init();
289 scope
290 .launch_kernel(TaskScope::flatten_to_base_kernel(), grid_dim, block_dim, &args, 0)
291 .unwrap();
292 }
293 let encoder = SpparkDftKoalaBear::default();
294 encode_batch(
295 encoder,
296 self.config.log_blowup as u32,
297 folded_mle_flattened.as_view(),
298 &mut folded_codeword,
299 )
300 .unwrap();
301
302 Ok((beta, folded_mle, folded_codeword, commit, leaves, prover_data))
303 }
304
305 fn final_poly(&self, final_codeword: Tensor<GC::F, TaskScope>) -> GC::EF {
306 let final_codeword_host = DeviceTensor::from_raw(final_codeword).to_host().unwrap();
307 let final_codeword_transposed = final_codeword_host.transpose();
308 GC::EF::from_base_slice(
309 &final_codeword_transposed.storage.as_slice()
310 [0..(<GC::EF as AbstractExtensionField<GC::F>>::D)],
311 )
312 }
313
314 #[inline]
315 pub fn prove_trusted_evaluations_basefold(
316 &self,
317 mut eval_point: Point<GC::EF>,
318 evaluation_claims: Vec<GC::EF>,
319 mles: &JaggedTraceMle<GC::F, TaskScope>,
320 prover_data: Rounds<&CudaStackedPcsProverData<GC>>,
321 challenger: &mut GC::Challenger,
322 ) -> Result<BasefoldProof<GC>, BasefoldProverError<SingleLayerMerkleTreeProverError>>
323 where
324 GC::Challenger: DeviceGrindingChallenger<Witness = GC::F>,
325 {
326 let scope = mles.dense().dense.backend().clone();
327 let mut codewords: Vec<Arc<Tensor<Felt, TaskScope>>> = Vec::new();
328 for data in prover_data.iter() {
329 if let Some(ref codeword) = data.codeword_mle {
330 codewords.push(codeword.clone());
331 } else {
332 let mut dst = Tensor::<Felt, TaskScope>::with_sizes_in(
334 [
335 mles.dense().main_size() >> self.log_height,
336 1 << (self.log_height as usize + self.config.log_blowup()),
337 ],
338 scope.clone(),
339 );
340 unsafe {
341 dst.assume_init();
342 }
343
344 let encoder = SpparkDftKoalaBear::default();
345 encode_batch(
346 encoder,
347 self.config.log_blowup as u32,
348 mles.main_virtual_tensor(self.log_height),
349 &mut dst,
350 )
351 .unwrap();
352
353 codewords.push(Arc::new(dst));
354 }
355 }
356
357 let total_num_polynomials = codewords.iter().map(|c| c.sizes()[0]).sum::<usize>();
358 let num_batching_variables = total_num_polynomials.next_power_of_two().ilog2();
359
360 let encoded_messages: Message<_> = codewords.iter().cloned().collect();
361
362 let batch_grinding_witness =
364 GrindingPowCudaProver::grind(challenger, BATCH_GRINDING_BITS, &scope);
365
366 let batching_point = challenger.sample_point::<GC::EF>(num_batching_variables);
367 let batching_coefficients = partial_lagrange_blocking(&batching_point);
368
369 let (mle_batch, codeword_batch, batched_eval_claim) =
371 self.batch(&batching_coefficients, mles.dense(), encoded_messages, evaluation_claims);
372 let mut current_mle = mle_batch;
376 let mut current_codeword = codeword_batch;
377 let log_len = current_mle.num_variables();
379 let mut univariate_messages: Vec<[GC::EF; 2]> = vec![];
380 let mut fri_commitments = vec![];
381 let mut commit_phase_data = vec![];
382 let mut current_batched_eval_claim = batched_eval_claim;
383 let mut commit_phase_values = vec![];
384
385 assert_eq!(
386 current_mle.num_variables(),
387 eval_point.dimension() as u32,
388 "eval point dimension mismatch"
389 );
390
391 challenger.observe(Felt::from_canonical_usize(eval_point.dimension()));
392 for _ in 0..eval_point.dimension() {
393 let last_coord = eval_point.remove_last_coordinate();
395 let zero_values = {
396 use sp1_gpu_cudart::DeviceMle;
397 let device_mle = DeviceMle::from(current_mle.clone());
398 let evals = device_mle.fixed_at_zero(&eval_point);
399 evals.to_host_vec().unwrap()
400 };
401 let zero_val = zero_values[0];
402 let one_val = (current_batched_eval_claim - zero_val) / last_coord + zero_val;
403 let uni_poly = [zero_val, one_val];
404 univariate_messages.push(uni_poly);
405
406 uni_poly.iter().for_each(|elem| challenger.observe_ext_element(*elem));
407
408 let (beta, folded_mle, folded_codeword, commitment, leaves, prover_data) = self
411 .commit_phase_round(current_mle, current_codeword, challenger)
412 .map_err(BasefoldProverError::CommitPhaseError)?;
413
414 fri_commitments.push(commitment);
415 commit_phase_data.push(prover_data);
416 commit_phase_values.push(leaves);
417
418 current_mle = folded_mle;
419 current_codeword = folded_codeword;
420 current_batched_eval_claim = zero_val + beta * one_val;
421 }
422
423 let final_poly = self.final_poly(current_codeword);
424 challenger.observe_ext_element(final_poly);
425
426 let fri_config = self.config;
427 let pow_bits = fri_config.proof_of_work_bits;
428 let pow_witness = GrindingPowCudaProver::grind(challenger, pow_bits, &scope);
429 let query_indices: Vec<usize> = (0..fri_config.num_queries)
431 .map(|_| challenger.sample_bits(log_len as usize + fri_config.log_blowup()))
432 .collect();
433
434 let mut component_polynomials_query_openings_and_proofs = vec![];
436 for (data, codeword) in prover_data.iter().zip(codewords.iter()) {
437 let values = self.tcs_prover.compute_openings_at_indices(codeword, &query_indices);
438 let proof = self
439 .tcs_prover
440 .prove_openings_at_indices(&data.merkle_tree_tcs_data, &query_indices)
441 .map_err(BasefoldProverError::TcsCommitError)?;
442 let opening = MerkleTreeOpeningAndProof::<GC> { values, proof };
443 component_polynomials_query_openings_and_proofs.push(opening);
444 }
445
446 let mut query_phase_openings_and_proofs = vec![];
448 let mut indices = query_indices;
449 for (leaves, data) in commit_phase_values.into_iter().zip_eq(commit_phase_data) {
450 for index in indices.iter_mut() {
451 *index >>= 1;
452 }
453 let values = self.tcs_prover.compute_openings_at_indices(&leaves, &indices);
454
455 let proof = self
456 .tcs_prover
457 .prove_openings_at_indices(&data, &indices)
458 .map_err(BasefoldProverError::TcsCommitError)?;
459 let opening = MerkleTreeOpeningAndProof { values, proof };
460 query_phase_openings_and_proofs.push(opening);
461 }
462
463 Ok(BasefoldProof {
464 univariate_messages,
465 fri_commitments,
466 component_polynomials_query_openings_and_proofs,
467 query_phase_openings_and_proofs,
468 final_poly,
469 pow_witness,
470 batch_grinding_witness,
471 })
472 }
473}
474
475unsafe impl MleBatchKernel<SP1Field, SP1ExtensionField> for TaskScope {
476 fn batch_mle_kernel() -> KernelPtr {
477 unsafe { batch_koala_bear_base_ext_kernel() }
478 }
479}
480
481unsafe impl RsCodeWordBatchKernel<SP1Field, SP1ExtensionField> for TaskScope {
482 fn batch_rs_codeword_kernel() -> KernelPtr {
483 unsafe { batch_koala_bear_base_ext_kernel_flattened() }
484 }
485}
486
487unsafe impl RsCodeWordTransposeKernel<SP1Field, SP1ExtensionField> for TaskScope {
488 fn transpose_even_odd_kernel() -> KernelPtr {
489 unsafe { transpose_even_odd_koala_bear_base_ext_kernel() }
490 }
491}
492
493unsafe impl MleFlattenKernel<SP1Field, SP1ExtensionField> for TaskScope {
494 fn flatten_to_base_kernel() -> KernelPtr {
495 unsafe { flatten_to_base_koala_bear_base_ext_kernel() }
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use std::sync::Arc;
502
503 use slop_alloc::{CpuBackend, ToHost};
504 use slop_basefold::BasefoldVerifier;
505 use slop_basefold_prover::BasefoldProver;
506 use slop_commit::Message;
507 use slop_futures::queue::WorkerQueue;
508 use slop_merkle_tree::Poseidon2KoalaBear16Prover;
509 use slop_multilinear::{Evaluations, Mle, MleEval};
510 use slop_stacked::interleave_multilinears_with_fixed_rate;
511 use sp1_gpu_cudart::{run_sync_in_place, PinnedBuffer};
512 use sp1_gpu_merkle_tree::{CudaTcsProver, Poseidon2SP1Field16CudaProver};
513 use sp1_gpu_tracegen::CudaTraceGenerator;
514 use sp1_hypercube::prover::{ProverSemaphore, TraceGenerator};
515
516 use sp1_core_machine::io::SP1Stdin;
517 use sp1_gpu_jagged_tracegen::test_utils::tracegen_setup::{
518 self, CORE_MAX_LOG_ROW_COUNT, LOG_STACKING_HEIGHT,
519 };
520 use sp1_gpu_jagged_tracegen::{full_tracegen, CORE_MAX_TRACE_SIZE};
521 use sp1_gpu_utils::{Ext, Felt, TestGC};
522 use sp1_primitives::fri_params::core_fri_config;
523 use sp1_primitives::SP1GlobalContext;
524
525 use super::*;
526
527 #[test]
528 fn test_basefold() {
529 let rt = tokio::runtime::Runtime::new().unwrap();
530 let (machine, record, program) =
531 rt.block_on(tracegen_setup::setup(&test_artifacts::FIBONACCI_ELF, SP1Stdin::new()));
532
533 run_sync_in_place(|scope| {
534 let verifier = BasefoldVerifier::<SP1GlobalContext>::new(core_fri_config(), 2);
535 let old_prover =
536 BasefoldProver::<SP1GlobalContext, Poseidon2KoalaBear16Prover>::new(&verifier);
537
538 let new_cuda_prover = FriCudaProver::<TestGC, _, Felt> {
539 tcs_prover: Poseidon2SP1Field16CudaProver::new(&scope),
540 config: verifier.fri_config,
541 log_height: LOG_STACKING_HEIGHT,
542 _marker: PhantomData::<TestGC>,
543 };
544
545 let semaphore = ProverSemaphore::new(1);
547 let trace_generator = CudaTraceGenerator::new_in(machine.clone(), scope.clone());
548 let old_traces = rt.block_on(trace_generator.generate_traces(
549 program.clone(),
550 record.clone(),
551 CORE_MAX_LOG_ROW_COUNT as usize,
552 semaphore.clone(),
553 ));
554
555 let preprocessed_traces = old_traces.preprocessed_traces.clone();
556
557 let message = preprocessed_traces
558 .into_iter()
559 .filter_map(|mle| mle.1.into_inner())
560 .map(|x| Clone::clone(x.as_ref()))
561 .collect::<Message<Mle<_, _>>>();
562
563 let host_message: Message<_> = message
564 .clone()
565 .into_iter()
566 .map(|mle| {
567 let mle = Arc::unwrap_or_clone(mle);
568 let guts = mle.into_guts();
569 let device_mle = sp1_gpu_cudart::DeviceMle::from(guts);
570 device_mle.to_host().unwrap()
571 })
572 .collect();
573
574 let interleaved_message =
575 interleave_multilinears_with_fixed_rate(32, host_message, LOG_STACKING_HEIGHT);
576
577 let interleaved_message =
578 interleaved_message.into_iter().map(|x| x.as_ref().clone()).collect::<Message<_>>();
579
580 let (old_preprocessed_commitment, old_preprocessed_prover_data) =
581 old_prover.commit_mles(interleaved_message.clone()).unwrap();
582
583 let new_semaphore = ProverSemaphore::new(1);
584 let capacity = CORE_MAX_TRACE_SIZE as usize;
585 let buffer = PinnedBuffer::<Felt>::with_capacity(capacity);
586 let queue = Arc::new(WorkerQueue::new(vec![buffer]));
587 let buffer = rt.block_on(queue.pop()).unwrap();
588 let (_, new_traces, _, _) = rt.block_on(full_tracegen(
589 &machine,
590 program,
591 Arc::new(record),
592 &buffer,
593 CORE_MAX_TRACE_SIZE as usize,
594 LOG_STACKING_HEIGHT,
595 CORE_MAX_LOG_ROW_COUNT,
596 &scope,
597 new_semaphore,
598 false,
599 ));
600
601 let dst = Tensor::<Felt, TaskScope>::with_sizes_in(
602 [
603 new_traces.0.dense().preprocessed_offset >> LOG_STACKING_HEIGHT,
604 1 << (LOG_STACKING_HEIGHT as usize + verifier.fri_config.log_blowup()),
605 ],
606 scope.clone(),
607 );
608
609 let (new_preprocessed_commit, new_preprocessed_prover_data) =
610 new_cuda_prover.encode_and_commit(true, false, &new_traces, dst).unwrap();
611
612 assert_eq!(new_preprocessed_commit, old_preprocessed_commitment);
613
614 let dst = Tensor::<Felt, TaskScope>::with_sizes_in(
615 [
616 new_traces.0.dense().main_size() >> LOG_STACKING_HEIGHT,
617 1 << (LOG_STACKING_HEIGHT as usize + verifier.fri_config.log_blowup()),
618 ],
619 scope.clone(),
620 );
621
622 let (new_main_commit, new_main_prover_data) =
623 new_cuda_prover.encode_and_commit(false, false, &new_traces, dst).unwrap();
624 let message = old_traces
625 .main_trace_data
626 .traces
627 .into_iter()
628 .filter_map(|mle| mle.1.into_inner())
629 .map(|x| Clone::clone(x.as_ref()))
630 .collect::<Message<Mle<_, _>>>();
631
632 let mut host_message = Vec::new();
633 for mle in message.into_iter() {
634 let mle = Arc::unwrap_or_clone(mle);
635 let guts = mle.into_guts();
636 let device_mle = sp1_gpu_cudart::DeviceMle::from(guts);
637 let mle_host = device_mle.to_host().unwrap();
638 host_message.push(mle_host);
639 }
640
641 let host_message = host_message.into_iter().collect::<Message<Mle<Felt, CpuBackend>>>();
642
643 let interleaved_message_2 =
644 interleave_multilinears_with_fixed_rate(32, host_message, LOG_STACKING_HEIGHT);
645
646 let (old_main_commitment, old_main_prover_data) =
647 old_prover.commit_mles(interleaved_message_2.clone()).unwrap();
648
649 assert_eq!(new_main_commit, old_main_commitment);
650
651 let mut rng = rand::thread_rng();
652
653 let eval_point_host = Point::<Ext>::rand(&mut rng, LOG_STACKING_HEIGHT);
654
655 let evaluation_claims_1: Vec<_> = interleaved_message
656 .clone()
657 .into_iter()
658 .map(|mle| mle.eval_at(&eval_point_host))
659 .collect();
660
661 let evaluation_claims_1 = Evaluations { round_evaluations: evaluation_claims_1 };
662
663 let evaluation_claims_2: Vec<_> = interleaved_message_2
664 .clone()
665 .into_iter()
666 .map(|mle| mle.eval_at(&eval_point_host))
667 .collect();
668
669 let host_evaluation_claims_1: Vec<MleEval<Ext, CpuBackend>> = evaluation_claims_1
670 .round_evaluations
671 .iter()
672 .map(|mle| mle.to_host().unwrap())
673 .collect();
674
675 let host_evaluation_claims_2: Vec<MleEval<Ext, CpuBackend>> =
676 evaluation_claims_2.iter().map(|mle| mle.to_host().unwrap()).collect();
677
678 let flattened_evaluation_claims = vec![
679 MleEval::new(
680 host_evaluation_claims_1
681 .into_iter()
682 .flat_map(|x: MleEval<Ext, CpuBackend>| x.evaluations().storage.to_vec())
683 .collect(),
684 ),
685 MleEval::new(
686 host_evaluation_claims_2
687 .into_iter()
688 .flat_map(|x: MleEval<Ext, CpuBackend>| x.evaluations().storage.to_vec())
689 .collect(),
690 ),
691 ];
692
693 let evaluation_claims_2 = Evaluations { round_evaluations: evaluation_claims_2 };
694
695 let mut challenger = SP1GlobalContext::default_challenger();
696
697 scope.synchronize_blocking().unwrap();
698 let now = std::time::Instant::now();
699
700 let basefold_proof = old_prover
701 .prove_trusted_mle_evaluations(
702 eval_point_host.clone(),
703 vec![interleaved_message, interleaved_message_2].into_iter().collect(),
704 vec![evaluation_claims_1.clone(), evaluation_claims_2.clone()]
705 .into_iter()
706 .collect(),
707 vec![old_preprocessed_prover_data, old_main_prover_data].into_iter().collect(),
708 &mut challenger,
709 )
710 .unwrap();
711
712 scope.synchronize_blocking().unwrap();
713 tracing::info!("Old proof time: {:?}", now.elapsed());
714
715 let mut challenger = SP1GlobalContext::default_challenger();
716
717 let flat_evaluation_claims: Vec<Ext> = evaluation_claims_1
718 .round_evaluations
719 .iter()
720 .chain(evaluation_claims_2.round_evaluations.iter())
721 .flat_map(|mle_eval| mle_eval.iter().copied())
722 .collect();
723
724 scope.synchronize_blocking().unwrap();
725
726 let now = std::time::Instant::now();
727
728 let new_basefold_proof = new_cuda_prover
729 .prove_trusted_evaluations_basefold(
730 eval_point_host.clone(),
731 flat_evaluation_claims,
732 &new_traces,
733 [&new_preprocessed_prover_data, &new_main_prover_data].into_iter().collect(),
734 &mut challenger,
735 )
736 .unwrap();
737
738 scope.synchronize_blocking().unwrap();
739 tracing::info!("New proof time: {:?}", now.elapsed());
740
741 verifier
747 .verify_mle_evaluations(
748 &[old_preprocessed_commitment, old_main_commitment],
749 eval_point_host.clone(),
750 &flattened_evaluation_claims,
751 &basefold_proof,
752 &mut SP1GlobalContext::default_challenger(),
753 )
754 .unwrap();
755
756 verifier
757 .verify_mle_evaluations(
758 &[new_preprocessed_commit, new_main_commit],
759 eval_point_host,
760 &flattened_evaluation_claims,
761 &new_basefold_proof,
762 &mut SP1GlobalContext::default_challenger(),
763 )
764 .unwrap();
765 })
766 .unwrap();
767 }
768}