sp1_prover/worker/controller/
global.rs1use std::{
2 collections::{BTreeMap, BTreeSet},
3 sync::Arc,
4};
5
6use itertools::Itertools;
7use sp1_core_executor::{
8 chunked_memory_init_events, events::MemoryInitializeFinalizeEvent, Program, SP1CoreOpts,
9 SplitOpts, UnsafeMemory,
10};
11use sp1_core_executor_runner::MinimalExecutorRunner;
12use sp1_hypercube::air::ShardRange;
13use sp1_prover_types::{Artifact, ArtifactClient};
14use tokio::{
15 sync::{mpsc, oneshot},
16 task::JoinSet,
17};
18use tracing::Instrument;
19
20use crate::worker::{
21 controller::create_core_proving_task, FinalVmState, GlobalMemoryShard, MessageSender,
22 MinimalExecutorCache, ProofData, SpawnProveOutput, TaskContext, TaskError, TraceData,
23 WorkerClient,
24};
25
26pub struct SpliceAddresses {
27 start_clk: u64,
28 end_clk: u64,
29 addresses: Vec<u64>,
30}
31
32#[derive(Clone)]
33pub struct TouchedAddresses {
34 inner: mpsc::Sender<SpliceAddresses>,
35}
36
37impl std::fmt::Debug for TouchedAddresses {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 write!(f, "TouchedAddresses")
40 }
41}
42
43impl TouchedAddresses {
44 pub fn blocking_extend(
45 &self,
46 start_clk: u64,
47 end_clk: u64,
48 addresses: Vec<u64>,
49 ) -> anyhow::Result<()> {
50 self.inner.blocking_send(SpliceAddresses { start_clk, end_clk, addresses })?;
51 Ok(())
52 }
53
54 pub async fn extend(
55 &self,
56 start_clk: u64,
57 end_clk: u64,
58 addresses: Vec<u64>,
59 ) -> anyhow::Result<()> {
60 self.inner.send(SpliceAddresses { start_clk, end_clk, addresses }).await?;
61 Ok(())
62 }
63}
64
65pub struct GlobalMemoryHandler(mpsc::Receiver<SpliceAddresses>);
66
67pub fn global_memory(capacity: usize) -> (TouchedAddresses, GlobalMemoryHandler) {
68 let (tx, rx) = mpsc::channel(capacity);
69 (TouchedAddresses { inner: tx }, GlobalMemoryHandler(rx))
70}
71
72impl GlobalMemoryHandler {
73 #[allow(clippy::too_many_arguments)]
74 pub(super) async fn emit_global_memory_shards<A: ArtifactClient, W: WorkerClient>(
75 mut self,
76 program: Arc<Program>,
77 final_state_rx: oneshot::Receiver<FinalVmState>,
78 executor_rx: oneshot::Receiver<MinimalExecutorRunner>,
79 prove_shard_tx: MessageSender<W, ProofData>,
80 elf_artifact: Artifact,
81 common_input_artifact: Artifact,
82 context: TaskContext,
83 memory: UnsafeMemory,
84 opts: SP1CoreOpts,
85 num_deferred_proofs: usize,
86 artifact_client: A,
87 worker_client: W,
88 minimal_executor_cache: Option<MinimalExecutorCache>,
89 ) -> Result<(), TaskError> {
90 let (shard_data_tx, mut shard_data_rx) =
91 mpsc::unbounded_channel::<(ShardRange, TraceData)>();
92
93 let span = tracing::debug_span!("collect global memory events");
94 let mut join_set = JoinSet::<Result<_, TaskError>>::new();
95 join_set.spawn_blocking({
96 let program = program.clone();
97 move || {
98 let _guard = span.enter();
99 let mut initialized_events = BTreeMap::<u64, MemoryInitializeFinalizeEvent>::new();
100 let mut finalized_events = BTreeMap::<u64, MemoryInitializeFinalizeEvent>::new();
101 let mut dirty_addresses = BTreeSet::<u64>::new();
102 #[cfg(sp1_debug_global_memory)]
103 let mut touched_addresses = hashbrown::HashSet::<u64>::new();
104
105 while let Some(addresses) = self.0.blocking_recv() {
107 let SpliceAddresses { start_clk, end_clk, addresses } = addresses;
108 for addr in addresses {
109 #[cfg(sp1_debug_global_memory)]
110 touched_addresses.insert(addr);
111 initialized_events
113 .entry(addr)
114 .or_insert_with(|| MemoryInitializeFinalizeEvent::initialize(addr, 0));
115
116 let value = unsafe { memory.get(addr) };
121 if value.clk > end_clk || value.clk < start_clk {
124 dirty_addresses.insert(addr);
125 continue;
126 }
127 finalized_events
131 .entry(addr)
132 .and_modify(|entry| {
133 if entry.timestamp < value.clk {
134 entry.value = value.value;
135 entry.timestamp = value.clk;
136 }
137 })
138 .or_insert_with(|| {
139 MemoryInitializeFinalizeEvent::finalize(
140 addr,
141 value.value,
142 value.clk,
143 )
144 });
145 dirty_addresses.remove(&addr);
148 }
149 }
150
151 let minimal_executor = executor_rx
153 .blocking_recv()
154 .map_err(|_| anyhow::anyhow!("failed to receive minimal executor"))?;
155 let hint_init_events = minimal_executor
156 .hints()
157 .iter()
158 .flat_map(|(addr, value)| chunked_memory_init_events(*addr, value));
159 for event in hint_init_events {
160 #[cfg(sp1_debug_global_memory)]
161 touched_addresses.insert(event.addr);
162 initialized_events.insert(event.addr, event);
164 let value = minimal_executor.get_memory_value(event.addr);
166 finalized_events.insert(
167 event.addr,
168 MemoryInitializeFinalizeEvent::finalize(event.addr, value.value, value.clk),
169 );
170 }
171 for addr in dirty_addresses {
173 let value = minimal_executor.get_memory_value(addr);
174 finalized_events.insert(
175 addr,
176 MemoryInitializeFinalizeEvent::finalize(addr, value.value, value.clk),
177 );
178 }
179
180 let final_state = final_state_rx
182 .blocking_recv()
183 .map_err(|_| anyhow::anyhow!("failed to receive final state"))?;
184
185 for (i, entry) in
186 final_state.registers.iter().enumerate().filter(|(_, e)| e.timestamp != 0)
187 {
188 initialized_events
189 .insert(i as u64, MemoryInitializeFinalizeEvent::initialize(i as u64, 0));
190 finalized_events.insert(
191 i as u64,
192 MemoryInitializeFinalizeEvent::finalize(
193 i as u64,
194 entry.value,
195 entry.timestamp,
196 ),
197 );
198 }
199
200 for addr in program.memory_image.keys() {
202 initialized_events.remove(addr);
203 }
204
205 for addr in program.memory_image.keys() {
207 #[cfg(sp1_debug_global_memory)]
208 touched_addresses.insert(*addr);
209 initialized_events.remove(addr);
213 let value = minimal_executor.get_memory_value(*addr);
215 let event =
216 MemoryInitializeFinalizeEvent::finalize(*addr, value.value, value.clk);
217 finalized_events.insert(*addr, event);
218 }
219
220 #[cfg(sp1_debug_global_memory)]
221 for (i, addr) in touched_addresses.into_iter().enumerate() {
222 if i % 100_000 == 0 {
223 tracing::debug!("checked {i} addresses");
224 }
225 let value = minimal_executor.get_memory_value(addr);
226 let event = finalized_events.get(&addr).unwrap();
227
228 let expected_value = value.value;
229 let expected_clk = value.clk;
230 let seen_value = event.value;
231 let seen_clk = event.timestamp;
232 if expected_value != seen_value || expected_clk != seen_clk {
233 panic!("Address {addr} wrong value\n
234 Expected value: {expected_value}, expected clk: {expected_clk}/
235 seen value: {seen_value}, seen clk: {seen_clk}");
236 }
237
238 }
239
240 let mut memory_initialize_events = Vec::with_capacity(initialized_events.len());
241 memory_initialize_events.extend(initialized_events.into_values());
242 let mut memory_finalize_events = Vec::with_capacity(finalized_events.len());
243 memory_finalize_events.extend(finalized_events.into_values());
244
245 let split_opts = SplitOpts::new(&opts, program.instructions.len(), false);
247 let threshold = split_opts.memory;
248
249 let mut previous_init_addr = 0;
250 let mut previous_finalize_addr = 0;
251 let mut previous_init_page_idx = 0;
252 let mut previous_finalize_page_idx = 0;
253 for (i, chunks) in memory_initialize_events
254 .chunks(threshold)
255 .zip_longest(memory_finalize_events.chunks(threshold))
256 .enumerate()
257 {
258 let (initialize_events, finalize_events) = match chunks {
259 itertools::EitherOrBoth::Left(initialize_events) => {
260 let mut init_events = Vec::with_capacity(threshold);
261 init_events.extend_from_slice(initialize_events);
262 (init_events, vec![])
263 }
264 itertools::EitherOrBoth::Right(finalize_events) => {
265 let mut final_events = Vec::with_capacity(threshold);
266 final_events.extend_from_slice(finalize_events);
267 (vec![], final_events)
268 }
269 itertools::EitherOrBoth::Both(initialize_events, finalize_events) => {
270 let mut init_events = Vec::with_capacity(threshold);
271 init_events.extend_from_slice(initialize_events);
272 let mut final_events = Vec::with_capacity(threshold);
273 final_events.extend_from_slice(finalize_events);
274 (init_events, final_events)
275 }
276 };
277 tracing::debug!("Got global memory shard number {i}");
278 let last_init_addr = initialize_events
279 .last()
280 .map(|event| event.addr)
281 .unwrap_or(previous_init_addr);
282 let last_finalize_addr = finalize_events
283 .last()
284 .map(|event| event.addr)
285 .unwrap_or(previous_finalize_addr);
286 tracing::debug!("last_init_addr: {last_init_addr}, last_finalize_addr: {last_finalize_addr}");
287 let last_init_page_idx = previous_init_page_idx;
288 let last_finalize_page_idx = previous_finalize_page_idx;
289 let range = ShardRange {
291 timestamp_range: (final_state.timestamp, final_state.timestamp),
292 initialized_address_range: (previous_init_addr, last_init_addr),
293 finalized_address_range: (previous_finalize_addr, last_finalize_addr),
294 initialized_page_index_range: (previous_init_page_idx, last_init_page_idx),
295 finalized_page_index_range: (
296 previous_finalize_page_idx,
297 last_finalize_page_idx,
298 ),
299 deferred_proof_range: (
300 num_deferred_proofs as u64,
301 num_deferred_proofs as u64,
302 ),
303 };
304 let mem_global_shard = GlobalMemoryShard {
305 final_state,
306 initialize_events,
307 finalize_events,
308 previous_init_addr,
309 previous_finalize_addr,
310 previous_init_page_idx,
311 previous_finalize_page_idx,
312 last_init_addr,
313 last_finalize_addr,
314 last_init_page_idx,
315 last_finalize_page_idx,
316 };
317
318 let data = TraceData::Memory(Box::new(mem_global_shard));
319 shard_data_tx
320 .send((range, data))
321 .map_err(|e| anyhow::anyhow!("failed to send shard data: {}", e))?;
322
323 previous_init_addr = last_init_addr;
324 previous_finalize_addr = last_finalize_addr;
325 previous_init_page_idx = last_init_page_idx;
326 previous_finalize_page_idx = last_finalize_page_idx;
327 }
328
329 Ok(Some(minimal_executor))
330 }
331 });
332
333 join_set.spawn(
334 async move {
335 let mut shard_join_set = JoinSet::new();
336 while let Some((range, data)) = shard_data_rx.recv().await {
337 shard_join_set.spawn({
338 let worker_client = worker_client.clone();
339 let artifact_client = artifact_client.clone();
340 let elf_artifact = elf_artifact.clone();
341 let common_input_artifact = common_input_artifact.clone();
342 let context = context.clone();
343 let prove_shard_tx = prove_shard_tx.clone();
344 async move {
345 let SpawnProveOutput { proof_data, .. } = create_core_proving_task(
346 elf_artifact.clone(),
347 common_input_artifact.clone(),
348 context.clone(),
349 range,
350 data,
351 worker_client,
352 artifact_client,
353 )
354 .await?;
355
356 prove_shard_tx
357 .send(proof_data)
358 .await
359 .map_err(|e| anyhow::anyhow!("failed to send task id: {}", e))?;
360 Ok::<(), TaskError>(())
361 }
362 .in_current_span()
363 });
364 }
365 while let Some(result) = shard_join_set.join_next().await {
367 result.map_err(|e| {
368 anyhow::anyhow!("failed to create a global memory shard task: {}", e)
369 })??;
370 }
371 Ok(None)
372 }
373 .instrument(tracing::debug_span!("create global memory shards")),
374 );
375
376 while let Some(result) = join_set.join_next().await {
378 let maybe_minimal_executor = result
379 .map_err(|e| anyhow::anyhow!("global memory shards task panicked: {}", e))??;
380 if let Some(mut minimal_executor) = maybe_minimal_executor {
381 if let Some(ref minimal_executor_cache) = minimal_executor_cache {
382 minimal_executor.reset();
383 let mut cache = minimal_executor_cache
384 .lock()
385 .instrument(tracing::debug_span!("wait for executor cache lock"))
386 .await;
387 if cache.is_some() {
388 tracing::warn!("Unexpected minimal executor cache is not empty");
389 }
390 *cache = Some(minimal_executor);
391 }
392 }
393 }
394
395 Ok(())
396 }
397}