1use slop_air::BaseAir;
4use slop_algebra::PrimeField32;
5use slop_challenger::IopCtx;
6use std::{collections::BTreeMap, collections::BTreeSet, sync::Arc};
7
8use crate::{
9 air::MachineAir,
10 prover::{shard::AirProver, CoreProofShape, PcsProof, ProvingKey},
11 MachineVerifier, MachineVerifierConfigError, MachineVerifyingKey, ShardContext, ShardProof,
12 ShardVerifier,
13};
14
15use super::{PreprocessedData, ProverSemaphore};
16
17pub fn shape_from_record<GC: IopCtx, SC: ShardContext<GC>>(
21 verifier: &MachineVerifier<GC, SC>,
22 record: &<<SC as ShardContext<GC>>::Air as MachineAir<GC::F>>::Record,
23) -> Option<CoreProofShape<GC::F, SC::Air>> {
24 let log_stacking_height = verifier.log_stacking_height() as usize;
25 let max_log_row_count = verifier.max_log_row_count();
26 let airs = verifier.machine().chips();
27 let shard_chips: BTreeSet<_> =
28 airs.iter().filter(|air| air.included(record)).cloned().collect();
29 let preprocessed_multiple = shard_chips
30 .iter()
31 .map(|air| air.preprocessed_width() * air.num_rows(record).unwrap_or_default())
32 .sum::<usize>()
33 .div_ceil(1 << log_stacking_height);
34 let main_multiple = shard_chips
35 .iter()
36 .map(|air| air.width() * air.num_rows(record).unwrap_or_default())
37 .sum::<usize>()
38 .div_ceil(1 << log_stacking_height);
39
40 let main_padding_cols = (main_multiple * (1 << log_stacking_height)
41 - shard_chips
42 .iter()
43 .map(|air| air.width() * air.num_rows(record).unwrap_or_default())
44 .sum::<usize>())
45 .div_ceil(1 << max_log_row_count);
46
47 let preprocessed_padding_cols = (preprocessed_multiple * (1 << log_stacking_height)
48 - shard_chips
49 .iter()
50 .map(|air| air.preprocessed_width() * air.num_rows(record).unwrap_or_default())
51 .sum::<usize>())
52 .div_ceil(1 << max_log_row_count);
53
54 let shard_chips = verifier.machine().smallest_cluster(&shard_chips).cloned()?;
55 Some(CoreProofShape {
56 shard_chips,
57 preprocessed_multiple,
58 main_multiple,
59 preprocessed_padding_cols,
60 main_padding_cols,
61 })
62}
63
64fn single_permit() -> ProverSemaphore {
66 ProverSemaphore::new(1)
67}
68
69pub type Program<GC, SC> =
71 <<SC as ShardContext<GC>>::Air as MachineAir<<GC as IopCtx>::F>>::Program;
72
73pub type Record<GC, SC> = <<SC as ShardContext<GC>>::Air as MachineAir<<GC as IopCtx>::F>>::Record;
75
76pub struct SimpleProver<GC: IopCtx, SC: ShardContext<GC>, C: AirProver<GC, SC>> {
80 prover: Arc<C>,
82 verifier: MachineVerifier<GC, SC>,
84}
85
86impl<GC: IopCtx, SC: ShardContext<GC>, C: AirProver<GC, SC>> SimpleProver<GC, SC, C> {
87 #[must_use]
89 pub fn new(shard_verifier: ShardVerifier<GC, SC>, prover: C) -> Self {
90 Self { prover: Arc::new(prover), verifier: MachineVerifier::new(shard_verifier) }
91 }
92
93 pub fn verify(
95 &self,
96 vk: &MachineVerifyingKey<GC>,
97 proof: &crate::MachineProof<GC, PcsProof<GC, SC>>,
98 ) -> Result<(), MachineVerifierConfigError<GC, SC::Config>>
99 where
100 GC::F: PrimeField32,
101 {
102 self.verifier.verify(vk, proof)
103 }
104
105 #[must_use]
107 #[inline]
108 pub fn verifier(&self) -> &MachineVerifier<GC, SC> {
109 &self.verifier
110 }
111
112 #[must_use]
114 #[inline]
115 pub fn challenger(&self) -> GC::Challenger {
116 self.verifier.challenger()
117 }
118
119 #[must_use]
121 #[inline]
122 pub fn machine(&self) -> &crate::Machine<GC::F, SC::Air> {
123 self.verifier.machine()
124 }
125
126 #[must_use]
128 pub fn max_log_row_count(&self) -> usize {
129 self.verifier.max_log_row_count()
130 }
131
132 #[must_use]
134 pub fn log_stacking_height(&self) -> u32 {
135 self.verifier.log_stacking_height()
136 }
137
138 pub fn shape_from_record(
140 &self,
141 record: &Record<GC, SC>,
142 ) -> Option<CoreProofShape<GC::F, SC::Air>> {
143 shape_from_record(&self.verifier, record)
144 }
145
146 #[inline]
148 #[must_use]
149 #[tracing::instrument(skip_all, name = "simple_setup")]
150 pub async fn setup(
151 &self,
152 program: Arc<Program<GC, SC>>,
153 ) -> (PreprocessedData<ProvingKey<GC, SC, C>>, MachineVerifyingKey<GC>) {
154 self.prover.setup(program, single_permit()).await
155 }
156
157 #[inline]
159 #[must_use]
160 #[tracing::instrument(skip_all, name = "simple_prove_shard")]
161 pub async fn prove_shard(
162 &self,
163 pk: Arc<ProvingKey<GC, SC, C>>,
164 record: Record<GC, SC>,
165 ) -> ShardProof<GC, PcsProof<GC, SC>> {
166 let (proof, _) = self.prover.prove_shard_with_pk(pk, record, single_permit()).await;
167
168 proof
169 }
170
171 #[inline]
173 #[must_use]
174 #[allow(clippy::type_complexity)]
175 #[tracing::instrument(skip_all, name = "simple_setup_and_prove_shard")]
176 pub async fn setup_and_prove_shard(
177 &self,
178 program: Arc<Program<GC, SC>>,
179 vk: Option<MachineVerifyingKey<GC>>,
180 record: Record<GC, SC>,
181 ) -> (MachineVerifyingKey<GC>, ShardProof<GC, PcsProof<GC, SC>>) {
182 let (vk, proof, _) =
183 self.prover.setup_and_prove_shard(program, record, vk, single_permit()).await;
184
185 (vk, proof)
186 }
187
188 pub async fn preprocessed_table_heights(
190 &self,
191 pk: Arc<ProvingKey<GC, SC, C>>,
192 ) -> BTreeMap<String, usize> {
193 C::preprocessed_table_heights(pk).await
194 }
195}