Skip to main content

sp1_hypercube/prover/
shard.rs

1use derive_where::derive_where;
2use itertools::Itertools;
3use serde::{Deserialize, Serialize};
4use slop_air::Air;
5use slop_algebra::{AbstractField, Field};
6use slop_alloc::{Backend, CanCopyFromRef, CpuBackend};
7use slop_challenger::{CanObserve, FieldChallenger, IopCtx, VariableLengthChallenger};
8use slop_commit::Rounds;
9use slop_jagged::{DefaultJaggedProver, JaggedProver, JaggedProverData};
10use slop_matrix::dense::RowMajorMatrixView;
11use slop_multilinear::{
12    Evaluations, MleEval, MultilinearPcsProver, MultilinearPcsVerifier, Point, VirtualGeq,
13};
14use slop_sumcheck::{reduce_sumcheck_to_evaluation, PartialSumcheckProof};
15use slop_tensor::Tensor;
16use std::{
17    collections::{BTreeMap, BTreeSet},
18    fmt::Debug,
19    future::Future,
20    iter::once,
21    sync::Arc,
22};
23use thousands::Separable;
24use tracing::Instrument;
25
26use crate::{
27    air::{MachineAir, MachineProgram},
28    prover::{
29        DefaultTraceGenerator, Program, ProverPermit, ProverSemaphore, Record, ZeroCheckPoly,
30        ZerocheckCpuProverData,
31    },
32    septic_digest::SepticDigest,
33    AirOpenedValues, Chip, ChipEvaluation, ChipOpenedValues, ChipStatistics,
34    ConstraintSumcheckFolder, GkrProverImpl, LogUpEvaluations, Machine, MachineVerifyingKey,
35    ShardContext, ShardOpenedValues, ShardProof,
36};
37
38use super::{TraceGenerator, Traces};
39
40/// The PCS proof type associated to a shard context.
41pub type PcsProof<GC, SC> = <<SC as ShardContext<GC>>::Config as MultilinearPcsVerifier<GC>>::Proof;
42
43/// A prover for an AIR.
44#[allow(clippy::type_complexity)]
45pub trait AirProver<GC: IopCtx, SC: ShardContext<GC>>: 'static + Send + Sync + Sized {
46    /// The proving key type.
47    type PreprocessedData: 'static + Send + Sync;
48
49    /// Get the machine.
50    fn machine(&self) -> &Machine<GC::F, SC::Air>;
51
52    /// Setup from a verifying key.
53    fn setup_from_vk(
54        &self,
55        program: Arc<Program<GC, SC>>,
56        vk: Option<MachineVerifyingKey<GC>>,
57        prover_permits: ProverSemaphore,
58    ) -> impl Future<Output = (PreprocessedData<ProvingKey<GC, SC, Self>>, MachineVerifyingKey<GC>)> + Send;
59
60    /// Setup and prove a shard.
61    fn setup_and_prove_shard(
62        &self,
63        program: Arc<Program<GC, SC>>,
64        record: Record<GC, SC>,
65        vk: Option<MachineVerifyingKey<GC>>,
66        prover_permits: ProverSemaphore,
67    ) -> impl Future<
68        Output = (MachineVerifyingKey<GC>, ShardProof<GC, PcsProof<GC, SC>>, ProverPermit),
69    > + Send;
70
71    /// Prove a shard with a given proving key.
72    fn prove_shard_with_pk(
73        &self,
74        pk: Arc<ProvingKey<GC, SC, Self>>,
75        record: Record<GC, SC>,
76        prover_permits: ProverSemaphore,
77    ) -> impl Future<Output = (ShardProof<GC, PcsProof<GC, SC>>, ProverPermit)> + Send;
78    /// Get all the chips in the machine.
79    fn all_chips(&self) -> &[Chip<GC::F, SC::Air>] {
80        self.machine().chips()
81    }
82
83    /// Setup from a program.
84    ///
85    /// The setup phase produces a pair '(pk, vk)' of proving and verifying keys. The proving key
86    /// consists of information used by the prover that only depends on the program itself and not
87    /// a specific execution.
88    fn setup(
89        &self,
90        program: Arc<Program<GC, SC>>,
91        setup_permits: ProverSemaphore,
92    ) -> impl Future<Output = (PreprocessedData<ProvingKey<GC, SC, Self>>, MachineVerifyingKey<GC>)> + Send
93    {
94        self.setup_from_vk(program, None, setup_permits)
95    }
96
97    /// A function which deduces preprocessed table heights from the proving key.
98    fn preprocessed_table_heights(
99        pk: Arc<ProvingKey<GC, SC, Self>>,
100    ) -> impl Future<Output = BTreeMap<String, usize>> + Send;
101}
102
103/// A proving key for an AIR prover.
104pub struct ProvingKey<GC: IopCtx, SC: ShardContext<GC>, Prover: AirProver<GC, SC>> {
105    /// The verifying key.
106    pub vk: MachineVerifyingKey<GC>,
107    /// The preprocessed data.
108    pub preprocessed_data: Prover::PreprocessedData,
109}
110
111/// A collection of main traces with a permit.
112#[allow(clippy::type_complexity)]
113pub struct ShardData<GC: IopCtx, SC: ShardContext<GC>, C: DefaultJaggedProver<GC, SC::Config>> {
114    /// The proving key.
115    pub pk: Arc<ProvingKey<GC, SC, ShardProver<GC, SC, C>>>,
116    /// Main trace data
117    pub main_trace_data: MainTraceData<GC::F, SC::Air, CpuBackend>,
118}
119
120/// The main traces for a program, with a permit.
121pub struct MainTraceData<F: Field, A: MachineAir<F>, B: Backend> {
122    /// The traces.
123    pub traces: Traces<F, B>,
124    /// The public values.
125    pub public_values: Vec<F>,
126    /// The shape cluster corresponding to the traces.
127    pub shard_chips: BTreeSet<Chip<F, A>>,
128    /// A permit for a prover resource.
129    pub permit: ProverPermit,
130}
131
132/// The total trace data for a shard.
133pub struct TraceData<F: Field, A: MachineAir<F>, B: Backend> {
134    /// The preprocessed traces.
135    pub preprocessed_traces: Traces<F, B>,
136    /// The main traces.
137    pub main_trace_data: MainTraceData<F, A, B>,
138}
139
140/// The preprocessed traces for a program.
141pub struct PreprocessedTraceData<F: Field, B: Backend> {
142    /// The preprocessed traces.
143    pub preprocessed_traces: Traces<F, B>,
144    /// A permit for a prover resource.
145    pub permit: ProverPermit,
146}
147
148/// The preprocessed data for a program.
149pub struct PreprocessedData<T> {
150    /// The proving key.
151    pub pk: Arc<T>,
152    /// A permit for a prover resource.
153    pub permit: ProverPermit,
154}
155
156impl<T> PreprocessedData<T> {
157    /// Unsafely take the inner proving key.
158    ///
159    /// # Safety
160    /// This is unsafe because the permit is dropped.
161    #[must_use]
162    #[inline]
163    pub unsafe fn into_inner(self) -> Arc<T> {
164        self.pk
165    }
166}
167
168/// Inner struct containing the actual prover data.
169pub struct ShardProverInner<
170    GC: IopCtx,
171    SC: ShardContext<GC>,
172    C: MultilinearPcsProver<GC, PcsProof<GC, SC>>,
173> {
174    /// The trace generator.
175    pub trace_generator: DefaultTraceGenerator<GC::F, SC::Air, CpuBackend>,
176    /// The logup GKR prover.
177    pub logup_gkr_prover: GkrProverImpl<GC, SC>,
178    /// A prover for the PCS.
179    pub pcs_prover: JaggedProver<GC, PcsProof<GC, SC>, C>,
180}
181
182/// A prover for the hypercube STARK, given a configuration.
183/// Wrapped in Arc for cheap cloning to enable `spawn_blocking`.
184pub struct ShardProver<
185    GC: IopCtx,
186    SC: ShardContext<GC>,
187    C: MultilinearPcsProver<GC, PcsProof<GC, SC>>,
188> {
189    inner: Arc<ShardProverInner<GC, SC, C>>,
190}
191
192// Implement Clone manually to avoid requiring Clone bounds on generic parameters.
193// Arc::clone doesn't need the inner type to be Clone.
194impl<GC: IopCtx, SC: ShardContext<GC>, C: MultilinearPcsProver<GC, PcsProof<GC, SC>>> Clone
195    for ShardProver<GC, SC, C>
196{
197    fn clone(&self) -> Self {
198        Self { inner: Arc::clone(&self.inner) }
199    }
200}
201
202impl<GC: IopCtx, SC: ShardContext<GC>, C: MultilinearPcsProver<GC, PcsProof<GC, SC>>>
203    ShardProver<GC, SC, C>
204{
205    /// Create a new `ShardProver` from its components.
206    pub fn from_components(
207        trace_generator: DefaultTraceGenerator<GC::F, SC::Air, CpuBackend>,
208        logup_gkr_prover: GkrProverImpl<GC, SC>,
209        pcs_prover: JaggedProver<GC, PcsProof<GC, SC>, C>,
210    ) -> Self {
211        Self { inner: Arc::new(ShardProverInner { trace_generator, logup_gkr_prover, pcs_prover }) }
212    }
213
214    /// Access the trace generator.
215    #[must_use]
216    pub fn trace_generator(&self) -> &DefaultTraceGenerator<GC::F, SC::Air, CpuBackend> {
217        &self.inner.trace_generator
218    }
219
220    /// Access the logup GKR prover.
221    #[must_use]
222    pub fn logup_gkr_prover(&self) -> &GkrProverImpl<GC, SC> {
223        &self.inner.logup_gkr_prover
224    }
225
226    /// Access the PCS prover.
227    #[must_use]
228    pub fn pcs_prover(&self) -> &JaggedProver<GC, PcsProof<GC, SC>, C> {
229        &self.inner.pcs_prover
230    }
231}
232
233impl<GC: IopCtx, SC: ShardContext<GC>, C: DefaultJaggedProver<GC, SC::Config>> AirProver<GC, SC>
234    for ShardProver<GC, SC, C>
235{
236    type PreprocessedData = ShardProverData<GC, SC, C>;
237
238    fn machine(&self) -> &Machine<GC::F, SC::Air> {
239        self.inner.trace_generator.machine()
240    }
241
242    /// Setup a shard, using a verifying key if provided.
243    async fn setup_from_vk(
244        &self,
245        program: Arc<Program<GC, SC>>,
246        vk: Option<MachineVerifyingKey<GC>>,
247        prover_permits: ProverSemaphore,
248    ) -> (PreprocessedData<ProvingKey<GC, SC, Self>>, MachineVerifyingKey<GC>) {
249        if let Some(vk) = vk {
250            let initial_global_cumulative_sum = vk.initial_global_cumulative_sum;
251            self.setup_with_initial_global_cumulative_sum(
252                program,
253                initial_global_cumulative_sum,
254                prover_permits,
255            )
256            .await
257        } else {
258            let program_sent = program.clone();
259            let initial_global_cumulative_sum =
260                tokio::task::spawn_blocking(move || program_sent.initial_global_cumulative_sum())
261                    .await
262                    .unwrap();
263            self.setup_with_initial_global_cumulative_sum(
264                program,
265                initial_global_cumulative_sum,
266                prover_permits,
267            )
268            .await
269        }
270    }
271
272    /// Setup and prove a shard.
273    async fn setup_and_prove_shard(
274        &self,
275        program: Arc<Program<GC, SC>>,
276        record: Record<GC, SC>,
277        vk: Option<MachineVerifyingKey<GC>>,
278        prover_permits: ProverSemaphore,
279    ) -> (MachineVerifyingKey<GC>, ShardProof<GC, PcsProof<GC, SC>>, ProverPermit) {
280        // Get the initial global cumulative sum and pc start.
281        let pc_start = program.pc_start();
282        let enable_untrusted_programs = program.enable_untrusted_programs();
283        let initial_global_cumulative_sum = if let Some(vk) = vk {
284            vk.initial_global_cumulative_sum
285        } else {
286            let program = program.clone();
287            tokio::task::spawn_blocking(move || program.initial_global_cumulative_sum())
288                .instrument(tracing::debug_span!("initial_global_cumulative_sum"))
289                .await
290                .unwrap()
291        };
292
293        // Generate trace.
294        let trace_data = self
295            .inner
296            .trace_generator
297            .generate_traces(program, record, self.max_log_row_count(), prover_permits)
298            .instrument(tracing::debug_span!("generate full traces"))
299            .await;
300
301        let TraceData { preprocessed_traces, main_trace_data } = trace_data;
302
303        let (pk, vk) = {
304            let _span = tracing::debug_span!("setup_from_preprocessed_data_and_traces").entered();
305            self.setup_from_preprocessed_data_and_traces(
306                pc_start,
307                initial_global_cumulative_sum,
308                preprocessed_traces,
309                enable_untrusted_programs,
310            )
311        };
312
313        let pk = ProvingKey { vk: vk.clone(), preprocessed_data: pk };
314
315        let pk = Arc::new(pk);
316
317        // Create a challenger.
318        let mut challenger = GC::default_challenger();
319        // Observe the preprocessed information.
320        vk.observe_into(&mut challenger);
321
322        let shard_data = ShardData { pk, main_trace_data };
323
324        let prover = self.clone();
325        let (shard_proof, permit) = tokio::task::spawn_blocking(move || {
326            let _span = tracing::debug_span!("prove shard with data").entered();
327            prover.prove_shard_with_data(shard_data, challenger)
328        })
329        .await
330        .unwrap();
331
332        (vk, shard_proof, permit)
333    }
334
335    /// Prove a shard with a given proving key.
336    async fn prove_shard_with_pk(
337        &self,
338        pk: Arc<ProvingKey<GC, SC, Self>>,
339        record: Record<GC, SC>,
340        prover_permits: ProverSemaphore,
341    ) -> (ShardProof<GC, PcsProof<GC, SC>>, ProverPermit) {
342        let mut challenger = GC::default_challenger();
343        pk.vk.observe_into(&mut challenger);
344        // Generate the traces.
345        let main_trace_data = self
346            .inner
347            .trace_generator
348            .generate_main_traces(record, self.max_log_row_count(), prover_permits)
349            .instrument(tracing::debug_span!("generate main traces"))
350            .await;
351
352        let shard_data = ShardData { pk, main_trace_data };
353
354        let prover = self.clone();
355        tokio::task::spawn_blocking(move || {
356            let _span = tracing::debug_span!("prove shard with data").entered();
357            prover.prove_shard_with_data(shard_data, challenger)
358        })
359        .await
360        .unwrap()
361    }
362
363    async fn preprocessed_table_heights(
364        pk: Arc<super::ProvingKey<GC, SC, Self>>,
365    ) -> BTreeMap<String, usize> {
366        std::future::ready(
367            pk.preprocessed_data
368                .preprocessed_traces
369                .iter()
370                .map(|(name, trace)| (name.to_owned(), trace.num_real_entries()))
371                .collect(),
372        )
373        .await
374    }
375}
376
377impl<GC: IopCtx, SC: ShardContext<GC>, C: DefaultJaggedProver<GC, SC::Config>>
378    ShardProver<GC, SC, C>
379{
380    /// Get all the chips in the machine.
381    #[must_use]
382    pub fn all_chips(&self) -> &[Chip<GC::F, SC::Air>] {
383        self.inner.trace_generator.machine().chips()
384    }
385
386    /// Get the machine.
387    #[must_use]
388    pub fn machine(&self) -> &Machine<GC::F, SC::Air> {
389        self.inner.trace_generator.machine()
390    }
391
392    /// Get the number of public values in the machine.
393    #[must_use]
394    pub fn num_pv_elts(&self) -> usize {
395        self.inner.trace_generator.machine().num_pv_elts()
396    }
397
398    /// Get the maximum log row count.
399    #[inline]
400    #[must_use]
401    pub fn max_log_row_count(&self) -> usize {
402        self.inner.pcs_prover.max_log_row_count
403    }
404
405    /// Setup from preprocessed data and traces.
406    pub fn setup_from_preprocessed_data_and_traces(
407        &self,
408        pc_start: [GC::F; 3],
409        initial_global_cumulative_sum: SepticDigest<GC::F>,
410        preprocessed_traces: Traces<GC::F, CpuBackend>,
411        enable_untrusted_programs: GC::F,
412    ) -> (ShardProverData<GC, SC, C>, MachineVerifyingKey<GC>) {
413        // Commit to the preprocessed traces, if there are any.
414        assert!(!preprocessed_traces.is_empty(), "preprocessed trace cannot be empty");
415        let message = preprocessed_traces.values().cloned().collect::<Vec<_>>();
416        let (preprocessed_commit, preprocessed_data) =
417            self.inner.pcs_prover.commit_multilinears(message).unwrap();
418
419        let vk = MachineVerifyingKey {
420            pc_start,
421            initial_global_cumulative_sum,
422            preprocessed_commit,
423            enable_untrusted_programs,
424        };
425
426        let pk = ShardProverData { preprocessed_traces, preprocessed_data };
427
428        (pk, vk)
429    }
430
431    /// Setup from a program with a specific initial global cumulative sum.
432    pub async fn setup_with_initial_global_cumulative_sum(
433        &self,
434        program: Arc<Program<GC, SC>>,
435        initial_global_cumulative_sum: SepticDigest<GC::F>,
436        setup_permits: ProverSemaphore,
437    ) -> (PreprocessedData<ProvingKey<GC, SC, Self>>, MachineVerifyingKey<GC>) {
438        let pc_start = program.pc_start();
439        let enable_untrusted_programs = program.enable_untrusted_programs();
440        let preprocessed_data = self
441            .inner
442            .trace_generator
443            .generate_preprocessed_traces(program, self.max_log_row_count(), setup_permits)
444            .await;
445
446        let PreprocessedTraceData { preprocessed_traces, permit } = preprocessed_data;
447
448        let (pk, vk) = self.setup_from_preprocessed_data_and_traces(
449            pc_start,
450            initial_global_cumulative_sum,
451            preprocessed_traces,
452            enable_untrusted_programs,
453        );
454
455        let pk = ProvingKey { vk: vk.clone(), preprocessed_data: pk };
456
457        let pk = Arc::new(pk);
458
459        (PreprocessedData { pk, permit }, vk)
460    }
461
462    fn commit_traces(
463        &self,
464        traces: &Traces<GC::F, CpuBackend>,
465    ) -> (GC::Digest, JaggedProverData<GC, C::ProverData>) {
466        let message = traces.values().cloned().collect::<Vec<_>>();
467        self.inner.pcs_prover.commit_multilinears(message).unwrap()
468    }
469
470    #[allow(clippy::too_many_arguments)]
471    #[allow(clippy::too_many_lines)]
472    #[allow(clippy::type_complexity)]
473    #[allow(clippy::needless_pass_by_value)]
474    fn zerocheck(
475        &self,
476        chips: &BTreeSet<Chip<GC::F, SC::Air>>,
477        preprocessed_traces: Traces<GC::F, CpuBackend>,
478        traces: Traces<GC::F, CpuBackend>,
479        batching_challenge: GC::EF,
480        gkr_opening_batch_randomness: GC::EF,
481        logup_evaluations: &LogUpEvaluations<GC::EF>,
482        public_values: Vec<GC::F>,
483        challenger: &mut GC::Challenger,
484    ) -> (ShardOpenedValues<GC::F, GC::EF>, PartialSumcheckProof<GC::EF>) {
485        let max_num_constraints =
486            itertools::max(chips.iter().map(|chip| chip.num_constraints)).unwrap();
487        let powers_of_challenge =
488            batching_challenge.powers().take(max_num_constraints).collect::<Vec<_>>();
489        let airs =
490            chips.iter().map(|chip| (chip.air.clone(), chip.num_constraints)).collect::<Vec<_>>();
491
492        let public_values = Arc::new(public_values);
493
494        let mut zerocheck_polys = Vec::new();
495        let mut chip_sumcheck_claims = Vec::new();
496
497        let LogUpEvaluations { point: gkr_point, chip_openings } = logup_evaluations;
498
499        let mut chip_heights = BTreeMap::new();
500        for ((air, num_constraints), chip) in airs.iter().cloned().zip_eq(chips.iter()) {
501            let ChipEvaluation {
502                main_trace_evaluations: main_opening,
503                preprocessed_trace_evaluations: prep_opening,
504            } = chip_openings.get(chip.name()).unwrap();
505
506            let main_trace = traces.get(air.name()).unwrap().clone();
507            let num_real_entries = main_trace.num_real_entries();
508
509            let threshold_point =
510                Point::from_usize(num_real_entries, self.inner.pcs_prover.max_log_row_count + 1);
511            chip_heights.insert(air.name().to_string(), threshold_point);
512            let name = air.name();
513            let num_variables = main_trace.num_variables();
514            assert_eq!(num_variables, self.inner.pcs_prover.max_log_row_count as u32);
515
516            let preprocessed_width = air.preprocessed_width();
517            let dummy_preprocessed_trace = vec![GC::F::zero(); preprocessed_width];
518            let dummy_main_trace = vec![GC::F::zero(); main_trace.num_polynomials()];
519
520            // Calculate powers of alpha for constraint evaluation:
521            // 1. Generate sequence [α⁰, α¹, ..., α^(n-1)] where n = num_constraints.
522            // 2. Reverse to [α^(n-1), ..., α¹, α⁰] to align with Horner's method in the verifier.
523            let mut chip_powers_of_alpha = powers_of_challenge[0..num_constraints].to_vec();
524            chip_powers_of_alpha.reverse();
525
526            let mut folder = ConstraintSumcheckFolder {
527                preprocessed: RowMajorMatrixView::new_row(&dummy_preprocessed_trace),
528                main: RowMajorMatrixView::new_row(&dummy_main_trace),
529                accumulator: GC::EF::zero(),
530                public_values: &public_values,
531                constraint_index: 0,
532                powers_of_alpha: &chip_powers_of_alpha,
533            };
534
535            air.eval(&mut folder);
536            let padded_row_adjustment = folder.accumulator;
537
538            // TODO: This could be computed once for the maximally wide chip and stored for later
539            // use, but since it's a computation that's done once per chip, we have chosen not to
540            // perform this optimization for now.
541            let gkr_opening_batch_randomness_powers = gkr_opening_batch_randomness
542                .powers()
543                .skip(1)
544                .take(
545                    main_opening.num_polynomials()
546                        + prep_opening.as_ref().map_or(0, MleEval::num_polynomials),
547                )
548                .collect::<Vec<_>>();
549            let gkr_powers = Arc::new(gkr_opening_batch_randomness_powers);
550
551            let alpha_powers = Arc::new(chip_powers_of_alpha);
552            let air_data = ZerocheckCpuProverData::round_prover(
553                air,
554                public_values.clone(),
555                alpha_powers,
556                gkr_powers.clone(),
557            );
558            let preprocessed_trace = preprocessed_traces.get(name).cloned();
559
560            let chip_sumcheck_claim = main_opening
561                .evaluations()
562                .as_slice()
563                .iter()
564                .chain(
565                    prep_opening
566                        .as_ref()
567                        .map_or_else(Vec::new, |mle| mle.evaluations().as_slice().to_vec())
568                        .iter(),
569                )
570                .zip(gkr_powers.iter())
571                .map(|(opening, power)| *opening * *power)
572                .sum::<GC::EF>();
573
574            let initial_geq_value =
575                if main_trace.num_real_entries() > 0 { GC::EF::zero() } else { GC::EF::one() };
576
577            let virtual_geq = VirtualGeq::new(
578                main_trace.num_real_entries() as u32,
579                GC::F::one(),
580                GC::F::zero(),
581                self.inner.pcs_prover.max_log_row_count as u32,
582            );
583
584            let zerocheck_poly = ZeroCheckPoly::new(
585                air_data,
586                gkr_point.clone(),
587                preprocessed_trace,
588                main_trace,
589                GC::EF::one(),
590                initial_geq_value,
591                padded_row_adjustment,
592                virtual_geq,
593            );
594            zerocheck_polys.push(zerocheck_poly);
595            chip_sumcheck_claims.push(chip_sumcheck_claim);
596        }
597
598        // Same lambda for the RLC of the zerocheck polynomials.
599        let lambda = challenger.sample_ext_element::<GC::EF>();
600
601        // Compute the sumcheck proof for the zerocheck polynomials.
602        let (partial_sumcheck_proof, component_poly_evals) = reduce_sumcheck_to_evaluation(
603            zerocheck_polys,
604            challenger,
605            chip_sumcheck_claims,
606            1,
607            lambda,
608        );
609
610        let mut point_extended = partial_sumcheck_proof.point_and_eval.0.clone();
611        point_extended.add_dimension(GC::EF::zero());
612
613        // Compute the chip openings from the component poly evaluations.
614
615        debug_assert_eq!(component_poly_evals.len(), airs.len());
616        let len = airs.len();
617        challenger.observe(GC::F::from_canonical_usize(len));
618        let shard_open_values = airs
619            .into_iter()
620            .zip_eq(component_poly_evals)
621            .map(|((air, _), evals)| {
622                let (preprocessed_evals, main_evals) = evals.split_at(air.preprocessed_width());
623
624                // Observe the openings
625                challenger.observe_variable_length_extension_slice(preprocessed_evals);
626                challenger.observe_variable_length_extension_slice(main_evals);
627
628                let preprocessed = AirOpenedValues { local: preprocessed_evals.to_vec() };
629
630                let main = AirOpenedValues { local: main_evals.to_vec() };
631
632                (
633                    air.name().to_string(),
634                    ChipOpenedValues {
635                        preprocessed,
636                        main,
637                        degree: chip_heights[air.name()].clone(),
638                    },
639                )
640            })
641            .collect::<BTreeMap<_, _>>();
642
643        let shard_open_values = ShardOpenedValues { chips: shard_open_values };
644
645        (shard_open_values, partial_sumcheck_proof)
646    }
647
648    /// Generate a proof for a given execution record.
649    #[allow(clippy::type_complexity)]
650    pub fn prove_shard_with_data(
651        &self,
652        data: ShardData<GC, SC, C>,
653        mut challenger: GC::Challenger,
654    ) -> (ShardProof<GC, PcsProof<GC, SC>>, ProverPermit) {
655        let ShardData { pk, main_trace_data } = data;
656        let MainTraceData { traces, public_values, shard_chips, permit } = main_trace_data;
657
658        // Log the shard data.
659        let mut total_number_of_cells = 0;
660        tracing::debug!("Proving shard");
661        for (chip, trace) in shard_chips.iter().zip_eq(traces.values()) {
662            let height = trace.num_real_entries();
663            let stats = ChipStatistics::new(chip, height);
664            tracing::debug!("{}", stats);
665            total_number_of_cells += stats.total_number_of_cells();
666        }
667
668        tracing::debug!(
669            "Total number of cells: {}, number of variables: {}",
670            total_number_of_cells.separate_with_underscores(),
671            total_number_of_cells.next_power_of_two().ilog2(),
672        );
673
674        // Observe the public values.
675        challenger.observe_constant_length_slice(&public_values);
676
677        // Commit to the traces.
678        let (main_commit, main_data) = {
679            let _span = tracing::debug_span!("commit traces").entered();
680            self.commit_traces(&traces)
681        };
682        // Observe the commitments.
683        challenger.observe(main_commit);
684        // Observe the number of chips.
685        challenger.observe(GC::F::from_canonical_usize(shard_chips.len()));
686
687        for chips in shard_chips.iter() {
688            let num_real_entries = traces.get(chips.air.name()).unwrap().num_real_entries();
689            challenger.observe(GC::F::from_canonical_usize(num_real_entries));
690            challenger.observe(GC::F::from_canonical_usize(chips.air.name().len()));
691            for byte in chips.air.name().as_bytes() {
692                challenger.observe(GC::F::from_canonical_u8(*byte));
693            }
694        }
695
696        let logup_gkr_proof = {
697            let _span = tracing::debug_span!("logup gkr proof").entered();
698            self.inner.logup_gkr_prover.prove_logup_gkr(
699                &shard_chips,
700                &pk.preprocessed_data.preprocessed_traces,
701                &traces,
702                public_values.clone(),
703                &mut challenger,
704            )
705        };
706        // Get the challenge for batching constraints.
707        let batching_challenge = challenger.sample_ext_element::<GC::EF>();
708        // Get the challenge for batching the evaluations from the GKR proof.
709        let gkr_opening_batch_challenge = challenger.sample_ext_element::<GC::EF>();
710
711        #[cfg(sp1_debug_constraints)]
712        {
713            crate::debug::debug_constraints_all_chips::<GC, _>(
714                &shard_chips.iter().cloned().collect::<Vec<_>>(),
715                &pk.preprocessed_data.preprocessed_traces,
716                &traces,
717                &public_values,
718            );
719        }
720
721        // Generate the zerocheck proof.
722        let (shard_open_values, zerocheck_partial_sumcheck_proof) = {
723            let _span = tracing::debug_span!("zerocheck").entered();
724            self.zerocheck(
725                &shard_chips,
726                pk.preprocessed_data.preprocessed_traces.clone(),
727                traces,
728                batching_challenge,
729                gkr_opening_batch_challenge,
730                &logup_gkr_proof.logup_evaluations,
731                public_values.clone(),
732                &mut challenger,
733            )
734        };
735
736        // Get the evaluation point for the trace polynomials.
737        let evaluation_point = zerocheck_partial_sumcheck_proof.point_and_eval.0.clone();
738        let mut preprocessed_evaluation_claims: Option<Evaluations<GC::EF, CpuBackend>> = None;
739        let mut main_evaluation_claims = Evaluations::new(vec![]);
740
741        let alloc = self.inner.trace_generator.allocator();
742
743        for (_, open_values) in shard_open_values.chips.iter() {
744            let prep_local = &open_values.preprocessed.local;
745            let main_local = &open_values.main.local;
746            if !prep_local.is_empty() {
747                let preprocessed_evals = alloc.copy_to(&MleEval::from(prep_local.clone())).unwrap();
748                if let Some(preprocessed_claims) = preprocessed_evaluation_claims.as_mut() {
749                    preprocessed_claims.push(preprocessed_evals);
750                } else {
751                    let evals = Evaluations::new(vec![preprocessed_evals]);
752                    preprocessed_evaluation_claims = Some(evals);
753                }
754            }
755            let main_evals = alloc.copy_to(&MleEval::from(main_local.clone())).unwrap();
756            main_evaluation_claims.push(main_evals);
757        }
758
759        let round_evaluation_claims = preprocessed_evaluation_claims
760            .into_iter()
761            .chain(once(main_evaluation_claims))
762            .collect::<Rounds<_>>();
763
764        let round_prover_data = once(pk.preprocessed_data.preprocessed_data.clone())
765            .chain(once(main_data))
766            .collect::<Rounds<_>>();
767
768        // Generate the evaluation proof.
769        let evaluation_proof = {
770            let _span = tracing::debug_span!("prove evaluation claims").entered();
771            self.inner
772                .pcs_prover
773                .prove_trusted_evaluations(
774                    evaluation_point,
775                    round_evaluation_claims,
776                    round_prover_data,
777                    &mut challenger,
778                )
779                .unwrap()
780        };
781
782        let proof = ShardProof {
783            main_commitment: main_commit,
784            opened_values: shard_open_values,
785            logup_gkr_proof,
786            evaluation_proof,
787            zerocheck_proof: zerocheck_partial_sumcheck_proof,
788            public_values,
789        };
790
791        (proof, permit)
792    }
793}
794
795/// The shape of the core proof. This and prover setup parameters should entirely determine the
796/// verifier circuit.
797#[derive_where(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
798pub struct CoreProofShape<F: Field, A: MachineAir<F>> {
799    /// The chips included in the record.
800    pub shard_chips: BTreeSet<Chip<F, A>>,
801
802    /// The multiple of `log_stacking_height` that the preprocessed area adds up to.
803    pub preprocessed_multiple: usize,
804
805    /// The multiple of `log_stacking_height` that the main area adds up to.
806    pub main_multiple: usize,
807
808    /// The number of columns added to the preprocessed commit to round to the nearest multiple of
809    /// `stacking_height`.
810    pub preprocessed_padding_cols: usize,
811
812    /// The number of columns added to the main commit to round to the nearest multiple of
813    /// `stacking_height`.
814    pub main_padding_cols: usize,
815}
816
817impl<F, A> Debug for CoreProofShape<F, A>
818where
819    F: Field + Debug,
820    A: MachineAir<F> + Debug,
821{
822    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
823        f.debug_struct("ProofShape")
824            .field(
825                "shard_chips",
826                &self.shard_chips.iter().map(MachineAir::name).collect::<BTreeSet<_>>(),
827            )
828            .field("preprocessed_multiple", &self.preprocessed_multiple)
829            .field("main_multiple", &self.main_multiple)
830            .field("preprocessed_padding_cols", &self.preprocessed_padding_cols)
831            .field("main_padding_cols", &self.main_padding_cols)
832            .finish()
833    }
834}
835
836/// A proving key for a STARK.
837#[derive(Clone, Serialize, Deserialize)]
838#[serde(bound(
839    serialize = "Tensor<GC::F, CpuBackend>: Serialize, JaggedProverData<GC, C::ProverData>: Serialize, GC::F: Serialize,"
840))]
841#[serde(bound(
842    deserialize = "Tensor<GC::F, CpuBackend>: Deserialize<'de>, JaggedProverData<GC, C::ProverData>: Deserialize<'de>, GC::F: Deserialize<'de>, "
843))]
844pub struct ShardProverData<
845    GC: IopCtx,
846    SC: ShardContext<GC>,
847    C: MultilinearPcsProver<GC, PcsProof<GC, SC>>,
848> {
849    /// The preprocessed traces.
850    pub preprocessed_traces: Traces<GC::F, CpuBackend>,
851    /// The pcs data for the preprocessed traces.
852    pub preprocessed_data: JaggedProverData<GC, C::ProverData>,
853}