1use slop_air::BaseAir;
4use slop_algebra::PrimeField32;
5use slop_challenger::IopCtx;
6use std::{collections::BTreeMap, collections::BTreeSet, sync::Arc};
7
8use crate::{
9 air::MachineAir,
10 prover::{shard::AirProver, CoreProofShape, PcsProof, ProvingKey},
11 MachineVerifier, MachineVerifierConfigError, MachineVerifyingKey, ShardContext, ShardProof,
12 ShardVerifier,
13};
14
15use super::{PreprocessedData, ProverSemaphore};
16
17pub fn shape_from_record<GC: IopCtx, SC: ShardContext<GC>>(
21 verifier: &MachineVerifier<GC, SC>,
22 record: &<<SC as ShardContext<GC>>::Air as MachineAir<GC::F>>::Record,
23) -> Option<CoreProofShape<GC::F, SC::Air>> {
24 let log_stacking_height = verifier.log_stacking_height() as usize;
25 let max_log_row_count = verifier.max_log_row_count();
26 let airs = verifier.machine().chips();
27 let shard_chips: BTreeSet<_> =
28 airs.iter().filter(|air| air.included(record)).cloned().collect();
29 let preprocessed_multiple = shard_chips
30 .iter()
31 .map(|air| air.preprocessed_width() * air.num_rows(record).unwrap_or_default())
32 .sum::<usize>()
33 .div_ceil(1 << log_stacking_height);
34 let main_multiple = shard_chips
35 .iter()
36 .map(|air| air.width() * air.num_rows(record).unwrap_or_default())
37 .sum::<usize>()
38 .div_ceil(1 << log_stacking_height);
39
40 let main_padding_cols = (main_multiple * (1 << log_stacking_height)
41 - shard_chips
42 .iter()
43 .map(|air| air.width() * air.num_rows(record).unwrap_or_default())
44 .sum::<usize>())
45 .div_ceil(1 << max_log_row_count)
46 .max(1);
47
48 let preprocessed_padding_cols = (preprocessed_multiple * (1 << log_stacking_height)
49 - shard_chips
50 .iter()
51 .map(|air| air.preprocessed_width() * air.num_rows(record).unwrap_or_default())
52 .sum::<usize>())
53 .div_ceil(1 << max_log_row_count)
54 .max(1);
55
56 let shard_chips = verifier.machine().smallest_cluster(&shard_chips).cloned()?;
57 Some(CoreProofShape {
58 shard_chips,
59 preprocessed_multiple,
60 main_multiple,
61 preprocessed_padding_cols,
62 main_padding_cols,
63 })
64}
65
66fn single_permit() -> ProverSemaphore {
68 ProverSemaphore::new(1)
69}
70
71pub type Program<GC, SC> =
73 <<SC as ShardContext<GC>>::Air as MachineAir<<GC as IopCtx>::F>>::Program;
74
75pub type Record<GC, SC> = <<SC as ShardContext<GC>>::Air as MachineAir<<GC as IopCtx>::F>>::Record;
77
78pub struct SimpleProver<GC: IopCtx, SC: ShardContext<GC>, C: AirProver<GC, SC>> {
82 prover: Arc<C>,
84 verifier: MachineVerifier<GC, SC>,
86}
87
88impl<GC: IopCtx, SC: ShardContext<GC>, C: AirProver<GC, SC>> SimpleProver<GC, SC, C> {
89 #[must_use]
91 pub fn new(shard_verifier: ShardVerifier<GC, SC>, prover: C) -> Self {
92 Self { prover: Arc::new(prover), verifier: MachineVerifier::new(shard_verifier) }
93 }
94
95 pub fn verify(
97 &self,
98 vk: &MachineVerifyingKey<GC>,
99 proof: &crate::MachineProof<GC, PcsProof<GC, SC>>,
100 ) -> Result<(), MachineVerifierConfigError<GC, SC::Config>>
101 where
102 GC::F: PrimeField32,
103 {
104 self.verifier.verify(vk, proof)
105 }
106
107 #[must_use]
109 #[inline]
110 pub fn verifier(&self) -> &MachineVerifier<GC, SC> {
111 &self.verifier
112 }
113
114 #[must_use]
116 #[inline]
117 pub fn challenger(&self) -> GC::Challenger {
118 self.verifier.challenger()
119 }
120
121 #[must_use]
123 #[inline]
124 pub fn machine(&self) -> &crate::Machine<GC::F, SC::Air> {
125 self.verifier.machine()
126 }
127
128 #[must_use]
130 pub fn max_log_row_count(&self) -> usize {
131 self.verifier.max_log_row_count()
132 }
133
134 #[must_use]
136 pub fn log_stacking_height(&self) -> u32 {
137 self.verifier.log_stacking_height()
138 }
139
140 pub fn shape_from_record(
142 &self,
143 record: &Record<GC, SC>,
144 ) -> Option<CoreProofShape<GC::F, SC::Air>> {
145 shape_from_record(&self.verifier, record)
146 }
147
148 #[inline]
150 #[must_use]
151 #[tracing::instrument(skip_all, name = "simple_setup")]
152 pub async fn setup(
153 &self,
154 program: Arc<Program<GC, SC>>,
155 ) -> (PreprocessedData<ProvingKey<GC, SC, C>>, MachineVerifyingKey<GC>) {
156 self.prover.setup(program, single_permit()).await
157 }
158
159 #[inline]
161 #[must_use]
162 #[tracing::instrument(skip_all, name = "simple_prove_shard")]
163 pub async fn prove_shard(
164 &self,
165 pk: Arc<ProvingKey<GC, SC, C>>,
166 record: Record<GC, SC>,
167 ) -> ShardProof<GC, PcsProof<GC, SC>> {
168 let (proof, _) = self.prover.prove_shard_with_pk(pk, record, single_permit()).await;
169
170 proof
171 }
172
173 #[inline]
175 #[must_use]
176 #[allow(clippy::type_complexity)]
177 #[tracing::instrument(skip_all, name = "simple_setup_and_prove_shard")]
178 pub async fn setup_and_prove_shard(
179 &self,
180 program: Arc<Program<GC, SC>>,
181 vk: Option<MachineVerifyingKey<GC>>,
182 record: Record<GC, SC>,
183 ) -> (MachineVerifyingKey<GC>, ShardProof<GC, PcsProof<GC, SC>>) {
184 let (vk, proof, _) =
185 self.prover.setup_and_prove_shard(program, record, vk, single_permit()).await;
186
187 (vk, proof)
188 }
189
190 pub async fn preprocessed_table_heights(
192 &self,
193 pk: Arc<ProvingKey<GC, SC, C>>,
194 ) -> BTreeMap<String, usize> {
195 C::preprocessed_table_heights(pk).await
196 }
197}