Skip to main content

sp1_prover/worker/controller/
splicing.rs

1use std::sync::Arc;
2
3use futures::{stream::FuturesUnordered, StreamExt};
4use slop_futures::pipeline::{AsyncEngine, AsyncWorker, Pipeline};
5use sp1_core_executor::{
6    CompressedMemory, CycleResult, ExecutionError, Program, SP1CoreOpts, SplicedMinimalTrace,
7    SplicingVM,
8};
9use sp1_hypercube::air::{ShardBoundary, ShardRange};
10use sp1_jit::{MinimalTrace, TraceChunkRaw};
11use sp1_prover_types::{await_blocking, Artifact, ArtifactClient};
12use tokio::{sync::mpsc, task::JoinSet};
13use tracing::Instrument;
14
15use crate::worker::{
16    controller::create_core_proving_task, CommonProverInput, DeferredMessage, FinalVmState,
17    FinalVmStateLock, MessageSender, ProofData, SpawnProveOutput, TaskContext, TouchedAddresses,
18    TraceData, WorkerClient,
19};
20
21pub type SplicingEngine<A, W> =
22    AsyncEngine<SplicingTask<W>, Result<(), ExecutionError>, SplicingWorker<A, W>>;
23
24/// A task for splicing a trace into single shard chunks.
25pub struct SplicingTask<W: WorkerClient> {
26    pub program: Arc<Program>,
27    pub chunk: TraceChunkRaw,
28    pub elf_artifact: Artifact,
29    pub num_deferred_proofs: usize,
30    pub common_input_artifact: Artifact,
31    pub all_touched_addresses: TouchedAddresses,
32    pub final_vm_state: FinalVmStateLock,
33    pub prove_shard_tx: MessageSender<W, ProofData>,
34    pub context: TaskContext,
35    pub opts: SP1CoreOpts,
36    pub deferred_marker_tx: mpsc::UnboundedSender<DeferredMessage>,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
40pub struct SplicingWorker<A, W> {
41    artifact_client: A,
42    worker_client: W,
43    number_of_send_splice_workers: usize,
44    send_splice_input_buffer_size: usize,
45}
46
47impl<A, W> SplicingWorker<A, W>
48where
49    A: ArtifactClient,
50    W: WorkerClient,
51{
52    pub fn new(
53        artifact_client: A,
54        worker_client: W,
55        number_of_send_splice_workers: usize,
56        send_splice_input_buffer_size: usize,
57    ) -> Self {
58        Self {
59            artifact_client,
60            worker_client,
61            number_of_send_splice_workers,
62            send_splice_input_buffer_size,
63        }
64    }
65
66    fn initialize_send_splice_engine(
67        &self,
68        elf_artifact: Artifact,
69        common_input_artifact: Artifact,
70        context: TaskContext,
71        prove_shard_tx: MessageSender<W, ProofData>,
72        deferred_marker_tx: mpsc::UnboundedSender<DeferredMessage>,
73    ) -> SendSpliceEngine<A, W> {
74        let workers = (0..self.number_of_send_splice_workers)
75            .map(|_| SendSpliceWorker {
76                artifact_client: self.artifact_client.clone(),
77                worker_client: self.worker_client.clone(),
78                elf_artifact: elf_artifact.clone(),
79                common_input_artifact: common_input_artifact.clone(),
80                context: context.clone(),
81                prove_shard_tx: prove_shard_tx.clone(),
82                deferred_marker_tx: deferred_marker_tx.clone(),
83            })
84            .collect();
85        let input_buffer_size = self.send_splice_input_buffer_size;
86        SendSpliceEngine::new(workers, input_buffer_size)
87    }
88}
89
90impl<A, W> AsyncWorker<SplicingTask<W>, Result<(), ExecutionError>> for SplicingWorker<A, W>
91where
92    A: ArtifactClient,
93    W: WorkerClient,
94{
95    async fn call(&self, input: SplicingTask<W>) -> Result<(), ExecutionError> {
96        let SplicingTask {
97            program,
98            chunk,
99            all_touched_addresses,
100            final_vm_state,
101            elf_artifact,
102            common_input_artifact,
103            num_deferred_proofs,
104            prove_shard_tx,
105            context,
106            deferred_marker_tx,
107            opts,
108        } = input;
109        let (splicing_tx, mut splicing_rx) = mpsc::channel::<SendSpliceTask>(2);
110
111        let mut join_set = JoinSet::<Result<(), ExecutionError>>::new();
112        // Spawn the task to spawn the prove shard tasks.
113        let (send_handle_tx, mut send_handle_rx) = mpsc::unbounded_channel();
114        join_set.spawn(
115            {
116                let send_splice_engine = self.initialize_send_splice_engine(
117                    elf_artifact.clone(),
118                    common_input_artifact.clone(),
119                    context.clone(),
120                    prove_shard_tx.clone(),
121                    deferred_marker_tx,
122                );
123                async move {
124                    while let Some(task) = splicing_rx.recv().await {
125                        let handle = send_splice_engine
126                            .submit(task)
127                            .instrument(tracing::debug_span!("send splice"))
128                            .await
129                            .map_err(|_| {
130                                ExecutionError::Other(
131                                    "failed to submit send splice task".to_string(),
132                                )
133                            })?;
134                        send_handle_tx.send(handle).map_err(|e| {
135                            ExecutionError::Other(format!("error sending to send handle tx: {}", e))
136                        })?;
137                    }
138                    Ok(())
139                }
140            }
141            .instrument(tracing::debug_span!("get splices to serialize")),
142        );
143
144        // This task waits for prove shard tasks to be sent.
145        join_set.spawn(
146            {
147                async move {
148                    let mut handles = FuturesUnordered::new();
149                    loop {
150                        tokio::select! {
151                            Some(handle) = send_handle_rx.recv() => {
152                                handles.push(handle);
153                            }
154                            Some(result) = handles.next() => {
155                                result.map_err(|e| ExecutionError::Other(format!("failed to join send splice task: {}", e)))??;
156                            }
157                            else => {
158                                break;
159                            }
160                        }
161                    }
162                    Ok::<_, ExecutionError>(())
163                }
164            }
165            .instrument(tracing::debug_span!("spawn prove shard tasks")),
166        );
167
168        let common_prover_input = self
169            .artifact_client
170            .download::<CommonProverInput>(&common_input_artifact)
171            .await
172            .map_err(|e| {
173                ExecutionError::Other(format!("error downloading common prover input: {}", e))
174            })?;
175
176        // Spawn the task that splices the trace.
177        let span = tracing::debug_span!("splicing trace chunk");
178        join_set.spawn_blocking(
179            move || {
180            let _guard = span.enter();
181            let mut touched_addresses = CompressedMemory::new();
182            let mut vm = SplicingVM::new(&chunk, program.clone(), &mut touched_addresses, common_prover_input.nonce, opts);
183
184            let start_num_mem_reads = chunk.num_mem_reads();
185            let start_clk = vm.core.clk();
186            let mut end_clk : u64;
187            let mut last_splice = SplicedMinimalTrace::new_full_trace(chunk.clone());
188                let mut boundary = ShardBoundary {
189                    timestamp: start_clk,
190                    initialized_address: 0,
191                    finalized_address: 0,
192                    initialized_page_index: 0,
193                    finalized_page_index: 0,
194                    deferred_proof: num_deferred_proofs as u64,
195                };
196            loop {
197                tracing::debug!("starting new shard at clk: {} at pc: {}", vm.core.clk(), vm.core.pc());
198                match vm.execute()? {
199                    CycleResult::ShardBoundary => {
200                        // Note: Chunk implentations should always be cheap to clone.
201                        if let Some(spliced) = vm.splice(chunk.clone()) {
202                            tracing::debug!(global_clk = vm.core.global_clk(), pc = vm.core.pc(), num_mem_reads_left = vm.core.mem_reads.len(), clk = vm.core.clk(), "shard boundary");
203                            // Get the end boundary of the shard.
204                            end_clk = vm.core.clk();
205                            let end = ShardBoundary {
206                                timestamp: end_clk,
207                                initialized_address: 0,
208                                finalized_address: 0,
209                                initialized_page_index: 0,
210                                finalized_page_index: 0,
211                                deferred_proof: num_deferred_proofs as u64,
212                            };
213                            // Get the range of the shard.
214                            let range = (boundary..end).into();
215                            // Update the boundary to the end of the shard.
216                            boundary = end;
217
218                            // Set the last splice clk.
219                            last_splice.set_last_clk(vm.core.clk());
220                            last_splice.set_last_mem_reads_idx(
221                                start_num_mem_reads as usize - vm.core.mem_reads.len(),
222                            );
223                            let splice_to_send = std::mem::replace(&mut last_splice, spliced);
224                            tracing::debug!(global_clk = vm.core.global_clk(), "sending spliced trace to splicing tx");
225                            splicing_tx.blocking_send(SendSpliceTask { chunk: splice_to_send, range })
226                                .map_err(|e| ExecutionError::Other(format!("error sending to splicing tx: {}", e)))?;
227                            tracing::debug!(global_clk = vm.core.global_clk(), "spliced trace sent to splicing tx");
228                        } else {
229                            tracing::debug!(global_clk = vm.core.global_clk(), pc = vm.core.pc(), num_mem_reads_left = vm.core.mem_reads.len(), "trace ended");
230                            // Get the end boundary of the shard.
231                            end_clk = vm.core.clk();
232                            let end = ShardBoundary {
233                                timestamp: end_clk,
234                                initialized_address: 0,
235                                finalized_address: 0,
236                                initialized_page_index: 0,
237                                finalized_page_index: 0,
238                                deferred_proof: num_deferred_proofs as u64,
239                            };
240                            // Get the range of the shard.
241                            let range = (boundary..end).into();
242
243                            last_splice.set_last_clk(vm.core.clk());
244                            last_splice.set_last_mem_reads_idx(
245                                start_num_mem_reads as usize - vm.core.mem_reads.len(),
246                            );
247                            tracing::debug!(global_clk = vm.core.global_clk(), "sending last splice to splicing tx");
248                            splicing_tx.blocking_send(SendSpliceTask { chunk: last_splice, range })
249                                .map_err(|e| ExecutionError::Other(format!("error sending to splicing tx: {}", e)))?;
250                            tracing::debug!(global_clk = vm.core.global_clk(), "last splice sent to splicing tx");
251                            break;
252                        }
253                    }
254                    CycleResult::Done(true) => {
255                        tracing::debug!(global_clk = vm.core.global_clk(), "done cycle result");
256                        last_splice.set_last_clk(vm.core.clk());
257                        last_splice.set_last_mem_reads_idx(chunk.num_mem_reads() as usize);
258
259                        // Get the end boundary of the shard.
260                        end_clk = vm.core.clk();
261                        let end = ShardBoundary {
262                            timestamp: end_clk,
263                            initialized_address: 0,
264                            finalized_address: 0,
265                            initialized_page_index: 0,
266                            finalized_page_index: 0,
267                            deferred_proof: num_deferred_proofs as u64,
268                        };
269                        // Get the range of the shard.
270                        let range = (boundary..end).into();
271
272                        // Get the last state of the vm execution and set the global final vm state to
273                        // this value.
274                        let final_state = FinalVmState::new(&vm.core);
275                        final_vm_state.set(final_state).map_err(|e| ExecutionError::Other(e.to_string()))?;
276
277                        tracing::debug!(global_clk = vm.core.global_clk(), "sending last splice to splicing tx");
278                        // Send the last splice.
279                        splicing_tx.blocking_send(SendSpliceTask { chunk: last_splice, range })
280                            .map_err(|e| ExecutionError::Other(format!("error sending to splicing tx: {}", e)))?;
281                        tracing::debug!(global_clk = vm.core.global_clk(), "last splice sent to splicing tx");
282                        break;
283                    }
284                    CycleResult::Done(false) | CycleResult::TraceEnd => {
285                        // Note: Trace ends get mapped to shard boundaries.
286                        unreachable!("The executor should never return an imcomplete program without a shard boundary");
287                    }
288                }
289            }
290            // Append the touched addresses from this chunk to the globally tracked touched addresses.
291            tracing::debug_span!("collecting touched addresses and sending to global memory").in_scope(|| {
292            all_touched_addresses.blocking_extend(start_clk, end_clk, touched_addresses.is_set())
293                .map_err(|e| ExecutionError::Other(e.to_string()))})?;
294            Ok(())
295           });
296
297        // Wait for the tasks to finish and collect the errors.
298        while let Some(result) = join_set.join_next().await {
299            result
300                .map_err(|e| ExecutionError::Other(format!("splicer task panicked: {}", e)))??;
301        }
302
303        Ok(())
304    }
305}
306
307pub struct SendSpliceTask {
308    pub chunk: SplicedMinimalTrace<TraceChunkRaw>,
309    pub range: ShardRange,
310}
311
312struct SendSpliceWorker<A, W: WorkerClient> {
313    artifact_client: A,
314    worker_client: W,
315    context: TaskContext,
316    elf_artifact: Artifact,
317    common_input_artifact: Artifact,
318    prove_shard_tx: MessageSender<W, ProofData>,
319    deferred_marker_tx: mpsc::UnboundedSender<DeferredMessage>,
320}
321
322impl<A, W> AsyncWorker<SendSpliceTask, Result<(), ExecutionError>> for SendSpliceWorker<A, W>
323where
324    A: ArtifactClient,
325    W: WorkerClient,
326{
327    async fn call(&self, input: SendSpliceTask) -> Result<(), ExecutionError> {
328        let SendSpliceTask { chunk, range } = input;
329        let chunk_bytes = await_blocking(|| bincode::serialize(&chunk))
330            .await
331            .map_err(|_| ExecutionError::Other("chunk serialization failed".to_string()))?
332            .map_err(|e| ExecutionError::Other(e.to_string()))?;
333        let data = TraceData::Core(chunk_bytes);
334
335        let SpawnProveOutput { deferred_message, proof_data } = create_core_proving_task(
336            self.elf_artifact.clone(),
337            self.common_input_artifact.clone(),
338            self.context.clone(),
339            range,
340            data,
341            self.worker_client.clone(),
342            self.artifact_client.clone(),
343        )
344        .await
345        .map_err(|e| ExecutionError::Other(format!("error in create_core_proving_task: {}", e)))?;
346
347        self.prove_shard_tx
348            .send(proof_data)
349            .await
350            .map_err(|e| ExecutionError::Other(format!("error in send proof data: {}", e)))?;
351        if let Some(deferred_message) = deferred_message {
352            self.deferred_marker_tx.send(deferred_message).map_err(|e| {
353                ExecutionError::Other(format!("error in send deferred message: {}", e))
354            })?;
355        }
356        Ok(())
357    }
358}
359
360type SendSpliceEngine<A, W> =
361    AsyncEngine<SendSpliceTask, Result<(), ExecutionError>, SendSpliceWorker<A, W>>;