Skip to main content

sp1_prover/worker/prover/
deferred.rs

1use std::sync::Arc;
2
3use slop_algebra::AbstractField;
4use slop_futures::pipeline::{AsyncEngine, AsyncWorker, Pipeline, SubmitHandle};
5
6use sp1_hypercube::HashableKey;
7use sp1_primitives::SP1Field;
8use sp1_prover_types::{Artifact, ArtifactClient};
9use sp1_recursion_circuit::machine::SP1DeferredWitnessValues;
10
11use crate::{
12    worker::{
13        CommonProverInput, ProverMetrics, RawTaskRequest, SP1DeferredData, SP1RecursionProver,
14        TaskContext, TaskError, TaskMetadata,
15    },
16    SP1CircuitWitness, SP1ProverComponents,
17};
18
19#[derive(Clone)]
20pub struct SP1DeferredProverConfig {
21    /// The number of deferred workers.
22    pub num_deferred_workers: usize,
23    /// The buffer size for the deferred workers.
24    pub deferred_buffer_size: usize,
25}
26
27pub type SP1DeferredEngine<A, C> = AsyncEngine<
28    RecursionDeferredTaskRequest,
29    Result<TaskMetadata, TaskError>,
30    SP1DeferredWorker<A, C>,
31>;
32
33pub type SP1DeferredSubmitHandle<A, C> = SubmitHandle<SP1DeferredEngine<A, C>>;
34
35pub struct SP1DeferredProver<A, C: SP1ProverComponents> {
36    engine: Arc<SP1DeferredEngine<A, C>>,
37}
38
39impl<A, C: SP1ProverComponents> Clone for SP1DeferredProver<A, C> {
40    fn clone(&self) -> Self {
41        Self { engine: self.engine.clone() }
42    }
43}
44
45impl<A: ArtifactClient, C: SP1ProverComponents> SP1DeferredProver<A, C> {
46    pub fn new(
47        config: SP1DeferredProverConfig,
48        recursion_prover: SP1RecursionProver<A, C>,
49        artifact_client: A,
50    ) -> Self {
51        let deferred_workers = (0..config.num_deferred_workers)
52            .map(|_| SP1DeferredWorker {
53                recursion_prover: recursion_prover.clone(),
54                artifact_client: artifact_client.clone(),
55            })
56            .collect();
57        let engine = AsyncEngine::new(deferred_workers, config.deferred_buffer_size);
58        Self { engine: Arc::new(engine) }
59    }
60
61    pub(super) async fn submit(
62        &self,
63        task: RawTaskRequest,
64    ) -> Result<SP1DeferredSubmitHandle<A, C>, TaskError> {
65        let task = RecursionDeferredTaskRequest::from_raw(task)?;
66        let handle = self.engine.submit(task).await?;
67        Ok(handle)
68    }
69}
70
71pub struct SP1DeferredWorker<A, C: SP1ProverComponents> {
72    recursion_prover: SP1RecursionProver<A, C>,
73    artifact_client: A,
74}
75
76pub struct RecursionDeferredTaskRequest {
77    /// The common input artifact.
78    pub common_input: Artifact,
79    /// The deferred data artifact.
80    pub deferred_data: Artifact,
81    // The output artifact.
82    pub output: Artifact,
83    /// The task context.
84    pub context: TaskContext,
85}
86
87impl RecursionDeferredTaskRequest {
88    pub fn from_raw(request: RawTaskRequest) -> Result<Self, TaskError> {
89        let RawTaskRequest { inputs, mut outputs, context } = request;
90        let [common_input, deferred_data] = inputs
91            .try_into()
92            .map_err(|_| TaskError::Fatal(anyhow::anyhow!("Invalid input length")))?;
93        let output =
94            outputs.pop().ok_or(TaskError::Fatal(anyhow::anyhow!("No output artifact")))?;
95
96        Ok(RecursionDeferredTaskRequest { common_input, deferred_data, output, context })
97    }
98
99    pub fn into_raw(self) -> Result<RawTaskRequest, TaskError> {
100        let RecursionDeferredTaskRequest { common_input, deferred_data, output, context } = self;
101
102        let inputs = vec![common_input, deferred_data];
103        let raw_task_request = RawTaskRequest { inputs, outputs: vec![output], context };
104        Ok(raw_task_request)
105    }
106}
107
108impl<A: ArtifactClient, C: SP1ProverComponents>
109    AsyncWorker<RecursionDeferredTaskRequest, Result<TaskMetadata, TaskError>>
110    for SP1DeferredWorker<A, C>
111{
112    async fn call(&self, input: RecursionDeferredTaskRequest) -> Result<TaskMetadata, TaskError> {
113        let RecursionDeferredTaskRequest { common_input, deferred_data, output, .. } = input;
114
115        // Download the inputs
116        let (common_input, deferred_data) = tokio::try_join!(
117            self.artifact_client.download::<CommonProverInput>(&common_input),
118            self.artifact_client.download::<SP1DeferredData>(&deferred_data),
119        )?;
120
121        let SP1DeferredData {
122            input,
123            start_reconstruct_deferred_digest,
124            deferred_proof_index,
125            vk_merkle_proofs,
126        } = deferred_data;
127
128        let input = self
129            .recursion_prover
130            .prover_data
131            .append_merkle_proofs_to_witness(input, vk_merkle_proofs)?;
132
133        let nonce = common_input.nonce.map(SP1Field::from_canonical_u32);
134
135        let witness = SP1DeferredWitnessValues {
136            vks_and_proofs: input.compress_val.vks_and_proofs,
137            vk_merkle_data: input.merkle_val,
138            start_reconstruct_deferred_digest,
139            sp1_vk_digest: common_input.vk.hash_koalabear(),
140            end_pc: common_input.vk.vk.pc_start,
141            proof_nonce: nonce,
142            deferred_proof_index,
143        };
144
145        let witness = SP1CircuitWitness::Deferred(witness);
146
147        let program = self.recursion_prover.prover_data.deferred_program().clone();
148
149        // Get the deferred proof
150        let metrics = ProverMetrics::new();
151        let metadata = self
152            .recursion_prover
153            .submit_prove_shard(program, witness, output, metrics)
154            .await?
155            .await
156            .map_err(|e| TaskError::Fatal(e.into()))??;
157
158        Ok(metadata)
159    }
160}