Skip to main content

sp1_prover/
recursion.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    marker::PhantomData,
4    path::PathBuf,
5};
6
7use slop_algebra::{AbstractField, PrimeField32};
8use slop_challenger::IopCtx;
9use sp1_core_machine::riscv::RiscvAir;
10use sp1_hypercube::{
11    air::{POSEIDON_NUM_WORDS, PROOF_NONCE_NUM_WORDS},
12    prover::ZerocheckAir,
13    verify_merkle_proof, HashableKey, MachineVerifier, MachineVerifyingKey, MerkleProof,
14    SP1InnerPcs, SP1PcsProofInner, ShardVerifier, SP1SC,
15};
16use sp1_primitives::{SP1ExtensionField, SP1Field, SP1GlobalContext};
17use sp1_recursion_circuit::{
18    basefold::{
19        merkle_tree::MerkleTree, stacked::RecursiveStackedPcsVerifier, tcs::RecursiveMerkleTreeTcs,
20        RecursiveBasefoldVerifier,
21    },
22    jagged::{RecursiveJaggedEvalSumcheckConfig, RecursiveJaggedPcsVerifier},
23    machine::{
24        InnerVal, PublicValuesOutputDigest, SP1CompressRootVerifierWithVKey,
25        SP1CompressWithVKeyVerifier, SP1CompressWithVKeyWitnessValues, SP1DeferredVerifier,
26        SP1DeferredWitnessValues, SP1NormalizeWitnessValues, SP1RecursiveVerifier,
27    },
28    shard::RecursiveShardVerifier,
29    witness::Witnessable,
30    CircuitConfig, SP1FieldConfigVariable, WrapConfig as CircuitWrapConfig,
31};
32use sp1_recursion_compiler::{
33    circuit::AsmCompiler,
34    config::InnerConfig,
35    ir::{Builder, DslIrProgram},
36};
37use sp1_recursion_executor::{RecursionProgram, DIGEST_SIZE};
38
39use crate::{
40    shapes::{create_all_input_shapes, SP1RecursionProofShape},
41    worker::{TaskError, DEFAULT_MAX_COMPOSE_ARITY},
42    CompressAir, RecursionSC,
43};
44
45#[derive(Clone)]
46pub struct RecursionVks {
47    root: <SP1GlobalContext as IopCtx>::Digest,
48    map: BTreeMap<<SP1GlobalContext as IopCtx>::Digest, usize>,
49    tree: MerkleTree<SP1GlobalContext>,
50    vk_verification: bool,
51}
52
53impl Default for RecursionVks {
54    fn default() -> Self {
55        Self::new(None, DEFAULT_MAX_COMPOSE_ARITY, true)
56    }
57}
58
59impl RecursionVks {
60    /// The map for the recursion vk hashes to their indice in the merkle tree.
61    const RECURSION_VK_MAP_BYTES: &[u8] = include_bytes!("vk_map.bin");
62
63    fn from_map(
64        mut map: BTreeMap<[SP1Field; DIGEST_SIZE], usize>,
65        max_compose_arity: usize,
66        vk_verification: bool,
67    ) -> Self {
68        // Pad the map to the expected number of shapes. This allows us to build partial vk maps
69        // for development purposes.
70        let num_shapes = create_all_input_shapes(RiscvAir::machine().shape(), max_compose_arity)
71            .into_iter()
72            .collect::<BTreeSet<_>>()
73            .len();
74
75        let added_len = num_shapes.saturating_sub(map.len());
76        let prev_len = map.len();
77
78        map.extend((0..added_len).map(|i| {
79            let index = i + prev_len;
80            ([SP1Field::from_canonical_u32(index as u32); DIGEST_SIZE], index)
81        }));
82
83        let vks = map.into_keys().collect::<BTreeSet<_>>();
84        let map: BTreeMap<_, _> = vks.into_iter().enumerate().map(|(i, vk)| (vk, i)).collect();
85
86        // Commit the merkle tree.
87        let (root, tree) = MerkleTree::<SP1GlobalContext>::commit(map.keys().copied().collect());
88
89        Self { root, map, tree, vk_verification }
90    }
91
92    fn dummy(max_compose_arity: usize) -> Self {
93        Self::from_map(BTreeMap::new(), max_compose_arity, false)
94    }
95
96    fn from_file(path: PathBuf, max_compose_arity: usize, vk_verification: bool) -> Self {
97        let file = std::fs::File::open(path).expect("failed to open vk map file");
98        let map = bincode::deserialize_from(file).expect("failed to deserialize vk map");
99        Self::from_map(map, max_compose_arity, vk_verification)
100    }
101
102    pub fn new(path: Option<PathBuf>, max_compose_arity: usize, vk_verification: bool) -> Self {
103        if !vk_verification {
104            return Self::dummy(max_compose_arity);
105        }
106
107        if let Some(path) = path {
108            return Self::from_file(path, max_compose_arity, vk_verification);
109        }
110
111        let map = bincode::deserialize(Self::RECURSION_VK_MAP_BYTES)
112            .expect("failed to deserialize vk map");
113        Self::from_map(map, max_compose_arity, vk_verification)
114    }
115
116    pub fn root(&self) -> <SP1GlobalContext as IopCtx>::Digest {
117        self.root
118    }
119
120    pub fn contains(&self, vk: &MachineVerifyingKey<SP1GlobalContext>) -> bool {
121        self.map.contains_key(&vk.hash_koalabear())
122    }
123
124    pub fn num_keys(&self) -> usize {
125        self.map.len()
126    }
127
128    /// Whether to verify the recursion vks.
129    pub fn vk_verification(&self) -> bool {
130        self.vk_verification
131    }
132
133    pub fn open(
134        &self,
135        vk: &MachineVerifyingKey<SP1GlobalContext>,
136    ) -> Result<([SP1Field; DIGEST_SIZE], MerkleProof<SP1GlobalContext>), TaskError> {
137        let index = if self.vk_verification {
138            let digest = vk.hash_koalabear();
139            let index = self
140                .map
141                .get(&digest)
142                .copied()
143                .ok_or(TaskError::Fatal(anyhow::anyhow!("vk not allowed")))?;
144            index
145        } else {
146            let vk_digest = vk.hash_koalabear();
147            let num_vks = self.num_keys();
148            (vk_digest[0].as_canonical_u32() as usize) % num_vks
149        };
150
151        let (value, proof) = MerkleTree::open(&self.tree, index);
152        // Verify the proof.
153        verify_merkle_proof(&proof, value, self.root)
154            .map_err(|e| TaskError::Fatal(anyhow::anyhow!("invalid merkle proof: {:?}", e)))?;
155
156        Ok((value, proof))
157    }
158
159    pub fn verify(
160        &self,
161        proof: &MerkleProof<SP1GlobalContext>,
162        vk: &MachineVerifyingKey<SP1GlobalContext>,
163    ) -> Result<(), TaskError> {
164        let mut digest = vk.hash_koalabear();
165        if !self.vk_verification {
166            let num_vks = self.num_keys();
167            let vk_index = digest[0].as_canonical_u32() % num_vks as u32;
168            digest = [SP1Field::from_canonical_u32(vk_index); DIGEST_SIZE];
169        }
170        let result = verify_merkle_proof(proof, digest, self.root)
171            .map_err(|e| TaskError::Fatal(anyhow::anyhow!("invalid merkle proof: {:?}", e)));
172        result
173    }
174
175    pub fn height(&self) -> usize {
176        self.tree.height
177    }
178}
179
180/// The program that proves the correct execution of the verifier of a single shard of the core
181/// (RISC-V) machine.
182pub fn normalize_program_from_input(
183    recursive_verifier: &RecursiveShardVerifier<SP1GlobalContext, RiscvAir<SP1Field>, InnerConfig>,
184    input: &SP1NormalizeWitnessValues<SP1GlobalContext, SP1PcsProofInner>,
185) -> RecursionProgram<SP1Field> {
186    // Get the operations.
187    let builder_span = tracing::debug_span!("build recursion program").entered();
188    let mut builder = Builder::<InnerConfig>::default();
189    let input_variable = input.read(&mut builder);
190    SP1RecursiveVerifier::<InnerConfig>::verify(&mut builder, recursive_verifier, input_variable);
191    let block = builder.into_root_block();
192    // SAFETY: The circuit is well-formed. It does not use synchronization primitives
193    // (or possibly other means) to violate the invariants.
194    let dsl_program = unsafe { DslIrProgram::new_unchecked(block) };
195    builder_span.exit();
196
197    // Compile the program.
198    let compiler_span = tracing::debug_span!("compile recursion program").entered();
199    let mut compiler = AsmCompiler::default();
200    let program = compiler.compile(dsl_program);
201    compiler_span.exit();
202    program
203}
204
205/// The deferred program.
206pub(crate) fn deferred_program_from_input(
207    recursive_verifier: &RecursiveShardVerifier<
208        SP1GlobalContext,
209        CompressAir<InnerVal>,
210        InnerConfig,
211    >,
212    vk_verification: bool,
213    input: &SP1DeferredWitnessValues<SP1GlobalContext, SP1PcsProofInner>,
214) -> RecursionProgram<SP1Field> {
215    // Get the operations.
216    let operations_span = tracing::debug_span!("get operations for the deferred program").entered();
217    let mut builder = Builder::<InnerConfig>::default();
218    let input_read_span = tracing::debug_span!("Read input values").entered();
219    let input = input.read(&mut builder);
220    input_read_span.exit();
221    let verify_span = tracing::debug_span!("Verify deferred program").entered();
222
223    // Verify the proof.
224    SP1DeferredVerifier::verify(&mut builder, recursive_verifier, input, vk_verification);
225    verify_span.exit();
226    let block = builder.into_root_block();
227    operations_span.exit();
228    // SAFETY: The circuit is well-formed. It does not use synchronization primitives
229    // (or possibly other means) to violate the invariants.
230    let dsl_program = unsafe { DslIrProgram::new_unchecked(block) };
231
232    let compiler_span = tracing::debug_span!("compile deferred program").entered();
233    let mut compiler = AsmCompiler::default();
234    let program = compiler.compile(dsl_program);
235    compiler_span.exit();
236    program
237}
238
239/// The "compose" program, which verifies some number of normalized shard proofs.
240pub(crate) fn compose_program_from_input(
241    recursive_verifier: &RecursiveShardVerifier<
242        SP1GlobalContext,
243        CompressAir<InnerVal>,
244        InnerConfig,
245    >,
246    vk_verification: bool,
247    input: &SP1CompressWithVKeyWitnessValues<SP1PcsProofInner>,
248) -> RecursionProgram<SP1Field> {
249    let builder_span = tracing::debug_span!("build compress program").entered();
250    let mut builder = Builder::<InnerConfig>::default();
251    // read the input.
252    let input = input.read(&mut builder);
253
254    // Verify the proof.
255    SP1CompressWithVKeyVerifier::<InnerConfig, SP1InnerPcs, _>::verify(
256        &mut builder,
257        recursive_verifier,
258        input,
259        vk_verification,
260        PublicValuesOutputDigest::Reduce,
261    );
262    let block = builder.into_root_block();
263    builder_span.exit();
264    // SAFETY: The circuit is well-formed. It does not use synchronization primitives
265    // (or possibly other means) to violate the invariants.
266    let dsl_program = unsafe { DslIrProgram::new_unchecked(block) };
267
268    // Compile the program.
269    let compiler_span = tracing::debug_span!("compile compress program").entered();
270    let mut compiler = AsmCompiler::default();
271    let program = compiler.compile(dsl_program);
272    compiler_span.exit();
273    program
274}
275
276/// The "shrink" program, which only verifies the single root shard.
277pub(crate) fn shrink_program_from_input(
278    recursive_verifier: &RecursiveShardVerifier<
279        SP1GlobalContext,
280        CompressAir<InnerVal>,
281        InnerConfig,
282    >,
283    vk_verification: bool,
284    input: &SP1CompressWithVKeyWitnessValues<SP1PcsProofInner>,
285) -> RecursionProgram<SP1Field> {
286    let builder_span = tracing::debug_span!("build shrink program").entered();
287    let mut builder = Builder::<InnerConfig>::default();
288    // read the input.
289    let input = input.read(&mut builder);
290
291    // Verify the root proof.
292    SP1CompressRootVerifierWithVKey::<InnerConfig, _>::verify(
293        &mut builder,
294        recursive_verifier,
295        input,
296        vk_verification,
297        PublicValuesOutputDigest::Reduce,
298    );
299
300    let block = builder.into_root_block();
301    builder_span.exit();
302    // SAFETY: The circuit is well-formed. It does not use synchronization primitives
303    // (or possibly other means) to violate the invariants.
304    let dsl_program = unsafe { DslIrProgram::new_unchecked(block) };
305
306    // Compile the program.
307    let compiler_span = tracing::debug_span!("compile shrink program").entered();
308    let mut compiler = AsmCompiler::default();
309    let program = compiler.compile(dsl_program);
310    compiler_span.exit();
311
312    program
313}
314
315/// The "wrap" program, which only verifies the single root shard.
316pub(crate) fn wrap_program_from_input(
317    recursive_verifier: &RecursiveShardVerifier<
318        SP1GlobalContext,
319        CompressAir<InnerVal>,
320        CircuitWrapConfig,
321    >,
322    vk_verification: bool,
323    input: &SP1CompressWithVKeyWitnessValues<SP1PcsProofInner>,
324) -> RecursionProgram<SP1Field> {
325    let builder_span = tracing::debug_span!("build wrap program").entered();
326    let mut builder = Builder::<CircuitWrapConfig>::default();
327    // read the input.
328    let input = input.read(&mut builder);
329
330    // Verify the root proof.
331    SP1CompressRootVerifierWithVKey::<CircuitWrapConfig, _>::verify(
332        &mut builder,
333        recursive_verifier,
334        input,
335        vk_verification,
336        PublicValuesOutputDigest::Root,
337    );
338
339    let block = builder.into_root_block();
340    builder_span.exit();
341    // SAFETY: The circuit is well-formed. It does not use synchronization primitives
342    // (or possibly other means) to violate the invariants.
343    let dsl_program = unsafe { DslIrProgram::new_unchecked(block) };
344
345    // Compile the program.
346    let compiler_span = tracing::debug_span!("compile wrap program").entered();
347    let mut compiler = AsmCompiler::default();
348    let program = compiler.compile(dsl_program);
349    compiler_span.exit();
350
351    program
352}
353
354pub(crate) fn dummy_compose_input(
355    verifier: &MachineVerifier<SP1GlobalContext, RecursionSC>,
356    shape: &SP1RecursionProofShape,
357    arity: usize,
358    height: usize,
359) -> SP1CompressWithVKeyWitnessValues<SP1PcsProofInner> {
360    let chips =
361        verifier.shard_verifier().machine().chips().iter().cloned().collect::<BTreeSet<_>>();
362
363    let max_log_row_count = verifier.max_log_row_count();
364    let log_stacking_height = verifier.log_stacking_height() as usize;
365
366    shape.dummy_input(
367        arity,
368        height,
369        chips,
370        max_log_row_count,
371        *verifier.fri_config(),
372        log_stacking_height,
373    )
374}
375
376pub(crate) fn dummy_deferred_input(
377    verifier: &MachineVerifier<SP1GlobalContext, RecursionSC>,
378    shape: &SP1RecursionProofShape,
379    height: usize,
380) -> SP1DeferredWitnessValues<SP1GlobalContext, SP1PcsProofInner> {
381    let chips =
382        verifier.shard_verifier().machine().chips().iter().cloned().collect::<BTreeSet<_>>();
383
384    let max_log_row_count = verifier.max_log_row_count();
385    let log_stacking_height = verifier.log_stacking_height() as usize;
386
387    let compress_input = shape.dummy_input(
388        1,
389        height,
390        chips,
391        max_log_row_count,
392        *verifier.fri_config(),
393        log_stacking_height,
394    );
395
396    SP1DeferredWitnessValues {
397        vks_and_proofs: compress_input.compress_val.vks_and_proofs,
398        vk_merkle_data: compress_input.merkle_val,
399        start_reconstruct_deferred_digest: [SP1Field::zero(); POSEIDON_NUM_WORDS],
400        sp1_vk_digest: [SP1Field::zero(); DIGEST_SIZE],
401        end_pc: [SP1Field::zero(); 3],
402        proof_nonce: [SP1Field::zero(); PROOF_NONCE_NUM_WORDS],
403        deferred_proof_index: SP1Field::zero(),
404    }
405}
406
407pub(crate) fn recursive_verifier<GC, A, C>(
408    shard_verifier: &ShardVerifier<GC, SP1SC<GC, A>>,
409) -> RecursiveShardVerifier<GC, A, C>
410where
411    GC: IopCtx<F = SP1Field, EF = SP1ExtensionField> + SP1FieldConfigVariable<C>,
412    A: ZerocheckAir<SP1Field, SP1ExtensionField>,
413    C: CircuitConfig,
414{
415    let log_stacking_height = shard_verifier.log_stacking_height();
416    let max_log_row_count = shard_verifier.max_log_row_count();
417    let machine = shard_verifier.machine().clone();
418    let pcs_verifier = RecursiveBasefoldVerifier {
419        fri_config: shard_verifier.jagged_pcs_verifier.pcs_verifier.basefold_verifier.fri_config,
420        tcs: RecursiveMerkleTreeTcs::<C, GC>(PhantomData),
421    };
422    let recursive_verifier = RecursiveStackedPcsVerifier::new(pcs_verifier, log_stacking_height);
423
424    let recursive_jagged_verifier = RecursiveJaggedPcsVerifier {
425        stacked_pcs_verifier: recursive_verifier,
426        max_log_row_count,
427        jagged_evaluator: RecursiveJaggedEvalSumcheckConfig::<GC>(PhantomData),
428    };
429
430    RecursiveShardVerifier {
431        machine,
432        pcs_verifier: recursive_jagged_verifier,
433        _phantom: std::marker::PhantomData,
434    }
435}