1use std::{borrow::Borrow, collections::BTreeMap, io, sync::Arc};
2
3use crate::executor::trace_chunk;
4use crate::riscv::RiscvAir;
5use hashbrown::HashSet;
6use thiserror::Error;
7
8use slop_algebra::PrimeField32;
9use slop_challenger::IopCtx;
10use sp1_hypercube::{
11 air::{PublicValues, PROOF_NONCE_NUM_WORDS},
12 prover::{AirProver, PcsProof, ProvingKey, SimpleProver},
13 MachineProof, MachineRecord, ShardContext,
14};
15
16use crate::io::SP1Stdin;
17use sp1_core_executor::{SP1CoreOpts, SplitOpts};
18
19use sp1_core_executor::{
20 chunked_memory_init_events,
21 events::{MemoryInitializeFinalizeEvent, MemoryRecord},
22 CompressedMemory, CompressedPages, CycleResult, ExecutionError, ExecutionRecord, Program,
23 SP1Context, SplicedMinimalTrace, SplicingVMEnum,
24};
25use sp1_core_executor_runner::MinimalExecutorRunner;
26use sp1_jit::{MinimalTrace, TraceChunk};
27
28pub fn generate_records<F>(
35 program: Arc<Program>,
36 stdin: SP1Stdin,
37 opts: SP1CoreOpts,
38 proof_nonce: [u32; PROOF_NONCE_NUM_WORDS],
39) -> Result<(Vec<ExecutionRecord>, u64), SP1CoreProverError>
40where
41 F: PrimeField32,
42{
43 let machine = RiscvAir::<F>::machine();
44 let split_opts = SplitOpts::new(&opts, program.instructions.len(), false);
45
46 let mut minimal_executor = MinimalExecutorRunner::new(
48 program.clone(),
49 false,
50 Some(opts.minimal_trace_chunk_threshold),
51 opts.memory_limit,
52 opts.trace_chunk_slots,
53 );
54
55 for buf in stdin.buffer {
56 minimal_executor.with_input(&buf);
57 }
58
59 let mut trace_chunks = Vec::new();
60 while let Some(chunk) = minimal_executor.try_execute_chunk()? {
61 let chunk: TraceChunk = chunk.into();
65 trace_chunks.push(chunk);
66 }
67
68 let mut all_records = Vec::new();
70 let mut deferred =
71 ExecutionRecord::new(program.clone(), proof_nonce, opts.global_dependencies_opt);
72 let mut touched_addresses = HashSet::new();
73 let mut touched_pages = HashSet::new();
74
75 for chunk in trace_chunks {
76 let spliced_traces = splice_chunk_sequential(
78 program.clone(),
79 chunk,
80 proof_nonce,
81 opts.clone(),
82 &mut touched_addresses,
83 &mut touched_pages,
84 );
85
86 for (is_last, spliced) in spliced_traces {
88 let record =
89 ExecutionRecord::new(program.clone(), proof_nonce, opts.global_dependencies_opt);
90 let (done, mut record, final_registers) =
91 trace_chunk::<F>(program.clone(), opts.clone(), spliced, proof_nonce, record)
92 .map_err(SP1CoreProverError::ExecutionError)?;
93
94 if done {
95 emit_globals(
97 &minimal_executor,
98 &mut record,
99 final_registers,
100 touched_addresses.clone(),
101 touched_pages.clone(),
102 );
103 }
104
105 deferred.append(&mut record.defer(&opts.retained_events_presets));
107 let can_pack = done
108 && record.estimated_trace_area <= split_opts.pack_trace_threshold
109 && deferred.global_memory_initialize_events.len()
110 <= split_opts.combine_memory_threshold
111 && deferred.global_memory_finalize_events.len()
112 <= split_opts.combine_memory_threshold
113 && deferred.global_page_prot_initialize_events.len()
114 <= split_opts.combine_page_prot_threshold
115 && deferred.global_page_prot_finalize_events.len()
116 <= split_opts.combine_page_prot_threshold;
117 let deferred_records =
118 deferred.split(done || is_last, &mut record, can_pack, &split_opts);
119
120 let mut records = vec![record];
122 records.extend(deferred_records);
123 machine.generate_dependencies(records.iter_mut(), None);
124 all_records.extend(records);
125 }
126 }
127
128 let cycles = minimal_executor.global_clk();
129 Ok((all_records, cycles))
130}
131
132#[tracing::instrument(name = "emit globals", skip_all)]
135pub fn emit_globals(
136 minimal_executor: &MinimalExecutorRunner,
137 record: &mut ExecutionRecord,
138 final_registers: [MemoryRecord; 32],
139 mut touched_addresses: HashSet<u64>,
140 _touched_pages: HashSet<u64>,
141) {
142 touched_addresses.extend(minimal_executor.program().memory_image.keys().copied());
144
145 record.global_memory_initialize_events.extend(
146 final_registers
147 .iter()
148 .enumerate()
149 .filter(|(_, e)| e.timestamp != 0)
150 .map(|(i, _)| MemoryInitializeFinalizeEvent::initialize(i as u64, 0)),
151 );
152
153 record.global_memory_finalize_events.extend(
154 final_registers.iter().enumerate().filter(|(_, e)| e.timestamp != 0).map(|(i, entry)| {
155 MemoryInitializeFinalizeEvent::finalize(i as u64, entry.value, entry.timestamp)
156 }),
157 );
158
159 let hint_init_events: Vec<MemoryInitializeFinalizeEvent> = minimal_executor
160 .hints()
161 .iter()
162 .flat_map(|(addr, value)| chunked_memory_init_events(*addr, value))
163 .collect::<Vec<_>>();
164 let hint_addrs = hint_init_events.iter().map(|event| event.addr).collect::<HashSet<_>>();
165
166 record.global_memory_initialize_events.extend(hint_init_events);
168
169 let memory_init_events = touched_addresses
173 .iter()
174 .filter(|addr| !minimal_executor.program().memory_image.contains_key(*addr))
175 .filter(|addr| !hint_addrs.contains(*addr))
176 .map(|addr| MemoryInitializeFinalizeEvent::initialize(*addr, 0));
177 record.global_memory_initialize_events.extend(memory_init_events);
178
179 touched_addresses.extend(hint_addrs);
181
182 for addr in &touched_addresses {
184 let entry = minimal_executor.get_memory_value(*addr);
185
186 record.global_memory_finalize_events.push(MemoryInitializeFinalizeEvent::finalize(
187 *addr,
188 entry.value,
189 entry.clk,
190 ));
191 }
192}
193
194#[must_use]
196pub fn get_hint_event_addrs(minimal_executor: &MinimalExecutorRunner) -> HashSet<u64> {
197 let events = minimal_executor
198 .hints()
199 .iter()
200 .flat_map(|(addr, value)| chunked_memory_init_events(*addr, value))
201 .collect::<Vec<_>>();
202 let hint_event_addrs = events.iter().map(|event| event.addr).collect::<HashSet<_>>();
203
204 hint_event_addrs
205}
206
207pub async fn prove_core<GC, SC, PC>(
212 prover: &SimpleProver<GC, SC, PC>,
213 pk: Arc<ProvingKey<GC, SC, PC>>,
214 program: Arc<Program>,
215 stdin: SP1Stdin,
216 opts: SP1CoreOpts,
217 context: SP1Context<'static>,
218) -> Result<(MachineProof<GC, PcsProof<GC, SC>>, u64), SP1CoreProverError>
219where
220 GC: IopCtx,
221 SC: ShardContext<GC, Air = RiscvAir<GC::F>>,
222 PC: AirProver<GC, SC>,
223 GC::F: PrimeField32,
224{
225 let (all_records, cycles) =
226 generate_records::<GC::F>(program, stdin, opts, context.proof_nonce)?;
227
228 let mut shard_proofs = BTreeMap::new();
230 for record in all_records {
231 let proof = prover.prove_shard(pk.clone(), record).await;
232 let public_values: &PublicValues<[GC::F; 4], [GC::F; 3], [GC::F; 4], GC::F> =
233 proof.public_values.as_slice().borrow();
234 shard_proofs.insert(
235 (
236 public_values.initial_timestamp,
237 public_values.last_timestamp,
238 public_values.previous_init_addr,
239 public_values.previous_finalize_addr,
240 ),
241 proof,
242 );
243 }
244
245 let shard_proofs = shard_proofs.into_values().collect();
246 let proof = MachineProof { shard_proofs };
247
248 Ok((proof, cycles))
249}
250
251fn splice_chunk_sequential<T: MinimalTrace>(
254 program: Arc<Program>,
255 chunk: T,
256 proof_nonce: [u32; sp1_hypercube::air::PROOF_NONCE_NUM_WORDS],
257 opts: SP1CoreOpts,
258 touched_addresses: &mut HashSet<u64>,
259 touched_pages: &mut HashSet<u64>,
260) -> Vec<(bool, SplicedMinimalTrace<T>)> {
261 let mut result = Vec::new();
262 let mut compressed_touched = CompressedMemory::new();
263 let mut compressed_touched_pages = CompressedPages::new();
264 let mut vm = SplicingVMEnum::new(
265 &chunk,
266 program.clone(),
267 &mut compressed_touched,
268 &mut compressed_touched_pages,
269 proof_nonce,
270 opts,
271 );
272
273 let mut last_splice = SplicedMinimalTrace::new_full_trace(chunk.clone());
274 let start_num_mem_reads = chunk.num_mem_reads();
275
276 loop {
277 match vm.execute().expect("splicing execution failed") {
278 CycleResult::ShardBoundary => {
279 if let Some(spliced) = vm.splice(chunk.clone()) {
280 last_splice.set_last_clk(vm.clk());
281 last_splice
282 .set_last_mem_reads_idx(start_num_mem_reads as usize - vm.mem_reads_len());
283 let splice_to_emit = std::mem::replace(&mut last_splice, spliced);
284 result.push((false, splice_to_emit));
285 } else {
286 last_splice.set_last_clk(vm.clk());
287 last_splice
288 .set_last_mem_reads_idx(start_num_mem_reads as usize - vm.mem_reads_len());
289 result.push((true, last_splice));
290 break;
291 }
292 }
293 CycleResult::Done(true) => {
294 last_splice.set_last_clk(vm.clk());
295 last_splice.set_last_mem_reads_idx(chunk.num_mem_reads() as usize);
296 result.push((true, last_splice));
297 break;
298 }
299 CycleResult::Done(false) | CycleResult::TraceEnd => {
300 unreachable!("splicing should not return incomplete without shard boundary");
301 }
302 }
303 }
304
305 touched_addresses.extend(compressed_touched.is_set());
306 touched_pages.extend(compressed_touched_pages.is_set());
307 result
308}
309
310#[derive(Error, Debug)]
311pub enum SP1CoreProverError {
312 #[error("failed to execute program: {0}")]
313 ExecutionError(ExecutionError),
314 #[error("io error: {0}")]
315 IoError(io::Error),
316 #[error("serialization error: {0}")]
317 SerializationError(bincode::Error),
318}
319
320impl From<ExecutionError> for SP1CoreProverError {
321 fn from(e: ExecutionError) -> SP1CoreProverError {
322 SP1CoreProverError::ExecutionError(e)
323 }
324}