Skip to main content

sp1_prover/worker/node/full/
init.rs

1use std::sync::Arc;
2
3use slop_futures::pipeline::TaskJoinError;
4use sp1_hypercube::prover::ProverSemaphore;
5use sp1_prover_types::{
6    ArtifactClient, ArtifactType, InMemoryArtifactClient, TaskStatus, TaskType,
7};
8use tokio::{sync::mpsc, task::JoinSet};
9use tracing::Instrument;
10
11use crate::{
12    worker::{
13        node::SP1NodeCore, run_vk_generation, LocalWorkerClient, LocalWorkerClientChannels,
14        ProofId, RawTaskRequest, SP1LocalNode, SP1NodeInner, SP1WorkerBuilder, TaskError, TaskId,
15        TaskMetadata, WorkerClient,
16    },
17    SP1ProverComponents,
18};
19
20pub struct SP1LocalNodeBuilder<C: SP1ProverComponents> {
21    pub worker_builder: SP1WorkerBuilder<C, InMemoryArtifactClient, LocalWorkerClient>,
22    pub channels: LocalWorkerClientChannels,
23}
24
25impl<C: SP1ProverComponents> Default for SP1LocalNodeBuilder<C> {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl<C: SP1ProverComponents> SP1LocalNodeBuilder<C> {
32    /// Creates a new local node builder with a default worker client builder.
33    pub fn new() -> Self {
34        Self::from_worker_client_builder(SP1WorkerBuilder::new())
35    }
36
37    /// Creates a new local node builder from a worker client builder.
38    ///
39    /// This method can be used to initialize a node from a worker client builder that has already
40    /// been configured with the desired prover components.
41    pub fn from_worker_client_builder(builder: SP1WorkerBuilder<C>) -> Self {
42        let artifact_client = InMemoryArtifactClient::new();
43        let (worker_client, channels) = LocalWorkerClient::init();
44        let worker_builder =
45            builder.with_artifact_client(artifact_client).with_worker_client(worker_client);
46        Self { worker_builder, channels }
47    }
48
49    /// Sets the core air prover to the worker client builder.
50    pub fn with_core_air_prover(
51        mut self,
52        core_air_prover: Arc<C::CoreProver>,
53        permit: ProverSemaphore,
54    ) -> Self {
55        self.worker_builder = self.worker_builder.with_core_air_prover(core_air_prover, permit);
56        self
57    }
58
59    /// Sets the compress air prover to the worker client builder.
60    pub fn with_compress_air_prover(
61        mut self,
62        compress_air_prover: Arc<C::RecursionProver>,
63        permit: ProverSemaphore,
64    ) -> Self {
65        self.worker_builder =
66            self.worker_builder.with_compress_air_prover(compress_air_prover, permit);
67        self
68    }
69
70    /// Sets the shrink air prover to the worker client builder.
71    pub fn with_shrink_air_prover(
72        mut self,
73        shrink_air_prover: Arc<C::RecursionProver>,
74        permit: ProverSemaphore,
75    ) -> Self {
76        self.worker_builder = self.worker_builder.with_shrink_air_prover(shrink_air_prover, permit);
77        self
78    }
79
80    /// Sets the wrap air prover to the worker client builder.
81    pub fn with_wrap_air_prover(
82        mut self,
83        wrap_air_prover: C::WrapProverBuilder,
84        permit: ProverSemaphore,
85    ) -> Self {
86        self.worker_builder = self.worker_builder.with_wrap_air_prover(wrap_air_prover, permit);
87        self
88    }
89
90    pub async fn build(self) -> anyhow::Result<SP1LocalNode> {
91        // Destructure the builder.
92        let Self { worker_builder, mut channels } = self;
93        // Get the core options from the worker builder.
94        let opts = worker_builder.core_opts().clone();
95
96        // Build the node.
97        let worker = worker_builder.build().await?;
98
99        // Create a join set for the task handlers.
100        let mut join_set = JoinSet::new();
101
102        // Spawn tasks to handle all the requests. We must spawn a handler for each task type to
103        // avoid blocking the main thread by not having processed the input channel.
104
105        // Spawn the controller handler
106        join_set.spawn({
107            let mut controller_rx = channels.task_receivers.remove(&TaskType::Controller).unwrap();
108            let worker = worker.clone();
109            async move {
110                while let Some((task_id, request)) = controller_rx.recv().await {
111                    let span = tracing::debug_span!("Controller", proof_id = %request.context.proof_id, task_id = %task_id);
112                    // Run the controller task
113                    if let Err(e) = worker.controller().run(request.clone()).instrument(span).await
114                    {
115                        tracing::error!("Controller: task failed: {e:?}");
116                    }
117
118                    // Complete the task
119                    if let Err(e) = worker
120                        .worker_client()
121                        .complete_task(
122                            request.context.proof_id,
123                            task_id,
124                            TaskMetadata { gpu_ms: None },
125                        )
126                        .await
127                    {
128                        tracing::error!("Controller: marking task as complete failed: {e:?}");
129                    }
130
131                    // Remove all the inputs from the task
132                    for input in request.inputs {
133                        if let Err(e) = worker
134                            .artifact_client()
135                            .delete(&input, ArtifactType::UnspecifiedArtifactType)
136                            .await
137                        {
138                            tracing::error!("Controller: deleting input artifact failed: {e:?}");
139                        }
140                    }
141                }
142            }
143        });
144
145        // Spawn the CoreExecute handler
146        join_set.spawn({
147            let mut execute_rx =
148                channels.task_receivers.remove(&TaskType::CoreExecute).unwrap();
149            let worker = worker.clone();
150            async move {
151                while let Some((task_id, request)) = execute_rx.recv().await {
152                    let span = tracing::debug_span!("CoreExecute", proof_id = %request.context.proof_id, task_id = %task_id);
153                    let proof_id = request.context.proof_id.clone();
154                    match crate::worker::CoreExecuteTaskRequest::from_raw(request.clone()) {
155                        Ok(req) => {
156                            if let Err(e) =
157                                worker.controller().execute(task_id.clone(), req).instrument(span).await
158                            {
159                                tracing::error!("CoreExecute: task failed: {e:?}");
160                            }
161                        }
162                        Err(e) => {
163                            tracing::error!("CoreExecute: failed to parse request: {e:?}");
164                        }
165                    }
166
167                    if let Err(e) = worker
168                        .worker_client()
169                        .complete_task(proof_id, task_id, TaskMetadata { gpu_ms: None })
170                        .await
171                    {
172                        tracing::error!("CoreExecute: marking task as complete failed: {e:?}");
173                    }
174                }
175            }
176        });
177
178        // Spawn the setup handler
179        join_set.spawn({
180            let mut setup_rx = channels.task_receivers.remove(&TaskType::SetupVkey).unwrap();
181            let worker = worker.clone();
182            let worker_client = worker.worker_client().clone();
183            async move {
184                let mut task_set = JoinSet::new();
185                let (task_tx, mut task_rx) = mpsc::unbounded_channel();
186                loop {
187                    tokio::select! {
188                        Some((id, request)) = setup_rx.recv() => {
189                            let span = tracing::debug_span!("SetupVkey", proof_id = %request.context.proof_id, task_id = %id);
190                            let RawTaskRequest { inputs, outputs, context } = request.clone();
191                            let proof_id = context.proof_id.clone();
192                            let elf = inputs[0].clone();
193                            let output = outputs[0].clone();
194                            let handle = worker
195                                    .prover_engine()
196                                    .submit_setup(id.clone(), elf, output)
197                                    .instrument(span.clone())
198                                    .await
199                                    .unwrap();
200                            let tx = task_tx.clone();
201                            task_set.spawn(async move {
202                                let result = handle.await.map(|res| res.map(|(_, metadata)| metadata));
203                                TaskOutput::handle_worker_result(result, &tx, proof_id, id, request, TaskType::SetupVkey);
204                            }
205                          );
206                        }
207
208                        Some(output) = task_rx.recv() => {
209                            output.handle_task_output(&worker_client).await;
210                        }
211                        else => {
212                            break;
213                        }
214                    }
215                }
216            }
217        });
218
219        // Spawn the recursion vk tree handler
220        join_set.spawn({
221            let mut controller_rx =
222                channels.task_receivers.remove(&TaskType::UtilVkeyMapController).unwrap();
223            let worker = worker.clone();
224            async move {
225                while let Some((task_id, request)) = controller_rx.recv().await {
226                    // Run the controller task
227                    if let Err(e) =
228                        worker.controller().run_sp1_util_vkey_map_controller(request.clone()).await
229                    {
230                        tracing::error!("Controller: task failed: {e:?}");
231                    }
232
233                    // Complete the task
234                    if let Err(e) = worker
235                        .worker_client()
236                        .complete_task(
237                            request.context.proof_id,
238                            task_id,
239                            TaskMetadata { gpu_ms: None },
240                        )
241                        .await
242                    {
243                        tracing::error!("Controller: marking task as complete failed: {e:?}");
244                    }
245
246                    // Remove all the inputs from the task
247                    for input in request.inputs {
248                        if let Err(e) = worker
249                            .artifact_client()
250                            .delete(&input, ArtifactType::UnspecifiedArtifactType)
251                            .await
252                        {
253                            tracing::error!("Controller: deleting input artifact failed: {e:?}");
254                        }
255                    }
256                }
257            }
258        });
259
260        // Spawn the vk chunk worker handler.
261        join_set.spawn({
262            let mut core_prover_rx =
263                channels.task_receivers.remove(&TaskType::UtilVkeyMapChunk).unwrap();
264            let worker = worker.clone();
265            let worker_client = worker.worker_client().clone();
266            let vk_worker = Arc::new(worker.clone().prover_engine().vk_worker.clone());
267            async move {
268                let mut task_set = JoinSet::new();
269                let (task_tx, mut task_rx) = mpsc::unbounded_channel();
270
271                loop {
272                    let vk_worker = vk_worker.clone();
273                    tokio::select! {
274                        Some((id, request)) = core_prover_rx.recv() => {
275                            let proof_id = request.context.proof_id.clone();
276                        let handle = run_vk_generation::<_,_>(vk_worker, request, worker.artifact_client().clone());
277                            let tx = task_tx.clone();
278                            let task_id = id;
279                            task_set.spawn(async move {
280                                match handle.await {
281                                    Ok(()) => {
282                                        tx.send((proof_id, task_id, TaskStatus::Succeeded)).ok();
283                                    }
284                                    Err(e) => {
285                                        tracing::error!("Failed to generate vk chunk: {:?}", e);
286                                    }
287                                }
288                            });
289                        }
290
291                        Some((proof_id, task_id , status)) = task_rx.recv() => {
292                            assert_eq!(status, TaskStatus::Succeeded);
293                         if let Err(e) = worker_client.complete_task(proof_id, task_id, TaskMetadata { gpu_ms: None }).await {
294                             tracing::error!("Failed to complete vk chunk task: {:?}", e);
295                         }
296                        }
297                        else => {
298                            break;
299                        }
300                    }
301                }
302            }
303        });
304
305        // Spawn the prove shard handler
306        join_set.spawn({
307            let mut core_prover_rx = channels.task_receivers.remove(&TaskType::ProveShard).unwrap();
308            let worker = worker.clone();
309            let worker_client = worker.worker_client().clone();
310            async move {
311                let mut task_set = JoinSet::new();
312                let (task_tx, mut task_rx) = mpsc::unbounded_channel();
313
314                loop {
315                    tokio::select! {
316                        Some((id, request)) = core_prover_rx.recv() => {
317                            let span = tracing::debug_span!("ProveShard", proof_id = %request.context.proof_id, task_id = %id);
318                            let proof_id = request.context.proof_id.clone();
319                            let handle = worker
320                                .prover_engine()
321                                .submit_prove_core_shard(
322                                    request.clone(),
323                                )
324                                .instrument(span.clone())
325                                .await
326                                .unwrap();
327                            let tx = task_tx.clone();
328                            task_set.spawn(
329                                async move {
330                                    let result = handle.await;
331                                    TaskOutput::handle_worker_result(result, &tx, proof_id, id, request, TaskType::ProveShard);
332                                }.instrument(span)
333                           );
334                        }
335
336                        Some(output) = task_rx.recv() => {
337                            output.handle_task_output(&worker_client).await;
338                        }
339                        else => {
340                            break;
341                        }
342                    }
343                }
344            }
345        });
346
347        // Spawn the recursion reduce handler
348        join_set.spawn({
349            let mut recursion_reduce_rx =
350                channels.task_receivers.remove(&TaskType::RecursionReduce).unwrap();
351            let worker = worker.clone();
352            let worker_client = worker.worker_client().clone();
353            async move {
354                let mut task_set = JoinSet::new();
355                let (task_tx, mut task_rx) = mpsc::unbounded_channel();
356                loop {
357                    tokio::select! {
358                        Some((id, request)) = recursion_reduce_rx.recv() => {
359                            let span = tracing::debug_span!("RecursionReduce", proof_id = %request.context.proof_id, task_id = %id);
360                            let proof_id = request.context.proof_id.clone();
361                            let handle = worker
362                                .prover_engine()
363                                .submit_recursion_reduce(request.clone())
364                                .instrument(span.clone())
365                                .await
366                                .unwrap();
367                            let tx = task_tx.clone();
368                            task_set.spawn(async move {
369                                let result = handle.await;
370                                TaskOutput::handle_worker_result(result, &tx, proof_id, id, request, TaskType::RecursionReduce);
371                            }.instrument(span)
372                          );
373                        }
374
375                        Some(output) = task_rx.recv() => {
376                            output.handle_task_output(&worker_client).await;
377                        }
378                        else => {
379                            break;
380                        }
381                    }
382                }
383            }
384        });
385
386        // Spawn the deferred handler
387        join_set.spawn({
388            let mut recursion_deferred_rx =
389                channels.task_receivers.remove(&TaskType::RecursionDeferred).unwrap();
390            let worker = worker.clone();
391            let worker_client = worker.worker_client().clone();
392            async move {
393                let mut task_set = JoinSet::new();
394                let (task_tx, mut task_rx) = mpsc::unbounded_channel();
395                loop {
396                    tokio::select! {
397                        Some((id, request)) = recursion_deferred_rx.recv() => {
398                            let span = tracing::debug_span!("RecursionDeferred", proof_id = %request.context.proof_id, task_id = %id);
399                            let proof_id = request.context.proof_id.clone();
400                            let handle = worker
401                                .prover_engine()
402                                .submit_prove_deferred(request.clone())
403                                .instrument(span.clone())
404                                .await
405                                .unwrap();
406                            let tx = task_tx.clone();
407                            task_set.spawn(async move {
408                                let result = handle.await;
409                                TaskOutput::handle_worker_result(result, &tx, proof_id, id, request, TaskType::RecursionDeferred);
410                            }.instrument(span)
411                          );
412                        }
413                        Some(output) = task_rx.recv() => {
414                            output.handle_task_output(&worker_client).await;
415                        }
416                        else => {
417                            break;
418                        }
419                    }
420                }
421            }
422        });
423
424        // Spawn the deferred marker task handler.
425        // Marker deferred tasks are completed by the [TaskType::ProveShard] task, but we still need to consume the receiver here.
426        join_set.spawn({
427            let mut marker_deferred_task_rx =
428                channels.task_receivers.remove(&TaskType::MarkerDeferredRecord).unwrap();
429            async move { while let Some((_task_id, _request)) = marker_deferred_task_rx.recv().await {} }
430        });
431
432        // Spawn the shrink wrap handler
433        //
434        // In the local node, we only allow one of these tasks to be run at a time.
435        join_set.spawn({
436            let mut shrink_wrap_rx = channels.task_receivers.remove(&TaskType::ShrinkWrap).unwrap();
437            let worker = worker.clone();
438            let worker_client = worker.worker_client().clone();
439            async move {
440                let (task_tx, mut task_rx) = mpsc::unbounded_channel();
441                loop {
442                    tokio::select! {
443                        Some((id, request)) = shrink_wrap_rx.recv() => {
444                            let span = tracing::debug_span!("ShrinkWrap", proof_id = %request.context.proof_id, task_id = %id);
445                            let worker = worker.clone();
446                            let proof_id = request.context.proof_id.clone();
447                            let result = worker
448                                .prover_engine()
449                                .run_shrink_wrap(request.clone())
450                                .instrument(span)
451                                .await
452                                .map(|_| TaskMetadata::default());
453                            TaskOutput::handle_worker_result(Ok(result), &task_tx, proof_id, id, request, TaskType::ShrinkWrap);
454                        }
455                        Some(output) = task_rx.recv() => {
456                            output.handle_task_output(&worker_client).await;
457                        }
458                        else => {
459                            break;
460                        }
461                    }
462                }
463            }
464        });
465
466        // Spawn the plonk wrap handler
467        //
468        // in the local node, we only allow one of these tasks to be run at a time.
469        join_set.spawn({
470            let mut plonk_wrap_rx = channels.task_receivers.remove(&TaskType::PlonkWrap).unwrap();
471            let worker = worker.clone();
472            let worker_client = worker.worker_client().clone();
473            async move {
474                let (task_tx, mut task_rx) = mpsc::unbounded_channel();
475                loop {
476                    tokio::select! {
477                        Some((id, request)) = plonk_wrap_rx.recv() => {
478                            let span = tracing::debug_span!("PlonkWrap", proof_id = %request.context.proof_id, task_id = %id);
479                            let worker = worker.clone();
480                            let proof_id = request.context.proof_id.clone();
481                            let result = worker
482                                .prover_engine()
483                                .run_plonk(request.clone())
484                                .instrument(span)
485                                .await
486                                .map(|_| TaskMetadata::default());
487                            TaskOutput::handle_worker_result(Ok(result), &task_tx, proof_id, id, request, TaskType::PlonkWrap);
488                        }
489                        Some(output) = task_rx.recv() => {
490                            output.handle_task_output(&worker_client).await;
491                        }
492                        else => {
493                            break;
494                        }
495                    }
496                }
497            }
498        });
499
500        // Spawn the groth16 wrap handler
501        //
502        // In the local node, we only allow one of these tasks to be run at a time.
503        join_set.spawn({
504            let mut groth16_wrap_rx =
505                channels.task_receivers.remove(&TaskType::Groth16Wrap).unwrap();
506            let worker = worker.clone();
507            async move {
508                let (task_tx, mut task_rx) = mpsc::unbounded_channel();
509                loop {
510                    tokio::select! {
511                        Some((id, request)) = groth16_wrap_rx.recv() => {
512                            let span = tracing::debug_span!("Groth16Wrap", proof_id = %request.context.proof_id, task_id = %id);
513                            let worker = worker.clone();
514                            let proof_id = request.context.proof_id.clone();
515                            let result = worker
516                                .prover_engine()
517                                .run_groth16(request.clone())
518                                .instrument(span)
519                                .await
520                                .map(|_| TaskMetadata::default());
521                            TaskOutput::handle_worker_result(Ok(result), &task_tx, proof_id, id, request, TaskType::Groth16Wrap);
522                        }
523                        Some(output) = task_rx.recv() => {
524                            output.handle_task_output(worker.worker_client()).await;
525                        }
526                        else => {
527                            break;
528                        }
529                    }
530                }
531            }
532        });
533
534        // Get the verifier, artifact client, and worker client from the worker
535        let verifier = worker.verifier().clone();
536        let artifact_client = worker.artifact_client().clone();
537        let worker_client = worker.worker_client().clone();
538        let core = SP1NodeCore::new(verifier, opts);
539        let inner =
540            Arc::new(SP1NodeInner { artifact_client, worker_client, core, _tasks: join_set });
541        Ok(SP1LocalNode { inner })
542    }
543}
544
545struct TaskOutput {
546    proof_id: ProofId,
547    task_id: TaskId,
548    status: TaskStatus,
549    task_metadata: TaskMetadata,
550    task_data: Option<RawTaskRequest>,
551    task_type: TaskType,
552}
553
554impl TaskOutput {
555    fn handle_worker_result(
556        result: Result<Result<TaskMetadata, TaskError>, TaskJoinError>,
557        tx: &mpsc::UnboundedSender<TaskOutput>,
558        proof_id: ProofId,
559        task_id: TaskId,
560        request: RawTaskRequest,
561        task_type: TaskType,
562    ) {
563        match result {
564            Ok(Ok(task_metadata)) => {
565                tracing::debug!("task succeeded");
566                let task_output = TaskOutput {
567                    proof_id,
568                    task_id,
569                    status: TaskStatus::Succeeded,
570                    task_metadata,
571                    task_data: None,
572                    task_type,
573                };
574                tx.send(task_output).ok();
575            }
576            Ok(Err(TaskError::Retryable(e))) => {
577                tracing::error!("task failed with retryable error: {:?}", e);
578                let task_output = TaskOutput {
579                    proof_id,
580                    task_id,
581                    status: TaskStatus::FailedRetryable,
582                    task_metadata: TaskMetadata::default(),
583                    task_data: Some(request),
584                    task_type,
585                };
586                tx.send(task_output).ok();
587            }
588            Ok(Err(TaskError::Fatal(e))) => {
589                tracing::error!("task failed with fatal error: {:?}", e);
590                let task_output = TaskOutput {
591                    proof_id,
592                    task_id,
593                    status: TaskStatus::FailedFatal,
594                    task_metadata: TaskMetadata::default(),
595                    task_data: None,
596                    task_type,
597                };
598                tx.send(task_output).ok();
599            }
600            Ok(Err(TaskError::Execution(e))) => {
601                tracing::error!("task failed with fatal error: {:?}", e);
602                let task_output = TaskOutput {
603                    proof_id,
604                    task_id,
605                    status: TaskStatus::FailedFatal,
606                    task_metadata: TaskMetadata::default(),
607                    task_data: None,
608                    task_type,
609                };
610                tx.send(task_output).ok();
611            }
612            Err(e) => {
613                tracing::error!("task panicked: {:?}", e);
614            }
615        }
616    }
617
618    async fn handle_task_output(self, worker_client: &LocalWorkerClient) {
619        let TaskOutput { proof_id, task_id, status, task_metadata, task_data, task_type } = self;
620        match status {
621            TaskStatus::Succeeded => {
622                let result = worker_client
623                    .complete_task(proof_id.clone(), task_id.clone(), task_metadata)
624                    .await;
625                if let Err(e) = result {
626                    tracing::error!(
627                        "Failed to complete task, proof_id: {:?}, task_id: {:?}, error: {:?}",
628                        proof_id,
629                        task_id,
630                        e
631                    );
632                }
633            }
634            TaskStatus::FailedRetryable => {
635                let task = task_data.unwrap();
636                let res = worker_client.submit_task(task_type, task).await;
637                if let Err(e) = res {
638                    tracing::error!("Failed to submit retry, task: {:?}, error: {:?}", task_id, e);
639                }
640            }
641            TaskStatus::FailedFatal => {
642                let res = worker_client
643                    .update_task_status(task_id.clone(), TaskStatus::FailedFatal)
644                    .await;
645                if let Err(e) = res {
646                    tracing::error!("Failed to fail task, task: {:?}, error: {:?}", task_id, e);
647                }
648            }
649            _ => {}
650        }
651    }
652}