Skip to main content

sp1_gpu_basefold/
fri.rs

1use itertools::Itertools;
2use std::{marker::PhantomData, sync::Arc};
3
4use slop_algebra::{AbstractExtensionField, AbstractField, ExtensionField, TwoAdicField};
5use slop_alloc::{Buffer, HasBackend};
6use slop_basefold::{BasefoldProof, FriConfig, BATCH_GRINDING_BITS};
7use slop_basefold_prover::{host_fold_even_odd, BasefoldProverError};
8use slop_challenger::{CanObserve, CanSampleBits, FieldChallenger, IopCtx};
9use slop_commit::{Message, Rounds};
10use slop_merkle_tree::MerkleTreeOpeningAndProof;
11use slop_multilinear::{partial_lagrange_blocking, Mle, MultilinearPcsChallenger, Point};
12use slop_tensor::Tensor;
13use sp1_primitives::{SP1ExtensionField, SP1Field};
14
15use sp1_gpu_cudart::{
16    args,
17    sys::{
18        basefold::{
19            batch_koala_bear_base_ext_kernel, batch_koala_bear_base_ext_kernel_flattened,
20            flatten_to_base_koala_bear_base_ext_kernel,
21            transpose_even_odd_koala_bear_base_ext_kernel,
22        },
23        runtime::KernelPtr,
24    },
25    DeviceBuffer, DeviceMle, DeviceTensor, TaskScope,
26};
27use sp1_gpu_merkle_tree::{CudaTcsProver, MerkleTreeProverData, SingleLayerMerkleTreeProverError};
28use sp1_gpu_utils::{Ext, Felt, JaggedTraceMle, TraceDenseData};
29
30use crate::{
31    encode_batch, CudaStackedPcsProverData, DeviceGrindingChallenger, GrindingPowCudaProver,
32    SpparkDftKoalaBear,
33};
34
35/// # Safety
36///
37pub unsafe trait MleBatchKernel<F: TwoAdicField, EF: ExtensionField<F>> {
38    fn batch_mle_kernel() -> KernelPtr;
39}
40
41/// # Safety
42///
43pub unsafe trait RsCodeWordBatchKernel<F: TwoAdicField, EF: ExtensionField<F>> {
44    fn batch_rs_codeword_kernel() -> KernelPtr;
45}
46
47/// # Safety
48pub unsafe trait RsCodeWordTransposeKernel<F: TwoAdicField, EF: ExtensionField<F>> {
49    fn transpose_even_odd_kernel() -> KernelPtr;
50}
51
52/// # Safety
53pub unsafe trait MleFlattenKernel<F: TwoAdicField, EF: ExtensionField<F>> {
54    fn flatten_to_base_kernel() -> KernelPtr;
55}
56
57pub struct FriCudaProver<GC, P, F> {
58    pub tcs_prover: P,
59    pub config: FriConfig<F>,
60    pub log_height: u32,
61    _marker: PhantomData<GC>,
62}
63
64impl<GC: IopCtx<F = Felt, EF = Ext>, P> FriCudaProver<GC, P, GC::F>
65where
66    GC::F: TwoAdicField,
67    GC::EF: ExtensionField<GC::F> + TwoAdicField,
68    P: CudaTcsProver<GC>,
69
70    TaskScope: MleBatchKernel<GC::F, GC::EF>
71        + RsCodeWordBatchKernel<GC::F, GC::EF>
72        + RsCodeWordTransposeKernel<GC::F, GC::EF>
73        + MleFlattenKernel<GC::F, GC::EF>,
74{
75    pub fn new(tcs_prover: P, config: FriConfig<GC::F>, log_height: u32) -> Self {
76        Self { tcs_prover, config, log_height, _marker: PhantomData }
77    }
78    pub fn encode_and_commit(
79        &self,
80        use_preprocessed: bool,
81        drop_traces: bool,
82        jagged_trace_mle: &JaggedTraceMle<Felt, TaskScope>,
83        mut dst: Tensor<Felt, TaskScope>,
84    ) -> Result<
85        (<GC as IopCtx>::Digest, CudaStackedPcsProverData<GC>),
86        SingleLayerMerkleTreeProverError,
87    > {
88        let encoder = SpparkDftKoalaBear::default();
89
90        unsafe {
91            dst.assume_init();
92        }
93
94        let virtual_tensor = if use_preprocessed {
95            jagged_trace_mle.preprocessed_virtual_tensor(self.log_height)
96        } else {
97            jagged_trace_mle.main_virtual_tensor(self.log_height)
98        };
99
100        encode_batch(encoder, self.config.log_blowup as u32, virtual_tensor, &mut dst).unwrap();
101
102        // Commit to the tensors.
103
104        let (commitment, tcs_data) = self.tcs_prover.commit_tensors(&dst)?;
105
106        let codeword_mle = if drop_traces { None } else { Some(Arc::new(dst)) };
107        let prover_data = CudaStackedPcsProverData { merkle_tree_tcs_data: tcs_data, codeword_mle };
108
109        Ok((commitment, prover_data))
110    }
111
112    #[allow(clippy::type_complexity)]
113    pub fn batch(
114        &self,
115        batching_coefficients: &Tensor<GC::EF>,
116        mles: &TraceDenseData<GC::F, TaskScope>,
117        codewords: Message<Tensor<Felt, TaskScope>>,
118        evaluation_claims: Vec<GC::EF>,
119    ) -> (Mle<GC::EF, TaskScope>, Tensor<GC::F, TaskScope>, GC::EF) {
120        let log_stacking_height = self.log_height;
121        // Compute all the batch challenge powers.
122        let total_num_polynomials = codewords.iter().map(|c| c.sizes()[0]).sum::<usize>();
123
124        // Compute the random linear combination of the MLEs of the columns of the matrices
125        let num_variables = log_stacking_height;
126        let codeword_size = (codewords.first().unwrap()).sizes()[1];
127        let scope: TaskScope = mles.backend().clone();
128        let mut batch_mle =
129            Mle::new(Tensor::<GC::EF, TaskScope>::zeros_in([1, 1 << num_variables], scope.clone()));
130        let mut batch_codeword = Tensor::<GC::F, TaskScope>::zeros_in(
131            [<GC::EF as AbstractExtensionField<GC::F>>::D, codeword_size],
132            scope.clone(),
133        );
134
135        unsafe {
136            let block_dim = 256;
137            let grid_dim = (1usize << num_variables).div_ceil(block_dim);
138            let batch_size = total_num_polynomials;
139            let powers_device = DeviceBuffer::from_host(batching_coefficients.as_buffer(), &scope)
140                .unwrap()
141                .into_inner();
142            let mle_args = args!(
143                mles.dense.as_ptr(),
144                batch_mle.guts_mut().as_mut_ptr(),
145                powers_device.as_ptr(),
146                (1 << num_variables) as usize,
147                batch_size
148            );
149            scope
150                .launch_kernel(TaskScope::batch_mle_kernel(), grid_dim, block_dim, &mle_args, 0)
151                .unwrap();
152        }
153
154        let mut batch_coefficients = batching_coefficients.as_buffer().to_vec();
155        for codeword in codewords.iter() {
156            let batch_size = codeword.sizes()[0];
157            let mut powers = batch_coefficients;
158            batch_coefficients = powers.split_off(batch_size);
159            let powers_device = DeviceBuffer::from_host(&Buffer::from(powers.clone()), &scope)
160                .unwrap()
161                .into_inner();
162
163            let block_dim = 256;
164            let grid_dim = codeword_size.div_ceil(block_dim);
165            let codeword_args = args!(
166                codeword.as_ptr(),
167                batch_codeword.as_mut_ptr(),
168                powers_device.as_ptr(),
169                codeword_size,
170                batch_size
171            );
172            unsafe {
173                scope
174                    .launch_kernel(
175                        TaskScope::batch_rs_codeword_kernel(),
176                        grid_dim,
177                        block_dim,
178                        &codeword_args,
179                        0,
180                    )
181                    .unwrap();
182            }
183        }
184
185        // Compute the batched evaluation claim.
186        let batch_eval_claim = evaluation_claims
187            .into_iter()
188            .zip(batching_coefficients.as_slice())
189            .map(|(eval, coeff)| eval * *coeff)
190            .sum::<GC::EF>();
191
192        (batch_mle, batch_codeword, batch_eval_claim)
193    }
194
195    #[allow(clippy::type_complexity)]
196    fn commit_phase_round(
197        &self,
198        current_mle: Mle<GC::EF, TaskScope>,
199        current_codeword: Tensor<GC::F, TaskScope>,
200        challenger: &mut GC::Challenger,
201    ) -> Result<
202        (
203            GC::EF,
204            Mle<GC::EF, TaskScope>,
205            Tensor<GC::F, TaskScope>,
206            GC::Digest,
207            Tensor<GC::F, TaskScope>,
208            MerkleTreeProverData<GC::Digest>,
209        ),
210        SingleLayerMerkleTreeProverError,
211    > {
212        // Perform a single round of the FRI commit phase, returning the commitment, folded
213        // codeword, and folding parameter.
214        // On CPU, the current codeword is in row-major form, which means that in order to put
215        // even and odd entries together all we need to do is rehsape it to multiply the number of
216        // columns by 2 and divide the number of rows by 2.
217        let codeword_size = current_codeword.sizes()[1];
218        let batch_size = current_codeword.sizes()[0];
219        let scope = current_codeword.backend().clone();
220
221        let mut leaves = Tensor::with_sizes_in([batch_size * 2, codeword_size / 2], scope.clone());
222        let output_codeword_size = codeword_size / 2;
223        let block_dim = 256;
224        let grid_dim = output_codeword_size.div_ceil(block_dim);
225        unsafe {
226            let args = args!(current_codeword.as_ptr(), leaves.as_mut_ptr(), output_codeword_size);
227            leaves.assume_init();
228            scope
229                .launch_kernel(
230                    TaskScope::transpose_even_odd_kernel(),
231                    grid_dim,
232                    block_dim,
233                    &args,
234                    0,
235                )
236                .unwrap();
237        }
238
239        let (commit, prover_data) = self.tcs_prover.commit_tensors(&leaves)?;
240        // Observe the commitment.
241        challenger.observe(commit);
242
243        let beta: GC::EF = challenger.sample_ext_element();
244
245        // Fold the mle.
246        let folded_mle: Mle<_, TaskScope> = {
247            let device_mle = DeviceMle::from(current_mle);
248            device_mle.fold(beta).into()
249        };
250        let folded_num_variables = folded_mle.num_variables();
251
252        if folded_num_variables < 4 {
253            let current_codeword_transposed =
254                DeviceTensor::from_raw(current_codeword.clone()).transpose();
255            let current_codeword_vec = current_codeword_transposed.to_host().unwrap();
256            let current_codeword_vec =
257                current_codeword_vec.into_buffer().into_extension::<GC::EF>().into_vec();
258            let folded_codeword_vec = host_fold_even_odd(current_codeword_vec, beta);
259            let folded_codeword_storage =
260                Buffer::from(folded_codeword_vec).flatten_to_base::<GC::F>();
261            let mut new_size = current_codeword.sizes().to_vec();
262            new_size[1] /= 2;
263            let folded_codeword =
264                DeviceBuffer::from_host(&folded_codeword_storage, folded_mle.backend())
265                    .unwrap()
266                    .into_inner();
267            let folded_codeword = Tensor::from(folded_codeword).reshape([new_size[1], new_size[0]]);
268            let folded_codeword = DeviceTensor::from_raw(folded_codeword).transpose().into_inner();
269            return Ok((beta, folded_mle, folded_codeword, commit, leaves, prover_data));
270        }
271
272        let folded_height = 1 << folded_num_variables;
273        let mut folded_mle_flattened = Tensor::<GC::F, TaskScope>::with_sizes_in(
274            [<GC::EF as AbstractExtensionField<GC::F>>::D, folded_height],
275            scope.clone(),
276        );
277
278        let mut folded_codeword = Tensor::<GC::F, TaskScope>::zeros_in(
279            [<GC::EF as AbstractExtensionField<GC::F>>::D, folded_height << self.config.log_blowup],
280            scope.clone(),
281        );
282
283        let block_dim = 256;
284        let grid_dim = folded_height.div_ceil(block_dim);
285        unsafe {
286            let args =
287                args!(folded_mle.guts().as_ptr(), folded_mle_flattened.as_mut_ptr(), folded_height);
288            folded_mle_flattened.assume_init();
289            scope
290                .launch_kernel(TaskScope::flatten_to_base_kernel(), grid_dim, block_dim, &args, 0)
291                .unwrap();
292        }
293        let encoder = SpparkDftKoalaBear::default();
294        encode_batch(
295            encoder,
296            self.config.log_blowup as u32,
297            folded_mle_flattened.as_view(),
298            &mut folded_codeword,
299        )
300        .unwrap();
301
302        Ok((beta, folded_mle, folded_codeword, commit, leaves, prover_data))
303    }
304
305    fn final_poly(&self, final_codeword: Tensor<GC::F, TaskScope>) -> GC::EF {
306        let final_codeword_host = DeviceTensor::from_raw(final_codeword).to_host().unwrap();
307        let final_codeword_transposed = final_codeword_host.transpose();
308        GC::EF::from_base_slice(
309            &final_codeword_transposed.storage.as_slice()
310                [0..(<GC::EF as AbstractExtensionField<GC::F>>::D)],
311        )
312    }
313
314    #[inline]
315    pub fn prove_trusted_evaluations_basefold(
316        &self,
317        mut eval_point: Point<GC::EF>,
318        evaluation_claims: Vec<GC::EF>,
319        mles: &JaggedTraceMle<GC::F, TaskScope>,
320        prover_data: Rounds<&CudaStackedPcsProverData<GC>>,
321        challenger: &mut GC::Challenger,
322    ) -> Result<BasefoldProof<GC>, BasefoldProverError<SingleLayerMerkleTreeProverError>>
323    where
324        GC::Challenger: DeviceGrindingChallenger<Witness = GC::F>,
325    {
326        let scope = mles.dense().dense.backend().clone();
327        let mut codewords: Vec<Arc<Tensor<Felt, TaskScope>>> = Vec::new();
328        for data in prover_data.iter() {
329            if let Some(ref codeword) = data.codeword_mle {
330                codewords.push(codeword.clone());
331            } else {
332                // Codeword was dropped - this is always a main trace.
333                let mut dst = Tensor::<Felt, TaskScope>::with_sizes_in(
334                    [
335                        mles.dense().main_size() >> self.log_height,
336                        1 << (self.log_height as usize + self.config.log_blowup()),
337                    ],
338                    scope.clone(),
339                );
340                unsafe {
341                    dst.assume_init();
342                }
343
344                let encoder = SpparkDftKoalaBear::default();
345                encode_batch(
346                    encoder,
347                    self.config.log_blowup as u32,
348                    mles.main_virtual_tensor(self.log_height),
349                    &mut dst,
350                )
351                .unwrap();
352
353                codewords.push(Arc::new(dst));
354            }
355        }
356
357        let total_num_polynomials = codewords.iter().map(|c| c.sizes()[0]).sum::<usize>();
358        let num_batching_variables = total_num_polynomials.next_power_of_two().ilog2();
359
360        let encoded_messages: Message<_> = codewords.iter().cloned().collect();
361
362        // Grind for batch randomness.
363        let batch_grinding_witness =
364            GrindingPowCudaProver::grind(challenger, BATCH_GRINDING_BITS, &scope);
365
366        let batching_point = challenger.sample_point::<GC::EF>(num_batching_variables);
367        let batching_coefficients = partial_lagrange_blocking(&batching_point);
368
369        // Batch the mles and codewords.
370        let (mle_batch, codeword_batch, batched_eval_claim) =
371            self.batch(&batching_coefficients, mles.dense(), encoded_messages, evaluation_claims);
372        // From this point on, run the BaseFold protocol on the random linear combination codeword,
373        // the random linear combination multilinear, and the random linear combination of the
374        // evaluation claims.
375        let mut current_mle = mle_batch;
376        let mut current_codeword = codeword_batch;
377        // Initialize the vecs that go into a BaseFoldProof.
378        let log_len = current_mle.num_variables();
379        let mut univariate_messages: Vec<[GC::EF; 2]> = vec![];
380        let mut fri_commitments = vec![];
381        let mut commit_phase_data = vec![];
382        let mut current_batched_eval_claim = batched_eval_claim;
383        let mut commit_phase_values = vec![];
384
385        assert_eq!(
386            current_mle.num_variables(),
387            eval_point.dimension() as u32,
388            "eval point dimension mismatch"
389        );
390
391        challenger.observe(Felt::from_canonical_usize(eval_point.dimension()));
392        for _ in 0..eval_point.dimension() {
393            // Compute claims for `g(X_0, X_1, ..., X_{d-1}, 0)` and `g(X_0, X_1, ..., X_{d-1}, 1)`.
394            let last_coord = eval_point.remove_last_coordinate();
395            let zero_values = {
396                use sp1_gpu_cudart::DeviceMle;
397                let device_mle = DeviceMle::from(current_mle.clone());
398                let evals = device_mle.fixed_at_zero(&eval_point);
399                evals.to_host_vec().unwrap()
400            };
401            let zero_val = zero_values[0];
402            let one_val = (current_batched_eval_claim - zero_val) / last_coord + zero_val;
403            let uni_poly = [zero_val, one_val];
404            univariate_messages.push(uni_poly);
405
406            uni_poly.iter().for_each(|elem| challenger.observe_ext_element(*elem));
407
408            // Perform a single round of the FRI commit phase, returning the commitment, folded
409            // codeword, and folding parameter.
410            let (beta, folded_mle, folded_codeword, commitment, leaves, prover_data) = self
411                .commit_phase_round(current_mle, current_codeword, challenger)
412                .map_err(BasefoldProverError::CommitPhaseError)?;
413
414            fri_commitments.push(commitment);
415            commit_phase_data.push(prover_data);
416            commit_phase_values.push(leaves);
417
418            current_mle = folded_mle;
419            current_codeword = folded_codeword;
420            current_batched_eval_claim = zero_val + beta * one_val;
421        }
422
423        let final_poly = self.final_poly(current_codeword);
424        challenger.observe_ext_element(final_poly);
425
426        let fri_config = self.config;
427        let pow_bits = fri_config.proof_of_work_bits;
428        let pow_witness = GrindingPowCudaProver::grind(challenger, pow_bits, &scope);
429        // FRI Query Phase.
430        let query_indices: Vec<usize> = (0..fri_config.num_queries)
431            .map(|_| challenger.sample_bits(log_len as usize + fri_config.log_blowup()))
432            .collect();
433
434        // Open the original polynomials at the query indices.
435        let mut component_polynomials_query_openings_and_proofs = vec![];
436        for (data, codeword) in prover_data.iter().zip(codewords.iter()) {
437            let values = self.tcs_prover.compute_openings_at_indices(codeword, &query_indices);
438            let proof = self
439                .tcs_prover
440                .prove_openings_at_indices(&data.merkle_tree_tcs_data, &query_indices)
441                .map_err(BasefoldProverError::TcsCommitError)?;
442            let opening = MerkleTreeOpeningAndProof::<GC> { values, proof };
443            component_polynomials_query_openings_and_proofs.push(opening);
444        }
445
446        // Provide openings for the FRI query phase.
447        let mut query_phase_openings_and_proofs = vec![];
448        let mut indices = query_indices;
449        for (leaves, data) in commit_phase_values.into_iter().zip_eq(commit_phase_data) {
450            for index in indices.iter_mut() {
451                *index >>= 1;
452            }
453            let values = self.tcs_prover.compute_openings_at_indices(&leaves, &indices);
454
455            let proof = self
456                .tcs_prover
457                .prove_openings_at_indices(&data, &indices)
458                .map_err(BasefoldProverError::TcsCommitError)?;
459            let opening = MerkleTreeOpeningAndProof { values, proof };
460            query_phase_openings_and_proofs.push(opening);
461        }
462
463        Ok(BasefoldProof {
464            univariate_messages,
465            fri_commitments,
466            component_polynomials_query_openings_and_proofs,
467            query_phase_openings_and_proofs,
468            final_poly,
469            pow_witness,
470            batch_grinding_witness,
471        })
472    }
473}
474
475unsafe impl MleBatchKernel<SP1Field, SP1ExtensionField> for TaskScope {
476    fn batch_mle_kernel() -> KernelPtr {
477        unsafe { batch_koala_bear_base_ext_kernel() }
478    }
479}
480
481unsafe impl RsCodeWordBatchKernel<SP1Field, SP1ExtensionField> for TaskScope {
482    fn batch_rs_codeword_kernel() -> KernelPtr {
483        unsafe { batch_koala_bear_base_ext_kernel_flattened() }
484    }
485}
486
487unsafe impl RsCodeWordTransposeKernel<SP1Field, SP1ExtensionField> for TaskScope {
488    fn transpose_even_odd_kernel() -> KernelPtr {
489        unsafe { transpose_even_odd_koala_bear_base_ext_kernel() }
490    }
491}
492
493unsafe impl MleFlattenKernel<SP1Field, SP1ExtensionField> for TaskScope {
494    fn flatten_to_base_kernel() -> KernelPtr {
495        unsafe { flatten_to_base_koala_bear_base_ext_kernel() }
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use std::sync::Arc;
502
503    use slop_alloc::{CpuBackend, ToHost};
504    use slop_basefold::BasefoldVerifier;
505    use slop_basefold_prover::BasefoldProver;
506    use slop_commit::Message;
507    use slop_futures::queue::WorkerQueue;
508    use slop_merkle_tree::Poseidon2KoalaBear16Prover;
509    use slop_multilinear::{Evaluations, Mle, MleEval};
510    use slop_stacked::interleave_multilinears_with_fixed_rate;
511    use sp1_gpu_cudart::{run_sync_in_place, PinnedBuffer};
512    use sp1_gpu_merkle_tree::{CudaTcsProver, Poseidon2SP1Field16CudaProver};
513    use sp1_gpu_tracegen::CudaTraceGenerator;
514    use sp1_hypercube::prover::{ProverSemaphore, TraceGenerator};
515
516    use sp1_core_machine::io::SP1Stdin;
517    use sp1_gpu_jagged_tracegen::test_utils::tracegen_setup::{
518        self, CORE_MAX_LOG_ROW_COUNT, LOG_STACKING_HEIGHT,
519    };
520    use sp1_gpu_jagged_tracegen::{full_tracegen, CORE_MAX_TRACE_SIZE};
521    use sp1_gpu_utils::{Ext, Felt, TestGC};
522    use sp1_primitives::fri_params::core_fri_config;
523    use sp1_primitives::SP1GlobalContext;
524
525    use super::*;
526
527    #[test]
528    fn test_basefold() {
529        let rt = tokio::runtime::Runtime::new().unwrap();
530        let (machine, record, program) =
531            rt.block_on(tracegen_setup::setup(&test_artifacts::FIBONACCI_ELF, SP1Stdin::new()));
532
533        run_sync_in_place(|scope| {
534            let verifier = BasefoldVerifier::<SP1GlobalContext>::new(core_fri_config(), 2);
535            let old_prover =
536                BasefoldProver::<SP1GlobalContext, Poseidon2KoalaBear16Prover>::new(&verifier);
537
538            let new_cuda_prover = FriCudaProver::<TestGC, _, Felt> {
539                tcs_prover: Poseidon2SP1Field16CudaProver::new(&scope),
540                config: verifier.fri_config,
541                log_height: LOG_STACKING_HEIGHT,
542                _marker: PhantomData::<TestGC>,
543            };
544
545            // Generate traces using the host tracegen.
546            let semaphore = ProverSemaphore::new(1);
547            let trace_generator = CudaTraceGenerator::new_in(machine.clone(), scope.clone());
548            let old_traces = rt.block_on(trace_generator.generate_traces(
549                program.clone(),
550                record.clone(),
551                CORE_MAX_LOG_ROW_COUNT as usize,
552                semaphore.clone(),
553            ));
554
555            let preprocessed_traces = old_traces.preprocessed_traces.clone();
556
557            let message = preprocessed_traces
558                .into_iter()
559                .filter_map(|mle| mle.1.into_inner())
560                .map(|x| Clone::clone(x.as_ref()))
561                .collect::<Message<Mle<_, _>>>();
562
563            let host_message: Message<_> = message
564                .clone()
565                .into_iter()
566                .map(|mle| {
567                    let mle = Arc::unwrap_or_clone(mle);
568                    let guts = mle.into_guts();
569                    let device_mle = sp1_gpu_cudart::DeviceMle::from(guts);
570                    device_mle.to_host().unwrap()
571                })
572                .collect();
573
574            let interleaved_message =
575                interleave_multilinears_with_fixed_rate(32, host_message, LOG_STACKING_HEIGHT);
576
577            let interleaved_message =
578                interleaved_message.into_iter().map(|x| x.as_ref().clone()).collect::<Message<_>>();
579
580            let (old_preprocessed_commitment, old_preprocessed_prover_data) =
581                old_prover.commit_mles(interleaved_message.clone()).unwrap();
582
583            let new_semaphore = ProverSemaphore::new(1);
584            let capacity = CORE_MAX_TRACE_SIZE as usize;
585            let buffer = PinnedBuffer::<Felt>::with_capacity(capacity);
586            let queue = Arc::new(WorkerQueue::new(vec![buffer]));
587            let buffer = rt.block_on(queue.pop()).unwrap();
588            let (_, new_traces, _, _) = rt.block_on(full_tracegen(
589                &machine,
590                program,
591                Arc::new(record),
592                &buffer,
593                CORE_MAX_TRACE_SIZE as usize,
594                LOG_STACKING_HEIGHT,
595                CORE_MAX_LOG_ROW_COUNT,
596                &scope,
597                new_semaphore,
598                false,
599            ));
600
601            let dst = Tensor::<Felt, TaskScope>::with_sizes_in(
602                [
603                    new_traces.0.dense().preprocessed_offset >> LOG_STACKING_HEIGHT,
604                    1 << (LOG_STACKING_HEIGHT as usize + verifier.fri_config.log_blowup()),
605                ],
606                scope.clone(),
607            );
608
609            let (new_preprocessed_commit, new_preprocessed_prover_data) =
610                new_cuda_prover.encode_and_commit(true, false, &new_traces, dst).unwrap();
611
612            assert_eq!(new_preprocessed_commit, old_preprocessed_commitment);
613
614            let dst = Tensor::<Felt, TaskScope>::with_sizes_in(
615                [
616                    new_traces.0.dense().main_size() >> LOG_STACKING_HEIGHT,
617                    1 << (LOG_STACKING_HEIGHT as usize + verifier.fri_config.log_blowup()),
618                ],
619                scope.clone(),
620            );
621
622            let (new_main_commit, new_main_prover_data) =
623                new_cuda_prover.encode_and_commit(false, false, &new_traces, dst).unwrap();
624            let message = old_traces
625                .main_trace_data
626                .traces
627                .into_iter()
628                .filter_map(|mle| mle.1.into_inner())
629                .map(|x| Clone::clone(x.as_ref()))
630                .collect::<Message<Mle<_, _>>>();
631
632            let mut host_message = Vec::new();
633            for mle in message.into_iter() {
634                let mle = Arc::unwrap_or_clone(mle);
635                let guts = mle.into_guts();
636                let device_mle = sp1_gpu_cudart::DeviceMle::from(guts);
637                let mle_host = device_mle.to_host().unwrap();
638                host_message.push(mle_host);
639            }
640
641            let host_message = host_message.into_iter().collect::<Message<Mle<Felt, CpuBackend>>>();
642
643            let interleaved_message_2 =
644                interleave_multilinears_with_fixed_rate(32, host_message, LOG_STACKING_HEIGHT);
645
646            let (old_main_commitment, old_main_prover_data) =
647                old_prover.commit_mles(interleaved_message_2.clone()).unwrap();
648
649            assert_eq!(new_main_commit, old_main_commitment);
650
651            let mut rng = rand::thread_rng();
652
653            let eval_point_host = Point::<Ext>::rand(&mut rng, LOG_STACKING_HEIGHT);
654
655            let evaluation_claims_1: Vec<_> = interleaved_message
656                .clone()
657                .into_iter()
658                .map(|mle| mle.eval_at(&eval_point_host))
659                .collect();
660
661            let evaluation_claims_1 = Evaluations { round_evaluations: evaluation_claims_1 };
662
663            let evaluation_claims_2: Vec<_> = interleaved_message_2
664                .clone()
665                .into_iter()
666                .map(|mle| mle.eval_at(&eval_point_host))
667                .collect();
668
669            let host_evaluation_claims_1: Vec<MleEval<Ext, CpuBackend>> = evaluation_claims_1
670                .round_evaluations
671                .iter()
672                .map(|mle| mle.to_host().unwrap())
673                .collect();
674
675            let host_evaluation_claims_2: Vec<MleEval<Ext, CpuBackend>> =
676                evaluation_claims_2.iter().map(|mle| mle.to_host().unwrap()).collect();
677
678            let flattened_evaluation_claims = vec![
679                MleEval::new(
680                    host_evaluation_claims_1
681                        .into_iter()
682                        .flat_map(|x: MleEval<Ext, CpuBackend>| x.evaluations().storage.to_vec())
683                        .collect(),
684                ),
685                MleEval::new(
686                    host_evaluation_claims_2
687                        .into_iter()
688                        .flat_map(|x: MleEval<Ext, CpuBackend>| x.evaluations().storage.to_vec())
689                        .collect(),
690                ),
691            ];
692
693            let evaluation_claims_2 = Evaluations { round_evaluations: evaluation_claims_2 };
694
695            let mut challenger = SP1GlobalContext::default_challenger();
696
697            scope.synchronize_blocking().unwrap();
698            let now = std::time::Instant::now();
699
700            let basefold_proof = old_prover
701                .prove_trusted_mle_evaluations(
702                    eval_point_host.clone(),
703                    vec![interleaved_message, interleaved_message_2].into_iter().collect(),
704                    vec![evaluation_claims_1.clone(), evaluation_claims_2.clone()]
705                        .into_iter()
706                        .collect(),
707                    vec![old_preprocessed_prover_data, old_main_prover_data].into_iter().collect(),
708                    &mut challenger,
709                )
710                .unwrap();
711
712            scope.synchronize_blocking().unwrap();
713            tracing::info!("Old proof time: {:?}", now.elapsed());
714
715            let mut challenger = SP1GlobalContext::default_challenger();
716
717            let flat_evaluation_claims: Vec<Ext> = evaluation_claims_1
718                .round_evaluations
719                .iter()
720                .chain(evaluation_claims_2.round_evaluations.iter())
721                .flat_map(|mle_eval| mle_eval.iter().copied())
722                .collect();
723
724            scope.synchronize_blocking().unwrap();
725
726            let now = std::time::Instant::now();
727
728            let new_basefold_proof = new_cuda_prover
729                .prove_trusted_evaluations_basefold(
730                    eval_point_host.clone(),
731                    flat_evaluation_claims,
732                    &new_traces,
733                    [&new_preprocessed_prover_data, &new_main_prover_data].into_iter().collect(),
734                    &mut challenger,
735                )
736                .unwrap();
737
738            scope.synchronize_blocking().unwrap();
739            tracing::info!("New proof time: {:?}", now.elapsed());
740
741            // Because the batch grinding is non-deterministic between CPU and GPU, the
742            // grinding witnesses may differ, causing all subsequent proof values (batching
743            // point, univariate messages, etc.) to diverge. Instead of comparing proof
744            // components directly, we verify both proofs independently.
745
746            verifier
747                .verify_mle_evaluations(
748                    &[old_preprocessed_commitment, old_main_commitment],
749                    eval_point_host.clone(),
750                    &flattened_evaluation_claims,
751                    &basefold_proof,
752                    &mut SP1GlobalContext::default_challenger(),
753                )
754                .unwrap();
755
756            verifier
757                .verify_mle_evaluations(
758                    &[new_preprocessed_commit, new_main_commit],
759                    eval_point_host,
760                    &flattened_evaluation_claims,
761                    &new_basefold_proof,
762                    &mut SP1GlobalContext::default_challenger(),
763                )
764                .unwrap();
765        })
766        .unwrap();
767    }
768}