1use std::sync::Arc;
2
3use slop_futures::pipeline::TaskJoinError;
4use sp1_hypercube::prover::ProverSemaphore;
5use sp1_prover_types::{
6 ArtifactClient, ArtifactType, InMemoryArtifactClient, TaskStatus, TaskType,
7};
8use tokio::{sync::mpsc, task::JoinSet};
9use tracing::Instrument;
10
11use crate::{
12 worker::{
13 node::SP1NodeCore, run_vk_generation, LocalWorkerClient, LocalWorkerClientChannels,
14 ProofId, RawTaskRequest, SP1LocalNode, SP1NodeInner, SP1WorkerBuilder, TaskError, TaskId,
15 TaskMetadata, WorkerClient,
16 },
17 SP1ProverComponents,
18};
19
20pub struct SP1LocalNodeBuilder<C: SP1ProverComponents> {
21 pub worker_builder: SP1WorkerBuilder<C, InMemoryArtifactClient, LocalWorkerClient>,
22 pub channels: LocalWorkerClientChannels,
23}
24
25impl<C: SP1ProverComponents> Default for SP1LocalNodeBuilder<C> {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl<C: SP1ProverComponents> SP1LocalNodeBuilder<C> {
32 pub fn new() -> Self {
34 Self::from_worker_client_builder(SP1WorkerBuilder::new())
35 }
36
37 pub fn from_worker_client_builder(builder: SP1WorkerBuilder<C>) -> Self {
42 let artifact_client = InMemoryArtifactClient::new();
43 let (worker_client, channels) = LocalWorkerClient::init();
44 let worker_builder =
45 builder.with_artifact_client(artifact_client).with_worker_client(worker_client);
46 Self { worker_builder, channels }
47 }
48
49 pub fn with_core_air_prover(
51 mut self,
52 core_air_prover: Arc<C::CoreProver>,
53 permit: ProverSemaphore,
54 ) -> Self {
55 self.worker_builder = self.worker_builder.with_core_air_prover(core_air_prover, permit);
56 self
57 }
58
59 pub fn with_compress_air_prover(
61 mut self,
62 compress_air_prover: Arc<C::RecursionProver>,
63 permit: ProverSemaphore,
64 ) -> Self {
65 self.worker_builder =
66 self.worker_builder.with_compress_air_prover(compress_air_prover, permit);
67 self
68 }
69
70 pub fn with_shrink_air_prover(
72 mut self,
73 shrink_air_prover: Arc<C::RecursionProver>,
74 permit: ProverSemaphore,
75 ) -> Self {
76 self.worker_builder = self.worker_builder.with_shrink_air_prover(shrink_air_prover, permit);
77 self
78 }
79
80 pub fn with_wrap_air_prover(
82 mut self,
83 wrap_air_prover: C::WrapProverBuilder,
84 permit: ProverSemaphore,
85 ) -> Self {
86 self.worker_builder = self.worker_builder.with_wrap_air_prover(wrap_air_prover, permit);
87 self
88 }
89
90 pub async fn build(self) -> anyhow::Result<SP1LocalNode> {
91 let Self { worker_builder, mut channels } = self;
93 let opts = worker_builder.core_opts().clone();
95
96 let worker = worker_builder.build().await?;
98
99 let mut join_set = JoinSet::new();
101
102 join_set.spawn({
107 let mut controller_rx = channels.task_receivers.remove(&TaskType::Controller).unwrap();
108 let worker = worker.clone();
109 async move {
110 while let Some((task_id, request)) = controller_rx.recv().await {
111 let span = tracing::debug_span!("Controller", proof_id = %request.context.proof_id, task_id = %task_id);
112 if let Err(e) = worker.controller().run(request.clone()).instrument(span).await
114 {
115 tracing::error!("Controller: task failed: {e:?}");
116 }
117
118 if let Err(e) = worker
120 .worker_client()
121 .complete_task(
122 request.context.proof_id,
123 task_id,
124 TaskMetadata { gpu_ms: None },
125 )
126 .await
127 {
128 tracing::error!("Controller: marking task as complete failed: {e:?}");
129 }
130
131 for input in request.inputs {
133 if let Err(e) = worker
134 .artifact_client()
135 .delete(&input, ArtifactType::UnspecifiedArtifactType)
136 .await
137 {
138 tracing::error!("Controller: deleting input artifact failed: {e:?}");
139 }
140 }
141 }
142 }
143 });
144
145 join_set.spawn({
147 let mut execute_rx =
148 channels.task_receivers.remove(&TaskType::CoreExecute).unwrap();
149 let worker = worker.clone();
150 async move {
151 while let Some((task_id, request)) = execute_rx.recv().await {
152 let span = tracing::debug_span!("CoreExecute", proof_id = %request.context.proof_id, task_id = %task_id);
153 let proof_id = request.context.proof_id.clone();
154 match crate::worker::CoreExecuteTaskRequest::from_raw(request.clone()) {
155 Ok(req) => {
156 if let Err(e) =
157 worker.controller().execute(task_id.clone(), req).instrument(span).await
158 {
159 tracing::error!("CoreExecute: task failed: {e:?}");
160 }
161 }
162 Err(e) => {
163 tracing::error!("CoreExecute: failed to parse request: {e:?}");
164 }
165 }
166
167 if let Err(e) = worker
168 .worker_client()
169 .complete_task(proof_id, task_id, TaskMetadata { gpu_ms: None })
170 .await
171 {
172 tracing::error!("CoreExecute: marking task as complete failed: {e:?}");
173 }
174 }
175 }
176 });
177
178 join_set.spawn({
180 let mut setup_rx = channels.task_receivers.remove(&TaskType::SetupVkey).unwrap();
181 let worker = worker.clone();
182 let worker_client = worker.worker_client().clone();
183 async move {
184 let mut task_set = JoinSet::new();
185 let (task_tx, mut task_rx) = mpsc::unbounded_channel();
186 loop {
187 tokio::select! {
188 Some((id, request)) = setup_rx.recv() => {
189 let span = tracing::debug_span!("SetupVkey", proof_id = %request.context.proof_id, task_id = %id);
190 let RawTaskRequest { inputs, outputs, context } = request.clone();
191 let proof_id = context.proof_id.clone();
192 let elf = inputs[0].clone();
193 let output = outputs[0].clone();
194 let handle = worker
195 .prover_engine()
196 .submit_setup(id.clone(), elf, output)
197 .instrument(span.clone())
198 .await
199 .unwrap();
200 let tx = task_tx.clone();
201 task_set.spawn(async move {
202 let result = handle.await.map(|res| res.map(|(_, metadata)| metadata));
203 TaskOutput::handle_worker_result(result, &tx, proof_id, id, request, TaskType::SetupVkey);
204 }
205 );
206 }
207
208 Some(output) = task_rx.recv() => {
209 output.handle_task_output(&worker_client).await;
210 }
211 else => {
212 break;
213 }
214 }
215 }
216 }
217 });
218
219 join_set.spawn({
221 let mut controller_rx =
222 channels.task_receivers.remove(&TaskType::UtilVkeyMapController).unwrap();
223 let worker = worker.clone();
224 async move {
225 while let Some((task_id, request)) = controller_rx.recv().await {
226 if let Err(e) =
228 worker.controller().run_sp1_util_vkey_map_controller(request.clone()).await
229 {
230 tracing::error!("Controller: task failed: {e:?}");
231 }
232
233 if let Err(e) = worker
235 .worker_client()
236 .complete_task(
237 request.context.proof_id,
238 task_id,
239 TaskMetadata { gpu_ms: None },
240 )
241 .await
242 {
243 tracing::error!("Controller: marking task as complete failed: {e:?}");
244 }
245
246 for input in request.inputs {
248 if let Err(e) = worker
249 .artifact_client()
250 .delete(&input, ArtifactType::UnspecifiedArtifactType)
251 .await
252 {
253 tracing::error!("Controller: deleting input artifact failed: {e:?}");
254 }
255 }
256 }
257 }
258 });
259
260 join_set.spawn({
262 let mut core_prover_rx =
263 channels.task_receivers.remove(&TaskType::UtilVkeyMapChunk).unwrap();
264 let worker = worker.clone();
265 let worker_client = worker.worker_client().clone();
266 let vk_worker = Arc::new(worker.clone().prover_engine().vk_worker.clone());
267 async move {
268 let mut task_set = JoinSet::new();
269 let (task_tx, mut task_rx) = mpsc::unbounded_channel();
270
271 loop {
272 let vk_worker = vk_worker.clone();
273 tokio::select! {
274 Some((id, request)) = core_prover_rx.recv() => {
275 let proof_id = request.context.proof_id.clone();
276 let handle = run_vk_generation::<_,_>(vk_worker, request, worker.artifact_client().clone());
277 let tx = task_tx.clone();
278 let task_id = id;
279 task_set.spawn(async move {
280 match handle.await {
281 Ok(()) => {
282 tx.send((proof_id, task_id, TaskStatus::Succeeded)).ok();
283 }
284 Err(e) => {
285 tracing::error!("Failed to generate vk chunk: {:?}", e);
286 }
287 }
288 });
289 }
290
291 Some((proof_id, task_id , status)) = task_rx.recv() => {
292 assert_eq!(status, TaskStatus::Succeeded);
293 if let Err(e) = worker_client.complete_task(proof_id, task_id, TaskMetadata { gpu_ms: None }).await {
294 tracing::error!("Failed to complete vk chunk task: {:?}", e);
295 }
296 }
297 else => {
298 break;
299 }
300 }
301 }
302 }
303 });
304
305 join_set.spawn({
307 let mut core_prover_rx = channels.task_receivers.remove(&TaskType::ProveShard).unwrap();
308 let worker = worker.clone();
309 let worker_client = worker.worker_client().clone();
310 async move {
311 let mut task_set = JoinSet::new();
312 let (task_tx, mut task_rx) = mpsc::unbounded_channel();
313
314 loop {
315 tokio::select! {
316 Some((id, request)) = core_prover_rx.recv() => {
317 let span = tracing::debug_span!("ProveShard", proof_id = %request.context.proof_id, task_id = %id);
318 let proof_id = request.context.proof_id.clone();
319 let handle = worker
320 .prover_engine()
321 .submit_prove_core_shard(
322 request.clone(),
323 )
324 .instrument(span.clone())
325 .await
326 .unwrap();
327 let tx = task_tx.clone();
328 task_set.spawn(
329 async move {
330 let result = handle.await;
331 TaskOutput::handle_worker_result(result, &tx, proof_id, id, request, TaskType::ProveShard);
332 }.instrument(span)
333 );
334 }
335
336 Some(output) = task_rx.recv() => {
337 output.handle_task_output(&worker_client).await;
338 }
339 else => {
340 break;
341 }
342 }
343 }
344 }
345 });
346
347 join_set.spawn({
349 let mut recursion_reduce_rx =
350 channels.task_receivers.remove(&TaskType::RecursionReduce).unwrap();
351 let worker = worker.clone();
352 let worker_client = worker.worker_client().clone();
353 async move {
354 let mut task_set = JoinSet::new();
355 let (task_tx, mut task_rx) = mpsc::unbounded_channel();
356 loop {
357 tokio::select! {
358 Some((id, request)) = recursion_reduce_rx.recv() => {
359 let span = tracing::debug_span!("RecursionReduce", proof_id = %request.context.proof_id, task_id = %id);
360 let proof_id = request.context.proof_id.clone();
361 let handle = worker
362 .prover_engine()
363 .submit_recursion_reduce(request.clone())
364 .instrument(span.clone())
365 .await
366 .unwrap();
367 let tx = task_tx.clone();
368 task_set.spawn(async move {
369 let result = handle.await;
370 TaskOutput::handle_worker_result(result, &tx, proof_id, id, request, TaskType::RecursionReduce);
371 }.instrument(span)
372 );
373 }
374
375 Some(output) = task_rx.recv() => {
376 output.handle_task_output(&worker_client).await;
377 }
378 else => {
379 break;
380 }
381 }
382 }
383 }
384 });
385
386 join_set.spawn({
388 let mut recursion_deferred_rx =
389 channels.task_receivers.remove(&TaskType::RecursionDeferred).unwrap();
390 let worker = worker.clone();
391 let worker_client = worker.worker_client().clone();
392 async move {
393 let mut task_set = JoinSet::new();
394 let (task_tx, mut task_rx) = mpsc::unbounded_channel();
395 loop {
396 tokio::select! {
397 Some((id, request)) = recursion_deferred_rx.recv() => {
398 let span = tracing::debug_span!("RecursionDeferred", proof_id = %request.context.proof_id, task_id = %id);
399 let proof_id = request.context.proof_id.clone();
400 let handle = worker
401 .prover_engine()
402 .submit_prove_deferred(request.clone())
403 .instrument(span.clone())
404 .await
405 .unwrap();
406 let tx = task_tx.clone();
407 task_set.spawn(async move {
408 let result = handle.await;
409 TaskOutput::handle_worker_result(result, &tx, proof_id, id, request, TaskType::RecursionDeferred);
410 }.instrument(span)
411 );
412 }
413 Some(output) = task_rx.recv() => {
414 output.handle_task_output(&worker_client).await;
415 }
416 else => {
417 break;
418 }
419 }
420 }
421 }
422 });
423
424 join_set.spawn({
427 let mut marker_deferred_task_rx =
428 channels.task_receivers.remove(&TaskType::MarkerDeferredRecord).unwrap();
429 async move { while let Some((_task_id, _request)) = marker_deferred_task_rx.recv().await {} }
430 });
431
432 join_set.spawn({
436 let mut shrink_wrap_rx = channels.task_receivers.remove(&TaskType::ShrinkWrap).unwrap();
437 let worker = worker.clone();
438 let worker_client = worker.worker_client().clone();
439 async move {
440 let (task_tx, mut task_rx) = mpsc::unbounded_channel();
441 loop {
442 tokio::select! {
443 Some((id, request)) = shrink_wrap_rx.recv() => {
444 let span = tracing::debug_span!("ShrinkWrap", proof_id = %request.context.proof_id, task_id = %id);
445 let worker = worker.clone();
446 let proof_id = request.context.proof_id.clone();
447 let result = worker
448 .prover_engine()
449 .run_shrink_wrap(request.clone())
450 .instrument(span)
451 .await
452 .map(|_| TaskMetadata::default());
453 TaskOutput::handle_worker_result(Ok(result), &task_tx, proof_id, id, request, TaskType::ShrinkWrap);
454 }
455 Some(output) = task_rx.recv() => {
456 output.handle_task_output(&worker_client).await;
457 }
458 else => {
459 break;
460 }
461 }
462 }
463 }
464 });
465
466 join_set.spawn({
470 let mut plonk_wrap_rx = channels.task_receivers.remove(&TaskType::PlonkWrap).unwrap();
471 let worker = worker.clone();
472 let worker_client = worker.worker_client().clone();
473 async move {
474 let (task_tx, mut task_rx) = mpsc::unbounded_channel();
475 loop {
476 tokio::select! {
477 Some((id, request)) = plonk_wrap_rx.recv() => {
478 let span = tracing::debug_span!("PlonkWrap", proof_id = %request.context.proof_id, task_id = %id);
479 let worker = worker.clone();
480 let proof_id = request.context.proof_id.clone();
481 let result = worker
482 .prover_engine()
483 .run_plonk(request.clone())
484 .instrument(span)
485 .await
486 .map(|_| TaskMetadata::default());
487 TaskOutput::handle_worker_result(Ok(result), &task_tx, proof_id, id, request, TaskType::PlonkWrap);
488 }
489 Some(output) = task_rx.recv() => {
490 output.handle_task_output(&worker_client).await;
491 }
492 else => {
493 break;
494 }
495 }
496 }
497 }
498 });
499
500 join_set.spawn({
504 let mut groth16_wrap_rx =
505 channels.task_receivers.remove(&TaskType::Groth16Wrap).unwrap();
506 let worker = worker.clone();
507 async move {
508 let (task_tx, mut task_rx) = mpsc::unbounded_channel();
509 loop {
510 tokio::select! {
511 Some((id, request)) = groth16_wrap_rx.recv() => {
512 let span = tracing::debug_span!("Groth16Wrap", proof_id = %request.context.proof_id, task_id = %id);
513 let worker = worker.clone();
514 let proof_id = request.context.proof_id.clone();
515 let result = worker
516 .prover_engine()
517 .run_groth16(request.clone())
518 .instrument(span)
519 .await
520 .map(|_| TaskMetadata::default());
521 TaskOutput::handle_worker_result(Ok(result), &task_tx, proof_id, id, request, TaskType::Groth16Wrap);
522 }
523 Some(output) = task_rx.recv() => {
524 output.handle_task_output(worker.worker_client()).await;
525 }
526 else => {
527 break;
528 }
529 }
530 }
531 }
532 });
533
534 let verifier = worker.verifier().clone();
536 let artifact_client = worker.artifact_client().clone();
537 let worker_client = worker.worker_client().clone();
538 let core = SP1NodeCore::new(verifier, opts);
539 let inner =
540 Arc::new(SP1NodeInner { artifact_client, worker_client, core, _tasks: join_set });
541 Ok(SP1LocalNode { inner })
542 }
543}
544
545struct TaskOutput {
546 proof_id: ProofId,
547 task_id: TaskId,
548 status: TaskStatus,
549 task_metadata: TaskMetadata,
550 task_data: Option<RawTaskRequest>,
551 task_type: TaskType,
552}
553
554impl TaskOutput {
555 fn handle_worker_result(
556 result: Result<Result<TaskMetadata, TaskError>, TaskJoinError>,
557 tx: &mpsc::UnboundedSender<TaskOutput>,
558 proof_id: ProofId,
559 task_id: TaskId,
560 request: RawTaskRequest,
561 task_type: TaskType,
562 ) {
563 match result {
564 Ok(Ok(task_metadata)) => {
565 tracing::debug!("task succeeded");
566 let task_output = TaskOutput {
567 proof_id,
568 task_id,
569 status: TaskStatus::Succeeded,
570 task_metadata,
571 task_data: None,
572 task_type,
573 };
574 tx.send(task_output).ok();
575 }
576 Ok(Err(TaskError::Retryable(e))) => {
577 tracing::error!("task failed with retryable error: {:?}", e);
578 let task_output = TaskOutput {
579 proof_id,
580 task_id,
581 status: TaskStatus::FailedRetryable,
582 task_metadata: TaskMetadata::default(),
583 task_data: Some(request),
584 task_type,
585 };
586 tx.send(task_output).ok();
587 }
588 Ok(Err(TaskError::Fatal(e))) => {
589 tracing::error!("task failed with fatal error: {:?}", e);
590 let task_output = TaskOutput {
591 proof_id,
592 task_id,
593 status: TaskStatus::FailedFatal,
594 task_metadata: TaskMetadata::default(),
595 task_data: None,
596 task_type,
597 };
598 tx.send(task_output).ok();
599 }
600 Ok(Err(TaskError::Execution(e))) => {
601 tracing::error!("task failed with fatal error: {:?}", e);
602 let task_output = TaskOutput {
603 proof_id,
604 task_id,
605 status: TaskStatus::FailedFatal,
606 task_metadata: TaskMetadata::default(),
607 task_data: None,
608 task_type,
609 };
610 tx.send(task_output).ok();
611 }
612 Err(e) => {
613 tracing::error!("task panicked: {:?}", e);
614 }
615 }
616 }
617
618 async fn handle_task_output(self, worker_client: &LocalWorkerClient) {
619 let TaskOutput { proof_id, task_id, status, task_metadata, task_data, task_type } = self;
620 match status {
621 TaskStatus::Succeeded => {
622 let result = worker_client
623 .complete_task(proof_id.clone(), task_id.clone(), task_metadata)
624 .await;
625 if let Err(e) = result {
626 tracing::error!(
627 "Failed to complete task, proof_id: {:?}, task_id: {:?}, error: {:?}",
628 proof_id,
629 task_id,
630 e
631 );
632 }
633 }
634 TaskStatus::FailedRetryable => {
635 let task = task_data.unwrap();
636 let res = worker_client.submit_task(task_type, task).await;
637 if let Err(e) = res {
638 tracing::error!("Failed to submit retry, task: {:?}, error: {:?}", task_id, e);
639 }
640 }
641 TaskStatus::FailedFatal => {
642 let res = worker_client
643 .update_task_status(task_id.clone(), TaskStatus::FailedFatal)
644 .await;
645 if let Err(e) = res {
646 tracing::error!("Failed to fail task, task: {:?}, error: {:?}", task_id, e);
647 }
648 }
649 _ => {}
650 }
651 }
652}