sp1_core_machine/utils/
prove.rs

1use p3_matrix::dense::RowMajorMatrix;
2use std::{
3    error::Error,
4    fs::File,
5    io::{self, Seek, SeekFrom},
6    str::FromStr,
7    sync::{
8        mpsc::{channel, sync_channel, Sender},
9        Arc, Mutex,
10    },
11    thread::ScopedJoinHandle,
12};
13use web_time::Instant;
14
15use crate::{
16    riscv::RiscvAir,
17    shape::{CoreShapeConfig, Shapeable},
18    utils::test::MaliciousTracePVGeneratorType,
19};
20use p3_maybe_rayon::prelude::*;
21use sp1_stark::{MachineProvingKey, StarkVerifyingKey};
22use thiserror::Error;
23
24use p3_field::PrimeField32;
25use sp1_stark::air::MachineAir;
26
27use crate::{
28    io::SP1Stdin,
29    utils::{chunk_vec, concurrency::TurnBasedSync},
30};
31use sp1_core_executor::{
32    estimator::RecordEstimator,
33    events::{format_table_line, sorted_table_lines},
34    ExecutionState, RiscvAirId,
35};
36
37use sp1_core_executor::{
38    subproof::NoOpSubproofVerifier, ExecutionError, ExecutionRecord, ExecutionReport, Executor,
39    Program, SP1Context,
40};
41use sp1_stark::{
42    air::PublicValues, shape::OrderedShape, Com, MachineProof, MachineProver, MachineRecord,
43    OpeningProof, PcsProverData, SP1CoreOpts, ShardProof, StarkGenericConfig, Val,
44};
45
46#[allow(clippy::too_many_arguments)]
47pub fn prove_core<SC: StarkGenericConfig, P: MachineProver<SC, RiscvAir<SC::Val>>>(
48    prover: &P,
49    pk: &P::DeviceProvingKey,
50    _: &StarkVerifyingKey<SC>,
51    program: Program,
52    stdin: &SP1Stdin,
53    opts: SP1CoreOpts,
54    context: SP1Context,
55    shape_config: Option<&CoreShapeConfig<SC::Val>>,
56    malicious_trace_pv_generator: Option<MaliciousTracePVGeneratorType<SC::Val, P>>,
57) -> Result<(MachineProof<SC>, Vec<u8>, u64), SP1CoreProverError>
58where
59    SC::Val: PrimeField32,
60    SC::Challenger: 'static + Clone + Send,
61    OpeningProof<SC>: Send,
62    Com<SC>: Send + Sync,
63    PcsProverData<SC>: Send + Sync,
64{
65    let (proof_tx, proof_rx) = channel();
66    let (shape_tx, shape_rx) = channel();
67    let (public_values, cycles) = prove_core_stream(
68        prover,
69        pk,
70        program,
71        stdin,
72        opts,
73        context,
74        shape_config,
75        proof_tx,
76        shape_tx,
77        malicious_trace_pv_generator,
78        None,
79    )?;
80
81    let _: Vec<_> = shape_rx.iter().collect();
82    let shard_proofs: Vec<ShardProof<SC>> = proof_rx.iter().collect();
83    let proof = MachineProof { shard_proofs };
84
85    Ok((proof, public_values, cycles))
86}
87
88#[allow(clippy::too_many_arguments)]
89pub fn prove_core_stream<SC: StarkGenericConfig, P: MachineProver<SC, RiscvAir<SC::Val>>>(
90    prover: &P,
91    pk: &P::DeviceProvingKey,
92    program: Program,
93    stdin: &SP1Stdin,
94    opts: SP1CoreOpts,
95    context: SP1Context,
96    shape_config: Option<&CoreShapeConfig<SC::Val>>,
97    proof_tx: Sender<ShardProof<SC>>,
98    shape_and_done_tx: Sender<(OrderedShape, bool)>,
99    malicious_trace_pv_generator: Option<MaliciousTracePVGeneratorType<SC::Val, P>>, /* This is used for failure test cases that generate malicious traces and public values. */
100    gas_calculator: Option<Box<dyn FnOnce(&RecordEstimator) -> Result<u64, Box<dyn Error>> + '_>>,
101) -> Result<(Vec<u8>, u64), SP1CoreProverError>
102where
103    SC::Val: PrimeField32,
104    SC::Challenger: 'static + Clone + Send,
105    OpeningProof<SC>: Send,
106    Com<SC>: Send + Sync,
107    PcsProverData<SC>: Send + Sync,
108{
109    // Setup the runtime.
110    let mut runtime = Box::new(Executor::with_context(program.clone(), opts, context));
111    runtime.maximal_shapes = shape_config.map(|config| {
112        config.maximal_core_shapes(opts.shard_size.ilog2() as usize).into_iter().collect()
113    });
114    runtime.write_vecs(&stdin.buffer);
115    for proof in stdin.proofs.iter() {
116        let (proof, vk) = proof.clone();
117        runtime.write_proof(proof, vk);
118    }
119    // Set the record estimator to collect data for gas calculation.
120    if gas_calculator.is_some() {
121        runtime.record_estimator = Some(Box::default());
122    }
123
124    #[cfg(feature = "debug")]
125    let (all_records_tx, all_records_rx) = std::sync::mpsc::channel::<Vec<ExecutionRecord>>();
126
127    // Need to create an optional reference, because of the `move` below.
128    let malicious_trace_pv_generator: Option<&MaliciousTracePVGeneratorType<SC::Val, P>> =
129        malicious_trace_pv_generator.as_ref();
130
131    // Record the start of the process.
132    let proving_start = Instant::now();
133    let span = tracing::Span::current().clone();
134    std::thread::scope(move |s| {
135        let _span = span.enter();
136
137        // Spawn the checkpoint generator thread.
138        let checkpoint_generator_span = tracing::Span::current().clone();
139        let (checkpoints_tx, checkpoints_rx) =
140            sync_channel::<(usize, File, bool, u64)>(opts.checkpoints_channel_capacity);
141        let checkpoint_generator_handle: ScopedJoinHandle<Result<_, SP1CoreProverError>> =
142            s.spawn(move || {
143                let _span = checkpoint_generator_span.enter();
144                tracing::debug_span!("checkpoint generator").in_scope(|| {
145                    let mut index = 0;
146                    loop {
147                        // Enter the span.
148                        let span = tracing::debug_span!("batch");
149                        let _span = span.enter();
150
151                        // Execute the runtime until we reach a checkpoint.
152                        let (checkpoint, _, done) = runtime
153                            .execute_state(false)
154                            .map_err(SP1CoreProverError::ExecutionError)?;
155
156                        // Save the checkpoint to a temp file.
157                        let mut checkpoint_file =
158                            tempfile::tempfile().map_err(SP1CoreProverError::IoError)?;
159                        checkpoint
160                            .save(&mut checkpoint_file)
161                            .map_err(SP1CoreProverError::IoError)?;
162
163                        // Send the checkpoint.
164                        checkpoints_tx
165                            .send((index, checkpoint_file, done, runtime.state.global_clk))
166                            .unwrap();
167
168                        // If we've reached the final checkpoint, break out of the loop.
169                        if done {
170                            break Ok(runtime);
171                        }
172
173                        // Update the index.
174                        index += 1;
175                    }
176                })
177            });
178
179        // Create the challenger and observe the verifying key.
180        let mut challenger = prover.config().challenger();
181        pk.observe_into(&mut challenger);
182
183        // Spawn the phase 2 record generator thread.
184        let p2_record_gen_sync = Arc::new(TurnBasedSync::new());
185        let p2_trace_gen_sync = Arc::new(TurnBasedSync::new());
186        let (p2_records_and_traces_tx, p2_records_and_traces_rx) =
187            sync_channel::<(Vec<ExecutionRecord>, Vec<Vec<(String, RowMajorMatrix<Val<SC>>)>>)>(
188                opts.records_and_traces_channel_capacity,
189            );
190        let p2_records_and_traces_tx = Arc::new(Mutex::new(p2_records_and_traces_tx));
191
192        let shape_tx = Arc::new(Mutex::new(shape_and_done_tx));
193        let report_aggregate = Arc::new(Mutex::new(ExecutionReport::default()));
194        let state = Arc::new(Mutex::new(PublicValues::<u32, u32>::default().reset()));
195        let deferred = Arc::new(Mutex::new(ExecutionRecord::new(program.clone().into())));
196        let mut p2_record_and_trace_gen_handles = Vec::new();
197        let checkpoints_rx = Arc::new(Mutex::new(checkpoints_rx));
198        for _ in 0..opts.trace_gen_workers {
199            let record_gen_sync = Arc::clone(&p2_record_gen_sync);
200            let trace_gen_sync = Arc::clone(&p2_trace_gen_sync);
201            let records_and_traces_tx = Arc::clone(&p2_records_and_traces_tx);
202            let checkpoints_rx = Arc::clone(&checkpoints_rx);
203
204            let shape_tx = Arc::clone(&shape_tx);
205            let report_aggregate = Arc::clone(&report_aggregate);
206            let state = Arc::clone(&state);
207            let deferred = Arc::clone(&deferred);
208            let program = program.clone();
209            let span = tracing::Span::current().clone();
210
211            #[cfg(feature = "debug")]
212            let all_records_tx = all_records_tx.clone();
213
214            let handle = s.spawn(move || {
215                let _span = span.enter();
216                tracing::debug_span!("phase 2 trace generation").in_scope(|| {
217                    loop {
218                        let received = { checkpoints_rx.lock().unwrap().recv() };
219                        if let Ok((index, mut checkpoint, done, num_cycles)) = received {
220                            let (mut records, report) = tracing::debug_span!("trace checkpoint")
221                                .in_scope(|| {
222                                    trace_checkpoint::<SC>(
223                                        program.clone(),
224                                        &checkpoint,
225                                        opts,
226                                        shape_config,
227                                    )
228                                });
229
230                            // Trace the checkpoint and reconstruct the execution records.
231                            *report_aggregate.lock().unwrap() += report;
232                            checkpoint
233                                .seek(SeekFrom::Start(0))
234                                .expect("failed to seek to start of tempfile");
235
236                            // Wait for our turn to update the state.
237                            record_gen_sync.wait_for_turn(index);
238
239                            // Update the public values & prover state for the shards which contain
240                            // "cpu events".
241                            let mut state = state.lock().unwrap();
242                            for record in records.iter_mut() {
243                                state.shard += 1;
244                                state.execution_shard = record.public_values.execution_shard;
245                                state.start_pc = record.public_values.start_pc;
246                                state.next_pc = record.public_values.next_pc;
247                                if state.committed_value_digest == [0u32; 8] {
248                                    state.committed_value_digest =
249                                        record.public_values.committed_value_digest;
250                                }
251                                if state.deferred_proofs_digest == [0u32; 8] {
252                                    state.deferred_proofs_digest =
253                                        record.public_values.deferred_proofs_digest;
254                                }
255                                record.public_values = *state;
256                            }
257
258                            // Defer events that are too expensive to include in every shard.
259                            let mut deferred = deferred.lock().unwrap();
260                            for record in records.iter_mut() {
261                                deferred.append(&mut record.defer());
262                            }
263
264                            // We combine the memory init/finalize events if they are "small"
265                            // and would affect performance.
266                            let mut shape_fixed_records = if done &&
267                                num_cycles < 1 << 21 &&
268                                deferred.global_memory_initialize_events.len() <
269                                    opts.split_opts.combine_memory_threshold &&
270                                deferred.global_memory_finalize_events.len() <
271                                    opts.split_opts.combine_memory_threshold
272                            {
273                                let mut records_clone = records.clone();
274                                let last_record = records_clone.last_mut();
275                                // See if any deferred shards are ready to be committed to.
276                                let mut deferred =
277                                    deferred.split(done, last_record, opts.split_opts);
278                                tracing::debug!("deferred {} records", deferred.len());
279
280                                // Update the public values & prover state for the shards which do
281                                // not contain "cpu events" before
282                                // committing to them.
283                                if !done {
284                                    state.execution_shard += 1;
285                                }
286                                for record in deferred.iter_mut() {
287                                    state.shard += 1;
288                                    state.previous_init_addr_bits =
289                                        record.public_values.previous_init_addr_bits;
290                                    state.last_init_addr_bits =
291                                        record.public_values.last_init_addr_bits;
292                                    state.previous_finalize_addr_bits =
293                                        record.public_values.previous_finalize_addr_bits;
294                                    state.last_finalize_addr_bits =
295                                        record.public_values.last_finalize_addr_bits;
296                                    state.start_pc = state.next_pc;
297                                    record.public_values = *state;
298                                }
299                                records_clone.append(&mut deferred);
300
301                                // Generate the dependencies.
302                                tracing::debug_span!("generate dependencies", index).in_scope(
303                                    || {
304                                        prover.machine().generate_dependencies(
305                                            &mut records_clone,
306                                            &opts,
307                                            None,
308                                        );
309                                    },
310                                );
311
312                                // Let another worker update the state.
313                                record_gen_sync.advance_turn();
314
315                                // Fix the shape of the records.
316                                let mut fixed_shape = true;
317                                if let Some(shape_config) = shape_config {
318                                    for record in records_clone.iter_mut() {
319                                        if shape_config.fix_shape(record).is_err() {
320                                            fixed_shape = false;
321                                        }
322                                    }
323                                }
324                                fixed_shape.then_some(records_clone)
325                            } else {
326                                None
327                            };
328
329                            if shape_fixed_records.is_none() {
330                                // See if any deferred shards are ready to be committed to.
331                                let mut deferred = deferred.split(done, None, opts.split_opts);
332                                tracing::debug!("deferred {} records", deferred.len());
333
334                                // Update the public values & prover state for the shards which do
335                                // not contain "cpu events" before
336                                // committing to them.
337                                if !done {
338                                    state.execution_shard += 1;
339                                }
340                                for record in deferred.iter_mut() {
341                                    state.shard += 1;
342                                    state.previous_init_addr_bits =
343                                        record.public_values.previous_init_addr_bits;
344                                    state.last_init_addr_bits =
345                                        record.public_values.last_init_addr_bits;
346                                    state.previous_finalize_addr_bits =
347                                        record.public_values.previous_finalize_addr_bits;
348                                    state.last_finalize_addr_bits =
349                                        record.public_values.last_finalize_addr_bits;
350                                    state.start_pc = state.next_pc;
351                                    record.public_values = *state;
352                                }
353                                records.append(&mut deferred);
354
355                                // Generate the dependencies.
356                                tracing::debug_span!("generate dependencies", index).in_scope(
357                                    || {
358                                        prover.machine().generate_dependencies(
359                                            &mut records,
360                                            &opts,
361                                            None,
362                                        );
363                                    },
364                                );
365
366                                // Let another worker update the state.
367                                record_gen_sync.advance_turn();
368
369                                // Fix the shape of the records.
370                                if let Some(shape_config) = shape_config {
371                                    for record in records.iter_mut() {
372                                        shape_config.fix_shape(record).unwrap();
373                                    }
374                                }
375                                shape_fixed_records = Some(records);
376                            }
377
378                            let mut records = shape_fixed_records.unwrap();
379
380                            // Send the shapes to the channel, if necessary.
381                            for record in records.iter() {
382                                let mut heights = vec![];
383                                let chips = prover.shard_chips(record).collect::<Vec<_>>();
384                                if let Some(shape) = record.shape.as_ref() {
385                                    for chip in chips.iter() {
386                                        let id = RiscvAirId::from_str(&chip.name()).unwrap();
387                                        let height = shape.log2_height(&id).unwrap();
388                                        heights.push((chip.name().clone(), height));
389                                    }
390                                    shape_tx
391                                        .lock()
392                                        .unwrap()
393                                        .send((OrderedShape::from_log2_heights(&heights), done))
394                                        .unwrap();
395                                }
396                            }
397
398                            #[cfg(feature = "debug")]
399                            all_records_tx.send(records.clone()).unwrap();
400
401                            let mut main_traces = Vec::new();
402                            if let Some(malicious_trace_pv_generator) = malicious_trace_pv_generator
403                            {
404                                tracing::info_span!("generate main traces", index).in_scope(|| {
405                                    main_traces = records
406                                        .par_iter_mut()
407                                        .map(|record| malicious_trace_pv_generator(prover, record))
408                                        .collect::<Vec<_>>();
409                                });
410                            } else {
411                                tracing::info_span!("generate main traces", index).in_scope(|| {
412                                    main_traces = records
413                                        .par_iter()
414                                        .map(|record| prover.generate_traces(record))
415                                        .collect::<Vec<_>>();
416                                });
417                            }
418
419                            trace_gen_sync.wait_for_turn(index);
420
421                            // Send the records to the phase 2 prover.
422                            let chunked_records = chunk_vec(records, opts.shard_batch_size);
423                            let chunked_main_traces = chunk_vec(main_traces, opts.shard_batch_size);
424                            chunked_records
425                                .into_iter()
426                                .zip(chunked_main_traces.into_iter())
427                                .for_each(|(records, main_traces)| {
428                                    records_and_traces_tx
429                                        .lock()
430                                        .unwrap()
431                                        .send((records, main_traces))
432                                        .unwrap();
433                                });
434
435                            trace_gen_sync.advance_turn();
436                        } else {
437                            break;
438                        }
439                    }
440                })
441            });
442            p2_record_and_trace_gen_handles.push(handle);
443        }
444        drop(p2_records_and_traces_tx);
445        #[cfg(feature = "debug")]
446        drop(all_records_tx);
447
448        // Spawn the phase 2 prover thread.
449        let p2_prover_span = tracing::Span::current().clone();
450        let proof_tx = Arc::new(Mutex::new(proof_tx));
451        let p2_prover_handle = s.spawn(move || {
452            let _span = p2_prover_span.enter();
453            tracing::debug_span!("phase 2 prover").in_scope(|| {
454                for (records, traces) in p2_records_and_traces_rx.into_iter() {
455                    tracing::debug_span!("batch").in_scope(|| {
456                        let span = tracing::Span::current().clone();
457                        let proofs = records
458                            .into_par_iter()
459                            .zip(traces.into_par_iter())
460                            .map(|(record, main_traces)| {
461                                let _span = span.enter();
462
463                                let shard = record.shard();
464                                let before = Instant::now();
465
466                                let main_data = tracing::debug_span!("commit", shard)
467                                    .in_scope(|| prover.commit(&record, main_traces));
468
469                                let proof = tracing::debug_span!("opening", shard).in_scope(|| {
470                                    prover.open(pk, main_data, &mut challenger.clone()).unwrap()
471                                });
472
473                                let elapsed = before.elapsed();
474
475                                // Log the shard heights/shape as well as how long it took to prove.
476                                let debug_shapes = record.shape.as_ref().map(|shape| {
477                                    shape
478                                        .iter()
479                                        .filter_map(|(&k, &v)| (v > 0).then_some((k, v)))
480                                        .collect::<Vec<_>>()
481                                });
482                                tracing::debug!(
483                                    "proving shard {shard} took {} ns. shape: {:?}",
484                                    elapsed.as_nanos(),
485                                    debug_shapes
486                                );
487
488                                #[cfg(debug_assertions)]
489                                {
490                                    if let Some(shape) = record.shape.as_ref() {
491                                        assert_eq!(
492                                            proof.shape(),
493                                            shape
494                                                .clone()
495                                                .into_iter()
496                                                .map(|(k, v)| (k.to_string(), v as usize))
497                                                .collect(),
498                                        );
499                                    }
500                                }
501
502                                rayon::spawn(move || {
503                                    drop(record);
504                                });
505
506                                proof
507                            })
508                            .collect::<Vec<_>>();
509
510                        // Send the batch of proofs to the channel.
511                        let proof_tx = proof_tx.lock().unwrap();
512                        for proof in proofs {
513                            proof_tx.send(proof).unwrap();
514                        }
515                    });
516                }
517            });
518        });
519
520        // Wait until the checkpoint generator handle has fully finished.
521        let runtime = checkpoint_generator_handle.join().unwrap().unwrap();
522        let gas = gas_calculator.map(|calc| calc(runtime.record_estimator.as_ref().unwrap()));
523        let public_values_stream = runtime.state.public_values_stream;
524
525        // Wait until the records and traces have been fully generated for phase 2.
526        p2_record_and_trace_gen_handles.into_iter().for_each(|handle| handle.join().unwrap());
527
528        // Wait until the phase 2 prover has finished.
529        p2_prover_handle.join().unwrap();
530
531        // Log some of the `ExecutionReport` information.
532        let mut report_aggregate = report_aggregate.lock().unwrap();
533        tracing::debug!(
534            "execution report (totals): total_cycles={}, total_syscall_cycles={}, touched_memory_addresses={}",
535            report_aggregate.total_instruction_count(),
536            report_aggregate.total_syscall_count(),
537            report_aggregate.touched_memory_addresses,
538        );
539        match gas {
540            Some(Ok(gas)) => {
541                tracing::debug!("execution report (gas): {}", gas);
542                report_aggregate.gas = Some(gas);
543            }
544            Some(Err(err)) => tracing::error!("Encountered error while calculating gas: {}", err),
545            None => (),
546        }
547
548        // Print the opcode and syscall count tables like `du`: sorted by count (descending) and
549        // with the count in the first column.
550        tracing::debug!("execution report (opcode counts):");
551        let (width, lines) = sorted_table_lines(report_aggregate.opcode_counts.as_ref());
552        for (label, count) in lines {
553            if *count > 0 {
554                tracing::debug!("  {}", format_table_line(&width, &label, count));
555            } else {
556                tracing::debug!("  {}", format_table_line(&width, &label, count));
557            }
558        }
559
560        tracing::debug!("execution report (syscall counts):");
561        let (width, lines) = sorted_table_lines(report_aggregate.syscall_counts.as_ref());
562        for (label, count) in lines {
563            if *count > 0 {
564                tracing::debug!("  {}", format_table_line(&width, &label, count));
565            } else {
566                tracing::debug!("  {}", format_table_line(&width, &label, count));
567            }
568        }
569
570        let cycles = report_aggregate.total_instruction_count();
571
572        // Print the summary.
573        let proving_time = proving_start.elapsed().as_secs_f64();
574        tracing::debug!(
575            "summary: cycles={}, e2e={}s, khz={:.2}",
576            cycles,
577            proving_time,
578            (cycles as f64 / (proving_time * 1000.0) as f64),
579        );
580
581        #[cfg(feature = "debug")]
582        {
583            let all_records = all_records_rx.iter().flatten().collect::<Vec<_>>();
584            let mut challenger = prover.machine().config().challenger();
585            let pk_host = prover.pk_to_host(pk);
586            prover.machine().debug_constraints(&pk_host, all_records, &mut challenger);
587        }
588
589        Ok((public_values_stream, cycles))
590    })
591}
592
593pub fn trace_checkpoint<SC: StarkGenericConfig>(
594    program: Program,
595    file: &File,
596    opts: SP1CoreOpts,
597    shape_config: Option<&CoreShapeConfig<SC::Val>>,
598) -> (Vec<ExecutionRecord>, ExecutionReport)
599where
600    <SC as StarkGenericConfig>::Val: PrimeField32,
601{
602    let noop = NoOpSubproofVerifier;
603
604    let mut reader = std::io::BufReader::new(file);
605    let state: ExecutionState =
606        bincode::deserialize_from(&mut reader).expect("failed to deserialize state");
607    let mut runtime = Executor::recover(program, state, opts);
608    runtime.maximal_shapes = shape_config.map(|config| {
609        config.maximal_core_shapes(opts.shard_size.ilog2() as usize).into_iter().collect()
610    });
611
612    // We already passed the deferred proof verifier when creating checkpoints, so the proofs were
613    // already verified. So here we use a noop verifier to not print any warnings.
614    runtime.subproof_verifier = Some(&noop);
615
616    // Execute from the checkpoint.
617    let (records, done) = runtime.execute_record(true).unwrap();
618
619    let mut records = records.into_iter().map(|r| *r).collect::<Vec<_>>();
620    let pv = records.last().unwrap().public_values;
621
622    // Handle the case where the COMMIT happens across the last two shards.
623    if !done &&
624        (pv.committed_value_digest.iter().any(|v| *v != 0) ||
625            pv.deferred_proofs_digest.iter().any(|v| *v != 0))
626    {
627        // We turn off the `print_report` flag to avoid modifying the report.
628        runtime.print_report = false;
629        let (_, next_pv, _) = runtime.execute_state(true).unwrap();
630        for record in records.iter_mut() {
631            record.public_values.committed_value_digest = next_pv.committed_value_digest;
632            record.public_values.deferred_proofs_digest = next_pv.deferred_proofs_digest;
633        }
634    }
635
636    (records, runtime.report)
637}
638
639#[derive(Error, Debug)]
640pub enum SP1CoreProverError {
641    #[error("failed to execute program: {0}")]
642    ExecutionError(ExecutionError),
643    #[error("io error: {0}")]
644    IoError(io::Error),
645    #[error("serialization error: {0}")]
646    SerializationError(bincode::Error),
647}