1use derive_where::derive_where;
2use itertools::Itertools;
3use serde::{Deserialize, Serialize};
4use slop_air::Air;
5use slop_algebra::{AbstractField, Field};
6use slop_alloc::{Backend, CanCopyFromRef, CpuBackend};
7use slop_challenger::{CanObserve, FieldChallenger, IopCtx, VariableLengthChallenger};
8use slop_commit::Rounds;
9use slop_jagged::{DefaultJaggedProver, JaggedProver, JaggedProverData};
10use slop_matrix::dense::RowMajorMatrixView;
11use slop_multilinear::{
12 Evaluations, MleEval, MultilinearPcsProver, MultilinearPcsVerifier, Point, VirtualGeq,
13};
14use slop_sumcheck::{reduce_sumcheck_to_evaluation, PartialSumcheckProof};
15use slop_tensor::Tensor;
16use std::{
17 collections::{BTreeMap, BTreeSet},
18 fmt::Debug,
19 future::Future,
20 iter::once,
21 sync::Arc,
22};
23use thousands::Separable;
24use tracing::Instrument;
25
26use crate::{
27 air::{MachineAir, MachineProgram},
28 prover::{
29 DefaultTraceGenerator, Program, ProverPermit, ProverSemaphore, Record, ZeroCheckPoly,
30 ZerocheckCpuProverData,
31 },
32 septic_digest::SepticDigest,
33 AirOpenedValues, Chip, ChipEvaluation, ChipOpenedValues, ChipStatistics,
34 ConstraintSumcheckFolder, GkrProverImpl, LogUpEvaluations, Machine, MachineVerifyingKey,
35 ShardContext, ShardOpenedValues, ShardProof,
36};
37
38use super::{TraceGenerator, Traces};
39
40pub type PcsProof<GC, SC> = <<SC as ShardContext<GC>>::Config as MultilinearPcsVerifier<GC>>::Proof;
42
43#[allow(clippy::type_complexity)]
45pub trait AirProver<GC: IopCtx, SC: ShardContext<GC>>: 'static + Send + Sync + Sized {
46 type PreprocessedData: 'static + Send + Sync;
48
49 fn machine(&self) -> &Machine<GC::F, SC::Air>;
51
52 fn setup_from_vk(
54 &self,
55 program: Arc<Program<GC, SC>>,
56 vk: Option<MachineVerifyingKey<GC>>,
57 prover_permits: ProverSemaphore,
58 ) -> impl Future<Output = (PreprocessedData<ProvingKey<GC, SC, Self>>, MachineVerifyingKey<GC>)> + Send;
59
60 fn setup_and_prove_shard(
62 &self,
63 program: Arc<Program<GC, SC>>,
64 record: Record<GC, SC>,
65 vk: Option<MachineVerifyingKey<GC>>,
66 prover_permits: ProverSemaphore,
67 ) -> impl Future<
68 Output = (MachineVerifyingKey<GC>, ShardProof<GC, PcsProof<GC, SC>>, ProverPermit),
69 > + Send;
70
71 fn prove_shard_with_pk(
73 &self,
74 pk: Arc<ProvingKey<GC, SC, Self>>,
75 record: Record<GC, SC>,
76 prover_permits: ProverSemaphore,
77 ) -> impl Future<Output = (ShardProof<GC, PcsProof<GC, SC>>, ProverPermit)> + Send;
78 fn all_chips(&self) -> &[Chip<GC::F, SC::Air>] {
80 self.machine().chips()
81 }
82
83 fn setup(
89 &self,
90 program: Arc<Program<GC, SC>>,
91 setup_permits: ProverSemaphore,
92 ) -> impl Future<Output = (PreprocessedData<ProvingKey<GC, SC, Self>>, MachineVerifyingKey<GC>)> + Send
93 {
94 self.setup_from_vk(program, None, setup_permits)
95 }
96
97 fn preprocessed_table_heights(
99 pk: Arc<ProvingKey<GC, SC, Self>>,
100 ) -> impl Future<Output = BTreeMap<String, usize>> + Send;
101}
102
103pub struct ProvingKey<GC: IopCtx, SC: ShardContext<GC>, Prover: AirProver<GC, SC>> {
105 pub vk: MachineVerifyingKey<GC>,
107 pub preprocessed_data: Prover::PreprocessedData,
109}
110
111#[allow(clippy::type_complexity)]
113pub struct ShardData<GC: IopCtx, SC: ShardContext<GC>, C: DefaultJaggedProver<GC, SC::Config>> {
114 pub pk: Arc<ProvingKey<GC, SC, ShardProver<GC, SC, C>>>,
116 pub main_trace_data: MainTraceData<GC::F, SC::Air, CpuBackend>,
118}
119
120pub struct MainTraceData<F: Field, A: MachineAir<F>, B: Backend> {
122 pub traces: Traces<F, B>,
124 pub public_values: Vec<F>,
126 pub shard_chips: BTreeSet<Chip<F, A>>,
128 pub permit: ProverPermit,
130}
131
132pub struct TraceData<F: Field, A: MachineAir<F>, B: Backend> {
134 pub preprocessed_traces: Traces<F, B>,
136 pub main_trace_data: MainTraceData<F, A, B>,
138}
139
140pub struct PreprocessedTraceData<F: Field, B: Backend> {
142 pub preprocessed_traces: Traces<F, B>,
144 pub permit: ProverPermit,
146}
147
148pub struct PreprocessedData<T> {
150 pub pk: Arc<T>,
152 pub permit: ProverPermit,
154}
155
156impl<T> PreprocessedData<T> {
157 #[must_use]
162 #[inline]
163 pub unsafe fn into_inner(self) -> Arc<T> {
164 self.pk
165 }
166}
167
168pub struct ShardProverInner<
170 GC: IopCtx,
171 SC: ShardContext<GC>,
172 C: MultilinearPcsProver<GC, PcsProof<GC, SC>>,
173> {
174 pub trace_generator: DefaultTraceGenerator<GC::F, SC::Air, CpuBackend>,
176 pub logup_gkr_prover: GkrProverImpl<GC, SC>,
178 pub pcs_prover: JaggedProver<GC, PcsProof<GC, SC>, C>,
180}
181
182pub struct ShardProver<
185 GC: IopCtx,
186 SC: ShardContext<GC>,
187 C: MultilinearPcsProver<GC, PcsProof<GC, SC>>,
188> {
189 inner: Arc<ShardProverInner<GC, SC, C>>,
190}
191
192impl<GC: IopCtx, SC: ShardContext<GC>, C: MultilinearPcsProver<GC, PcsProof<GC, SC>>> Clone
195 for ShardProver<GC, SC, C>
196{
197 fn clone(&self) -> Self {
198 Self { inner: Arc::clone(&self.inner) }
199 }
200}
201
202impl<GC: IopCtx, SC: ShardContext<GC>, C: MultilinearPcsProver<GC, PcsProof<GC, SC>>>
203 ShardProver<GC, SC, C>
204{
205 pub fn from_components(
207 trace_generator: DefaultTraceGenerator<GC::F, SC::Air, CpuBackend>,
208 logup_gkr_prover: GkrProverImpl<GC, SC>,
209 pcs_prover: JaggedProver<GC, PcsProof<GC, SC>, C>,
210 ) -> Self {
211 Self { inner: Arc::new(ShardProverInner { trace_generator, logup_gkr_prover, pcs_prover }) }
212 }
213
214 #[must_use]
216 pub fn trace_generator(&self) -> &DefaultTraceGenerator<GC::F, SC::Air, CpuBackend> {
217 &self.inner.trace_generator
218 }
219
220 #[must_use]
222 pub fn logup_gkr_prover(&self) -> &GkrProverImpl<GC, SC> {
223 &self.inner.logup_gkr_prover
224 }
225
226 #[must_use]
228 pub fn pcs_prover(&self) -> &JaggedProver<GC, PcsProof<GC, SC>, C> {
229 &self.inner.pcs_prover
230 }
231}
232
233impl<GC: IopCtx, SC: ShardContext<GC>, C: DefaultJaggedProver<GC, SC::Config>> AirProver<GC, SC>
234 for ShardProver<GC, SC, C>
235{
236 type PreprocessedData = ShardProverData<GC, SC, C>;
237
238 fn machine(&self) -> &Machine<GC::F, SC::Air> {
239 self.inner.trace_generator.machine()
240 }
241
242 async fn setup_from_vk(
244 &self,
245 program: Arc<Program<GC, SC>>,
246 vk: Option<MachineVerifyingKey<GC>>,
247 prover_permits: ProverSemaphore,
248 ) -> (PreprocessedData<ProvingKey<GC, SC, Self>>, MachineVerifyingKey<GC>) {
249 if let Some(vk) = vk {
250 let initial_global_cumulative_sum = vk.initial_global_cumulative_sum;
251 self.setup_with_initial_global_cumulative_sum(
252 program,
253 initial_global_cumulative_sum,
254 prover_permits,
255 )
256 .await
257 } else {
258 let program_sent = program.clone();
259 let initial_global_cumulative_sum =
260 tokio::task::spawn_blocking(move || program_sent.initial_global_cumulative_sum())
261 .await
262 .unwrap();
263 self.setup_with_initial_global_cumulative_sum(
264 program,
265 initial_global_cumulative_sum,
266 prover_permits,
267 )
268 .await
269 }
270 }
271
272 async fn setup_and_prove_shard(
274 &self,
275 program: Arc<Program<GC, SC>>,
276 record: Record<GC, SC>,
277 vk: Option<MachineVerifyingKey<GC>>,
278 prover_permits: ProverSemaphore,
279 ) -> (MachineVerifyingKey<GC>, ShardProof<GC, PcsProof<GC, SC>>, ProverPermit) {
280 let pc_start = program.pc_start();
282 let enable_untrusted_programs = program.enable_untrusted_programs();
283 let initial_global_cumulative_sum = if let Some(vk) = vk {
284 vk.initial_global_cumulative_sum
285 } else {
286 let program = program.clone();
287 tokio::task::spawn_blocking(move || program.initial_global_cumulative_sum())
288 .instrument(tracing::debug_span!("initial_global_cumulative_sum"))
289 .await
290 .unwrap()
291 };
292
293 let trace_data = self
295 .inner
296 .trace_generator
297 .generate_traces(program, record, self.max_log_row_count(), prover_permits)
298 .instrument(tracing::debug_span!("generate full traces"))
299 .await;
300
301 let TraceData { preprocessed_traces, main_trace_data } = trace_data;
302
303 let (pk, vk) = {
304 let _span = tracing::debug_span!("setup_from_preprocessed_data_and_traces").entered();
305 self.setup_from_preprocessed_data_and_traces(
306 pc_start,
307 initial_global_cumulative_sum,
308 preprocessed_traces,
309 enable_untrusted_programs,
310 )
311 };
312
313 let pk = ProvingKey { vk: vk.clone(), preprocessed_data: pk };
314
315 let pk = Arc::new(pk);
316
317 let mut challenger = GC::default_challenger();
319 vk.observe_into(&mut challenger);
321
322 let shard_data = ShardData { pk, main_trace_data };
323
324 let prover = self.clone();
325 let (shard_proof, permit) = tokio::task::spawn_blocking(move || {
326 let _span = tracing::debug_span!("prove shard with data").entered();
327 prover.prove_shard_with_data(shard_data, challenger)
328 })
329 .await
330 .unwrap();
331
332 (vk, shard_proof, permit)
333 }
334
335 async fn prove_shard_with_pk(
337 &self,
338 pk: Arc<ProvingKey<GC, SC, Self>>,
339 record: Record<GC, SC>,
340 prover_permits: ProverSemaphore,
341 ) -> (ShardProof<GC, PcsProof<GC, SC>>, ProverPermit) {
342 let mut challenger = GC::default_challenger();
343 pk.vk.observe_into(&mut challenger);
344 let main_trace_data = self
346 .inner
347 .trace_generator
348 .generate_main_traces(record, self.max_log_row_count(), prover_permits)
349 .instrument(tracing::debug_span!("generate main traces"))
350 .await;
351
352 let shard_data = ShardData { pk, main_trace_data };
353
354 let prover = self.clone();
355 tokio::task::spawn_blocking(move || {
356 let _span = tracing::debug_span!("prove shard with data").entered();
357 prover.prove_shard_with_data(shard_data, challenger)
358 })
359 .await
360 .unwrap()
361 }
362
363 async fn preprocessed_table_heights(
364 pk: Arc<super::ProvingKey<GC, SC, Self>>,
365 ) -> BTreeMap<String, usize> {
366 std::future::ready(
367 pk.preprocessed_data
368 .preprocessed_traces
369 .iter()
370 .map(|(name, trace)| (name.to_owned(), trace.num_real_entries()))
371 .collect(),
372 )
373 .await
374 }
375}
376
377impl<GC: IopCtx, SC: ShardContext<GC>, C: DefaultJaggedProver<GC, SC::Config>>
378 ShardProver<GC, SC, C>
379{
380 #[must_use]
382 pub fn all_chips(&self) -> &[Chip<GC::F, SC::Air>] {
383 self.inner.trace_generator.machine().chips()
384 }
385
386 #[must_use]
388 pub fn machine(&self) -> &Machine<GC::F, SC::Air> {
389 self.inner.trace_generator.machine()
390 }
391
392 #[must_use]
394 pub fn num_pv_elts(&self) -> usize {
395 self.inner.trace_generator.machine().num_pv_elts()
396 }
397
398 #[inline]
400 #[must_use]
401 pub fn max_log_row_count(&self) -> usize {
402 self.inner.pcs_prover.max_log_row_count
403 }
404
405 pub fn setup_from_preprocessed_data_and_traces(
407 &self,
408 pc_start: [GC::F; 3],
409 initial_global_cumulative_sum: SepticDigest<GC::F>,
410 preprocessed_traces: Traces<GC::F, CpuBackend>,
411 enable_untrusted_programs: GC::F,
412 ) -> (ShardProverData<GC, SC, C>, MachineVerifyingKey<GC>) {
413 assert!(!preprocessed_traces.is_empty(), "preprocessed trace cannot be empty");
415 let message = preprocessed_traces.values().cloned().collect::<Vec<_>>();
416 let (preprocessed_commit, preprocessed_data) =
417 self.inner.pcs_prover.commit_multilinears(message).unwrap();
418
419 let vk = MachineVerifyingKey {
420 pc_start,
421 initial_global_cumulative_sum,
422 preprocessed_commit,
423 enable_untrusted_programs,
424 };
425
426 let pk = ShardProverData { preprocessed_traces, preprocessed_data };
427
428 (pk, vk)
429 }
430
431 pub async fn setup_with_initial_global_cumulative_sum(
433 &self,
434 program: Arc<Program<GC, SC>>,
435 initial_global_cumulative_sum: SepticDigest<GC::F>,
436 setup_permits: ProverSemaphore,
437 ) -> (PreprocessedData<ProvingKey<GC, SC, Self>>, MachineVerifyingKey<GC>) {
438 let pc_start = program.pc_start();
439 let enable_untrusted_programs = program.enable_untrusted_programs();
440 let preprocessed_data = self
441 .inner
442 .trace_generator
443 .generate_preprocessed_traces(program, self.max_log_row_count(), setup_permits)
444 .await;
445
446 let PreprocessedTraceData { preprocessed_traces, permit } = preprocessed_data;
447
448 let (pk, vk) = self.setup_from_preprocessed_data_and_traces(
449 pc_start,
450 initial_global_cumulative_sum,
451 preprocessed_traces,
452 enable_untrusted_programs,
453 );
454
455 let pk = ProvingKey { vk: vk.clone(), preprocessed_data: pk };
456
457 let pk = Arc::new(pk);
458
459 (PreprocessedData { pk, permit }, vk)
460 }
461
462 fn commit_traces(
463 &self,
464 traces: &Traces<GC::F, CpuBackend>,
465 ) -> (GC::Digest, JaggedProverData<GC, C::ProverData>) {
466 let message = traces.values().cloned().collect::<Vec<_>>();
467 self.inner.pcs_prover.commit_multilinears(message).unwrap()
468 }
469
470 #[allow(clippy::too_many_arguments)]
471 #[allow(clippy::too_many_lines)]
472 #[allow(clippy::type_complexity)]
473 #[allow(clippy::needless_pass_by_value)]
474 fn zerocheck(
475 &self,
476 chips: &BTreeSet<Chip<GC::F, SC::Air>>,
477 preprocessed_traces: Traces<GC::F, CpuBackend>,
478 traces: Traces<GC::F, CpuBackend>,
479 batching_challenge: GC::EF,
480 gkr_opening_batch_randomness: GC::EF,
481 logup_evaluations: &LogUpEvaluations<GC::EF>,
482 public_values: Vec<GC::F>,
483 challenger: &mut GC::Challenger,
484 ) -> (ShardOpenedValues<GC::F, GC::EF>, PartialSumcheckProof<GC::EF>) {
485 let max_num_constraints =
486 itertools::max(chips.iter().map(|chip| chip.num_constraints)).unwrap();
487 let powers_of_challenge =
488 batching_challenge.powers().take(max_num_constraints).collect::<Vec<_>>();
489 let airs =
490 chips.iter().map(|chip| (chip.air.clone(), chip.num_constraints)).collect::<Vec<_>>();
491
492 let public_values = Arc::new(public_values);
493
494 let mut zerocheck_polys = Vec::new();
495 let mut chip_sumcheck_claims = Vec::new();
496
497 let LogUpEvaluations { point: gkr_point, chip_openings } = logup_evaluations;
498
499 let mut chip_heights = BTreeMap::new();
500 for ((air, num_constraints), chip) in airs.iter().cloned().zip_eq(chips.iter()) {
501 let ChipEvaluation {
502 main_trace_evaluations: main_opening,
503 preprocessed_trace_evaluations: prep_opening,
504 } = chip_openings.get(chip.name()).unwrap();
505
506 let main_trace = traces.get(air.name()).unwrap().clone();
507 let num_real_entries = main_trace.num_real_entries();
508
509 let threshold_point =
510 Point::from_usize(num_real_entries, self.inner.pcs_prover.max_log_row_count + 1);
511 chip_heights.insert(air.name().to_string(), threshold_point);
512 let name = air.name();
513 let num_variables = main_trace.num_variables();
514 assert_eq!(num_variables, self.inner.pcs_prover.max_log_row_count as u32);
515
516 let preprocessed_width = air.preprocessed_width();
517 let dummy_preprocessed_trace = vec![GC::F::zero(); preprocessed_width];
518 let dummy_main_trace = vec![GC::F::zero(); main_trace.num_polynomials()];
519
520 let mut chip_powers_of_alpha = powers_of_challenge[0..num_constraints].to_vec();
524 chip_powers_of_alpha.reverse();
525
526 let mut folder = ConstraintSumcheckFolder {
527 preprocessed: RowMajorMatrixView::new_row(&dummy_preprocessed_trace),
528 main: RowMajorMatrixView::new_row(&dummy_main_trace),
529 accumulator: GC::EF::zero(),
530 public_values: &public_values,
531 constraint_index: 0,
532 powers_of_alpha: &chip_powers_of_alpha,
533 };
534
535 air.eval(&mut folder);
536 let padded_row_adjustment = folder.accumulator;
537
538 let gkr_opening_batch_randomness_powers = gkr_opening_batch_randomness
542 .powers()
543 .skip(1)
544 .take(
545 main_opening.num_polynomials()
546 + prep_opening.as_ref().map_or(0, MleEval::num_polynomials),
547 )
548 .collect::<Vec<_>>();
549 let gkr_powers = Arc::new(gkr_opening_batch_randomness_powers);
550
551 let alpha_powers = Arc::new(chip_powers_of_alpha);
552 let air_data = ZerocheckCpuProverData::round_prover(
553 air,
554 public_values.clone(),
555 alpha_powers,
556 gkr_powers.clone(),
557 );
558 let preprocessed_trace = preprocessed_traces.get(name).cloned();
559
560 let chip_sumcheck_claim = main_opening
561 .evaluations()
562 .as_slice()
563 .iter()
564 .chain(
565 prep_opening
566 .as_ref()
567 .map_or_else(Vec::new, |mle| mle.evaluations().as_slice().to_vec())
568 .iter(),
569 )
570 .zip(gkr_powers.iter())
571 .map(|(opening, power)| *opening * *power)
572 .sum::<GC::EF>();
573
574 let initial_geq_value =
575 if main_trace.num_real_entries() > 0 { GC::EF::zero() } else { GC::EF::one() };
576
577 let virtual_geq = VirtualGeq::new(
578 main_trace.num_real_entries() as u32,
579 GC::F::one(),
580 GC::F::zero(),
581 self.inner.pcs_prover.max_log_row_count as u32,
582 );
583
584 let zerocheck_poly = ZeroCheckPoly::new(
585 air_data,
586 gkr_point.clone(),
587 preprocessed_trace,
588 main_trace,
589 GC::EF::one(),
590 initial_geq_value,
591 padded_row_adjustment,
592 virtual_geq,
593 );
594 zerocheck_polys.push(zerocheck_poly);
595 chip_sumcheck_claims.push(chip_sumcheck_claim);
596 }
597
598 let lambda = challenger.sample_ext_element::<GC::EF>();
600
601 let (partial_sumcheck_proof, component_poly_evals) = reduce_sumcheck_to_evaluation(
603 zerocheck_polys,
604 challenger,
605 chip_sumcheck_claims,
606 1,
607 lambda,
608 );
609
610 let mut point_extended = partial_sumcheck_proof.point_and_eval.0.clone();
611 point_extended.add_dimension(GC::EF::zero());
612
613 debug_assert_eq!(component_poly_evals.len(), airs.len());
616 let len = airs.len();
617 challenger.observe(GC::F::from_canonical_usize(len));
618 let shard_open_values = airs
619 .into_iter()
620 .zip_eq(component_poly_evals)
621 .map(|((air, _), evals)| {
622 let (preprocessed_evals, main_evals) = evals.split_at(air.preprocessed_width());
623
624 challenger.observe_variable_length_extension_slice(preprocessed_evals);
626 challenger.observe_variable_length_extension_slice(main_evals);
627
628 let preprocessed = AirOpenedValues { local: preprocessed_evals.to_vec() };
629
630 let main = AirOpenedValues { local: main_evals.to_vec() };
631
632 (
633 air.name().to_string(),
634 ChipOpenedValues {
635 preprocessed,
636 main,
637 degree: chip_heights[air.name()].clone(),
638 },
639 )
640 })
641 .collect::<BTreeMap<_, _>>();
642
643 let shard_open_values = ShardOpenedValues { chips: shard_open_values };
644
645 (shard_open_values, partial_sumcheck_proof)
646 }
647
648 #[allow(clippy::type_complexity)]
650 pub fn prove_shard_with_data(
651 &self,
652 data: ShardData<GC, SC, C>,
653 mut challenger: GC::Challenger,
654 ) -> (ShardProof<GC, PcsProof<GC, SC>>, ProverPermit) {
655 let ShardData { pk, main_trace_data } = data;
656 let MainTraceData { traces, public_values, shard_chips, permit } = main_trace_data;
657
658 let mut total_number_of_cells = 0;
660 tracing::debug!("Proving shard");
661 for (chip, trace) in shard_chips.iter().zip_eq(traces.values()) {
662 let height = trace.num_real_entries();
663 let stats = ChipStatistics::new(chip, height);
664 tracing::debug!("{}", stats);
665 total_number_of_cells += stats.total_number_of_cells();
666 }
667
668 tracing::debug!(
669 "Total number of cells: {}, number of variables: {}",
670 total_number_of_cells.separate_with_underscores(),
671 total_number_of_cells.next_power_of_two().ilog2(),
672 );
673
674 challenger.observe_constant_length_slice(&public_values);
676
677 let (main_commit, main_data) = {
679 let _span = tracing::debug_span!("commit traces").entered();
680 self.commit_traces(&traces)
681 };
682 challenger.observe(main_commit);
684 challenger.observe(GC::F::from_canonical_usize(shard_chips.len()));
686
687 for chips in shard_chips.iter() {
688 let num_real_entries = traces.get(chips.air.name()).unwrap().num_real_entries();
689 challenger.observe(GC::F::from_canonical_usize(num_real_entries));
690 challenger.observe(GC::F::from_canonical_usize(chips.air.name().len()));
691 for byte in chips.air.name().as_bytes() {
692 challenger.observe(GC::F::from_canonical_u8(*byte));
693 }
694 }
695
696 let logup_gkr_proof = {
697 let _span = tracing::debug_span!("logup gkr proof").entered();
698 self.inner.logup_gkr_prover.prove_logup_gkr(
699 &shard_chips,
700 &pk.preprocessed_data.preprocessed_traces,
701 &traces,
702 public_values.clone(),
703 &mut challenger,
704 )
705 };
706 let batching_challenge = challenger.sample_ext_element::<GC::EF>();
708 let gkr_opening_batch_challenge = challenger.sample_ext_element::<GC::EF>();
710
711 #[cfg(sp1_debug_constraints)]
712 {
713 crate::debug::debug_constraints_all_chips::<GC, _>(
714 &shard_chips.iter().cloned().collect::<Vec<_>>(),
715 &pk.preprocessed_data.preprocessed_traces,
716 &traces,
717 &public_values,
718 );
719 }
720
721 let (shard_open_values, zerocheck_partial_sumcheck_proof) = {
723 let _span = tracing::debug_span!("zerocheck").entered();
724 self.zerocheck(
725 &shard_chips,
726 pk.preprocessed_data.preprocessed_traces.clone(),
727 traces,
728 batching_challenge,
729 gkr_opening_batch_challenge,
730 &logup_gkr_proof.logup_evaluations,
731 public_values.clone(),
732 &mut challenger,
733 )
734 };
735
736 let evaluation_point = zerocheck_partial_sumcheck_proof.point_and_eval.0.clone();
738 let mut preprocessed_evaluation_claims: Option<Evaluations<GC::EF, CpuBackend>> = None;
739 let mut main_evaluation_claims = Evaluations::new(vec![]);
740
741 let alloc = self.inner.trace_generator.allocator();
742
743 for (_, open_values) in shard_open_values.chips.iter() {
744 let prep_local = &open_values.preprocessed.local;
745 let main_local = &open_values.main.local;
746 if !prep_local.is_empty() {
747 let preprocessed_evals = alloc.copy_to(&MleEval::from(prep_local.clone())).unwrap();
748 if let Some(preprocessed_claims) = preprocessed_evaluation_claims.as_mut() {
749 preprocessed_claims.push(preprocessed_evals);
750 } else {
751 let evals = Evaluations::new(vec![preprocessed_evals]);
752 preprocessed_evaluation_claims = Some(evals);
753 }
754 }
755 let main_evals = alloc.copy_to(&MleEval::from(main_local.clone())).unwrap();
756 main_evaluation_claims.push(main_evals);
757 }
758
759 let round_evaluation_claims = preprocessed_evaluation_claims
760 .into_iter()
761 .chain(once(main_evaluation_claims))
762 .collect::<Rounds<_>>();
763
764 let round_prover_data = once(pk.preprocessed_data.preprocessed_data.clone())
765 .chain(once(main_data))
766 .collect::<Rounds<_>>();
767
768 let evaluation_proof = {
770 let _span = tracing::debug_span!("prove evaluation claims").entered();
771 self.inner
772 .pcs_prover
773 .prove_trusted_evaluations(
774 evaluation_point,
775 round_evaluation_claims,
776 round_prover_data,
777 &mut challenger,
778 )
779 .unwrap()
780 };
781
782 let proof = ShardProof {
783 main_commitment: main_commit,
784 opened_values: shard_open_values,
785 logup_gkr_proof,
786 evaluation_proof,
787 zerocheck_proof: zerocheck_partial_sumcheck_proof,
788 public_values,
789 };
790
791 (proof, permit)
792 }
793}
794
795#[derive_where(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
798pub struct CoreProofShape<F: Field, A: MachineAir<F>> {
799 pub shard_chips: BTreeSet<Chip<F, A>>,
801
802 pub preprocessed_multiple: usize,
804
805 pub main_multiple: usize,
807
808 pub preprocessed_padding_cols: usize,
811
812 pub main_padding_cols: usize,
815}
816
817impl<F, A> Debug for CoreProofShape<F, A>
818where
819 F: Field + Debug,
820 A: MachineAir<F> + Debug,
821{
822 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
823 f.debug_struct("ProofShape")
824 .field(
825 "shard_chips",
826 &self.shard_chips.iter().map(MachineAir::name).collect::<BTreeSet<_>>(),
827 )
828 .field("preprocessed_multiple", &self.preprocessed_multiple)
829 .field("main_multiple", &self.main_multiple)
830 .field("preprocessed_padding_cols", &self.preprocessed_padding_cols)
831 .field("main_padding_cols", &self.main_padding_cols)
832 .finish()
833 }
834}
835
836#[derive(Clone, Serialize, Deserialize)]
838#[serde(bound(
839 serialize = "Tensor<GC::F, CpuBackend>: Serialize, JaggedProverData<GC, C::ProverData>: Serialize, GC::F: Serialize,"
840))]
841#[serde(bound(
842 deserialize = "Tensor<GC::F, CpuBackend>: Deserialize<'de>, JaggedProverData<GC, C::ProverData>: Deserialize<'de>, GC::F: Deserialize<'de>, "
843))]
844pub struct ShardProverData<
845 GC: IopCtx,
846 SC: ShardContext<GC>,
847 C: MultilinearPcsProver<GC, PcsProof<GC, SC>>,
848> {
849 pub preprocessed_traces: Traces<GC::F, CpuBackend>,
851 pub preprocessed_data: JaggedProverData<GC, C::ProverData>,
853}