1use p3_matrix::dense::RowMajorMatrix;
2use std::{
3 error::Error,
4 fs::File,
5 io::{self, Seek, SeekFrom},
6 str::FromStr,
7 sync::{
8 mpsc::{channel, sync_channel, Sender},
9 Arc, Mutex,
10 },
11 thread::ScopedJoinHandle,
12};
13use web_time::Instant;
14
15use crate::{
16 riscv::RiscvAir,
17 shape::{CoreShapeConfig, Shapeable},
18 utils::test::MaliciousTracePVGeneratorType,
19};
20use p3_maybe_rayon::prelude::*;
21use sp1_stark::{MachineProvingKey, StarkVerifyingKey};
22use thiserror::Error;
23
24use p3_field::PrimeField32;
25use sp1_stark::air::MachineAir;
26
27use crate::{
28 io::SP1Stdin,
29 utils::{chunk_vec, concurrency::TurnBasedSync},
30};
31use sp1_core_executor::{
32 estimator::RecordEstimator,
33 events::{format_table_line, sorted_table_lines},
34 ExecutionState, RiscvAirId,
35};
36
37use sp1_core_executor::{
38 subproof::NoOpSubproofVerifier, ExecutionError, ExecutionRecord, ExecutionReport, Executor,
39 Program, SP1Context,
40};
41use sp1_stark::{
42 air::PublicValues, shape::OrderedShape, Com, MachineProof, MachineProver, MachineRecord,
43 OpeningProof, PcsProverData, SP1CoreOpts, ShardProof, StarkGenericConfig, Val,
44};
45
46#[allow(clippy::too_many_arguments)]
47pub fn prove_core<SC: StarkGenericConfig, P: MachineProver<SC, RiscvAir<SC::Val>>>(
48 prover: &P,
49 pk: &P::DeviceProvingKey,
50 _: &StarkVerifyingKey<SC>,
51 program: Program,
52 stdin: &SP1Stdin,
53 opts: SP1CoreOpts,
54 context: SP1Context,
55 shape_config: Option<&CoreShapeConfig<SC::Val>>,
56 malicious_trace_pv_generator: Option<MaliciousTracePVGeneratorType<SC::Val, P>>,
57) -> Result<(MachineProof<SC>, Vec<u8>, u64), SP1CoreProverError>
58where
59 SC::Val: PrimeField32,
60 SC::Challenger: 'static + Clone + Send,
61 OpeningProof<SC>: Send,
62 Com<SC>: Send + Sync,
63 PcsProverData<SC>: Send + Sync,
64{
65 let (proof_tx, proof_rx) = channel();
66 let (shape_tx, shape_rx) = channel();
67 let (public_values, cycles) = prove_core_stream(
68 prover,
69 pk,
70 program,
71 stdin,
72 opts,
73 context,
74 shape_config,
75 proof_tx,
76 shape_tx,
77 malicious_trace_pv_generator,
78 None,
79 )?;
80
81 let _: Vec<_> = shape_rx.iter().collect();
82 let shard_proofs: Vec<ShardProof<SC>> = proof_rx.iter().collect();
83 let proof = MachineProof { shard_proofs };
84
85 Ok((proof, public_values, cycles))
86}
87
88#[allow(clippy::too_many_arguments)]
89pub fn prove_core_stream<SC: StarkGenericConfig, P: MachineProver<SC, RiscvAir<SC::Val>>>(
90 prover: &P,
91 pk: &P::DeviceProvingKey,
92 program: Program,
93 stdin: &SP1Stdin,
94 opts: SP1CoreOpts,
95 context: SP1Context,
96 shape_config: Option<&CoreShapeConfig<SC::Val>>,
97 proof_tx: Sender<ShardProof<SC>>,
98 shape_and_done_tx: Sender<(OrderedShape, bool)>,
99 malicious_trace_pv_generator: Option<MaliciousTracePVGeneratorType<SC::Val, P>>, gas_calculator: Option<Box<dyn FnOnce(&RecordEstimator) -> Result<u64, Box<dyn Error>> + '_>>,
101) -> Result<(Vec<u8>, u64), SP1CoreProverError>
102where
103 SC::Val: PrimeField32,
104 SC::Challenger: 'static + Clone + Send,
105 OpeningProof<SC>: Send,
106 Com<SC>: Send + Sync,
107 PcsProverData<SC>: Send + Sync,
108{
109 let mut runtime = Box::new(Executor::with_context(program.clone(), opts, context));
111 runtime.maximal_shapes = shape_config.map(|config| {
112 config.maximal_core_shapes(opts.shard_size.ilog2() as usize).into_iter().collect()
113 });
114 runtime.write_vecs(&stdin.buffer);
115 for proof in stdin.proofs.iter() {
116 let (proof, vk) = proof.clone();
117 runtime.write_proof(proof, vk);
118 }
119 if gas_calculator.is_some() {
121 runtime.record_estimator = Some(Box::default());
122 }
123
124 #[cfg(feature = "debug")]
125 let (all_records_tx, all_records_rx) = std::sync::mpsc::channel::<Vec<ExecutionRecord>>();
126
127 let malicious_trace_pv_generator: Option<&MaliciousTracePVGeneratorType<SC::Val, P>> =
129 malicious_trace_pv_generator.as_ref();
130
131 let proving_start = Instant::now();
133 let span = tracing::Span::current().clone();
134 std::thread::scope(move |s| {
135 let _span = span.enter();
136
137 let checkpoint_generator_span = tracing::Span::current().clone();
139 let (checkpoints_tx, checkpoints_rx) =
140 sync_channel::<(usize, File, bool, u64)>(opts.checkpoints_channel_capacity);
141 let checkpoint_generator_handle: ScopedJoinHandle<Result<_, SP1CoreProverError>> =
142 s.spawn(move || {
143 let _span = checkpoint_generator_span.enter();
144 tracing::debug_span!("checkpoint generator").in_scope(|| {
145 let mut index = 0;
146 loop {
147 let span = tracing::debug_span!("batch");
149 let _span = span.enter();
150
151 let (checkpoint, _, done) = runtime
153 .execute_state(false)
154 .map_err(SP1CoreProverError::ExecutionError)?;
155
156 let mut checkpoint_file =
158 tempfile::tempfile().map_err(SP1CoreProverError::IoError)?;
159 checkpoint
160 .save(&mut checkpoint_file)
161 .map_err(SP1CoreProverError::IoError)?;
162
163 checkpoints_tx
165 .send((index, checkpoint_file, done, runtime.state.global_clk))
166 .unwrap();
167
168 if done {
170 break Ok(runtime);
171 }
172
173 index += 1;
175 }
176 })
177 });
178
179 let mut challenger = prover.config().challenger();
181 pk.observe_into(&mut challenger);
182
183 let p2_record_gen_sync = Arc::new(TurnBasedSync::new());
185 let p2_trace_gen_sync = Arc::new(TurnBasedSync::new());
186 let (p2_records_and_traces_tx, p2_records_and_traces_rx) =
187 sync_channel::<(Vec<ExecutionRecord>, Vec<Vec<(String, RowMajorMatrix<Val<SC>>)>>)>(
188 opts.records_and_traces_channel_capacity,
189 );
190 let p2_records_and_traces_tx = Arc::new(Mutex::new(p2_records_and_traces_tx));
191
192 let shape_tx = Arc::new(Mutex::new(shape_and_done_tx));
193 let report_aggregate = Arc::new(Mutex::new(ExecutionReport::default()));
194 let state = Arc::new(Mutex::new(PublicValues::<u32, u32>::default().reset()));
195 let deferred = Arc::new(Mutex::new(ExecutionRecord::new(program.clone().into())));
196 let mut p2_record_and_trace_gen_handles = Vec::new();
197 let checkpoints_rx = Arc::new(Mutex::new(checkpoints_rx));
198 for _ in 0..opts.trace_gen_workers {
199 let record_gen_sync = Arc::clone(&p2_record_gen_sync);
200 let trace_gen_sync = Arc::clone(&p2_trace_gen_sync);
201 let records_and_traces_tx = Arc::clone(&p2_records_and_traces_tx);
202 let checkpoints_rx = Arc::clone(&checkpoints_rx);
203
204 let shape_tx = Arc::clone(&shape_tx);
205 let report_aggregate = Arc::clone(&report_aggregate);
206 let state = Arc::clone(&state);
207 let deferred = Arc::clone(&deferred);
208 let program = program.clone();
209 let span = tracing::Span::current().clone();
210
211 #[cfg(feature = "debug")]
212 let all_records_tx = all_records_tx.clone();
213
214 let handle = s.spawn(move || {
215 let _span = span.enter();
216 tracing::debug_span!("phase 2 trace generation").in_scope(|| {
217 loop {
218 let received = { checkpoints_rx.lock().unwrap().recv() };
219 if let Ok((index, mut checkpoint, done, num_cycles)) = received {
220 let (mut records, report) = tracing::debug_span!("trace checkpoint")
221 .in_scope(|| {
222 trace_checkpoint::<SC>(
223 program.clone(),
224 &checkpoint,
225 opts,
226 shape_config,
227 )
228 });
229
230 *report_aggregate.lock().unwrap() += report;
232 checkpoint
233 .seek(SeekFrom::Start(0))
234 .expect("failed to seek to start of tempfile");
235
236 record_gen_sync.wait_for_turn(index);
238
239 let mut state = state.lock().unwrap();
242 for record in records.iter_mut() {
243 state.shard += 1;
244 state.execution_shard = record.public_values.execution_shard;
245 state.start_pc = record.public_values.start_pc;
246 state.next_pc = record.public_values.next_pc;
247 if state.committed_value_digest == [0u32; 8] {
248 state.committed_value_digest =
249 record.public_values.committed_value_digest;
250 }
251 if state.deferred_proofs_digest == [0u32; 8] {
252 state.deferred_proofs_digest =
253 record.public_values.deferred_proofs_digest;
254 }
255 record.public_values = *state;
256 }
257
258 let mut deferred = deferred.lock().unwrap();
260 for record in records.iter_mut() {
261 deferred.append(&mut record.defer());
262 }
263
264 let mut shape_fixed_records = if done &&
267 num_cycles < 1 << 21 &&
268 deferred.global_memory_initialize_events.len() <
269 opts.split_opts.combine_memory_threshold &&
270 deferred.global_memory_finalize_events.len() <
271 opts.split_opts.combine_memory_threshold
272 {
273 let mut records_clone = records.clone();
274 let last_record = records_clone.last_mut();
275 let mut deferred =
277 deferred.split(done, last_record, opts.split_opts);
278 tracing::debug!("deferred {} records", deferred.len());
279
280 if !done {
284 state.execution_shard += 1;
285 }
286 for record in deferred.iter_mut() {
287 state.shard += 1;
288 state.previous_init_addr_bits =
289 record.public_values.previous_init_addr_bits;
290 state.last_init_addr_bits =
291 record.public_values.last_init_addr_bits;
292 state.previous_finalize_addr_bits =
293 record.public_values.previous_finalize_addr_bits;
294 state.last_finalize_addr_bits =
295 record.public_values.last_finalize_addr_bits;
296 state.start_pc = state.next_pc;
297 record.public_values = *state;
298 }
299 records_clone.append(&mut deferred);
300
301 tracing::debug_span!("generate dependencies", index).in_scope(
303 || {
304 prover.machine().generate_dependencies(
305 &mut records_clone,
306 &opts,
307 None,
308 );
309 },
310 );
311
312 record_gen_sync.advance_turn();
314
315 let mut fixed_shape = true;
317 if let Some(shape_config) = shape_config {
318 for record in records_clone.iter_mut() {
319 if shape_config.fix_shape(record).is_err() {
320 fixed_shape = false;
321 }
322 }
323 }
324 fixed_shape.then_some(records_clone)
325 } else {
326 None
327 };
328
329 if shape_fixed_records.is_none() {
330 let mut deferred = deferred.split(done, None, opts.split_opts);
332 tracing::debug!("deferred {} records", deferred.len());
333
334 if !done {
338 state.execution_shard += 1;
339 }
340 for record in deferred.iter_mut() {
341 state.shard += 1;
342 state.previous_init_addr_bits =
343 record.public_values.previous_init_addr_bits;
344 state.last_init_addr_bits =
345 record.public_values.last_init_addr_bits;
346 state.previous_finalize_addr_bits =
347 record.public_values.previous_finalize_addr_bits;
348 state.last_finalize_addr_bits =
349 record.public_values.last_finalize_addr_bits;
350 state.start_pc = state.next_pc;
351 record.public_values = *state;
352 }
353 records.append(&mut deferred);
354
355 tracing::debug_span!("generate dependencies", index).in_scope(
357 || {
358 prover.machine().generate_dependencies(
359 &mut records,
360 &opts,
361 None,
362 );
363 },
364 );
365
366 record_gen_sync.advance_turn();
368
369 if let Some(shape_config) = shape_config {
371 for record in records.iter_mut() {
372 shape_config.fix_shape(record).unwrap();
373 }
374 }
375 shape_fixed_records = Some(records);
376 }
377
378 let mut records = shape_fixed_records.unwrap();
379
380 for record in records.iter() {
382 let mut heights = vec![];
383 let chips = prover.shard_chips(record).collect::<Vec<_>>();
384 if let Some(shape) = record.shape.as_ref() {
385 for chip in chips.iter() {
386 let id = RiscvAirId::from_str(&chip.name()).unwrap();
387 let height = shape.log2_height(&id).unwrap();
388 heights.push((chip.name().clone(), height));
389 }
390 shape_tx
391 .lock()
392 .unwrap()
393 .send((OrderedShape::from_log2_heights(&heights), done))
394 .unwrap();
395 }
396 }
397
398 #[cfg(feature = "debug")]
399 all_records_tx.send(records.clone()).unwrap();
400
401 let mut main_traces = Vec::new();
402 if let Some(malicious_trace_pv_generator) = malicious_trace_pv_generator
403 {
404 tracing::info_span!("generate main traces", index).in_scope(|| {
405 main_traces = records
406 .par_iter_mut()
407 .map(|record| malicious_trace_pv_generator(prover, record))
408 .collect::<Vec<_>>();
409 });
410 } else {
411 tracing::info_span!("generate main traces", index).in_scope(|| {
412 main_traces = records
413 .par_iter()
414 .map(|record| prover.generate_traces(record))
415 .collect::<Vec<_>>();
416 });
417 }
418
419 trace_gen_sync.wait_for_turn(index);
420
421 let chunked_records = chunk_vec(records, opts.shard_batch_size);
423 let chunked_main_traces = chunk_vec(main_traces, opts.shard_batch_size);
424 chunked_records
425 .into_iter()
426 .zip(chunked_main_traces.into_iter())
427 .for_each(|(records, main_traces)| {
428 records_and_traces_tx
429 .lock()
430 .unwrap()
431 .send((records, main_traces))
432 .unwrap();
433 });
434
435 trace_gen_sync.advance_turn();
436 } else {
437 break;
438 }
439 }
440 })
441 });
442 p2_record_and_trace_gen_handles.push(handle);
443 }
444 drop(p2_records_and_traces_tx);
445 #[cfg(feature = "debug")]
446 drop(all_records_tx);
447
448 let p2_prover_span = tracing::Span::current().clone();
450 let proof_tx = Arc::new(Mutex::new(proof_tx));
451 let p2_prover_handle = s.spawn(move || {
452 let _span = p2_prover_span.enter();
453 tracing::debug_span!("phase 2 prover").in_scope(|| {
454 for (records, traces) in p2_records_and_traces_rx.into_iter() {
455 tracing::debug_span!("batch").in_scope(|| {
456 let span = tracing::Span::current().clone();
457 let proofs = records
458 .into_par_iter()
459 .zip(traces.into_par_iter())
460 .map(|(record, main_traces)| {
461 let _span = span.enter();
462
463 let shard = record.shard();
464 let before = Instant::now();
465
466 let main_data = tracing::debug_span!("commit", shard)
467 .in_scope(|| prover.commit(&record, main_traces));
468
469 let proof = tracing::debug_span!("opening", shard).in_scope(|| {
470 prover.open(pk, main_data, &mut challenger.clone()).unwrap()
471 });
472
473 let elapsed = before.elapsed();
474
475 let debug_shapes = record.shape.as_ref().map(|shape| {
477 shape
478 .iter()
479 .filter_map(|(&k, &v)| (v > 0).then_some((k, v)))
480 .collect::<Vec<_>>()
481 });
482 tracing::debug!(
483 "proving shard {shard} took {} ns. shape: {:?}",
484 elapsed.as_nanos(),
485 debug_shapes
486 );
487
488 #[cfg(debug_assertions)]
489 {
490 if let Some(shape) = record.shape.as_ref() {
491 assert_eq!(
492 proof.shape(),
493 shape
494 .clone()
495 .into_iter()
496 .map(|(k, v)| (k.to_string(), v as usize))
497 .collect(),
498 );
499 }
500 }
501
502 rayon::spawn(move || {
503 drop(record);
504 });
505
506 proof
507 })
508 .collect::<Vec<_>>();
509
510 let proof_tx = proof_tx.lock().unwrap();
512 for proof in proofs {
513 proof_tx.send(proof).unwrap();
514 }
515 });
516 }
517 });
518 });
519
520 let runtime = checkpoint_generator_handle.join().unwrap().unwrap();
522 let gas = gas_calculator.map(|calc| calc(runtime.record_estimator.as_ref().unwrap()));
523 let public_values_stream = runtime.state.public_values_stream;
524
525 p2_record_and_trace_gen_handles.into_iter().for_each(|handle| handle.join().unwrap());
527
528 p2_prover_handle.join().unwrap();
530
531 let mut report_aggregate = report_aggregate.lock().unwrap();
533 tracing::debug!(
534 "execution report (totals): total_cycles={}, total_syscall_cycles={}, touched_memory_addresses={}",
535 report_aggregate.total_instruction_count(),
536 report_aggregate.total_syscall_count(),
537 report_aggregate.touched_memory_addresses,
538 );
539 match gas {
540 Some(Ok(gas)) => {
541 tracing::debug!("execution report (gas): {}", gas);
542 report_aggregate.gas = Some(gas);
543 }
544 Some(Err(err)) => tracing::error!("Encountered error while calculating gas: {}", err),
545 None => (),
546 }
547
548 tracing::debug!("execution report (opcode counts):");
551 let (width, lines) = sorted_table_lines(report_aggregate.opcode_counts.as_ref());
552 for (label, count) in lines {
553 if *count > 0 {
554 tracing::debug!(" {}", format_table_line(&width, &label, count));
555 } else {
556 tracing::debug!(" {}", format_table_line(&width, &label, count));
557 }
558 }
559
560 tracing::debug!("execution report (syscall counts):");
561 let (width, lines) = sorted_table_lines(report_aggregate.syscall_counts.as_ref());
562 for (label, count) in lines {
563 if *count > 0 {
564 tracing::debug!(" {}", format_table_line(&width, &label, count));
565 } else {
566 tracing::debug!(" {}", format_table_line(&width, &label, count));
567 }
568 }
569
570 let cycles = report_aggregate.total_instruction_count();
571
572 let proving_time = proving_start.elapsed().as_secs_f64();
574 tracing::debug!(
575 "summary: cycles={}, e2e={}s, khz={:.2}",
576 cycles,
577 proving_time,
578 (cycles as f64 / (proving_time * 1000.0) as f64),
579 );
580
581 #[cfg(feature = "debug")]
582 {
583 let all_records = all_records_rx.iter().flatten().collect::<Vec<_>>();
584 let mut challenger = prover.machine().config().challenger();
585 let pk_host = prover.pk_to_host(pk);
586 prover.machine().debug_constraints(&pk_host, all_records, &mut challenger);
587 }
588
589 Ok((public_values_stream, cycles))
590 })
591}
592
593pub fn trace_checkpoint<SC: StarkGenericConfig>(
594 program: Program,
595 file: &File,
596 opts: SP1CoreOpts,
597 shape_config: Option<&CoreShapeConfig<SC::Val>>,
598) -> (Vec<ExecutionRecord>, ExecutionReport)
599where
600 <SC as StarkGenericConfig>::Val: PrimeField32,
601{
602 let noop = NoOpSubproofVerifier;
603
604 let mut reader = std::io::BufReader::new(file);
605 let state: ExecutionState =
606 bincode::deserialize_from(&mut reader).expect("failed to deserialize state");
607 let mut runtime = Executor::recover(program, state, opts);
608 runtime.maximal_shapes = shape_config.map(|config| {
609 config.maximal_core_shapes(opts.shard_size.ilog2() as usize).into_iter().collect()
610 });
611
612 runtime.subproof_verifier = Some(&noop);
615
616 let (records, done) = runtime.execute_record(true).unwrap();
618
619 let mut records = records.into_iter().map(|r| *r).collect::<Vec<_>>();
620 let pv = records.last().unwrap().public_values;
621
622 if !done &&
624 (pv.committed_value_digest.iter().any(|v| *v != 0) ||
625 pv.deferred_proofs_digest.iter().any(|v| *v != 0))
626 {
627 runtime.print_report = false;
629 let (_, next_pv, _) = runtime.execute_state(true).unwrap();
630 for record in records.iter_mut() {
631 record.public_values.committed_value_digest = next_pv.committed_value_digest;
632 record.public_values.deferred_proofs_digest = next_pv.deferred_proofs_digest;
633 }
634 }
635
636 (records, runtime.report)
637}
638
639#[derive(Error, Debug)]
640pub enum SP1CoreProverError {
641 #[error("failed to execute program: {0}")]
642 ExecutionError(ExecutionError),
643 #[error("io error: {0}")]
644 IoError(io::Error),
645 #[error("serialization error: {0}")]
646 SerializationError(bincode::Error),
647}