Skip to main content

sp1_prover/worker/controller/
global.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    sync::Arc,
4};
5
6use itertools::Itertools;
7use sp1_core_executor::{
8    chunked_memory_init_events, events::MemoryInitializeFinalizeEvent, Program, SP1CoreOpts,
9    SplitOpts, UnsafeMemory,
10};
11use sp1_core_executor_runner::MinimalExecutorRunner;
12use sp1_hypercube::air::ShardRange;
13use sp1_prover_types::{Artifact, ArtifactClient};
14use tokio::{
15    sync::{mpsc, oneshot},
16    task::JoinSet,
17};
18use tracing::Instrument;
19
20use crate::worker::{
21    controller::create_core_proving_task, FinalVmState, GlobalMemoryShard, MessageSender,
22    MinimalExecutorCache, ProofData, SpawnProveOutput, TaskContext, TaskError, TraceData,
23    WorkerClient,
24};
25
26pub struct SpliceAddresses {
27    start_clk: u64,
28    end_clk: u64,
29    addresses: Vec<u64>,
30}
31
32#[derive(Clone)]
33pub struct TouchedAddresses {
34    inner: mpsc::Sender<SpliceAddresses>,
35}
36
37impl std::fmt::Debug for TouchedAddresses {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        write!(f, "TouchedAddresses")
40    }
41}
42
43impl TouchedAddresses {
44    pub fn blocking_extend(
45        &self,
46        start_clk: u64,
47        end_clk: u64,
48        addresses: Vec<u64>,
49    ) -> anyhow::Result<()> {
50        self.inner.blocking_send(SpliceAddresses { start_clk, end_clk, addresses })?;
51        Ok(())
52    }
53
54    pub async fn extend(
55        &self,
56        start_clk: u64,
57        end_clk: u64,
58        addresses: Vec<u64>,
59    ) -> anyhow::Result<()> {
60        self.inner.send(SpliceAddresses { start_clk, end_clk, addresses }).await?;
61        Ok(())
62    }
63}
64
65pub struct GlobalMemoryHandler(mpsc::Receiver<SpliceAddresses>);
66
67pub fn global_memory(capacity: usize) -> (TouchedAddresses, GlobalMemoryHandler) {
68    let (tx, rx) = mpsc::channel(capacity);
69    (TouchedAddresses { inner: tx }, GlobalMemoryHandler(rx))
70}
71
72impl GlobalMemoryHandler {
73    #[allow(clippy::too_many_arguments)]
74    pub(super) async fn emit_global_memory_shards<A: ArtifactClient, W: WorkerClient>(
75        mut self,
76        program: Arc<Program>,
77        final_state_rx: oneshot::Receiver<FinalVmState>,
78        executor_rx: oneshot::Receiver<MinimalExecutorRunner>,
79        prove_shard_tx: MessageSender<W, ProofData>,
80        elf_artifact: Artifact,
81        common_input_artifact: Artifact,
82        context: TaskContext,
83        memory: UnsafeMemory,
84        opts: SP1CoreOpts,
85        num_deferred_proofs: usize,
86        artifact_client: A,
87        worker_client: W,
88        minimal_executor_cache: Option<MinimalExecutorCache>,
89    ) -> Result<(), TaskError> {
90        let (shard_data_tx, mut shard_data_rx) =
91            mpsc::unbounded_channel::<(ShardRange, TraceData)>();
92
93        let span = tracing::debug_span!("collect global memory events");
94        let mut join_set = JoinSet::<Result<_, TaskError>>::new();
95        join_set.spawn_blocking({
96            let program = program.clone();
97            move || {
98                let _guard = span.enter();
99                let mut initialized_events = BTreeMap::<u64, MemoryInitializeFinalizeEvent>::new();
100                let mut finalized_events = BTreeMap::<u64, MemoryInitializeFinalizeEvent>::new();
101                let mut dirty_addresses = BTreeSet::<u64>::new();
102                #[cfg(sp1_debug_global_memory)]
103                let mut touched_addresses = hashbrown::HashSet::<u64>::new();
104
105                // Collect the addresses
106                while let Some(addresses) = self.0.blocking_recv() {
107                    let SpliceAddresses { start_clk, end_clk, addresses } = addresses;
108                    for addr in addresses {
109                        #[cfg(sp1_debug_global_memory)]
110                        touched_addresses.insert(addr);
111                        // Add the address to the initialized events map if it was not already initialized.
112                        initialized_events
113                            .entry(addr)
114                            .or_insert_with(|| MemoryInitializeFinalizeEvent::initialize(addr, 0));
115
116                        // Get the memory value
117                        // # Safety: since we are waiting for the minimal executor to finish, we assume that
118                        // it is still alive. However, if it did panic, the whole proof flow should fail
119                        // but the potential for undefined behavior is still there.
120                        let value = unsafe { memory.get(addr) };
121                        // If the value was touched after this splice has finished, add it to the 
122                        // dirty addresses set and skip the finalization.
123                        if value.clk > end_clk || value.clk < start_clk {
124                            dirty_addresses.insert(addr);
125                            continue;
126                        }
127                        // Add the address to the finalized events map. If it was already seen, 
128                        // update the value, timestamp and clk. Otherwise, create a new event and 
129                        // add it to the map.
130                        finalized_events
131                            .entry(addr)
132                            .and_modify(|entry| {
133                                if entry.timestamp < value.clk {
134                                   entry.value = value.value;
135                                   entry.timestamp = value.clk;
136                                }
137                            })
138                            .or_insert_with(|| {
139                                MemoryInitializeFinalizeEvent::finalize(
140                                    addr,
141                                    value.value,
142                                    value.clk,
143                                )
144                            });
145                        // If the address was previously dirty, remove it from the dirty addresses
146                        // set.
147                        dirty_addresses.remove(&addr);
148                    }
149                }
150
151                // Collect the hints
152                let minimal_executor = executor_rx
153                    .blocking_recv()
154                    .map_err(|_| anyhow::anyhow!("failed to receive minimal executor"))?;
155                let hint_init_events = minimal_executor
156                    .hints()
157                    .iter()
158                    .flat_map(|(addr, value)| chunked_memory_init_events(*addr, value));
159                for event in hint_init_events {
160                    #[cfg(sp1_debug_global_memory)]
161                    touched_addresses.insert(event.addr);
162                    // Initialize the hint address to the value of the hint
163                    initialized_events.insert(event.addr, event);
164                    // Finalize the addresses of hints.
165                    let value = minimal_executor.get_memory_value(event.addr);
166                    finalized_events.insert(
167                        event.addr,
168                        MemoryInitializeFinalizeEvent::finalize(event.addr, value.value, value.clk),
169                    );
170                }
171                // Finalize the dirty addresses.
172                for addr in dirty_addresses {
173                    let value = minimal_executor.get_memory_value(addr);
174                    finalized_events.insert(
175                        addr,
176                        MemoryInitializeFinalizeEvent::finalize(addr, value.value, value.clk),
177                    );
178                }
179
180                // Wait for the final state
181                let final_state = final_state_rx
182                    .blocking_recv()
183                    .map_err(|_| anyhow::anyhow!("failed to receive final state"))?;
184
185                for (i, entry) in
186                    final_state.registers.iter().enumerate().filter(|(_, e)| e.timestamp != 0)
187                {
188                    initialized_events
189                        .insert(i as u64, MemoryInitializeFinalizeEvent::initialize(i as u64, 0));
190                    finalized_events.insert(
191                        i as u64,
192                        MemoryInitializeFinalizeEvent::finalize(
193                            i as u64,
194                            entry.value,
195                            entry.timestamp,
196                        ),
197                    );
198                }
199
200                // Remove initialized events for addresses in the program memory image.
201                for addr in program.memory_image.keys() {
202                    initialized_events.remove(addr);
203                }
204
205                // Handle the program memory image addresses.
206                for addr in program.memory_image.keys() {
207                    #[cfg(sp1_debug_global_memory)]
208                    touched_addresses.insert(*addr);
209                    // Remove the address from the initialized events map. This is because the 
210                    // program memory image is already initialized as part of the program initial 
211                    // cumulative sum.
212                    initialized_events.remove(addr);
213                    // Finalize the address.
214                    let value = minimal_executor.get_memory_value(*addr);
215                    let event =
216                        MemoryInitializeFinalizeEvent::finalize(*addr, value.value, value.clk);
217                    finalized_events.insert(*addr, event);
218                }
219
220                #[cfg(sp1_debug_global_memory)]
221                for (i, addr) in touched_addresses.into_iter().enumerate() {
222                    if i % 100_000 == 0 {
223                        tracing::debug!("checked {i} addresses");
224                    }
225                    let value = minimal_executor.get_memory_value(addr);
226                    let event = finalized_events.get(&addr).unwrap();
227
228                    let expected_value = value.value;
229                    let expected_clk = value.clk;
230                    let seen_value = event.value;
231                    let seen_clk = event.timestamp;
232                    if expected_value != seen_value || expected_clk != seen_clk {
233                        panic!("Address {addr} wrong value\n
234                            Expected value: {expected_value}, expected clk: {expected_clk}/ 
235                            seen value: {seen_value}, seen clk: {seen_clk}");
236                    }
237
238                }
239
240                let mut memory_initialize_events = Vec::with_capacity(initialized_events.len());
241                memory_initialize_events.extend(initialized_events.into_values());
242                let mut memory_finalize_events = Vec::with_capacity(finalized_events.len());
243                memory_finalize_events.extend(finalized_events.into_values());
244
245                // Get the split opts.
246                let split_opts = SplitOpts::new(&opts, program.instructions.len(), false);
247                let threshold = split_opts.memory;
248
249                let mut previous_init_addr = 0;
250                let mut previous_finalize_addr = 0;
251                let mut previous_init_page_idx = 0;
252                let mut previous_finalize_page_idx = 0;
253                for (i, chunks) in memory_initialize_events
254                    .chunks(threshold)
255                    .zip_longest(memory_finalize_events.chunks(threshold))
256                    .enumerate()
257                {
258                    let (initialize_events, finalize_events) = match chunks {
259                        itertools::EitherOrBoth::Left(initialize_events) => {
260                            let mut init_events = Vec::with_capacity(threshold);
261                            init_events.extend_from_slice(initialize_events);
262                            (init_events, vec![])
263                        }
264                        itertools::EitherOrBoth::Right(finalize_events) => {
265                            let mut final_events = Vec::with_capacity(threshold);
266                            final_events.extend_from_slice(finalize_events);
267                            (vec![], final_events)
268                        }
269                        itertools::EitherOrBoth::Both(initialize_events, finalize_events) => {
270                            let mut init_events = Vec::with_capacity(threshold);
271                            init_events.extend_from_slice(initialize_events);
272                            let mut final_events = Vec::with_capacity(threshold);
273                            final_events.extend_from_slice(finalize_events);
274                            (init_events, final_events)
275                        }
276                    };
277                    tracing::debug!("Got global memory shard number {i}");
278                    let last_init_addr = initialize_events
279                        .last()
280                        .map(|event| event.addr)
281                        .unwrap_or(previous_init_addr);
282                    let last_finalize_addr = finalize_events
283                        .last()
284                        .map(|event| event.addr)
285                        .unwrap_or(previous_finalize_addr);
286                    tracing::debug!("last_init_addr: {last_init_addr}, last_finalize_addr: {last_finalize_addr}");
287                    let last_init_page_idx = previous_init_page_idx;
288                    let last_finalize_page_idx = previous_finalize_page_idx;
289                    // Calculate the range of the shard.
290                    let range = ShardRange {
291                        timestamp_range: (final_state.timestamp, final_state.timestamp),
292                        initialized_address_range: (previous_init_addr, last_init_addr),
293                        finalized_address_range: (previous_finalize_addr, last_finalize_addr),
294                        initialized_page_index_range: (previous_init_page_idx, last_init_page_idx),
295                        finalized_page_index_range: (
296                            previous_finalize_page_idx,
297                            last_finalize_page_idx,
298                        ),
299                        deferred_proof_range: (
300                            num_deferred_proofs as u64,
301                            num_deferred_proofs as u64,
302                        ),
303                    };
304                    let mem_global_shard = GlobalMemoryShard {
305                        final_state,
306                        initialize_events,
307                        finalize_events,
308                        previous_init_addr,
309                        previous_finalize_addr,
310                        previous_init_page_idx,
311                        previous_finalize_page_idx,
312                        last_init_addr,
313                        last_finalize_addr,
314                        last_init_page_idx,
315                        last_finalize_page_idx,
316                    };
317
318                    let data = TraceData::Memory(Box::new(mem_global_shard));
319                    shard_data_tx
320                        .send((range, data))
321                        .map_err(|e| anyhow::anyhow!("failed to send shard data: {}", e))?;
322
323                    previous_init_addr = last_init_addr;
324                    previous_finalize_addr = last_finalize_addr;
325                    previous_init_page_idx = last_init_page_idx;
326                    previous_finalize_page_idx = last_finalize_page_idx;
327                }
328
329                Ok(Some(minimal_executor))
330            }
331        });
332
333        join_set.spawn(
334            async move {
335                let mut shard_join_set = JoinSet::new();
336                while let Some((range, data)) = shard_data_rx.recv().await {
337                    shard_join_set.spawn({
338                        let worker_client = worker_client.clone();
339                        let artifact_client = artifact_client.clone();
340                        let elf_artifact = elf_artifact.clone();
341                        let common_input_artifact = common_input_artifact.clone();
342                        let context = context.clone();
343                        let prove_shard_tx = prove_shard_tx.clone();
344                        async move {
345                            let SpawnProveOutput { proof_data, .. } = create_core_proving_task(
346                                elf_artifact.clone(),
347                                common_input_artifact.clone(),
348                                context.clone(),
349                                range,
350                                data,
351                                worker_client,
352                                artifact_client,
353                            )
354                            .await?;
355
356                            prove_shard_tx
357                                .send(proof_data)
358                                .await
359                                .map_err(|e| anyhow::anyhow!("failed to send task id: {}", e))?;
360                            Ok::<(), TaskError>(())
361                        }
362                        .in_current_span()
363                    });
364                }
365                // Wait for all the shard task to be created
366                while let Some(result) = shard_join_set.join_next().await {
367                    result.map_err(|e| {
368                        anyhow::anyhow!("failed to create a global memory shard task: {}", e)
369                    })??;
370                }
371                Ok(None)
372            }
373            .instrument(tracing::debug_span!("create global memory shards")),
374        );
375
376        // Wait for the tasks to finish
377        while let Some(result) = join_set.join_next().await {
378            let maybe_minimal_executor = result
379                .map_err(|e| anyhow::anyhow!("global memory shards task panicked: {}", e))??;
380            if let Some(mut minimal_executor) = maybe_minimal_executor {
381                if let Some(ref minimal_executor_cache) = minimal_executor_cache {
382                    minimal_executor.reset();
383                    let mut cache = minimal_executor_cache
384                        .lock()
385                        .instrument(tracing::debug_span!("wait for executor cache lock"))
386                        .await;
387                    if cache.is_some() {
388                        tracing::warn!("Unexpected minimal executor cache is not empty");
389                    }
390                    *cache = Some(minimal_executor);
391                }
392            }
393        }
394
395        Ok(())
396    }
397}