1use std::{
2 marker::PhantomData,
3 sync::{Arc, OnceLock},
4};
5
6use futures::{prelude::*, stream::FuturesUnordered};
7use serde::{Deserialize, Serialize};
8use slop_futures::pipeline::Pipeline;
9use sp1_core_executor::{
10 events::{MemoryInitializeFinalizeEvent, MemoryRecord},
11 CoreVM, ExecutionError, Program, SP1CoreOpts, SyscallCode, UnsafeMemory,
12};
13use sp1_core_executor_runner::MinimalExecutorRunner;
14use sp1_core_machine::{executor::ExecutionOutput, io::SP1Stdin};
15use sp1_hypercube::{
16 air::{ShardRange, PROOF_NONCE_NUM_WORDS, PV_DIGEST_NUM_WORDS},
17 SP1VerifyingKey, DIGEST_SIZE,
18};
19use sp1_jit::MinimalTrace;
20use sp1_prover_types::{network_base_types::ProofMode, Artifact, ArtifactClient, TaskType};
21use tokio::{
22 sync::{mpsc, oneshot},
23 task::JoinSet,
24};
25use tracing::Instrument;
26
27use crate::worker::{
28 global_memory, precompile_channel, DeferredMessage, MinimalExecutorCache,
29 PrecompileArtifactSlice, ProveShardTaskRequest, RawTaskRequest, SplicingEngine, SplicingTask,
30 TaskContext, TaskError, TaskId, WorkerClient,
31};
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ProofData {
35 pub task_id: TaskId,
36 pub range: ShardRange,
37 pub proof: Artifact,
38}
39
40#[derive(Debug, Clone)]
41pub struct MessageSender<W: WorkerClient, T: Serialize> {
42 worker_client: W,
43 task_id: TaskId,
44 _marker: PhantomData<T>,
45}
46
47impl<W: WorkerClient, T: Serialize> MessageSender<W, T> {
48 pub fn new(worker_client: W, task_id: TaskId) -> Self {
49 Self { worker_client, task_id, _marker: PhantomData }
50 }
51
52 pub async fn send(&self, message: T) -> anyhow::Result<()> {
53 let payload = bincode::serialize(&message)?;
54 self.worker_client.send_task_message(&self.task_id, payload).await
55 }
56}
57
58#[derive(Serialize, Deserialize)]
59struct CoreExecuteMetadata {
60 num_deferred_proofs: usize,
61 cycle_limit: Option<u64>,
62}
63
64pub struct CoreExecuteTaskRequest {
65 pub elf: Artifact,
66 pub stdin: Artifact,
67 pub common_input: Artifact,
68 pub execution_output: Artifact,
69 pub num_deferred_proofs: usize,
70 pub cycle_limit: Option<u64>,
71 pub context: TaskContext,
72}
73
74impl CoreExecuteTaskRequest {
75 pub fn from_raw(request: RawTaskRequest) -> Result<Self, TaskError> {
76 let RawTaskRequest { inputs, outputs, context } = request;
77 let [elf, stdin, common_input, metadata] = inputs
78 .try_into()
79 .map_err(|e| TaskError::Fatal(anyhow::anyhow!("invalid task inputs: {e:?}")))?;
80 let [execution_output] = outputs
81 .try_into()
82 .map_err(|e| TaskError::Fatal(anyhow::anyhow!("invalid task outputs: {e:?}")))?;
83 let metadata: CoreExecuteMetadata =
84 serde_json::from_str(&metadata.to_id()).map_err(|e| {
85 TaskError::Fatal(anyhow::anyhow!("failed to deserialize CoreExecuteMetadata: {e}"))
86 })?;
87 Ok(CoreExecuteTaskRequest {
88 elf,
89 stdin,
90 common_input,
91 execution_output,
92 num_deferred_proofs: metadata.num_deferred_proofs,
93 cycle_limit: metadata.cycle_limit,
94 context,
95 })
96 }
97
98 pub fn into_raw(self) -> Result<RawTaskRequest, TaskError> {
99 let metadata = CoreExecuteMetadata {
100 num_deferred_proofs: self.num_deferred_proofs,
101 cycle_limit: self.cycle_limit,
102 };
103 let metadata_str = serde_json::to_string(&metadata).map_err(|e| {
104 TaskError::Fatal(anyhow::anyhow!("failed to serialize CoreExecuteMetadata: {e}"))
105 })?;
106 let metadata_artifact = Artifact::from(metadata_str);
107
108 let inputs = vec![self.elf, self.stdin, self.common_input, metadata_artifact];
109 let outputs = vec![self.execution_output];
110 Ok(RawTaskRequest { inputs, outputs, context: self.context })
111 }
112}
113
114#[derive(Serialize, Deserialize)]
115pub enum TraceData {
116 Core(Vec<u8>),
118 Precompile(Vec<PrecompileArtifactSlice>, SyscallCode),
120 Memory(Box<GlobalMemoryShard>),
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct GlobalMemoryShard {
126 pub final_state: FinalVmState,
127 pub initialize_events: Vec<MemoryInitializeFinalizeEvent>,
128 pub finalize_events: Vec<MemoryInitializeFinalizeEvent>,
129 pub previous_init_addr: u64,
130 pub previous_finalize_addr: u64,
131 pub previous_init_page_idx: u64,
132 pub previous_finalize_page_idx: u64,
133 pub last_init_addr: u64,
134 pub last_finalize_addr: u64,
135 pub last_init_page_idx: u64,
136 pub last_finalize_page_idx: u64,
137}
138
139pub struct ProveShardInput {
140 pub elf: Vec<u8>,
141 pub common_input: CommonProverInput,
142 pub record: TraceData,
143 pub opts: SP1CoreOpts,
144}
145
146#[derive(Clone, Serialize, Deserialize)]
147pub struct CommonProverInput {
148 pub vk: SP1VerifyingKey,
149 pub mode: ProofMode,
150 pub deferred_digest: [u32; DIGEST_SIZE],
151 pub num_deferred_proofs: usize,
152 pub nonce: [u32; PROOF_NONCE_NUM_WORDS],
153}
154
155pub struct SP1CoreExecutor<A, W: WorkerClient> {
156 splicing_engine: Arc<SplicingEngine<A, W>>,
157 global_memory_buffer_size: usize,
158 elf: Artifact,
159 stdin: Arc<SP1Stdin>,
160 common_input: Artifact,
161 opts: SP1CoreOpts,
162 num_deferred_proofs: usize,
163 context: TaskContext,
164 sender: MessageSender<W, ProofData>,
165 artifact_client: A,
166 worker_client: W,
167 minimal_executor_cache: Option<MinimalExecutorCache>,
168 cycle_limit: Option<u64>,
169}
170
171impl<A, W: WorkerClient> SP1CoreExecutor<A, W> {
172 #[allow(clippy::too_many_arguments)]
173 pub fn new(
174 splicing_engine: Arc<SplicingEngine<A, W>>,
175 global_memory_buffer_size: usize,
176 elf: Artifact,
177 stdin: Arc<SP1Stdin>,
178 common_input: Artifact,
179 opts: SP1CoreOpts,
180 num_deferred_proofs: usize,
181 context: TaskContext,
182 sender: MessageSender<W, ProofData>,
183 artifact_client: A,
184 worker_client: W,
185 minimal_executor_cache: Option<MinimalExecutorCache>,
186 cycle_limit: Option<u64>,
187 ) -> Self {
188 Self {
189 splicing_engine,
190 global_memory_buffer_size,
191 elf,
192 stdin,
193 common_input,
194 opts,
195 num_deferred_proofs,
196 context,
197 sender,
198 artifact_client,
199 worker_client,
200 minimal_executor_cache,
201 cycle_limit,
202 }
203 }
204}
205
206impl<A, W> SP1CoreExecutor<A, W>
207where
208 A: ArtifactClient,
209 W: WorkerClient,
210{
211 pub async fn execute(self) -> Result<ExecutionOutput, TaskError> {
212 let elf_bytes = self.artifact_client.download_program(&self.elf).await?;
213 let stdin = self.stdin.clone();
214 let opts = self.opts.clone();
215
216 let program = Arc::new(Program::from(&elf_bytes).map_err(|e| {
218 TaskError::Execution(ExecutionError::Other(format!(
219 "failed to dissassemble program: {}",
220 e
221 )))
222 })?);
223
224 let (all_touched_addresses, global_memory_handler) =
226 global_memory(self.global_memory_buffer_size);
227 let (deferred_marker_tx, precompile_handler) = precompile_channel(&program, &opts);
228 let final_vm_state = FinalVmStateLock::new();
230 let (final_state_tx, final_state_rx) = oneshot::channel::<FinalVmState>();
231
232 let mut join_set = JoinSet::<Result<(), TaskError>>::new();
234
235 let (memory_tx, memory_rx) = oneshot::channel::<UnsafeMemory>();
237 let (minimal_executor_tx, minimal_executor_rx) =
238 oneshot::channel::<MinimalExecutorRunner>();
239 let (output_tx, output_rx) = oneshot::channel::<ExecutionOutput>();
240 let (splicing_submit_tx, mut splicing_submit_rx) = mpsc::unbounded_channel();
243 let span = tracing::debug_span!("minimal executor");
244
245 let mut minimal_executor = if let Some(cache) = &self.minimal_executor_cache {
247 let mut optional_minimal_executor = cache.lock().await;
248 if let Some(minimal_executor) = optional_minimal_executor.take() {
249 tracing::info!("minimal executor cache hit");
250 minimal_executor
251 } else {
252 MinimalExecutorRunner::new(
253 program.clone(),
254 false,
255 Some(opts.minimal_trace_chunk_threshold),
256 opts.memory_limit,
257 opts.trace_chunk_slots,
258 )
259 }
260 } else {
261 MinimalExecutorRunner::new(
262 program.clone(),
263 false,
264 Some(opts.minimal_trace_chunk_threshold),
265 opts.memory_limit,
266 opts.trace_chunk_slots,
267 )
268 };
269 join_set.spawn_blocking({
270 let program = program.clone();
271 let elf = self.elf.clone();
272 let common_input_artifact = self.common_input.clone();
273 let context = self.context.clone();
274 let sender = self.sender.clone();
275 let final_vm_state = final_vm_state.clone();
276 let opts = opts.clone();
277 let splicing_engine = self.splicing_engine.clone();
278
279 move || {
280 let _guard = span.enter();
281 for buf in stdin.buffer.iter() {
283 minimal_executor.with_input(buf);
284 }
285 let unsafe_memory = minimal_executor.unsafe_memory();
287 memory_tx
289 .send(unsafe_memory)
290 .map_err(|_| anyhow::anyhow!("failed to send unsafe memory"))?;
291 tracing::debug!("Starting minimal executor");
292 let now = std::time::Instant::now();
293 let mut chunk_count = 0;
294 while let Some(chunk) = minimal_executor
295 .try_execute_chunk()
296 .map_err(|e| anyhow::anyhow!("failed to execute chunk: {e}"))?
297 {
298 tracing::debug!(
299 trace_chunk = chunk_count,
300 "mem reads chunk size bytes {}, program is done?: {}",
301 chunk.num_mem_reads() * std::mem::size_of::<sp1_jit::MemValue>() as u64,
302 minimal_executor.is_done()
303 );
304
305 if let Some(cycle_limit) = self.cycle_limit {
307 let last_clk = chunk.global_clk_end();
308 if last_clk > cycle_limit {
309 tracing::error!("Cycle limit exceeded: last_clk = {last_clk}, cycle_limit = {cycle_limit}");
310 return Err(TaskError::Execution(ExecutionError::ExceededCycleLimit(
311 cycle_limit,
312 )));
313 }
314 }
315
316 let task = SplicingTask {
318 program: program.clone(),
319 chunk,
320 elf_artifact: elf.clone(),
321 common_input_artifact: common_input_artifact.clone(),
322 num_deferred_proofs: self.num_deferred_proofs,
323 all_touched_addresses: all_touched_addresses.clone(),
324 final_vm_state: final_vm_state.clone(),
325 prove_shard_tx: sender.clone(),
326 context: context.clone(),
327 opts: opts.clone(),
328 deferred_marker_tx: deferred_marker_tx.clone(),
329 };
330
331 let splicing_handle = tracing::debug_span!("splicing", idx = chunk_count)
332 .in_scope(|| {
333 splicing_engine.blocking_submit(task).map_err(|e| {
334 anyhow::anyhow!("failed to submit splicing task: {}", e)
335 })
336 })?;
337 splicing_submit_tx
338 .send((chunk_count, splicing_handle))
339 .map_err(|e| anyhow::anyhow!("failed to send splicing handle: {}", e))?;
340
341 chunk_count += 1;
342 }
343 let elapsed = now.elapsed().as_secs_f64();
344 tracing::debug!(
345 "minimal Executor finished. elapsed: {}s, mhz: {}",
346 elapsed,
347 minimal_executor.global_clk() as f64 / (elapsed * 1e6)
348 );
349
350 if chunk_count == 0 {
351 return Err(TaskError::Fatal(anyhow::anyhow!(
352 "executor produced zero trace chunks in {elapsed:.3}s \
353 (global_clk={}, is_done={})",
354 minimal_executor.global_clk(),
355 minimal_executor.is_done(),
356 )));
357 }
358 let cycles = minimal_executor.global_clk();
360 let public_value_stream = minimal_executor.public_values_stream().clone();
361
362 let output = ExecutionOutput { cycles, public_value_stream };
363 output_tx.send(output).map_err(|_| anyhow::anyhow!("failed to send output"))?;
364 minimal_executor_tx
366 .send(minimal_executor)
367 .map_err(|_| anyhow::anyhow!("failed to send minimal executor"))?;
368 Ok::<_, TaskError>(())
369 }
370 });
371
372 let memory =
373 memory_rx.await.map_err(|_| anyhow::anyhow!("failed to receive unsafe memory"))?;
374
375 join_set.spawn({
376 async move {
377 let mut splicing_handles = FuturesUnordered::new();
378 loop {
379 tokio::select! {
380 Some((chunk_count, splicing_handle)) = splicing_submit_rx.recv() => {
381 tracing::debug!(chunk_count = chunk_count, "Received splicing handle");
382 let handle = splicing_handle.map_ok(move |_| chunk_count);
383 splicing_handles.push(handle);
384 }
385 Some(result) = splicing_handles.next() => {
386 let chunk_count = result.map_err(|e| anyhow::anyhow!("splicing task panicked: {}", e))?;
387 tracing::debug!(chunk_count = chunk_count, "Splicing task finished");
388 }
389 else => {
390 tracing::debug!("No more splicing handles to receive");
391 break;
392 }
393 }
394 }
395 let final_state = *final_vm_state.get().ok_or(TaskError::Fatal(anyhow::anyhow!("final vm state not set")))?;
397 final_state_tx.send(final_state).map_err(|_| anyhow::anyhow!("failed to send final vm state"))?;
398 Ok::<_, TaskError>(())
399 }
400 .instrument(tracing::debug_span!("wait for splicers"))
401 });
402
403 join_set.spawn(
405 {
406 let artifact_client = self.artifact_client.clone();
407 let worker_client = self.worker_client.clone();
408 let num_deferred_proofs = self.num_deferred_proofs;
409 let sender = self.sender.clone();
410 let elf = self.elf.clone();
411 let common_input = self.common_input.clone();
412 let context = self.context.clone();
413 let minimal_executor_cache = self.minimal_executor_cache.clone();
414
415 async move {
416 global_memory_handler
417 .emit_global_memory_shards(
418 program,
419 final_state_rx,
420 minimal_executor_rx,
421 sender,
422 elf,
423 common_input,
424 context,
425 memory,
426 opts,
427 num_deferred_proofs,
428 artifact_client,
429 worker_client,
430 minimal_executor_cache,
431 )
432 .await?;
433 Ok::<_, TaskError>(())
434 }
435 }
436 .instrument(tracing::debug_span!("emit global memory shards")),
437 );
438
439 join_set.spawn({
441 let artifact_client = self.artifact_client.clone();
442 let worker_client = self.worker_client.clone();
443 let sender = self.sender.clone();
444 let elf = self.elf.clone();
445 let common_input = self.common_input.clone();
446 let context = self.context.clone();
447 async move {
448 precompile_handler
449 .emit_precompile_shards(
450 elf,
451 common_input,
452 sender,
453 artifact_client,
454 worker_client,
455 context,
456 )
457 .await?;
458 Ok::<_, TaskError>(())
459 }
460 .instrument(tracing::debug_span!("emit precompile shards"))
461 });
462
463 while let Some(result) = join_set.join_next().await {
465 result.map_err(|e| TaskError::Fatal(e.into()))??;
466 }
467
468 let output = output_rx.await.map_err(|_| anyhow::anyhow!("failed to receive output"))?;
469
470 Ok(output)
471 }
472}
473
474#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
475pub struct FinalVmState {
476 pub registers: [MemoryRecord; 32],
477 pub timestamp: u64,
478 pub pc: u64,
479 pub exit_code: u32,
480 pub public_value_digest: [u32; PV_DIGEST_NUM_WORDS],
481 pub proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
482}
483
484impl FinalVmState {
485 pub fn new<'a, 'b>(vm: &'a CoreVM<'b>) -> Self {
486 let registers = *vm.registers();
487 let timestamp = vm.clk();
488 let pc = vm.pc();
489 let exit_code = vm.exit_code();
490 let public_value_digest = vm.public_value_digest;
491 let proof_nonce = vm.proof_nonce;
492
493 Self { registers, timestamp, pc, exit_code, public_value_digest, proof_nonce }
494 }
495}
496
497#[derive(Debug, Clone)]
498pub struct FinalVmStateLock {
499 inner: Arc<OnceLock<FinalVmState>>,
500}
501
502impl Default for FinalVmStateLock {
503 fn default() -> Self {
504 Self::new()
505 }
506}
507
508impl FinalVmStateLock {
509 pub fn new() -> Self {
510 Self { inner: Arc::new(OnceLock::new()) }
511 }
512
513 pub fn set(&self, state: FinalVmState) -> Result<(), TaskError> {
514 self.inner
515 .set(state)
516 .map_err(|_| TaskError::Fatal(anyhow::anyhow!("final vm state already set")))
517 }
518
519 pub fn get(&self) -> Option<&FinalVmState> {
520 self.inner.get()
521 }
522}
523
524pub struct SpawnProveOutput {
525 pub deferred_message: Option<DeferredMessage>,
526 pub proof_data: ProofData,
527}
528
529pub(super) async fn create_core_proving_task<A: ArtifactClient, W: WorkerClient>(
530 elf_artifact: Artifact,
531 common_input_artifact: Artifact,
532 context: TaskContext,
533 range: ShardRange,
534 trace_data: TraceData,
535 worker_client: W,
536 artifact_client: A,
537) -> Result<SpawnProveOutput, ExecutionError> {
538 let record_artifact =
539 artifact_client.create_artifact().map_err(|e| ExecutionError::Other(e.to_string()))?;
540
541 let deferred_message = match &trace_data {
544 TraceData::Core(_) => {
545 let marker_task_id = worker_client
546 .submit_task(
547 TaskType::MarkerDeferredRecord,
548 RawTaskRequest {
549 inputs: vec![],
550 outputs: vec![],
551 context: TaskContext {
552 proof_id: context.proof_id.clone(),
553 parent_id: None,
554 parent_context: None,
555 requester_id: context.requester_id.clone(),
556 },
557 },
558 )
559 .await
560 .map_err(|e| ExecutionError::Other(e.to_string()))?;
561 let deferred_output_artifact = artifact_client
562 .create_artifact()
563 .map_err(|e| ExecutionError::Other(e.to_string()))?;
564 Some(DeferredMessage { task_id: marker_task_id, record: deferred_output_artifact })
565 }
566 TraceData::Memory(_) | TraceData::Precompile(_, _) => None,
567 };
568
569 artifact_client
570 .upload(&record_artifact, trace_data)
571 .await
572 .map_err(|e| ExecutionError::Other(e.to_string()))?;
573
574 let proof_artifact = artifact_client
576 .create_artifact()
577 .map_err(|_| ExecutionError::Other("failed to create shard proof artifact".to_string()))?;
578
579 let request = ProveShardTaskRequest {
580 elf: elf_artifact,
581 common_input: common_input_artifact,
582 record: record_artifact,
583 output: proof_artifact.clone(),
584 deferred_marker_task: deferred_message
585 .as_ref()
586 .map(|m| Artifact::from(m.task_id.to_string()))
587 .unwrap_or(Artifact::from("dummy marker task".to_string())),
588 deferred_output: deferred_message
589 .as_ref()
590 .map(|m| m.record.clone())
591 .unwrap_or(Artifact::from("dummy output artifact".to_string())),
592 context,
593 };
594
595 let task = request.into_raw().map_err(|e| ExecutionError::Other(e.to_string()))?;
596
597 let task_id = worker_client
599 .submit_task(TaskType::ProveShard, task)
600 .await
601 .map_err(|e| ExecutionError::Other(e.to_string()))?;
602 let proof_data = ProofData { task_id, range, proof: proof_artifact };
603 Ok(SpawnProveOutput { deferred_message, proof_data })
604}