Skip to main content

sp1_gpu_shard_prover/
prover.rs

1use crate::{MainTraceData, ShardData};
2use slop_algebra::AbstractField;
3use slop_alloc::{Buffer, HasBackend};
4use slop_challenger::{CanObserve, FieldChallenger, FromChallenger, IopCtx};
5use slop_commit::Rounds;
6use slop_futures::queue::{Worker, WorkerQueue};
7use slop_jagged::{
8    unzip_and_prefix_sums, JaggedLittlePolynomialProverParams, JaggedPcsProof, JaggedProverData,
9    JaggedProverError, PrefixSumsMaxLogRowCount,
10};
11use slop_multilinear::{MleEval, MultilinearPcsVerifier, Point};
12use sp1_gpu_air::air_block::BlockAir;
13use sp1_gpu_air::SymbolicProverFolder;
14use sp1_gpu_basefold::{CudaStackedPcsProverData, DeviceGrindingChallenger, FriCudaProver};
15use sp1_gpu_challenger::FromHostChallengerSync;
16use sp1_gpu_cudart::PinnedBuffer;
17use sp1_gpu_cudart::{DeviceMle, DevicePoint, TaskScope};
18use sp1_gpu_jagged_assist::prove_jagged_evaluation_sync;
19use sp1_gpu_jagged_sumcheck::{generate_jagged_sumcheck_poly, jagged_sumcheck};
20use sp1_gpu_jagged_tracegen::{full_tracegen_permit, main_tracegen_permit, CudaShardProverData};
21use sp1_gpu_logup_gkr::{prove_logup_gkr, CudaLogUpGkrOptions, Interactions};
22use sp1_gpu_merkle_tree::{CudaTcsProver, SingleLayerMerkleTreeProverError};
23use sp1_gpu_tracegen::CudaTracegenAir;
24use sp1_gpu_utils::{Ext, Felt, JaggedTraceMle};
25use sp1_gpu_zerocheck::zerocheck;
26use sp1_gpu_zerocheck::CudaEvalResult;
27use sp1_hypercube::prover::ZerocheckAir;
28use sp1_hypercube::{
29    air::{MachineAir, MachineProgram},
30    prover::{AirProver, PreprocessedData, ProverPermit, ProverSemaphore, ProvingKey},
31    Machine, MachineVerifyingKey, ShardProof,
32};
33use sp1_hypercube::{SP1PcsProof, ShardContextImpl};
34use std::collections::BTreeMap;
35use std::iter::once;
36use std::vec;
37use std::{marker::PhantomData, sync::Arc};
38use thiserror::Error;
39use tokio::sync::Mutex;
40use tracing::Instrument;
41
42pub trait CudaShardProverComponents<GC: IopCtx>: Send + Sync + 'static {
43    type P: CudaTcsProver<GC>;
44    type Air: CudaTracegenAir<GC::F>
45        + ZerocheckAir<Felt, Ext>
46        + for<'a> BlockAir<SymbolicProverFolder<'a>>;
47    type C: MultilinearPcsVerifier<GC> + Send + Sync;
48    /// The device challenger type used for GPU-based challenger operations.
49    type DeviceChallenger: sp1_gpu_jagged_assist::AsMutRawChallenger
50        + FromChallenger<GC::Challenger, TaskScope>
51        + FromHostChallengerSync<GC::Challenger>
52        + Clone
53        + Send
54        + Sync;
55}
56
57pub struct CudaShardProver<GC: IopCtx, PC: CudaShardProverComponents<GC>> {
58    inner: Arc<CudaShardProverInner<GC, PC>>,
59}
60
61impl<GC: IopCtx, PC: CudaShardProverComponents<GC>> Clone for CudaShardProver<GC, PC> {
62    fn clone(&self) -> Self {
63        Self { inner: self.inner.clone() }
64    }
65}
66
67impl<GC: IopCtx, PC: CudaShardProverComponents<GC>> CudaShardProver<GC, PC> {
68    #[allow(clippy::too_many_arguments)]
69    pub fn new(
70        trace_buffers: Arc<WorkerQueue<PinnedBuffer<GC::F>>>,
71        max_log_row_count: u32,
72        basefold_prover: FriCudaProver<GC, PC::P, GC::F>,
73        machine: Machine<GC::F, PC::Air>,
74        max_trace_size: usize,
75        backend: TaskScope,
76        all_interactions: BTreeMap<String, Arc<Interactions<GC::F, TaskScope>>>,
77        all_zerocheck_programs: BTreeMap<String, CudaEvalResult>,
78        recompute_first_layer: bool,
79        drop_ldes: bool,
80    ) -> Self {
81        Self {
82            inner: Arc::new(CudaShardProverInner {
83                trace_buffers,
84                max_log_row_count,
85                basefold_prover,
86                machine,
87                max_trace_size,
88                backend,
89                all_interactions,
90                all_zerocheck_programs,
91                recompute_first_layer,
92                drop_ldes,
93                _marker: PhantomData,
94            }),
95        }
96    }
97}
98
99impl<GC: IopCtx<F = Felt, EF = Ext>, PC: CudaShardProverComponents<GC>> CudaShardProver<GC, PC> {
100    /// Prove trusted evaluations for a shard.
101    #[allow(clippy::type_complexity)]
102    pub fn prove_trusted_evaluations(
103        &self,
104        eval_point: Point<Ext>,
105        evaluation_claims: Rounds<MleEval<Ext, TaskScope>>,
106        all_mles: &JaggedTraceMle<Felt, TaskScope>,
107        prover_data: Rounds<&JaggedProverData<GC, CudaStackedPcsProverData<GC>>>,
108        challenger: &mut GC::Challenger,
109    ) -> Result<
110        JaggedPcsProof<GC, <PC::C as MultilinearPcsVerifier<GC>>::Proof>,
111        JaggedProverError<CudaShardProverError>,
112    >
113    where
114        GC::Challenger: DeviceGrindingChallenger<Witness = GC::F>,
115        GC::Challenger: slop_challenger::FieldChallenger<
116            <GC::Challenger as slop_challenger::GrindingChallenger>::Witness,
117        >,
118        SP1PcsProof<GC>: Into<<PC::C as MultilinearPcsVerifier<GC>>::Proof>,
119        TaskScope:
120            sp1_gpu_jagged_assist::BranchingProgramKernel<GC::F, GC::EF, PC::DeviceChallenger>,
121    {
122        self.inner.prove_trusted_evaluations(
123            eval_point,
124            evaluation_claims,
125            all_mles,
126            prover_data,
127            challenger,
128        )
129    }
130}
131
132/// A prover for the hypercube STARK, given a configuration.
133pub(crate) struct CudaShardProverInner<GC: IopCtx, PC: CudaShardProverComponents<GC>> {
134    #[allow(clippy::type_complexity)]
135    pub trace_buffers: Arc<WorkerQueue<PinnedBuffer<GC::F>>>,
136    pub max_log_row_count: u32,
137    pub basefold_prover: FriCudaProver<GC, PC::P, GC::F>,
138    pub machine: Machine<GC::F, PC::Air>,
139    pub max_trace_size: usize,
140    pub backend: TaskScope,
141    pub all_interactions: BTreeMap<String, Arc<Interactions<GC::F, TaskScope>>>,
142    pub all_zerocheck_programs: BTreeMap<String, CudaEvalResult>,
143    pub recompute_first_layer: bool,
144    pub drop_ldes: bool,
145    pub _marker: PhantomData<GC>,
146}
147
148impl<GC: IopCtx<F = Felt, EF = Ext>, PC: CudaShardProverComponents<GC>>
149    CudaShardProverInner<GC, PC>
150{
151    pub async fn get_buffer(&self) -> Worker<PinnedBuffer<GC::F>> {
152        self.trace_buffers.clone().pop().await.expect("buffer pool exhausted")
153    }
154
155    fn machine(&self) -> &Machine<GC::F, PC::Air> {
156        &self.machine
157    }
158}
159
160impl<GC: IopCtx<F = Felt, EF = Ext>, PC: CudaShardProverComponents<GC>>
161    AirProver<GC, ShardContextImpl<GC, PC::C, PC::Air>> for CudaShardProver<GC, PC>
162where
163    GC::Challenger: DeviceGrindingChallenger<Witness = GC::F>,
164    GC::Challenger: slop_challenger::FieldChallenger<
165        <GC::Challenger as slop_challenger::GrindingChallenger>::Witness,
166    >,
167    SP1PcsProof<GC>: Into<<PC::C as MultilinearPcsVerifier<GC>>::Proof>,
168    TaskScope: sp1_gpu_jagged_assist::BranchingProgramKernel<GC::F, GC::EF, PC::DeviceChallenger>,
169{
170    type PreprocessedData = Mutex<CudaShardProverData<GC, PC::Air>>;
171
172    fn machine(&self) -> &Machine<GC::F, PC::Air> {
173        &self.inner.machine
174    }
175
176    /// Setup a shard, using a verifying key if provided.
177    async fn setup_from_vk(
178        &self,
179        program: Arc<<PC::Air as MachineAir<GC::F>>::Program>,
180        vk: Option<MachineVerifyingKey<GC>>,
181        prover_permits: ProverSemaphore,
182    ) -> (
183        PreprocessedData<ProvingKey<GC, ShardContextImpl<GC, PC::C, PC::Air>, Self>>,
184        MachineVerifyingKey<GC>,
185    ) {
186        let inner = self.inner.clone();
187        if let Some(vk) = vk {
188            let initial_global_cumulative_sum = vk.initial_global_cumulative_sum;
189            inner
190                .setup_with_initial_global_cumulative_sum(
191                    program,
192                    initial_global_cumulative_sum,
193                    prover_permits,
194                )
195                .await
196        } else {
197            let program_sent = program.clone();
198            let initial_global_cumulative_sum =
199                tokio::task::spawn_blocking(move || program_sent.initial_global_cumulative_sum())
200                    .await
201                    .unwrap();
202            inner
203                .setup_with_initial_global_cumulative_sum(
204                    program,
205                    initial_global_cumulative_sum,
206                    prover_permits,
207                )
208                .await
209        }
210    }
211
212    /// Setup and prove a shard.
213    async fn setup_and_prove_shard(
214        &self,
215        program: Arc<<PC::Air as MachineAir<GC::F>>::Program>,
216        record: <PC::Air as MachineAir<GC::F>>::Record,
217        vk: Option<MachineVerifyingKey<GC>>,
218        prover_permits: ProverSemaphore,
219    ) -> (
220        MachineVerifyingKey<GC>,
221        ShardProof<GC, <PC::C as MultilinearPcsVerifier<GC>>::Proof>,
222        ProverPermit,
223    ) {
224        // Get the initial global cumulative sum and pc start.
225        let pc_start = program.pc_start();
226        let untrusted_config = program.untrusted_config();
227        let initial_global_cumulative_sum = if let Some(vk) = vk {
228            vk.initial_global_cumulative_sum
229        } else {
230            let program = program.clone();
231            tokio::task::spawn_blocking(move || program.initial_global_cumulative_sum())
232                .instrument(tracing::debug_span!("initial_global_cumulative_sum"))
233                .await
234                .unwrap()
235        };
236
237        let buffer = self.inner.get_buffer().await;
238
239        let record = Arc::new(record);
240
241        // Generate trace.
242        let (public_values, trace_data, chip_set, permit) = full_tracegen_permit(
243            self.machine(),
244            program,
245            record,
246            &buffer,
247            self.inner.max_trace_size,
248            self.inner.basefold_prover.log_height,
249            self.inner.max_log_row_count,
250            &self.inner.backend,
251            prover_permits,
252            true,
253        )
254        .instrument(tracing::debug_span!("generate all traces"))
255        .await;
256
257        let inner = self.inner.clone();
258        let (pk, vk) = tokio::task::spawn_blocking({
259            let span = tracing::debug_span!("setup_from_preprocessed_data_and_traces");
260            move || {
261                let _guard = span.enter();
262                inner.setup_from_preprocessed_data_and_traces(
263                    pc_start,
264                    initial_global_cumulative_sum,
265                    trace_data,
266                    untrusted_config,
267                )
268            }
269        })
270        .await
271        .unwrap();
272
273        let trace_data = Mutex::new(pk);
274
275        let pk = ProvingKey { vk: vk.clone(), preprocessed_data: trace_data };
276
277        let pk = Arc::new(pk);
278
279        let main_trace_data =
280            MainTraceData { traces: pk, public_values, shard_chips: chip_set, permit };
281
282        // Create a chanllenger
283        let mut challenger = GC::default_challenger();
284        // Observe the preprocessed information.
285        vk.observe_into(&mut challenger);
286
287        let shard_data = ShardData { main_trace_data };
288
289        let inner = self.inner.clone();
290        let (shard_proof, permit) = tokio::task::spawn_blocking({
291            let span = tracing::debug_span!("prove_shard_with_data");
292            move || {
293                let _guard = span.enter();
294                inner.prove_shard_with_data(shard_data, challenger)
295            }
296        })
297        .await
298        .unwrap();
299
300        // tracing::debug_span!("prove shard with data")
301        //     .in_scope(|| self.prove_shard_with_data(shard_data, challenger));
302        drop(buffer);
303
304        (vk, shard_proof, permit)
305    }
306
307    /// Prove a shard with a given proving key.
308    async fn prove_shard_with_pk(
309        &self,
310        pk: Arc<ProvingKey<GC, ShardContextImpl<GC, PC::C, PC::Air>, Self>>,
311        record: <PC::Air as MachineAir<GC::F>>::Record,
312        prover_permits: ProverSemaphore,
313    ) -> (ShardProof<GC, <PC::C as MultilinearPcsVerifier<GC>>::Proof>, ProverPermit) {
314        // Generate the traces.
315        let record = Arc::new(record);
316
317        let buffer = self.inner.get_buffer().await;
318
319        let (public_values, chip_set, permit) = main_tracegen_permit(
320            &self.inner.machine,
321            record,
322            &pk.preprocessed_data,
323            &buffer,
324            self.inner.basefold_prover.log_height,
325            self.inner.max_log_row_count,
326            &self.inner.backend,
327            prover_permits,
328            true,
329        )
330        .instrument(tracing::debug_span!("generate main traces"))
331        .await;
332
333        let shard_data = ShardData {
334            main_trace_data: MainTraceData {
335                traces: pk.clone(),
336                public_values,
337                shard_chips: chip_set,
338                permit,
339            },
340        };
341
342        let mut challenger = GC::default_challenger();
343        pk.vk.observe_into(&mut challenger);
344
345        let inner = self.inner.clone();
346        let (shard_proof, permit) = tokio::task::spawn_blocking({
347            let span = tracing::debug_span!("prove_shard_with_data");
348            move || {
349                let _guard = span.enter();
350                inner.prove_shard_with_data(shard_data, challenger)
351            }
352        })
353        .await
354        .unwrap();
355
356        drop(buffer);
357
358        (shard_proof, permit)
359    }
360
361    async fn preprocessed_table_heights(
362        pk: Arc<ProvingKey<GC, ShardContextImpl<GC, PC::C, PC::Air>, Self>>,
363    ) -> BTreeMap<String, usize> {
364        // Access through pk.preprocessed_data which is of type CudaShardProverData
365        let preprocessed_data = pk.preprocessed_data.lock().await;
366        preprocessed_data
367            .preprocessed_traces
368            .dense()
369            .preprocessed_table_index
370            .iter()
371            .map(|(name, offset)| (name.clone(), offset.poly_size))
372            .collect()
373    }
374}
375
376// An error type for cuda jagged prover
377#[derive(Debug, Error)]
378pub enum CudaShardProverError {}
379
380impl<GC: IopCtx<F = Felt, EF = Ext>, PC: CudaShardProverComponents<GC>>
381    CudaShardProverInner<GC, PC>
382{
383    /// Commit to a batch of padded multilinears.
384    ///
385    /// The jagged polynomial commitments scheme is able to commit to sparse polynomials having
386    /// very few or no real rows.
387    /// **Note** the padding values will be ignored and treated as though they are zero.
388    #[allow(clippy::type_complexity)]
389    pub fn commit_multilinears(
390        &self,
391        multilinears: &JaggedTraceMle<Felt, TaskScope>,
392        use_preprocessed_data: bool,
393    ) -> Result<
394        (GC::Digest, JaggedProverData<GC, CudaStackedPcsProverData<GC>>),
395        JaggedProverError<SingleLayerMerkleTreeProverError>,
396    > {
397        sp1_gpu_commit::commit_multilinears::<GC, PC::P>(
398            multilinears,
399            self.max_log_row_count,
400            use_preprocessed_data,
401            self.drop_ldes,
402            &self.basefold_prover,
403        )
404        .map_err(JaggedProverError::BatchPcsProverError)
405    }
406
407    /// Prove trusted evaluations (sync version).
408    #[allow(clippy::type_complexity)]
409    pub fn prove_trusted_evaluations(
410        &self,
411        eval_point: Point<Ext>,
412        evaluation_claims: Rounds<MleEval<Ext, TaskScope>>,
413        all_mles: &JaggedTraceMle<Felt, TaskScope>,
414        prover_data: Rounds<&JaggedProverData<GC, CudaStackedPcsProverData<GC>>>,
415        challenger: &mut GC::Challenger,
416    ) -> Result<
417        JaggedPcsProof<GC, <PC::C as MultilinearPcsVerifier<GC>>::Proof>,
418        JaggedProverError<CudaShardProverError>,
419    >
420    where
421        GC::Challenger: DeviceGrindingChallenger<Witness = GC::F>,
422        GC::Challenger: slop_challenger::FieldChallenger<
423            <GC::Challenger as slop_challenger::GrindingChallenger>::Witness,
424        >,
425        SP1PcsProof<GC>: Into<<PC::C as MultilinearPcsVerifier<GC>>::Proof>,
426        TaskScope:
427            sp1_gpu_jagged_assist::BranchingProgramKernel<GC::F, GC::EF, PC::DeviceChallenger>,
428    {
429        let num_col_variables = prover_data
430            .iter()
431            .map(|data| data.column_counts.iter().sum::<usize>())
432            .sum::<usize>()
433            .next_power_of_two()
434            .ilog2();
435        let z_col = (0..num_col_variables)
436            .map(|_| challenger.sample_ext_element::<Ext>())
437            .collect::<Point<_>>();
438
439        let z_row = eval_point.clone();
440
441        let backend = evaluation_claims[0].backend().clone();
442
443        // First, allocate a buffer for all of the column claims on device.
444        let total_column_claims =
445            evaluation_claims.iter().map(|evals| evals.num_polynomials()).sum::<usize>();
446
447        // Add in the dummy padding columns added during the stacked PCS commitment.
448        let total_len = total_column_claims
449            + prover_data.iter().map(|data| data.padding_column_count).sum::<usize>();
450
451        let mut column_claims: Buffer<Ext, TaskScope> =
452            Buffer::with_capacity_in(total_len, backend.clone());
453
454        // Then, copy the column claims from the evaluation claims into the buffer, inserting extra
455        // zeros for the dummy columns.
456        for (column_claim_round, data) in evaluation_claims.into_iter().zip(prover_data.iter()) {
457            column_claims
458                .extend_from_device_slice(column_claim_round.into_evaluations().as_buffer())?;
459            column_claims
460                .extend_from_host_slice(vec![Ext::zero(); data.padding_column_count].as_slice())?;
461        }
462
463        assert!(prover_data
464            .iter()
465            .flat_map(|data| data.row_counts.iter())
466            .all(|x| *x <= 1 << self.max_log_row_count));
467
468        // Collect the jagged polynomial parameters.
469        let params = JaggedLittlePolynomialProverParams::new(
470            prover_data
471                .iter()
472                .flat_map(|data| {
473                    data.row_counts
474                        .iter()
475                        .copied()
476                        .zip(data.column_counts.iter().copied())
477                        .flat_map(|(row_count, column_count)| {
478                            std::iter::repeat_n(row_count, column_count)
479                        })
480                })
481                .collect(),
482            self.max_log_row_count as usize,
483        );
484
485        // Generate the jagged sumcheck proof.
486        let z_row_device = DevicePoint::from_host(&z_row, &backend).unwrap();
487        let z_col_device = DevicePoint::from_host(&z_col, &backend).unwrap();
488
489        // The overall evaluation claim of the sparse polynomial is inferred from the individual
490        // table claims.
491        let device_column_claims = DeviceMle::from(column_claims);
492
493        // Use the sync GPU evaluation
494        let sumcheck_claims = device_column_claims.eval_at_point(&z_col_device);
495        let sumcheck_claims_host = sumcheck_claims.to_host_vec().unwrap();
496        let sumcheck_claim = sumcheck_claims_host[0];
497
498        // Compute eq polynomials for the jagged sumcheck
499        let eq_z_row = z_row_device.partial_lagrange();
500        let eq_z_col = z_col_device.partial_lagrange();
501
502        let sumcheck_poly = generate_jagged_sumcheck_poly(all_mles, eq_z_col, eq_z_row);
503
504        let log_stacking_height = self.basefold_prover.log_height as usize;
505
506        let (sumcheck_proof, component_poly_evals, column_evals) =
507            tracing::debug_span!("jagged sumcheck").in_scope(|| {
508                jagged_sumcheck(sumcheck_poly, challenger, sumcheck_claim, log_stacking_height)
509            });
510        let final_eval_point = sumcheck_proof.point_and_eval.0.clone();
511
512        // Use sync GPU jagged evaluation proof
513        let jagged_eval_proof = tracing::debug_span!("jagged evaluation proof").in_scope(|| {
514            prove_jagged_evaluation_sync::<Felt, Ext, GC::Challenger, PC::DeviceChallenger>(
515                &params,
516                &z_row,
517                &z_col,
518                &final_eval_point,
519                challenger,
520                component_poly_evals[1],
521                &backend,
522            )
523        });
524
525        let (row_counts, column_counts): (Rounds<_>, Rounds<_>) = prover_data
526            .iter()
527            .map(|data| {
528                (Clone::clone(data.row_counts.as_ref()), Clone::clone(data.column_counts.as_ref()))
529            })
530            .unzip();
531
532        let original_commitments: Rounds<_> =
533            prover_data.iter().map(|data| data.original_commitment).collect();
534
535        let stacked_prover_data =
536            prover_data.iter().map(|data| &data.pcs_prover_data).collect::<Rounds<_>>();
537
538        let final_eval_point = sumcheck_proof.point_and_eval.0.clone();
539
540        let (_, stack_point) =
541            final_eval_point.split_at(final_eval_point.dimension() - log_stacking_height);
542
543        // Copy column evals to host once and reuse for the challenger transcript, the basefold
544        // call, and the SP1PcsProof field.
545        let column_evals_host = column_evals.to_host().unwrap();
546
547        challenger.observe_ext_element(component_poly_evals[0]);
548        for &evaluation in &column_evals_host {
549            challenger.observe_ext_element(evaluation);
550        }
551
552        let pcs_proof = tracing::debug_span!("prove trusted evaluations basefold")
553            .in_scope(|| {
554                self.basefold_prover.prove_trusted_evaluations_basefold(
555                    stack_point,
556                    column_evals_host.clone(),
557                    all_mles,
558                    stacked_prover_data,
559                    challenger,
560                )
561            })
562            .unwrap();
563
564        let row_counts_and_column_counts: Rounds<Vec<(usize, usize)>> = row_counts
565            .into_iter()
566            .zip(column_counts)
567            .map(|(r, c)| r.into_iter().zip(c).collect())
568            .collect();
569
570        let preprocessed_stacked_size =
571            all_mles.dense().preprocessed_offset / (1 << log_stacking_height);
572        let mut prep_evals_host = column_evals_host;
573        let main_evals_host = prep_evals_host.split_off(preprocessed_stacked_size);
574
575        let host_batch_evaluations: Rounds<MleEval<Ext>> = Rounds {
576            rounds: vec![
577                MleEval::new(prep_evals_host.into()),
578                MleEval::new(main_evals_host.into()),
579            ],
580        };
581
582        let stacked_basefold_proof =
583            SP1PcsProof { basefold_proof: pcs_proof, batch_evaluations: host_batch_evaluations };
584
585        let PrefixSumsMaxLogRowCount { log_m, .. } =
586            unzip_and_prefix_sums(&row_counts_and_column_counts);
587
588        Ok(JaggedPcsProof {
589            pcs_proof: stacked_basefold_proof.into(),
590            sumcheck_proof,
591            jagged_eval_proof,
592            row_counts_and_column_counts,
593            merkle_tree_commitments: original_commitments,
594            expected_eval: component_poly_evals[0],
595            max_log_row_count: self.max_log_row_count as usize,
596            log_m,
597        })
598    }
599
600    fn commit_traces(
601        &self,
602        traces: &JaggedTraceMle<GC::F, TaskScope>,
603        use_preprocessed: bool,
604    ) -> (GC::Digest, JaggedProverData<GC, CudaStackedPcsProverData<GC>>) {
605        self.commit_multilinears(traces, use_preprocessed).unwrap()
606    }
607
608    /// Prove a shard with the given data (sync version).
609    /// This is the main proving function that runs on the GPU.
610    #[allow(clippy::type_complexity)]
611    pub fn prove_shard_with_data(
612        &self,
613        data: ShardData<GC, PC>,
614        mut challenger: GC::Challenger,
615    ) -> (ShardProof<GC, <PC::C as MultilinearPcsVerifier<GC>>::Proof>, ProverPermit)
616    where
617        GC::Challenger: DeviceGrindingChallenger<Witness = GC::F>,
618        GC::Challenger: slop_challenger::FieldChallenger<
619            <GC::Challenger as slop_challenger::GrindingChallenger>::Witness,
620        >,
621        SP1PcsProof<GC>: Into<<PC::C as MultilinearPcsVerifier<GC>>::Proof>,
622        TaskScope:
623            sp1_gpu_jagged_assist::BranchingProgramKernel<GC::F, GC::EF, PC::DeviceChallenger>,
624    {
625        let ShardData { main_trace_data } = data;
626        let MainTraceData { traces, public_values, shard_chips, permit } = main_trace_data;
627
628        let shard_chips = self.machine().smallest_cluster(&shard_chips).unwrap();
629
630        // Observe the public values.
631        challenger.observe_slice(&public_values);
632
633        let locked_preprocessed_data = traces.preprocessed_data.blocking_lock();
634        let traces = &locked_preprocessed_data.preprocessed_traces;
635        let preprocessed_data = &locked_preprocessed_data.preprocessed_data;
636
637        // Commit to the traces.
638        let (main_commit, main_data) =
639            tracing::debug_span!("commit traces").in_scope(|| self.commit_traces(traces, false));
640        // Observe the commitments.
641        <GC::Challenger as CanObserve<GC::Digest>>::observe(&mut challenger, main_commit);
642        challenger.observe(GC::F::from_canonical_usize(shard_chips.len()));
643
644        for (chip_name, chip_height) in traces.dense().main_table_index.iter() {
645            let chip_height = chip_height.poly_size;
646            challenger.observe(GC::F::from_canonical_usize(chip_height));
647            challenger.observe(GC::F::from_canonical_usize(chip_name.len()));
648            for byte in chip_name.as_bytes() {
649                challenger.observe(GC::F::from_canonical_u8(*byte));
650            }
651        }
652
653        let logup_gkr_proof = tracing::debug_span!("logup gkr proof").in_scope(|| {
654            prove_logup_gkr::<GC, _>(
655                shard_chips,
656                self.all_interactions.clone(),
657                traces,
658                CudaLogUpGkrOptions {
659                    recompute_first_layer: self.recompute_first_layer,
660                    num_row_variables: self.max_log_row_count,
661                },
662                &mut challenger,
663            )
664        });
665        // Get the challenge for batching constraints.
666        let batching_challenge = challenger.sample_ext_element::<GC::EF>();
667        // Get the challenge for batching the evaluations from the GKR proof.
668        let gkr_opening_batch_challenge = challenger.sample_ext_element::<GC::EF>();
669
670        // Generate the zerocheck proof.
671        let (shard_open_values, zerocheck_partial_sumcheck_proof) =
672            tracing::debug_span!("zerocheck").in_scope(|| {
673                zerocheck(
674                    shard_chips,
675                    &self.all_zerocheck_programs,
676                    traces,
677                    batching_challenge,
678                    gkr_opening_batch_challenge,
679                    &logup_gkr_proof.logup_evaluations,
680                    public_values.clone(),
681                    &mut challenger,
682                    self.max_log_row_count,
683                )
684            });
685
686        // Get the evaluation point for the trace polynomials.
687        let evaluation_point = zerocheck_partial_sumcheck_proof.point_and_eval.0.clone();
688        let mut preprocessed_host: Vec<GC::EF> = Vec::new();
689        let mut main_host: Vec<GC::EF> = Vec::new();
690        let mut has_preprocessed = false;
691
692        let alloc = self.backend.clone();
693
694        for (_, open_values) in shard_open_values.chips.iter() {
695            let prep_local = &open_values.preprocessed.local;
696            let main_local = &open_values.main.local;
697            if !prep_local.is_empty() {
698                has_preprocessed = true;
699                preprocessed_host.extend_from_slice(prep_local);
700            }
701            main_host.extend_from_slice(main_local);
702        }
703
704        let main_evaluation_claims = MleEval::new(
705            sp1_gpu_cudart::DeviceTensor::from_host(
706                &MleEval::from(main_host).into_evaluations(),
707                &alloc,
708            )
709            .unwrap()
710            .into_inner(),
711        );
712        let preprocessed_evaluation_claims = has_preprocessed.then(|| {
713            MleEval::new(
714                sp1_gpu_cudart::DeviceTensor::from_host(
715                    &MleEval::from(preprocessed_host).into_evaluations(),
716                    &alloc,
717                )
718                .unwrap()
719                .into_inner(),
720            )
721        });
722
723        let round_evaluation_claims = preprocessed_evaluation_claims
724            .into_iter()
725            .chain(once(main_evaluation_claims))
726            .collect::<Rounds<_>>();
727
728        let round_prover_data =
729            once(preprocessed_data).chain(once(&main_data)).collect::<Rounds<_>>();
730
731        // Generate the evaluation proof (sync call).
732        let evaluation_proof = tracing::debug_span!("prove evaluation claims").in_scope(|| {
733            self.prove_trusted_evaluations(
734                evaluation_point,
735                round_evaluation_claims,
736                traces,
737                round_prover_data,
738                &mut challenger,
739            )
740            .unwrap()
741        });
742
743        let proof = ShardProof {
744            main_commitment: main_commit,
745            opened_values: shard_open_values,
746            logup_gkr_proof,
747            evaluation_proof,
748            zerocheck_proof: zerocheck_partial_sumcheck_proof,
749            public_values,
750        };
751
752        (proof, permit)
753    }
754}
755
756#[cfg(test)]
757mod tests {
758    use super::*;
759    use serial_test::serial;
760    use slop_basefold::BasefoldVerifier;
761    use slop_jagged::JaggedPcsVerifier;
762    use slop_multilinear::MultilinearPcsChallenger;
763    use slop_tensor::Tensor;
764    use sp1_core_machine::io::SP1Stdin;
765    use sp1_core_machine::riscv::RiscvAir;
766    use sp1_gpu_air::codegen_cuda_eval;
767    use sp1_gpu_cudart::run_in_place;
768    use sp1_gpu_jagged_tracegen::test_utils::tracegen_setup::{
769        self, CORE_MAX_LOG_ROW_COUNT, LOG_STACKING_HEIGHT,
770    };
771    use sp1_gpu_jagged_tracegen::{full_tracegen, CORE_MAX_TRACE_SIZE};
772    use sp1_gpu_merkle_tree::{CudaTcsProver, Poseidon2SP1Field16CudaProver};
773    use sp1_gpu_utils::TestGC;
774    use sp1_gpu_zerocheck::primitives::round_batch_evaluations;
775    use sp1_hypercube::SP1InnerPcs;
776    use sp1_primitives::fri_params::core_fri_config;
777
778    pub struct TestProverComponentsImpl {}
779
780    impl CudaShardProverComponents<TestGC> for TestProverComponentsImpl {
781        type P = Poseidon2SP1Field16CudaProver;
782        type Air = RiscvAir<Felt>;
783        type C = SP1InnerPcs;
784        type DeviceChallenger = sp1_gpu_challenger::DuplexChallenger<Felt, TaskScope>;
785    }
786
787    #[tokio::test]
788    #[serial]
789    async fn test_prove_trusted_evaluations() {
790        let (machine, record, program) =
791            tracegen_setup::setup(&test_artifacts::FIBONACCI_ELF, SP1Stdin::new()).await;
792        run_in_place(|scope| async move {
793            // *********** Generate traces using the host tracegen. ***********
794            let capacity = CORE_MAX_TRACE_SIZE as usize;
795            let buffer = PinnedBuffer::<Felt>::with_capacity(capacity);
796            let queue = Arc::new(WorkerQueue::new(vec![buffer]));
797            let buffer = queue.pop().await.unwrap();
798            let (_public_values, jagged_trace_data, _shard_chips, _permit) = full_tracegen(
799                &machine,
800                program.clone(),
801                Arc::new(record),
802                &buffer,
803                CORE_MAX_TRACE_SIZE as usize,
804                LOG_STACKING_HEIGHT,
805                CORE_MAX_LOG_ROW_COUNT,
806                &scope,
807                ProverSemaphore::new(1),
808                true,
809            )
810            .await;
811
812            let jagged_trace_data = Arc::new(jagged_trace_data);
813
814            let verifier = BasefoldVerifier::<TestGC>::new(core_fri_config(), 2);
815
816            let basefold_prover = FriCudaProver::<TestGC, _, Felt>::new(
817                Poseidon2SP1Field16CudaProver::new(&scope),
818                verifier.fri_config,
819                LOG_STACKING_HEIGHT,
820            );
821
822            let mut all_interactions = BTreeMap::new();
823
824            for chip in machine.chips().iter() {
825                let host_interactions = Interactions::new(chip.sends(), chip.receives());
826                let device_interactions = host_interactions.copy_to_device(&scope).unwrap();
827                all_interactions.insert(chip.name().to_string(), Arc::new(device_interactions));
828            }
829
830            let mut cache = BTreeMap::new();
831            for chip in machine.chips().iter() {
832                let result = codegen_cuda_eval(chip.air.as_ref());
833                cache.insert(chip.name().to_string(), result);
834            }
835
836            let num_workers = 1;
837            let mut trace_buffers = Vec::with_capacity(num_workers);
838            for _ in 0..num_workers {
839                let buffer = PinnedBuffer::<Felt>::with_capacity(CORE_MAX_TRACE_SIZE as usize);
840                trace_buffers.push(buffer);
841            }
842
843            let shard_prover_inner: CudaShardProverInner<TestGC, TestProverComponentsImpl> =
844                CudaShardProverInner {
845                    trace_buffers: Arc::new(WorkerQueue::new(trace_buffers)),
846                    all_interactions,
847                    all_zerocheck_programs: cache,
848                    max_log_row_count: CORE_MAX_LOG_ROW_COUNT,
849                    basefold_prover,
850                    max_trace_size: CORE_MAX_TRACE_SIZE as usize,
851                    machine,
852                    recompute_first_layer: false,
853                    drop_ldes: false,
854                    backend: scope.clone(),
855                    _marker: PhantomData,
856                };
857            let shard_prover = CudaShardProver { inner: Arc::new(shard_prover_inner) };
858
859            let mut challenger = TestGC::default_challenger();
860
861            let eval_point = challenger.sample_point(CORE_MAX_LOG_ROW_COUNT);
862
863            // round_batch_evaluations is now sync and returns host evaluations
864            let evaluation_claims =
865                round_batch_evaluations(&eval_point, jagged_trace_data.as_ref());
866
867            let (preprocessed_digest, preprocessed_prover_data) =
868                shard_prover.inner.commit_multilinears(jagged_trace_data.as_ref(), true).unwrap();
869
870            let (main_digest, main_prover_data) =
871                shard_prover.inner.commit_multilinears(jagged_trace_data.as_ref(), false).unwrap();
872
873            let prover_data = Rounds::from_iter([&preprocessed_prover_data, &main_prover_data]);
874
875            // The evaluation_claims are already on host (CpuBackend) and split per chip.
876            // Pack each round's per-chip evaluations into a single host buffer, then upload
877            // once per round to the device.
878            let mut new_evaluation_claims = Vec::new();
879            for round_evals in evaluation_claims.iter() {
880                let mut round_host: Vec<Ext> = Vec::new();
881                for eval in round_evals.iter() {
882                    round_host.extend_from_slice(eval.to_vec().as_slice());
883                }
884                let device_tensor = sp1_gpu_cudart::DeviceTensor::from_host(
885                    &MleEval::from(round_host).into_evaluations(),
886                    &scope,
887                )
888                .unwrap();
889                new_evaluation_claims.push(MleEval::new(device_tensor.into_inner()));
890            }
891
892            let mut prover_challenger = challenger.clone();
893            let proof = shard_prover
894                .inner
895                .prove_trusted_evaluations(
896                    eval_point.clone(),
897                    new_evaluation_claims.into_iter().collect(),
898                    jagged_trace_data.as_ref(),
899                    prover_data,
900                    &mut prover_challenger,
901                )
902                .unwrap();
903
904            let jagged_verifier = JaggedPcsVerifier::<_, SP1InnerPcs>::new_from_basefold_params(
905                core_fri_config(),
906                LOG_STACKING_HEIGHT,
907                CORE_MAX_LOG_ROW_COUNT as usize,
908                2,
909            );
910
911            // evaluation_claims are already on host, just extract the values
912            let mut all_evaluations = Vec::new();
913            for round_evals in evaluation_claims.iter() {
914                let mut host_evals = Vec::new();
915                for eval in round_evals.iter() {
916                    // eval is already MleEval<Ext, CpuBackend>
917                    host_evals.extend_from_slice(eval.evaluations().as_buffer().as_slice());
918                }
919                let buf = Buffer::from(host_evals);
920                let mle_eval = MleEval::new(Tensor::from(buf));
921                all_evaluations.push(mle_eval);
922            }
923
924            let mut verifier_challenger = challenger.clone();
925            jagged_verifier
926                .verify_trusted_evaluations(
927                    &[preprocessed_digest, main_digest],
928                    eval_point,
929                    &all_evaluations,
930                    &proof,
931                    &mut verifier_challenger,
932                )
933                .unwrap();
934        })
935        .await;
936    }
937}