sp1_stark/
prover.rs

1use crate::{
2    air::InteractionScope, septic_curve::SepticCurve, septic_digest::SepticDigest,
3    septic_extension::SepticExtension, AirOpenedValues, ChipOpenedValues, ShardOpenedValues,
4};
5use core::fmt::Display;
6use itertools::Itertools;
7use p3_air::Air;
8use p3_challenger::{CanObserve, FieldChallenger};
9use p3_commit::{Pcs, PolynomialSpace};
10use p3_field::{AbstractExtensionField, AbstractField, PrimeField32};
11use p3_matrix::{dense::RowMajorMatrix, Matrix};
12use p3_maybe_rayon::prelude::*;
13use p3_uni_stark::SymbolicAirBuilder;
14use p3_util::log2_strict_usize;
15use serde::{de::DeserializeOwned, Serialize};
16use std::{cmp::Reverse, error::Error, time::Instant};
17
18use super::{
19    quotient_values, Com, OpeningProof, StarkGenericConfig, StarkMachine, StarkProvingKey, Val,
20    VerifierConstraintFolder,
21};
22use crate::{
23    air::MachineAir, lookup::InteractionBuilder, opts::SP1CoreOpts, record::MachineRecord,
24    Challenger, DebugConstraintBuilder, MachineChip, MachineProof, PackedChallenge, PcsProverData,
25    ProverConstraintFolder, ShardCommitment, ShardMainData, ShardProof, StarkVerifyingKey,
26};
27
28/// An algorithmic & hardware independent prover implementation for any [`MachineAir`].
29pub trait MachineProver<SC: StarkGenericConfig, A: MachineAir<SC::Val>>:
30    'static + Send + Sync
31{
32    /// The type used to store the traces.
33    type DeviceMatrix: Matrix<SC::Val>;
34
35    /// The type used to store the polynomial commitment schemes data.
36    type DeviceProverData;
37
38    /// The type used to store the proving key.
39    type DeviceProvingKey: MachineProvingKey<SC>;
40
41    /// The type used for error handling.
42    type Error: Error + Send + Sync;
43
44    /// Create a new prover from a given machine.
45    fn new(machine: StarkMachine<SC, A>) -> Self;
46
47    /// A reference to the machine that this prover is using.
48    fn machine(&self) -> &StarkMachine<SC, A>;
49
50    /// Setup the preprocessed data into a proving and verifying key.
51    fn setup(&self, program: &A::Program) -> (Self::DeviceProvingKey, StarkVerifyingKey<SC>);
52
53    /// Setup the proving key given a verifying key. This is similar to `setup` but faster since
54    /// some computed information is already in the verifying key.
55    fn pk_from_vk(
56        &self,
57        program: &A::Program,
58        vk: &StarkVerifyingKey<SC>,
59    ) -> Self::DeviceProvingKey;
60
61    /// Copy the proving key from the host to the device.
62    fn pk_to_device(&self, pk: &StarkProvingKey<SC>) -> Self::DeviceProvingKey;
63
64    /// Copy the proving key from the device to the host.
65    fn pk_to_host(&self, pk: &Self::DeviceProvingKey) -> StarkProvingKey<SC>;
66
67    /// Generate the main traces.
68    fn generate_traces(&self, record: &A::Record) -> Vec<(String, RowMajorMatrix<Val<SC>>)> {
69        let shard_chips = self.shard_chips(record).collect::<Vec<_>>();
70
71        // For each chip, generate the trace.
72        let parent_span = tracing::debug_span!("generate traces for shard");
73        parent_span.in_scope(|| {
74            shard_chips
75                .par_iter()
76                .map(|chip| {
77                    let chip_name = chip.name();
78                    let begin = Instant::now();
79                    let trace = chip.generate_trace(record, &mut A::Record::default());
80                    tracing::debug!(
81                        parent: &parent_span,
82                        "generated trace for chip {} in {:?}",
83                        chip_name,
84                        begin.elapsed()
85                    );
86                    (chip_name, trace)
87                })
88                .collect::<Vec<_>>()
89        })
90    }
91
92    /// Commit to the main traces.
93    fn commit(
94        &self,
95        record: &A::Record,
96        traces: Vec<(String, RowMajorMatrix<Val<SC>>)>,
97    ) -> ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData>;
98
99    /// Observe the main commitment and public values and update the challenger.
100    fn observe(
101        &self,
102        challenger: &mut SC::Challenger,
103        commitment: Com<SC>,
104        public_values: &[SC::Val],
105    ) {
106        // Observe the commitment.
107        challenger.observe(commitment);
108
109        // Observe the public values.
110        challenger.observe_slice(public_values);
111    }
112
113    /// Compute the openings of the traces.
114    fn open(
115        &self,
116        pk: &Self::DeviceProvingKey,
117        data: ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData>,
118        challenger: &mut SC::Challenger,
119    ) -> Result<ShardProof<SC>, Self::Error>;
120
121    /// Generate a proof for the given records.
122    fn prove(
123        &self,
124        pk: &Self::DeviceProvingKey,
125        records: Vec<A::Record>,
126        challenger: &mut SC::Challenger,
127        opts: <A::Record as MachineRecord>::Config,
128    ) -> Result<MachineProof<SC>, Self::Error>
129    where
130        A: for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>;
131
132    /// The stark config for the machine.
133    fn config(&self) -> &SC {
134        self.machine().config()
135    }
136
137    /// The number of public values elements.
138    fn num_pv_elts(&self) -> usize {
139        self.machine().num_pv_elts()
140    }
141
142    /// The chips that will be necessary to prove this record.
143    fn shard_chips<'a, 'b>(
144        &'a self,
145        record: &'b A::Record,
146    ) -> impl Iterator<Item = &'b MachineChip<SC, A>>
147    where
148        'a: 'b,
149        SC: 'b,
150    {
151        self.machine().shard_chips(record)
152    }
153
154    /// Debug the constraints for the given inputs.
155    fn debug_constraints(
156        &self,
157        pk: &StarkProvingKey<SC>,
158        records: Vec<A::Record>,
159        challenger: &mut SC::Challenger,
160    ) where
161        SC::Val: PrimeField32,
162        A: for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>,
163    {
164        self.machine().debug_constraints(pk, records, challenger);
165    }
166}
167
168/// A proving key for any [`MachineAir`] that is agnostic to hardware.
169pub trait MachineProvingKey<SC: StarkGenericConfig>: Send + Sync {
170    /// The main commitment.
171    fn preprocessed_commit(&self) -> Com<SC>;
172
173    /// The start pc.
174    fn pc_start(&self) -> Val<SC>;
175
176    /// The initial global cumulative sum.
177    fn initial_global_cumulative_sum(&self) -> SepticDigest<Val<SC>>;
178
179    /// Observe itself in the challenger.
180    fn observe_into(&self, challenger: &mut Challenger<SC>);
181}
182
183/// A prover implementation based on x86 and ARM CPUs.
184pub struct CpuProver<SC: StarkGenericConfig, A> {
185    machine: StarkMachine<SC, A>,
186}
187
188/// An error that occurs during the execution of the [`CpuProver`].
189#[derive(Debug, Clone, Copy)]
190pub struct CpuProverError;
191
192impl<SC, A> MachineProver<SC, A> for CpuProver<SC, A>
193where
194    SC: 'static + StarkGenericConfig + Send + Sync,
195    A: MachineAir<SC::Val>
196        + for<'a> Air<ProverConstraintFolder<'a, SC>>
197        + Air<InteractionBuilder<Val<SC>>>
198        + for<'a> Air<VerifierConstraintFolder<'a, SC>>
199        + for<'a> Air<SymbolicAirBuilder<Val<SC>>>,
200    A::Record: MachineRecord<Config = SP1CoreOpts>,
201    SC::Val: PrimeField32,
202    Com<SC>: Send + Sync,
203    PcsProverData<SC>: Send + Sync + Serialize + DeserializeOwned,
204    OpeningProof<SC>: Send + Sync,
205    SC::Challenger: Clone,
206{
207    type DeviceMatrix = RowMajorMatrix<Val<SC>>;
208    type DeviceProverData = PcsProverData<SC>;
209    type DeviceProvingKey = StarkProvingKey<SC>;
210    type Error = CpuProverError;
211
212    fn new(machine: StarkMachine<SC, A>) -> Self {
213        Self { machine }
214    }
215
216    fn machine(&self) -> &StarkMachine<SC, A> {
217        &self.machine
218    }
219
220    fn setup(&self, program: &A::Program) -> (Self::DeviceProvingKey, StarkVerifyingKey<SC>) {
221        self.machine().setup(program)
222    }
223
224    fn pk_from_vk(
225        &self,
226        program: &A::Program,
227        vk: &StarkVerifyingKey<SC>,
228    ) -> Self::DeviceProvingKey {
229        self.machine().setup_core(program, vk.initial_global_cumulative_sum).0
230    }
231
232    fn pk_to_device(&self, pk: &StarkProvingKey<SC>) -> Self::DeviceProvingKey {
233        pk.clone()
234    }
235
236    fn pk_to_host(&self, pk: &Self::DeviceProvingKey) -> StarkProvingKey<SC> {
237        pk.clone()
238    }
239
240    fn commit(
241        &self,
242        record: &A::Record,
243        mut named_traces: Vec<(String, RowMajorMatrix<Val<SC>>)>,
244    ) -> ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData> {
245        // Order the chips and traces by trace size (biggest first), and get the ordering map.
246        named_traces.sort_by_key(|(name, trace)| (Reverse(trace.height()), name.clone()));
247
248        let pcs = self.config().pcs();
249
250        let domains_and_traces = named_traces
251            .iter()
252            .map(|(_, trace)| {
253                let domain = pcs.natural_domain_for_degree(trace.height());
254                (domain, trace.to_owned())
255            })
256            .collect::<Vec<_>>();
257
258        // Commit to the batch of traces.
259        let (main_commit, main_data) = pcs.commit(domains_and_traces);
260
261        // Get the chip ordering.
262        let chip_ordering =
263            named_traces.iter().enumerate().map(|(i, (name, _))| (name.to_owned(), i)).collect();
264
265        let traces = named_traces.into_iter().map(|(_, trace)| trace).collect::<Vec<_>>();
266
267        ShardMainData {
268            traces,
269            main_commit,
270            main_data,
271            chip_ordering,
272            public_values: record.public_values(),
273        }
274    }
275
276    /// Prove the program for the given shard and given a commitment to the main data.
277    #[allow(clippy::too_many_lines)]
278    #[allow(clippy::redundant_closure_for_method_calls)]
279    #[allow(clippy::map_unwrap_or)]
280    fn open(
281        &self,
282        pk: &StarkProvingKey<SC>,
283        data: ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData>,
284        challenger: &mut <SC as StarkGenericConfig>::Challenger,
285    ) -> Result<ShardProof<SC>, Self::Error> {
286        let chips = self.machine().shard_chips_ordered(&data.chip_ordering).collect::<Vec<_>>();
287        let traces = data.traces;
288
289        let config = self.machine().config();
290
291        let degrees = traces.iter().map(|trace| trace.height()).collect::<Vec<_>>();
292
293        let log_degrees =
294            degrees.iter().map(|degree| log2_strict_usize(*degree)).collect::<Vec<_>>();
295
296        let log_quotient_degrees =
297            chips.iter().map(|chip| chip.log_quotient_degree()).collect::<Vec<_>>();
298
299        let pcs = config.pcs();
300        let trace_domains =
301            degrees.iter().map(|degree| pcs.natural_domain_for_degree(*degree)).collect::<Vec<_>>();
302
303        // Observe the public values and the main commitment.
304        challenger.observe_slice(&data.public_values[0..self.num_pv_elts()]);
305        challenger.observe(data.main_commit.clone());
306
307        // Obtain the challenges used for the local permutation argument.
308        let mut local_permutation_challenges: Vec<SC::Challenge> = Vec::new();
309        for _ in 0..2 {
310            local_permutation_challenges.push(challenger.sample_ext_element());
311        }
312
313        let packed_perm_challenges = local_permutation_challenges
314            .iter()
315            .map(|c| PackedChallenge::<SC>::from_f(*c))
316            .collect::<Vec<_>>();
317
318        // Generate the permutation traces.
319        let ((permutation_traces, prep_traces), (global_cumulative_sums, local_cumulative_sums)): (
320            (Vec<_>, Vec<_>),
321            (Vec<_>, Vec<_>),
322        ) = tracing::debug_span!("generate permutation traces").in_scope(|| {
323            chips
324                .par_iter()
325                .zip(traces.par_iter())
326                .map(|(chip, main_trace)| {
327                    let preprocessed_trace =
328                        pk.chip_ordering.get(&chip.name()).map(|&index| &pk.traces[index]);
329                    let (perm_trace, local_sum) = chip.generate_permutation_trace(
330                        preprocessed_trace,
331                        main_trace,
332                        &local_permutation_challenges,
333                    );
334                    let global_sum = if chip.commit_scope() == InteractionScope::Local {
335                        SepticDigest::<Val<SC>>::zero()
336                    } else {
337                        let main_trace_size = main_trace.height() * main_trace.width();
338                        let last_row = &main_trace.values[main_trace_size - 14..main_trace_size];
339                        SepticDigest(SepticCurve {
340                            x: SepticExtension::<Val<SC>>::from_base_fn(|i| last_row[i]),
341                            y: SepticExtension::<Val<SC>>::from_base_fn(|i| last_row[i + 7]),
342                        })
343                    };
344                    ((perm_trace, preprocessed_trace), (global_sum, local_sum))
345                })
346                .unzip()
347        });
348
349        // Compute some statistics.
350        for i in 0..chips.len() {
351            let trace_width = traces[i].width();
352            let trace_height = traces[i].height();
353            let prep_width = prep_traces[i].map_or(0, |x| x.width());
354            let permutation_width = permutation_traces[i].width();
355            let total_width = trace_width +
356                prep_width +
357                permutation_width * <SC::Challenge as AbstractExtensionField<SC::Val>>::D;
358            tracing::debug!(
359                "{:<15} | Main Cols = {:<5} | Pre Cols = {:<5}  | Perm Cols = {:<5} | Rows = {:<5} | Cells = {:<10}",
360                chips[i].name(),
361                trace_width,
362                prep_width,
363                permutation_width * <SC::Challenge as AbstractExtensionField<SC::Val>>::D,
364                trace_height,
365                total_width * trace_height,
366            );
367        }
368
369        let domains_and_perm_traces =
370            tracing::debug_span!("flatten permutation traces and collect domains").in_scope(|| {
371                permutation_traces
372                    .into_iter()
373                    .zip(trace_domains.iter())
374                    .map(|(perm_trace, domain)| {
375                        let trace = perm_trace.flatten_to_base();
376                        (*domain, trace.clone())
377                    })
378                    .collect::<Vec<_>>()
379            });
380
381        let pcs = config.pcs();
382
383        let (permutation_commit, permutation_data) =
384            tracing::debug_span!("commit to permutation traces")
385                .in_scope(|| pcs.commit(domains_and_perm_traces));
386
387        // Observe the permutation commitment and cumulative sums.
388        challenger.observe(permutation_commit.clone());
389        for (local_sum, global_sum) in
390            local_cumulative_sums.iter().zip(global_cumulative_sums.iter())
391        {
392            challenger.observe_slice(local_sum.as_base_slice());
393            challenger.observe_slice(&global_sum.0.x.0);
394            challenger.observe_slice(&global_sum.0.y.0);
395        }
396
397        // Compute the quotient polynomial for all chips.
398        let quotient_domains = trace_domains
399            .iter()
400            .zip_eq(log_degrees.iter())
401            .zip_eq(log_quotient_degrees.iter())
402            .map(|((domain, log_degree), log_quotient_degree)| {
403                domain.create_disjoint_domain(1 << (log_degree + log_quotient_degree))
404            })
405            .collect::<Vec<_>>();
406
407        // Compute the quotient values.
408        let alpha: SC::Challenge = challenger.sample_ext_element::<SC::Challenge>();
409        let parent_span = tracing::debug_span!("compute quotient values");
410        let quotient_values =
411            parent_span.in_scope(|| {
412                quotient_domains
413                .into_par_iter()
414                .enumerate()
415                .map(|(i, quotient_domain)| {
416                    tracing::debug_span!(parent: &parent_span, "compute quotient values for domain")
417                        .in_scope(|| {
418                            let preprocessed_trace_on_quotient_domains =
419                                pk.chip_ordering.get(&chips[i].name()).map(|&index| {
420                                    pcs.get_evaluations_on_domain(&pk.data, index, *quotient_domain)
421                                        .to_row_major_matrix()
422                                });
423                            let main_trace_on_quotient_domains = pcs
424                                .get_evaluations_on_domain(&data.main_data, i, *quotient_domain)
425                                .to_row_major_matrix();
426                            let permutation_trace_on_quotient_domains = pcs
427                                .get_evaluations_on_domain(&permutation_data, i, *quotient_domain)
428                                .to_row_major_matrix();
429
430                            let chip_num_constraints =
431                                pk.constraints_map.get(&chips[i].name()).unwrap();
432
433                            // Calculate powers of alpha for constraint evaluation:
434                            // 1. Generate sequence [α⁰, α¹, ..., α^(n-1)] where n = chip_num_constraints.
435                            // 2. Reverse to [α^(n-1), ..., α¹, α⁰] to align with Horner's method in the verifier.
436                            let powers_of_alpha =
437                                alpha.powers().take(*chip_num_constraints).collect::<Vec<_>>();
438                            let mut powers_of_alpha_rev = powers_of_alpha.clone();
439                            powers_of_alpha_rev.reverse();
440
441                            quotient_values(
442                                chips[i],
443                                &local_cumulative_sums[i],
444                                &global_cumulative_sums[i],
445                                trace_domains[i],
446                                *quotient_domain,
447                                preprocessed_trace_on_quotient_domains,
448                                main_trace_on_quotient_domains,
449                                permutation_trace_on_quotient_domains,
450                                &packed_perm_challenges,
451                                &powers_of_alpha_rev,
452                                &data.public_values,
453                            )
454                        })
455                })
456                .collect::<Vec<_>>()
457            });
458
459        // Split the quotient values and commit to them.
460        let quotient_domains_and_chunks = quotient_domains
461            .into_iter()
462            .zip_eq(quotient_values)
463            .zip_eq(log_quotient_degrees.iter())
464            .flat_map(|((quotient_domain, quotient_values), log_quotient_degree)| {
465                let quotient_degree = 1 << *log_quotient_degree;
466                let quotient_flat = RowMajorMatrix::new_col(quotient_values).flatten_to_base();
467                let quotient_chunks = quotient_domain.split_evals(quotient_degree, quotient_flat);
468                let qc_domains = quotient_domain.split_domains(quotient_degree);
469                qc_domains.into_iter().zip_eq(quotient_chunks)
470            })
471            .collect::<Vec<_>>();
472
473        let num_quotient_chunks = quotient_domains_and_chunks.len();
474        assert_eq!(
475            num_quotient_chunks,
476            chips.iter().map(|c| 1 << c.log_quotient_degree()).sum::<usize>()
477        );
478
479        let (quotient_commit, quotient_data) = tracing::debug_span!("commit to quotient traces")
480            .in_scope(|| pcs.commit(quotient_domains_and_chunks));
481        challenger.observe(quotient_commit.clone());
482
483        // Compute the quotient argument.
484        let zeta: SC::Challenge = challenger.sample_ext_element();
485
486        let preprocessed_opening_points =
487            tracing::debug_span!("compute preprocessed opening points").in_scope(|| {
488                pk.traces
489                    .iter()
490                    .zip(pk.local_only.iter())
491                    .map(|(trace, local_only)| {
492                        let domain = pcs.natural_domain_for_degree(trace.height());
493                        if !local_only {
494                            vec![zeta, domain.next_point(zeta).unwrap()]
495                        } else {
496                            vec![zeta]
497                        }
498                    })
499                    .collect::<Vec<_>>()
500            });
501
502        let main_trace_opening_points = tracing::debug_span!("compute main trace opening points")
503            .in_scope(|| {
504                trace_domains
505                    .iter()
506                    .zip(chips.iter())
507                    .map(|(domain, chip)| {
508                        if !chip.local_only() {
509                            vec![zeta, domain.next_point(zeta).unwrap()]
510                        } else {
511                            vec![zeta]
512                        }
513                    })
514                    .collect::<Vec<_>>()
515            });
516
517        let permutation_trace_opening_points =
518            tracing::debug_span!("compute permutation trace opening points").in_scope(|| {
519                trace_domains
520                    .iter()
521                    .map(|domain| vec![zeta, domain.next_point(zeta).unwrap()])
522                    .collect::<Vec<_>>()
523            });
524
525        // Compute quotient opening points, open every chunk at zeta.
526        let quotient_opening_points =
527            (0..num_quotient_chunks).map(|_| vec![zeta]).collect::<Vec<_>>();
528
529        let (openings, opening_proof) = tracing::debug_span!("open multi batches").in_scope(|| {
530            pcs.open(
531                vec![
532                    (&pk.data, preprocessed_opening_points),
533                    (&data.main_data, main_trace_opening_points.clone()),
534                    (&permutation_data, permutation_trace_opening_points.clone()),
535                    (&quotient_data, quotient_opening_points),
536                ],
537                challenger,
538            )
539        });
540
541        // Collect the opened values for each chip.
542        let [preprocessed_values, main_values, permutation_values, mut quotient_values] =
543            openings.try_into().unwrap();
544        assert!(main_values.len() == chips.len());
545        let preprocessed_opened_values = preprocessed_values
546            .into_iter()
547            .zip(pk.local_only.iter())
548            .map(|(op, local_only)| {
549                if !local_only {
550                    let [local, next] = op.try_into().unwrap();
551                    AirOpenedValues { local, next }
552                } else {
553                    let [local] = op.try_into().unwrap();
554                    let width = local.len();
555                    AirOpenedValues { local, next: vec![SC::Challenge::zero(); width] }
556                }
557            })
558            .collect::<Vec<_>>();
559
560        let main_opened_values = main_values
561            .into_iter()
562            .zip(chips.iter())
563            .map(|(op, chip)| {
564                if !chip.local_only() {
565                    let [local, next] = op.try_into().unwrap();
566                    AirOpenedValues { local, next }
567                } else {
568                    let [local] = op.try_into().unwrap();
569                    let width = local.len();
570                    AirOpenedValues { local, next: vec![SC::Challenge::zero(); width] }
571                }
572            })
573            .collect::<Vec<_>>();
574        let permutation_opened_values = permutation_values
575            .into_iter()
576            .map(|op| {
577                let [local, next] = op.try_into().unwrap();
578                AirOpenedValues { local, next }
579            })
580            .collect::<Vec<_>>();
581        let mut quotient_opened_values = Vec::with_capacity(log_quotient_degrees.len());
582        for log_quotient_degree in log_quotient_degrees.iter() {
583            let degree = 1 << *log_quotient_degree;
584            let slice = quotient_values.drain(0..degree);
585            quotient_opened_values.push(slice.map(|mut op| op.pop().unwrap()).collect::<Vec<_>>());
586        }
587
588        let opened_values = main_opened_values
589            .into_iter()
590            .zip_eq(permutation_opened_values)
591            .zip_eq(quotient_opened_values)
592            .zip_eq(local_cumulative_sums)
593            .zip_eq(global_cumulative_sums)
594            .zip_eq(log_degrees.iter())
595            .enumerate()
596            .map(
597                |(
598                    i,
599                    (
600                        (
601                            (((main, permutation), quotient), local_cumulative_sum),
602                            global_cumulative_sum,
603                        ),
604                        log_degree,
605                    ),
606                )| {
607                    let preprocessed = pk
608                        .chip_ordering
609                        .get(&chips[i].name())
610                        .map(|&index| preprocessed_opened_values[index].clone())
611                        .unwrap_or(AirOpenedValues { local: vec![], next: vec![] });
612                    ChipOpenedValues {
613                        preprocessed,
614                        main,
615                        permutation,
616                        quotient,
617                        global_cumulative_sum,
618                        local_cumulative_sum,
619                        log_degree: *log_degree,
620                    }
621                },
622            )
623            .collect::<Vec<_>>();
624
625        Ok(ShardProof::<SC> {
626            commitment: ShardCommitment {
627                main_commit: data.main_commit.clone(),
628                permutation_commit,
629                quotient_commit,
630            },
631            opened_values: ShardOpenedValues { chips: opened_values },
632            opening_proof,
633            chip_ordering: data.chip_ordering,
634            public_values: data.public_values,
635        })
636    }
637
638    /// Prove the execution record is valid.
639    ///
640    /// Given a proving key `pk` and a matching execution record `record`, this function generates
641    /// a STARK proof that the execution record is valid.
642    #[allow(clippy::needless_for_each)]
643    fn prove(
644        &self,
645        pk: &StarkProvingKey<SC>,
646        mut records: Vec<A::Record>,
647        challenger: &mut SC::Challenger,
648        opts: <A::Record as MachineRecord>::Config,
649    ) -> Result<MachineProof<SC>, Self::Error>
650    where
651        A: for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>,
652    {
653        // Generate dependencies.
654        self.machine().generate_dependencies(&mut records, &opts, None);
655
656        // Observe the preprocessed commitment.
657        pk.observe_into(challenger);
658
659        let shard_proofs = tracing::info_span!("prove_shards").in_scope(|| {
660            records
661                .into_par_iter()
662                .map(|record| {
663                    let named_traces = self.generate_traces(&record);
664                    let shard_data = self.commit(&record, named_traces);
665                    self.open(pk, shard_data, &mut challenger.clone())
666                })
667                .collect::<Result<Vec<_>, _>>()
668        })?;
669
670        Ok(MachineProof { shard_proofs })
671    }
672}
673
674impl<SC> MachineProvingKey<SC> for StarkProvingKey<SC>
675where
676    SC: 'static + StarkGenericConfig + Send + Sync,
677    PcsProverData<SC>: Send + Sync + Serialize + DeserializeOwned,
678    Com<SC>: Send + Sync,
679{
680    fn preprocessed_commit(&self) -> Com<SC> {
681        self.commit.clone()
682    }
683
684    fn pc_start(&self) -> Val<SC> {
685        self.pc_start
686    }
687
688    fn initial_global_cumulative_sum(&self) -> SepticDigest<Val<SC>> {
689        self.initial_global_cumulative_sum
690    }
691
692    fn observe_into(&self, challenger: &mut Challenger<SC>) {
693        challenger.observe(self.commit.clone());
694        challenger.observe(self.pc_start);
695        challenger.observe_slice(&self.initial_global_cumulative_sum.0.x.0);
696        challenger.observe_slice(&self.initial_global_cumulative_sum.0.y.0);
697        let zero = Val::<SC>::zero();
698        challenger.observe(zero);
699    }
700}
701
702impl Display for CpuProverError {
703    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
704        write!(f, "DefaultProverError")
705    }
706}
707
708impl Error for CpuProverError {}