1mod compress;
2mod core;
3mod deferred;
4mod global;
5mod precompiles;
6mod splicing;
7mod vk_tree;
8
9pub use compress::*;
10pub use core::*;
11pub use deferred::*;
12pub use global::*;
13pub use precompiles::*;
14pub use splicing::*;
15pub use vk_tree::*;
16
17use lru::LruCache;
18
19use slop_algebra::PrimeField32;
20
21use sp1_core_executor::SP1CoreOpts;
22use sp1_core_executor_runner::MinimalExecutorRunner;
23use sp1_core_machine::{executor::ExecutionOutput, io::SP1Stdin};
24use sp1_hypercube::{
25 air::{PublicValues, PROOF_NONCE_NUM_WORDS},
26 SP1PcsProofInner, SP1VerifyingKey, ShardProof,
27};
28use sp1_primitives::{io::SP1PublicValues, SP1GlobalContext};
29use sp1_prover_types::{
30 network_base_types::ProofMode, Artifact, ArtifactClient, ArtifactType, TaskStatus, TaskType,
31};
32use sp1_verifier::{ProofFromNetwork, SP1Proof};
33use std::{borrow::Borrow, sync::Arc};
34use tokio::{
35 sync::{oneshot, Mutex, MutexGuard},
36 task::JoinSet,
37};
38use tracing::Instrument;
39
40use crate::{
41 verify::SP1Verifier,
42 worker::{MessageReceiver, RawTaskRequest, TaskContext, TaskError, TaskId, WorkerClient},
43 SP1_CIRCUIT_VERSION,
44};
45
46#[derive(Clone)]
47pub struct MinimalExecutorCache(Arc<Mutex<Option<MinimalExecutorRunner>>>);
48
49impl MinimalExecutorCache {
50 pub fn empty() -> Self {
51 Self(Arc::new(Mutex::new(None)))
52 }
53
54 pub async fn lock(&self) -> MutexGuard<'_, Option<MinimalExecutorRunner>> {
55 self.0.lock().await
56 }
57}
58
59#[derive(Clone)]
60pub struct SP1ControllerConfig {
61 pub opts: SP1CoreOpts,
62 pub num_splicing_workers: usize,
63 pub splicing_buffer_size: usize,
64 pub max_reduce_arity: usize,
65 pub number_of_send_splice_workers_per_splice: usize,
66 pub send_splice_input_buffer_size_per_splice: usize,
67 pub use_fixed_pk: bool,
68 pub global_memory_buffer_size: usize,
69}
70
71pub struct SP1Controller<A, W> {
72 config: SP1ControllerConfig,
73 setup_cache: Arc<Mutex<LruCache<Artifact, SP1VerifyingKey>>>,
74 pub(crate) artifact_client: A,
75 pub(crate) worker_client: W,
76 pub(crate) verifier: SP1Verifier,
77 minimal_executor_cache: Option<MinimalExecutorCache>,
78}
79
80impl<A, W> SP1Controller<A, W>
81where
82 A: ArtifactClient,
83 W: WorkerClient,
84{
85 pub fn new(
86 config: SP1ControllerConfig,
87 artifact_client: A,
88 worker_client: W,
89 verifier: SP1Verifier,
90 ) -> Self {
91 let minimal_executor_cache =
92 if config.use_fixed_pk { Some(MinimalExecutorCache::empty()) } else { None };
93
94 Self {
95 config,
96 setup_cache: Arc::new(Mutex::new(LruCache::new(20.try_into().unwrap()))),
97 artifact_client,
98 worker_client,
99 verifier,
100 minimal_executor_cache,
101 }
102 }
103
104 #[inline]
105 pub const fn opts(&self) -> &SP1CoreOpts {
106 &self.config.opts
107 }
108
109 #[inline]
110 pub const fn max_reduce_arity(&self) -> usize {
111 self.config.max_reduce_arity
112 }
113
114 #[inline]
115 pub const fn global_memory_buffer_size(&self) -> usize {
116 self.config.global_memory_buffer_size
117 }
118
119 pub fn initialize_splicing_engine(&self) -> Arc<SplicingEngine<A, W>> {
120 let splicing_workers = (0..self.config.num_splicing_workers)
121 .map(|_| {
122 SplicingWorker::new(
123 self.artifact_client.clone(),
124 self.worker_client.clone(),
125 self.config.number_of_send_splice_workers_per_splice,
126 self.config.send_splice_input_buffer_size_per_splice,
127 )
128 })
129 .collect();
130 Arc::new(SplicingEngine::new(splicing_workers, self.config.splicing_buffer_size))
131 }
132
133 pub async fn execute(
137 &self,
138 task_id: TaskId,
139 request: CoreExecuteTaskRequest,
140 ) -> Result<ExecutionOutput, TaskError> {
141 let stdin = self.artifact_client.download_stdin::<SP1Stdin>(&request.stdin).await?;
142
143 let deferred_proofs = stdin.proofs.iter().map(|(proof, _)| proof.clone());
144 let deferred_inputs = DeferredInputs::new(deferred_proofs);
145
146 let splicing_engine = self.initialize_splicing_engine();
147 let proof_data_sender =
148 MessageSender::<W, ProofData>::new(self.worker_client.clone(), task_id);
149 let executor = SP1CoreExecutor::new(
150 splicing_engine,
151 self.global_memory_buffer_size(),
152 request.elf,
153 Arc::new(stdin),
154 request.common_input.clone(),
155 self.opts().clone(),
156 request.num_deferred_proofs,
157 request.context.clone(),
158 proof_data_sender.clone(),
159 self.artifact_client.clone(),
160 self.worker_client.clone(),
161 self.minimal_executor_cache.clone(),
162 request.cycle_limit,
163 );
164
165 let mut join_set = JoinSet::<Result<(), TaskError>>::new();
166
167 {
169 let deferred_sender = proof_data_sender.clone();
170 let artifact_client = self.artifact_client.clone();
171 let worker_client = self.worker_client.clone();
172 let common_input_artifact = request.common_input.clone();
173 let context = request.context.clone();
174 join_set.spawn(async move {
175 deferred_inputs
176 .emit_deferred_tasks(
177 common_input_artifact,
178 context,
179 deferred_sender,
180 artifact_client,
181 worker_client,
182 )
183 .await
184 });
185 }
186
187 let output = executor.execute().await;
189
190 while let Some(result) = join_set.join_next().await {
192 result.map_err(|e| TaskError::Fatal(e.into()))??;
193 }
194
195 let output = output?;
196 if let Some(limit) = request.cycle_limit {
197 if limit > 0 && output.cycles > limit {
198 return Err(TaskError::Fatal(anyhow::anyhow!(
199 "cycle limit exceeded: {} > {}",
200 output.cycles,
201 limit
202 )));
203 }
204 }
205 self.artifact_client.upload(&request.execution_output, &output).await?;
206 Ok(output)
207 }
208
209 pub async fn run(&self, request: RawTaskRequest) -> Result<ExecutionOutput, TaskError> {
210 let RawTaskRequest { inputs, outputs, context } = request;
211 let elf = inputs[0].clone();
212 let stdin_artifact = inputs[1].clone();
213 let mode_artifact = inputs[2].clone();
214 let cycle_limit = inputs.get(3).and_then(|a| a.clone().to_id().parse::<u64>().ok());
215 let proof_nonce = inputs.get(4);
216 let [output] = outputs.try_into().unwrap();
217 let mode = {
218 let parsed =
219 mode_artifact.to_id().parse::<i32>().map_err(|e| TaskError::Fatal(e.into()))?;
220 ProofMode::try_from(parsed).map_err(|e| TaskError::Fatal(e.into()))?
221 };
222
223 let stdin_download_handle =
224 self.artifact_client.download_stdin::<SP1Stdin>(&stdin_artifact);
225
226 let proof_nonce = match proof_nonce {
227 Some(artifact) => {
228 self.artifact_client.download::<[u32; PROOF_NONCE_NUM_WORDS]>(artifact).await?
229 }
230 None => [0u32; PROOF_NONCE_NUM_WORDS],
231 };
232
233 let vkey_download_handle = tokio::spawn({
234 let artifact_client_clone = self.artifact_client.clone();
235 let worker_client_clone = self.worker_client.clone();
236 let elf_clone = elf.clone();
237 let setup_cache = self.setup_cache.clone();
238 let context = context.clone();
239 async move {
240 let mut lock = setup_cache.lock().await;
241 let vkey = lock.get(&elf_clone).cloned();
242 drop(lock);
243 let vk = if let Some(vkey) = vkey {
244 tracing::debug!("setup cache hit");
245 vkey.clone()
246 } else {
247 let vk_artifact = artifact_client_clone.create_artifact()?;
248 let setup_request = RawTaskRequest {
249 inputs: vec![elf_clone.clone()],
250 outputs: vec![vk_artifact.clone()],
251 context: context.clone(),
252 };
253
254 tracing::debug!("submitting setup task");
255 let setup_id =
256 worker_client_clone.submit_task(TaskType::SetupVkey, setup_request).await?;
257
258 let subscriber =
259 worker_client_clone.subscriber(context.proof_id.clone()).await?.per_task();
260 let status = subscriber
261 .wait_task(setup_id)
262 .instrument(tracing::debug_span!("setup task"))
263 .await
264 .map_err(|e| TaskError::Fatal(e.into()))?;
265 if status != TaskStatus::Succeeded {
266 return Err(TaskError::Fatal(anyhow::anyhow!("setup task failed")));
267 }
268 tracing::debug!("setup task succeeded");
269 let vk =
270 artifact_client_clone.download::<SP1VerifyingKey>(&vk_artifact).await?;
271 setup_cache.lock().await.put(elf_clone, vk.clone());
272 vk
273 };
274 Ok(vk)
275 }
276 .instrument(tracing::debug_span!("setup vkey"))
277 });
278
279 let stdin: SP1Stdin = stdin_download_handle.await?;
280 let vk = vkey_download_handle.await.map_err(|e| TaskError::Fatal(e.into()))??;
281
282 let stdin = Arc::new(stdin);
283
284 let deferred_proofs = stdin.proofs.iter().map(|(proof, _)| proof.clone());
285 let deferred_inputs = DeferredInputs::new(deferred_proofs);
286
287 let num_deferred_proofs = deferred_inputs.num_deferred_proofs();
288 let deferred_digest = deferred_inputs.deferred_digest().map(|x| x.as_canonical_u32());
289 let common_input = CommonProverInput {
290 vk,
291 mode,
292 deferred_digest,
293 num_deferred_proofs,
294 nonce: proof_nonce,
295 };
296 let common_input_artifact = self.artifact_client.create_artifact()?;
297 self.artifact_client.upload(&common_input_artifact.clone(), common_input.clone()).await?;
298
299 let execution_output_artifact = self.artifact_client.create_artifact()?;
301 let executor_request = CoreExecuteTaskRequest {
302 elf: elf.clone(),
303 stdin: stdin_artifact.clone(),
304 common_input: common_input_artifact.clone(),
305 execution_output: execution_output_artifact.clone(),
306 num_deferred_proofs,
307 cycle_limit,
308 context: context.clone(),
309 };
310 let executor_task_id = self
311 .worker_client
312 .submit_task(TaskType::CoreExecute, executor_request.into_raw()?)
313 .await?;
314
315 let core_proof_rx = MessageReceiver::<ProofData>::new(
316 self.worker_client.subscribe_task_messages(&executor_task_id).await?,
317 );
318
319 let mut join_set = JoinSet::<Result<(), TaskError>>::new();
320
321 let mut core_proof_artifact = None;
322 let mut compress_proof_artifact = None;
323 let mut shrinkwrap_proof_artifact = None;
324 let mut groth16_proof_artifact = None;
325 let mut plonk_proof_artifact = None;
326
327 let (compress_complete_tx, compress_complete_rx) = oneshot::channel();
328
329 if mode == ProofMode::Core {
330 core_proof_artifact = Some(self.artifact_client.create_artifact()?);
331 join_set.spawn(collect_core_proofs(
332 self.worker_client.clone(),
333 self.artifact_client.clone(),
334 core_proof_artifact.clone().unwrap(),
335 context.clone(),
336 core_proof_rx,
337 ));
338 } else {
339 let mut tree = CompressTree::new(self.max_reduce_arity());
340 let artifact_client = self.artifact_client.clone();
341 let worker_client = self.worker_client.clone();
342 let context = context.clone();
343 compress_proof_artifact = Some(self.artifact_client.create_artifact()?);
344 let compress_proof_artifact = compress_proof_artifact.clone().unwrap();
345 join_set.spawn(
346 async move {
347 tree.reduce_proofs(
348 context,
349 compress_proof_artifact.clone(),
350 core_proof_rx,
351 &artifact_client,
352 &worker_client,
353 )
354 .await?;
355 compress_complete_tx.send(()).unwrap();
356 Ok(())
357 }
358 .instrument(tracing::debug_span!("reduce")),
359 );
360 }
361
362 match mode {
363 ProofMode::Groth16 => {
364 shrinkwrap_proof_artifact = Some(self.artifact_client.create_artifact()?);
365 groth16_proof_artifact = Some(self.artifact_client.create_artifact()?);
366
367 let shrinkwrap_task = RawTaskRequest {
368 inputs: vec![compress_proof_artifact.clone().unwrap()],
369 outputs: vec![shrinkwrap_proof_artifact.clone().unwrap()],
370 context: context.clone(),
371 };
372
373 let groth16_task = RawTaskRequest {
374 inputs: vec![shrinkwrap_proof_artifact.clone().unwrap()],
375 outputs: vec![groth16_proof_artifact.clone().unwrap()],
376 context: context.clone(),
377 };
378
379 let subscriber =
380 self.worker_client.subscriber(context.proof_id.clone()).await?.per_task();
381 let worker_client = self.worker_client.clone();
382 join_set.spawn(async move {
383 compress_complete_rx.await.unwrap();
384
385 let shrinkwrap_task_id =
386 worker_client.submit_task(TaskType::ShrinkWrap, shrinkwrap_task).await?;
387 subscriber.wait_task(shrinkwrap_task_id).await?;
388
389 let groth16_task_id =
390 worker_client.submit_task(TaskType::Groth16Wrap, groth16_task).await?;
391 subscriber.wait_task(groth16_task_id).await?;
392 Ok(())
393 });
394 }
395 ProofMode::Plonk => {
396 shrinkwrap_proof_artifact = Some(self.artifact_client.create_artifact()?);
397 plonk_proof_artifact = Some(self.artifact_client.create_artifact()?);
398
399 let shrinkwrap_task = RawTaskRequest {
400 inputs: vec![compress_proof_artifact.clone().unwrap()],
401 outputs: vec![shrinkwrap_proof_artifact.clone().unwrap()],
402 context: context.clone(),
403 };
404 let plonk_task = RawTaskRequest {
405 inputs: vec![shrinkwrap_proof_artifact.clone().unwrap()],
406 outputs: vec![plonk_proof_artifact.clone().unwrap()],
407 context: context.clone(),
408 };
409
410 let subscriber =
411 self.worker_client.subscriber(context.proof_id.clone()).await?.per_task();
412 let worker_client = self.worker_client.clone();
413 join_set.spawn(async move {
414 compress_complete_rx.await.unwrap();
415
416 let shrinkwrap_task_id =
417 worker_client.submit_task(TaskType::ShrinkWrap, shrinkwrap_task).await?;
418 subscriber.wait_task(shrinkwrap_task_id).await?;
419
420 let plonk_task_id =
421 worker_client.submit_task(TaskType::PlonkWrap, plonk_task).await?;
422 subscriber.wait_task(plonk_task_id).await?;
423 Ok(())
424 });
425 }
426 _ => {}
427 }
428
429 {
431 let subscriber =
432 self.worker_client.subscriber(context.proof_id.clone()).await?.per_task();
433 join_set.spawn(async move {
434 let status = subscriber
435 .wait_task(executor_task_id)
436 .instrument(tracing::debug_span!("wait executor"))
437 .await?;
438 if status != TaskStatus::Succeeded {
439 return Err(TaskError::Fatal(anyhow::anyhow!("CoreExecute task failed")));
440 }
441 Ok(())
442 });
443 }
444
445 while let Some(result) = join_set.join_next().await {
447 result.map_err(|e| TaskError::Fatal(e.into()))??;
448 }
449
450 let result: ExecutionOutput =
452 self.artifact_client.download(&execution_output_artifact).await?;
453
454 let inner_proof = match mode {
456 ProofMode::Core => {
457 let shard_proofs =
458 self.artifact_client.download(&core_proof_artifact.clone().unwrap()).await?;
459 SP1Proof::Core(shard_proofs)
460 }
461 ProofMode::Compressed => {
462 let proof = self
463 .artifact_client
464 .download(&compress_proof_artifact.clone().unwrap())
465 .await?;
466 SP1Proof::Compressed(Box::new(proof))
467 }
468 ProofMode::Plonk => {
469 let proof =
470 self.artifact_client.download(&plonk_proof_artifact.clone().unwrap()).await?;
471 SP1Proof::Plonk(proof)
472 }
473 ProofMode::Groth16 => {
474 let proof =
475 self.artifact_client.download(&groth16_proof_artifact.clone().unwrap()).await?;
476 SP1Proof::Groth16(proof)
477 }
478 _ => unimplemented!("proof mode not supported: {:?}", mode),
479 };
480
481 let public_values = SP1PublicValues::from(&result.public_value_stream);
483 let proof = ProofFromNetwork {
484 proof: inner_proof,
485 public_values,
486 sp1_version: SP1_CIRCUIT_VERSION.to_string(),
487 };
488
489 self.artifact_client.upload_proof(&output, proof).await?;
491
492 let artifacts_to_cleanup = vec![
494 Some(common_input_artifact),
495 Some(stdin_artifact),
496 Some(execution_output_artifact),
497 core_proof_artifact,
498 compress_proof_artifact,
499 shrinkwrap_proof_artifact,
500 groth16_proof_artifact,
501 plonk_proof_artifact,
502 ]
503 .into_iter()
504 .flatten()
505 .collect::<Vec<_>>();
506
507 self.artifact_client
508 .delete_batch(&artifacts_to_cleanup, ArtifactType::UnspecifiedArtifactType)
509 .await?;
510
511 Ok(result)
512 }
513}
514
515async fn collect_core_proofs(
516 worker_client: impl WorkerClient,
517 artifact_client: impl ArtifactClient,
518 result_artifact: Artifact,
519 context: TaskContext,
520 mut core_proof_rx: MessageReceiver<ProofData>,
521) -> Result<(), TaskError> {
522 let subscriber = worker_client.subscriber(context.proof_id.clone()).await?.per_task();
523 let mut shard_proofs = Vec::new();
524 while let Some(proof_data) = core_proof_rx.recv().await {
525 let ProofData { task_id, proof, .. } = proof_data;
526 let status = subscriber.wait_task(task_id.clone()).await?;
527 if status != TaskStatus::Succeeded {
528 tracing::error!("core proof task failed: {:?}", task_id);
529 return Err(TaskError::Fatal(anyhow::anyhow!("core proof task failed: {:?}", task_id)));
530 }
531 let proof = artifact_client
532 .download::<ShardProof<SP1GlobalContext, SP1PcsProofInner>>(&proof)
533 .await?;
534 shard_proofs.push(proof);
535 }
536 shard_proofs.sort_by_key(|shard_proof| {
537 let public_values: &PublicValues<[_; 4], [_; 3], [_; 4], _> =
538 shard_proof.public_values.as_slice().borrow();
539 public_values.range()
540 });
541
542 artifact_client.upload(&result_artifact, shard_proofs).await?;
543
544 Ok(())
545}