Skip to main content

sp1_prover/worker/controller/
core.rs

1use std::{
2    marker::PhantomData,
3    sync::{Arc, OnceLock},
4};
5
6use futures::{prelude::*, stream::FuturesUnordered};
7use serde::{Deserialize, Serialize};
8use slop_futures::pipeline::Pipeline;
9use sp1_core_executor::{
10    events::{MemoryInitializeFinalizeEvent, MemoryRecord},
11    CoreVM, ExecutionError, Program, SP1CoreOpts, SyscallCode, UnsafeMemory,
12};
13use sp1_core_executor_runner::MinimalExecutorRunner;
14use sp1_core_machine::{executor::ExecutionOutput, io::SP1Stdin};
15use sp1_hypercube::{
16    air::{ShardRange, PROOF_NONCE_NUM_WORDS, PV_DIGEST_NUM_WORDS},
17    SP1VerifyingKey, DIGEST_SIZE,
18};
19use sp1_jit::MinimalTrace;
20use sp1_prover_types::{network_base_types::ProofMode, Artifact, ArtifactClient, TaskType};
21use tokio::{
22    sync::{mpsc, oneshot},
23    task::JoinSet,
24};
25use tracing::Instrument;
26
27use crate::worker::{
28    global_memory, precompile_channel, DeferredMessage, MinimalExecutorCache,
29    PrecompileArtifactSlice, ProveShardTaskRequest, RawTaskRequest, SplicingEngine, SplicingTask,
30    TaskContext, TaskError, TaskId, WorkerClient,
31};
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ProofData {
35    pub task_id: TaskId,
36    pub range: ShardRange,
37    pub proof: Artifact,
38}
39
40#[derive(Debug, Clone)]
41pub struct MessageSender<W: WorkerClient, T: Serialize> {
42    worker_client: W,
43    task_id: TaskId,
44    _marker: PhantomData<T>,
45}
46
47impl<W: WorkerClient, T: Serialize> MessageSender<W, T> {
48    pub fn new(worker_client: W, task_id: TaskId) -> Self {
49        Self { worker_client, task_id, _marker: PhantomData }
50    }
51
52    pub async fn send(&self, message: T) -> anyhow::Result<()> {
53        let payload = bincode::serialize(&message)?;
54        self.worker_client.send_task_message(&self.task_id, payload).await
55    }
56}
57
58#[derive(Serialize, Deserialize)]
59struct CoreExecuteMetadata {
60    num_deferred_proofs: usize,
61    cycle_limit: Option<u64>,
62}
63
64pub struct CoreExecuteTaskRequest {
65    pub elf: Artifact,
66    pub stdin: Artifact,
67    pub common_input: Artifact,
68    pub execution_output: Artifact,
69    pub num_deferred_proofs: usize,
70    pub cycle_limit: Option<u64>,
71    pub context: TaskContext,
72}
73
74impl CoreExecuteTaskRequest {
75    pub fn from_raw(request: RawTaskRequest) -> Result<Self, TaskError> {
76        let RawTaskRequest { inputs, outputs, context } = request;
77        let [elf, stdin, common_input, metadata] = inputs
78            .try_into()
79            .map_err(|e| TaskError::Fatal(anyhow::anyhow!("invalid task inputs: {e:?}")))?;
80        let [execution_output] = outputs
81            .try_into()
82            .map_err(|e| TaskError::Fatal(anyhow::anyhow!("invalid task outputs: {e:?}")))?;
83        let metadata: CoreExecuteMetadata =
84            serde_json::from_str(&metadata.to_id()).map_err(|e| {
85                TaskError::Fatal(anyhow::anyhow!("failed to deserialize CoreExecuteMetadata: {e}"))
86            })?;
87        Ok(CoreExecuteTaskRequest {
88            elf,
89            stdin,
90            common_input,
91            execution_output,
92            num_deferred_proofs: metadata.num_deferred_proofs,
93            cycle_limit: metadata.cycle_limit,
94            context,
95        })
96    }
97
98    pub fn into_raw(self) -> Result<RawTaskRequest, TaskError> {
99        let metadata = CoreExecuteMetadata {
100            num_deferred_proofs: self.num_deferred_proofs,
101            cycle_limit: self.cycle_limit,
102        };
103        let metadata_str = serde_json::to_string(&metadata).map_err(|e| {
104            TaskError::Fatal(anyhow::anyhow!("failed to serialize CoreExecuteMetadata: {e}"))
105        })?;
106        let metadata_artifact = Artifact::from(metadata_str);
107
108        let inputs = vec![self.elf, self.stdin, self.common_input, metadata_artifact];
109        let outputs = vec![self.execution_output];
110        Ok(RawTaskRequest { inputs, outputs, context: self.context })
111    }
112}
113
114#[derive(Serialize, Deserialize)]
115pub enum TraceData {
116    /// A core record to be proven.
117    Core(Vec<u8>),
118    // Precompile data. Several `PrecompileArtifactSlice`s, and the type of precompile.
119    Precompile(Vec<PrecompileArtifactSlice>, SyscallCode),
120    /// Memory data.
121    Memory(Box<GlobalMemoryShard>),
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct GlobalMemoryShard {
126    pub final_state: FinalVmState,
127    pub initialize_events: Vec<MemoryInitializeFinalizeEvent>,
128    pub finalize_events: Vec<MemoryInitializeFinalizeEvent>,
129    pub previous_init_addr: u64,
130    pub previous_finalize_addr: u64,
131    pub previous_init_page_idx: u64,
132    pub previous_finalize_page_idx: u64,
133    pub last_init_addr: u64,
134    pub last_finalize_addr: u64,
135    pub last_init_page_idx: u64,
136    pub last_finalize_page_idx: u64,
137}
138
139pub struct ProveShardInput {
140    pub elf: Vec<u8>,
141    pub common_input: CommonProverInput,
142    pub record: TraceData,
143    pub opts: SP1CoreOpts,
144}
145
146#[derive(Clone, Serialize, Deserialize)]
147pub struct CommonProverInput {
148    pub vk: SP1VerifyingKey,
149    pub mode: ProofMode,
150    pub deferred_digest: [u32; DIGEST_SIZE],
151    pub num_deferred_proofs: usize,
152    pub nonce: [u32; PROOF_NONCE_NUM_WORDS],
153}
154
155pub struct SP1CoreExecutor<A, W: WorkerClient> {
156    splicing_engine: Arc<SplicingEngine<A, W>>,
157    global_memory_buffer_size: usize,
158    elf: Artifact,
159    stdin: Arc<SP1Stdin>,
160    common_input: Artifact,
161    opts: SP1CoreOpts,
162    num_deferred_proofs: usize,
163    context: TaskContext,
164    sender: MessageSender<W, ProofData>,
165    artifact_client: A,
166    worker_client: W,
167    minimal_executor_cache: Option<MinimalExecutorCache>,
168    cycle_limit: Option<u64>,
169}
170
171impl<A, W: WorkerClient> SP1CoreExecutor<A, W> {
172    #[allow(clippy::too_many_arguments)]
173    pub fn new(
174        splicing_engine: Arc<SplicingEngine<A, W>>,
175        global_memory_buffer_size: usize,
176        elf: Artifact,
177        stdin: Arc<SP1Stdin>,
178        common_input: Artifact,
179        opts: SP1CoreOpts,
180        num_deferred_proofs: usize,
181        context: TaskContext,
182        sender: MessageSender<W, ProofData>,
183        artifact_client: A,
184        worker_client: W,
185        minimal_executor_cache: Option<MinimalExecutorCache>,
186        cycle_limit: Option<u64>,
187    ) -> Self {
188        Self {
189            splicing_engine,
190            global_memory_buffer_size,
191            elf,
192            stdin,
193            common_input,
194            opts,
195            num_deferred_proofs,
196            context,
197            sender,
198            artifact_client,
199            worker_client,
200            minimal_executor_cache,
201            cycle_limit,
202        }
203    }
204}
205
206impl<A, W> SP1CoreExecutor<A, W>
207where
208    A: ArtifactClient,
209    W: WorkerClient,
210{
211    pub async fn execute(self) -> Result<ExecutionOutput, TaskError> {
212        let elf_bytes = self.artifact_client.download_program(&self.elf).await?;
213        let stdin = self.stdin.clone();
214        let opts = self.opts.clone();
215
216        // Get the program from the elf. TODO: handle errors.
217        let program = Arc::new(Program::from(&elf_bytes).map_err(|e| {
218            TaskError::Execution(ExecutionError::Other(format!(
219                "failed to dissassemble program: {}",
220                e
221            )))
222        })?);
223
224        // Initialize the touched addresses map.
225        let (all_touched_addresses, global_memory_handler) =
226            global_memory(self.global_memory_buffer_size);
227        let (deferred_marker_tx, precompile_handler) = precompile_channel(&program, &opts);
228        // Initialize the final vm state.
229        let final_vm_state = FinalVmStateLock::new();
230        let (final_state_tx, final_state_rx) = oneshot::channel::<FinalVmState>();
231
232        // Create a join set in order to be able to cancel all tasks
233        let mut join_set = JoinSet::<Result<(), TaskError>>::new();
234
235        // Start the minimal executor.
236        let (memory_tx, memory_rx) = oneshot::channel::<UnsafeMemory>();
237        let (minimal_executor_tx, minimal_executor_rx) =
238            oneshot::channel::<MinimalExecutorRunner>();
239        let (output_tx, output_rx) = oneshot::channel::<ExecutionOutput>();
240        // Create a channel to send the splicing handles to be awaited and their task_ids being
241        // sent after being submitted to the splicing pipeline.
242        let (splicing_submit_tx, mut splicing_submit_rx) = mpsc::unbounded_channel();
243        let span = tracing::debug_span!("minimal executor");
244
245        // Making the minimal executor blocks the rest of execution anyway, so we initialize it before spawning the rest of the tokio tasks.
246        let mut minimal_executor = if let Some(cache) = &self.minimal_executor_cache {
247            let mut optional_minimal_executor = cache.lock().await;
248            if let Some(minimal_executor) = optional_minimal_executor.take() {
249                tracing::info!("minimal executor cache hit");
250                minimal_executor
251            } else {
252                MinimalExecutorRunner::new(
253                    program.clone(),
254                    false,
255                    Some(opts.minimal_trace_chunk_threshold),
256                    opts.memory_limit,
257                    opts.trace_chunk_slots,
258                )
259            }
260        } else {
261            MinimalExecutorRunner::new(
262                program.clone(),
263                false,
264                Some(opts.minimal_trace_chunk_threshold),
265                opts.memory_limit,
266                opts.trace_chunk_slots,
267            )
268        };
269        join_set.spawn_blocking({
270            let program = program.clone();
271            let elf = self.elf.clone();
272            let common_input_artifact = self.common_input.clone();
273            let context = self.context.clone();
274            let sender = self.sender.clone();
275            let final_vm_state = final_vm_state.clone();
276            let opts = opts.clone();
277            let splicing_engine = self.splicing_engine.clone();
278
279            move || {
280                let _guard = span.enter();
281                // Write input to the minimal executor.
282                for buf in stdin.buffer.iter() {
283                    minimal_executor.with_input(buf);
284                }
285                // Get the unsafe memory view of the minimal executor.
286                let unsafe_memory = minimal_executor.unsafe_memory();
287                // Send the unsafe memory view to the parent task.
288                memory_tx
289                    .send(unsafe_memory)
290                    .map_err(|_| anyhow::anyhow!("failed to send unsafe memory"))?;
291                tracing::debug!("Starting minimal executor");
292                let now = std::time::Instant::now();
293                let mut chunk_count = 0;
294                while let Some(chunk) = minimal_executor
295                    .try_execute_chunk()
296                    .map_err(|e| anyhow::anyhow!("failed to execute chunk: {e}"))?
297                {
298                    tracing::debug!(
299                        trace_chunk = chunk_count,
300                        "mem reads chunk size bytes {}, program is done?: {}",
301                        chunk.num_mem_reads() * std::mem::size_of::<sp1_jit::MemValue>() as u64,
302                        minimal_executor.is_done()
303                    );
304
305                    // Check the `end_clk` for cycle limit
306                    if let Some(cycle_limit) = self.cycle_limit {
307                        let last_clk = chunk.global_clk_end();
308                        if last_clk > cycle_limit {
309                            tracing::error!("Cycle limit exceeded: last_clk = {last_clk}, cycle_limit = {cycle_limit}");
310                            return Err(TaskError::Execution(ExecutionError::ExceededCycleLimit(
311                                cycle_limit,
312                            )));
313                        }
314                    }
315
316                    // Create a splicing task
317                    let task = SplicingTask {
318                        program: program.clone(),
319                        chunk,
320                        elf_artifact: elf.clone(),
321                        common_input_artifact: common_input_artifact.clone(),
322                        num_deferred_proofs: self.num_deferred_proofs,
323                        all_touched_addresses: all_touched_addresses.clone(),
324                        final_vm_state: final_vm_state.clone(),
325                        prove_shard_tx: sender.clone(),
326                        context: context.clone(),
327                        opts: opts.clone(),
328                        deferred_marker_tx: deferred_marker_tx.clone(),
329                    };
330
331                    let splicing_handle = tracing::debug_span!("splicing", idx = chunk_count)
332                        .in_scope(|| {
333                            splicing_engine.blocking_submit(task).map_err(|e| {
334                                anyhow::anyhow!("failed to submit splicing task: {}", e)
335                            })
336                        })?;
337                    splicing_submit_tx
338                        .send((chunk_count, splicing_handle))
339                        .map_err(|e| anyhow::anyhow!("failed to send splicing handle: {}", e))?;
340
341                    chunk_count += 1;
342                }
343                let elapsed = now.elapsed().as_secs_f64();
344                tracing::debug!(
345                    "minimal Executor finished. elapsed: {}s, mhz: {}",
346                    elapsed,
347                    minimal_executor.global_clk() as f64 / (elapsed * 1e6)
348                );
349
350                if chunk_count == 0 {
351                    return Err(TaskError::Fatal(anyhow::anyhow!(
352                        "executor produced zero trace chunks in {elapsed:.3}s \
353                         (global_clk={}, is_done={})",
354                        minimal_executor.global_clk(),
355                        minimal_executor.is_done(),
356                    )));
357                }
358                // Get the output and send it to the output channel.
359                let cycles = minimal_executor.global_clk();
360                let public_value_stream = minimal_executor.public_values_stream().clone();
361
362                let output = ExecutionOutput { cycles, public_value_stream };
363                output_tx.send(output).map_err(|_| anyhow::anyhow!("failed to send output"))?;
364                // Send the hints to the global memory handler.
365                minimal_executor_tx
366                    .send(minimal_executor)
367                    .map_err(|_| anyhow::anyhow!("failed to send minimal executor"))?;
368                Ok::<_, TaskError>(())
369            }
370        });
371
372        let memory =
373            memory_rx.await.map_err(|_| anyhow::anyhow!("failed to receive unsafe memory"))?;
374
375        join_set.spawn({
376            async move {
377                let mut splicing_handles = FuturesUnordered::new();
378                loop {
379                    tokio::select! {
380                        Some((chunk_count, splicing_handle)) = splicing_submit_rx.recv() => {
381                            tracing::debug!(chunk_count = chunk_count, "Received splicing handle");
382                            let handle = splicing_handle.map_ok(move |_| chunk_count);
383                            splicing_handles.push(handle);
384                        }
385                        Some(result) = splicing_handles.next() => {
386                            let chunk_count = result.map_err(|e| anyhow::anyhow!("splicing task panicked: {}", e))?;
387                            tracing::debug!(chunk_count = chunk_count, "Splicing task finished");
388                        }
389                        else => {
390                            tracing::debug!("No more splicing handles to receive");
391                            break;
392                        }
393                    }
394                }
395                // Now that all the splicing tasks are finished, send the final vm state to the global memory handler.
396                let final_state = *final_vm_state.get().ok_or(TaskError::Fatal(anyhow::anyhow!("final vm state not set")))?;
397                final_state_tx.send(final_state).map_err(|_| anyhow::anyhow!("failed to send final vm state"))?;
398                Ok::<_, TaskError>(())
399            }
400            .instrument(tracing::debug_span!("wait for splicers"))
401        });
402
403        // Emit the global memory shards.
404        join_set.spawn(
405            {
406                let artifact_client = self.artifact_client.clone();
407                let worker_client = self.worker_client.clone();
408                let num_deferred_proofs = self.num_deferred_proofs;
409                let sender = self.sender.clone();
410                let elf = self.elf.clone();
411                let common_input = self.common_input.clone();
412                let context = self.context.clone();
413                let minimal_executor_cache = self.minimal_executor_cache.clone();
414
415                async move {
416                    global_memory_handler
417                        .emit_global_memory_shards(
418                            program,
419                            final_state_rx,
420                            minimal_executor_rx,
421                            sender,
422                            elf,
423                            common_input,
424                            context,
425                            memory,
426                            opts,
427                            num_deferred_proofs,
428                            artifact_client,
429                            worker_client,
430                            minimal_executor_cache,
431                        )
432                        .await?;
433                    Ok::<_, TaskError>(())
434                }
435            }
436            .instrument(tracing::debug_span!("emit global memory shards")),
437        );
438
439        // Emit the precompile shards.
440        join_set.spawn({
441            let artifact_client = self.artifact_client.clone();
442            let worker_client = self.worker_client.clone();
443            let sender = self.sender.clone();
444            let elf = self.elf.clone();
445            let common_input = self.common_input.clone();
446            let context = self.context.clone();
447            async move {
448                precompile_handler
449                    .emit_precompile_shards(
450                        elf,
451                        common_input,
452                        sender,
453                        artifact_client,
454                        worker_client,
455                        context,
456                    )
457                    .await?;
458                Ok::<_, TaskError>(())
459            }
460            .instrument(tracing::debug_span!("emit precompile shards"))
461        });
462
463        // Wait for tasks to finish
464        while let Some(result) = join_set.join_next().await {
465            result.map_err(|e| TaskError::Fatal(e.into()))??;
466        }
467
468        let output = output_rx.await.map_err(|_| anyhow::anyhow!("failed to receive output"))?;
469
470        Ok(output)
471    }
472}
473
474#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
475pub struct FinalVmState {
476    pub registers: [MemoryRecord; 32],
477    pub timestamp: u64,
478    pub pc: u64,
479    pub exit_code: u32,
480    pub public_value_digest: [u32; PV_DIGEST_NUM_WORDS],
481    pub proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
482}
483
484impl FinalVmState {
485    pub fn new<'a, 'b>(vm: &'a CoreVM<'b>) -> Self {
486        let registers = *vm.registers();
487        let timestamp = vm.clk();
488        let pc = vm.pc();
489        let exit_code = vm.exit_code();
490        let public_value_digest = vm.public_value_digest;
491        let proof_nonce = vm.proof_nonce;
492
493        Self { registers, timestamp, pc, exit_code, public_value_digest, proof_nonce }
494    }
495}
496
497#[derive(Debug, Clone)]
498pub struct FinalVmStateLock {
499    inner: Arc<OnceLock<FinalVmState>>,
500}
501
502impl Default for FinalVmStateLock {
503    fn default() -> Self {
504        Self::new()
505    }
506}
507
508impl FinalVmStateLock {
509    pub fn new() -> Self {
510        Self { inner: Arc::new(OnceLock::new()) }
511    }
512
513    pub fn set(&self, state: FinalVmState) -> Result<(), TaskError> {
514        self.inner
515            .set(state)
516            .map_err(|_| TaskError::Fatal(anyhow::anyhow!("final vm state already set")))
517    }
518
519    pub fn get(&self) -> Option<&FinalVmState> {
520        self.inner.get()
521    }
522}
523
524pub struct SpawnProveOutput {
525    pub deferred_message: Option<DeferredMessage>,
526    pub proof_data: ProofData,
527}
528
529pub(super) async fn create_core_proving_task<A: ArtifactClient, W: WorkerClient>(
530    elf_artifact: Artifact,
531    common_input_artifact: Artifact,
532    context: TaskContext,
533    range: ShardRange,
534    trace_data: TraceData,
535    worker_client: W,
536    artifact_client: A,
537) -> Result<SpawnProveOutput, ExecutionError> {
538    let record_artifact =
539        artifact_client.create_artifact().map_err(|e| ExecutionError::Other(e.to_string()))?;
540
541    // Make a deferred marker task. This is used for the worker to send
542    // its deferred record back to the controller.
543    let deferred_message = match &trace_data {
544        TraceData::Core(_) => {
545            let marker_task_id = worker_client
546                .submit_task(
547                    TaskType::MarkerDeferredRecord,
548                    RawTaskRequest {
549                        inputs: vec![],
550                        outputs: vec![],
551                        context: TaskContext {
552                            proof_id: context.proof_id.clone(),
553                            parent_id: None,
554                            parent_context: None,
555                            requester_id: context.requester_id.clone(),
556                        },
557                    },
558                )
559                .await
560                .map_err(|e| ExecutionError::Other(e.to_string()))?;
561            let deferred_output_artifact = artifact_client
562                .create_artifact()
563                .map_err(|e| ExecutionError::Other(e.to_string()))?;
564            Some(DeferredMessage { task_id: marker_task_id, record: deferred_output_artifact })
565        }
566        TraceData::Memory(_) | TraceData::Precompile(_, _) => None,
567    };
568
569    artifact_client
570        .upload(&record_artifact, trace_data)
571        .await
572        .map_err(|e| ExecutionError::Other(e.to_string()))?;
573
574    // Allocate an artifact for the proof
575    let proof_artifact = artifact_client
576        .create_artifact()
577        .map_err(|_| ExecutionError::Other("failed to create shard proof artifact".to_string()))?;
578
579    let request = ProveShardTaskRequest {
580        elf: elf_artifact,
581        common_input: common_input_artifact,
582        record: record_artifact,
583        output: proof_artifact.clone(),
584        deferred_marker_task: deferred_message
585            .as_ref()
586            .map(|m| Artifact::from(m.task_id.to_string()))
587            .unwrap_or(Artifact::from("dummy marker task".to_string())),
588        deferred_output: deferred_message
589            .as_ref()
590            .map(|m| m.record.clone())
591            .unwrap_or(Artifact::from("dummy output artifact".to_string())),
592        context,
593    };
594
595    let task = request.into_raw().map_err(|e| ExecutionError::Other(e.to_string()))?;
596
597    // Send the task to the worker.
598    let task_id = worker_client
599        .submit_task(TaskType::ProveShard, task)
600        .await
601        .map_err(|e| ExecutionError::Other(e.to_string()))?;
602    let proof_data = ProofData { task_id, range, proof: proof_artifact };
603    Ok(SpawnProveOutput { deferred_message, proof_data })
604}