sp1_prover/worker/controller/
precompiles.rs1use std::sync::Arc;
2
3use futures::StreamExt;
4use hashbrown::HashMap;
5use serde::{Deserialize, Serialize};
6use sp1_core_executor::{ExecutionRecord, Program, SP1CoreOpts, SplitOpts, SyscallCode};
7use sp1_hypercube::air::ShardRange;
8use sp1_prover_types::{await_scoped_vec, Artifact, ArtifactClient, ArtifactType, TaskStatus};
9use tokio::{sync::mpsc, task::JoinSet};
10use tracing::Instrument;
11
12use crate::worker::{
13 controller::create_core_proving_task, MessageSender, ProofData, SpawnProveOutput, TaskContext,
14 TaskError, TaskId, TraceData, WorkerClient,
15};
16
17const CONTROLLER_PRECOMPILE_ARTIFACT_REF: &str = "_controller";
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct PrecompileArtifactSlice {
24 pub artifact: Artifact,
25 pub start_idx: usize,
26 pub end_idx: usize,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct DeferredEvents(pub HashMap<SyscallCode, Vec<PrecompileArtifactSlice>>);
35
36impl DeferredEvents {
37 pub async fn defer_record<A: ArtifactClient>(
39 record: ExecutionRecord,
40 client: &A,
41 split_opts: SplitOpts,
42 ) -> Result<DeferredEvents, TaskError> {
43 let chunk_data = tokio::task::spawn_blocking(move || {
46 let mut chunk_data = Vec::new();
47 for (code, events) in record.precompile_events.events.iter() {
48 let threshold = split_opts.syscall_threshold[*code];
49 for chunk in events.chunks(threshold) {
50 chunk_data.push((*code, chunk.to_vec()));
51 }
52 }
53 chunk_data
54 })
55 .await
56 .map_err(|e| TaskError::Fatal(e.into()))?;
57
58 let artifacts =
60 client.create_artifacts(chunk_data.len()).map_err(TaskError::Fatal)?.to_vec();
61
62 let futures = chunk_data
64 .into_iter()
65 .zip(artifacts.into_iter())
66 .map(|((code, chunk), artifact)| {
67 let client = client.clone();
68 async move {
69 client.upload(&artifact, &chunk).await.unwrap();
70 (code, artifact, chunk.len())
71 }
72 })
73 .collect::<Vec<_>>();
74
75 let res =
76 await_scoped_vec(futures).await.map_err(|e| TaskError::Fatal(anyhow::anyhow!(e)))?;
77
78 let mut deferred: HashMap<SyscallCode, Vec<PrecompileArtifactSlice>> = HashMap::new();
79 for (code, artifact, count) in res {
80 deferred.entry(code).or_default().push(PrecompileArtifactSlice {
81 artifact,
82 start_idx: 0,
83 end_idx: count,
84 });
85 }
86 Ok(DeferredEvents(deferred))
87 }
88
89 pub fn empty() -> Self {
91 Self(HashMap::new())
92 }
93
94 pub async fn append(&mut self, other: DeferredEvents, client: &impl ArtifactClient) {
97 for (code, events) in other.0 {
98 for PrecompileArtifactSlice { artifact, .. } in &events {
101 if let Err(e) = client.add_ref(artifact, CONTROLLER_PRECOMPILE_ARTIFACT_REF).await {
102 tracing::error!("Failed to add ref to artifact {:?}: {:?}", artifact, e);
103 }
104 }
105 self.0.entry(code).or_default().extend(events);
106 }
107 }
108
109 pub async fn split(
111 &mut self,
112 last: bool,
113 opts: SplitOpts,
114 client: &impl ArtifactClient,
115 ) -> Vec<TraceData> {
116 let mut shards = Vec::new();
117 let keys = self.0.keys().cloned().collect::<Vec<_>>();
118 for code in keys {
119 let threshold = opts.syscall_threshold[code];
120 loop {
125 let mut count = 0;
126 let mut index = 0;
129 for (i, artifact_slice) in self.0[&code].iter().enumerate() {
130 let PrecompileArtifactSlice { start_idx, end_idx, .. } = artifact_slice;
131 count += end_idx - start_idx;
132 if count >= threshold || (last && i == self.0[&code].len() - 1) {
134 index = i + 1;
135 break;
136 }
137 }
138 if index == 0 {
140 break;
141 }
142 let mut artifacts =
145 self.0.get_mut(&code).unwrap().drain(..index).collect::<Vec<_>>();
146 for (i, slice) in artifacts.iter().enumerate() {
149 let PrecompileArtifactSlice { artifact, start_idx, end_idx } = slice;
150 if let Err(e) =
151 client.add_ref(artifact, &format!("{:?}_{:?}", start_idx, end_idx)).await
152 {
153 tracing::error!("Failed to add ref to artifact {}: {:?}", artifact, e);
154 }
155 if i == artifacts.len() - 1 && count > threshold {
157 break;
158 }
159 if let Err(e) = client
160 .remove_ref(
161 artifact,
162 ArtifactType::UnspecifiedArtifactType,
163 CONTROLLER_PRECOMPILE_ARTIFACT_REF,
164 )
165 .await
166 {
167 tracing::error!("Failed to remove ref to artifact {}: {:?}", artifact, e);
168 }
169 }
170 if count > threshold {
173 let mut new_range = artifacts.last().cloned().unwrap();
174 new_range.start_idx = new_range.end_idx - (count - threshold);
175 artifacts[index - 1].end_idx = new_range.start_idx;
176 self.0.get_mut(&code).unwrap().insert(0, new_range);
177 }
178 shards.push(TraceData::Precompile(artifacts, code));
179 }
180 }
181 shards
182 }
183}
184
185pub struct DeferredMessage {
186 pub task_id: TaskId,
187 pub record: Artifact,
188}
189
190pub fn precompile_channel(
191 program: &Program,
192 opts: &SP1CoreOpts,
193) -> (mpsc::UnboundedSender<DeferredMessage>, PrecompileHandler) {
194 let split_opts = SplitOpts::new(opts, program.instructions.len(), false);
195 let (deferred_marker_tx, deferred_marker_rx) = mpsc::unbounded_channel();
196 (deferred_marker_tx, PrecompileHandler { split_opts, deferred_marker_rx })
197}
198
199pub struct PrecompileHandler {
200 split_opts: SplitOpts,
201 deferred_marker_rx: mpsc::UnboundedReceiver<DeferredMessage>,
202}
203
204impl PrecompileHandler {
205 #[allow(clippy::too_many_arguments)]
206 pub(super) async fn emit_precompile_shards<A: ArtifactClient, W: WorkerClient>(
207 self,
208 elf_artifact: Artifact,
209 common_input_artifact: Artifact,
210 prove_shard_tx: MessageSender<W, ProofData>,
211 artifact_client: A,
212 worker_client: W,
213 context: TaskContext,
214 ) -> Result<(), TaskError> {
215 let precompile_range = ShardRange::precompile();
216 let mut join_set = JoinSet::new();
217 let task_data_map = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
218
219 let PrecompileHandler { split_opts, mut deferred_marker_rx } = self;
220
221 let (subscriber, mut event_stream) =
223 worker_client.subscriber(context.proof_id.clone()).await?.stream();
224 join_set.spawn({
225 let task_data_map = task_data_map.clone();
226 async move {
227 while let Some(deferred_message) = deferred_marker_rx.recv().await {
228 tracing::debug!(
229 "received deferred message with task id {:?}",
230 deferred_message.task_id
231 );
232 let DeferredMessage { task_id, record: deferred_events } = deferred_message;
233 task_data_map.lock().await.insert(task_id.clone(), deferred_events);
234 subscriber.subscribe(task_id.clone()).map_err(|e| {
235 TaskError::Fatal(anyhow::anyhow!(
236 "error subscribing to task {}: {}",
237 task_id,
238 e
239 ))
240 })?;
241 }
242 Ok::<_, TaskError>(())
243 }
244 .instrument(tracing::debug_span!("deferred listener"))
245 });
246
247 join_set.spawn({
248 let worker_client = worker_client.clone();
249 let artifact_client = artifact_client.clone();
250 async move {
251 let mut deferred_accumulator = DeferredEvents::empty();
252 while let Some((task_id, status)) = event_stream.next().await {
253 tracing::debug!(
254 task_id = task_id.to_string(),
255 "received deferred marker task status: {:?}",
256 status
257 );
258 if status != TaskStatus::Succeeded {
259 return Err(TaskError::Fatal(anyhow::anyhow!(
260 "deferred marker task failed: {}",
261 task_id
262 )));
263 }
264 let deferred_events_artifact = task_data_map.lock().await.remove(&task_id);
265 if let Some(deferred_events_artifact) = deferred_events_artifact {
266 let deferred_events = artifact_client
267 .download::<DeferredEvents>(&deferred_events_artifact)
268 .await;
269 if deferred_events.is_err() {
270 tracing::error!(
271 "failed to download deferred events artifact: {:?}",
272 deferred_events_artifact
273 );
274 }
275 let deferred_events =
278 deferred_events.unwrap_or_else(|_| DeferredEvents::empty());
279
280 deferred_accumulator.append(deferred_events, &artifact_client).await;
281 let new_shards =
282 deferred_accumulator.split(false, split_opts, &artifact_client).await;
283
284 for shard in new_shards {
285 let SpawnProveOutput { deferred_message, proof_data } =
286 create_core_proving_task(
287 elf_artifact.clone(),
288 common_input_artifact.clone(),
289 context.clone(),
290 precompile_range,
291 shard,
292 worker_client.clone(),
293 artifact_client.clone(),
294 )
295 .await
296 .map_err(|e| TaskError::Fatal(e.into()))?;
297
298 if deferred_message.is_some() {
299 return Err(TaskError::Fatal(anyhow::anyhow!(
300 "deferred message is not none",
301 )));
302 }
303 prove_shard_tx.send(proof_data).await.map_err(|e| {
304 TaskError::Fatal(anyhow::anyhow!(
305 "error sending to proving tx: {}",
306 e
307 ))
308 })?;
309 }
310 } else {
311 tracing::debug!(
312 "deferred events artifact not found for task id: {}",
313 task_id
314 );
315 }
316 }
317 let final_shards = deferred_accumulator
318 .split(true, split_opts, &artifact_client)
319 .instrument(tracing::debug_span!("split last"))
320 .await;
321 for shard in final_shards {
322 let SpawnProveOutput { deferred_message, proof_data } =
323 create_core_proving_task(
324 elf_artifact.clone(),
325 common_input_artifact.clone(),
326 context.clone(),
327 precompile_range,
328 shard,
329 worker_client.clone(),
330 artifact_client.clone(),
331 )
332 .await
333 .map_err(|e| TaskError::Fatal(e.into()))?;
334
335 debug_assert!(deferred_message.is_none());
336 prove_shard_tx.send(proof_data).await.map_err(|e| {
337 TaskError::Fatal(anyhow::anyhow!("error sending to proving tx: {}", e))
338 })?;
339 }
340 tracing::debug!("deferred listener task finished");
341 Ok::<_, TaskError>(())
342 }
343 .instrument(tracing::debug_span!("deferred sender"))
344 });
345
346 while let Some(result) = join_set.join_next().await {
347 result.map_err(|e| {
348 TaskError::Fatal(anyhow::anyhow!("deferred listener task panicked: {}", e))
349 })??;
350 }
351 Ok::<(), TaskError>(())
352 }
353}