1use std::{
2 collections::{BTreeMap, BTreeSet},
3 marker::PhantomData,
4 path::PathBuf,
5};
6
7use slop_algebra::{AbstractField, PrimeField32};
8use slop_challenger::IopCtx;
9use sp1_core_machine::riscv::RiscvAir;
10use sp1_hypercube::{
11 air::{POSEIDON_NUM_WORDS, PROOF_NONCE_NUM_WORDS},
12 prover::ZerocheckAir,
13 verify_merkle_proof, HashableKey, MachineVerifier, MachineVerifyingKey, MerkleProof,
14 SP1InnerPcs, SP1PcsProofInner, ShardVerifier, SP1SC,
15};
16use sp1_primitives::{SP1ExtensionField, SP1Field, SP1GlobalContext};
17use sp1_recursion_circuit::{
18 basefold::{
19 merkle_tree::MerkleTree, stacked::RecursiveStackedPcsVerifier, tcs::RecursiveMerkleTreeTcs,
20 RecursiveBasefoldVerifier,
21 },
22 jagged::{RecursiveJaggedEvalSumcheckConfig, RecursiveJaggedPcsVerifier},
23 machine::{
24 InnerVal, PublicValuesOutputDigest, SP1CompressRootVerifierWithVKey,
25 SP1CompressWithVKeyVerifier, SP1CompressWithVKeyWitnessValues, SP1DeferredVerifier,
26 SP1DeferredWitnessValues, SP1NormalizeWitnessValues, SP1RecursiveVerifier,
27 },
28 shard::RecursiveShardVerifier,
29 witness::Witnessable,
30 CircuitConfig, SP1FieldConfigVariable, WrapConfig as CircuitWrapConfig,
31};
32use sp1_recursion_compiler::{
33 circuit::AsmCompiler,
34 config::InnerConfig,
35 ir::{Builder, DslIrProgram},
36};
37use sp1_recursion_executor::{RecursionProgram, DIGEST_SIZE};
38
39use crate::{
40 shapes::{create_all_input_shapes, SP1RecursionProofShape},
41 worker::{TaskError, DEFAULT_MAX_COMPOSE_ARITY},
42 CompressAir, RecursionSC,
43};
44
45#[derive(Clone)]
46pub struct RecursionVks {
47 root: <SP1GlobalContext as IopCtx>::Digest,
48 map: BTreeMap<<SP1GlobalContext as IopCtx>::Digest, usize>,
49 tree: MerkleTree<SP1GlobalContext>,
50 vk_verification: bool,
51}
52
53impl Default for RecursionVks {
54 fn default() -> Self {
55 Self::new(None, DEFAULT_MAX_COMPOSE_ARITY, true)
56 }
57}
58
59impl RecursionVks {
60 const RECURSION_VK_MAP_BYTES: &[u8] = include_bytes!("vk_map.bin");
62
63 fn from_map(
64 mut map: BTreeMap<[SP1Field; DIGEST_SIZE], usize>,
65 max_compose_arity: usize,
66 vk_verification: bool,
67 ) -> Self {
68 let num_shapes = create_all_input_shapes(RiscvAir::machine().shape(), max_compose_arity)
71 .into_iter()
72 .collect::<BTreeSet<_>>()
73 .len();
74
75 let added_len = num_shapes.saturating_sub(map.len());
76 let prev_len = map.len();
77
78 map.extend((0..added_len).map(|i| {
79 let index = i + prev_len;
80 ([SP1Field::from_canonical_u32(index as u32); DIGEST_SIZE], index)
81 }));
82
83 let vks = map.into_keys().collect::<BTreeSet<_>>();
84 let map: BTreeMap<_, _> = vks.into_iter().enumerate().map(|(i, vk)| (vk, i)).collect();
85
86 let (root, tree) = MerkleTree::<SP1GlobalContext>::commit(map.keys().copied().collect());
88
89 Self { root, map, tree, vk_verification }
90 }
91
92 fn dummy(max_compose_arity: usize) -> Self {
93 Self::from_map(BTreeMap::new(), max_compose_arity, false)
94 }
95
96 fn from_file(path: PathBuf, max_compose_arity: usize, vk_verification: bool) -> Self {
97 let file = std::fs::File::open(path).expect("failed to open vk map file");
98 let map = bincode::deserialize_from(file).expect("failed to deserialize vk map");
99 Self::from_map(map, max_compose_arity, vk_verification)
100 }
101
102 pub fn new(path: Option<PathBuf>, max_compose_arity: usize, vk_verification: bool) -> Self {
103 if !vk_verification {
104 return Self::dummy(max_compose_arity);
105 }
106
107 if let Some(path) = path {
108 return Self::from_file(path, max_compose_arity, vk_verification);
109 }
110
111 let map = bincode::deserialize(Self::RECURSION_VK_MAP_BYTES)
112 .expect("failed to deserialize vk map");
113 Self::from_map(map, max_compose_arity, vk_verification)
114 }
115
116 pub fn root(&self) -> <SP1GlobalContext as IopCtx>::Digest {
117 self.root
118 }
119
120 pub fn contains(&self, vk: &MachineVerifyingKey<SP1GlobalContext>) -> bool {
121 self.map.contains_key(&vk.hash_koalabear())
122 }
123
124 pub fn num_keys(&self) -> usize {
125 self.map.len()
126 }
127
128 pub fn vk_verification(&self) -> bool {
130 self.vk_verification
131 }
132
133 pub fn open(
134 &self,
135 vk: &MachineVerifyingKey<SP1GlobalContext>,
136 ) -> Result<([SP1Field; DIGEST_SIZE], MerkleProof<SP1GlobalContext>), TaskError> {
137 let index = if self.vk_verification {
138 let digest = vk.hash_koalabear();
139 let index = self
140 .map
141 .get(&digest)
142 .copied()
143 .ok_or(TaskError::Fatal(anyhow::anyhow!("vk not allowed")))?;
144 index
145 } else {
146 let vk_digest = vk.hash_koalabear();
147 let num_vks = self.num_keys();
148 (vk_digest[0].as_canonical_u32() as usize) % num_vks
149 };
150
151 let (value, proof) = MerkleTree::open(&self.tree, index);
152 verify_merkle_proof(&proof, value, self.root)
154 .map_err(|e| TaskError::Fatal(anyhow::anyhow!("invalid merkle proof: {:?}", e)))?;
155
156 Ok((value, proof))
157 }
158
159 pub fn verify(
160 &self,
161 proof: &MerkleProof<SP1GlobalContext>,
162 vk: &MachineVerifyingKey<SP1GlobalContext>,
163 ) -> Result<(), TaskError> {
164 let mut digest = vk.hash_koalabear();
165 if !self.vk_verification {
166 let num_vks = self.num_keys();
167 let vk_index = digest[0].as_canonical_u32() % num_vks as u32;
168 digest = [SP1Field::from_canonical_u32(vk_index); DIGEST_SIZE];
169 }
170 let result = verify_merkle_proof(proof, digest, self.root)
171 .map_err(|e| TaskError::Fatal(anyhow::anyhow!("invalid merkle proof: {:?}", e)));
172 result
173 }
174
175 pub fn height(&self) -> usize {
176 self.tree.height
177 }
178}
179
180pub fn normalize_program_from_input(
183 recursive_verifier: &RecursiveShardVerifier<SP1GlobalContext, RiscvAir<SP1Field>, InnerConfig>,
184 input: &SP1NormalizeWitnessValues<SP1GlobalContext, SP1PcsProofInner>,
185) -> RecursionProgram<SP1Field> {
186 let builder_span = tracing::debug_span!("build recursion program").entered();
188 let mut builder = Builder::<InnerConfig>::default();
189 let input_variable = input.read(&mut builder);
190 SP1RecursiveVerifier::<InnerConfig>::verify(&mut builder, recursive_verifier, input_variable);
191 let block = builder.into_root_block();
192 let dsl_program = unsafe { DslIrProgram::new_unchecked(block) };
195 builder_span.exit();
196
197 let compiler_span = tracing::debug_span!("compile recursion program").entered();
199 let mut compiler = AsmCompiler::default();
200 let program = compiler.compile(dsl_program);
201 compiler_span.exit();
202 program
203}
204
205pub(crate) fn deferred_program_from_input(
207 recursive_verifier: &RecursiveShardVerifier<
208 SP1GlobalContext,
209 CompressAir<InnerVal>,
210 InnerConfig,
211 >,
212 vk_verification: bool,
213 input: &SP1DeferredWitnessValues<SP1GlobalContext, SP1PcsProofInner>,
214) -> RecursionProgram<SP1Field> {
215 let operations_span = tracing::debug_span!("get operations for the deferred program").entered();
217 let mut builder = Builder::<InnerConfig>::default();
218 let input_read_span = tracing::debug_span!("Read input values").entered();
219 let input = input.read(&mut builder);
220 input_read_span.exit();
221 let verify_span = tracing::debug_span!("Verify deferred program").entered();
222
223 SP1DeferredVerifier::verify(&mut builder, recursive_verifier, input, vk_verification);
225 verify_span.exit();
226 let block = builder.into_root_block();
227 operations_span.exit();
228 let dsl_program = unsafe { DslIrProgram::new_unchecked(block) };
231
232 let compiler_span = tracing::debug_span!("compile deferred program").entered();
233 let mut compiler = AsmCompiler::default();
234 let program = compiler.compile(dsl_program);
235 compiler_span.exit();
236 program
237}
238
239pub(crate) fn compose_program_from_input(
241 recursive_verifier: &RecursiveShardVerifier<
242 SP1GlobalContext,
243 CompressAir<InnerVal>,
244 InnerConfig,
245 >,
246 vk_verification: bool,
247 input: &SP1CompressWithVKeyWitnessValues<SP1PcsProofInner>,
248) -> RecursionProgram<SP1Field> {
249 let builder_span = tracing::debug_span!("build compress program").entered();
250 let mut builder = Builder::<InnerConfig>::default();
251 let input = input.read(&mut builder);
253
254 SP1CompressWithVKeyVerifier::<InnerConfig, SP1InnerPcs, _>::verify(
256 &mut builder,
257 recursive_verifier,
258 input,
259 vk_verification,
260 PublicValuesOutputDigest::Reduce,
261 );
262 let block = builder.into_root_block();
263 builder_span.exit();
264 let dsl_program = unsafe { DslIrProgram::new_unchecked(block) };
267
268 let compiler_span = tracing::debug_span!("compile compress program").entered();
270 let mut compiler = AsmCompiler::default();
271 let program = compiler.compile(dsl_program);
272 compiler_span.exit();
273 program
274}
275
276pub(crate) fn shrink_program_from_input(
278 recursive_verifier: &RecursiveShardVerifier<
279 SP1GlobalContext,
280 CompressAir<InnerVal>,
281 InnerConfig,
282 >,
283 vk_verification: bool,
284 input: &SP1CompressWithVKeyWitnessValues<SP1PcsProofInner>,
285) -> RecursionProgram<SP1Field> {
286 let builder_span = tracing::debug_span!("build shrink program").entered();
287 let mut builder = Builder::<InnerConfig>::default();
288 let input = input.read(&mut builder);
290
291 SP1CompressRootVerifierWithVKey::<InnerConfig, _>::verify(
293 &mut builder,
294 recursive_verifier,
295 input,
296 vk_verification,
297 PublicValuesOutputDigest::Reduce,
298 );
299
300 let block = builder.into_root_block();
301 builder_span.exit();
302 let dsl_program = unsafe { DslIrProgram::new_unchecked(block) };
305
306 let compiler_span = tracing::debug_span!("compile shrink program").entered();
308 let mut compiler = AsmCompiler::default();
309 let program = compiler.compile(dsl_program);
310 compiler_span.exit();
311
312 program
313}
314
315pub(crate) fn wrap_program_from_input(
317 recursive_verifier: &RecursiveShardVerifier<
318 SP1GlobalContext,
319 CompressAir<InnerVal>,
320 CircuitWrapConfig,
321 >,
322 vk_verification: bool,
323 input: &SP1CompressWithVKeyWitnessValues<SP1PcsProofInner>,
324) -> RecursionProgram<SP1Field> {
325 let builder_span = tracing::debug_span!("build wrap program").entered();
326 let mut builder = Builder::<CircuitWrapConfig>::default();
327 let input = input.read(&mut builder);
329
330 SP1CompressRootVerifierWithVKey::<CircuitWrapConfig, _>::verify(
332 &mut builder,
333 recursive_verifier,
334 input,
335 vk_verification,
336 PublicValuesOutputDigest::Root,
337 );
338
339 let block = builder.into_root_block();
340 builder_span.exit();
341 let dsl_program = unsafe { DslIrProgram::new_unchecked(block) };
344
345 let compiler_span = tracing::debug_span!("compile wrap program").entered();
347 let mut compiler = AsmCompiler::default();
348 let program = compiler.compile(dsl_program);
349 compiler_span.exit();
350
351 program
352}
353
354pub(crate) fn dummy_compose_input(
355 verifier: &MachineVerifier<SP1GlobalContext, RecursionSC>,
356 shape: &SP1RecursionProofShape,
357 arity: usize,
358 height: usize,
359) -> SP1CompressWithVKeyWitnessValues<SP1PcsProofInner> {
360 let chips =
361 verifier.shard_verifier().machine().chips().iter().cloned().collect::<BTreeSet<_>>();
362
363 let max_log_row_count = verifier.max_log_row_count();
364 let log_stacking_height = verifier.log_stacking_height() as usize;
365
366 shape.dummy_input(
367 arity,
368 height,
369 chips,
370 max_log_row_count,
371 *verifier.fri_config(),
372 log_stacking_height,
373 )
374}
375
376pub(crate) fn dummy_deferred_input(
377 verifier: &MachineVerifier<SP1GlobalContext, RecursionSC>,
378 shape: &SP1RecursionProofShape,
379 height: usize,
380) -> SP1DeferredWitnessValues<SP1GlobalContext, SP1PcsProofInner> {
381 let chips =
382 verifier.shard_verifier().machine().chips().iter().cloned().collect::<BTreeSet<_>>();
383
384 let max_log_row_count = verifier.max_log_row_count();
385 let log_stacking_height = verifier.log_stacking_height() as usize;
386
387 let compress_input = shape.dummy_input(
388 1,
389 height,
390 chips,
391 max_log_row_count,
392 *verifier.fri_config(),
393 log_stacking_height,
394 );
395
396 SP1DeferredWitnessValues {
397 vks_and_proofs: compress_input.compress_val.vks_and_proofs,
398 vk_merkle_data: compress_input.merkle_val,
399 start_reconstruct_deferred_digest: [SP1Field::zero(); POSEIDON_NUM_WORDS],
400 sp1_vk_digest: [SP1Field::zero(); DIGEST_SIZE],
401 end_pc: [SP1Field::zero(); 3],
402 proof_nonce: [SP1Field::zero(); PROOF_NONCE_NUM_WORDS],
403 deferred_proof_index: SP1Field::zero(),
404 }
405}
406
407pub(crate) fn recursive_verifier<GC, A, C>(
408 shard_verifier: &ShardVerifier<GC, SP1SC<GC, A>>,
409) -> RecursiveShardVerifier<GC, A, C>
410where
411 GC: IopCtx<F = SP1Field, EF = SP1ExtensionField> + SP1FieldConfigVariable<C>,
412 A: ZerocheckAir<SP1Field, SP1ExtensionField>,
413 C: CircuitConfig,
414{
415 let log_stacking_height = shard_verifier.log_stacking_height();
416 let max_log_row_count = shard_verifier.max_log_row_count();
417 let machine = shard_verifier.machine().clone();
418 let pcs_verifier = RecursiveBasefoldVerifier {
419 fri_config: shard_verifier.jagged_pcs_verifier.pcs_verifier.basefold_verifier.fri_config,
420 tcs: RecursiveMerkleTreeTcs::<C, GC>(PhantomData),
421 };
422 let recursive_verifier = RecursiveStackedPcsVerifier::new(pcs_verifier, log_stacking_height);
423
424 let recursive_jagged_verifier = RecursiveJaggedPcsVerifier {
425 stacked_pcs_verifier: recursive_verifier,
426 max_log_row_count,
427 jagged_evaluator: RecursiveJaggedEvalSumcheckConfig::<GC>(PhantomData),
428 };
429
430 RecursiveShardVerifier {
431 machine,
432 pcs_verifier: recursive_jagged_verifier,
433 _phantom: std::marker::PhantomData,
434 }
435}