Skip to main content

sp1_prover/worker/prover/
recursion.rs

1use crate::{
2    build::{try_build_groth16_artifacts_dir, try_build_plonk_artifacts_dir},
3    recursion::{
4        compose_program_from_input, deferred_program_from_input, dummy_deferred_input,
5        recursive_verifier, shrink_program_from_input, wrap_program_from_input, RecursionVks,
6    },
7    shapes::SP1RecursionProofShape,
8    verify::WRAP_VK_BYTES,
9    worker::{
10        CommonProverInput, DeferredInputs, ProverMetrics, RangeProofs, RawTaskRequest, TaskContext,
11        TaskError, TaskMetadata, WrapAirProverInit,
12    },
13    RecursionSC, SP1CircuitWitness, SP1ProverComponents,
14};
15use slop_algebra::PrimeField32;
16use slop_algebra::{AbstractField, PrimeField};
17use slop_bn254::Bn254Fr;
18use slop_challenger::IopCtx;
19use slop_futures::pipeline::{
20    AsyncEngine, AsyncWorker, BlockingEngine, BlockingWorker, Chain, Pipeline, SubmitError,
21    SubmitHandle,
22};
23use sp1_hypercube::{
24    inner_perm, koalabears_to_bn254,
25    prover::{AirProver, ProverSemaphore, ProvingKey},
26    HashableKey, MachineProof, MachineVerifier, MachineVerifyingKey, MerkleProof, SP1PcsProofInner,
27    SP1PcsProofOuter, SP1RecursionProof, SP1WrapProof, ShardProof, DIGEST_SIZE,
28};
29use sp1_primitives::{SP1ExtensionField, SP1Field, SP1GlobalContext, SP1OuterGlobalContext};
30use sp1_prover_types::{Artifact, ArtifactClient, ArtifactId};
31use sp1_recursion_circuit::{
32    machine::{
33        SP1CompressWithVKeyWitnessValues, SP1MerkleProofWitnessValues, SP1NormalizeWitnessValues,
34        SP1ShapedWitnessValues,
35    },
36    utils::{koalabear_bytes_to_bn254, koalabears_proof_nonce_to_bn254, words_to_bytes},
37    witness::{OuterWitness, Witnessable},
38    WrapConfig,
39};
40use sp1_recursion_compiler::config::InnerConfig;
41use sp1_recursion_executor::{
42    shape::RecursionShape, Block, ExecutionRecord, Executor, RecursionProgram,
43    RecursionPublicValues,
44};
45use sp1_recursion_gnark_ffi::{Groth16Bn254Prover, PlonkBn254Prover};
46use std::{
47    borrow::Borrow,
48    collections::{BTreeMap, BTreeSet, VecDeque},
49    sync::Arc,
50};
51use tokio::sync::{oneshot, OnceCell};
52use tracing::Instrument;
53
54/// Configuration for the recursion prover.
55#[derive(Debug, Clone)]
56pub struct SP1RecursionProverConfig {
57    /// The number of prepare reduce workers.
58    pub num_prepare_reduce_workers: usize,
59    /// The buffer size for the prepare reduce.
60    pub prepare_reduce_buffer_size: usize,
61    /// The number of recursion executor workers.
62    pub num_recursion_executor_workers: usize,
63    /// The buffer size for the recursion executor.
64    pub recursion_executor_buffer_size: usize,
65    /// The number of recursion prover workers.
66    pub num_recursion_prover_workers: usize,
67    /// The buffer size for the recursion prover.
68    pub recursion_prover_buffer_size: usize,
69    /// The maximum compose arity.
70    pub max_compose_arity: usize,
71    /// Whether to verify the recursion vks. Should be true by default and only can be set to false
72    /// manually for code that is feature-gated behind the `experimental` flag.
73    vk_verification: bool,
74    /// Whether or not to verify the proof result at the end.
75    pub verify_intermediates: bool,
76    /// An optional file path for the vk map. Should be `None` by default and only can be set manually
77    /// for code that is feature-gated behind the `experimental` flag.
78    vk_map_file: Option<String>,
79}
80
81impl SP1RecursionProverConfig {
82    #[allow(clippy::too_many_arguments)]
83    pub fn new(
84        num_prepare_reduce_workers: usize,
85        prepare_reduce_buffer_size: usize,
86        num_recursion_executor_workers: usize,
87        recursion_executor_buffer_size: usize,
88        num_recursion_prover_workers: usize,
89        recursion_prover_buffer_size: usize,
90        max_compose_arity: usize,
91        verify_intermediates: bool,
92    ) -> Self {
93        Self {
94            num_prepare_reduce_workers,
95            prepare_reduce_buffer_size,
96            num_recursion_executor_workers,
97            recursion_executor_buffer_size,
98            num_recursion_prover_workers,
99            recursion_prover_buffer_size,
100            max_compose_arity,
101            vk_verification: true,
102            verify_intermediates,
103            vk_map_file: None,
104        }
105    }
106    #[cfg(feature = "experimental")]
107    /// Turn off vk verification for recursion proofs.
108    pub fn without_vk_verification(self) -> Self {
109        Self { vk_verification: false, ..self }
110    }
111
112    #[cfg(feature = "experimental")]
113    /// Set the path to the recursion vk map.
114    pub fn with_vk_map_path(self, vk_map_path: String) -> Self {
115        Self { vk_map_file: Some(vk_map_path), ..self }
116    }
117}
118
119pub struct ReduceTaskRequest {
120    pub range_proofs: RangeProofs,
121    pub is_complete: bool,
122    pub output: Artifact,
123    pub context: TaskContext,
124}
125
126impl ReduceTaskRequest {
127    pub fn from_raw(request: RawTaskRequest) -> Result<Self, TaskError> {
128        let RawTaskRequest { inputs, mut outputs, context } = request;
129        let is_complete = inputs[0].id().parse::<bool>().map_err(|e| TaskError::Fatal(e.into()))?;
130        let range_proofs = RangeProofs::from_artifacts(&inputs[1..])?;
131        let output =
132            outputs.pop().ok_or(TaskError::Fatal(anyhow::anyhow!("No output artifact")))?;
133        Ok(ReduceTaskRequest { range_proofs, is_complete, output, context })
134    }
135
136    pub fn into_raw(self) -> Result<RawTaskRequest, TaskError> {
137        let ReduceTaskRequest { range_proofs, is_complete, output, context } = self;
138        let is_complete_artifact = Artifact::from(is_complete.to_string());
139        let mut inputs = Vec::with_capacity(2 * range_proofs.len() + 2);
140        inputs.push(is_complete_artifact);
141        inputs.extend(range_proofs.as_artifacts());
142        let raw_task_request = RawTaskRequest { inputs, outputs: vec![output], context };
143        Ok(raw_task_request)
144    }
145}
146
147pub struct PrepareReduceTaskWorker<A, C: SP1ProverComponents> {
148    prover_data: Arc<RecursionProverData<C>>,
149    artifact_client: A,
150}
151
152impl<A: ArtifactClient, C: SP1ProverComponents>
153    AsyncWorker<ReduceTaskRequest, Result<RecursionTask, TaskError>>
154    for PrepareReduceTaskWorker<A, C>
155{
156    #[tracing::instrument(level = "trace", name = "prepare_reduce_task", skip(self, input))]
157    async fn call(&self, input: ReduceTaskRequest) -> Result<RecursionTask, TaskError> {
158        let ReduceTaskRequest { range_proofs, is_complete, output, .. } = input;
159
160        let program = self.prover_data.compose_programs.get(&range_proofs.len()).cloned().ok_or(
161            TaskError::Fatal(anyhow::anyhow!(
162                "Compress program not found for arity {}",
163                range_proofs.len()
164            )),
165        )?;
166
167        let witness = range_proofs
168            .download_witness::<C>(is_complete, &self.artifact_client, &self.prover_data)
169            .await?;
170
171        let metrics = ProverMetrics::new();
172        Ok(RecursionTask {
173            program,
174            witness,
175            output,
176            metrics,
177            range_proofs_to_cleanup: Some(range_proofs),
178        })
179    }
180}
181
182pub struct RecursionTask {
183    program: Arc<RecursionProgram<SP1Field>>,
184    witness: SP1CircuitWitness,
185    range_proofs_to_cleanup: Option<RangeProofs>,
186    output: Artifact,
187    metrics: ProverMetrics,
188}
189
190pub struct RecursionExecutorWorker<C: SP1ProverComponents> {
191    compress_verifier: MachineVerifier<SP1GlobalContext, RecursionSC>,
192    prover_data: Arc<RecursionProverData<C>>,
193}
194
195impl<C: SP1ProverComponents>
196    BlockingWorker<Result<RecursionTask, TaskError>, Result<ProveRecursionTask<C>, TaskError>>
197    for RecursionExecutorWorker<C>
198{
199    fn call(
200        &self,
201        input: Result<RecursionTask, TaskError>,
202    ) -> Result<ProveRecursionTask<C>, TaskError> {
203        let RecursionTask { program, witness, output, metrics, range_proofs_to_cleanup } = input?;
204
205        // Execute the runtime.
206        let runtime_span = tracing::debug_span!("execute runtime").entered();
207        let mut runtime =
208            Executor::<SP1Field, SP1ExtensionField, _>::new(program.clone(), inner_perm());
209        runtime.witness_stream = self.prover_data.witness_stream(&witness)?;
210        runtime.run().map_err(|e| TaskError::Fatal(e.into()))?;
211        let mut record = runtime.record;
212        runtime_span.exit();
213
214        tokio::task::spawn_blocking(move || {
215            drop(runtime.memory);
216            drop(runtime.program);
217            drop(runtime.witness_stream);
218        });
219
220        // Generate the dependencies.
221        tracing::debug_span!("generate dependencies").in_scope(|| {
222            self.compress_verifier
223                .machine()
224                .generate_dependencies(std::iter::once(&mut record), None)
225        });
226
227        let keys = tracing::debug_span!("get keys").in_scope(|| match witness {
228            SP1CircuitWitness::Core(_) => anyhow::Ok(RecursionKeys::Program(program)),
229            SP1CircuitWitness::Compress(input) => {
230                let arity = input.compress_val.vks_and_proofs.len();
231                let (pk, vk) = self.prover_data.compose_keys.get(&arity).cloned().ok_or(
232                    TaskError::Fatal(anyhow::anyhow!("Compose key not found for arity {}", arity)),
233                )?;
234                anyhow::Ok(RecursionKeys::Exists(pk, vk))
235            }
236            SP1CircuitWitness::Deferred(_) => {
237                let keys = self
238                    .prover_data
239                    .deferred_keys
240                    .clone()
241                    .map(|(pk, vk)| RecursionKeys::Exists(pk, vk))
242                    .unwrap_or_else(|| {
243                        RecursionKeys::Program(self.prover_data.deferred_program.clone())
244                    });
245                anyhow::Ok(keys)
246            }
247            _ => unimplemented!(),
248        })?;
249
250        Ok(ProveRecursionTask { record, keys, output, metrics, range_proofs_to_cleanup })
251    }
252}
253
254pub type CompressProvingKey<C> =
255    ProvingKey<SP1GlobalContext, RecursionSC, <C as SP1ProverComponents>::RecursionProver>;
256
257enum RecursionKeys<C: SP1ProverComponents> {
258    Exists(Arc<CompressProvingKey<C>>, MachineVerifyingKey<SP1GlobalContext>),
259    Program(Arc<RecursionProgram<SP1Field>>),
260}
261
262pub struct ProveRecursionTask<C: SP1ProverComponents> {
263    record: ExecutionRecord<SP1Field>,
264    keys: RecursionKeys<C>,
265    output: Artifact,
266    metrics: ProverMetrics,
267    range_proofs_to_cleanup: Option<RangeProofs>,
268}
269
270pub struct RecursionProverWorker<A, C: SP1ProverComponents> {
271    recursion_prover: Arc<C::RecursionProver>,
272    permits: ProverSemaphore,
273    artifact_client: A,
274    verify_intermediates: bool,
275    prover_data: Arc<RecursionProverData<C>>,
276}
277
278impl<A: ArtifactClient, C: SP1ProverComponents> RecursionProverWorker<A, C> {
279    async fn prove_shard(
280        &self,
281        keys: RecursionKeys<C>,
282        record: ExecutionRecord<SP1Field>,
283        metrics: ProverMetrics,
284    ) -> Result<SP1RecursionProof<SP1GlobalContext, SP1PcsProofInner>, TaskError> {
285        let proof = match keys {
286            RecursionKeys::Exists(pk, vk) => {
287                let (proof, permit) = self
288                    .recursion_prover
289                    .prove_shard_with_pk(pk.clone(), record, self.permits.clone())
290                    .await;
291                let duration = permit.release();
292                metrics.increment_permit_time(duration);
293
294                if self.verify_intermediates {
295                    let proof = proof.clone();
296                    let vk = vk.clone();
297                    let parent = tracing::Span::current();
298                    tokio::task::spawn_blocking(move || {
299                        let _guard = parent.enter();
300                        C::compress_verifier()
301                            .verify(&vk, &MachineProof::from(vec![proof]))
302                            .map_err(|e| {
303                                TaskError::Retryable(anyhow::anyhow!(
304                                    "compress verify failed: {}",
305                                    e
306                                ))
307                            })
308                    })
309                    .await
310                    .map_err(|e| TaskError::Fatal(e.into()))??;
311                }
312                let vk_merkle_proof = self.prover_data.recursion_vks.open(&vk)?.1;
313                SP1RecursionProof { vk, proof, vk_merkle_proof }
314            }
315            RecursionKeys::Program(program) => {
316                let (vk, proof, permit) = self
317                    .recursion_prover
318                    .setup_and_prove_shard(program, record, None, self.permits.clone())
319                    .await;
320                let duration = permit.release();
321                metrics.increment_permit_time(duration);
322                if self.verify_intermediates {
323                    let proof = proof.clone();
324                    let vk = vk.clone();
325                    let parent = tracing::Span::current();
326                    tokio::task::spawn_blocking(move || {
327                        let _guard = parent.enter();
328                        C::compress_verifier()
329                            .verify(&vk, &MachineProof::from(vec![proof.clone()]))
330                            .map_err(|e| {
331                                TaskError::Retryable(anyhow::anyhow!(
332                                    "lift/deferred verify failed: {}",
333                                    e
334                                ))
335                            })
336                    })
337                    .await
338                    .map_err(|e| TaskError::Fatal(e.into()))??;
339                }
340                let vk_merkle_proof = self.prover_data.recursion_vks.open(&vk)?.1;
341                SP1RecursionProof { vk, proof, vk_merkle_proof }
342            }
343        };
344        Ok(proof)
345    }
346}
347
348impl<A: ArtifactClient, C: SP1ProverComponents>
349    AsyncWorker<Result<ProveRecursionTask<C>, TaskError>, Result<TaskMetadata, TaskError>>
350    for RecursionProverWorker<A, C>
351{
352    async fn call(
353        &self,
354        input: Result<ProveRecursionTask<C>, TaskError>,
355    ) -> Result<TaskMetadata, TaskError> {
356        // Get the input or return an error
357        let ProveRecursionTask { record, keys, output, metrics, range_proofs_to_cleanup } = input?;
358        // Prove the shard
359        let proof = self.prove_shard(keys, record, metrics.clone()).await?;
360        // Upload the proof
361
362        self.artifact_client.upload(&output, proof.clone()).await?;
363        let metadata = metrics.to_metadata();
364
365        // Delete the proofs to cleanup.
366        if let Some(proofs_to_cleanup) = range_proofs_to_cleanup {
367            proofs_to_cleanup.try_delete_proofs(&self.artifact_client).await?;
368        }
369
370        Ok(metadata)
371    }
372}
373
374type ExecutorEngine<C> = Arc<
375    BlockingEngine<
376        Result<RecursionTask, TaskError>,
377        Result<ProveRecursionTask<C>, TaskError>,
378        RecursionExecutorWorker<C>,
379    >,
380>;
381
382type RecursionProverEngine<A, C> = Arc<
383    AsyncEngine<
384        Result<ProveRecursionTask<C>, TaskError>,
385        Result<TaskMetadata, TaskError>,
386        RecursionProverWorker<A, C>,
387    >,
388>;
389
390type PrepareReduceEngine<A, C> = Arc<
391    AsyncEngine<ReduceTaskRequest, Result<RecursionTask, TaskError>, PrepareReduceTaskWorker<A, C>>,
392>;
393
394type RecursionProvePipeline<A, C> = Chain<ExecutorEngine<C>, RecursionProverEngine<A, C>>;
395
396type ReducePipeline<A, C> = Chain<PrepareReduceEngine<A, C>, Arc<RecursionProvePipeline<A, C>>>;
397
398pub type RecursionProveSubmitHandle<A, C> = SubmitHandle<RecursionProvePipeline<A, C>>;
399
400pub type ReduceSubmitHandle<A, C> = SubmitHandle<ReducePipeline<A, C>>;
401
402pub struct SP1RecursionProver<A, C: SP1ProverComponents> {
403    reduce_pipeline: Arc<ReducePipeline<A, C>>,
404    pub shrink_prover: Arc<ShrinkProver<C>>,
405    wrap_prover: Arc<OnceCell<Arc<WrapProver<C>>>>,
406    wrap_prover_init: Arc<WrapProverInit<C>>,
407    pub prover_data: Arc<RecursionProverData<C>>,
408    artifact_client: A,
409}
410
411struct WrapProverInit<C: SP1ProverComponents> {
412    wrap_air_prover: WrapAirProverInit<C>,
413    config: SP1RecursionProverConfig,
414    shrink_shape: BTreeMap<String, usize>,
415    expected_wrap_vk: MachineVerifyingKey<SP1OuterGlobalContext>,
416}
417
418impl<A: Clone, C: SP1ProverComponents> Clone for SP1RecursionProver<A, C> {
419    fn clone(&self) -> Self {
420        Self {
421            reduce_pipeline: self.reduce_pipeline.clone(),
422            shrink_prover: self.shrink_prover.clone(),
423            wrap_prover: self.wrap_prover.clone(),
424            wrap_prover_init: self.wrap_prover_init.clone(),
425            prover_data: self.prover_data.clone(),
426            artifact_client: self.artifact_client.clone(),
427        }
428    }
429}
430
431impl<A: ArtifactClient, C: SP1ProverComponents> SP1RecursionProver<A, C> {
432    pub async fn new(
433        config: SP1RecursionProverConfig,
434        artifact_client: A,
435        (compress_prover, compress_prover_permits): (Arc<C::RecursionProver>, ProverSemaphore),
436        (shrink_prover, shrink_prover_permits): (Arc<C::RecursionProver>, ProverSemaphore),
437        wrap_air_prover_init: WrapAirProverInit<C>,
438    ) -> Self {
439        tokio::task::spawn_blocking(move || {
440            // Get the reduce shape.
441            let reduce_shape =
442                SP1RecursionProofShape::compress_proof_shape_from_arity(config.max_compose_arity)
443                    .expect("arity not supported");
444
445            // Make the reduce programs and keys.
446            let mut compose_programs = BTreeMap::new();
447            let mut compose_keys = BTreeMap::new();
448
449            let vk_map_path = config.vk_map_file.as_ref().map(std::path::PathBuf::from);
450
451            let recursion_vks =
452                RecursionVks::new(vk_map_path, config.max_compose_arity, config.vk_verification);
453
454            let recursion_vks_height = recursion_vks.height();
455
456            let compress_verifier = C::compress_verifier();
457            let recursive_compress_verifier =
458                recursive_verifier::<SP1GlobalContext, _,  InnerConfig>(
459                    compress_verifier.shard_verifier(),
460                );
461            for arity in 1..=config.max_compose_arity {
462                let dummy_input =
463                    dummy_compose_input::<C>(&reduce_shape, arity, recursion_vks_height);
464                let mut program = compose_program_from_input(
465                    &recursive_compress_verifier,
466                    config.vk_verification,
467                    &dummy_input,
468                );
469                program.shape = Some(reduce_shape.shape.clone());
470                let program = Arc::new(program);
471
472                // Make the reduce keys.
473                let (tx, rx) = oneshot::channel();
474                tokio::task::spawn({
475                    let program = program.clone();
476                    let air_prover = compress_prover.clone();
477                    async move {
478                        let permits = ProverSemaphore::new(1);
479                        let (pk, vk) = air_prover.setup(program, permits).await;
480                        tx.send((pk, vk)).ok();
481                    }
482                });
483                let (pk, vk) = rx.blocking_recv().unwrap();
484                let pk = unsafe { pk.into_inner() };
485                compose_keys.insert(arity, (pk, vk));
486                compose_programs.insert(arity, program);
487            }
488
489            // Make the deferred program and keys.
490            let deferred_input =
491                dummy_deferred_input(&compress_verifier, &reduce_shape, recursion_vks_height);
492            let mut deferred_program = deferred_program_from_input(
493                &recursive_compress_verifier,
494                config.vk_verification,
495                &deferred_input,
496            );
497            deferred_program.shape = Some(reduce_shape.shape.clone());
498            let deferred_program = Arc::new(deferred_program);
499            let (tx, rx) = oneshot::channel();
500            tokio::task::spawn({
501                let program = deferred_program.clone();
502                let air_prover = compress_prover.clone();
503                async move {
504                    let permits = ProverSemaphore::new(1);
505                    let (pk, vk) = air_prover.setup(program, permits).await;
506                    tx.send((pk, vk)).ok();
507                }
508            });
509            let (pk, vk) = rx.blocking_recv().unwrap();
510            let pk = unsafe { pk.into_inner() };
511            let deferred_keys = (pk, vk);
512
513            let prover_data = Arc::new(RecursionProverData {
514                recursion_vks,
515                reduce_shape,
516                compose_programs,
517                compose_keys,
518                deferred_program,
519                deferred_keys: Some(deferred_keys),
520            });
521
522            let compress_verifier = C::compress_verifier();
523
524            // Initialize the prepare reduce engine.
525            let prepare_reduce_workers = (0..config.num_prepare_reduce_workers)
526                .map(|_| PrepareReduceTaskWorker {
527                    prover_data: prover_data.clone(),
528                    artifact_client: artifact_client.clone(),
529                })
530                .collect();
531            let prepare_reduce_engine = Arc::new(AsyncEngine::new(
532                prepare_reduce_workers,
533                config.prepare_reduce_buffer_size,
534            ));
535
536            // Initialize the executor engine.
537            let executor_workers = (0..config.num_recursion_executor_workers)
538                .map(|_| RecursionExecutorWorker {
539                    compress_verifier: compress_verifier.clone(),
540                    prover_data: prover_data.clone(),
541                })
542                .collect();
543
544            let executor_engine = Arc::new(BlockingEngine::new(
545                executor_workers,
546                config.recursion_executor_buffer_size,
547            ));
548
549            // Initialize the prove engine.
550            let prove_workers = (0..config.num_recursion_prover_workers)
551                .map(|_| RecursionProverWorker {
552                    prover_data: prover_data.clone(),
553                    recursion_prover: compress_prover.clone(),
554                    permits: compress_prover_permits.clone(),
555                    artifact_client: artifact_client.clone(),
556                    verify_intermediates: config.verify_intermediates,
557                })
558                .collect();
559            let prove_engine =
560                Arc::new(AsyncEngine::new(prove_workers, config.recursion_prover_buffer_size));
561
562            // Make the recursion pipeline.
563            let recursion_pipeline = Arc::new(Chain::new(executor_engine, prove_engine));
564
565            // Make the reduce pipeline.
566            let reduce_pipeline = Arc::new(Chain::new(prepare_reduce_engine, recursion_pipeline));
567
568            let shrink_prover = Arc::new(ShrinkProver::new(
569                shrink_prover,
570                shrink_prover_permits,
571                prover_data.clone(),
572                config.clone(),
573            ));
574
575            let expected_wrap_vk = bincode::deserialize(WRAP_VK_BYTES).unwrap();
576            let wrap_prover_init = WrapProverInit {
577                wrap_air_prover: wrap_air_prover_init,
578                config: config.clone(),
579                shrink_shape: shrink_prover.shrink_shape.clone(),
580                expected_wrap_vk,
581            };
582
583            Self {
584                reduce_pipeline,
585                shrink_prover,
586                wrap_prover: Arc::new(OnceCell::new()),
587                wrap_prover_init: Arc::new(wrap_prover_init),
588                prover_data,
589                artifact_client,
590            }
591        })
592        .await
593        .unwrap()
594    }
595
596    pub fn recursion_prover_pipeline(&self) -> &Arc<RecursionProvePipeline<A, C>> {
597        self.reduce_pipeline.second()
598    }
599
600    pub async fn submit_prove_shard(
601        &self,
602        program: Arc<RecursionProgram<SP1Field>>,
603        witness: SP1CircuitWitness,
604        output: Artifact,
605        metrics: ProverMetrics,
606    ) -> Result<RecursionProveSubmitHandle<A, C>, SubmitError> {
607        self.recursion_prover_pipeline()
608            .submit(Ok(RecursionTask {
609                program,
610                witness,
611                output,
612                metrics,
613                range_proofs_to_cleanup: None,
614            }))
615            .await
616    }
617
618    pub async fn submit_recursion_reduce(
619        &self,
620        request: RawTaskRequest,
621    ) -> Result<ReduceSubmitHandle<A, C>, TaskError> {
622        let input = ReduceTaskRequest::from_raw(request)?;
623        let handle = self.reduce_pipeline.submit(input).await?;
624        Ok(handle)
625    }
626
627    async fn wrap_prover(&self) -> Result<Arc<WrapProver<C>>, TaskError> {
628        let wrap_prover_init = self.wrap_prover_init.clone();
629        let prover_data = self.prover_data.clone();
630
631        let wrap_prover = self
632            .wrap_prover
633            .get_or_try_init(|| async move {
634                let wrap_prover_init = wrap_prover_init.clone();
635                let prover_data = prover_data.clone();
636                tokio::task::spawn_blocking(move || {
637                    let expected_wrap_vk = wrap_prover_init.expected_wrap_vk.clone();
638                    let wrap_air_prover = wrap_prover_init.wrap_air_prover.build();
639                    let wrap_air_permits = wrap_prover_init.wrap_air_prover.permits();
640                    let wrap_prover = WrapProver::new(
641                        wrap_air_prover,
642                        wrap_air_permits,
643                        prover_data,
644                        wrap_prover_init.config.clone(),
645                        wrap_prover_init.shrink_shape.clone(),
646                    );
647
648                    if wrap_prover.prover_data.recursion_vks.vk_verification()
649                        && wrap_prover.verifying_key != expected_wrap_vk
650                    {
651                        return Err(TaskError::Fatal(anyhow::anyhow!(
652                            "Wrap vk mismatch, expected: {:?}, got: {:?}",
653                            expected_wrap_vk,
654                            wrap_prover.verifying_key
655                        )));
656                    }
657
658                    Ok(Arc::new(wrap_prover))
659                })
660                .await
661                .map_err(|err| TaskError::Fatal(anyhow::anyhow!(err)))?
662            })
663            .await?;
664
665        Ok(wrap_prover.clone())
666    }
667
668    pub async fn run_shrink_wrap(&self, request: RawTaskRequest) -> Result<(), TaskError> {
669        let RawTaskRequest { inputs, outputs, .. } = request;
670        let [compress_proof_artifact] = inputs.try_into().unwrap();
671        let [wrap_proof_artifact] = outputs.try_into().unwrap();
672
673        let compress_proof = self
674            .artifact_client
675            .download(&compress_proof_artifact)
676            .instrument(tracing::debug_span!("download compress proof"))
677            .await?;
678
679        let shrink_proof = self
680            .shrink_prover
681            .prove(compress_proof)
682            .instrument(tracing::info_span!("prove shrink"))
683            .await?;
684
685        tracing::debug_span!("verify shrink proof")
686            .in_scope(|| self.shrink_prover.verify(&shrink_proof))?;
687
688        let wrap_prover = self.wrap_prover().await?;
689        let wrap_proof =
690            wrap_prover.prove(shrink_proof).instrument(tracing::info_span!("prove wrap")).await?;
691
692        tracing::debug_span!("verify wrap proof").in_scope(|| wrap_prover.verify(&wrap_proof))?;
693
694        self.artifact_client
695            .upload(&wrap_proof_artifact, wrap_proof)
696            .instrument(tracing::debug_span!("upload wrap proof"))
697            .await?;
698
699        Ok(())
700    }
701
702    pub async fn run_groth16(&self, request: RawTaskRequest) -> Result<(), TaskError> {
703        let RawTaskRequest { inputs, outputs, .. } = request;
704        let [wrap_proof_artifact] = inputs.try_into().unwrap();
705        let [groth16_proof_artifact] = outputs.try_into().unwrap();
706
707        let wrap_proof: SP1WrapProof<SP1OuterGlobalContext, SP1PcsProofOuter> = self
708            .artifact_client
709            .download(&wrap_proof_artifact)
710            .instrument(tracing::debug_span!("download wrap proof"))
711            .await?;
712
713        let build_dir = try_build_groth16_artifacts_dir(&wrap_proof.vk, &wrap_proof.proof)
714            .await
715            .map_err(TaskError::Fatal)?;
716
717        let groth16_proof = tokio::task::spawn_blocking(move || -> Result<_, anyhow::Error> {
718            let SP1WrapProof { vk, proof } = wrap_proof;
719            let input = SP1ShapedWitnessValues {
720                vks_and_proofs: vec![(vk, proof.clone())],
721                is_complete: true,
722            };
723            let pv: &RecursionPublicValues<SP1Field> = proof.public_values.as_slice().borrow();
724            let vkey_hash = koalabears_to_bn254(&pv.sp1_vk_digest);
725            let committed_values_digest_bytes: [SP1Field; 32] =
726                words_to_bytes(&pv.committed_value_digest).try_into().map_err(|_| {
727                    anyhow::anyhow!(
728                        "committed_value_digest has invalid length, expected exactly 32 elements"
729                    )
730                })?;
731            let committed_values_digest = koalabear_bytes_to_bn254(&committed_values_digest_bytes);
732            let exit_code = Bn254Fr::from_canonical_u32(pv.exit_code.as_canonical_u32());
733            let proof_nonce = koalabears_proof_nonce_to_bn254(&pv.proof_nonce);
734            let vk_root = koalabears_to_bn254(&pv.vk_root);
735            let witness = {
736                let mut witness = OuterWitness::default();
737                input.write(&mut witness);
738                witness.write_committed_values_digest(committed_values_digest);
739                witness.write_vkey_hash(vkey_hash);
740                witness.write_exit_code(exit_code);
741                witness.write_vk_root(vk_root);
742                witness.write_proof_nonce(proof_nonce);
743                witness
744            };
745            let prover = Groth16Bn254Prover::new();
746            let proof = prover.prove(witness, &build_dir);
747            prover
748                .verify(
749                    &proof,
750                    &vkey_hash.as_canonical_biguint(),
751                    &committed_values_digest.as_canonical_biguint(),
752                    &exit_code.as_canonical_biguint(),
753                    &vk_root.as_canonical_biguint(),
754                    &proof_nonce.as_canonical_biguint(),
755                    &build_dir,
756                )
757                .map_err(|e| anyhow::anyhow!("Failed to verify groth16 wrap proof: {}", e))?;
758            Ok(proof)
759        })
760        .instrument(tracing::info_span!("prove groth16"))
761        .await
762        .map_err(|e| TaskError::Fatal(anyhow::anyhow!("Groth16 proof task panicked: {}", e)))?
763        .map_err(TaskError::Fatal)?;
764
765        self.artifact_client
766            .upload(&groth16_proof_artifact, groth16_proof)
767            .instrument(tracing::debug_span!("upload groth16 proof"))
768            .await?;
769        Ok(())
770    }
771
772    pub async fn run_plonk(&self, request: RawTaskRequest) -> Result<(), TaskError> {
773        let RawTaskRequest { inputs, outputs, .. } = request;
774        let [wrap_proof_artifact] = inputs.try_into().unwrap();
775        let [plonk_proof_artifact] = outputs.try_into().unwrap();
776        let wrap_proof: SP1WrapProof<SP1OuterGlobalContext, SP1PcsProofOuter> = self
777            .artifact_client
778            .download(&wrap_proof_artifact)
779            .instrument(tracing::debug_span!("download wrap proof"))
780            .await?;
781
782        let build_dir = try_build_plonk_artifacts_dir(&wrap_proof.vk, &wrap_proof.proof).await?;
783
784        let plonk_proof = tokio::task::spawn_blocking(move || -> Result<_, anyhow::Error> {
785            let SP1WrapProof { vk: wrap_vk, proof: wrap_proof } = wrap_proof;
786            let input = SP1ShapedWitnessValues {
787                vks_and_proofs: vec![(wrap_vk.clone(), wrap_proof.clone())],
788                is_complete: true,
789            };
790            let pv: &RecursionPublicValues<SP1Field> = wrap_proof.public_values.as_slice().borrow();
791            let vkey_hash = koalabears_to_bn254(&pv.sp1_vk_digest);
792            let committed_values_digest_bytes: [SP1Field; 32] =
793                words_to_bytes(&pv.committed_value_digest).try_into().map_err(|_| {
794                    anyhow::anyhow!(
795                        "committed_value_digest has invalid length, expected exactly 32 elements"
796                    )
797                })?;
798            let committed_values_digest = koalabear_bytes_to_bn254(&committed_values_digest_bytes);
799            let exit_code = Bn254Fr::from_canonical_u32(pv.exit_code.as_canonical_u32());
800            let vk_root = koalabears_to_bn254(&pv.vk_root);
801            let proof_nonce = koalabears_proof_nonce_to_bn254(&pv.proof_nonce);
802            let witness = {
803                let mut witness = OuterWitness::default();
804                input.write(&mut witness);
805                witness.write_committed_values_digest(committed_values_digest);
806                witness.write_vkey_hash(vkey_hash);
807                witness.write_exit_code(exit_code);
808                witness.write_vk_root(vk_root);
809                witness.write_proof_nonce(proof_nonce);
810                witness
811            };
812            let prover = PlonkBn254Prover::new();
813            let proof = prover.prove(witness, &build_dir);
814            prover
815                .verify(
816                    &proof,
817                    &vkey_hash.as_canonical_biguint(),
818                    &committed_values_digest.as_canonical_biguint(),
819                    &exit_code.as_canonical_biguint(),
820                    &vk_root.as_canonical_biguint(),
821                    &proof_nonce.as_canonical_biguint(),
822                    &build_dir,
823                )
824                .map_err(|e| anyhow::anyhow!("Failed to verify plonk wrap proof: {}", e))?;
825            Ok(proof)
826        })
827        .instrument(tracing::info_span!("prove plonk"))
828        .await
829        .map_err(|e| TaskError::Fatal(anyhow::anyhow!("Plonk proof task panicked: {}", e)))?
830        .map_err(TaskError::Fatal)?;
831
832        self.artifact_client
833            .upload(&plonk_proof_artifact, plonk_proof)
834            .instrument(tracing::debug_span!("upload plonk proof"))
835            .await?;
836        Ok(())
837    }
838
839    #[inline]
840    #[must_use]
841    pub fn recursion_vk_root(&self) -> [SP1Field; DIGEST_SIZE] {
842        self.prover_data.recursion_vks.root()
843    }
844
845    #[must_use]
846    pub fn vk_verification(&self) -> bool {
847        self.prover_data.vk_verification()
848    }
849
850    #[must_use]
851    pub fn get_normalize_witness(
852        &self,
853        common_input: &CommonProverInput,
854        proof: &ShardProof<SP1GlobalContext, SP1PcsProofInner>,
855        is_complete: bool,
856        is_precompile: bool,
857    ) -> SP1NormalizeWitnessValues<SP1GlobalContext, SP1PcsProofInner> {
858        // Use the final deferred digest from common_input for reconstruct_deferred_digest.
859        // This is needed because:
860        // - For core and global memory shards: deferred_proofs_digest equals
861        //   common_input.deferred_digest and the number of deferred proofs accumulated so far is
862        //   the total number of deferred proofs.
863        // - For precompile shards: they are ordered first in the deferred tree so their number
864        //   of accumulated deferred proofs is 0 and deferred_proofs_digest is the initial digest
865        let (num_deferred_proofs, reconstruct_deferred_digest) = if is_precompile {
866            (SP1Field::zero(), DeferredInputs::initial_deferred_digest())
867        } else {
868            (
869                SP1Field::from_canonical_usize(common_input.num_deferred_proofs),
870                common_input.deferred_digest.map(SP1Field::from_canonical_u32),
871            )
872        };
873        SP1NormalizeWitnessValues {
874            vk: common_input.vk.vk.clone(),
875            shard_proofs: vec![proof.clone()],
876            is_complete,
877            vk_root: self.recursion_vk_root(),
878            reconstruct_deferred_digest,
879            num_deferred_proofs,
880        }
881    }
882
883    pub fn reduce_shape(&self) -> &SP1RecursionProofShape {
884        &self.prover_data.reduce_shape
885    }
886}
887
888type CompressKeys<C> = (
889    Arc<ProvingKey<SP1GlobalContext, RecursionSC, <C as SP1ProverComponents>::RecursionProver>>,
890    MachineVerifyingKey<SP1GlobalContext>,
891);
892
893pub struct RecursionProverData<C: SP1ProverComponents> {
894    recursion_vks: RecursionVks,
895    reduce_shape: SP1RecursionProofShape,
896    compose_programs: BTreeMap<usize, Arc<RecursionProgram<SP1Field>>>,
897    compose_keys: BTreeMap<usize, CompressKeys<C>>,
898    deferred_program: Arc<RecursionProgram<SP1Field>>,
899    deferred_keys: Option<CompressKeys<C>>,
900}
901
902impl<C: SP1ProverComponents> RecursionProverData<C> {
903    pub fn vk_verification(&self) -> bool {
904        self.recursion_vks.vk_verification()
905    }
906
907    pub fn recursion_vks(&self) -> &RecursionVks {
908        &self.recursion_vks
909    }
910
911    pub fn append_merkle_proofs_to_witness(
912        &self,
913        input: SP1ShapedWitnessValues<SP1GlobalContext, SP1PcsProofInner>,
914        merkle_proofs: Vec<MerkleProof<SP1GlobalContext>>,
915    ) -> Result<SP1CompressWithVKeyWitnessValues<SP1PcsProofInner>, TaskError> {
916        let values = if self.recursion_vks.vk_verification() {
917            input.vks_and_proofs.iter().map(|(vk, _)| vk.hash_koalabear()).collect()
918        } else {
919            let num_vks = self.recursion_vks.num_keys();
920            input
921                .vks_and_proofs
922                .iter()
923                .map(|(vk, _)| {
924                    let vk_digest = vk.hash_koalabear();
925                    let index = (vk_digest[0].as_canonical_u32() as usize) % num_vks;
926                    [SP1Field::from_canonical_u32(index as u32); DIGEST_SIZE]
927                })
928                .collect()
929        };
930
931        let merkle_val = SP1MerkleProofWitnessValues {
932            root: self.recursion_vks.root(),
933            values,
934            vk_merkle_proofs: merkle_proofs,
935        };
936
937        Ok(SP1CompressWithVKeyWitnessValues { compress_val: input, merkle_val })
938    }
939
940    pub fn witness_stream(
941        &self,
942        witness: &SP1CircuitWitness,
943    ) -> Result<VecDeque<Block<SP1Field>>, TaskError> {
944        let mut witness_stream = Vec::new();
945        match witness {
946            SP1CircuitWitness::Core(input) => {
947                Witnessable::<InnerConfig>::write(&input, &mut witness_stream);
948            }
949            SP1CircuitWitness::Deferred(input) => {
950                Witnessable::<InnerConfig>::write(&input, &mut witness_stream);
951            }
952            SP1CircuitWitness::Compress(input) => {
953                Witnessable::<InnerConfig>::write(&input, &mut witness_stream);
954            }
955            SP1CircuitWitness::Shrink(input) => {
956                Witnessable::<InnerConfig>::write(&input, &mut witness_stream);
957            }
958            SP1CircuitWitness::Wrap(input) => {
959                Witnessable::<WrapConfig>::write(&input, &mut witness_stream);
960            }
961        }
962        Ok(witness_stream.into())
963    }
964
965    pub fn deferred_program(&self) -> &Arc<RecursionProgram<SP1Field>> {
966        &self.deferred_program
967    }
968}
969
970fn dummy_compose_input<C: SP1ProverComponents>(
971    shape: &SP1RecursionProofShape,
972    arity: usize,
973    height: usize,
974) -> SP1CompressWithVKeyWitnessValues<SP1PcsProofInner> {
975    let verifier = C::compress_verifier();
976    shape.dummy_input(
977        arity,
978        height,
979        verifier.shard_verifier().machine().chips().iter().cloned().collect::<BTreeSet<_>>(),
980        verifier.max_log_row_count(),
981        *verifier.fri_config(),
982        verifier.log_stacking_height() as usize,
983    )
984}
985
986pub struct ShrinkProver<C: SP1ProverComponents> {
987    prover: Arc<C::RecursionProver>,
988    permits: ProverSemaphore,
989    program: Arc<RecursionProgram<SP1Field>>,
990    pub verifying_key: MachineVerifyingKey<SP1GlobalContext>,
991    prover_data: Arc<RecursionProverData<C>>,
992    pub shrink_shape: BTreeMap<String, usize>,
993}
994
995impl<C: SP1ProverComponents> ShrinkProver<C> {
996    fn new(
997        prover: Arc<C::RecursionProver>,
998        permits: ProverSemaphore,
999        prover_data: Arc<RecursionProverData<C>>,
1000        config: SP1RecursionProverConfig,
1001    ) -> Self {
1002        let verifier = C::compress_verifier();
1003        let input = prover_data.reduce_shape.dummy_input(
1004            1,
1005            prover_data.recursion_vks.height(),
1006            verifier.shard_verifier().machine().chips().iter().cloned().collect::<BTreeSet<_>>(),
1007            verifier.max_log_row_count(),
1008            *verifier.fri_config(),
1009            verifier.log_stacking_height() as usize,
1010        );
1011        let program = Arc::new(shrink_program_from_input(
1012            &recursive_verifier(verifier.shard_verifier()),
1013            config.vk_verification,
1014            &input,
1015        ));
1016
1017        let (pk, vk) = {
1018            let (prover, program, permits) = (prover.clone(), program.clone(), permits.clone());
1019            let (tx, rx) = oneshot::channel();
1020            tokio::task::spawn(async move {
1021                tx.send(prover.setup(program.clone(), permits.clone()).await).ok()
1022            });
1023            rx.blocking_recv().unwrap()
1024        };
1025        let shrink_shape = {
1026            let (tx, rx) = oneshot::channel();
1027            tokio::task::spawn(async move {
1028                let heights = <C::RecursionProver as AirProver<
1029                    SP1GlobalContext,_
1030                >>::preprocessed_table_heights(pk.pk)
1031                .await;
1032                tx.send(heights).ok();
1033            });
1034            rx.blocking_recv().unwrap()
1035        };
1036        Self { prover, permits, program, verifying_key: vk, prover_data, shrink_shape }
1037    }
1038
1039    pub(crate) async fn setup(
1040        self: Arc<Self>,
1041        program: Arc<RecursionProgram<SP1Field>>,
1042    ) -> MachineVerifyingKey<SP1GlobalContext> {
1043        self.prover.setup(program, self.permits.clone()).await.1
1044    }
1045
1046    async fn prove(
1047        &self,
1048        compressed_proof: SP1RecursionProof<SP1GlobalContext, SP1PcsProofInner>,
1049    ) -> Result<SP1RecursionProof<SP1GlobalContext, SP1PcsProofInner>, TaskError> {
1050        let execution_record = {
1051            let mut runtime =
1052                Executor::<SP1Field, SP1ExtensionField, _>::new(self.program.clone(), inner_perm());
1053            runtime.witness_stream = self.prover_data.witness_stream(&{
1054                let SP1RecursionProof { vk, proof, vk_merkle_proof } = compressed_proof;
1055                let input =
1056                    SP1ShapedWitnessValues { vks_and_proofs: vec![(vk, proof)], is_complete: true };
1057                SP1CircuitWitness::Shrink(
1058                    self.prover_data
1059                        .append_merkle_proofs_to_witness(input, vec![vk_merkle_proof])?,
1060                )
1061            })?;
1062            runtime.run().map_err(|e| TaskError::Fatal(e.into()))?;
1063            runtime.record
1064        };
1065
1066        let (vk, proof, _permit) = self
1067            .prover
1068            .setup_and_prove_shard(
1069                self.program.clone(),
1070                execution_record,
1071                Some(self.verifying_key.clone()),
1072                self.permits.clone(),
1073            )
1074            .await;
1075        let vk_merkle_proof = self.prover_data.recursion_vks.open(&vk)?.1;
1076        Ok(SP1RecursionProof { vk: self.verifying_key.clone(), proof, vk_merkle_proof })
1077    }
1078
1079    fn verify(
1080        &self,
1081        shrink_proof: &SP1RecursionProof<SP1GlobalContext, SP1PcsProofInner>,
1082    ) -> Result<(), TaskError> {
1083        let SP1RecursionProof { vk, proof, vk_merkle_proof } = shrink_proof;
1084        let mut challenger = SP1GlobalContext::default_challenger();
1085        vk.observe_into(&mut challenger);
1086        C::shrink_verifier()
1087            .verify_shard(vk, proof, &mut challenger)
1088            .map_err(|e| TaskError::Fatal(e.into()))?;
1089
1090        self.prover_data.recursion_vks.verify(vk_merkle_proof, vk)
1091    }
1092}
1093
1094pub struct WrapProver<C: SP1ProverComponents> {
1095    prover: Arc<C::WrapProver>,
1096    permits: ProverSemaphore,
1097    program: Arc<RecursionProgram<SP1Field>>,
1098    pub verifying_key: MachineVerifyingKey<SP1OuterGlobalContext>,
1099    prover_data: Arc<RecursionProverData<C>>,
1100}
1101
1102impl<C: SP1ProverComponents> WrapProver<C> {
1103    pub fn new(
1104        prover: Arc<C::WrapProver>,
1105        permits: ProverSemaphore,
1106        prover_data: Arc<RecursionProverData<C>>,
1107        config: SP1RecursionProverConfig,
1108        shrink_shape: BTreeMap<String, usize>,
1109    ) -> Self {
1110        let verifier = C::shrink_verifier();
1111        let shrink_proof_shape =
1112            SP1RecursionProofShape { shape: RecursionShape::new(shrink_shape) };
1113        let wrap_input = shrink_proof_shape.dummy_input(
1114            1,
1115            prover_data.recursion_vks.height(),
1116            verifier.shard_verifier().machine().chips().iter().cloned().collect::<BTreeSet<_>>(),
1117            verifier.max_log_row_count(),
1118            *verifier.fri_config(),
1119            verifier.log_stacking_height() as usize,
1120        );
1121
1122        let program = Arc::new(wrap_program_from_input(
1123            &recursive_verifier(verifier.shard_verifier()),
1124            config.vk_verification,
1125            &wrap_input,
1126        ));
1127        let (_, verifying_key) = {
1128            let (prover, program, permits) = (prover.clone(), program.clone(), permits.clone());
1129            let (tx, rx) = oneshot::channel();
1130            tokio::task::spawn(async move {
1131                tx.send(prover.setup(program.clone(), permits).await).ok();
1132            });
1133            rx.blocking_recv().unwrap()
1134        };
1135
1136        Self { prover, permits, program, verifying_key, prover_data }
1137    }
1138
1139    pub async fn prove(
1140        &self,
1141        shrunk_proof: SP1RecursionProof<SP1GlobalContext, SP1PcsProofInner>,
1142    ) -> Result<SP1WrapProof<SP1OuterGlobalContext, SP1PcsProofOuter>, TaskError> {
1143        let execution_record = {
1144            let mut runtime =
1145                Executor::<SP1Field, SP1ExtensionField, _>::new(self.program.clone(), inner_perm());
1146            runtime.witness_stream = self.prover_data.witness_stream(&{
1147                let SP1RecursionProof { vk, proof, vk_merkle_proof } = shrunk_proof;
1148                let input =
1149                    SP1ShapedWitnessValues { vks_and_proofs: vec![(vk, proof)], is_complete: true };
1150                SP1CircuitWitness::Wrap(
1151                    self.prover_data
1152                        .append_merkle_proofs_to_witness(input, vec![vk_merkle_proof.clone()])?,
1153                )
1154            })?;
1155            runtime.run().map_err(|e| TaskError::Fatal(e.into()))?;
1156            runtime.record
1157        };
1158
1159        let (_, proof, _permit) = self
1160            .prover
1161            .setup_and_prove_shard(
1162                self.program.clone(),
1163                execution_record,
1164                Some(self.verifying_key.clone()),
1165                self.permits.clone(),
1166            )
1167            .await;
1168
1169        Ok(SP1WrapProof { vk: self.verifying_key.clone(), proof })
1170    }
1171
1172    fn verify(
1173        &self,
1174        wrapped_proof: &SP1WrapProof<SP1OuterGlobalContext, SP1PcsProofOuter>,
1175    ) -> Result<(), TaskError> {
1176        let SP1WrapProof { vk, proof } = wrapped_proof;
1177        let mut challenger = SP1OuterGlobalContext::default_challenger();
1178        vk.observe_into(&mut challenger);
1179        C::wrap_verifier()
1180            .verify_shard(vk, proof, &mut challenger)
1181            .map_err(|e| TaskError::Fatal(e.into()))
1182    }
1183}