sp1_core_machine/utils/
test.rs

1use p3_air::Air;
2use p3_baby_bear::BabyBear;
3use p3_matrix::dense::RowMajorMatrix;
4use p3_uni_stark::SymbolicAirBuilder;
5use serde::{de::DeserializeOwned, Serialize};
6use sp1_core_executor::{ExecutionRecord, Executor, Program, SP1Context};
7use sp1_primitives::io::SP1PublicValues;
8use sp1_stark::{
9    air::MachineAir, baby_bear_poseidon2::BabyBearPoseidon2, Com, CpuProver,
10    DebugConstraintBuilder, InteractionBuilder, MachineProof, MachineProver, MachineRecord,
11    MachineVerificationError, OpeningProof, PcsProverData, ProverConstraintFolder, SP1CoreOpts,
12    StarkGenericConfig, StarkMachine, StarkProvingKey, StarkVerifyingKey, Val,
13    VerifierConstraintFolder,
14};
15
16use crate::{io::SP1Stdin, riscv::RiscvAir, shape::CoreShapeConfig};
17
18use super::prove_core;
19
20/// This type is the function signature used for malicious trace and public values generators for
21/// failure test cases.
22pub(crate) type MaliciousTracePVGeneratorType<Val, P> =
23    Box<dyn Fn(&P, &mut ExecutionRecord) -> Vec<(String, RowMajorMatrix<Val>)> + Send + Sync>;
24
25/// The canonical entry point for testing a [`Program`] and [`SP1Stdin`] with a [`MachineProver`].
26pub fn run_test<P: MachineProver<BabyBearPoseidon2, RiscvAir<BabyBear>>>(
27    mut program: Program,
28    inputs: SP1Stdin,
29) -> Result<SP1PublicValues, MachineVerificationError<BabyBearPoseidon2>> {
30    let shape_config = CoreShapeConfig::<BabyBear>::default();
31    shape_config.fix_preprocessed_shape(&mut program).unwrap();
32
33    let runtime = tracing::debug_span!("runtime.run(...)").in_scope(|| {
34        let mut runtime = Executor::new(program, SP1CoreOpts::default());
35        runtime.maximal_shapes = Some(
36            shape_config
37                .maximal_core_shapes(SP1CoreOpts::default().shard_size.ilog2() as usize)
38                .into_iter()
39                .collect(),
40        );
41        runtime.write_vecs(&inputs.buffer);
42        runtime.run().unwrap();
43        runtime
44    });
45    let public_values = SP1PublicValues::from(&runtime.state.public_values_stream);
46
47    let _ = run_test_core::<P>(runtime, inputs, Some(&shape_config), None)?;
48    Ok(public_values)
49}
50
51pub fn run_malicious_test<P: MachineProver<BabyBearPoseidon2, RiscvAir<BabyBear>>>(
52    mut program: Program,
53    inputs: SP1Stdin,
54    malicious_trace_pv_generator: MaliciousTracePVGeneratorType<BabyBear, P>,
55) -> Result<SP1PublicValues, MachineVerificationError<BabyBearPoseidon2>> {
56    let shape_config = CoreShapeConfig::<BabyBear>::default();
57    shape_config.fix_preprocessed_shape(&mut program).unwrap();
58
59    let runtime = tracing::debug_span!("runtime.run(...)").in_scope(|| {
60        let mut runtime = Executor::new(program, SP1CoreOpts::default());
61        runtime.maximal_shapes = Some(
62            shape_config
63                .maximal_core_shapes(SP1CoreOpts::default().shard_size.ilog2() as usize)
64                .into_iter()
65                .collect(),
66        );
67        runtime.write_vecs(&inputs.buffer);
68        runtime.run().unwrap();
69        runtime
70    });
71    let public_values = SP1PublicValues::from(&runtime.state.public_values_stream);
72
73    let result = run_test_core::<P>(
74        runtime,
75        inputs,
76        Some(&shape_config),
77        Some(malicious_trace_pv_generator),
78    );
79    if let Err(verification_error) = result {
80        Err(verification_error)
81    } else {
82        Ok(public_values)
83    }
84}
85
86#[allow(unused_variables)]
87pub fn run_test_core<P: MachineProver<BabyBearPoseidon2, RiscvAir<BabyBear>>>(
88    runtime: Executor,
89    inputs: SP1Stdin,
90    shape_config: Option<&CoreShapeConfig<BabyBear>>,
91    malicious_trace_pv_generator: Option<MaliciousTracePVGeneratorType<BabyBear, P>>,
92) -> Result<MachineProof<BabyBearPoseidon2>, MachineVerificationError<BabyBearPoseidon2>> {
93    let config = BabyBearPoseidon2::new();
94    let machine = RiscvAir::machine(config);
95    let prover = P::new(machine);
96
97    let (pk, vk) = prover.setup(runtime.program.as_ref());
98    let (proof, output, _) = prove_core(
99        &prover,
100        &pk,
101        &vk,
102        Program::clone(&runtime.program),
103        &inputs,
104        SP1CoreOpts::default(),
105        SP1Context::default(),
106        shape_config,
107        malicious_trace_pv_generator,
108    )
109    .unwrap();
110
111    let config = BabyBearPoseidon2::new();
112    let machine = RiscvAir::machine(config);
113    let (pk, vk) = machine.setup(runtime.program.as_ref());
114    let mut challenger = machine.config().challenger();
115    if let Err(e) = machine.verify(&vk, &proof, &mut challenger) {
116        Err(e)
117    } else {
118        Ok(proof)
119    }
120}
121
122#[allow(unused_variables)]
123pub fn run_test_machine_with_prover<SC, A, P: MachineProver<SC, A>>(
124    prover: &P,
125    records: Vec<A::Record>,
126    pk: P::DeviceProvingKey,
127    vk: StarkVerifyingKey<SC>,
128) -> Result<MachineProof<SC>, MachineVerificationError<SC>>
129where
130    A: MachineAir<SC::Val>
131        + Air<InteractionBuilder<Val<SC>>>
132        + for<'a> Air<VerifierConstraintFolder<'a, SC>>
133        + for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>
134        + Air<SymbolicAirBuilder<SC::Val>>,
135    A::Record: MachineRecord<Config = SP1CoreOpts>,
136    SC: StarkGenericConfig,
137    SC::Val: p3_field::PrimeField32,
138    SC::Challenger: Clone,
139    Com<SC>: Send + Sync,
140    PcsProverData<SC>: Send + Sync + Serialize + DeserializeOwned,
141    OpeningProof<SC>: Send + Sync,
142{
143    let mut challenger = prover.config().challenger();
144    let prove_span = tracing::debug_span!("prove").entered();
145
146    #[cfg(feature = "debug")]
147    prover.machine().debug_constraints(
148        &prover.pk_to_host(&pk),
149        records.clone(),
150        &mut challenger.clone(),
151    );
152
153    let proof = prover.prove(&pk, records, &mut challenger, SP1CoreOpts::default()).unwrap();
154    prove_span.exit();
155    let nb_bytes = bincode::serialize(&proof).unwrap().len();
156
157    let mut challenger = prover.config().challenger();
158    prover.machine().verify(&vk, &proof, &mut challenger)?;
159
160    Ok(proof)
161}
162
163#[allow(unused_variables)]
164pub fn run_test_machine<SC, A>(
165    records: Vec<A::Record>,
166    machine: StarkMachine<SC, A>,
167    pk: StarkProvingKey<SC>,
168    vk: StarkVerifyingKey<SC>,
169) -> Result<MachineProof<SC>, MachineVerificationError<SC>>
170where
171    A: MachineAir<SC::Val>
172        + for<'a> Air<ProverConstraintFolder<'a, SC>>
173        + Air<InteractionBuilder<Val<SC>>>
174        + for<'a> Air<VerifierConstraintFolder<'a, SC>>
175        + for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>
176        + Air<SymbolicAirBuilder<SC::Val>>,
177    A::Record: MachineRecord<Config = SP1CoreOpts>,
178    SC: StarkGenericConfig,
179    SC::Val: p3_field::PrimeField32,
180    SC::Challenger: Clone,
181    Com<SC>: Send + Sync,
182    PcsProverData<SC>: Send + Sync + Serialize + DeserializeOwned,
183    OpeningProof<SC>: Send + Sync,
184{
185    let prover = CpuProver::new(machine);
186    run_test_machine_with_prover::<SC, A, CpuProver<_, _>>(&prover, records, pk, vk)
187}