sp1_sdk/cuda/
mod.rs

1//! # SP1 CUDA Prover
2//!
3//! A prover that uses the CUDA to execute and prove programs.
4
5pub mod builder;
6pub mod prove;
7
8use anyhow::Result;
9use prove::CudaProveBuilder;
10use sp1_core_executor::SP1ContextBuilder;
11use sp1_core_machine::io::SP1Stdin;
12use sp1_cuda::{MoongateServer, SP1CudaProver};
13use sp1_prover::{components::CpuProverComponents, SP1Prover};
14
15use crate::{
16    cpu::execute::CpuExecuteBuilder, install::try_install_circuit_artifacts, Prover, SP1Proof,
17    SP1ProofMode, SP1ProofWithPublicValues, SP1ProvingKey, SP1VerifyingKey,
18};
19
20/// A prover that uses the CPU for execution and the CUDA for proving.
21pub struct CudaProver {
22    pub(crate) cpu_prover: SP1Prover<CpuProverComponents>,
23    pub(crate) cuda_prover: SP1CudaProver,
24}
25
26impl CudaProver {
27    /// Creates a new [`CudaProver`].
28    pub fn new(prover: SP1Prover, moongate_server: MoongateServer) -> Self {
29        let cuda_prover = SP1CudaProver::new(moongate_server);
30        Self {
31            cpu_prover: prover,
32            cuda_prover: cuda_prover.expect("Failed to initialize CUDA prover"),
33        }
34    }
35
36    /// Creates a new [`CpuExecuteBuilder`] for simulating the execution of a program on the CPU.
37    ///
38    /// # Details
39    /// The builder is used for both the [`crate::cpu::CpuProver`] and [`crate::CudaProver`] client
40    /// types.
41    ///
42    /// # Example
43    /// ```rust,no_run
44    /// use sp1_sdk::{include_elf, Prover, ProverClient, SP1Stdin};
45    ///
46    /// let elf = &[1, 2, 3];
47    /// let stdin = SP1Stdin::new();
48    ///
49    /// let client = ProverClient::builder().cuda().build();
50    /// let (public_values, execution_report) = client.execute(elf, &stdin).run().unwrap();
51    /// ```
52    pub fn execute<'a>(&'a self, elf: &'a [u8], stdin: &SP1Stdin) -> CpuExecuteBuilder<'a> {
53        CpuExecuteBuilder {
54            prover: &self.cpu_prover,
55            elf,
56            stdin: stdin.clone(),
57            context_builder: SP1ContextBuilder::default(),
58        }
59    }
60
61    /// Creates a new [`CudaProveBuilder`] for proving a program on the CUDA.
62    ///
63    /// # Details
64    /// The builder is used for only the [`crate::CudaProver`] client type.
65    ///
66    /// # Example
67    /// ```rust,no_run
68    /// use sp1_sdk::{include_elf, Prover, ProverClient, SP1Stdin};
69    ///
70    /// let elf = &[1, 2, 3];
71    /// let stdin = SP1Stdin::new();
72    ///
73    /// let client = ProverClient::builder().cuda().build();
74    /// let (pk, vk) = client.setup(elf);
75    /// let proof = client.prove(&pk, &stdin).run().unwrap();
76    /// ```
77    pub fn prove<'a>(&'a self, pk: &'a SP1ProvingKey, stdin: &'a SP1Stdin) -> CudaProveBuilder<'a> {
78        CudaProveBuilder { prover: self, mode: SP1ProofMode::Core, pk, stdin: stdin.clone() }
79    }
80
81    /// Proves the given program on the given input in the given proof mode.
82    ///
83    /// Returns the cycle count in addition to the proof.
84    pub fn prove_with_cycles(
85        &self,
86        pk: &SP1ProvingKey,
87        stdin: &SP1Stdin,
88        kind: SP1ProofMode,
89    ) -> Result<(SP1ProofWithPublicValues, u64)> {
90        // Generate the core proof.
91        let proof = self.cuda_prover.prove_core_stateless(pk, stdin)?;
92        // TODO: Return the prover gas
93        let cycles = proof.cycles;
94        if kind == SP1ProofMode::Core {
95            let proof_with_pv = SP1ProofWithPublicValues::new(
96                SP1Proof::Core(proof.proof.0),
97                proof.public_values,
98                self.version().to_string(),
99            );
100            return Ok((proof_with_pv, cycles));
101        }
102
103        // Generate the compressed proof.
104        let deferred_proofs =
105            stdin.proofs.iter().map(|(reduce_proof, _)| reduce_proof.clone()).collect();
106        let public_values = proof.public_values.clone();
107        let reduce_proof = self.cuda_prover.compress(&pk.vk, proof, deferred_proofs)?;
108        if kind == SP1ProofMode::Compressed {
109            let proof_with_pv = SP1ProofWithPublicValues::new(
110                SP1Proof::Compressed(Box::new(reduce_proof)),
111                public_values,
112                self.version().to_string(),
113            );
114            return Ok((proof_with_pv, cycles));
115        }
116
117        // Generate the shrink proof.
118        let compress_proof = self.cuda_prover.shrink(reduce_proof)?;
119
120        // Genenerate the wrap proof.
121        let outer_proof = self.cuda_prover.wrap_bn254(compress_proof)?;
122
123        if kind == SP1ProofMode::Plonk {
124            let plonk_bn254_artifacts = if sp1_prover::build::sp1_dev_mode() {
125                sp1_prover::build::try_build_plonk_bn254_artifacts_dev(
126                    &outer_proof.vk,
127                    &outer_proof.proof,
128                )
129            } else {
130                try_install_circuit_artifacts("plonk")
131            };
132            let proof = self.cpu_prover.wrap_plonk_bn254(outer_proof, &plonk_bn254_artifacts);
133            let proof_with_pv = SP1ProofWithPublicValues::new(
134                SP1Proof::Plonk(proof),
135                public_values,
136                self.version().to_string(),
137            );
138            return Ok((proof_with_pv, cycles));
139        } else if kind == SP1ProofMode::Groth16 {
140            let groth16_bn254_artifacts = if sp1_prover::build::sp1_dev_mode() {
141                sp1_prover::build::try_build_groth16_bn254_artifacts_dev(
142                    &outer_proof.vk,
143                    &outer_proof.proof,
144                )
145            } else {
146                try_install_circuit_artifacts("groth16")
147            };
148
149            let proof = self.cpu_prover.wrap_groth16_bn254(outer_proof, &groth16_bn254_artifacts);
150            let proof_with_pv = SP1ProofWithPublicValues::new(
151                SP1Proof::Groth16(proof),
152                public_values,
153                self.version().to_string(),
154            );
155            return Ok((proof_with_pv, cycles));
156        }
157
158        unreachable!()
159    }
160}
161
162impl Prover<CpuProverComponents> for CudaProver {
163    fn setup(&self, elf: &[u8]) -> (SP1ProvingKey, SP1VerifyingKey) {
164        let (pk, vk) = self.cuda_prover.setup(elf).unwrap();
165        (pk, vk)
166    }
167
168    fn inner(&self) -> &SP1Prover<CpuProverComponents> {
169        &self.cpu_prover
170    }
171
172    fn prove(
173        &self,
174        pk: &SP1ProvingKey,
175        stdin: &SP1Stdin,
176        kind: SP1ProofMode,
177    ) -> Result<SP1ProofWithPublicValues> {
178        self.prove_with_cycles(pk, stdin, kind).map(|(p, _)| p)
179    }
180}
181
182impl Default for CudaProver {
183    fn default() -> Self {
184        Self::new(SP1Prover::new(), MoongateServer::default())
185    }
186}