1use crate::{
2 build::{try_build_groth16_artifacts_dir, try_build_plonk_artifacts_dir},
3 recursion::{
4 compose_program_from_input, deferred_program_from_input, dummy_deferred_input,
5 recursive_verifier, shrink_program_from_input, wrap_program_from_input, RecursionVks,
6 },
7 shapes::SP1RecursionProofShape,
8 verify::WRAP_VK_BYTES,
9 worker::{
10 CommonProverInput, DeferredInputs, ProverMetrics, RangeProofs, RawTaskRequest, TaskContext,
11 TaskError, TaskMetadata, WrapAirProverInit,
12 },
13 RecursionSC, SP1CircuitWitness, SP1ProverComponents,
14};
15use slop_algebra::PrimeField32;
16use slop_algebra::{AbstractField, PrimeField};
17use slop_bn254::Bn254Fr;
18use slop_challenger::IopCtx;
19use slop_futures::pipeline::{
20 AsyncEngine, AsyncWorker, BlockingEngine, BlockingWorker, Chain, Pipeline, SubmitError,
21 SubmitHandle,
22};
23use sp1_hypercube::{
24 inner_perm, koalabears_to_bn254,
25 prover::{AirProver, ProverSemaphore, ProvingKey},
26 HashableKey, MachineProof, MachineVerifier, MachineVerifyingKey, MerkleProof, SP1PcsProofInner,
27 SP1PcsProofOuter, SP1RecursionProof, SP1WrapProof, ShardProof, DIGEST_SIZE,
28};
29use sp1_primitives::{SP1ExtensionField, SP1Field, SP1GlobalContext, SP1OuterGlobalContext};
30use sp1_prover_types::{Artifact, ArtifactClient, ArtifactId};
31use sp1_recursion_circuit::{
32 machine::{
33 SP1CompressWithVKeyWitnessValues, SP1MerkleProofWitnessValues, SP1NormalizeWitnessValues,
34 SP1ShapedWitnessValues,
35 },
36 utils::{koalabear_bytes_to_bn254, koalabears_proof_nonce_to_bn254, words_to_bytes},
37 witness::{OuterWitness, Witnessable},
38 WrapConfig,
39};
40use sp1_recursion_compiler::config::InnerConfig;
41use sp1_recursion_executor::{
42 shape::RecursionShape, Block, ExecutionRecord, Executor, RecursionProgram,
43 RecursionPublicValues,
44};
45use sp1_recursion_gnark_ffi::{Groth16Bn254Prover, PlonkBn254Prover};
46use std::{
47 borrow::Borrow,
48 collections::{BTreeMap, BTreeSet, VecDeque},
49 sync::Arc,
50};
51use tokio::sync::{oneshot, OnceCell};
52use tracing::Instrument;
53
54#[derive(Debug, Clone)]
56pub struct SP1RecursionProverConfig {
57 pub num_prepare_reduce_workers: usize,
59 pub prepare_reduce_buffer_size: usize,
61 pub num_recursion_executor_workers: usize,
63 pub recursion_executor_buffer_size: usize,
65 pub num_recursion_prover_workers: usize,
67 pub recursion_prover_buffer_size: usize,
69 pub max_compose_arity: usize,
71 vk_verification: bool,
74 pub verify_intermediates: bool,
76 vk_map_file: Option<String>,
79}
80
81impl SP1RecursionProverConfig {
82 #[allow(clippy::too_many_arguments)]
83 pub fn new(
84 num_prepare_reduce_workers: usize,
85 prepare_reduce_buffer_size: usize,
86 num_recursion_executor_workers: usize,
87 recursion_executor_buffer_size: usize,
88 num_recursion_prover_workers: usize,
89 recursion_prover_buffer_size: usize,
90 max_compose_arity: usize,
91 verify_intermediates: bool,
92 ) -> Self {
93 Self {
94 num_prepare_reduce_workers,
95 prepare_reduce_buffer_size,
96 num_recursion_executor_workers,
97 recursion_executor_buffer_size,
98 num_recursion_prover_workers,
99 recursion_prover_buffer_size,
100 max_compose_arity,
101 vk_verification: true,
102 verify_intermediates,
103 vk_map_file: None,
104 }
105 }
106 #[cfg(feature = "experimental")]
107 pub fn without_vk_verification(self) -> Self {
109 Self { vk_verification: false, ..self }
110 }
111
112 #[cfg(feature = "experimental")]
113 pub fn with_vk_map_path(self, vk_map_path: String) -> Self {
115 Self { vk_map_file: Some(vk_map_path), ..self }
116 }
117}
118
119pub struct ReduceTaskRequest {
120 pub range_proofs: RangeProofs,
121 pub is_complete: bool,
122 pub output: Artifact,
123 pub context: TaskContext,
124}
125
126impl ReduceTaskRequest {
127 pub fn from_raw(request: RawTaskRequest) -> Result<Self, TaskError> {
128 let RawTaskRequest { inputs, mut outputs, context } = request;
129 let is_complete = inputs[0].id().parse::<bool>().map_err(|e| TaskError::Fatal(e.into()))?;
130 let range_proofs = RangeProofs::from_artifacts(&inputs[1..])?;
131 let output =
132 outputs.pop().ok_or(TaskError::Fatal(anyhow::anyhow!("No output artifact")))?;
133 Ok(ReduceTaskRequest { range_proofs, is_complete, output, context })
134 }
135
136 pub fn into_raw(self) -> Result<RawTaskRequest, TaskError> {
137 let ReduceTaskRequest { range_proofs, is_complete, output, context } = self;
138 let is_complete_artifact = Artifact::from(is_complete.to_string());
139 let mut inputs = Vec::with_capacity(2 * range_proofs.len() + 2);
140 inputs.push(is_complete_artifact);
141 inputs.extend(range_proofs.as_artifacts());
142 let raw_task_request = RawTaskRequest { inputs, outputs: vec![output], context };
143 Ok(raw_task_request)
144 }
145}
146
147pub struct PrepareReduceTaskWorker<A, C: SP1ProverComponents> {
148 prover_data: Arc<RecursionProverData<C>>,
149 artifact_client: A,
150}
151
152impl<A: ArtifactClient, C: SP1ProverComponents>
153 AsyncWorker<ReduceTaskRequest, Result<RecursionTask, TaskError>>
154 for PrepareReduceTaskWorker<A, C>
155{
156 #[tracing::instrument(level = "trace", name = "prepare_reduce_task", skip(self, input))]
157 async fn call(&self, input: ReduceTaskRequest) -> Result<RecursionTask, TaskError> {
158 let ReduceTaskRequest { range_proofs, is_complete, output, .. } = input;
159
160 let program = self.prover_data.compose_programs.get(&range_proofs.len()).cloned().ok_or(
161 TaskError::Fatal(anyhow::anyhow!(
162 "Compress program not found for arity {}",
163 range_proofs.len()
164 )),
165 )?;
166
167 let witness = range_proofs
168 .download_witness::<C>(is_complete, &self.artifact_client, &self.prover_data)
169 .await?;
170
171 let metrics = ProverMetrics::new();
172 Ok(RecursionTask {
173 program,
174 witness,
175 output,
176 metrics,
177 range_proofs_to_cleanup: Some(range_proofs),
178 })
179 }
180}
181
182pub struct RecursionTask {
183 program: Arc<RecursionProgram<SP1Field>>,
184 witness: SP1CircuitWitness,
185 range_proofs_to_cleanup: Option<RangeProofs>,
186 output: Artifact,
187 metrics: ProverMetrics,
188}
189
190pub struct RecursionExecutorWorker<C: SP1ProverComponents> {
191 compress_verifier: MachineVerifier<SP1GlobalContext, RecursionSC>,
192 prover_data: Arc<RecursionProverData<C>>,
193}
194
195impl<C: SP1ProverComponents>
196 BlockingWorker<Result<RecursionTask, TaskError>, Result<ProveRecursionTask<C>, TaskError>>
197 for RecursionExecutorWorker<C>
198{
199 fn call(
200 &self,
201 input: Result<RecursionTask, TaskError>,
202 ) -> Result<ProveRecursionTask<C>, TaskError> {
203 let RecursionTask { program, witness, output, metrics, range_proofs_to_cleanup } = input?;
204
205 let runtime_span = tracing::debug_span!("execute runtime").entered();
207 let mut runtime =
208 Executor::<SP1Field, SP1ExtensionField, _>::new(program.clone(), inner_perm());
209 runtime.witness_stream = self.prover_data.witness_stream(&witness)?;
210 runtime.run().map_err(|e| TaskError::Fatal(e.into()))?;
211 let mut record = runtime.record;
212 runtime_span.exit();
213
214 tokio::task::spawn_blocking(move || {
215 drop(runtime.memory);
216 drop(runtime.program);
217 drop(runtime.witness_stream);
218 });
219
220 tracing::debug_span!("generate dependencies").in_scope(|| {
222 self.compress_verifier
223 .machine()
224 .generate_dependencies(std::iter::once(&mut record), None)
225 });
226
227 let keys = tracing::debug_span!("get keys").in_scope(|| match witness {
228 SP1CircuitWitness::Core(_) => anyhow::Ok(RecursionKeys::Program(program)),
229 SP1CircuitWitness::Compress(input) => {
230 let arity = input.compress_val.vks_and_proofs.len();
231 let (pk, vk) = self.prover_data.compose_keys.get(&arity).cloned().ok_or(
232 TaskError::Fatal(anyhow::anyhow!("Compose key not found for arity {}", arity)),
233 )?;
234 anyhow::Ok(RecursionKeys::Exists(pk, vk))
235 }
236 SP1CircuitWitness::Deferred(_) => {
237 let keys = self
238 .prover_data
239 .deferred_keys
240 .clone()
241 .map(|(pk, vk)| RecursionKeys::Exists(pk, vk))
242 .unwrap_or_else(|| {
243 RecursionKeys::Program(self.prover_data.deferred_program.clone())
244 });
245 anyhow::Ok(keys)
246 }
247 _ => unimplemented!(),
248 })?;
249
250 Ok(ProveRecursionTask { record, keys, output, metrics, range_proofs_to_cleanup })
251 }
252}
253
254pub type CompressProvingKey<C> =
255 ProvingKey<SP1GlobalContext, RecursionSC, <C as SP1ProverComponents>::RecursionProver>;
256
257enum RecursionKeys<C: SP1ProverComponents> {
258 Exists(Arc<CompressProvingKey<C>>, MachineVerifyingKey<SP1GlobalContext>),
259 Program(Arc<RecursionProgram<SP1Field>>),
260}
261
262pub struct ProveRecursionTask<C: SP1ProverComponents> {
263 record: ExecutionRecord<SP1Field>,
264 keys: RecursionKeys<C>,
265 output: Artifact,
266 metrics: ProverMetrics,
267 range_proofs_to_cleanup: Option<RangeProofs>,
268}
269
270pub struct RecursionProverWorker<A, C: SP1ProverComponents> {
271 recursion_prover: Arc<C::RecursionProver>,
272 permits: ProverSemaphore,
273 artifact_client: A,
274 verify_intermediates: bool,
275 prover_data: Arc<RecursionProverData<C>>,
276}
277
278impl<A: ArtifactClient, C: SP1ProverComponents> RecursionProverWorker<A, C> {
279 async fn prove_shard(
280 &self,
281 keys: RecursionKeys<C>,
282 record: ExecutionRecord<SP1Field>,
283 metrics: ProverMetrics,
284 ) -> Result<SP1RecursionProof<SP1GlobalContext, SP1PcsProofInner>, TaskError> {
285 let proof = match keys {
286 RecursionKeys::Exists(pk, vk) => {
287 let (proof, permit) = self
288 .recursion_prover
289 .prove_shard_with_pk(pk.clone(), record, self.permits.clone())
290 .await;
291 let duration = permit.release();
292 metrics.increment_permit_time(duration);
293
294 if self.verify_intermediates {
295 let proof = proof.clone();
296 let vk = vk.clone();
297 let parent = tracing::Span::current();
298 tokio::task::spawn_blocking(move || {
299 let _guard = parent.enter();
300 C::compress_verifier()
301 .verify(&vk, &MachineProof::from(vec![proof]))
302 .map_err(|e| {
303 TaskError::Retryable(anyhow::anyhow!(
304 "compress verify failed: {}",
305 e
306 ))
307 })
308 })
309 .await
310 .map_err(|e| TaskError::Fatal(e.into()))??;
311 }
312 let vk_merkle_proof = self.prover_data.recursion_vks.open(&vk)?.1;
313 SP1RecursionProof { vk, proof, vk_merkle_proof }
314 }
315 RecursionKeys::Program(program) => {
316 let (vk, proof, permit) = self
317 .recursion_prover
318 .setup_and_prove_shard(program, record, None, self.permits.clone())
319 .await;
320 let duration = permit.release();
321 metrics.increment_permit_time(duration);
322 if self.verify_intermediates {
323 let proof = proof.clone();
324 let vk = vk.clone();
325 let parent = tracing::Span::current();
326 tokio::task::spawn_blocking(move || {
327 let _guard = parent.enter();
328 C::compress_verifier()
329 .verify(&vk, &MachineProof::from(vec![proof.clone()]))
330 .map_err(|e| {
331 TaskError::Retryable(anyhow::anyhow!(
332 "lift/deferred verify failed: {}",
333 e
334 ))
335 })
336 })
337 .await
338 .map_err(|e| TaskError::Fatal(e.into()))??;
339 }
340 let vk_merkle_proof = self.prover_data.recursion_vks.open(&vk)?.1;
341 SP1RecursionProof { vk, proof, vk_merkle_proof }
342 }
343 };
344 Ok(proof)
345 }
346}
347
348impl<A: ArtifactClient, C: SP1ProverComponents>
349 AsyncWorker<Result<ProveRecursionTask<C>, TaskError>, Result<TaskMetadata, TaskError>>
350 for RecursionProverWorker<A, C>
351{
352 async fn call(
353 &self,
354 input: Result<ProveRecursionTask<C>, TaskError>,
355 ) -> Result<TaskMetadata, TaskError> {
356 let ProveRecursionTask { record, keys, output, metrics, range_proofs_to_cleanup } = input?;
358 let proof = self.prove_shard(keys, record, metrics.clone()).await?;
360 self.artifact_client.upload(&output, proof.clone()).await?;
363 let metadata = metrics.to_metadata();
364
365 if let Some(proofs_to_cleanup) = range_proofs_to_cleanup {
367 proofs_to_cleanup.try_delete_proofs(&self.artifact_client).await?;
368 }
369
370 Ok(metadata)
371 }
372}
373
374type ExecutorEngine<C> = Arc<
375 BlockingEngine<
376 Result<RecursionTask, TaskError>,
377 Result<ProveRecursionTask<C>, TaskError>,
378 RecursionExecutorWorker<C>,
379 >,
380>;
381
382type RecursionProverEngine<A, C> = Arc<
383 AsyncEngine<
384 Result<ProveRecursionTask<C>, TaskError>,
385 Result<TaskMetadata, TaskError>,
386 RecursionProverWorker<A, C>,
387 >,
388>;
389
390type PrepareReduceEngine<A, C> = Arc<
391 AsyncEngine<ReduceTaskRequest, Result<RecursionTask, TaskError>, PrepareReduceTaskWorker<A, C>>,
392>;
393
394type RecursionProvePipeline<A, C> = Chain<ExecutorEngine<C>, RecursionProverEngine<A, C>>;
395
396type ReducePipeline<A, C> = Chain<PrepareReduceEngine<A, C>, Arc<RecursionProvePipeline<A, C>>>;
397
398pub type RecursionProveSubmitHandle<A, C> = SubmitHandle<RecursionProvePipeline<A, C>>;
399
400pub type ReduceSubmitHandle<A, C> = SubmitHandle<ReducePipeline<A, C>>;
401
402pub struct SP1RecursionProver<A, C: SP1ProverComponents> {
403 reduce_pipeline: Arc<ReducePipeline<A, C>>,
404 pub shrink_prover: Arc<ShrinkProver<C>>,
405 wrap_prover: Arc<OnceCell<Arc<WrapProver<C>>>>,
406 wrap_prover_init: Arc<WrapProverInit<C>>,
407 pub prover_data: Arc<RecursionProverData<C>>,
408 artifact_client: A,
409}
410
411struct WrapProverInit<C: SP1ProverComponents> {
412 wrap_air_prover: WrapAirProverInit<C>,
413 config: SP1RecursionProverConfig,
414 shrink_shape: BTreeMap<String, usize>,
415 expected_wrap_vk: MachineVerifyingKey<SP1OuterGlobalContext>,
416}
417
418impl<A: Clone, C: SP1ProverComponents> Clone for SP1RecursionProver<A, C> {
419 fn clone(&self) -> Self {
420 Self {
421 reduce_pipeline: self.reduce_pipeline.clone(),
422 shrink_prover: self.shrink_prover.clone(),
423 wrap_prover: self.wrap_prover.clone(),
424 wrap_prover_init: self.wrap_prover_init.clone(),
425 prover_data: self.prover_data.clone(),
426 artifact_client: self.artifact_client.clone(),
427 }
428 }
429}
430
431impl<A: ArtifactClient, C: SP1ProverComponents> SP1RecursionProver<A, C> {
432 pub async fn new(
433 config: SP1RecursionProverConfig,
434 artifact_client: A,
435 (compress_prover, compress_prover_permits): (Arc<C::RecursionProver>, ProverSemaphore),
436 (shrink_prover, shrink_prover_permits): (Arc<C::RecursionProver>, ProverSemaphore),
437 wrap_air_prover_init: WrapAirProverInit<C>,
438 ) -> Self {
439 tokio::task::spawn_blocking(move || {
440 let reduce_shape =
442 SP1RecursionProofShape::compress_proof_shape_from_arity(config.max_compose_arity)
443 .expect("arity not supported");
444
445 let mut compose_programs = BTreeMap::new();
447 let mut compose_keys = BTreeMap::new();
448
449 let vk_map_path = config.vk_map_file.as_ref().map(std::path::PathBuf::from);
450
451 let recursion_vks =
452 RecursionVks::new(vk_map_path, config.max_compose_arity, config.vk_verification);
453
454 let recursion_vks_height = recursion_vks.height();
455
456 let compress_verifier = C::compress_verifier();
457 let recursive_compress_verifier =
458 recursive_verifier::<SP1GlobalContext, _, InnerConfig>(
459 compress_verifier.shard_verifier(),
460 );
461 for arity in 1..=config.max_compose_arity {
462 let dummy_input =
463 dummy_compose_input::<C>(&reduce_shape, arity, recursion_vks_height);
464 let mut program = compose_program_from_input(
465 &recursive_compress_verifier,
466 config.vk_verification,
467 &dummy_input,
468 );
469 program.shape = Some(reduce_shape.shape.clone());
470 let program = Arc::new(program);
471
472 let (tx, rx) = oneshot::channel();
474 tokio::task::spawn({
475 let program = program.clone();
476 let air_prover = compress_prover.clone();
477 async move {
478 let permits = ProverSemaphore::new(1);
479 let (pk, vk) = air_prover.setup(program, permits).await;
480 tx.send((pk, vk)).ok();
481 }
482 });
483 let (pk, vk) = rx.blocking_recv().unwrap();
484 let pk = unsafe { pk.into_inner() };
485 compose_keys.insert(arity, (pk, vk));
486 compose_programs.insert(arity, program);
487 }
488
489 let deferred_input =
491 dummy_deferred_input(&compress_verifier, &reduce_shape, recursion_vks_height);
492 let mut deferred_program = deferred_program_from_input(
493 &recursive_compress_verifier,
494 config.vk_verification,
495 &deferred_input,
496 );
497 deferred_program.shape = Some(reduce_shape.shape.clone());
498 let deferred_program = Arc::new(deferred_program);
499 let (tx, rx) = oneshot::channel();
500 tokio::task::spawn({
501 let program = deferred_program.clone();
502 let air_prover = compress_prover.clone();
503 async move {
504 let permits = ProverSemaphore::new(1);
505 let (pk, vk) = air_prover.setup(program, permits).await;
506 tx.send((pk, vk)).ok();
507 }
508 });
509 let (pk, vk) = rx.blocking_recv().unwrap();
510 let pk = unsafe { pk.into_inner() };
511 let deferred_keys = (pk, vk);
512
513 let prover_data = Arc::new(RecursionProverData {
514 recursion_vks,
515 reduce_shape,
516 compose_programs,
517 compose_keys,
518 deferred_program,
519 deferred_keys: Some(deferred_keys),
520 });
521
522 let compress_verifier = C::compress_verifier();
523
524 let prepare_reduce_workers = (0..config.num_prepare_reduce_workers)
526 .map(|_| PrepareReduceTaskWorker {
527 prover_data: prover_data.clone(),
528 artifact_client: artifact_client.clone(),
529 })
530 .collect();
531 let prepare_reduce_engine = Arc::new(AsyncEngine::new(
532 prepare_reduce_workers,
533 config.prepare_reduce_buffer_size,
534 ));
535
536 let executor_workers = (0..config.num_recursion_executor_workers)
538 .map(|_| RecursionExecutorWorker {
539 compress_verifier: compress_verifier.clone(),
540 prover_data: prover_data.clone(),
541 })
542 .collect();
543
544 let executor_engine = Arc::new(BlockingEngine::new(
545 executor_workers,
546 config.recursion_executor_buffer_size,
547 ));
548
549 let prove_workers = (0..config.num_recursion_prover_workers)
551 .map(|_| RecursionProverWorker {
552 prover_data: prover_data.clone(),
553 recursion_prover: compress_prover.clone(),
554 permits: compress_prover_permits.clone(),
555 artifact_client: artifact_client.clone(),
556 verify_intermediates: config.verify_intermediates,
557 })
558 .collect();
559 let prove_engine =
560 Arc::new(AsyncEngine::new(prove_workers, config.recursion_prover_buffer_size));
561
562 let recursion_pipeline = Arc::new(Chain::new(executor_engine, prove_engine));
564
565 let reduce_pipeline = Arc::new(Chain::new(prepare_reduce_engine, recursion_pipeline));
567
568 let shrink_prover = Arc::new(ShrinkProver::new(
569 shrink_prover,
570 shrink_prover_permits,
571 prover_data.clone(),
572 config.clone(),
573 ));
574
575 let expected_wrap_vk = bincode::deserialize(WRAP_VK_BYTES).unwrap();
576 let wrap_prover_init = WrapProverInit {
577 wrap_air_prover: wrap_air_prover_init,
578 config: config.clone(),
579 shrink_shape: shrink_prover.shrink_shape.clone(),
580 expected_wrap_vk,
581 };
582
583 Self {
584 reduce_pipeline,
585 shrink_prover,
586 wrap_prover: Arc::new(OnceCell::new()),
587 wrap_prover_init: Arc::new(wrap_prover_init),
588 prover_data,
589 artifact_client,
590 }
591 })
592 .await
593 .unwrap()
594 }
595
596 pub fn recursion_prover_pipeline(&self) -> &Arc<RecursionProvePipeline<A, C>> {
597 self.reduce_pipeline.second()
598 }
599
600 pub async fn submit_prove_shard(
601 &self,
602 program: Arc<RecursionProgram<SP1Field>>,
603 witness: SP1CircuitWitness,
604 output: Artifact,
605 metrics: ProverMetrics,
606 ) -> Result<RecursionProveSubmitHandle<A, C>, SubmitError> {
607 self.recursion_prover_pipeline()
608 .submit(Ok(RecursionTask {
609 program,
610 witness,
611 output,
612 metrics,
613 range_proofs_to_cleanup: None,
614 }))
615 .await
616 }
617
618 pub async fn submit_recursion_reduce(
619 &self,
620 request: RawTaskRequest,
621 ) -> Result<ReduceSubmitHandle<A, C>, TaskError> {
622 let input = ReduceTaskRequest::from_raw(request)?;
623 let handle = self.reduce_pipeline.submit(input).await?;
624 Ok(handle)
625 }
626
627 async fn wrap_prover(&self) -> Result<Arc<WrapProver<C>>, TaskError> {
628 let wrap_prover_init = self.wrap_prover_init.clone();
629 let prover_data = self.prover_data.clone();
630
631 let wrap_prover = self
632 .wrap_prover
633 .get_or_try_init(|| async move {
634 let wrap_prover_init = wrap_prover_init.clone();
635 let prover_data = prover_data.clone();
636 tokio::task::spawn_blocking(move || {
637 let expected_wrap_vk = wrap_prover_init.expected_wrap_vk.clone();
638 let wrap_air_prover = wrap_prover_init.wrap_air_prover.build();
639 let wrap_air_permits = wrap_prover_init.wrap_air_prover.permits();
640 let wrap_prover = WrapProver::new(
641 wrap_air_prover,
642 wrap_air_permits,
643 prover_data,
644 wrap_prover_init.config.clone(),
645 wrap_prover_init.shrink_shape.clone(),
646 );
647
648 if wrap_prover.prover_data.recursion_vks.vk_verification()
649 && wrap_prover.verifying_key != expected_wrap_vk
650 {
651 return Err(TaskError::Fatal(anyhow::anyhow!(
652 "Wrap vk mismatch, expected: {:?}, got: {:?}",
653 expected_wrap_vk,
654 wrap_prover.verifying_key
655 )));
656 }
657
658 Ok(Arc::new(wrap_prover))
659 })
660 .await
661 .map_err(|err| TaskError::Fatal(anyhow::anyhow!(err)))?
662 })
663 .await?;
664
665 Ok(wrap_prover.clone())
666 }
667
668 pub async fn run_shrink_wrap(&self, request: RawTaskRequest) -> Result<(), TaskError> {
669 let RawTaskRequest { inputs, outputs, .. } = request;
670 let [compress_proof_artifact] = inputs.try_into().unwrap();
671 let [wrap_proof_artifact] = outputs.try_into().unwrap();
672
673 let compress_proof = self
674 .artifact_client
675 .download(&compress_proof_artifact)
676 .instrument(tracing::debug_span!("download compress proof"))
677 .await?;
678
679 let shrink_proof = self
680 .shrink_prover
681 .prove(compress_proof)
682 .instrument(tracing::info_span!("prove shrink"))
683 .await?;
684
685 tracing::debug_span!("verify shrink proof")
686 .in_scope(|| self.shrink_prover.verify(&shrink_proof))?;
687
688 let wrap_prover = self.wrap_prover().await?;
689 let wrap_proof =
690 wrap_prover.prove(shrink_proof).instrument(tracing::info_span!("prove wrap")).await?;
691
692 tracing::debug_span!("verify wrap proof").in_scope(|| wrap_prover.verify(&wrap_proof))?;
693
694 self.artifact_client
695 .upload(&wrap_proof_artifact, wrap_proof)
696 .instrument(tracing::debug_span!("upload wrap proof"))
697 .await?;
698
699 Ok(())
700 }
701
702 pub async fn run_groth16(&self, request: RawTaskRequest) -> Result<(), TaskError> {
703 let RawTaskRequest { inputs, outputs, .. } = request;
704 let [wrap_proof_artifact] = inputs.try_into().unwrap();
705 let [groth16_proof_artifact] = outputs.try_into().unwrap();
706
707 let wrap_proof: SP1WrapProof<SP1OuterGlobalContext, SP1PcsProofOuter> = self
708 .artifact_client
709 .download(&wrap_proof_artifact)
710 .instrument(tracing::debug_span!("download wrap proof"))
711 .await?;
712
713 let build_dir = try_build_groth16_artifacts_dir(&wrap_proof.vk, &wrap_proof.proof)
714 .await
715 .map_err(TaskError::Fatal)?;
716
717 let groth16_proof = tokio::task::spawn_blocking(move || -> Result<_, anyhow::Error> {
718 let SP1WrapProof { vk, proof } = wrap_proof;
719 let input = SP1ShapedWitnessValues {
720 vks_and_proofs: vec![(vk, proof.clone())],
721 is_complete: true,
722 };
723 let pv: &RecursionPublicValues<SP1Field> = proof.public_values.as_slice().borrow();
724 let vkey_hash = koalabears_to_bn254(&pv.sp1_vk_digest);
725 let committed_values_digest_bytes: [SP1Field; 32] =
726 words_to_bytes(&pv.committed_value_digest).try_into().map_err(|_| {
727 anyhow::anyhow!(
728 "committed_value_digest has invalid length, expected exactly 32 elements"
729 )
730 })?;
731 let committed_values_digest = koalabear_bytes_to_bn254(&committed_values_digest_bytes);
732 let exit_code = Bn254Fr::from_canonical_u32(pv.exit_code.as_canonical_u32());
733 let proof_nonce = koalabears_proof_nonce_to_bn254(&pv.proof_nonce);
734 let vk_root = koalabears_to_bn254(&pv.vk_root);
735 let witness = {
736 let mut witness = OuterWitness::default();
737 input.write(&mut witness);
738 witness.write_committed_values_digest(committed_values_digest);
739 witness.write_vkey_hash(vkey_hash);
740 witness.write_exit_code(exit_code);
741 witness.write_vk_root(vk_root);
742 witness.write_proof_nonce(proof_nonce);
743 witness
744 };
745 let prover = Groth16Bn254Prover::new();
746 let proof = prover.prove(witness, &build_dir);
747 prover
748 .verify(
749 &proof,
750 &vkey_hash.as_canonical_biguint(),
751 &committed_values_digest.as_canonical_biguint(),
752 &exit_code.as_canonical_biguint(),
753 &vk_root.as_canonical_biguint(),
754 &proof_nonce.as_canonical_biguint(),
755 &build_dir,
756 )
757 .map_err(|e| anyhow::anyhow!("Failed to verify groth16 wrap proof: {}", e))?;
758 Ok(proof)
759 })
760 .instrument(tracing::info_span!("prove groth16"))
761 .await
762 .map_err(|e| TaskError::Fatal(anyhow::anyhow!("Groth16 proof task panicked: {}", e)))?
763 .map_err(TaskError::Fatal)?;
764
765 self.artifact_client
766 .upload(&groth16_proof_artifact, groth16_proof)
767 .instrument(tracing::debug_span!("upload groth16 proof"))
768 .await?;
769 Ok(())
770 }
771
772 pub async fn run_plonk(&self, request: RawTaskRequest) -> Result<(), TaskError> {
773 let RawTaskRequest { inputs, outputs, .. } = request;
774 let [wrap_proof_artifact] = inputs.try_into().unwrap();
775 let [plonk_proof_artifact] = outputs.try_into().unwrap();
776 let wrap_proof: SP1WrapProof<SP1OuterGlobalContext, SP1PcsProofOuter> = self
777 .artifact_client
778 .download(&wrap_proof_artifact)
779 .instrument(tracing::debug_span!("download wrap proof"))
780 .await?;
781
782 let build_dir = try_build_plonk_artifacts_dir(&wrap_proof.vk, &wrap_proof.proof).await?;
783
784 let plonk_proof = tokio::task::spawn_blocking(move || -> Result<_, anyhow::Error> {
785 let SP1WrapProof { vk: wrap_vk, proof: wrap_proof } = wrap_proof;
786 let input = SP1ShapedWitnessValues {
787 vks_and_proofs: vec![(wrap_vk.clone(), wrap_proof.clone())],
788 is_complete: true,
789 };
790 let pv: &RecursionPublicValues<SP1Field> = wrap_proof.public_values.as_slice().borrow();
791 let vkey_hash = koalabears_to_bn254(&pv.sp1_vk_digest);
792 let committed_values_digest_bytes: [SP1Field; 32] =
793 words_to_bytes(&pv.committed_value_digest).try_into().map_err(|_| {
794 anyhow::anyhow!(
795 "committed_value_digest has invalid length, expected exactly 32 elements"
796 )
797 })?;
798 let committed_values_digest = koalabear_bytes_to_bn254(&committed_values_digest_bytes);
799 let exit_code = Bn254Fr::from_canonical_u32(pv.exit_code.as_canonical_u32());
800 let vk_root = koalabears_to_bn254(&pv.vk_root);
801 let proof_nonce = koalabears_proof_nonce_to_bn254(&pv.proof_nonce);
802 let witness = {
803 let mut witness = OuterWitness::default();
804 input.write(&mut witness);
805 witness.write_committed_values_digest(committed_values_digest);
806 witness.write_vkey_hash(vkey_hash);
807 witness.write_exit_code(exit_code);
808 witness.write_vk_root(vk_root);
809 witness.write_proof_nonce(proof_nonce);
810 witness
811 };
812 let prover = PlonkBn254Prover::new();
813 let proof = prover.prove(witness, &build_dir);
814 prover
815 .verify(
816 &proof,
817 &vkey_hash.as_canonical_biguint(),
818 &committed_values_digest.as_canonical_biguint(),
819 &exit_code.as_canonical_biguint(),
820 &vk_root.as_canonical_biguint(),
821 &proof_nonce.as_canonical_biguint(),
822 &build_dir,
823 )
824 .map_err(|e| anyhow::anyhow!("Failed to verify plonk wrap proof: {}", e))?;
825 Ok(proof)
826 })
827 .instrument(tracing::info_span!("prove plonk"))
828 .await
829 .map_err(|e| TaskError::Fatal(anyhow::anyhow!("Plonk proof task panicked: {}", e)))?
830 .map_err(TaskError::Fatal)?;
831
832 self.artifact_client
833 .upload(&plonk_proof_artifact, plonk_proof)
834 .instrument(tracing::debug_span!("upload plonk proof"))
835 .await?;
836 Ok(())
837 }
838
839 #[inline]
840 #[must_use]
841 pub fn recursion_vk_root(&self) -> [SP1Field; DIGEST_SIZE] {
842 self.prover_data.recursion_vks.root()
843 }
844
845 #[must_use]
846 pub fn vk_verification(&self) -> bool {
847 self.prover_data.vk_verification()
848 }
849
850 #[must_use]
851 pub fn get_normalize_witness(
852 &self,
853 common_input: &CommonProverInput,
854 proof: &ShardProof<SP1GlobalContext, SP1PcsProofInner>,
855 is_complete: bool,
856 is_precompile: bool,
857 ) -> SP1NormalizeWitnessValues<SP1GlobalContext, SP1PcsProofInner> {
858 let (num_deferred_proofs, reconstruct_deferred_digest) = if is_precompile {
866 (SP1Field::zero(), DeferredInputs::initial_deferred_digest())
867 } else {
868 (
869 SP1Field::from_canonical_usize(common_input.num_deferred_proofs),
870 common_input.deferred_digest.map(SP1Field::from_canonical_u32),
871 )
872 };
873 SP1NormalizeWitnessValues {
874 vk: common_input.vk.vk.clone(),
875 shard_proofs: vec![proof.clone()],
876 is_complete,
877 vk_root: self.recursion_vk_root(),
878 reconstruct_deferred_digest,
879 num_deferred_proofs,
880 }
881 }
882
883 pub fn reduce_shape(&self) -> &SP1RecursionProofShape {
884 &self.prover_data.reduce_shape
885 }
886}
887
888type CompressKeys<C> = (
889 Arc<ProvingKey<SP1GlobalContext, RecursionSC, <C as SP1ProverComponents>::RecursionProver>>,
890 MachineVerifyingKey<SP1GlobalContext>,
891);
892
893pub struct RecursionProverData<C: SP1ProverComponents> {
894 recursion_vks: RecursionVks,
895 reduce_shape: SP1RecursionProofShape,
896 compose_programs: BTreeMap<usize, Arc<RecursionProgram<SP1Field>>>,
897 compose_keys: BTreeMap<usize, CompressKeys<C>>,
898 deferred_program: Arc<RecursionProgram<SP1Field>>,
899 deferred_keys: Option<CompressKeys<C>>,
900}
901
902impl<C: SP1ProverComponents> RecursionProverData<C> {
903 pub fn vk_verification(&self) -> bool {
904 self.recursion_vks.vk_verification()
905 }
906
907 pub fn recursion_vks(&self) -> &RecursionVks {
908 &self.recursion_vks
909 }
910
911 pub fn append_merkle_proofs_to_witness(
912 &self,
913 input: SP1ShapedWitnessValues<SP1GlobalContext, SP1PcsProofInner>,
914 merkle_proofs: Vec<MerkleProof<SP1GlobalContext>>,
915 ) -> Result<SP1CompressWithVKeyWitnessValues<SP1PcsProofInner>, TaskError> {
916 let values = if self.recursion_vks.vk_verification() {
917 input.vks_and_proofs.iter().map(|(vk, _)| vk.hash_koalabear()).collect()
918 } else {
919 let num_vks = self.recursion_vks.num_keys();
920 input
921 .vks_and_proofs
922 .iter()
923 .map(|(vk, _)| {
924 let vk_digest = vk.hash_koalabear();
925 let index = (vk_digest[0].as_canonical_u32() as usize) % num_vks;
926 [SP1Field::from_canonical_u32(index as u32); DIGEST_SIZE]
927 })
928 .collect()
929 };
930
931 let merkle_val = SP1MerkleProofWitnessValues {
932 root: self.recursion_vks.root(),
933 values,
934 vk_merkle_proofs: merkle_proofs,
935 };
936
937 Ok(SP1CompressWithVKeyWitnessValues { compress_val: input, merkle_val })
938 }
939
940 pub fn witness_stream(
941 &self,
942 witness: &SP1CircuitWitness,
943 ) -> Result<VecDeque<Block<SP1Field>>, TaskError> {
944 let mut witness_stream = Vec::new();
945 match witness {
946 SP1CircuitWitness::Core(input) => {
947 Witnessable::<InnerConfig>::write(&input, &mut witness_stream);
948 }
949 SP1CircuitWitness::Deferred(input) => {
950 Witnessable::<InnerConfig>::write(&input, &mut witness_stream);
951 }
952 SP1CircuitWitness::Compress(input) => {
953 Witnessable::<InnerConfig>::write(&input, &mut witness_stream);
954 }
955 SP1CircuitWitness::Shrink(input) => {
956 Witnessable::<InnerConfig>::write(&input, &mut witness_stream);
957 }
958 SP1CircuitWitness::Wrap(input) => {
959 Witnessable::<WrapConfig>::write(&input, &mut witness_stream);
960 }
961 }
962 Ok(witness_stream.into())
963 }
964
965 pub fn deferred_program(&self) -> &Arc<RecursionProgram<SP1Field>> {
966 &self.deferred_program
967 }
968}
969
970fn dummy_compose_input<C: SP1ProverComponents>(
971 shape: &SP1RecursionProofShape,
972 arity: usize,
973 height: usize,
974) -> SP1CompressWithVKeyWitnessValues<SP1PcsProofInner> {
975 let verifier = C::compress_verifier();
976 shape.dummy_input(
977 arity,
978 height,
979 verifier.shard_verifier().machine().chips().iter().cloned().collect::<BTreeSet<_>>(),
980 verifier.max_log_row_count(),
981 *verifier.fri_config(),
982 verifier.log_stacking_height() as usize,
983 )
984}
985
986pub struct ShrinkProver<C: SP1ProverComponents> {
987 prover: Arc<C::RecursionProver>,
988 permits: ProverSemaphore,
989 program: Arc<RecursionProgram<SP1Field>>,
990 pub verifying_key: MachineVerifyingKey<SP1GlobalContext>,
991 prover_data: Arc<RecursionProverData<C>>,
992 pub shrink_shape: BTreeMap<String, usize>,
993}
994
995impl<C: SP1ProverComponents> ShrinkProver<C> {
996 fn new(
997 prover: Arc<C::RecursionProver>,
998 permits: ProverSemaphore,
999 prover_data: Arc<RecursionProverData<C>>,
1000 config: SP1RecursionProverConfig,
1001 ) -> Self {
1002 let verifier = C::compress_verifier();
1003 let input = prover_data.reduce_shape.dummy_input(
1004 1,
1005 prover_data.recursion_vks.height(),
1006 verifier.shard_verifier().machine().chips().iter().cloned().collect::<BTreeSet<_>>(),
1007 verifier.max_log_row_count(),
1008 *verifier.fri_config(),
1009 verifier.log_stacking_height() as usize,
1010 );
1011 let program = Arc::new(shrink_program_from_input(
1012 &recursive_verifier(verifier.shard_verifier()),
1013 config.vk_verification,
1014 &input,
1015 ));
1016
1017 let (pk, vk) = {
1018 let (prover, program, permits) = (prover.clone(), program.clone(), permits.clone());
1019 let (tx, rx) = oneshot::channel();
1020 tokio::task::spawn(async move {
1021 tx.send(prover.setup(program.clone(), permits.clone()).await).ok()
1022 });
1023 rx.blocking_recv().unwrap()
1024 };
1025 let shrink_shape = {
1026 let (tx, rx) = oneshot::channel();
1027 tokio::task::spawn(async move {
1028 let heights = <C::RecursionProver as AirProver<
1029 SP1GlobalContext,_
1030 >>::preprocessed_table_heights(pk.pk)
1031 .await;
1032 tx.send(heights).ok();
1033 });
1034 rx.blocking_recv().unwrap()
1035 };
1036 Self { prover, permits, program, verifying_key: vk, prover_data, shrink_shape }
1037 }
1038
1039 pub(crate) async fn setup(
1040 self: Arc<Self>,
1041 program: Arc<RecursionProgram<SP1Field>>,
1042 ) -> MachineVerifyingKey<SP1GlobalContext> {
1043 self.prover.setup(program, self.permits.clone()).await.1
1044 }
1045
1046 async fn prove(
1047 &self,
1048 compressed_proof: SP1RecursionProof<SP1GlobalContext, SP1PcsProofInner>,
1049 ) -> Result<SP1RecursionProof<SP1GlobalContext, SP1PcsProofInner>, TaskError> {
1050 let execution_record = {
1051 let mut runtime =
1052 Executor::<SP1Field, SP1ExtensionField, _>::new(self.program.clone(), inner_perm());
1053 runtime.witness_stream = self.prover_data.witness_stream(&{
1054 let SP1RecursionProof { vk, proof, vk_merkle_proof } = compressed_proof;
1055 let input =
1056 SP1ShapedWitnessValues { vks_and_proofs: vec![(vk, proof)], is_complete: true };
1057 SP1CircuitWitness::Shrink(
1058 self.prover_data
1059 .append_merkle_proofs_to_witness(input, vec![vk_merkle_proof])?,
1060 )
1061 })?;
1062 runtime.run().map_err(|e| TaskError::Fatal(e.into()))?;
1063 runtime.record
1064 };
1065
1066 let (vk, proof, _permit) = self
1067 .prover
1068 .setup_and_prove_shard(
1069 self.program.clone(),
1070 execution_record,
1071 Some(self.verifying_key.clone()),
1072 self.permits.clone(),
1073 )
1074 .await;
1075 let vk_merkle_proof = self.prover_data.recursion_vks.open(&vk)?.1;
1076 Ok(SP1RecursionProof { vk: self.verifying_key.clone(), proof, vk_merkle_proof })
1077 }
1078
1079 fn verify(
1080 &self,
1081 shrink_proof: &SP1RecursionProof<SP1GlobalContext, SP1PcsProofInner>,
1082 ) -> Result<(), TaskError> {
1083 let SP1RecursionProof { vk, proof, vk_merkle_proof } = shrink_proof;
1084 let mut challenger = SP1GlobalContext::default_challenger();
1085 vk.observe_into(&mut challenger);
1086 C::shrink_verifier()
1087 .verify_shard(vk, proof, &mut challenger)
1088 .map_err(|e| TaskError::Fatal(e.into()))?;
1089
1090 self.prover_data.recursion_vks.verify(vk_merkle_proof, vk)
1091 }
1092}
1093
1094pub struct WrapProver<C: SP1ProverComponents> {
1095 prover: Arc<C::WrapProver>,
1096 permits: ProverSemaphore,
1097 program: Arc<RecursionProgram<SP1Field>>,
1098 pub verifying_key: MachineVerifyingKey<SP1OuterGlobalContext>,
1099 prover_data: Arc<RecursionProverData<C>>,
1100}
1101
1102impl<C: SP1ProverComponents> WrapProver<C> {
1103 pub fn new(
1104 prover: Arc<C::WrapProver>,
1105 permits: ProverSemaphore,
1106 prover_data: Arc<RecursionProverData<C>>,
1107 config: SP1RecursionProverConfig,
1108 shrink_shape: BTreeMap<String, usize>,
1109 ) -> Self {
1110 let verifier = C::shrink_verifier();
1111 let shrink_proof_shape =
1112 SP1RecursionProofShape { shape: RecursionShape::new(shrink_shape) };
1113 let wrap_input = shrink_proof_shape.dummy_input(
1114 1,
1115 prover_data.recursion_vks.height(),
1116 verifier.shard_verifier().machine().chips().iter().cloned().collect::<BTreeSet<_>>(),
1117 verifier.max_log_row_count(),
1118 *verifier.fri_config(),
1119 verifier.log_stacking_height() as usize,
1120 );
1121
1122 let program = Arc::new(wrap_program_from_input(
1123 &recursive_verifier(verifier.shard_verifier()),
1124 config.vk_verification,
1125 &wrap_input,
1126 ));
1127 let (_, verifying_key) = {
1128 let (prover, program, permits) = (prover.clone(), program.clone(), permits.clone());
1129 let (tx, rx) = oneshot::channel();
1130 tokio::task::spawn(async move {
1131 tx.send(prover.setup(program.clone(), permits).await).ok();
1132 });
1133 rx.blocking_recv().unwrap()
1134 };
1135
1136 Self { prover, permits, program, verifying_key, prover_data }
1137 }
1138
1139 pub async fn prove(
1140 &self,
1141 shrunk_proof: SP1RecursionProof<SP1GlobalContext, SP1PcsProofInner>,
1142 ) -> Result<SP1WrapProof<SP1OuterGlobalContext, SP1PcsProofOuter>, TaskError> {
1143 let execution_record = {
1144 let mut runtime =
1145 Executor::<SP1Field, SP1ExtensionField, _>::new(self.program.clone(), inner_perm());
1146 runtime.witness_stream = self.prover_data.witness_stream(&{
1147 let SP1RecursionProof { vk, proof, vk_merkle_proof } = shrunk_proof;
1148 let input =
1149 SP1ShapedWitnessValues { vks_and_proofs: vec![(vk, proof)], is_complete: true };
1150 SP1CircuitWitness::Wrap(
1151 self.prover_data
1152 .append_merkle_proofs_to_witness(input, vec![vk_merkle_proof.clone()])?,
1153 )
1154 })?;
1155 runtime.run().map_err(|e| TaskError::Fatal(e.into()))?;
1156 runtime.record
1157 };
1158
1159 let (_, proof, _permit) = self
1160 .prover
1161 .setup_and_prove_shard(
1162 self.program.clone(),
1163 execution_record,
1164 Some(self.verifying_key.clone()),
1165 self.permits.clone(),
1166 )
1167 .await;
1168
1169 Ok(SP1WrapProof { vk: self.verifying_key.clone(), proof })
1170 }
1171
1172 fn verify(
1173 &self,
1174 wrapped_proof: &SP1WrapProof<SP1OuterGlobalContext, SP1PcsProofOuter>,
1175 ) -> Result<(), TaskError> {
1176 let SP1WrapProof { vk, proof } = wrapped_proof;
1177 let mut challenger = SP1OuterGlobalContext::default_challenger();
1178 vk.observe_into(&mut challenger);
1179 C::wrap_verifier()
1180 .verify_shard(vk, proof, &mut challenger)
1181 .map_err(|e| TaskError::Fatal(e.into()))
1182 }
1183}