Skip to main content

sp1_hypercube/prover/
simple.rs

1//! Simplified prover for test and development use.
2
3use slop_air::BaseAir;
4use slop_algebra::PrimeField32;
5use slop_challenger::IopCtx;
6use std::{collections::BTreeMap, collections::BTreeSet, sync::Arc};
7
8use crate::{
9    air::MachineAir,
10    prover::{shard::AirProver, CoreProofShape, PcsProof, ProvingKey},
11    MachineVerifier, MachineVerifierConfigError, MachineVerifyingKey, ShardContext, ShardProof,
12    ShardVerifier,
13};
14
15use super::{PreprocessedData, ProverSemaphore};
16
17/// Given a record, compute the shape of the resulting shard proof.
18///
19/// This is a standalone function that can be used outside of `SimpleProver`.
20pub fn shape_from_record<GC: IopCtx, SC: ShardContext<GC>>(
21    verifier: &MachineVerifier<GC, SC>,
22    record: &<<SC as ShardContext<GC>>::Air as MachineAir<GC::F>>::Record,
23) -> Option<CoreProofShape<GC::F, SC::Air>> {
24    let log_stacking_height = verifier.log_stacking_height() as usize;
25    let max_log_row_count = verifier.max_log_row_count();
26    let airs = verifier.machine().chips();
27    let shard_chips: BTreeSet<_> =
28        airs.iter().filter(|air| air.included(record)).cloned().collect();
29    let preprocessed_multiple = shard_chips
30        .iter()
31        .map(|air| air.preprocessed_width() * air.num_rows(record).unwrap_or_default())
32        .sum::<usize>()
33        .div_ceil(1 << log_stacking_height);
34    let main_multiple = shard_chips
35        .iter()
36        .map(|air| air.width() * air.num_rows(record).unwrap_or_default())
37        .sum::<usize>()
38        .div_ceil(1 << log_stacking_height);
39
40    let main_padding_cols = (main_multiple * (1 << log_stacking_height)
41        - shard_chips
42            .iter()
43            .map(|air| air.width() * air.num_rows(record).unwrap_or_default())
44            .sum::<usize>())
45    .div_ceil(1 << max_log_row_count);
46
47    let preprocessed_padding_cols = (preprocessed_multiple * (1 << log_stacking_height)
48        - shard_chips
49            .iter()
50            .map(|air| air.preprocessed_width() * air.num_rows(record).unwrap_or_default())
51            .sum::<usize>())
52    .div_ceil(1 << max_log_row_count);
53
54    let shard_chips = verifier.machine().smallest_cluster(&shard_chips).cloned()?;
55    Some(CoreProofShape {
56        shard_chips,
57        preprocessed_multiple,
58        main_multiple,
59        preprocessed_padding_cols,
60        main_padding_cols,
61    })
62}
63
64/// Create a single-permit semaphore for simple prover operations.
65fn single_permit() -> ProverSemaphore {
66    ProverSemaphore::new(1)
67}
68
69/// The type of program this prover can make proofs for.
70pub type Program<GC, SC> =
71    <<SC as ShardContext<GC>>::Air as MachineAir<<GC as IopCtx>::F>>::Program;
72
73/// The execution record for this prover.
74pub type Record<GC, SC> = <<SC as ShardContext<GC>>::Air as MachineAir<<GC as IopCtx>::F>>::Record;
75
76/// A prover that proves traces sequentially using a single `AirProver`.
77///
78/// Prioritizes simplicity over performance - suitable for tests and development.
79pub struct SimpleProver<GC: IopCtx, SC: ShardContext<GC>, C: AirProver<GC, SC>> {
80    /// The underlying prover.
81    prover: Arc<C>,
82    /// The verifier.
83    verifier: MachineVerifier<GC, SC>,
84}
85
86impl<GC: IopCtx, SC: ShardContext<GC>, C: AirProver<GC, SC>> SimpleProver<GC, SC, C> {
87    /// Create a new simple prover.
88    #[must_use]
89    pub fn new(shard_verifier: ShardVerifier<GC, SC>, prover: C) -> Self {
90        Self { prover: Arc::new(prover), verifier: MachineVerifier::new(shard_verifier) }
91    }
92
93    /// Verify a machine proof.
94    pub fn verify(
95        &self,
96        vk: &MachineVerifyingKey<GC>,
97        proof: &crate::MachineProof<GC, PcsProof<GC, SC>>,
98    ) -> Result<(), MachineVerifierConfigError<GC, SC::Config>>
99    where
100        GC::F: PrimeField32,
101    {
102        self.verifier.verify(vk, proof)
103    }
104
105    /// Get the verifier.
106    #[must_use]
107    #[inline]
108    pub fn verifier(&self) -> &MachineVerifier<GC, SC> {
109        &self.verifier
110    }
111
112    /// Get a new challenger.
113    #[must_use]
114    #[inline]
115    pub fn challenger(&self) -> GC::Challenger {
116        self.verifier.challenger()
117    }
118
119    /// Get the machine.
120    #[must_use]
121    #[inline]
122    pub fn machine(&self) -> &crate::Machine<GC::F, SC::Air> {
123        self.verifier.machine()
124    }
125
126    /// Get the maximum log row count.
127    #[must_use]
128    pub fn max_log_row_count(&self) -> usize {
129        self.verifier.max_log_row_count()
130    }
131
132    /// Get the log stacking height.
133    #[must_use]
134    pub fn log_stacking_height(&self) -> u32 {
135        self.verifier.log_stacking_height()
136    }
137
138    /// Given a record, compute the shape of the resulting shard proof.
139    pub fn shape_from_record(
140        &self,
141        record: &Record<GC, SC>,
142    ) -> Option<CoreProofShape<GC::F, SC::Air>> {
143        shape_from_record(&self.verifier, record)
144    }
145
146    /// Setup the prover for a given program.
147    #[inline]
148    #[must_use]
149    #[tracing::instrument(skip_all, name = "simple_setup")]
150    pub async fn setup(
151        &self,
152        program: Arc<Program<GC, SC>>,
153    ) -> (PreprocessedData<ProvingKey<GC, SC, C>>, MachineVerifyingKey<GC>) {
154        self.prover.setup(program, single_permit()).await
155    }
156
157    /// Prove a shard with a given proving key.
158    #[inline]
159    #[must_use]
160    #[tracing::instrument(skip_all, name = "simple_prove_shard")]
161    pub async fn prove_shard(
162        &self,
163        pk: Arc<ProvingKey<GC, SC, C>>,
164        record: Record<GC, SC>,
165    ) -> ShardProof<GC, PcsProof<GC, SC>> {
166        let (proof, _) = self.prover.prove_shard_with_pk(pk, record, single_permit()).await;
167
168        proof
169    }
170
171    /// Setup and prove a shard in one call.
172    #[inline]
173    #[must_use]
174    #[allow(clippy::type_complexity)]
175    #[tracing::instrument(skip_all, name = "simple_setup_and_prove_shard")]
176    pub async fn setup_and_prove_shard(
177        &self,
178        program: Arc<Program<GC, SC>>,
179        vk: Option<MachineVerifyingKey<GC>>,
180        record: Record<GC, SC>,
181    ) -> (MachineVerifyingKey<GC>, ShardProof<GC, PcsProof<GC, SC>>) {
182        let (vk, proof, _) =
183            self.prover.setup_and_prove_shard(program, record, vk, single_permit()).await;
184
185        (vk, proof)
186    }
187
188    /// Get the preprocessed table heights from the proving key.
189    pub async fn preprocessed_table_heights(
190        &self,
191        pk: Arc<ProvingKey<GC, SC, C>>,
192    ) -> BTreeMap<String, usize> {
193        C::preprocessed_table_heights(pk).await
194    }
195}