1use p3_air::Air;
2use p3_baby_bear::BabyBear;
3use p3_matrix::dense::RowMajorMatrix;
4use p3_uni_stark::SymbolicAirBuilder;
5use serde::{de::DeserializeOwned, Serialize};
6use sp1_core_executor::{ExecutionRecord, Executor, Program, SP1Context};
7use sp1_primitives::io::SP1PublicValues;
8use sp1_stark::{
9 air::MachineAir, baby_bear_poseidon2::BabyBearPoseidon2, Com, CpuProver,
10 DebugConstraintBuilder, InteractionBuilder, MachineProof, MachineProver, MachineRecord,
11 MachineVerificationError, OpeningProof, PcsProverData, ProverConstraintFolder, SP1CoreOpts,
12 StarkGenericConfig, StarkMachine, StarkProvingKey, StarkVerifyingKey, Val,
13 VerifierConstraintFolder,
14};
15
16use crate::{io::SP1Stdin, riscv::RiscvAir, shape::CoreShapeConfig};
17
18use super::prove_core;
19
20pub(crate) type MaliciousTracePVGeneratorType<Val, P> =
23 Box<dyn Fn(&P, &mut ExecutionRecord) -> Vec<(String, RowMajorMatrix<Val>)> + Send + Sync>;
24
25pub fn run_test<P: MachineProver<BabyBearPoseidon2, RiscvAir<BabyBear>>>(
27 mut program: Program,
28 inputs: SP1Stdin,
29) -> Result<SP1PublicValues, MachineVerificationError<BabyBearPoseidon2>> {
30 let shape_config = CoreShapeConfig::<BabyBear>::default();
31 shape_config.fix_preprocessed_shape(&mut program).unwrap();
32
33 let runtime = tracing::debug_span!("runtime.run(...)").in_scope(|| {
34 let mut runtime = Executor::new(program, SP1CoreOpts::default());
35 runtime.maximal_shapes = Some(
36 shape_config
37 .maximal_core_shapes(SP1CoreOpts::default().shard_size.ilog2() as usize)
38 .into_iter()
39 .collect(),
40 );
41 runtime.write_vecs(&inputs.buffer);
42 runtime.run().unwrap();
43 runtime
44 });
45 let public_values = SP1PublicValues::from(&runtime.state.public_values_stream);
46
47 let _ = run_test_core::<P>(runtime, inputs, Some(&shape_config), None)?;
48 Ok(public_values)
49}
50
51pub fn run_malicious_test<P: MachineProver<BabyBearPoseidon2, RiscvAir<BabyBear>>>(
52 mut program: Program,
53 inputs: SP1Stdin,
54 malicious_trace_pv_generator: MaliciousTracePVGeneratorType<BabyBear, P>,
55) -> Result<SP1PublicValues, MachineVerificationError<BabyBearPoseidon2>> {
56 let shape_config = CoreShapeConfig::<BabyBear>::default();
57 shape_config.fix_preprocessed_shape(&mut program).unwrap();
58
59 let runtime = tracing::debug_span!("runtime.run(...)").in_scope(|| {
60 let mut runtime = Executor::new(program, SP1CoreOpts::default());
61 runtime.maximal_shapes = Some(
62 shape_config
63 .maximal_core_shapes(SP1CoreOpts::default().shard_size.ilog2() as usize)
64 .into_iter()
65 .collect(),
66 );
67 runtime.write_vecs(&inputs.buffer);
68 runtime.run().unwrap();
69 runtime
70 });
71 let public_values = SP1PublicValues::from(&runtime.state.public_values_stream);
72
73 let result = run_test_core::<P>(
74 runtime,
75 inputs,
76 Some(&shape_config),
77 Some(malicious_trace_pv_generator),
78 );
79 if let Err(verification_error) = result {
80 Err(verification_error)
81 } else {
82 Ok(public_values)
83 }
84}
85
86#[allow(unused_variables)]
87pub fn run_test_core<P: MachineProver<BabyBearPoseidon2, RiscvAir<BabyBear>>>(
88 runtime: Executor,
89 inputs: SP1Stdin,
90 shape_config: Option<&CoreShapeConfig<BabyBear>>,
91 malicious_trace_pv_generator: Option<MaliciousTracePVGeneratorType<BabyBear, P>>,
92) -> Result<MachineProof<BabyBearPoseidon2>, MachineVerificationError<BabyBearPoseidon2>> {
93 let config = BabyBearPoseidon2::new();
94 let machine = RiscvAir::machine(config);
95 let prover = P::new(machine);
96
97 let (pk, vk) = prover.setup(runtime.program.as_ref());
98 let (proof, output, _) = prove_core(
99 &prover,
100 &pk,
101 &vk,
102 Program::clone(&runtime.program),
103 &inputs,
104 SP1CoreOpts::default(),
105 SP1Context::default(),
106 shape_config,
107 malicious_trace_pv_generator,
108 )
109 .unwrap();
110
111 let config = BabyBearPoseidon2::new();
112 let machine = RiscvAir::machine(config);
113 let (pk, vk) = machine.setup(runtime.program.as_ref());
114 let mut challenger = machine.config().challenger();
115 if let Err(e) = machine.verify(&vk, &proof, &mut challenger) {
116 Err(e)
117 } else {
118 Ok(proof)
119 }
120}
121
122#[allow(unused_variables)]
123pub fn run_test_machine_with_prover<SC, A, P: MachineProver<SC, A>>(
124 prover: &P,
125 records: Vec<A::Record>,
126 pk: P::DeviceProvingKey,
127 vk: StarkVerifyingKey<SC>,
128) -> Result<MachineProof<SC>, MachineVerificationError<SC>>
129where
130 A: MachineAir<SC::Val>
131 + Air<InteractionBuilder<Val<SC>>>
132 + for<'a> Air<VerifierConstraintFolder<'a, SC>>
133 + for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>
134 + Air<SymbolicAirBuilder<SC::Val>>,
135 A::Record: MachineRecord<Config = SP1CoreOpts>,
136 SC: StarkGenericConfig,
137 SC::Val: p3_field::PrimeField32,
138 SC::Challenger: Clone,
139 Com<SC>: Send + Sync,
140 PcsProverData<SC>: Send + Sync + Serialize + DeserializeOwned,
141 OpeningProof<SC>: Send + Sync,
142{
143 let mut challenger = prover.config().challenger();
144 let prove_span = tracing::debug_span!("prove").entered();
145
146 #[cfg(feature = "debug")]
147 prover.machine().debug_constraints(
148 &prover.pk_to_host(&pk),
149 records.clone(),
150 &mut challenger.clone(),
151 );
152
153 let proof = prover.prove(&pk, records, &mut challenger, SP1CoreOpts::default()).unwrap();
154 prove_span.exit();
155 let nb_bytes = bincode::serialize(&proof).unwrap().len();
156
157 let mut challenger = prover.config().challenger();
158 prover.machine().verify(&vk, &proof, &mut challenger)?;
159
160 Ok(proof)
161}
162
163#[allow(unused_variables)]
164pub fn run_test_machine<SC, A>(
165 records: Vec<A::Record>,
166 machine: StarkMachine<SC, A>,
167 pk: StarkProvingKey<SC>,
168 vk: StarkVerifyingKey<SC>,
169) -> Result<MachineProof<SC>, MachineVerificationError<SC>>
170where
171 A: MachineAir<SC::Val>
172 + for<'a> Air<ProverConstraintFolder<'a, SC>>
173 + Air<InteractionBuilder<Val<SC>>>
174 + for<'a> Air<VerifierConstraintFolder<'a, SC>>
175 + for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>
176 + Air<SymbolicAirBuilder<SC::Val>>,
177 A::Record: MachineRecord<Config = SP1CoreOpts>,
178 SC: StarkGenericConfig,
179 SC::Val: p3_field::PrimeField32,
180 SC::Challenger: Clone,
181 Com<SC>: Send + Sync,
182 PcsProverData<SC>: Send + Sync + Serialize + DeserializeOwned,
183 OpeningProof<SC>: Send + Sync,
184{
185 let prover = CpuProver::new(machine);
186 run_test_machine_with_prover::<SC, A, CpuProver<_, _>>(&prover, records, pk, vk)
187}