Skip to main content

sp1_prover/worker/controller/
mod.rs

1mod compress;
2mod core;
3mod deferred;
4mod global;
5mod precompiles;
6mod splicing;
7mod vk_tree;
8
9pub use compress::*;
10pub use core::*;
11pub use deferred::*;
12pub use global::*;
13pub use precompiles::*;
14pub use splicing::*;
15pub use vk_tree::*;
16
17use lru::LruCache;
18
19use slop_algebra::PrimeField32;
20
21use sp1_core_executor::SP1CoreOpts;
22use sp1_core_executor_runner::MinimalExecutorRunner;
23use sp1_core_machine::{executor::ExecutionOutput, io::SP1Stdin};
24use sp1_hypercube::{
25    air::{PublicValues, PROOF_NONCE_NUM_WORDS},
26    SP1PcsProofInner, SP1VerifyingKey, ShardProof,
27};
28use sp1_primitives::{io::SP1PublicValues, SP1GlobalContext};
29use sp1_prover_types::{
30    network_base_types::ProofMode, Artifact, ArtifactClient, ArtifactType, TaskStatus, TaskType,
31};
32use sp1_verifier::{ProofFromNetwork, SP1Proof};
33use std::{borrow::Borrow, sync::Arc};
34use tokio::{
35    sync::{oneshot, Mutex, MutexGuard},
36    task::JoinSet,
37};
38use tracing::Instrument;
39
40use crate::{
41    verify::SP1Verifier,
42    worker::{MessageReceiver, RawTaskRequest, TaskContext, TaskError, TaskId, WorkerClient},
43    SP1_CIRCUIT_VERSION,
44};
45
46#[derive(Clone)]
47pub struct MinimalExecutorCache(Arc<Mutex<Option<MinimalExecutorRunner>>>);
48
49impl MinimalExecutorCache {
50    pub fn empty() -> Self {
51        Self(Arc::new(Mutex::new(None)))
52    }
53
54    pub async fn lock(&self) -> MutexGuard<'_, Option<MinimalExecutorRunner>> {
55        self.0.lock().await
56    }
57}
58
59#[derive(Clone)]
60pub struct SP1ControllerConfig {
61    pub opts: SP1CoreOpts,
62    pub num_splicing_workers: usize,
63    pub splicing_buffer_size: usize,
64    pub max_reduce_arity: usize,
65    pub number_of_send_splice_workers_per_splice: usize,
66    pub send_splice_input_buffer_size_per_splice: usize,
67    pub use_fixed_pk: bool,
68    pub global_memory_buffer_size: usize,
69}
70
71pub struct SP1Controller<A, W> {
72    config: SP1ControllerConfig,
73    setup_cache: Arc<Mutex<LruCache<Artifact, SP1VerifyingKey>>>,
74    pub(crate) artifact_client: A,
75    pub(crate) worker_client: W,
76    pub(crate) verifier: SP1Verifier,
77    minimal_executor_cache: Option<MinimalExecutorCache>,
78}
79
80impl<A, W> SP1Controller<A, W>
81where
82    A: ArtifactClient,
83    W: WorkerClient,
84{
85    pub fn new(
86        config: SP1ControllerConfig,
87        artifact_client: A,
88        worker_client: W,
89        verifier: SP1Verifier,
90    ) -> Self {
91        let minimal_executor_cache =
92            if config.use_fixed_pk { Some(MinimalExecutorCache::empty()) } else { None };
93
94        Self {
95            config,
96            setup_cache: Arc::new(Mutex::new(LruCache::new(20.try_into().unwrap()))),
97            artifact_client,
98            worker_client,
99            verifier,
100            minimal_executor_cache,
101        }
102    }
103
104    #[inline]
105    pub const fn opts(&self) -> &SP1CoreOpts {
106        &self.config.opts
107    }
108
109    #[inline]
110    pub const fn max_reduce_arity(&self) -> usize {
111        self.config.max_reduce_arity
112    }
113
114    #[inline]
115    pub const fn global_memory_buffer_size(&self) -> usize {
116        self.config.global_memory_buffer_size
117    }
118
119    pub fn initialize_splicing_engine(&self) -> Arc<SplicingEngine<A, W>> {
120        let splicing_workers = (0..self.config.num_splicing_workers)
121            .map(|_| {
122                SplicingWorker::new(
123                    self.artifact_client.clone(),
124                    self.worker_client.clone(),
125                    self.config.number_of_send_splice_workers_per_splice,
126                    self.config.send_splice_input_buffer_size_per_splice,
127                )
128            })
129            .collect();
130        Arc::new(SplicingEngine::new(splicing_workers, self.config.splicing_buffer_size))
131    }
132
133    /// Execute Risc-V program, and trigger shard proofs for each trace chunk.
134    /// Run the core executor and deferred proof emitter for a `CoreExecute` task. Proof shards
135    /// are streamed back to the consumer via the task's message channel.
136    pub async fn execute(
137        &self,
138        task_id: TaskId,
139        request: CoreExecuteTaskRequest,
140    ) -> Result<ExecutionOutput, TaskError> {
141        let stdin = self.artifact_client.download_stdin::<SP1Stdin>(&request.stdin).await?;
142
143        let deferred_proofs = stdin.proofs.iter().map(|(proof, _)| proof.clone());
144        let deferred_inputs = DeferredInputs::new(deferred_proofs);
145
146        let splicing_engine = self.initialize_splicing_engine();
147        let proof_data_sender =
148            MessageSender::<W, ProofData>::new(self.worker_client.clone(), task_id);
149        let executor = SP1CoreExecutor::new(
150            splicing_engine,
151            self.global_memory_buffer_size(),
152            request.elf,
153            Arc::new(stdin),
154            request.common_input.clone(),
155            self.opts().clone(),
156            request.num_deferred_proofs,
157            request.context.clone(),
158            proof_data_sender.clone(),
159            self.artifact_client.clone(),
160            self.worker_client.clone(),
161            self.minimal_executor_cache.clone(),
162            request.cycle_limit,
163        );
164
165        let mut join_set = JoinSet::<Result<(), TaskError>>::new();
166
167        // Spawn the deferred proof emitter.
168        {
169            let deferred_sender = proof_data_sender.clone();
170            let artifact_client = self.artifact_client.clone();
171            let worker_client = self.worker_client.clone();
172            let common_input_artifact = request.common_input.clone();
173            let context = request.context.clone();
174            join_set.spawn(async move {
175                deferred_inputs
176                    .emit_deferred_tasks(
177                        common_input_artifact,
178                        context,
179                        deferred_sender,
180                        artifact_client,
181                        worker_client,
182                    )
183                    .await
184            });
185        }
186
187        // Run the executor inline (not spawned — it uses self's executor cache).
188        let output = executor.execute().await;
189
190        // Wait for the deferred emitter to finish.
191        while let Some(result) = join_set.join_next().await {
192            result.map_err(|e| TaskError::Fatal(e.into()))??;
193        }
194
195        let output = output?;
196        if let Some(limit) = request.cycle_limit {
197            if limit > 0 && output.cycles > limit {
198                return Err(TaskError::Fatal(anyhow::anyhow!(
199                    "cycle limit exceeded: {} > {}",
200                    output.cycles,
201                    limit
202                )));
203            }
204        }
205        self.artifact_client.upload(&request.execution_output, &output).await?;
206        Ok(output)
207    }
208
209    pub async fn run(&self, request: RawTaskRequest) -> Result<ExecutionOutput, TaskError> {
210        let RawTaskRequest { inputs, outputs, context } = request;
211        let elf = inputs[0].clone();
212        let stdin_artifact = inputs[1].clone();
213        let mode_artifact = inputs[2].clone();
214        let cycle_limit = inputs.get(3).and_then(|a| a.clone().to_id().parse::<u64>().ok());
215        let proof_nonce = inputs.get(4);
216        let [output] = outputs.try_into().unwrap();
217        let mode = {
218            let parsed =
219                mode_artifact.to_id().parse::<i32>().map_err(|e| TaskError::Fatal(e.into()))?;
220            ProofMode::try_from(parsed).map_err(|e| TaskError::Fatal(e.into()))?
221        };
222
223        let stdin_download_handle =
224            self.artifact_client.download_stdin::<SP1Stdin>(&stdin_artifact);
225
226        let proof_nonce = match proof_nonce {
227            Some(artifact) => {
228                self.artifact_client.download::<[u32; PROOF_NONCE_NUM_WORDS]>(artifact).await?
229            }
230            None => [0u32; PROOF_NONCE_NUM_WORDS],
231        };
232
233        let vkey_download_handle = tokio::spawn({
234            let artifact_client_clone = self.artifact_client.clone();
235            let worker_client_clone = self.worker_client.clone();
236            let elf_clone = elf.clone();
237            let setup_cache = self.setup_cache.clone();
238            let context = context.clone();
239            async move {
240                let mut lock = setup_cache.lock().await;
241                let vkey = lock.get(&elf_clone).cloned();
242                drop(lock);
243                let vk = if let Some(vkey) = vkey {
244                    tracing::debug!("setup cache hit");
245                    vkey.clone()
246                } else {
247                    let vk_artifact = artifact_client_clone.create_artifact()?;
248                    let setup_request = RawTaskRequest {
249                        inputs: vec![elf_clone.clone()],
250                        outputs: vec![vk_artifact.clone()],
251                        context: context.clone(),
252                    };
253
254                    tracing::debug!("submitting setup task");
255                    let setup_id =
256                        worker_client_clone.submit_task(TaskType::SetupVkey, setup_request).await?;
257
258                    let subscriber =
259                        worker_client_clone.subscriber(context.proof_id.clone()).await?.per_task();
260                    let status = subscriber
261                        .wait_task(setup_id)
262                        .instrument(tracing::debug_span!("setup task"))
263                        .await
264                        .map_err(|e| TaskError::Fatal(e.into()))?;
265                    if status != TaskStatus::Succeeded {
266                        return Err(TaskError::Fatal(anyhow::anyhow!("setup task failed")));
267                    }
268                    tracing::debug!("setup task succeeded");
269                    let vk =
270                        artifact_client_clone.download::<SP1VerifyingKey>(&vk_artifact).await?;
271                    setup_cache.lock().await.put(elf_clone, vk.clone());
272                    vk
273                };
274                Ok(vk)
275            }
276            .instrument(tracing::debug_span!("setup vkey"))
277        });
278
279        let stdin: SP1Stdin = stdin_download_handle.await?;
280        let vk = vkey_download_handle.await.map_err(|e| TaskError::Fatal(e.into()))??;
281
282        let stdin = Arc::new(stdin);
283
284        let deferred_proofs = stdin.proofs.iter().map(|(proof, _)| proof.clone());
285        let deferred_inputs = DeferredInputs::new(deferred_proofs);
286
287        let num_deferred_proofs = deferred_inputs.num_deferred_proofs();
288        let deferred_digest = deferred_inputs.deferred_digest().map(|x| x.as_canonical_u32());
289        let common_input = CommonProverInput {
290            vk,
291            mode,
292            deferred_digest,
293            num_deferred_proofs,
294            nonce: proof_nonce,
295        };
296        let common_input_artifact = self.artifact_client.create_artifact()?;
297        self.artifact_client.upload(&common_input_artifact.clone(), common_input.clone()).await?;
298
299        // Submit the executor as a CoreExecute task
300        let execution_output_artifact = self.artifact_client.create_artifact()?;
301        let executor_request = CoreExecuteTaskRequest {
302            elf: elf.clone(),
303            stdin: stdin_artifact.clone(),
304            common_input: common_input_artifact.clone(),
305            execution_output: execution_output_artifact.clone(),
306            num_deferred_proofs,
307            cycle_limit,
308            context: context.clone(),
309        };
310        let executor_task_id = self
311            .worker_client
312            .submit_task(TaskType::CoreExecute, executor_request.into_raw()?)
313            .await?;
314
315        let core_proof_rx = MessageReceiver::<ProofData>::new(
316            self.worker_client.subscribe_task_messages(&executor_task_id).await?,
317        );
318
319        let mut join_set = JoinSet::<Result<(), TaskError>>::new();
320
321        let mut core_proof_artifact = None;
322        let mut compress_proof_artifact = None;
323        let mut shrinkwrap_proof_artifact = None;
324        let mut groth16_proof_artifact = None;
325        let mut plonk_proof_artifact = None;
326
327        let (compress_complete_tx, compress_complete_rx) = oneshot::channel();
328
329        if mode == ProofMode::Core {
330            core_proof_artifact = Some(self.artifact_client.create_artifact()?);
331            join_set.spawn(collect_core_proofs(
332                self.worker_client.clone(),
333                self.artifact_client.clone(),
334                core_proof_artifact.clone().unwrap(),
335                context.clone(),
336                core_proof_rx,
337            ));
338        } else {
339            let mut tree = CompressTree::new(self.max_reduce_arity());
340            let artifact_client = self.artifact_client.clone();
341            let worker_client = self.worker_client.clone();
342            let context = context.clone();
343            compress_proof_artifact = Some(self.artifact_client.create_artifact()?);
344            let compress_proof_artifact = compress_proof_artifact.clone().unwrap();
345            join_set.spawn(
346                async move {
347                    tree.reduce_proofs(
348                        context,
349                        compress_proof_artifact.clone(),
350                        core_proof_rx,
351                        &artifact_client,
352                        &worker_client,
353                    )
354                    .await?;
355                    compress_complete_tx.send(()).unwrap();
356                    Ok(())
357                }
358                .instrument(tracing::debug_span!("reduce")),
359            );
360        }
361
362        match mode {
363            ProofMode::Groth16 => {
364                shrinkwrap_proof_artifact = Some(self.artifact_client.create_artifact()?);
365                groth16_proof_artifact = Some(self.artifact_client.create_artifact()?);
366
367                let shrinkwrap_task = RawTaskRequest {
368                    inputs: vec![compress_proof_artifact.clone().unwrap()],
369                    outputs: vec![shrinkwrap_proof_artifact.clone().unwrap()],
370                    context: context.clone(),
371                };
372
373                let groth16_task = RawTaskRequest {
374                    inputs: vec![shrinkwrap_proof_artifact.clone().unwrap()],
375                    outputs: vec![groth16_proof_artifact.clone().unwrap()],
376                    context: context.clone(),
377                };
378
379                let subscriber =
380                    self.worker_client.subscriber(context.proof_id.clone()).await?.per_task();
381                let worker_client = self.worker_client.clone();
382                join_set.spawn(async move {
383                    compress_complete_rx.await.unwrap();
384
385                    let shrinkwrap_task_id =
386                        worker_client.submit_task(TaskType::ShrinkWrap, shrinkwrap_task).await?;
387                    subscriber.wait_task(shrinkwrap_task_id).await?;
388
389                    let groth16_task_id =
390                        worker_client.submit_task(TaskType::Groth16Wrap, groth16_task).await?;
391                    subscriber.wait_task(groth16_task_id).await?;
392                    Ok(())
393                });
394            }
395            ProofMode::Plonk => {
396                shrinkwrap_proof_artifact = Some(self.artifact_client.create_artifact()?);
397                plonk_proof_artifact = Some(self.artifact_client.create_artifact()?);
398
399                let shrinkwrap_task = RawTaskRequest {
400                    inputs: vec![compress_proof_artifact.clone().unwrap()],
401                    outputs: vec![shrinkwrap_proof_artifact.clone().unwrap()],
402                    context: context.clone(),
403                };
404                let plonk_task = RawTaskRequest {
405                    inputs: vec![shrinkwrap_proof_artifact.clone().unwrap()],
406                    outputs: vec![plonk_proof_artifact.clone().unwrap()],
407                    context: context.clone(),
408                };
409
410                let subscriber =
411                    self.worker_client.subscriber(context.proof_id.clone()).await?.per_task();
412                let worker_client = self.worker_client.clone();
413                join_set.spawn(async move {
414                    compress_complete_rx.await.unwrap();
415
416                    let shrinkwrap_task_id =
417                        worker_client.submit_task(TaskType::ShrinkWrap, shrinkwrap_task).await?;
418                    subscriber.wait_task(shrinkwrap_task_id).await?;
419
420                    let plonk_task_id =
421                        worker_client.submit_task(TaskType::PlonkWrap, plonk_task).await?;
422                    subscriber.wait_task(plonk_task_id).await?;
423                    Ok(())
424                });
425            }
426            _ => {}
427        }
428
429        // Spawn a task to wait for the executor CoreExecute task to complete
430        {
431            let subscriber =
432                self.worker_client.subscriber(context.proof_id.clone()).await?.per_task();
433            join_set.spawn(async move {
434                let status = subscriber
435                    .wait_task(executor_task_id)
436                    .instrument(tracing::debug_span!("wait executor"))
437                    .await?;
438                if status != TaskStatus::Succeeded {
439                    return Err(TaskError::Fatal(anyhow::anyhow!("CoreExecute task failed")));
440                }
441                Ok(())
442            });
443        }
444
445        // Wait for all tasks to finish
446        while let Some(result) = join_set.join_next().await {
447            result.map_err(|e| TaskError::Fatal(e.into()))??;
448        }
449
450        // Download the execution output from the executor task's artifact
451        let result: ExecutionOutput =
452            self.artifact_client.download(&execution_output_artifact).await?;
453
454        // Get the proof and wrap it if the mode is either groth16 or plonk.
455        let inner_proof = match mode {
456            ProofMode::Core => {
457                let shard_proofs =
458                    self.artifact_client.download(&core_proof_artifact.clone().unwrap()).await?;
459                SP1Proof::Core(shard_proofs)
460            }
461            ProofMode::Compressed => {
462                let proof = self
463                    .artifact_client
464                    .download(&compress_proof_artifact.clone().unwrap())
465                    .await?;
466                SP1Proof::Compressed(Box::new(proof))
467            }
468            ProofMode::Plonk => {
469                let proof =
470                    self.artifact_client.download(&plonk_proof_artifact.clone().unwrap()).await?;
471                SP1Proof::Plonk(proof)
472            }
473            ProofMode::Groth16 => {
474                let proof =
475                    self.artifact_client.download(&groth16_proof_artifact.clone().unwrap()).await?;
476                SP1Proof::Groth16(proof)
477            }
478            _ => unimplemented!("proof mode not supported: {:?}", mode),
479        };
480
481        // Pair with public values and version
482        let public_values = SP1PublicValues::from(&result.public_value_stream);
483        let proof = ProofFromNetwork {
484            proof: inner_proof,
485            public_values,
486            sp1_version: SP1_CIRCUIT_VERSION.to_string(),
487        };
488
489        // Upload the proof
490        self.artifact_client.upload_proof(&output, proof).await?;
491
492        // Clean up artifacts
493        let artifacts_to_cleanup = vec![
494            Some(common_input_artifact),
495            Some(stdin_artifact),
496            Some(execution_output_artifact),
497            core_proof_artifact,
498            compress_proof_artifact,
499            shrinkwrap_proof_artifact,
500            groth16_proof_artifact,
501            plonk_proof_artifact,
502        ]
503        .into_iter()
504        .flatten()
505        .collect::<Vec<_>>();
506
507        self.artifact_client
508            .delete_batch(&artifacts_to_cleanup, ArtifactType::UnspecifiedArtifactType)
509            .await?;
510
511        Ok(result)
512    }
513}
514
515async fn collect_core_proofs(
516    worker_client: impl WorkerClient,
517    artifact_client: impl ArtifactClient,
518    result_artifact: Artifact,
519    context: TaskContext,
520    mut core_proof_rx: MessageReceiver<ProofData>,
521) -> Result<(), TaskError> {
522    let subscriber = worker_client.subscriber(context.proof_id.clone()).await?.per_task();
523    let mut shard_proofs = Vec::new();
524    while let Some(proof_data) = core_proof_rx.recv().await {
525        let ProofData { task_id, proof, .. } = proof_data;
526        let status = subscriber.wait_task(task_id.clone()).await?;
527        if status != TaskStatus::Succeeded {
528            tracing::error!("core proof task failed: {:?}", task_id);
529            return Err(TaskError::Fatal(anyhow::anyhow!("core proof task failed: {:?}", task_id)));
530        }
531        let proof = artifact_client
532            .download::<ShardProof<SP1GlobalContext, SP1PcsProofInner>>(&proof)
533            .await?;
534        shard_proofs.push(proof);
535    }
536    shard_proofs.sort_by_key(|shard_proof| {
537        let public_values: &PublicValues<[_; 4], [_; 3], [_; 4], _> =
538            shard_proof.public_values.as_slice().borrow();
539        public_values.range()
540    });
541
542    artifact_client.upload(&result_artifact, shard_proofs).await?;
543
544    Ok(())
545}