1use crate::{MainTraceData, ShardData};
2use slop_algebra::AbstractField;
3use slop_alloc::{Buffer, HasBackend};
4use slop_challenger::{CanObserve, FieldChallenger, FromChallenger, IopCtx};
5use slop_commit::Rounds;
6use slop_futures::queue::{Worker, WorkerQueue};
7use slop_jagged::{
8 unzip_and_prefix_sums, JaggedLittlePolynomialProverParams, JaggedPcsProof, JaggedProverData,
9 JaggedProverError, PrefixSumsMaxLogRowCount,
10};
11use slop_multilinear::{MleEval, MultilinearPcsVerifier, Point};
12use sp1_gpu_air::air_block::BlockAir;
13use sp1_gpu_air::SymbolicProverFolder;
14use sp1_gpu_basefold::{CudaStackedPcsProverData, DeviceGrindingChallenger, FriCudaProver};
15use sp1_gpu_challenger::FromHostChallengerSync;
16use sp1_gpu_cudart::PinnedBuffer;
17use sp1_gpu_cudart::{DeviceMle, DevicePoint, TaskScope};
18use sp1_gpu_jagged_assist::prove_jagged_evaluation_sync;
19use sp1_gpu_jagged_sumcheck::{generate_jagged_sumcheck_poly, jagged_sumcheck};
20use sp1_gpu_jagged_tracegen::{full_tracegen_permit, main_tracegen_permit, CudaShardProverData};
21use sp1_gpu_logup_gkr::{prove_logup_gkr, CudaLogUpGkrOptions, Interactions};
22use sp1_gpu_merkle_tree::{CudaTcsProver, SingleLayerMerkleTreeProverError};
23use sp1_gpu_tracegen::CudaTracegenAir;
24use sp1_gpu_utils::{Ext, Felt, JaggedTraceMle};
25use sp1_gpu_zerocheck::zerocheck;
26use sp1_gpu_zerocheck::CudaEvalResult;
27use sp1_hypercube::prover::ZerocheckAir;
28use sp1_hypercube::{
29 air::{MachineAir, MachineProgram},
30 prover::{AirProver, PreprocessedData, ProverPermit, ProverSemaphore, ProvingKey},
31 Machine, MachineVerifyingKey, ShardProof,
32};
33use sp1_hypercube::{SP1PcsProof, ShardContextImpl};
34use std::collections::BTreeMap;
35use std::iter::once;
36use std::vec;
37use std::{marker::PhantomData, sync::Arc};
38use thiserror::Error;
39use tokio::sync::Mutex;
40use tracing::Instrument;
41
42pub trait CudaShardProverComponents<GC: IopCtx>: Send + Sync + 'static {
43 type P: CudaTcsProver<GC>;
44 type Air: CudaTracegenAir<GC::F>
45 + ZerocheckAir<Felt, Ext>
46 + for<'a> BlockAir<SymbolicProverFolder<'a>>;
47 type C: MultilinearPcsVerifier<GC> + Send + Sync;
48 type DeviceChallenger: sp1_gpu_jagged_assist::AsMutRawChallenger
50 + FromChallenger<GC::Challenger, TaskScope>
51 + FromHostChallengerSync<GC::Challenger>
52 + Clone
53 + Send
54 + Sync;
55}
56
57pub struct CudaShardProver<GC: IopCtx, PC: CudaShardProverComponents<GC>> {
58 inner: Arc<CudaShardProverInner<GC, PC>>,
59}
60
61impl<GC: IopCtx, PC: CudaShardProverComponents<GC>> Clone for CudaShardProver<GC, PC> {
62 fn clone(&self) -> Self {
63 Self { inner: self.inner.clone() }
64 }
65}
66
67impl<GC: IopCtx, PC: CudaShardProverComponents<GC>> CudaShardProver<GC, PC> {
68 #[allow(clippy::too_many_arguments)]
69 pub fn new(
70 trace_buffers: Arc<WorkerQueue<PinnedBuffer<GC::F>>>,
71 max_log_row_count: u32,
72 basefold_prover: FriCudaProver<GC, PC::P, GC::F>,
73 machine: Machine<GC::F, PC::Air>,
74 max_trace_size: usize,
75 backend: TaskScope,
76 all_interactions: BTreeMap<String, Arc<Interactions<GC::F, TaskScope>>>,
77 all_zerocheck_programs: BTreeMap<String, CudaEvalResult>,
78 recompute_first_layer: bool,
79 drop_ldes: bool,
80 ) -> Self {
81 Self {
82 inner: Arc::new(CudaShardProverInner {
83 trace_buffers,
84 max_log_row_count,
85 basefold_prover,
86 machine,
87 max_trace_size,
88 backend,
89 all_interactions,
90 all_zerocheck_programs,
91 recompute_first_layer,
92 drop_ldes,
93 _marker: PhantomData,
94 }),
95 }
96 }
97}
98
99impl<GC: IopCtx<F = Felt, EF = Ext>, PC: CudaShardProverComponents<GC>> CudaShardProver<GC, PC> {
100 #[allow(clippy::type_complexity)]
102 pub fn prove_trusted_evaluations(
103 &self,
104 eval_point: Point<Ext>,
105 evaluation_claims: Rounds<MleEval<Ext, TaskScope>>,
106 all_mles: &JaggedTraceMle<Felt, TaskScope>,
107 prover_data: Rounds<&JaggedProverData<GC, CudaStackedPcsProverData<GC>>>,
108 challenger: &mut GC::Challenger,
109 ) -> Result<
110 JaggedPcsProof<GC, <PC::C as MultilinearPcsVerifier<GC>>::Proof>,
111 JaggedProverError<CudaShardProverError>,
112 >
113 where
114 GC::Challenger: DeviceGrindingChallenger<Witness = GC::F>,
115 GC::Challenger: slop_challenger::FieldChallenger<
116 <GC::Challenger as slop_challenger::GrindingChallenger>::Witness,
117 >,
118 SP1PcsProof<GC>: Into<<PC::C as MultilinearPcsVerifier<GC>>::Proof>,
119 TaskScope:
120 sp1_gpu_jagged_assist::BranchingProgramKernel<GC::F, GC::EF, PC::DeviceChallenger>,
121 {
122 self.inner.prove_trusted_evaluations(
123 eval_point,
124 evaluation_claims,
125 all_mles,
126 prover_data,
127 challenger,
128 )
129 }
130}
131
132pub(crate) struct CudaShardProverInner<GC: IopCtx, PC: CudaShardProverComponents<GC>> {
134 #[allow(clippy::type_complexity)]
135 pub trace_buffers: Arc<WorkerQueue<PinnedBuffer<GC::F>>>,
136 pub max_log_row_count: u32,
137 pub basefold_prover: FriCudaProver<GC, PC::P, GC::F>,
138 pub machine: Machine<GC::F, PC::Air>,
139 pub max_trace_size: usize,
140 pub backend: TaskScope,
141 pub all_interactions: BTreeMap<String, Arc<Interactions<GC::F, TaskScope>>>,
142 pub all_zerocheck_programs: BTreeMap<String, CudaEvalResult>,
143 pub recompute_first_layer: bool,
144 pub drop_ldes: bool,
145 pub _marker: PhantomData<GC>,
146}
147
148impl<GC: IopCtx<F = Felt, EF = Ext>, PC: CudaShardProverComponents<GC>>
149 CudaShardProverInner<GC, PC>
150{
151 pub async fn get_buffer(&self) -> Worker<PinnedBuffer<GC::F>> {
152 self.trace_buffers.clone().pop().await.expect("buffer pool exhausted")
153 }
154
155 fn machine(&self) -> &Machine<GC::F, PC::Air> {
156 &self.machine
157 }
158}
159
160impl<GC: IopCtx<F = Felt, EF = Ext>, PC: CudaShardProverComponents<GC>>
161 AirProver<GC, ShardContextImpl<GC, PC::C, PC::Air>> for CudaShardProver<GC, PC>
162where
163 GC::Challenger: DeviceGrindingChallenger<Witness = GC::F>,
164 GC::Challenger: slop_challenger::FieldChallenger<
165 <GC::Challenger as slop_challenger::GrindingChallenger>::Witness,
166 >,
167 SP1PcsProof<GC>: Into<<PC::C as MultilinearPcsVerifier<GC>>::Proof>,
168 TaskScope: sp1_gpu_jagged_assist::BranchingProgramKernel<GC::F, GC::EF, PC::DeviceChallenger>,
169{
170 type PreprocessedData = Mutex<CudaShardProverData<GC, PC::Air>>;
171
172 fn machine(&self) -> &Machine<GC::F, PC::Air> {
173 &self.inner.machine
174 }
175
176 async fn setup_from_vk(
178 &self,
179 program: Arc<<PC::Air as MachineAir<GC::F>>::Program>,
180 vk: Option<MachineVerifyingKey<GC>>,
181 prover_permits: ProverSemaphore,
182 ) -> (
183 PreprocessedData<ProvingKey<GC, ShardContextImpl<GC, PC::C, PC::Air>, Self>>,
184 MachineVerifyingKey<GC>,
185 ) {
186 let inner = self.inner.clone();
187 if let Some(vk) = vk {
188 let initial_global_cumulative_sum = vk.initial_global_cumulative_sum;
189 inner
190 .setup_with_initial_global_cumulative_sum(
191 program,
192 initial_global_cumulative_sum,
193 prover_permits,
194 )
195 .await
196 } else {
197 let program_sent = program.clone();
198 let initial_global_cumulative_sum =
199 tokio::task::spawn_blocking(move || program_sent.initial_global_cumulative_sum())
200 .await
201 .unwrap();
202 inner
203 .setup_with_initial_global_cumulative_sum(
204 program,
205 initial_global_cumulative_sum,
206 prover_permits,
207 )
208 .await
209 }
210 }
211
212 async fn setup_and_prove_shard(
214 &self,
215 program: Arc<<PC::Air as MachineAir<GC::F>>::Program>,
216 record: <PC::Air as MachineAir<GC::F>>::Record,
217 vk: Option<MachineVerifyingKey<GC>>,
218 prover_permits: ProverSemaphore,
219 ) -> (
220 MachineVerifyingKey<GC>,
221 ShardProof<GC, <PC::C as MultilinearPcsVerifier<GC>>::Proof>,
222 ProverPermit,
223 ) {
224 let pc_start = program.pc_start();
226 let untrusted_config = program.untrusted_config();
227 let initial_global_cumulative_sum = if let Some(vk) = vk {
228 vk.initial_global_cumulative_sum
229 } else {
230 let program = program.clone();
231 tokio::task::spawn_blocking(move || program.initial_global_cumulative_sum())
232 .instrument(tracing::debug_span!("initial_global_cumulative_sum"))
233 .await
234 .unwrap()
235 };
236
237 let buffer = self.inner.get_buffer().await;
238
239 let record = Arc::new(record);
240
241 let (public_values, trace_data, chip_set, permit) = full_tracegen_permit(
243 self.machine(),
244 program,
245 record,
246 &buffer,
247 self.inner.max_trace_size,
248 self.inner.basefold_prover.log_height,
249 self.inner.max_log_row_count,
250 &self.inner.backend,
251 prover_permits,
252 true,
253 )
254 .instrument(tracing::debug_span!("generate all traces"))
255 .await;
256
257 let inner = self.inner.clone();
258 let (pk, vk) = tokio::task::spawn_blocking({
259 let span = tracing::debug_span!("setup_from_preprocessed_data_and_traces");
260 move || {
261 let _guard = span.enter();
262 inner.setup_from_preprocessed_data_and_traces(
263 pc_start,
264 initial_global_cumulative_sum,
265 trace_data,
266 untrusted_config,
267 )
268 }
269 })
270 .await
271 .unwrap();
272
273 let trace_data = Mutex::new(pk);
274
275 let pk = ProvingKey { vk: vk.clone(), preprocessed_data: trace_data };
276
277 let pk = Arc::new(pk);
278
279 let main_trace_data =
280 MainTraceData { traces: pk, public_values, shard_chips: chip_set, permit };
281
282 let mut challenger = GC::default_challenger();
284 vk.observe_into(&mut challenger);
286
287 let shard_data = ShardData { main_trace_data };
288
289 let inner = self.inner.clone();
290 let (shard_proof, permit) = tokio::task::spawn_blocking({
291 let span = tracing::debug_span!("prove_shard_with_data");
292 move || {
293 let _guard = span.enter();
294 inner.prove_shard_with_data(shard_data, challenger)
295 }
296 })
297 .await
298 .unwrap();
299
300 drop(buffer);
303
304 (vk, shard_proof, permit)
305 }
306
307 async fn prove_shard_with_pk(
309 &self,
310 pk: Arc<ProvingKey<GC, ShardContextImpl<GC, PC::C, PC::Air>, Self>>,
311 record: <PC::Air as MachineAir<GC::F>>::Record,
312 prover_permits: ProverSemaphore,
313 ) -> (ShardProof<GC, <PC::C as MultilinearPcsVerifier<GC>>::Proof>, ProverPermit) {
314 let record = Arc::new(record);
316
317 let buffer = self.inner.get_buffer().await;
318
319 let (public_values, chip_set, permit) = main_tracegen_permit(
320 &self.inner.machine,
321 record,
322 &pk.preprocessed_data,
323 &buffer,
324 self.inner.basefold_prover.log_height,
325 self.inner.max_log_row_count,
326 &self.inner.backend,
327 prover_permits,
328 true,
329 )
330 .instrument(tracing::debug_span!("generate main traces"))
331 .await;
332
333 let shard_data = ShardData {
334 main_trace_data: MainTraceData {
335 traces: pk.clone(),
336 public_values,
337 shard_chips: chip_set,
338 permit,
339 },
340 };
341
342 let mut challenger = GC::default_challenger();
343 pk.vk.observe_into(&mut challenger);
344
345 let inner = self.inner.clone();
346 let (shard_proof, permit) = tokio::task::spawn_blocking({
347 let span = tracing::debug_span!("prove_shard_with_data");
348 move || {
349 let _guard = span.enter();
350 inner.prove_shard_with_data(shard_data, challenger)
351 }
352 })
353 .await
354 .unwrap();
355
356 drop(buffer);
357
358 (shard_proof, permit)
359 }
360
361 async fn preprocessed_table_heights(
362 pk: Arc<ProvingKey<GC, ShardContextImpl<GC, PC::C, PC::Air>, Self>>,
363 ) -> BTreeMap<String, usize> {
364 let preprocessed_data = pk.preprocessed_data.lock().await;
366 preprocessed_data
367 .preprocessed_traces
368 .dense()
369 .preprocessed_table_index
370 .iter()
371 .map(|(name, offset)| (name.clone(), offset.poly_size))
372 .collect()
373 }
374}
375
376#[derive(Debug, Error)]
378pub enum CudaShardProverError {}
379
380impl<GC: IopCtx<F = Felt, EF = Ext>, PC: CudaShardProverComponents<GC>>
381 CudaShardProverInner<GC, PC>
382{
383 #[allow(clippy::type_complexity)]
389 pub fn commit_multilinears(
390 &self,
391 multilinears: &JaggedTraceMle<Felt, TaskScope>,
392 use_preprocessed_data: bool,
393 ) -> Result<
394 (GC::Digest, JaggedProverData<GC, CudaStackedPcsProverData<GC>>),
395 JaggedProverError<SingleLayerMerkleTreeProverError>,
396 > {
397 sp1_gpu_commit::commit_multilinears::<GC, PC::P>(
398 multilinears,
399 self.max_log_row_count,
400 use_preprocessed_data,
401 self.drop_ldes,
402 &self.basefold_prover,
403 )
404 .map_err(JaggedProverError::BatchPcsProverError)
405 }
406
407 #[allow(clippy::type_complexity)]
409 pub fn prove_trusted_evaluations(
410 &self,
411 eval_point: Point<Ext>,
412 evaluation_claims: Rounds<MleEval<Ext, TaskScope>>,
413 all_mles: &JaggedTraceMle<Felt, TaskScope>,
414 prover_data: Rounds<&JaggedProverData<GC, CudaStackedPcsProverData<GC>>>,
415 challenger: &mut GC::Challenger,
416 ) -> Result<
417 JaggedPcsProof<GC, <PC::C as MultilinearPcsVerifier<GC>>::Proof>,
418 JaggedProverError<CudaShardProverError>,
419 >
420 where
421 GC::Challenger: DeviceGrindingChallenger<Witness = GC::F>,
422 GC::Challenger: slop_challenger::FieldChallenger<
423 <GC::Challenger as slop_challenger::GrindingChallenger>::Witness,
424 >,
425 SP1PcsProof<GC>: Into<<PC::C as MultilinearPcsVerifier<GC>>::Proof>,
426 TaskScope:
427 sp1_gpu_jagged_assist::BranchingProgramKernel<GC::F, GC::EF, PC::DeviceChallenger>,
428 {
429 let num_col_variables = prover_data
430 .iter()
431 .map(|data| data.column_counts.iter().sum::<usize>())
432 .sum::<usize>()
433 .next_power_of_two()
434 .ilog2();
435 let z_col = (0..num_col_variables)
436 .map(|_| challenger.sample_ext_element::<Ext>())
437 .collect::<Point<_>>();
438
439 let z_row = eval_point.clone();
440
441 let backend = evaluation_claims[0].backend().clone();
442
443 let total_column_claims =
445 evaluation_claims.iter().map(|evals| evals.num_polynomials()).sum::<usize>();
446
447 let total_len = total_column_claims
449 + prover_data.iter().map(|data| data.padding_column_count).sum::<usize>();
450
451 let mut column_claims: Buffer<Ext, TaskScope> =
452 Buffer::with_capacity_in(total_len, backend.clone());
453
454 for (column_claim_round, data) in evaluation_claims.into_iter().zip(prover_data.iter()) {
457 column_claims
458 .extend_from_device_slice(column_claim_round.into_evaluations().as_buffer())?;
459 column_claims
460 .extend_from_host_slice(vec![Ext::zero(); data.padding_column_count].as_slice())?;
461 }
462
463 assert!(prover_data
464 .iter()
465 .flat_map(|data| data.row_counts.iter())
466 .all(|x| *x <= 1 << self.max_log_row_count));
467
468 let params = JaggedLittlePolynomialProverParams::new(
470 prover_data
471 .iter()
472 .flat_map(|data| {
473 data.row_counts
474 .iter()
475 .copied()
476 .zip(data.column_counts.iter().copied())
477 .flat_map(|(row_count, column_count)| {
478 std::iter::repeat_n(row_count, column_count)
479 })
480 })
481 .collect(),
482 self.max_log_row_count as usize,
483 );
484
485 let z_row_device = DevicePoint::from_host(&z_row, &backend).unwrap();
487 let z_col_device = DevicePoint::from_host(&z_col, &backend).unwrap();
488
489 let device_column_claims = DeviceMle::from(column_claims);
492
493 let sumcheck_claims = device_column_claims.eval_at_point(&z_col_device);
495 let sumcheck_claims_host = sumcheck_claims.to_host_vec().unwrap();
496 let sumcheck_claim = sumcheck_claims_host[0];
497
498 let eq_z_row = z_row_device.partial_lagrange();
500 let eq_z_col = z_col_device.partial_lagrange();
501
502 let sumcheck_poly = generate_jagged_sumcheck_poly(all_mles, eq_z_col, eq_z_row);
503
504 let log_stacking_height = self.basefold_prover.log_height as usize;
505
506 let (sumcheck_proof, component_poly_evals, column_evals) =
507 tracing::debug_span!("jagged sumcheck").in_scope(|| {
508 jagged_sumcheck(sumcheck_poly, challenger, sumcheck_claim, log_stacking_height)
509 });
510 let final_eval_point = sumcheck_proof.point_and_eval.0.clone();
511
512 let jagged_eval_proof = tracing::debug_span!("jagged evaluation proof").in_scope(|| {
514 prove_jagged_evaluation_sync::<Felt, Ext, GC::Challenger, PC::DeviceChallenger>(
515 ¶ms,
516 &z_row,
517 &z_col,
518 &final_eval_point,
519 challenger,
520 component_poly_evals[1],
521 &backend,
522 )
523 });
524
525 let (row_counts, column_counts): (Rounds<_>, Rounds<_>) = prover_data
526 .iter()
527 .map(|data| {
528 (Clone::clone(data.row_counts.as_ref()), Clone::clone(data.column_counts.as_ref()))
529 })
530 .unzip();
531
532 let original_commitments: Rounds<_> =
533 prover_data.iter().map(|data| data.original_commitment).collect();
534
535 let stacked_prover_data =
536 prover_data.iter().map(|data| &data.pcs_prover_data).collect::<Rounds<_>>();
537
538 let final_eval_point = sumcheck_proof.point_and_eval.0.clone();
539
540 let (_, stack_point) =
541 final_eval_point.split_at(final_eval_point.dimension() - log_stacking_height);
542
543 let column_evals_host = column_evals.to_host().unwrap();
546
547 challenger.observe_ext_element(component_poly_evals[0]);
548 for &evaluation in &column_evals_host {
549 challenger.observe_ext_element(evaluation);
550 }
551
552 let pcs_proof = tracing::debug_span!("prove trusted evaluations basefold")
553 .in_scope(|| {
554 self.basefold_prover.prove_trusted_evaluations_basefold(
555 stack_point,
556 column_evals_host.clone(),
557 all_mles,
558 stacked_prover_data,
559 challenger,
560 )
561 })
562 .unwrap();
563
564 let row_counts_and_column_counts: Rounds<Vec<(usize, usize)>> = row_counts
565 .into_iter()
566 .zip(column_counts)
567 .map(|(r, c)| r.into_iter().zip(c).collect())
568 .collect();
569
570 let preprocessed_stacked_size =
571 all_mles.dense().preprocessed_offset / (1 << log_stacking_height);
572 let mut prep_evals_host = column_evals_host;
573 let main_evals_host = prep_evals_host.split_off(preprocessed_stacked_size);
574
575 let host_batch_evaluations: Rounds<MleEval<Ext>> = Rounds {
576 rounds: vec![
577 MleEval::new(prep_evals_host.into()),
578 MleEval::new(main_evals_host.into()),
579 ],
580 };
581
582 let stacked_basefold_proof =
583 SP1PcsProof { basefold_proof: pcs_proof, batch_evaluations: host_batch_evaluations };
584
585 let PrefixSumsMaxLogRowCount { log_m, .. } =
586 unzip_and_prefix_sums(&row_counts_and_column_counts);
587
588 Ok(JaggedPcsProof {
589 pcs_proof: stacked_basefold_proof.into(),
590 sumcheck_proof,
591 jagged_eval_proof,
592 row_counts_and_column_counts,
593 merkle_tree_commitments: original_commitments,
594 expected_eval: component_poly_evals[0],
595 max_log_row_count: self.max_log_row_count as usize,
596 log_m,
597 })
598 }
599
600 fn commit_traces(
601 &self,
602 traces: &JaggedTraceMle<GC::F, TaskScope>,
603 use_preprocessed: bool,
604 ) -> (GC::Digest, JaggedProverData<GC, CudaStackedPcsProverData<GC>>) {
605 self.commit_multilinears(traces, use_preprocessed).unwrap()
606 }
607
608 #[allow(clippy::type_complexity)]
611 pub fn prove_shard_with_data(
612 &self,
613 data: ShardData<GC, PC>,
614 mut challenger: GC::Challenger,
615 ) -> (ShardProof<GC, <PC::C as MultilinearPcsVerifier<GC>>::Proof>, ProverPermit)
616 where
617 GC::Challenger: DeviceGrindingChallenger<Witness = GC::F>,
618 GC::Challenger: slop_challenger::FieldChallenger<
619 <GC::Challenger as slop_challenger::GrindingChallenger>::Witness,
620 >,
621 SP1PcsProof<GC>: Into<<PC::C as MultilinearPcsVerifier<GC>>::Proof>,
622 TaskScope:
623 sp1_gpu_jagged_assist::BranchingProgramKernel<GC::F, GC::EF, PC::DeviceChallenger>,
624 {
625 let ShardData { main_trace_data } = data;
626 let MainTraceData { traces, public_values, shard_chips, permit } = main_trace_data;
627
628 let shard_chips = self.machine().smallest_cluster(&shard_chips).unwrap();
629
630 challenger.observe_slice(&public_values);
632
633 let locked_preprocessed_data = traces.preprocessed_data.blocking_lock();
634 let traces = &locked_preprocessed_data.preprocessed_traces;
635 let preprocessed_data = &locked_preprocessed_data.preprocessed_data;
636
637 let (main_commit, main_data) =
639 tracing::debug_span!("commit traces").in_scope(|| self.commit_traces(traces, false));
640 <GC::Challenger as CanObserve<GC::Digest>>::observe(&mut challenger, main_commit);
642 challenger.observe(GC::F::from_canonical_usize(shard_chips.len()));
643
644 for (chip_name, chip_height) in traces.dense().main_table_index.iter() {
645 let chip_height = chip_height.poly_size;
646 challenger.observe(GC::F::from_canonical_usize(chip_height));
647 challenger.observe(GC::F::from_canonical_usize(chip_name.len()));
648 for byte in chip_name.as_bytes() {
649 challenger.observe(GC::F::from_canonical_u8(*byte));
650 }
651 }
652
653 let logup_gkr_proof = tracing::debug_span!("logup gkr proof").in_scope(|| {
654 prove_logup_gkr::<GC, _>(
655 shard_chips,
656 self.all_interactions.clone(),
657 traces,
658 CudaLogUpGkrOptions {
659 recompute_first_layer: self.recompute_first_layer,
660 num_row_variables: self.max_log_row_count,
661 },
662 &mut challenger,
663 )
664 });
665 let batching_challenge = challenger.sample_ext_element::<GC::EF>();
667 let gkr_opening_batch_challenge = challenger.sample_ext_element::<GC::EF>();
669
670 let (shard_open_values, zerocheck_partial_sumcheck_proof) =
672 tracing::debug_span!("zerocheck").in_scope(|| {
673 zerocheck(
674 shard_chips,
675 &self.all_zerocheck_programs,
676 traces,
677 batching_challenge,
678 gkr_opening_batch_challenge,
679 &logup_gkr_proof.logup_evaluations,
680 public_values.clone(),
681 &mut challenger,
682 self.max_log_row_count,
683 )
684 });
685
686 let evaluation_point = zerocheck_partial_sumcheck_proof.point_and_eval.0.clone();
688 let mut preprocessed_host: Vec<GC::EF> = Vec::new();
689 let mut main_host: Vec<GC::EF> = Vec::new();
690 let mut has_preprocessed = false;
691
692 let alloc = self.backend.clone();
693
694 for (_, open_values) in shard_open_values.chips.iter() {
695 let prep_local = &open_values.preprocessed.local;
696 let main_local = &open_values.main.local;
697 if !prep_local.is_empty() {
698 has_preprocessed = true;
699 preprocessed_host.extend_from_slice(prep_local);
700 }
701 main_host.extend_from_slice(main_local);
702 }
703
704 let main_evaluation_claims = MleEval::new(
705 sp1_gpu_cudart::DeviceTensor::from_host(
706 &MleEval::from(main_host).into_evaluations(),
707 &alloc,
708 )
709 .unwrap()
710 .into_inner(),
711 );
712 let preprocessed_evaluation_claims = has_preprocessed.then(|| {
713 MleEval::new(
714 sp1_gpu_cudart::DeviceTensor::from_host(
715 &MleEval::from(preprocessed_host).into_evaluations(),
716 &alloc,
717 )
718 .unwrap()
719 .into_inner(),
720 )
721 });
722
723 let round_evaluation_claims = preprocessed_evaluation_claims
724 .into_iter()
725 .chain(once(main_evaluation_claims))
726 .collect::<Rounds<_>>();
727
728 let round_prover_data =
729 once(preprocessed_data).chain(once(&main_data)).collect::<Rounds<_>>();
730
731 let evaluation_proof = tracing::debug_span!("prove evaluation claims").in_scope(|| {
733 self.prove_trusted_evaluations(
734 evaluation_point,
735 round_evaluation_claims,
736 traces,
737 round_prover_data,
738 &mut challenger,
739 )
740 .unwrap()
741 });
742
743 let proof = ShardProof {
744 main_commitment: main_commit,
745 opened_values: shard_open_values,
746 logup_gkr_proof,
747 evaluation_proof,
748 zerocheck_proof: zerocheck_partial_sumcheck_proof,
749 public_values,
750 };
751
752 (proof, permit)
753 }
754}
755
756#[cfg(test)]
757mod tests {
758 use super::*;
759 use serial_test::serial;
760 use slop_basefold::BasefoldVerifier;
761 use slop_jagged::JaggedPcsVerifier;
762 use slop_multilinear::MultilinearPcsChallenger;
763 use slop_tensor::Tensor;
764 use sp1_core_machine::io::SP1Stdin;
765 use sp1_core_machine::riscv::RiscvAir;
766 use sp1_gpu_air::codegen_cuda_eval;
767 use sp1_gpu_cudart::run_in_place;
768 use sp1_gpu_jagged_tracegen::test_utils::tracegen_setup::{
769 self, CORE_MAX_LOG_ROW_COUNT, LOG_STACKING_HEIGHT,
770 };
771 use sp1_gpu_jagged_tracegen::{full_tracegen, CORE_MAX_TRACE_SIZE};
772 use sp1_gpu_merkle_tree::{CudaTcsProver, Poseidon2SP1Field16CudaProver};
773 use sp1_gpu_utils::TestGC;
774 use sp1_gpu_zerocheck::primitives::round_batch_evaluations;
775 use sp1_hypercube::SP1InnerPcs;
776 use sp1_primitives::fri_params::core_fri_config;
777
778 pub struct TestProverComponentsImpl {}
779
780 impl CudaShardProverComponents<TestGC> for TestProverComponentsImpl {
781 type P = Poseidon2SP1Field16CudaProver;
782 type Air = RiscvAir<Felt>;
783 type C = SP1InnerPcs;
784 type DeviceChallenger = sp1_gpu_challenger::DuplexChallenger<Felt, TaskScope>;
785 }
786
787 #[tokio::test]
788 #[serial]
789 async fn test_prove_trusted_evaluations() {
790 let (machine, record, program) =
791 tracegen_setup::setup(&test_artifacts::FIBONACCI_ELF, SP1Stdin::new()).await;
792 run_in_place(|scope| async move {
793 let capacity = CORE_MAX_TRACE_SIZE as usize;
795 let buffer = PinnedBuffer::<Felt>::with_capacity(capacity);
796 let queue = Arc::new(WorkerQueue::new(vec![buffer]));
797 let buffer = queue.pop().await.unwrap();
798 let (_public_values, jagged_trace_data, _shard_chips, _permit) = full_tracegen(
799 &machine,
800 program.clone(),
801 Arc::new(record),
802 &buffer,
803 CORE_MAX_TRACE_SIZE as usize,
804 LOG_STACKING_HEIGHT,
805 CORE_MAX_LOG_ROW_COUNT,
806 &scope,
807 ProverSemaphore::new(1),
808 true,
809 )
810 .await;
811
812 let jagged_trace_data = Arc::new(jagged_trace_data);
813
814 let verifier = BasefoldVerifier::<TestGC>::new(core_fri_config(), 2);
815
816 let basefold_prover = FriCudaProver::<TestGC, _, Felt>::new(
817 Poseidon2SP1Field16CudaProver::new(&scope),
818 verifier.fri_config,
819 LOG_STACKING_HEIGHT,
820 );
821
822 let mut all_interactions = BTreeMap::new();
823
824 for chip in machine.chips().iter() {
825 let host_interactions = Interactions::new(chip.sends(), chip.receives());
826 let device_interactions = host_interactions.copy_to_device(&scope).unwrap();
827 all_interactions.insert(chip.name().to_string(), Arc::new(device_interactions));
828 }
829
830 let mut cache = BTreeMap::new();
831 for chip in machine.chips().iter() {
832 let result = codegen_cuda_eval(chip.air.as_ref());
833 cache.insert(chip.name().to_string(), result);
834 }
835
836 let num_workers = 1;
837 let mut trace_buffers = Vec::with_capacity(num_workers);
838 for _ in 0..num_workers {
839 let buffer = PinnedBuffer::<Felt>::with_capacity(CORE_MAX_TRACE_SIZE as usize);
840 trace_buffers.push(buffer);
841 }
842
843 let shard_prover_inner: CudaShardProverInner<TestGC, TestProverComponentsImpl> =
844 CudaShardProverInner {
845 trace_buffers: Arc::new(WorkerQueue::new(trace_buffers)),
846 all_interactions,
847 all_zerocheck_programs: cache,
848 max_log_row_count: CORE_MAX_LOG_ROW_COUNT,
849 basefold_prover,
850 max_trace_size: CORE_MAX_TRACE_SIZE as usize,
851 machine,
852 recompute_first_layer: false,
853 drop_ldes: false,
854 backend: scope.clone(),
855 _marker: PhantomData,
856 };
857 let shard_prover = CudaShardProver { inner: Arc::new(shard_prover_inner) };
858
859 let mut challenger = TestGC::default_challenger();
860
861 let eval_point = challenger.sample_point(CORE_MAX_LOG_ROW_COUNT);
862
863 let evaluation_claims =
865 round_batch_evaluations(&eval_point, jagged_trace_data.as_ref());
866
867 let (preprocessed_digest, preprocessed_prover_data) =
868 shard_prover.inner.commit_multilinears(jagged_trace_data.as_ref(), true).unwrap();
869
870 let (main_digest, main_prover_data) =
871 shard_prover.inner.commit_multilinears(jagged_trace_data.as_ref(), false).unwrap();
872
873 let prover_data = Rounds::from_iter([&preprocessed_prover_data, &main_prover_data]);
874
875 let mut new_evaluation_claims = Vec::new();
879 for round_evals in evaluation_claims.iter() {
880 let mut round_host: Vec<Ext> = Vec::new();
881 for eval in round_evals.iter() {
882 round_host.extend_from_slice(eval.to_vec().as_slice());
883 }
884 let device_tensor = sp1_gpu_cudart::DeviceTensor::from_host(
885 &MleEval::from(round_host).into_evaluations(),
886 &scope,
887 )
888 .unwrap();
889 new_evaluation_claims.push(MleEval::new(device_tensor.into_inner()));
890 }
891
892 let mut prover_challenger = challenger.clone();
893 let proof = shard_prover
894 .inner
895 .prove_trusted_evaluations(
896 eval_point.clone(),
897 new_evaluation_claims.into_iter().collect(),
898 jagged_trace_data.as_ref(),
899 prover_data,
900 &mut prover_challenger,
901 )
902 .unwrap();
903
904 let jagged_verifier = JaggedPcsVerifier::<_, SP1InnerPcs>::new_from_basefold_params(
905 core_fri_config(),
906 LOG_STACKING_HEIGHT,
907 CORE_MAX_LOG_ROW_COUNT as usize,
908 2,
909 );
910
911 let mut all_evaluations = Vec::new();
913 for round_evals in evaluation_claims.iter() {
914 let mut host_evals = Vec::new();
915 for eval in round_evals.iter() {
916 host_evals.extend_from_slice(eval.evaluations().as_buffer().as_slice());
918 }
919 let buf = Buffer::from(host_evals);
920 let mle_eval = MleEval::new(Tensor::from(buf));
921 all_evaluations.push(mle_eval);
922 }
923
924 let mut verifier_challenger = challenger.clone();
925 jagged_verifier
926 .verify_trusted_evaluations(
927 &[preprocessed_digest, main_digest],
928 eval_point,
929 &all_evaluations,
930 &proof,
931 &mut verifier_challenger,
932 )
933 .unwrap();
934 })
935 .await;
936 }
937}