Skip to main content

sp1_prover/worker/controller/
precompiles.rs

1use std::sync::Arc;
2
3use futures::StreamExt;
4use hashbrown::HashMap;
5use serde::{Deserialize, Serialize};
6use sp1_core_executor::{ExecutionRecord, Program, SP1CoreOpts, SplitOpts, SyscallCode};
7use sp1_hypercube::air::ShardRange;
8use sp1_prover_types::{await_scoped_vec, Artifact, ArtifactClient, ArtifactType, TaskStatus};
9use tokio::{sync::mpsc, task::JoinSet};
10use tracing::Instrument;
11
12use crate::worker::{
13    controller::create_core_proving_task, MessageSender, ProofData, SpawnProveOutput, TaskContext,
14    TaskError, TaskId, TraceData, WorkerClient,
15};
16
17/// String used as key for add_ref to ensure precompile artifacts are not cleaned up before they
18/// are fully split into multiple shards.
19const CONTROLLER_PRECOMPILE_ARTIFACT_REF: &str = "_controller";
20
21/// An artifact of precompile events, and the range of indices to index into.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct PrecompileArtifactSlice {
24    pub artifact: Artifact,
25    pub start_idx: usize,
26    pub end_idx: usize,
27}
28
29/// A lightweight container for the precompile events in a shard.
30///
31/// Rather than actually holding all of the events, the events are represented as `Artifact`s with
32/// start and end indices.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct DeferredEvents(pub HashMap<SyscallCode, Vec<PrecompileArtifactSlice>>);
35
36impl DeferredEvents {
37    /// Defer all events in an ExecutionRecord by uploading each precompile in chunks.
38    pub async fn defer_record<A: ArtifactClient>(
39        record: ExecutionRecord,
40        client: &A,
41        split_opts: SplitOpts,
42    ) -> Result<DeferredEvents, TaskError> {
43        // Move all synchronous work (iteration, chunking) into spawn_blocking
44        // to avoid blocking the async runtime.
45        let chunk_data = tokio::task::spawn_blocking(move || {
46            let mut chunk_data = Vec::new();
47            for (code, events) in record.precompile_events.events.iter() {
48                let threshold = split_opts.syscall_threshold[*code];
49                for chunk in events.chunks(threshold) {
50                    chunk_data.push((*code, chunk.to_vec()));
51                }
52            }
53            chunk_data
54        })
55        .await
56        .map_err(|e| TaskError::Fatal(e.into()))?;
57
58        // Create all artifacts in batch (this is cheap - just generates IDs)
59        let artifacts =
60            client.create_artifacts(chunk_data.len()).map_err(TaskError::Fatal)?.to_vec();
61
62        // Build futures with pre-created artifacts and run uploads in parallel
63        let futures = chunk_data
64            .into_iter()
65            .zip(artifacts.into_iter())
66            .map(|((code, chunk), artifact)| {
67                let client = client.clone();
68                async move {
69                    client.upload(&artifact, &chunk).await.unwrap();
70                    (code, artifact, chunk.len())
71                }
72            })
73            .collect::<Vec<_>>();
74
75        let res =
76            await_scoped_vec(futures).await.map_err(|e| TaskError::Fatal(anyhow::anyhow!(e)))?;
77
78        let mut deferred: HashMap<SyscallCode, Vec<PrecompileArtifactSlice>> = HashMap::new();
79        for (code, artifact, count) in res {
80            deferred.entry(code).or_default().push(PrecompileArtifactSlice {
81                artifact,
82                start_idx: 0,
83                end_idx: count,
84            });
85        }
86        Ok(DeferredEvents(deferred))
87    }
88
89    /// Create an empty DeferredEvents.
90    pub fn empty() -> Self {
91        Self(HashMap::new())
92    }
93
94    /// Append the events from another DeferredEvents to self. Analogous to
95    /// `ExecutionRecord::append`.
96    pub async fn append(&mut self, other: DeferredEvents, client: &impl ArtifactClient) {
97        for (code, events) in other.0 {
98            // Add task references for artifacts so they are not cleaned up before they are fully
99            // split.
100            for PrecompileArtifactSlice { artifact, .. } in &events {
101                if let Err(e) = client.add_ref(artifact, CONTROLLER_PRECOMPILE_ARTIFACT_REF).await {
102                    tracing::error!("Failed to add ref to artifact {:?}: {:?}", artifact, e);
103                }
104            }
105            self.0.entry(code).or_default().extend(events);
106        }
107    }
108
109    /// Split the DeferredEvents into multiple TraceData. Similar to `ExecutionRecord::split`.
110    pub async fn split(
111        &mut self,
112        last: bool,
113        opts: SplitOpts,
114        client: &impl ArtifactClient,
115    ) -> Vec<TraceData> {
116        let mut shards = Vec::new();
117        let keys = self.0.keys().cloned().collect::<Vec<_>>();
118        for code in keys {
119            let threshold = opts.syscall_threshold[code];
120            // self.0[code] contains uploaded artifacts with start and end indices. start is
121            // initially 0. Create shards of precompiles from self.0[code] up to
122            // threshold, then update new [start, end) indices for future splits. If
123            // last is true, don't leave any remainder.
124            loop {
125                let mut count = 0;
126                // Loop through until we've found enough precompiles, and remove from self.0[code].
127                // `index` will be set such that artifacts [0, index) will be made into a shard.
128                let mut index = 0;
129                for (i, artifact_slice) in self.0[&code].iter().enumerate() {
130                    let PrecompileArtifactSlice { start_idx, end_idx, .. } = artifact_slice;
131                    count += end_idx - start_idx;
132                    // Break if we've found enough or it's the last Artifact and `last` is true.
133                    if count >= threshold || (last && i == self.0[&code].len() - 1) {
134                        index = i + 1;
135                        break;
136                    }
137                }
138                // If not enough was found, break.
139                if index == 0 {
140                    break;
141                }
142                // Otherwise remove the artifacts and handle remainder of last artifact if there is
143                // any.
144                let mut artifacts =
145                    self.0.get_mut(&code).unwrap().drain(..index).collect::<Vec<_>>();
146                // For each artifact, add refs for the range needed in prove_shard, and then remove
147                // the controller ref if it's been fully split.
148                for (i, slice) in artifacts.iter().enumerate() {
149                    let PrecompileArtifactSlice { artifact, start_idx, end_idx } = slice;
150                    if let Err(e) =
151                        client.add_ref(artifact, &format!("{:?}_{:?}", start_idx, end_idx)).await
152                    {
153                        tracing::error!("Failed to add ref to artifact {}: {:?}", artifact, e);
154                    }
155                    // If there's a remainder, don't remove the controller ref yet.
156                    if i == artifacts.len() - 1 && count > threshold {
157                        break;
158                    }
159                    if let Err(e) = client
160                        .remove_ref(
161                            artifact,
162                            ArtifactType::UnspecifiedArtifactType,
163                            CONTROLLER_PRECOMPILE_ARTIFACT_REF,
164                        )
165                        .await
166                    {
167                        tracing::error!("Failed to remove ref to artifact {}: {:?}", artifact, e);
168                    }
169                }
170                // If there's extra in the last artifact, truncate it and leave it in the front of
171                // self.0[code].
172                if count > threshold {
173                    let mut new_range = artifacts.last().cloned().unwrap();
174                    new_range.start_idx = new_range.end_idx - (count - threshold);
175                    artifacts[index - 1].end_idx = new_range.start_idx;
176                    self.0.get_mut(&code).unwrap().insert(0, new_range);
177                }
178                shards.push(TraceData::Precompile(artifacts, code));
179            }
180        }
181        shards
182    }
183}
184
185pub struct DeferredMessage {
186    pub task_id: TaskId,
187    pub record: Artifact,
188}
189
190pub fn precompile_channel(
191    program: &Program,
192    opts: &SP1CoreOpts,
193) -> (mpsc::UnboundedSender<DeferredMessage>, PrecompileHandler) {
194    let split_opts = SplitOpts::new(opts, program.instructions.len(), false);
195    let (deferred_marker_tx, deferred_marker_rx) = mpsc::unbounded_channel();
196    (deferred_marker_tx, PrecompileHandler { split_opts, deferred_marker_rx })
197}
198
199pub struct PrecompileHandler {
200    split_opts: SplitOpts,
201    deferred_marker_rx: mpsc::UnboundedReceiver<DeferredMessage>,
202}
203
204impl PrecompileHandler {
205    #[allow(clippy::too_many_arguments)]
206    pub(super) async fn emit_precompile_shards<A: ArtifactClient, W: WorkerClient>(
207        self,
208        elf_artifact: Artifact,
209        common_input_artifact: Artifact,
210        prove_shard_tx: MessageSender<W, ProofData>,
211        artifact_client: A,
212        worker_client: W,
213        context: TaskContext,
214    ) -> Result<(), TaskError> {
215        let precompile_range = ShardRange::precompile();
216        let mut join_set = JoinSet::new();
217        let task_data_map = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
218
219        let PrecompileHandler { split_opts, mut deferred_marker_rx } = self;
220
221        // This subscriber monitors for deferred marker task completion
222        let (subscriber, mut event_stream) =
223            worker_client.subscriber(context.proof_id.clone()).await?.stream();
224        join_set.spawn({
225            let task_data_map = task_data_map.clone();
226            async move {
227                while let Some(deferred_message) = deferred_marker_rx.recv().await {
228                    tracing::debug!(
229                        "received deferred message with task id {:?}",
230                        deferred_message.task_id
231                    );
232                    let DeferredMessage { task_id, record: deferred_events } = deferred_message;
233                    task_data_map.lock().await.insert(task_id.clone(), deferred_events);
234                    subscriber.subscribe(task_id.clone()).map_err(|e| {
235                        TaskError::Fatal(anyhow::anyhow!(
236                            "error subscribing to task {}: {}",
237                            task_id,
238                            e
239                        ))
240                    })?;
241                }
242                Ok::<_, TaskError>(())
243            }
244            .instrument(tracing::debug_span!("deferred listener"))
245        });
246
247        join_set.spawn({
248            let worker_client = worker_client.clone();
249            let artifact_client = artifact_client.clone();
250            async move {
251                let mut deferred_accumulator = DeferredEvents::empty();
252                while let Some((task_id, status)) = event_stream.next().await {
253                    tracing::debug!(
254                        task_id = task_id.to_string(),
255                        "received deferred marker task status: {:?}",
256                        status
257                    );
258                    if status != TaskStatus::Succeeded {
259                        return Err(TaskError::Fatal(anyhow::anyhow!(
260                            "deferred marker task failed: {}",
261                            task_id
262                        )));
263                    }
264                    let deferred_events_artifact = task_data_map.lock().await.remove(&task_id);
265                    if let Some(deferred_events_artifact) = deferred_events_artifact {
266                        let deferred_events = artifact_client
267                            .download::<DeferredEvents>(&deferred_events_artifact)
268                            .await;
269                        if deferred_events.is_err() {
270                            tracing::error!(
271                                "failed to download deferred events artifact: {:?}",
272                                deferred_events_artifact
273                            );
274                        }
275                        // TODO: figure out how to return this as an error while still
276                        // being able to run pure execution without proving.
277                        let deferred_events =
278                            deferred_events.unwrap_or_else(|_| DeferredEvents::empty());
279
280                        deferred_accumulator.append(deferred_events, &artifact_client).await;
281                        let new_shards =
282                            deferred_accumulator.split(false, split_opts, &artifact_client).await;
283
284                        for shard in new_shards {
285                            let SpawnProveOutput { deferred_message, proof_data } =
286                                create_core_proving_task(
287                                    elf_artifact.clone(),
288                                    common_input_artifact.clone(),
289                                    context.clone(),
290                                    precompile_range,
291                                    shard,
292                                    worker_client.clone(),
293                                    artifact_client.clone(),
294                                )
295                                .await
296                                .map_err(|e| TaskError::Fatal(e.into()))?;
297
298                            if deferred_message.is_some() {
299                                return Err(TaskError::Fatal(anyhow::anyhow!(
300                                    "deferred message is not none",
301                                )));
302                            }
303                            prove_shard_tx.send(proof_data).await.map_err(|e| {
304                                TaskError::Fatal(anyhow::anyhow!(
305                                    "error sending to proving tx: {}",
306                                    e
307                                ))
308                            })?;
309                        }
310                    } else {
311                        tracing::debug!(
312                            "deferred events artifact not found for task id: {}",
313                            task_id
314                        );
315                    }
316                }
317                let final_shards = deferred_accumulator
318                    .split(true, split_opts, &artifact_client)
319                    .instrument(tracing::debug_span!("split last"))
320                    .await;
321                for shard in final_shards {
322                    let SpawnProveOutput { deferred_message, proof_data } =
323                        create_core_proving_task(
324                            elf_artifact.clone(),
325                            common_input_artifact.clone(),
326                            context.clone(),
327                            precompile_range,
328                            shard,
329                            worker_client.clone(),
330                            artifact_client.clone(),
331                        )
332                        .await
333                        .map_err(|e| TaskError::Fatal(e.into()))?;
334
335                    debug_assert!(deferred_message.is_none());
336                    prove_shard_tx.send(proof_data).await.map_err(|e| {
337                        TaskError::Fatal(anyhow::anyhow!("error sending to proving tx: {}", e))
338                    })?;
339                }
340                tracing::debug!("deferred listener task finished");
341                Ok::<_, TaskError>(())
342            }
343            .instrument(tracing::debug_span!("deferred sender"))
344        });
345
346        while let Some(result) = join_set.join_next().await {
347            result.map_err(|e| {
348                TaskError::Fatal(anyhow::anyhow!("deferred listener task panicked: {}", e))
349            })??;
350        }
351        Ok::<(), TaskError>(())
352    }
353}