1use anyhow::Result;
2use clap::ValueEnum;
3use serde::{de::DeserializeOwned, Deserialize, Serialize};
4use sp1_core_machine::io::SP1Stdin;
5use sp1_hypercube::{air::ShardRange, SP1PcsProofInner, ShardProof};
6use sp1_primitives::{io::SP1PublicValues, SP1GlobalContext};
7use sp1_recursion_circuit::machine::{
8 SP1CompressWithVKeyWitnessValues, SP1DeferredWitnessValues, SP1NormalizeWitnessValues,
9};
10pub use sp1_recursion_gnark_ffi::proof::{Groth16Bn254Proof, PlonkBn254Proof};
11use std::{fs::File, path::Path};
12use thiserror::Error;
13
14#[derive(Serialize, Deserialize, Clone)]
16#[serde(bound(serialize = "P: Serialize"))]
17#[serde(bound(deserialize = "P: DeserializeOwned"))]
18pub struct SP1ProofWithMetadata<P: Clone> {
19 pub proof: P,
20 pub stdin: SP1Stdin,
21 pub public_values: SP1PublicValues,
22 pub cycles: u64,
23}
24
25impl<P: Serialize + DeserializeOwned + Clone> SP1ProofWithMetadata<P> {
26 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
27 bincode::serialize_into(File::create(path).expect("failed to open file"), self)
28 .map_err(Into::into)
29 }
30
31 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
32 bincode::deserialize_from(File::open(path).expect("failed to open file"))
33 .map_err(Into::into)
34 }
35}
36
37impl<P: std::fmt::Debug + Clone> std::fmt::Debug for SP1ProofWithMetadata<P> {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.debug_struct("SP1ProofWithMetadata").field("proof", &self.proof).finish()
40 }
41}
42
43pub type SP1CoreProof = SP1ProofWithMetadata<SP1CoreProofData>;
45
46pub type SP1ReducedProof = SP1ProofWithMetadata<SP1ReducedProofData>;
49
50pub type SP1PlonkBn254Proof = SP1ProofWithMetadata<SP1PlonkBn254ProofData>;
52
53pub type SP1Groth16Bn254Proof = SP1ProofWithMetadata<SP1Groth16Bn254ProofData>;
55
56pub type SP1Proof = SP1ProofWithMetadata<SP1Bn254ProofData>;
58
59#[derive(Serialize, Deserialize, Clone)]
60pub struct SP1CoreProofData(pub Vec<ShardProof<SP1GlobalContext, SP1PcsProofInner>>);
61
62#[derive(Serialize, Deserialize, Clone)]
63pub struct SP1ReducedProofData(pub ShardProof<SP1GlobalContext, SP1PcsProofInner>);
64
65#[derive(Serialize, Deserialize, Clone)]
66pub struct SP1PlonkBn254ProofData(pub PlonkBn254Proof);
67
68#[derive(Serialize, Deserialize, Clone)]
69pub struct SP1Groth16Bn254ProofData(pub Groth16Bn254Proof);
70
71#[derive(Serialize, Deserialize, Clone)]
72pub enum SP1Bn254ProofData {
73 Plonk(PlonkBn254Proof),
74 Groth16(Groth16Bn254Proof),
75}
76
77impl SP1Bn254ProofData {
78 pub fn get_proof_system(&self) -> ProofSystem {
79 match self {
80 SP1Bn254ProofData::Plonk(_) => ProofSystem::Plonk,
81 SP1Bn254ProofData::Groth16(_) => ProofSystem::Groth16,
82 }
83 }
84
85 pub fn get_raw_proof(&self) -> &str {
86 match self {
87 SP1Bn254ProofData::Plonk(proof) => &proof.raw_proof,
88 SP1Bn254ProofData::Groth16(proof) => &proof.raw_proof,
89 }
90 }
91}
92
93#[derive(Debug, Default, Clone, ValueEnum, PartialEq, Eq)]
95pub enum ProverMode {
96 #[default]
97 Cpu,
98 Cuda,
99 Network,
100 Mock,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub enum ProofSystem {
105 Plonk,
106 Groth16,
107}
108
109impl ProofSystem {
110 pub fn as_str(&self) -> &'static str {
111 match self {
112 ProofSystem::Plonk => "Plonk",
113 ProofSystem::Groth16 => "Groth16",
114 }
115 }
116}
117
118#[derive(Error, Debug)]
119pub enum SP1RecursionProverError {
120 #[error("Runtime error: {0}")]
121 RuntimeError(String),
122}
123
124pub type SP1CompressWitness = SP1CompressWithVKeyWitnessValues<SP1PcsProofInner>;
125
126#[allow(clippy::large_enum_variant)]
127pub enum SP1CircuitWitness {
128 Core(SP1NormalizeWitnessValues<SP1GlobalContext, SP1PcsProofInner>),
129 Deferred(SP1DeferredWitnessValues<SP1GlobalContext, SP1PcsProofInner>),
130 Compress(SP1CompressWitness),
131 Shrink(SP1CompressWithVKeyWitnessValues<SP1PcsProofInner>),
132 Wrap(SP1CompressWithVKeyWitnessValues<SP1PcsProofInner>),
133}
134
135impl SP1CircuitWitness {
136 pub fn range(&self) -> ShardRange {
137 match self {
138 SP1CircuitWitness::Core(input) => input.range(),
139 SP1CircuitWitness::Deferred(input) => input.range(),
140 SP1CircuitWitness::Compress(input) => input.compress_val.range(),
141 SP1CircuitWitness::Shrink(_) => {
142 unimplemented!("Shrink witness does not need to have a range")
143 }
144 SP1CircuitWitness::Wrap(_) => {
145 unimplemented!("Wrap witness does not need to have a range")
146 }
147 }
148 }
149}