1pub mod builder;
6pub mod prove;
7
8use anyhow::Result;
9use prove::CudaProveBuilder;
10use sp1_core_executor::SP1ContextBuilder;
11use sp1_core_machine::io::SP1Stdin;
12use sp1_cuda::{MoongateServer, SP1CudaProver};
13use sp1_prover::{components::CpuProverComponents, SP1Prover};
14
15use crate::{
16 cpu::execute::CpuExecuteBuilder, install::try_install_circuit_artifacts, Prover, SP1Proof,
17 SP1ProofMode, SP1ProofWithPublicValues, SP1ProvingKey, SP1VerifyingKey,
18};
19
20pub struct CudaProver {
22 pub(crate) cpu_prover: SP1Prover<CpuProverComponents>,
23 pub(crate) cuda_prover: SP1CudaProver,
24}
25
26impl CudaProver {
27 pub fn new(prover: SP1Prover, moongate_server: MoongateServer) -> Self {
29 let cuda_prover = SP1CudaProver::new(moongate_server);
30 Self {
31 cpu_prover: prover,
32 cuda_prover: cuda_prover.expect("Failed to initialize CUDA prover"),
33 }
34 }
35
36 pub fn execute<'a>(&'a self, elf: &'a [u8], stdin: &SP1Stdin) -> CpuExecuteBuilder<'a> {
53 CpuExecuteBuilder {
54 prover: &self.cpu_prover,
55 elf,
56 stdin: stdin.clone(),
57 context_builder: SP1ContextBuilder::default(),
58 }
59 }
60
61 pub fn prove<'a>(&'a self, pk: &'a SP1ProvingKey, stdin: &'a SP1Stdin) -> CudaProveBuilder<'a> {
78 CudaProveBuilder { prover: self, mode: SP1ProofMode::Core, pk, stdin: stdin.clone() }
79 }
80
81 pub fn prove_with_cycles(
85 &self,
86 pk: &SP1ProvingKey,
87 stdin: &SP1Stdin,
88 kind: SP1ProofMode,
89 ) -> Result<(SP1ProofWithPublicValues, u64)> {
90 let proof = self.cuda_prover.prove_core_stateless(pk, stdin)?;
92 let cycles = proof.cycles;
94 if kind == SP1ProofMode::Core {
95 let proof_with_pv = SP1ProofWithPublicValues::new(
96 SP1Proof::Core(proof.proof.0),
97 proof.public_values,
98 self.version().to_string(),
99 );
100 return Ok((proof_with_pv, cycles));
101 }
102
103 let deferred_proofs =
105 stdin.proofs.iter().map(|(reduce_proof, _)| reduce_proof.clone()).collect();
106 let public_values = proof.public_values.clone();
107 let reduce_proof = self.cuda_prover.compress(&pk.vk, proof, deferred_proofs)?;
108 if kind == SP1ProofMode::Compressed {
109 let proof_with_pv = SP1ProofWithPublicValues::new(
110 SP1Proof::Compressed(Box::new(reduce_proof)),
111 public_values,
112 self.version().to_string(),
113 );
114 return Ok((proof_with_pv, cycles));
115 }
116
117 let compress_proof = self.cuda_prover.shrink(reduce_proof)?;
119
120 let outer_proof = self.cuda_prover.wrap_bn254(compress_proof)?;
122
123 if kind == SP1ProofMode::Plonk {
124 let plonk_bn254_artifacts = if sp1_prover::build::sp1_dev_mode() {
125 sp1_prover::build::try_build_plonk_bn254_artifacts_dev(
126 &outer_proof.vk,
127 &outer_proof.proof,
128 )
129 } else {
130 try_install_circuit_artifacts("plonk")
131 };
132 let proof = self.cpu_prover.wrap_plonk_bn254(outer_proof, &plonk_bn254_artifacts);
133 let proof_with_pv = SP1ProofWithPublicValues::new(
134 SP1Proof::Plonk(proof),
135 public_values,
136 self.version().to_string(),
137 );
138 return Ok((proof_with_pv, cycles));
139 } else if kind == SP1ProofMode::Groth16 {
140 let groth16_bn254_artifacts = if sp1_prover::build::sp1_dev_mode() {
141 sp1_prover::build::try_build_groth16_bn254_artifacts_dev(
142 &outer_proof.vk,
143 &outer_proof.proof,
144 )
145 } else {
146 try_install_circuit_artifacts("groth16")
147 };
148
149 let proof = self.cpu_prover.wrap_groth16_bn254(outer_proof, &groth16_bn254_artifacts);
150 let proof_with_pv = SP1ProofWithPublicValues::new(
151 SP1Proof::Groth16(proof),
152 public_values,
153 self.version().to_string(),
154 );
155 return Ok((proof_with_pv, cycles));
156 }
157
158 unreachable!()
159 }
160}
161
162impl Prover<CpuProverComponents> for CudaProver {
163 fn setup(&self, elf: &[u8]) -> (SP1ProvingKey, SP1VerifyingKey) {
164 let (pk, vk) = self.cuda_prover.setup(elf).unwrap();
165 (pk, vk)
166 }
167
168 fn inner(&self) -> &SP1Prover<CpuProverComponents> {
169 &self.cpu_prover
170 }
171
172 fn prove(
173 &self,
174 pk: &SP1ProvingKey,
175 stdin: &SP1Stdin,
176 kind: SP1ProofMode,
177 ) -> Result<SP1ProofWithPublicValues> {
178 self.prove_with_cycles(pk, stdin, kind).map(|(p, _)| p)
179 }
180}
181
182impl Default for CudaProver {
183 fn default() -> Self {
184 Self::new(SP1Prover::new(), MoongateServer::default())
185 }
186}