Skip to main content

sp1_core_machine/utils/
prove.rs

1use std::{borrow::Borrow, collections::BTreeMap, io, sync::Arc};
2
3use crate::executor::trace_chunk;
4use crate::riscv::RiscvAir;
5use hashbrown::HashSet;
6use thiserror::Error;
7
8use slop_algebra::PrimeField32;
9use slop_challenger::IopCtx;
10use sp1_hypercube::{
11    air::{PublicValues, PROOF_NONCE_NUM_WORDS},
12    prover::{AirProver, PcsProof, ProvingKey, SimpleProver},
13    MachineProof, MachineRecord, ShardContext,
14};
15
16use crate::io::SP1Stdin;
17use sp1_core_executor::{SP1CoreOpts, SplitOpts};
18
19use sp1_core_executor::{
20    chunked_memory_init_events,
21    events::{MemoryInitializeFinalizeEvent, MemoryRecord},
22    CompressedMemory, CompressedPages, CycleResult, ExecutionError, ExecutionRecord, Program,
23    SP1Context, SplicedMinimalTrace, SplicingVMEnum,
24};
25use sp1_core_executor_runner::MinimalExecutorRunner;
26use sp1_jit::{MinimalTrace, TraceChunk};
27
28/// Generate execution records from a program and inputs.
29///
30/// This function executes the program, splits execution into shards, and generates
31/// execution records suitable for proving. Returns the records and total cycle count.
32///
33/// This is a test-only function that generates records sequentially for simplicity.
34pub fn generate_records<F>(
35    program: Arc<Program>,
36    stdin: SP1Stdin,
37    opts: SP1CoreOpts,
38    proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
39) -> Result<(Vec<ExecutionRecord>, u64), SP1CoreProverError>
40where
41    F: PrimeField32,
42{
43    let machine = RiscvAir::<F>::machine();
44    let split_opts = SplitOpts::new(&opts, program.instructions.len(), false);
45
46    // Phase 1: Run MinimalExecutorRunner to generate trace chunks
47    let mut minimal_executor = MinimalExecutorRunner::new(
48        program.clone(),
49        false,
50        Some(opts.minimal_trace_chunk_threshold),
51        opts.memory_limit,
52        opts.trace_chunk_slots,
53    );
54
55    for buf in stdin.buffer {
56        minimal_executor.with_input(&buf);
57    }
58
59    let mut trace_chunks = Vec::new();
60    while let Some(chunk) = minimal_executor.try_execute_chunk()? {
61        // Convert TraceChunkRaw to TraceChunk so we are sure to **own**
62        // the memory. This avoids deadlock situation when shared memory
63        // based chunk is used.
64        let chunk: TraceChunk = chunk.into();
65        trace_chunks.push(chunk);
66    }
67
68    // Phase 2: Splice chunks and trace them to generate records
69    let mut all_records = Vec::new();
70    let mut deferred =
71        ExecutionRecord::new(program.clone(), proof_nonce, opts.global_dependencies_opt);
72    let mut touched_addresses = HashSet::new();
73    let mut touched_pages = HashSet::new();
74
75    for chunk in trace_chunks {
76        // Splice the chunk into shards
77        let spliced_traces = splice_chunk_sequential(
78            program.clone(),
79            chunk,
80            proof_nonce,
81            opts.clone(),
82            &mut touched_addresses,
83            &mut touched_pages,
84        );
85
86        // Trace each spliced chunk to generate execution records
87        for (is_last, spliced) in spliced_traces {
88            let record =
89                ExecutionRecord::new(program.clone(), proof_nonce, opts.global_dependencies_opt);
90            let (done, mut record, final_registers) =
91                trace_chunk::<F>(program.clone(), opts.clone(), spliced, proof_nonce, record)
92                    .map_err(SP1CoreProverError::ExecutionError)?;
93
94            if done {
95                // Insert global memory events for the last record
96                emit_globals(
97                    &minimal_executor,
98                    &mut record,
99                    final_registers,
100                    touched_addresses.clone(),
101                    touched_pages.clone(),
102                );
103            }
104
105            // Handle deferral
106            deferred.append(&mut record.defer(&opts.retained_events_presets));
107            let can_pack = done
108                && record.estimated_trace_area <= split_opts.pack_trace_threshold
109                && deferred.global_memory_initialize_events.len()
110                    <= split_opts.combine_memory_threshold
111                && deferred.global_memory_finalize_events.len()
112                    <= split_opts.combine_memory_threshold
113                && deferred.global_page_prot_initialize_events.len()
114                    <= split_opts.combine_page_prot_threshold
115                && deferred.global_page_prot_finalize_events.len()
116                    <= split_opts.combine_page_prot_threshold;
117            let deferred_records =
118                deferred.split(done || is_last, &mut record, can_pack, &split_opts);
119
120            // Generate dependencies and collect records
121            let mut records = vec![record];
122            records.extend(deferred_records);
123            machine.generate_dependencies(records.iter_mut(), None);
124            all_records.extend(records);
125        }
126    }
127
128    let cycles = minimal_executor.global_clk();
129    Ok((all_records, cycles))
130}
131
132/// Postprocess into an existing [`ExecutionRecord`],
133/// consisting of all the [`MemoryInitializeFinalizeEvent`]s.
134#[tracing::instrument(name = "emit globals", skip_all)]
135pub fn emit_globals(
136    minimal_executor: &MinimalExecutorRunner,
137    record: &mut ExecutionRecord,
138    final_registers: [MemoryRecord; 32],
139    mut touched_addresses: HashSet<u64>,
140    _touched_pages: HashSet<u64>,
141) {
142    // Add all the finalize addresses to the touched addresses.
143    touched_addresses.extend(minimal_executor.program().memory_image.keys().copied());
144
145    record.global_memory_initialize_events.extend(
146        final_registers
147            .iter()
148            .enumerate()
149            .filter(|(_, e)| e.timestamp != 0)
150            .map(|(i, _)| MemoryInitializeFinalizeEvent::initialize(i as u64, 0)),
151    );
152
153    record.global_memory_finalize_events.extend(
154        final_registers.iter().enumerate().filter(|(_, e)| e.timestamp != 0).map(|(i, entry)| {
155            MemoryInitializeFinalizeEvent::finalize(i as u64, entry.value, entry.timestamp)
156        }),
157    );
158
159    let hint_init_events: Vec<MemoryInitializeFinalizeEvent> = minimal_executor
160        .hints()
161        .iter()
162        .flat_map(|(addr, value)| chunked_memory_init_events(*addr, value))
163        .collect::<Vec<_>>();
164    let hint_addrs = hint_init_events.iter().map(|event| event.addr).collect::<HashSet<_>>();
165
166    // Initialize the all the hints written during execution.
167    record.global_memory_initialize_events.extend(hint_init_events);
168
169    // Initialize the memory addresses that were touched during execution.
170    // We don't initialize the memory addresses that were in the program image, since they were
171    // initialized in the MemoryProgram chip.
172    let memory_init_events = touched_addresses
173        .iter()
174        .filter(|addr| !minimal_executor.program().memory_image.contains_key(*addr))
175        .filter(|addr| !hint_addrs.contains(*addr))
176        .map(|addr| MemoryInitializeFinalizeEvent::initialize(*addr, 0));
177    record.global_memory_initialize_events.extend(memory_init_events);
178
179    // Ensure all the hinted addresses are initialized.
180    touched_addresses.extend(hint_addrs);
181
182    // Finalize the memory addresses that were touched during execution.
183    for addr in &touched_addresses {
184        let entry = minimal_executor.get_memory_value(*addr);
185
186        record.global_memory_finalize_events.push(MemoryInitializeFinalizeEvent::finalize(
187            *addr,
188            entry.value,
189            entry.clk,
190        ));
191    }
192}
193
194/// Get set of addresses that were hinted.
195#[must_use]
196pub fn get_hint_event_addrs(minimal_executor: &MinimalExecutorRunner) -> HashSet<u64> {
197    let events = minimal_executor
198        .hints()
199        .iter()
200        .flat_map(|(addr, value)| chunked_memory_init_events(*addr, value))
201        .collect::<Vec<_>>();
202    let hint_event_addrs = events.iter().map(|event| event.addr).collect::<HashSet<_>>();
203
204    hint_event_addrs
205}
206
207/// Prove a program with the given inputs using SimpleProver.
208///
209/// This is a test-only function that proves records sequentially for simplicity. It is
210/// extremely inefficient in both time and space, and should only be used for testing.
211pub async fn prove_core<GC, SC, PC>(
212    prover: &SimpleProver<GC, SC, PC>,
213    pk: Arc<ProvingKey<GC, SC, PC>>,
214    program: Arc<Program>,
215    stdin: SP1Stdin,
216    opts: SP1CoreOpts,
217    context: SP1Context<'static>,
218) -> Result<(MachineProof<GC, PcsProof<GC, SC>>, u64), SP1CoreProverError>
219where
220    GC: IopCtx,
221    SC: ShardContext<GC, Air = RiscvAir<GC::F>>,
222    PC: AirProver<GC, SC>,
223    GC::F: PrimeField32,
224{
225    let (all_records, cycles) =
226        generate_records::<GC::F>(program, stdin, opts, context.proof_nonce)?;
227
228    // Prove records sequentially
229    let mut shard_proofs = BTreeMap::new();
230    for record in all_records {
231        let proof = prover.prove_shard(pk.clone(), record).await;
232        let public_values: &PublicValues<[GC::F; 4], [GC::F; 3], [GC::F; 4], GC::F> =
233            proof.public_values.as_slice().borrow();
234        shard_proofs.insert(
235            (
236                public_values.initial_timestamp,
237                public_values.last_timestamp,
238                public_values.previous_init_addr,
239                public_values.previous_finalize_addr,
240            ),
241            proof,
242        );
243    }
244
245    let shard_proofs = shard_proofs.into_values().collect();
246    let proof = MachineProof { shard_proofs };
247
248    Ok((proof, cycles))
249}
250
251/// Splice a trace chunk into shard-sized pieces sequentially.
252/// Returns a vector of (is_last, spliced_trace) pairs.
253fn splice_chunk_sequential<T: MinimalTrace>(
254    program: Arc<Program>,
255    chunk: T,
256    proof_nonce: [u32; sp1_hypercube::air::PROOF_NONCE_NUM_WORDS],
257    opts: SP1CoreOpts,
258    touched_addresses: &mut HashSet<u64>,
259    touched_pages: &mut HashSet<u64>,
260) -> Vec<(bool, SplicedMinimalTrace<T>)> {
261    let mut result = Vec::new();
262    let mut compressed_touched = CompressedMemory::new();
263    let mut compressed_touched_pages = CompressedPages::new();
264    let mut vm = SplicingVMEnum::new(
265        &chunk,
266        program.clone(),
267        &mut compressed_touched,
268        &mut compressed_touched_pages,
269        proof_nonce,
270        opts,
271    );
272
273    let mut last_splice = SplicedMinimalTrace::new_full_trace(chunk.clone());
274    let start_num_mem_reads = chunk.num_mem_reads();
275
276    loop {
277        match vm.execute().expect("splicing execution failed") {
278            CycleResult::ShardBoundary => {
279                if let Some(spliced) = vm.splice(chunk.clone()) {
280                    last_splice.set_last_clk(vm.clk());
281                    last_splice
282                        .set_last_mem_reads_idx(start_num_mem_reads as usize - vm.mem_reads_len());
283                    let splice_to_emit = std::mem::replace(&mut last_splice, spliced);
284                    result.push((false, splice_to_emit));
285                } else {
286                    last_splice.set_last_clk(vm.clk());
287                    last_splice
288                        .set_last_mem_reads_idx(start_num_mem_reads as usize - vm.mem_reads_len());
289                    result.push((true, last_splice));
290                    break;
291                }
292            }
293            CycleResult::Done(true) => {
294                last_splice.set_last_clk(vm.clk());
295                last_splice.set_last_mem_reads_idx(chunk.num_mem_reads() as usize);
296                result.push((true, last_splice));
297                break;
298            }
299            CycleResult::Done(false) | CycleResult::TraceEnd => {
300                unreachable!("splicing should not return incomplete without shard boundary");
301            }
302        }
303    }
304
305    touched_addresses.extend(compressed_touched.is_set());
306    touched_pages.extend(compressed_touched_pages.is_set());
307    result
308}
309
310#[derive(Error, Debug)]
311pub enum SP1CoreProverError {
312    #[error("failed to execute program: {0}")]
313    ExecutionError(ExecutionError),
314    #[error("io error: {0}")]
315    IoError(io::Error),
316    #[error("serialization error: {0}")]
317    SerializationError(bincode::Error),
318}
319
320impl From<ExecutionError> for SP1CoreProverError {
321    fn from(e: ExecutionError) -> SP1CoreProverError {
322        SP1CoreProverError::ExecutionError(e)
323    }
324}