Skip to main content

sp1_prover/worker/prover/
core.rs

1use std::sync::Arc;
2
3use anyhow::anyhow;
4use slop_algebra::AbstractField;
5use slop_futures::pipeline::{AsyncEngine, AsyncWorker, Pipeline, SubmitError, SubmitHandle};
6use sp1_core_executor::{
7    events::{PrecompileEvent, SyscallEvent},
8    ExecutionRecord, Program, SP1CoreOpts, SplitOpts,
9};
10use sp1_core_machine::{executor::trace_chunk, riscv::RiscvAir};
11use sp1_hypercube::{
12    prover::{shape_from_record, CoreProofShape, ProverSemaphore, ProvingKey},
13    Machine, MachineProof, MachineVerifier, SP1VerifyingKey,
14};
15use sp1_jit::TraceChunk;
16use sp1_primitives::{SP1Field, SP1GlobalContext};
17use sp1_prover_types::{
18    await_scoped_vec, network_base_types::ProofMode, Artifact, ArtifactClient, ArtifactType,
19};
20use sp1_recursion_circuit::shard::RecursiveShardVerifier;
21use sp1_recursion_compiler::config::InnerConfig;
22use sp1_recursion_executor::RecursionProgram;
23use tokio::sync::OnceCell;
24use tracing::Instrument;
25
26use crate::{
27    recursion::{normalize_program_from_input, recursive_verifier},
28    shapes::{SP1NormalizeCache, SP1NormalizeInputShape, SP1RecursionProofShape},
29    worker::{
30        AirProverWorker, CommonProverInput, DeferredEvents, GlobalMemoryShard,
31        PrecompileArtifactSlice, ProofId, ProverMetrics, RawTaskRequest, SP1RecursionProver,
32        TaskContext, TaskError, TaskId, TaskMetadata, TraceData, WorkerClient,
33    },
34    CoreSC, SP1CircuitWitness, SP1ProverComponents,
35};
36
37pub struct SetupTask {
38    pub id: TaskId,
39    pub elf: Artifact,
40    pub output: Artifact,
41}
42
43pub struct ProveShardTaskRequest {
44    /// The elf artifact.
45    pub elf: Artifact,
46    /// The common input artifact.
47    pub common_input: Artifact,
48    /// The record artifact.
49    pub record: Artifact,
50    /// The traces output artifact.
51    pub output: Artifact,
52    /// The deferred marker task id.
53    pub deferred_marker_task: Artifact,
54    /// The deferred output artifact.
55    pub deferred_output: Artifact,
56    /// The task context.
57    pub context: TaskContext,
58}
59
60impl ProveShardTaskRequest {
61    pub fn from_raw(request: RawTaskRequest) -> Result<Self, TaskError> {
62        let RawTaskRequest { inputs, outputs, context } = request;
63        let elf = inputs[0].clone();
64        let common_input = inputs[1].clone();
65        let record = inputs[2].clone();
66        let deferred_marker_task = inputs[3].clone();
67
68        let output = outputs[0].clone();
69        let deferred_output = outputs[1].clone();
70
71        Ok(ProveShardTaskRequest {
72            elf,
73            common_input,
74            record,
75            output,
76            deferred_marker_task,
77            deferred_output,
78            context,
79        })
80    }
81
82    pub fn into_raw(self) -> Result<RawTaskRequest, TaskError> {
83        let ProveShardTaskRequest {
84            elf,
85            common_input,
86            record,
87            output,
88            deferred_marker_task,
89            deferred_output,
90            context,
91        } = self;
92
93        let inputs = vec![elf, common_input, record, deferred_marker_task];
94        let outputs = vec![output, deferred_output];
95        let raw_task_request = RawTaskRequest { inputs, outputs, context };
96        Ok(raw_task_request)
97    }
98}
99
100/// Generates traces and optionally deferred records for a core shard.
101pub struct CoreProvingTask {
102    /// The proof id.
103    pub proof_id: ProofId,
104    /// The elf artifact.
105    pub elf: Artifact,
106    /// The common input artifact.
107    pub common_input: Artifact,
108    /// The record artifact.
109    pub record: Artifact,
110    /// The traces output artifact.
111    pub output: Artifact,
112    /// The deferred marker task id.
113    pub deferred_marker_task: Artifact,
114    /// The deferred output artifact.
115    pub deferred_output: Artifact,
116    /// The metrics for the prover.
117    pub metrics: ProverMetrics,
118}
119
120struct NormalizeProgramCompiler {
121    cache: SP1NormalizeCache,
122    recursive_verifier: RecursiveShardVerifier<SP1GlobalContext, RiscvAir<SP1Field>, InnerConfig>,
123    reduce_shape: SP1RecursionProofShape,
124    verifier: MachineVerifier<SP1GlobalContext, CoreSC>,
125}
126
127impl NormalizeProgramCompiler {
128    pub fn new(
129        cache: SP1NormalizeCache,
130        recursive_verifier: RecursiveShardVerifier<
131            SP1GlobalContext,
132            RiscvAir<SP1Field>,
133            InnerConfig,
134        >,
135
136        reduce_shape: SP1RecursionProofShape,
137        machine_verifier: MachineVerifier<SP1GlobalContext, CoreSC>,
138    ) -> Self {
139        Self { cache, recursive_verifier, reduce_shape, verifier: machine_verifier }
140    }
141
142    pub fn machine(&self) -> &Machine<SP1Field, RiscvAir<SP1Field>> {
143        self.verifier.machine()
144    }
145
146    pub fn get_program(
147        &self,
148        vk: SP1VerifyingKey,
149        proof_shape: &CoreProofShape<SP1Field, RiscvAir<SP1Field>>,
150    ) -> Arc<RecursionProgram<SP1Field>> {
151        let shape = SP1NormalizeInputShape {
152            proof_shapes: vec![proof_shape.clone()],
153            max_log_row_count: self.verifier.max_log_row_count(),
154            log_blowup: self.verifier.fri_config().log_blowup,
155            log_stacking_height: self.verifier.log_stacking_height() as usize,
156        };
157        if let Some(program) = self.cache.get(&shape) {
158            return program.clone();
159        }
160
161        let input = shape.dummy_input(vk);
162        let mut program = normalize_program_from_input(&self.recursive_verifier, &input);
163        program.shape = Some(self.reduce_shape.shape.clone());
164        let program = Arc::new(program);
165        self.cache.push(shape, program.clone());
166        program
167    }
168}
169
170/// Unified worker that combines tracing, core proving, and normalize proving.
171pub struct CoreWorker<A, W, C: SP1ProverComponents> {
172    normalize_program_compiler: Arc<NormalizeProgramCompiler>,
173    opts: SP1CoreOpts,
174    artifact_client: A,
175    worker_client: W,
176    core_prover: Arc<C::CoreProver>,
177    recursion_prover: SP1RecursionProver<A, C>,
178    permits: ProverSemaphore,
179    /// Optional fixed PK cache shared across workers.
180    pk: Option<CoreProvingKeyCache<C>>,
181    verify_intermediates: bool,
182    dump_shard_dir: Option<String>,
183}
184
185impl<A, W, C: SP1ProverComponents> CoreWorker<A, W, C> {
186    #[allow(clippy::too_many_arguments)]
187    fn new(
188        normalize_program_compiler: Arc<NormalizeProgramCompiler>,
189        opts: SP1CoreOpts,
190        artifact_client: A,
191        worker_client: W,
192        core_prover: Arc<C::CoreProver>,
193        recursion_prover: SP1RecursionProver<A, C>,
194        permits: ProverSemaphore,
195        pk: Option<CoreProvingKeyCache<C>>,
196        verify_intermediates: bool,
197        dump_shard_dir: Option<String>,
198    ) -> Self {
199        Self {
200            normalize_program_compiler,
201            opts,
202            artifact_client,
203            worker_client,
204            core_prover,
205            recursion_prover,
206            permits,
207            pk,
208            verify_intermediates,
209            dump_shard_dir,
210        }
211    }
212
213    fn machine(&self) -> &Machine<SP1Field, RiscvAir<SP1Field>> {
214        self.normalize_program_compiler.machine()
215    }
216}
217
218impl<A, W, C> AsyncWorker<CoreProvingTask, Result<TaskMetadata, TaskError>> for CoreWorker<A, W, C>
219where
220    A: ArtifactClient,
221    W: WorkerClient,
222    C: SP1ProverComponents,
223{
224    async fn call(&self, input: CoreProvingTask) -> Result<TaskMetadata, TaskError> {
225        // === Phase 1: Tracing ===
226        // Save the trace input artifact for later use in the task
227        let record_artifact = input.record.clone();
228        let metrics = input.metrics.clone();
229
230        // Ok to panic because it will send a JoinError.
231        let (elf, common_input, record) = tokio::try_join!(
232            self.artifact_client.download_program(&input.elf),
233            self.artifact_client.download::<CommonProverInput>(&input.common_input),
234            self.artifact_client.download::<TraceData>(&input.record),
235        )?;
236
237        // Extract precompile artifacts before moving input
238        let precompile_artifacts = if let TraceData::Precompile(ref artifacts, _) = record {
239            Some(artifacts.clone())
240        } else {
241            None
242        };
243
244        let span = tracing::debug_span!("into_record");
245        let (program, mut record, deferred_record, is_precompile) = tokio::task::spawn_blocking({
246            let artifact_client = self.artifact_client.clone();
247            let opts = self.opts.clone();
248            move || {
249                let _guard = span.enter();
250                {
251                    let program = Program::from(&elf).map_err(|e| {
252                        TaskError::Fatal(anyhow::anyhow!("failed to disassemble program: {}", e))
253                    })?;
254                    let program = Arc::new(program);
255                    let (record, deferred_record, is_precompile) = match record {
256                        TraceData::Core(chunk_bytes) => {
257                            let chunk: TraceChunk =
258                                bincode::deserialize(&chunk_bytes).map_err(|e| {
259                                    TaskError::Fatal(anyhow::anyhow!(
260                                        "failed to deserialize chunk: {}",
261                                        e
262                                    ))
263                                })?;
264                            tracing::debug!(
265                                "tracing chunk at clk range: {}..{}",
266                                chunk.clk_start,
267                                chunk.clk_end
268                            );
269                            // Here, we reserve 1/8 of the shard size for common events. In other words,
270                            // we assume that no event will take up more than 1/8 of the shard's events.
271                            let record = tracing::debug_span!("allocating record").in_scope(|| {
272                                ExecutionRecord::new_preallocated(
273                                    program.clone(),
274                                    common_input.nonce,
275                                    opts.global_dependencies_opt,
276                                    opts.shard_size >> 3,
277                                )
278                            });
279                            let (_, mut record, _) = trace_chunk::<SP1Field>(
280                                program.clone(),
281                                opts.clone(),
282                                chunk,
283                                common_input.nonce,
284                                record,
285                            )
286                            .map_err(|e| {
287                                TaskError::Fatal(anyhow::anyhow!("failed to trace chunk: {}", e))
288                            })?;
289
290                            let deferred_record = record.defer(&opts.retained_events_presets);
291
292                            (record, Some(deferred_record), false)
293                        }
294                        TraceData::Memory(shard) => {
295                            tracing::debug!("global memory shard");
296                            let GlobalMemoryShard {
297                                final_state,
298                                initialize_events,
299                                finalize_events,
300                                previous_init_addr,
301                                previous_finalize_addr,
302                                previous_init_page_idx,
303                                previous_finalize_page_idx,
304                                last_init_addr,
305                                last_finalize_addr,
306                                last_init_page_idx,
307                                last_finalize_page_idx,
308                            } = *shard;
309                            let mut record = ExecutionRecord::new(
310                                program.clone(),
311                                common_input.nonce,
312                                opts.global_dependencies_opt,
313                            );
314                            record.global_memory_initialize_events = initialize_events;
315                            record.global_memory_finalize_events = finalize_events;
316
317                            let enable_untrusted_programs =
318                                common_input.vk.vk.enable_untrusted_programs == SP1Field::one();
319
320                            // Update the public values
321                            record.public_values.update_finalized_state(
322                                final_state.timestamp,
323                                final_state.pc,
324                                final_state.exit_code,
325                                enable_untrusted_programs as u32,
326                                final_state.public_value_digest,
327                                common_input.deferred_digest,
328                                final_state.proof_nonce,
329                            );
330                            // Update previous init and finalize addresses and page indices from the
331                            // oracle values received from the controller.
332                            record.public_values.previous_init_addr = previous_init_addr;
333                            record.public_values.previous_finalize_addr = previous_finalize_addr;
334                            record.public_values.previous_init_page_idx = previous_init_page_idx;
335                            record.public_values.previous_finalize_page_idx =
336                                previous_finalize_page_idx;
337
338                            // Update last init and finalize addresses and page indices from the
339                            // events of the shard.
340                            record.public_values.last_init_addr = last_init_addr;
341                            record.public_values.last_finalize_addr = last_finalize_addr;
342                            record.public_values.last_init_page_idx = last_init_page_idx;
343                            record.public_values.last_finalize_page_idx = last_finalize_page_idx;
344
345                            record.finalize_public_values::<SP1Field>(false);
346                            (record, None, false)
347                        }
348                        TraceData::Precompile(artifacts, code) => {
349                            tracing::debug!("precompile events: code {}", code);
350                            let mut main_record = ExecutionRecord::new(
351                                program.clone(),
352                                common_input.nonce,
353                                opts.global_dependencies_opt,
354                            );
355
356                            // [start, end)
357                            let mut total_events = 0;
358                            let mut indices = Vec::new();
359                            for artifact_slice in artifacts.iter() {
360                                let PrecompileArtifactSlice { start_idx, end_idx, .. } =
361                                    artifact_slice;
362                                indices.push(total_events);
363                                total_events += end_idx - start_idx;
364                            }
365
366                            main_record
367                                .precompile_events
368                                .events
369                                .insert(code, Vec::with_capacity(total_events));
370
371                            // Download all artifacts at once.
372                            let mut futures = Vec::new();
373                            for artifact_slice in &artifacts {
374                                let PrecompileArtifactSlice { artifact, .. } = artifact_slice;
375                                let client = artifact_client.clone();
376                                futures.push(async move {
377                                    client
378                                        .download::<Vec<(SyscallEvent, PrecompileEvent)>>(artifact)
379                                        .await
380                                });
381                            }
382
383                            // TODO: Better error handling here?
384                            let results = futures::executor::block_on(await_scoped_vec(futures))
385                                .map_err(|e| {
386                                    TaskError::Fatal(anyhow::anyhow!(
387                                        "failed to download precompile events: {}",
388                                        e
389                                    ))
390                                })?;
391
392                            for (i, events) in results.into_iter().enumerate() {
393                                // TODO: unwrap
394                                let events = events.unwrap();
395                                let PrecompileArtifactSlice { start_idx, end_idx, .. } =
396                                    artifacts[i];
397                                main_record
398                                    .precompile_events
399                                    .events
400                                    .get_mut(&code)
401                                    .unwrap()
402                                    .append(
403                                        &mut events
404                                            .into_iter()
405                                            .skip(start_idx)
406                                            .take(end_idx - start_idx)
407                                            .collect(),
408                                    );
409                            }
410
411                            // Set the precompile shard's public values to the initialized state.
412                            main_record.public_values.update_initialized_state(
413                                program.pc_start_abs,
414                                program.enable_untrusted_programs,
415                            );
416
417                            (main_record, None, true)
418                        }
419                    };
420
421                    Ok::<_, TaskError>((program, record, deferred_record, is_precompile))
422                }
423            }
424        })
425        .await
426        .map_err(|e| TaskError::Fatal(e.into()))??;
427
428        // Asynchronously upload the deferred record
429        let deferred_upload_handle = deferred_record.map(|deferred_record| {
430            let artifact_client = self.artifact_client.clone();
431            let worker_client = self.worker_client.clone();
432            let output_artifact = input.deferred_output.clone();
433            let deferred_marker_task = TaskId::new(input.deferred_marker_task.clone().to_id());
434            let opts = self.opts.clone();
435            let program = program.clone();
436            tokio::spawn(
437                async move {
438                    // SplitOpts::new() parses JSON and builds lookup tables - run in spawn_blocking
439                    let program_len = program.instructions.len();
440                    let split_opts = tokio::task::spawn_blocking(move || {
441                        SplitOpts::new(&opts, program_len, false)
442                    })
443                    .await
444                    .map_err(|e| TaskError::Fatal(e.into()))?;
445                    let deferred_data =
446                        DeferredEvents::defer_record(deferred_record, &artifact_client, split_opts)
447                            .await?;
448
449                    artifact_client.upload(&output_artifact, &deferred_data).await?;
450                    worker_client
451                        .complete_task(
452                            input.proof_id,
453                            deferred_marker_task,
454                            TaskMetadata::default(),
455                        )
456                        .await?;
457                    Ok::<(), TaskError>(())
458                }
459                .instrument(tracing::debug_span!("deferred upload")),
460            )
461        });
462
463        // Generate dependencies on the main record.
464        let span = tracing::debug_span!("generate dependencies");
465        let machine_clone = self.machine().clone();
466        let record = tokio::task::spawn_blocking(move || {
467            let _guard = span.enter();
468            let record_iter = std::iter::once(&mut record);
469            machine_clone.generate_dependencies(record_iter, None);
470            record
471        })
472        .await
473        .map_err(|e| TaskError::Fatal(e.into()))?;
474
475        // Optionally dump the shard record and vk to disk for benchmarking/replay.
476        if let Some(dir) = self.dump_shard_dir.as_ref() {
477            use std::sync::atomic::{AtomicUsize, Ordering};
478            static SHARD_IDX: AtomicUsize = AtomicUsize::new(0);
479            let idx = SHARD_IDX.fetch_add(1, Ordering::SeqCst);
480            let path = std::path::PathBuf::from(&dir);
481            std::fs::create_dir_all(&path).ok();
482
483            let record_bytes = bincode::serialize(&record).expect("failed to serialize record");
484            std::fs::write(path.join(format!("record_{idx:04}.bin")), &record_bytes)
485                .expect("failed to write record");
486
487            let vk_bytes = bincode::serialize(&common_input.vk.vk).expect("failed to serialize vk");
488            std::fs::write(path.join(format!("vk_{idx:04}.bin")), &vk_bytes)
489                .expect("failed to write vk");
490
491            tracing::info!(
492                "Dumped shard {idx} record ({} bytes) and vk to {dir}",
493                record_bytes.len()
494            );
495        }
496
497        // If this is not a Core proof request, spawn a task to get the recursion program.
498        let span = tracing::debug_span!("get recursion program");
499        let recursion_program_handle = if common_input.mode != ProofMode::Core {
500            let handle = tokio::task::spawn_blocking({
501                let normalize_program_compiler = self.normalize_program_compiler.clone();
502                let vk = common_input.vk.clone();
503                let shape = shape_from_record(&normalize_program_compiler.verifier, &record)
504                    .ok_or_else(|| {
505                        TaskError::Fatal(anyhow::anyhow!("failed to get shape from record"))
506                    })?;
507                move || {
508                    let _guard = span.enter();
509                    normalize_program_compiler.get_program(vk, &shape)
510                }
511            });
512            Some(handle)
513        } else {
514            None
515        };
516
517        // === Phase 2: Core Proving ===
518        let permits = self.permits.clone();
519
520        let (proof, permit) = if let Some(pk_cache) = &self.pk {
521            // We have a fixed PK cache - use get_or_init to ensure only one worker does setup
522            let pk = pk_cache
523                .get_or_init(|| async {
524                    tracing::info!("Initializing fixed PK cache");
525                    let (pk, _vk) = self
526                        .core_prover
527                        .setup(program.clone(), permits.clone())
528                        .instrument(tracing::debug_span!("core setup"))
529                        .await;
530                    pk
531                })
532                .await;
533
534            tracing::debug!("Using fixed PK");
535            self.core_prover
536                .prove_shard_with_pk(pk.clone(), record, permits)
537                .instrument(tracing::debug_span!("core prove with pk"))
538                .await
539        } else {
540            // No fixed PK cache - always do setup and prove
541            let (_, proof, permit) = self
542                .core_prover
543                .setup_and_prove_shard(
544                    program.clone(),
545                    record,
546                    Some(common_input.vk.vk.clone()),
547                    permits,
548                )
549                .instrument(tracing::debug_span!("core setup and prove"))
550                .await;
551            (proof, permit)
552        };
553        // Release the permit and update the metrics
554        let duration = permit.release();
555        metrics.increment_permit_time(duration);
556
557        let vk_clone = common_input.vk.vk.clone();
558        let proof_clone = proof.clone();
559
560        if self.verify_intermediates {
561            let parent = tracing::Span::current();
562            tokio::task::spawn_blocking(move || {
563                let _guard = parent.enter();
564                let machine_proof = MachineProof::from(vec![proof_clone]);
565                C::core_verifier()
566                    .verify(&vk_clone, &machine_proof)
567                    .map_err(|e| TaskError::Retryable(anyhow!("shard verification failed: {e}")))
568            })
569            .await
570            .map_err(|e| TaskError::Fatal(e.into()))??;
571        }
572
573        let output = input.output;
574        if common_input.mode != ProofMode::Core {
575            let recursion_program = recursion_program_handle
576                .ok_or_else(|| {
577                    TaskError::Fatal(anyhow::anyhow!("recursion program handle not found"))
578                })?
579                .await
580                .map_err(|e| TaskError::Fatal(e.into()))?;
581            let input = self.recursion_prover.get_normalize_witness(
582                &common_input,
583                &proof,
584                false,
585                is_precompile,
586            );
587            let witness = SP1CircuitWitness::Core(input);
588            self.recursion_prover
589                .submit_prove_shard(recursion_program, witness, output, metrics.clone())
590                .instrument(tracing::debug_span!("normalize prove shard"))
591                .await?
592                .await
593                .map_err(|e| TaskError::Fatal(e.into()))??;
594        } else {
595            // Upload the proof
596            self.artifact_client.upload(&output, proof).await?;
597        }
598
599        // Remove the record artifact since it is no longer needed
600        self.artifact_client
601            .try_delete(&record_artifact, ArtifactType::UnspecifiedArtifactType)
602            .await?;
603
604        // Remove task reference for precompile artifacts only at successful completion
605        if let Some(artifacts) = precompile_artifacts {
606            for range in artifacts {
607                let PrecompileArtifactSlice { artifact, start_idx, end_idx } = range;
608                let _ = self
609                    .artifact_client
610                    .remove_ref(
611                        &artifact,
612                        ArtifactType::UnspecifiedArtifactType,
613                        &format!("{}_{}", start_idx, end_idx),
614                    )
615                    .await;
616            }
617        }
618
619        if let Some(deferred_upload_handle) = deferred_upload_handle {
620            deferred_upload_handle.await.map_err(|e| TaskError::Fatal(e.into()))??;
621        }
622
623        // Get the metadata
624        let metadata = metrics.to_metadata();
625        Ok(metadata)
626    }
627}
628
629pub type CoreProvingKey<C> =
630    ProvingKey<SP1GlobalContext, CoreSC, <C as SP1ProverComponents>::CoreProver>;
631
632/// The Core Proving Key cache is initialized once and shared across all CoreAndNormalizeWorkers.
633pub type CoreProvingKeyCache<C> = Arc<OnceCell<Arc<CoreProvingKey<C>>>>;
634
635/// Worker for handling setup tasks only.
636pub struct CoreAndNormalizeWorker<A, C: SP1ProverComponents> {
637    artifact_client: A,
638    core_prover: Arc<C::CoreProver>,
639    permits: ProverSemaphore,
640    _marker: std::marker::PhantomData<C>,
641}
642
643impl<A, C: SP1ProverComponents> CoreAndNormalizeWorker<A, C> {
644    pub fn new(
645        artifact_client: A,
646        core_prover: Arc<C::CoreProver>,
647        permits: ProverSemaphore,
648    ) -> Self {
649        Self { artifact_client, core_prover, permits, _marker: std::marker::PhantomData }
650    }
651}
652
653impl<A: ArtifactClient, C: SP1ProverComponents>
654    AsyncWorker<SetupTask, Result<(TaskId, TaskMetadata), TaskError>>
655    for CoreAndNormalizeWorker<A, C>
656{
657    async fn call(&self, input: SetupTask) -> Result<(TaskId, TaskMetadata), TaskError> {
658        let SetupTask { id, elf, output } = input;
659
660        let elf = self.artifact_client.download_program(&elf).await?;
661
662        let program = Program::from(&elf)?;
663        let program = Arc::new(program);
664
665        let permits = self.permits.clone();
666        let (_pk, vk) = self.core_prover.setup(program, permits).await;
667        tracing::debug!("setup completed for task {}", id);
668
669        // Upload the vk
670        self.artifact_client.upload(&output, vk).await.expect("failed to upload vk");
671        tracing::debug!("upload completed for artifact {}", output.to_id());
672
673        // TODO: Add the busy time here.
674        Ok((id, TaskMetadata::default()))
675    }
676}
677
678pub type SetupEngine<A, P> = Arc<
679    AsyncEngine<SetupTask, Result<(TaskId, TaskMetadata), TaskError>, CoreAndNormalizeWorker<A, P>>,
680>;
681
682/// Unified engine that handles both tracing and core proving in a single async task.
683pub type SP1CoreEngine<A, W, C> =
684    Arc<AsyncEngine<CoreProvingTask, Result<TaskMetadata, TaskError>, CoreWorker<A, W, C>>>;
685
686pub type CoreProveSubmitHandle<A, W, C> = SubmitHandle<SP1CoreEngine<A, W, C>>;
687
688pub type SetupSubmitHandle<A, C> = SubmitHandle<SetupEngine<A, C>>;
689
690pub struct SP1CoreProver<A, W, C: SP1ProverComponents> {
691    prove_shard_engine: SP1CoreEngine<A, W, C>,
692    setup_engine: SetupEngine<A, C>,
693}
694
695impl<A: ArtifactClient, W: WorkerClient, C: SP1ProverComponents> Clone for SP1CoreProver<A, W, C> {
696    fn clone(&self) -> Self {
697        Self {
698            prove_shard_engine: self.prove_shard_engine.clone(),
699            setup_engine: self.setup_engine.clone(),
700        }
701    }
702}
703
704impl<A: ArtifactClient, W: WorkerClient, C: SP1ProverComponents> SP1CoreProver<A, W, C> {
705    pub async fn submit_prove_shard(
706        &self,
707        task: RawTaskRequest,
708    ) -> Result<CoreProveSubmitHandle<A, W, C>, TaskError> {
709        let task = ProveShardTaskRequest::from_raw(task)?;
710        let ProveShardTaskRequest {
711            elf,
712            common_input,
713            record,
714            output,
715            deferred_marker_task,
716            deferred_output,
717            context,
718        } = task;
719
720        let metrics = ProverMetrics::new();
721        let tracing_task = CoreProvingTask {
722            proof_id: context.proof_id,
723            elf,
724            common_input,
725            record,
726            output,
727            deferred_marker_task,
728            deferred_output,
729            metrics,
730        };
731        let handle = self.prove_shard_engine.submit(tracing_task).await?;
732        Ok(handle)
733    }
734
735    pub async fn submit_setup(
736        &self,
737        task: SetupTask,
738    ) -> Result<SetupSubmitHandle<A, C>, SubmitError> {
739        self.setup_engine.submit(task).await
740    }
741}
742
743/// Configuration for the core prover.
744#[derive(Clone)]
745pub struct SP1CoreProverConfig {
746    /// The number of core workers (handles both tracing and proving).
747    pub num_core_workers: usize,
748    /// The buffer size for the core engine.
749    pub core_buffer_size: usize,
750    /// The number of setup workers.
751    pub num_setup_workers: usize,
752    /// The buffer size for the setup.
753    pub setup_buffer_size: usize,
754    /// The size of the normalize program cache.
755    pub normalize_program_cache_size: usize,
756    /// Whether to use a fixed public key.
757    pub use_fixed_pk: bool,
758    /// Whether to verify intermediates.
759    pub verify_intermediates: bool,
760    /// Optional directory to dump shard records and vks for benchmarking/replay.
761    pub dump_shard_dir: Option<String>,
762}
763
764impl<A: ArtifactClient, W: WorkerClient, C: SP1ProverComponents> SP1CoreProver<A, W, C> {
765    pub fn new(
766        config: SP1CoreProverConfig,
767        opts: SP1CoreOpts,
768        artifact_client: A,
769        worker_client: W,
770        air_prover: Arc<C::CoreProver>,
771        permits: ProverSemaphore,
772        recursion_prover: SP1RecursionProver<A, C>,
773    ) -> Self {
774        // Initialize the normalize program compiler
775        let core_verifier = C::core_verifier();
776
777        let normalize_program_cache = SP1NormalizeCache::new(config.normalize_program_cache_size);
778
779        let recursive_core_verifier =
780            recursive_verifier::<SP1GlobalContext, _, InnerConfig>(core_verifier.shard_verifier());
781
782        let reduce_shape = recursion_prover.reduce_shape().clone();
783        let normalize_program_compiler = NormalizeProgramCompiler::new(
784            normalize_program_cache,
785            recursive_core_verifier,
786            reduce_shape,
787            core_verifier,
788        );
789        let normalize_program_compiler = Arc::new(normalize_program_compiler);
790
791        // Create a shared fixed PK cache if enabled
792        let pk_cache = if config.use_fixed_pk { Some(Arc::new(OnceCell::new())) } else { None };
793
794        // Initialize the unified core engine (handles both tracing and proving)
795        let core_workers = (0..config.num_core_workers)
796            .map(|_| {
797                CoreWorker::new(
798                    normalize_program_compiler.clone(),
799                    opts.clone(),
800                    artifact_client.clone(),
801                    worker_client.clone(),
802                    air_prover.clone(),
803                    recursion_prover.clone(),
804                    permits.clone(),
805                    pk_cache.clone(),
806                    config.verify_intermediates,
807                    config.dump_shard_dir.clone(),
808                )
809            })
810            .collect::<Vec<_>>();
811        let prove_shard_engine = Arc::new(AsyncEngine::new(core_workers, config.core_buffer_size));
812
813        // Make the setup engine
814        let setup_workers = (0..config.num_setup_workers)
815            .map(|_| {
816                CoreAndNormalizeWorker::new(
817                    artifact_client.clone(),
818                    air_prover.clone(),
819                    permits.clone(),
820                )
821            })
822            .collect::<Vec<_>>();
823        let setup_engine = Arc::new(AsyncEngine::new(setup_workers, config.setup_buffer_size));
824
825        Self { prove_shard_engine, setup_engine }
826    }
827}