Skip to main content

sp1_cuda/
lib.rs

1/// The shared API between the client and server.
2pub mod api;
3
4/// The client that interacts with the CUDA server.
5pub mod client;
6
7/// The proving key type, which is a "remote" reference to a key held by the CUDA server.
8pub mod pk;
9
10/// The server startup logic.
11mod server;
12
13mod error;
14pub use error::CudaClientError;
15
16pub use pk::CudaProvingKey;
17use sp1_core_executor::SP1Context;
18use sp1_core_machine::io::SP1Stdin;
19use sp1_primitives::Elf;
20use sp1_prover::worker::ProofFromNetwork;
21use sp1_prover_types::network_base_types::ProofMode;
22
23use crate::client::CudaClient;
24
25#[derive(Clone)]
26pub struct CudaProver {
27    client: CudaClient,
28}
29
30impl CudaProver {
31    /// Create a new prover, using the 0th CUDA device.
32    pub async fn new() -> Result<Self, CudaClientError> {
33        Ok(Self { client: CudaClient::connect(0).await? })
34    }
35
36    /// Create a new prover, using the given CUDA device.
37    pub async fn new_with_id(cuda_id: u32) -> Result<Self, CudaClientError> {
38        Ok(Self { client: CudaClient::connect(cuda_id).await? })
39    }
40
41    pub async fn setup(&self, elf: Elf) -> Result<CudaProvingKey, CudaClientError> {
42        self.client.setup(elf).await
43    }
44
45    pub async fn prove_with_mode(
46        &self,
47        pk: &CudaProvingKey,
48        stdin: SP1Stdin,
49        context: SP1Context<'static>,
50        mode: ProofMode,
51    ) -> Result<ProofFromNetwork, CudaClientError> {
52        self.client.prove_with_mode(pk, stdin, context, mode).await
53    }
54}