1use std::sync::Arc;
2
3use futures::{stream::FuturesUnordered, StreamExt};
4use slop_futures::pipeline::{AsyncEngine, AsyncWorker, Pipeline};
5use sp1_core_executor::{
6 CompressedMemory, CycleResult, ExecutionError, Program, SP1CoreOpts, SplicedMinimalTrace,
7 SplicingVM,
8};
9use sp1_hypercube::air::{ShardBoundary, ShardRange};
10use sp1_jit::{MinimalTrace, TraceChunkRaw};
11use sp1_prover_types::{await_blocking, Artifact, ArtifactClient};
12use tokio::{sync::mpsc, task::JoinSet};
13use tracing::Instrument;
14
15use crate::worker::{
16 controller::create_core_proving_task, CommonProverInput, DeferredMessage, FinalVmState,
17 FinalVmStateLock, MessageSender, ProofData, SpawnProveOutput, TaskContext, TouchedAddresses,
18 TraceData, WorkerClient,
19};
20
21pub type SplicingEngine<A, W> =
22 AsyncEngine<SplicingTask<W>, Result<(), ExecutionError>, SplicingWorker<A, W>>;
23
24pub struct SplicingTask<W: WorkerClient> {
26 pub program: Arc<Program>,
27 pub chunk: TraceChunkRaw,
28 pub elf_artifact: Artifact,
29 pub num_deferred_proofs: usize,
30 pub common_input_artifact: Artifact,
31 pub all_touched_addresses: TouchedAddresses,
32 pub final_vm_state: FinalVmStateLock,
33 pub prove_shard_tx: MessageSender<W, ProofData>,
34 pub context: TaskContext,
35 pub opts: SP1CoreOpts,
36 pub deferred_marker_tx: mpsc::UnboundedSender<DeferredMessage>,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
40pub struct SplicingWorker<A, W> {
41 artifact_client: A,
42 worker_client: W,
43 number_of_send_splice_workers: usize,
44 send_splice_input_buffer_size: usize,
45}
46
47impl<A, W> SplicingWorker<A, W>
48where
49 A: ArtifactClient,
50 W: WorkerClient,
51{
52 pub fn new(
53 artifact_client: A,
54 worker_client: W,
55 number_of_send_splice_workers: usize,
56 send_splice_input_buffer_size: usize,
57 ) -> Self {
58 Self {
59 artifact_client,
60 worker_client,
61 number_of_send_splice_workers,
62 send_splice_input_buffer_size,
63 }
64 }
65
66 fn initialize_send_splice_engine(
67 &self,
68 elf_artifact: Artifact,
69 common_input_artifact: Artifact,
70 context: TaskContext,
71 prove_shard_tx: MessageSender<W, ProofData>,
72 deferred_marker_tx: mpsc::UnboundedSender<DeferredMessage>,
73 ) -> SendSpliceEngine<A, W> {
74 let workers = (0..self.number_of_send_splice_workers)
75 .map(|_| SendSpliceWorker {
76 artifact_client: self.artifact_client.clone(),
77 worker_client: self.worker_client.clone(),
78 elf_artifact: elf_artifact.clone(),
79 common_input_artifact: common_input_artifact.clone(),
80 context: context.clone(),
81 prove_shard_tx: prove_shard_tx.clone(),
82 deferred_marker_tx: deferred_marker_tx.clone(),
83 })
84 .collect();
85 let input_buffer_size = self.send_splice_input_buffer_size;
86 SendSpliceEngine::new(workers, input_buffer_size)
87 }
88}
89
90impl<A, W> AsyncWorker<SplicingTask<W>, Result<(), ExecutionError>> for SplicingWorker<A, W>
91where
92 A: ArtifactClient,
93 W: WorkerClient,
94{
95 async fn call(&self, input: SplicingTask<W>) -> Result<(), ExecutionError> {
96 let SplicingTask {
97 program,
98 chunk,
99 all_touched_addresses,
100 final_vm_state,
101 elf_artifact,
102 common_input_artifact,
103 num_deferred_proofs,
104 prove_shard_tx,
105 context,
106 deferred_marker_tx,
107 opts,
108 } = input;
109 let (splicing_tx, mut splicing_rx) = mpsc::channel::<SendSpliceTask>(2);
110
111 let mut join_set = JoinSet::<Result<(), ExecutionError>>::new();
112 let (send_handle_tx, mut send_handle_rx) = mpsc::unbounded_channel();
114 join_set.spawn(
115 {
116 let send_splice_engine = self.initialize_send_splice_engine(
117 elf_artifact.clone(),
118 common_input_artifact.clone(),
119 context.clone(),
120 prove_shard_tx.clone(),
121 deferred_marker_tx,
122 );
123 async move {
124 while let Some(task) = splicing_rx.recv().await {
125 let handle = send_splice_engine
126 .submit(task)
127 .instrument(tracing::debug_span!("send splice"))
128 .await
129 .map_err(|_| {
130 ExecutionError::Other(
131 "failed to submit send splice task".to_string(),
132 )
133 })?;
134 send_handle_tx.send(handle).map_err(|e| {
135 ExecutionError::Other(format!("error sending to send handle tx: {}", e))
136 })?;
137 }
138 Ok(())
139 }
140 }
141 .instrument(tracing::debug_span!("get splices to serialize")),
142 );
143
144 join_set.spawn(
146 {
147 async move {
148 let mut handles = FuturesUnordered::new();
149 loop {
150 tokio::select! {
151 Some(handle) = send_handle_rx.recv() => {
152 handles.push(handle);
153 }
154 Some(result) = handles.next() => {
155 result.map_err(|e| ExecutionError::Other(format!("failed to join send splice task: {}", e)))??;
156 }
157 else => {
158 break;
159 }
160 }
161 }
162 Ok::<_, ExecutionError>(())
163 }
164 }
165 .instrument(tracing::debug_span!("spawn prove shard tasks")),
166 );
167
168 let common_prover_input = self
169 .artifact_client
170 .download::<CommonProverInput>(&common_input_artifact)
171 .await
172 .map_err(|e| {
173 ExecutionError::Other(format!("error downloading common prover input: {}", e))
174 })?;
175
176 let span = tracing::debug_span!("splicing trace chunk");
178 join_set.spawn_blocking(
179 move || {
180 let _guard = span.enter();
181 let mut touched_addresses = CompressedMemory::new();
182 let mut vm = SplicingVM::new(&chunk, program.clone(), &mut touched_addresses, common_prover_input.nonce, opts);
183
184 let start_num_mem_reads = chunk.num_mem_reads();
185 let start_clk = vm.core.clk();
186 let mut end_clk : u64;
187 let mut last_splice = SplicedMinimalTrace::new_full_trace(chunk.clone());
188 let mut boundary = ShardBoundary {
189 timestamp: start_clk,
190 initialized_address: 0,
191 finalized_address: 0,
192 initialized_page_index: 0,
193 finalized_page_index: 0,
194 deferred_proof: num_deferred_proofs as u64,
195 };
196 loop {
197 tracing::debug!("starting new shard at clk: {} at pc: {}", vm.core.clk(), vm.core.pc());
198 match vm.execute()? {
199 CycleResult::ShardBoundary => {
200 if let Some(spliced) = vm.splice(chunk.clone()) {
202 tracing::debug!(global_clk = vm.core.global_clk(), pc = vm.core.pc(), num_mem_reads_left = vm.core.mem_reads.len(), clk = vm.core.clk(), "shard boundary");
203 end_clk = vm.core.clk();
205 let end = ShardBoundary {
206 timestamp: end_clk,
207 initialized_address: 0,
208 finalized_address: 0,
209 initialized_page_index: 0,
210 finalized_page_index: 0,
211 deferred_proof: num_deferred_proofs as u64,
212 };
213 let range = (boundary..end).into();
215 boundary = end;
217
218 last_splice.set_last_clk(vm.core.clk());
220 last_splice.set_last_mem_reads_idx(
221 start_num_mem_reads as usize - vm.core.mem_reads.len(),
222 );
223 let splice_to_send = std::mem::replace(&mut last_splice, spliced);
224 tracing::debug!(global_clk = vm.core.global_clk(), "sending spliced trace to splicing tx");
225 splicing_tx.blocking_send(SendSpliceTask { chunk: splice_to_send, range })
226 .map_err(|e| ExecutionError::Other(format!("error sending to splicing tx: {}", e)))?;
227 tracing::debug!(global_clk = vm.core.global_clk(), "spliced trace sent to splicing tx");
228 } else {
229 tracing::debug!(global_clk = vm.core.global_clk(), pc = vm.core.pc(), num_mem_reads_left = vm.core.mem_reads.len(), "trace ended");
230 end_clk = vm.core.clk();
232 let end = ShardBoundary {
233 timestamp: end_clk,
234 initialized_address: 0,
235 finalized_address: 0,
236 initialized_page_index: 0,
237 finalized_page_index: 0,
238 deferred_proof: num_deferred_proofs as u64,
239 };
240 let range = (boundary..end).into();
242
243 last_splice.set_last_clk(vm.core.clk());
244 last_splice.set_last_mem_reads_idx(
245 start_num_mem_reads as usize - vm.core.mem_reads.len(),
246 );
247 tracing::debug!(global_clk = vm.core.global_clk(), "sending last splice to splicing tx");
248 splicing_tx.blocking_send(SendSpliceTask { chunk: last_splice, range })
249 .map_err(|e| ExecutionError::Other(format!("error sending to splicing tx: {}", e)))?;
250 tracing::debug!(global_clk = vm.core.global_clk(), "last splice sent to splicing tx");
251 break;
252 }
253 }
254 CycleResult::Done(true) => {
255 tracing::debug!(global_clk = vm.core.global_clk(), "done cycle result");
256 last_splice.set_last_clk(vm.core.clk());
257 last_splice.set_last_mem_reads_idx(chunk.num_mem_reads() as usize);
258
259 end_clk = vm.core.clk();
261 let end = ShardBoundary {
262 timestamp: end_clk,
263 initialized_address: 0,
264 finalized_address: 0,
265 initialized_page_index: 0,
266 finalized_page_index: 0,
267 deferred_proof: num_deferred_proofs as u64,
268 };
269 let range = (boundary..end).into();
271
272 let final_state = FinalVmState::new(&vm.core);
275 final_vm_state.set(final_state).map_err(|e| ExecutionError::Other(e.to_string()))?;
276
277 tracing::debug!(global_clk = vm.core.global_clk(), "sending last splice to splicing tx");
278 splicing_tx.blocking_send(SendSpliceTask { chunk: last_splice, range })
280 .map_err(|e| ExecutionError::Other(format!("error sending to splicing tx: {}", e)))?;
281 tracing::debug!(global_clk = vm.core.global_clk(), "last splice sent to splicing tx");
282 break;
283 }
284 CycleResult::Done(false) | CycleResult::TraceEnd => {
285 unreachable!("The executor should never return an imcomplete program without a shard boundary");
287 }
288 }
289 }
290 tracing::debug_span!("collecting touched addresses and sending to global memory").in_scope(|| {
292 all_touched_addresses.blocking_extend(start_clk, end_clk, touched_addresses.is_set())
293 .map_err(|e| ExecutionError::Other(e.to_string()))})?;
294 Ok(())
295 });
296
297 while let Some(result) = join_set.join_next().await {
299 result
300 .map_err(|e| ExecutionError::Other(format!("splicer task panicked: {}", e)))??;
301 }
302
303 Ok(())
304 }
305}
306
307pub struct SendSpliceTask {
308 pub chunk: SplicedMinimalTrace<TraceChunkRaw>,
309 pub range: ShardRange,
310}
311
312struct SendSpliceWorker<A, W: WorkerClient> {
313 artifact_client: A,
314 worker_client: W,
315 context: TaskContext,
316 elf_artifact: Artifact,
317 common_input_artifact: Artifact,
318 prove_shard_tx: MessageSender<W, ProofData>,
319 deferred_marker_tx: mpsc::UnboundedSender<DeferredMessage>,
320}
321
322impl<A, W> AsyncWorker<SendSpliceTask, Result<(), ExecutionError>> for SendSpliceWorker<A, W>
323where
324 A: ArtifactClient,
325 W: WorkerClient,
326{
327 async fn call(&self, input: SendSpliceTask) -> Result<(), ExecutionError> {
328 let SendSpliceTask { chunk, range } = input;
329 let chunk_bytes = await_blocking(|| bincode::serialize(&chunk))
330 .await
331 .map_err(|_| ExecutionError::Other("chunk serialization failed".to_string()))?
332 .map_err(|e| ExecutionError::Other(e.to_string()))?;
333 let data = TraceData::Core(chunk_bytes);
334
335 let SpawnProveOutput { deferred_message, proof_data } = create_core_proving_task(
336 self.elf_artifact.clone(),
337 self.common_input_artifact.clone(),
338 self.context.clone(),
339 range,
340 data,
341 self.worker_client.clone(),
342 self.artifact_client.clone(),
343 )
344 .await
345 .map_err(|e| ExecutionError::Other(format!("error in create_core_proving_task: {}", e)))?;
346
347 self.prove_shard_tx
348 .send(proof_data)
349 .await
350 .map_err(|e| ExecutionError::Other(format!("error in send proof data: {}", e)))?;
351 if let Some(deferred_message) = deferred_message {
352 self.deferred_marker_tx.send(deferred_message).map_err(|e| {
353 ExecutionError::Other(format!("error in send deferred message: {}", e))
354 })?;
355 }
356 Ok(())
357 }
358}
359
360type SendSpliceEngine<A, W> =
361 AsyncEngine<SendSpliceTask, Result<(), ExecutionError>, SendSpliceWorker<A, W>>;